diff --git a/README.md b/README.md index 69a7f11..291e4bd 100644 --- a/README.md +++ b/README.md @@ -118,6 +118,11 @@ params["note"] = "coder"; auto res = cli.post("/post", params); ``` +### Connection Timeout + +```c++ +httplib::Client cli("localhost", 8080, 5); // timeouts in 5 seconds +``` ### With Progress Callback ```cpp diff --git a/httplib.h b/httplib.h index cded0e1..cc000e6 100644 --- a/httplib.h +++ b/httplib.h @@ -27,7 +27,6 @@ #define S_ISDIR(m) (((m)&S_IFDIR)==S_IFDIR) #endif -#include #include #include #include @@ -57,6 +56,7 @@ typedef int socket_t; #include #include #include +#include #include #ifdef CPPHTTPLIB_OPENSSL_SUPPORT @@ -207,6 +207,8 @@ protected: private: typedef std::vector> Handlers; + socket_t create_server_socket(const char* host, int port, int socket_flags) const; + bool routing(Request& req, Response& res); bool handle_file_request(Request& req, Response& res); bool dispatch_request(Request& req, Response& res, Handlers& handlers); @@ -226,7 +228,12 @@ private: class Client { public: - Client(const char* host, int port, HttpVersion http_version = HttpVersion::v1_0); + Client( + const char* host, + int port = 80, + size_t timeout_sec = 300, + HttpVersion http_version = HttpVersion::v1_0); + virtual ~Client(); virtual bool is_valid() const; @@ -250,10 +257,12 @@ protected: const std::string host_; const int port_; + size_t timeout_sec_; const HttpVersion http_version_; const std::string host_and_port_; private: + socket_t create_client_socket() const; bool read_response_line(Stream& strm, Response& res); void write_request(Stream& strm, Request& req); @@ -292,7 +301,12 @@ private: class SSLClient : public Client { public: - SSLClient(const char* host, int port, HttpVersion http_version = HttpVersion::v1_0); + SSLClient( + const char* host, + int port = 80, + size_t timeout_sec = 300, + HttpVersion http_version = HttpVersion::v1_0); + virtual ~SSLClient(); virtual bool is_valid() const; @@ -406,7 +420,7 @@ inline int close_socket(socket_t sock) #endif } -inline int select(socket_t sock, size_t sec, size_t usec) +inline int select_read(socket_t sock, size_t sec, size_t usec) { fd_set fds; FD_ZERO(&fds); @@ -416,7 +430,28 @@ inline int select(socket_t sock, size_t sec, size_t usec) tv.tv_sec = sec; tv.tv_usec = usec; - return ::select(sock + 1, &fds, NULL, NULL, &tv); + return select(sock + 1, &fds, NULL, NULL, &tv); +} + +inline bool is_socket_writable(socket_t sock, size_t sec, size_t usec) +{ + fd_set fdsw; + FD_ZERO(&fdsw); + FD_SET(sock, &fdsw); + + fd_set fdse; + FD_ZERO(&fdse); + FD_SET(sock, &fdse); + + timeval tv; + tv.tv_sec = sec; + tv.tv_usec = usec; + + if (select(sock + 1, NULL, &fdsw, &fdse, &tv) <= 0) { + return false; + } + + return FD_ISSET(sock, &fdsw) != 0; } template @@ -427,7 +462,7 @@ inline bool read_and_close_socket(socket_t sock, bool keep_alive, T callback) if (keep_alive) { auto count = CPPHTTPLIB_KEEPALIVE_MAX_COUNT; while (count > 0 && - detail::select(sock, + detail::select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0) { auto last_connection = count == 1; @@ -507,27 +542,24 @@ socket_t create_socket(const char* host, int port, Fn fn, int socket_flags = 0) return -1; } -inline socket_t create_server_socket(const char* host, int port, int socket_flags) +inline void set_nonblocking(socket_t sock, bool nonblocking) { - return create_socket(host, port, [](socket_t sock, struct addrinfo& ai) -> socket_t { - if (::bind(sock, ai.ai_addr, ai.ai_addrlen)) { - return false; - } - if (listen(sock, 5)) { // Listen through 5 channels - return false; - } - return true; - }, socket_flags); +#ifdef _WIN32 + auto flags = nonblocking ? 1UL : 0UL; + ioctlsocket(sock, FIONBIO, &flags); +#else + auto flags = fcntl(sock, F_GETFL, 0); + fcntl(sock, F_SETFL, nonblocking ? (flags | O_NONBLOCK) : (flags & (~O_NONBLOCK))); +#endif } -inline socket_t create_client_socket(const char* host, int port) +inline bool is_connection_error() { - return create_socket(host, port, [](socket_t sock, struct addrinfo& ai) -> socket_t { - if (connect(sock, ai.ai_addr, ai.ai_addrlen)) { - return false; - } - return true; - }); +#ifdef _WIN32 + return WSAGetLastError() != WSAEWOULDBLOCK; +#else + return errno != EINPROGRESS; +#endif } inline bool is_file(const std::string& path) @@ -1339,7 +1371,7 @@ inline bool Server::listen(const char* host, int port, int socket_flags) return false; } - svr_sock_ = detail::create_server_socket(host, port, socket_flags); + svr_sock_ = create_server_socket(host, port, socket_flags); if (svr_sock_ == -1) { return false; } @@ -1347,7 +1379,7 @@ inline bool Server::listen(const char* host, int port, int socket_flags) auto ret = true; for (;;) { - auto val = detail::select(svr_sock_, 0, 100000); + auto val = detail::select_read(svr_sock_, 0, 100000); if (val == 0) { // Timeout if (svr_sock_ == -1) { @@ -1480,6 +1512,20 @@ inline bool Server::handle_file_request(Request& req, Response& res) return false; } +inline socket_t Server::create_server_socket(const char* host, int port, int socket_flags) const +{ + return detail::create_socket(host, port, + [](socket_t sock, struct addrinfo& ai) -> bool { + if (::bind(sock, ai.ai_addr, ai.ai_addrlen)) { + return false; + } + if (::listen(sock, 5)) { // Listen through 5 channels + return false; + } + return true; + }, socket_flags); +} + inline bool Server::routing(Request& req, Response& res) { if (req.method == "GET" && handle_file_request(req, res)) { @@ -1590,9 +1636,11 @@ inline bool Server::read_and_close_socket(socket_t sock) } // HTTP client implementation -inline Client::Client(const char* host, int port, HttpVersion http_version) +inline Client::Client( + const char* host, int port, size_t timeout_sec, HttpVersion http_version) : host_(host) , port_(port) + , timeout_sec_(timeout_sec) , http_version_(http_version) , host_and_port_(host_ + ":" + std::to_string(port_)) { @@ -1607,6 +1655,23 @@ inline bool Client::is_valid() const return true; } +inline socket_t Client::create_client_socket() const +{ + return detail::create_socket(host_.c_str(), port_, + [=](socket_t sock, struct addrinfo& ai) -> bool { + detail::set_nonblocking(sock, true); + + auto ret = connect(sock, ai.ai_addr, ai.ai_addrlen); + if (ret == -1 && detail::is_connection_error()) { + return false; + } + + detail::set_nonblocking(sock, false); + + return detail::is_socket_writable(sock, timeout_sec_, 0); + }); +} + inline bool Client::read_response_line(Stream& strm, Response& res) { const auto bufsiz = 2048; @@ -1634,7 +1699,7 @@ inline bool Client::send(Request& req, Response& res) return false; } - auto sock = detail::create_client_socket(host_.c_str(), port_); + auto sock = create_client_socket(); if (sock == -1) { return false; } @@ -1826,7 +1891,7 @@ inline bool read_and_close_socket_ssl( if (keep_alive) { auto count = CPPHTTPLIB_KEEPALIVE_MAX_COUNT; while (count > 0 && - detail::select(sock, + detail::select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0) { auto last_connection = count == 1; @@ -1936,8 +2001,9 @@ inline bool SSLServer::read_and_close_socket(socket_t sock) } // SSL HTTP client implementation -inline SSLClient::SSLClient(const char* host, int port, HttpVersion http_version) - : Client(host, port, http_version) +inline SSLClient::SSLClient( + const char* host, int port, size_t timeout_sec, HttpVersion http_version) + : Client(host, port, timeout_sec, http_version) { ctx_ = SSL_CTX_new(SSLv23_client_method()); } diff --git a/test/test.cc b/test/test.cc index 1edb8af..48e776b 100644 --- a/test/test.cc +++ b/test/test.cc @@ -63,24 +63,6 @@ TEST(ParseQueryTest, ParseQueryString) EXPECT_EQ("val3", dic.find("key3")->second); } -TEST(SocketTest, OpenClose) -{ - socket_t sock = detail::create_server_socket(HOST, PORT, 0); - ASSERT_NE(-1, sock); - - auto ret = detail::close_socket(sock); - EXPECT_EQ(0, ret); -} - -TEST(SocketTest, OpenCloseWithAI_PASSIVE) -{ - socket_t sock = detail::create_server_socket(nullptr, PORT, AI_PASSIVE); - ASSERT_NE(-1, sock); - - auto ret = detail::close_socket(sock); - EXPECT_EQ(0, ret); -} - TEST(GetHeaderValueTest, DefaultValue) { Headers headers = {{"Dummy","Dummy"}}; @@ -139,13 +121,14 @@ TEST(GetHeaderValueTest, Range) void testChunkedEncoding(httplib::HttpVersion ver) { auto host = "www.httpwatch.com"; + auto sec = 5; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT auto port = 443; - httplib::SSLClient cli(host, port, ver); + httplib::SSLClient cli(host, port, sec, ver); #else auto port = 80; - httplib::Client cli(host, port, ver); + httplib::Client cli(host, port, sec, ver); #endif auto res = cli.get("/httpgallery/chunked/chunkedimage.aspx?0.4153841143030137"); @@ -167,13 +150,15 @@ TEST(ChunkedEncodingTest, FromHTTPWatch) TEST(RangeTest, FromHTTPBin) { auto host = "httpbin.org"; + auto sec = 5; + auto ver = httplib::HttpVersion::v1_1; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT auto port = 443; - httplib::SSLClient cli(host, port, httplib::HttpVersion::v1_1); + httplib::SSLClient cli(host, port, sec, ver); #else auto port = 80; - httplib::Client cli(host, port, httplib::HttpVersion::v1_1); + httplib::Client cli(host, port, sec, ver); #endif { @@ -631,7 +616,7 @@ protected: res.set_content("Hello World!", "text/plain"); }); - t_ = thread([&](){ + t_ = thread([&]() { svr_.listen(nullptr, PORT, AI_PASSIVE); });