diff --git a/include/sp/protocol/MessageDispatcher.h b/include/sp/protocol/MessageDispatcher.h index c7e2f6d..8ee0d06 100644 --- a/include/sp/protocol/MessageDispatcher.h +++ b/include/sp/protocol/MessageDispatcher.h @@ -19,7 +19,6 @@ template class MessageDispatcher { public: using MessageBaseType = MessageBase; - using MessageIdType = typename MessageBase::MessageIdType; using MessageHandler = typename MessageBase::HandlerType; /** @@ -38,23 +37,16 @@ class MessageDispatcher { * \param type The packet type * \param handler The packet handler */ - void RegisterHandler(MessageIdType a_MessageType, MessageHandler* a_Handler); - - /** - * \brief Unregister a packet handler - * \param type The packet type - * \param handler The packet handler - */ - void UnregisterHandler(MessageIdType a_MessageType, MessageHandler* a_Handler); + void RegisterHandler(const std::shared_ptr& a_Handler); /** * \brief Unregister a packet handler * \param handler The packet handler */ - void UnregisterHandler(MessageHandler* a_Handler); + void UnregisterHandler(const std::shared_ptr& a_Handler); private: - std::map> m_Handlers; + std::vector> m_Handlers; }; } // namespace sp diff --git a/include/sp/protocol/MessageDispatcherImpl.inl b/include/sp/protocol/MessageDispatcherImpl.inl index 408a9b9..1000faf 100644 --- a/include/sp/protocol/MessageDispatcherImpl.inl +++ b/include/sp/protocol/MessageDispatcherImpl.inl @@ -6,37 +6,22 @@ namespace sp { template -void MessageDispatcher::RegisterHandler(MessageIdType a_MessageType, MessageHandler* a_Handler) { +void MessageDispatcher::RegisterHandler(const std::shared_ptr& a_Handler) { assert(a_Handler); - auto found = std::find(m_Handlers[a_MessageType].begin(), m_Handlers[a_MessageType].end(), a_Handler); - if (found == m_Handlers[a_MessageType].end()) - m_Handlers[a_MessageType].push_back(a_Handler); + m_Handlers.push_back(a_Handler); } template -void MessageDispatcher::UnregisterHandler(MessageIdType a_MessageType, MessageHandler* a_Handler) { - auto found = std::find(m_Handlers[a_MessageType].begin(), m_Handlers[a_MessageType].end(), a_Handler); - if (found != m_Handlers[a_MessageType].end()) - m_Handlers[a_MessageType].erase(found); -} - -template -void MessageDispatcher::UnregisterHandler(MessageHandler* a_Handler) { - for (auto& pair : m_Handlers) { - if (pair.second.empty()) - continue; - - MessageIdType type = pair.first; - - pair.second.erase(std::remove(pair.second.begin(), pair.second.end(), a_Handler), pair.second.end()); - } +void MessageDispatcher::UnregisterHandler(const std::shared_ptr& a_Handler) { + auto found = std::find(m_Handlers.begin(), m_Handlers.end(), a_Handler); + if (found != m_Handlers.end()) + m_Handlers.erase(found); } template void MessageDispatcher::Dispatch(const MessageBase& a_Message) { - MessageIdType type = a_Message.GetId(); - for (auto& handler : m_Handlers[type]) { - a_Message.Dispatch(*handler); + for (auto& handler : m_Handlers) { + a_Message.Dispatch(*handler.lock()); } } diff --git a/test/test_message.cpp b/test/test_message.cpp index 8a10ef8..8798680 100644 --- a/test/test_message.cpp +++ b/test/test_message.cpp @@ -61,10 +61,9 @@ int main() { // dispatch tests - MyHandler h; + auto h = std::make_shared(); PacketDispatcher d; - d.RegisterHandler(PacketID::KeepAlive, &h); - d.RegisterHandler(PacketID::MDC, &h); + d.RegisterHandler(h); d.Dispatch(m); PacketFactory f; auto message = f.CreateMessage(PacketID::KeepAlive);