diff --git a/include/sp/common/ByteSwapping.h b/include/sp/common/ByteSwapping.h index 19578a5..b7cf531 100644 --- a/include/sp/common/ByteSwapping.h +++ b/include/sp/common/ByteSwapping.h @@ -4,34 +4,13 @@ namespace sp { -/** - * \brief Serialize value to (network byte order) big endian - */ -template -void ToNetwork(T& value) {} +bool IsLittleEndian(); -template <> -void ToNetwork(std::uint16_t& value); +void SwapBytes(std::uint8_t* begin, std::uint8_t* end); -template <> -void ToNetwork(std::uint32_t& value); - -template <> -void ToNetwork(std::uint64_t& value); - -/** - * \brief Deserialize value from (network byte order) big endian - */ -template -void FromNetwork(T& value) {} - -template <> -void FromNetwork(std::uint16_t& value); - -template <> -void FromNetwork(std::uint32_t& value); - -template <> -void FromNetwork(std::uint64_t& value); +template +void SwapBytes(T& a_Data) { + SwapBytes(reinterpret_cast(&a_Data), reinterpret_cast(&a_Data) + sizeof(T)); +} } // namespace sp diff --git a/include/sp/common/DataBuffer.h b/include/sp/common/DataBuffer.h index 33e5697..b8d0b97 100644 --- a/include/sp/common/DataBuffer.h +++ b/include/sp/common/DataBuffer.h @@ -10,10 +10,11 @@ #include #include #include -#include -#include -#include #include +#include +#include +#include +#include namespace sp { @@ -23,16 +24,18 @@ namespace sp { */ class DataBuffer { private: - typedef std::vector Data; + using Data = std::vector; + + private: Data m_Buffer; std::size_t m_ReadOffset; public: - typedef Data::iterator iterator; - typedef Data::const_iterator const_iterator; - typedef Data::reference reference; - typedef Data::const_reference const_reference; - typedef Data::difference_type difference_type; + using iterator = Data::iterator; + using const_iterator = Data::const_iterator; + using reference = Data::reference; + using const_reference = Data::const_reference; + using difference_type = Data::difference_type; DataBuffer(); DataBuffer(std::size_t a_InitialSize); @@ -46,21 +49,23 @@ class DataBuffer { /** * \brief Append data to the buffer + * \warning No endian checks */ template - void Append(const T& data) { - std::size_t size = sizeof(data); + void Append(const T& a_Data) { + std::size_t size = sizeof(a_Data); std::size_t end_pos = m_Buffer.size(); m_Buffer.resize(m_Buffer.size() + size); - std::memcpy(&m_Buffer[end_pos], &data, size); + std::memcpy(&m_Buffer[end_pos], &a_Data, size); } /** - * \brief Append data to the buffer + * \brief Append data to the buffer (converted to big endian) */ template - DataBuffer& operator<<(const T& data) { - Append(data); + DataBuffer& operator<<(T a_Data) { + SwapBytes(a_Data); + Append(a_Data); return *this; } @@ -115,14 +120,24 @@ class DataBuffer { return *this; } + /** + * \brief Read data into a_Data + * \warning No endian checks + */ + template + void Read(T& a_Data) { + assert(m_ReadOffset + sizeof(T) <= GetSize()); + std::memcpy(&a_Data, m_Buffer.data() + m_ReadOffset, sizeof(T)); + m_ReadOffset += sizeof(T); + } + /** * \brief Read some data from the buffer and assign to desired variable */ template - DataBuffer& operator>>(T& data) { - assert(m_ReadOffset + sizeof(T) <= GetSize()); - data = *(reinterpret_cast(&m_Buffer[m_ReadOffset])); - m_ReadOffset += sizeof(T); + DataBuffer& operator>>(T& a_Data) { + Read(a_Data); + SwapBytes(a_Data); return *this; } diff --git a/include/sp/io/BitBuffer.h b/include/sp/io/BitBuffer.h new file mode 100644 index 0000000..fc68928 --- /dev/null +++ b/include/sp/io/BitBuffer.h @@ -0,0 +1,90 @@ +#pragma once + +#include +#include + +namespace sp { + +// TODO: flush if offset exceeds 64 +class BitBuffer { + private: + using Data = std::uint64_t; + + private: + DataBuffer& m_Buffer; + Data m_Data; + std::size_t m_Offset; + bool m_WasBitField; + + public: + BitBuffer(DataBuffer& a_Buffer) : m_Buffer(a_Buffer), m_Data(0), m_Offset(0), m_WasBitField(false) {} + + void UpdateWrite(bool a_IsBitField) { + if (m_Offset == 0) + return; + if ((m_Offset % 8 == 0) || (!a_IsBitField && m_WasBitField)) { + Flush(); + } + m_WasBitField = a_IsBitField; + } + + void UpdateRead(bool a_IsBitField) { + if (m_Offset == 0) + return; + if ((m_Offset % 8 == 0) || (!a_IsBitField && m_WasBitField)) { + MoveReadOffset(); + } + m_WasBitField = a_IsBitField; + } + + template + void Append(T a_Data) { + Data bin = static_cast(a_Data); + bin &= ((1 << BitSize) - 1); // prevents overflow + std::size_t pushCount = sizeof(Data) * 8 - m_Offset - BitSize; + m_Data |= bin << pushCount; + m_Offset += BitSize; + } + + template + void Read(T& a_Data) { + std::size_t byteCount = GetByteCount(m_Offset + BitSize); + constexpr Data dataMask = (1 << BitSize) - 1; + m_Buffer.ReadSome(reinterpret_cast(&m_Data), byteCount); + SwapBytes(reinterpret_cast(&m_Data), reinterpret_cast(&m_Data) + byteCount); + m_Data >>= byteCount * 8 - m_Offset - BitSize; + m_Data &= dataMask; + a_Data = T(m_Data); + m_Buffer.SetReadOffset(m_Buffer.GetReadOffset() - byteCount); + m_Offset += BitSize; + } + + private: + void Flush() { + std::size_t byteCount = GetByteCount(); + m_Data >>= (sizeof(Data) - byteCount) * 8; + SwapBytes(reinterpret_cast(&m_Data), reinterpret_cast(&m_Data) + byteCount); + m_Buffer.WriteSome(reinterpret_cast(&m_Data), byteCount); + m_Offset = 0; + m_WasBitField = false; + m_Data = 0; + } + + void MoveReadOffset() { + std::size_t byteCount = GetByteCount(); + m_Buffer.SetReadOffset(m_Buffer.GetReadOffset() + byteCount); + m_Offset = 0; + m_WasBitField = false; + m_Data = 0; + } + + std::size_t GetByteCount(std::size_t a_Offset = -1) const { + if (a_Offset == static_cast(-1)) + a_Offset = m_Offset; + if (a_Offset <= 8) + return 1; + return (a_Offset - 1) / 8 + 1; + } +}; + +} // namespace sp diff --git a/include/sp/io/MessageIO.h b/include/sp/io/MessageIO.h index 89f0ed5..74df617 100644 --- a/include/sp/io/MessageIO.h +++ b/include/sp/io/MessageIO.h @@ -1,78 +1,53 @@ #pragma once #include -#include -#include -#include +#include namespace sp { namespace details { +template +void WriteField(DataBuffer& a_Buffer, const BitField& a_Data, BitBuffer& a_BitBuffer) { + a_BitBuffer.Append(*a_Data); + a_BitBuffer.UpdateWrite(true); +} + template -void WriteBitField(DataBuffer& a_Buffer, std::uint64_t& a_DataRaw, std::size_t& a_Offset) { - T filled = static_cast(a_DataRaw); - ToNetwork(filled); - a_Buffer << filled; - a_Offset = 0; - a_DataRaw = 0; +void WriteField(DataBuffer& a_Buffer, const T& a_Data, BitBuffer& a_BitBuffer) { + a_Buffer << a_Data; + a_BitBuffer.UpdateWrite(false); } template -void WriteField(DataBuffer& a_Buffer, const BitField& a_Data, std::uint64_t& a_DataRaw, std::size_t& a_Offset) { - T cut = *a_Data & ((1 << a_Data.GetBitSize()) - 1); - std::size_t pushCount = sizeof(T) * 8 - a_Offset - a_Data.GetBitSize(); - a_DataRaw |= cut << pushCount; - a_Offset += a_Data.GetBitSize(); - if (a_Offset == sizeof(T) * 8) { - WriteBitField(a_Buffer, a_DataRaw, a_Offset); - } +void ReadField(DataBuffer& a_Buffer, BitField& a_Data, BitBuffer& a_BitBuffer) { + a_BitBuffer.Read(*a_Data); + a_BitBuffer.UpdateRead(true); } template -void WriteField(DataBuffer& a_Buffer, const T& a_Data, std::uint64_t& a_DataRaw, std::size_t& a_Offset) { - T swapped = a_Data; - ToNetwork(swapped); - a_Buffer << swapped; -} - -template -void ReadField(DataBuffer& a_Buffer, BitField& a_Data, std::size_t& a_Offset) { - a_Buffer >> *a_Data; - FromNetwork(*a_Data); - - *a_Data >>= sizeof(T) * 8 - a_Offset - a_Data.GetBitSize(); - *a_Data &= (1 << a_Data.GetBitSize()) - 1; - - if (a_Offset != sizeof(T) * 8) { - a_Buffer.SetReadOffset(a_Buffer.GetReadOffset() - sizeof(T)); - a_Offset += a_Data.GetBitSize(); - } else { - a_Offset = 0; - } -} - -template -void ReadField(DataBuffer& a_Buffer, T& a_Data, std::size_t& a_Offset) { +void ReadField(DataBuffer& a_Buffer, T& a_Data, BitBuffer& a_BitBuffer) { a_Buffer >> a_Data; - FromNetwork(a_Data); + a_BitBuffer.UpdateRead(false); } +} // namespace details + template DataBuffer WriteMessage(const TData& a_MessageData) { DataBuffer buffer; - std::size_t currentOffset = 0; - std::uint64_t dataRaw = 0; - boost::pfr::for_each_field(a_MessageData, - [&buffer, &dataRaw, ¤tOffset](const auto& a_Field) { WriteField(buffer, a_Field, dataRaw, currentOffset); }); + BitBuffer bitBuffer(buffer); + boost::pfr::for_each_field( + a_MessageData, [&buffer, &bitBuffer](const auto& a_Field) { details::WriteField(buffer, a_Field, bitBuffer); }); + bitBuffer.UpdateWrite(false); return buffer; } template void ReadMessage(DataBuffer& a_Buffer, TData& a_MessageData) { - std::size_t currentOffset = 0; + BitBuffer bitBuffer(a_Buffer); boost::pfr::for_each_field( - a_MessageData, [&a_Buffer, ¤tOffset](auto& a_Field) { ReadField(a_Buffer, a_Field, currentOffset); }); + a_MessageData, [&a_Buffer, &bitBuffer](auto& a_Field) { details::ReadField(a_Buffer, a_Field, bitBuffer); }); + bitBuffer.UpdateRead(false); } -} // namespace details } // namespace sp diff --git a/include/sp/protocol/ConcreteMessage.h b/include/sp/protocol/ConcreteMessage.h index ab96f8d..6c82965 100644 --- a/include/sp/protocol/ConcreteMessage.h +++ b/include/sp/protocol/ConcreteMessage.h @@ -27,11 +27,11 @@ class ConcreteMessage : public MessageBase { } virtual void Read(DataBuffer& a_Buffer) override { - details::ReadMessage(a_Buffer, m_Data); + ReadMessage(a_Buffer, m_Data); } virtual DataBuffer Write() const override { - return details::WriteMessage(m_Data); + return WriteMessage(m_Data); } DataType& operator*() { diff --git a/src/sp/common/ByteSwapping.cpp b/src/sp/common/ByteSwapping.cpp index 3a02546..d9a75fe 100644 --- a/src/sp/common/ByteSwapping.cpp +++ b/src/sp/common/ByteSwapping.cpp @@ -1,49 +1,21 @@ #include -#ifdef _WIN32 - -#include - -#else - -#include -#include - -#define htonll htobe64 -#define ntohll be64toh - -#endif +#include namespace sp { -template <> -void ToNetwork(std::uint16_t& value) { - value = htons(value); +bool IsLittleEndian() { +#ifdef SP_BIG_ENDIAN + return false; +#else + return true; +#endif } -template <> -void ToNetwork(std::uint32_t& value) { - value = htonl(value); -} - -template <> -void ToNetwork(std::uint64_t& value) { - value = htonll(value); -} - -template <> -void FromNetwork(std::uint16_t& value) { - value = ntohs(value); -} - -template <> -void FromNetwork(std::uint32_t& value) { - value = ntohl(value); -} - -template <> -void FromNetwork(std::uint64_t& value) { - value = ntohll(value); +void SwapBytes(std::uint8_t* begin, std::uint8_t* end) { + if (IsLittleEndian()) { + std::reverse(begin, end); + } } } // namespace sp diff --git a/test/test_message.cpp b/test/test_message.cpp index 65617a0..31b52e1 100644 --- a/test/test_message.cpp +++ b/test/test_message.cpp @@ -12,7 +12,7 @@ #include #include -enum class PacketID { KeepAlive = 0, MDC = 1 }; +enum class PacketID : std::uint8_t { KeepAlive = 0, MDC = 1 }; class PacketHandler; @@ -26,8 +26,13 @@ struct KeepAlivePacket { sp::BitField two; }; +struct MDCPacket { + sp::BitField one; + sp::BitField two; +}; + using KeepAliveMessage = Message; -using MDCMessage = Message; +using MDCMessage = Message; using AllMessages = std::tuple; @@ -39,7 +44,7 @@ class MyHandler : public PacketHandler { std::cout << "I recieved a keep alive : " << *msg->one << " : " << *msg->two << "\n"; } virtual void Handle(const MDCMessage& msg) override { - std::cout << "I recieved a keep alive : " << *msg->one << " : " << *msg->two << "\n"; + std::cout << "I recieved a mdc : " << *msg->one << " : " << static_cast(*msg->two) << "\n"; } }; @@ -57,6 +62,7 @@ int main() { MyHandler h; PacketDispatcher d; d.RegisterHandler(PacketID::KeepAlive, &h); + d.RegisterHandler(PacketID::MDC, &h); d.Dispatch(m); PacketFactory f; auto message = f.CreateMessage(PacketID::KeepAlive); @@ -73,6 +79,7 @@ int main() { PacketStream p(std::make_shared(file)); p.WriteMessage(m); + p.WriteMessage(MDCMessage{42, PacketID::MDC}); file.flush(); @@ -81,10 +88,10 @@ int main() { PacketStream p2(std::make_shared(file2)); auto message2 = p2.ReadMessage(); + auto message3 = p2.ReadMessage(); d.Dispatch(*message2); - - // Todo : verify bitfields + d.Dispatch(*message3); // message->Write(file); // file << std::endl; diff --git a/xmake.lua b/xmake.lua index e925610..29a13e5 100644 --- a/xmake.lua +++ b/xmake.lua @@ -1,5 +1,7 @@ add_rules("mode.debug", "mode.release") +includes("@builtin/check") + add_requires("boost_pfr") set_warnings("all") @@ -76,6 +78,7 @@ target("SimpleProtocol") set_group("Library") set_kind("$(kind)") add_packages("boost_pfr", {public = true}) + check_bigendian("SP_BIG_ENDIAN") add_headerfiles("include/(sp/**.h)", "include/(sp/**.inl)")