#include #include #include #include #define COMPRESSION_THRESHOLD 64 namespace sp { namespace zlib { static DataBuffer Inflate(const std::uint8_t* source, std::size_t size, std::size_t uncompressedSize) { DataBuffer result; result.Resize(uncompressedSize); uncompress(reinterpret_cast(result.data()), reinterpret_cast(&uncompressedSize), reinterpret_cast(source), static_cast(size)); assert(result.GetSize() == uncompressedSize); return result; } static DataBuffer Deflate(const std::uint8_t* source, std::size_t size) { DataBuffer result; uLongf compressedSize = size; result.Resize(size); // Resize for the compressed data to fit into compress( reinterpret_cast(result.data()), &compressedSize, reinterpret_cast(source), static_cast(size)); result.Resize(compressedSize); // Resize to cut useless data return result; } DataBuffer Compress(const DataBuffer& buffer) { DataBuffer packet; if (buffer.GetSize() < COMPRESSION_THRESHOLD) { // Don't compress since it's a small packet VarInt compressedDataLength = 0; std::uint64_t packetLength = compressedDataLength.GetSerializedLength() + buffer.GetSize(); packet << packetLength; packet << compressedDataLength; packet << buffer; return packet; } DataBuffer compressedData = Deflate(buffer.data(), buffer.GetSize()); VarInt uncompressedDataLength = buffer.GetSize(); std::uint64_t packetLength = uncompressedDataLength.GetSerializedLength() + compressedData.GetSize(); packet << packetLength; packet << uncompressedDataLength; packet.WriteSome(compressedData.data(), compressedData.GetSize()); return packet; } DataBuffer Decompress(DataBuffer& buffer, std::uint64_t packetLength) { VarInt uncompressedLength; buffer >> uncompressedLength; std::uint64_t compressedLength = packetLength - uncompressedLength.GetSerializedLength(); if (uncompressedLength.GetValue() == 0) { // Data already uncompressed. Nothing to do DataBuffer ret; buffer.ReadSome(ret, compressedLength); return ret; } assert(buffer.GetReadOffset() + compressedLength <= buffer.GetSize()); return Inflate(buffer.data() + buffer.GetReadOffset(), compressedLength, uncompressedLength.GetValue()); } DataBuffer Decompress(DataBuffer& buffer) { std::uint64_t packetLength; buffer >> packetLength; return Decompress(buffer, packetLength); } } // namespace zlib } // namespace sp