summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--http/routing/websocketrule.hpp12
-rw-r--r--http/websocket.hpp56
-rw-r--r--include/nbd_proxy.hpp4
3 files changed, 40 insertions, 32 deletions
diff --git a/http/routing/websocketrule.hpp b/http/routing/websocketrule.hpp
index 5e558f2edb..bf6daad6d0 100644
--- a/http/routing/websocketrule.hpp
+++ b/http/routing/websocketrule.hpp
@@ -37,9 +37,9 @@ class WebSocketRule : public BaseRule
crow::websocket::ConnectionImpl<boost::asio::ip::tcp::socket>>
myConnection = std::make_shared<
crow::websocket::ConnectionImpl<boost::asio::ip::tcp::socket>>(
- req, req.url(), std::move(adaptor), openHandler, messageHandler,
- messageExHandler, closeHandler, errorHandler);
- myConnection->start();
+ req.url(), req.session, std::move(adaptor), openHandler,
+ messageHandler, messageExHandler, closeHandler, errorHandler);
+ myConnection->start(req);
}
#else
void handleUpgrade(const Request& req,
@@ -52,9 +52,9 @@ class WebSocketRule : public BaseRule
boost::beast::ssl_stream<boost::asio::ip::tcp::socket>>>
myConnection = std::make_shared<crow::websocket::ConnectionImpl<
boost::beast::ssl_stream<boost::asio::ip::tcp::socket>>>(
- req, req.url(), std::move(adaptor), openHandler, messageHandler,
- messageExHandler, closeHandler, errorHandler);
- myConnection->start();
+ req.url(), req.session, std::move(adaptor), openHandler,
+ messageHandler, messageExHandler, closeHandler, errorHandler);
+ myConnection->start(req);
}
#endif
diff --git a/http/websocket.hpp b/http/websocket.hpp
index 216c4131db..0fda7ee275 100644
--- a/http/websocket.hpp
+++ b/http/websocket.hpp
@@ -27,7 +27,7 @@ enum class MessageType
struct Connection : std::enable_shared_from_this<Connection>
{
public:
- explicit Connection(const crow::Request& reqIn) : req(reqIn.req) {}
+ Connection() = default;
Connection(const Connection&) = delete;
Connection(Connection&&) = delete;
@@ -46,15 +46,17 @@ struct Connection : std::enable_shared_from_this<Connection>
virtual boost::asio::io_context& getIoContext() = 0;
virtual ~Connection() = default;
virtual boost::urls::url_view url() = 0;
- boost::beast::http::request<boost::beast::http::string_body> req;
};
template <typename Adaptor>
class ConnectionImpl : public Connection
{
+ using self_t = ConnectionImpl<Adaptor>;
+
public:
ConnectionImpl(
- const crow::Request& reqIn, boost::urls::url_view urlViewIn,
+ const boost::urls::url_view& urlViewIn,
+ const std::shared_ptr<persistent_data::UserSession>& sessionIn,
Adaptor adaptorIn, std::function<void(Connection&)> openHandlerIn,
std::function<void(Connection&, const std::string&, bool)>
messageHandlerIn,
@@ -64,13 +66,13 @@ class ConnectionImpl : public Connection
messageExHandlerIn,
std::function<void(Connection&, const std::string&)> closeHandlerIn,
std::function<void(Connection&)> errorHandlerIn) :
- Connection(reqIn),
- uri(urlViewIn), ws(std::move(adaptorIn)), inBuffer(inString, 131088),
+ uri(urlViewIn),
+ ws(std::move(adaptorIn)), inBuffer(inString, 131088),
openHandler(std::move(openHandlerIn)),
messageHandler(std::move(messageHandlerIn)),
messageExHandler(std::move(messageExHandlerIn)),
closeHandler(std::move(closeHandlerIn)),
- errorHandler(std::move(errorHandlerIn)), session(reqIn.session)
+ errorHandler(std::move(errorHandlerIn)), session(sessionIn)
{
/* Turn on the timeouts on websocket stream to server role */
ws.set_option(boost::beast::websocket::stream_base::timeout::suggested(
@@ -84,17 +86,16 @@ class ConnectionImpl : public Connection
ws.get_executor().context());
}
- void start()
+ void start(const crow::Request& req)
{
BMCWEB_LOG_DEBUG("starting connection {}", logPtr(this));
using bf = boost::beast::http::field;
-
- std::string_view protocol = req[bf::sec_websocket_protocol];
+ std::string protocolHeader = req.req[bf::sec_websocket_protocol];
ws.set_option(boost::beast::websocket::stream_base::decorator(
- [session{session}, protocol{std::string(protocol)}](
- boost::beast::websocket::response_type& m) {
+ [session{session},
+ protocolHeader](boost::beast::websocket::response_type& m) {
#ifndef BMCWEB_INSECURE_DISABLE_CSRF_PREVENTION
if (session != nullptr)
@@ -102,7 +103,7 @@ class ConnectionImpl : public Connection
// use protocol for csrf checking
if (session->cookieAuth &&
!crow::utility::constantTimeStringCompare(
- protocol, session->csrfToken))
+ protocolHeader, session->csrfToken))
{
BMCWEB_LOG_ERROR("Websocket CSRF error");
m.result(boost::beast::http::status::unauthorized);
@@ -110,9 +111,9 @@ class ConnectionImpl : public Connection
}
}
#endif
- if (!protocol.empty())
+ if (!protocolHeader.empty())
{
- m.insert(bf::sec_websocket_protocol, protocol);
+ m.insert(bf::sec_websocket_protocol, protocolHeader);
}
m.insert(bf::strict_transport_security, "max-age=31536000; "
@@ -126,16 +127,15 @@ class ConnectionImpl : public Connection
m.insert("X-Content-Type-Options", "nosniff");
}));
+ // Make a pointer to keep the req alive while we accept it.
+ using Body =
+ boost::beast::http::request<boost::beast::http::string_body>;
+ std::unique_ptr<Body> mobile = std::make_unique<Body>(req.req);
+ Body* ptr = mobile.get();
// Perform the websocket upgrade
- ws.async_accept(req, [this, self(shared_from_this())](
- const boost::system::error_code& ec) {
- if (ec)
- {
- BMCWEB_LOG_ERROR("Error in ws.async_accept {}", ec);
- return;
- }
- acceptDone();
- });
+ ws.async_accept(*ptr,
+ std::bind_front(&self_t::acceptDone, this,
+ shared_from_this(), std::move(mobile)));
}
void sendBinary(std::string_view msg) override
@@ -221,8 +221,16 @@ class ConnectionImpl : public Connection
return uri;
}
- void acceptDone()
+ void acceptDone(const std::shared_ptr<Connection>& /*self*/,
+ const std::unique_ptr<boost::beast::http::request<
+ boost::beast::http::string_body>>& /*req*/,
+ const boost::system::error_code& ec)
{
+ if (ec)
+ {
+ BMCWEB_LOG_ERROR("Error in ws.async_accept {}", ec);
+ return;
+ }
BMCWEB_LOG_DEBUG("Websocket accepted connection");
if (openHandler)
diff --git a/include/nbd_proxy.hpp b/include/nbd_proxy.hpp
index 5de3039765..5540886419 100644
--- a/include/nbd_proxy.hpp
+++ b/include/nbd_proxy.hpp
@@ -289,7 +289,7 @@ inline void
}
if ((endpointValue != nullptr) && (socketValue != nullptr) &&
- *endpointValue == conn.req.target())
+ *endpointValue == conn.url().path())
{
endpointObjectPath = &objectPath.str;
break;
@@ -305,7 +305,7 @@ inline void
for (const auto& session : sessions)
{
- if (session.second->getEndpointId() == conn.req.target())
+ if (session.second->getEndpointId() == conn.url().path())
{
BMCWEB_LOG_ERROR("Cannot open new connection - socket is in use");
conn.close("Slot is in use");