diff --git a/include/sp/protocol/MessageDispatcher.h b/include/sp/protocol/MessageDispatcher.h index 2f5e9a0..1cdba2c 100644 --- a/include/sp/protocol/MessageDispatcher.h +++ b/include/sp/protocol/MessageDispatcher.h @@ -17,7 +17,7 @@ namespace sp { template class MessageDispatcher { private: - std::map>> m_Handlers; + std::map> m_Handlers; public: using MessageBaseType = MessageBase; @@ -38,20 +38,20 @@ class MessageDispatcher { * \param type The packet type * \param handler The packet handler */ - void RegisterHandler(MessageIdType a_MessageType, const std::weak_ptr& a_Handler); + void RegisterHandler(MessageIdType a_MessageType, MessageHandler* a_Handler); /** * \brief Unregister a packet handler * \param type The packet type * \param handler The packet handler */ - void UnregisterHandler(MessageIdType a_MessageType, const std::weak_ptr& a_Handler); + void UnregisterHandler(MessageIdType a_MessageType, MessageHandler* a_Handler); /** * \brief Unregister a packet handler * \param handler The packet handler */ - void UnregisterHandler(const std::weak_ptr& a_Handler); + void UnregisterHandler(MessageHandler* a_Handler); }; #include diff --git a/include/sp/protocol/message/MessageDispatcherImpl.inl b/include/sp/protocol/message/MessageDispatcherImpl.inl index 568c4b7..734fab6 100644 --- a/include/sp/protocol/message/MessageDispatcherImpl.inl +++ b/include/sp/protocol/message/MessageDispatcherImpl.inl @@ -1,33 +1,33 @@ #pragma once template -void MessageDispatcher::RegisterHandler(MessageIdType a_MessageType, const std::weak_ptr& a_Handler) { - auto found = std::find_if(m_Handlers[a_MessageType].begin(), m_Handlers[a_MessageType].end(), [&a_Handler](const std::weak_ptr& handler){ - return a_Handler.lock() == handler.lock(); +void MessageDispatcher::RegisterHandler(MessageIdType a_MessageType, MessageHandler* a_Handler) { + auto found = std::find_if(m_Handlers[a_MessageType].begin(), m_Handlers[a_MessageType].end(), [&a_Handler](const MessageHandler* handler){ + return a_Handler == handler; }); if (found == m_Handlers[a_MessageType].end()) m_Handlers[a_MessageType].push_back(a_Handler); } template -void MessageDispatcher::UnregisterHandler(MessageIdType a_MessageType, const std::weak_ptr& a_Handler) { - auto found = std::find_if(m_Handlers[a_MessageType].begin(), m_Handlers[a_MessageType].end(), [&a_Handler](const std::weak_ptr& handler){ - return a_Handler.lock() == handler.lock(); +void MessageDispatcher::UnregisterHandler(MessageIdType a_MessageType, MessageHandler* a_Handler) { + auto found = std::find_if(m_Handlers[a_MessageType].begin(), m_Handlers[a_MessageType].end(), [&a_Handler](const MessageHandler* handler){ + return a_Handler == handler; }); if (found != m_Handlers[a_MessageType].end()) m_Handlers[a_MessageType].erase(found); } template -void MessageDispatcher::UnregisterHandler(const std::weak_ptr& a_Handler) { +void MessageDispatcher::UnregisterHandler(MessageHandler* a_Handler) { for (auto& pair : m_Handlers) { if (pair.second.empty()) continue; MessageIdType type = pair.first; - auto it = std::find_if(pair.second.begin(), pair.second.end(), [&a_Handler](const std::weak_ptr& handler){ - return handler.lock() == a_Handler.lock(); + auto it = std::find_if(pair.second.begin(), pair.second.end(), [&a_Handler](MessageHandler* handler){ + return handler == a_Handler; }); if (it != pair.second.end()) @@ -40,7 +40,7 @@ template void MessageDispatcher::Dispatch(const MessageBase& a_Message) { MessageIdType type = a_Message.GetId(); for (auto& handler : m_Handlers[type]) { - if (!handler.expired()) - a_Message.Dispatch(*handler.lock()); + if (handler) + a_Message.Dispatch(*handler); } } diff --git a/test/test_file.cpp b/test/test_file.cpp index 2a89a3c..f777338 100644 --- a/test/test_file.cpp +++ b/test/test_file.cpp @@ -23,8 +23,8 @@ int main() { auto handler = std::make_shared(); FileStream stream(sp::io::File{"test.txt", sp::io::FileTag::In | sp::io::FileTag::Out}, {}); - stream.GetDispatcher().RegisterHandler(PacketId::Disconnect, handler); - stream.GetDispatcher().RegisterHandler(PacketId::KeepAlive, handler); + stream.GetDispatcher().RegisterHandler(PacketId::Disconnect, handler.get()); + stream.GetDispatcher().RegisterHandler(PacketId::KeepAlive, handler.get()); stream.SendMessage(KeepAlivePacket{96}); stream.SendMessage(KeepAlivePacket{69}); diff --git a/test/test_io.cpp b/test/test_io.cpp index c8affb6..d67c521 100644 --- a/test/test_io.cpp +++ b/test/test_io.cpp @@ -23,13 +23,13 @@ int main() { auto handler = std::make_shared(); DataBufferStream stream; - stream.GetDispatcher().RegisterHandler(PacketId::Disconnect, handler); + stream.GetDispatcher().RegisterHandler(PacketId::Disconnect, handler.get()); // this should not be dispatched stream.SendMessage(KeepAlivePacket{96}); stream.RecieveMessages(); - stream.GetDispatcher().RegisterHandler(PacketId::KeepAlive, handler); + stream.GetDispatcher().RegisterHandler(PacketId::KeepAlive, handler.get()); stream.SendMessage(KeepAlivePacket{69}); stream.RecieveMessages(); diff --git a/test/test_packets.cpp b/test/test_packets.cpp index e89e171..b17f582 100644 --- a/test/test_packets.cpp +++ b/test/test_packets.cpp @@ -48,10 +48,10 @@ int main() { packet->Dispatch(*handler); sp::PacketDispatcher dispatcher; - dispatcher.RegisterHandler(PacketId::KeepAlive, handler); + dispatcher.RegisterHandler(PacketId::KeepAlive, handler.get()); dispatcher.Dispatch(*packet); - dispatcher.UnregisterHandler(PacketId::KeepAlive, handler); - dispatcher.UnregisterHandler(handler); + dispatcher.UnregisterHandler(PacketId::KeepAlive, handler.get()); + dispatcher.UnregisterHandler(handler.get()); return 0; } \ No newline at end of file