This commit is contained in:
yhirose 2020-07-25 09:37:57 -04:00
parent 15c4106a36
commit 9ca1fa8b18
2 changed files with 111 additions and 76 deletions

171
httplib.h
View file

@ -349,6 +349,8 @@ struct Request {
bool has_header(const char *key) const;
std::string get_header_value(const char *key, size_t id = 0) const;
template <typename T>
T get_header_value(const char *key, size_t id = 0) const;
size_t get_header_value_count(const char *key) const;
void set_header(const char *key, const char *val);
void set_header(const char *key, const std::string &val);
@ -374,6 +376,8 @@ struct Response {
bool has_header(const char *key) const;
std::string get_header_value(const char *key, size_t id = 0) const;
template <typename T>
T get_header_value(const char *key, size_t id = 0) const;
size_t get_header_value_count(const char *key) const;
void set_header(const char *key, const char *val);
void set_header(const char *key, const std::string &val);
@ -1580,6 +1584,74 @@ inline bool is_valid_path(const std::string &path) {
return true;
}
inline std::string encode_url(const std::string &s) {
std::string result;
for (size_t i = 0; s[i]; i++) {
switch (s[i]) {
case ' ': result += "%20"; break;
case '+': result += "%2B"; break;
case '\r': result += "%0D"; break;
case '\n': result += "%0A"; break;
case '\'': result += "%27"; break;
case ',': result += "%2C"; break;
// case ':': result += "%3A"; break; // ok? probably...
case ';': result += "%3B"; break;
default:
auto c = static_cast<uint8_t>(s[i]);
if (c >= 0x80) {
result += '%';
char hex[4];
auto len = snprintf(hex, sizeof(hex) - 1, "%02X", c);
assert(len == 2);
result.append(hex, static_cast<size_t>(len));
} else {
result += s[i];
}
break;
}
}
return result;
}
inline std::string decode_url(const std::string &s,
bool convert_plus_to_space) {
std::string result;
for (size_t i = 0; i < s.size(); i++) {
if (s[i] == '%' && i + 1 < s.size()) {
if (s[i + 1] == 'u') {
int val = 0;
if (from_hex_to_i(s, i + 2, 4, val)) {
// 4 digits Unicode codes
char buff[4];
size_t len = to_utf8(val, buff);
if (len > 0) { result.append(buff, len); }
i += 5; // 'u0000'
} else {
result += s[i];
}
} else {
int val = 0;
if (from_hex_to_i(s, i + 1, 2, val)) {
// 2 digits hex codes
result += static_cast<char>(val);
i += 2; // '00'
} else {
result += s[i];
}
}
} else if (convert_plus_to_space && s[i] == '+') {
result += ' ';
} else {
result += s[i];
}
}
return result;
}
inline void read_file(const std::string &path, std::string &out) {
std::ifstream fs(path, std::ios_base::binary);
fs.seekg(0, std::ios_base::end);
@ -2379,10 +2451,18 @@ inline const char *get_header_value(const Headers &headers, const char *key,
return def;
}
inline uint64_t get_header_value_uint64(const Headers &headers, const char *key,
uint64_t def = 0) {
auto it = headers.find(key);
if (it != headers.end()) {
template <typename T>
inline T get_header_value(const Headers & /*headers*/, const char * /*key*/,
size_t /*id*/ = 0, uint64_t /*def*/ = 0) {}
template <>
inline uint64_t get_header_value<uint64_t>(const Headers &headers,
const char *key, size_t id,
uint64_t def) {
auto rng = headers.equal_range(key);
auto it = rng.first;
std::advance(it, static_cast<ssize_t>(id));
if (it != rng.second) {
return std::strtoull(it->second.data(), nullptr, 10);
}
return def;
@ -2404,7 +2484,8 @@ inline void parse_header(const char *beg, const char *end, Headers &headers) {
while (p < end) {
p++;
}
headers.emplace(std::string(beg, key_end), std::string(val_begin, end));
headers.emplace(std::string(beg, key_end),
decode_url(std::string(val_begin, end), true));
}
}
}
@ -2574,7 +2655,7 @@ bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status,
} else if (!has_header(x.headers, "Content-Length")) {
ret = read_content_without_length(strm, out);
} else {
auto len = get_header_value_uint64(x.headers, "Content-Length", 0);
auto len = get_header_value<uint64_t>(x.headers, "Content-Length");
if (len > payload_max_length) {
exceed_payload_max_length = true;
skip_content_with_length(strm, len);
@ -2765,74 +2846,6 @@ inline bool redirect(T &cli, const Request &req, Response &res,
return ret;
}
inline std::string encode_url(const std::string &s) {
std::string result;
for (size_t i = 0; s[i]; i++) {
switch (s[i]) {
case ' ': result += "%20"; break;
case '+': result += "%2B"; break;
case '\r': result += "%0D"; break;
case '\n': result += "%0A"; break;
case '\'': result += "%27"; break;
case ',': result += "%2C"; break;
// case ':': result += "%3A"; break; // ok? probably...
case ';': result += "%3B"; break;
default:
auto c = static_cast<uint8_t>(s[i]);
if (c >= 0x80) {
result += '%';
char hex[4];
auto len = snprintf(hex, sizeof(hex) - 1, "%02X", c);
assert(len == 2);
result.append(hex, static_cast<size_t>(len));
} else {
result += s[i];
}
break;
}
}
return result;
}
inline std::string decode_url(const std::string &s,
bool convert_plus_to_space) {
std::string result;
for (size_t i = 0; i < s.size(); i++) {
if (s[i] == '%' && i + 1 < s.size()) {
if (s[i + 1] == 'u') {
int val = 0;
if (from_hex_to_i(s, i + 2, 4, val)) {
// 4 digits Unicode codes
char buff[4];
size_t len = to_utf8(val, buff);
if (len > 0) { result.append(buff, len); }
i += 5; // 'u0000'
} else {
result += s[i];
}
} else {
int val = 0;
if (from_hex_to_i(s, i + 1, 2, val)) {
// 2 digits hex codes
result += static_cast<char>(val);
i += 2; // '00'
} else {
result += s[i];
}
}
} else if (convert_plus_to_space && s[i] == '+') {
result += ' ';
} else {
result += s[i];
}
}
return result;
}
inline std::string params_to_query_str(const Params &params) {
std::string query;
@ -3458,6 +3471,11 @@ inline std::string Request::get_header_value(const char *key, size_t id) const {
return detail::get_header_value(headers, key, id, "");
}
template <typename T>
inline T Request::get_header_value(const char *key, size_t id) const {
return detail::get_header_value<T>(headers, key, id, 0);
}
inline size_t Request::get_header_value_count(const char *key) const {
auto r = headers.equal_range(key);
return static_cast<size_t>(std::distance(r.first, r.second));
@ -3517,6 +3535,11 @@ inline std::string Response::get_header_value(const char *key,
return detail::get_header_value(headers, key, id, "");
}
template <typename T>
inline T Response::get_header_value(const char *key, size_t id) const {
return detail::get_header_value<T>(headers, key, id, 0);
}
inline size_t Response::get_header_value_count(const char *key) const {
auto r = headers.equal_range(key);
return static_cast<size_t>(std::distance(r.first, r.second));

View file

@ -100,7 +100,8 @@ TEST(GetHeaderValueTest, DefaultValue) {
TEST(GetHeaderValueTest, DefaultValueInt) {
Headers headers = {{"Dummy", "Dummy"}};
auto val = detail::get_header_value_uint64(headers, "Content-Length", 100);
auto val =
detail::get_header_value<uint64_t>(headers, "Content-Length", 0, 100);
EXPECT_EQ(100ull, val);
}
@ -112,7 +113,8 @@ TEST(GetHeaderValueTest, RegularValue) {
TEST(GetHeaderValueTest, RegularValueInt) {
Headers headers = {{"Content-Length", "100"}, {"Dummy", "Dummy"}};
auto val = detail::get_header_value_uint64(headers, "Content-Length", 0);
auto val =
detail::get_header_value<uint64_t>(headers, "Content-Length", 0, 0);
EXPECT_EQ(100ull, val);
}
@ -716,6 +718,16 @@ TEST(RedirectToDifferentPort, Redirect) {
ASSERT_FALSE(svr8080.is_running());
ASSERT_FALSE(svr8081.is_running());
}
TEST(UrlWithSpace, Redirect) {
httplib::SSLClient cli("edge.forgecdn.net");
cli.set_follow_location(true);
auto res = cli.Get("/files/2595/310/Neat 1.4-17.jar");
ASSERT_TRUE(res != nullptr);
EXPECT_EQ(200, res->status);
EXPECT_EQ(18527, res->get_header_value<uint64_t>("Content-Length"));
}
#endif
TEST(Server, BindDualStack) {