7 Commits

Author SHA1 Message Date
37ff881819 add file input/output 2025-02-26 00:07:06 +01:00
a2eb10ec6d add file interface 2025-02-26 00:01:51 +01:00
132c3c3c8d add base io interface 2025-02-25 23:25:12 +01:00
8a5286d0ce move ArrayFiller in its own file 2025-02-25 20:30:26 +01:00
a194774925 add MessageDispatcher 2025-02-25 20:29:59 +01:00
8f32b09b17 add extensions include 2025-02-25 18:29:55 +01:00
60bb4ea06e begin compression module 2025-02-25 14:06:56 +01:00
17 changed files with 594 additions and 48 deletions

View File

@@ -6,12 +6,13 @@ enum PacketId {
UpgradeTower,
};
#include <examples/KeepAlivePacket.h>
#include <examples/DisconnectPacket.h>
#include <examples/KeepAlivePacket.h>
#include <examples/UpgradeTowerPacket.h>
// they must be in the same order as in the enum !
using AllPackets = std::tuple<KeepAlivePacket, DisconnectPacket, UpgradeTowerPacket>;
#include <sp/default/DefaultPacketHandler.h>
#include <sp/default/DefaultPacketFactory.h>
#include <sp/default/DefaultPacketDispatcher.h>
#include <sp/default/DefaultPacketFactory.h>
#include <sp/default/DefaultPacketHandler.h>

View File

@@ -0,0 +1,15 @@
#pragma once
#include <sp/default/DefaultPacket.h>
#include <sp/default/DefaultPacketHandler.h>
#include <sp/protocol/MessageDispatcher.h>
namespace sp {
using PacketDispatcher = MessageDispatcher<
PacketMessage::ParsedOptions::MsgIdType,
PacketMessage,
PacketMessage::ParsedOptions::HandlerType::HandlerT
>;
} // namespace sp

View File

@@ -0,0 +1,37 @@
#pragma once
/**
* \file Compression.h
* \brief File containing compress utilities
*/
#include <cstdint>
#include <sp/common/DataBuffer.h>
namespace sp {
namespace zlib {
/**
* \brief Compress some data
* \param buffer the data to compress
* \return the compressed data
*/
DataBuffer Compress(const DataBuffer& buffer);
/**
* \brief Reads the packet lenght and uncompress it
* \param buffer the data to uncompress
* \return the uncompressed data
*/
DataBuffer Decompress(DataBuffer& buffer);
/**
* \brief Uncompress some data
* \param buffer the data to uncompress
* \param packetLength lenght of data
* \return the uncompressed data
*/
DataBuffer Decompress(DataBuffer& buffer, std::uint64_t packetLength);
} // namespace zlib
} // namespace sp

View File

@@ -0,0 +1,5 @@
#pragma once
#if __has_include(<sp/extensions/Compress.h>)
#include <sp/extensions/Compress.h>
#endif

42
include/sp/io/FileIO.h Normal file
View File

@@ -0,0 +1,42 @@
#pragma once
#include <fstream>
#include <sp/io/IOInterface.h>
namespace sp {
struct FileTag {};
template <>
class IOInterface<FileTag> {
private:
std::unique_ptr<std::ifstream> m_FileInput;
std::unique_ptr<std::ofstream> m_FileOutput;
public:
IOInterface(const std::string& fileInput, const std::string& fileOutput) {
if (!fileInput.empty())
m_FileInput = std::make_unique<std::ifstream>(fileInput);
if (!fileOutput.empty())
m_FileOutput = std::make_unique<std::ofstream>(fileOutput);
}
IOInterface(IOInterface&& other) : m_FileOutput(std::move(other.m_FileOutput)), m_FileInput(std::move(other.m_FileInput)) {}
DataBuffer Read(std::size_t a_Amount) {
DataBuffer buffer;
buffer.Resize(a_Amount);
assert(m_FileInput != nullptr);
m_FileInput->read(reinterpret_cast<char*>(buffer.data()), a_Amount);
return buffer;
}
void Write(const sp::DataBuffer& a_Data) {
assert(m_FileOutput != nullptr);
m_FileOutput->write(reinterpret_cast<const char*>(a_Data.data()), a_Data.GetSize());
m_FileOutput->flush();
}
};
using FileIO = IOInterface<FileTag>;
} // namespace sp

101
include/sp/io/IOInterface.h Normal file
View File

@@ -0,0 +1,101 @@
#pragma once
#include <memory>
#include <sp/common/DataBuffer.h>
namespace sp {
template <typename IOTag>
class IOInterface {
public:
DataBuffer Read(std::size_t a_Amount);
void Write(const DataBuffer& a_Data);
};
template <typename IOTag, typename MessageDispatcher, typename MessageFactory>
class IOStream {
protected:
MessageDispatcher m_Dispatcher;
IOInterface<IOTag> m_Interface;
using MessageBase = typename MessageDispatcher::MessageBaseType;
using MsgIdType = typename MessageBase::MsgIdType;
public:
IOStream() {}
IOStream(IOInterface<IOTag>&& a_Interface) : m_Interface(std::move(a_Interface)) {}
IOStream(IOStream&& a_Stream) : m_Dispatcher(std::move(a_Stream.m_Dispatcher)), m_Interface(std::move(a_Stream.m_Interface)) {}
void RecieveMessages();
void SendMessage(const MessageBase& a_Message);
MessageDispatcher& GetDispatcher() {
return m_Dispatcher;
}
};
template <typename IOTag, typename MessageDispatcher, typename MessageFactory>
void IOStream<IOTag, MessageDispatcher, MessageFactory>::SendMessage(const MessageBase& a_Message) {
// TODO: process compress + encryption
DataBuffer data = a_Message.Write();
DataBuffer dataSize;
m_Interface.Write(dataSize << sp::VarInt{data.GetSize()} << data);
}
template <typename IOTag, typename MessageDispatcher, typename MessageFactory>
void IOStream<IOTag, MessageDispatcher, MessageFactory>::RecieveMessages() {
// TODO: process compress + encryption
while (true) {
// reading the first VarInt part byte by byte
std::uint64_t lenghtValue = 0;
unsigned int readPos = 0;
while (true) {
static constexpr int SEGMENT_BITS = (1 << 7) - 1;
static constexpr int CONTINUE_BIT = 1 << 7;
DataBuffer buffer = m_Interface.Read(sizeof(std::uint8_t));
// if non-blocking call
if (buffer.GetSize() == 0)
return;
std::uint8_t part;
buffer >> part;
lenghtValue |= static_cast<std::uint64_t>(part & SEGMENT_BITS) << readPos;
if ((part & CONTINUE_BIT) == 0)
break;
readPos += 7;
if (readPos >= 8 * sizeof(lenghtValue))
throw std::runtime_error("VarInt is too big");
}
// nothing to read
if (lenghtValue == 0)
return;
DataBuffer buffer;
buffer = m_Interface.Read(lenghtValue);
// TODO: process compress + encryption
MsgIdType packetType;
buffer >> packetType;
static const MessageFactory messageFactory;
std::unique_ptr<MessageBase> message = messageFactory.CreateMessage(packetType);
assert(message != nullptr);
message->Read(buffer);
GetDispatcher().Dispatch(*message);
}
}
} // namespace sp

View File

@@ -0,0 +1 @@
#pragma once

View File

@@ -0,0 +1,57 @@
#pragma once
/**
* \file MessageDispatcher.h
* \brief File containing the sp::MessageDispatcher class
*/
#include <map>
#include <memory>
namespace sp {
/**
* \class MessageDispatcher
* \brief Class used to dispatch messages
*/
template <typename MessageIdType, typename MessageBase, typename MessageHandler>
class MessageDispatcher {
private:
std::map<MessageIdType, std::vector<std::shared_ptr<MessageHandler>>> m_Handlers;
public:
using MessageBaseType = MessageBase;
/**
* \brief Constructor
*/
MessageDispatcher() {}
/**
* \brief Dispatch a packet
* \param packet The packet to dispatch
*/
void Dispatch(const MessageBase& a_Message);
/**
* \brief Register a packet handler
* \param type The packet type
* \param handler The packet handler
*/
void RegisterHandler(MessageIdType a_MessageType, const std::shared_ptr<MessageHandler>& a_Handler);
/**
* \brief Unregister a packet handler
* \param type The packet type
* \param handler The packet handler
*/
void UnregisterHandler(MessageIdType a_MessageType, const std::shared_ptr<MessageHandler>& a_Handler);
/**
* \brief Unregister a packet handler
* \param handler The packet handler
*/
void UnregisterHandler(const std::shared_ptr<MessageHandler>& a_Handler);
};
#include <sp/protocol/message/MessageDispatcherImpl.inl>
} // namespace sp

View File

@@ -7,42 +7,7 @@
namespace sp {
namespace details {
template <typename TBase>
using ArrayType = std::vector<std::function<std::unique_ptr<TBase>(void)>>;
template <typename TBase, typename... TMessages>
struct ArrayFiller {};
template <typename TBase, typename... TMessages>
struct ArrayFiller<TBase, std::tuple<TMessages...>> {
static ArrayType<TBase> ArrayCreate() {
ArrayType<TBase> array;
array.reserve(sizeof...(TMessages));
ArrayFiller<TBase, TMessages...>::ArrayAppend(array);
return array;
}
};
template <typename TBase, typename TMessage, typename... TMessages>
struct ArrayFiller<TBase, TMessage, TMessages...> {
static void ArrayAppend(details::ArrayType<TBase>& array) {
ArrayFiller<TBase, TMessage>::ArrayAppend(array);
ArrayFiller<TBase, TMessages...>::ArrayAppend(array);
}
};
template <typename TBase, typename TMessage>
struct ArrayFiller<TBase, TMessage> {
static void ArrayAppend(details::ArrayType<TBase>& array) {
array.push_back([]() -> std::unique_ptr<TBase> { return std::make_unique<TMessage>(); });
}
};
} // namespace details
#include <sp/protocol/message/ArrayFillerImpl.inl>
template <typename TBase, typename TTMessages>
class MessageFactory {
@@ -51,7 +16,7 @@ class MessageFactory {
MessageFactory() : m_Factory(details::ArrayFiller<TBase, TTMessages>::ArrayCreate()) {}
std::unique_ptr<TBase> CreateMessage(IdType id) {
std::unique_ptr<TBase> CreateMessage(IdType id) const {
if (id >= m_Factory.size())
return nullptr;
return m_Factory.at(id)();

View File

@@ -0,0 +1,38 @@
#pragma once
namespace details {
template <typename TBase>
using ArrayType = std::vector<std::function<std::unique_ptr<TBase>(void)>>;
template <typename TBase, typename... TMessages>
struct ArrayFiller {};
template <typename TBase, typename... TMessages>
struct ArrayFiller<TBase, std::tuple<TMessages...>> {
static ArrayType<TBase> ArrayCreate() {
ArrayType<TBase> array;
array.reserve(sizeof...(TMessages));
ArrayFiller<TBase, TMessages...>::ArrayAppend(array);
return array;
}
};
template <typename TBase, typename TMessage, typename... TMessages>
struct ArrayFiller<TBase, TMessage, TMessages...> {
static void ArrayAppend(details::ArrayType<TBase>& array) {
ArrayFiller<TBase, TMessage>::ArrayAppend(array);
ArrayFiller<TBase, TMessages...>::ArrayAppend(array);
}
};
template <typename TBase, typename TMessage>
struct ArrayFiller<TBase, TMessage> {
static void ArrayAppend(details::ArrayType<TBase>& array) {
array.push_back([]() -> std::unique_ptr<TBase> { return std::make_unique<TMessage>(); });
}
};
} // namespace details

View File

@@ -0,0 +1,34 @@
#pragma once
template <typename MessageIdType, typename MessageBase, typename MessageHandler>
void MessageDispatcher<MessageIdType, MessageBase, MessageHandler>::RegisterHandler(MessageIdType a_MessageType, const std::shared_ptr<MessageHandler>& a_Handler) {
auto found = std::find(m_Handlers[a_MessageType].begin(), m_Handlers[a_MessageType].end(), a_Handler);
if (found == m_Handlers[a_MessageType].end())
m_Handlers[a_MessageType].push_back(a_Handler);
}
template <typename MessageIdType, typename MessageBase, typename MessageHandler>
void MessageDispatcher<MessageIdType, MessageBase, MessageHandler>::UnregisterHandler(MessageIdType a_MessageType, const std::shared_ptr<MessageHandler>& a_Handler) {
auto found = std::find(m_Handlers[a_MessageType].begin(), m_Handlers[a_MessageType].end(), a_Handler);
if (found != m_Handlers[a_MessageType].end())
m_Handlers[a_MessageType].erase(found);
}
template <typename MessageIdType, typename MessageBase, typename MessageHandler>
void MessageDispatcher<MessageIdType, MessageBase, MessageHandler>::UnregisterHandler(const std::shared_ptr<MessageHandler>& a_Handler) {
for (auto& pair : m_Handlers) {
if (pair.second.empty())
continue;
MessageIdType type = pair.first;
m_Handlers[type].erase(std::remove(m_Handlers[type].begin(), m_Handlers[type].end(), a_Handler), m_Handlers[type].end());
}
}
template <typename MessageIdType, typename MessageBase, typename MessageHandler>
void MessageDispatcher<MessageIdType, MessageBase, MessageHandler>::Dispatch(const MessageBase& a_Message) {
MessageIdType type = a_Message.GetId();
for (auto& handler : m_Handlers[type])
a_Message.Dispatch(*handler);
}

View File

@@ -75,6 +75,13 @@ class MessageInterfaceWriteBase : public TBase {
WriteImpl(buffer);
}
// helper
DataBuffer Write() const {
DataBuffer buffer;
this->Write(buffer);
return buffer;
}
protected:
virtual void WriteImpl(DataBuffer& buffer) const = 0;
};
@@ -113,6 +120,13 @@ class MessageInterfaceWriteIdBase : public TBase {
this->WriteData(this->GetId(), buffer);
this->WriteImpl(buffer);
}
// helper
DataBuffer Write() const {
DataBuffer buffer;
this->Write(buffer);
return buffer;
}
};
} // namespace details

View File

@@ -0,0 +1,86 @@
#include <sp/extensions/Compress.h>
#include <cassert>
#include <sp/common/VarInt.h>
#include <zlib.h>
#define COMPRESSION_THRESHOLD 64
namespace sp {
namespace zlib {
static DataBuffer Inflate(const std::uint8_t* source, std::size_t size, std::size_t uncompressedSize) {
DataBuffer result;
result.Resize(uncompressedSize);
uncompress(reinterpret_cast<Bytef*>(result.data()), reinterpret_cast<uLongf*>(&uncompressedSize),
reinterpret_cast<const Bytef*>(source), static_cast<uLong>(size));
assert(result.GetSize() == uncompressedSize);
return result;
}
static DataBuffer Deflate(const std::uint8_t* source, std::size_t size) {
DataBuffer result;
uLongf compressedSize = size;
result.Resize(size); // Resize for the compressed data to fit into
compress(
reinterpret_cast<Bytef*>(result.data()), &compressedSize, reinterpret_cast<const Bytef*>(source), static_cast<uLong>(size));
result.Resize(compressedSize); // Resize to cut useless data
return result;
}
DataBuffer Compress(const DataBuffer& buffer) {
DataBuffer packet;
if (buffer.GetSize() < COMPRESSION_THRESHOLD) {
// Don't compress since it's a small packet
VarInt compressedDataLength = 0;
std::uint64_t packetLength = compressedDataLength.GetSerializedLength() + buffer.GetSize();
packet << packetLength;
packet << compressedDataLength;
packet << buffer;
return packet;
}
DataBuffer compressedData = Deflate(buffer.data(), buffer.GetSize());
VarInt uncompressedDataLength = buffer.GetSize();
std::uint64_t packetLength = uncompressedDataLength.GetSerializedLength() + compressedData.GetSize();
packet << packetLength;
packet << uncompressedDataLength;
packet.WriteSome(compressedData.data(), compressedData.GetSize());
return packet;
}
DataBuffer Decompress(DataBuffer& buffer, std::uint64_t packetLength) {
VarInt uncompressedLength;
buffer >> uncompressedLength;
std::uint64_t compressedLength = packetLength - uncompressedLength.GetSerializedLength();
if (uncompressedLength.GetValue() == 0) {
// Data already uncompressed. Nothing to do
DataBuffer ret;
buffer.ReadSome(ret, compressedLength);
return ret;
}
assert(buffer.GetReadOffset() + compressedLength <= buffer.GetSize());
return Inflate(buffer.data() + buffer.GetReadOffset(), compressedLength, uncompressedLength.GetValue());
}
DataBuffer Decompress(DataBuffer& buffer) {
std::uint64_t packetLength;
buffer >> packetLength;
return Decompress(buffer, packetLength);
}
} // namespace zlib
} // namespace sp

37
test/test_file.cpp Normal file
View File

@@ -0,0 +1,37 @@
#include <iostream>
#include <examples/PacketExample.h>
#include <sp/io/FileIO.h>
class CustomPacketHandler : public sp::PacketHandler {
void Handle(const KeepAlivePacket& packet) {
std::cout << "KeepAlive handled ! " << packet.GetKeepAliveId() << "\n";
}
void Handle(const DisconnectPacket& packet) {
std::cout << "Disconnect handled ! " << packet.GetReason() << "\n";
}
void Handle(const UpgradeTowerPacket& packet) {
std::cout << "UpgradeTower handled !\n";
}
};
using FileStream = sp::IOStream<sp::FileTag, sp::PacketDispatcher, sp::PacketFactory>;
int main() {
auto handler = std::make_shared<CustomPacketHandler>();
FileStream stream(sp::FileIO{"test.txt", "text.txt"});
stream.GetDispatcher().RegisterHandler(PacketId::Disconnect, handler);
stream.GetDispatcher().RegisterHandler(PacketId::KeepAlive, handler);
stream.SendMessage(KeepAlivePacket{96});
stream.SendMessage(KeepAlivePacket{69});
stream.SendMessage(DisconnectPacket{"This is in the file !"});
stream.RecieveMessages();
return 0;
}

62
test/test_io.cpp Normal file
View File

@@ -0,0 +1,62 @@
#include <iostream>
#include <examples/PacketExample.h>
#include <sp/io/IOInterface.h>
struct DBTag {};
template <>
class sp::IOInterface<DBTag> {
private:
sp::DataBuffer m_VirtualIO;
public:
sp::DataBuffer Read(std::size_t a_Amount) {
// since we are just testing it, we ignore reads that overflows
if (m_VirtualIO.GetReadOffset() + a_Amount > m_VirtualIO.GetSize())
return {};
DataBuffer data;
m_VirtualIO.ReadSome(data, a_Amount);
return data;
}
void Write(const sp::DataBuffer& a_Data) {
m_VirtualIO << a_Data;
}
};
using DataBufferStream = sp::IOStream<DBTag, sp::PacketDispatcher, sp::PacketFactory>;
class CustomPacketHandler : public sp::PacketHandler {
void Handle(const KeepAlivePacket& packet) {
std::cout << "KeepAlive handled ! " << packet.GetKeepAliveId() << "\n";
}
void Handle(const DisconnectPacket& packet) {
std::cout << "Disconnect handled ! " << packet.GetReason() << "\n";
}
void Handle(const UpgradeTowerPacket& packet) {
std::cout << "UpgradeTower handled !\n";
}
};
int main() {
auto handler = std::make_shared<CustomPacketHandler>();
DataBufferStream stream;
stream.GetDispatcher().RegisterHandler(PacketId::Disconnect, handler);
// this should not be dispatched
stream.SendMessage(KeepAlivePacket{96});
stream.RecieveMessages();
stream.GetDispatcher().RegisterHandler(PacketId::KeepAlive, handler);
stream.SendMessage(KeepAlivePacket{69});
stream.RecieveMessages(); // here, it's non-blocking
stream.SendMessage(DisconnectPacket{"I don't know"});
stream.RecieveMessages(); // here, it's non-blocking
return 0;
}

View File

@@ -3,6 +3,9 @@
#include <examples/PacketExample.h>
#include <memory>
#include <sp/default/DefaultPacketDispatcher.h>
#include <sp/extensions/Extensions.h>
class KeepAliveHandler : public sp::PacketHandler {
void Handle(const KeepAlivePacket& packet) {
std::cout << "KeepAlive handled !\n";
@@ -22,11 +25,10 @@ int main() {
sp::PacketMessage* msg = upgradeTower.get();
KeepAliveHandler handler;
msg->Dispatch(handler);
auto handler = std::make_shared<KeepAliveHandler>();
msg->Dispatch(*handler);
sp::DataBuffer buffer;
msg->Write(buffer);
sp::DataBuffer buffer = msg->Write();
std::uint8_t msgId;
buffer >> msgId;
@@ -34,7 +36,7 @@ int main() {
auto upgradeTower2 = std::make_unique<UpgradeTowerPacket>();
upgradeTower2->Read(buffer);
std::cout << "Test : " << (unsigned) upgradeTower2->GetTowerId() << "\n";
std::cout << "Test : " << (unsigned)upgradeTower2->GetTowerId() << "\n";
sp::PacketFactory factory;
auto packet = factory.CreateMessage(msgId);
@@ -43,7 +45,13 @@ int main() {
return 1;
}
std::cout << (unsigned)packet->GetId() << std::endl;
packet->Dispatch(handler);
packet->Dispatch(*handler);
sp::PacketDispatcher dispatcher;
dispatcher.RegisterHandler(PacketId::KeepAlive, handler);
dispatcher.Dispatch(*packet);
dispatcher.UnregisterHandler(PacketId::KeepAlive, handler);
dispatcher.UnregisterHandler(handler);
return 0;
}

View File

@@ -2,11 +2,54 @@ add_rules("mode.debug", "mode.release")
set_languages("c++17")
local modules = {
Compression = {
Option = "zlib",
Deps = {"zlib"},
Packages = {"zlib"},
Includes = {"include/(sp/extensions/Compress.h)"},
Sources = {"src/sp/extensions/Compress.cpp"}
}
}
-- Map modules to options
for name, module in table.orderpairs(modules) do
if module.Option then
option(module.Option, { description = "Enables the " .. name .. " module", default = true, category = "Modules" })
end
end
-- Add modules requirements
for name, module in table.orderpairs(modules) do
if module.Deps then
add_requires(module.Deps)
end
end
-- Add modules targets
for name, module in table.orderpairs(modules) do
if module.Deps and has_config(module.Option) then
target("SimpleProtocolLib-" .. name)
add_includedirs("include")
for _, include in table.orderpairs(module.Includes) do
add_headerfiles(include)
end
for _, source in table.orderpairs(module.Sources) do
add_files(source)
end
for _, package in table.orderpairs(module.Packages) do
add_packages(package)
end
set_group("Library")
set_kind("$(kind)")
end
end
target("SimpleProtocolLib")
add_includedirs("include")
add_headerfiles("include/(sp/**.h)")
add_headerfiles("include/(sp/common/**.h)", "include/(sp/common/**.h)", "include/(sp/common/**.h)")
set_group("Library")
add_files("src/sp/**.cpp")
add_files("src/sp/common/*.cpp")
set_kind("$(kind)")
-- Tests