90 lines
2.7 KiB
C++
90 lines
2.7 KiB
C++
#include <sp/extensions/Compress.h>
|
|
|
|
#include <cassert>
|
|
#include <sp/common/VarInt.h>
|
|
#include <zlib.h>
|
|
|
|
namespace sp {
|
|
namespace zlib {
|
|
|
|
static DataBuffer Inflate(const std::uint8_t* source, std::size_t size, std::size_t uncompressedSize) {
|
|
DataBuffer result;
|
|
result.Resize(uncompressedSize);
|
|
|
|
uncompress(static_cast<Bytef*>(result.data()), reinterpret_cast<uLongf*>(&uncompressedSize), static_cast<const Bytef*>(source),
|
|
static_cast<uLong>(size));
|
|
|
|
assert(result.GetSize() == uncompressedSize);
|
|
return result;
|
|
}
|
|
|
|
static DataBuffer Deflate(const std::uint8_t* source, std::size_t size) {
|
|
DataBuffer result;
|
|
uLongf compressedSize = size;
|
|
|
|
result.Resize(size); // Resize for the compressed data to fit into
|
|
compress(static_cast<Bytef*>(result.data()), &compressedSize, static_cast<const Bytef*>(source), static_cast<uLong>(size));
|
|
result.Resize(compressedSize); // Resize to cut useless data
|
|
|
|
return result;
|
|
}
|
|
|
|
DataBuffer Compress(const DataBuffer& buffer, std::size_t a_CompressionThreshold) {
|
|
DataBuffer packet;
|
|
|
|
if (buffer.GetSize() < a_CompressionThreshold) {
|
|
// Don't compress since it's a small packet
|
|
packet << VarInt{0};
|
|
packet << buffer;
|
|
return packet;
|
|
}
|
|
|
|
DataBuffer compressedData = Deflate(buffer.data(), buffer.GetSize());
|
|
VarInt uncompressedDataLength = buffer.GetSize();
|
|
|
|
if (compressedData.GetSize() >= buffer.GetSize()) {
|
|
// the compression is overkill so we don't send the compressed buffer
|
|
packet << VarInt{0};
|
|
packet << buffer;
|
|
} else {
|
|
packet << uncompressedDataLength;
|
|
packet << compressedData;
|
|
}
|
|
|
|
return packet;
|
|
}
|
|
|
|
DataBuffer Decompress(DataBuffer& buffer, std::uint64_t packetLength) {
|
|
VarInt uncompressedLength;
|
|
buffer >> uncompressedLength;
|
|
|
|
std::uint64_t compressedLength = packetLength - uncompressedLength.GetSerializedLength();
|
|
|
|
if (uncompressedLength.GetValue() == 0) {
|
|
// Data already uncompressed. Nothing to do
|
|
DataBuffer ret;
|
|
buffer.ReadSome(ret, compressedLength);
|
|
return ret;
|
|
}
|
|
|
|
assert(buffer.GetReadOffset() + compressedLength <= buffer.GetSize());
|
|
|
|
return Inflate(buffer.data() + buffer.GetReadOffset(), compressedLength, uncompressedLength.GetValue());
|
|
}
|
|
|
|
} // namespace zlib
|
|
|
|
namespace io {
|
|
|
|
DataBuffer MessageEncapsulator<option::ZlibCompress>::Encapsulate(const DataBuffer& a_Data, const option::ZlibCompress& a_Option) {
|
|
static constexpr std::size_t MAX_COMPRESS_THRESHOLD = VarInt::MAX_VALUE;
|
|
return zlib::Compress(a_Data, a_Option.m_Enabled ? a_Option.m_CompressionThreshold : MAX_COMPRESS_THRESHOLD);
|
|
}
|
|
|
|
DataBuffer MessageEncapsulator<option::ZlibCompress>::Decapsulate(DataBuffer& a_Data, const option::ZlibCompress& a_Option) {
|
|
return zlib::Decompress(a_Data, a_Data.GetSize());
|
|
}
|
|
|
|
} // namespace io
|
|
} // namespace sp
|