diff --git a/esphome/components/api/api_connection.h b/esphome/components/api/api_connection.h index ae7f864568..7f738a9bfd 100644 --- a/esphome/components/api/api_connection.h +++ b/esphome/components/api/api_connection.h @@ -28,7 +28,7 @@ static constexpr size_t MAX_INITIAL_PER_BATCH = 34; // For clients >= AP static_assert(MAX_MESSAGES_PER_BATCH >= MAX_INITIAL_PER_BATCH, "MAX_MESSAGES_PER_BATCH must be >= MAX_INITIAL_PER_BATCH"); -class APIConnection final : public APIServerConnection { +class APIConnection final : public APIServerConnectionBase { public: friend class APIServer; friend class ListEntitiesIterator; diff --git a/esphome/components/api/api_pb2_service.cpp b/esphome/components/api/api_pb2_service.cpp index 1c04eacc82..2d15deb90d 100644 --- a/esphome/components/api/api_pb2_service.cpp +++ b/esphome/components/api/api_pb2_service.cpp @@ -21,6 +21,23 @@ void APIServerConnectionBase::log_receive_message_(const LogString *name) { #endif void APIServerConnectionBase::read_message(uint32_t msg_size, uint32_t msg_type, const uint8_t *msg_data) { + // Check authentication/connection requirements + switch (msg_type) { + case HelloRequest::MESSAGE_TYPE: // No setup required + case DisconnectRequest::MESSAGE_TYPE: // No setup required + case PingRequest::MESSAGE_TYPE: // No setup required + break; + case DeviceInfoRequest::MESSAGE_TYPE: // Connection setup only + if (!this->check_connection_setup_()) { + return; + } + break; + default: + if (!this->check_authenticated_()) { + return; + } + break; + } switch (msg_type) { case HelloRequest::MESSAGE_TYPE: { HelloRequest msg; @@ -623,28 +640,4 @@ void APIServerConnectionBase::read_message(uint32_t msg_size, uint32_t msg_type, } } -void APIServerConnection::read_message(uint32_t msg_size, uint32_t msg_type, const uint8_t *msg_data) { - // Check authentication/connection requirements for messages - switch (msg_type) { - case HelloRequest::MESSAGE_TYPE: // No setup required - case DisconnectRequest::MESSAGE_TYPE: // No setup required - case PingRequest::MESSAGE_TYPE: // No setup required - break; // Skip all checks for these messages - case DeviceInfoRequest::MESSAGE_TYPE: // Connection setup only - if (!this->check_connection_setup_()) { - return; // Connection not setup - } - break; - default: - // All other messages require authentication (which includes connection check) - if (!this->check_authenticated_()) { - return; // Authentication failed - } - break; - } - - // Call base implementation to process the message - APIServerConnectionBase::read_message(msg_size, msg_type, msg_data); -} - } // namespace esphome::api diff --git a/esphome/components/api/api_pb2_service.h b/esphome/components/api/api_pb2_service.h index 4dc6ce27d0..1441507406 100644 --- a/esphome/components/api/api_pb2_service.h +++ b/esphome/components/api/api_pb2_service.h @@ -228,9 +228,4 @@ class APIServerConnectionBase : public ProtoService { void read_message(uint32_t msg_size, uint32_t msg_type, const uint8_t *msg_data) override; }; -class APIServerConnection : public APIServerConnectionBase { - protected: - void read_message(uint32_t msg_size, uint32_t msg_type, const uint8_t *msg_data) override; -}; - } // namespace esphome::api diff --git a/script/api_protobuf/api_protobuf.py b/script/api_protobuf/api_protobuf.py index 8673996a25..e022b9e2d2 100755 --- a/script/api_protobuf/api_protobuf.py +++ b/script/api_protobuf/api_protobuf.py @@ -2881,9 +2881,78 @@ static const char *const TAG = "api.service"; cases = list(RECEIVE_CASES.items()) cases.sort() + + serv = file.service[0] + + # Build a mapping of message input types to their authentication requirements + message_auth_map: dict[str, bool] = {} + message_conn_map: dict[str, bool] = {} + + for m in serv.method: + inp = m.input_type[1:] + needs_conn = get_opt(m, pb.needs_setup_connection, True) + needs_auth = get_opt(m, pb.needs_authentication, True) + + # Store authentication requirements for message types + message_auth_map[inp] = needs_auth + message_conn_map[inp] = needs_conn + + # Categorize messages by their authentication requirements + no_conn_ids: set[int] = set() + conn_only_ids: set[int] = set() + + for id_, (_, _, case_msg_name) in cases: + if case_msg_name in message_auth_map: + needs_auth = message_auth_map[case_msg_name] + needs_conn = message_conn_map[case_msg_name] + + if not needs_conn: + no_conn_ids.add(id_) + elif not needs_auth: + conn_only_ids.add(id_) + + # Helper to generate case statements with ifdefs + def generate_cases(ids: set[int], comment: str) -> str: + result = "" + for id_ in sorted(ids): + _, ifdef, msg_name = RECEIVE_CASES[id_] + if ifdef: + result += f"#ifdef {ifdef}\n" + result += f" case {msg_name}::MESSAGE_TYPE: {comment}\n" + if ifdef: + result += "#endif\n" + return result + + # Generate read_message with auth check before dispatch hpp += " protected:\n" hpp += " void read_message(uint32_t msg_size, uint32_t msg_type, const uint8_t *msg_data) override;\n" + out = f"void {class_name}::read_message(uint32_t msg_size, uint32_t msg_type, const uint8_t *msg_data) {{\n" + + # Auth check block before dispatch switch + if no_conn_ids or conn_only_ids: + out += " // Check authentication/connection requirements\n" + out += " switch (msg_type) {\n" + + if no_conn_ids: + out += generate_cases(no_conn_ids, "// No setup required") + out += " break;\n" + + if conn_only_ids: + out += generate_cases(conn_only_ids, "// Connection setup only") + out += " if (!this->check_connection_setup_()) {\n" + out += " return;\n" + out += " }\n" + out += " break;\n" + + out += " default:\n" + out += " if (!this->check_authenticated_()) {\n" + out += " return;\n" + out += " }\n" + out += " break;\n" + out += " }\n" + + # Dispatch switch out += " switch (msg_type) {\n" for i, (case, ifdef, message_name) in cases: if ifdef is not None: @@ -2902,89 +2971,6 @@ static const char *const TAG = "api.service"; cpp += out hpp += "};\n" - serv = file.service[0] - class_name = "APIServerConnection" - hpp += "\n" - hpp += f"class {class_name} : public {class_name}Base {{\n" - hpp_protected = "" - cpp += "\n" - - # Build a mapping of message input types to their authentication requirements - message_auth_map: dict[str, bool] = {} - message_conn_map: dict[str, bool] = {} - - for m in serv.method: - inp = m.input_type[1:] - needs_conn = get_opt(m, pb.needs_setup_connection, True) - needs_auth = get_opt(m, pb.needs_authentication, True) - - # Store authentication requirements for message types - message_auth_map[inp] = needs_auth - message_conn_map[inp] = needs_conn - - # Generate optimized read_message with authentication checking - # Categorize messages by their authentication requirements - no_conn_ids: set[int] = set() - conn_only_ids: set[int] = set() - - for id_, (_, _, case_msg_name) in cases: - if case_msg_name in message_auth_map: - needs_auth = message_auth_map[case_msg_name] - needs_conn = message_conn_map[case_msg_name] - - if not needs_conn: - no_conn_ids.add(id_) - elif not needs_auth: - conn_only_ids.add(id_) - - # Generate override if we have messages that skip checks - if no_conn_ids or conn_only_ids: - # Helper to generate case statements with ifdefs - def generate_cases(ids: set[int], comment: str) -> str: - result = "" - for id_ in sorted(ids): - _, ifdef, msg_name = RECEIVE_CASES[id_] - if ifdef: - result += f"#ifdef {ifdef}\n" - result += f" case {msg_name}::MESSAGE_TYPE: {comment}\n" - if ifdef: - result += "#endif\n" - return result - - hpp_protected += " void read_message(uint32_t msg_size, uint32_t msg_type, const uint8_t *msg_data) override;\n" - - cpp += f"\nvoid {class_name}::read_message(uint32_t msg_size, uint32_t msg_type, const uint8_t *msg_data) {{\n" - cpp += " // Check authentication/connection requirements for messages\n" - cpp += " switch (msg_type) {\n" - - # Messages that don't need any checks - if no_conn_ids: - cpp += generate_cases(no_conn_ids, "// No setup required") - cpp += " break; // Skip all checks for these messages\n" - - # Messages that only need connection setup - if conn_only_ids: - cpp += generate_cases(conn_only_ids, "// Connection setup only") - cpp += " if (!this->check_connection_setup_()) {\n" - cpp += " return; // Connection not setup\n" - cpp += " }\n" - cpp += " break;\n" - - cpp += " default:\n" - cpp += " // All other messages require authentication (which includes connection check)\n" - cpp += " if (!this->check_authenticated_()) {\n" - cpp += " return; // Authentication failed\n" - cpp += " }\n" - cpp += " break;\n" - cpp += " }\n\n" - cpp += " // Call base implementation to process the message\n" - cpp += f" {class_name}Base::read_message(msg_size, msg_type, msg_data);\n" - cpp += "}\n" - - hpp += " protected:\n" - hpp += hpp_protected - hpp += "};\n" - hpp += """\ } // namespace esphome::api