From efcfae69dba48396ca6d3145f814b236be673daf Mon Sep 17 00:00:00 2001 From: Persson-dev Date: Sun, 2 Mar 2025 11:30:54 +0100 Subject: [PATCH 01/12] fix windows build --- src/sp/extensions/Compress.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sp/extensions/Compress.cpp b/src/sp/extensions/Compress.cpp index 7d50fd5..fbb67ac 100644 --- a/src/sp/extensions/Compress.cpp +++ b/src/sp/extensions/Compress.cpp @@ -11,7 +11,7 @@ static DataBuffer Inflate(const std::uint8_t* source, std::size_t size, std::siz DataBuffer result; result.Resize(uncompressedSize); - uncompress(static_cast(result.data()), static_cast(&uncompressedSize), static_cast(source), + uncompress(static_cast(result.data()), reinterpret_cast(&uncompressedSize), static_cast(source), static_cast(size)); assert(result.GetSize() == uncompressedSize); -- 2.49.1 From 4b2e4ca132461a6bda2d63c5b4348f70458295fb Mon Sep 17 00:00:00 2001 From: Persson-dev Date: Sun, 2 Mar 2025 11:31:07 +0100 Subject: [PATCH 02/12] add tcp support --- include/sp/common/DataBuffer.h | 1 + include/sp/common/NonCopyable.h | 25 ++++ include/sp/extensions/Extensions.h | 4 + include/sp/extensions/Tcp.h | 4 + include/sp/extensions/tcp/TcpListener.h | 65 ++++++++++ include/sp/extensions/tcp/TcpSocket.h | 85 +++++++++++++ src/sp/common/DataBuffer.cpp | 2 + src/sp/extensions/TcpListener.cpp | 114 +++++++++++++++++ src/sp/extensions/TcpSocket.cpp | 157 ++++++++++++++++++++++++ xmake.lua | 36 +++++- 10 files changed, 489 insertions(+), 4 deletions(-) create mode 100644 include/sp/common/NonCopyable.h create mode 100644 include/sp/extensions/Tcp.h create mode 100644 include/sp/extensions/tcp/TcpListener.h create mode 100644 include/sp/extensions/tcp/TcpSocket.h create mode 100644 src/sp/extensions/TcpListener.cpp create mode 100644 src/sp/extensions/TcpSocket.cpp diff --git a/include/sp/common/DataBuffer.h b/include/sp/common/DataBuffer.h index d9cf312..33e5697 100644 --- a/include/sp/common/DataBuffer.h +++ b/include/sp/common/DataBuffer.h @@ -35,6 +35,7 @@ class DataBuffer { typedef Data::difference_type difference_type; DataBuffer(); + DataBuffer(std::size_t a_InitialSize); DataBuffer(const DataBuffer& other); DataBuffer(const DataBuffer& other, difference_type offset); DataBuffer(DataBuffer&& other); diff --git a/include/sp/common/NonCopyable.h b/include/sp/common/NonCopyable.h new file mode 100644 index 0000000..85b155e --- /dev/null +++ b/include/sp/common/NonCopyable.h @@ -0,0 +1,25 @@ +#pragma once + +/** + * \file NonCopyable.h + * \brief File containing the sp::NonCopyable class + */ + +namespace sp { + +/** + * \class NonCopyable + * \brief Class used to make a class non copyable + * \note Inherit from this class privately to make a class non copyable + */ +class NonCopyable { + public: + NonCopyable(const NonCopyable&) = delete; + NonCopyable& operator=(const NonCopyable&) = delete; + + protected: + NonCopyable() {} + ~NonCopyable() {} +}; + +} // namespace sp diff --git a/include/sp/extensions/Extensions.h b/include/sp/extensions/Extensions.h index 68c30f4..8fd8c6a 100644 --- a/include/sp/extensions/Extensions.h +++ b/include/sp/extensions/Extensions.h @@ -2,4 +2,8 @@ #if __has_include() #include +#endif + +#if __has_include() + #include #endif \ No newline at end of file diff --git a/include/sp/extensions/Tcp.h b/include/sp/extensions/Tcp.h new file mode 100644 index 0000000..5315211 --- /dev/null +++ b/include/sp/extensions/Tcp.h @@ -0,0 +1,4 @@ +#pragma once + +#include +#include \ No newline at end of file diff --git a/include/sp/extensions/tcp/TcpListener.h b/include/sp/extensions/tcp/TcpListener.h new file mode 100644 index 0000000..98ed8bb --- /dev/null +++ b/include/sp/extensions/tcp/TcpListener.h @@ -0,0 +1,65 @@ +#pragma once + +#include + +namespace sp { +namespace io { + +/** + * \class TcpListener + */ +class TcpListener : private NonCopyable { + public: + /** + * \brief Starts listening for guests to connect + * \param port The port to listen to + * \param maxConnexions The maximum amount of connexion that can happen at the same time. \n + * Every other guests will be kicked if this amount is reached. + * \return Whether this action was succesfull + */ + TcpListener(std::uint16_t a_Port, int a_MaxConnexions); + + /** + * \brief Default destructor + */ + ~TcpListener(); + + /** + * \brief Tries to accept an incoming request to connect + * \return the new socket if a new connexion was accepted or nullptr + */ + std::unique_ptr Accept(); + + /** + * \brief Closes the socket + */ + void Close(); + + /** + * \brief Allows to set the socket in non blocking/blocking mode + * \param a_Blocking If set to true, every call to Read will wait until the socket receives something + * \return true if the operation was successful + */ + bool SetBlocking(bool a_Blocking); + + /** + * \brief Getter of the m_Port member + * \return The port which the socket listen to + */ + std::uint16_t GetListeningPort() const; + + /** + * \brief Getter of the m_MaxConnections member + * \return The maximum amount of connexions that can happen at the same time. + */ + int GetMaximumConnections() const; + + private: + SocketHandle m_Handle; + std::uint16_t m_Port; + int m_MaxConnections; +}; + + +} // namespace io +} // namespace sp diff --git a/include/sp/extensions/tcp/TcpSocket.h b/include/sp/extensions/tcp/TcpSocket.h new file mode 100644 index 0000000..b84b724 --- /dev/null +++ b/include/sp/extensions/tcp/TcpSocket.h @@ -0,0 +1,85 @@ +#pragma once + +#include +#include + +namespace sp { +namespace io { + +using SocketHandle = int; + +struct TcpTag {}; + +class SocketError : public std::exception { + private: + std::string m_Error; + + public: + SocketError(std::string&& a_Msg) : m_Error(std::move(a_Msg)) {} + + virtual const char* what() const noexcept override { + return m_Error.c_str(); + } +}; + +template <> +class IOInterface : private NonCopyable { + public: + /** + * \enum Status + * \brief Describes the state of a socket + */ + enum class Status { + /** The socket is connected */ + Connected, + /** The socket is not connected */ + Disconnected, + /** Something bad happened */ + Error, + }; + + IOInterface(); + IOInterface(const std::string& a_Host, std::uint16_t a_Port); + IOInterface(IOInterface&& a_Other); + virtual ~IOInterface(); + + DataBuffer Read(std::size_t a_Amount); + void Write(const sp::DataBuffer& a_Data); + + /** + * \brief Allows to set the socket in non blocking/blocking mode + * \param a_Block If set to true, every call to Read will wait until the socket receives something + * \return true if the operation was successful + */ + bool SetBlocking(bool a_Block); + + /** + * \brief Getter of the m_Status member + * \return The TcpSocket::Status of this socket + */ + Status GetStatus() const; + + /** + * \brief Disconnects the socket from the remote + * \note Does nothing if the socket is not connected. \n + * This function is also called by the destructor. + */ + void Disconnect(); + + + private: + SocketHandle m_Handle; + Status m_Status; + + void Connect(const std::string& a_Host, std::uint16_t a_Port); + + friend class TcpListener; +}; + +/** + * \typedef TcpSocket + */ +using TcpSocket = IOInterface; + +} // namespace io +} // namespace sp diff --git a/src/sp/common/DataBuffer.cpp b/src/sp/common/DataBuffer.cpp index 00d8335..03c8a4e 100644 --- a/src/sp/common/DataBuffer.cpp +++ b/src/sp/common/DataBuffer.cpp @@ -8,6 +8,8 @@ namespace sp { DataBuffer::DataBuffer() : m_ReadOffset(0) {} +DataBuffer::DataBuffer(std::size_t a_InitialSize) : m_Buffer(a_InitialSize), m_ReadOffset(0) {} + DataBuffer::DataBuffer(const DataBuffer& other) : m_Buffer(other.m_Buffer), m_ReadOffset(other.m_ReadOffset) {} DataBuffer::DataBuffer(DataBuffer&& other) : m_Buffer(std::move(other.m_Buffer)), m_ReadOffset(std::move(other.m_ReadOffset)) {} diff --git a/src/sp/extensions/TcpListener.cpp b/src/sp/extensions/TcpListener.cpp new file mode 100644 index 0000000..330c975 --- /dev/null +++ b/src/sp/extensions/TcpListener.cpp @@ -0,0 +1,114 @@ +#include + + +#ifdef _WIN32 + +// Windows + +#include +#include + +#define ioctl ioctlsocket +#define WOULDBLOCK WSAEWOULDBLOCK +#define MSG_DONTWAIT 0 + +#else + +// Linux/Unix + +#include +#include +#include +#include +#include +#include +#include + +#define closesocket close +#define WOULDBLOCK EWOULDBLOCK +#define SD_BOTH SHUT_RDWR + +#endif + + + +#ifndef INVALID_SOCKET +#define INVALID_SOCKET -1 +#endif + + +namespace sp { +namespace io { + +TcpListener::TcpListener(std::uint16_t a_Port, int a_MaxConnexions) { + if ((m_Handle = static_cast(socket(AF_INET, SOCK_STREAM, IPPROTO_TCP))) < 0) { + throw SocketError("Failed to create server socket"); + } + + struct sockaddr_in address; + address.sin_family = AF_INET; + address.sin_addr.s_addr = INADDR_ANY; + address.sin_port = htons(a_Port); + + if (bind(m_Handle, reinterpret_cast(&address), sizeof(address)) < 0) + throw SocketError("Failed to create server socket"); + + if (listen(m_Handle, a_MaxConnexions) < 0) + throw SocketError("Failed to create server socket"); + + socklen_t len = sizeof(address); + if (getsockname(m_Handle, reinterpret_cast(&address), &len) < 0) + throw SocketError("Failed to create server socket"); + + m_Port = ntohs(address.sin_port); + m_MaxConnections = a_MaxConnexions; +} + + +TcpListener::~TcpListener() { + Close(); +} + +std::unique_ptr TcpListener::Accept() { + sockaddr remoteAddress; + int addrlen = sizeof(remoteAddress); + + auto newSocket = std::make_unique(); + + newSocket->m_Handle = static_cast( + accept(m_Handle, reinterpret_cast(&remoteAddress), reinterpret_cast(&addrlen))); + + if (newSocket->m_Handle < 0) + return nullptr; + + newSocket->m_Status = TcpSocket::Status::Connected; + return newSocket; +} + +void TcpListener::Close() { + if (m_Handle > 0) { + closesocket(m_Handle); + shutdown(m_Handle, SD_BOTH); + } +} + +bool TcpListener::SetBlocking(bool a_Blocking) { + unsigned long mode = !a_Blocking; + + if (ioctl(m_Handle, FIONBIO, &mode) < 0) { + return false; + } + + return true; +} + +std::uint16_t TcpListener::GetListeningPort() const { + return m_Port; +} + +int TcpListener::GetMaximumConnections() const { + return m_MaxConnections; +} + +} // namespace io +} // namespace sp diff --git a/src/sp/extensions/TcpSocket.cpp b/src/sp/extensions/TcpSocket.cpp new file mode 100644 index 0000000..c2da877 --- /dev/null +++ b/src/sp/extensions/TcpSocket.cpp @@ -0,0 +1,157 @@ +#include + +#ifdef _WIN32 + +// Windows + +#include +#include + +#define ioctl ioctlsocket +#define WOULDBLOCK WSAEWOULDBLOCK +#define MSG_DONTWAIT 0 + +#else + +// Linux/Unix + +#include +#include +#include +#include +#include +#include +#include + +#define closesocket close +#define WOULDBLOCK EWOULDBLOCK + +#endif + + + +#ifndef INVALID_SOCKET +#define INVALID_SOCKET -1 +#endif + + +namespace sp { +namespace io { + +TcpSocket::IOInterface() : m_Handle(static_cast(INVALID_SOCKET)), m_Status(Status::Disconnected) {} + +TcpSocket::IOInterface(const std::string& a_Host, std::uint16_t a_Port) : IOInterface() { + Connect(a_Host, a_Port); +} + +TcpSocket::IOInterface(IOInterface&& a_Other) {} + +TcpSocket::~IOInterface() {} + +void TcpSocket::Connect(const std::string& a_Host, std::uint16_t a_Port) { + struct addrinfo hints {}; + + struct addrinfo* result = nullptr; + + hints.ai_family = AF_INET; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = IPPROTO_TCP; + + m_Status = Status::Error; + + if (getaddrinfo(a_Host.c_str(), std::to_string(static_cast(a_Port)).c_str(), &hints, &result) != 0) { + throw SocketError("Failed to get address info"); + } + + m_Handle = static_cast(socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); + if (m_Handle < 0) { + throw SocketError("Failed to create socket"); + } + + struct addrinfo* ptr = nullptr; + for (ptr = result; ptr != nullptr; ptr = ptr->ai_next) { + struct sockaddr* sockaddr = ptr->ai_addr; + if (connect(m_Handle, sockaddr, sizeof(sockaddr_in)) == 0) { + break; + } + } + + freeaddrinfo(result); + + if (!ptr) { + throw SocketError("Could not find a suitable interface for connecting"); + } + + m_Status = Status::Connected; +} + +DataBuffer TcpSocket::Read(std::size_t a_Amount) { + DataBuffer buffer(a_Amount); + + std::size_t totalRecieved = 0; + + while (totalRecieved < a_Amount) { + int recvAmount = recv(m_Handle, reinterpret_cast(buffer.data() + totalRecieved), static_cast(a_Amount - totalRecieved), 0); + if (recvAmount <= 0) { +#if defined(_WIN32) || defined(WIN32) + int err = WSAGetLastError(); +#else + int err = errno; +#endif + if (err == WOULDBLOCK) { + m_Status = Status::Error; + throw SocketError("Error while reading"); + } + + Disconnect(); + m_Status = Status::Error; + throw SocketError("Error while reading"); + } + totalRecieved += recvAmount; + } + return buffer; +} + +void TcpSocket::Write(const sp::DataBuffer& a_Data) { + if (GetStatus() != Status::Connected) + return; + + std::size_t sent = 0; + + while (sent < a_Data.GetSize()) { + int cur = send(m_Handle, reinterpret_cast(a_Data.GetSize() + sent), static_cast(a_Data.GetSize() - sent), 0); + + if (cur <= 0) { + Disconnect(); + m_Status = Status::Error; + return; + } + sent += static_cast(cur); + } +} + +bool TcpSocket::SetBlocking(bool a_Block) { + unsigned long mode = !a_Block; + + if (ioctl(m_Handle, FIONBIO, &mode) < 0) { + return false; + } + + return true; +} + +TcpSocket::Status TcpSocket::GetStatus() const { + return m_Status; +} + +void TcpSocket::Disconnect() { + if (m_Handle > 0) + closesocket(m_Handle); + m_Status = Status::Disconnected; +} + + + +} // namespace io + +} // namespace sp diff --git a/xmake.lua b/xmake.lua index 2012f06..f7d3b3e 100644 --- a/xmake.lua +++ b/xmake.lua @@ -6,12 +6,21 @@ local modules = { Compression = { Option = "zlib", Deps = {"zlib"}, - Packages = {"zlib"}, Includes = {"include/(sp/extensions/Compress.h)"}, Sources = {"src/sp/extensions/Compress.cpp"} + }, + TcpSocket = { + Option = "tcp", + Deps = {}, + Includes = {"include/(sp/extensions/Tcp.h)"}, + Sources = {"src/sp/extensions/Tcp*.cpp"} } } + + + + -- Map modules to options for name, module in table.orderpairs(modules) do if module.Option then @@ -19,6 +28,10 @@ for name, module in table.orderpairs(modules) do end end + + + + -- Add modules requirements for name, module in table.orderpairs(modules) do if module.Deps then @@ -26,6 +39,10 @@ for name, module in table.orderpairs(modules) do end end + + + + -- Add modules targets for name, module in table.orderpairs(modules) do if module.Deps and has_config(module.Option) then @@ -37,7 +54,7 @@ for name, module in table.orderpairs(modules) do for _, source in table.orderpairs(module.Sources) do add_files(source) end - for _, package in table.orderpairs(module.Packages) do + for _, package in table.orderpairs(module.Deps) do add_packages(package) end set_group("Library") @@ -45,9 +62,15 @@ for name, module in table.orderpairs(modules) do end end + + + + target("SimpleProtocol") add_includedirs("include") add_files("src/sp/**.cpp") + set_group("Library") + set_kind("$(kind)") local includeFolders = {"common", "default", "io", "protocol"} for _, folder in ipairs(includeFolders) do @@ -63,8 +86,13 @@ target("SimpleProtocol") -- we don't want extensions remove_files("src/sp/extensions/**.cpp") - set_group("Library") - set_kind("$(kind)") + + -- we need this for endian functions + if is_os("windows") then + add_links("ws2_32") + end + + -- Tests for _, file in ipairs(os.files("test/**.cpp")) do -- 2.49.1 From 04ca498b0cad7f7f67f2330ab759ed71107f573f Mon Sep 17 00:00:00 2001 From: Persson-dev Date: Sun, 2 Mar 2025 17:36:17 +0100 Subject: [PATCH 03/12] add missing headers --- xmake.lua | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xmake.lua b/xmake.lua index f7d3b3e..1241360 100644 --- a/xmake.lua +++ b/xmake.lua @@ -12,7 +12,7 @@ local modules = { TcpSocket = { Option = "tcp", Deps = {}, - Includes = {"include/(sp/extensions/Tcp.h)"}, + Includes = {"include/(sp/extensions/Tcp.h)", "include/(sp/extensions/tcp/*.h)"}, Sources = {"src/sp/extensions/Tcp*.cpp"} } } -- 2.49.1 From 32d30c7f446bfc50e1bfff2ee93669512a27c435 Mon Sep 17 00:00:00 2001 From: Persson-dev Date: Sun, 2 Mar 2025 17:52:37 +0100 Subject: [PATCH 04/12] fix includes --- xmake.lua | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/xmake.lua b/xmake.lua index 1241360..1785976 100644 --- a/xmake.lua +++ b/xmake.lua @@ -71,11 +71,8 @@ target("SimpleProtocol") add_files("src/sp/**.cpp") set_group("Library") set_kind("$(kind)") - - local includeFolders = {"common", "default", "io", "protocol"} - for _, folder in ipairs(includeFolders) do - add_headerfiles("include/(sp/" .. folder .. "/**.h)") - end + + add_headerfiles("include/(sp/**.h)", "include/(sp/**.inl)") -- adding extensions for name, module in table.orderpairs(modules) do @@ -86,6 +83,7 @@ target("SimpleProtocol") -- we don't want extensions remove_files("src/sp/extensions/**.cpp") + remove_headerfiles("include/(sp/extension/**.h)") -- we need this for endian functions if is_os("windows") then -- 2.49.1 From 643da71e34d46ca149e114097f89976e44cbae4e Mon Sep 17 00:00:00 2001 From: Persson-dev Date: Sun, 2 Mar 2025 20:33:59 +0100 Subject: [PATCH 05/12] add socket move --- include/sp/extensions/tcp/TcpSocket.h | 1 + src/sp/extensions/TcpSocket.cpp | 9 ++++++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/include/sp/extensions/tcp/TcpSocket.h b/include/sp/extensions/tcp/TcpSocket.h index b84b724..1b14f3f 100644 --- a/include/sp/extensions/tcp/TcpSocket.h +++ b/include/sp/extensions/tcp/TcpSocket.h @@ -41,6 +41,7 @@ class IOInterface : private NonCopyable { IOInterface(); IOInterface(const std::string& a_Host, std::uint16_t a_Port); IOInterface(IOInterface&& a_Other); + IOInterface& operator=(IOInterface&& a_Other); virtual ~IOInterface(); DataBuffer Read(std::size_t a_Amount); diff --git a/src/sp/extensions/TcpSocket.cpp b/src/sp/extensions/TcpSocket.cpp index c2da877..48c3662 100644 --- a/src/sp/extensions/TcpSocket.cpp +++ b/src/sp/extensions/TcpSocket.cpp @@ -91,7 +91,8 @@ DataBuffer TcpSocket::Read(std::size_t a_Amount) { std::size_t totalRecieved = 0; while (totalRecieved < a_Amount) { - int recvAmount = recv(m_Handle, reinterpret_cast(buffer.data() + totalRecieved), static_cast(a_Amount - totalRecieved), 0); + int recvAmount = + recv(m_Handle, reinterpret_cast(buffer.data() + totalRecieved), static_cast(a_Amount - totalRecieved), 0); if (recvAmount <= 0) { #if defined(_WIN32) || defined(WIN32) int err = WSAGetLastError(); @@ -150,8 +151,10 @@ void TcpSocket::Disconnect() { m_Status = Status::Disconnected; } - +TcpSocket& TcpSocket::operator=(IOInterface&& a_Other) { + std::swap(m_Handle, a_Other.m_Handle); + std::swap(m_Status, a_Other.m_Status); +} } // namespace io - } // namespace sp -- 2.49.1 From 8d3d9e38eeef41662cfc2d9f7797d01200c730e3 Mon Sep 17 00:00:00 2001 From: Persson-dev Date: Sun, 2 Mar 2025 20:37:54 +0100 Subject: [PATCH 06/12] dispatcher: use weak_ptr --- include/sp/protocol/MessageDispatcher.h | 10 ++++--- .../message/MessageDispatcherImpl.inl | 28 +++++++++++++------ 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/include/sp/protocol/MessageDispatcher.h b/include/sp/protocol/MessageDispatcher.h index cd6b5cd..2f5e9a0 100644 --- a/include/sp/protocol/MessageDispatcher.h +++ b/include/sp/protocol/MessageDispatcher.h @@ -17,7 +17,7 @@ namespace sp { template class MessageDispatcher { private: - std::map>> m_Handlers; + std::map>> m_Handlers; public: using MessageBaseType = MessageBase; @@ -38,18 +38,20 @@ class MessageDispatcher { * \param type The packet type * \param handler The packet handler */ - void RegisterHandler(MessageIdType a_MessageType, const std::shared_ptr& a_Handler); + void RegisterHandler(MessageIdType a_MessageType, const std::weak_ptr& a_Handler); + /** * \brief Unregister a packet handler * \param type The packet type * \param handler The packet handler */ - void UnregisterHandler(MessageIdType a_MessageType, const std::shared_ptr& a_Handler); + void UnregisterHandler(MessageIdType a_MessageType, const std::weak_ptr& a_Handler); + /** * \brief Unregister a packet handler * \param handler The packet handler */ - void UnregisterHandler(const std::shared_ptr& a_Handler); + void UnregisterHandler(const std::weak_ptr& a_Handler); }; #include diff --git a/include/sp/protocol/message/MessageDispatcherImpl.inl b/include/sp/protocol/message/MessageDispatcherImpl.inl index 3531452..568c4b7 100644 --- a/include/sp/protocol/message/MessageDispatcherImpl.inl +++ b/include/sp/protocol/message/MessageDispatcherImpl.inl @@ -1,34 +1,46 @@ #pragma once template -void MessageDispatcher::RegisterHandler(MessageIdType a_MessageType, const std::shared_ptr& a_Handler) { - auto found = std::find(m_Handlers[a_MessageType].begin(), m_Handlers[a_MessageType].end(), a_Handler); +void MessageDispatcher::RegisterHandler(MessageIdType a_MessageType, const std::weak_ptr& a_Handler) { + auto found = std::find_if(m_Handlers[a_MessageType].begin(), m_Handlers[a_MessageType].end(), [&a_Handler](const std::weak_ptr& handler){ + return a_Handler.lock() == handler.lock(); + }); if (found == m_Handlers[a_MessageType].end()) m_Handlers[a_MessageType].push_back(a_Handler); } template -void MessageDispatcher::UnregisterHandler(MessageIdType a_MessageType, const std::shared_ptr& a_Handler) { - auto found = std::find(m_Handlers[a_MessageType].begin(), m_Handlers[a_MessageType].end(), a_Handler); +void MessageDispatcher::UnregisterHandler(MessageIdType a_MessageType, const std::weak_ptr& a_Handler) { + auto found = std::find_if(m_Handlers[a_MessageType].begin(), m_Handlers[a_MessageType].end(), [&a_Handler](const std::weak_ptr& handler){ + return a_Handler.lock() == handler.lock(); + }); if (found != m_Handlers[a_MessageType].end()) m_Handlers[a_MessageType].erase(found); } template -void MessageDispatcher::UnregisterHandler(const std::shared_ptr& a_Handler) { +void MessageDispatcher::UnregisterHandler(const std::weak_ptr& a_Handler) { for (auto& pair : m_Handlers) { if (pair.second.empty()) continue; MessageIdType type = pair.first; - m_Handlers[type].erase(std::remove(m_Handlers[type].begin(), m_Handlers[type].end(), a_Handler), m_Handlers[type].end()); + auto it = std::find_if(pair.second.begin(), pair.second.end(), [&a_Handler](const std::weak_ptr& handler){ + return handler.lock() == a_Handler.lock(); + }); + + if (it != pair.second.end()) + pair.second.erase(it); + } } template void MessageDispatcher::Dispatch(const MessageBase& a_Message) { MessageIdType type = a_Message.GetId(); - for (auto& handler : m_Handlers[type]) - a_Message.Dispatch(*handler); + for (auto& handler : m_Handlers[type]) { + if (!handler.expired()) + a_Message.Dispatch(*handler.lock()); + } } -- 2.49.1 From afc41894506734cc248e795ccfec29e77d28a701 Mon Sep 17 00:00:00 2001 From: Persson-dev Date: Sun, 2 Mar 2025 20:38:44 +0100 Subject: [PATCH 07/12] fix tcpsocket error --- src/sp/extensions/TcpSocket.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sp/extensions/TcpSocket.cpp b/src/sp/extensions/TcpSocket.cpp index 48c3662..d7e6e85 100644 --- a/src/sp/extensions/TcpSocket.cpp +++ b/src/sp/extensions/TcpSocket.cpp @@ -100,8 +100,8 @@ DataBuffer TcpSocket::Read(std::size_t a_Amount) { int err = errno; #endif if (err == WOULDBLOCK) { - m_Status = Status::Error; - throw SocketError("Error while reading"); + // we are in non blocking mode and nothing is available + return {}; } Disconnect(); -- 2.49.1 From b687ac65f1be1b596261315314cd2d14ac39c912 Mon Sep 17 00:00:00 2001 From: Persson-dev Date: Sun, 2 Mar 2025 21:04:14 +0100 Subject: [PATCH 08/12] fix tcpsocket move --- src/sp/extensions/TcpSocket.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/sp/extensions/TcpSocket.cpp b/src/sp/extensions/TcpSocket.cpp index d7e6e85..80a2f58 100644 --- a/src/sp/extensions/TcpSocket.cpp +++ b/src/sp/extensions/TcpSocket.cpp @@ -44,7 +44,10 @@ TcpSocket::IOInterface(const std::string& a_Host, std::uint16_t a_Port) : IOInte Connect(a_Host, a_Port); } -TcpSocket::IOInterface(IOInterface&& a_Other) {} +TcpSocket::IOInterface(IOInterface&& a_Other) { + std::swap(m_Handle, a_Other.m_Handle); + std::swap(m_Status, a_Other.m_Status); +} TcpSocket::~IOInterface() {} @@ -154,6 +157,7 @@ void TcpSocket::Disconnect() { TcpSocket& TcpSocket::operator=(IOInterface&& a_Other) { std::swap(m_Handle, a_Other.m_Handle); std::swap(m_Status, a_Other.m_Status); + return *this; } } // namespace io -- 2.49.1 From 025a9c14692dc5375a07465d865db1283290e37e Mon Sep 17 00:00:00 2001 From: Persson-dev Date: Sun, 2 Mar 2025 21:13:07 +0100 Subject: [PATCH 09/12] fix socket write --- src/sp/extensions/TcpSocket.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sp/extensions/TcpSocket.cpp b/src/sp/extensions/TcpSocket.cpp index 80a2f58..ab08b3a 100644 --- a/src/sp/extensions/TcpSocket.cpp +++ b/src/sp/extensions/TcpSocket.cpp @@ -123,7 +123,7 @@ void TcpSocket::Write(const sp::DataBuffer& a_Data) { std::size_t sent = 0; while (sent < a_Data.GetSize()) { - int cur = send(m_Handle, reinterpret_cast(a_Data.GetSize() + sent), static_cast(a_Data.GetSize() - sent), 0); + int cur = send(m_Handle, reinterpret_cast(a_Data.data() + sent), static_cast(a_Data.GetSize() - sent), 0); if (cur <= 0) { Disconnect(); -- 2.49.1 From 6466e4bc452c3e2400f9bf52c19610ca018e3539 Mon Sep 17 00:00:00 2001 From: Persson-dev Date: Sun, 2 Mar 2025 22:29:31 +0100 Subject: [PATCH 10/12] raw ptr test --- include/sp/protocol/MessageDispatcher.h | 8 +++---- .../message/MessageDispatcherImpl.inl | 22 +++++++++---------- test/test_file.cpp | 4 ++-- test/test_io.cpp | 4 ++-- test/test_packets.cpp | 6 ++--- 5 files changed, 22 insertions(+), 22 deletions(-) diff --git a/include/sp/protocol/MessageDispatcher.h b/include/sp/protocol/MessageDispatcher.h index 2f5e9a0..1cdba2c 100644 --- a/include/sp/protocol/MessageDispatcher.h +++ b/include/sp/protocol/MessageDispatcher.h @@ -17,7 +17,7 @@ namespace sp { template class MessageDispatcher { private: - std::map>> m_Handlers; + std::map> m_Handlers; public: using MessageBaseType = MessageBase; @@ -38,20 +38,20 @@ class MessageDispatcher { * \param type The packet type * \param handler The packet handler */ - void RegisterHandler(MessageIdType a_MessageType, const std::weak_ptr& a_Handler); + void RegisterHandler(MessageIdType a_MessageType, MessageHandler* a_Handler); /** * \brief Unregister a packet handler * \param type The packet type * \param handler The packet handler */ - void UnregisterHandler(MessageIdType a_MessageType, const std::weak_ptr& a_Handler); + void UnregisterHandler(MessageIdType a_MessageType, MessageHandler* a_Handler); /** * \brief Unregister a packet handler * \param handler The packet handler */ - void UnregisterHandler(const std::weak_ptr& a_Handler); + void UnregisterHandler(MessageHandler* a_Handler); }; #include diff --git a/include/sp/protocol/message/MessageDispatcherImpl.inl b/include/sp/protocol/message/MessageDispatcherImpl.inl index 568c4b7..734fab6 100644 --- a/include/sp/protocol/message/MessageDispatcherImpl.inl +++ b/include/sp/protocol/message/MessageDispatcherImpl.inl @@ -1,33 +1,33 @@ #pragma once template -void MessageDispatcher::RegisterHandler(MessageIdType a_MessageType, const std::weak_ptr& a_Handler) { - auto found = std::find_if(m_Handlers[a_MessageType].begin(), m_Handlers[a_MessageType].end(), [&a_Handler](const std::weak_ptr& handler){ - return a_Handler.lock() == handler.lock(); +void MessageDispatcher::RegisterHandler(MessageIdType a_MessageType, MessageHandler* a_Handler) { + auto found = std::find_if(m_Handlers[a_MessageType].begin(), m_Handlers[a_MessageType].end(), [&a_Handler](const MessageHandler* handler){ + return a_Handler == handler; }); if (found == m_Handlers[a_MessageType].end()) m_Handlers[a_MessageType].push_back(a_Handler); } template -void MessageDispatcher::UnregisterHandler(MessageIdType a_MessageType, const std::weak_ptr& a_Handler) { - auto found = std::find_if(m_Handlers[a_MessageType].begin(), m_Handlers[a_MessageType].end(), [&a_Handler](const std::weak_ptr& handler){ - return a_Handler.lock() == handler.lock(); +void MessageDispatcher::UnregisterHandler(MessageIdType a_MessageType, MessageHandler* a_Handler) { + auto found = std::find_if(m_Handlers[a_MessageType].begin(), m_Handlers[a_MessageType].end(), [&a_Handler](const MessageHandler* handler){ + return a_Handler == handler; }); if (found != m_Handlers[a_MessageType].end()) m_Handlers[a_MessageType].erase(found); } template -void MessageDispatcher::UnregisterHandler(const std::weak_ptr& a_Handler) { +void MessageDispatcher::UnregisterHandler(MessageHandler* a_Handler) { for (auto& pair : m_Handlers) { if (pair.second.empty()) continue; MessageIdType type = pair.first; - auto it = std::find_if(pair.second.begin(), pair.second.end(), [&a_Handler](const std::weak_ptr& handler){ - return handler.lock() == a_Handler.lock(); + auto it = std::find_if(pair.second.begin(), pair.second.end(), [&a_Handler](MessageHandler* handler){ + return handler == a_Handler; }); if (it != pair.second.end()) @@ -40,7 +40,7 @@ template void MessageDispatcher::Dispatch(const MessageBase& a_Message) { MessageIdType type = a_Message.GetId(); for (auto& handler : m_Handlers[type]) { - if (!handler.expired()) - a_Message.Dispatch(*handler.lock()); + if (handler) + a_Message.Dispatch(*handler); } } diff --git a/test/test_file.cpp b/test/test_file.cpp index 2a89a3c..f777338 100644 --- a/test/test_file.cpp +++ b/test/test_file.cpp @@ -23,8 +23,8 @@ int main() { auto handler = std::make_shared(); FileStream stream(sp::io::File{"test.txt", sp::io::FileTag::In | sp::io::FileTag::Out}, {}); - stream.GetDispatcher().RegisterHandler(PacketId::Disconnect, handler); - stream.GetDispatcher().RegisterHandler(PacketId::KeepAlive, handler); + stream.GetDispatcher().RegisterHandler(PacketId::Disconnect, handler.get()); + stream.GetDispatcher().RegisterHandler(PacketId::KeepAlive, handler.get()); stream.SendMessage(KeepAlivePacket{96}); stream.SendMessage(KeepAlivePacket{69}); diff --git a/test/test_io.cpp b/test/test_io.cpp index c8affb6..d67c521 100644 --- a/test/test_io.cpp +++ b/test/test_io.cpp @@ -23,13 +23,13 @@ int main() { auto handler = std::make_shared(); DataBufferStream stream; - stream.GetDispatcher().RegisterHandler(PacketId::Disconnect, handler); + stream.GetDispatcher().RegisterHandler(PacketId::Disconnect, handler.get()); // this should not be dispatched stream.SendMessage(KeepAlivePacket{96}); stream.RecieveMessages(); - stream.GetDispatcher().RegisterHandler(PacketId::KeepAlive, handler); + stream.GetDispatcher().RegisterHandler(PacketId::KeepAlive, handler.get()); stream.SendMessage(KeepAlivePacket{69}); stream.RecieveMessages(); diff --git a/test/test_packets.cpp b/test/test_packets.cpp index e89e171..b17f582 100644 --- a/test/test_packets.cpp +++ b/test/test_packets.cpp @@ -48,10 +48,10 @@ int main() { packet->Dispatch(*handler); sp::PacketDispatcher dispatcher; - dispatcher.RegisterHandler(PacketId::KeepAlive, handler); + dispatcher.RegisterHandler(PacketId::KeepAlive, handler.get()); dispatcher.Dispatch(*packet); - dispatcher.UnregisterHandler(PacketId::KeepAlive, handler); - dispatcher.UnregisterHandler(handler); + dispatcher.UnregisterHandler(PacketId::KeepAlive, handler.get()); + dispatcher.UnregisterHandler(handler.get()); return 0; } \ No newline at end of file -- 2.49.1 From c46c5bed6f770f36dd3f1f4a793bc7b3e068ec6b Mon Sep 17 00:00:00 2001 From: Persson-dev Date: Mon, 3 Mar 2025 11:00:36 +0100 Subject: [PATCH 11/12] simplify dispatcher --- .../protocol/message/MessageDispatcherImpl.inl | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/include/sp/protocol/message/MessageDispatcherImpl.inl b/include/sp/protocol/message/MessageDispatcherImpl.inl index 734fab6..23e9a87 100644 --- a/include/sp/protocol/message/MessageDispatcherImpl.inl +++ b/include/sp/protocol/message/MessageDispatcherImpl.inl @@ -2,18 +2,14 @@ template void MessageDispatcher::RegisterHandler(MessageIdType a_MessageType, MessageHandler* a_Handler) { - auto found = std::find_if(m_Handlers[a_MessageType].begin(), m_Handlers[a_MessageType].end(), [&a_Handler](const MessageHandler* handler){ - return a_Handler == handler; - }); + auto found = std::find(m_Handlers[a_MessageType].begin(), m_Handlers[a_MessageType].end(), a_Handler); if (found == m_Handlers[a_MessageType].end()) m_Handlers[a_MessageType].push_back(a_Handler); } template void MessageDispatcher::UnregisterHandler(MessageIdType a_MessageType, MessageHandler* a_Handler) { - auto found = std::find_if(m_Handlers[a_MessageType].begin(), m_Handlers[a_MessageType].end(), [&a_Handler](const MessageHandler* handler){ - return a_Handler == handler; - }); + auto found = std::find(m_Handlers[a_MessageType].begin(), m_Handlers[a_MessageType].end(), a_Handler); if (found != m_Handlers[a_MessageType].end()) m_Handlers[a_MessageType].erase(found); } @@ -26,13 +22,7 @@ void MessageDispatcher::UnregisterHa MessageIdType type = pair.first; - auto it = std::find_if(pair.second.begin(), pair.second.end(), [&a_Handler](MessageHandler* handler){ - return handler == a_Handler; - }); - - if (it != pair.second.end()) - pair.second.erase(it); - + pair.second.erase(std::remove(pair.second.begin(), pair.second.end(), a_Handler), pair.second.end()); } } -- 2.49.1 From 7a4b2aeb4a5af36259b8c8f23aa95850cef679fd Mon Sep 17 00:00:00 2001 From: Persson-dev Date: Mon, 3 Mar 2025 11:02:56 +0100 Subject: [PATCH 12/12] dispatcher nullcheck --- include/sp/protocol/message/MessageDispatcherImpl.inl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/sp/protocol/message/MessageDispatcherImpl.inl b/include/sp/protocol/message/MessageDispatcherImpl.inl index 23e9a87..7319fbf 100644 --- a/include/sp/protocol/message/MessageDispatcherImpl.inl +++ b/include/sp/protocol/message/MessageDispatcherImpl.inl @@ -2,6 +2,7 @@ template void MessageDispatcher::RegisterHandler(MessageIdType a_MessageType, MessageHandler* a_Handler) { + assert(a_Handler); auto found = std::find(m_Handlers[a_MessageType].begin(), m_Handlers[a_MessageType].end(), a_Handler); if (found == m_Handlers[a_MessageType].end()) m_Handlers[a_MessageType].push_back(a_Handler); @@ -30,7 +31,6 @@ template void MessageDispatcher::Dispatch(const MessageBase& a_Message) { MessageIdType type = a_Message.GetId(); for (auto& handler : m_Handlers[type]) { - if (handler) - a_Message.Dispatch(*handler); + a_Message.Dispatch(*handler); } } -- 2.49.1