Add optional user defined header writer (#1683)

* Add optional user defined header writer

* Fix errors and add test
This commit is contained in:
PabloMK7 2023-10-01 04:13:14 +02:00 committed by GitHub
parent c029597a5a
commit a609330e4c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 74 additions and 2 deletions

View file

@ -737,6 +737,8 @@ private:
std::regex regex_; std::regex regex_;
}; };
ssize_t write_headers(Stream &strm, const Headers &headers);
} // namespace detail } // namespace detail
class Server { class Server {
@ -800,6 +802,8 @@ public:
Server &set_socket_options(SocketOptions socket_options); Server &set_socket_options(SocketOptions socket_options);
Server &set_default_headers(Headers headers); Server &set_default_headers(Headers headers);
Server &
set_header_writer(std::function<ssize_t(Stream &, Headers &)> const &writer);
Server &set_keep_alive_max_count(size_t count); Server &set_keep_alive_max_count(size_t count);
Server &set_keep_alive_timeout(time_t sec); Server &set_keep_alive_timeout(time_t sec);
@ -934,6 +938,8 @@ private:
SocketOptions socket_options_ = default_socket_options; SocketOptions socket_options_ = default_socket_options;
Headers default_headers_; Headers default_headers_;
std::function<ssize_t(Stream &, Headers &)> header_writer_ =
detail::write_headers;
}; };
enum class Error { enum class Error {
@ -1164,6 +1170,9 @@ public:
void set_default_headers(Headers headers); void set_default_headers(Headers headers);
void
set_header_writer(std::function<ssize_t(Stream &, Headers &)> const &writer);
void set_address_family(int family); void set_address_family(int family);
void set_tcp_nodelay(bool on); void set_tcp_nodelay(bool on);
void set_socket_options(SocketOptions socket_options); void set_socket_options(SocketOptions socket_options);
@ -1273,6 +1282,10 @@ protected:
// Default headers // Default headers
Headers default_headers_; Headers default_headers_;
// Header writer
std::function<ssize_t(Stream &, Headers &)> header_writer_ =
detail::write_headers;
// Settings // Settings
std::string client_cert_path_; std::string client_cert_path_;
std::string client_key_path_; std::string client_key_path_;
@ -1539,6 +1552,9 @@ public:
void set_default_headers(Headers headers); void set_default_headers(Headers headers);
void
set_header_writer(std::function<ssize_t(Stream &, Headers &)> const &writer);
void set_address_family(int family); void set_address_family(int family);
void set_tcp_nodelay(bool on); void set_tcp_nodelay(bool on);
void set_socket_options(SocketOptions socket_options); void set_socket_options(SocketOptions socket_options);
@ -5672,6 +5688,12 @@ inline Server &Server::set_default_headers(Headers headers) {
return *this; return *this;
} }
inline Server &Server::set_header_writer(
std::function<ssize_t(Stream &, Headers &)> const &writer) {
header_writer_ = writer;
return *this;
}
inline Server &Server::set_keep_alive_max_count(size_t count) { inline Server &Server::set_keep_alive_max_count(size_t count) {
keep_alive_max_count_ = count; keep_alive_max_count_ = count;
return *this; return *this;
@ -5866,7 +5888,7 @@ inline bool Server::write_response_core(Stream &strm, bool close_connection,
return false; return false;
} }
if (!detail::write_headers(bstrm, res.headers)) { return false; } if (!header_writer_(bstrm, res.headers)) { return false; }
// Flush buffer // Flush buffer
auto &data = bstrm.get_buffer(); auto &data = bstrm.get_buffer();
@ -7105,7 +7127,7 @@ inline bool ClientImpl::write_request(Stream &strm, Request &req,
const auto &path = url_encode_ ? detail::encode_url(req.path) : req.path; const auto &path = url_encode_ ? detail::encode_url(req.path) : req.path;
bstrm.write_format("%s %s HTTP/1.1\r\n", req.method.c_str(), path.c_str()); bstrm.write_format("%s %s HTTP/1.1\r\n", req.method.c_str(), path.c_str());
detail::write_headers(bstrm, req.headers); header_writer_(bstrm, req.headers);
// Flush buffer // Flush buffer
auto &data = bstrm.get_buffer(); auto &data = bstrm.get_buffer();
@ -7916,6 +7938,11 @@ inline void ClientImpl::set_default_headers(Headers headers) {
default_headers_ = std::move(headers); default_headers_ = std::move(headers);
} }
inline void ClientImpl::set_header_writer(
std::function<ssize_t(Stream &, Headers &)> const &writer) {
header_writer_ = writer;
}
inline void ClientImpl::set_address_family(int family) { inline void ClientImpl::set_address_family(int family) {
address_family_ = family; address_family_ = family;
} }
@ -9110,6 +9137,11 @@ inline void Client::set_default_headers(Headers headers) {
cli_->set_default_headers(std::move(headers)); cli_->set_default_headers(std::move(headers));
} }
inline void Client::set_header_writer(
std::function<ssize_t(Stream &, Headers &)> const &writer) {
cli_->set_header_writer(writer);
}
inline void Client::set_address_family(int family) { inline void Client::set_address_family(int family) {
cli_->set_address_family(family); cli_->set_address_family(family);
} }

View file

@ -1592,6 +1592,46 @@ TEST(URLFragmentTest, WithFragment) {
} }
} }
TEST(HeaderWriter, SetHeaderWriter) {
Server svr;
svr.set_header_writer([](Stream &strm, Headers &hdrs) {
hdrs.emplace("CustomServerHeader", "CustomServerValue");
return detail::write_headers(strm, hdrs);
});
svr.Get("/hi", [](const Request &req, Response &res) {
auto it = req.headers.find("CustomClientHeader");
EXPECT_TRUE(it != req.headers.end());
EXPECT_EQ(it->second, "CustomClientValue");
res.set_content("Hello World!\n", "text/plain");
});
auto thread = std::thread([&]() { svr.listen(HOST, PORT); });
auto se = detail::scope_exit([&] {
svr.stop();
thread.join();
ASSERT_FALSE(svr.is_running());
});
std::this_thread::sleep_for(std::chrono::seconds(1));
{
Client cli(HOST, PORT);
cli.set_header_writer([](Stream &strm, Headers &hdrs) {
hdrs.emplace("CustomClientHeader", "CustomClientValue");
return detail::write_headers(strm, hdrs);
});
auto res = cli.Get("/hi");
EXPECT_TRUE(res);
EXPECT_EQ(200, res->status);
auto it = res->headers.find("CustomServerHeader");
EXPECT_TRUE(it != res->headers.end());
EXPECT_EQ(it->second, "CustomServerValue");
}
}
class ServerTest : public ::testing::Test { class ServerTest : public ::testing::Test {
protected: protected:
ServerTest() ServerTest()