[core] Add MbedTLSClient
This commit is contained in:
@@ -7,3 +7,9 @@ __weak char *strdup(const char *s) {
|
||||
return NULL;
|
||||
return (char *)memcpy(newp, s, len);
|
||||
}
|
||||
|
||||
String ipToString(const IPAddress &ip) {
|
||||
char szRet[16];
|
||||
sprintf(szRet, "%hhu.%hhu.%hhu.%hhu", ip[0], ip[1], ip[2], ip[3]);
|
||||
return String(szRet);
|
||||
}
|
||||
|
||||
@@ -1,22 +1,22 @@
|
||||
#pragma once
|
||||
|
||||
// LibreTuya version macros
|
||||
#ifndef LT_VERSION
|
||||
#define LT_VERSION 1.0.0
|
||||
#endif
|
||||
|
||||
#ifndef LT_BOARD
|
||||
#define LT_BOARD unknown
|
||||
#endif
|
||||
|
||||
#define STRINGIFY(x) #x
|
||||
#define STRINGIFY_MACRO(x) STRINGIFY(x)
|
||||
#define LT_VERSION_STR STRINGIFY_MACRO(LT_VERSION)
|
||||
#define LT_BOARD_STR STRINGIFY_MACRO(LT_BOARD)
|
||||
|
||||
// Includes
|
||||
#include "LibreTuyaConfig.h"
|
||||
#include <Arduino.h>
|
||||
|
||||
#include "LibreTuyaConfig.h"
|
||||
|
||||
// C includes
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
@@ -27,16 +27,23 @@ extern "C" {
|
||||
} // extern "C"
|
||||
#endif
|
||||
|
||||
// Functional macros
|
||||
#define LT_BANNER() \
|
||||
LT_LOG( \
|
||||
LT_LEVEL_INFO, \
|
||||
"main.cpp", \
|
||||
__FUNCTION__, \
|
||||
__LINE__, \
|
||||
"LibreTuya v" LT_VERSION_STR " on " LT_BOARD_STR ", compiled at " __DATE__ " " __TIME__ \
|
||||
)
|
||||
|
||||
extern char *strdup(const char *);
|
||||
|
||||
// ArduinCore-API doesn't define these anymore
|
||||
// ArduinoCore-API doesn't define these anymore
|
||||
#define FPSTR(pstr_pointer) (reinterpret_cast<const __FlashStringHelper *>(pstr_pointer))
|
||||
#define PGM_VOID_P const void *
|
||||
|
||||
// C functions
|
||||
extern char *strdup(const char *);
|
||||
|
||||
// C++ only functions
|
||||
#ifdef __cplusplus
|
||||
String ipToString(const IPAddress &ip);
|
||||
#endif
|
||||
|
||||
@@ -21,8 +21,8 @@
|
||||
#define LT_LOGGER_TIMESTAMP 1
|
||||
#endif
|
||||
|
||||
#ifndef LT_LOGGER_FILE
|
||||
#define LT_LOGGER_FILE 0
|
||||
#ifndef LT_LOGGER_CALLER
|
||||
#define LT_LOGGER_CALLER 1
|
||||
#endif
|
||||
|
||||
#ifndef LT_LOGGER_TASK
|
||||
@@ -62,3 +62,7 @@
|
||||
#ifndef LT_DEBUG_WIFI_AP
|
||||
#define LT_DEBUG_WIFI_AP 0
|
||||
#endif
|
||||
|
||||
#ifndef LT_DEBUG_SSL
|
||||
#define LT_DEBUG_SSL 0
|
||||
#endif
|
||||
|
||||
@@ -22,42 +22,28 @@
|
||||
|
||||
#include <Arduino.h>
|
||||
|
||||
#include "WiFi.h"
|
||||
#include "WiFiClient.h"
|
||||
|
||||
class IWiFiClientSecure : public IWiFiClient {
|
||||
class IWiFiClientSecure {
|
||||
public:
|
||||
int connect(IPAddress ip, uint16_t port, const char *rootCABuff, const char *cli_cert, const char *cli_key);
|
||||
int connect(const char *host, uint16_t port, const char *rootCABuff, const char *cli_cert, const char *cli_key);
|
||||
int connect(IPAddress ip, uint16_t port, const char *pskIdent, const char *psKey);
|
||||
int connect(const char *host, uint16_t port, const char *pskIdent, const char *psKey);
|
||||
virtual int
|
||||
connect(IPAddress ip, uint16_t port, const char *rootCABuf, const char *clientCert, const char *clientKey) = 0;
|
||||
virtual int
|
||||
connect(const char *host, uint16_t port, const char *rootCABuf, const char *clientCert, const char *clientKey) = 0;
|
||||
virtual int connect(IPAddress ip, uint16_t port, const char *pskIdent, const char *psk) = 0;
|
||||
virtual int connect(const char *host, uint16_t port, const char *pskIdent, const char *psk) = 0;
|
||||
|
||||
int lastError(char *buf, const size_t size);
|
||||
void setInsecure(); // Don't validate the chain, just accept whatever is given. VERY INSECURE!
|
||||
void setPreSharedKey(const char *pskIdent, const char *psKey); // psKey in Hex
|
||||
void setCACert(const char *rootCA);
|
||||
void setCertificate(const char *client_ca);
|
||||
void setPrivateKey(const char *private_key);
|
||||
bool loadCACert(Stream &stream, size_t size);
|
||||
bool loadCertificate(Stream &stream, size_t size);
|
||||
bool loadPrivateKey(Stream &stream, size_t size);
|
||||
bool verify(const char *fingerprint, const char *domain_name);
|
||||
void setHandshakeTimeout(unsigned long handshake_timeout);
|
||||
|
||||
WiFiClientSecure &operator=(const WiFiClientSecure &other);
|
||||
|
||||
bool operator==(const bool value) {
|
||||
return bool() == value;
|
||||
}
|
||||
|
||||
bool operator!=(const bool value) {
|
||||
return bool() != value;
|
||||
}
|
||||
|
||||
bool operator==(const WiFiClientSecure &);
|
||||
|
||||
bool operator!=(const WiFiClientSecure &rhs) {
|
||||
return !this->operator==(rhs);
|
||||
};
|
||||
|
||||
using Print::write;
|
||||
virtual int lastError(char *buf, const size_t size) = 0;
|
||||
virtual void setInsecure() = 0; // Don't validate the chain, just accept whatever is given. VERY INSECURE!
|
||||
virtual void setPreSharedKey(const char *pskIdent, const char *psk) = 0; // psk in hex
|
||||
virtual void setCACert(const char *rootCA) = 0;
|
||||
virtual void setCertificate(const char *clientCA) = 0;
|
||||
virtual void setPrivateKey(const char *privateKey) = 0;
|
||||
virtual bool loadCACert(Stream &stream, size_t size) = 0;
|
||||
virtual bool loadCertificate(Stream &stream, size_t size) = 0;
|
||||
virtual bool loadPrivateKey(Stream &stream, size_t size) = 0;
|
||||
virtual bool verify(const char *fingerprint, const char *domainName) = 0;
|
||||
virtual void setHandshakeTimeout(unsigned long handshakeTimeout) = 0;
|
||||
virtual void setAlpnProtocols(const char **alpnProtocols) = 0;
|
||||
virtual bool getFingerprintSHA256(uint8_t result[32]) = 0;
|
||||
};
|
||||
|
||||
@@ -40,8 +40,8 @@ const uint8_t colors[] = {
|
||||
|
||||
unsigned long millis(void);
|
||||
|
||||
#if LT_LOGGER_FILE
|
||||
void lt_log(const uint8_t level, const char *filename, const unsigned short line, const char *format, ...) {
|
||||
#if LT_LOGGER_CALLER
|
||||
void lt_log(const uint8_t level, const char *caller, const unsigned short line, const char *format, ...) {
|
||||
#else
|
||||
void lt_log(const uint8_t level, const char *format, ...) {
|
||||
#endif
|
||||
@@ -85,8 +85,8 @@ void lt_log(const uint8_t level, const char *format, ...) {
|
||||
#if LT_LOGGER_COLOR
|
||||
"\e[0m"
|
||||
#endif
|
||||
#if LT_LOGGER_FILE
|
||||
"%s:%hu: "
|
||||
#if LT_LOGGER_CALLER
|
||||
"%s():%hu: "
|
||||
#endif
|
||||
#if LT_LOGGER_TASK
|
||||
"%s%c "
|
||||
@@ -106,9 +106,9 @@ void lt_log(const uint8_t level, const char *format, ...) {
|
||||
zero // append missing zeroes if printf "%11.3f" prints "0."
|
||||
#endif
|
||||
#endif
|
||||
#if LT_LOGGER_FILE
|
||||
#if LT_LOGGER_CALLER
|
||||
,
|
||||
filename,
|
||||
caller,
|
||||
line
|
||||
#endif
|
||||
#if LT_LOGGER_TASK
|
||||
|
||||
@@ -3,48 +3,48 @@
|
||||
#include "LibreTuyaConfig.h"
|
||||
#include <stdint.h>
|
||||
|
||||
#if LT_LOGGER_FILE
|
||||
#define LT_LOG(level, file, line, ...) lt_log(level, file, line, __VA_ARGS__)
|
||||
void lt_log(const uint8_t level, const char *filename, const unsigned short line, const char *format, ...);
|
||||
#if LT_LOGGER_CALLER
|
||||
#define LT_LOG(level, caller, line, ...) lt_log(level, caller, line, __VA_ARGS__)
|
||||
void lt_log(const uint8_t level, const char *caller, const unsigned short line, const char *format, ...);
|
||||
#else
|
||||
#define LT_LOG(level, file, line, ...) lt_log(level, __VA_ARGS__)
|
||||
#define LT_LOG(level, caller, line, ...) lt_log(level, __VA_ARGS__)
|
||||
void lt_log(const uint8_t level, const char *format, ...);
|
||||
#endif
|
||||
|
||||
#if LT_LEVEL_TRACE >= LT_LOGLEVEL
|
||||
#define LT_T(...) LT_LOG(LT_LEVEL_TRACE, __FILE__, __LINE__, __VA_ARGS__)
|
||||
#define LT_V(...) LT_LOG(LT_LEVEL_TRACE, __FILE__, __LINE__, __VA_ARGS__)
|
||||
#define LT_T(...) LT_LOG(LT_LEVEL_TRACE, __FUNCTION__, __LINE__, __VA_ARGS__)
|
||||
#define LT_V(...) LT_LOG(LT_LEVEL_TRACE, __FUNCTION__, __LINE__, __VA_ARGS__)
|
||||
#else
|
||||
#define LT_T(...)
|
||||
#define LT_V(...)
|
||||
#endif
|
||||
|
||||
#if LT_LEVEL_DEBUG >= LT_LOGLEVEL
|
||||
#define LT_D(...) LT_LOG(LT_LEVEL_DEBUG, __FILE__, __LINE__, __VA_ARGS__)
|
||||
#define LT_D(...) LT_LOG(LT_LEVEL_DEBUG, __FUNCTION__, __LINE__, __VA_ARGS__)
|
||||
#else
|
||||
#define LT_D(...)
|
||||
#endif
|
||||
|
||||
#if LT_LEVEL_INFO >= LT_LOGLEVEL
|
||||
#define LT_I(...) LT_LOG(LT_LEVEL_INFO, __FILE__, __LINE__, __VA_ARGS__)
|
||||
#define LT_I(...) LT_LOG(LT_LEVEL_INFO, __FUNCTION__, __LINE__, __VA_ARGS__)
|
||||
#else
|
||||
#define LT_I(...)
|
||||
#endif
|
||||
|
||||
#if LT_LEVEL_WARN >= LT_LOGLEVEL
|
||||
#define LT_W(...) LT_LOG(LT_LEVEL_WARN, __FILE__, __LINE__, __VA_ARGS__)
|
||||
#define LT_W(...) LT_LOG(LT_LEVEL_WARN, __FUNCTION__, __LINE__, __VA_ARGS__)
|
||||
#else
|
||||
#define LT_W(...)
|
||||
#endif
|
||||
|
||||
#if LT_LEVEL_ERROR >= LT_LOGLEVEL
|
||||
#define LT_E(...) LT_LOG(LT_LEVEL_ERROR, __FILE__, __LINE__, __VA_ARGS__)
|
||||
#define LT_E(...) LT_LOG(LT_LEVEL_ERROR, __FUNCTION__, __LINE__, __VA_ARGS__)
|
||||
#else
|
||||
#define LT_E(...)
|
||||
#endif
|
||||
|
||||
#if LT_LEVEL_FATAL >= LT_LOGLEVEL
|
||||
#define LT_F(...) LT_LOG(LT_LEVEL_FATAL, __FILE__, __LINE__, __VA_ARGS__)
|
||||
#define LT_F(...) LT_LOG(LT_LEVEL_FATAL, __FUNCTION__, __LINE__, __VA_ARGS__)
|
||||
#else
|
||||
#define LT_F(...)
|
||||
#endif
|
||||
@@ -88,6 +88,42 @@ void lt_log(const uint8_t level, const char *format, ...);
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define LT_RET(ret) \
|
||||
LT_E("ret=%d", ret); \
|
||||
return ret;
|
||||
|
||||
#define LT_RET_NZ(ret) \
|
||||
if (ret) { \
|
||||
LT_E("ret=%d", ret); \
|
||||
return ret; \
|
||||
}
|
||||
#define LT_RET_LZ(ret) \
|
||||
if (ret < 0) { \
|
||||
LT_E("ret=%d", ret); \
|
||||
return ret; \
|
||||
}
|
||||
#define LT_RET_LEZ(ret) \
|
||||
if (ret <= 0) { \
|
||||
LT_E("ret=%d", ret); \
|
||||
return ret; \
|
||||
}
|
||||
|
||||
#define LT_ERRNO_NZ(ret) \
|
||||
if (ret) { \
|
||||
LT_E("errno=%d, ret=%d", errno, ret); \
|
||||
return ret; \
|
||||
}
|
||||
#define LT_ERRNO_LZ(ret) \
|
||||
if (ret < 0) { \
|
||||
LT_E("errno=%d, ret=%d", errno, ret); \
|
||||
return ret; \
|
||||
}
|
||||
#define LT_ERRNO_LEZ(ret) \
|
||||
if (ret <= 0) { \
|
||||
LT_E("errno=%d, ret=%d", errno, ret); \
|
||||
return ret; \
|
||||
}
|
||||
|
||||
// WiFi.cpp
|
||||
#define LT_T_WG(...) LT_T_MOD(LT_DEBUG_WIFI, __VA_ARGS__)
|
||||
#define LT_V_WG(...) LT_T_MOD(LT_DEBUG_WIFI, __VA_ARGS__)
|
||||
@@ -112,3 +148,8 @@ void lt_log(const uint8_t level, const char *format, ...);
|
||||
#define LT_T_WAP(...) LT_T_MOD(LT_DEBUG_WIFI_AP, __VA_ARGS__)
|
||||
#define LT_V_WAP(...) LT_T_MOD(LT_DEBUG_WIFI_AP, __VA_ARGS__)
|
||||
#define LT_D_WAP(...) LT_D_MOD(LT_DEBUG_WIFI_AP, __VA_ARGS__)
|
||||
|
||||
// WiFiClientSecure.cpp & implementations
|
||||
#define LT_T_SSL(...) LT_T_MOD(LT_DEBUG_SSL, __VA_ARGS__)
|
||||
#define LT_V_SSL(...) LT_T_MOD(LT_DEBUG_SSL, __VA_ARGS__)
|
||||
#define LT_D_SSL(...) LT_D_MOD(LT_DEBUG_SSL, __VA_ARGS__)
|
||||
|
||||
443
arduino/libretuya/libraries/NetUtils/ssl/MbedTLSClient.cpp
Normal file
443
arduino/libretuya/libraries/NetUtils/ssl/MbedTLSClient.cpp
Normal file
@@ -0,0 +1,443 @@
|
||||
/* Copyright (c) Kuba Szczodrzyński 2022-04-30. */
|
||||
|
||||
#include "MbedTLSClient.h"
|
||||
|
||||
#include <IPAddress.h>
|
||||
#include <WiFi.h>
|
||||
#include <WiFiClient.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
#include <mbedtls/debug.h>
|
||||
#include <mbedtls/platform.h>
|
||||
#include <mbedtls/sha256.h>
|
||||
#include <mbedtls/ssl.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif
|
||||
|
||||
MbedTLSClient::MbedTLSClient() : WiFiClient() {}
|
||||
|
||||
MbedTLSClient::MbedTLSClient(int sock) : WiFiClient(sock) {}
|
||||
|
||||
void MbedTLSClient::stop() {
|
||||
WiFiClient::stop();
|
||||
LT_V_SSL("Closing SSL connection");
|
||||
|
||||
if (_sslCfg.ca_chain) {
|
||||
mbedtls_x509_crt_free(&_caCert);
|
||||
}
|
||||
if (_sslCfg.key_cert) {
|
||||
mbedtls_x509_crt_free(&_clientCert);
|
||||
mbedtls_pk_free(&_clientKey);
|
||||
}
|
||||
mbedtls_ssl_free(&_sslCtx);
|
||||
mbedtls_ssl_config_free(&_sslCfg);
|
||||
}
|
||||
|
||||
void MbedTLSClient::init() {
|
||||
// Realtek AmbZ: init platform here to ensure HW crypto is initialized in ssl_init
|
||||
mbedtls_platform_set_calloc_free(calloc, free);
|
||||
mbedtls_ssl_init(&_sslCtx);
|
||||
mbedtls_ssl_config_init(&_sslCfg);
|
||||
}
|
||||
|
||||
int MbedTLSClient::connect(IPAddress ip, uint16_t port, int32_t timeout) {
|
||||
return connect(ipToString(ip).c_str(), port, timeout) == 0;
|
||||
}
|
||||
|
||||
int MbedTLSClient::connect(const char *host, uint16_t port, int32_t timeout) {
|
||||
if (_pskIdentStr && _pskStr)
|
||||
return connect(host, port, NULL, NULL, NULL, _pskIdentStr, _pskStr, _alpnProtocols) == 0;
|
||||
return connect(host, port, _caCertStr, _clientCertStr, _clientKeyStr, NULL, NULL, _alpnProtocols) == 0;
|
||||
}
|
||||
|
||||
int MbedTLSClient::connect(
|
||||
IPAddress ip, uint16_t port, const char *rootCABuf, const char *clientCert, const char *clientKey
|
||||
) {
|
||||
return connect(ipToString(ip).c_str(), port, rootCABuf, clientCert, clientKey, NULL, NULL, _alpnProtocols) == 0;
|
||||
}
|
||||
|
||||
int MbedTLSClient::connect(
|
||||
const char *host, uint16_t port, const char *rootCABuf, const char *clientCert, const char *clientKey
|
||||
) {
|
||||
return connect(host, port, rootCABuf, clientCert, clientKey, NULL, NULL, _alpnProtocols) == 0;
|
||||
}
|
||||
|
||||
int MbedTLSClient::connect(IPAddress ip, uint16_t port, const char *pskIdent, const char *psk) {
|
||||
return connect(ipToString(ip).c_str(), port, NULL, NULL, NULL, pskIdent, psk, _alpnProtocols) == 0;
|
||||
}
|
||||
|
||||
int MbedTLSClient::connect(const char *host, uint16_t port, const char *pskIdent, const char *psk) {
|
||||
return connect(host, port, NULL, NULL, NULL, pskIdent, psk, _alpnProtocols) == 0;
|
||||
}
|
||||
|
||||
static int ssl_random(void *data, unsigned char *output, size_t len) {
|
||||
int *buf = (int *)output;
|
||||
size_t i;
|
||||
for (i = 0; len >= sizeof(int); len -= sizeof(int)) {
|
||||
buf[i++] = rand();
|
||||
}
|
||||
if (len) {
|
||||
int rem = rand();
|
||||
unsigned char *pRem = (unsigned char *)&rem;
|
||||
memcpy(output + i * sizeof(int), pRem, len);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
void debug_cb(void *ctx, int level, const char *file, int line, const char *str) {
|
||||
LT_I("%04d: |%d| %s", line, level, str);
|
||||
}
|
||||
|
||||
int MbedTLSClient::connect(
|
||||
const char *host,
|
||||
uint16_t port,
|
||||
const char *rootCABuf,
|
||||
const char *clientCert,
|
||||
const char *clientKey,
|
||||
const char *pskIdent,
|
||||
const char *psk,
|
||||
const char **alpnProtocols
|
||||
) {
|
||||
LT_D_SSL("Free heap before TLS: TODO");
|
||||
|
||||
if (!rootCABuf && !pskIdent && !psk && !_insecure && !_useRootCA)
|
||||
return -1;
|
||||
|
||||
IPAddress addr = WiFi.hostByName(host);
|
||||
if (!(uint32_t)addr)
|
||||
return -1;
|
||||
|
||||
int ret = WiFiClient::connect(addr, port, _timeout);
|
||||
if (ret < 0) {
|
||||
LT_E("SSL socket failed");
|
||||
return ret;
|
||||
}
|
||||
|
||||
char *uid = "lt-ssl"; // TODO
|
||||
|
||||
LT_V_SSL("Init SSL");
|
||||
init();
|
||||
|
||||
// mbedtls_debug_set_threshold(4);
|
||||
// mbedtls_ssl_conf_dbg(&_sslCfg, debug_cb, NULL);
|
||||
|
||||
ret = mbedtls_ssl_config_defaults(
|
||||
&_sslCfg,
|
||||
MBEDTLS_SSL_IS_CLIENT,
|
||||
MBEDTLS_SSL_TRANSPORT_STREAM,
|
||||
MBEDTLS_SSL_PRESET_DEFAULT
|
||||
);
|
||||
LT_RET_NZ(ret);
|
||||
|
||||
#ifdef MBEDTLS_SSL_ALPN
|
||||
if (alpnProtocols) {
|
||||
ret = mbedtls_ssl_conf_alpn_protocols(&_sslCfg, alpnProtocols);
|
||||
LT_RET_NZ(ret);
|
||||
}
|
||||
#endif
|
||||
|
||||
if (_insecure) {
|
||||
mbedtls_ssl_conf_authmode(&_sslCfg, MBEDTLS_SSL_VERIFY_NONE);
|
||||
} else if (rootCABuf) {
|
||||
mbedtls_x509_crt_init(&_caCert);
|
||||
mbedtls_ssl_conf_authmode(&_sslCfg, MBEDTLS_SSL_VERIFY_REQUIRED);
|
||||
ret = mbedtls_x509_crt_parse(&_caCert, (const unsigned char *)rootCABuf, strlen(rootCABuf) + 1);
|
||||
mbedtls_ssl_conf_ca_chain(&_sslCfg, &_caCert, NULL);
|
||||
if (ret < 0) {
|
||||
mbedtls_x509_crt_free(&_caCert);
|
||||
LT_RET(ret);
|
||||
}
|
||||
} else if (_useRootCA) {
|
||||
return -1; // not implemented
|
||||
} else if (pskIdent && psk) {
|
||||
#ifdef MBEDTLS_KEY_EXCHANGE__SOME__PSK_ENABLED
|
||||
uint16_t len = strlen(psk);
|
||||
if ((len & 1) != 0 || len > 2 * MBEDTLS_PSK_MAX_LEN) {
|
||||
LT_E("PSK length invalid");
|
||||
return -1;
|
||||
}
|
||||
unsigned char pskBin[MBEDTLS_PSK_MAX_LEN] = {};
|
||||
for (uint8_t i = 0; i < len; i++) {
|
||||
uint8_t c = psk[i];
|
||||
c |= 0b00100000; // make lowercase
|
||||
c -= '0' * (c >= '0' && c <= '9');
|
||||
c -= ('a' - 10) * (c >= 'a' && c <= 'z');
|
||||
if (c > 0xf)
|
||||
return -1;
|
||||
pskBin[i / 2] |= c << (4 * ((i & 1) ^ 1));
|
||||
}
|
||||
ret = mbedtls_ssl_conf_psk(&_sslCfg, pskBin, len / 2, (const unsigned char *)pskIdent, strlen(pskIdent));
|
||||
LT_RET_NZ(ret);
|
||||
#else
|
||||
return -1;
|
||||
#endif
|
||||
} else {
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!_insecure && clientCert && clientKey) {
|
||||
mbedtls_x509_crt_init(&_clientCert);
|
||||
mbedtls_pk_init(&_clientKey);
|
||||
LT_V_SSL("Loading client cert");
|
||||
ret = mbedtls_x509_crt_parse(&_clientCert, (const unsigned char *)clientCert, strlen(clientCert) + 1);
|
||||
if (ret < 0) {
|
||||
mbedtls_x509_crt_free(&_clientCert);
|
||||
LT_RET(ret);
|
||||
}
|
||||
LT_V_SSL("Loading private key");
|
||||
ret = mbedtls_pk_parse_key(&_clientKey, (const unsigned char *)clientKey, strlen(clientKey) + 1, NULL, 0);
|
||||
if (ret < 0) {
|
||||
mbedtls_x509_crt_free(&_clientCert);
|
||||
LT_RET(ret);
|
||||
}
|
||||
mbedtls_ssl_conf_own_cert(&_sslCfg, &_clientCert, &_clientKey);
|
||||
}
|
||||
|
||||
LT_V_SSL("Setting TLS hostname");
|
||||
ret = mbedtls_ssl_set_hostname(&_sslCtx, host);
|
||||
LT_RET_NZ(ret);
|
||||
|
||||
mbedtls_ssl_conf_rng(&_sslCfg, ssl_random, NULL);
|
||||
ret = mbedtls_ssl_setup(&_sslCtx, &_sslCfg);
|
||||
LT_RET_NZ(ret);
|
||||
|
||||
_sockTls = fd();
|
||||
mbedtls_ssl_set_bio(&_sslCtx, &_sockTls, mbedtls_net_send, mbedtls_net_recv, NULL);
|
||||
|
||||
LT_V_SSL("SSL handshake");
|
||||
if (_handshakeTimeout == 0)
|
||||
_handshakeTimeout = _timeout * 1000;
|
||||
unsigned long start = millis();
|
||||
while (ret = mbedtls_ssl_handshake(&_sslCtx)) {
|
||||
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
|
||||
LT_RET(ret);
|
||||
}
|
||||
if ((millis() - start) > _handshakeTimeout) {
|
||||
LT_E("SSL handshake timeout");
|
||||
return -1;
|
||||
}
|
||||
delay(2);
|
||||
}
|
||||
|
||||
if (clientCert && clientKey) {
|
||||
LT_D_SSL(
|
||||
"Protocol %s, ciphersuite %s",
|
||||
mbedtls_ssl_get_version(&_sslCtx),
|
||||
mbedtls_ssl_get_ciphersuite(&_sslCtx)
|
||||
);
|
||||
ret = mbedtls_ssl_get_record_expansion(&_sslCtx);
|
||||
if (ret >= 0)
|
||||
LT_D_SSL("Record expansion: %d", ret);
|
||||
else {
|
||||
LT_W("Record expansion unknown");
|
||||
}
|
||||
}
|
||||
|
||||
LT_V_SSL("Verifying certificate");
|
||||
ret = mbedtls_ssl_get_verify_result(&_sslCtx);
|
||||
if (ret) {
|
||||
char buf[512];
|
||||
memset(buf, 0, sizeof(buf));
|
||||
mbedtls_x509_crt_verify_info(buf, sizeof(buf), " ! ", ret);
|
||||
LT_E("Failed to verify peer certificate! Verification info: %s", buf);
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (rootCABuf)
|
||||
mbedtls_x509_crt_free(&_caCert);
|
||||
if (clientCert)
|
||||
mbedtls_x509_crt_free(&_clientCert);
|
||||
if (clientKey != NULL)
|
||||
mbedtls_pk_free(&_clientKey);
|
||||
return 0; // OK
|
||||
}
|
||||
|
||||
size_t MbedTLSClient::write(const uint8_t *buf, size_t size) {
|
||||
int ret = -1;
|
||||
while ((ret = mbedtls_ssl_write(&_sslCtx, buf, size)) <= 0) {
|
||||
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE && ret < 0) {
|
||||
LT_RET(ret);
|
||||
}
|
||||
delay(2);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
int MbedTLSClient::available() {
|
||||
bool peeked = _peeked >= 0;
|
||||
if (!connected())
|
||||
return peeked;
|
||||
|
||||
int ret = mbedtls_ssl_read(&_sslCtx, NULL, 0);
|
||||
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE && ret < 0) {
|
||||
stop();
|
||||
return peeked ? peeked : ret;
|
||||
}
|
||||
return mbedtls_ssl_get_bytes_avail(&_sslCtx) + peeked;
|
||||
}
|
||||
|
||||
int MbedTLSClient::read(uint8_t *buf, size_t size) {
|
||||
bool peeked = false;
|
||||
int toRead = available();
|
||||
if ((!buf && size) || toRead <= 0)
|
||||
return -1;
|
||||
if (!size)
|
||||
return 0;
|
||||
if (_peeked >= 0) {
|
||||
buf[0] = _peeked;
|
||||
_peeked = -1;
|
||||
size--;
|
||||
toRead--;
|
||||
if (!size || !toRead)
|
||||
return 1;
|
||||
buf++;
|
||||
peeked = true;
|
||||
}
|
||||
|
||||
int ret = mbedtls_ssl_read(&_sslCtx, buf, size);
|
||||
if (ret < 0) {
|
||||
stop();
|
||||
return peeked ? peeked : ret;
|
||||
}
|
||||
return ret + peeked;
|
||||
}
|
||||
|
||||
int MbedTLSClient::peek() {
|
||||
if (_peeked >= 0)
|
||||
return _peeked;
|
||||
_peeked = timedRead();
|
||||
return _peeked;
|
||||
}
|
||||
|
||||
void MbedTLSClient::flush() {}
|
||||
|
||||
int MbedTLSClient::lastError(char *buf, const size_t size) {
|
||||
return 0; // TODO (?)
|
||||
}
|
||||
|
||||
void MbedTLSClient::setInsecure() {
|
||||
_caCertStr = NULL;
|
||||
_clientCertStr = NULL;
|
||||
_clientKeyStr = NULL;
|
||||
_pskIdentStr = NULL;
|
||||
_pskStr = NULL;
|
||||
_insecure = true;
|
||||
}
|
||||
|
||||
void MbedTLSClient::setPreSharedKey(const char *pskIdent, const char *psk) {
|
||||
_pskIdentStr = pskIdent;
|
||||
_pskStr = psk;
|
||||
}
|
||||
|
||||
void MbedTLSClient::setCACert(const char *rootCA) {
|
||||
_caCertStr = rootCA;
|
||||
}
|
||||
|
||||
void MbedTLSClient::setCertificate(const char *clientCA) {
|
||||
_clientCertStr = clientCA;
|
||||
}
|
||||
|
||||
void MbedTLSClient::setPrivateKey(const char *privateKey) {
|
||||
_clientKeyStr = privateKey;
|
||||
}
|
||||
|
||||
char *streamToStr(Stream &stream, size_t size) {
|
||||
char *buf = (char *)malloc(size + 1);
|
||||
if (!buf)
|
||||
return NULL;
|
||||
if (size != stream.readBytes(buf, size)) {
|
||||
free(buf);
|
||||
return NULL;
|
||||
}
|
||||
buf[size] = '\0';
|
||||
return buf;
|
||||
}
|
||||
|
||||
bool MbedTLSClient::loadCACert(Stream &stream, size_t size) {
|
||||
char *str = streamToStr(stream, size);
|
||||
if (str) {
|
||||
_caCertStr = str;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool MbedTLSClient::loadCertificate(Stream &stream, size_t size) {
|
||||
char *str = streamToStr(stream, size);
|
||||
if (str) {
|
||||
_clientCertStr = str;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool MbedTLSClient::loadPrivateKey(Stream &stream, size_t size) {
|
||||
char *str = streamToStr(stream, size);
|
||||
if (str) {
|
||||
_clientKeyStr = str;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool MbedTLSClient::verify(const char *fingerprint, const char *domainName) {
|
||||
uint8_t fpLocal[32] = {};
|
||||
uint16_t len = strlen(fingerprint);
|
||||
uint8_t byte = 0;
|
||||
for (uint8_t i = 0; i < len; i++) {
|
||||
uint8_t c = fingerprint[i];
|
||||
while ((c == ' ' || c == ':') && i < len) {
|
||||
c = fingerprint[++i];
|
||||
}
|
||||
c |= 0b00100000; // make lowercase
|
||||
c -= '0' * (c >= '0' && c <= '9');
|
||||
c -= ('a' - 10) * (c >= 'a' && c <= 'z');
|
||||
if (c > 0xf)
|
||||
return -1;
|
||||
fpLocal[byte / 2] |= c << (4 * ((byte & 1) ^ 1));
|
||||
byte++;
|
||||
if (byte >= 64)
|
||||
break;
|
||||
}
|
||||
|
||||
uint8_t fpRemote[32];
|
||||
if (!getFingerprintSHA256(fpRemote))
|
||||
return false;
|
||||
|
||||
if (memcmp(fpLocal, fpRemote, 32)) {
|
||||
LT_D_SSL("Fingerprints don't match");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!domainName)
|
||||
return true;
|
||||
// TODO domain name verification
|
||||
return true;
|
||||
}
|
||||
|
||||
void MbedTLSClient::setHandshakeTimeout(unsigned long handshakeTimeout) {
|
||||
_handshakeTimeout = handshakeTimeout * 1000;
|
||||
}
|
||||
|
||||
void MbedTLSClient::setAlpnProtocols(const char **alpnProtocols) {
|
||||
_alpnProtocols = alpnProtocols;
|
||||
}
|
||||
|
||||
bool MbedTLSClient::getFingerprintSHA256(uint8_t result[32]) {
|
||||
const mbedtls_x509_crt *cert = mbedtls_ssl_get_peer_cert(&_sslCtx);
|
||||
if (!cert) {
|
||||
LT_E("Failed to get peer certificate");
|
||||
return false;
|
||||
}
|
||||
mbedtls_sha256_context shaCtx;
|
||||
mbedtls_sha256_init(&shaCtx);
|
||||
mbedtls_sha256_starts(&shaCtx, false);
|
||||
mbedtls_sha256_update(&shaCtx, cert->raw.p, cert->raw.len);
|
||||
mbedtls_sha256_finish(&shaCtx, result);
|
||||
return true;
|
||||
}
|
||||
90
arduino/libretuya/libraries/NetUtils/ssl/MbedTLSClient.h
Normal file
90
arduino/libretuya/libraries/NetUtils/ssl/MbedTLSClient.h
Normal file
@@ -0,0 +1,90 @@
|
||||
/* Copyright (c) Kuba Szczodrzyński 2022-04-30. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <api/WiFiClient.h>
|
||||
#include <api/WiFiClientSecure.h>
|
||||
|
||||
#include <WiFiClient.h> // extend platform's WiFiClient impl
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
#include <mbedtls/net.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif
|
||||
|
||||
class MbedTLSClient : public WiFiClient, public IWiFiClientSecure {
|
||||
private:
|
||||
mbedtls_ssl_context _sslCtx;
|
||||
mbedtls_ssl_config _sslCfg;
|
||||
mbedtls_x509_crt _caCert;
|
||||
mbedtls_x509_crt _clientCert;
|
||||
mbedtls_pk_context _clientKey;
|
||||
uint32_t _handshakeTimeout = 0;
|
||||
|
||||
void init();
|
||||
int _sockTls = -1;
|
||||
bool _insecure = false;
|
||||
bool _useRootCA = false;
|
||||
int _peeked = -1;
|
||||
|
||||
const char *_caCertStr;
|
||||
const char *_clientCertStr;
|
||||
const char *_clientKeyStr;
|
||||
const char *_pskIdentStr;
|
||||
const char *_pskStr;
|
||||
const char **_alpnProtocols;
|
||||
|
||||
int connect(
|
||||
const char *host,
|
||||
uint16_t port,
|
||||
const char *rootCABuf,
|
||||
const char *clientCert,
|
||||
const char *clientKey,
|
||||
const char *pskIdent,
|
||||
const char *psk,
|
||||
const char **alpnProtocols
|
||||
);
|
||||
|
||||
public:
|
||||
MbedTLSClient();
|
||||
MbedTLSClient(int sock);
|
||||
|
||||
int connect(IPAddress ip, uint16_t port, int32_t timeout);
|
||||
int connect(const char *host, uint16_t port, int32_t timeout);
|
||||
|
||||
int connect(IPAddress ip, uint16_t port, const char *rootCABuf, const char *clientCert, const char *clientKey);
|
||||
int connect(const char *host, uint16_t port, const char *rootCABuf, const char *clientCert, const char *clientKey);
|
||||
int connect(IPAddress ip, uint16_t port, const char *pskIdent, const char *psk);
|
||||
int connect(const char *host, uint16_t port, const char *pskIdent, const char *psk);
|
||||
|
||||
size_t write(const uint8_t *buf, size_t size);
|
||||
|
||||
int available();
|
||||
|
||||
int read(uint8_t *buf, size_t size);
|
||||
int peek();
|
||||
void flush();
|
||||
void stop();
|
||||
|
||||
int lastError(char *buf, const size_t size);
|
||||
void setInsecure(); // Don't validate the chain, just accept whatever is given. VERY INSECURE!
|
||||
void setPreSharedKey(const char *pskIdent, const char *psk); // psk in hex
|
||||
void setCACert(const char *rootCA);
|
||||
void setCertificate(const char *clientCA);
|
||||
void setPrivateKey(const char *privateKey);
|
||||
bool loadCACert(Stream &stream, size_t size);
|
||||
bool loadCertificate(Stream &stream, size_t size);
|
||||
bool loadPrivateKey(Stream &stream, size_t size);
|
||||
bool verify(const char *fingerprint, const char *domainName);
|
||||
void setHandshakeTimeout(unsigned long handshakeTimeout);
|
||||
void setAlpnProtocols(const char **alpnProtocols);
|
||||
bool getFingerprintSHA256(uint8_t result[32]);
|
||||
|
||||
using WiFiClient::connect;
|
||||
using WiFiClient::read;
|
||||
};
|
||||
Reference in New Issue
Block a user