mirror of
https://github.com/esphome/esphome.git
synced 2026-02-18 15:35:59 -07:00
Merge auth check into base read_message, eliminate APIServerConnection
Move the authentication/connection check switch into APIServerConnectionBase::read_message before the dispatch switch, eliminating the APIServerConnection class entirely. APIConnection now inherits directly from APIServerConnectionBase.
This commit is contained in:
@@ -28,7 +28,7 @@ static constexpr size_t MAX_INITIAL_PER_BATCH = 34; // For clients >= AP
|
||||
static_assert(MAX_MESSAGES_PER_BATCH >= MAX_INITIAL_PER_BATCH,
|
||||
"MAX_MESSAGES_PER_BATCH must be >= MAX_INITIAL_PER_BATCH");
|
||||
|
||||
class APIConnection final : public APIServerConnection {
|
||||
class APIConnection final : public APIServerConnectionBase {
|
||||
public:
|
||||
friend class APIServer;
|
||||
friend class ListEntitiesIterator;
|
||||
|
||||
@@ -21,6 +21,23 @@ void APIServerConnectionBase::log_receive_message_(const LogString *name) {
|
||||
#endif
|
||||
|
||||
void APIServerConnectionBase::read_message(uint32_t msg_size, uint32_t msg_type, const uint8_t *msg_data) {
|
||||
// Check authentication/connection requirements
|
||||
switch (msg_type) {
|
||||
case HelloRequest::MESSAGE_TYPE: // No setup required
|
||||
case DisconnectRequest::MESSAGE_TYPE: // No setup required
|
||||
case PingRequest::MESSAGE_TYPE: // No setup required
|
||||
break;
|
||||
case DeviceInfoRequest::MESSAGE_TYPE: // Connection setup only
|
||||
if (!this->check_connection_setup_()) {
|
||||
return;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
if (!this->check_authenticated_()) {
|
||||
return;
|
||||
}
|
||||
break;
|
||||
}
|
||||
switch (msg_type) {
|
||||
case HelloRequest::MESSAGE_TYPE: {
|
||||
HelloRequest msg;
|
||||
@@ -623,28 +640,4 @@ void APIServerConnectionBase::read_message(uint32_t msg_size, uint32_t msg_type,
|
||||
}
|
||||
}
|
||||
|
||||
void APIServerConnection::read_message(uint32_t msg_size, uint32_t msg_type, const uint8_t *msg_data) {
|
||||
// Check authentication/connection requirements for messages
|
||||
switch (msg_type) {
|
||||
case HelloRequest::MESSAGE_TYPE: // No setup required
|
||||
case DisconnectRequest::MESSAGE_TYPE: // No setup required
|
||||
case PingRequest::MESSAGE_TYPE: // No setup required
|
||||
break; // Skip all checks for these messages
|
||||
case DeviceInfoRequest::MESSAGE_TYPE: // Connection setup only
|
||||
if (!this->check_connection_setup_()) {
|
||||
return; // Connection not setup
|
||||
}
|
||||
break;
|
||||
default:
|
||||
// All other messages require authentication (which includes connection check)
|
||||
if (!this->check_authenticated_()) {
|
||||
return; // Authentication failed
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
// Call base implementation to process the message
|
||||
APIServerConnectionBase::read_message(msg_size, msg_type, msg_data);
|
||||
}
|
||||
|
||||
} // namespace esphome::api
|
||||
|
||||
@@ -228,9 +228,4 @@ class APIServerConnectionBase : public ProtoService {
|
||||
void read_message(uint32_t msg_size, uint32_t msg_type, const uint8_t *msg_data) override;
|
||||
};
|
||||
|
||||
class APIServerConnection : public APIServerConnectionBase {
|
||||
protected:
|
||||
void read_message(uint32_t msg_size, uint32_t msg_type, const uint8_t *msg_data) override;
|
||||
};
|
||||
|
||||
} // namespace esphome::api
|
||||
|
||||
@@ -2881,9 +2881,78 @@ static const char *const TAG = "api.service";
|
||||
|
||||
cases = list(RECEIVE_CASES.items())
|
||||
cases.sort()
|
||||
|
||||
serv = file.service[0]
|
||||
|
||||
# Build a mapping of message input types to their authentication requirements
|
||||
message_auth_map: dict[str, bool] = {}
|
||||
message_conn_map: dict[str, bool] = {}
|
||||
|
||||
for m in serv.method:
|
||||
inp = m.input_type[1:]
|
||||
needs_conn = get_opt(m, pb.needs_setup_connection, True)
|
||||
needs_auth = get_opt(m, pb.needs_authentication, True)
|
||||
|
||||
# Store authentication requirements for message types
|
||||
message_auth_map[inp] = needs_auth
|
||||
message_conn_map[inp] = needs_conn
|
||||
|
||||
# Categorize messages by their authentication requirements
|
||||
no_conn_ids: set[int] = set()
|
||||
conn_only_ids: set[int] = set()
|
||||
|
||||
for id_, (_, _, case_msg_name) in cases:
|
||||
if case_msg_name in message_auth_map:
|
||||
needs_auth = message_auth_map[case_msg_name]
|
||||
needs_conn = message_conn_map[case_msg_name]
|
||||
|
||||
if not needs_conn:
|
||||
no_conn_ids.add(id_)
|
||||
elif not needs_auth:
|
||||
conn_only_ids.add(id_)
|
||||
|
||||
# Helper to generate case statements with ifdefs
|
||||
def generate_cases(ids: set[int], comment: str) -> str:
|
||||
result = ""
|
||||
for id_ in sorted(ids):
|
||||
_, ifdef, msg_name = RECEIVE_CASES[id_]
|
||||
if ifdef:
|
||||
result += f"#ifdef {ifdef}\n"
|
||||
result += f" case {msg_name}::MESSAGE_TYPE: {comment}\n"
|
||||
if ifdef:
|
||||
result += "#endif\n"
|
||||
return result
|
||||
|
||||
# Generate read_message with auth check before dispatch
|
||||
hpp += " protected:\n"
|
||||
hpp += " void read_message(uint32_t msg_size, uint32_t msg_type, const uint8_t *msg_data) override;\n"
|
||||
|
||||
out = f"void {class_name}::read_message(uint32_t msg_size, uint32_t msg_type, const uint8_t *msg_data) {{\n"
|
||||
|
||||
# Auth check block before dispatch switch
|
||||
if no_conn_ids or conn_only_ids:
|
||||
out += " // Check authentication/connection requirements\n"
|
||||
out += " switch (msg_type) {\n"
|
||||
|
||||
if no_conn_ids:
|
||||
out += generate_cases(no_conn_ids, "// No setup required")
|
||||
out += " break;\n"
|
||||
|
||||
if conn_only_ids:
|
||||
out += generate_cases(conn_only_ids, "// Connection setup only")
|
||||
out += " if (!this->check_connection_setup_()) {\n"
|
||||
out += " return;\n"
|
||||
out += " }\n"
|
||||
out += " break;\n"
|
||||
|
||||
out += " default:\n"
|
||||
out += " if (!this->check_authenticated_()) {\n"
|
||||
out += " return;\n"
|
||||
out += " }\n"
|
||||
out += " break;\n"
|
||||
out += " }\n"
|
||||
|
||||
# Dispatch switch
|
||||
out += " switch (msg_type) {\n"
|
||||
for i, (case, ifdef, message_name) in cases:
|
||||
if ifdef is not None:
|
||||
@@ -2902,89 +2971,6 @@ static const char *const TAG = "api.service";
|
||||
cpp += out
|
||||
hpp += "};\n"
|
||||
|
||||
serv = file.service[0]
|
||||
class_name = "APIServerConnection"
|
||||
hpp += "\n"
|
||||
hpp += f"class {class_name} : public {class_name}Base {{\n"
|
||||
hpp_protected = ""
|
||||
cpp += "\n"
|
||||
|
||||
# Build a mapping of message input types to their authentication requirements
|
||||
message_auth_map: dict[str, bool] = {}
|
||||
message_conn_map: dict[str, bool] = {}
|
||||
|
||||
for m in serv.method:
|
||||
inp = m.input_type[1:]
|
||||
needs_conn = get_opt(m, pb.needs_setup_connection, True)
|
||||
needs_auth = get_opt(m, pb.needs_authentication, True)
|
||||
|
||||
# Store authentication requirements for message types
|
||||
message_auth_map[inp] = needs_auth
|
||||
message_conn_map[inp] = needs_conn
|
||||
|
||||
# Generate optimized read_message with authentication checking
|
||||
# Categorize messages by their authentication requirements
|
||||
no_conn_ids: set[int] = set()
|
||||
conn_only_ids: set[int] = set()
|
||||
|
||||
for id_, (_, _, case_msg_name) in cases:
|
||||
if case_msg_name in message_auth_map:
|
||||
needs_auth = message_auth_map[case_msg_name]
|
||||
needs_conn = message_conn_map[case_msg_name]
|
||||
|
||||
if not needs_conn:
|
||||
no_conn_ids.add(id_)
|
||||
elif not needs_auth:
|
||||
conn_only_ids.add(id_)
|
||||
|
||||
# Generate override if we have messages that skip checks
|
||||
if no_conn_ids or conn_only_ids:
|
||||
# Helper to generate case statements with ifdefs
|
||||
def generate_cases(ids: set[int], comment: str) -> str:
|
||||
result = ""
|
||||
for id_ in sorted(ids):
|
||||
_, ifdef, msg_name = RECEIVE_CASES[id_]
|
||||
if ifdef:
|
||||
result += f"#ifdef {ifdef}\n"
|
||||
result += f" case {msg_name}::MESSAGE_TYPE: {comment}\n"
|
||||
if ifdef:
|
||||
result += "#endif\n"
|
||||
return result
|
||||
|
||||
hpp_protected += " void read_message(uint32_t msg_size, uint32_t msg_type, const uint8_t *msg_data) override;\n"
|
||||
|
||||
cpp += f"\nvoid {class_name}::read_message(uint32_t msg_size, uint32_t msg_type, const uint8_t *msg_data) {{\n"
|
||||
cpp += " // Check authentication/connection requirements for messages\n"
|
||||
cpp += " switch (msg_type) {\n"
|
||||
|
||||
# Messages that don't need any checks
|
||||
if no_conn_ids:
|
||||
cpp += generate_cases(no_conn_ids, "// No setup required")
|
||||
cpp += " break; // Skip all checks for these messages\n"
|
||||
|
||||
# Messages that only need connection setup
|
||||
if conn_only_ids:
|
||||
cpp += generate_cases(conn_only_ids, "// Connection setup only")
|
||||
cpp += " if (!this->check_connection_setup_()) {\n"
|
||||
cpp += " return; // Connection not setup\n"
|
||||
cpp += " }\n"
|
||||
cpp += " break;\n"
|
||||
|
||||
cpp += " default:\n"
|
||||
cpp += " // All other messages require authentication (which includes connection check)\n"
|
||||
cpp += " if (!this->check_authenticated_()) {\n"
|
||||
cpp += " return; // Authentication failed\n"
|
||||
cpp += " }\n"
|
||||
cpp += " break;\n"
|
||||
cpp += " }\n\n"
|
||||
cpp += " // Call base implementation to process the message\n"
|
||||
cpp += f" {class_name}Base::read_message(msg_size, msg_type, msg_data);\n"
|
||||
cpp += "}\n"
|
||||
|
||||
hpp += " protected:\n"
|
||||
hpp += hpp_protected
|
||||
hpp += "};\n"
|
||||
|
||||
hpp += """\
|
||||
|
||||
} // namespace esphome::api
|
||||
|
||||
Reference in New Issue
Block a user