diff options
-rw-r--r-- | http/utility.hpp | 138 | ||||
-rw-r--r-- | test/http/utility_test.cpp | 21 |
2 files changed, 126 insertions, 33 deletions
diff --git a/http/utility.hpp b/http/utility.hpp index dc7ea7f1b2..bf708915b6 100644 --- a/http/utility.hpp +++ b/http/utility.hpp @@ -114,62 +114,134 @@ constexpr size_t numArgsFromTag(int tag) return ret; }; -inline std::string base64encode(std::string_view data) +class Base64Encoder { - const std::array<char, 64> key = { + char overflow1 = '\0'; + char overflow2 = '\0'; + uint8_t overflowCount = 0; + + constexpr static std::array<char, 64> key = { 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/'}; - size_t size = data.size(); - std::string ret; - ret.resize((size + 2) / 3 * 4); - auto it = ret.begin(); - - size_t i = 0; - while (i < size) + // Takes 3 ascii chars, and encodes them as 4 base64 chars + static void encodeTriple(char first, char second, char third, + std::string& output) { size_t keyIndex = 0; - keyIndex = static_cast<size_t>(data[i] & 0xFC) >> 2; - *it++ = key[keyIndex]; + keyIndex = static_cast<size_t>(first & 0xFC) >> 2; + output += key[keyIndex]; - if (i + 1 < size) - { - keyIndex = static_cast<size_t>(data[i] & 0x03) << 4; - keyIndex += static_cast<size_t>(data[i + 1] & 0xF0) >> 4; - *it++ = key[keyIndex]; + keyIndex = static_cast<size_t>(first & 0x03) << 4; + keyIndex += static_cast<size_t>(second & 0xF0) >> 4; + output += key[keyIndex]; - if (i + 2 < size) - { - keyIndex = static_cast<size_t>(data[i + 1] & 0x0F) << 2; - keyIndex += static_cast<size_t>(data[i + 2] & 0xC0) >> 6; - *it++ = key[keyIndex]; + keyIndex = static_cast<size_t>(second & 0x0F) << 2; + keyIndex += static_cast<size_t>(third & 0xC0) >> 6; + output += key[keyIndex]; - keyIndex = static_cast<size_t>(data[i + 2] & 0x3F); - *it++ = key[keyIndex]; + keyIndex = static_cast<size_t>(third & 0x3F); + output += key[keyIndex]; + } + + public: + // Accepts a partial string to encode, and writes the encoded characters to + // the output stream. requires subsequently calling finalize to complete + // stream. + void encode(std::string_view data, std::string& output) + { + // Encode the last round of overflow chars first + if (overflowCount == 2) + { + if (!data.empty()) + { + encodeTriple(overflow1, overflow2, data[0], output); + overflowCount = 0; + data.remove_prefix(1); } - else + } + else if (overflowCount == 1) + { + if (data.size() >= 2) { - keyIndex = static_cast<size_t>(data[i + 1] & 0x0F) << 2; - *it++ = key[keyIndex]; - *it++ = '='; + encodeTriple(overflow1, data[0], data[1], output); + overflowCount = 0; + data.remove_prefix(2); } } + + while (data.size() >= 3) + { + encodeTriple(data[0], data[1], data[2], output); + data.remove_prefix(3); + } + + if (!data.empty() && overflowCount == 0) + { + overflow1 = data[0]; + overflowCount++; + data.remove_prefix(1); + } + + if (!data.empty() && overflowCount == 1) + { + overflow2 = data[0]; + overflowCount++; + data.remove_prefix(1); + } + } + + // Completes a base64 output, by writing any MOD(3) characters to the + // output, as well as any required trailing = + void finalize(std::string& output) + { + if (overflowCount == 0) + { + return; + } + size_t keyIndex = static_cast<size_t>(overflow1 & 0xFC) >> 2; + output += key[keyIndex]; + + keyIndex = static_cast<size_t>(overflow1 & 0x03) << 4; + if (overflowCount == 2) + { + keyIndex += static_cast<size_t>(overflow2 & 0xF0) >> 4; + output += key[keyIndex]; + keyIndex = static_cast<size_t>(overflow2 & 0x0F) << 2; + output += key[keyIndex]; + } else { - keyIndex = static_cast<size_t>(data[i] & 0x03) << 4; - *it++ = key[keyIndex]; - *it++ = '='; - *it++ = '='; + output += key[keyIndex]; + output += '='; } + output += '='; + overflowCount = 0; + } - i += 3; + // Returns the required output buffer in characters for an input of size + // inputSize + static size_t constexpr encodedSize(size_t inputSize) + { + // Base64 encodes 3 character blocks as 4 character blocks + // With a possibility of 2 trailing = characters + return (inputSize + 2) / 3 * 4; } +}; - return ret; +inline std::string base64encode(std::string_view data) +{ + // Encodes a 3 character stream into a 4 character stream + std::string out; + Base64Encoder base64; + out.reserve(Base64Encoder::encodedSize(data.size())); + base64.encode(data, out); + base64.finalize(out); + return out; } // TODO this is temporary and should be deleted once base64 is refactored out of diff --git a/test/http/utility_test.cpp b/test/http/utility_test.cpp index 5f62d3f3d1..c0b6412af2 100644 --- a/test/http/utility_test.cpp +++ b/test/http/utility_test.cpp @@ -71,6 +71,27 @@ TEST(Utility, Base64EncodeString) EXPECT_EQ(encoded, "ZjAAIEJhcg=="); } +TEST(Utility, Base64Encoder) +{ + using namespace std::string_literals; + std::string data = "f0\0 Bar"s; + for (size_t chunkSize = 1; chunkSize < 6; chunkSize++) + { + std::string_view testString(data); + std::string out; + Base64Encoder encoder; + while (!testString.empty()) + { + size_t thisChunk = std::min(testString.size(), chunkSize); + encoder.encode(testString.substr(0, thisChunk), out); + testString.remove_prefix(thisChunk); + } + + encoder.finalize(out); + EXPECT_EQ(out, "ZjAAIEJhcg=="); + } +} + TEST(Utility, Base64EncodeDecodeString) { using namespace std::string_literals; |