From 5beb5e92a77f9db43a05db98579e1da7c64c2c08 Mon Sep 17 00:00:00 2001 From: Persson-dev Date: Sat, 1 Mar 2025 18:20:51 +0000 Subject: [PATCH] zlib support (#9) Reviewed-on: https://git.ale-pri.com/Persson-dev/Simple-Protocol-Lib/pulls/9 Co-authored-by: Persson-dev Co-committed-by: Persson-dev --- include/sp/common/DataBuffer.h | 2 +- include/sp/common/VarInt.h | 2 + include/sp/extensions/Compress.h | 33 +++++++++---- include/sp/io/IOInterface.h | 22 ++++++++- include/sp/io/IOInterfaceImpl.inl | 79 +++++++++++++++++++++++++------ src/sp/extensions/Compress.cpp | 49 ++++++++++--------- test/test_file.cpp | 12 +++-- test/test_io.cpp | 2 +- xmake.lua | 13 +++-- 9 files changed, 158 insertions(+), 56 deletions(-) diff --git a/include/sp/common/DataBuffer.h b/include/sp/common/DataBuffer.h index c387eb9..d9cf312 100644 --- a/include/sp/common/DataBuffer.h +++ b/include/sp/common/DataBuffer.h @@ -166,7 +166,7 @@ class DataBuffer { K newKey; V newValue; *this >> newKey >> newValue; - data.insert({newKey, newValue}); + data.emplace(newKey, newValue); } return *this; } diff --git a/include/sp/common/VarInt.h b/include/sp/common/VarInt.h index 8da0069..4d9e601 100644 --- a/include/sp/common/VarInt.h +++ b/include/sp/common/VarInt.h @@ -21,6 +21,8 @@ class VarInt { std::uint64_t m_Value; public: + static const std::uint64_t MAX_VALUE = static_cast(-1) >> 8; + VarInt() : m_Value(0) {} /** * \brief Construct a variable integer from a value diff --git a/include/sp/extensions/Compress.h b/include/sp/extensions/Compress.h index c5383ac..b3b06e3 100644 --- a/include/sp/extensions/Compress.h +++ b/include/sp/extensions/Compress.h @@ -8,6 +8,19 @@ #include #include +namespace sp { +namespace option { + +struct ZlibCompress { + bool m_Enabled = true; + std::size_t m_CompressionThreshold = 64; +}; + +} // namespace option +} // namespace sp + +#include + namespace sp { namespace zlib { @@ -16,14 +29,7 @@ namespace zlib { * \param buffer the data to compress * \return the compressed data */ -DataBuffer Compress(const DataBuffer& buffer); - -/** - * \brief Reads the packet lenght and uncompress it - * \param buffer the data to uncompress - * \return the uncompressed data - */ -DataBuffer Decompress(DataBuffer& buffer); +DataBuffer Compress(const DataBuffer& buffer, std::size_t a_CompressionThreshold = 64); /** * \brief Uncompress some data @@ -34,4 +40,15 @@ DataBuffer Decompress(DataBuffer& buffer); DataBuffer Decompress(DataBuffer& buffer, std::uint64_t packetLength); } // namespace zlib + +namespace io { + +template <> +class MessageEncapsulator { + public: + static DataBuffer Encapsulate(const DataBuffer& a_Data, const option::ZlibCompress& a_Option); + static DataBuffer Decapsulate(DataBuffer& a_Data, const option::ZlibCompress& a_Option); +}; + +} // namespace io } // namespace sp diff --git a/include/sp/io/IOInterface.h b/include/sp/io/IOInterface.h index 3ab050b..aca8cbd 100644 --- a/include/sp/io/IOInterface.h +++ b/include/sp/io/IOInterface.h @@ -13,25 +13,43 @@ class IOInterface { void Write(const DataBuffer& a_Data); }; -template +template +class MessageEncapsulator { + public: + static DataBuffer Encapsulate(const DataBuffer& a_Data, const TOption& a_Option); + static DataBuffer Decapsulate(DataBuffer& a_Data, const TOption& a_Option); +}; + +template class Stream { protected: MessageDispatcher m_Dispatcher; IOInterface m_Interface; + std::tuple m_Options; using MessageBase = typename MessageDispatcher::MessageBaseType; public: Stream() {} - Stream(IOInterface&& a_Interface); + Stream(IOInterface&& a_Interface, TOptions&&... a_Options); Stream(Stream&& a_Stream); void RecieveMessages(); void SendMessage(const MessageBase& a_Message); + + template + TOption& GetOption() { + return std::get(m_Options); + } + MessageDispatcher& GetDispatcher() { return m_Dispatcher; } + + private: + static DataBuffer Encapsulate(const DataBuffer& a_Data, const TOptions&... a_Options); + static DataBuffer Decapsulate(DataBuffer& a_Data, const TOptions&... a_Options); }; } // namespace io diff --git a/include/sp/io/IOInterfaceImpl.inl b/include/sp/io/IOInterfaceImpl.inl index 501c30a..5aabee2 100644 --- a/include/sp/io/IOInterfaceImpl.inl +++ b/include/sp/io/IOInterfaceImpl.inl @@ -1,29 +1,64 @@ #pragma once +#include #include namespace sp { -namespace io { -template -Stream::Stream(IOInterface&& a_Interface) : m_Interface(std::move(a_Interface)) {} -template -Stream::Stream(Stream&& a_Stream) : +namespace details { + +template +struct MessageEncapsulatorPack {}; + +template <> +struct MessageEncapsulatorPack<> { + static DataBuffer Encapsulate(const DataBuffer& a_Data) { + return a_Data; + } + static DataBuffer Decapsulate(DataBuffer& a_Data) { + return a_Data; + } +}; + +template +struct MessageEncapsulatorPack { + static DataBuffer Encapsulate(const DataBuffer& a_Data, const TOption& a_Option, const TOptions&... a_Options) { + DataBuffer data = io::MessageEncapsulator::Encapsulate(a_Data, a_Option); + return MessageEncapsulatorPack::Encapsulate(data, a_Options...); + } + static DataBuffer Decapsulate(DataBuffer& a_Data, const TOption& a_Option, const TOptions&... a_Options) { + DataBuffer data = io::MessageEncapsulator::Decapsulate(a_Data, a_Option); + return MessageEncapsulatorPack::Decapsulate(data, a_Options...); + } +}; + +} // namespace details + +namespace io { + +template +Stream::Stream(IOInterface&& a_Interface, TOptions&&... a_Options) : + m_Interface(std::move(a_Interface)), m_Options(std::make_tuple(std::move(a_Options)...)) {} + +template +Stream::Stream( + Stream&& a_Stream) : m_Dispatcher(std::move(a_Stream.m_Dispatcher)), m_Interface(std::move(a_Stream.m_Interface)) {} -template -void Stream::SendMessage(const MessageBase& a_Message) { - // TODO: process compress + encryption +template +void Stream::SendMessage(const MessageBase& a_Message) { DataBuffer data = a_Message.Write(); - DataBuffer dataSize; - m_Interface.Write(dataSize << sp::VarInt{data.GetSize()} << data); + DataBuffer encapsulated = std::apply([&data](const auto&... a_Options) { + return Encapsulate(data, a_Options...); + }, m_Options); + DataBuffer finalData; + finalData << VarInt{encapsulated.GetSize()} << encapsulated; + m_Interface.Write(finalData); } -template -void Stream::RecieveMessages() { - // TODO: process compress + encryption +template +void Stream::RecieveMessages() { while (true) { - // reading the first VarInt part byte by byte std::uint64_t lenghtValue = 0; unsigned int readPos = 0; @@ -58,7 +93,9 @@ void Stream::RecieveMessages() { DataBuffer buffer; buffer = m_Interface.Read(lenghtValue); - // TODO: process compress + encryption + buffer = std::apply([&buffer, lenghtValue](const auto&... a_Options) { + return Decapsulate(buffer, a_Options...); + }, m_Options); VarInt packetType; buffer >> packetType; @@ -75,5 +112,17 @@ void Stream::RecieveMessages() { } } +template +DataBuffer Stream::Encapsulate( + const DataBuffer& a_Data, const TOptions&... a_Options) { + return details::MessageEncapsulatorPack::Encapsulate(a_Data, a_Options...); +} + +template +DataBuffer Stream::Decapsulate( + DataBuffer& a_Data, const TOptions&... a_Options) { + return details::MessageEncapsulatorPack::Decapsulate(a_Data, a_Options...); +} + } // namespace io } // namespace sp \ No newline at end of file diff --git a/src/sp/extensions/Compress.cpp b/src/sp/extensions/Compress.cpp index 8995f3f..7d50fd5 100644 --- a/src/sp/extensions/Compress.cpp +++ b/src/sp/extensions/Compress.cpp @@ -4,8 +4,6 @@ #include #include -#define COMPRESSION_THRESHOLD 64 - namespace sp { namespace zlib { @@ -13,8 +11,8 @@ static DataBuffer Inflate(const std::uint8_t* source, std::size_t size, std::siz DataBuffer result; result.Resize(uncompressedSize); - uncompress(reinterpret_cast(result.data()), reinterpret_cast(&uncompressedSize), - reinterpret_cast(source), static_cast(size)); + uncompress(static_cast(result.data()), static_cast(&uncompressedSize), static_cast(source), + static_cast(size)); assert(result.GetSize() == uncompressedSize); return result; @@ -25,35 +23,34 @@ static DataBuffer Deflate(const std::uint8_t* source, std::size_t size) { uLongf compressedSize = size; result.Resize(size); // Resize for the compressed data to fit into - compress( - reinterpret_cast(result.data()), &compressedSize, reinterpret_cast(source), static_cast(size)); + compress(static_cast(result.data()), &compressedSize, static_cast(source), static_cast(size)); result.Resize(compressedSize); // Resize to cut useless data return result; } -DataBuffer Compress(const DataBuffer& buffer) { +DataBuffer Compress(const DataBuffer& buffer, std::size_t a_CompressionThreshold) { DataBuffer packet; - if (buffer.GetSize() < COMPRESSION_THRESHOLD) { + if (buffer.GetSize() < a_CompressionThreshold) { // Don't compress since it's a small packet - VarInt compressedDataLength = 0; - std::uint64_t packetLength = compressedDataLength.GetSerializedLength() + buffer.GetSize(); - - packet << packetLength; - packet << compressedDataLength; + packet << VarInt{0}; packet << buffer; return packet; } DataBuffer compressedData = Deflate(buffer.data(), buffer.GetSize()); - VarInt uncompressedDataLength = buffer.GetSize(); - std::uint64_t packetLength = uncompressedDataLength.GetSerializedLength() + compressedData.GetSize(); - packet << packetLength; - packet << uncompressedDataLength; - packet.WriteSome(compressedData.data(), compressedData.GetSize()); + if (compressedData.GetSize() >= buffer.GetSize()) { + // the compression is overkill so we don't send the compressed buffer + packet << VarInt{0}; + packet << buffer; + } else { + packet << uncompressedDataLength; + packet << compressedData; + } + return packet; } @@ -75,12 +72,18 @@ DataBuffer Decompress(DataBuffer& buffer, std::uint64_t packetLength) { return Inflate(buffer.data() + buffer.GetReadOffset(), compressedLength, uncompressedLength.GetValue()); } -DataBuffer Decompress(DataBuffer& buffer) { - std::uint64_t packetLength; - buffer >> packetLength; +} // namespace zlib - return Decompress(buffer, packetLength); +namespace io { + +DataBuffer MessageEncapsulator::Encapsulate(const DataBuffer& a_Data, const option::ZlibCompress& a_Option) { + static constexpr std::size_t MAX_COMPRESS_THRESHOLD = VarInt::MAX_VALUE; + return zlib::Compress(a_Data, a_Option.m_Enabled ? a_Option.m_CompressionThreshold : MAX_COMPRESS_THRESHOLD); } -} // namespace zlib +DataBuffer MessageEncapsulator::Decapsulate(DataBuffer& a_Data, const option::ZlibCompress& a_Option) { + return zlib::Decompress(a_Data, a_Data.GetSize()); +} + +} // namespace io } // namespace sp diff --git a/test/test_file.cpp b/test/test_file.cpp index f05d696..2a89a3c 100644 --- a/test/test_file.cpp +++ b/test/test_file.cpp @@ -17,18 +17,24 @@ class CustomPacketHandler : public sp::PacketHandler { } }; -using FileStream = sp::io::Stream; +using FileStream = sp::io::Stream; int main() { auto handler = std::make_shared(); - FileStream stream(sp::io::File{"test.txt", sp::io::FileTag::In | sp::io::FileTag::Out}); + FileStream stream(sp::io::File{"test.txt", sp::io::FileTag::In | sp::io::FileTag::Out}, {}); stream.GetDispatcher().RegisterHandler(PacketId::Disconnect, handler); stream.GetDispatcher().RegisterHandler(PacketId::KeepAlive, handler); stream.SendMessage(KeepAlivePacket{96}); stream.SendMessage(KeepAlivePacket{69}); - stream.SendMessage(DisconnectPacket{"This is in the file !"}); + stream.SendMessage(DisconnectPacket{ + "This is in the " + "fiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiile !"}); + stream.GetOption().m_Enabled = false; + stream.SendMessage(DisconnectPacket{ + "This is in the " + "fiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiile !"}); stream.RecieveMessages(); diff --git a/test/test_io.cpp b/test/test_io.cpp index 8bb78ec..c8affb6 100644 --- a/test/test_io.cpp +++ b/test/test_io.cpp @@ -33,7 +33,7 @@ int main() { stream.SendMessage(KeepAlivePacket{69}); stream.RecieveMessages(); - stream.SendMessage(DisconnectPacket{"I don't know"}); + stream.SendMessage(DisconnectPacket{"A valid reason"}); stream.RecieveMessages(); return 0; diff --git a/xmake.lua b/xmake.lua index 79ca7d5..2012f06 100644 --- a/xmake.lua +++ b/xmake.lua @@ -29,7 +29,7 @@ end -- Add modules targets for name, module in table.orderpairs(modules) do if module.Deps and has_config(module.Option) then - target("SimpleProtocolLib-" .. name) + target("SimpleProtocol-" .. name) add_includedirs("include") for _, include in table.orderpairs(module.Includes) do add_headerfiles(include) @@ -45,7 +45,7 @@ for name, module in table.orderpairs(modules) do end end -target("SimpleProtocolLib") +target("SimpleProtocol") add_includedirs("include") add_files("src/sp/**.cpp") @@ -54,6 +54,13 @@ target("SimpleProtocolLib") add_headerfiles("include/(sp/" .. folder .. "/**.h)") end + -- adding extensions + for name, module in table.orderpairs(modules) do + if module.Deps and has_config(module.Option) then + add_deps("SimpleProtocol-" .. name) + end + end + -- we don't want extensions remove_files("src/sp/extensions/**.cpp") set_group("Library") @@ -68,7 +75,7 @@ for _, file in ipairs(os.files("test/**.cpp")) do add_files(file) add_includedirs("include") - add_deps("SimpleProtocolLib") + add_deps("SimpleProtocol") add_tests("compile_and_run") end