From d0b123be26fbbc45c28a65265acc7d08a8c41785 Mon Sep 17 00:00:00 2001 From: yhirose Date: Thu, 23 Apr 2020 22:10:50 -0400 Subject: [PATCH] Support remote_addr and remote_port REMOTE_PORT header in client Request (#433) --- httplib.h | 93 +++++++++++++++++++++++++++++++--------------------- test/test.cc | 3 ++ 2 files changed, 59 insertions(+), 37 deletions(-) diff --git a/httplib.h b/httplib.h index af371d0..723c562 100644 --- a/httplib.h +++ b/httplib.h @@ -262,6 +262,9 @@ struct Request { Headers headers; std::string body; + std::string remote_addr; + int remote_port = -1; + // for server std::string version; std::string target; @@ -352,7 +355,7 @@ public: virtual ssize_t read(char *ptr, size_t size) = 0; virtual ssize_t write(const char *ptr, size_t size) = 0; - virtual std::string get_remote_addr() const = 0; + virtual void get_remote_ip_and_port(std::string &ip, int &port) const = 0; template ssize_t write_format(const char *fmt, const Args &... args); @@ -1283,7 +1286,7 @@ public: bool is_writable() const override; ssize_t read(char *ptr, size_t size) override; ssize_t write(const char *ptr, size_t size) override; - std::string get_remote_addr() const override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; private: socket_t sock_; @@ -1302,7 +1305,7 @@ public: bool is_writable() const override; ssize_t read(char *ptr, size_t size) override; ssize_t write(const char *ptr, size_t size) override; - std::string get_remote_addr() const override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; private: socket_t sock_; @@ -1321,7 +1324,7 @@ public: bool is_writable() const override; ssize_t read(char *ptr, size_t size) override; ssize_t write(const char *ptr, size_t size) override; - std::string get_remote_addr() const override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; const std::string &get_buffer() const; @@ -1554,21 +1557,32 @@ inline socket_t create_client_socket(const char *host, int port, }); } -inline std::string get_remote_addr(socket_t sock) { - struct sockaddr_storage addr; - socklen_t len = sizeof(addr); - - if (!getpeername(sock, reinterpret_cast(&addr), &len)) { - std::array ipstr{}; - - if (!getnameinfo(reinterpret_cast(&addr), len, - ipstr.data(), static_cast(ipstr.size()), - nullptr, 0, NI_NUMERICHOST)) { - return ipstr.data(); - } +inline void get_remote_ip_and_port(const struct sockaddr_storage &addr, + socklen_t addr_len, std::string &ip, + int &port) { + if (addr.ss_family == AF_INET) { + port = ntohs(reinterpret_cast(&addr)->sin_port); + } else if (addr.ss_family == AF_INET6) { + port = + ntohs(reinterpret_cast(&addr)->sin6_port); } - return std::string(); + std::array ipstr{}; + if (!getnameinfo(reinterpret_cast(&addr), addr_len, + ipstr.data(), static_cast(ipstr.size()), nullptr, + 0, NI_NUMERICHOST)) { + ip = ipstr.data(); + } +} + +inline void get_remote_ip_and_port(socket_t sock, std::string &ip, int &port) { + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + + if (!getpeername(sock, reinterpret_cast(&addr), + &addr_len)) { + get_remote_ip_and_port(addr, addr_len, ip, port); + } } inline const char * @@ -2910,11 +2924,11 @@ inline SocketStream::SocketStream(socket_t sock, time_t read_timeout_sec, inline SocketStream::~SocketStream() {} inline bool SocketStream::is_readable() const { - return detail::select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; + return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; } inline bool SocketStream::is_writable() const { - return detail::select_write(sock_, 0, 0) > 0; + return select_write(sock_, 0, 0) > 0; } inline ssize_t SocketStream::read(char *ptr, size_t size) { @@ -2927,8 +2941,9 @@ inline ssize_t SocketStream::write(const char *ptr, size_t size) { return -1; } -inline std::string SocketStream::get_remote_addr() const { - return detail::get_remote_addr(sock_); +inline void SocketStream::get_remote_ip_and_port(std::string &ip, + int &port) const { + return detail::get_remote_ip_and_port(sock_, ip, port); } // Buffer stream implementation @@ -2951,7 +2966,8 @@ inline ssize_t BufferStream::write(const char *ptr, size_t size) { return static_cast(size); } -inline std::string BufferStream::get_remote_addr() const { return ""; } +inline void BufferStream::get_remote_ip_and_port(std::string & /*ip*/, + int & /*port*/) const {} inline const std::string &BufferStream::get_buffer() const { return buffer; } @@ -3431,17 +3447,16 @@ inline int Server::bind_internal(const char *host, int port, int socket_flags) { if (svr_sock_ == INVALID_SOCKET) { return -1; } if (port == 0) { - struct sockaddr_storage address; - socklen_t len = sizeof(address); - if (getsockname(svr_sock_, reinterpret_cast(&address), - &len) == -1) { + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + if (getsockname(svr_sock_, reinterpret_cast(&addr), + &addr_len) == -1) { return -1; } - if (address.ss_family == AF_INET) { - return ntohs(reinterpret_cast(&address)->sin_port); - } else if (address.ss_family == AF_INET6) { - return ntohs( - reinterpret_cast(&address)->sin6_port); + if (addr.ss_family == AF_INET) { + return ntohs(reinterpret_cast(&addr)->sin_port); + } else if (addr.ss_family == AF_INET6) { + return ntohs(reinterpret_cast(&addr)->sin6_port); } else { return -1; } @@ -3646,7 +3661,9 @@ Server::process_request(Stream &strm, bool last_connection, connection_close = true; } - req.set_header("REMOTE_ADDR", strm.get_remote_addr()); + strm.get_remote_ip_and_port(req.remote_addr, req.remote_port); + req.set_header("REMOTE_ADDR", req.remote_addr); + req.set_header("REMOTE_PORT", std::to_string(req.remote_port)); if (req.has_header("Range")) { const auto &range_header_value = req.get_header_value("Range"); @@ -4527,8 +4544,8 @@ inline bool process_and_close_socket_ssl( auto count = keep_alive_max_count; while (count > 0 && (is_client_request || - detail::select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, - CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) { + select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, + CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) { SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec); auto last_connection = count == 1; auto connection_close = false; @@ -4639,8 +4656,9 @@ inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) { return -1; } -inline std::string SSLSocketStream::get_remote_addr() const { - return detail::get_remote_addr(sock_); +inline void SSLSocketStream::get_remote_ip_and_port(std::string &ip, + int &port) const { + detail::get_remote_ip_and_port(sock_, ip, port); } static SSLInit sslinit_; @@ -5020,7 +5038,8 @@ inline std::shared_ptr Get(const char *url, Options &options) { SSLClient cli(next_host.c_str(), next_port, options.client_cert_path, options.client_key_path); cli.set_follow_location(options.follow_location); - cli.set_ca_cert_path(options.ca_cert_file_path.c_str(), options.ca_cert_dir_path.c_str()); + cli.set_ca_cert_path(options.ca_cert_file_path.c_str(), + options.ca_cert_dir_path.c_str()); cli.enable_server_certificate_verification( options.server_certificate_verification); return cli.Get(next_path.c_str()); diff --git a/test/test.cc b/test/test.cc index 15451df..b5f5cf5 100644 --- a/test/test.cc +++ b/test/test.cc @@ -801,6 +801,9 @@ protected: .Get("/remote_addr", [&](const Request &req, Response &res) { auto remote_addr = req.headers.find("REMOTE_ADDR")->second; + EXPECT_TRUE(req.has_header("REMOTE_PORT")); + EXPECT_EQ(req.remote_addr, req.get_header_value("REMOTE_ADDR")); + EXPECT_EQ(req.remote_port, std::stoi(req.get_header_value("REMOTE_PORT"))); res.set_content(remote_addr.c_str(), "text/plain"); }) .Get("/endwith%",