diff --git a/include/sp/extensions/Compress.h b/include/sp/extensions/Compress.h index c5383ac..17d36c5 100644 --- a/include/sp/extensions/Compress.h +++ b/include/sp/extensions/Compress.h @@ -16,14 +16,7 @@ namespace zlib { * \param buffer the data to compress * \return the compressed data */ -DataBuffer Compress(const DataBuffer& buffer); - -/** - * \brief Reads the packet lenght and uncompress it - * \param buffer the data to uncompress - * \return the uncompressed data - */ -DataBuffer Decompress(DataBuffer& buffer); +DataBuffer Compress(const DataBuffer& buffer, std::size_t a_CompressionThreshold = 64); /** * \brief Uncompress some data diff --git a/include/sp/io/IOInterfaceImpl.inl b/include/sp/io/IOInterfaceImpl.inl index 501c30a..be73ff0 100644 --- a/include/sp/io/IOInterfaceImpl.inl +++ b/include/sp/io/IOInterfaceImpl.inl @@ -1,6 +1,7 @@ #pragma once #include +#include namespace sp { namespace io { @@ -15,8 +16,8 @@ template void Stream::SendMessage(const MessageBase& a_Message) { // TODO: process compress + encryption DataBuffer data = a_Message.Write(); - DataBuffer dataSize; - m_Interface.Write(dataSize << sp::VarInt{data.GetSize()} << data); + DataBuffer compressed = zlib::Compress(data); + m_Interface.Write(compressed); } template @@ -60,6 +61,8 @@ void Stream::RecieveMessages() { // TODO: process compress + encryption + buffer = zlib::Decompress(buffer, lenghtValue); + VarInt packetType; buffer >> packetType; diff --git a/src/sp/extensions/Compress.cpp b/src/sp/extensions/Compress.cpp index 8995f3f..2af8ac0 100644 --- a/src/sp/extensions/Compress.cpp +++ b/src/sp/extensions/Compress.cpp @@ -4,8 +4,6 @@ #include #include -#define COMPRESSION_THRESHOLD 64 - namespace sp { namespace zlib { @@ -13,8 +11,8 @@ static DataBuffer Inflate(const std::uint8_t* source, std::size_t size, std::siz DataBuffer result; result.Resize(uncompressedSize); - uncompress(reinterpret_cast(result.data()), reinterpret_cast(&uncompressedSize), - reinterpret_cast(source), static_cast(size)); + uncompress(static_cast(result.data()), static_cast(&uncompressedSize), static_cast(source), + static_cast(size)); assert(result.GetSize() == uncompressedSize); return result; @@ -25,20 +23,19 @@ static DataBuffer Deflate(const std::uint8_t* source, std::size_t size) { uLongf compressedSize = size; result.Resize(size); // Resize for the compressed data to fit into - compress( - reinterpret_cast(result.data()), &compressedSize, reinterpret_cast(source), static_cast(size)); + compress(static_cast(result.data()), &compressedSize, static_cast(source), static_cast(size)); result.Resize(compressedSize); // Resize to cut useless data return result; } -DataBuffer Compress(const DataBuffer& buffer) { +DataBuffer Compress(const DataBuffer& buffer, std::size_t a_CompressionThreshold) { DataBuffer packet; - if (buffer.GetSize() < COMPRESSION_THRESHOLD) { + if (buffer.GetSize() < a_CompressionThreshold) { // Don't compress since it's a small packet VarInt compressedDataLength = 0; - std::uint64_t packetLength = compressedDataLength.GetSerializedLength() + buffer.GetSize(); + VarInt packetLength = compressedDataLength.GetSerializedLength() + buffer.GetSize(); packet << packetLength; packet << compressedDataLength; @@ -49,11 +46,19 @@ DataBuffer Compress(const DataBuffer& buffer) { DataBuffer compressedData = Deflate(buffer.data(), buffer.GetSize()); VarInt uncompressedDataLength = buffer.GetSize(); - std::uint64_t packetLength = uncompressedDataLength.GetSerializedLength() + compressedData.GetSize(); + VarInt packetLength = uncompressedDataLength.GetSerializedLength() + compressedData.GetSize(); packet << packetLength; - packet << uncompressedDataLength; - packet.WriteSome(compressedData.data(), compressedData.GetSize()); + + if (compressedData.GetSize() >= buffer.GetSize()) { + // the compression is overkill so we don't send the compressed buffer + packet << VarInt{0}; + packet << buffer; + } else { + packet << uncompressedDataLength; + packet << compressedData; + } + return packet; } @@ -75,12 +80,5 @@ DataBuffer Decompress(DataBuffer& buffer, std::uint64_t packetLength) { return Inflate(buffer.data() + buffer.GetReadOffset(), compressedLength, uncompressedLength.GetValue()); } -DataBuffer Decompress(DataBuffer& buffer) { - std::uint64_t packetLength; - buffer >> packetLength; - - return Decompress(buffer, packetLength); -} - } // namespace zlib } // namespace sp diff --git a/test/test_file.cpp b/test/test_file.cpp index f05d696..0d92188 100644 --- a/test/test_file.cpp +++ b/test/test_file.cpp @@ -28,7 +28,7 @@ int main() { stream.SendMessage(KeepAlivePacket{96}); stream.SendMessage(KeepAlivePacket{69}); - stream.SendMessage(DisconnectPacket{"This is in the file !"}); + stream.SendMessage(DisconnectPacket{"This is in the fiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiile !"}); stream.RecieveMessages(); diff --git a/test/test_io.cpp b/test/test_io.cpp index 8bb78ec..c8affb6 100644 --- a/test/test_io.cpp +++ b/test/test_io.cpp @@ -33,7 +33,7 @@ int main() { stream.SendMessage(KeepAlivePacket{69}); stream.RecieveMessages(); - stream.SendMessage(DisconnectPacket{"I don't know"}); + stream.SendMessage(DisconnectPacket{"A valid reason"}); stream.RecieveMessages(); return 0;