Write error handling

This commit is contained in:
yhirose 2019-07-21 21:34:49 -04:00
parent 4c18ac2b18
commit 7267b3f3e2

View file

@ -197,7 +197,7 @@ public:
virtual std::string get_remote_addr() const = 0;
template <typename... Args>
void write_format(const char *fmt, const Args &... args);
int write_format(const char *fmt, const Args &... args);
};
class SocketStream : public Stream {
@ -286,7 +286,7 @@ private:
bool dispatch_request(Request &req, Response &res, Handlers &handlers);
bool parse_request_line(const char *s, Request &req);
void write_response(Stream &strm, bool last_connection, const Request &req,
bool write_response(Stream &strm, bool last_connection, const Request &req,
Response &res);
virtual bool read_and_close_socket(socket_t sock);
@ -1228,18 +1228,29 @@ bool read_content(Stream &strm, T &x, uint64_t payload_max_length, int &status,
return ret;
}
template <typename T> inline void write_headers(Stream &strm, const T &info) {
template <typename T> inline int write_headers(Stream &strm, const T &info) {
auto write_len = 0;
for (const auto &x : info.headers) {
strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str());
auto len = strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str());
if (len < 0) {
return len;
}
strm.write("\r\n");
write_len += len;
}
auto len = strm.write("\r\n");
if (len < 0) {
return len;
}
write_len += len;
return write_len;
}
template <typename T>
inline void write_content_chunked(Stream &strm, const T &x) {
inline int write_content_chunked(Stream &strm, const T &x) {
auto chunked_response = !x.has_header("Content-Length");
uint64_t offset = 0;
auto data_available = true;
auto write_len = 0;
while (data_available) {
auto chunk = x.content_producer(offset);
offset += chunk.size();
@ -1250,10 +1261,13 @@ inline void write_content_chunked(Stream &strm, const T &x) {
chunk = from_i_to_hex(chunk.size()) + "\r\n" + chunk + "\r\n";
}
if (strm.write(chunk.c_str(), chunk.size()) < 0) {
break; // Stop on error
auto len = strm.write(chunk.c_str(), chunk.size());
if (len < 0) {
return len;
}
write_len += len;
}
return write_len;
}
inline std::string encode_url(const std::string &s) {
@ -1560,7 +1574,7 @@ inline void Response::set_content(const std::string &s,
// Rstream implementation
template <typename... Args>
inline void Stream::write_format(const char *fmt, const Args &... args) {
inline int Stream::write_format(const char *fmt, const Args &... args) {
const auto bufsiz = 2048;
char buf[bufsiz];
@ -1569,7 +1583,10 @@ inline void Stream::write_format(const char *fmt, const Args &... args) {
#else
auto n = snprintf(buf, bufsiz - 1, fmt, args...);
#endif
if (n > 0) {
if (n <= 0) {
return n;
}
if (n >= bufsiz - 1) {
std::vector<char> glowable_buf(bufsiz);
@ -1582,10 +1599,9 @@ inline void Stream::write_format(const char *fmt, const Args &... args) {
n = snprintf(&glowable_buf[0], glowable_buf.size() - 1, fmt, args...);
#endif
}
write(&glowable_buf[0], n);
return write(&glowable_buf[0], n);
} else {
write(buf, n);
}
return write(buf, n);
}
}
@ -1745,15 +1761,17 @@ inline bool Server::parse_request_line(const char *s, Request &req) {
return false;
}
inline void Server::write_response(Stream &strm, bool last_connection,
inline bool Server::write_response(Stream &strm, bool last_connection,
const Request &req, Response &res) {
assert(res.status != -1);
if (400 <= res.status && error_handler_) { error_handler_(req, res); }
// Response line
strm.write_format("HTTP/1.1 %d %s\r\n", res.status,
detail::status_message(res.status));
if (!strm.write_format("HTTP/1.1 %d %s\r\n", res.status,
detail::status_message(res.status))) {
return false;
}
// Headers
if (last_connection || req.get_header_value("Connection") == "close") {
@ -1793,19 +1811,27 @@ inline void Server::write_response(Stream &strm, bool last_connection,
res.set_header("Content-Length", length.c_str());
}
detail::write_headers(strm, res);
if (!detail::write_headers(strm, res)) {
return false;
}
// Body
if (req.method != "HEAD") {
if (!res.body.empty()) {
strm.write(res.body.c_str(), res.body.size());
if (!strm.write(res.body.c_str(), res.body.size())) {
return false;
}
} else if (res.content_producer) {
detail::write_content_chunked(strm, res);
if (!detail::write_content_chunked(strm, res)) {
return false;
}
}
}
// Log
if (logger_) { logger_(req, res); }
return true;
}
inline bool Server::handle_file_request(Request &req, Response &res) {
@ -1978,16 +2004,14 @@ Server::process_request(Stream &strm, bool last_connection,
// Check if the request URI doesn't exceed the limit
if (reader.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) {
res.status = 414;
write_response(strm, last_connection, req, res);
return true;
return write_response(strm, last_connection, req, res);
}
// Request line and headers
if (!parse_request_line(reader.ptr(), req) ||
!detail::read_headers(strm, req.headers)) {
res.status = 400;
write_response(strm, last_connection, req, res);
return true;
return write_response(strm, last_connection, req, res);
}
if (req.get_header_value("Connection") == "close") {
@ -2001,8 +2025,7 @@ Server::process_request(Stream &strm, bool last_connection,
if (!detail::read_content(
strm, req, payload_max_length_, res.status, Progress(),
[&](const char *buf, size_t n) { req.body.append(buf, n); })) {
write_response(strm, last_connection, req, res);
return true;
return write_response(strm, last_connection, req, res);
}
const auto &content_type = req.get_header_value("Content-Type");
@ -2014,8 +2037,7 @@ Server::process_request(Stream &strm, bool last_connection,
if (!detail::parse_multipart_boundary(content_type, boundary) ||
!detail::parse_multipart_formdata(boundary, req.body, req.files)) {
res.status = 400;
write_response(strm, last_connection, req, res);
return true;
return write_response(strm, last_connection, req, res);
}
}
}
@ -2029,8 +2051,7 @@ Server::process_request(Stream &strm, bool last_connection,
res.status = 404;
}
write_response(strm, last_connection, req, res);
return true;
return write_response(strm, last_connection, req, res);
}
inline bool Server::is_valid() const { return true; }