Merge branch 'api-proto-write-presized-buffer' into integration

# Conflicts:
#	esphome/components/api/proto.h
This commit is contained in:
J. Nick Koston
2026-02-16 15:17:06 -06:00
4 changed files with 48 additions and 22 deletions

View File

@@ -111,7 +111,7 @@ void DeviceInfoResponse::encode(ProtoWriteBuffer buffer) const {
}
#endif
#ifdef USE_AREAS
buffer.encode_message(22, this->area);
buffer.encode_message(22, this->area, false);
#endif
#ifdef USE_ZWAVE_PROXY
buffer.encode_uint32(23, this->zwave_proxy_feature_flags);
@@ -2435,7 +2435,7 @@ void VoiceAssistantRequest::encode(ProtoWriteBuffer buffer) const {
buffer.encode_bool(1, this->start);
buffer.encode_string(2, this->conversation_id);
buffer.encode_uint32(3, this->flags);
buffer.encode_message(4, this->audio_settings);
buffer.encode_message(4, this->audio_settings, false);
buffer.encode_string(5, this->wake_word_phrase);
}
void VoiceAssistantRequest::calculate_size(ProtoSize &size) const {

View File

@@ -70,6 +70,16 @@ uint32_t ProtoDecodableMessage::count_repeated_field(const uint8_t *buffer, size
return count;
}
#ifdef ESPHOME_DEBUG_API
void ProtoWriteBuffer::debug_check_bounds_(size_t bytes, const char *caller) {
if (this->pos_ + bytes > this->buffer_->data() + this->buffer_->size()) {
ESP_LOGE(TAG, "ProtoWriteBuffer bounds check failed in %s: bytes=%zu offset=%td buf_size=%zu", caller, bytes,
this->pos_ - this->buffer_->data(), this->buffer_->size());
abort();
}
}
#endif
void ProtoDecodableMessage::decode(const uint8_t *buffer, size_t length) {
const uint8_t *ptr = buffer;
const uint8_t *end = buffer + length;

View File

@@ -220,10 +220,6 @@ class ProtoWriteBuffer {
ProtoWriteBuffer(std::vector<uint8_t> *buffer) : buffer_(buffer), pos_(buffer->data() + buffer->size()) {}
ProtoWriteBuffer(std::vector<uint8_t> *buffer, size_t write_pos)
: buffer_(buffer), pos_(buffer->data() + write_pos) {}
void write(uint8_t value) {
this->debug_check_bounds_(1);
*this->pos_++ = value;
}
void encode_varint_raw(uint32_t value) {
while (value > 0x7F) {
this->debug_check_bounds_(1);
@@ -254,10 +250,7 @@ class ProtoWriteBuffer {
*
* Following https://protobuf.dev/programming-guides/encoding/#structure
*/
void encode_field_raw(uint32_t field_id, uint32_t type) {
uint32_t val = (field_id << 3) | (type & WIRE_TYPE_MASK);
this->encode_varint_raw(val);
}
void encode_field_raw(uint32_t field_id, uint32_t type) { this->encode_varint_raw((field_id << 3) | type); }
void encode_string(uint32_t field_id, const char *string, size_t len, bool force = false) {
if (len == 0 && !force)
return;
@@ -298,15 +291,22 @@ class ProtoWriteBuffer {
this->debug_check_bounds_(1);
*this->pos_++ = value ? 0x01 : 0x00;
}
void encode_fixed32(uint32_t field_id, uint32_t value, bool force = false) {
__attribute__((noinline)) void encode_fixed32(uint32_t field_id, uint32_t value, bool force = false) {
if (value == 0 && !force)
return;
this->encode_field_raw(field_id, 5); // type 5: 32-bit fixed32
this->write((value >> 0) & 0xFF);
this->write((value >> 8) & 0xFF);
this->write((value >> 16) & 0xFF);
this->write((value >> 24) & 0xFF);
this->debug_check_bounds_(4);
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
// Protobuf fixed32 is little-endian, so direct copy works
std::memcpy(this->pos_, &value, 4);
this->pos_ += 4;
#else
*this->pos_++ = (value >> 0) & 0xFF;
*this->pos_++ = (value >> 8) & 0xFF;
*this->pos_++ = (value >> 16) & 0xFF;
*this->pos_++ = (value >> 24) & 0xFF;
#endif
}
// NOTE: Wire type 1 (64-bit fixed: double, fixed64, sfixed64) is intentionally
// not supported to reduce overhead on embedded systems. All ESPHome devices are
@@ -342,15 +342,17 @@ class ProtoWriteBuffer {
}
/// Encode a packed repeated sint32 field (zero-copy from vector)
void encode_packed_sint32(uint32_t field_id, const std::vector<int32_t> &values);
void encode_message(uint32_t field_id, const ProtoMessage &value);
/// Encode a nested message field (force=true for repeated, false for singular)
void encode_message(uint32_t field_id, const ProtoMessage &value, bool force = true);
std::vector<uint8_t> *get_buffer() const { return buffer_; }
protected:
void debug_check_bounds_([[maybe_unused]] size_t bytes) {
#ifdef ESPHOME_DEBUG_API
assert(this->pos_ + bytes <= this->buffer_->data() + this->buffer_->size());
void debug_check_bounds_(size_t bytes, const char *caller = __builtin_FUNCTION());
#else
void debug_check_bounds_(size_t bytes, const char *caller = __builtin_FUNCTION()) {}
#endif
}
std::vector<uint8_t> *buffer_;
uint8_t *pos_;
};
@@ -921,14 +923,19 @@ inline void ProtoWriteBuffer::encode_packed_sint32(uint32_t field_id, const std:
}
// Implementation of encode_message - must be after ProtoMessage is defined
inline void ProtoWriteBuffer::encode_message(uint32_t field_id, const ProtoMessage &value) {
this->encode_field_raw(field_id, 2); // type 2: Length-delimited message
inline void ProtoWriteBuffer::encode_message(uint32_t field_id, const ProtoMessage &value, bool force) {
// Calculate the message size first
ProtoSize msg_size;
value.calculate_size(msg_size);
uint32_t msg_length_bytes = msg_size.get_size();
// Skip empty singular messages (matches add_message_field which skips when nested_size == 0)
// Repeated messages (force=true) are always encoded since an empty item is meaningful
if (msg_length_bytes == 0 && !force)
return;
this->encode_field_raw(field_id, 2); // type 2: Length-delimited message
// Write the length varint directly through pos_
this->encode_varint_raw(msg_length_bytes);
@@ -936,6 +943,7 @@ inline void ProtoWriteBuffer::encode_message(uint32_t field_id, const ProtoMessa
// The copy writes msg_length_bytes bytes starting from our current pos_.
// We then advance our pos_ by the known message size.
value.encode(*this);
this->debug_check_bounds_(msg_length_bytes);
this->pos_ += msg_length_bytes;
}

View File

@@ -689,6 +689,14 @@ class MessageType(TypeInfo):
def encode_func(self) -> str:
return "encode_message"
@property
def encode_content(self) -> str:
# Singular message fields pass force=false (skip empty messages)
# The default for encode_nested_message is force=true (for repeated fields)
return (
f"buffer.{self.encode_func}({self.number}, this->{self.field_name}, false);"
)
@property
def decode_length(self) -> str:
# Override to return None for message types because we can't use template-based