From ed0b06f78dc5696bf52565b02df7b72b73abf334 Mon Sep 17 00:00:00 2001 From: Persson-dev Date: Fri, 27 Jun 2025 18:53:03 +0200 Subject: [PATCH] finish io --- include/sp/extensions/Compress.h | 35 ++--- include/sp/extensions/tcp/TcpListener.h | 3 +- include/sp/extensions/tcp/TcpSocket.h | 30 ++--- include/sp/io/IoInterface.h | 3 +- include/sp/io/MessageEncapsulator.h | 19 ++- include/sp/io/MessageStream.h | 23 +++- include/sp/io/MessageStream.inl | 8 +- src/sp/extensions/Compress.cpp | 85 ++++++++++++ src/sp/extensions/TcpListener.cpp | 114 ++++++++++++++++ src/sp/extensions/TcpSocket.cpp | 164 ++++++++++++++++++++++++ test/test_message.cpp | 13 +- 11 files changed, 443 insertions(+), 54 deletions(-) diff --git a/include/sp/extensions/Compress.h b/include/sp/extensions/Compress.h index b3b06e3..e6b6ce6 100644 --- a/include/sp/extensions/Compress.h +++ b/include/sp/extensions/Compress.h @@ -6,20 +6,7 @@ */ #include -#include - -namespace sp { -namespace option { - -struct ZlibCompress { - bool m_Enabled = true; - std::size_t m_CompressionThreshold = 64; -}; - -} // namespace option -} // namespace sp - -#include +#include namespace sp { namespace zlib { @@ -41,14 +28,20 @@ DataBuffer Decompress(DataBuffer& buffer, std::uint64_t packetLength); } // namespace zlib -namespace io { -template <> -class MessageEncapsulator { +class ZlibCompress : public MessageEncapsulator { + private: + std::size_t m_CompressionThreshold; + public: - static DataBuffer Encapsulate(const DataBuffer& a_Data, const option::ZlibCompress& a_Option); - static DataBuffer Decapsulate(DataBuffer& a_Data, const option::ZlibCompress& a_Option); + ZlibCompress() : m_CompressionThreshold(64) {} + ZlibCompress(const ZlibCompress&) = default; + virtual ~ZlibCompress() {} + + protected: + virtual DataBuffer EncapsulateImpl(const DataBuffer& a_Data) override; + virtual DataBuffer DecapsulateImpl(DataBuffer& a_Data) override; }; -} // namespace io -} // namespace sp + +} // namespace sp \ No newline at end of file diff --git a/include/sp/extensions/tcp/TcpListener.h b/include/sp/extensions/tcp/TcpListener.h index 07e2e06..6000e19 100644 --- a/include/sp/extensions/tcp/TcpListener.h +++ b/include/sp/extensions/tcp/TcpListener.h @@ -1,6 +1,7 @@ #pragma once #include +#include namespace sp { namespace io { @@ -10,7 +11,7 @@ namespace io { */ class TcpListener : private NonCopyable { public: - using SocketHandle = TcpTag::SocketHandle; + using SocketHandle = TcpSocket::SocketHandle; /** * \brief Starts listening for guests to connect diff --git a/include/sp/extensions/tcp/TcpSocket.h b/include/sp/extensions/tcp/TcpSocket.h index 2ec39d3..0abdfd6 100644 --- a/include/sp/extensions/tcp/TcpSocket.h +++ b/include/sp/extensions/tcp/TcpSocket.h @@ -1,14 +1,12 @@ #pragma once #include -#include +#include namespace sp { namespace io { -struct TcpTag { - using SocketHandle = int; -}; +class TcpListener; class SocketError : public std::exception { private: @@ -22,10 +20,9 @@ class SocketError : public std::exception { } }; -template <> -class IOInterface : private NonCopyable { +class TcpSocket : public sp::IoInterface { public: - using SocketHandle = TcpTag::SocketHandle; + using SocketHandle = int; /** * \enum Status @@ -40,14 +37,14 @@ class IOInterface : private NonCopyable { Error, }; - IOInterface(); - IOInterface(const std::string& a_Host, std::uint16_t a_Port); - IOInterface(IOInterface&& a_Other); - IOInterface& operator=(IOInterface&& a_Other); - virtual ~IOInterface(); + TcpSocket(); + TcpSocket(const std::string& a_Host, std::uint16_t a_Port); + TcpSocket(TcpSocket&& a_Other); + TcpSocket& operator=(TcpSocket&& a_Other); + virtual ~TcpSocket(); - DataBuffer Read(std::size_t a_Amount); - void Write(const sp::DataBuffer& a_Data); + virtual DataBuffer Read(std::size_t a_Amount) override; + virtual void Write(const sp::DataBuffer& a_Data) override; /** * \brief Allows to set the socket in non blocking/blocking mode @@ -79,10 +76,5 @@ class IOInterface : private NonCopyable { friend class TcpListener; }; -/** - * \typedef TcpSocket - */ -using TcpSocket = IOInterface; - } // namespace io } // namespace sp diff --git a/include/sp/io/IoInterface.h b/include/sp/io/IoInterface.h index 037dfef..32c48ef 100644 --- a/include/sp/io/IoInterface.h +++ b/include/sp/io/IoInterface.h @@ -1,10 +1,11 @@ #pragma once #include +#include namespace sp { -class IoInterface { +class IoInterface : private NonCopyable { public: virtual DataBuffer Read(std::size_t a_Amount) = 0; virtual void Write(const DataBuffer& a_Data) = 0; diff --git a/include/sp/io/MessageEncapsulator.h b/include/sp/io/MessageEncapsulator.h index d8796fd..4a2a30b 100644 --- a/include/sp/io/MessageEncapsulator.h +++ b/include/sp/io/MessageEncapsulator.h @@ -5,12 +5,27 @@ namespace sp { class MessageEncapsulator { + protected: + bool m_Enabled = true; + public: MessageEncapsulator() {} virtual ~MessageEncapsulator() {} - virtual DataBuffer Encapsulate(const DataBuffer& a_Data) = 0; - virtual DataBuffer Decapsulate(DataBuffer& a_Data) = 0; + DataBuffer Encapsulate(const DataBuffer& a_Data) { + if (!m_Enabled) + return a_Data; + return EncapsulateImpl(a_Data); + } + DataBuffer Decapsulate(DataBuffer& a_Data) { + if (!m_Enabled) + return a_Data; + return DecapsulateImpl(a_Data); + } + + protected: + virtual DataBuffer EncapsulateImpl(const DataBuffer& a_Data) = 0; + virtual DataBuffer DecapsulateImpl(DataBuffer& a_Data) = 0; }; } // namespace sp diff --git a/include/sp/io/MessageStream.h b/include/sp/io/MessageStream.h index c5442f8..3aea568 100644 --- a/include/sp/io/MessageStream.h +++ b/include/sp/io/MessageStream.h @@ -9,8 +9,8 @@ namespace sp { template class MessageStream { - private: - std::vector m_Encapsulators; + protected: + std::vector> m_Encapsulators; std::shared_ptr m_Stream; using MessageBaseType = typename TMessageFactory::MessageBaseType; @@ -19,11 +19,30 @@ class MessageStream { public: MessageStream(std::shared_ptr&& a_Stream) : m_Stream(std::move(a_Stream)) {} + template + MessageStream(std::shared_ptr&& a_Stream, TEnc&&... a_Encapsulators) : + m_Stream(std::move(a_Stream)){ + m_Encapsulators.reserve(sizeof...(a_Encapsulators)); + AddEncapsulators(std::move(a_Encapsulators ...)); + } + std::unique_ptr ReadMessage(); std::unique_ptr ReadMessage(MessageIdType a_Id); void WriteMessage(const MessageBaseType& a_Message, bool a_WriteId = true); + template + void AddEncapsulators(Args&& ... a_Encapsulators) { + AddEncapsulators(std::move(std::make_tuple<>(a_Encapsulators ...))); + } + + template + void AddEncapsulators(std::tuple&& a_Encapsulators) { + TupleForEach([this](auto&& a_Encapsulator){ + m_Encapsulators.push_back(std::move(a_Encapsulator)); + }, a_Encapsulators); + } + private: DataBuffer ReadAndDecapsulate(); std::unique_ptr MakeMessage(DataBuffer& buffer, MessageIdType a_Id); diff --git a/include/sp/io/MessageStream.inl b/include/sp/io/MessageStream.inl index 98231f9..4a6f300 100644 --- a/include/sp/io/MessageStream.inl +++ b/include/sp/io/MessageStream.inl @@ -12,8 +12,8 @@ DataBuffer MessageStream::ReadAndDecapsulate() { std::size_t amount = messageLength.GetValue(); DataBuffer buffer = m_Stream->Read(amount); - for (MessageEncapsulator& enc : m_Encapsulators) { - buffer = enc.Decapsulate(buffer); + for (auto& enc : m_Encapsulators) { + buffer = enc->Decapsulate(buffer); } return buffer; @@ -51,8 +51,8 @@ void MessageStream::WriteMessage(const MessageBaseType& a_Messa if (a_WriteId) buffer << VarInt{static_cast(a_Message.GetId())}; buffer << a_Message.Write(); - for (MessageEncapsulator& enc : m_Encapsulators) { - buffer = enc.Encapsulate(buffer); + for (auto& enc : m_Encapsulators) { + buffer = enc->Encapsulate(buffer); } DataBuffer header; header << VarInt{buffer.GetSize()}; diff --git a/src/sp/extensions/Compress.cpp b/src/sp/extensions/Compress.cpp index e69de29..49faa3e 100644 --- a/src/sp/extensions/Compress.cpp +++ b/src/sp/extensions/Compress.cpp @@ -0,0 +1,85 @@ +#include + +#include +#include +#include + +namespace sp { +namespace zlib { + +static DataBuffer Inflate(const std::uint8_t* source, std::size_t size, std::size_t uncompressedSize) { + DataBuffer result; + result.Resize(uncompressedSize); + + uncompress(static_cast(result.data()), reinterpret_cast(&uncompressedSize), static_cast(source), + static_cast(size)); + + assert(result.GetSize() == uncompressedSize); + return result; +} + +static DataBuffer Deflate(const std::uint8_t* source, std::size_t size) { + DataBuffer result; + uLongf compressedSize = size; + + result.Resize(size); // Resize for the compressed data to fit into + compress(static_cast(result.data()), &compressedSize, static_cast(source), static_cast(size)); + result.Resize(compressedSize); // Resize to cut useless data + + return result; +} + +DataBuffer Compress(const DataBuffer& buffer, std::size_t a_CompressionThreshold) { + DataBuffer packet; + + if (buffer.GetSize() < a_CompressionThreshold) { + // Don't compress since it's a small packet + packet << VarInt{0}; + packet << buffer; + return packet; + } + + DataBuffer compressedData = Deflate(buffer.data(), buffer.GetSize()); + VarInt uncompressedDataLength = buffer.GetSize(); + + if (compressedData.GetSize() >= buffer.GetSize()) { + // the compression is overkill so we don't send the compressed buffer + packet << VarInt{0}; + packet << buffer; + } else { + packet << uncompressedDataLength; + packet << compressedData; + } + + return packet; +} + +DataBuffer Decompress(DataBuffer& buffer, std::uint64_t packetLength) { + VarInt uncompressedLength; + buffer >> uncompressedLength; + + std::uint64_t compressedLength = packetLength - uncompressedLength.GetSerializedLength(); + + if (uncompressedLength.GetValue() == 0) { + // Data already uncompressed. Nothing to do + DataBuffer ret; + buffer.ReadSome(ret, compressedLength); + return ret; + } + + assert(buffer.GetReadOffset() + compressedLength <= buffer.GetSize()); + + return Inflate(buffer.data() + buffer.GetReadOffset(), compressedLength, uncompressedLength.GetValue()); +} + +} // namespace zlib + +DataBuffer ZlibCompress::EncapsulateImpl(const DataBuffer& a_Data) { + return zlib::Compress(a_Data, m_CompressionThreshold); +} + +DataBuffer ZlibCompress::DecapsulateImpl(DataBuffer& a_Data) { + return zlib::Decompress(a_Data, a_Data.GetSize()); +} + +} // namespace sp diff --git a/src/sp/extensions/TcpListener.cpp b/src/sp/extensions/TcpListener.cpp index e69de29..330c975 100644 --- a/src/sp/extensions/TcpListener.cpp +++ b/src/sp/extensions/TcpListener.cpp @@ -0,0 +1,114 @@ +#include + + +#ifdef _WIN32 + +// Windows + +#include +#include + +#define ioctl ioctlsocket +#define WOULDBLOCK WSAEWOULDBLOCK +#define MSG_DONTWAIT 0 + +#else + +// Linux/Unix + +#include +#include +#include +#include +#include +#include +#include + +#define closesocket close +#define WOULDBLOCK EWOULDBLOCK +#define SD_BOTH SHUT_RDWR + +#endif + + + +#ifndef INVALID_SOCKET +#define INVALID_SOCKET -1 +#endif + + +namespace sp { +namespace io { + +TcpListener::TcpListener(std::uint16_t a_Port, int a_MaxConnexions) { + if ((m_Handle = static_cast(socket(AF_INET, SOCK_STREAM, IPPROTO_TCP))) < 0) { + throw SocketError("Failed to create server socket"); + } + + struct sockaddr_in address; + address.sin_family = AF_INET; + address.sin_addr.s_addr = INADDR_ANY; + address.sin_port = htons(a_Port); + + if (bind(m_Handle, reinterpret_cast(&address), sizeof(address)) < 0) + throw SocketError("Failed to create server socket"); + + if (listen(m_Handle, a_MaxConnexions) < 0) + throw SocketError("Failed to create server socket"); + + socklen_t len = sizeof(address); + if (getsockname(m_Handle, reinterpret_cast(&address), &len) < 0) + throw SocketError("Failed to create server socket"); + + m_Port = ntohs(address.sin_port); + m_MaxConnections = a_MaxConnexions; +} + + +TcpListener::~TcpListener() { + Close(); +} + +std::unique_ptr TcpListener::Accept() { + sockaddr remoteAddress; + int addrlen = sizeof(remoteAddress); + + auto newSocket = std::make_unique(); + + newSocket->m_Handle = static_cast( + accept(m_Handle, reinterpret_cast(&remoteAddress), reinterpret_cast(&addrlen))); + + if (newSocket->m_Handle < 0) + return nullptr; + + newSocket->m_Status = TcpSocket::Status::Connected; + return newSocket; +} + +void TcpListener::Close() { + if (m_Handle > 0) { + closesocket(m_Handle); + shutdown(m_Handle, SD_BOTH); + } +} + +bool TcpListener::SetBlocking(bool a_Blocking) { + unsigned long mode = !a_Blocking; + + if (ioctl(m_Handle, FIONBIO, &mode) < 0) { + return false; + } + + return true; +} + +std::uint16_t TcpListener::GetListeningPort() const { + return m_Port; +} + +int TcpListener::GetMaximumConnections() const { + return m_MaxConnections; +} + +} // namespace io +} // namespace sp diff --git a/src/sp/extensions/TcpSocket.cpp b/src/sp/extensions/TcpSocket.cpp index e69de29..047666f 100644 --- a/src/sp/extensions/TcpSocket.cpp +++ b/src/sp/extensions/TcpSocket.cpp @@ -0,0 +1,164 @@ +#include + +#ifdef _WIN32 + +// Windows + +#include +#include + +#define ioctl ioctlsocket +#define WOULDBLOCK WSAEWOULDBLOCK +#define MSG_DONTWAIT 0 + +#else + +// Linux/Unix + +#include +#include +#include +#include +#include +#include +#include + +#define closesocket close +#define WOULDBLOCK EWOULDBLOCK + +#endif + + + +#ifndef INVALID_SOCKET +#define INVALID_SOCKET -1 +#endif + + +namespace sp { +namespace io { + +TcpSocket::TcpSocket() : m_Handle(static_cast(INVALID_SOCKET)), m_Status(Status::Disconnected) {} + +TcpSocket::TcpSocket(const std::string& a_Host, std::uint16_t a_Port) : TcpSocket() { + Connect(a_Host, a_Port); +} + +TcpSocket::TcpSocket(TcpSocket&& a_Other) { + std::swap(m_Handle, a_Other.m_Handle); + std::swap(m_Status, a_Other.m_Status); +} + +TcpSocket::~TcpSocket() {} + +void TcpSocket::Connect(const std::string& a_Host, std::uint16_t a_Port) { + struct addrinfo hints {}; + + struct addrinfo* result = nullptr; + + hints.ai_family = AF_INET; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = IPPROTO_TCP; + + m_Status = Status::Error; + + if (getaddrinfo(a_Host.c_str(), std::to_string(static_cast(a_Port)).c_str(), &hints, &result) != 0) { + throw SocketError("Failed to get address info"); + } + + m_Handle = static_cast(socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); + if (m_Handle < 0) { + throw SocketError("Failed to create socket"); + } + + struct addrinfo* ptr = nullptr; + for (ptr = result; ptr != nullptr; ptr = ptr->ai_next) { + struct sockaddr* sockaddr = ptr->ai_addr; + if (connect(m_Handle, sockaddr, sizeof(sockaddr_in)) == 0) { + break; + } + } + + freeaddrinfo(result); + + if (!ptr) { + throw SocketError("Could not find a suitable interface for connecting"); + } + + m_Status = Status::Connected; +} + +DataBuffer TcpSocket::Read(std::size_t a_Amount) { + DataBuffer buffer(a_Amount); + + std::size_t totalRecieved = 0; + + while (totalRecieved < a_Amount) { + int recvAmount = + recv(m_Handle, reinterpret_cast(buffer.data() + totalRecieved), static_cast(a_Amount - totalRecieved), 0); + if (recvAmount <= 0) { +#if defined(_WIN32) || defined(WIN32) + int err = WSAGetLastError(); +#else + int err = errno; +#endif + if (err == WOULDBLOCK) { + // we are in non blocking mode and nothing is available + return {}; + } + + Disconnect(); + m_Status = Status::Error; + throw SocketError("Error while reading"); + } + totalRecieved += recvAmount; + } + return buffer; +} + +void TcpSocket::Write(const sp::DataBuffer& a_Data) { + if (GetStatus() != Status::Connected) + return; + + std::size_t sent = 0; + + while (sent < a_Data.GetSize()) { + int cur = send(m_Handle, reinterpret_cast(a_Data.data() + sent), static_cast(a_Data.GetSize() - sent), 0); + + if (cur <= 0) { + Disconnect(); + m_Status = Status::Error; + return; + } + sent += static_cast(cur); + } +} + +bool TcpSocket::SetBlocking(bool a_Block) { + unsigned long mode = !a_Block; + + if (ioctl(m_Handle, FIONBIO, &mode) < 0) { + return false; + } + + return true; +} + +TcpSocket::Status TcpSocket::GetStatus() const { + return m_Status; +} + +void TcpSocket::Disconnect() { + if (m_Handle > 0) + closesocket(m_Handle); + m_Status = Status::Disconnected; +} + +TcpSocket& TcpSocket::operator=(TcpSocket&& a_Other) { + std::swap(m_Handle, a_Other.m_Handle); + std::swap(m_Status, a_Other.m_Status); + return *this; +} + +} // namespace io +} // namespace sp diff --git a/test/test_message.cpp b/test/test_message.cpp index 6b8a49e..fadd268 100644 --- a/test/test_message.cpp +++ b/test/test_message.cpp @@ -5,6 +5,8 @@ #include #include +#include + #include #include #include @@ -20,6 +22,7 @@ using Message = sp::ConcreteMessage; struct KeepAlivePacket { std::uint64_t m_KeepAlive; + std::string mdc; }; using KeepAliveMessage = Message; @@ -31,7 +34,7 @@ class PacketHandler : public sp::MessageHandler {}; class MyHandler : public PacketHandler { public: virtual void Handle(const KeepAlivePacket& msg) { - std::cout << "I recieved a keep alive : " << msg.m_KeepAlive << "\n"; + std::cout << "I recieved a keep alive : " << msg.m_KeepAlive << " : " << msg.mdc << "\n"; } }; @@ -42,7 +45,7 @@ using PacketFactory = sp::MessageFactory; using PacketStream = sp::MessageStream; int main() { - KeepAliveMessage m{69UL}; + KeepAliveMessage m{69UL, "ceci est une mdc aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}; // dispatch tests @@ -58,9 +61,11 @@ int main() { // write tests + auto compress = std::make_shared(); + std::ofstream file {"test.bin"}; - PacketStream p(std::make_shared(file)); + PacketStream p(std::make_shared(file), compress); p.WriteMessage(m); @@ -68,7 +73,7 @@ int main() { std::ifstream file2 {"test.bin"}; - PacketStream p2(std::make_shared(file2)); + PacketStream p2(std::make_shared(file2), compress); auto message2 = p2.ReadMessage();