#include "misc/Compression.h" #include #include #define COMPRESSION_THRESHOLD 64 namespace td { namespace utils { unsigned long inflate(const std::string& source, std::string& dest) { unsigned long size = dest.size(); uncompress((Bytef*)&dest[0], &size, (const Bytef*)source.c_str(), source.length()); return size; } unsigned long deflate(const std::string& source, std::string& dest) { unsigned long size = source.length(); dest.resize(size); compress((Bytef*)&dest[0], &size, (const Bytef*)source.c_str(), source.length()); dest.resize(size); return size; } DataBuffer Compress(const DataBuffer& buffer) { std::string compressedData; DataBuffer packet; if (buffer.GetSize() < COMPRESSION_THRESHOLD) { // Don't compress since it's a small packet std::uint64_t dataLength = 0; std::uint64_t packetLength = buffer.GetSize() + sizeof(dataLength); packet << packetLength; packet << dataLength; packet << buffer; return packet; } deflate(buffer.ToString(), compressedData); std::uint64_t dataLength = buffer.GetSize(); std::uint64_t packetLength = compressedData.length() + sizeof(dataLength); packet << packetLength; packet << dataLength; packet << compressedData; return packet; } DataBuffer Decompress(DataBuffer& buffer, std::uint64_t packetLength) { std::uint64_t uncompressedLength; buffer >> uncompressedLength; std::uint64_t compressedLength = packetLength - sizeof(uncompressedLength); if (uncompressedLength == 0) { // Uncompressed DataBuffer ret; buffer.ReadSome(ret, compressedLength); return ret; } assert(buffer.GetReadOffset() + compressedLength <= buffer.GetSize()); std::string deflatedData; buffer.ReadSome(deflatedData, compressedLength); std::string inflated; inflated.resize(uncompressedLength); inflate(deflatedData, inflated); assert(inflated.length() == uncompressedLength); return DataBuffer(inflated); } DataBuffer Decompress(DataBuffer& buffer) { std::uint64_t packetLength; buffer >> packetLength; return Decompress(buffer, packetLength); } } // namespace utils } // namespace td