diff --git a/httplib.h b/httplib.h index 2fdd76c..6fa1019 100644 --- a/httplib.h +++ b/httplib.h @@ -597,6 +597,7 @@ inline void default_socket_options(socket_t sock) { class Server { public: using Handler = std::function; + using HandlerWithReturn = std::function; using HandlerWithContentReader = std::function; using Expect100ContinueHandler = @@ -627,7 +628,11 @@ public: const char *mime); void set_file_request_handler(Handler handler); + void set_error_handler(HandlerWithReturn handler); void set_error_handler(Handler handler); + void set_pre_routing_handler(HandlerWithReturn handler); + void set_post_routing_handler(Handler handler); + void set_expect_100_continue_handler(Expect100ContinueHandler handler); void set_logger(Logger logger); @@ -734,7 +739,9 @@ private: Handlers delete_handlers_; HandlersForContentReader delete_handlers_for_content_reader_; Handlers options_handlers_; - Handler error_handler_; + HandlerWithReturn error_handler_; + HandlerWithReturn pre_routing_handler_; + Handler post_routing_handler_; Logger logger_; Expect100ContinueHandler expect_100_continue_handler_; @@ -4160,14 +4167,23 @@ inline void Server::set_file_request_handler(Handler handler) { file_request_handler_ = std::move(handler); } -inline void Server::set_error_handler(Handler handler) { +inline void Server::set_error_handler(HandlerWithReturn handler) { error_handler_ = std::move(handler); } -inline void Server::set_tcp_nodelay(bool on) { tcp_nodelay_ = on; } +inline void Server::set_error_handler(Handler handler) { + error_handler_ = [handler](const Request &req, Response &res) { + handler(req, res); + return true; + }; +} -inline void Server::set_socket_options(SocketOptions socket_options) { - socket_options_ = std::move(socket_options); +inline void Server::set_pre_routing_handler(HandlerWithReturn handler) { + pre_routing_handler_ = std::move(handler); +} + +inline void Server::set_post_routing_handler(Handler handler) { + post_routing_handler_ = std::move(handler); } inline void Server::set_logger(Logger logger) { logger_ = std::move(logger); } @@ -4177,6 +4193,12 @@ Server::set_expect_100_continue_handler(Expect100ContinueHandler handler) { expect_100_continue_handler_ = std::move(handler); } +inline void Server::set_tcp_nodelay(bool on) { tcp_nodelay_ = on; } + +inline void Server::set_socket_options(SocketOptions socket_options) { + socket_options_ = std::move(socket_options); +} + inline void Server::set_keep_alive_max_count(size_t count) { keep_alive_max_count_ = count; } @@ -4268,8 +4290,7 @@ inline bool Server::write_response_core(Stream &strm, bool close_connection, bool need_apply_ranges) { assert(res.status != -1); - if (400 <= res.status && error_handler_) { - error_handler_(req, res); + if (400 <= res.status && error_handler_ && error_handler_(req, res)) { need_apply_ranges = true; } @@ -4277,7 +4298,7 @@ inline bool Server::write_response_core(Stream &strm, bool close_connection, std::string boundary; if (need_apply_ranges) { apply_ranges(req, res, content_type, boundary); } - // Preapre additional headers + // Prepare additional headers if (close_connection || req.get_header_value("Connection") == "close") { res.set_header("Connection", "close"); } else { @@ -4301,6 +4322,8 @@ inline bool Server::write_response_core(Stream &strm, bool close_connection, res.set_header("Accept-Ranges", "bytes"); } + if (post_routing_handler_) { post_routing_handler_(req, res); } + // Response line and headers { detail::BufferStream bstrm; @@ -4604,6 +4627,8 @@ inline bool Server::listen_internal() { } inline bool Server::routing(Request &req, Response &res, Stream &strm) { + if (pre_routing_handler_ && pre_routing_handler_(req, res)) { return true; } + // File handler bool is_head_request = req.method == "HEAD"; if ((req.method == "GET" || is_head_request) && @@ -5302,7 +5327,7 @@ inline bool ClientImpl::write_content_with_provider(Stream &strm, inline bool ClientImpl::write_request(Stream &strm, const Request &req, bool close_connection, Error &error) { - // Prepare additonal headers + // Prepare additional headers Headers headers; if (close_connection) { headers.emplace("Connection", "close"); } diff --git a/test/test.cc b/test/test.cc index 4417501..51290ac 100644 --- a/test/test.cc +++ b/test/test.cc @@ -953,6 +953,77 @@ TEST(ErrorHandlerTest, ContentLength) { ASSERT_FALSE(svr.is_running()); } +TEST(RoutingHandlerTest, PreRoutingHandler) { + Server svr; + + svr.set_pre_routing_handler([](const Request &req, Response &res) { + if (req.path == "/routing_handler") { + res.set_header("PRE_ROUTING", "on"); + res.set_content("Routing Handler", "text/plain"); + return true; + } + return false; + }); + + svr.set_error_handler([](const Request & /*req*/, Response &res) { + res.set_content("Error", "text/html"); + }); + + svr.set_post_routing_handler([](const Request &req, Response &res) { + if (req.path == "/routing_handler") { + res.set_header("POST_ROUTING", "on"); + } + }); + + svr.Get("/hi", [](const Request & /*req*/, Response &res) { + res.set_content("Hello World!\n", "text/plain"); + }); + + auto thread = std::thread([&]() { svr.listen(HOST, PORT); }); + + // Give GET time to get a few messages. + std::this_thread::sleep_for(std::chrono::seconds(1)); + + { + Client cli(HOST, PORT); + + auto res = cli.Get("/routing_handler"); + ASSERT_TRUE(res); + EXPECT_EQ(200, res->status); + EXPECT_EQ("Routing Handler", res->body); + EXPECT_EQ(1, res->get_header_value_count("PRE_ROUTING")); + EXPECT_EQ("on", res->get_header_value("PRE_ROUTING")); + EXPECT_EQ(1, res->get_header_value_count("POST_ROUTING")); + EXPECT_EQ("on", res->get_header_value("POST_ROUTING")); + } + + { + Client cli(HOST, PORT); + + auto res = cli.Get("/hi"); + ASSERT_TRUE(res); + EXPECT_EQ(200, res->status); + EXPECT_EQ("Hello World!\n", res->body); + EXPECT_EQ(0, res->get_header_value_count("PRE_ROUTING")); + EXPECT_EQ(0, res->get_header_value_count("POST_ROUTING")); + } + + { + Client cli(HOST, PORT); + + auto res = cli.Get("/aaa"); + ASSERT_TRUE(res); + EXPECT_EQ(404, res->status); + EXPECT_EQ("Error", res->body); + EXPECT_EQ(0, res->get_header_value_count("PRE_ROUTING")); + EXPECT_EQ(0, res->get_header_value_count("POST_ROUTING")); + } + + svr.stop(); + thread.join(); + ASSERT_FALSE(svr.is_running()); +} + TEST(InvalidFormatTest, StatusCode) { Server svr;