diff --git a/arduino/libretuya/api/LibreTuyaAPI.cpp b/arduino/libretuya/api/LibreTuyaAPI.cpp index ab7404c..1be6967 100644 --- a/arduino/libretuya/api/LibreTuyaAPI.cpp +++ b/arduino/libretuya/api/LibreTuyaAPI.cpp @@ -7,3 +7,9 @@ __weak char *strdup(const char *s) { return NULL; return (char *)memcpy(newp, s, len); } + +String ipToString(const IPAddress &ip) { + char szRet[16]; + sprintf(szRet, "%hhu.%hhu.%hhu.%hhu", ip[0], ip[1], ip[2], ip[3]); + return String(szRet); +} diff --git a/arduino/libretuya/api/LibreTuyaAPI.h b/arduino/libretuya/api/LibreTuyaAPI.h index c5f6556..11f6c99 100644 --- a/arduino/libretuya/api/LibreTuyaAPI.h +++ b/arduino/libretuya/api/LibreTuyaAPI.h @@ -1,22 +1,22 @@ #pragma once +// LibreTuya version macros #ifndef LT_VERSION #define LT_VERSION 1.0.0 #endif - #ifndef LT_BOARD #define LT_BOARD unknown #endif - #define STRINGIFY(x) #x #define STRINGIFY_MACRO(x) STRINGIFY(x) #define LT_VERSION_STR STRINGIFY_MACRO(LT_VERSION) #define LT_BOARD_STR STRINGIFY_MACRO(LT_BOARD) +// Includes +#include "LibreTuyaConfig.h" #include -#include "LibreTuyaConfig.h" - +// C includes #ifdef __cplusplus extern "C" { #endif // __cplusplus @@ -27,16 +27,23 @@ extern "C" { } // extern "C" #endif +// Functional macros #define LT_BANNER() \ LT_LOG( \ LT_LEVEL_INFO, \ - "main.cpp", \ + __FUNCTION__, \ __LINE__, \ "LibreTuya v" LT_VERSION_STR " on " LT_BOARD_STR ", compiled at " __DATE__ " " __TIME__ \ ) -extern char *strdup(const char *); - -// ArduinCore-API doesn't define these anymore +// ArduinoCore-API doesn't define these anymore #define FPSTR(pstr_pointer) (reinterpret_cast(pstr_pointer)) #define PGM_VOID_P const void * + +// C functions +extern char *strdup(const char *); + +// C++ only functions +#ifdef __cplusplus +String ipToString(const IPAddress &ip); +#endif diff --git a/arduino/libretuya/api/LibreTuyaConfig.h b/arduino/libretuya/api/LibreTuyaConfig.h index e197742..a606ae5 100644 --- a/arduino/libretuya/api/LibreTuyaConfig.h +++ b/arduino/libretuya/api/LibreTuyaConfig.h @@ -21,8 +21,8 @@ #define LT_LOGGER_TIMESTAMP 1 #endif -#ifndef LT_LOGGER_FILE -#define LT_LOGGER_FILE 0 +#ifndef LT_LOGGER_CALLER +#define LT_LOGGER_CALLER 1 #endif #ifndef LT_LOGGER_TASK @@ -62,3 +62,7 @@ #ifndef LT_DEBUG_WIFI_AP #define LT_DEBUG_WIFI_AP 0 #endif + +#ifndef LT_DEBUG_SSL +#define LT_DEBUG_SSL 0 +#endif diff --git a/arduino/libretuya/api/WiFiClientSecure.h b/arduino/libretuya/api/WiFiClientSecure.h index c05d1b9..2391d77 100644 --- a/arduino/libretuya/api/WiFiClientSecure.h +++ b/arduino/libretuya/api/WiFiClientSecure.h @@ -22,42 +22,28 @@ #include -#include "WiFi.h" +#include "WiFiClient.h" -class IWiFiClientSecure : public IWiFiClient { +class IWiFiClientSecure { public: - int connect(IPAddress ip, uint16_t port, const char *rootCABuff, const char *cli_cert, const char *cli_key); - int connect(const char *host, uint16_t port, const char *rootCABuff, const char *cli_cert, const char *cli_key); - int connect(IPAddress ip, uint16_t port, const char *pskIdent, const char *psKey); - int connect(const char *host, uint16_t port, const char *pskIdent, const char *psKey); + virtual int + connect(IPAddress ip, uint16_t port, const char *rootCABuf, const char *clientCert, const char *clientKey) = 0; + virtual int + connect(const char *host, uint16_t port, const char *rootCABuf, const char *clientCert, const char *clientKey) = 0; + virtual int connect(IPAddress ip, uint16_t port, const char *pskIdent, const char *psk) = 0; + virtual int connect(const char *host, uint16_t port, const char *pskIdent, const char *psk) = 0; - int lastError(char *buf, const size_t size); - void setInsecure(); // Don't validate the chain, just accept whatever is given. VERY INSECURE! - void setPreSharedKey(const char *pskIdent, const char *psKey); // psKey in Hex - void setCACert(const char *rootCA); - void setCertificate(const char *client_ca); - void setPrivateKey(const char *private_key); - bool loadCACert(Stream &stream, size_t size); - bool loadCertificate(Stream &stream, size_t size); - bool loadPrivateKey(Stream &stream, size_t size); - bool verify(const char *fingerprint, const char *domain_name); - void setHandshakeTimeout(unsigned long handshake_timeout); - - WiFiClientSecure &operator=(const WiFiClientSecure &other); - - bool operator==(const bool value) { - return bool() == value; - } - - bool operator!=(const bool value) { - return bool() != value; - } - - bool operator==(const WiFiClientSecure &); - - bool operator!=(const WiFiClientSecure &rhs) { - return !this->operator==(rhs); - }; - - using Print::write; + virtual int lastError(char *buf, const size_t size) = 0; + virtual void setInsecure() = 0; // Don't validate the chain, just accept whatever is given. VERY INSECURE! + virtual void setPreSharedKey(const char *pskIdent, const char *psk) = 0; // psk in hex + virtual void setCACert(const char *rootCA) = 0; + virtual void setCertificate(const char *clientCA) = 0; + virtual void setPrivateKey(const char *privateKey) = 0; + virtual bool loadCACert(Stream &stream, size_t size) = 0; + virtual bool loadCertificate(Stream &stream, size_t size) = 0; + virtual bool loadPrivateKey(Stream &stream, size_t size) = 0; + virtual bool verify(const char *fingerprint, const char *domainName) = 0; + virtual void setHandshakeTimeout(unsigned long handshakeTimeout) = 0; + virtual void setAlpnProtocols(const char **alpnProtocols) = 0; + virtual bool getFingerprintSHA256(uint8_t result[32]) = 0; }; diff --git a/arduino/libretuya/api/lt_logger.c b/arduino/libretuya/api/lt_logger.c index 93f4998..e0a25fa 100644 --- a/arduino/libretuya/api/lt_logger.c +++ b/arduino/libretuya/api/lt_logger.c @@ -40,8 +40,8 @@ const uint8_t colors[] = { unsigned long millis(void); -#if LT_LOGGER_FILE -void lt_log(const uint8_t level, const char *filename, const unsigned short line, const char *format, ...) { +#if LT_LOGGER_CALLER +void lt_log(const uint8_t level, const char *caller, const unsigned short line, const char *format, ...) { #else void lt_log(const uint8_t level, const char *format, ...) { #endif @@ -85,8 +85,8 @@ void lt_log(const uint8_t level, const char *format, ...) { #if LT_LOGGER_COLOR "\e[0m" #endif -#if LT_LOGGER_FILE - "%s:%hu: " +#if LT_LOGGER_CALLER + "%s():%hu: " #endif #if LT_LOGGER_TASK "%s%c " @@ -106,9 +106,9 @@ void lt_log(const uint8_t level, const char *format, ...) { zero // append missing zeroes if printf "%11.3f" prints "0." #endif #endif -#if LT_LOGGER_FILE +#if LT_LOGGER_CALLER , - filename, + caller, line #endif #if LT_LOGGER_TASK diff --git a/arduino/libretuya/api/lt_logger.h b/arduino/libretuya/api/lt_logger.h index 9d5abeb..65a3df3 100644 --- a/arduino/libretuya/api/lt_logger.h +++ b/arduino/libretuya/api/lt_logger.h @@ -3,48 +3,48 @@ #include "LibreTuyaConfig.h" #include -#if LT_LOGGER_FILE -#define LT_LOG(level, file, line, ...) lt_log(level, file, line, __VA_ARGS__) -void lt_log(const uint8_t level, const char *filename, const unsigned short line, const char *format, ...); +#if LT_LOGGER_CALLER +#define LT_LOG(level, caller, line, ...) lt_log(level, caller, line, __VA_ARGS__) +void lt_log(const uint8_t level, const char *caller, const unsigned short line, const char *format, ...); #else -#define LT_LOG(level, file, line, ...) lt_log(level, __VA_ARGS__) +#define LT_LOG(level, caller, line, ...) lt_log(level, __VA_ARGS__) void lt_log(const uint8_t level, const char *format, ...); #endif #if LT_LEVEL_TRACE >= LT_LOGLEVEL -#define LT_T(...) LT_LOG(LT_LEVEL_TRACE, __FILE__, __LINE__, __VA_ARGS__) -#define LT_V(...) LT_LOG(LT_LEVEL_TRACE, __FILE__, __LINE__, __VA_ARGS__) +#define LT_T(...) LT_LOG(LT_LEVEL_TRACE, __FUNCTION__, __LINE__, __VA_ARGS__) +#define LT_V(...) LT_LOG(LT_LEVEL_TRACE, __FUNCTION__, __LINE__, __VA_ARGS__) #else #define LT_T(...) #define LT_V(...) #endif #if LT_LEVEL_DEBUG >= LT_LOGLEVEL -#define LT_D(...) LT_LOG(LT_LEVEL_DEBUG, __FILE__, __LINE__, __VA_ARGS__) +#define LT_D(...) LT_LOG(LT_LEVEL_DEBUG, __FUNCTION__, __LINE__, __VA_ARGS__) #else #define LT_D(...) #endif #if LT_LEVEL_INFO >= LT_LOGLEVEL -#define LT_I(...) LT_LOG(LT_LEVEL_INFO, __FILE__, __LINE__, __VA_ARGS__) +#define LT_I(...) LT_LOG(LT_LEVEL_INFO, __FUNCTION__, __LINE__, __VA_ARGS__) #else #define LT_I(...) #endif #if LT_LEVEL_WARN >= LT_LOGLEVEL -#define LT_W(...) LT_LOG(LT_LEVEL_WARN, __FILE__, __LINE__, __VA_ARGS__) +#define LT_W(...) LT_LOG(LT_LEVEL_WARN, __FUNCTION__, __LINE__, __VA_ARGS__) #else #define LT_W(...) #endif #if LT_LEVEL_ERROR >= LT_LOGLEVEL -#define LT_E(...) LT_LOG(LT_LEVEL_ERROR, __FILE__, __LINE__, __VA_ARGS__) +#define LT_E(...) LT_LOG(LT_LEVEL_ERROR, __FUNCTION__, __LINE__, __VA_ARGS__) #else #define LT_E(...) #endif #if LT_LEVEL_FATAL >= LT_LOGLEVEL -#define LT_F(...) LT_LOG(LT_LEVEL_FATAL, __FILE__, __LINE__, __VA_ARGS__) +#define LT_F(...) LT_LOG(LT_LEVEL_FATAL, __FUNCTION__, __LINE__, __VA_ARGS__) #else #define LT_F(...) #endif @@ -88,6 +88,42 @@ void lt_log(const uint8_t level, const char *format, ...); } \ } while (0) +#define LT_RET(ret) \ + LT_E("ret=%d", ret); \ + return ret; + +#define LT_RET_NZ(ret) \ + if (ret) { \ + LT_E("ret=%d", ret); \ + return ret; \ + } +#define LT_RET_LZ(ret) \ + if (ret < 0) { \ + LT_E("ret=%d", ret); \ + return ret; \ + } +#define LT_RET_LEZ(ret) \ + if (ret <= 0) { \ + LT_E("ret=%d", ret); \ + return ret; \ + } + +#define LT_ERRNO_NZ(ret) \ + if (ret) { \ + LT_E("errno=%d, ret=%d", errno, ret); \ + return ret; \ + } +#define LT_ERRNO_LZ(ret) \ + if (ret < 0) { \ + LT_E("errno=%d, ret=%d", errno, ret); \ + return ret; \ + } +#define LT_ERRNO_LEZ(ret) \ + if (ret <= 0) { \ + LT_E("errno=%d, ret=%d", errno, ret); \ + return ret; \ + } + // WiFi.cpp #define LT_T_WG(...) LT_T_MOD(LT_DEBUG_WIFI, __VA_ARGS__) #define LT_V_WG(...) LT_T_MOD(LT_DEBUG_WIFI, __VA_ARGS__) @@ -112,3 +148,8 @@ void lt_log(const uint8_t level, const char *format, ...); #define LT_T_WAP(...) LT_T_MOD(LT_DEBUG_WIFI_AP, __VA_ARGS__) #define LT_V_WAP(...) LT_T_MOD(LT_DEBUG_WIFI_AP, __VA_ARGS__) #define LT_D_WAP(...) LT_D_MOD(LT_DEBUG_WIFI_AP, __VA_ARGS__) + +// WiFiClientSecure.cpp & implementations +#define LT_T_SSL(...) LT_T_MOD(LT_DEBUG_SSL, __VA_ARGS__) +#define LT_V_SSL(...) LT_T_MOD(LT_DEBUG_SSL, __VA_ARGS__) +#define LT_D_SSL(...) LT_D_MOD(LT_DEBUG_SSL, __VA_ARGS__) diff --git a/arduino/libretuya/libraries/NetUtils/ssl/MbedTLSClient.cpp b/arduino/libretuya/libraries/NetUtils/ssl/MbedTLSClient.cpp new file mode 100644 index 0000000..6f0cf29 --- /dev/null +++ b/arduino/libretuya/libraries/NetUtils/ssl/MbedTLSClient.cpp @@ -0,0 +1,443 @@ +/* Copyright (c) Kuba SzczodrzyƄski 2022-04-30. */ + +#include "MbedTLSClient.h" + +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +#include +#include +#include +#include + +#ifdef __cplusplus +} // extern "C" +#endif + +MbedTLSClient::MbedTLSClient() : WiFiClient() {} + +MbedTLSClient::MbedTLSClient(int sock) : WiFiClient(sock) {} + +void MbedTLSClient::stop() { + WiFiClient::stop(); + LT_V_SSL("Closing SSL connection"); + + if (_sslCfg.ca_chain) { + mbedtls_x509_crt_free(&_caCert); + } + if (_sslCfg.key_cert) { + mbedtls_x509_crt_free(&_clientCert); + mbedtls_pk_free(&_clientKey); + } + mbedtls_ssl_free(&_sslCtx); + mbedtls_ssl_config_free(&_sslCfg); +} + +void MbedTLSClient::init() { + // Realtek AmbZ: init platform here to ensure HW crypto is initialized in ssl_init + mbedtls_platform_set_calloc_free(calloc, free); + mbedtls_ssl_init(&_sslCtx); + mbedtls_ssl_config_init(&_sslCfg); +} + +int MbedTLSClient::connect(IPAddress ip, uint16_t port, int32_t timeout) { + return connect(ipToString(ip).c_str(), port, timeout) == 0; +} + +int MbedTLSClient::connect(const char *host, uint16_t port, int32_t timeout) { + if (_pskIdentStr && _pskStr) + return connect(host, port, NULL, NULL, NULL, _pskIdentStr, _pskStr, _alpnProtocols) == 0; + return connect(host, port, _caCertStr, _clientCertStr, _clientKeyStr, NULL, NULL, _alpnProtocols) == 0; +} + +int MbedTLSClient::connect( + IPAddress ip, uint16_t port, const char *rootCABuf, const char *clientCert, const char *clientKey +) { + return connect(ipToString(ip).c_str(), port, rootCABuf, clientCert, clientKey, NULL, NULL, _alpnProtocols) == 0; +} + +int MbedTLSClient::connect( + const char *host, uint16_t port, const char *rootCABuf, const char *clientCert, const char *clientKey +) { + return connect(host, port, rootCABuf, clientCert, clientKey, NULL, NULL, _alpnProtocols) == 0; +} + +int MbedTLSClient::connect(IPAddress ip, uint16_t port, const char *pskIdent, const char *psk) { + return connect(ipToString(ip).c_str(), port, NULL, NULL, NULL, pskIdent, psk, _alpnProtocols) == 0; +} + +int MbedTLSClient::connect(const char *host, uint16_t port, const char *pskIdent, const char *psk) { + return connect(host, port, NULL, NULL, NULL, pskIdent, psk, _alpnProtocols) == 0; +} + +static int ssl_random(void *data, unsigned char *output, size_t len) { + int *buf = (int *)output; + size_t i; + for (i = 0; len >= sizeof(int); len -= sizeof(int)) { + buf[i++] = rand(); + } + if (len) { + int rem = rand(); + unsigned char *pRem = (unsigned char *)&rem; + memcpy(output + i * sizeof(int), pRem, len); + } + return 0; +} + +void debug_cb(void *ctx, int level, const char *file, int line, const char *str) { + LT_I("%04d: |%d| %s", line, level, str); +} + +int MbedTLSClient::connect( + const char *host, + uint16_t port, + const char *rootCABuf, + const char *clientCert, + const char *clientKey, + const char *pskIdent, + const char *psk, + const char **alpnProtocols +) { + LT_D_SSL("Free heap before TLS: TODO"); + + if (!rootCABuf && !pskIdent && !psk && !_insecure && !_useRootCA) + return -1; + + IPAddress addr = WiFi.hostByName(host); + if (!(uint32_t)addr) + return -1; + + int ret = WiFiClient::connect(addr, port, _timeout); + if (ret < 0) { + LT_E("SSL socket failed"); + return ret; + } + + char *uid = "lt-ssl"; // TODO + + LT_V_SSL("Init SSL"); + init(); + + // mbedtls_debug_set_threshold(4); + // mbedtls_ssl_conf_dbg(&_sslCfg, debug_cb, NULL); + + ret = mbedtls_ssl_config_defaults( + &_sslCfg, + MBEDTLS_SSL_IS_CLIENT, + MBEDTLS_SSL_TRANSPORT_STREAM, + MBEDTLS_SSL_PRESET_DEFAULT + ); + LT_RET_NZ(ret); + +#ifdef MBEDTLS_SSL_ALPN + if (alpnProtocols) { + ret = mbedtls_ssl_conf_alpn_protocols(&_sslCfg, alpnProtocols); + LT_RET_NZ(ret); + } +#endif + + if (_insecure) { + mbedtls_ssl_conf_authmode(&_sslCfg, MBEDTLS_SSL_VERIFY_NONE); + } else if (rootCABuf) { + mbedtls_x509_crt_init(&_caCert); + mbedtls_ssl_conf_authmode(&_sslCfg, MBEDTLS_SSL_VERIFY_REQUIRED); + ret = mbedtls_x509_crt_parse(&_caCert, (const unsigned char *)rootCABuf, strlen(rootCABuf) + 1); + mbedtls_ssl_conf_ca_chain(&_sslCfg, &_caCert, NULL); + if (ret < 0) { + mbedtls_x509_crt_free(&_caCert); + LT_RET(ret); + } + } else if (_useRootCA) { + return -1; // not implemented + } else if (pskIdent && psk) { +#ifdef MBEDTLS_KEY_EXCHANGE__SOME__PSK_ENABLED + uint16_t len = strlen(psk); + if ((len & 1) != 0 || len > 2 * MBEDTLS_PSK_MAX_LEN) { + LT_E("PSK length invalid"); + return -1; + } + unsigned char pskBin[MBEDTLS_PSK_MAX_LEN] = {}; + for (uint8_t i = 0; i < len; i++) { + uint8_t c = psk[i]; + c |= 0b00100000; // make lowercase + c -= '0' * (c >= '0' && c <= '9'); + c -= ('a' - 10) * (c >= 'a' && c <= 'z'); + if (c > 0xf) + return -1; + pskBin[i / 2] |= c << (4 * ((i & 1) ^ 1)); + } + ret = mbedtls_ssl_conf_psk(&_sslCfg, pskBin, len / 2, (const unsigned char *)pskIdent, strlen(pskIdent)); + LT_RET_NZ(ret); +#else + return -1; +#endif + } else { + return -1; + } + + if (!_insecure && clientCert && clientKey) { + mbedtls_x509_crt_init(&_clientCert); + mbedtls_pk_init(&_clientKey); + LT_V_SSL("Loading client cert"); + ret = mbedtls_x509_crt_parse(&_clientCert, (const unsigned char *)clientCert, strlen(clientCert) + 1); + if (ret < 0) { + mbedtls_x509_crt_free(&_clientCert); + LT_RET(ret); + } + LT_V_SSL("Loading private key"); + ret = mbedtls_pk_parse_key(&_clientKey, (const unsigned char *)clientKey, strlen(clientKey) + 1, NULL, 0); + if (ret < 0) { + mbedtls_x509_crt_free(&_clientCert); + LT_RET(ret); + } + mbedtls_ssl_conf_own_cert(&_sslCfg, &_clientCert, &_clientKey); + } + + LT_V_SSL("Setting TLS hostname"); + ret = mbedtls_ssl_set_hostname(&_sslCtx, host); + LT_RET_NZ(ret); + + mbedtls_ssl_conf_rng(&_sslCfg, ssl_random, NULL); + ret = mbedtls_ssl_setup(&_sslCtx, &_sslCfg); + LT_RET_NZ(ret); + + _sockTls = fd(); + mbedtls_ssl_set_bio(&_sslCtx, &_sockTls, mbedtls_net_send, mbedtls_net_recv, NULL); + + LT_V_SSL("SSL handshake"); + if (_handshakeTimeout == 0) + _handshakeTimeout = _timeout * 1000; + unsigned long start = millis(); + while (ret = mbedtls_ssl_handshake(&_sslCtx)) { + if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) { + LT_RET(ret); + } + if ((millis() - start) > _handshakeTimeout) { + LT_E("SSL handshake timeout"); + return -1; + } + delay(2); + } + + if (clientCert && clientKey) { + LT_D_SSL( + "Protocol %s, ciphersuite %s", + mbedtls_ssl_get_version(&_sslCtx), + mbedtls_ssl_get_ciphersuite(&_sslCtx) + ); + ret = mbedtls_ssl_get_record_expansion(&_sslCtx); + if (ret >= 0) + LT_D_SSL("Record expansion: %d", ret); + else { + LT_W("Record expansion unknown"); + } + } + + LT_V_SSL("Verifying certificate"); + ret = mbedtls_ssl_get_verify_result(&_sslCtx); + if (ret) { + char buf[512]; + memset(buf, 0, sizeof(buf)); + mbedtls_x509_crt_verify_info(buf, sizeof(buf), " ! ", ret); + LT_E("Failed to verify peer certificate! Verification info: %s", buf); + return ret; + } + + if (rootCABuf) + mbedtls_x509_crt_free(&_caCert); + if (clientCert) + mbedtls_x509_crt_free(&_clientCert); + if (clientKey != NULL) + mbedtls_pk_free(&_clientKey); + return 0; // OK +} + +size_t MbedTLSClient::write(const uint8_t *buf, size_t size) { + int ret = -1; + while ((ret = mbedtls_ssl_write(&_sslCtx, buf, size)) <= 0) { + if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE && ret < 0) { + LT_RET(ret); + } + delay(2); + } + return ret; +} + +int MbedTLSClient::available() { + bool peeked = _peeked >= 0; + if (!connected()) + return peeked; + + int ret = mbedtls_ssl_read(&_sslCtx, NULL, 0); + if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE && ret < 0) { + stop(); + return peeked ? peeked : ret; + } + return mbedtls_ssl_get_bytes_avail(&_sslCtx) + peeked; +} + +int MbedTLSClient::read(uint8_t *buf, size_t size) { + bool peeked = false; + int toRead = available(); + if ((!buf && size) || toRead <= 0) + return -1; + if (!size) + return 0; + if (_peeked >= 0) { + buf[0] = _peeked; + _peeked = -1; + size--; + toRead--; + if (!size || !toRead) + return 1; + buf++; + peeked = true; + } + + int ret = mbedtls_ssl_read(&_sslCtx, buf, size); + if (ret < 0) { + stop(); + return peeked ? peeked : ret; + } + return ret + peeked; +} + +int MbedTLSClient::peek() { + if (_peeked >= 0) + return _peeked; + _peeked = timedRead(); + return _peeked; +} + +void MbedTLSClient::flush() {} + +int MbedTLSClient::lastError(char *buf, const size_t size) { + return 0; // TODO (?) +} + +void MbedTLSClient::setInsecure() { + _caCertStr = NULL; + _clientCertStr = NULL; + _clientKeyStr = NULL; + _pskIdentStr = NULL; + _pskStr = NULL; + _insecure = true; +} + +void MbedTLSClient::setPreSharedKey(const char *pskIdent, const char *psk) { + _pskIdentStr = pskIdent; + _pskStr = psk; +} + +void MbedTLSClient::setCACert(const char *rootCA) { + _caCertStr = rootCA; +} + +void MbedTLSClient::setCertificate(const char *clientCA) { + _clientCertStr = clientCA; +} + +void MbedTLSClient::setPrivateKey(const char *privateKey) { + _clientKeyStr = privateKey; +} + +char *streamToStr(Stream &stream, size_t size) { + char *buf = (char *)malloc(size + 1); + if (!buf) + return NULL; + if (size != stream.readBytes(buf, size)) { + free(buf); + return NULL; + } + buf[size] = '\0'; + return buf; +} + +bool MbedTLSClient::loadCACert(Stream &stream, size_t size) { + char *str = streamToStr(stream, size); + if (str) { + _caCertStr = str; + return true; + } + return false; +} + +bool MbedTLSClient::loadCertificate(Stream &stream, size_t size) { + char *str = streamToStr(stream, size); + if (str) { + _clientCertStr = str; + return true; + } + return false; +} + +bool MbedTLSClient::loadPrivateKey(Stream &stream, size_t size) { + char *str = streamToStr(stream, size); + if (str) { + _clientKeyStr = str; + return true; + } + return false; +} + +bool MbedTLSClient::verify(const char *fingerprint, const char *domainName) { + uint8_t fpLocal[32] = {}; + uint16_t len = strlen(fingerprint); + uint8_t byte = 0; + for (uint8_t i = 0; i < len; i++) { + uint8_t c = fingerprint[i]; + while ((c == ' ' || c == ':') && i < len) { + c = fingerprint[++i]; + } + c |= 0b00100000; // make lowercase + c -= '0' * (c >= '0' && c <= '9'); + c -= ('a' - 10) * (c >= 'a' && c <= 'z'); + if (c > 0xf) + return -1; + fpLocal[byte / 2] |= c << (4 * ((byte & 1) ^ 1)); + byte++; + if (byte >= 64) + break; + } + + uint8_t fpRemote[32]; + if (!getFingerprintSHA256(fpRemote)) + return false; + + if (memcmp(fpLocal, fpRemote, 32)) { + LT_D_SSL("Fingerprints don't match"); + return false; + } + + if (!domainName) + return true; + // TODO domain name verification + return true; +} + +void MbedTLSClient::setHandshakeTimeout(unsigned long handshakeTimeout) { + _handshakeTimeout = handshakeTimeout * 1000; +} + +void MbedTLSClient::setAlpnProtocols(const char **alpnProtocols) { + _alpnProtocols = alpnProtocols; +} + +bool MbedTLSClient::getFingerprintSHA256(uint8_t result[32]) { + const mbedtls_x509_crt *cert = mbedtls_ssl_get_peer_cert(&_sslCtx); + if (!cert) { + LT_E("Failed to get peer certificate"); + return false; + } + mbedtls_sha256_context shaCtx; + mbedtls_sha256_init(&shaCtx); + mbedtls_sha256_starts(&shaCtx, false); + mbedtls_sha256_update(&shaCtx, cert->raw.p, cert->raw.len); + mbedtls_sha256_finish(&shaCtx, result); + return true; +} diff --git a/arduino/libretuya/libraries/NetUtils/ssl/MbedTLSClient.h b/arduino/libretuya/libraries/NetUtils/ssl/MbedTLSClient.h new file mode 100644 index 0000000..6bcd472 --- /dev/null +++ b/arduino/libretuya/libraries/NetUtils/ssl/MbedTLSClient.h @@ -0,0 +1,90 @@ +/* Copyright (c) Kuba SzczodrzyƄski 2022-04-30. */ + +#pragma once + +#include +#include + +#include // extend platform's WiFiClient impl + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +#include + +#ifdef __cplusplus +} // extern "C" +#endif + +class MbedTLSClient : public WiFiClient, public IWiFiClientSecure { + private: + mbedtls_ssl_context _sslCtx; + mbedtls_ssl_config _sslCfg; + mbedtls_x509_crt _caCert; + mbedtls_x509_crt _clientCert; + mbedtls_pk_context _clientKey; + uint32_t _handshakeTimeout = 0; + + void init(); + int _sockTls = -1; + bool _insecure = false; + bool _useRootCA = false; + int _peeked = -1; + + const char *_caCertStr; + const char *_clientCertStr; + const char *_clientKeyStr; + const char *_pskIdentStr; + const char *_pskStr; + const char **_alpnProtocols; + + int connect( + const char *host, + uint16_t port, + const char *rootCABuf, + const char *clientCert, + const char *clientKey, + const char *pskIdent, + const char *psk, + const char **alpnProtocols + ); + + public: + MbedTLSClient(); + MbedTLSClient(int sock); + + int connect(IPAddress ip, uint16_t port, int32_t timeout); + int connect(const char *host, uint16_t port, int32_t timeout); + + int connect(IPAddress ip, uint16_t port, const char *rootCABuf, const char *clientCert, const char *clientKey); + int connect(const char *host, uint16_t port, const char *rootCABuf, const char *clientCert, const char *clientKey); + int connect(IPAddress ip, uint16_t port, const char *pskIdent, const char *psk); + int connect(const char *host, uint16_t port, const char *pskIdent, const char *psk); + + size_t write(const uint8_t *buf, size_t size); + + int available(); + + int read(uint8_t *buf, size_t size); + int peek(); + void flush(); + void stop(); + + int lastError(char *buf, const size_t size); + void setInsecure(); // Don't validate the chain, just accept whatever is given. VERY INSECURE! + void setPreSharedKey(const char *pskIdent, const char *psk); // psk in hex + void setCACert(const char *rootCA); + void setCertificate(const char *clientCA); + void setPrivateKey(const char *privateKey); + bool loadCACert(Stream &stream, size_t size); + bool loadCertificate(Stream &stream, size_t size); + bool loadPrivateKey(Stream &stream, size_t size); + bool verify(const char *fingerprint, const char *domainName); + void setHandshakeTimeout(unsigned long handshakeTimeout); + void setAlpnProtocols(const char **alpnProtocols); + bool getFingerprintSHA256(uint8_t result[32]); + + using WiFiClient::connect; + using WiFiClient::read; +};