[core] Fix and enable SSL client support

This commit is contained in:
Kuba Szczodrzyński
2022-05-08 18:43:10 +02:00
parent ccf63a4cdb
commit 2e80469ab3
7 changed files with 35 additions and 23 deletions

View File

@@ -185,7 +185,7 @@ Wi-Fi Client (SSL) | ✔️ (✔️)
Wi-Fi Server | ✔️
Wi-Fi Events | ❌
IPv6 | ❌
HTTP Client (SSL) | ✔️ ()
HTTP Client (SSL) | ✔️ (✔️)
HTTP Server | ✔️
NVS / Preferences | ❌
SPIFFS | ❌

View File

@@ -29,7 +29,7 @@
#ifdef HTTPCLIENT_1_1_COMPATIBLE
#include <WiFi.h>
// #include <WiFiClientSecure.h>
#include <WiFiClientSecure.h>
#endif
// #include <StreamString.h>
@@ -63,7 +63,7 @@ class TLSTraits : public TransportTraits {
TLSTraits(const char *CAcert, const char *clicert = nullptr, const char *clikey = nullptr)
: _cacert(CAcert), _clicert(clicert), _clikey(clikey) {}
/* std::unique_ptr<WiFiClient> create() override {
std::unique_ptr<WiFiClient> create() override {
return std::unique_ptr<WiFiClient>(new WiFiClientSecure());
}
@@ -77,7 +77,7 @@ class TLSTraits : public TransportTraits {
wcs.setPrivateKey(_clikey);
}
return true;
} */
}
protected:
const char *_cacert;

View File

@@ -33,7 +33,7 @@
#include <Arduino.h>
#include <WiFiClient.h>
// #include <WiFiClientSecure.h>
#include <WiFiClientSecure.h>
#include <memory>
/// Cookie jar support

View File

@@ -19,9 +19,13 @@ extern "C" {
} // extern "C"
#endif
MbedTLSClient::MbedTLSClient() : WiFiClient() {}
MbedTLSClient::MbedTLSClient() : WiFiClient() {
init(); // ensure the context is zero filled
}
MbedTLSClient::MbedTLSClient(int sock) : WiFiClient(sock) {}
MbedTLSClient::MbedTLSClient(int sock) : WiFiClient(sock) {
init(); // ensure the context is zero filled
}
void MbedTLSClient::stop() {
WiFiClient::stop();
@@ -46,33 +50,33 @@ void MbedTLSClient::init() {
}
int MbedTLSClient::connect(IPAddress ip, uint16_t port, int32_t timeout) {
return connect(ipToString(ip).c_str(), port, timeout) == 0;
return connect(ipToString(ip).c_str(), port, timeout);
}
int MbedTLSClient::connect(const char *host, uint16_t port, int32_t timeout) {
if (_pskIdentStr && _pskStr)
return connect(host, port, NULL, NULL, NULL, _pskIdentStr, _pskStr, _alpnProtocols) == 0;
return connect(host, port, _caCertStr, _clientCertStr, _clientKeyStr, NULL, NULL, _alpnProtocols) == 0;
return connect(host, port, timeout, NULL, NULL, NULL, _pskIdentStr, _pskStr) == 0;
return connect(host, port, timeout, _caCertStr, _clientCertStr, _clientKeyStr, NULL, NULL) == 0;
}
int MbedTLSClient::connect(
IPAddress ip, uint16_t port, const char *rootCABuf, const char *clientCert, const char *clientKey
) {
return connect(ipToString(ip).c_str(), port, rootCABuf, clientCert, clientKey, NULL, NULL, _alpnProtocols) == 0;
return connect(ipToString(ip).c_str(), port, 0, rootCABuf, clientCert, clientKey, NULL, NULL) == 0;
}
int MbedTLSClient::connect(
const char *host, uint16_t port, const char *rootCABuf, const char *clientCert, const char *clientKey
) {
return connect(host, port, rootCABuf, clientCert, clientKey, NULL, NULL, _alpnProtocols) == 0;
return connect(host, port, 0, rootCABuf, clientCert, clientKey, NULL, NULL) == 0;
}
int MbedTLSClient::connect(IPAddress ip, uint16_t port, const char *pskIdent, const char *psk) {
return connect(ipToString(ip).c_str(), port, NULL, NULL, NULL, pskIdent, psk, _alpnProtocols) == 0;
return connect(ipToString(ip).c_str(), port, 0, NULL, NULL, NULL, pskIdent, psk) == 0;
}
int MbedTLSClient::connect(const char *host, uint16_t port, const char *pskIdent, const char *psk) {
return connect(host, port, NULL, NULL, NULL, pskIdent, psk, _alpnProtocols) == 0;
return connect(host, port, 0, NULL, NULL, NULL, pskIdent, psk) == 0;
}
static int ssl_random(void *data, unsigned char *output, size_t len) {
@@ -96,23 +100,26 @@ void debug_cb(void *ctx, int level, const char *file, int line, const char *str)
int MbedTLSClient::connect(
const char *host,
uint16_t port,
int32_t timeout,
const char *rootCABuf,
const char *clientCert,
const char *clientKey,
const char *pskIdent,
const char *psk,
const char **alpnProtocols
const char *psk
) {
LT_D_SSL("Free heap before TLS: TODO");
if (!rootCABuf && !pskIdent && !psk && !_insecure && !_useRootCA)
return -1;
if (timeout <= 0)
timeout = _timeout; // use default when -1 passed as timeout
IPAddress addr = WiFi.hostByName(host);
if (!(uint32_t)addr)
return -1;
int ret = WiFiClient::connect(addr, port, _timeout);
int ret = WiFiClient::connect(addr, port, timeout);
if (ret < 0) {
LT_E("SSL socket failed");
return ret;
@@ -135,8 +142,8 @@ int MbedTLSClient::connect(
LT_RET_NZ(ret);
#ifdef MBEDTLS_SSL_ALPN
if (alpnProtocols) {
ret = mbedtls_ssl_conf_alpn_protocols(&_sslCfg, alpnProtocols);
if (_alpnProtocols) {
ret = mbedtls_ssl_conf_alpn_protocols(&_sslCfg, _alpnProtocols);
LT_RET_NZ(ret);
}
#endif
@@ -208,10 +215,11 @@ int MbedTLSClient::connect(
_sockTls = fd();
mbedtls_ssl_set_bio(&_sslCtx, &_sockTls, mbedtls_net_send, mbedtls_net_recv, NULL);
mbedtls_net_set_nonblock((mbedtls_net_context *)&_sockTls);
LT_V_SSL("SSL handshake");
if (_handshakeTimeout == 0)
_handshakeTimeout = _timeout * 1000;
_handshakeTimeout = timeout;
unsigned long start = millis();
while (ret = mbedtls_ssl_handshake(&_sslCtx)) {
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) {

View File

@@ -42,12 +42,12 @@ class MbedTLSClient : public WiFiClient, public IWiFiClientSecure {
int connect(
const char *host,
uint16_t port,
int32_t timeout,
const char *rootCABuf,
const char *clientCert,
const char *clientKey,
const char *pskIdent,
const char *psk,
const char **alpnProtocols
const char *psk
);
public:

View File

@@ -18,6 +18,7 @@ extern "C" {
#endif
#include "WiFiClient.h"
#include "WiFiClientSecure.h"
#include "WiFiServer.h"
class WiFiClass : public IWiFiClass,

View File

@@ -68,6 +68,9 @@ int WiFiClient::connect(IPAddress ip, uint16_t port, int32_t timeout) {
return -1;
}
if (timeout <= 0)
timeout = _timeout; // use default when -1 passed as timeout
lwip_fcntl(sock, F_SETFL, lwip_fcntl(sock, F_GETFL, 0) | O_NONBLOCK);
struct sockaddr_in addr;
@@ -80,7 +83,7 @@ int WiFiClient::connect(IPAddress ip, uint16_t port, int32_t timeout) {
FD_ZERO(&fdset);
FD_SET(sock, &fdset);
tv.tv_sec = 0;
tv.tv_usec = timeout * 1000;
tv.tv_usec = timeout * 1000; // millis -> micros
int res = lwip_connect(sock, (struct sockaddr *)&addr, sizeof(addr));
if (res < 0 && errno != EINPROGRESS) {