fix compression

This commit is contained in:
2025-02-26 15:09:23 +01:00
parent 68fcd514a3
commit 6a52b7fe2a
5 changed files with 25 additions and 31 deletions

View File

@@ -16,14 +16,7 @@ namespace zlib {
* \param buffer the data to compress
* \return the compressed data
*/
DataBuffer Compress(const DataBuffer& buffer);
/**
* \brief Reads the packet lenght and uncompress it
* \param buffer the data to uncompress
* \return the uncompressed data
*/
DataBuffer Decompress(DataBuffer& buffer);
DataBuffer Compress(const DataBuffer& buffer, std::size_t a_CompressionThreshold = 64);
/**
* \brief Uncompress some data

View File

@@ -1,6 +1,7 @@
#pragma once
#include <stdexcept>
#include <sp/extensions/Compress.h>
namespace sp {
namespace io {
@@ -15,8 +16,8 @@ template <typename IOTag, typename MessageDispatcher, typename MessageFactory>
void Stream<IOTag, MessageDispatcher, MessageFactory>::SendMessage(const MessageBase& a_Message) {
// TODO: process compress + encryption
DataBuffer data = a_Message.Write();
DataBuffer dataSize;
m_Interface.Write(dataSize << sp::VarInt{data.GetSize()} << data);
DataBuffer compressed = zlib::Compress(data);
m_Interface.Write(compressed);
}
template <typename IOTag, typename MessageDispatcher, typename MessageFactory>
@@ -60,6 +61,8 @@ void Stream<IOTag, MessageDispatcher, MessageFactory>::RecieveMessages() {
// TODO: process compress + encryption
buffer = zlib::Decompress(buffer, lenghtValue);
VarInt packetType;
buffer >> packetType;

View File

@@ -4,8 +4,6 @@
#include <sp/common/VarInt.h>
#include <zlib.h>
#define COMPRESSION_THRESHOLD 64
namespace sp {
namespace zlib {
@@ -13,8 +11,8 @@ static DataBuffer Inflate(const std::uint8_t* source, std::size_t size, std::siz
DataBuffer result;
result.Resize(uncompressedSize);
uncompress(reinterpret_cast<Bytef*>(result.data()), reinterpret_cast<uLongf*>(&uncompressedSize),
reinterpret_cast<const Bytef*>(source), static_cast<uLong>(size));
uncompress(static_cast<Bytef*>(result.data()), static_cast<uLongf*>(&uncompressedSize), static_cast<const Bytef*>(source),
static_cast<uLong>(size));
assert(result.GetSize() == uncompressedSize);
return result;
@@ -25,20 +23,19 @@ static DataBuffer Deflate(const std::uint8_t* source, std::size_t size) {
uLongf compressedSize = size;
result.Resize(size); // Resize for the compressed data to fit into
compress(
reinterpret_cast<Bytef*>(result.data()), &compressedSize, reinterpret_cast<const Bytef*>(source), static_cast<uLong>(size));
compress(static_cast<Bytef*>(result.data()), &compressedSize, static_cast<const Bytef*>(source), static_cast<uLong>(size));
result.Resize(compressedSize); // Resize to cut useless data
return result;
}
DataBuffer Compress(const DataBuffer& buffer) {
DataBuffer Compress(const DataBuffer& buffer, std::size_t a_CompressionThreshold) {
DataBuffer packet;
if (buffer.GetSize() < COMPRESSION_THRESHOLD) {
if (buffer.GetSize() < a_CompressionThreshold) {
// Don't compress since it's a small packet
VarInt compressedDataLength = 0;
std::uint64_t packetLength = compressedDataLength.GetSerializedLength() + buffer.GetSize();
VarInt packetLength = compressedDataLength.GetSerializedLength() + buffer.GetSize();
packet << packetLength;
packet << compressedDataLength;
@@ -49,11 +46,19 @@ DataBuffer Compress(const DataBuffer& buffer) {
DataBuffer compressedData = Deflate(buffer.data(), buffer.GetSize());
VarInt uncompressedDataLength = buffer.GetSize();
std::uint64_t packetLength = uncompressedDataLength.GetSerializedLength() + compressedData.GetSize();
VarInt packetLength = uncompressedDataLength.GetSerializedLength() + compressedData.GetSize();
packet << packetLength;
if (compressedData.GetSize() >= buffer.GetSize()) {
// the compression is overkill so we don't send the compressed buffer
packet << VarInt{0};
packet << buffer;
} else {
packet << uncompressedDataLength;
packet.WriteSome(compressedData.data(), compressedData.GetSize());
packet << compressedData;
}
return packet;
}
@@ -75,12 +80,5 @@ DataBuffer Decompress(DataBuffer& buffer, std::uint64_t packetLength) {
return Inflate(buffer.data() + buffer.GetReadOffset(), compressedLength, uncompressedLength.GetValue());
}
DataBuffer Decompress(DataBuffer& buffer) {
std::uint64_t packetLength;
buffer >> packetLength;
return Decompress(buffer, packetLength);
}
} // namespace zlib
} // namespace sp

View File

@@ -28,7 +28,7 @@ int main() {
stream.SendMessage(KeepAlivePacket{96});
stream.SendMessage(KeepAlivePacket{69});
stream.SendMessage(DisconnectPacket{"This is in the file !"});
stream.SendMessage(DisconnectPacket{"This is in the fiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiile !"});
stream.RecieveMessages();

View File

@@ -33,7 +33,7 @@ int main() {
stream.SendMessage(KeepAlivePacket{69});
stream.RecieveMessages();
stream.SendMessage(DisconnectPacket{"I don't know"});
stream.SendMessage(DisconnectPacket{"A valid reason"});
stream.RecieveMessages();
return 0;