Don't allow invalid status code format (It sould be a three-digit code.)

This commit is contained in:
yhirose 2020-12-15 19:06:52 -05:00
parent a6edfc730a
commit 7c1c952f5a
2 changed files with 34 additions and 8 deletions

View file

@ -1025,7 +1025,7 @@ protected:
private: private:
socket_t create_client_socket(Error &error) const; socket_t create_client_socket(Error &error) const;
bool read_response_line(Stream &strm, Response &res); bool read_response_line(Stream &strm, const Request &req, Response &res);
bool write_request(Stream &strm, const Request &req, bool close_connection, bool write_request(Stream &strm, const Request &req, bool close_connection,
Error &error); Error &error);
bool redirect(const Request &req, Response &res, Error &error); bool redirect(const Request &req, Response &res, Error &error);
@ -4947,17 +4947,20 @@ inline void ClientImpl::lock_socket_and_shutdown_and_close() {
close_socket(socket_); close_socket(socket_);
} }
inline bool ClientImpl::read_response_line(Stream &strm, Response &res) { inline bool ClientImpl::read_response_line(Stream &strm, const Request &req,
Response &res) {
std::array<char, 2048> buf; std::array<char, 2048> buf;
detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); detail::stream_line_reader line_reader(strm, buf.data(), buf.size());
if (!line_reader.getline()) { return false; } if (!line_reader.getline()) { return false; }
const static std::regex re("(HTTP/1\\.[01]) (\\d+) (.*?)\r\n"); const static std::regex re("(HTTP/1\\.[01]) (\\d{3}) (.*?)\r\n");
std::cmatch m; std::cmatch m;
if (!std::regex_match(line_reader.ptr(), m, re)) { return true; } if (!std::regex_match(line_reader.ptr(), m, re)) {
return req.method == "CONNECT";
}
res.version = std::string(m[1]); res.version = std::string(m[1]);
res.status = std::stoi(std::string(m[2])); res.status = std::stoi(std::string(m[2]));
res.reason = std::string(m[3]); res.reason = std::string(m[3]);
@ -5404,7 +5407,7 @@ inline bool ClientImpl::process_request(Stream &strm, const Request &req,
if (!write_request(strm, req, close_connection, error)) { return false; } if (!write_request(strm, req, close_connection, error)) { return false; }
// Receive response and headers // Receive response and headers
if (!read_response_line(strm, res) || if (!read_response_line(strm, req, res) ||
!detail::read_headers(strm, res.headers)) { !detail::read_headers(strm, res.headers)) {
error = Error::Read; error = Error::Read;
return false; return false;
@ -5448,9 +5451,7 @@ inline bool ClientImpl::process_request(Stream &strm, const Request &req,
if (!detail::read_content(strm, res, (std::numeric_limits<size_t>::max)(), if (!detail::read_content(strm, res, (std::numeric_limits<size_t>::max)(),
dummy_status, std::move(progress), std::move(out), dummy_status, std::move(progress), std::move(out),
decompress_)) { decompress_)) {
if (error != Error::Canceled) { if (error != Error::Canceled) { error = Error::Read; }
error = Error::Read;
}
return false; return false;
} }
} }

View file

@ -930,6 +930,31 @@ TEST(ErrorHandlerTest, ContentLength) {
ASSERT_FALSE(svr.is_running()); ASSERT_FALSE(svr.is_running());
} }
TEST(InvalidFormatTest, StatusCode) {
Server svr;
svr.Get("/hi", [](const Request & /*req*/, Response &res) {
res.set_content("Hello World!\n", "text/plain");
res.status = 9999; // Status should be a three-digit code...
});
auto thread = std::thread([&]() { svr.listen(HOST, PORT); });
// Give GET time to get a few messages.
std::this_thread::sleep_for(std::chrono::seconds(1));
{
Client cli(HOST, PORT);
auto res = cli.Get("/hi");
ASSERT_FALSE(res);
}
svr.stop();
thread.join();
ASSERT_FALSE(svr.is_running());
}
class ServerTest : public ::testing::Test { class ServerTest : public ::testing::Test {
protected: protected:
ServerTest() ServerTest()