[core] Add MbedTLSClient

This commit is contained in:
Kuba Szczodrzyński
2022-05-05 20:56:37 +02:00
parent 783955cc5d
commit 9659ff8afa
8 changed files with 639 additions and 62 deletions

View File

@@ -7,3 +7,9 @@ __weak char *strdup(const char *s) {
return NULL;
return (char *)memcpy(newp, s, len);
}
String ipToString(const IPAddress &ip) {
char szRet[16];
sprintf(szRet, "%hhu.%hhu.%hhu.%hhu", ip[0], ip[1], ip[2], ip[3]);
return String(szRet);
}

View File

@@ -1,22 +1,22 @@
#pragma once
// LibreTuya version macros
#ifndef LT_VERSION
#define LT_VERSION 1.0.0
#endif
#ifndef LT_BOARD
#define LT_BOARD unknown
#endif
#define STRINGIFY(x) #x
#define STRINGIFY_MACRO(x) STRINGIFY(x)
#define LT_VERSION_STR STRINGIFY_MACRO(LT_VERSION)
#define LT_BOARD_STR STRINGIFY_MACRO(LT_BOARD)
// Includes
#include "LibreTuyaConfig.h"
#include <Arduino.h>
#include "LibreTuyaConfig.h"
// C includes
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
@@ -27,16 +27,23 @@ extern "C" {
} // extern "C"
#endif
// Functional macros
#define LT_BANNER() \
LT_LOG( \
LT_LEVEL_INFO, \
"main.cpp", \
__FUNCTION__, \
__LINE__, \
"LibreTuya v" LT_VERSION_STR " on " LT_BOARD_STR ", compiled at " __DATE__ " " __TIME__ \
)
extern char *strdup(const char *);
// ArduinCore-API doesn't define these anymore
// ArduinoCore-API doesn't define these anymore
#define FPSTR(pstr_pointer) (reinterpret_cast<const __FlashStringHelper *>(pstr_pointer))
#define PGM_VOID_P const void *
// C functions
extern char *strdup(const char *);
// C++ only functions
#ifdef __cplusplus
String ipToString(const IPAddress &ip);
#endif

View File

@@ -21,8 +21,8 @@
#define LT_LOGGER_TIMESTAMP 1
#endif
#ifndef LT_LOGGER_FILE
#define LT_LOGGER_FILE 0
#ifndef LT_LOGGER_CALLER
#define LT_LOGGER_CALLER 1
#endif
#ifndef LT_LOGGER_TASK
@@ -62,3 +62,7 @@
#ifndef LT_DEBUG_WIFI_AP
#define LT_DEBUG_WIFI_AP 0
#endif
#ifndef LT_DEBUG_SSL
#define LT_DEBUG_SSL 0
#endif

View File

@@ -22,42 +22,28 @@
#include <Arduino.h>
#include "WiFi.h"
#include "WiFiClient.h"
class IWiFiClientSecure : public IWiFiClient {
class IWiFiClientSecure {
public:
int connect(IPAddress ip, uint16_t port, const char *rootCABuff, const char *cli_cert, const char *cli_key);
int connect(const char *host, uint16_t port, const char *rootCABuff, const char *cli_cert, const char *cli_key);
int connect(IPAddress ip, uint16_t port, const char *pskIdent, const char *psKey);
int connect(const char *host, uint16_t port, const char *pskIdent, const char *psKey);
virtual int
connect(IPAddress ip, uint16_t port, const char *rootCABuf, const char *clientCert, const char *clientKey) = 0;
virtual int
connect(const char *host, uint16_t port, const char *rootCABuf, const char *clientCert, const char *clientKey) = 0;
virtual int connect(IPAddress ip, uint16_t port, const char *pskIdent, const char *psk) = 0;
virtual int connect(const char *host, uint16_t port, const char *pskIdent, const char *psk) = 0;
int lastError(char *buf, const size_t size);
void setInsecure(); // Don't validate the chain, just accept whatever is given. VERY INSECURE!
void setPreSharedKey(const char *pskIdent, const char *psKey); // psKey in Hex
void setCACert(const char *rootCA);
void setCertificate(const char *client_ca);
void setPrivateKey(const char *private_key);
bool loadCACert(Stream &stream, size_t size);
bool loadCertificate(Stream &stream, size_t size);
bool loadPrivateKey(Stream &stream, size_t size);
bool verify(const char *fingerprint, const char *domain_name);
void setHandshakeTimeout(unsigned long handshake_timeout);
WiFiClientSecure &operator=(const WiFiClientSecure &other);
bool operator==(const bool value) {
return bool() == value;
}
bool operator!=(const bool value) {
return bool() != value;
}
bool operator==(const WiFiClientSecure &);
bool operator!=(const WiFiClientSecure &rhs) {
return !this->operator==(rhs);
};
using Print::write;
virtual int lastError(char *buf, const size_t size) = 0;
virtual void setInsecure() = 0; // Don't validate the chain, just accept whatever is given. VERY INSECURE!
virtual void setPreSharedKey(const char *pskIdent, const char *psk) = 0; // psk in hex
virtual void setCACert(const char *rootCA) = 0;
virtual void setCertificate(const char *clientCA) = 0;
virtual void setPrivateKey(const char *privateKey) = 0;
virtual bool loadCACert(Stream &stream, size_t size) = 0;
virtual bool loadCertificate(Stream &stream, size_t size) = 0;
virtual bool loadPrivateKey(Stream &stream, size_t size) = 0;
virtual bool verify(const char *fingerprint, const char *domainName) = 0;
virtual void setHandshakeTimeout(unsigned long handshakeTimeout) = 0;
virtual void setAlpnProtocols(const char **alpnProtocols) = 0;
virtual bool getFingerprintSHA256(uint8_t result[32]) = 0;
};

View File

@@ -40,8 +40,8 @@ const uint8_t colors[] = {
unsigned long millis(void);
#if LT_LOGGER_FILE
void lt_log(const uint8_t level, const char *filename, const unsigned short line, const char *format, ...) {
#if LT_LOGGER_CALLER
void lt_log(const uint8_t level, const char *caller, const unsigned short line, const char *format, ...) {
#else
void lt_log(const uint8_t level, const char *format, ...) {
#endif
@@ -85,8 +85,8 @@ void lt_log(const uint8_t level, const char *format, ...) {
#if LT_LOGGER_COLOR
"\e[0m"
#endif
#if LT_LOGGER_FILE
"%s:%hu: "
#if LT_LOGGER_CALLER
"%s():%hu: "
#endif
#if LT_LOGGER_TASK
"%s%c "
@@ -106,9 +106,9 @@ void lt_log(const uint8_t level, const char *format, ...) {
zero // append missing zeroes if printf "%11.3f" prints "0."
#endif
#endif
#if LT_LOGGER_FILE
#if LT_LOGGER_CALLER
,
filename,
caller,
line
#endif
#if LT_LOGGER_TASK

View File

@@ -3,48 +3,48 @@
#include "LibreTuyaConfig.h"
#include <stdint.h>
#if LT_LOGGER_FILE
#define LT_LOG(level, file, line, ...) lt_log(level, file, line, __VA_ARGS__)
void lt_log(const uint8_t level, const char *filename, const unsigned short line, const char *format, ...);
#if LT_LOGGER_CALLER
#define LT_LOG(level, caller, line, ...) lt_log(level, caller, line, __VA_ARGS__)
void lt_log(const uint8_t level, const char *caller, const unsigned short line, const char *format, ...);
#else
#define LT_LOG(level, file, line, ...) lt_log(level, __VA_ARGS__)
#define LT_LOG(level, caller, line, ...) lt_log(level, __VA_ARGS__)
void lt_log(const uint8_t level, const char *format, ...);
#endif
#if LT_LEVEL_TRACE >= LT_LOGLEVEL
#define LT_T(...) LT_LOG(LT_LEVEL_TRACE, __FILE__, __LINE__, __VA_ARGS__)
#define LT_V(...) LT_LOG(LT_LEVEL_TRACE, __FILE__, __LINE__, __VA_ARGS__)
#define LT_T(...) LT_LOG(LT_LEVEL_TRACE, __FUNCTION__, __LINE__, __VA_ARGS__)
#define LT_V(...) LT_LOG(LT_LEVEL_TRACE, __FUNCTION__, __LINE__, __VA_ARGS__)
#else
#define LT_T(...)
#define LT_V(...)
#endif
#if LT_LEVEL_DEBUG >= LT_LOGLEVEL
#define LT_D(...) LT_LOG(LT_LEVEL_DEBUG, __FILE__, __LINE__, __VA_ARGS__)
#define LT_D(...) LT_LOG(LT_LEVEL_DEBUG, __FUNCTION__, __LINE__, __VA_ARGS__)
#else
#define LT_D(...)
#endif
#if LT_LEVEL_INFO >= LT_LOGLEVEL
#define LT_I(...) LT_LOG(LT_LEVEL_INFO, __FILE__, __LINE__, __VA_ARGS__)
#define LT_I(...) LT_LOG(LT_LEVEL_INFO, __FUNCTION__, __LINE__, __VA_ARGS__)
#else
#define LT_I(...)
#endif
#if LT_LEVEL_WARN >= LT_LOGLEVEL
#define LT_W(...) LT_LOG(LT_LEVEL_WARN, __FILE__, __LINE__, __VA_ARGS__)
#define LT_W(...) LT_LOG(LT_LEVEL_WARN, __FUNCTION__, __LINE__, __VA_ARGS__)
#else
#define LT_W(...)
#endif
#if LT_LEVEL_ERROR >= LT_LOGLEVEL
#define LT_E(...) LT_LOG(LT_LEVEL_ERROR, __FILE__, __LINE__, __VA_ARGS__)
#define LT_E(...) LT_LOG(LT_LEVEL_ERROR, __FUNCTION__, __LINE__, __VA_ARGS__)
#else
#define LT_E(...)
#endif
#if LT_LEVEL_FATAL >= LT_LOGLEVEL
#define LT_F(...) LT_LOG(LT_LEVEL_FATAL, __FILE__, __LINE__, __VA_ARGS__)
#define LT_F(...) LT_LOG(LT_LEVEL_FATAL, __FUNCTION__, __LINE__, __VA_ARGS__)
#else
#define LT_F(...)
#endif
@@ -88,6 +88,42 @@ void lt_log(const uint8_t level, const char *format, ...);
} \
} while (0)
#define LT_RET(ret) \
LT_E("ret=%d", ret); \
return ret;
#define LT_RET_NZ(ret) \
if (ret) { \
LT_E("ret=%d", ret); \
return ret; \
}
#define LT_RET_LZ(ret) \
if (ret < 0) { \
LT_E("ret=%d", ret); \
return ret; \
}
#define LT_RET_LEZ(ret) \
if (ret <= 0) { \
LT_E("ret=%d", ret); \
return ret; \
}
#define LT_ERRNO_NZ(ret) \
if (ret) { \
LT_E("errno=%d, ret=%d", errno, ret); \
return ret; \
}
#define LT_ERRNO_LZ(ret) \
if (ret < 0) { \
LT_E("errno=%d, ret=%d", errno, ret); \
return ret; \
}
#define LT_ERRNO_LEZ(ret) \
if (ret <= 0) { \
LT_E("errno=%d, ret=%d", errno, ret); \
return ret; \
}
// WiFi.cpp
#define LT_T_WG(...) LT_T_MOD(LT_DEBUG_WIFI, __VA_ARGS__)
#define LT_V_WG(...) LT_T_MOD(LT_DEBUG_WIFI, __VA_ARGS__)
@@ -112,3 +148,8 @@ void lt_log(const uint8_t level, const char *format, ...);
#define LT_T_WAP(...) LT_T_MOD(LT_DEBUG_WIFI_AP, __VA_ARGS__)
#define LT_V_WAP(...) LT_T_MOD(LT_DEBUG_WIFI_AP, __VA_ARGS__)
#define LT_D_WAP(...) LT_D_MOD(LT_DEBUG_WIFI_AP, __VA_ARGS__)
// WiFiClientSecure.cpp & implementations
#define LT_T_SSL(...) LT_T_MOD(LT_DEBUG_SSL, __VA_ARGS__)
#define LT_V_SSL(...) LT_T_MOD(LT_DEBUG_SSL, __VA_ARGS__)
#define LT_D_SSL(...) LT_D_MOD(LT_DEBUG_SSL, __VA_ARGS__)

View File

@@ -0,0 +1,443 @@
/* Copyright (c) Kuba Szczodrzyński 2022-04-30. */
#include "MbedTLSClient.h"
#include <IPAddress.h>
#include <WiFi.h>
#include <WiFiClient.h>
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
#include <mbedtls/debug.h>
#include <mbedtls/platform.h>
#include <mbedtls/sha256.h>
#include <mbedtls/ssl.h>
#ifdef __cplusplus
} // extern "C"
#endif
MbedTLSClient::MbedTLSClient() : WiFiClient() {}
MbedTLSClient::MbedTLSClient(int sock) : WiFiClient(sock) {}
void MbedTLSClient::stop() {
WiFiClient::stop();
LT_V_SSL("Closing SSL connection");
if (_sslCfg.ca_chain) {
mbedtls_x509_crt_free(&_caCert);
}
if (_sslCfg.key_cert) {
mbedtls_x509_crt_free(&_clientCert);
mbedtls_pk_free(&_clientKey);
}
mbedtls_ssl_free(&_sslCtx);
mbedtls_ssl_config_free(&_sslCfg);
}
void MbedTLSClient::init() {
// Realtek AmbZ: init platform here to ensure HW crypto is initialized in ssl_init
mbedtls_platform_set_calloc_free(calloc, free);
mbedtls_ssl_init(&_sslCtx);
mbedtls_ssl_config_init(&_sslCfg);
}
int MbedTLSClient::connect(IPAddress ip, uint16_t port, int32_t timeout) {
return connect(ipToString(ip).c_str(), port, timeout) == 0;
}
int MbedTLSClient::connect(const char *host, uint16_t port, int32_t timeout) {
if (_pskIdentStr && _pskStr)
return connect(host, port, NULL, NULL, NULL, _pskIdentStr, _pskStr, _alpnProtocols) == 0;
return connect(host, port, _caCertStr, _clientCertStr, _clientKeyStr, NULL, NULL, _alpnProtocols) == 0;
}
int MbedTLSClient::connect(
IPAddress ip, uint16_t port, const char *rootCABuf, const char *clientCert, const char *clientKey
) {
return connect(ipToString(ip).c_str(), port, rootCABuf, clientCert, clientKey, NULL, NULL, _alpnProtocols) == 0;
}
int MbedTLSClient::connect(
const char *host, uint16_t port, const char *rootCABuf, const char *clientCert, const char *clientKey
) {
return connect(host, port, rootCABuf, clientCert, clientKey, NULL, NULL, _alpnProtocols) == 0;
}
int MbedTLSClient::connect(IPAddress ip, uint16_t port, const char *pskIdent, const char *psk) {
return connect(ipToString(ip).c_str(), port, NULL, NULL, NULL, pskIdent, psk, _alpnProtocols) == 0;
}
int MbedTLSClient::connect(const char *host, uint16_t port, const char *pskIdent, const char *psk) {
return connect(host, port, NULL, NULL, NULL, pskIdent, psk, _alpnProtocols) == 0;
}
static int ssl_random(void *data, unsigned char *output, size_t len) {
int *buf = (int *)output;
size_t i;
for (i = 0; len >= sizeof(int); len -= sizeof(int)) {
buf[i++] = rand();
}
if (len) {
int rem = rand();
unsigned char *pRem = (unsigned char *)&rem;
memcpy(output + i * sizeof(int), pRem, len);
}
return 0;
}
void debug_cb(void *ctx, int level, const char *file, int line, const char *str) {
LT_I("%04d: |%d| %s", line, level, str);
}
int MbedTLSClient::connect(
const char *host,
uint16_t port,
const char *rootCABuf,
const char *clientCert,
const char *clientKey,
const char *pskIdent,
const char *psk,
const char **alpnProtocols
) {
LT_D_SSL("Free heap before TLS: TODO");
if (!rootCABuf && !pskIdent && !psk && !_insecure && !_useRootCA)
return -1;
IPAddress addr = WiFi.hostByName(host);
if (!(uint32_t)addr)
return -1;
int ret = WiFiClient::connect(addr, port, _timeout);
if (ret < 0) {
LT_E("SSL socket failed");
return ret;
}
char *uid = "lt-ssl"; // TODO
LT_V_SSL("Init SSL");
init();
// mbedtls_debug_set_threshold(4);
// mbedtls_ssl_conf_dbg(&_sslCfg, debug_cb, NULL);
ret = mbedtls_ssl_config_defaults(
&_sslCfg,
MBEDTLS_SSL_IS_CLIENT,
MBEDTLS_SSL_TRANSPORT_STREAM,
MBEDTLS_SSL_PRESET_DEFAULT
);
LT_RET_NZ(ret);
#ifdef MBEDTLS_SSL_ALPN
if (alpnProtocols) {
ret = mbedtls_ssl_conf_alpn_protocols(&_sslCfg, alpnProtocols);
LT_RET_NZ(ret);
}
#endif
if (_insecure) {
mbedtls_ssl_conf_authmode(&_sslCfg, MBEDTLS_SSL_VERIFY_NONE);
} else if (rootCABuf) {
mbedtls_x509_crt_init(&_caCert);
mbedtls_ssl_conf_authmode(&_sslCfg, MBEDTLS_SSL_VERIFY_REQUIRED);
ret = mbedtls_x509_crt_parse(&_caCert, (const unsigned char *)rootCABuf, strlen(rootCABuf) + 1);
mbedtls_ssl_conf_ca_chain(&_sslCfg, &_caCert, NULL);
if (ret < 0) {
mbedtls_x509_crt_free(&_caCert);
LT_RET(ret);
}
} else if (_useRootCA) {
return -1; // not implemented
} else if (pskIdent && psk) {
#ifdef MBEDTLS_KEY_EXCHANGE__SOME__PSK_ENABLED
uint16_t len = strlen(psk);
if ((len & 1) != 0 || len > 2 * MBEDTLS_PSK_MAX_LEN) {
LT_E("PSK length invalid");
return -1;
}
unsigned char pskBin[MBEDTLS_PSK_MAX_LEN] = {};
for (uint8_t i = 0; i < len; i++) {
uint8_t c = psk[i];
c |= 0b00100000; // make lowercase
c -= '0' * (c >= '0' && c <= '9');
c -= ('a' - 10) * (c >= 'a' && c <= 'z');
if (c > 0xf)
return -1;
pskBin[i / 2] |= c << (4 * ((i & 1) ^ 1));
}
ret = mbedtls_ssl_conf_psk(&_sslCfg, pskBin, len / 2, (const unsigned char *)pskIdent, strlen(pskIdent));
LT_RET_NZ(ret);
#else
return -1;
#endif
} else {
return -1;
}
if (!_insecure && clientCert && clientKey) {
mbedtls_x509_crt_init(&_clientCert);
mbedtls_pk_init(&_clientKey);
LT_V_SSL("Loading client cert");
ret = mbedtls_x509_crt_parse(&_clientCert, (const unsigned char *)clientCert, strlen(clientCert) + 1);
if (ret < 0) {
mbedtls_x509_crt_free(&_clientCert);
LT_RET(ret);
}
LT_V_SSL("Loading private key");
ret = mbedtls_pk_parse_key(&_clientKey, (const unsigned char *)clientKey, strlen(clientKey) + 1, NULL, 0);
if (ret < 0) {
mbedtls_x509_crt_free(&_clientCert);
LT_RET(ret);
}
mbedtls_ssl_conf_own_cert(&_sslCfg, &_clientCert, &_clientKey);
}
LT_V_SSL("Setting TLS hostname");
ret = mbedtls_ssl_set_hostname(&_sslCtx, host);
LT_RET_NZ(ret);
mbedtls_ssl_conf_rng(&_sslCfg, ssl_random, NULL);
ret = mbedtls_ssl_setup(&_sslCtx, &_sslCfg);
LT_RET_NZ(ret);
_sockTls = fd();
mbedtls_ssl_set_bio(&_sslCtx, &_sockTls, mbedtls_net_send, mbedtls_net_recv, NULL);
LT_V_SSL("SSL handshake");
if (_handshakeTimeout == 0)
_handshakeTimeout = _timeout * 1000;
unsigned long start = millis();
while (ret = mbedtls_ssl_handshake(&_sslCtx)) {
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
LT_RET(ret);
}
if ((millis() - start) > _handshakeTimeout) {
LT_E("SSL handshake timeout");
return -1;
}
delay(2);
}
if (clientCert && clientKey) {
LT_D_SSL(
"Protocol %s, ciphersuite %s",
mbedtls_ssl_get_version(&_sslCtx),
mbedtls_ssl_get_ciphersuite(&_sslCtx)
);
ret = mbedtls_ssl_get_record_expansion(&_sslCtx);
if (ret >= 0)
LT_D_SSL("Record expansion: %d", ret);
else {
LT_W("Record expansion unknown");
}
}
LT_V_SSL("Verifying certificate");
ret = mbedtls_ssl_get_verify_result(&_sslCtx);
if (ret) {
char buf[512];
memset(buf, 0, sizeof(buf));
mbedtls_x509_crt_verify_info(buf, sizeof(buf), " ! ", ret);
LT_E("Failed to verify peer certificate! Verification info: %s", buf);
return ret;
}
if (rootCABuf)
mbedtls_x509_crt_free(&_caCert);
if (clientCert)
mbedtls_x509_crt_free(&_clientCert);
if (clientKey != NULL)
mbedtls_pk_free(&_clientKey);
return 0; // OK
}
size_t MbedTLSClient::write(const uint8_t *buf, size_t size) {
int ret = -1;
while ((ret = mbedtls_ssl_write(&_sslCtx, buf, size)) <= 0) {
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE && ret < 0) {
LT_RET(ret);
}
delay(2);
}
return ret;
}
int MbedTLSClient::available() {
bool peeked = _peeked >= 0;
if (!connected())
return peeked;
int ret = mbedtls_ssl_read(&_sslCtx, NULL, 0);
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE && ret < 0) {
stop();
return peeked ? peeked : ret;
}
return mbedtls_ssl_get_bytes_avail(&_sslCtx) + peeked;
}
int MbedTLSClient::read(uint8_t *buf, size_t size) {
bool peeked = false;
int toRead = available();
if ((!buf && size) || toRead <= 0)
return -1;
if (!size)
return 0;
if (_peeked >= 0) {
buf[0] = _peeked;
_peeked = -1;
size--;
toRead--;
if (!size || !toRead)
return 1;
buf++;
peeked = true;
}
int ret = mbedtls_ssl_read(&_sslCtx, buf, size);
if (ret < 0) {
stop();
return peeked ? peeked : ret;
}
return ret + peeked;
}
int MbedTLSClient::peek() {
if (_peeked >= 0)
return _peeked;
_peeked = timedRead();
return _peeked;
}
void MbedTLSClient::flush() {}
int MbedTLSClient::lastError(char *buf, const size_t size) {
return 0; // TODO (?)
}
void MbedTLSClient::setInsecure() {
_caCertStr = NULL;
_clientCertStr = NULL;
_clientKeyStr = NULL;
_pskIdentStr = NULL;
_pskStr = NULL;
_insecure = true;
}
void MbedTLSClient::setPreSharedKey(const char *pskIdent, const char *psk) {
_pskIdentStr = pskIdent;
_pskStr = psk;
}
void MbedTLSClient::setCACert(const char *rootCA) {
_caCertStr = rootCA;
}
void MbedTLSClient::setCertificate(const char *clientCA) {
_clientCertStr = clientCA;
}
void MbedTLSClient::setPrivateKey(const char *privateKey) {
_clientKeyStr = privateKey;
}
char *streamToStr(Stream &stream, size_t size) {
char *buf = (char *)malloc(size + 1);
if (!buf)
return NULL;
if (size != stream.readBytes(buf, size)) {
free(buf);
return NULL;
}
buf[size] = '\0';
return buf;
}
bool MbedTLSClient::loadCACert(Stream &stream, size_t size) {
char *str = streamToStr(stream, size);
if (str) {
_caCertStr = str;
return true;
}
return false;
}
bool MbedTLSClient::loadCertificate(Stream &stream, size_t size) {
char *str = streamToStr(stream, size);
if (str) {
_clientCertStr = str;
return true;
}
return false;
}
bool MbedTLSClient::loadPrivateKey(Stream &stream, size_t size) {
char *str = streamToStr(stream, size);
if (str) {
_clientKeyStr = str;
return true;
}
return false;
}
bool MbedTLSClient::verify(const char *fingerprint, const char *domainName) {
uint8_t fpLocal[32] = {};
uint16_t len = strlen(fingerprint);
uint8_t byte = 0;
for (uint8_t i = 0; i < len; i++) {
uint8_t c = fingerprint[i];
while ((c == ' ' || c == ':') && i < len) {
c = fingerprint[++i];
}
c |= 0b00100000; // make lowercase
c -= '0' * (c >= '0' && c <= '9');
c -= ('a' - 10) * (c >= 'a' && c <= 'z');
if (c > 0xf)
return -1;
fpLocal[byte / 2] |= c << (4 * ((byte & 1) ^ 1));
byte++;
if (byte >= 64)
break;
}
uint8_t fpRemote[32];
if (!getFingerprintSHA256(fpRemote))
return false;
if (memcmp(fpLocal, fpRemote, 32)) {
LT_D_SSL("Fingerprints don't match");
return false;
}
if (!domainName)
return true;
// TODO domain name verification
return true;
}
void MbedTLSClient::setHandshakeTimeout(unsigned long handshakeTimeout) {
_handshakeTimeout = handshakeTimeout * 1000;
}
void MbedTLSClient::setAlpnProtocols(const char **alpnProtocols) {
_alpnProtocols = alpnProtocols;
}
bool MbedTLSClient::getFingerprintSHA256(uint8_t result[32]) {
const mbedtls_x509_crt *cert = mbedtls_ssl_get_peer_cert(&_sslCtx);
if (!cert) {
LT_E("Failed to get peer certificate");
return false;
}
mbedtls_sha256_context shaCtx;
mbedtls_sha256_init(&shaCtx);
mbedtls_sha256_starts(&shaCtx, false);
mbedtls_sha256_update(&shaCtx, cert->raw.p, cert->raw.len);
mbedtls_sha256_finish(&shaCtx, result);
return true;
}

View File

@@ -0,0 +1,90 @@
/* Copyright (c) Kuba Szczodrzyński 2022-04-30. */
#pragma once
#include <api/WiFiClient.h>
#include <api/WiFiClientSecure.h>
#include <WiFiClient.h> // extend platform's WiFiClient impl
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
#include <mbedtls/net.h>
#ifdef __cplusplus
} // extern "C"
#endif
class MbedTLSClient : public WiFiClient, public IWiFiClientSecure {
private:
mbedtls_ssl_context _sslCtx;
mbedtls_ssl_config _sslCfg;
mbedtls_x509_crt _caCert;
mbedtls_x509_crt _clientCert;
mbedtls_pk_context _clientKey;
uint32_t _handshakeTimeout = 0;
void init();
int _sockTls = -1;
bool _insecure = false;
bool _useRootCA = false;
int _peeked = -1;
const char *_caCertStr;
const char *_clientCertStr;
const char *_clientKeyStr;
const char *_pskIdentStr;
const char *_pskStr;
const char **_alpnProtocols;
int connect(
const char *host,
uint16_t port,
const char *rootCABuf,
const char *clientCert,
const char *clientKey,
const char *pskIdent,
const char *psk,
const char **alpnProtocols
);
public:
MbedTLSClient();
MbedTLSClient(int sock);
int connect(IPAddress ip, uint16_t port, int32_t timeout);
int connect(const char *host, uint16_t port, int32_t timeout);
int connect(IPAddress ip, uint16_t port, const char *rootCABuf, const char *clientCert, const char *clientKey);
int connect(const char *host, uint16_t port, const char *rootCABuf, const char *clientCert, const char *clientKey);
int connect(IPAddress ip, uint16_t port, const char *pskIdent, const char *psk);
int connect(const char *host, uint16_t port, const char *pskIdent, const char *psk);
size_t write(const uint8_t *buf, size_t size);
int available();
int read(uint8_t *buf, size_t size);
int peek();
void flush();
void stop();
int lastError(char *buf, const size_t size);
void setInsecure(); // Don't validate the chain, just accept whatever is given. VERY INSECURE!
void setPreSharedKey(const char *pskIdent, const char *psk); // psk in hex
void setCACert(const char *rootCA);
void setCertificate(const char *clientCA);
void setPrivateKey(const char *privateKey);
bool loadCACert(Stream &stream, size_t size);
bool loadCertificate(Stream &stream, size_t size);
bool loadPrivateKey(Stream &stream, size_t size);
bool verify(const char *fingerprint, const char *domainName);
void setHandshakeTimeout(unsigned long handshakeTimeout);
void setAlpnProtocols(const char **alpnProtocols);
bool getFingerprintSHA256(uint8_t result[32]);
using WiFiClient::connect;
using WiFiClient::read;
};