diff --git a/httplib.h b/httplib.h index 59acc84..2b32b0d 100644 --- a/httplib.h +++ b/httplib.h @@ -114,6 +114,7 @@ using socket_t = SOCKET; #include #include +#include #include #include #ifdef CPPHTTPLIB_USE_POLL @@ -743,6 +744,8 @@ public: void set_compress(bool on); + void set_interface(const char *intf); + protected: bool process_request(Stream &strm, const Request &req, Response &res, bool last_connection, bool &connection_close); @@ -758,6 +761,7 @@ protected: std::string username_; std::string password_; bool compress_; + std::string interface_; private: socket_t create_client_socket() const; @@ -1348,10 +1352,62 @@ inline bool is_connection_error() { #endif } +inline bool bind_ip_address(socket_t sock, const char *host) { + struct addrinfo hints; + struct addrinfo *result; + + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = 0; + + if (getaddrinfo(host, "0", &hints, &result)) { return false; } + + bool ret = false; + for (auto rp = result; rp; rp = rp->ai_next) { + const auto &ai = *rp; + if (!::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { + ret = true; + break; + } + } + + freeaddrinfo(result); + return ret; +} + +inline std::string if2ip(const std::string &ifn) { +#ifndef _WIN32 + struct ifaddrs *ifap; + getifaddrs(&ifap); + for (auto ifa = ifap; ifa; ifa = ifa->ifa_next) { + if (ifa->ifa_addr && ifn == ifa->ifa_name) { + if (ifa->ifa_addr->sa_family == AF_INET) { + auto sa = reinterpret_cast(ifa->ifa_addr); + char buf[INET_ADDRSTRLEN]; + if (inet_ntop(AF_INET, &sa->sin_addr, buf, INET_ADDRSTRLEN)) { + freeifaddrs(ifap); + return std::string(buf, INET_ADDRSTRLEN); + } + } + } + } + freeifaddrs(ifap); +#endif + return std::string(); +} + inline socket_t create_client_socket(const char *host, int port, - time_t timeout_sec) { + time_t timeout_sec, + const std::string &intf) { return create_socket( - host, port, [=](socket_t sock, struct addrinfo &ai) -> bool { + host, port, [&](socket_t sock, struct addrinfo &ai) -> bool { + if (!intf.empty()) { + auto ip = if2ip(intf); + if (ip.empty()) { ip = intf; } + if (!bind_ip_address(sock, ip.c_str())) { return false; } + } + set_nonblocking(sock, true); auto ret = ::connect(sock, ai.ai_addr, static_cast(ai.ai_addrlen)); @@ -3312,7 +3368,8 @@ inline Client::~Client() {} inline bool Client::is_valid() const { return true; } inline socket_t Client::create_client_socket() const { - return detail::create_client_socket(host_.c_str(), port_, timeout_sec_); + return detail::create_client_socket(host_.c_str(), port_, timeout_sec_, + interface_); } inline bool Client::read_response_line(Stream &strm, Response &res) { @@ -3942,6 +3999,10 @@ inline void Client::set_follow_location(bool on) { follow_location_ = on; } inline void Client::set_compress(bool on) { compress_ = on; } +inline void Client::set_interface(const char *intf) { + interface_ = intf; +} + /* * SSL Implementation */ diff --git a/test/test.cc b/test/test.cc index 83dfe92..d436646 100644 --- a/test/test.cc +++ b/test/test.cc @@ -1817,8 +1817,8 @@ TEST_F(ServerTest, MultipartFormDataGzip) { // Sends a raw request to a server listening at HOST:PORT. static bool send_request(time_t read_timeout_sec, const std::string &req) { - auto client_sock = - detail::create_client_socket(HOST, PORT, /*timeout_sec=*/5); + auto client_sock = detail::create_client_socket(HOST, PORT, /*timeout_sec=*/5, + std::string()); if (client_sock == INVALID_SOCKET) { return false; }