[ota] Replace std::function callbacks with listener interface (#12167)
This commit is contained in:
@@ -5,7 +5,7 @@ import logging
|
||||
|
||||
from esphome import automation
|
||||
import esphome.codegen as cg
|
||||
from esphome.components import esp32_ble
|
||||
from esphome.components import esp32_ble, ota
|
||||
from esphome.components.esp32 import add_idf_sdkconfig_option
|
||||
from esphome.components.esp32_ble import (
|
||||
IDF_MAX_CONNECTIONS,
|
||||
@@ -328,7 +328,7 @@ async def to_code(config):
|
||||
# Note: CONFIG_BT_ACL_CONNECTIONS and CONFIG_BTDM_CTRL_BLE_MAX_CONN are now
|
||||
# configured in esp32_ble component based on max_connections setting
|
||||
|
||||
cg.add_define("USE_OTA_STATE_CALLBACK") # To be notified when an OTA update starts
|
||||
ota.request_ota_state_listeners() # To be notified when an OTA update starts
|
||||
cg.add_define("USE_ESP32_BLE_CLIENT")
|
||||
|
||||
CORE.add_job(_add_ble_features)
|
||||
|
||||
@@ -71,21 +71,24 @@ void ESP32BLETracker::setup() {
|
||||
|
||||
global_esp32_ble_tracker = this;
|
||||
|
||||
#ifdef USE_OTA
|
||||
ota::get_global_ota_callback()->add_on_state_callback(
|
||||
[this](ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) {
|
||||
if (state == ota::OTA_STARTED) {
|
||||
this->stop_scan();
|
||||
#ifdef ESPHOME_ESP32_BLE_TRACKER_CLIENT_COUNT
|
||||
for (auto *client : this->clients_) {
|
||||
client->disconnect();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
});
|
||||
#ifdef USE_OTA_STATE_LISTENER
|
||||
ota::get_global_ota_callback()->add_global_state_listener(this);
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef USE_OTA_STATE_LISTENER
|
||||
void ESP32BLETracker::on_ota_global_state(ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) {
|
||||
if (state == ota::OTA_STARTED) {
|
||||
this->stop_scan();
|
||||
#ifdef ESPHOME_ESP32_BLE_TRACKER_CLIENT_COUNT
|
||||
for (auto *client : this->clients_) {
|
||||
client->disconnect();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void ESP32BLETracker::loop() {
|
||||
if (!this->parent_->is_active()) {
|
||||
this->ble_was_disabled_ = true;
|
||||
|
||||
@@ -22,6 +22,10 @@
|
||||
#include "esphome/components/esp32_ble/ble_uuid.h"
|
||||
#include "esphome/components/esp32_ble/ble_scan_result.h"
|
||||
|
||||
#ifdef USE_OTA_STATE_LISTENER
|
||||
#include "esphome/components/ota/ota_backend.h"
|
||||
#endif
|
||||
|
||||
namespace esphome::esp32_ble_tracker {
|
||||
|
||||
using namespace esp32_ble;
|
||||
@@ -241,6 +245,9 @@ class ESP32BLETracker : public Component,
|
||||
public GAPScanEventHandler,
|
||||
public GATTcEventHandler,
|
||||
public BLEStatusEventHandler,
|
||||
#ifdef USE_OTA_STATE_LISTENER
|
||||
public ota::OTAGlobalStateListener,
|
||||
#endif
|
||||
public Parented<ESP32BLE> {
|
||||
public:
|
||||
void set_scan_duration(uint32_t scan_duration) { scan_duration_ = scan_duration; }
|
||||
@@ -274,6 +281,10 @@ class ESP32BLETracker : public Component,
|
||||
void gap_scan_event_handler(const BLEScanResult &scan_result) override;
|
||||
void ble_before_disabled_event_handler() override;
|
||||
|
||||
#ifdef USE_OTA_STATE_LISTENER
|
||||
void on_ota_global_state(ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) override;
|
||||
#endif
|
||||
|
||||
/// Add a listener for scanner state changes
|
||||
void add_scanner_state_listener(BLEScannerStateListener *listener) {
|
||||
this->scanner_state_listeners_.push_back(listener);
|
||||
|
||||
@@ -41,10 +41,6 @@ static constexpr size_t SHA256_HEX_SIZE = 64; // SHA256 hash as hex string (32
|
||||
#endif // USE_OTA_PASSWORD
|
||||
|
||||
void ESPHomeOTAComponent::setup() {
|
||||
#ifdef USE_OTA_STATE_CALLBACK
|
||||
ota::register_ota_platform(this);
|
||||
#endif
|
||||
|
||||
this->server_ = socket::socket_ip_loop_monitored(SOCK_STREAM, 0); // monitored for incoming connections
|
||||
if (this->server_ == nullptr) {
|
||||
this->log_socket_error_(LOG_STR("creation"));
|
||||
@@ -297,8 +293,8 @@ void ESPHomeOTAComponent::handle_data_() {
|
||||
// accidentally trigger the update process.
|
||||
this->log_start_(LOG_STR("update"));
|
||||
this->status_set_warning();
|
||||
#ifdef USE_OTA_STATE_CALLBACK
|
||||
this->state_callback_.call(ota::OTA_STARTED, 0.0f, 0);
|
||||
#ifdef USE_OTA_STATE_LISTENER
|
||||
this->notify_state_(ota::OTA_STARTED, 0.0f, 0);
|
||||
#endif
|
||||
|
||||
// This will block for a few seconds as it locks flash
|
||||
@@ -357,8 +353,8 @@ void ESPHomeOTAComponent::handle_data_() {
|
||||
last_progress = now;
|
||||
float percentage = (total * 100.0f) / ota_size;
|
||||
ESP_LOGD(TAG, "Progress: %0.1f%%", percentage);
|
||||
#ifdef USE_OTA_STATE_CALLBACK
|
||||
this->state_callback_.call(ota::OTA_IN_PROGRESS, percentage, 0);
|
||||
#ifdef USE_OTA_STATE_LISTENER
|
||||
this->notify_state_(ota::OTA_IN_PROGRESS, percentage, 0);
|
||||
#endif
|
||||
// feed watchdog and give other tasks a chance to run
|
||||
this->yield_and_feed_watchdog_();
|
||||
@@ -387,8 +383,8 @@ void ESPHomeOTAComponent::handle_data_() {
|
||||
delay(10);
|
||||
ESP_LOGI(TAG, "Update complete");
|
||||
this->status_clear_warning();
|
||||
#ifdef USE_OTA_STATE_CALLBACK
|
||||
this->state_callback_.call(ota::OTA_COMPLETED, 100.0f, 0);
|
||||
#ifdef USE_OTA_STATE_LISTENER
|
||||
this->notify_state_(ota::OTA_COMPLETED, 100.0f, 0);
|
||||
#endif
|
||||
delay(100); // NOLINT
|
||||
App.safe_reboot();
|
||||
@@ -402,8 +398,8 @@ error:
|
||||
}
|
||||
|
||||
this->status_momentary_error("err", 5000);
|
||||
#ifdef USE_OTA_STATE_CALLBACK
|
||||
this->state_callback_.call(ota::OTA_ERROR, 0.0f, static_cast<uint8_t>(error_code));
|
||||
#ifdef USE_OTA_STATE_LISTENER
|
||||
this->notify_state_(ota::OTA_ERROR, 0.0f, static_cast<uint8_t>(error_code));
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -16,12 +16,6 @@ namespace http_request {
|
||||
|
||||
static const char *const TAG = "http_request.ota";
|
||||
|
||||
void OtaHttpRequestComponent::setup() {
|
||||
#ifdef USE_OTA_STATE_CALLBACK
|
||||
ota::register_ota_platform(this);
|
||||
#endif
|
||||
}
|
||||
|
||||
void OtaHttpRequestComponent::dump_config() { ESP_LOGCONFIG(TAG, "Over-The-Air updates via HTTP request"); };
|
||||
|
||||
void OtaHttpRequestComponent::set_md5_url(const std::string &url) {
|
||||
@@ -48,24 +42,24 @@ void OtaHttpRequestComponent::flash() {
|
||||
}
|
||||
|
||||
ESP_LOGI(TAG, "Starting update");
|
||||
#ifdef USE_OTA_STATE_CALLBACK
|
||||
this->state_callback_.call(ota::OTA_STARTED, 0.0f, 0);
|
||||
#ifdef USE_OTA_STATE_LISTENER
|
||||
this->notify_state_(ota::OTA_STARTED, 0.0f, 0);
|
||||
#endif
|
||||
|
||||
auto ota_status = this->do_ota_();
|
||||
|
||||
switch (ota_status) {
|
||||
case ota::OTA_RESPONSE_OK:
|
||||
#ifdef USE_OTA_STATE_CALLBACK
|
||||
this->state_callback_.call(ota::OTA_COMPLETED, 100.0f, ota_status);
|
||||
#ifdef USE_OTA_STATE_LISTENER
|
||||
this->notify_state_(ota::OTA_COMPLETED, 100.0f, ota_status);
|
||||
#endif
|
||||
delay(10);
|
||||
App.safe_reboot();
|
||||
break;
|
||||
|
||||
default:
|
||||
#ifdef USE_OTA_STATE_CALLBACK
|
||||
this->state_callback_.call(ota::OTA_ERROR, 0.0f, ota_status);
|
||||
#ifdef USE_OTA_STATE_LISTENER
|
||||
this->notify_state_(ota::OTA_ERROR, 0.0f, ota_status);
|
||||
#endif
|
||||
this->md5_computed_.clear(); // will be reset at next attempt
|
||||
this->md5_expected_.clear(); // will be reset at next attempt
|
||||
@@ -165,8 +159,8 @@ uint8_t OtaHttpRequestComponent::do_ota_() {
|
||||
last_progress = now;
|
||||
float percentage = container->get_bytes_read() * 100.0f / container->content_length;
|
||||
ESP_LOGD(TAG, "Progress: %0.1f%%", percentage);
|
||||
#ifdef USE_OTA_STATE_CALLBACK
|
||||
this->state_callback_.call(ota::OTA_IN_PROGRESS, percentage, 0);
|
||||
#ifdef USE_OTA_STATE_LISTENER
|
||||
this->notify_state_(ota::OTA_IN_PROGRESS, percentage, 0);
|
||||
#endif
|
||||
}
|
||||
} // while
|
||||
|
||||
@@ -24,7 +24,6 @@ enum OtaHttpRequestError : uint8_t {
|
||||
|
||||
class OtaHttpRequestComponent : public ota::OTAComponent, public Parented<HttpRequestComponent> {
|
||||
public:
|
||||
void setup() override;
|
||||
void dump_config() override;
|
||||
float get_setup_priority() const override { return setup_priority::AFTER_WIFI; }
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import esphome.codegen as cg
|
||||
from esphome.components import update
|
||||
from esphome.components import ota, update
|
||||
import esphome.config_validation as cv
|
||||
from esphome.const import CONF_SOURCE
|
||||
|
||||
@@ -38,6 +38,6 @@ async def to_code(config):
|
||||
|
||||
cg.add(var.set_source_url(config[CONF_SOURCE]))
|
||||
|
||||
cg.add_define("USE_OTA_STATE_CALLBACK")
|
||||
ota.request_ota_state_listeners()
|
||||
|
||||
await cg.register_component(var, config)
|
||||
|
||||
@@ -20,19 +20,19 @@ static const char *const TAG = "http_request.update";
|
||||
|
||||
static const size_t MAX_READ_SIZE = 256;
|
||||
|
||||
void HttpRequestUpdate::setup() {
|
||||
this->ota_parent_->add_on_state_callback([this](ota::OTAState state, float progress, uint8_t err) {
|
||||
if (state == ota::OTAState::OTA_IN_PROGRESS) {
|
||||
this->state_ = update::UPDATE_STATE_INSTALLING;
|
||||
this->update_info_.has_progress = true;
|
||||
this->update_info_.progress = progress;
|
||||
this->publish_state();
|
||||
} else if (state == ota::OTAState::OTA_ABORT || state == ota::OTAState::OTA_ERROR) {
|
||||
this->state_ = update::UPDATE_STATE_AVAILABLE;
|
||||
this->status_set_error(LOG_STR("Failed to install firmware"));
|
||||
this->publish_state();
|
||||
}
|
||||
});
|
||||
void HttpRequestUpdate::setup() { this->ota_parent_->add_state_listener(this); }
|
||||
|
||||
void HttpRequestUpdate::on_ota_state(ota::OTAState state, float progress, uint8_t error) {
|
||||
if (state == ota::OTAState::OTA_IN_PROGRESS) {
|
||||
this->state_ = update::UPDATE_STATE_INSTALLING;
|
||||
this->update_info_.has_progress = true;
|
||||
this->update_info_.progress = progress;
|
||||
this->publish_state();
|
||||
} else if (state == ota::OTAState::OTA_ABORT || state == ota::OTAState::OTA_ERROR) {
|
||||
this->state_ = update::UPDATE_STATE_AVAILABLE;
|
||||
this->status_set_error(LOG_STR("Failed to install firmware"));
|
||||
this->publish_state();
|
||||
}
|
||||
}
|
||||
|
||||
void HttpRequestUpdate::update() {
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
namespace esphome {
|
||||
namespace http_request {
|
||||
|
||||
class HttpRequestUpdate : public update::UpdateEntity, public PollingComponent {
|
||||
class HttpRequestUpdate final : public update::UpdateEntity, public PollingComponent, public ota::OTAStateListener {
|
||||
public:
|
||||
void setup() override;
|
||||
void update() override;
|
||||
@@ -29,6 +29,8 @@ class HttpRequestUpdate : public update::UpdateEntity, public PollingComponent {
|
||||
|
||||
float get_setup_priority() const override { return setup_priority::AFTER_WIFI; }
|
||||
|
||||
void on_ota_state(ota::OTAState state, float progress, uint8_t error) override;
|
||||
|
||||
protected:
|
||||
HttpRequestComponent *request_parent_;
|
||||
OtaHttpRequestComponent *ota_parent_;
|
||||
|
||||
@@ -7,7 +7,7 @@ from urllib.parse import urljoin
|
||||
from esphome import automation, external_files, git
|
||||
from esphome.automation import register_action, register_condition
|
||||
import esphome.codegen as cg
|
||||
from esphome.components import esp32, microphone, socket
|
||||
from esphome.components import esp32, microphone, ota, socket
|
||||
import esphome.config_validation as cv
|
||||
from esphome.const import (
|
||||
CONF_FILE,
|
||||
@@ -452,7 +452,7 @@ async def to_code(config):
|
||||
cg.add(var.set_microphone_source(mic_source))
|
||||
|
||||
cg.add_define("USE_MICRO_WAKE_WORD")
|
||||
cg.add_define("USE_OTA_STATE_CALLBACK")
|
||||
ota.request_ota_state_listeners()
|
||||
|
||||
esp32.add_idf_component(name="espressif/esp-tflite-micro", ref="1.3.3~1")
|
||||
|
||||
|
||||
@@ -119,18 +119,21 @@ void MicroWakeWord::setup() {
|
||||
}
|
||||
});
|
||||
|
||||
#ifdef USE_OTA
|
||||
ota::get_global_ota_callback()->add_on_state_callback(
|
||||
[this](ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) {
|
||||
if (state == ota::OTA_STARTED) {
|
||||
this->suspend_task_();
|
||||
} else if (state == ota::OTA_ERROR) {
|
||||
this->resume_task_();
|
||||
}
|
||||
});
|
||||
#ifdef USE_OTA_STATE_LISTENER
|
||||
ota::get_global_ota_callback()->add_global_state_listener(this);
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef USE_OTA_STATE_LISTENER
|
||||
void MicroWakeWord::on_ota_global_state(ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) {
|
||||
if (state == ota::OTA_STARTED) {
|
||||
this->suspend_task_();
|
||||
} else if (state == ota::OTA_ERROR) {
|
||||
this->resume_task_();
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void MicroWakeWord::inference_task(void *params) {
|
||||
MicroWakeWord *this_mww = (MicroWakeWord *) params;
|
||||
|
||||
|
||||
@@ -9,8 +9,13 @@
|
||||
|
||||
#include "esphome/core/automation.h"
|
||||
#include "esphome/core/component.h"
|
||||
#include "esphome/core/defines.h"
|
||||
#include "esphome/core/ring_buffer.h"
|
||||
|
||||
#ifdef USE_OTA_STATE_LISTENER
|
||||
#include "esphome/components/ota/ota_backend.h"
|
||||
#endif
|
||||
|
||||
#include <freertos/event_groups.h>
|
||||
|
||||
#include <frontend.h>
|
||||
@@ -26,13 +31,22 @@ enum State {
|
||||
STOPPED,
|
||||
};
|
||||
|
||||
class MicroWakeWord : public Component {
|
||||
class MicroWakeWord : public Component
|
||||
#ifdef USE_OTA_STATE_LISTENER
|
||||
,
|
||||
public ota::OTAGlobalStateListener
|
||||
#endif
|
||||
{
|
||||
public:
|
||||
void setup() override;
|
||||
void loop() override;
|
||||
float get_setup_priority() const override;
|
||||
void dump_config() override;
|
||||
|
||||
#ifdef USE_OTA_STATE_LISTENER
|
||||
void on_ota_global_state(ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) override;
|
||||
#endif
|
||||
|
||||
void start();
|
||||
void stop();
|
||||
|
||||
|
||||
@@ -13,6 +13,8 @@ from esphome.const import (
|
||||
from esphome.core import CORE, coroutine_with_priority
|
||||
from esphome.coroutine import CoroPriority
|
||||
|
||||
OTA_STATE_LISTENER_KEY = "ota_state_listener"
|
||||
|
||||
CODEOWNERS = ["@esphome/core"]
|
||||
AUTO_LOAD = ["md5", "safe_mode"]
|
||||
|
||||
@@ -86,6 +88,7 @@ BASE_OTA_SCHEMA = cv.Schema(
|
||||
@coroutine_with_priority(CoroPriority.OTA_UPDATES)
|
||||
async def to_code(config):
|
||||
cg.add_define("USE_OTA")
|
||||
CORE.add_job(final_step)
|
||||
|
||||
if CORE.is_rp2040 and CORE.using_arduino:
|
||||
cg.add_library("Updater", None)
|
||||
@@ -119,7 +122,24 @@ async def ota_to_code(var, config):
|
||||
await automation.build_automation(trigger, [(cg.uint8, "x")], conf)
|
||||
use_state_callback = True
|
||||
if use_state_callback:
|
||||
cg.add_define("USE_OTA_STATE_CALLBACK")
|
||||
request_ota_state_listeners()
|
||||
|
||||
|
||||
def request_ota_state_listeners() -> None:
|
||||
"""Request that OTA state listeners be compiled in.
|
||||
|
||||
Components that need to be notified about OTA state changes (start, progress,
|
||||
complete, error) should call this function during their code generation.
|
||||
This enables the add_state_listener() API on OTAComponent.
|
||||
"""
|
||||
CORE.data[OTA_STATE_LISTENER_KEY] = True
|
||||
|
||||
|
||||
@coroutine_with_priority(CoroPriority.FINAL)
|
||||
async def final_step():
|
||||
"""Final code generation step to configure optional OTA features."""
|
||||
if CORE.data.get(OTA_STATE_LISTENER_KEY, False):
|
||||
cg.add_define("USE_OTA_STATE_LISTENER")
|
||||
|
||||
|
||||
FILTER_SOURCE_FILES = filter_source_files_from_platform(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#pragma once
|
||||
#ifdef USE_OTA_STATE_CALLBACK
|
||||
#ifdef USE_OTA_STATE_LISTENER
|
||||
#include "ota_backend.h"
|
||||
|
||||
#include "esphome/core/automation.h"
|
||||
@@ -7,70 +7,64 @@
|
||||
namespace esphome {
|
||||
namespace ota {
|
||||
|
||||
class OTAStateChangeTrigger : public Trigger<OTAState> {
|
||||
class OTAStateChangeTrigger final : public Trigger<OTAState>, public OTAStateListener {
|
||||
public:
|
||||
explicit OTAStateChangeTrigger(OTAComponent *parent) {
|
||||
parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) {
|
||||
if (!parent->is_failed()) {
|
||||
trigger(state);
|
||||
}
|
||||
});
|
||||
explicit OTAStateChangeTrigger(OTAComponent *parent) : parent_(parent) { parent->add_state_listener(this); }
|
||||
|
||||
void on_ota_state(OTAState state, float progress, uint8_t error) override {
|
||||
if (!this->parent_->is_failed()) {
|
||||
this->trigger(state);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
OTAComponent *parent_;
|
||||
};
|
||||
|
||||
class OTAStartTrigger : public Trigger<> {
|
||||
template<OTAState State> class OTAStateTrigger final : public Trigger<>, public OTAStateListener {
|
||||
public:
|
||||
explicit OTAStartTrigger(OTAComponent *parent) {
|
||||
parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) {
|
||||
if (state == OTA_STARTED && !parent->is_failed()) {
|
||||
trigger();
|
||||
}
|
||||
});
|
||||
explicit OTAStateTrigger(OTAComponent *parent) : parent_(parent) { parent->add_state_listener(this); }
|
||||
|
||||
void on_ota_state(OTAState state, float progress, uint8_t error) override {
|
||||
if (state == State && !this->parent_->is_failed()) {
|
||||
this->trigger();
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
OTAComponent *parent_;
|
||||
};
|
||||
|
||||
class OTAProgressTrigger : public Trigger<float> {
|
||||
using OTAStartTrigger = OTAStateTrigger<OTA_STARTED>;
|
||||
using OTAEndTrigger = OTAStateTrigger<OTA_COMPLETED>;
|
||||
using OTAAbortTrigger = OTAStateTrigger<OTA_ABORT>;
|
||||
|
||||
class OTAProgressTrigger final : public Trigger<float>, public OTAStateListener {
|
||||
public:
|
||||
explicit OTAProgressTrigger(OTAComponent *parent) {
|
||||
parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) {
|
||||
if (state == OTA_IN_PROGRESS && !parent->is_failed()) {
|
||||
trigger(progress);
|
||||
}
|
||||
});
|
||||
explicit OTAProgressTrigger(OTAComponent *parent) : parent_(parent) { parent->add_state_listener(this); }
|
||||
|
||||
void on_ota_state(OTAState state, float progress, uint8_t error) override {
|
||||
if (state == OTA_IN_PROGRESS && !this->parent_->is_failed()) {
|
||||
this->trigger(progress);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
OTAComponent *parent_;
|
||||
};
|
||||
|
||||
class OTAEndTrigger : public Trigger<> {
|
||||
class OTAErrorTrigger final : public Trigger<uint8_t>, public OTAStateListener {
|
||||
public:
|
||||
explicit OTAEndTrigger(OTAComponent *parent) {
|
||||
parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) {
|
||||
if (state == OTA_COMPLETED && !parent->is_failed()) {
|
||||
trigger();
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
explicit OTAErrorTrigger(OTAComponent *parent) : parent_(parent) { parent->add_state_listener(this); }
|
||||
|
||||
class OTAAbortTrigger : public Trigger<> {
|
||||
public:
|
||||
explicit OTAAbortTrigger(OTAComponent *parent) {
|
||||
parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) {
|
||||
if (state == OTA_ABORT && !parent->is_failed()) {
|
||||
trigger();
|
||||
}
|
||||
});
|
||||
void on_ota_state(OTAState state, float progress, uint8_t error) override {
|
||||
if (state == OTA_ERROR && !this->parent_->is_failed()) {
|
||||
this->trigger(error);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class OTAErrorTrigger : public Trigger<uint8_t> {
|
||||
public:
|
||||
explicit OTAErrorTrigger(OTAComponent *parent) {
|
||||
parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) {
|
||||
if (state == OTA_ERROR && !parent->is_failed()) {
|
||||
trigger(error);
|
||||
}
|
||||
});
|
||||
}
|
||||
protected:
|
||||
OTAComponent *parent_;
|
||||
};
|
||||
|
||||
} // namespace ota
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
namespace esphome {
|
||||
namespace ota {
|
||||
|
||||
#ifdef USE_OTA_STATE_CALLBACK
|
||||
#ifdef USE_OTA_STATE_LISTENER
|
||||
OTAGlobalCallback *global_ota_callback{nullptr}; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
|
||||
OTAGlobalCallback *get_global_ota_callback() {
|
||||
@@ -13,7 +13,12 @@ OTAGlobalCallback *get_global_ota_callback() {
|
||||
return global_ota_callback;
|
||||
}
|
||||
|
||||
void register_ota_platform(OTAComponent *ota_caller) { get_global_ota_callback()->register_ota(ota_caller); }
|
||||
void OTAComponent::notify_state_(OTAState state, float progress, uint8_t error) {
|
||||
for (auto *listener : this->state_listeners_) {
|
||||
listener->on_ota_state(state, progress, error);
|
||||
}
|
||||
get_global_ota_callback()->notify_ota_state(state, progress, error, this);
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace ota
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
#include "esphome/core/defines.h"
|
||||
#include "esphome/core/helpers.h"
|
||||
|
||||
#ifdef USE_OTA_STATE_CALLBACK
|
||||
#include "esphome/core/automation.h"
|
||||
#ifdef USE_OTA_STATE_LISTENER
|
||||
#include <vector>
|
||||
#endif
|
||||
|
||||
namespace esphome {
|
||||
@@ -60,62 +60,75 @@ class OTABackend {
|
||||
virtual bool supports_compression() = 0;
|
||||
};
|
||||
|
||||
class OTAComponent : public Component {
|
||||
#ifdef USE_OTA_STATE_CALLBACK
|
||||
/** Listener interface for OTA state changes.
|
||||
*
|
||||
* Components can implement this interface to receive OTA state updates
|
||||
* without the overhead of std::function callbacks.
|
||||
*/
|
||||
class OTAStateListener {
|
||||
public:
|
||||
void add_on_state_callback(std::function<void(ota::OTAState, float, uint8_t)> &&callback) {
|
||||
this->state_callback_.add(std::move(callback));
|
||||
}
|
||||
virtual ~OTAStateListener() = default;
|
||||
virtual void on_ota_state(OTAState state, float progress, uint8_t error) = 0;
|
||||
};
|
||||
|
||||
class OTAComponent : public Component {
|
||||
#ifdef USE_OTA_STATE_LISTENER
|
||||
public:
|
||||
void add_state_listener(OTAStateListener *listener) { this->state_listeners_.push_back(listener); }
|
||||
|
||||
protected:
|
||||
/** Extended callback manager with deferred call support.
|
||||
void notify_state_(OTAState state, float progress, uint8_t error);
|
||||
|
||||
/** Notify state with deferral to main loop (for thread safety).
|
||||
*
|
||||
* This adds a call_deferred() method for thread-safe execution from other tasks.
|
||||
* This should be used by OTA implementations that run in separate tasks
|
||||
* (like web_server OTA) to ensure listeners execute in the main loop.
|
||||
*/
|
||||
class StateCallbackManager : public CallbackManager<void(OTAState, float, uint8_t)> {
|
||||
public:
|
||||
StateCallbackManager(OTAComponent *component) : component_(component) {}
|
||||
void notify_state_deferred_(OTAState state, float progress, uint8_t error) {
|
||||
this->defer([this, state, progress, error]() { this->notify_state_(state, progress, error); });
|
||||
}
|
||||
|
||||
/** Call callbacks with deferral to main loop (for thread safety).
|
||||
*
|
||||
* This should be used by OTA implementations that run in separate tasks
|
||||
* (like web_server OTA) to ensure callbacks execute in the main loop.
|
||||
*/
|
||||
void call_deferred(ota::OTAState state, float progress, uint8_t error) {
|
||||
component_->defer([this, state, progress, error]() { this->call(state, progress, error); });
|
||||
}
|
||||
|
||||
private:
|
||||
OTAComponent *component_;
|
||||
};
|
||||
|
||||
StateCallbackManager state_callback_{this};
|
||||
std::vector<OTAStateListener *> state_listeners_;
|
||||
#endif
|
||||
};
|
||||
|
||||
#ifdef USE_OTA_STATE_CALLBACK
|
||||
#ifdef USE_OTA_STATE_LISTENER
|
||||
|
||||
/** Listener interface for global OTA state changes (includes OTA component pointer).
|
||||
*
|
||||
* Used by OTAGlobalCallback to aggregate state from multiple OTA components.
|
||||
*/
|
||||
class OTAGlobalStateListener {
|
||||
public:
|
||||
virtual ~OTAGlobalStateListener() = default;
|
||||
virtual void on_ota_global_state(OTAState state, float progress, uint8_t error, OTAComponent *component) = 0;
|
||||
};
|
||||
|
||||
/** Global callback that aggregates OTA state from all OTA components.
|
||||
*
|
||||
* OTA components call notify_ota_state() directly with their pointer,
|
||||
* which forwards the event to all registered global listeners.
|
||||
*/
|
||||
class OTAGlobalCallback {
|
||||
public:
|
||||
void register_ota(OTAComponent *ota_caller) {
|
||||
ota_caller->add_on_state_callback([this, ota_caller](OTAState state, float progress, uint8_t error) {
|
||||
this->state_callback_.call(state, progress, error, ota_caller);
|
||||
});
|
||||
}
|
||||
void add_on_state_callback(std::function<void(OTAState, float, uint8_t, OTAComponent *)> &&callback) {
|
||||
this->state_callback_.add(std::move(callback));
|
||||
void add_global_state_listener(OTAGlobalStateListener *listener) { this->global_listeners_.push_back(listener); }
|
||||
|
||||
void notify_ota_state(OTAState state, float progress, uint8_t error, OTAComponent *component) {
|
||||
for (auto *listener : this->global_listeners_) {
|
||||
listener->on_ota_global_state(state, progress, error, component);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
CallbackManager<void(OTAState, float, uint8_t, OTAComponent *)> state_callback_{};
|
||||
std::vector<OTAGlobalStateListener *> global_listeners_;
|
||||
};
|
||||
|
||||
OTAGlobalCallback *get_global_ota_callback();
|
||||
void register_ota_platform(OTAComponent *ota_caller);
|
||||
|
||||
// OTA implementations should use:
|
||||
// - state_callback_.call() when already in main loop (e.g., esphome OTA)
|
||||
// - state_callback_.call_deferred() when in separate task (e.g., web_server OTA)
|
||||
// This ensures proper callback execution in all contexts.
|
||||
// - notify_state_() when already in main loop (e.g., esphome OTA)
|
||||
// - notify_state_deferred_() when in separate task (e.g., web_server OTA)
|
||||
// This ensures proper listener execution in all contexts.
|
||||
#endif
|
||||
std::unique_ptr<ota::OTABackend> make_ota_backend();
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from pathlib import Path
|
||||
|
||||
from esphome import automation, external_files
|
||||
import esphome.codegen as cg
|
||||
from esphome.components import audio, esp32, media_player, network, psram, speaker
|
||||
from esphome.components import audio, esp32, media_player, network, ota, psram, speaker
|
||||
import esphome.config_validation as cv
|
||||
from esphome.const import (
|
||||
CONF_BUFFER_SIZE,
|
||||
@@ -342,7 +342,7 @@ async def to_code(config):
|
||||
var = await media_player.new_media_player(config)
|
||||
await cg.register_component(var, config)
|
||||
|
||||
cg.add_define("USE_OTA_STATE_CALLBACK")
|
||||
ota.request_ota_state_listeners()
|
||||
|
||||
cg.add(var.set_buffer_size(config[CONF_BUFFER_SIZE]))
|
||||
|
||||
|
||||
@@ -66,25 +66,8 @@ void SpeakerMediaPlayer::setup() {
|
||||
this->set_mute_state_(false);
|
||||
}
|
||||
|
||||
#ifdef USE_OTA
|
||||
ota::get_global_ota_callback()->add_on_state_callback(
|
||||
[this](ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) {
|
||||
if (state == ota::OTA_STARTED) {
|
||||
if (this->media_pipeline_ != nullptr) {
|
||||
this->media_pipeline_->suspend_tasks();
|
||||
}
|
||||
if (this->announcement_pipeline_ != nullptr) {
|
||||
this->announcement_pipeline_->suspend_tasks();
|
||||
}
|
||||
} else if (state == ota::OTA_ERROR) {
|
||||
if (this->media_pipeline_ != nullptr) {
|
||||
this->media_pipeline_->resume_tasks();
|
||||
}
|
||||
if (this->announcement_pipeline_ != nullptr) {
|
||||
this->announcement_pipeline_->resume_tasks();
|
||||
}
|
||||
}
|
||||
});
|
||||
#ifdef USE_OTA_STATE_LISTENER
|
||||
ota::get_global_ota_callback()->add_global_state_listener(this);
|
||||
#endif
|
||||
|
||||
this->announcement_pipeline_ =
|
||||
@@ -300,6 +283,27 @@ void SpeakerMediaPlayer::watch_media_commands_() {
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef USE_OTA_STATE_LISTENER
|
||||
void SpeakerMediaPlayer::on_ota_global_state(ota::OTAState state, float progress, uint8_t error,
|
||||
ota::OTAComponent *comp) {
|
||||
if (state == ota::OTA_STARTED) {
|
||||
if (this->media_pipeline_ != nullptr) {
|
||||
this->media_pipeline_->suspend_tasks();
|
||||
}
|
||||
if (this->announcement_pipeline_ != nullptr) {
|
||||
this->announcement_pipeline_->suspend_tasks();
|
||||
}
|
||||
} else if (state == ota::OTA_ERROR) {
|
||||
if (this->media_pipeline_ != nullptr) {
|
||||
this->media_pipeline_->resume_tasks();
|
||||
}
|
||||
if (this->announcement_pipeline_ != nullptr) {
|
||||
this->announcement_pipeline_->resume_tasks();
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void SpeakerMediaPlayer::loop() {
|
||||
this->watch_media_commands_();
|
||||
|
||||
|
||||
@@ -5,14 +5,18 @@
|
||||
#include "audio_pipeline.h"
|
||||
|
||||
#include "esphome/components/audio/audio.h"
|
||||
|
||||
#include "esphome/components/media_player/media_player.h"
|
||||
#include "esphome/components/speaker/speaker.h"
|
||||
|
||||
#include "esphome/core/automation.h"
|
||||
#include "esphome/core/component.h"
|
||||
#include "esphome/core/defines.h"
|
||||
#include "esphome/core/preferences.h"
|
||||
|
||||
#ifdef USE_OTA_STATE_LISTENER
|
||||
#include "esphome/components/ota/ota_backend.h"
|
||||
#endif
|
||||
|
||||
#include <deque>
|
||||
#include <freertos/FreeRTOS.h>
|
||||
#include <freertos/queue.h>
|
||||
@@ -39,12 +43,22 @@ struct VolumeRestoreState {
|
||||
bool is_muted;
|
||||
};
|
||||
|
||||
class SpeakerMediaPlayer : public Component, public media_player::MediaPlayer {
|
||||
class SpeakerMediaPlayer : public Component,
|
||||
public media_player::MediaPlayer
|
||||
#ifdef USE_OTA_STATE_LISTENER
|
||||
,
|
||||
public ota::OTAGlobalStateListener
|
||||
#endif
|
||||
{
|
||||
public:
|
||||
float get_setup_priority() const override { return esphome::setup_priority::PROCESSOR; }
|
||||
void setup() override;
|
||||
void loop() override;
|
||||
|
||||
#ifdef USE_OTA_STATE_LISTENER
|
||||
void on_ota_global_state(ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) override;
|
||||
#endif
|
||||
|
||||
// MediaPlayer implementations
|
||||
media_player::MediaPlayerTraits get_traits() override;
|
||||
bool is_muted() const override { return this->is_muted_; }
|
||||
|
||||
@@ -84,9 +84,9 @@ void OTARequestHandler::report_ota_progress_(AsyncWebServerRequest *request) {
|
||||
} else {
|
||||
ESP_LOGD(TAG, "OTA in progress: %" PRIu32 " bytes read", this->ota_read_length_);
|
||||
}
|
||||
#ifdef USE_OTA_STATE_CALLBACK
|
||||
// Report progress - use call_deferred since we're in web server task
|
||||
this->parent_->state_callback_.call_deferred(ota::OTA_IN_PROGRESS, percentage, 0);
|
||||
#ifdef USE_OTA_STATE_LISTENER
|
||||
// Report progress - use notify_state_deferred_ since we're in web server task
|
||||
this->parent_->notify_state_deferred_(ota::OTA_IN_PROGRESS, percentage, 0);
|
||||
#endif
|
||||
this->last_ota_progress_ = now;
|
||||
}
|
||||
@@ -114,9 +114,9 @@ void OTARequestHandler::handleUpload(AsyncWebServerRequest *request, const Platf
|
||||
// Initialize OTA on first call
|
||||
this->ota_init_(filename.c_str());
|
||||
|
||||
#ifdef USE_OTA_STATE_CALLBACK
|
||||
// Notify OTA started - use call_deferred since we're in web server task
|
||||
this->parent_->state_callback_.call_deferred(ota::OTA_STARTED, 0.0f, 0);
|
||||
#ifdef USE_OTA_STATE_LISTENER
|
||||
// Notify OTA started - use notify_state_deferred_ since we're in web server task
|
||||
this->parent_->notify_state_deferred_(ota::OTA_STARTED, 0.0f, 0);
|
||||
#endif
|
||||
|
||||
// Platform-specific pre-initialization
|
||||
@@ -134,9 +134,9 @@ void OTARequestHandler::handleUpload(AsyncWebServerRequest *request, const Platf
|
||||
this->ota_backend_ = ota::make_ota_backend();
|
||||
if (!this->ota_backend_) {
|
||||
ESP_LOGE(TAG, "Failed to create OTA backend");
|
||||
#ifdef USE_OTA_STATE_CALLBACK
|
||||
this->parent_->state_callback_.call_deferred(ota::OTA_ERROR, 0.0f,
|
||||
static_cast<uint8_t>(ota::OTA_RESPONSE_ERROR_UNKNOWN));
|
||||
#ifdef USE_OTA_STATE_LISTENER
|
||||
this->parent_->notify_state_deferred_(ota::OTA_ERROR, 0.0f,
|
||||
static_cast<uint8_t>(ota::OTA_RESPONSE_ERROR_UNKNOWN));
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
@@ -148,8 +148,8 @@ void OTARequestHandler::handleUpload(AsyncWebServerRequest *request, const Platf
|
||||
if (error_code != ota::OTA_RESPONSE_OK) {
|
||||
ESP_LOGE(TAG, "OTA begin failed: %d", error_code);
|
||||
this->ota_backend_.reset();
|
||||
#ifdef USE_OTA_STATE_CALLBACK
|
||||
this->parent_->state_callback_.call_deferred(ota::OTA_ERROR, 0.0f, static_cast<uint8_t>(error_code));
|
||||
#ifdef USE_OTA_STATE_LISTENER
|
||||
this->parent_->notify_state_deferred_(ota::OTA_ERROR, 0.0f, static_cast<uint8_t>(error_code));
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
@@ -166,8 +166,8 @@ void OTARequestHandler::handleUpload(AsyncWebServerRequest *request, const Platf
|
||||
ESP_LOGE(TAG, "OTA write failed: %d", error_code);
|
||||
this->ota_backend_->abort();
|
||||
this->ota_backend_.reset();
|
||||
#ifdef USE_OTA_STATE_CALLBACK
|
||||
this->parent_->state_callback_.call_deferred(ota::OTA_ERROR, 0.0f, static_cast<uint8_t>(error_code));
|
||||
#ifdef USE_OTA_STATE_LISTENER
|
||||
this->parent_->notify_state_deferred_(ota::OTA_ERROR, 0.0f, static_cast<uint8_t>(error_code));
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
@@ -186,15 +186,15 @@ void OTARequestHandler::handleUpload(AsyncWebServerRequest *request, const Platf
|
||||
error_code = this->ota_backend_->end();
|
||||
if (error_code == ota::OTA_RESPONSE_OK) {
|
||||
this->ota_success_ = true;
|
||||
#ifdef USE_OTA_STATE_CALLBACK
|
||||
// Report completion before reboot - use call_deferred since we're in web server task
|
||||
this->parent_->state_callback_.call_deferred(ota::OTA_COMPLETED, 100.0f, 0);
|
||||
#ifdef USE_OTA_STATE_LISTENER
|
||||
// Report completion before reboot - use notify_state_deferred_ since we're in web server task
|
||||
this->parent_->notify_state_deferred_(ota::OTA_COMPLETED, 100.0f, 0);
|
||||
#endif
|
||||
this->schedule_ota_reboot_();
|
||||
} else {
|
||||
ESP_LOGE(TAG, "OTA end failed: %d", error_code);
|
||||
#ifdef USE_OTA_STATE_CALLBACK
|
||||
this->parent_->state_callback_.call_deferred(ota::OTA_ERROR, 0.0f, static_cast<uint8_t>(error_code));
|
||||
#ifdef USE_OTA_STATE_LISTENER
|
||||
this->parent_->notify_state_deferred_(ota::OTA_ERROR, 0.0f, static_cast<uint8_t>(error_code));
|
||||
#endif
|
||||
}
|
||||
this->ota_backend_.reset();
|
||||
@@ -232,10 +232,6 @@ void WebServerOTAComponent::setup() {
|
||||
|
||||
// AsyncWebServer takes ownership of the handler and will delete it when the server is destroyed
|
||||
base->add_handler(new OTARequestHandler(this)); // NOLINT
|
||||
#ifdef USE_OTA_STATE_CALLBACK
|
||||
// Register with global OTA callback system
|
||||
ota::register_ota_platform(this);
|
||||
#endif
|
||||
}
|
||||
|
||||
void WebServerOTAComponent::dump_config() { ESP_LOGCONFIG(TAG, "Web Server OTA"); }
|
||||
|
||||
@@ -146,7 +146,7 @@
|
||||
#define USE_OTA_PASSWORD
|
||||
#define USE_OTA_SHA256
|
||||
#define ALLOW_OTA_DOWNGRADE_MD5
|
||||
#define USE_OTA_STATE_CALLBACK
|
||||
#define USE_OTA_STATE_LISTENER
|
||||
#define USE_OTA_VERSION 2
|
||||
#define USE_TIME_TIMEZONE
|
||||
#define USE_WIFI
|
||||
|
||||
Reference in New Issue
Block a user