diff --git a/include/sp/common/DataBuffer.h b/include/sp/common/DataBuffer.h index d9cf312..33e5697 100644 --- a/include/sp/common/DataBuffer.h +++ b/include/sp/common/DataBuffer.h @@ -35,6 +35,7 @@ class DataBuffer { typedef Data::difference_type difference_type; DataBuffer(); + DataBuffer(std::size_t a_InitialSize); DataBuffer(const DataBuffer& other); DataBuffer(const DataBuffer& other, difference_type offset); DataBuffer(DataBuffer&& other); diff --git a/include/sp/common/NonCopyable.h b/include/sp/common/NonCopyable.h new file mode 100644 index 0000000..85b155e --- /dev/null +++ b/include/sp/common/NonCopyable.h @@ -0,0 +1,25 @@ +#pragma once + +/** + * \file NonCopyable.h + * \brief File containing the sp::NonCopyable class + */ + +namespace sp { + +/** + * \class NonCopyable + * \brief Class used to make a class non copyable + * \note Inherit from this class privately to make a class non copyable + */ +class NonCopyable { + public: + NonCopyable(const NonCopyable&) = delete; + NonCopyable& operator=(const NonCopyable&) = delete; + + protected: + NonCopyable() {} + ~NonCopyable() {} +}; + +} // namespace sp diff --git a/include/sp/extensions/Extensions.h b/include/sp/extensions/Extensions.h index 68c30f4..8fd8c6a 100644 --- a/include/sp/extensions/Extensions.h +++ b/include/sp/extensions/Extensions.h @@ -2,4 +2,8 @@ #if __has_include() #include +#endif + +#if __has_include() + #include #endif \ No newline at end of file diff --git a/include/sp/extensions/Tcp.h b/include/sp/extensions/Tcp.h new file mode 100644 index 0000000..5315211 --- /dev/null +++ b/include/sp/extensions/Tcp.h @@ -0,0 +1,4 @@ +#pragma once + +#include +#include \ No newline at end of file diff --git a/include/sp/extensions/tcp/TcpListener.h b/include/sp/extensions/tcp/TcpListener.h new file mode 100644 index 0000000..98ed8bb --- /dev/null +++ b/include/sp/extensions/tcp/TcpListener.h @@ -0,0 +1,65 @@ +#pragma once + +#include + +namespace sp { +namespace io { + +/** + * \class TcpListener + */ +class TcpListener : private NonCopyable { + public: + /** + * \brief Starts listening for guests to connect + * \param port The port to listen to + * \param maxConnexions The maximum amount of connexion that can happen at the same time. \n + * Every other guests will be kicked if this amount is reached. + * \return Whether this action was succesfull + */ + TcpListener(std::uint16_t a_Port, int a_MaxConnexions); + + /** + * \brief Default destructor + */ + ~TcpListener(); + + /** + * \brief Tries to accept an incoming request to connect + * \return the new socket if a new connexion was accepted or nullptr + */ + std::unique_ptr Accept(); + + /** + * \brief Closes the socket + */ + void Close(); + + /** + * \brief Allows to set the socket in non blocking/blocking mode + * \param a_Blocking If set to true, every call to Read will wait until the socket receives something + * \return true if the operation was successful + */ + bool SetBlocking(bool a_Blocking); + + /** + * \brief Getter of the m_Port member + * \return The port which the socket listen to + */ + std::uint16_t GetListeningPort() const; + + /** + * \brief Getter of the m_MaxConnections member + * \return The maximum amount of connexions that can happen at the same time. + */ + int GetMaximumConnections() const; + + private: + SocketHandle m_Handle; + std::uint16_t m_Port; + int m_MaxConnections; +}; + + +} // namespace io +} // namespace sp diff --git a/include/sp/extensions/tcp/TcpSocket.h b/include/sp/extensions/tcp/TcpSocket.h new file mode 100644 index 0000000..1b14f3f --- /dev/null +++ b/include/sp/extensions/tcp/TcpSocket.h @@ -0,0 +1,86 @@ +#pragma once + +#include +#include + +namespace sp { +namespace io { + +using SocketHandle = int; + +struct TcpTag {}; + +class SocketError : public std::exception { + private: + std::string m_Error; + + public: + SocketError(std::string&& a_Msg) : m_Error(std::move(a_Msg)) {} + + virtual const char* what() const noexcept override { + return m_Error.c_str(); + } +}; + +template <> +class IOInterface : private NonCopyable { + public: + /** + * \enum Status + * \brief Describes the state of a socket + */ + enum class Status { + /** The socket is connected */ + Connected, + /** The socket is not connected */ + Disconnected, + /** Something bad happened */ + Error, + }; + + IOInterface(); + IOInterface(const std::string& a_Host, std::uint16_t a_Port); + IOInterface(IOInterface&& a_Other); + IOInterface& operator=(IOInterface&& a_Other); + virtual ~IOInterface(); + + DataBuffer Read(std::size_t a_Amount); + void Write(const sp::DataBuffer& a_Data); + + /** + * \brief Allows to set the socket in non blocking/blocking mode + * \param a_Block If set to true, every call to Read will wait until the socket receives something + * \return true if the operation was successful + */ + bool SetBlocking(bool a_Block); + + /** + * \brief Getter of the m_Status member + * \return The TcpSocket::Status of this socket + */ + Status GetStatus() const; + + /** + * \brief Disconnects the socket from the remote + * \note Does nothing if the socket is not connected. \n + * This function is also called by the destructor. + */ + void Disconnect(); + + + private: + SocketHandle m_Handle; + Status m_Status; + + void Connect(const std::string& a_Host, std::uint16_t a_Port); + + friend class TcpListener; +}; + +/** + * \typedef TcpSocket + */ +using TcpSocket = IOInterface; + +} // namespace io +} // namespace sp diff --git a/include/sp/protocol/MessageDispatcher.h b/include/sp/protocol/MessageDispatcher.h index cd6b5cd..1cdba2c 100644 --- a/include/sp/protocol/MessageDispatcher.h +++ b/include/sp/protocol/MessageDispatcher.h @@ -17,7 +17,7 @@ namespace sp { template class MessageDispatcher { private: - std::map>> m_Handlers; + std::map> m_Handlers; public: using MessageBaseType = MessageBase; @@ -38,18 +38,20 @@ class MessageDispatcher { * \param type The packet type * \param handler The packet handler */ - void RegisterHandler(MessageIdType a_MessageType, const std::shared_ptr& a_Handler); + void RegisterHandler(MessageIdType a_MessageType, MessageHandler* a_Handler); + /** * \brief Unregister a packet handler * \param type The packet type * \param handler The packet handler */ - void UnregisterHandler(MessageIdType a_MessageType, const std::shared_ptr& a_Handler); + void UnregisterHandler(MessageIdType a_MessageType, MessageHandler* a_Handler); + /** * \brief Unregister a packet handler * \param handler The packet handler */ - void UnregisterHandler(const std::shared_ptr& a_Handler); + void UnregisterHandler(MessageHandler* a_Handler); }; #include diff --git a/include/sp/protocol/message/MessageDispatcherImpl.inl b/include/sp/protocol/message/MessageDispatcherImpl.inl index 3531452..7319fbf 100644 --- a/include/sp/protocol/message/MessageDispatcherImpl.inl +++ b/include/sp/protocol/message/MessageDispatcherImpl.inl @@ -1,34 +1,36 @@ #pragma once template -void MessageDispatcher::RegisterHandler(MessageIdType a_MessageType, const std::shared_ptr& a_Handler) { +void MessageDispatcher::RegisterHandler(MessageIdType a_MessageType, MessageHandler* a_Handler) { + assert(a_Handler); auto found = std::find(m_Handlers[a_MessageType].begin(), m_Handlers[a_MessageType].end(), a_Handler); if (found == m_Handlers[a_MessageType].end()) m_Handlers[a_MessageType].push_back(a_Handler); } template -void MessageDispatcher::UnregisterHandler(MessageIdType a_MessageType, const std::shared_ptr& a_Handler) { +void MessageDispatcher::UnregisterHandler(MessageIdType a_MessageType, MessageHandler* a_Handler) { auto found = std::find(m_Handlers[a_MessageType].begin(), m_Handlers[a_MessageType].end(), a_Handler); if (found != m_Handlers[a_MessageType].end()) m_Handlers[a_MessageType].erase(found); } template -void MessageDispatcher::UnregisterHandler(const std::shared_ptr& a_Handler) { +void MessageDispatcher::UnregisterHandler(MessageHandler* a_Handler) { for (auto& pair : m_Handlers) { if (pair.second.empty()) continue; MessageIdType type = pair.first; - m_Handlers[type].erase(std::remove(m_Handlers[type].begin(), m_Handlers[type].end(), a_Handler), m_Handlers[type].end()); + pair.second.erase(std::remove(pair.second.begin(), pair.second.end(), a_Handler), pair.second.end()); } } template void MessageDispatcher::Dispatch(const MessageBase& a_Message) { MessageIdType type = a_Message.GetId(); - for (auto& handler : m_Handlers[type]) + for (auto& handler : m_Handlers[type]) { a_Message.Dispatch(*handler); + } } diff --git a/src/sp/common/DataBuffer.cpp b/src/sp/common/DataBuffer.cpp index 00d8335..03c8a4e 100644 --- a/src/sp/common/DataBuffer.cpp +++ b/src/sp/common/DataBuffer.cpp @@ -8,6 +8,8 @@ namespace sp { DataBuffer::DataBuffer() : m_ReadOffset(0) {} +DataBuffer::DataBuffer(std::size_t a_InitialSize) : m_Buffer(a_InitialSize), m_ReadOffset(0) {} + DataBuffer::DataBuffer(const DataBuffer& other) : m_Buffer(other.m_Buffer), m_ReadOffset(other.m_ReadOffset) {} DataBuffer::DataBuffer(DataBuffer&& other) : m_Buffer(std::move(other.m_Buffer)), m_ReadOffset(std::move(other.m_ReadOffset)) {} diff --git a/src/sp/extensions/Compress.cpp b/src/sp/extensions/Compress.cpp index 7d50fd5..fbb67ac 100644 --- a/src/sp/extensions/Compress.cpp +++ b/src/sp/extensions/Compress.cpp @@ -11,7 +11,7 @@ static DataBuffer Inflate(const std::uint8_t* source, std::size_t size, std::siz DataBuffer result; result.Resize(uncompressedSize); - uncompress(static_cast(result.data()), static_cast(&uncompressedSize), static_cast(source), + uncompress(static_cast(result.data()), reinterpret_cast(&uncompressedSize), static_cast(source), static_cast(size)); assert(result.GetSize() == uncompressedSize); diff --git a/src/sp/extensions/TcpListener.cpp b/src/sp/extensions/TcpListener.cpp new file mode 100644 index 0000000..330c975 --- /dev/null +++ b/src/sp/extensions/TcpListener.cpp @@ -0,0 +1,114 @@ +#include + + +#ifdef _WIN32 + +// Windows + +#include +#include + +#define ioctl ioctlsocket +#define WOULDBLOCK WSAEWOULDBLOCK +#define MSG_DONTWAIT 0 + +#else + +// Linux/Unix + +#include +#include +#include +#include +#include +#include +#include + +#define closesocket close +#define WOULDBLOCK EWOULDBLOCK +#define SD_BOTH SHUT_RDWR + +#endif + + + +#ifndef INVALID_SOCKET +#define INVALID_SOCKET -1 +#endif + + +namespace sp { +namespace io { + +TcpListener::TcpListener(std::uint16_t a_Port, int a_MaxConnexions) { + if ((m_Handle = static_cast(socket(AF_INET, SOCK_STREAM, IPPROTO_TCP))) < 0) { + throw SocketError("Failed to create server socket"); + } + + struct sockaddr_in address; + address.sin_family = AF_INET; + address.sin_addr.s_addr = INADDR_ANY; + address.sin_port = htons(a_Port); + + if (bind(m_Handle, reinterpret_cast(&address), sizeof(address)) < 0) + throw SocketError("Failed to create server socket"); + + if (listen(m_Handle, a_MaxConnexions) < 0) + throw SocketError("Failed to create server socket"); + + socklen_t len = sizeof(address); + if (getsockname(m_Handle, reinterpret_cast(&address), &len) < 0) + throw SocketError("Failed to create server socket"); + + m_Port = ntohs(address.sin_port); + m_MaxConnections = a_MaxConnexions; +} + + +TcpListener::~TcpListener() { + Close(); +} + +std::unique_ptr TcpListener::Accept() { + sockaddr remoteAddress; + int addrlen = sizeof(remoteAddress); + + auto newSocket = std::make_unique(); + + newSocket->m_Handle = static_cast( + accept(m_Handle, reinterpret_cast(&remoteAddress), reinterpret_cast(&addrlen))); + + if (newSocket->m_Handle < 0) + return nullptr; + + newSocket->m_Status = TcpSocket::Status::Connected; + return newSocket; +} + +void TcpListener::Close() { + if (m_Handle > 0) { + closesocket(m_Handle); + shutdown(m_Handle, SD_BOTH); + } +} + +bool TcpListener::SetBlocking(bool a_Blocking) { + unsigned long mode = !a_Blocking; + + if (ioctl(m_Handle, FIONBIO, &mode) < 0) { + return false; + } + + return true; +} + +std::uint16_t TcpListener::GetListeningPort() const { + return m_Port; +} + +int TcpListener::GetMaximumConnections() const { + return m_MaxConnections; +} + +} // namespace io +} // namespace sp diff --git a/src/sp/extensions/TcpSocket.cpp b/src/sp/extensions/TcpSocket.cpp new file mode 100644 index 0000000..ab08b3a --- /dev/null +++ b/src/sp/extensions/TcpSocket.cpp @@ -0,0 +1,164 @@ +#include + +#ifdef _WIN32 + +// Windows + +#include +#include + +#define ioctl ioctlsocket +#define WOULDBLOCK WSAEWOULDBLOCK +#define MSG_DONTWAIT 0 + +#else + +// Linux/Unix + +#include +#include +#include +#include +#include +#include +#include + +#define closesocket close +#define WOULDBLOCK EWOULDBLOCK + +#endif + + + +#ifndef INVALID_SOCKET +#define INVALID_SOCKET -1 +#endif + + +namespace sp { +namespace io { + +TcpSocket::IOInterface() : m_Handle(static_cast(INVALID_SOCKET)), m_Status(Status::Disconnected) {} + +TcpSocket::IOInterface(const std::string& a_Host, std::uint16_t a_Port) : IOInterface() { + Connect(a_Host, a_Port); +} + +TcpSocket::IOInterface(IOInterface&& a_Other) { + std::swap(m_Handle, a_Other.m_Handle); + std::swap(m_Status, a_Other.m_Status); +} + +TcpSocket::~IOInterface() {} + +void TcpSocket::Connect(const std::string& a_Host, std::uint16_t a_Port) { + struct addrinfo hints {}; + + struct addrinfo* result = nullptr; + + hints.ai_family = AF_INET; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = IPPROTO_TCP; + + m_Status = Status::Error; + + if (getaddrinfo(a_Host.c_str(), std::to_string(static_cast(a_Port)).c_str(), &hints, &result) != 0) { + throw SocketError("Failed to get address info"); + } + + m_Handle = static_cast(socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); + if (m_Handle < 0) { + throw SocketError("Failed to create socket"); + } + + struct addrinfo* ptr = nullptr; + for (ptr = result; ptr != nullptr; ptr = ptr->ai_next) { + struct sockaddr* sockaddr = ptr->ai_addr; + if (connect(m_Handle, sockaddr, sizeof(sockaddr_in)) == 0) { + break; + } + } + + freeaddrinfo(result); + + if (!ptr) { + throw SocketError("Could not find a suitable interface for connecting"); + } + + m_Status = Status::Connected; +} + +DataBuffer TcpSocket::Read(std::size_t a_Amount) { + DataBuffer buffer(a_Amount); + + std::size_t totalRecieved = 0; + + while (totalRecieved < a_Amount) { + int recvAmount = + recv(m_Handle, reinterpret_cast(buffer.data() + totalRecieved), static_cast(a_Amount - totalRecieved), 0); + if (recvAmount <= 0) { +#if defined(_WIN32) || defined(WIN32) + int err = WSAGetLastError(); +#else + int err = errno; +#endif + if (err == WOULDBLOCK) { + // we are in non blocking mode and nothing is available + return {}; + } + + Disconnect(); + m_Status = Status::Error; + throw SocketError("Error while reading"); + } + totalRecieved += recvAmount; + } + return buffer; +} + +void TcpSocket::Write(const sp::DataBuffer& a_Data) { + if (GetStatus() != Status::Connected) + return; + + std::size_t sent = 0; + + while (sent < a_Data.GetSize()) { + int cur = send(m_Handle, reinterpret_cast(a_Data.data() + sent), static_cast(a_Data.GetSize() - sent), 0); + + if (cur <= 0) { + Disconnect(); + m_Status = Status::Error; + return; + } + sent += static_cast(cur); + } +} + +bool TcpSocket::SetBlocking(bool a_Block) { + unsigned long mode = !a_Block; + + if (ioctl(m_Handle, FIONBIO, &mode) < 0) { + return false; + } + + return true; +} + +TcpSocket::Status TcpSocket::GetStatus() const { + return m_Status; +} + +void TcpSocket::Disconnect() { + if (m_Handle > 0) + closesocket(m_Handle); + m_Status = Status::Disconnected; +} + +TcpSocket& TcpSocket::operator=(IOInterface&& a_Other) { + std::swap(m_Handle, a_Other.m_Handle); + std::swap(m_Status, a_Other.m_Status); + return *this; +} + +} // namespace io +} // namespace sp diff --git a/test/test_file.cpp b/test/test_file.cpp index 2a89a3c..f777338 100644 --- a/test/test_file.cpp +++ b/test/test_file.cpp @@ -23,8 +23,8 @@ int main() { auto handler = std::make_shared(); FileStream stream(sp::io::File{"test.txt", sp::io::FileTag::In | sp::io::FileTag::Out}, {}); - stream.GetDispatcher().RegisterHandler(PacketId::Disconnect, handler); - stream.GetDispatcher().RegisterHandler(PacketId::KeepAlive, handler); + stream.GetDispatcher().RegisterHandler(PacketId::Disconnect, handler.get()); + stream.GetDispatcher().RegisterHandler(PacketId::KeepAlive, handler.get()); stream.SendMessage(KeepAlivePacket{96}); stream.SendMessage(KeepAlivePacket{69}); diff --git a/test/test_io.cpp b/test/test_io.cpp index c8affb6..d67c521 100644 --- a/test/test_io.cpp +++ b/test/test_io.cpp @@ -23,13 +23,13 @@ int main() { auto handler = std::make_shared(); DataBufferStream stream; - stream.GetDispatcher().RegisterHandler(PacketId::Disconnect, handler); + stream.GetDispatcher().RegisterHandler(PacketId::Disconnect, handler.get()); // this should not be dispatched stream.SendMessage(KeepAlivePacket{96}); stream.RecieveMessages(); - stream.GetDispatcher().RegisterHandler(PacketId::KeepAlive, handler); + stream.GetDispatcher().RegisterHandler(PacketId::KeepAlive, handler.get()); stream.SendMessage(KeepAlivePacket{69}); stream.RecieveMessages(); diff --git a/test/test_packets.cpp b/test/test_packets.cpp index e89e171..b17f582 100644 --- a/test/test_packets.cpp +++ b/test/test_packets.cpp @@ -48,10 +48,10 @@ int main() { packet->Dispatch(*handler); sp::PacketDispatcher dispatcher; - dispatcher.RegisterHandler(PacketId::KeepAlive, handler); + dispatcher.RegisterHandler(PacketId::KeepAlive, handler.get()); dispatcher.Dispatch(*packet); - dispatcher.UnregisterHandler(PacketId::KeepAlive, handler); - dispatcher.UnregisterHandler(handler); + dispatcher.UnregisterHandler(PacketId::KeepAlive, handler.get()); + dispatcher.UnregisterHandler(handler.get()); return 0; } \ No newline at end of file diff --git a/xmake.lua b/xmake.lua index 2012f06..1785976 100644 --- a/xmake.lua +++ b/xmake.lua @@ -6,12 +6,21 @@ local modules = { Compression = { Option = "zlib", Deps = {"zlib"}, - Packages = {"zlib"}, Includes = {"include/(sp/extensions/Compress.h)"}, Sources = {"src/sp/extensions/Compress.cpp"} + }, + TcpSocket = { + Option = "tcp", + Deps = {}, + Includes = {"include/(sp/extensions/Tcp.h)", "include/(sp/extensions/tcp/*.h)"}, + Sources = {"src/sp/extensions/Tcp*.cpp"} } } + + + + -- Map modules to options for name, module in table.orderpairs(modules) do if module.Option then @@ -19,6 +28,10 @@ for name, module in table.orderpairs(modules) do end end + + + + -- Add modules requirements for name, module in table.orderpairs(modules) do if module.Deps then @@ -26,6 +39,10 @@ for name, module in table.orderpairs(modules) do end end + + + + -- Add modules targets for name, module in table.orderpairs(modules) do if module.Deps and has_config(module.Option) then @@ -37,7 +54,7 @@ for name, module in table.orderpairs(modules) do for _, source in table.orderpairs(module.Sources) do add_files(source) end - for _, package in table.orderpairs(module.Packages) do + for _, package in table.orderpairs(module.Deps) do add_packages(package) end set_group("Library") @@ -45,14 +62,17 @@ for name, module in table.orderpairs(modules) do end end + + + + target("SimpleProtocol") add_includedirs("include") add_files("src/sp/**.cpp") - - local includeFolders = {"common", "default", "io", "protocol"} - for _, folder in ipairs(includeFolders) do - add_headerfiles("include/(sp/" .. folder .. "/**.h)") - end + set_group("Library") + set_kind("$(kind)") + + add_headerfiles("include/(sp/**.h)", "include/(sp/**.inl)") -- adding extensions for name, module in table.orderpairs(modules) do @@ -63,8 +83,14 @@ target("SimpleProtocol") -- we don't want extensions remove_files("src/sp/extensions/**.cpp") - set_group("Library") - set_kind("$(kind)") + remove_headerfiles("include/(sp/extension/**.h)") + + -- we need this for endian functions + if is_os("windows") then + add_links("ws2_32") + end + + -- Tests for _, file in ipairs(os.files("test/**.cpp")) do