diff --git a/httplib.h b/httplib.h index 2a3742b..45b3152 100644 --- a/httplib.h +++ b/httplib.h @@ -183,6 +183,7 @@ using socket_t = SOCKET; #include #include #include +#include #include using socket_t = int; @@ -2570,6 +2571,30 @@ socket_t create_socket(const std::string &host, const std::string &ip, int port, hints.ai_flags = socket_flags; } +#ifndef _WIN32 + if (hints.ai_family == AF_UNIX) { + const auto addrlen = host.length(); + if (addrlen > sizeof(sockaddr_un::sun_path)) return INVALID_SOCKET; + + auto sock = socket(hints.ai_family, hints.ai_socktype, hints.ai_protocol); + if (sock != INVALID_SOCKET) { + sockaddr_un addr; + addr.sun_family = AF_UNIX; + std::copy(host.begin(), host.end(), addr.sun_path); + + hints.ai_addr = reinterpret_cast(&addr); + hints.ai_addrlen = static_cast( + sizeof(addr) - sizeof(addr.sun_path) + addrlen); + + if (!bind_or_connect(sock, hints)) { + close_socket(sock); + sock = INVALID_SOCKET; + } + } + return sock; + } +#endif + auto service = std::to_string(port); if (getaddrinfo(node, service.c_str(), &hints, &result)) { @@ -7858,12 +7883,12 @@ inline Client::Client(const std::string &scheme_host_port, if (is_ssl) { #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - cli_ = detail::make_unique(host.c_str(), port, + cli_ = detail::make_unique(host, port, client_cert_path, client_key_path); is_ssl_ = is_ssl; #endif } else { - cli_ = detail::make_unique(host.c_str(), port, + cli_ = detail::make_unique(host, port, client_cert_path, client_key_path); } } else { diff --git a/test/test.cc b/test/test.cc index 204ae43..fc0fa09 100644 --- a/test/test.cc +++ b/test/test.cc @@ -5064,3 +5064,71 @@ TEST(MultipartFormDataTest, WithPreamble) { #endif +#ifndef _WIN32 +class UnixSocketTest : public ::testing::Test { +protected: + void TearDown() override { + std::remove(pathname_.c_str()); + } + + void client_GET(const std::string &addr) { + httplib::Client cli{addr}; + cli.set_address_family(AF_UNIX); + ASSERT_TRUE(cli.is_valid()); + + const auto &result = cli.Get(pattern_); + ASSERT_TRUE(result) << "error: " << result.error(); + + const auto &resp = result.value(); + EXPECT_EQ(resp.status, 200); + EXPECT_EQ(resp.body, content_); + } + + const std::string pathname_ {"./httplib-server.sock"}; + const std::string pattern_ {"/hi"}; + const std::string content_ {"Hello World!"}; +}; + +TEST_F(UnixSocketTest, pathname) { + httplib::Server svr; + svr.Get(pattern_, [&](const httplib::Request &, httplib::Response &res) { + res.set_content(content_, "text/plain"); + }); + + std::thread t {[&] { + ASSERT_TRUE(svr.set_address_family(AF_UNIX).listen(pathname_, 80)); }}; + while (!svr.is_running()) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + ASSERT_TRUE(svr.is_running()); + + client_GET(pathname_); + + svr.stop(); + t.join(); +} + +#ifdef __linux__ +TEST_F(UnixSocketTest, abstract) { + constexpr char svr_path[] {"\x00httplib-server.sock"}; + const std::string abstract_addr {svr_path, sizeof(svr_path) - 1}; + + httplib::Server svr; + svr.Get(pattern_, [&](const httplib::Request &, httplib::Response &res) { + res.set_content(content_, "text/plain"); + }); + + std::thread t {[&] { + ASSERT_TRUE(svr.set_address_family(AF_UNIX).listen(abstract_addr, 80)); }}; + while (!svr.is_running()) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + ASSERT_TRUE(svr.is_running()); + + client_GET(abstract_addr); + + svr.stop(); + t.join(); +} +#endif +#endif // #ifndef _WIN32