fix compression

This commit is contained in:
2025-02-26 15:09:23 +01:00
parent 68fcd514a3
commit 6a52b7fe2a
5 changed files with 25 additions and 31 deletions

View File

@@ -16,14 +16,7 @@ namespace zlib {
* \param buffer the data to compress * \param buffer the data to compress
* \return the compressed data * \return the compressed data
*/ */
DataBuffer Compress(const DataBuffer& buffer); DataBuffer Compress(const DataBuffer& buffer, std::size_t a_CompressionThreshold = 64);
/**
* \brief Reads the packet lenght and uncompress it
* \param buffer the data to uncompress
* \return the uncompressed data
*/
DataBuffer Decompress(DataBuffer& buffer);
/** /**
* \brief Uncompress some data * \brief Uncompress some data

View File

@@ -1,6 +1,7 @@
#pragma once #pragma once
#include <stdexcept> #include <stdexcept>
#include <sp/extensions/Compress.h>
namespace sp { namespace sp {
namespace io { namespace io {
@@ -15,8 +16,8 @@ template <typename IOTag, typename MessageDispatcher, typename MessageFactory>
void Stream<IOTag, MessageDispatcher, MessageFactory>::SendMessage(const MessageBase& a_Message) { void Stream<IOTag, MessageDispatcher, MessageFactory>::SendMessage(const MessageBase& a_Message) {
// TODO: process compress + encryption // TODO: process compress + encryption
DataBuffer data = a_Message.Write(); DataBuffer data = a_Message.Write();
DataBuffer dataSize; DataBuffer compressed = zlib::Compress(data);
m_Interface.Write(dataSize << sp::VarInt{data.GetSize()} << data); m_Interface.Write(compressed);
} }
template <typename IOTag, typename MessageDispatcher, typename MessageFactory> template <typename IOTag, typename MessageDispatcher, typename MessageFactory>
@@ -60,6 +61,8 @@ void Stream<IOTag, MessageDispatcher, MessageFactory>::RecieveMessages() {
// TODO: process compress + encryption // TODO: process compress + encryption
buffer = zlib::Decompress(buffer, lenghtValue);
VarInt packetType; VarInt packetType;
buffer >> packetType; buffer >> packetType;

View File

@@ -4,8 +4,6 @@
#include <sp/common/VarInt.h> #include <sp/common/VarInt.h>
#include <zlib.h> #include <zlib.h>
#define COMPRESSION_THRESHOLD 64
namespace sp { namespace sp {
namespace zlib { namespace zlib {
@@ -13,8 +11,8 @@ static DataBuffer Inflate(const std::uint8_t* source, std::size_t size, std::siz
DataBuffer result; DataBuffer result;
result.Resize(uncompressedSize); result.Resize(uncompressedSize);
uncompress(reinterpret_cast<Bytef*>(result.data()), reinterpret_cast<uLongf*>(&uncompressedSize), uncompress(static_cast<Bytef*>(result.data()), static_cast<uLongf*>(&uncompressedSize), static_cast<const Bytef*>(source),
reinterpret_cast<const Bytef*>(source), static_cast<uLong>(size)); static_cast<uLong>(size));
assert(result.GetSize() == uncompressedSize); assert(result.GetSize() == uncompressedSize);
return result; return result;
@@ -25,20 +23,19 @@ static DataBuffer Deflate(const std::uint8_t* source, std::size_t size) {
uLongf compressedSize = size; uLongf compressedSize = size;
result.Resize(size); // Resize for the compressed data to fit into result.Resize(size); // Resize for the compressed data to fit into
compress( compress(static_cast<Bytef*>(result.data()), &compressedSize, static_cast<const Bytef*>(source), static_cast<uLong>(size));
reinterpret_cast<Bytef*>(result.data()), &compressedSize, reinterpret_cast<const Bytef*>(source), static_cast<uLong>(size));
result.Resize(compressedSize); // Resize to cut useless data result.Resize(compressedSize); // Resize to cut useless data
return result; return result;
} }
DataBuffer Compress(const DataBuffer& buffer) { DataBuffer Compress(const DataBuffer& buffer, std::size_t a_CompressionThreshold) {
DataBuffer packet; DataBuffer packet;
if (buffer.GetSize() < COMPRESSION_THRESHOLD) { if (buffer.GetSize() < a_CompressionThreshold) {
// Don't compress since it's a small packet // Don't compress since it's a small packet
VarInt compressedDataLength = 0; VarInt compressedDataLength = 0;
std::uint64_t packetLength = compressedDataLength.GetSerializedLength() + buffer.GetSize(); VarInt packetLength = compressedDataLength.GetSerializedLength() + buffer.GetSize();
packet << packetLength; packet << packetLength;
packet << compressedDataLength; packet << compressedDataLength;
@@ -49,11 +46,19 @@ DataBuffer Compress(const DataBuffer& buffer) {
DataBuffer compressedData = Deflate(buffer.data(), buffer.GetSize()); DataBuffer compressedData = Deflate(buffer.data(), buffer.GetSize());
VarInt uncompressedDataLength = buffer.GetSize(); VarInt uncompressedDataLength = buffer.GetSize();
std::uint64_t packetLength = uncompressedDataLength.GetSerializedLength() + compressedData.GetSize(); VarInt packetLength = uncompressedDataLength.GetSerializedLength() + compressedData.GetSize();
packet << packetLength; packet << packetLength;
packet << uncompressedDataLength;
packet.WriteSome(compressedData.data(), compressedData.GetSize()); if (compressedData.GetSize() >= buffer.GetSize()) {
// the compression is overkill so we don't send the compressed buffer
packet << VarInt{0};
packet << buffer;
} else {
packet << uncompressedDataLength;
packet << compressedData;
}
return packet; return packet;
} }
@@ -75,12 +80,5 @@ DataBuffer Decompress(DataBuffer& buffer, std::uint64_t packetLength) {
return Inflate(buffer.data() + buffer.GetReadOffset(), compressedLength, uncompressedLength.GetValue()); return Inflate(buffer.data() + buffer.GetReadOffset(), compressedLength, uncompressedLength.GetValue());
} }
DataBuffer Decompress(DataBuffer& buffer) {
std::uint64_t packetLength;
buffer >> packetLength;
return Decompress(buffer, packetLength);
}
} // namespace zlib } // namespace zlib
} // namespace sp } // namespace sp

View File

@@ -28,7 +28,7 @@ int main() {
stream.SendMessage(KeepAlivePacket{96}); stream.SendMessage(KeepAlivePacket{96});
stream.SendMessage(KeepAlivePacket{69}); stream.SendMessage(KeepAlivePacket{69});
stream.SendMessage(DisconnectPacket{"This is in the file !"}); stream.SendMessage(DisconnectPacket{"This is in the fiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiile !"});
stream.RecieveMessages(); stream.RecieveMessages();

View File

@@ -33,7 +33,7 @@ int main() {
stream.SendMessage(KeepAlivePacket{69}); stream.SendMessage(KeepAlivePacket{69});
stream.RecieveMessages(); stream.RecieveMessages();
stream.SendMessage(DisconnectPacket{"I don't know"}); stream.SendMessage(DisconnectPacket{"A valid reason"});
stream.RecieveMessages(); stream.RecieveMessages();
return 0; return 0;