diff options
-rw-r--r-- | http/routing/websocketrule.hpp | 12 | ||||
-rw-r--r-- | http/websocket.hpp | 56 | ||||
-rw-r--r-- | include/nbd_proxy.hpp | 4 |
3 files changed, 40 insertions, 32 deletions
diff --git a/http/routing/websocketrule.hpp b/http/routing/websocketrule.hpp index 5e558f2edb..bf6daad6d0 100644 --- a/http/routing/websocketrule.hpp +++ b/http/routing/websocketrule.hpp @@ -37,9 +37,9 @@ class WebSocketRule : public BaseRule crow::websocket::ConnectionImpl<boost::asio::ip::tcp::socket>> myConnection = std::make_shared< crow::websocket::ConnectionImpl<boost::asio::ip::tcp::socket>>( - req, req.url(), std::move(adaptor), openHandler, messageHandler, - messageExHandler, closeHandler, errorHandler); - myConnection->start(); + req.url(), req.session, std::move(adaptor), openHandler, + messageHandler, messageExHandler, closeHandler, errorHandler); + myConnection->start(req); } #else void handleUpgrade(const Request& req, @@ -52,9 +52,9 @@ class WebSocketRule : public BaseRule boost::beast::ssl_stream<boost::asio::ip::tcp::socket>>> myConnection = std::make_shared<crow::websocket::ConnectionImpl< boost::beast::ssl_stream<boost::asio::ip::tcp::socket>>>( - req, req.url(), std::move(adaptor), openHandler, messageHandler, - messageExHandler, closeHandler, errorHandler); - myConnection->start(); + req.url(), req.session, std::move(adaptor), openHandler, + messageHandler, messageExHandler, closeHandler, errorHandler); + myConnection->start(req); } #endif diff --git a/http/websocket.hpp b/http/websocket.hpp index 216c4131db..0fda7ee275 100644 --- a/http/websocket.hpp +++ b/http/websocket.hpp @@ -27,7 +27,7 @@ enum class MessageType struct Connection : std::enable_shared_from_this<Connection> { public: - explicit Connection(const crow::Request& reqIn) : req(reqIn.req) {} + Connection() = default; Connection(const Connection&) = delete; Connection(Connection&&) = delete; @@ -46,15 +46,17 @@ struct Connection : std::enable_shared_from_this<Connection> virtual boost::asio::io_context& getIoContext() = 0; virtual ~Connection() = default; virtual boost::urls::url_view url() = 0; - boost::beast::http::request<boost::beast::http::string_body> req; }; template <typename Adaptor> class ConnectionImpl : public Connection { + using self_t = ConnectionImpl<Adaptor>; + public: ConnectionImpl( - const crow::Request& reqIn, boost::urls::url_view urlViewIn, + const boost::urls::url_view& urlViewIn, + const std::shared_ptr<persistent_data::UserSession>& sessionIn, Adaptor adaptorIn, std::function<void(Connection&)> openHandlerIn, std::function<void(Connection&, const std::string&, bool)> messageHandlerIn, @@ -64,13 +66,13 @@ class ConnectionImpl : public Connection messageExHandlerIn, std::function<void(Connection&, const std::string&)> closeHandlerIn, std::function<void(Connection&)> errorHandlerIn) : - Connection(reqIn), - uri(urlViewIn), ws(std::move(adaptorIn)), inBuffer(inString, 131088), + uri(urlViewIn), + ws(std::move(adaptorIn)), inBuffer(inString, 131088), openHandler(std::move(openHandlerIn)), messageHandler(std::move(messageHandlerIn)), messageExHandler(std::move(messageExHandlerIn)), closeHandler(std::move(closeHandlerIn)), - errorHandler(std::move(errorHandlerIn)), session(reqIn.session) + errorHandler(std::move(errorHandlerIn)), session(sessionIn) { /* Turn on the timeouts on websocket stream to server role */ ws.set_option(boost::beast::websocket::stream_base::timeout::suggested( @@ -84,17 +86,16 @@ class ConnectionImpl : public Connection ws.get_executor().context()); } - void start() + void start(const crow::Request& req) { BMCWEB_LOG_DEBUG("starting connection {}", logPtr(this)); using bf = boost::beast::http::field; - - std::string_view protocol = req[bf::sec_websocket_protocol]; + std::string protocolHeader = req.req[bf::sec_websocket_protocol]; ws.set_option(boost::beast::websocket::stream_base::decorator( - [session{session}, protocol{std::string(protocol)}]( - boost::beast::websocket::response_type& m) { + [session{session}, + protocolHeader](boost::beast::websocket::response_type& m) { #ifndef BMCWEB_INSECURE_DISABLE_CSRF_PREVENTION if (session != nullptr) @@ -102,7 +103,7 @@ class ConnectionImpl : public Connection // use protocol for csrf checking if (session->cookieAuth && !crow::utility::constantTimeStringCompare( - protocol, session->csrfToken)) + protocolHeader, session->csrfToken)) { BMCWEB_LOG_ERROR("Websocket CSRF error"); m.result(boost::beast::http::status::unauthorized); @@ -110,9 +111,9 @@ class ConnectionImpl : public Connection } } #endif - if (!protocol.empty()) + if (!protocolHeader.empty()) { - m.insert(bf::sec_websocket_protocol, protocol); + m.insert(bf::sec_websocket_protocol, protocolHeader); } m.insert(bf::strict_transport_security, "max-age=31536000; " @@ -126,16 +127,15 @@ class ConnectionImpl : public Connection m.insert("X-Content-Type-Options", "nosniff"); })); + // Make a pointer to keep the req alive while we accept it. + using Body = + boost::beast::http::request<boost::beast::http::string_body>; + std::unique_ptr<Body> mobile = std::make_unique<Body>(req.req); + Body* ptr = mobile.get(); // Perform the websocket upgrade - ws.async_accept(req, [this, self(shared_from_this())]( - const boost::system::error_code& ec) { - if (ec) - { - BMCWEB_LOG_ERROR("Error in ws.async_accept {}", ec); - return; - } - acceptDone(); - }); + ws.async_accept(*ptr, + std::bind_front(&self_t::acceptDone, this, + shared_from_this(), std::move(mobile))); } void sendBinary(std::string_view msg) override @@ -221,8 +221,16 @@ class ConnectionImpl : public Connection return uri; } - void acceptDone() + void acceptDone(const std::shared_ptr<Connection>& /*self*/, + const std::unique_ptr<boost::beast::http::request< + boost::beast::http::string_body>>& /*req*/, + const boost::system::error_code& ec) { + if (ec) + { + BMCWEB_LOG_ERROR("Error in ws.async_accept {}", ec); + return; + } BMCWEB_LOG_DEBUG("Websocket accepted connection"); if (openHandler) diff --git a/include/nbd_proxy.hpp b/include/nbd_proxy.hpp index 5de3039765..5540886419 100644 --- a/include/nbd_proxy.hpp +++ b/include/nbd_proxy.hpp @@ -289,7 +289,7 @@ inline void } if ((endpointValue != nullptr) && (socketValue != nullptr) && - *endpointValue == conn.req.target()) + *endpointValue == conn.url().path()) { endpointObjectPath = &objectPath.str; break; @@ -305,7 +305,7 @@ inline void for (const auto& session : sessions) { - if (session.second->getEndpointId() == conn.req.target()) + if (session.second->getEndpointId() == conn.url().path()) { BMCWEB_LOG_ERROR("Cannot open new connection - socket is in use"); conn.close("Slot is in use"); |