diff --git a/include/sp/extensions/Compress.h b/include/sp/extensions/Compress.h new file mode 100644 index 0000000..c5383ac --- /dev/null +++ b/include/sp/extensions/Compress.h @@ -0,0 +1,37 @@ +#pragma once + +/** + * \file Compression.h + * \brief File containing compress utilities + */ + +#include +#include + +namespace sp { +namespace zlib { + +/** + * \brief Compress some data + * \param buffer the data to compress + * \return the compressed data + */ +DataBuffer Compress(const DataBuffer& buffer); + +/** + * \brief Reads the packet lenght and uncompress it + * \param buffer the data to uncompress + * \return the uncompressed data + */ +DataBuffer Decompress(DataBuffer& buffer); + +/** + * \brief Uncompress some data + * \param buffer the data to uncompress + * \param packetLength lenght of data + * \return the uncompressed data + */ +DataBuffer Decompress(DataBuffer& buffer, std::uint64_t packetLength); + +} // namespace zlib +} // namespace sp diff --git a/src/sp/extensions/Compress.cpp b/src/sp/extensions/Compress.cpp new file mode 100644 index 0000000..8995f3f --- /dev/null +++ b/src/sp/extensions/Compress.cpp @@ -0,0 +1,86 @@ +#include + +#include +#include +#include + +#define COMPRESSION_THRESHOLD 64 + +namespace sp { +namespace zlib { + +static DataBuffer Inflate(const std::uint8_t* source, std::size_t size, std::size_t uncompressedSize) { + DataBuffer result; + result.Resize(uncompressedSize); + + uncompress(reinterpret_cast(result.data()), reinterpret_cast(&uncompressedSize), + reinterpret_cast(source), static_cast(size)); + + assert(result.GetSize() == uncompressedSize); + return result; +} + +static DataBuffer Deflate(const std::uint8_t* source, std::size_t size) { + DataBuffer result; + uLongf compressedSize = size; + + result.Resize(size); // Resize for the compressed data to fit into + compress( + reinterpret_cast(result.data()), &compressedSize, reinterpret_cast(source), static_cast(size)); + result.Resize(compressedSize); // Resize to cut useless data + + return result; +} + +DataBuffer Compress(const DataBuffer& buffer) { + DataBuffer packet; + + if (buffer.GetSize() < COMPRESSION_THRESHOLD) { + // Don't compress since it's a small packet + VarInt compressedDataLength = 0; + std::uint64_t packetLength = compressedDataLength.GetSerializedLength() + buffer.GetSize(); + + packet << packetLength; + packet << compressedDataLength; + packet << buffer; + return packet; + } + + DataBuffer compressedData = Deflate(buffer.data(), buffer.GetSize()); + + VarInt uncompressedDataLength = buffer.GetSize(); + std::uint64_t packetLength = uncompressedDataLength.GetSerializedLength() + compressedData.GetSize(); + + packet << packetLength; + packet << uncompressedDataLength; + packet.WriteSome(compressedData.data(), compressedData.GetSize()); + return packet; +} + +DataBuffer Decompress(DataBuffer& buffer, std::uint64_t packetLength) { + VarInt uncompressedLength; + buffer >> uncompressedLength; + + std::uint64_t compressedLength = packetLength - uncompressedLength.GetSerializedLength(); + + if (uncompressedLength.GetValue() == 0) { + // Data already uncompressed. Nothing to do + DataBuffer ret; + buffer.ReadSome(ret, compressedLength); + return ret; + } + + assert(buffer.GetReadOffset() + compressedLength <= buffer.GetSize()); + + return Inflate(buffer.data() + buffer.GetReadOffset(), compressedLength, uncompressedLength.GetValue()); +} + +DataBuffer Decompress(DataBuffer& buffer) { + std::uint64_t packetLength; + buffer >> packetLength; + + return Decompress(buffer, packetLength); +} + +} // namespace zlib +} // namespace sp diff --git a/xmake.lua b/xmake.lua index 0c1c8af..6bd68e8 100644 --- a/xmake.lua +++ b/xmake.lua @@ -2,11 +2,54 @@ add_rules("mode.debug", "mode.release") set_languages("c++17") +local modules = { + Compression = { + Option = "zlib", + Deps = {"zlib"}, + Packages = {"zlib"}, + Includes = {"include/(sp/extensions/Compress.h)"}, + Sources = {"src/sp/extensions/Compress.cpp"} + } +} + +-- Map modules to options +for name, module in table.orderpairs(modules) do + if module.Option then + option(module.Option, { description = "Enables the " .. name .. " module", default = true, category = "Modules" }) + end +end + +-- Add modules requirements +for name, module in table.orderpairs(modules) do + if module.Deps then + add_requires(module.Deps) + end +end + +-- Add modules targets +for name, module in table.orderpairs(modules) do + if module.Deps and has_config(module.Option) then + target("SimpleProtocolLib-" .. name) + add_includedirs("include") + for _, include in table.orderpairs(module.Includes) do + add_headerfiles(include) + end + for _, source in table.orderpairs(module.Sources) do + add_files(source) + end + for _, package in table.orderpairs(module.Packages) do + add_packages(package) + end + set_group("Library") + set_kind("$(kind)") + end +end + target("SimpleProtocolLib") add_includedirs("include") - add_headerfiles("include/(sp/**.h)") + add_headerfiles("include/(sp/common/**.h)", "include/(sp/common/**.h)", "include/(sp/common/**.h)") set_group("Library") - add_files("src/sp/**.cpp") + add_files("src/sp/common/*.cpp") set_kind("$(kind)") -- Tests