diff --git a/include/examples/PacketExample.h b/include/examples/PacketExample.h index c185c8f..bf51532 100644 --- a/include/examples/PacketExample.h +++ b/include/examples/PacketExample.h @@ -6,12 +6,13 @@ enum PacketId { UpgradeTower, }; -#include #include +#include #include // they must be in the same order as in the enum ! using AllPackets = std::tuple; -#include -#include \ No newline at end of file +#include +#include +#include \ No newline at end of file diff --git a/include/sp/io/IOInterface.h b/include/sp/io/IOInterface.h new file mode 100644 index 0000000..31b96a7 --- /dev/null +++ b/include/sp/io/IOInterface.h @@ -0,0 +1,98 @@ +#pragma once + +#include +#include + +namespace sp { + +template +class IOInterface { + public: + DataBuffer Read(std::size_t a_Amount); + void Write(const DataBuffer& a_Data); +}; + +template +class IOStream { + protected: + MessageDispatcher m_Dispatcher; + IOInterface m_Interface; + + using MessageBase = typename MessageDispatcher::MessageBaseType; + using MsgIdType = typename MessageBase::MsgIdType; + + public: + IOStream() {} + IOStream(IOInterface&& a_Interface) : m_Interface(std::move(a_Interface)) {} + IOStream(IOStream&& a_Stream) : m_Dispatcher(std::move(a_Stream.m_Dispatcher)), m_Interface(std::move(a_Stream.m_Interface)) {} + + void RecieveMessages(); + void SendMessage(const MessageBase& a_Message); + + MessageDispatcher& GetDispatcher() { + return m_Dispatcher; + } +}; + +template +void IOStream::SendMessage(const MessageBase& a_Message) { + // TODO: process compress + encryption + DataBuffer data = a_Message.Write(); + DataBuffer dataSize; + m_Interface.Write(dataSize << sp::VarInt{data.GetSize()} << data); +} + +template +void IOStream::RecieveMessages() { + // TODO: process compress + encryption + while (true) { + + // reading the first VarInt part byte by byte + std::uint64_t lenghtValue = 0; + unsigned int readPos = 0; + + while (true) { + static constexpr int SEGMENT_BITS = (1 << 7) - 1; + static constexpr int CONTINUE_BIT = 1 << 7; + + DataBuffer buffer = m_Interface.Read(sizeof(std::uint8_t)); + + // if non-blocking call + if (buffer.GetSize() == 0) + return; + + std::uint8_t part; + buffer >> part; + lenghtValue |= static_cast(part & SEGMENT_BITS) << readPos; + + if ((part & CONTINUE_BIT) == 0) + break; + + readPos += 7; + + if (readPos >= 8 * sizeof(lenghtValue)) + throw std::runtime_error("VarInt is too big"); + } + + + DataBuffer buffer; + buffer = m_Interface.Read(lenghtValue); + + // TODO: process compress + encryption + + MsgIdType packetType; + buffer >> packetType; + + static const MessageFactory messageFactory; + + std::unique_ptr message = messageFactory.CreateMessage(packetType); + + assert(message != nullptr); + + message->Read(buffer); + + GetDispatcher().Dispatch(*message); + } +} + +} // namespace sp diff --git a/include/sp/io/IOInterfaceImpl.inl b/include/sp/io/IOInterfaceImpl.inl new file mode 100644 index 0000000..6f70f09 --- /dev/null +++ b/include/sp/io/IOInterfaceImpl.inl @@ -0,0 +1 @@ +#pragma once diff --git a/include/sp/protocol/MessageDispatcher.h b/include/sp/protocol/MessageDispatcher.h index eb346e6..cd6b5cd 100644 --- a/include/sp/protocol/MessageDispatcher.h +++ b/include/sp/protocol/MessageDispatcher.h @@ -6,6 +6,7 @@ */ #include +#include namespace sp { @@ -19,6 +20,8 @@ class MessageDispatcher { std::map>> m_Handlers; public: + using MessageBaseType = MessageBase; + /** * \brief Constructor */ @@ -51,4 +54,4 @@ class MessageDispatcher { #include -} // namespace blitz +} // namespace sp diff --git a/include/sp/protocol/MessageFactory.h b/include/sp/protocol/MessageFactory.h index 3212dd9..20742c4 100644 --- a/include/sp/protocol/MessageFactory.h +++ b/include/sp/protocol/MessageFactory.h @@ -16,7 +16,7 @@ class MessageFactory { MessageFactory() : m_Factory(details::ArrayFiller::ArrayCreate()) {} - std::unique_ptr CreateMessage(IdType id) { + std::unique_ptr CreateMessage(IdType id) const { if (id >= m_Factory.size()) return nullptr; return m_Factory.at(id)(); diff --git a/include/sp/protocol/message/MessageInterfacesImpl.h b/include/sp/protocol/message/MessageInterfacesImpl.h index 5f464a8..e04f054 100644 --- a/include/sp/protocol/message/MessageInterfacesImpl.h +++ b/include/sp/protocol/message/MessageInterfacesImpl.h @@ -75,6 +75,13 @@ class MessageInterfaceWriteBase : public TBase { WriteImpl(buffer); } + // helper + DataBuffer Write() const { + DataBuffer buffer; + this->Write(buffer); + return buffer; + } + protected: virtual void WriteImpl(DataBuffer& buffer) const = 0; }; @@ -113,6 +120,13 @@ class MessageInterfaceWriteIdBase : public TBase { this->WriteData(this->GetId(), buffer); this->WriteImpl(buffer); } + + // helper + DataBuffer Write() const { + DataBuffer buffer; + this->Write(buffer); + return buffer; + } }; } // namespace details diff --git a/test/test_io.cpp b/test/test_io.cpp new file mode 100644 index 0000000..01ff612 --- /dev/null +++ b/test/test_io.cpp @@ -0,0 +1,62 @@ +#include + +#include +#include + +struct DBTag {}; + +template <> +class sp::IOInterface { + private: + sp::DataBuffer m_VirtualIO; + + public: + sp::DataBuffer Read(std::size_t a_Amount) { + // since we are just testing it, we ignore reads that overflows + if (m_VirtualIO.GetReadOffset() + a_Amount > m_VirtualIO.GetSize()) + return {}; + + DataBuffer data; + m_VirtualIO.ReadSome(data, a_Amount); + return data; + } + void Write(const sp::DataBuffer& a_Data) { + m_VirtualIO << a_Data; + } +}; + +using DataBufferStream = sp::IOStream; + +class CustomPacketHandler : public sp::PacketHandler { + void Handle(const KeepAlivePacket& packet) { + std::cout << "KeepAlive handled ! " << packet.GetKeepAliveId() << "\n"; + } + + void Handle(const DisconnectPacket& packet) { + std::cout << "Disconnect handled ! " << packet.GetReason() << "\n"; + } + + void Handle(const UpgradeTowerPacket& packet) { + std::cout << "UpgradeTower handled !\n"; + } +}; + +int main() { + auto handler = std::make_shared(); + + DataBufferStream stream; + stream.GetDispatcher().RegisterHandler(PacketId::Disconnect, handler); + + // this should not be dispatched + stream.SendMessage(KeepAlivePacket{96}); + stream.RecieveMessages(); + + stream.GetDispatcher().RegisterHandler(PacketId::KeepAlive, handler); + + stream.SendMessage(KeepAlivePacket{69}); + stream.RecieveMessages(); // here, it's non-blocking + stream.SendMessage(DisconnectPacket{"I don't know"}); + stream.RecieveMessages(); // here, it's non-blocking + + return 0; +} \ No newline at end of file diff --git a/test/test_packets.cpp b/test/test_packets.cpp index 3af7410..e89e171 100644 --- a/test/test_packets.cpp +++ b/test/test_packets.cpp @@ -3,8 +3,8 @@ #include #include -#include #include +#include class KeepAliveHandler : public sp::PacketHandler { void Handle(const KeepAlivePacket& packet) { @@ -28,8 +28,7 @@ int main() { auto handler = std::make_shared(); msg->Dispatch(*handler); - sp::DataBuffer buffer; - msg->Write(buffer); + sp::DataBuffer buffer = msg->Write(); std::uint8_t msgId; buffer >> msgId; @@ -37,7 +36,7 @@ int main() { auto upgradeTower2 = std::make_unique(); upgradeTower2->Read(buffer); - std::cout << "Test : " << (unsigned) upgradeTower2->GetTowerId() << "\n"; + std::cout << "Test : " << (unsigned)upgradeTower2->GetTowerId() << "\n"; sp::PacketFactory factory; auto packet = factory.CreateMessage(msgId);