Compare commits

..

6 Commits

Author SHA1 Message Date
J. Nick Koston
4267e5dda2 Merge branch 'dev' into esp8266-logger-vsnprintf-p 2026-02-02 22:24:57 +01:00
J. Nick Koston
bd3b7aa50a naming 2026-02-02 21:52:34 +01:00
J. Nick Koston
bce4a9c9ab force in 2026-02-02 17:45:25 +01:00
J. Nick Koston
9ba295d334 preen 2026-02-02 17:40:26 +01:00
J. Nick Koston
1501db38b1 tweak 2026-02-02 13:28:00 +01:00
J. Nick Koston
bc6d88fabe [logger] Use vsnprintf_P directly for ESP8266 flash format strings
Instead of copying the format string from flash to RAM before
formatting, use vsnprintf_P to read the format string directly
from flash memory.

This eliminates:
- The byte-by-byte copy loop from PROGMEM
- The complex dual-purpose buffer management
- Potential buffer overflow if format string is very long

The new format_body_to_buffer_P_() function is a simple variant
that uses vsnprintf_P instead of vsnprintf.
2026-02-02 07:58:49 +01:00
4 changed files with 65 additions and 181 deletions

View File

@@ -128,22 +128,7 @@ void HOT Logger::log_vprintf_(uint8_t level, const char *tag, int line, const ch
// Note: USE_STORE_LOG_STR_IN_FLASH is only defined for ESP8266.
//
// This function handles format strings stored in flash memory (PROGMEM) to save RAM.
// The buffer is used in a special way to avoid allocating extra memory:
//
// Memory layout during execution:
// Step 1: Copy format string from flash to buffer
// tx_buffer_: [format_string][null][.....................]
// tx_buffer_at_: ------------------^
// msg_start: saved here -----------^
//
// Step 2: format_log_to_buffer_with_terminator_ reads format string from beginning
// and writes formatted output starting at msg_start position
// tx_buffer_: [format_string][null][formatted_message][null]
// tx_buffer_at_: -------------------------------------^
//
// Step 3: Output the formatted message (starting at msg_start)
// write_msg_ and callbacks receive: this->tx_buffer_ + msg_start
// which points to: [formatted_message][null]
// Uses vsnprintf_P to read the format string directly from flash without copying to RAM.
//
void Logger::log_vprintf_(uint8_t level, const char *tag, int line, const __FlashStringHelper *format,
va_list args) { // NOLINT
@@ -153,35 +138,25 @@ void Logger::log_vprintf_(uint8_t level, const char *tag, int line, const __Flas
RecursionGuard guard(global_recursion_guard_);
this->tx_buffer_at_ = 0;
// Copy format string from progmem
auto *format_pgm_p = reinterpret_cast<const uint8_t *>(format);
char ch = '.';
while (this->tx_buffer_at_ < this->tx_buffer_size_ && ch != '\0') {
this->tx_buffer_[this->tx_buffer_at_++] = ch = (char) progmem_read_byte(format_pgm_p++);
}
// Write header, format body directly from flash, and write footer
this->write_header_to_buffer_(level, tag, line, nullptr, this->tx_buffer_, &this->tx_buffer_at_,
this->tx_buffer_size_);
this->format_body_to_buffer_P_(this->tx_buffer_, &this->tx_buffer_at_, this->tx_buffer_size_,
reinterpret_cast<PGM_P>(format), args);
this->write_footer_to_buffer_(this->tx_buffer_, &this->tx_buffer_at_, this->tx_buffer_size_);
// Buffer full from copying format - RAII guard handles cleanup on return
if (this->tx_buffer_at_ >= this->tx_buffer_size_) {
return;
}
// Save the offset before calling format_log_to_buffer_with_terminator_
// since it will increment tx_buffer_at_ to the end of the formatted string
uint16_t msg_start = this->tx_buffer_at_;
this->format_log_to_buffer_with_terminator_(level, tag, line, this->tx_buffer_, args, this->tx_buffer_,
&this->tx_buffer_at_, this->tx_buffer_size_);
uint16_t msg_length =
this->tx_buffer_at_ - msg_start; // Don't subtract 1 - tx_buffer_at_ is already at the null terminator position
// Ensure null termination
uint16_t null_pos = this->tx_buffer_at_ >= this->tx_buffer_size_ ? this->tx_buffer_size_ - 1 : this->tx_buffer_at_;
this->tx_buffer_[null_pos] = '\0';
// Listeners get message first (before console write)
#ifdef USE_LOG_LISTENERS
for (auto *listener : this->log_listeners_)
listener->on_log(level, tag, this->tx_buffer_ + msg_start, msg_length);
listener->on_log(level, tag, this->tx_buffer_, this->tx_buffer_at_);
#endif
// Write to console starting at the msg_start
this->write_tx_buffer_to_console_(msg_start, &msg_length);
// Write to console
this->write_tx_buffer_to_console_();
}
#endif // USE_STORE_LOG_STR_IN_FLASH

View File

@@ -597,31 +597,40 @@ class Logger : public Component {
*buffer_at = pos;
}
// Helper to process vsnprintf return value and strip trailing newlines.
// Updates buffer_at with the formatted length, handling truncation:
// - When vsnprintf truncates (ret >= remaining), it writes (remaining - 1) chars + null terminator
// - When it doesn't truncate (ret < remaining), it writes ret chars + null terminator
__attribute__((always_inline)) static inline void process_vsnprintf_result(const char *buffer, uint16_t *buffer_at,
uint16_t remaining, int ret) {
if (ret < 0)
return; // Encoding error, do not increment buffer_at
*buffer_at += (ret >= remaining) ? (remaining - 1) : static_cast<uint16_t>(ret);
// Remove all trailing newlines right after formatting
while (*buffer_at > 0 && buffer[*buffer_at - 1] == '\n')
(*buffer_at)--;
}
inline void HOT format_body_to_buffer_(char *buffer, uint16_t *buffer_at, uint16_t buffer_size, const char *format,
va_list args) {
// Get remaining capacity in the buffer
// Check remaining capacity in the buffer
if (*buffer_at >= buffer_size)
return;
const uint16_t remaining = buffer_size - *buffer_at;
const int ret = vsnprintf(buffer + *buffer_at, remaining, format, args);
if (ret < 0) {
return; // Encoding error, do not increment buffer_at
}
// Update buffer_at with the formatted length (handle truncation)
// When vsnprintf truncates (ret >= remaining), it writes (remaining - 1) chars + null terminator
// When it doesn't truncate (ret < remaining), it writes ret chars + null terminator
uint16_t formatted_len = (ret >= remaining) ? (remaining - 1) : ret;
*buffer_at += formatted_len;
// Remove all trailing newlines right after formatting
while (*buffer_at > 0 && buffer[*buffer_at - 1] == '\n') {
(*buffer_at)--;
}
process_vsnprintf_result(buffer, buffer_at, remaining, vsnprintf(buffer + *buffer_at, remaining, format, args));
}
#ifdef USE_STORE_LOG_STR_IN_FLASH
// ESP8266 variant that reads format string directly from flash using vsnprintf_P
inline void HOT format_body_to_buffer_P_(char *buffer, uint16_t *buffer_at, uint16_t buffer_size, PGM_P format,
va_list args) {
if (*buffer_at >= buffer_size)
return;
const uint16_t remaining = buffer_size - *buffer_at;
process_vsnprintf_result(buffer, buffer_at, remaining, vsnprintf_P(buffer + *buffer_at, remaining, format, args));
}
#endif
inline void HOT write_footer_to_buffer_(char *buffer, uint16_t *buffer_at, uint16_t buffer_size) {
static constexpr uint16_t RESET_COLOR_LEN = sizeof(ESPHOME_LOG_RESET_COLOR) - 1;
this->write_body_to_buffer_(ESPHOME_LOG_RESET_COLOR, RESET_COLOR_LEN, buffer, buffer_at, buffer_size);

View File

@@ -1,7 +1,5 @@
import base64
from pathlib import Path
import random
import secrets
import string
from typing import Literal, NotRequired, TypedDict, Unpack
import unicodedata
@@ -118,6 +116,7 @@ class WizardFileKwargs(TypedDict):
board: str
ssid: NotRequired[str]
psk: NotRequired[str]
password: NotRequired[str]
ota_password: NotRequired[str]
api_encryption_key: NotRequired[str]
friendly_name: NotRequired[str]
@@ -145,7 +144,9 @@ def wizard_file(**kwargs: Unpack[WizardFileKwargs]) -> str:
config += API_CONFIG
# Configure API encryption
# Configure API
if "password" in kwargs:
config += f' password: "{kwargs["password"]}"\n'
if "api_encryption_key" in kwargs:
config += f' encryption:\n key: "{kwargs["api_encryption_key"]}"\n'
@@ -154,6 +155,8 @@ def wizard_file(**kwargs: Unpack[WizardFileKwargs]) -> str:
config += " - platform: esphome\n"
if "ota_password" in kwargs:
config += f' password: "{kwargs["ota_password"]}"'
elif "password" in kwargs:
config += f' password: "{kwargs["password"]}"'
# Configuring wifi
config += "\n\nwifi:\n"
@@ -202,6 +205,7 @@ class WizardWriteKwargs(TypedDict):
platform: NotRequired[str]
ssid: NotRequired[str]
psk: NotRequired[str]
password: NotRequired[str]
ota_password: NotRequired[str]
api_encryption_key: NotRequired[str]
friendly_name: NotRequired[str]
@@ -228,7 +232,7 @@ def wizard_write(path: Path, **kwargs: Unpack[WizardWriteKwargs]) -> bool:
else: # "basic"
board = kwargs["board"]
for key in ("ssid", "psk", "ota_password"):
for key in ("ssid", "psk", "password", "ota_password"):
if key in kwargs:
kwargs[key] = sanitize_double_quotes(kwargs[key])
if "platform" not in kwargs:
@@ -518,54 +522,26 @@ def wizard(path: Path) -> int:
"Almost there! ESPHome can automatically upload custom firmwares over WiFi "
"(over the air) and integrates into Home Assistant with a native API."
)
safe_print()
sleep(0.5)
# Generate encryption key (32 bytes, base64 encoded) for secure API communication
noise_psk = secrets.token_bytes(32)
api_encryption_key = base64.b64encode(noise_psk).decode()
safe_print(
"For secure API communication, I've generated a random encryption key."
)
safe_print()
safe_print(
f"Your {color(AnsiFore.GREEN, 'API encryption key')} is: "
f"{color(AnsiFore.BOLD_WHITE, api_encryption_key)}"
)
safe_print()
safe_print("You'll need this key when adding the device to Home Assistant.")
sleep(1)
safe_print()
safe_print(
f"Do you want to set a {color(AnsiFore.GREEN, 'password')} for OTA updates? "
"This can be insecure if you do not trust the WiFi network."
f"This can be insecure if you do not trust the WiFi network. Do you want to set a {color(AnsiFore.GREEN, 'password')} for connecting to this ESP?"
)
safe_print()
sleep(0.25)
safe_print("Press ENTER for no password")
ota_password = safe_input(color(AnsiFore.BOLD_WHITE, "(password): "))
password = safe_input(color(AnsiFore.BOLD_WHITE, "(password): "))
else:
ssid, psk = "", ""
api_encryption_key = None
ota_password = ""
ssid, password, psk = "", "", ""
kwargs = {
"path": path,
"name": name,
"platform": platform,
"board": board,
"ssid": ssid,
"psk": psk,
"type": "basic",
}
if api_encryption_key:
kwargs["api_encryption_key"] = api_encryption_key
if ota_password:
kwargs["ota_password"] = ota_password
if not wizard_write(**kwargs):
if not wizard_write(
path=path,
name=name,
platform=platform,
board=board,
ssid=ssid,
psk=psk,
password=password,
type="basic",
):
return 1
safe_print()

View File

@@ -25,6 +25,7 @@ def default_config() -> dict[str, Any]:
"board": "esp01_1m",
"ssid": "test_ssid",
"psk": "test_psk",
"password": "",
}
@@ -36,7 +37,7 @@ def wizard_answers() -> list[str]:
"nodemcuv2", # board
"SSID", # ssid
"psk", # wifi password
"", # ota password (empty for no password)
"ota_pass", # ota password
]
@@ -104,35 +105,16 @@ def test_config_file_should_include_ota_when_password_set(
default_config: dict[str, Any],
):
"""
The Over-The-Air update should be enabled when an OTA password is set
The Over-The-Air update should be enabled when a password is set
"""
# Given
default_config["ota_password"] = "foo"
default_config["password"] = "foo"
# When
config = wz.wizard_file(**default_config)
# Then
assert "ota:" in config
assert 'password: "foo"' in config
def test_config_file_should_include_api_encryption_key(
default_config: dict[str, Any],
):
"""
The API encryption key should be included when set
"""
# Given
default_config["api_encryption_key"] = "test_encryption_key_base64=="
# When
config = wz.wizard_file(**default_config)
# Then
assert "api:" in config
assert "encryption:" in config
assert 'key: "test_encryption_key_base64=="' in config
def test_wizard_write_sets_platform(
@@ -574,61 +556,3 @@ def test_wizard_write_protects_existing_config(
# Then
assert result is False # Should return False when file exists
assert config_file.read_text() == original_content
def test_wizard_accepts_ota_password(
tmp_path: Path, monkeypatch: MonkeyPatch, wizard_answers: list[str]
):
"""
The wizard should pass ota_password to wizard_write when the user provides one
"""
# Given
wizard_answers[5] = "my_ota_password" # Set OTA password
config_file = tmp_path / "test.yaml"
input_mock = MagicMock(side_effect=wizard_answers)
monkeypatch.setattr("builtins.input", input_mock)
monkeypatch.setattr(wz, "safe_print", lambda t=None, end=None: 0)
monkeypatch.setattr(wz, "sleep", lambda _: 0)
wizard_write_mock = MagicMock(return_value=True)
monkeypatch.setattr(wz, "wizard_write", wizard_write_mock)
# When
retval = wz.wizard(config_file)
# Then
assert retval == 0
call_kwargs = wizard_write_mock.call_args.kwargs
assert "ota_password" in call_kwargs
assert call_kwargs["ota_password"] == "my_ota_password"
def test_wizard_accepts_rpipico_board(tmp_path: Path, monkeypatch: MonkeyPatch):
"""
The wizard should handle rpipico board which doesn't support WiFi.
This tests the branch where api_encryption_key is None.
"""
# Given
wizard_answers_rp2040 = [
"test-node", # Name of the node
"RP2040", # platform
"rpipico", # board (no WiFi support)
]
config_file = tmp_path / "test.yaml"
input_mock = MagicMock(side_effect=wizard_answers_rp2040)
monkeypatch.setattr("builtins.input", input_mock)
monkeypatch.setattr(wz, "safe_print", lambda t=None, end=None: 0)
monkeypatch.setattr(wz, "sleep", lambda _: 0)
wizard_write_mock = MagicMock(return_value=True)
monkeypatch.setattr(wz, "wizard_write", wizard_write_mock)
# When
retval = wz.wizard(config_file)
# Then
assert retval == 0
call_kwargs = wizard_write_mock.call_args.kwargs
# rpipico doesn't support WiFi, so no api_encryption_key or ota_password
assert "api_encryption_key" not in call_kwargs
assert "ota_password" not in call_kwargs