zlib support #9

Merged
Persson-dev merged 8 commits from zlib into main 2025-03-01 18:20:51 +00:00
9 changed files with 158 additions and 56 deletions

View File

@@ -166,7 +166,7 @@ class DataBuffer {
K newKey;
V newValue;
*this >> newKey >> newValue;
data.insert({newKey, newValue});
data.emplace(newKey, newValue);
}
return *this;
}

View File

@@ -21,6 +21,8 @@ class VarInt {
std::uint64_t m_Value;
public:
static const std::uint64_t MAX_VALUE = static_cast<std::uint64_t>(-1) >> 8;
VarInt() : m_Value(0) {}
/**
* \brief Construct a variable integer from a value

View File

@@ -8,6 +8,19 @@
#include <cstdint>
#include <sp/common/DataBuffer.h>
namespace sp {
namespace option {
struct ZlibCompress {
bool m_Enabled = true;
std::size_t m_CompressionThreshold = 64;
};
} // namespace option
} // namespace sp
#include <sp/io/IOInterface.h>
namespace sp {
namespace zlib {
@@ -16,14 +29,7 @@ namespace zlib {
* \param buffer the data to compress
* \return the compressed data
*/
DataBuffer Compress(const DataBuffer& buffer);
/**
* \brief Reads the packet lenght and uncompress it
* \param buffer the data to uncompress
* \return the uncompressed data
*/
DataBuffer Decompress(DataBuffer& buffer);
DataBuffer Compress(const DataBuffer& buffer, std::size_t a_CompressionThreshold = 64);
/**
* \brief Uncompress some data
@@ -34,4 +40,15 @@ DataBuffer Decompress(DataBuffer& buffer);
DataBuffer Decompress(DataBuffer& buffer, std::uint64_t packetLength);
} // namespace zlib
namespace io {
template <>
class MessageEncapsulator<option::ZlibCompress> {
public:
static DataBuffer Encapsulate(const DataBuffer& a_Data, const option::ZlibCompress& a_Option);
static DataBuffer Decapsulate(DataBuffer& a_Data, const option::ZlibCompress& a_Option);
};
} // namespace io
} // namespace sp

View File

@@ -13,25 +13,43 @@ class IOInterface {
void Write(const DataBuffer& a_Data);
};
template <typename IOTag, typename MessageDispatcher, typename MessageFactory>
template <typename TOption>
class MessageEncapsulator {
public:
static DataBuffer Encapsulate(const DataBuffer& a_Data, const TOption& a_Option);
static DataBuffer Decapsulate(DataBuffer& a_Data, const TOption& a_Option);
};
template <typename IOTag, typename MessageDispatcher, typename MessageFactory, typename... TOptions>
class Stream {
protected:
MessageDispatcher m_Dispatcher;
IOInterface<IOTag> m_Interface;
std::tuple<TOptions...> m_Options;
using MessageBase = typename MessageDispatcher::MessageBaseType;
public:
Stream() {}
Stream(IOInterface<IOTag>&& a_Interface);
Stream(IOInterface<IOTag>&& a_Interface, TOptions&&... a_Options);
Stream(Stream&& a_Stream);
void RecieveMessages();
void SendMessage(const MessageBase& a_Message);
template <typename TOption>
TOption& GetOption() {
return std::get<TOption>(m_Options);
}
MessageDispatcher& GetDispatcher() {
return m_Dispatcher;
}
private:
static DataBuffer Encapsulate(const DataBuffer& a_Data, const TOptions&... a_Options);
static DataBuffer Decapsulate(DataBuffer& a_Data, const TOptions&... a_Options);
};
} // namespace io

View File

@@ -1,29 +1,64 @@
#pragma once
#include <sp/extensions/Compress.h>
#include <stdexcept>
namespace sp {
namespace io {
template <typename IOTag, typename MessageDispatcher, typename MessageFactory>
Stream<IOTag, MessageDispatcher, MessageFactory>::Stream(IOInterface<IOTag>&& a_Interface) : m_Interface(std::move(a_Interface)) {}
template <typename IOTag, typename MessageDispatcher, typename MessageFactory>
Stream<IOTag, MessageDispatcher, MessageFactory>::Stream(Stream<IOTag, MessageDispatcher, MessageFactory>&& a_Stream) :
namespace details {
template <typename... TOptions>
struct MessageEncapsulatorPack {};
template <>
struct MessageEncapsulatorPack<> {
static DataBuffer Encapsulate(const DataBuffer& a_Data) {
return a_Data;
}
static DataBuffer Decapsulate(DataBuffer& a_Data) {
return a_Data;
}
};
template <typename TOption, typename... TOptions>
struct MessageEncapsulatorPack<TOption, TOptions...> {
static DataBuffer Encapsulate(const DataBuffer& a_Data, const TOption& a_Option, const TOptions&... a_Options) {
DataBuffer data = io::MessageEncapsulator<TOption>::Encapsulate(a_Data, a_Option);
return MessageEncapsulatorPack<TOptions...>::Encapsulate(data, a_Options...);
}
static DataBuffer Decapsulate(DataBuffer& a_Data, const TOption& a_Option, const TOptions&... a_Options) {
DataBuffer data = io::MessageEncapsulator<TOption>::Decapsulate(a_Data, a_Option);
return MessageEncapsulatorPack<TOptions...>::Decapsulate(data, a_Options...);
}
};
} // namespace details
namespace io {
template <typename IOTag, typename MessageDispatcher, typename MessageFactory, typename... TOptions>
Stream<IOTag, MessageDispatcher, MessageFactory, TOptions...>::Stream(IOInterface<IOTag>&& a_Interface, TOptions&&... a_Options) :
m_Interface(std::move(a_Interface)), m_Options(std::make_tuple<TOptions...>(std::move(a_Options)...)) {}
template <typename IOTag, typename MessageDispatcher, typename MessageFactory, typename... TOptions>
Stream<IOTag, MessageDispatcher, MessageFactory, TOptions...>::Stream(
Stream<IOTag, MessageDispatcher, MessageFactory, TOptions...>&& a_Stream) :
m_Dispatcher(std::move(a_Stream.m_Dispatcher)), m_Interface(std::move(a_Stream.m_Interface)) {}
template <typename IOTag, typename MessageDispatcher, typename MessageFactory>
void Stream<IOTag, MessageDispatcher, MessageFactory>::SendMessage(const MessageBase& a_Message) {
// TODO: process compress + encryption
template <typename IOTag, typename MessageDispatcher, typename MessageFactory, typename... TOptions>
void Stream<IOTag, MessageDispatcher, MessageFactory, TOptions...>::SendMessage(const MessageBase& a_Message) {
DataBuffer data = a_Message.Write();
DataBuffer dataSize;
m_Interface.Write(dataSize << sp::VarInt{data.GetSize()} << data);
DataBuffer encapsulated = std::apply([&data](const auto&... a_Options) {
return Encapsulate(data, a_Options...);
}, m_Options);
DataBuffer finalData;
finalData << VarInt{encapsulated.GetSize()} << encapsulated;
m_Interface.Write(finalData);
}
template <typename IOTag, typename MessageDispatcher, typename MessageFactory>
void Stream<IOTag, MessageDispatcher, MessageFactory>::RecieveMessages() {
// TODO: process compress + encryption
template <typename IOTag, typename MessageDispatcher, typename MessageFactory, typename... TOptions>
void Stream<IOTag, MessageDispatcher, MessageFactory, TOptions...>::RecieveMessages() {
while (true) {
// reading the first VarInt part byte by byte
std::uint64_t lenghtValue = 0;
unsigned int readPos = 0;
@@ -58,7 +93,9 @@ void Stream<IOTag, MessageDispatcher, MessageFactory>::RecieveMessages() {
DataBuffer buffer;
buffer = m_Interface.Read(lenghtValue);
// TODO: process compress + encryption
buffer = std::apply([&buffer, lenghtValue](const auto&... a_Options) {
return Decapsulate(buffer, a_Options...);
}, m_Options);
VarInt packetType;
buffer >> packetType;
@@ -75,5 +112,17 @@ void Stream<IOTag, MessageDispatcher, MessageFactory>::RecieveMessages() {
}
}
template <typename IOTag, typename MessageDispatcher, typename MessageFactory, typename... TOptions>
DataBuffer Stream<IOTag, MessageDispatcher, MessageFactory, TOptions...>::Encapsulate(
const DataBuffer& a_Data, const TOptions&... a_Options) {
return details::MessageEncapsulatorPack<TOptions...>::Encapsulate(a_Data, a_Options...);
}
template <typename IOTag, typename MessageDispatcher, typename MessageFactory, typename... TOptions>
DataBuffer Stream<IOTag, MessageDispatcher, MessageFactory, TOptions...>::Decapsulate(
DataBuffer& a_Data, const TOptions&... a_Options) {
return details::MessageEncapsulatorPack<TOptions...>::Decapsulate(a_Data, a_Options...);
}
} // namespace io
} // namespace sp

View File

@@ -4,8 +4,6 @@
#include <sp/common/VarInt.h>
#include <zlib.h>
#define COMPRESSION_THRESHOLD 64
namespace sp {
namespace zlib {
@@ -13,8 +11,8 @@ static DataBuffer Inflate(const std::uint8_t* source, std::size_t size, std::siz
DataBuffer result;
result.Resize(uncompressedSize);
uncompress(reinterpret_cast<Bytef*>(result.data()), reinterpret_cast<uLongf*>(&uncompressedSize),
reinterpret_cast<const Bytef*>(source), static_cast<uLong>(size));
uncompress(static_cast<Bytef*>(result.data()), static_cast<uLongf*>(&uncompressedSize), static_cast<const Bytef*>(source),
static_cast<uLong>(size));
assert(result.GetSize() == uncompressedSize);
return result;
@@ -25,35 +23,34 @@ static DataBuffer Deflate(const std::uint8_t* source, std::size_t size) {
uLongf compressedSize = size;
result.Resize(size); // Resize for the compressed data to fit into
compress(
reinterpret_cast<Bytef*>(result.data()), &compressedSize, reinterpret_cast<const Bytef*>(source), static_cast<uLong>(size));
compress(static_cast<Bytef*>(result.data()), &compressedSize, static_cast<const Bytef*>(source), static_cast<uLong>(size));
result.Resize(compressedSize); // Resize to cut useless data
return result;
}
DataBuffer Compress(const DataBuffer& buffer) {
DataBuffer Compress(const DataBuffer& buffer, std::size_t a_CompressionThreshold) {
DataBuffer packet;
if (buffer.GetSize() < COMPRESSION_THRESHOLD) {
if (buffer.GetSize() < a_CompressionThreshold) {
// Don't compress since it's a small packet
VarInt compressedDataLength = 0;
std::uint64_t packetLength = compressedDataLength.GetSerializedLength() + buffer.GetSize();
packet << packetLength;
packet << compressedDataLength;
packet << VarInt{0};
packet << buffer;
return packet;
}
DataBuffer compressedData = Deflate(buffer.data(), buffer.GetSize());
VarInt uncompressedDataLength = buffer.GetSize();
std::uint64_t packetLength = uncompressedDataLength.GetSerializedLength() + compressedData.GetSize();
packet << packetLength;
packet << uncompressedDataLength;
packet.WriteSome(compressedData.data(), compressedData.GetSize());
if (compressedData.GetSize() >= buffer.GetSize()) {
// the compression is overkill so we don't send the compressed buffer
packet << VarInt{0};
packet << buffer;
} else {
packet << uncompressedDataLength;
packet << compressedData;
}
return packet;
}
@@ -75,12 +72,18 @@ DataBuffer Decompress(DataBuffer& buffer, std::uint64_t packetLength) {
return Inflate(buffer.data() + buffer.GetReadOffset(), compressedLength, uncompressedLength.GetValue());
}
DataBuffer Decompress(DataBuffer& buffer) {
std::uint64_t packetLength;
buffer >> packetLength;
} // namespace zlib
return Decompress(buffer, packetLength);
namespace io {
DataBuffer MessageEncapsulator<option::ZlibCompress>::Encapsulate(const DataBuffer& a_Data, const option::ZlibCompress& a_Option) {
static constexpr std::size_t MAX_COMPRESS_THRESHOLD = VarInt::MAX_VALUE;
return zlib::Compress(a_Data, a_Option.m_Enabled ? a_Option.m_CompressionThreshold : MAX_COMPRESS_THRESHOLD);
}
} // namespace zlib
DataBuffer MessageEncapsulator<option::ZlibCompress>::Decapsulate(DataBuffer& a_Data, const option::ZlibCompress& a_Option) {
return zlib::Decompress(a_Data, a_Data.GetSize());
}
} // namespace io
} // namespace sp

View File

@@ -17,18 +17,24 @@ class CustomPacketHandler : public sp::PacketHandler {
}
};
using FileStream = sp::io::Stream<sp::io::FileTag, sp::PacketDispatcher, sp::PacketFactory>;
using FileStream = sp::io::Stream<sp::io::FileTag, sp::PacketDispatcher, sp::PacketFactory, sp::option::ZlibCompress>;
int main() {
auto handler = std::make_shared<CustomPacketHandler>();
FileStream stream(sp::io::File{"test.txt", sp::io::FileTag::In | sp::io::FileTag::Out});
FileStream stream(sp::io::File{"test.txt", sp::io::FileTag::In | sp::io::FileTag::Out}, {});
stream.GetDispatcher().RegisterHandler(PacketId::Disconnect, handler);
stream.GetDispatcher().RegisterHandler(PacketId::KeepAlive, handler);
stream.SendMessage(KeepAlivePacket{96});
stream.SendMessage(KeepAlivePacket{69});
stream.SendMessage(DisconnectPacket{"This is in the file !"});
stream.SendMessage(DisconnectPacket{
"This is in the "
"fiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiile !"});
stream.GetOption<sp::option::ZlibCompress>().m_Enabled = false;
stream.SendMessage(DisconnectPacket{
"This is in the "
"fiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiile !"});
stream.RecieveMessages();

View File

@@ -33,7 +33,7 @@ int main() {
stream.SendMessage(KeepAlivePacket{69});
stream.RecieveMessages();
stream.SendMessage(DisconnectPacket{"I don't know"});
stream.SendMessage(DisconnectPacket{"A valid reason"});
stream.RecieveMessages();
return 0;

View File

@@ -29,7 +29,7 @@ end
-- Add modules targets
for name, module in table.orderpairs(modules) do
if module.Deps and has_config(module.Option) then
target("SimpleProtocolLib-" .. name)
target("SimpleProtocol-" .. name)
add_includedirs("include")
for _, include in table.orderpairs(module.Includes) do
add_headerfiles(include)
@@ -45,7 +45,7 @@ for name, module in table.orderpairs(modules) do
end
end
target("SimpleProtocolLib")
target("SimpleProtocol")
add_includedirs("include")
add_files("src/sp/**.cpp")
@@ -54,6 +54,13 @@ target("SimpleProtocolLib")
add_headerfiles("include/(sp/" .. folder .. "/**.h)")
end
-- adding extensions
for name, module in table.orderpairs(modules) do
if module.Deps and has_config(module.Option) then
add_deps("SimpleProtocol-" .. name)
end
end
-- we don't want extensions
remove_files("src/sp/extensions/**.cpp")
set_group("Library")
@@ -68,7 +75,7 @@ for _, file in ipairs(os.files("test/**.cpp")) do
add_files(file)
add_includedirs("include")
add_deps("SimpleProtocolLib")
add_deps("SimpleProtocol")
add_tests("compile_and_run")
end