v2.0 #15

Merged
Persson-dev merged 9 commits from v2.0 into main 2025-07-10 13:15:42 +00:00
11 changed files with 443 additions and 54 deletions
Showing only changes of commit ed0b06f78d - Show all commits

View File

@@ -6,20 +6,7 @@
*/
#include <cstdint>
#include <sp/common/DataBuffer.h>
namespace sp {
namespace option {
struct ZlibCompress {
bool m_Enabled = true;
std::size_t m_CompressionThreshold = 64;
};
} // namespace option
} // namespace sp
#include <sp/io/IOInterface.h>
#include <sp/io/MessageEncapsulator.h>
namespace sp {
namespace zlib {
@@ -41,14 +28,20 @@ DataBuffer Decompress(DataBuffer& buffer, std::uint64_t packetLength);
} // namespace zlib
namespace io {
template <>
class MessageEncapsulator<option::ZlibCompress> {
class ZlibCompress : public MessageEncapsulator {
private:
std::size_t m_CompressionThreshold;
public:
static DataBuffer Encapsulate(const DataBuffer& a_Data, const option::ZlibCompress& a_Option);
static DataBuffer Decapsulate(DataBuffer& a_Data, const option::ZlibCompress& a_Option);
ZlibCompress() : m_CompressionThreshold(64) {}
ZlibCompress(const ZlibCompress&) = default;
virtual ~ZlibCompress() {}
protected:
virtual DataBuffer EncapsulateImpl(const DataBuffer& a_Data) override;
virtual DataBuffer DecapsulateImpl(DataBuffer& a_Data) override;
};
} // namespace io
} // namespace sp
} // namespace sp

View File

@@ -1,6 +1,7 @@
#pragma once
#include <sp/extensions/tcp/TcpSocket.h>
#include <memory>
namespace sp {
namespace io {
@@ -10,7 +11,7 @@ namespace io {
*/
class TcpListener : private NonCopyable {
public:
using SocketHandle = TcpTag::SocketHandle;
using SocketHandle = TcpSocket::SocketHandle;
/**
* \brief Starts listening for guests to connect

View File

@@ -1,14 +1,12 @@
#pragma once
#include <sp/common/NonCopyable.h>
#include <sp/io/IOInterface.h>
#include <sp/io/IoInterface.h>
namespace sp {
namespace io {
struct TcpTag {
using SocketHandle = int;
};
class TcpListener;
class SocketError : public std::exception {
private:
@@ -22,10 +20,9 @@ class SocketError : public std::exception {
}
};
template <>
class IOInterface<TcpTag> : private NonCopyable {
class TcpSocket : public sp::IoInterface {
public:
using SocketHandle = TcpTag::SocketHandle;
using SocketHandle = int;
/**
* \enum Status
@@ -40,14 +37,14 @@ class IOInterface<TcpTag> : private NonCopyable {
Error,
};
IOInterface();
IOInterface(const std::string& a_Host, std::uint16_t a_Port);
IOInterface(IOInterface&& a_Other);
IOInterface& operator=(IOInterface&& a_Other);
virtual ~IOInterface();
TcpSocket();
TcpSocket(const std::string& a_Host, std::uint16_t a_Port);
TcpSocket(TcpSocket&& a_Other);
TcpSocket& operator=(TcpSocket&& a_Other);
virtual ~TcpSocket();
DataBuffer Read(std::size_t a_Amount);
void Write(const sp::DataBuffer& a_Data);
virtual DataBuffer Read(std::size_t a_Amount) override;
virtual void Write(const sp::DataBuffer& a_Data) override;
/**
* \brief Allows to set the socket in non blocking/blocking mode
@@ -79,10 +76,5 @@ class IOInterface<TcpTag> : private NonCopyable {
friend class TcpListener;
};
/**
* \typedef TcpSocket
*/
using TcpSocket = IOInterface<TcpTag>;
} // namespace io
} // namespace sp

View File

@@ -1,10 +1,11 @@
#pragma once
#include <sp/common/DataBuffer.h>
#include <sp/common/NonCopyable.h>
namespace sp {
class IoInterface {
class IoInterface : private NonCopyable {
public:
virtual DataBuffer Read(std::size_t a_Amount) = 0;
virtual void Write(const DataBuffer& a_Data) = 0;

View File

@@ -5,12 +5,27 @@
namespace sp {
class MessageEncapsulator {
protected:
bool m_Enabled = true;
public:
MessageEncapsulator() {}
virtual ~MessageEncapsulator() {}
virtual DataBuffer Encapsulate(const DataBuffer& a_Data) = 0;
virtual DataBuffer Decapsulate(DataBuffer& a_Data) = 0;
DataBuffer Encapsulate(const DataBuffer& a_Data) {
if (!m_Enabled)
return a_Data;
return EncapsulateImpl(a_Data);
}
DataBuffer Decapsulate(DataBuffer& a_Data) {
if (!m_Enabled)
return a_Data;
return DecapsulateImpl(a_Data);
}
protected:
virtual DataBuffer EncapsulateImpl(const DataBuffer& a_Data) = 0;
virtual DataBuffer DecapsulateImpl(DataBuffer& a_Data) = 0;
};
} // namespace sp

View File

@@ -9,8 +9,8 @@ namespace sp {
template <typename TMessageFactory>
class MessageStream {
private:
std::vector<MessageEncapsulator> m_Encapsulators;
protected:
std::vector<std::shared_ptr<MessageEncapsulator>> m_Encapsulators;
std::shared_ptr<IoInterface> m_Stream;
using MessageBaseType = typename TMessageFactory::MessageBaseType;
@@ -19,11 +19,30 @@ class MessageStream {
public:
MessageStream(std::shared_ptr<IoInterface>&& a_Stream) : m_Stream(std::move(a_Stream)) {}
template<typename... TEnc>
MessageStream(std::shared_ptr<IoInterface>&& a_Stream, TEnc&&... a_Encapsulators) :
m_Stream(std::move(a_Stream)){
m_Encapsulators.reserve(sizeof...(a_Encapsulators));
AddEncapsulators(std::move(a_Encapsulators ...));
}
std::unique_ptr<MessageBaseType> ReadMessage();
std::unique_ptr<MessageBaseType> ReadMessage(MessageIdType a_Id);
void WriteMessage(const MessageBaseType& a_Message, bool a_WriteId = true);
template<typename... Args>
void AddEncapsulators(Args&& ... a_Encapsulators) {
AddEncapsulators(std::move(std::make_tuple<>(a_Encapsulators ...)));
}
template<typename... Args>
void AddEncapsulators(std::tuple<Args...>&& a_Encapsulators) {
TupleForEach([this](auto&& a_Encapsulator){
m_Encapsulators.push_back(std::move(a_Encapsulator));
}, a_Encapsulators);
}
private:
DataBuffer ReadAndDecapsulate();
std::unique_ptr<MessageBaseType> MakeMessage(DataBuffer& buffer, MessageIdType a_Id);

View File

@@ -12,8 +12,8 @@ DataBuffer MessageStream<TMessageFactory>::ReadAndDecapsulate() {
std::size_t amount = messageLength.GetValue();
DataBuffer buffer = m_Stream->Read(amount);
for (MessageEncapsulator& enc : m_Encapsulators) {
buffer = enc.Decapsulate(buffer);
for (auto& enc : m_Encapsulators) {
buffer = enc->Decapsulate(buffer);
}
return buffer;
@@ -51,8 +51,8 @@ void MessageStream<TMessageFactory>::WriteMessage(const MessageBaseType& a_Messa
if (a_WriteId)
buffer << VarInt{static_cast<std::uint64_t>(a_Message.GetId())};
buffer << a_Message.Write();
for (MessageEncapsulator& enc : m_Encapsulators) {
buffer = enc.Encapsulate(buffer);
for (auto& enc : m_Encapsulators) {
buffer = enc->Encapsulate(buffer);
}
DataBuffer header;
header << VarInt{buffer.GetSize()};

View File

@@ -0,0 +1,85 @@
#include <sp/extensions/Compress.h>
#include <cassert>
#include <sp/common/VarInt.h>
#include <zlib.h>
namespace sp {
namespace zlib {
static DataBuffer Inflate(const std::uint8_t* source, std::size_t size, std::size_t uncompressedSize) {
DataBuffer result;
result.Resize(uncompressedSize);
uncompress(static_cast<Bytef*>(result.data()), reinterpret_cast<uLongf*>(&uncompressedSize), static_cast<const Bytef*>(source),
static_cast<uLong>(size));
assert(result.GetSize() == uncompressedSize);
return result;
}
static DataBuffer Deflate(const std::uint8_t* source, std::size_t size) {
DataBuffer result;
uLongf compressedSize = size;
result.Resize(size); // Resize for the compressed data to fit into
compress(static_cast<Bytef*>(result.data()), &compressedSize, static_cast<const Bytef*>(source), static_cast<uLong>(size));
result.Resize(compressedSize); // Resize to cut useless data
return result;
}
DataBuffer Compress(const DataBuffer& buffer, std::size_t a_CompressionThreshold) {
DataBuffer packet;
if (buffer.GetSize() < a_CompressionThreshold) {
// Don't compress since it's a small packet
packet << VarInt{0};
packet << buffer;
return packet;
}
DataBuffer compressedData = Deflate(buffer.data(), buffer.GetSize());
VarInt uncompressedDataLength = buffer.GetSize();
if (compressedData.GetSize() >= buffer.GetSize()) {
// the compression is overkill so we don't send the compressed buffer
packet << VarInt{0};
packet << buffer;
} else {
packet << uncompressedDataLength;
packet << compressedData;
}
return packet;
}
DataBuffer Decompress(DataBuffer& buffer, std::uint64_t packetLength) {
VarInt uncompressedLength;
buffer >> uncompressedLength;
std::uint64_t compressedLength = packetLength - uncompressedLength.GetSerializedLength();
if (uncompressedLength.GetValue() == 0) {
// Data already uncompressed. Nothing to do
DataBuffer ret;
buffer.ReadSome(ret, compressedLength);
return ret;
}
assert(buffer.GetReadOffset() + compressedLength <= buffer.GetSize());
return Inflate(buffer.data() + buffer.GetReadOffset(), compressedLength, uncompressedLength.GetValue());
}
} // namespace zlib
DataBuffer ZlibCompress::EncapsulateImpl(const DataBuffer& a_Data) {
return zlib::Compress(a_Data, m_CompressionThreshold);
}
DataBuffer ZlibCompress::DecapsulateImpl(DataBuffer& a_Data) {
return zlib::Decompress(a_Data, a_Data.GetSize());
}
} // namespace sp

View File

@@ -0,0 +1,114 @@
#include <sp/extensions/tcp/TcpListener.h>
#ifdef _WIN32
// Windows
#include <winsock2.h>
#include <ws2tcpip.h>
#define ioctl ioctlsocket
#define WOULDBLOCK WSAEWOULDBLOCK
#define MSG_DONTWAIT 0
#else
// Linux/Unix
#include <arpa/inet.h>
#include <netdb.h>
#include <netinet/in.h>
#include <sys/ioctl.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
#define closesocket close
#define WOULDBLOCK EWOULDBLOCK
#define SD_BOTH SHUT_RDWR
#endif
#ifndef INVALID_SOCKET
#define INVALID_SOCKET -1
#endif
namespace sp {
namespace io {
TcpListener::TcpListener(std::uint16_t a_Port, int a_MaxConnexions) {
if ((m_Handle = static_cast<SocketHandle>(socket(AF_INET, SOCK_STREAM, IPPROTO_TCP))) < 0) {
throw SocketError("Failed to create server socket");
}
struct sockaddr_in address;
address.sin_family = AF_INET;
address.sin_addr.s_addr = INADDR_ANY;
address.sin_port = htons(a_Port);
if (bind(m_Handle, reinterpret_cast<sockaddr*>(&address), sizeof(address)) < 0)
throw SocketError("Failed to create server socket");
if (listen(m_Handle, a_MaxConnexions) < 0)
throw SocketError("Failed to create server socket");
socklen_t len = sizeof(address);
if (getsockname(m_Handle, reinterpret_cast<sockaddr*>(&address), &len) < 0)
throw SocketError("Failed to create server socket");
m_Port = ntohs(address.sin_port);
m_MaxConnections = a_MaxConnexions;
}
TcpListener::~TcpListener() {
Close();
}
std::unique_ptr<TcpSocket> TcpListener::Accept() {
sockaddr remoteAddress;
int addrlen = sizeof(remoteAddress);
auto newSocket = std::make_unique<TcpSocket>();
newSocket->m_Handle = static_cast<SocketHandle>(
accept(m_Handle, reinterpret_cast<sockaddr*>(&remoteAddress), reinterpret_cast<socklen_t*>(&addrlen)));
if (newSocket->m_Handle < 0)
return nullptr;
newSocket->m_Status = TcpSocket::Status::Connected;
return newSocket;
}
void TcpListener::Close() {
if (m_Handle > 0) {
closesocket(m_Handle);
shutdown(m_Handle, SD_BOTH);
}
}
bool TcpListener::SetBlocking(bool a_Blocking) {
unsigned long mode = !a_Blocking;
if (ioctl(m_Handle, FIONBIO, &mode) < 0) {
return false;
}
return true;
}
std::uint16_t TcpListener::GetListeningPort() const {
return m_Port;
}
int TcpListener::GetMaximumConnections() const {
return m_MaxConnections;
}
} // namespace io
} // namespace sp

View File

@@ -0,0 +1,164 @@
#include <sp/extensions/tcp/TcpSocket.h>
#ifdef _WIN32
// Windows
#include <winsock2.h>
#include <ws2tcpip.h>
#define ioctl ioctlsocket
#define WOULDBLOCK WSAEWOULDBLOCK
#define MSG_DONTWAIT 0
#else
// Linux/Unix
#include <arpa/inet.h>
#include <netdb.h>
#include <netinet/in.h>
#include <sys/ioctl.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
#define closesocket close
#define WOULDBLOCK EWOULDBLOCK
#endif
#ifndef INVALID_SOCKET
#define INVALID_SOCKET -1
#endif
namespace sp {
namespace io {
TcpSocket::TcpSocket() : m_Handle(static_cast<SocketHandle>(INVALID_SOCKET)), m_Status(Status::Disconnected) {}
TcpSocket::TcpSocket(const std::string& a_Host, std::uint16_t a_Port) : TcpSocket() {
Connect(a_Host, a_Port);
}
TcpSocket::TcpSocket(TcpSocket&& a_Other) {
std::swap(m_Handle, a_Other.m_Handle);
std::swap(m_Status, a_Other.m_Status);
}
TcpSocket::~TcpSocket() {}
void TcpSocket::Connect(const std::string& a_Host, std::uint16_t a_Port) {
struct addrinfo hints {};
struct addrinfo* result = nullptr;
hints.ai_family = AF_INET;
hints.ai_socktype = SOCK_STREAM;
hints.ai_protocol = IPPROTO_TCP;
m_Status = Status::Error;
if (getaddrinfo(a_Host.c_str(), std::to_string(static_cast<int>(a_Port)).c_str(), &hints, &result) != 0) {
throw SocketError("Failed to get address info");
}
m_Handle = static_cast<SocketHandle>(socket(AF_INET, SOCK_STREAM, IPPROTO_TCP));
if (m_Handle < 0) {
throw SocketError("Failed to create socket");
}
struct addrinfo* ptr = nullptr;
for (ptr = result; ptr != nullptr; ptr = ptr->ai_next) {
struct sockaddr* sockaddr = ptr->ai_addr;
if (connect(m_Handle, sockaddr, sizeof(sockaddr_in)) == 0) {
break;
}
}
freeaddrinfo(result);
if (!ptr) {
throw SocketError("Could not find a suitable interface for connecting");
}
m_Status = Status::Connected;
}
DataBuffer TcpSocket::Read(std::size_t a_Amount) {
DataBuffer buffer(a_Amount);
std::size_t totalRecieved = 0;
while (totalRecieved < a_Amount) {
int recvAmount =
recv(m_Handle, reinterpret_cast<char*>(buffer.data() + totalRecieved), static_cast<int>(a_Amount - totalRecieved), 0);
if (recvAmount <= 0) {
#if defined(_WIN32) || defined(WIN32)
int err = WSAGetLastError();
#else
int err = errno;
#endif
if (err == WOULDBLOCK) {
// we are in non blocking mode and nothing is available
return {};
}
Disconnect();
m_Status = Status::Error;
throw SocketError("Error while reading");
}
totalRecieved += recvAmount;
}
return buffer;
}
void TcpSocket::Write(const sp::DataBuffer& a_Data) {
if (GetStatus() != Status::Connected)
return;
std::size_t sent = 0;
while (sent < a_Data.GetSize()) {
int cur = send(m_Handle, reinterpret_cast<const char*>(a_Data.data() + sent), static_cast<int>(a_Data.GetSize() - sent), 0);
if (cur <= 0) {
Disconnect();
m_Status = Status::Error;
return;
}
sent += static_cast<std::size_t>(cur);
}
}
bool TcpSocket::SetBlocking(bool a_Block) {
unsigned long mode = !a_Block;
if (ioctl(m_Handle, FIONBIO, &mode) < 0) {
return false;
}
return true;
}
TcpSocket::Status TcpSocket::GetStatus() const {
return m_Status;
}
void TcpSocket::Disconnect() {
if (m_Handle > 0)
closesocket(m_Handle);
m_Status = Status::Disconnected;
}
TcpSocket& TcpSocket::operator=(TcpSocket&& a_Other) {
std::swap(m_Handle, a_Other.m_Handle);
std::swap(m_Status, a_Other.m_Status);
return *this;
}
} // namespace io
} // namespace sp

View File

@@ -5,6 +5,8 @@
#include <sp/io/MessageStream.h>
#include <sp/io/StdIo.h>
#include <sp/extensions/Compress.h>
#include <cstdint>
#include <iostream>
#include <fstream>
@@ -20,6 +22,7 @@ using Message = sp::ConcreteMessage<TData, PacketID, ID, PacketHandler>;
struct KeepAlivePacket {
std::uint64_t m_KeepAlive;
std::string mdc;
};
using KeepAliveMessage = Message<KeepAlivePacket, PacketID::KeepAlive>;
@@ -31,7 +34,7 @@ class PacketHandler : public sp::MessageHandler<AllMessages> {};
class MyHandler : public PacketHandler {
public:
virtual void Handle(const KeepAlivePacket& msg) {
std::cout << "I recieved a keep alive : " << msg.m_KeepAlive << "\n";
std::cout << "I recieved a keep alive : " << msg.m_KeepAlive << " : " << msg.mdc << "\n";
}
};
@@ -42,7 +45,7 @@ using PacketFactory = sp::MessageFactory<PacketBase, AllMessages>;
using PacketStream = sp::MessageStream<PacketFactory>;
int main() {
KeepAliveMessage m{69UL};
KeepAliveMessage m{69UL, "ceci est une mdc aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"};
// dispatch tests
@@ -58,9 +61,11 @@ int main() {
// write tests
auto compress = std::make_shared<sp::ZlibCompress>();
std::ofstream file {"test.bin"};
PacketStream p(std::make_shared<sp::StdOuput>(file));
PacketStream p(std::make_shared<sp::StdOuput>(file), compress);
p.WriteMessage(m);
@@ -68,7 +73,7 @@ int main() {
std::ifstream file2 {"test.bin"};
PacketStream p2(std::make_shared<sp::StdInput>(file2));
PacketStream p2(std::make_shared<sp::StdInput>(file2), compress);
auto message2 = p2.ReadMessage();