[logger] Add thread-safe logging for host platform (#13010)

This commit is contained in:
J. Nick Koston
2026-01-07 08:29:50 -10:00
committed by GitHub
parent d86d1f9f52
commit 25ac89e9b5
10 changed files with 698 additions and 45 deletions

View File

@@ -0,0 +1,91 @@
esphome:
name: host-logger-thread-test
host:
api:
logger:
button:
- platform: template
name: "Start Thread Race Test"
id: start_test_button
on_press:
- lambda: |-
// Number of threads and messages per thread
static const int NUM_THREADS = 3;
static const int MESSAGES_PER_THREAD = 100;
// Counters
static std::atomic<int> total_messages_logged{0};
// Thread function - must be a regular function pointer for pthread
struct ThreadTest {
static void *thread_func(void *arg) {
int thread_id = *static_cast<int *>(arg);
// Set thread name (different signatures on macOS vs Linux)
char thread_name[16];
snprintf(thread_name, sizeof(thread_name), "LogThread%d", thread_id);
#ifdef __APPLE__
pthread_setname_np(thread_name);
#else
pthread_setname_np(pthread_self(), thread_name);
#endif
// Log messages with different log levels
for (int i = 0; i < MESSAGES_PER_THREAD; i++) {
switch (i % 4) {
case 0:
ESP_LOGI("thread_test", "THREAD%d_MSG%03d_INFO_MESSAGE_WITH_DATA_%08X",
thread_id, i, i * 12345);
break;
case 1:
ESP_LOGD("thread_test", "THREAD%d_MSG%03d_DEBUG_MESSAGE_WITH_DATA_%08X",
thread_id, i, i * 12345);
break;
case 2:
ESP_LOGW("thread_test", "THREAD%d_MSG%03d_WARN_MESSAGE_WITH_DATA_%08X",
thread_id, i, i * 12345);
break;
case 3:
ESP_LOGE("thread_test", "THREAD%d_MSG%03d_ERROR_MESSAGE_WITH_DATA_%08X",
thread_id, i, i * 12345);
break;
}
total_messages_logged.fetch_add(1, std::memory_order_relaxed);
// Small busy loop to vary timing between threads
int delay_count = (thread_id + 1) * 10;
while (delay_count-- > 0) {
asm volatile("" ::: "memory"); // Prevent optimization
}
}
return nullptr;
}
};
ESP_LOGI("thread_test", "RACE_TEST_START: Starting %d threads with %d messages each",
NUM_THREADS, MESSAGES_PER_THREAD);
// Reset counter for this test run
total_messages_logged.store(0, std::memory_order_relaxed);
pthread_t threads[NUM_THREADS];
int thread_ids[NUM_THREADS];
// Create all threads
for (int i = 0; i < NUM_THREADS; i++) {
thread_ids[i] = i;
int ret = pthread_create(&threads[i], nullptr, ThreadTest::thread_func, &thread_ids[i]);
if (ret != 0) {
ESP_LOGE("thread_test", "RACE_TEST_ERROR: Failed to create thread %d", i);
return;
}
}
// Wait for all threads to complete
for (int i = 0; i < NUM_THREADS; i++) {
pthread_join(threads[i], nullptr);
}
ESP_LOGI("thread_test", "RACE_TEST_COMPLETE: All threads finished, total messages: %d",
total_messages_logged.load(std::memory_order_relaxed));

View File

@@ -0,0 +1,182 @@
"""Integration test for host logger thread safety.
This test verifies that the logger's MPSC ring buffer correctly handles
multiple threads racing to log messages without corruption or data loss.
"""
from __future__ import annotations
import asyncio
import re
import pytest
from .types import APIClientConnectedFactory, RunCompiledFunction
# Expected pattern for log messages from threads
# Format: THREADn_MSGnnn_LEVEL_MESSAGE_WITH_DATA_xxxxxxxx
THREAD_MSG_PATTERN = re.compile(
r"THREAD(\d+)_MSG(\d{3})_(INFO|DEBUG|WARN|ERROR)_MESSAGE_WITH_DATA_([0-9A-F]{8})"
)
# Pattern for test start/complete markers
TEST_START_PATTERN = re.compile(r"RACE_TEST_START.*Starting (\d+) threads")
TEST_COMPLETE_PATTERN = re.compile(r"RACE_TEST_COMPLETE.*total messages: (\d+)")
# Expected values
NUM_THREADS = 3
MESSAGES_PER_THREAD = 100
EXPECTED_TOTAL_MESSAGES = NUM_THREADS * MESSAGES_PER_THREAD
@pytest.mark.asyncio
async def test_host_logger_thread_safety(
yaml_config: str,
run_compiled: RunCompiledFunction,
api_client_connected: APIClientConnectedFactory,
) -> None:
"""Test that multiple threads can log concurrently without corruption.
This test:
1. Spawns 3 threads that each log 100 messages
2. Collects all log output
3. Verifies no lines are corrupted (partially written or interleaved)
4. Verifies all expected messages were received
"""
collected_lines: list[str] = []
test_complete_event = asyncio.Event()
def line_callback(line: str) -> None:
"""Collect log lines and detect test completion."""
collected_lines.append(line)
if "RACE_TEST_COMPLETE" in line:
test_complete_event.set()
# Run the test binary and collect output
async with (
run_compiled(yaml_config, line_callback=line_callback),
api_client_connected() as client,
):
# Verify connection works
device_info = await client.device_info()
assert device_info is not None
assert device_info.name == "host-logger-thread-test"
# Get the button entity - find by name
entities, _ = await client.list_entities_services()
button_entities = [e for e in entities if e.name == "Start Thread Race Test"]
assert button_entities, "Could not find Start Thread Race Test button"
button_key = button_entities[0].key
# Press the button to start the thread race test
client.button_command(button_key)
# Wait for test to complete (with timeout)
try:
await asyncio.wait_for(test_complete_event.wait(), timeout=30.0)
except TimeoutError:
pytest.fail(
"Test did not complete within timeout. "
f"Collected {len(collected_lines)} lines."
)
# Give a bit more time for any remaining buffered messages
await asyncio.sleep(0.5)
# Analyze collected log lines
thread_messages: dict[int, set[int]] = {i: set() for i in range(NUM_THREADS)}
corrupted_lines: list[str] = []
test_started = False
test_completed = False
reported_total = 0
for line in collected_lines:
# Check for test start
start_match = TEST_START_PATTERN.search(line)
if start_match:
test_started = True
assert int(start_match.group(1)) == NUM_THREADS, (
f"Unexpected thread count: {start_match.group(1)}"
)
continue
# Check for test completion
complete_match = TEST_COMPLETE_PATTERN.search(line)
if complete_match:
test_completed = True
reported_total = int(complete_match.group(1))
continue
# Check for thread messages
msg_match = THREAD_MSG_PATTERN.search(line)
if msg_match:
thread_id = int(msg_match.group(1))
msg_num = int(msg_match.group(2))
# level = msg_match.group(3) # INFO, DEBUG, WARN, ERROR
data_hex = msg_match.group(4)
# Verify data value matches expected calculation
expected_data = f"{msg_num * 12345:08X}"
if data_hex != expected_data:
corrupted_lines.append(
f"Data mismatch in line: {line} "
f"(expected {expected_data}, got {data_hex})"
)
continue
# Track which messages we received from each thread
if 0 <= thread_id < NUM_THREADS:
thread_messages[thread_id].add(msg_num)
else:
corrupted_lines.append(f"Invalid thread ID in line: {line}")
continue
# Check for partial/corrupted thread messages
# If a line contains part of a thread message pattern but doesn't match fully
# This could indicate line corruption from interleaving
if (
"THREAD" in line
and "MSG" in line
and not msg_match
and "_MESSAGE_WITH_DATA_" in line
):
corrupted_lines.append(f"Possibly corrupted line: {line}")
# Assertions
assert test_started, "Test start marker not found in output"
assert test_completed, "Test completion marker not found in output"
assert reported_total == EXPECTED_TOTAL_MESSAGES, (
f"Reported total {reported_total} != expected {EXPECTED_TOTAL_MESSAGES}"
)
# Check for corrupted lines
assert not corrupted_lines, (
f"Found {len(corrupted_lines)} corrupted lines:\n"
+ "\n".join(corrupted_lines[:10]) # Show first 10
)
# Count total messages received
total_received = sum(len(msgs) for msgs in thread_messages.values())
# We may not receive all messages due to ring buffer overflow when buffer is full
# The test primarily verifies no corruption, not that we receive every message
# However, we should receive a reasonable number of messages
min_expected = EXPECTED_TOTAL_MESSAGES // 2 # At least 50%
assert total_received >= min_expected, (
f"Received only {total_received} messages, expected at least {min_expected}. "
f"Per-thread breakdown: "
+ ", ".join(f"Thread{i}: {len(msgs)}" for i, msgs in thread_messages.items())
)
# Verify we got messages from all threads (proves concurrent logging worked)
for thread_id in range(NUM_THREADS):
assert thread_messages[thread_id], (
f"No messages received from thread {thread_id}"
)
# Log summary for debugging
print("\nThread safety test summary:")
print(f" Total messages received: {total_received}/{EXPECTED_TOTAL_MESSAGES}")
for thread_id in range(NUM_THREADS):
received = len(thread_messages[thread_id])
print(f" Thread {thread_id}: {received}/{MESSAGES_PER_THREAD} messages")