mirror of
https://github.com/esphome/esphome.git
synced 2026-02-25 21:43:14 -07:00
[api] Split ProtoVarInt::parse into 32-bit and 64-bit phases (#14039)
This commit is contained in:
@@ -2,7 +2,6 @@
|
||||
// See script/api_protobuf/api_protobuf.py
|
||||
#pragma once
|
||||
|
||||
#include "esphome/core/defines.h"
|
||||
#include "esphome/core/string_ref.h"
|
||||
|
||||
#include "proto.h"
|
||||
|
||||
12
esphome/components/api/api_pb2_defines.h
Normal file
12
esphome/components/api/api_pb2_defines.h
Normal file
@@ -0,0 +1,12 @@
|
||||
// This file was automatically generated with a tool.
|
||||
// See script/api_protobuf/api_protobuf.py
|
||||
#pragma once
|
||||
|
||||
#include "esphome/core/defines.h"
|
||||
#ifdef USE_BLUETOOTH_PROXY
|
||||
#ifndef USE_API_VARINT64
|
||||
#define USE_API_VARINT64
|
||||
#endif
|
||||
#endif
|
||||
|
||||
namespace esphome::api {} // namespace esphome::api
|
||||
@@ -7,6 +7,23 @@ namespace esphome::api {
|
||||
|
||||
static const char *const TAG = "api.proto";
|
||||
|
||||
#ifdef USE_API_VARINT64
|
||||
optional<ProtoVarInt> ProtoVarInt::parse_wide(const uint8_t *buffer, uint32_t len, uint32_t *consumed,
|
||||
uint32_t result32) {
|
||||
uint64_t result64 = result32;
|
||||
uint32_t limit = std::min(len, uint32_t(10));
|
||||
for (uint32_t i = 4; i < limit; i++) {
|
||||
uint8_t val = buffer[i];
|
||||
result64 |= uint64_t(val & 0x7F) << (i * 7);
|
||||
if ((val & 0x80) == 0) {
|
||||
*consumed = i + 1;
|
||||
return ProtoVarInt(result64);
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
#endif
|
||||
|
||||
uint32_t ProtoDecodableMessage::count_repeated_field(const uint8_t *buffer, size_t length, uint32_t target_field_id) {
|
||||
uint32_t count = 0;
|
||||
const uint8_t *ptr = buffer;
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "api_pb2_defines.h"
|
||||
#include "esphome/core/component.h"
|
||||
#include "esphome/core/helpers.h"
|
||||
#include "esphome/core/log.h"
|
||||
@@ -110,59 +111,78 @@ class ProtoVarInt {
|
||||
#endif
|
||||
if (len == 0)
|
||||
return {};
|
||||
|
||||
// Most common case: single-byte varint (values 0-127)
|
||||
// Fast path: single-byte varints (0-127) are the most common case
|
||||
// (booleans, small enums, field tags). Avoid loop overhead entirely.
|
||||
if ((buffer[0] & 0x80) == 0) {
|
||||
*consumed = 1;
|
||||
return ProtoVarInt(buffer[0]);
|
||||
}
|
||||
|
||||
// General case for multi-byte varints
|
||||
// Since we know buffer[0]'s high bit is set, initialize with its value
|
||||
uint64_t result = buffer[0] & 0x7F;
|
||||
uint8_t bitpos = 7;
|
||||
|
||||
// A 64-bit varint is at most 10 bytes (ceil(64/7)). Reject overlong encodings
|
||||
// to avoid undefined behavior from shifting uint64_t by >= 64 bits.
|
||||
uint32_t max_len = std::min(len, uint32_t(10));
|
||||
|
||||
// Start from the second byte since we've already processed the first
|
||||
for (uint32_t i = 1; i < max_len; i++) {
|
||||
// 32-bit phase: process remaining bytes with native 32-bit shifts.
|
||||
// Without USE_API_VARINT64: cover bytes 1-4 (shifts 7, 14, 21, 28) — the uint32_t
|
||||
// shift at byte 4 (shift by 28) may lose bits 32-34, but those are always zero for valid uint32 values.
|
||||
// With USE_API_VARINT64: cover bytes 1-3 (shifts 7, 14, 21) so parse_wide handles
|
||||
// byte 4+ with full 64-bit arithmetic (avoids truncating values > UINT32_MAX).
|
||||
uint32_t result32 = buffer[0] & 0x7F;
|
||||
#ifdef USE_API_VARINT64
|
||||
uint32_t limit = std::min(len, uint32_t(4));
|
||||
#else
|
||||
uint32_t limit = std::min(len, uint32_t(5));
|
||||
#endif
|
||||
for (uint32_t i = 1; i < limit; i++) {
|
||||
uint8_t val = buffer[i];
|
||||
result |= uint64_t(val & 0x7F) << uint64_t(bitpos);
|
||||
bitpos += 7;
|
||||
result32 |= uint32_t(val & 0x7F) << (i * 7);
|
||||
if ((val & 0x80) == 0) {
|
||||
*consumed = i + 1;
|
||||
return ProtoVarInt(result);
|
||||
return ProtoVarInt(result32);
|
||||
}
|
||||
}
|
||||
|
||||
return {}; // Incomplete or invalid varint
|
||||
// 64-bit phase for remaining bytes (BLE addresses etc.)
|
||||
#ifdef USE_API_VARINT64
|
||||
return parse_wide(buffer, len, consumed, result32);
|
||||
#else
|
||||
return {};
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef USE_API_VARINT64
|
||||
protected:
|
||||
/// Continue parsing varint bytes 4-9 with 64-bit arithmetic.
|
||||
/// Separated to keep 64-bit shift code (__ashldi3 on 32-bit platforms) out of the common path.
|
||||
static optional<ProtoVarInt> parse_wide(const uint8_t *buffer, uint32_t len, uint32_t *consumed, uint32_t result32)
|
||||
__attribute__((noinline));
|
||||
|
||||
public:
|
||||
#endif
|
||||
|
||||
constexpr uint16_t as_uint16() const { return this->value_; }
|
||||
constexpr uint32_t as_uint32() const { return this->value_; }
|
||||
constexpr uint64_t as_uint64() const { return this->value_; }
|
||||
constexpr bool as_bool() const { return this->value_; }
|
||||
constexpr int32_t as_int32() const {
|
||||
// Not ZigZag encoded
|
||||
return static_cast<int32_t>(this->as_int64());
|
||||
}
|
||||
constexpr int64_t as_int64() const {
|
||||
// Not ZigZag encoded
|
||||
return static_cast<int64_t>(this->value_);
|
||||
return static_cast<int32_t>(this->value_);
|
||||
}
|
||||
constexpr int32_t as_sint32() const {
|
||||
// with ZigZag encoding
|
||||
return decode_zigzag32(static_cast<uint32_t>(this->value_));
|
||||
}
|
||||
#ifdef USE_API_VARINT64
|
||||
constexpr uint64_t as_uint64() const { return this->value_; }
|
||||
constexpr int64_t as_int64() const {
|
||||
// Not ZigZag encoded
|
||||
return static_cast<int64_t>(this->value_);
|
||||
}
|
||||
constexpr int64_t as_sint64() const {
|
||||
// with ZigZag encoding
|
||||
return decode_zigzag64(this->value_);
|
||||
}
|
||||
#endif
|
||||
|
||||
protected:
|
||||
#ifdef USE_API_VARINT64
|
||||
uint64_t value_;
|
||||
#else
|
||||
uint32_t value_;
|
||||
#endif
|
||||
};
|
||||
|
||||
// Forward declarations for decode_to_message, encode_message and encode_packed_sint32
|
||||
|
||||
@@ -144,6 +144,7 @@
|
||||
#define USE_API_HOMEASSISTANT_SERVICES
|
||||
#define USE_API_HOMEASSISTANT_STATES
|
||||
#define USE_API_NOISE
|
||||
#define USE_API_VARINT64
|
||||
#define USE_API_PLAINTEXT
|
||||
#define USE_API_USER_DEFINED_ACTIONS
|
||||
#define USE_API_CUSTOM_SERVICES
|
||||
|
||||
@@ -1913,6 +1913,37 @@ def build_type_usage_map(
|
||||
)
|
||||
|
||||
|
||||
def get_varint64_ifdef(
|
||||
file_desc: descriptor.FileDescriptorProto,
|
||||
message_ifdef_map: dict[str, str | None],
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Check if 64-bit varint fields exist and get their common ifdef guard.
|
||||
|
||||
Returns:
|
||||
(has_varint64, ifdef_guard) - has_varint64 is True if any fields exist,
|
||||
ifdef_guard is the common guard or None if unconditional.
|
||||
"""
|
||||
varint64_types = {
|
||||
FieldDescriptorProto.TYPE_INT64,
|
||||
FieldDescriptorProto.TYPE_UINT64,
|
||||
FieldDescriptorProto.TYPE_SINT64,
|
||||
}
|
||||
ifdefs: set[str | None] = {
|
||||
message_ifdef_map.get(msg.name)
|
||||
for msg in file_desc.message_type
|
||||
if not msg.options.deprecated
|
||||
for field in msg.field
|
||||
if not field.options.deprecated and field.type in varint64_types
|
||||
}
|
||||
if not ifdefs:
|
||||
return False, None
|
||||
if None in ifdefs:
|
||||
# At least one 64-bit varint field is unconditional, so the guard must be unconditional.
|
||||
return True, None
|
||||
ifdefs.discard(None)
|
||||
return True, ifdefs.pop() if len(ifdefs) == 1 else None
|
||||
|
||||
|
||||
def build_enum_type(desc, enum_ifdef_map) -> tuple[str, str, str]:
|
||||
"""Builds the enum type.
|
||||
|
||||
@@ -2567,11 +2598,38 @@ def main() -> None:
|
||||
|
||||
file = d.file[0]
|
||||
|
||||
# Build dynamic ifdef mappings early so we can emit USE_API_VARINT64 before includes
|
||||
enum_ifdef_map, message_ifdef_map, message_source_map, used_messages = (
|
||||
build_type_usage_map(file)
|
||||
)
|
||||
|
||||
# Find the ifdef guard for 64-bit varint fields (int64/uint64/sint64).
|
||||
# Generated into api_pb2_defines.h so proto.h can include it, ensuring
|
||||
# consistent ProtoVarInt layout across all translation units.
|
||||
has_varint64, varint64_guard = get_varint64_ifdef(file, message_ifdef_map)
|
||||
|
||||
# Generate api_pb2_defines.h — included by proto.h to ensure all translation
|
||||
# units see USE_API_VARINT64 consistently (avoids ODR violations in ProtoVarInt).
|
||||
defines_content = FILE_HEADER
|
||||
defines_content += "#pragma once\n\n"
|
||||
defines_content += '#include "esphome/core/defines.h"\n'
|
||||
if has_varint64:
|
||||
lines = [
|
||||
"#ifndef USE_API_VARINT64",
|
||||
"#define USE_API_VARINT64",
|
||||
"#endif",
|
||||
]
|
||||
defines_content += "\n".join(wrap_with_ifdef(lines, varint64_guard))
|
||||
defines_content += "\n"
|
||||
defines_content += "\nnamespace esphome::api {} // namespace esphome::api\n"
|
||||
|
||||
with open(root / "api_pb2_defines.h", "w", encoding="utf-8") as f:
|
||||
f.write(defines_content)
|
||||
|
||||
content = FILE_HEADER
|
||||
content += """\
|
||||
#pragma once
|
||||
|
||||
#include "esphome/core/defines.h"
|
||||
#include "esphome/core/string_ref.h"
|
||||
|
||||
#include "proto.h"
|
||||
@@ -2702,11 +2760,6 @@ static void dump_bytes_field(DumpBuffer &out, const char *field_name, const uint
|
||||
|
||||
content += "namespace enums {\n\n"
|
||||
|
||||
# Build dynamic ifdef mappings for both enums and messages
|
||||
enum_ifdef_map, message_ifdef_map, message_source_map, used_messages = (
|
||||
build_type_usage_map(file)
|
||||
)
|
||||
|
||||
# Simple grouping of enums by ifdef
|
||||
current_ifdef = None
|
||||
|
||||
|
||||
47
tests/integration/fixtures/varint_five_byte_device_id.yaml
Normal file
47
tests/integration/fixtures/varint_five_byte_device_id.yaml
Normal file
@@ -0,0 +1,47 @@
|
||||
esphome:
|
||||
name: varint-5byte-test
|
||||
# Define areas and devices - device_ids will be FNV hashes > 2^28,
|
||||
# requiring 5-byte varint encoding that exercises the 32-bit parse boundary.
|
||||
areas:
|
||||
- id: test_area
|
||||
name: Test Area
|
||||
devices:
|
||||
- id: sub_device_one
|
||||
name: Sub Device One
|
||||
area_id: test_area
|
||||
- id: sub_device_two
|
||||
name: Sub Device Two
|
||||
area_id: test_area
|
||||
|
||||
host:
|
||||
api:
|
||||
logger:
|
||||
|
||||
# Switches on sub-devices so we can send commands with large device_id varints
|
||||
switch:
|
||||
- platform: template
|
||||
name: Device Switch
|
||||
device_id: sub_device_one
|
||||
id: device_switch_one
|
||||
optimistic: true
|
||||
turn_on_action:
|
||||
- logger.log: "Switch one on"
|
||||
turn_off_action:
|
||||
- logger.log: "Switch one off"
|
||||
|
||||
- platform: template
|
||||
name: Device Switch
|
||||
device_id: sub_device_two
|
||||
id: device_switch_two
|
||||
optimistic: true
|
||||
turn_on_action:
|
||||
- logger.log: "Switch two on"
|
||||
turn_off_action:
|
||||
- logger.log: "Switch two off"
|
||||
|
||||
sensor:
|
||||
- platform: template
|
||||
name: Device Sensor
|
||||
device_id: sub_device_one
|
||||
lambda: return 42.0;
|
||||
update_interval: 0.1s
|
||||
120
tests/integration/test_varint_five_byte_device_id.py
Normal file
120
tests/integration/test_varint_five_byte_device_id.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Integration test for 5-byte varint parsing of device_id fields.
|
||||
|
||||
Device IDs are FNV hashes (uint32) that frequently exceed 2^28 (268435456),
|
||||
requiring 5 varint bytes. This test verifies that:
|
||||
1. The firmware correctly decodes 5-byte varint device_id in incoming commands
|
||||
2. The firmware correctly encodes large device_id values in state responses
|
||||
3. Switch commands with large device_id reach the correct entity
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from aioesphomeapi import EntityState, SwitchInfo, SwitchState
|
||||
import pytest
|
||||
|
||||
from .types import APIClientConnectedFactory, RunCompiledFunction
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_varint_five_byte_device_id(
|
||||
yaml_config: str,
|
||||
run_compiled: RunCompiledFunction,
|
||||
api_client_connected: APIClientConnectedFactory,
|
||||
) -> None:
|
||||
"""Test that device_id values requiring 5-byte varints parse correctly."""
|
||||
async with run_compiled(yaml_config), api_client_connected() as client:
|
||||
device_info = await client.device_info()
|
||||
devices = device_info.devices
|
||||
assert len(devices) >= 2, f"Expected at least 2 devices, got {len(devices)}"
|
||||
|
||||
# Verify at least one device_id exceeds the 4-byte varint boundary (2^28)
|
||||
large_ids = [d for d in devices if d.device_id >= (1 << 28)]
|
||||
assert len(large_ids) > 0, (
|
||||
"Expected at least one device_id >= 2^28 to exercise 5-byte varint path. "
|
||||
f"Got device_ids: {[d.device_id for d in devices]}"
|
||||
)
|
||||
|
||||
# Get entities
|
||||
all_entities, _ = await client.list_entities_services()
|
||||
switch_entities = [e for e in all_entities if isinstance(e, SwitchInfo)]
|
||||
|
||||
# Find switches named "Device Switch" — one per sub-device
|
||||
device_switches = [e for e in switch_entities if e.name == "Device Switch"]
|
||||
assert len(device_switches) == 2, (
|
||||
f"Expected 2 'Device Switch' entities, got {len(device_switches)}"
|
||||
)
|
||||
|
||||
# Verify switches have different device_ids matching the sub-devices
|
||||
switch_device_ids = {s.device_id for s in device_switches}
|
||||
assert len(switch_device_ids) == 2, "Switches should have different device_ids"
|
||||
|
||||
# Subscribe to states and wait for initial states
|
||||
loop = asyncio.get_running_loop()
|
||||
states: dict[tuple[int, int], EntityState] = {}
|
||||
switch_futures: dict[tuple[int, int], asyncio.Future[EntityState]] = {}
|
||||
initial_done: asyncio.Future[bool] = loop.create_future()
|
||||
|
||||
def on_state(state: EntityState) -> None:
|
||||
key = (state.device_id, state.key)
|
||||
states[key] = state
|
||||
|
||||
if len(states) >= 3 and not initial_done.done():
|
||||
initial_done.set_result(True)
|
||||
|
||||
if initial_done.done() and key in switch_futures:
|
||||
fut = switch_futures[key]
|
||||
if not fut.done() and isinstance(state, SwitchState):
|
||||
fut.set_result(state)
|
||||
|
||||
client.subscribe_states(on_state)
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(initial_done, timeout=10.0)
|
||||
except TimeoutError:
|
||||
pytest.fail(
|
||||
f"Timed out waiting for initial states. Got {len(states)} states"
|
||||
)
|
||||
|
||||
# Verify state responses contain correct large device_id values
|
||||
for device in devices:
|
||||
device_states = [
|
||||
s for (did, _), s in states.items() if did == device.device_id
|
||||
]
|
||||
assert len(device_states) > 0, (
|
||||
f"No states received for device '{device.name}' "
|
||||
f"(device_id={device.device_id})"
|
||||
)
|
||||
|
||||
# Test switch commands with large device_id varints —
|
||||
# this is the critical path: the client encodes device_id as a varint
|
||||
# in the SwitchCommandRequest, and the firmware must decode it correctly.
|
||||
for switch in device_switches:
|
||||
state_key = (switch.device_id, switch.key)
|
||||
|
||||
# Turn on
|
||||
switch_futures[state_key] = loop.create_future()
|
||||
client.switch_command(switch.key, True, device_id=switch.device_id)
|
||||
try:
|
||||
await asyncio.wait_for(switch_futures[state_key], timeout=2.0)
|
||||
except TimeoutError:
|
||||
pytest.fail(
|
||||
f"Timed out waiting for switch ON state "
|
||||
f"(device_id={switch.device_id}, key={switch.key}). "
|
||||
f"This likely means the firmware failed to decode the "
|
||||
f"5-byte varint device_id in SwitchCommandRequest."
|
||||
)
|
||||
assert states[state_key].state is True
|
||||
|
||||
# Turn off
|
||||
switch_futures[state_key] = loop.create_future()
|
||||
client.switch_command(switch.key, False, device_id=switch.device_id)
|
||||
try:
|
||||
await asyncio.wait_for(switch_futures[state_key], timeout=2.0)
|
||||
except TimeoutError:
|
||||
pytest.fail(
|
||||
f"Timed out waiting for switch OFF state "
|
||||
f"(device_id={switch.device_id}, key={switch.key})"
|
||||
)
|
||||
assert states[state_key].state is False
|
||||
Reference in New Issue
Block a user