diff --git a/README.md b/README.md index 5cec670..2c92fa0 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@
-[![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/kuba2k2/libretiny/docs.yml?label=docs&logo=markdown)](https://kuba2k2.github.io/libretiny/) +[![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/kuba2k2/libretiny/push-master.yml?label=docs&logo=markdown)](https://docs.libretiny.eu/) ![GitHub last commit](https://img.shields.io/github/last-commit/kuba2k2/libretiny?logo=github) [![Code style: clang-format](https://img.shields.io/badge/code%20style-clang--format-purple.svg)](.clang-format) diff --git a/cores/common/arduino/libraries/common/MD5/MD5PolarSSLImpl.h b/cores/common/arduino/libraries/common/MD5/MD5PolarSSLImpl.h deleted file mode 100644 index 6eb7eaa..0000000 --- a/cores/common/arduino/libraries/common/MD5/MD5PolarSSLImpl.h +++ /dev/null @@ -1,19 +0,0 @@ -/* Copyright (c) Kuba SzczodrzyƄski 2022-06-03. */ - -#pragma once - -#ifdef __cplusplus -extern "C" { -#endif - -typedef struct { - unsigned long total[2]; /*!< number of bytes processed */ - unsigned long state[4]; /*!< intermediate digest state */ - unsigned char buffer[64]; /*!< data block being processed */ -} md5_context; - -#define LT_MD5_CTX_T md5_context - -#ifdef __cplusplus -} // extern "C" -#endif diff --git a/cores/common/arduino/libraries/common/Update/Update.cpp b/cores/common/arduino/libraries/common/Update/Update.cpp index 8c1ff70..0ec8782 100644 --- a/cores/common/arduino/libraries/common/Update/Update.cpp +++ b/cores/common/arduino/libraries/common/Update/Update.cpp @@ -61,6 +61,10 @@ bool UpdateClass::begin( lt_ota_begin(this->ctx, size); this->ctx->callback = reinterpret_cast(progressHandler); this->ctx->callback_param = this; + + this->md5Ctx = static_cast(malloc(sizeof(LT_MD5_CTX_T))); + MD5Init(this->md5Ctx); + return true; } @@ -79,6 +83,9 @@ bool UpdateClass::end(bool evenIfRemaining) { // abort if not finished this->errArd = UPDATE_ERROR_ABORT; + this->md5Digest = static_cast(malloc(16)); + MD5Final(this->md5Digest, this->md5Ctx); + this->cleanup(/* clearError= */ evenIfRemaining); return !this->hasError(); } @@ -97,6 +104,10 @@ void UpdateClass::cleanup(bool clearError) { // activating firmware failed this->errArd = UPDATE_ERROR_ACTIVATE; this->errUf2 = UF2_ERR_OK; + } else if (this->md5Digest && this->md5Expected && memcmp(this->md5Digest, this->md5Expected, 16) != 0) { + // MD5 doesn't match + this->errArd = UPDATE_ERROR_MD5; + this->errUf2 = UF2_ERR_OK; } else if (clearError) { // successful finish and activation, clear error codes this->clearError(); @@ -116,6 +127,12 @@ void UpdateClass::cleanup(bool clearError) { free(this->ctx); this->ctx = nullptr; + free(this->md5Ctx); + this->md5Ctx = nullptr; + free(this->md5Digest); + this->md5Digest = nullptr; + free(this->md5Expected); + this->md5Expected = nullptr; } /** @@ -132,6 +149,7 @@ size_t UpdateClass::write(const uint8_t *data, size_t len) { return 0; size_t written = lt_ota_write(ctx, data, len); + MD5Update(this->md5Ctx, data, len); if (written != len) this->cleanup(/* clearError= */ false); return written; @@ -171,6 +189,8 @@ size_t UpdateClass::writeStream(Stream &data) { // read data to fit in the remaining buffer space auto bufSize = this->ctx->buf_pos - this->ctx->buf; auto read = data.readBytes(this->ctx->buf_pos, UF2_BLOCK_SIZE - bufSize); + // update MD5 + MD5Update(this->md5Ctx, this->ctx->buf_pos, read); // increment buffer writing head this->ctx->buf_pos += read; // process the block if complete diff --git a/cores/common/arduino/libraries/common/Update/Update.h b/cores/common/arduino/libraries/common/Update/Update.h index e5277e5..1d7963b 100644 --- a/cores/common/arduino/libraries/common/Update/Update.h +++ b/cores/common/arduino/libraries/common/Update/Update.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -56,6 +57,9 @@ class UpdateClass { UpdateClass &onProgress(THandlerFunction_Progress handler); static bool canRollBack(); static bool rollBack(); + bool setMD5(const char *md5); + String md5String(); + void md5(uint8_t *result); uint16_t getErrorCode() const; bool hasError() const; void clearError(); @@ -71,6 +75,9 @@ class UpdateClass { uf2_err_t errUf2{UF2_ERR_OK}; UpdateError errArd{UPDATE_ERROR_OK}; THandlerFunction_Progress callback{nullptr}; + LT_MD5_CTX_T *md5Ctx{nullptr}; + uint8_t *md5Digest{nullptr}; + uint8_t *md5Expected{nullptr}; public: /** diff --git a/cores/common/arduino/libraries/common/Update/UpdateUtil.cpp b/cores/common/arduino/libraries/common/Update/UpdateUtil.cpp index 4fe09ee..22fa02f 100644 --- a/cores/common/arduino/libraries/common/Update/UpdateUtil.cpp +++ b/cores/common/arduino/libraries/common/Update/UpdateUtil.cpp @@ -71,6 +71,41 @@ bool UpdateClass::rollBack() { return lt_ota_switch(/* revert= */ false); } +/** + * @brief Set the expected MD5 of the firmware (hexadecimal string). + */ +bool UpdateClass::setMD5(const char *md5) { + if (strlen(md5) != 32) + return false; + this->md5Expected = static_cast(malloc(16)); + if (!this->md5Expected) + return false; + lt_xtob(md5, 32, this->md5Expected); + return true; +} + +/** + * @brief Return a hexadecimal string of calculated firmware MD5 sum. + */ +String UpdateClass::md5String() { + if (!this->md5Digest) + return ""; + char out[32 + 1]; + lt_btox(this->md5Digest, 16, out); + return String(out); +} + +/** + * @brief Get calculated MD5 digest of the firmware. + */ +void UpdateClass::md5(uint8_t *result) { + if (!this->md5Digest) { + memset(result, '\0', 16); + return; + } + memcpy(result, this->md5Digest, 16); +} + /** * @brief Get combined error code of the update. */ diff --git a/cores/common/base/api/lt_utils.c b/cores/common/base/api/lt_utils.c index 521a5bd..e4d571e 100644 --- a/cores/common/base/api/lt_utils.c +++ b/cores/common/base/api/lt_utils.c @@ -39,3 +39,36 @@ void hexdump(const uint8_t *buf, size_t len, uint32_t offset, uint8_t width) { pos += lineWidth; } } + +char *lt_btox(const uint8_t *src, int len, char *dest) { + // https://stackoverflow.com/a/53966346 + const char hex[] = "0123456789abcdef"; + len *= 2; + dest[len] = '\0'; + while (--len >= 0) + dest[len] = hex[(src[len >> 1] >> ((1 - (len & 1)) << 2)) & 0xF]; + return dest; +} + +uint8_t *lt_xtob(const char *src, int len, uint8_t *dest) { + // https://gist.github.com/vi/dd3b5569af8a26b97c8e20ae06e804cb + + // mapping of ASCII characters to hex values + // (16-byte swapped to reduce XOR 0x10 operation) + const uint8_t mapping[] = { + 0x00, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x00, // @ABCDEFG + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // HIJKLMNO + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, // 01234567 + 0x08, 0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // 89:;<=>? + }; + + int j = 0; + uint8_t idx0; + uint8_t idx1; + for (int i = 0; i < len; i += 2) { + idx0 = ((uint8_t)src[i + 0] & 0x1F); + idx1 = ((uint8_t)src[i + 1] & 0x1F); + dest[j++] = (mapping[idx0] << 4) | (mapping[idx1] << 0); + } + return dest; +} diff --git a/cores/common/base/api/lt_utils.h b/cores/common/base/api/lt_utils.h index be13c95..ab942bb 100644 --- a/cores/common/base/api/lt_utils.h +++ b/cores/common/base/api/lt_utils.h @@ -45,3 +45,23 @@ void hexdump( uint8_t width #endif ); + +/** + * @brief Convert a byte array to hexadecimal string. + * + * @param src source byte array + * @param len source length (bytes) + * @param dest destination string + * @return destination string + */ +char *lt_btox(const uint8_t *src, int len, char *dest); + +/** + * @brief Convert a hexadecimal string to byte array. + * + * @param src source string + * @param len source length (chars) + * @param dest destination byte array + * @return destination byte array + */ +uint8_t *lt_xtob(const char *src, int len, uint8_t *dest);