[api] Split ProtoVarInt::parse into 32-bit and 64-bit phases (#14039)

This commit is contained in:
J. Nick Koston
2026-02-25 12:23:13 -06:00
committed by GitHub
parent ede8235aae
commit 8bb577de64
8 changed files with 301 additions and 32 deletions

View File

@@ -2,7 +2,6 @@
// See script/api_protobuf/api_protobuf.py
#pragma once
#include "esphome/core/defines.h"
#include "esphome/core/string_ref.h"
#include "proto.h"

View File

@@ -0,0 +1,12 @@
// This file was automatically generated with a tool.
// See script/api_protobuf/api_protobuf.py
#pragma once
#include "esphome/core/defines.h"
#ifdef USE_BLUETOOTH_PROXY
#ifndef USE_API_VARINT64
#define USE_API_VARINT64
#endif
#endif
namespace esphome::api {} // namespace esphome::api

View File

@@ -7,6 +7,23 @@ namespace esphome::api {
static const char *const TAG = "api.proto";
#ifdef USE_API_VARINT64
optional<ProtoVarInt> ProtoVarInt::parse_wide(const uint8_t *buffer, uint32_t len, uint32_t *consumed,
uint32_t result32) {
uint64_t result64 = result32;
uint32_t limit = std::min(len, uint32_t(10));
for (uint32_t i = 4; i < limit; i++) {
uint8_t val = buffer[i];
result64 |= uint64_t(val & 0x7F) << (i * 7);
if ((val & 0x80) == 0) {
*consumed = i + 1;
return ProtoVarInt(result64);
}
}
return {};
}
#endif
uint32_t ProtoDecodableMessage::count_repeated_field(const uint8_t *buffer, size_t length, uint32_t target_field_id) {
uint32_t count = 0;
const uint8_t *ptr = buffer;

View File

@@ -1,5 +1,6 @@
#pragma once
#include "api_pb2_defines.h"
#include "esphome/core/component.h"
#include "esphome/core/helpers.h"
#include "esphome/core/log.h"
@@ -110,59 +111,78 @@ class ProtoVarInt {
#endif
if (len == 0)
return {};
// Most common case: single-byte varint (values 0-127)
// Fast path: single-byte varints (0-127) are the most common case
// (booleans, small enums, field tags). Avoid loop overhead entirely.
if ((buffer[0] & 0x80) == 0) {
*consumed = 1;
return ProtoVarInt(buffer[0]);
}
// General case for multi-byte varints
// Since we know buffer[0]'s high bit is set, initialize with its value
uint64_t result = buffer[0] & 0x7F;
uint8_t bitpos = 7;
// A 64-bit varint is at most 10 bytes (ceil(64/7)). Reject overlong encodings
// to avoid undefined behavior from shifting uint64_t by >= 64 bits.
uint32_t max_len = std::min(len, uint32_t(10));
// Start from the second byte since we've already processed the first
for (uint32_t i = 1; i < max_len; i++) {
// 32-bit phase: process remaining bytes with native 32-bit shifts.
// Without USE_API_VARINT64: cover bytes 1-4 (shifts 7, 14, 21, 28) — the uint32_t
// shift at byte 4 (shift by 28) may lose bits 32-34, but those are always zero for valid uint32 values.
// With USE_API_VARINT64: cover bytes 1-3 (shifts 7, 14, 21) so parse_wide handles
// byte 4+ with full 64-bit arithmetic (avoids truncating values > UINT32_MAX).
uint32_t result32 = buffer[0] & 0x7F;
#ifdef USE_API_VARINT64
uint32_t limit = std::min(len, uint32_t(4));
#else
uint32_t limit = std::min(len, uint32_t(5));
#endif
for (uint32_t i = 1; i < limit; i++) {
uint8_t val = buffer[i];
result |= uint64_t(val & 0x7F) << uint64_t(bitpos);
bitpos += 7;
result32 |= uint32_t(val & 0x7F) << (i * 7);
if ((val & 0x80) == 0) {
*consumed = i + 1;
return ProtoVarInt(result);
return ProtoVarInt(result32);
}
}
return {}; // Incomplete or invalid varint
// 64-bit phase for remaining bytes (BLE addresses etc.)
#ifdef USE_API_VARINT64
return parse_wide(buffer, len, consumed, result32);
#else
return {};
#endif
}
#ifdef USE_API_VARINT64
protected:
/// Continue parsing varint bytes 4-9 with 64-bit arithmetic.
/// Separated to keep 64-bit shift code (__ashldi3 on 32-bit platforms) out of the common path.
static optional<ProtoVarInt> parse_wide(const uint8_t *buffer, uint32_t len, uint32_t *consumed, uint32_t result32)
__attribute__((noinline));
public:
#endif
constexpr uint16_t as_uint16() const { return this->value_; }
constexpr uint32_t as_uint32() const { return this->value_; }
constexpr uint64_t as_uint64() const { return this->value_; }
constexpr bool as_bool() const { return this->value_; }
constexpr int32_t as_int32() const {
// Not ZigZag encoded
return static_cast<int32_t>(this->as_int64());
}
constexpr int64_t as_int64() const {
// Not ZigZag encoded
return static_cast<int64_t>(this->value_);
return static_cast<int32_t>(this->value_);
}
constexpr int32_t as_sint32() const {
// with ZigZag encoding
return decode_zigzag32(static_cast<uint32_t>(this->value_));
}
#ifdef USE_API_VARINT64
constexpr uint64_t as_uint64() const { return this->value_; }
constexpr int64_t as_int64() const {
// Not ZigZag encoded
return static_cast<int64_t>(this->value_);
}
constexpr int64_t as_sint64() const {
// with ZigZag encoding
return decode_zigzag64(this->value_);
}
#endif
protected:
#ifdef USE_API_VARINT64
uint64_t value_;
#else
uint32_t value_;
#endif
};
// Forward declarations for decode_to_message, encode_message and encode_packed_sint32

View File

@@ -144,6 +144,7 @@
#define USE_API_HOMEASSISTANT_SERVICES
#define USE_API_HOMEASSISTANT_STATES
#define USE_API_NOISE
#define USE_API_VARINT64
#define USE_API_PLAINTEXT
#define USE_API_USER_DEFINED_ACTIONS
#define USE_API_CUSTOM_SERVICES

View File

@@ -1913,6 +1913,37 @@ def build_type_usage_map(
)
def get_varint64_ifdef(
file_desc: descriptor.FileDescriptorProto,
message_ifdef_map: dict[str, str | None],
) -> tuple[bool, str | None]:
"""Check if 64-bit varint fields exist and get their common ifdef guard.
Returns:
(has_varint64, ifdef_guard) - has_varint64 is True if any fields exist,
ifdef_guard is the common guard or None if unconditional.
"""
varint64_types = {
FieldDescriptorProto.TYPE_INT64,
FieldDescriptorProto.TYPE_UINT64,
FieldDescriptorProto.TYPE_SINT64,
}
ifdefs: set[str | None] = {
message_ifdef_map.get(msg.name)
for msg in file_desc.message_type
if not msg.options.deprecated
for field in msg.field
if not field.options.deprecated and field.type in varint64_types
}
if not ifdefs:
return False, None
if None in ifdefs:
# At least one 64-bit varint field is unconditional, so the guard must be unconditional.
return True, None
ifdefs.discard(None)
return True, ifdefs.pop() if len(ifdefs) == 1 else None
def build_enum_type(desc, enum_ifdef_map) -> tuple[str, str, str]:
"""Builds the enum type.
@@ -2567,11 +2598,38 @@ def main() -> None:
file = d.file[0]
# Build dynamic ifdef mappings early so we can emit USE_API_VARINT64 before includes
enum_ifdef_map, message_ifdef_map, message_source_map, used_messages = (
build_type_usage_map(file)
)
# Find the ifdef guard for 64-bit varint fields (int64/uint64/sint64).
# Generated into api_pb2_defines.h so proto.h can include it, ensuring
# consistent ProtoVarInt layout across all translation units.
has_varint64, varint64_guard = get_varint64_ifdef(file, message_ifdef_map)
# Generate api_pb2_defines.h — included by proto.h to ensure all translation
# units see USE_API_VARINT64 consistently (avoids ODR violations in ProtoVarInt).
defines_content = FILE_HEADER
defines_content += "#pragma once\n\n"
defines_content += '#include "esphome/core/defines.h"\n'
if has_varint64:
lines = [
"#ifndef USE_API_VARINT64",
"#define USE_API_VARINT64",
"#endif",
]
defines_content += "\n".join(wrap_with_ifdef(lines, varint64_guard))
defines_content += "\n"
defines_content += "\nnamespace esphome::api {} // namespace esphome::api\n"
with open(root / "api_pb2_defines.h", "w", encoding="utf-8") as f:
f.write(defines_content)
content = FILE_HEADER
content += """\
#pragma once
#include "esphome/core/defines.h"
#include "esphome/core/string_ref.h"
#include "proto.h"
@@ -2702,11 +2760,6 @@ static void dump_bytes_field(DumpBuffer &out, const char *field_name, const uint
content += "namespace enums {\n\n"
# Build dynamic ifdef mappings for both enums and messages
enum_ifdef_map, message_ifdef_map, message_source_map, used_messages = (
build_type_usage_map(file)
)
# Simple grouping of enums by ifdef
current_ifdef = None

View File

@@ -0,0 +1,47 @@
esphome:
name: varint-5byte-test
# Define areas and devices - device_ids will be FNV hashes > 2^28,
# requiring 5-byte varint encoding that exercises the 32-bit parse boundary.
areas:
- id: test_area
name: Test Area
devices:
- id: sub_device_one
name: Sub Device One
area_id: test_area
- id: sub_device_two
name: Sub Device Two
area_id: test_area
host:
api:
logger:
# Switches on sub-devices so we can send commands with large device_id varints
switch:
- platform: template
name: Device Switch
device_id: sub_device_one
id: device_switch_one
optimistic: true
turn_on_action:
- logger.log: "Switch one on"
turn_off_action:
- logger.log: "Switch one off"
- platform: template
name: Device Switch
device_id: sub_device_two
id: device_switch_two
optimistic: true
turn_on_action:
- logger.log: "Switch two on"
turn_off_action:
- logger.log: "Switch two off"
sensor:
- platform: template
name: Device Sensor
device_id: sub_device_one
lambda: return 42.0;
update_interval: 0.1s

View File

@@ -0,0 +1,120 @@
"""Integration test for 5-byte varint parsing of device_id fields.
Device IDs are FNV hashes (uint32) that frequently exceed 2^28 (268435456),
requiring 5 varint bytes. This test verifies that:
1. The firmware correctly decodes 5-byte varint device_id in incoming commands
2. The firmware correctly encodes large device_id values in state responses
3. Switch commands with large device_id reach the correct entity
"""
from __future__ import annotations
import asyncio
from aioesphomeapi import EntityState, SwitchInfo, SwitchState
import pytest
from .types import APIClientConnectedFactory, RunCompiledFunction
@pytest.mark.asyncio
async def test_varint_five_byte_device_id(
yaml_config: str,
run_compiled: RunCompiledFunction,
api_client_connected: APIClientConnectedFactory,
) -> None:
"""Test that device_id values requiring 5-byte varints parse correctly."""
async with run_compiled(yaml_config), api_client_connected() as client:
device_info = await client.device_info()
devices = device_info.devices
assert len(devices) >= 2, f"Expected at least 2 devices, got {len(devices)}"
# Verify at least one device_id exceeds the 4-byte varint boundary (2^28)
large_ids = [d for d in devices if d.device_id >= (1 << 28)]
assert len(large_ids) > 0, (
"Expected at least one device_id >= 2^28 to exercise 5-byte varint path. "
f"Got device_ids: {[d.device_id for d in devices]}"
)
# Get entities
all_entities, _ = await client.list_entities_services()
switch_entities = [e for e in all_entities if isinstance(e, SwitchInfo)]
# Find switches named "Device Switch" — one per sub-device
device_switches = [e for e in switch_entities if e.name == "Device Switch"]
assert len(device_switches) == 2, (
f"Expected 2 'Device Switch' entities, got {len(device_switches)}"
)
# Verify switches have different device_ids matching the sub-devices
switch_device_ids = {s.device_id for s in device_switches}
assert len(switch_device_ids) == 2, "Switches should have different device_ids"
# Subscribe to states and wait for initial states
loop = asyncio.get_running_loop()
states: dict[tuple[int, int], EntityState] = {}
switch_futures: dict[tuple[int, int], asyncio.Future[EntityState]] = {}
initial_done: asyncio.Future[bool] = loop.create_future()
def on_state(state: EntityState) -> None:
key = (state.device_id, state.key)
states[key] = state
if len(states) >= 3 and not initial_done.done():
initial_done.set_result(True)
if initial_done.done() and key in switch_futures:
fut = switch_futures[key]
if not fut.done() and isinstance(state, SwitchState):
fut.set_result(state)
client.subscribe_states(on_state)
try:
await asyncio.wait_for(initial_done, timeout=10.0)
except TimeoutError:
pytest.fail(
f"Timed out waiting for initial states. Got {len(states)} states"
)
# Verify state responses contain correct large device_id values
for device in devices:
device_states = [
s for (did, _), s in states.items() if did == device.device_id
]
assert len(device_states) > 0, (
f"No states received for device '{device.name}' "
f"(device_id={device.device_id})"
)
# Test switch commands with large device_id varints —
# this is the critical path: the client encodes device_id as a varint
# in the SwitchCommandRequest, and the firmware must decode it correctly.
for switch in device_switches:
state_key = (switch.device_id, switch.key)
# Turn on
switch_futures[state_key] = loop.create_future()
client.switch_command(switch.key, True, device_id=switch.device_id)
try:
await asyncio.wait_for(switch_futures[state_key], timeout=2.0)
except TimeoutError:
pytest.fail(
f"Timed out waiting for switch ON state "
f"(device_id={switch.device_id}, key={switch.key}). "
f"This likely means the firmware failed to decode the "
f"5-byte varint device_id in SwitchCommandRequest."
)
assert states[state_key].state is True
# Turn off
switch_futures[state_key] = loop.create_future()
client.switch_command(switch.key, False, device_id=switch.device_id)
try:
await asyncio.wait_for(switch_futures[state_key], timeout=2.0)
except TimeoutError:
pytest.fail(
f"Timed out waiting for switch OFF state "
f"(device_id={switch.device_id}, key={switch.key})"
)
assert states[state_key].state is False