Compare commits
7 Commits
release
...
socket-ref
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
88632b22e2 | ||
|
|
44041d2526 | ||
|
|
7cfc36cb70 | ||
|
|
08dd72e716 | ||
|
|
7b7e5f7db5 | ||
|
|
c9b170eab9 | ||
|
|
40dd9c5dce |
@@ -192,6 +192,11 @@ class APIClient(threading.Thread):
|
||||
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self._socket.settimeout(10.0)
|
||||
self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||
import ssl
|
||||
|
||||
context = ssl.SSLContext()
|
||||
self._socket = context.wrap_socket(self._socket)
|
||||
|
||||
try:
|
||||
self._socket.connect((ip, self._port))
|
||||
except OSError as err:
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import base64
|
||||
|
||||
import esphome.codegen as cg
|
||||
import esphome.config_validation as cv
|
||||
from esphome import automation
|
||||
@@ -6,6 +8,7 @@ from esphome.const import (
|
||||
CONF_DATA,
|
||||
CONF_DATA_TEMPLATE,
|
||||
CONF_ID,
|
||||
CONF_KEY,
|
||||
CONF_PASSWORD,
|
||||
CONF_PORT,
|
||||
CONF_REBOOT_TIMEOUT,
|
||||
@@ -19,7 +22,7 @@ from esphome.const import (
|
||||
from esphome.core import coroutine_with_priority
|
||||
|
||||
DEPENDENCIES = ["network"]
|
||||
AUTO_LOAD = ["async_tcp"]
|
||||
AUTO_LOAD = ["socket"]
|
||||
CODEOWNERS = ["@OttoWinter"]
|
||||
|
||||
api_ns = cg.esphome_ns.namespace("api")
|
||||
@@ -41,6 +44,22 @@ SERVICE_ARG_NATIVE_TYPES = {
|
||||
"float[]": cg.std_vector.template(float),
|
||||
"string[]": cg.std_vector.template(cg.std_string),
|
||||
}
|
||||
CONF_ENCRYPTION = "encryption"
|
||||
|
||||
|
||||
def validate_encryption_key(value):
|
||||
value = cv.string_strict(value)
|
||||
try:
|
||||
decoded = base64.b64decode(value, validate=True)
|
||||
except ValueError as err:
|
||||
raise cv.Invalid("Invalid key format, please check it's using base64") from err
|
||||
|
||||
if len(decoded) != 32:
|
||||
raise cv.Invalid("Encryption key must be base64 and 32 bytes long")
|
||||
|
||||
# Return original data for roundtrip conversion
|
||||
return value
|
||||
|
||||
|
||||
CONFIG_SCHEMA = cv.Schema(
|
||||
{
|
||||
@@ -63,6 +82,11 @@ CONFIG_SCHEMA = cv.Schema(
|
||||
),
|
||||
}
|
||||
),
|
||||
cv.Optional(CONF_ENCRYPTION): cv.Schema(
|
||||
{
|
||||
cv.Required(CONF_KEY): validate_encryption_key,
|
||||
}
|
||||
),
|
||||
}
|
||||
).extend(cv.COMPONENT_SCHEMA)
|
||||
|
||||
@@ -92,6 +116,14 @@ async def to_code(config):
|
||||
cg.add(var.register_user_service(trigger))
|
||||
await automation.build_automation(trigger, func_args, conf)
|
||||
|
||||
if CONF_ENCRYPTION in config:
|
||||
conf = config[CONF_ENCRYPTION]
|
||||
decoded = base64.b64decode(conf[CONF_KEY])
|
||||
cg.add(var.set_noise_psk(list(decoded)))
|
||||
cg.add_define("USE_API_NOISE")
|
||||
else:
|
||||
cg.add_define("USE_API_PLAINTEXT")
|
||||
|
||||
cg.add_define("USE_API")
|
||||
cg.add_global(api_ns.using)
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#include "esphome/core/log.h"
|
||||
#include "esphome/core/util.h"
|
||||
#include "esphome/core/version.h"
|
||||
#include <errno.h>
|
||||
|
||||
#ifdef USE_DEEP_SLEEP
|
||||
#include "esphome/components/deep_sleep/deep_sleep_component.h"
|
||||
@@ -18,74 +19,36 @@ namespace api {
|
||||
|
||||
static const char *const TAG = "api.connection";
|
||||
|
||||
APIConnection::APIConnection(AsyncClient *client, APIServer *parent)
|
||||
: client_(client), parent_(parent), initial_state_iterator_(parent, this), list_entities_iterator_(parent, this) {
|
||||
this->client_->onError([](void *s, AsyncClient *c, int8_t error) { ((APIConnection *) s)->on_error_(error); }, this);
|
||||
this->client_->onDisconnect([](void *s, AsyncClient *c) { ((APIConnection *) s)->on_disconnect_(); }, this);
|
||||
this->client_->onTimeout([](void *s, AsyncClient *c, uint32_t time) { ((APIConnection *) s)->on_timeout_(time); },
|
||||
this);
|
||||
this->client_->onData([](void *s, AsyncClient *c, void *buf,
|
||||
size_t len) { ((APIConnection *) s)->on_data_(reinterpret_cast<uint8_t *>(buf), len); },
|
||||
this);
|
||||
APIConnection::APIConnection(std::unique_ptr<socket::Socket> sock, APIServer *parent)
|
||||
: parent_(parent),
|
||||
initial_state_iterator_(parent, this),
|
||||
list_entities_iterator_(parent, this) {
|
||||
this->proto_write_buffer_.reserve(64);
|
||||
|
||||
this->send_buffer_.reserve(64);
|
||||
this->recv_buffer_.reserve(32);
|
||||
this->client_info_ = this->client_->remoteIP().toString().c_str();
|
||||
#ifdef USE_API_NOISE
|
||||
helper_ = std::unique_ptr<APIFrameHelper>{new APINoiseFrameHelper(std::move(sock), parent->get_noise_ctx())};
|
||||
#elif defined(USE_API_PLAINTEXT)
|
||||
helper_ = std::unique_ptr<APIFrameHelper>{new APIPlaintextFrameHelper(std::move(sock))};
|
||||
#else
|
||||
#error "No api frame helper enabled"
|
||||
#endif
|
||||
|
||||
}
|
||||
void APIConnection::start() {
|
||||
this->last_traffic_ = millis();
|
||||
}
|
||||
APIConnection::~APIConnection() { delete this->client_; }
|
||||
void APIConnection::on_error_(int8_t error) { this->remove_ = true; }
|
||||
void APIConnection::on_disconnect_() { this->remove_ = true; }
|
||||
void APIConnection::on_timeout_(uint32_t time) { this->on_fatal_error(); }
|
||||
void APIConnection::on_data_(uint8_t *buf, size_t len) {
|
||||
if (len == 0 || buf == nullptr)
|
||||
|
||||
APIError err = helper_->init();
|
||||
if (err != APIError::OK) {
|
||||
ESP_LOGW(TAG, "Helper init failed: %d errno=%d", (int) err, errno);
|
||||
remove_ = true;
|
||||
return;
|
||||
this->recv_buffer_.insert(this->recv_buffer_.end(), buf, buf + len);
|
||||
}
|
||||
void APIConnection::parse_recv_buffer_() {
|
||||
if (this->recv_buffer_.empty() || this->remove_)
|
||||
return;
|
||||
|
||||
while (!this->recv_buffer_.empty()) {
|
||||
if (this->recv_buffer_[0] != 0x00) {
|
||||
ESP_LOGW(TAG, "Invalid preamble from %s", this->client_info_.c_str());
|
||||
this->on_fatal_error();
|
||||
return;
|
||||
}
|
||||
uint32_t i = 1;
|
||||
const uint32_t size = this->recv_buffer_.size();
|
||||
uint32_t consumed;
|
||||
auto msg_size_varint = ProtoVarInt::parse(&this->recv_buffer_[i], size - i, &consumed);
|
||||
if (!msg_size_varint.has_value())
|
||||
// not enough data there yet
|
||||
return;
|
||||
i += consumed;
|
||||
uint32_t msg_size = msg_size_varint->as_uint32();
|
||||
|
||||
auto msg_type_varint = ProtoVarInt::parse(&this->recv_buffer_[i], size - i, &consumed);
|
||||
if (!msg_type_varint.has_value())
|
||||
// not enough data there yet
|
||||
return;
|
||||
i += consumed;
|
||||
uint32_t msg_type = msg_type_varint->as_uint32();
|
||||
|
||||
if (size - i < msg_size)
|
||||
// message body not fully received
|
||||
return;
|
||||
|
||||
uint8_t *msg = &this->recv_buffer_[i];
|
||||
this->read_message(msg_size, msg_type, msg);
|
||||
if (this->remove_)
|
||||
return;
|
||||
// pop front
|
||||
uint32_t total = i + msg_size;
|
||||
this->recv_buffer_.erase(this->recv_buffer_.begin(), this->recv_buffer_.begin() + total);
|
||||
this->last_traffic_ = millis();
|
||||
}
|
||||
client_info_ = helper_->getpeername();
|
||||
helper_->set_log_info(client_info_);
|
||||
}
|
||||
|
||||
void APIConnection::disconnect_client() {
|
||||
this->client_->close();
|
||||
void APIConnection::force_disconnect_client() {
|
||||
this->helper_->close();
|
||||
this->remove_ = true;
|
||||
}
|
||||
|
||||
@@ -93,61 +56,78 @@ void APIConnection::loop() {
|
||||
if (this->remove_)
|
||||
return;
|
||||
|
||||
if (this->next_close_) {
|
||||
this->disconnect_client();
|
||||
return;
|
||||
}
|
||||
|
||||
if (!network_is_connected()) {
|
||||
// when network is disconnected force disconnect immediately
|
||||
// don't wait for timeout
|
||||
this->on_fatal_error();
|
||||
return;
|
||||
}
|
||||
if (this->client_->disconnected()) {
|
||||
// failsafe for disconnect logic
|
||||
this->on_disconnect_();
|
||||
if (this->next_close_) {
|
||||
this->helper_->close();
|
||||
this->remove_ = true;
|
||||
return;
|
||||
}
|
||||
this->parse_recv_buffer_();
|
||||
|
||||
APIError err = helper_->loop();
|
||||
if (err != APIError::OK) {
|
||||
on_fatal_error();
|
||||
ESP_LOGW(TAG, "%s: Socket operation failed: %d", client_info_.c_str(), (int) err);
|
||||
return;
|
||||
}
|
||||
ReadPacketBuffer buffer;
|
||||
err = helper_->read_packet(&buffer);
|
||||
if (err == APIError::WOULD_BLOCK) {
|
||||
// pass
|
||||
} else if (err != APIError::OK) {
|
||||
on_fatal_error();
|
||||
ESP_LOGW(TAG, "%s: Reading failed: %d", client_info_.c_str(), (int) err);
|
||||
return;
|
||||
} else {
|
||||
this->last_traffic_ = millis();
|
||||
// read a packet
|
||||
this->read_message(
|
||||
buffer.data_len,
|
||||
buffer.type,
|
||||
&buffer.container[buffer.data_offset]
|
||||
);
|
||||
if (this->remove_)
|
||||
return;
|
||||
}
|
||||
|
||||
this->list_entities_iterator_.advance();
|
||||
this->initial_state_iterator_.advance();
|
||||
|
||||
const uint32_t keepalive = 60000;
|
||||
const uint32_t now = millis();
|
||||
if (this->sent_ping_) {
|
||||
// Disconnect if not responded within 2.5*keepalive
|
||||
if (millis() - this->last_traffic_ > (keepalive * 5) / 2) {
|
||||
if (now - this->last_traffic_ > (keepalive * 5) / 2) {
|
||||
this->force_disconnect_client();
|
||||
ESP_LOGW(TAG, "'%s' didn't respond to ping request in time. Disconnecting...", this->client_info_.c_str());
|
||||
this->disconnect_client();
|
||||
}
|
||||
} else if (millis() - this->last_traffic_ > keepalive) {
|
||||
} else if (now - this->last_traffic_ > keepalive) {
|
||||
this->sent_ping_ = true;
|
||||
this->send_ping_request(PingRequest());
|
||||
}
|
||||
|
||||
#ifdef USE_ESP32_CAMERA
|
||||
if (this->image_reader_.available()) {
|
||||
uint32_t space = this->client_->space();
|
||||
// reserve 15 bytes for metadata, and at least 64 bytes of data
|
||||
if (space >= 15 + 64) {
|
||||
uint32_t to_send = std::min(space - 15, this->image_reader_.available());
|
||||
auto buffer = this->create_buffer();
|
||||
// fixed32 key = 1;
|
||||
buffer.encode_fixed32(1, esp32_camera::global_esp32_camera->get_object_id_hash());
|
||||
// bytes data = 2;
|
||||
buffer.encode_bytes(2, this->image_reader_.peek_data_buffer(), to_send);
|
||||
// bool done = 3;
|
||||
bool done = this->image_reader_.available() == to_send;
|
||||
buffer.encode_bool(3, done);
|
||||
bool success = this->send_buffer(buffer, 44);
|
||||
if (this->image_reader_.available() && this->helper_->can_write_without_blocking()) {
|
||||
uint32_t to_send = std::min((size_t) 1024, this->image_reader_.available());
|
||||
auto buffer = this->create_buffer();
|
||||
// fixed32 key = 1;
|
||||
buffer.encode_fixed32(1, esp32_camera::global_esp32_camera->get_object_id_hash());
|
||||
// bytes data = 2;
|
||||
buffer.encode_bytes(2, this->image_reader_.peek_data_buffer(), to_send);
|
||||
// bool done = 3;
|
||||
bool done = this->image_reader_.available() == to_send;
|
||||
buffer.encode_bool(3, done);
|
||||
bool success = this->send_buffer(buffer, 44);
|
||||
|
||||
if (success) {
|
||||
this->image_reader_.consume_data(to_send);
|
||||
}
|
||||
if (success && done) {
|
||||
this->image_reader_.return_image();
|
||||
}
|
||||
if (success) {
|
||||
this->image_reader_.consume_data(to_send);
|
||||
}
|
||||
if (success && done) {
|
||||
this->image_reader_.return_image();
|
||||
}
|
||||
}
|
||||
#endif
|
||||
@@ -718,8 +698,8 @@ bool APIConnection::send_log_message(int level, const char *tag, const char *lin
|
||||
}
|
||||
|
||||
HelloResponse APIConnection::hello(const HelloRequest &msg) {
|
||||
this->client_info_ = msg.client_info + " (" + this->client_->remoteIP().toString().c_str();
|
||||
this->client_info_ += ")";
|
||||
this->client_info_ = msg.client_info + " (" + this->helper_->getpeername() + ")";
|
||||
this->helper_->set_log_info(client_info_);
|
||||
ESP_LOGV(TAG, "Hello from client: '%s'", this->client_info_.c_str());
|
||||
|
||||
HelloResponse resp;
|
||||
@@ -797,44 +777,35 @@ void APIConnection::subscribe_home_assistant_states(const SubscribeHomeAssistant
|
||||
bool APIConnection::send_buffer(ProtoWriteBuffer buffer, uint32_t message_type) {
|
||||
if (this->remove_)
|
||||
return false;
|
||||
if (!this->helper_->can_write_without_blocking())
|
||||
return false;
|
||||
|
||||
std::vector<uint8_t> header;
|
||||
header.push_back(0x00);
|
||||
ProtoVarInt(buffer.get_buffer()->size()).encode(header);
|
||||
ProtoVarInt(message_type).encode(header);
|
||||
|
||||
size_t needed_space = buffer.get_buffer()->size() + header.size();
|
||||
|
||||
if (needed_space > this->client_->space()) {
|
||||
delay(0);
|
||||
if (needed_space > this->client_->space()) {
|
||||
// SubscribeLogsResponse
|
||||
if (message_type != 29) {
|
||||
ESP_LOGV(TAG, "Cannot send message because of TCP buffer space");
|
||||
}
|
||||
delay(0);
|
||||
return false;
|
||||
}
|
||||
APIError err = this->helper_->write_packet(
|
||||
message_type,
|
||||
buffer.get_buffer()->data(),
|
||||
buffer.get_buffer()->size()
|
||||
);
|
||||
if (err == APIError::WOULD_BLOCK)
|
||||
return false;
|
||||
if (err != APIError::OK) {
|
||||
on_fatal_error();
|
||||
ESP_LOGW(TAG, "%s: Packet write failed %d errno=%d", client_info_.c_str(), (int) err, errno);
|
||||
return false;
|
||||
}
|
||||
|
||||
this->client_->add(reinterpret_cast<char *>(header.data()), header.size(),
|
||||
ASYNC_WRITE_FLAG_COPY | ASYNC_WRITE_FLAG_MORE);
|
||||
this->client_->add(reinterpret_cast<char *>(buffer.get_buffer()->data()), buffer.get_buffer()->size(),
|
||||
ASYNC_WRITE_FLAG_COPY);
|
||||
bool ret = this->client_->send();
|
||||
return ret;
|
||||
this->last_traffic_ = millis();
|
||||
return true;
|
||||
}
|
||||
void APIConnection::on_unauthenticated_access() {
|
||||
ESP_LOGD(TAG, "'%s' tried to access without authentication.", this->client_info_.c_str());
|
||||
this->on_fatal_error();
|
||||
ESP_LOGD(TAG, "'%s' tried to access without authentication.", this->client_info_.c_str());
|
||||
}
|
||||
void APIConnection::on_no_setup_connection() {
|
||||
ESP_LOGD(TAG, "'%s' tried to access without full connection.", this->client_info_.c_str());
|
||||
this->on_fatal_error();
|
||||
ESP_LOGD(TAG, "'%s' tried to access without full connection.", this->client_info_.c_str());
|
||||
}
|
||||
void APIConnection::on_fatal_error() {
|
||||
ESP_LOGV(TAG, "Error: Disconnecting %s", this->client_info_.c_str());
|
||||
this->client_->close();
|
||||
this->helper_->close();
|
||||
this->remove_ = true;
|
||||
}
|
||||
|
||||
|
||||
@@ -5,16 +5,18 @@
|
||||
#include "api_pb2.h"
|
||||
#include "api_pb2_service.h"
|
||||
#include "api_server.h"
|
||||
#include "api_frame_helper.h"
|
||||
|
||||
namespace esphome {
|
||||
namespace api {
|
||||
|
||||
class APIConnection : public APIServerConnection {
|
||||
public:
|
||||
APIConnection(AsyncClient *client, APIServer *parent);
|
||||
virtual ~APIConnection();
|
||||
APIConnection(std::unique_ptr<socket::Socket> socket, APIServer *parent);
|
||||
virtual ~APIConnection() = default;
|
||||
|
||||
void disconnect_client();
|
||||
void start();
|
||||
void force_disconnect_client();
|
||||
void loop();
|
||||
|
||||
bool send_list_info_done() {
|
||||
@@ -87,8 +89,8 @@ class APIConnection : public APIServerConnection {
|
||||
#endif
|
||||
|
||||
void on_disconnect_response(const DisconnectResponse &value) override {
|
||||
// we initiated disconnect_client
|
||||
this->next_close_ = true;
|
||||
this->helper_->close();
|
||||
this->remove_ = true;
|
||||
}
|
||||
void on_ping_response(const PingResponse &value) override {
|
||||
// we initiated ping
|
||||
@@ -102,6 +104,8 @@ class APIConnection : public APIServerConnection {
|
||||
ConnectResponse connect(const ConnectRequest &msg) override;
|
||||
DisconnectResponse disconnect(const DisconnectRequest &msg) override {
|
||||
// remote initiated disconnect_client
|
||||
// don't close yet, we still need to send the disconnect response
|
||||
// close will happen on next loop
|
||||
this->next_close_ = true;
|
||||
DisconnectResponse resp;
|
||||
return resp;
|
||||
@@ -135,19 +139,16 @@ class APIConnection : public APIServerConnection {
|
||||
void on_unauthenticated_access() override;
|
||||
void on_no_setup_connection() override;
|
||||
ProtoWriteBuffer create_buffer() override {
|
||||
this->send_buffer_.clear();
|
||||
return {&this->send_buffer_};
|
||||
// FIXME: ensure no recursive writes can happen
|
||||
this->proto_write_buffer_.clear();
|
||||
return {&this->proto_write_buffer_};
|
||||
}
|
||||
bool send_buffer(ProtoWriteBuffer buffer, uint32_t message_type) override;
|
||||
|
||||
protected:
|
||||
friend APIServer;
|
||||
|
||||
void on_error_(int8_t error);
|
||||
void on_disconnect_();
|
||||
void on_timeout_(uint32_t time);
|
||||
void on_data_(uint8_t *buf, size_t len);
|
||||
void parse_recv_buffer_();
|
||||
bool send_(const void *buf, size_t len, bool force);
|
||||
|
||||
enum class ConnectionState {
|
||||
WAITING_FOR_HELLO,
|
||||
@@ -157,8 +158,10 @@ class APIConnection : public APIServerConnection {
|
||||
|
||||
bool remove_{false};
|
||||
|
||||
std::vector<uint8_t> send_buffer_;
|
||||
std::vector<uint8_t> recv_buffer_;
|
||||
// Buffer used to encode proto messages
|
||||
// Re-use to prevent allocations
|
||||
std::vector<uint8_t> proto_write_buffer_;
|
||||
std::unique_ptr<APIFrameHelper> helper_;
|
||||
|
||||
std::string client_info_;
|
||||
#ifdef USE_ESP32_CAMERA
|
||||
@@ -170,9 +173,7 @@ class APIConnection : public APIServerConnection {
|
||||
uint32_t last_traffic_;
|
||||
bool sent_ping_{false};
|
||||
bool service_call_subscription_{false};
|
||||
bool current_nodelay_{false};
|
||||
bool next_close_{false};
|
||||
AsyncClient *client_;
|
||||
bool next_close_ = false;
|
||||
APIServer *parent_;
|
||||
InitialStateIterator initial_state_iterator_;
|
||||
ListEntitiesIterator list_entities_iterator_;
|
||||
|
||||
907
esphome/components/api/api_frame_helper.cpp
Normal file
907
esphome/components/api/api_frame_helper.cpp
Normal file
@@ -0,0 +1,907 @@
|
||||
#include "api_frame_helper.h"
|
||||
|
||||
#include "esphome/core/log.h"
|
||||
#include "esphome/core/helpers.h"
|
||||
#include "proto.h"
|
||||
|
||||
namespace esphome {
|
||||
namespace api {
|
||||
|
||||
static const char *const TAG = "api.socket";
|
||||
|
||||
/// Is the given return value (from read/write syscalls) a wouldblock error?
|
||||
bool is_would_block(ssize_t ret) {
|
||||
if (ret == -1) {
|
||||
return errno == EWOULDBLOCK || errno == EAGAIN;
|
||||
}
|
||||
return ret == 0;
|
||||
}
|
||||
|
||||
#define HELPER_LOG(msg, ...) ESP_LOGVV(TAG, "%s: " msg, info_.c_str(), ##__VA_ARGS__)
|
||||
|
||||
#ifdef USE_API_NOISE
|
||||
static const char *const PROLOGUE_INIT = "NoiseAPIInit";
|
||||
|
||||
/// Convert a noise error code to a readable error
|
||||
std::string noise_err_to_str(int err) {
|
||||
if (err == NOISE_ERROR_NO_MEMORY)
|
||||
return "NO_MEMORY";
|
||||
if (err == NOISE_ERROR_UNKNOWN_ID)
|
||||
return "UNKNOWN_ID";
|
||||
if (err == NOISE_ERROR_UNKNOWN_NAME)
|
||||
return "UNKNOWN_NAME";
|
||||
if (err == NOISE_ERROR_MAC_FAILURE)
|
||||
return "MAC_FAILURE";
|
||||
if (err == NOISE_ERROR_NOT_APPLICABLE)
|
||||
return "NOT_APPLICABLE";
|
||||
if (err == NOISE_ERROR_SYSTEM)
|
||||
return "SYSTEM";
|
||||
if (err == NOISE_ERROR_REMOTE_KEY_REQUIRED)
|
||||
return "REMOTE_KEY_REQUIRED";
|
||||
if (err == NOISE_ERROR_LOCAL_KEY_REQUIRED)
|
||||
return "LOCAL_KEY_REQUIRED";
|
||||
if (err == NOISE_ERROR_PSK_REQUIRED)
|
||||
return "PSK_REQUIRED";
|
||||
if (err == NOISE_ERROR_INVALID_LENGTH)
|
||||
return "INVALID_LENGTH";
|
||||
if (err == NOISE_ERROR_INVALID_PARAM)
|
||||
return "INVALID_PARAM";
|
||||
if (err == NOISE_ERROR_INVALID_STATE)
|
||||
return "INVALID_STATE";
|
||||
if (err == NOISE_ERROR_INVALID_NONCE)
|
||||
return "INVALID_NONCE";
|
||||
if (err == NOISE_ERROR_INVALID_PRIVATE_KEY)
|
||||
return "INVALID_PRIVATE_KEY";
|
||||
if (err == NOISE_ERROR_INVALID_PUBLIC_KEY)
|
||||
return "INVALID_PUBLIC_KEY";
|
||||
if (err == NOISE_ERROR_INVALID_FORMAT)
|
||||
return "INVALID_FORMAT";
|
||||
if (err == NOISE_ERROR_INVALID_SIGNATURE)
|
||||
return "INVALID_SIGNATURE";
|
||||
return to_string(err);
|
||||
}
|
||||
|
||||
/// Initialize the frame helper, returns OK if successful.
|
||||
APIError APINoiseFrameHelper::init() {
|
||||
if (state_ != State::INITIALIZE || socket_ == nullptr) {
|
||||
HELPER_LOG("Bad state for init %d", (int) state_);
|
||||
return APIError::BAD_STATE;
|
||||
}
|
||||
int err = socket_->setblocking(false);
|
||||
if (err != 0) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Setting nonblocking failed with errno %d", errno);
|
||||
return APIError::TCP_NONBLOCKING_FAILED;
|
||||
}
|
||||
int enable = 1;
|
||||
err = socket_->setsockopt(IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(int));
|
||||
if (err != 0) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Setting nodelay failed with errno %d", errno);
|
||||
return APIError::TCP_NODELAY_FAILED;
|
||||
}
|
||||
|
||||
// init prologue
|
||||
prologue_.insert(prologue_.end(), PROLOGUE_INIT, PROLOGUE_INIT + strlen(PROLOGUE_INIT));
|
||||
|
||||
state_ = State::CLIENT_HELLO;
|
||||
return APIError::OK;
|
||||
}
|
||||
/// Run through handshake messages (if in that phase)
|
||||
APIError APINoiseFrameHelper::loop() {
|
||||
APIError err = state_action_();
|
||||
if (err == APIError::WOULD_BLOCK)
|
||||
return APIError::OK;
|
||||
if (err != APIError::OK)
|
||||
return err;
|
||||
if (!tx_buf_.empty()) {
|
||||
err = try_send_tx_buf_();
|
||||
if (err != APIError::OK) {
|
||||
return err;
|
||||
}
|
||||
}
|
||||
return APIError::OK;
|
||||
}
|
||||
|
||||
/** Read a packet into the rx_buf_. If successful, stores frame data in the frame parameter
|
||||
*
|
||||
* @param frame: The struct to hold the frame information in.
|
||||
* msg_start: points to the start of the payload - this pointer is only valid until the next
|
||||
* try_receive_raw_ call
|
||||
*
|
||||
* @return 0 if a full packet is in rx_buf_
|
||||
* @return -1 if error, check errno.
|
||||
*
|
||||
* errno EWOULDBLOCK: Packet could not be read without blocking. Try again later.
|
||||
* errno ENOMEM: Not enough memory for reading packet.
|
||||
* errno API_ERROR_BAD_INDICATOR: Bad indicator byte at start of frame.
|
||||
* errno API_ERROR_HANDSHAKE_PACKET_LEN: Packet too big for this phase.
|
||||
*/
|
||||
APIError APINoiseFrameHelper::try_read_frame_(ParsedFrame *frame) {
|
||||
int err;
|
||||
APIError aerr;
|
||||
|
||||
if (frame == nullptr) {
|
||||
HELPER_LOG("Bad argument for try_read_frame_");
|
||||
return APIError::BAD_ARG;
|
||||
}
|
||||
|
||||
// read header
|
||||
if (rx_header_buf_len_ < 3) {
|
||||
// no header information yet
|
||||
size_t to_read = 3 - rx_header_buf_len_;
|
||||
ssize_t received = socket_->read(&rx_header_buf_[rx_header_buf_len_], to_read);
|
||||
if (is_would_block(received)) {
|
||||
return APIError::WOULD_BLOCK;
|
||||
} else if (received == -1) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Socket read failed with errno %d", errno);
|
||||
return APIError::SOCKET_READ_FAILED;
|
||||
}
|
||||
rx_header_buf_len_ += received;
|
||||
if (received != to_read) {
|
||||
// not a full read
|
||||
return APIError::WOULD_BLOCK;
|
||||
}
|
||||
|
||||
// header reading done
|
||||
}
|
||||
|
||||
// read body
|
||||
uint8_t indicator = rx_header_buf_[0];
|
||||
if (indicator != 0x01) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Bad indicator byte %u", indicator);
|
||||
return APIError::BAD_INDICATOR;
|
||||
}
|
||||
|
||||
uint16_t msg_size = (((uint16_t) rx_header_buf_[1]) << 8) | rx_header_buf_[2];
|
||||
|
||||
if (state_ != State::DATA && msg_size > 128) {
|
||||
// for handshake message only permit up to 128 byte
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Bad packet len for handshake: %d", msg_size);
|
||||
return APIError::BAD_HANDSHAKE_PACKET_LEN;
|
||||
}
|
||||
|
||||
// reserve space for body
|
||||
if (rx_buf_.size() != msg_size) {
|
||||
rx_buf_.resize(msg_size);
|
||||
}
|
||||
|
||||
if (rx_buf_len_ < msg_size) {
|
||||
// more data to read
|
||||
size_t to_read = msg_size - rx_buf_len_;
|
||||
ssize_t received = socket_->read(&rx_buf_[rx_buf_len_], to_read);
|
||||
if (is_would_block(received)) {
|
||||
return APIError::WOULD_BLOCK;
|
||||
} else if (received == -1) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Socket read failed with errno %d", errno);
|
||||
return APIError::SOCKET_READ_FAILED;
|
||||
}
|
||||
rx_buf_len_ += received;
|
||||
if (received != to_read) {
|
||||
// not all read
|
||||
return APIError::WOULD_BLOCK;
|
||||
}
|
||||
}
|
||||
|
||||
// uncomment for even more debugging
|
||||
// ESP_LOGVV(TAG, "Received frame: %s", hexencode(rx_buf_).c_str());
|
||||
frame->msg = std::move(rx_buf_);
|
||||
// consume msg
|
||||
rx_buf_ = {};
|
||||
rx_buf_len_ = 0;
|
||||
rx_header_buf_len_ = 0;
|
||||
return APIError::OK;
|
||||
}
|
||||
|
||||
/** To be called from read/write methods.
|
||||
*
|
||||
* This method runs through the internal handshake methods, if in that state.
|
||||
*
|
||||
* If the handshake is still active when this method returns and a read/write can't take place at
|
||||
* the moment, returns WOULD_BLOCK.
|
||||
* If an error occured, returns that error. Only returns OK if the transport is ready for data
|
||||
* traffic.
|
||||
*/
|
||||
APIError APINoiseFrameHelper::state_action_() {
|
||||
int err;
|
||||
APIError aerr;
|
||||
if (state_ == State::INITIALIZE) {
|
||||
HELPER_LOG("Bad state for method: %d", (int) state_);
|
||||
return APIError::BAD_STATE;
|
||||
}
|
||||
if (state_ == State::CLIENT_HELLO) {
|
||||
// waiting for client hello
|
||||
ParsedFrame frame;
|
||||
aerr = try_read_frame_(&frame);
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
// ignore contents, may be used in future for flags
|
||||
prologue_.push_back((uint8_t) (frame.msg.size() >> 8));
|
||||
prologue_.push_back((uint8_t) frame.msg.size());
|
||||
prologue_.insert(prologue_.end(), frame.msg.begin(), frame.msg.end());
|
||||
|
||||
state_ = State::SERVER_HELLO;
|
||||
}
|
||||
if (state_ == State::SERVER_HELLO) {
|
||||
// send server hello
|
||||
uint8_t msg[1];
|
||||
msg[0] = 0x01; // chosen proto
|
||||
aerr = write_frame_(msg, 1);
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
|
||||
// start handshake
|
||||
aerr = init_handshake_();
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
|
||||
state_ = State::HANDSHAKE;
|
||||
}
|
||||
if (state_ == State::HANDSHAKE) {
|
||||
int action = noise_handshakestate_get_action(handshake_);
|
||||
if (action == NOISE_ACTION_READ_MESSAGE) {
|
||||
// waiting for handshake msg
|
||||
ParsedFrame frame;
|
||||
aerr = try_read_frame_(&frame);
|
||||
if (aerr == APIError::BAD_INDICATOR) {
|
||||
send_explicit_handshake_reject_("Bad indicator byte");
|
||||
return aerr;
|
||||
}
|
||||
if (frame.msg.size() < 1 || frame.msg[0] != 0x00) {
|
||||
aerr = APIError::BAD_HANDSHAKE_PACKET_LEN;
|
||||
}
|
||||
if (aerr == APIError::BAD_HANDSHAKE_PACKET_LEN) {
|
||||
send_explicit_handshake_reject_("Bad handshake packet len");
|
||||
return aerr;
|
||||
}
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
|
||||
NoiseBuffer mbuf;
|
||||
noise_buffer_init(mbuf);
|
||||
noise_buffer_set_input(mbuf, frame.msg.data() + 1, frame.msg.size() - 1);
|
||||
err = noise_handshakestate_read_message(handshake_, &mbuf, nullptr);
|
||||
if (err != 0) {
|
||||
// TODO: explicit rejection
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("noise_handshakestate_read_message failed: %s", noise_err_to_str(err).c_str());
|
||||
if (err == NOISE_ERROR_MAC_FAILURE) {
|
||||
send_explicit_handshake_reject_("Handshake MAC failure");
|
||||
} else {
|
||||
send_explicit_handshake_reject_("Handshake error");
|
||||
}
|
||||
return APIError::HANDSHAKESTATE_READ_FAILED;
|
||||
}
|
||||
|
||||
aerr = check_handshake_finished_();
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
} else if (action == NOISE_ACTION_WRITE_MESSAGE) {
|
||||
uint8_t buffer[65];
|
||||
NoiseBuffer mbuf;
|
||||
noise_buffer_init(mbuf);
|
||||
noise_buffer_set_output(mbuf, buffer + 1, sizeof(buffer) - 1);
|
||||
|
||||
err = noise_handshakestate_write_message(handshake_, &mbuf, nullptr);
|
||||
if (err != 0) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("noise_handshakestate_write_message failed: %s", noise_err_to_str(err).c_str());
|
||||
return APIError::HANDSHAKESTATE_WRITE_FAILED;
|
||||
}
|
||||
buffer[0] = 0x00; // success
|
||||
|
||||
aerr = write_frame_(buffer, mbuf.size + 1);
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
aerr = check_handshake_finished_();
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
} else {
|
||||
// bad state for action
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Bad action for handshake: %d", action);
|
||||
return APIError::HANDSHAKESTATE_BAD_STATE;
|
||||
}
|
||||
}
|
||||
if (state_ == State::CLOSED || state_ == State::FAILED) {
|
||||
return APIError::BAD_STATE;
|
||||
}
|
||||
return APIError::OK;
|
||||
}
|
||||
void APINoiseFrameHelper::send_explicit_handshake_reject_(const std::string &reason) {
|
||||
std::vector<uint8_t> data;
|
||||
data.reserve(reason.size() + 1);
|
||||
data[0] = 0x01; // failure
|
||||
for (size_t i = 0; i < reason.size(); i++) {
|
||||
data[i+1] = (uint8_t) reason[i];
|
||||
}
|
||||
write_frame_(data.data(), data.size());
|
||||
}
|
||||
|
||||
APIError APINoiseFrameHelper::read_packet(ReadPacketBuffer *buffer) {
|
||||
int err;
|
||||
APIError aerr;
|
||||
aerr = state_action_();
|
||||
if (aerr != APIError::OK) {
|
||||
return aerr;
|
||||
}
|
||||
|
||||
if (state_ != State::DATA) {
|
||||
return APIError::WOULD_BLOCK;
|
||||
}
|
||||
|
||||
ParsedFrame frame;
|
||||
aerr = try_read_frame_(&frame);
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
|
||||
NoiseBuffer mbuf;
|
||||
noise_buffer_init(mbuf);
|
||||
noise_buffer_set_inout(mbuf, frame.msg.data(), frame.msg.size(), frame.msg.size());
|
||||
err = noise_cipherstate_decrypt(recv_cipher_, &mbuf);
|
||||
if (err != 0) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("noise_cipherstate_decrypt failed: %s", noise_err_to_str(err).c_str());
|
||||
return APIError::CIPHERSTATE_DECRYPT_FAILED;
|
||||
}
|
||||
|
||||
size_t msg_size = mbuf.size;
|
||||
uint8_t *msg_data = frame.msg.data();
|
||||
if (msg_size < 4) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Bad data packet: size %d too short", msg_size);
|
||||
return APIError::BAD_DATA_PACKET;
|
||||
}
|
||||
|
||||
// uint16_t type;
|
||||
// uint16_t data_len;
|
||||
// uint8_t *data;
|
||||
// uint8_t *padding; zero or more bytes to fill up the rest of the packet
|
||||
uint16_t type = (((uint16_t) msg_data[0]) << 8) | msg_data[1];
|
||||
uint16_t data_len = (((uint16_t) msg_data[2]) << 8) | msg_data[3];
|
||||
if (data_len > msg_size - 4) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Bad data packet: data_len %u greater than msg_size %u", data_len, msg_size);
|
||||
return APIError::BAD_DATA_PACKET;
|
||||
}
|
||||
|
||||
buffer->container = std::move(frame.msg);
|
||||
buffer->data_offset = 4;
|
||||
buffer->data_len = data_len;
|
||||
buffer->type = type;
|
||||
return APIError::OK;
|
||||
}
|
||||
bool APINoiseFrameHelper::can_write_without_blocking() {
|
||||
return state_ == State::DATA && tx_buf_.empty();
|
||||
}
|
||||
APIError APINoiseFrameHelper::write_packet(uint16_t type, const uint8_t *payload, size_t payload_len) {
|
||||
int err;
|
||||
APIError aerr;
|
||||
aerr = state_action_();
|
||||
if (aerr != APIError::OK) {
|
||||
return aerr;
|
||||
}
|
||||
|
||||
if (state_ != State::DATA) {
|
||||
return APIError::WOULD_BLOCK;
|
||||
}
|
||||
|
||||
size_t padding = 0;
|
||||
size_t msg_len = 4 + payload_len + padding;
|
||||
size_t frame_len = 3 + msg_len + noise_cipherstate_get_mac_length(send_cipher_);
|
||||
auto tmpbuf = std::unique_ptr<uint8_t[]>{new (std::nothrow) uint8_t[frame_len]};
|
||||
if (tmpbuf == nullptr) {
|
||||
HELPER_LOG("Could not allocate for writing packet");
|
||||
return APIError::OUT_OF_MEMORY;
|
||||
}
|
||||
|
||||
tmpbuf[0] = 0x01; // indicator
|
||||
// tmpbuf[1], tmpbuf[2] to be set later
|
||||
const uint8_t msg_offset = 3;
|
||||
const uint8_t payload_offset = msg_offset + 4;
|
||||
tmpbuf[msg_offset + 0] = (uint8_t) (type >> 8); // type
|
||||
tmpbuf[msg_offset + 1] = (uint8_t) type;
|
||||
tmpbuf[msg_offset + 2] = (uint8_t) (payload_len >> 8); // data_len
|
||||
tmpbuf[msg_offset + 3] = (uint8_t) payload_len;
|
||||
// copy data
|
||||
std::copy(payload, payload + payload_len, &tmpbuf[payload_offset]);
|
||||
// fill padding with zeros
|
||||
std::fill(&tmpbuf[payload_offset + payload_len], &tmpbuf[frame_len], 0);
|
||||
|
||||
NoiseBuffer mbuf;
|
||||
noise_buffer_init(mbuf);
|
||||
noise_buffer_set_inout(mbuf, &tmpbuf[msg_offset], msg_len, frame_len - msg_offset);
|
||||
err = noise_cipherstate_encrypt(send_cipher_, &mbuf);
|
||||
if (err != 0) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("noise_cipherstate_encrypt failed: %s", noise_err_to_str(err).c_str());
|
||||
return APIError::CIPHERSTATE_ENCRYPT_FAILED;
|
||||
}
|
||||
|
||||
size_t total_len = 3 + mbuf.size;
|
||||
tmpbuf[1] = (uint8_t) (mbuf.size >> 8);
|
||||
tmpbuf[2] = (uint8_t) mbuf.size;
|
||||
// write raw to not have two packets sent if NAGLE disabled
|
||||
aerr = write_raw_(&tmpbuf[0], total_len);
|
||||
if (aerr != APIError::OK) {
|
||||
return aerr;
|
||||
}
|
||||
return APIError::OK;
|
||||
}
|
||||
APIError APINoiseFrameHelper::try_send_tx_buf_() {
|
||||
// try send from tx_buf
|
||||
while (state_ != State::CLOSED && !tx_buf_.empty()) {
|
||||
ssize_t sent = socket_->write(tx_buf_.data(), tx_buf_.size());
|
||||
if (sent == -1) {
|
||||
if (errno == EWOULDBLOCK || errno == EAGAIN)
|
||||
break;
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Socket write failed with errno %d", errno);
|
||||
return APIError::SOCKET_WRITE_FAILED;
|
||||
} else if (sent == 0) {
|
||||
break;
|
||||
}
|
||||
// TODO: inefficient if multiple packets in txbuf
|
||||
// replace with deque of buffers
|
||||
tx_buf_.erase(tx_buf_.begin(), tx_buf_.begin() + sent);
|
||||
}
|
||||
|
||||
return APIError::OK;
|
||||
}
|
||||
/** Write the data to the socket, or buffer it a write would block
|
||||
*
|
||||
* @param data The data to write
|
||||
* @param len The length of data
|
||||
*/
|
||||
APIError APINoiseFrameHelper::write_raw_(const uint8_t *data, size_t len) {
|
||||
if (len == 0)
|
||||
return APIError::OK;
|
||||
int err;
|
||||
APIError aerr;
|
||||
|
||||
// uncomment for even more debugging
|
||||
// ESP_LOGVV(TAG, "Sending raw: %s", hexencode(data, len).c_str());
|
||||
|
||||
if (!tx_buf_.empty()) {
|
||||
// try to empty tx_buf_ first
|
||||
aerr = try_send_tx_buf_();
|
||||
if (aerr != APIError::OK && aerr != APIError::WOULD_BLOCK)
|
||||
return aerr;
|
||||
}
|
||||
|
||||
if (!tx_buf_.empty()) {
|
||||
// tx buf not empty, can't write now because then stream would be inconsistent
|
||||
tx_buf_.insert(tx_buf_.end(), data, data + len);
|
||||
return APIError::OK;
|
||||
}
|
||||
|
||||
ssize_t sent = socket_->write(data, len);
|
||||
if (is_would_block(sent)) {
|
||||
// operation would block, add buffer to tx_buf
|
||||
tx_buf_.insert(tx_buf_.end(), data, data + len);
|
||||
return APIError::OK;
|
||||
} else if (sent == -1) {
|
||||
// an error occured
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Socket write failed with errno %d", errno);
|
||||
return APIError::SOCKET_WRITE_FAILED;
|
||||
} else if (sent != len) {
|
||||
// partially sent, add end to tx_buf
|
||||
tx_buf_.insert(tx_buf_.end(), data + sent, data + len);
|
||||
return APIError::OK;
|
||||
}
|
||||
// fully sent
|
||||
return APIError::OK;
|
||||
}
|
||||
APIError APINoiseFrameHelper::write_frame_(const uint8_t *data, size_t len) {
|
||||
APIError aerr;
|
||||
|
||||
uint8_t header[3];
|
||||
header[0] = 0x01; // indicator
|
||||
header[1] = (uint8_t) (len >> 8);
|
||||
header[2] = (uint8_t) len;
|
||||
|
||||
aerr = write_raw_(header, 3);
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
aerr = write_raw_(data, len);
|
||||
return aerr;
|
||||
}
|
||||
|
||||
/** Initiate the data structures for the handshake.
|
||||
*
|
||||
* @return 0 on success, -1 on error (check errno)
|
||||
*/
|
||||
APIError APINoiseFrameHelper::init_handshake_() {
|
||||
int err;
|
||||
memset(&nid_, 0, sizeof(nid_));
|
||||
// const char *proto = "Noise_NNpsk0_25519_ChaChaPoly_SHA256";
|
||||
// err = noise_protocol_name_to_id(&nid_, proto, strlen(proto));
|
||||
nid_.pattern_id = NOISE_PATTERN_NN;
|
||||
nid_.cipher_id = NOISE_CIPHER_CHACHAPOLY;
|
||||
nid_.dh_id = NOISE_DH_CURVE25519;
|
||||
nid_.prefix_id = NOISE_PREFIX_STANDARD;
|
||||
nid_.hybrid_id = NOISE_DH_NONE;
|
||||
nid_.hash_id = NOISE_HASH_SHA256;
|
||||
nid_.modifier_ids[0] = NOISE_MODIFIER_PSK0;
|
||||
|
||||
err = noise_handshakestate_new_by_id(&handshake_, &nid_, NOISE_ROLE_RESPONDER);
|
||||
if (err != 0) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("noise_handshakestate_new_by_id failed: %s", noise_err_to_str(err).c_str());
|
||||
return APIError::HANDSHAKESTATE_SETUP_FAILED;
|
||||
}
|
||||
|
||||
const auto &psk = ctx_->get_psk();
|
||||
err = noise_handshakestate_set_pre_shared_key(handshake_, psk.data(), psk.size());
|
||||
if (err != 0) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("noise_handshakestate_set_pre_shared_key failed: %s", noise_err_to_str(err).c_str());
|
||||
return APIError::HANDSHAKESTATE_SETUP_FAILED;
|
||||
}
|
||||
|
||||
err = noise_handshakestate_set_prologue(handshake_, prologue_.data(), prologue_.size());
|
||||
if (err != 0) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("noise_handshakestate_set_prologue failed: %s", noise_err_to_str(err).c_str());
|
||||
return APIError::HANDSHAKESTATE_SETUP_FAILED;
|
||||
}
|
||||
// set_prologue copies it into handshakestate, so we can get rid of it now
|
||||
prologue_ = {};
|
||||
|
||||
err = noise_handshakestate_start(handshake_);
|
||||
if (err != 0) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("noise_handshakestate_start failed: %s", noise_err_to_str(err).c_str());
|
||||
return APIError::HANDSHAKESTATE_SETUP_FAILED;
|
||||
}
|
||||
return APIError::OK;
|
||||
}
|
||||
|
||||
APIError APINoiseFrameHelper::check_handshake_finished_() {
|
||||
assert(state_ == State::HANDSHAKE);
|
||||
|
||||
int action = noise_handshakestate_get_action(handshake_);
|
||||
if (action == NOISE_ACTION_READ_MESSAGE || action == NOISE_ACTION_WRITE_MESSAGE)
|
||||
return APIError::OK;
|
||||
if (action != NOISE_ACTION_SPLIT) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Bad action for handshake: %d", action);
|
||||
return APIError::HANDSHAKESTATE_BAD_STATE;
|
||||
}
|
||||
int err = noise_handshakestate_split(handshake_, &send_cipher_, &recv_cipher_);
|
||||
if (err != 0) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("noise_handshakestate_split failed: %s", noise_err_to_str(err).c_str());
|
||||
return APIError::HANDSHAKESTATE_SPLIT_FAILED;
|
||||
}
|
||||
|
||||
HELPER_LOG("Handshake complete!");
|
||||
noise_handshakestate_free(handshake_);
|
||||
handshake_ = nullptr;
|
||||
state_ = State::DATA;
|
||||
return APIError::OK;
|
||||
}
|
||||
|
||||
APINoiseFrameHelper::~APINoiseFrameHelper() {
|
||||
if (handshake_ != nullptr) {
|
||||
noise_handshakestate_free(handshake_);
|
||||
handshake_ = nullptr;
|
||||
}
|
||||
if (send_cipher_ != nullptr) {
|
||||
noise_cipherstate_free(send_cipher_);
|
||||
send_cipher_ = nullptr;
|
||||
}
|
||||
if (recv_cipher_ != nullptr) {
|
||||
noise_cipherstate_free(recv_cipher_);
|
||||
recv_cipher_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
APIError APINoiseFrameHelper::close() {
|
||||
state_ = State::CLOSED;
|
||||
int err = socket_->close();
|
||||
if (err == -1)
|
||||
return APIError::CLOSE_FAILED;
|
||||
return APIError::OK;
|
||||
}
|
||||
APIError APINoiseFrameHelper::shutdown(int how) {
|
||||
int err = socket_->shutdown(how);
|
||||
if (err == -1)
|
||||
return APIError::SHUTDOWN_FAILED;
|
||||
if (how == SHUT_RDWR) {
|
||||
state_ = State::CLOSED;
|
||||
}
|
||||
return APIError::OK;
|
||||
}
|
||||
extern "C" {
|
||||
// declare how noise generates random bytes (here with a good HWRNG based on the RF system)
|
||||
void noise_rand_bytes(void *output, size_t len) {
|
||||
esphome::fill_random(reinterpret_cast<uint8_t *>(output), len);
|
||||
}
|
||||
}
|
||||
#endif // USE_API_NOISE
|
||||
|
||||
|
||||
#ifdef USE_API_PLAINTEXT
|
||||
|
||||
/// Initialize the frame helper, returns OK if successful.
|
||||
APIError APIPlaintextFrameHelper::init() {
|
||||
if (state_ != State::INITIALIZE || socket_ == nullptr) {
|
||||
HELPER_LOG("Bad state for init %d", (int) state_);
|
||||
return APIError::BAD_STATE;
|
||||
}
|
||||
int err = socket_->setblocking(false);
|
||||
if (err != 0) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Setting nonblocking failed with errno %d", errno);
|
||||
return APIError::TCP_NONBLOCKING_FAILED;
|
||||
}
|
||||
int enable = 1;
|
||||
err = socket_->setsockopt(IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(int));
|
||||
if (err != 0) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Setting nodelay failed with errno %d", errno);
|
||||
return APIError::TCP_NODELAY_FAILED;
|
||||
}
|
||||
|
||||
state_ = State::DATA;
|
||||
return APIError::OK;
|
||||
}
|
||||
/// Not used for plaintext
|
||||
APIError APIPlaintextFrameHelper::loop() {
|
||||
if (state_ != State::DATA) {
|
||||
return APIError::BAD_STATE;
|
||||
}
|
||||
// try send pending TX data
|
||||
if (!tx_buf_.empty()) {
|
||||
APIError err = try_send_tx_buf_();
|
||||
if (err != APIError::OK) {
|
||||
return err;
|
||||
}
|
||||
}
|
||||
return APIError::OK;
|
||||
}
|
||||
|
||||
/** Read a packet into the rx_buf_. If successful, stores frame data in the frame parameter
|
||||
*
|
||||
* @param frame: The struct to hold the frame information in.
|
||||
* msg: store the parsed frame in that struct
|
||||
*
|
||||
* @return See APIError
|
||||
*
|
||||
* error API_ERROR_BAD_INDICATOR: Bad indicator byte at start of frame.
|
||||
*/
|
||||
APIError APIPlaintextFrameHelper::try_read_frame_(ParsedFrame *frame) {
|
||||
int err;
|
||||
APIError aerr;
|
||||
|
||||
if (frame == nullptr) {
|
||||
HELPER_LOG("Bad argument for try_read_frame_");
|
||||
return APIError::BAD_ARG;
|
||||
}
|
||||
|
||||
// read header
|
||||
while (!rx_header_parsed_) {
|
||||
uint8_t data;
|
||||
ssize_t received = socket_->read(&data, 1);
|
||||
if (is_would_block(received)) {
|
||||
return APIError::WOULD_BLOCK;
|
||||
} else if (received == -1) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Socket read failed with errno %d", errno);
|
||||
return APIError::SOCKET_READ_FAILED;
|
||||
}
|
||||
rx_header_buf_.push_back(data);
|
||||
|
||||
// try parse header
|
||||
if (rx_header_buf_[0] != 0x00) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Bad indicator byte %u", rx_header_buf_[0]);
|
||||
return APIError::BAD_INDICATOR;
|
||||
}
|
||||
|
||||
size_t i = 1;
|
||||
size_t consumed = 0;
|
||||
auto msg_size_varint = ProtoVarInt::parse(&rx_header_buf_[i], rx_header_buf_.size() - i, &consumed);
|
||||
if (!msg_size_varint.has_value()) {
|
||||
// not enough data there yet
|
||||
continue;
|
||||
}
|
||||
|
||||
i += consumed;
|
||||
rx_header_parsed_len_ = msg_size_varint->as_uint32();
|
||||
|
||||
auto msg_type_varint = ProtoVarInt::parse(&rx_header_buf_[i], rx_header_buf_.size() - i, &consumed);
|
||||
if (!msg_type_varint.has_value()) {
|
||||
// not enough data there yet
|
||||
continue;
|
||||
}
|
||||
rx_header_parsed_type_ = msg_type_varint->as_uint32();
|
||||
rx_header_parsed_ = true;
|
||||
}
|
||||
// header reading done
|
||||
|
||||
// reserve space for body
|
||||
if (rx_buf_.size() != rx_header_parsed_len_) {
|
||||
rx_buf_.resize(rx_header_parsed_len_);
|
||||
}
|
||||
|
||||
if (rx_buf_len_ < rx_header_parsed_len_) {
|
||||
// more data to read
|
||||
size_t to_read = rx_header_parsed_len_ - rx_buf_len_;
|
||||
ssize_t received = socket_->read(&rx_buf_[rx_buf_len_], to_read);
|
||||
if (is_would_block(received)) {
|
||||
return APIError::WOULD_BLOCK;
|
||||
} else if (received == -1) {
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Socket read failed with errno %d", errno);
|
||||
return APIError::SOCKET_READ_FAILED;
|
||||
}
|
||||
rx_buf_len_ += received;
|
||||
if (received != to_read) {
|
||||
// not all read
|
||||
return APIError::WOULD_BLOCK;
|
||||
}
|
||||
}
|
||||
|
||||
// uncomment for even more debugging
|
||||
// ESP_LOGVV(TAG, "Received frame: %s", hexencode(rx_buf_).c_str());
|
||||
frame->msg = std::move(rx_buf_);
|
||||
// consume msg
|
||||
rx_buf_ = {};
|
||||
rx_buf_len_ = 0;
|
||||
rx_header_buf_.clear();
|
||||
rx_header_parsed_ = false;
|
||||
return APIError::OK;
|
||||
}
|
||||
|
||||
APIError APIPlaintextFrameHelper::read_packet(ReadPacketBuffer *buffer) {
|
||||
int err;
|
||||
APIError aerr;
|
||||
|
||||
if (state_ != State::DATA) {
|
||||
return APIError::WOULD_BLOCK;
|
||||
}
|
||||
|
||||
ParsedFrame frame;
|
||||
aerr = try_read_frame_(&frame);
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
|
||||
buffer->container = std::move(frame.msg);
|
||||
buffer->data_offset = 0;
|
||||
buffer->data_len = rx_header_parsed_len_;
|
||||
buffer->type = rx_header_parsed_type_;
|
||||
return APIError::OK;
|
||||
}
|
||||
bool APIPlaintextFrameHelper::can_write_without_blocking() {
|
||||
return state_ == State::DATA && tx_buf_.empty();
|
||||
}
|
||||
APIError APIPlaintextFrameHelper::write_packet(uint16_t type, const uint8_t *payload, size_t payload_len) {
|
||||
int err;
|
||||
APIError aerr;
|
||||
|
||||
if (state_ != State::DATA) {
|
||||
return APIError::BAD_STATE;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> header;
|
||||
header.push_back(0x00);
|
||||
ProtoVarInt(payload_len).encode(header);
|
||||
ProtoVarInt(type).encode(header);
|
||||
|
||||
aerr = write_raw_(&header[0], header.size());
|
||||
if (aerr != APIError::OK) {
|
||||
return aerr;
|
||||
}
|
||||
aerr = write_raw_(payload, payload_len);
|
||||
if (aerr != APIError::OK) {
|
||||
return aerr;
|
||||
}
|
||||
return APIError::OK;
|
||||
}
|
||||
APIError APIPlaintextFrameHelper::try_send_tx_buf_() {
|
||||
// try send from tx_buf
|
||||
while (state_ != State::CLOSED && !tx_buf_.empty()) {
|
||||
ssize_t sent = socket_->write(tx_buf_.data(), tx_buf_.size());
|
||||
if (sent == -1) {
|
||||
if (errno == EWOULDBLOCK || errno == EAGAIN)
|
||||
break;
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Socket write failed with errno %d", errno);
|
||||
return APIError::SOCKET_WRITE_FAILED;
|
||||
} else if (sent == 0) {
|
||||
break;
|
||||
}
|
||||
// TODO: inefficient if multiple packets in txbuf
|
||||
// replace with deque of buffers
|
||||
tx_buf_.erase(tx_buf_.begin(), tx_buf_.begin() + sent);
|
||||
}
|
||||
|
||||
return APIError::OK;
|
||||
}
|
||||
/** Write the data to the socket, or buffer it a write would block
|
||||
*
|
||||
* @param data The data to write
|
||||
* @param len The length of data
|
||||
*/
|
||||
APIError APIPlaintextFrameHelper::write_raw_(const uint8_t *data, size_t len) {
|
||||
if (len == 0)
|
||||
return APIError::OK;
|
||||
int err;
|
||||
APIError aerr;
|
||||
|
||||
// uncomment for even more debugging
|
||||
// ESP_LOGVV(TAG, "Sending raw: %s", hexencode(data, len).c_str());
|
||||
|
||||
if (!tx_buf_.empty()) {
|
||||
// try to empty tx_buf_ first
|
||||
aerr = try_send_tx_buf_();
|
||||
if (aerr != APIError::OK && aerr != APIError::WOULD_BLOCK)
|
||||
return aerr;
|
||||
}
|
||||
|
||||
if (!tx_buf_.empty()) {
|
||||
// tx buf not empty, can't write now because then stream would be inconsistent
|
||||
tx_buf_.insert(tx_buf_.end(), data, data + len);
|
||||
return APIError::OK;
|
||||
}
|
||||
|
||||
ssize_t sent = socket_->write(data, len);
|
||||
if (is_would_block(sent)) {
|
||||
// operation would block, add buffer to tx_buf
|
||||
tx_buf_.insert(tx_buf_.end(), data, data + len);
|
||||
return APIError::OK;
|
||||
} else if (sent == -1) {
|
||||
// an error occured
|
||||
state_ = State::FAILED;
|
||||
HELPER_LOG("Socket write failed with errno %d", errno);
|
||||
return APIError::SOCKET_WRITE_FAILED;
|
||||
} else if (sent != len) {
|
||||
// partially sent, add end to tx_buf
|
||||
tx_buf_.insert(tx_buf_.end(), data + sent, data + len);
|
||||
return APIError::OK;
|
||||
}
|
||||
// fully sent
|
||||
return APIError::OK;
|
||||
}
|
||||
APIError APIPlaintextFrameHelper::write_frame_(const uint8_t *data, size_t len) {
|
||||
APIError aerr;
|
||||
|
||||
uint8_t header[3];
|
||||
header[0] = 0x01; // indicator
|
||||
header[1] = (uint8_t) (len >> 8);
|
||||
header[2] = (uint8_t) len;
|
||||
|
||||
aerr = write_raw_(header, 3);
|
||||
if (aerr != APIError::OK)
|
||||
return aerr;
|
||||
aerr = write_raw_(data, len);
|
||||
return aerr;
|
||||
}
|
||||
|
||||
APIError APIPlaintextFrameHelper::close() {
|
||||
state_ = State::CLOSED;
|
||||
int err = socket_->close();
|
||||
if (err == -1)
|
||||
return APIError::CLOSE_FAILED;
|
||||
return APIError::OK;
|
||||
}
|
||||
APIError APIPlaintextFrameHelper::shutdown(int how) {
|
||||
int err = socket_->shutdown(how);
|
||||
if (err == -1)
|
||||
return APIError::SHUTDOWN_FAILED;
|
||||
if (how == SHUT_RDWR) {
|
||||
state_ = State::CLOSED;
|
||||
}
|
||||
return APIError::OK;
|
||||
}
|
||||
#endif // USE_API_PLAINTEXT
|
||||
|
||||
} // namespace api
|
||||
} // namespace esphome
|
||||
186
esphome/components/api/api_frame_helper.h
Normal file
186
esphome/components/api/api_frame_helper.h
Normal file
@@ -0,0 +1,186 @@
|
||||
#pragma once
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
#include <deque>
|
||||
|
||||
#include "esphome/core/defines.h"
|
||||
|
||||
#ifdef USE_API_NOISE
|
||||
#include "noise/protocol.h"
|
||||
#endif
|
||||
|
||||
#include "esphome/components/socket/socket.h"
|
||||
#include "api_noise_context.h"
|
||||
|
||||
namespace esphome {
|
||||
namespace api {
|
||||
|
||||
struct ReadPacketBuffer {
|
||||
std::vector<uint8_t> container;
|
||||
uint16_t type;
|
||||
size_t data_offset;
|
||||
size_t data_len;
|
||||
};
|
||||
|
||||
struct PacketBuffer {
|
||||
const std::vector<uint8_t> container;
|
||||
uint16_t type;
|
||||
uint8_t data_offset;
|
||||
uint8_t data_len;
|
||||
};
|
||||
|
||||
enum class APIError : int {
|
||||
OK = 0,
|
||||
WOULD_BLOCK = 1001,
|
||||
BAD_HANDSHAKE_PACKET_LEN = 1002,
|
||||
BAD_INDICATOR = 1003,
|
||||
BAD_DATA_PACKET = 1004,
|
||||
TCP_NODELAY_FAILED = 1005,
|
||||
TCP_NONBLOCKING_FAILED = 1006,
|
||||
CLOSE_FAILED = 1007,
|
||||
SHUTDOWN_FAILED = 1008,
|
||||
BAD_STATE = 1009,
|
||||
BAD_ARG = 1010,
|
||||
SOCKET_READ_FAILED = 1011,
|
||||
SOCKET_WRITE_FAILED = 1012,
|
||||
HANDSHAKESTATE_READ_FAILED = 1013,
|
||||
HANDSHAKESTATE_WRITE_FAILED = 1014,
|
||||
HANDSHAKESTATE_BAD_STATE = 1015,
|
||||
CIPHERSTATE_DECRYPT_FAILED = 1016,
|
||||
CIPHERSTATE_ENCRYPT_FAILED = 1017,
|
||||
OUT_OF_MEMORY = 1018,
|
||||
HANDSHAKESTATE_SETUP_FAILED = 1019,
|
||||
HANDSHAKESTATE_SPLIT_FAILED = 1020,
|
||||
};
|
||||
|
||||
class APIFrameHelper {
|
||||
public:
|
||||
virtual APIError init() = 0;
|
||||
virtual APIError loop() = 0;
|
||||
virtual APIError read_packet(ReadPacketBuffer *buffer) = 0;
|
||||
virtual bool can_write_without_blocking() = 0;
|
||||
virtual APIError write_packet(uint16_t type, const uint8_t *data, size_t len) = 0;
|
||||
virtual std::string getpeername() = 0;
|
||||
virtual APIError close() = 0;
|
||||
virtual APIError shutdown(int how) = 0;
|
||||
// Give this helper a name for logging
|
||||
virtual void set_log_info(std::string info) = 0;
|
||||
};
|
||||
|
||||
#ifdef USE_API_NOISE
|
||||
class APINoiseFrameHelper : public APIFrameHelper {
|
||||
public:
|
||||
APINoiseFrameHelper(std::unique_ptr<socket::Socket> socket, std::shared_ptr<APINoiseContext> ctx) : socket_(std::move(socket)), ctx_(ctx) {}
|
||||
~APINoiseFrameHelper();
|
||||
APIError init() override;
|
||||
APIError loop() override;
|
||||
APIError read_packet(ReadPacketBuffer *buffer) override;
|
||||
bool can_write_without_blocking() override;
|
||||
APIError write_packet(uint16_t type, const uint8_t *data, size_t len) override;
|
||||
std::string getpeername() override{
|
||||
return socket_->getpeername();
|
||||
}
|
||||
APIError close() override;
|
||||
APIError shutdown(int how) override;
|
||||
// Give this helper a name for logging
|
||||
void set_log_info(std::string info) override {
|
||||
info_ = std::move(info);
|
||||
}
|
||||
|
||||
protected:
|
||||
struct ParsedFrame {
|
||||
std::vector<uint8_t> msg;
|
||||
};
|
||||
|
||||
APIError state_action_();
|
||||
APIError try_read_frame_(ParsedFrame *frame);
|
||||
APIError try_send_tx_buf_();
|
||||
APIError write_frame_(const uint8_t *data, size_t len);
|
||||
APIError write_raw_(const uint8_t *data, size_t len);
|
||||
APIError init_handshake_();
|
||||
APIError check_handshake_finished_();
|
||||
void send_explicit_handshake_reject_(const std::string &reason);
|
||||
|
||||
std::unique_ptr<socket::Socket> socket_;
|
||||
|
||||
std::string info_;
|
||||
uint8_t rx_header_buf_[3];
|
||||
size_t rx_header_buf_len_ = 0;
|
||||
std::vector<uint8_t> rx_buf_;
|
||||
size_t rx_buf_len_ = 0;
|
||||
|
||||
std::vector<uint8_t> tx_buf_;
|
||||
std::vector<uint8_t> prologue_;
|
||||
|
||||
std::shared_ptr<APINoiseContext> ctx_;
|
||||
NoiseHandshakeState *handshake_ = nullptr;
|
||||
NoiseCipherState *send_cipher_ = nullptr;
|
||||
NoiseCipherState *recv_cipher_ = nullptr;
|
||||
NoiseProtocolId nid_;
|
||||
|
||||
enum class State {
|
||||
INITIALIZE = 1,
|
||||
CLIENT_HELLO = 2,
|
||||
SERVER_HELLO = 3,
|
||||
HANDSHAKE = 4,
|
||||
DATA = 5,
|
||||
CLOSED = 6,
|
||||
FAILED = 7,
|
||||
} state_ = State::INITIALIZE;
|
||||
};
|
||||
#endif // USE_API_NOISE
|
||||
|
||||
#ifdef USE_API_PLAINTEXT
|
||||
class APIPlaintextFrameHelper : public APIFrameHelper {
|
||||
public:
|
||||
APIPlaintextFrameHelper(std::unique_ptr<socket::Socket> socket) : socket_(std::move(socket)) {}
|
||||
~APIPlaintextFrameHelper() = default;
|
||||
APIError init() override;
|
||||
APIError loop() override;
|
||||
APIError read_packet(ReadPacketBuffer *buffer) override;
|
||||
bool can_write_without_blocking() override;
|
||||
APIError write_packet(uint16_t type, const uint8_t *data, size_t len) override;
|
||||
std::string getpeername() override {
|
||||
return socket_->getpeername();
|
||||
}
|
||||
APIError close() override;
|
||||
APIError shutdown(int how) override;
|
||||
// Give this helper a name for logging
|
||||
void set_log_info(std::string info) override {
|
||||
info_ = std::move(info);
|
||||
}
|
||||
|
||||
protected:
|
||||
struct ParsedFrame {
|
||||
std::vector<uint8_t> msg;
|
||||
};
|
||||
|
||||
APIError try_read_frame_(ParsedFrame *frame);
|
||||
APIError try_send_tx_buf_();
|
||||
APIError write_frame_(const uint8_t *data, size_t len);
|
||||
APIError write_raw_(const uint8_t *data, size_t len);
|
||||
|
||||
std::unique_ptr<socket::Socket> socket_;
|
||||
|
||||
std::string info_;
|
||||
std::vector<uint8_t> rx_header_buf_;
|
||||
bool rx_header_parsed_ = false;
|
||||
uint32_t rx_header_parsed_type_ = 0;
|
||||
uint32_t rx_header_parsed_len_ = 0;
|
||||
|
||||
std::vector<uint8_t> rx_buf_;
|
||||
size_t rx_buf_len_ = 0;
|
||||
|
||||
std::vector<uint8_t> tx_buf_;
|
||||
|
||||
enum class State {
|
||||
INITIALIZE = 1,
|
||||
DATA = 2,
|
||||
CLOSED = 3,
|
||||
FAILED = 4,
|
||||
} state_ = State::INITIALIZE;
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace api
|
||||
} // namespace esphome
|
||||
27
esphome/components/api/api_noise_context.h
Normal file
27
esphome/components/api/api_noise_context.h
Normal file
@@ -0,0 +1,27 @@
|
||||
#pragma once
|
||||
#include <cstdint>
|
||||
#include <array>
|
||||
#include "esphome/core/defines.h"
|
||||
|
||||
namespace esphome {
|
||||
namespace api {
|
||||
|
||||
#ifdef USE_API_NOISE
|
||||
using psk_t = std::array<uint8_t, 32>;
|
||||
|
||||
class APINoiseContext {
|
||||
public:
|
||||
void set_psk(psk_t psk) {
|
||||
psk_ = std::move(psk);
|
||||
}
|
||||
const psk_t &get_psk() const {
|
||||
return psk_;
|
||||
}
|
||||
|
||||
protected:
|
||||
psk_t psk_;
|
||||
};
|
||||
#endif // USE_API_NOISE
|
||||
|
||||
} // namespace api
|
||||
} // namespace esphome
|
||||
@@ -5,6 +5,8 @@
|
||||
#include "esphome/core/util.h"
|
||||
#include "esphome/core/defines.h"
|
||||
#include "esphome/core/version.h"
|
||||
#include <errno.h>
|
||||
//#include <arpa/inet.h>
|
||||
|
||||
#ifdef USE_LOGGER
|
||||
#include "esphome/components/logger/logger.h"
|
||||
@@ -21,20 +23,54 @@ static const char *const TAG = "api";
|
||||
void APIServer::setup() {
|
||||
ESP_LOGCONFIG(TAG, "Setting up Home Assistant API server...");
|
||||
this->setup_controller();
|
||||
this->server_ = AsyncServer(this->port_);
|
||||
this->server_.setNoDelay(false);
|
||||
this->server_.begin();
|
||||
this->server_.onClient(
|
||||
[](void *s, AsyncClient *client) {
|
||||
if (client == nullptr)
|
||||
return;
|
||||
socket_ = socket::socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (socket_ == nullptr) {
|
||||
ESP_LOGW(TAG, "Could not create socket.");
|
||||
this->mark_failed();
|
||||
return;
|
||||
}
|
||||
int enable = 1;
|
||||
int err = socket_->setsockopt(SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
|
||||
if (err != 0) {
|
||||
ESP_LOGW(TAG, "Socket unable to set reuseaddr: errno %d", err);
|
||||
// we can still continue
|
||||
}
|
||||
err = socket_->setblocking(false);
|
||||
if (err != 0) {
|
||||
ESP_LOGW(TAG, "Socket unable to set nonblocking mode: errno %d", err);
|
||||
this->mark_failed();
|
||||
return;
|
||||
}
|
||||
|
||||
/*struct sockaddr_storage dest_addr;
|
||||
memset(&dest_addr, 0, sizeof(dest_addr));
|
||||
struct sockaddr_in *dest_addr_ip4 = (struct sockaddr_in *) &dest_addr;
|
||||
dest_addr_ip4->sin_addr.s_addr = htonl(INADDR_ANY);
|
||||
dest_addr_ip4->sin_family = AF_INET;
|
||||
dest_addr_ip4->sin_port = htons(this->port_);
|
||||
|
||||
err = socket_->bind((struct sockaddr *) &dest_addr, sizeof(dest_addr));*/
|
||||
|
||||
struct sockaddr_in server;
|
||||
memset(&server, 0, sizeof(server));
|
||||
server.sin_family = AF_INET;
|
||||
server.sin_addr.s_addr = INADDR_ANY;
|
||||
server.sin_port = htons(this->port_);
|
||||
|
||||
err = socket_->bind((struct sockaddr *) &server, sizeof(server));
|
||||
if (err != 0) {
|
||||
ESP_LOGW(TAG, "Socket unable to bind: errno %d", errno);
|
||||
this->mark_failed();
|
||||
return;
|
||||
}
|
||||
|
||||
err = socket_->listen(4);
|
||||
if (err != 0) {
|
||||
ESP_LOGW(TAG, "Socket unable to listen: errno %d", errno);
|
||||
this->mark_failed();
|
||||
return;
|
||||
}
|
||||
|
||||
// can't print here because in lwIP thread
|
||||
// ESP_LOGD(TAG, "New client connected from %s", client->remoteIP().toString().c_str());
|
||||
auto *a_this = (APIServer *) s;
|
||||
a_this->clients_.push_back(new APIConnection(client, a_this));
|
||||
},
|
||||
this);
|
||||
#ifdef USE_LOGGER
|
||||
if (logger::global_logger != nullptr) {
|
||||
logger::global_logger->add_on_log_callback([this](int level, const char *tag, const char *message) {
|
||||
@@ -59,6 +95,20 @@ void APIServer::setup() {
|
||||
#endif
|
||||
}
|
||||
void APIServer::loop() {
|
||||
// Accept new clients
|
||||
while (true) {
|
||||
struct sockaddr_storage source_addr;
|
||||
socklen_t addr_len = sizeof(source_addr);
|
||||
auto sock = socket_->accept((struct sockaddr *) &source_addr, &addr_len);
|
||||
if (!sock)
|
||||
break;
|
||||
ESP_LOGD(TAG, "Accepted %s", sock->getpeername().c_str());
|
||||
|
||||
auto *conn = new APIConnection(std::move(sock), this);
|
||||
clients_.push_back(conn);
|
||||
conn->start();
|
||||
}
|
||||
|
||||
// Partition clients into remove and active
|
||||
auto new_end =
|
||||
std::partition(this->clients_.begin(), this->clients_.end(), [](APIConnection *conn) { return !conn->remove_; });
|
||||
|
||||
@@ -4,19 +4,14 @@
|
||||
#include "esphome/core/controller.h"
|
||||
#include "esphome/core/defines.h"
|
||||
#include "esphome/core/log.h"
|
||||
#include "esphome/components/socket/socket.h"
|
||||
#include "api_pb2.h"
|
||||
#include "api_pb2_service.h"
|
||||
#include "util.h"
|
||||
#include "list_entities.h"
|
||||
#include "subscribe_state.h"
|
||||
#include "user_services.h"
|
||||
|
||||
#ifdef ARDUINO_ARCH_ESP32
|
||||
#include <AsyncTCP.h>
|
||||
#endif
|
||||
#ifdef ARDUINO_ARCH_ESP8266
|
||||
#include <ESPAsyncTCP.h>
|
||||
#endif
|
||||
#include "api_noise_context.h"
|
||||
|
||||
namespace esphome {
|
||||
namespace api {
|
||||
@@ -35,6 +30,16 @@ class APIServer : public Component, public Controller {
|
||||
void set_port(uint16_t port);
|
||||
void set_password(const std::string &password);
|
||||
void set_reboot_timeout(uint32_t reboot_timeout);
|
||||
|
||||
#ifdef USE_API_NOISE
|
||||
void set_noise_psk(psk_t psk) {
|
||||
noise_ctx_->set_psk(std::move(psk));
|
||||
}
|
||||
std::shared_ptr<APINoiseContext> get_noise_ctx() {
|
||||
return noise_ctx_;
|
||||
}
|
||||
#endif // USE_API_NOISE
|
||||
|
||||
void handle_disconnect(APIConnection *conn);
|
||||
#ifdef USE_BINARY_SENSOR
|
||||
void on_binary_sensor_update(binary_sensor::BinarySensor *obj, bool state) override;
|
||||
@@ -86,7 +91,7 @@ class APIServer : public Component, public Controller {
|
||||
const std::vector<UserServiceDescriptor *> &get_user_services() const { return this->user_services_; }
|
||||
|
||||
protected:
|
||||
AsyncServer server_{0};
|
||||
std::unique_ptr<socket::Socket> socket_ = nullptr;
|
||||
uint16_t port_{6053};
|
||||
uint32_t reboot_timeout_{300000};
|
||||
uint32_t last_connected_{0};
|
||||
@@ -94,6 +99,10 @@ class APIServer : public Component, public Controller {
|
||||
std::string password_;
|
||||
std::vector<HomeAssistantStateSubscription> state_subs_;
|
||||
std::vector<UserServiceDescriptor *> user_services_;
|
||||
|
||||
#ifdef USE_API_NOISE
|
||||
std::shared_ptr<APINoiseContext> noise_ctx_ = std::make_shared<APINoiseContext>();
|
||||
#endif // USE_API_NOISE
|
||||
};
|
||||
|
||||
extern APIServer *global_api_server; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
|
||||
38
esphome/components/echo/__init__.py
Normal file
38
esphome/components/echo/__init__.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from esphome.const import CONF_ID
|
||||
import esphome.config_validation as cv
|
||||
import esphome.codegen as cg
|
||||
|
||||
AUTO_LOAD = ["socket"]
|
||||
|
||||
echo_ns = cg.esphome_ns.namespace("echo")
|
||||
EchoServer = echo_ns.class_("EchoServer", cg.Component)
|
||||
EchoNoiseServer = echo_ns.class_("EchoNoiseServer", cg.Component)
|
||||
|
||||
CONFIG_SCHEMA = cv.Schema(
|
||||
{
|
||||
cv.Optional("ssl"): cv.COMPONENT_SCHEMA.extend(
|
||||
{
|
||||
cv.GenerateID(): cv.declare_id(EchoServer),
|
||||
}
|
||||
),
|
||||
cv.Optional("noise"): cv.COMPONENT_SCHEMA.extend(
|
||||
{
|
||||
cv.GenerateID(): cv.declare_id(EchoNoiseServer),
|
||||
}
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def to_code(config):
|
||||
conf = config.get("ssl")
|
||||
if conf is not None:
|
||||
var = cg.new_Pvariable(conf[CONF_ID])
|
||||
await cg.register_component(var, conf)
|
||||
cg.add_define("USE_ECHO_SSL")
|
||||
conf = config.get("noise")
|
||||
if conf is not None:
|
||||
var = cg.new_Pvariable(conf[CONF_ID])
|
||||
await cg.register_component(var, conf)
|
||||
cg.add_define("USE_ECHO_NOISE")
|
||||
cg.add_library("esphome/noise-c", "0.1.0")
|
||||
500
esphome/components/echo/echo.cpp
Normal file
500
esphome/components/echo/echo.cpp
Normal file
@@ -0,0 +1,500 @@
|
||||
#include "echo.h"
|
||||
#include "esphome/core/log.h"
|
||||
#include "esphome/core/helpers.h"
|
||||
#include "esphome/core/esphal.h"
|
||||
|
||||
namespace esphome {
|
||||
namespace echo {
|
||||
|
||||
static const char *const TAG = "echo";
|
||||
|
||||
#ifdef USE_ECHO_SSL
|
||||
void EchoServer::setup() {
|
||||
ESP_LOGCONFIG(TAG, "Setting up echo server...");
|
||||
socket_ = socket::socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (socket_ == nullptr) {
|
||||
ESP_LOGW(TAG, "Could not create socket.");
|
||||
this->mark_failed();
|
||||
return;
|
||||
}
|
||||
int enable = 1;
|
||||
int err = socket_->setsockopt(SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
|
||||
if (err != 0) {
|
||||
ESP_LOGW(TAG, "Socket unable to set reuseaddr: errno %d", err);
|
||||
// we can still continue
|
||||
}
|
||||
err = socket_->setblocking(false);
|
||||
if (err != 0) {
|
||||
ESP_LOGW(TAG, "Socket unable to set nonblocking mode: errno %d", err);
|
||||
this->mark_failed();
|
||||
return;
|
||||
}
|
||||
|
||||
struct sockaddr_in server;
|
||||
memset(&server, 0, sizeof(server));
|
||||
server.sin_family = AF_INET;
|
||||
server.sin_addr.s_addr = INADDR_ANY;
|
||||
server.sin_port = htons(6055);
|
||||
|
||||
err = socket_->bind((struct sockaddr *) &server, sizeof(server));
|
||||
if (err != 0) {
|
||||
ESP_LOGW(TAG, "Socket unable to bind: errno %d", errno);
|
||||
this->mark_failed();
|
||||
return;
|
||||
}
|
||||
|
||||
err = socket_->listen(4);
|
||||
if (err != 0) {
|
||||
ESP_LOGW(TAG, "Socket unable to listen: errno %d", errno);
|
||||
this->mark_failed();
|
||||
return;
|
||||
}
|
||||
|
||||
ssl_ = ssl::create_context();
|
||||
if (!ssl_) {
|
||||
ESP_LOGW(TAG, "Failed to create SSL context: errno %d", errno);
|
||||
this->mark_failed();
|
||||
return;
|
||||
}
|
||||
ssl_->set_server_certificate(R"(-----BEGIN CERTIFICATE-----
|
||||
MIIB3zCCAYWgAwIBAgIUZvjl3kvRTeMLlFUXR8Vbhw0UXnMwCgYIKoZIzj0EAwIw
|
||||
RTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGElu
|
||||
dGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yMTA4MTAxODA5NDNaFw0yNDA1MDYx
|
||||
ODA5NDNaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYD
|
||||
VQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwWTATBgcqhkjOPQIBBggqhkjO
|
||||
PQMBBwNCAAQrW9JCZlDynrY1DphZGpDxV1jkfoVSTiBWQLWuixolu7aJuR3+o+BJ
|
||||
ZrvPCdNHpEOyx7r1DV23SWSp1eIZR43co1MwUTAdBgNVHQ4EFgQUPmWned9/9QAq
|
||||
TCnb3I8dou8plDkwHwYDVR0jBBgwFoAUPmWned9/9QAqTCnb3I8dou8plDkwDwYD
|
||||
VR0TAQH/BAUwAwEB/zAKBggqhkjOPQQDAgNIADBFAiBi375FEb+w297p0J/12lgp
|
||||
iA9ppA4/QwtZdzioULmwVAIhALhGbVdbSAaLI+bwoICROHnuttY6mxJmDK8158Xe
|
||||
s2U4
|
||||
-----END CERTIFICATE-----
|
||||
)");
|
||||
ssl_->set_private_key(R"(-----BEGIN EC PRIVATE KEY-----
|
||||
MHcCAQEEICqgqSPPEMmoWbwLpLm1lv4FQ48TsLOmXRbdceKs4DQ/oAoGCCqGSM49
|
||||
AwEHoUQDQgAEK1vSQmZQ8p62NQ6YWRqQ8VdY5H6FUk4gVkC1rosaJbu2ibkd/qPg
|
||||
SWa7zwnTR6RDsse69Q1dt0lkqdXiGUeN3A==
|
||||
-----END EC PRIVATE KEY-----
|
||||
)");
|
||||
err = ssl_->init();
|
||||
if (err != 0) {
|
||||
ESP_LOGW(TAG, "Failed to initialize SSL context: errno %d", errno);
|
||||
this->mark_failed();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
void EchoServer::loop() {
|
||||
// Accept new clients
|
||||
while (true) {
|
||||
struct sockaddr_storage source_addr;
|
||||
socklen_t addr_len = sizeof(source_addr);
|
||||
auto sock = socket_->accept((struct sockaddr *) &source_addr, &addr_len);
|
||||
if (!sock)
|
||||
break;
|
||||
ESP_LOGD(TAG, "Accepted %s", sock->getpeername().c_str());
|
||||
|
||||
// wrap socket
|
||||
auto sock2 = ssl_->wrap_socket(std::move(sock));
|
||||
if (!sock2) {
|
||||
ESP_LOGW(TAG, "Failed to wrap socket with SSL: errno %d", errno);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto cli = std::unique_ptr<EchoClient>{new EchoClient(std::move(sock2))};
|
||||
cli->start();
|
||||
clients_.push_back(std::move(cli));
|
||||
}
|
||||
|
||||
auto new_end = std::partition(this->clients_.begin(), this->clients_.end(),
|
||||
[](const std::unique_ptr<EchoClient> &cli) { return !cli->remove_; });
|
||||
this->clients_.erase(new_end, this->clients_.end());
|
||||
|
||||
for (auto &client : this->clients_) {
|
||||
client->loop();
|
||||
}
|
||||
}
|
||||
|
||||
void EchoClient::start() {
|
||||
ESP_LOGD(TAG, "Starting socket");
|
||||
int err = socket_->setblocking(false);
|
||||
if (err != 0) {
|
||||
on_error_();
|
||||
ESP_LOGW(TAG, "Socket could not enable non-blocking, errno: %d", errno);
|
||||
return;
|
||||
}
|
||||
int enable = 1;
|
||||
err = socket_->setsockopt(IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(int));
|
||||
if (err != 0) {
|
||||
on_error_();
|
||||
ESP_LOGW(TAG, "Socket could not enable tcp nodelay, errno: %d", errno);
|
||||
return;
|
||||
}
|
||||
|
||||
rx_buffer_.resize(64);
|
||||
}
|
||||
void EchoClient::loop() {
|
||||
uint32_t start = millis();
|
||||
int err = socket_->loop();
|
||||
if (err != 0) {
|
||||
on_error_();
|
||||
ESP_LOGW(TAG, "Socket loop failed: errno %d", errno);
|
||||
return;
|
||||
}
|
||||
|
||||
while (!this->remove_) {
|
||||
size_t capacity = this->rx_buffer_.size();
|
||||
|
||||
ssize_t received = socket_->read(rx_buffer_.data(), capacity);
|
||||
if (received == -1) {
|
||||
if (errno == EAGAIN || errno == EWOULDBLOCK) {
|
||||
// read would block
|
||||
break;
|
||||
}
|
||||
if (errno == ECONNRESET) {
|
||||
// connection reset
|
||||
this->on_error_();
|
||||
ESP_LOGW(TAG, "Client disconnected");
|
||||
return;
|
||||
}
|
||||
this->on_error_();
|
||||
ESP_LOGW(TAG, "Error reading from socket: errno %d", errno);
|
||||
return;
|
||||
} else if (received == 0) {
|
||||
break;
|
||||
}
|
||||
ESP_LOGD(TAG, "received %s", hexencode(rx_buffer_.data(), received).c_str());
|
||||
tx_buffer_.insert(tx_buffer_.end(), rx_buffer_.begin(), rx_buffer_.begin() + received);
|
||||
|
||||
if (received != capacity)
|
||||
// done with reading
|
||||
break;
|
||||
}
|
||||
|
||||
while (!this->remove_ && !tx_buffer_.empty()) {
|
||||
int err = socket_->write(tx_buffer_.data(), tx_buffer_.size());
|
||||
if (err == -1) {
|
||||
if (errno == EWOULDBLOCK || errno == EAGAIN) {
|
||||
break;
|
||||
}
|
||||
if (errno == ECONNRESET) {
|
||||
this->on_error_();
|
||||
ESP_LOGW(TAG, "Client disconnected");
|
||||
return;
|
||||
}
|
||||
on_error_();
|
||||
ESP_LOGW(TAG, "Socket write failed: errno %d", errno);
|
||||
return;
|
||||
} else if (err == 0) {
|
||||
break;
|
||||
}
|
||||
tx_buffer_.erase(tx_buffer_.begin(), tx_buffer_.begin() + err);
|
||||
}
|
||||
|
||||
uint32_t end = millis();
|
||||
if (end - start > 10) {
|
||||
ESP_LOGD(TAG, "loop took %u ms", end - start);
|
||||
}
|
||||
}
|
||||
|
||||
void EchoClient::on_error_() {
|
||||
this->socket_->close();
|
||||
this->remove_ = true;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_ECHO_NOISE
|
||||
void EchoNoiseServer::setup() {
|
||||
ESP_LOGCONFIG(TAG, "Setting up echo server...");
|
||||
socket_ = socket::socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (socket_ == nullptr) {
|
||||
ESP_LOGW(TAG, "Could not create socket.");
|
||||
this->mark_failed();
|
||||
return;
|
||||
}
|
||||
int enable = 1;
|
||||
int err = socket_->setsockopt(SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
|
||||
if (err != 0) {
|
||||
ESP_LOGW(TAG, "Socket unable to set reuseaddr: errno %d", err);
|
||||
// we can still continue
|
||||
}
|
||||
err = socket_->setblocking(false);
|
||||
if (err != 0) {
|
||||
ESP_LOGW(TAG, "Socket unable to set nonblocking mode: errno %d", err);
|
||||
this->mark_failed();
|
||||
return;
|
||||
}
|
||||
|
||||
struct sockaddr_in server;
|
||||
memset(&server, 0, sizeof(server));
|
||||
server.sin_family = AF_INET;
|
||||
server.sin_addr.s_addr = INADDR_ANY;
|
||||
server.sin_port = htons(6056);
|
||||
|
||||
err = socket_->bind((struct sockaddr *) &server, sizeof(server));
|
||||
if (err != 0) {
|
||||
ESP_LOGW(TAG, "Socket unable to bind: errno %d", errno);
|
||||
this->mark_failed();
|
||||
return;
|
||||
}
|
||||
|
||||
err = socket_->listen(4);
|
||||
if (err != 0) {
|
||||
ESP_LOGW(TAG, "Socket unable to listen: errno %d", errno);
|
||||
this->mark_failed();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
void EchoNoiseServer::loop() {
|
||||
// Accept new clients
|
||||
while (true) {
|
||||
struct sockaddr_storage source_addr;
|
||||
socklen_t addr_len = sizeof(source_addr);
|
||||
auto sock = socket_->accept((struct sockaddr *) &source_addr, &addr_len);
|
||||
if (!sock)
|
||||
break;
|
||||
ESP_LOGD(TAG, "Accepted %s", sock->getpeername().c_str());
|
||||
|
||||
auto cli = std::unique_ptr<EchoNoiseClient>{new EchoNoiseClient(std::move(sock))};
|
||||
cli->start();
|
||||
clients_.push_back(std::move(cli));
|
||||
}
|
||||
|
||||
auto new_end = std::partition(this->clients_.begin(), this->clients_.end(),
|
||||
[](const std::unique_ptr<EchoNoiseClient> &cli) { return !cli->remove_; });
|
||||
this->clients_.erase(new_end, this->clients_.end());
|
||||
|
||||
for (auto &client : this->clients_) {
|
||||
client->loop();
|
||||
}
|
||||
}
|
||||
|
||||
void EchoNoiseClient::start() {
|
||||
ESP_LOGD(TAG, "Starting socket");
|
||||
int err = socket_->setblocking(false);
|
||||
if (err != 0) {
|
||||
on_error_();
|
||||
ESP_LOGW(TAG, "Socket could not enable non-blocking, errno: %d", errno);
|
||||
return;
|
||||
}
|
||||
int enable = 1;
|
||||
err = socket_->setsockopt(IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(int));
|
||||
if (err != 0) {
|
||||
on_error_();
|
||||
ESP_LOGW(TAG, "Socket could not enable tcp nodelay, errno: %d", errno);
|
||||
return;
|
||||
}
|
||||
|
||||
memset(&nid_, 0, sizeof(nid_));
|
||||
const char *proto = "Noise_NNpsk0_25519_ChaChaPoly_SHA256";
|
||||
err = noise_protocol_name_to_id(&nid_, proto, strlen(proto));
|
||||
if (err != 0) {
|
||||
on_error_();
|
||||
ESP_LOGW(TAG, "noise_protocol_name_to_id failed: %d", err);
|
||||
return;
|
||||
}
|
||||
/*nid_.pattern_id = NOISE_PATTERN_NN;
|
||||
nid_.cipher_id = NOISE_CIPHER_CHACHAPOLY;
|
||||
nid_.dh_id = NOISE_DH_CURVE25519;
|
||||
nid_.prefix_id = NOISE_PREFIX_STANDARD;
|
||||
nid_.hybrid_id = NOISE_DH_NONE;
|
||||
nid_.hash_id = NOISE_HASH_SHA256; // NOISE_HASH_BLAKE2s
|
||||
nid_.modifier_ids[0] = NOISE_MODIFIER_PSK0;*/
|
||||
|
||||
err = noise_handshakestate_new_by_id(&handshake_, &nid_, NOISE_ROLE_RESPONDER);
|
||||
if (err != 0) {
|
||||
on_error_();
|
||||
ESP_LOGW(TAG, "noise_handshakestate_new_by_id failed: %d", err);
|
||||
return;
|
||||
}
|
||||
|
||||
// initialize_handshake
|
||||
{
|
||||
const uint8_t psk[] = {0xC1, 0xD5, 0xE0, 0x72, 0xE7, 0x77, 0x58, 0x02, 0x45, 0xCB, 0x3A,
|
||||
0x81, 0x04, 0x1B, 0x2D, 0x90, 0x3A, 0x0F, 0x0E, 0xC7, 0x9C, 0xFC,
|
||||
0xB4, 0x2A, 0x50, 0xC0, 0xE6, 0x35, 0xA1, 0x54, 0x18, 0x12};
|
||||
static_assert(sizeof(psk) == 32, "error");
|
||||
// noise_handshakestate_set_prologue(handshake, prologue, strlen(prologue));
|
||||
err = noise_handshakestate_set_pre_shared_key(handshake_, psk, 32);
|
||||
|
||||
if (err != 0) {
|
||||
on_error_();
|
||||
ESP_LOGW(TAG, "noise_handshakestate_set_pre_shared_key failed: %d", err);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Start the handshake
|
||||
err = noise_handshakestate_start(handshake_);
|
||||
if (err != 0) {
|
||||
on_error_();
|
||||
ESP_LOGW(TAG, "noise_handshakestate_start failed: %d", err);
|
||||
return;
|
||||
}
|
||||
|
||||
rx_buffer_.resize(64);
|
||||
msg_buffer_.resize(64);
|
||||
do_handshake_ = true;
|
||||
}
|
||||
|
||||
void EchoNoiseClient::loop() {
|
||||
if (this->remove_)
|
||||
return;
|
||||
|
||||
const uint32_t start = millis();
|
||||
int err = socket_->loop();
|
||||
if (err != 0) {
|
||||
on_error_();
|
||||
ESP_LOGW(TAG, "Socket loop failed: errno %d", errno);
|
||||
return;
|
||||
}
|
||||
NoiseBuffer mbuf;
|
||||
noise_buffer_init(mbuf);
|
||||
|
||||
if (!this->remove_ && do_handshake_) {
|
||||
int action = noise_handshakestate_get_action(handshake_);
|
||||
if (action == NOISE_ACTION_WRITE_MESSAGE) {
|
||||
// Write the next handshake message with a zero-length payload
|
||||
noise_buffer_set_output(mbuf, msg_buffer_.data(), msg_buffer_.size());
|
||||
err = noise_handshakestate_write_message(handshake_, &mbuf, nullptr);
|
||||
if (err == 0) {
|
||||
tx_buffer_.push_back((uint8_t)(mbuf.size >> 8));
|
||||
tx_buffer_.push_back((uint8_t) mbuf.size);
|
||||
tx_buffer_.insert(tx_buffer_.end(), msg_buffer_.begin(), msg_buffer_.begin() + mbuf.size);
|
||||
} else {
|
||||
on_error_();
|
||||
ESP_LOGW(TAG, "noise_handshakestate_write_message failed: %d", err);
|
||||
return;
|
||||
}
|
||||
} else if (action == NOISE_ACTION_READ_MESSAGE) {
|
||||
if (rx_size_ >= 2) {
|
||||
uint16_t msg_size = ((uint16_t)(rx_buffer_[0]) << 8) | (rx_buffer_[1]);
|
||||
if (rx_size_ >= msg_size + 2) {
|
||||
ESP_LOGD(TAG, "Message: %s", hexencode(rx_buffer_.data() + 2, msg_size).c_str());
|
||||
noise_buffer_set_input(mbuf, rx_buffer_.data() + 2, msg_size);
|
||||
err = noise_handshakestate_read_message(handshake_, &mbuf, nullptr);
|
||||
if (err != 0) {
|
||||
on_error_();
|
||||
ESP_LOGW(TAG, "noise_handshakestate_read_message failed: %d", err);
|
||||
return;
|
||||
}
|
||||
rx_size_ -= msg_size + 2;
|
||||
}
|
||||
}
|
||||
} else if (action == NOISE_ACTION_SPLIT) {
|
||||
err = noise_handshakestate_split(handshake_, &send_cipher_, &recv_cipher_);
|
||||
if (err != 0) {
|
||||
on_error_();
|
||||
ESP_LOGW(TAG, "noise_handshakestate_split failed: %d", err);
|
||||
return;
|
||||
}
|
||||
noise_handshakestate_free(handshake_);
|
||||
do_handshake_ = false;
|
||||
ESP_LOGI(TAG, "handshake finished");
|
||||
} else {
|
||||
on_error_();
|
||||
ESP_LOGW(TAG, "noise_handshakestate_get_action failed: %d", err);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
while (!this->remove_ && !this->do_handshake_) {
|
||||
if (rx_size_ < 2)
|
||||
break;
|
||||
uint16_t msg_size = ((uint16_t)(rx_buffer_[0]) << 8) | (rx_buffer_[1]);
|
||||
if (rx_size_ < msg_size + 2)
|
||||
break;
|
||||
noise_buffer_set_inout(mbuf, rx_buffer_.data() + 2, msg_size, rx_buffer_.size() - 2);
|
||||
err = noise_cipherstate_decrypt(recv_cipher_, &mbuf);
|
||||
if (err != 0) {
|
||||
on_error_();
|
||||
ESP_LOGW(TAG, "noise_cipherstate_decrypt failed: %d", err);
|
||||
return;
|
||||
}
|
||||
rx_size_ -= msg_size + 2;
|
||||
|
||||
err = noise_cipherstate_encrypt(send_cipher_, &mbuf);
|
||||
if (err != 0) {
|
||||
on_error_();
|
||||
ESP_LOGW(TAG, "noise_cipherstate_encrypt failed: %d", err);
|
||||
return;
|
||||
}
|
||||
tx_buffer_.push_back((uint8_t)(mbuf.size >> 8));
|
||||
tx_buffer_.push_back((uint8_t) mbuf.size);
|
||||
tx_buffer_.insert(tx_buffer_.end(), rx_buffer_.begin() + 2, rx_buffer_.begin() + 2 + mbuf.size);
|
||||
}
|
||||
|
||||
while (!this->remove_) {
|
||||
size_t capacity = this->rx_buffer_.size();
|
||||
size_t used = rx_size_;
|
||||
size_t space = capacity - used;
|
||||
if (space == 0) {
|
||||
rx_buffer_.resize(capacity + 64);
|
||||
continue;
|
||||
}
|
||||
|
||||
ssize_t received = socket_->read(rx_buffer_.data() + used, space);
|
||||
if (received == -1) {
|
||||
if (errno == EAGAIN || errno == EWOULDBLOCK) {
|
||||
// read would block
|
||||
break;
|
||||
}
|
||||
if (errno == ECONNRESET) {
|
||||
// connection reset
|
||||
this->on_error_();
|
||||
ESP_LOGW(TAG, "Client disconnected");
|
||||
return;
|
||||
}
|
||||
this->on_error_();
|
||||
ESP_LOGW(TAG, "Error reading from socket: errno %d", errno);
|
||||
return;
|
||||
} else if (received == 0) {
|
||||
break;
|
||||
}
|
||||
ESP_LOGD(TAG, "received %s", hexencode(rx_buffer_.data(), received).c_str());
|
||||
rx_size_ += received;
|
||||
|
||||
if (received != capacity)
|
||||
// done with reading
|
||||
break;
|
||||
}
|
||||
|
||||
while (!this->remove_ && !tx_buffer_.empty()) {
|
||||
int err = socket_->write(tx_buffer_.data(), tx_buffer_.size());
|
||||
if (err == -1) {
|
||||
if (errno == EWOULDBLOCK || errno == EAGAIN) {
|
||||
break;
|
||||
}
|
||||
if (errno == ECONNRESET) {
|
||||
this->on_error_();
|
||||
ESP_LOGW(TAG, "Client disconnected");
|
||||
return;
|
||||
}
|
||||
on_error_();
|
||||
ESP_LOGW(TAG, "Socket write failed: errno %d", errno);
|
||||
return;
|
||||
} else if (err == 0) {
|
||||
break;
|
||||
}
|
||||
tx_buffer_.erase(tx_buffer_.begin(), tx_buffer_.begin() + err);
|
||||
}
|
||||
|
||||
const uint32_t end = millis();
|
||||
if (end - start > 10) {
|
||||
ESP_LOGD(TAG, "noise took %u ms", end - start);
|
||||
}
|
||||
}
|
||||
|
||||
void EchoNoiseClient::on_error_() {
|
||||
this->socket_->close();
|
||||
this->remove_ = true;
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace echo
|
||||
} // namespace esphome
|
||||
|
||||
extern "C" {
|
||||
void noise_rand_bytes(void *output, size_t len) { esp_fill_random(output, len); }
|
||||
}
|
||||
88
esphome/components/echo/echo.h
Normal file
88
esphome/components/echo/echo.h
Normal file
@@ -0,0 +1,88 @@
|
||||
#pragma once
|
||||
#include <vector>
|
||||
#include "esphome/core/component.h"
|
||||
#include "esphome/core/defines.h"
|
||||
#include "esphome/components/socket/socket.h"
|
||||
#ifdef USE_ECHO_SSL
|
||||
#include "esphome/components/ssl/ssl_context.h"
|
||||
#endif
|
||||
#include "noise/protocol.h"
|
||||
|
||||
namespace esphome {
|
||||
namespace echo {
|
||||
|
||||
#ifdef USE_ECHO_SSL
|
||||
class EchoClient {
|
||||
public:
|
||||
EchoClient(std::unique_ptr<socket::Socket> socket) : socket_(std::move(socket)) {}
|
||||
void start();
|
||||
void loop();
|
||||
|
||||
protected:
|
||||
friend class EchoServer;
|
||||
|
||||
void on_error_();
|
||||
|
||||
std::unique_ptr<socket::Socket> socket_ = nullptr;
|
||||
bool remove_ = false;
|
||||
std::vector<uint8_t> rx_buffer_;
|
||||
std::vector<uint8_t> tx_buffer_;
|
||||
};
|
||||
|
||||
class EchoServer : public Component {
|
||||
public:
|
||||
void setup() override;
|
||||
void loop() override;
|
||||
float get_setup_priority() const override { return setup_priority::AFTER_WIFI; }
|
||||
|
||||
protected:
|
||||
friend class EchoClient;
|
||||
|
||||
std::unique_ptr<socket::Socket> socket_ = nullptr;
|
||||
std::unique_ptr<ssl::SSLContext> ssl_ = nullptr;
|
||||
std::vector<std::unique_ptr<EchoClient>> clients_;
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef USE_ECHO_NOISE
|
||||
class EchoNoiseClient {
|
||||
public:
|
||||
EchoNoiseClient(std::unique_ptr<socket::Socket> socket) : socket_(std::move(socket)) {}
|
||||
void start();
|
||||
void loop();
|
||||
|
||||
protected:
|
||||
friend class EchoNoiseServer;
|
||||
|
||||
void on_error_();
|
||||
|
||||
std::unique_ptr<socket::Socket> socket_ = nullptr;
|
||||
bool remove_ = false;
|
||||
std::vector<uint8_t> rx_buffer_;
|
||||
size_t rx_size_ = 0;
|
||||
std::vector<uint8_t> tx_buffer_;
|
||||
std::vector<uint8_t> msg_buffer_;
|
||||
|
||||
NoiseHandshakeState *handshake_ = nullptr;
|
||||
NoiseCipherState *send_cipher_ = nullptr;
|
||||
NoiseCipherState *recv_cipher_ = nullptr;
|
||||
NoiseProtocolId nid_;
|
||||
bool do_handshake_ = false;
|
||||
};
|
||||
|
||||
class EchoNoiseServer : public Component {
|
||||
public:
|
||||
void setup() override;
|
||||
void loop() override;
|
||||
float get_setup_priority() const override { return setup_priority::AFTER_WIFI; }
|
||||
|
||||
protected:
|
||||
friend class EchoNoiseClient;
|
||||
|
||||
std::unique_ptr<socket::Socket> socket_ = nullptr;
|
||||
std::vector<std::unique_ptr<EchoNoiseClient>> clients_;
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace echo
|
||||
} // namespace esphome
|
||||
74
esphome/components/echo/noise_script.py
Normal file
74
esphome/components/echo/noise_script.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import socket
|
||||
from noise.connection import NoiseConnection
|
||||
|
||||
proto = NoiseConnection.from_name(b"Noise_NNpsk0_25519_ChaChaPoly_SHA256")
|
||||
proto.set_as_initiator()
|
||||
proto.set_psks(
|
||||
b"\xC1\xD5\xE0\x72\xE7\x77\x58\x02\x45\xCB\x3A\x81\x04\x1B\x2D\x90"
|
||||
b"\x3A\x0F\x0E\xC7\x9C\xFC\xB4\x2A\x50\xC0\xE6\x35\xA1\x54\x18\x12"
|
||||
)
|
||||
# sys.exit(1)
|
||||
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
print("[x] Connecting...")
|
||||
sock.connect(("192.168.178.154", 6053))
|
||||
print("[x] Connected!")
|
||||
prologue = b"NoiseAPIInit"
|
||||
|
||||
|
||||
def write(msg):
|
||||
print(f"[x] Writing frame {msg.hex()}")
|
||||
l = len(msg)
|
||||
buf = bytes(
|
||||
[
|
||||
0x01,
|
||||
(l >> 8) & 0xFF,
|
||||
(l >> 0) & 0xFF,
|
||||
]
|
||||
)
|
||||
buf += msg
|
||||
print(f" -> {buf.hex()}")
|
||||
sock.sendall(buf)
|
||||
|
||||
|
||||
def recv():
|
||||
buf = b""
|
||||
while len(buf) < 3:
|
||||
buf += sock.recv(3 - len(buf))
|
||||
assert buf[0] == 0x01
|
||||
l = (buf[1] << 8) | buf[2]
|
||||
buf = buf[3:]
|
||||
while len(buf) < l:
|
||||
buf += sock.recv(l - len(buf))
|
||||
|
||||
print(f"[x] Received frame {buf.hex()}")
|
||||
return buf
|
||||
|
||||
|
||||
write(b"")
|
||||
prologue += b"\x00\x00"
|
||||
buf = recv()
|
||||
print(f"Received msg {buf.hex()}")
|
||||
|
||||
proto.set_prologue(prologue)
|
||||
proto.start_handshake()
|
||||
do_write = True
|
||||
while not proto.handshake_finished:
|
||||
if do_write:
|
||||
msg = proto.write_message()
|
||||
write(msg)
|
||||
else:
|
||||
msg = recv()
|
||||
proto.read_message(msg)
|
||||
|
||||
do_write = not do_write
|
||||
|
||||
print(f"[x] Handshake done!")
|
||||
|
||||
while True:
|
||||
msg = input().encode()
|
||||
buf = proto.encrypt(msg)
|
||||
write(buf)
|
||||
buf = recv()
|
||||
msg2 = proto.decrypt(buf)
|
||||
print(msg2)
|
||||
@@ -43,21 +43,24 @@ void Logger::write_header_(int level, const char *tag, int line) {
|
||||
}
|
||||
|
||||
void HOT Logger::log_vprintf_(int level, const char *tag, int line, const char *format, va_list args) { // NOLINT
|
||||
if (level > this->level_for(tag))
|
||||
if (level > this->level_for(tag) || recursion_guard_)
|
||||
return;
|
||||
|
||||
recursion_guard_ = true;
|
||||
this->reset_buffer_();
|
||||
this->write_header_(level, tag, line);
|
||||
this->vprintf_to_buffer_(format, args);
|
||||
this->write_footer_();
|
||||
this->log_message_(level, tag);
|
||||
recursion_guard_ = false;
|
||||
}
|
||||
#ifdef USE_STORE_LOG_STR_IN_FLASH
|
||||
void Logger::log_vprintf_(int level, const char *tag, int line, const __FlashStringHelper *format,
|
||||
va_list args) { // NOLINT
|
||||
if (level > this->level_for(tag))
|
||||
if (level > this->level_for(tag) || recursion_guard_)
|
||||
return;
|
||||
|
||||
recursion_guard_ = true;
|
||||
this->reset_buffer_();
|
||||
// copy format string
|
||||
const char *format_pgm_p = (PGM_P) format;
|
||||
@@ -78,6 +81,7 @@ void Logger::log_vprintf_(int level, const char *tag, int line, const __FlashStr
|
||||
this->vprintf_to_buffer_(this->tx_buffer_, args);
|
||||
this->write_footer_();
|
||||
this->log_message_(level, tag, offset);
|
||||
recursion_guard_ = false;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
@@ -113,6 +113,8 @@ class Logger : public Component {
|
||||
};
|
||||
std::vector<LogLevelOverride> log_levels_;
|
||||
CallbackManager<void(int, const char *, const char *)> log_callback_{};
|
||||
/// Prevents recursive log calls, if true a log message is already being processed.
|
||||
bool recursion_guard_ = false;
|
||||
};
|
||||
|
||||
extern Logger *global_logger; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
|
||||
@@ -15,6 +15,7 @@ from esphome.core import CORE, coroutine_with_priority
|
||||
|
||||
CODEOWNERS = ["@esphome/core"]
|
||||
DEPENDENCIES = ["network"]
|
||||
AUTO_LOAD = ["socket"]
|
||||
|
||||
CONF_ON_STATE_CHANGE = "on_state_change"
|
||||
CONF_ON_BEGIN = "on_begin"
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
#include "esphome/core/log.h"
|
||||
#include "esphome/core/application.h"
|
||||
#include "esphome/core/util.h"
|
||||
|
||||
#include <errno.h>
|
||||
#include <cstdio>
|
||||
#include <MD5Builder.h>
|
||||
#ifdef ARDUINO_ARCH_ESP32
|
||||
@@ -19,8 +19,44 @@ static const char *const TAG = "ota";
|
||||
static const uint8_t OTA_VERSION_1_0 = 1;
|
||||
|
||||
void OTAComponent::setup() {
|
||||
this->server_ = new WiFiServer(this->port_);
|
||||
this->server_->begin();
|
||||
server_ = socket::socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (server_ == nullptr) {
|
||||
ESP_LOGW(TAG, "Could not create socket.");
|
||||
this->mark_failed();
|
||||
return;
|
||||
}
|
||||
int enable = 1;
|
||||
int err = server_->setsockopt(SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
|
||||
if (err != 0) {
|
||||
ESP_LOGW(TAG, "Socket unable to set reuseaddr: errno %d", err);
|
||||
// we can still continue
|
||||
}
|
||||
err = server_->setblocking(false);
|
||||
if (err != 0) {
|
||||
ESP_LOGW(TAG, "Socket unable to set nonblocking mode: errno %d", err);
|
||||
this->mark_failed();
|
||||
return;
|
||||
}
|
||||
|
||||
struct sockaddr_in server;
|
||||
memset(&server, 0, sizeof(server));
|
||||
server.sin_family = AF_INET;
|
||||
server.sin_addr.s_addr = INADDR_ANY;
|
||||
server.sin_port = htons(this->port_);
|
||||
|
||||
err = server_->bind((struct sockaddr *) &server, sizeof(server));
|
||||
if (err != 0) {
|
||||
ESP_LOGW(TAG, "Socket unable to bind: errno %d", errno);
|
||||
this->mark_failed();
|
||||
return;
|
||||
}
|
||||
|
||||
err = server_->listen(4);
|
||||
if (err != 0) {
|
||||
ESP_LOGW(TAG, "Socket unable to listen: errno %d", errno);
|
||||
this->mark_failed();
|
||||
return;
|
||||
}
|
||||
|
||||
this->dump_config();
|
||||
}
|
||||
@@ -59,23 +95,28 @@ void OTAComponent::handle_() {
|
||||
uint8_t ota_features;
|
||||
(void) ota_features;
|
||||
|
||||
if (!this->client_.connected()) {
|
||||
this->client_ = this->server_->available();
|
||||
if (client_ == nullptr) {
|
||||
struct sockaddr_storage source_addr;
|
||||
socklen_t addr_len = sizeof(source_addr);
|
||||
client_ = server_->accept((struct sockaddr *) &source_addr, &addr_len);
|
||||
}
|
||||
if (client_ == nullptr)
|
||||
return;
|
||||
|
||||
if (!this->client_.connected())
|
||||
return;
|
||||
int enable = 1;
|
||||
int err = client_->setsockopt(IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(int));
|
||||
if (err != 0) {
|
||||
ESP_LOGW(TAG, "Socket could not enable tcp nodelay, errno: %d", errno);
|
||||
return;
|
||||
}
|
||||
|
||||
// enable nodelay for outgoing data
|
||||
this->client_.setNoDelay(true);
|
||||
|
||||
ESP_LOGD(TAG, "Starting OTA Update from %s...", this->client_.remoteIP().toString().c_str());
|
||||
ESP_LOGD(TAG, "Starting OTA Update from %s...", this->client_->getpeername().c_str());
|
||||
this->status_set_warning();
|
||||
#ifdef USE_OTA_STATE_CALLBACK
|
||||
this->state_callback_.call(OTA_STARTED, 0.0f, 0);
|
||||
#endif
|
||||
|
||||
if (!this->wait_receive_(buf, 5)) {
|
||||
if (!this->readall_(buf, 5)) {
|
||||
ESP_LOGW(TAG, "Reading magic bytes failed!");
|
||||
goto error;
|
||||
}
|
||||
@@ -88,11 +129,12 @@ void OTAComponent::handle_() {
|
||||
}
|
||||
|
||||
// Send OK and version - 2 bytes
|
||||
this->client_.write(OTA_RESPONSE_OK);
|
||||
this->client_.write(OTA_VERSION_1_0);
|
||||
buf[0] = OTA_RESPONSE_OK;
|
||||
buf[1] = OTA_VERSION_1_0;
|
||||
this->writeall_(buf, 2);
|
||||
|
||||
// Read features - 1 byte
|
||||
if (!this->wait_receive_(buf, 1)) {
|
||||
if (!this->readall_(buf, 1)) {
|
||||
ESP_LOGW(TAG, "Reading features failed!");
|
||||
goto error;
|
||||
}
|
||||
@@ -100,10 +142,12 @@ void OTAComponent::handle_() {
|
||||
ESP_LOGV(TAG, "OTA features is 0x%02X", ota_features);
|
||||
|
||||
// Acknowledge header - 1 byte
|
||||
this->client_.write(OTA_RESPONSE_HEADER_OK);
|
||||
buf[0] = OTA_RESPONSE_HEADER_OK;
|
||||
this->writeall_(buf, 1);
|
||||
|
||||
if (!this->password_.empty()) {
|
||||
this->client_.write(OTA_RESPONSE_REQUEST_AUTH);
|
||||
buf[0] = OTA_RESPONSE_REQUEST_AUTH;
|
||||
this->writeall_(buf, 1);
|
||||
MD5Builder md5_builder{};
|
||||
md5_builder.begin();
|
||||
sprintf(sbuf, "%08X", random_uint32());
|
||||
@@ -113,7 +157,7 @@ void OTAComponent::handle_() {
|
||||
ESP_LOGV(TAG, "Auth: Nonce is %s", sbuf);
|
||||
|
||||
// Send nonce, 32 bytes hex MD5
|
||||
if (this->client_.write(reinterpret_cast<uint8_t *>(sbuf), 32) != 32) {
|
||||
if (!this->writeall_(reinterpret_cast<uint8_t *>(sbuf), 32)) {
|
||||
ESP_LOGW(TAG, "Auth: Writing nonce failed!");
|
||||
goto error;
|
||||
}
|
||||
@@ -125,7 +169,7 @@ void OTAComponent::handle_() {
|
||||
md5_builder.add(sbuf);
|
||||
|
||||
// Receive cnonce, 32 bytes hex MD5
|
||||
if (!this->wait_receive_(buf, 32)) {
|
||||
if (!this->readall_(buf, 32)) {
|
||||
ESP_LOGW(TAG, "Auth: Reading cnonce failed!");
|
||||
goto error;
|
||||
}
|
||||
@@ -140,7 +184,7 @@ void OTAComponent::handle_() {
|
||||
ESP_LOGV(TAG, "Auth: Result is %s", sbuf);
|
||||
|
||||
// Receive result, 32 bytes hex MD5
|
||||
if (!this->wait_receive_(buf + 64, 32)) {
|
||||
if (!this->writeall_(buf + 64, 32)) {
|
||||
ESP_LOGW(TAG, "Auth: Reading response failed!");
|
||||
goto error;
|
||||
}
|
||||
@@ -159,10 +203,11 @@ void OTAComponent::handle_() {
|
||||
}
|
||||
|
||||
// Acknowledge auth OK - 1 byte
|
||||
this->client_.write(OTA_RESPONSE_AUTH_OK);
|
||||
buf[0] = OTA_RESPONSE_AUTH_OK;
|
||||
this->writeall_(buf, 1);
|
||||
|
||||
// Read size, 4 bytes MSB first
|
||||
if (!this->wait_receive_(buf, 4)) {
|
||||
if (!this->readall_(buf, 4)) {
|
||||
ESP_LOGW(TAG, "Reading size failed!");
|
||||
goto error;
|
||||
}
|
||||
@@ -211,10 +256,11 @@ void OTAComponent::handle_() {
|
||||
update_started = true;
|
||||
|
||||
// Acknowledge prepare OK - 1 byte
|
||||
this->client_.write(OTA_RESPONSE_UPDATE_PREPARE_OK);
|
||||
buf[0] = OTA_RESPONSE_UPDATE_PREPARE_OK;
|
||||
this->writeall_(buf, 1);
|
||||
|
||||
// Read binary MD5, 32 bytes
|
||||
if (!this->wait_receive_(buf, 32)) {
|
||||
if (!this->readall_(buf, 32)) {
|
||||
ESP_LOGW(TAG, "Reading binary MD5 checksum failed!");
|
||||
goto error;
|
||||
}
|
||||
@@ -223,17 +269,22 @@ void OTAComponent::handle_() {
|
||||
Update.setMD5(sbuf);
|
||||
|
||||
// Acknowledge MD5 OK - 1 byte
|
||||
this->client_.write(OTA_RESPONSE_BIN_MD5_OK);
|
||||
buf[0] = OTA_RESPONSE_BIN_MD5_OK;
|
||||
this->writeall_(buf, 1);
|
||||
|
||||
while (!Update.isFinished()) {
|
||||
size_t available = this->wait_receive_(buf, 0);
|
||||
if (!available) {
|
||||
// TODO: timeout check
|
||||
ssize_t read = this->client_->read(buf, sizeof(buf));
|
||||
if (read == -1) {
|
||||
if (errno == EAGAIN || errno == EWOULDBLOCK)
|
||||
continue;
|
||||
ESP_LOGW(TAG, "Error receiving data for update, errno: %d", errno);
|
||||
goto error;
|
||||
}
|
||||
|
||||
uint32_t written = Update.write(buf, available);
|
||||
if (written != available) {
|
||||
ESP_LOGW(TAG, "Error writing binary data to flash: %u != %u!", written, available); // NOLINT
|
||||
uint32_t written = Update.write(buf, read);
|
||||
if (written != read) {
|
||||
ESP_LOGW(TAG, "Error writing binary data to flash: %u != %u!", written, read); // NOLINT
|
||||
error_code = OTA_RESPONSE_ERROR_WRITING_FLASH;
|
||||
goto error;
|
||||
}
|
||||
@@ -253,7 +304,8 @@ void OTAComponent::handle_() {
|
||||
}
|
||||
|
||||
// Acknowledge receive OK - 1 byte
|
||||
this->client_.write(OTA_RESPONSE_RECEIVE_OK);
|
||||
buf[0] = OTA_RESPONSE_RECEIVE_OK;
|
||||
this->writeall_(buf, 1);
|
||||
|
||||
if (!Update.end()) {
|
||||
error_code = OTA_RESPONSE_ERROR_UPDATE_END;
|
||||
@@ -261,16 +313,17 @@ void OTAComponent::handle_() {
|
||||
}
|
||||
|
||||
// Acknowledge Update end OK - 1 byte
|
||||
this->client_.write(OTA_RESPONSE_UPDATE_END_OK);
|
||||
buf[0] = OTA_RESPONSE_UPDATE_END_OK;
|
||||
this->writeall_(buf, 1);
|
||||
|
||||
// Read ACK
|
||||
if (!this->wait_receive_(buf, 1, false) || buf[0] != OTA_RESPONSE_OK) {
|
||||
if (!this->readall_(buf, 1) || buf[0] != OTA_RESPONSE_OK) {
|
||||
ESP_LOGW(TAG, "Reading back acknowledgement failed!");
|
||||
// do not go to error, this is not fatal
|
||||
}
|
||||
|
||||
this->client_.flush();
|
||||
this->client_.stop();
|
||||
this->client_->close();
|
||||
this->client_ = nullptr;
|
||||
delay(10);
|
||||
ESP_LOGI(TAG, "OTA update finished!");
|
||||
this->status_clear_warning();
|
||||
@@ -286,11 +339,10 @@ error:
|
||||
Update.printError(ss);
|
||||
ESP_LOGW(TAG, "Update end failed! Error: %s", ss.c_str());
|
||||
}
|
||||
if (this->client_.connected()) {
|
||||
this->client_.write(static_cast<uint8_t>(error_code));
|
||||
this->client_.flush();
|
||||
}
|
||||
this->client_.stop();
|
||||
buf[0] = static_cast<uint8_t>(error_code);
|
||||
this->writeall_(buf, 1);
|
||||
this->client_->close();
|
||||
this->client_ = nullptr;
|
||||
|
||||
#ifdef ARDUINO_ARCH_ESP32
|
||||
if (update_started) {
|
||||
@@ -314,52 +366,56 @@ error:
|
||||
#endif
|
||||
}
|
||||
|
||||
size_t OTAComponent::wait_receive_(uint8_t *buf, size_t bytes, bool check_disconnected) {
|
||||
size_t available = 0;
|
||||
bool OTAComponent::readall_(uint8_t *buf, size_t len) {
|
||||
uint32_t start = millis();
|
||||
do {
|
||||
App.feed_wdt();
|
||||
if (check_disconnected && !this->client_.connected()) {
|
||||
ESP_LOGW(TAG, "Error client disconnected while receiving data!");
|
||||
return 0;
|
||||
}
|
||||
int availi = this->client_.available();
|
||||
if (availi < 0) {
|
||||
ESP_LOGW(TAG, "Error reading data!");
|
||||
return 0;
|
||||
}
|
||||
uint32_t at = 0;
|
||||
while (len - at > 0) {
|
||||
uint32_t now = millis();
|
||||
if (availi == 0 && now - start > 10000) {
|
||||
ESP_LOGW(TAG, "Timeout waiting for data!");
|
||||
return 0;
|
||||
if (now - start > 1000) {
|
||||
ESP_LOGW(TAG, "Timed out reading %d bytes of data", len);
|
||||
return false;
|
||||
}
|
||||
available = size_t(availi);
|
||||
yield();
|
||||
} while (bytes == 0 ? available == 0 : available < bytes);
|
||||
|
||||
if (bytes == 0)
|
||||
bytes = std::min(available, size_t(1024));
|
||||
|
||||
bool success = false;
|
||||
for (uint32_t i = 0; !success && i < 100; i++) {
|
||||
int res = this->client_.read(buf, bytes);
|
||||
|
||||
if (res != int(bytes)) {
|
||||
// ESP32 implementation has an issue where calling read can fail with EAGAIN (race condition)
|
||||
// so just re-try it until it works (with generous timeout of 1s)
|
||||
// because we check with available() first this should not cause us any trouble in all other cases
|
||||
delay(10);
|
||||
ssize_t read = this->client_->read(buf + at, len - at);
|
||||
if (read == -1) {
|
||||
if (errno == EAGAIN || errno == EWOULDBLOCK) {
|
||||
delay(1);
|
||||
continue;
|
||||
}
|
||||
ESP_LOGW(TAG, "Failed to read %d bytes of data, errno: %d", len, errno);
|
||||
return false;
|
||||
} else {
|
||||
success = true;
|
||||
at += read;
|
||||
}
|
||||
delay(1);
|
||||
}
|
||||
|
||||
if (!success) {
|
||||
ESP_LOGW(TAG, "Reading %u bytes of binary data failed!", bytes); // NOLINT
|
||||
return 0;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
bool OTAComponent::writeall_(const uint8_t *buf, size_t len) {
|
||||
uint32_t start = millis();
|
||||
uint32_t at = 0;
|
||||
while (len - at > 0) {
|
||||
uint32_t now = millis();
|
||||
if (now - start > 1000) {
|
||||
ESP_LOGW(TAG, "Timed out writing %d bytes of data", len);
|
||||
return false;
|
||||
}
|
||||
|
||||
return bytes;
|
||||
ssize_t written = this->client_->write(buf + at, len - at);
|
||||
if (written == -1) {
|
||||
if (errno == EAGAIN || errno == EWOULDBLOCK) {
|
||||
delay(1);
|
||||
continue;
|
||||
}
|
||||
ESP_LOGW(TAG, "Failed to write %d bytes of data, errno: %d", len, errno);
|
||||
return false;
|
||||
} else {
|
||||
at += written;
|
||||
}
|
||||
delay(1);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void OTAComponent::set_auth_password(const std::string &password) { this->password_ = password; }
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
#include "esphome/components/socket/socket.h"
|
||||
#include "esphome/core/component.h"
|
||||
#include "esphome/core/preferences.h"
|
||||
#include "esphome/core/helpers.h"
|
||||
#include <WiFiServer.h>
|
||||
#include <WiFiClient.h>
|
||||
|
||||
namespace esphome {
|
||||
namespace ota {
|
||||
@@ -74,14 +73,15 @@ class OTAComponent : public Component {
|
||||
uint32_t read_rtc_();
|
||||
|
||||
void handle_();
|
||||
size_t wait_receive_(uint8_t *buf, size_t bytes, bool check_disconnected = true);
|
||||
bool readall_(uint8_t *buf, size_t len);
|
||||
bool writeall_(const uint8_t *buf, size_t len);
|
||||
|
||||
std::string password_;
|
||||
|
||||
uint16_t port_;
|
||||
|
||||
WiFiServer *server_{nullptr};
|
||||
WiFiClient client_{};
|
||||
std::unique_ptr<socket::Socket> server_;
|
||||
std::unique_ptr<socket::Socket> client_;
|
||||
|
||||
bool has_safe_mode_{false}; ///< stores whether safe mode can be enabled.
|
||||
uint32_t safe_mode_start_time_; ///< stores when safe mode was enabled.
|
||||
|
||||
28
esphome/components/socket/__init__.py
Normal file
28
esphome/components/socket/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import esphome.config_validation as cv
|
||||
import esphome.codegen as cg
|
||||
|
||||
CODEOWNERS = ["@esphome/core"]
|
||||
|
||||
CONF_IMPLEMENTATION = "implementation"
|
||||
IMPLEMENTATION_LWIP_TCP = "lwip_tcp"
|
||||
IMPLEMENTATION_BSD_SOCKETS = "bsd_sockets"
|
||||
|
||||
CONFIG_SCHEMA = cv.Schema(
|
||||
{
|
||||
cv.SplitDefault(
|
||||
CONF_IMPLEMENTATION,
|
||||
esp8266=IMPLEMENTATION_LWIP_TCP,
|
||||
esp32=IMPLEMENTATION_BSD_SOCKETS,
|
||||
): cv.one_of(
|
||||
IMPLEMENTATION_LWIP_TCP, IMPLEMENTATION_BSD_SOCKETS, lower=True, space="_"
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def to_code(config):
|
||||
impl = config[CONF_IMPLEMENTATION]
|
||||
if impl == IMPLEMENTATION_LWIP_TCP:
|
||||
cg.add_define("USE_SOCKET_IMPL_LWIP_TCP")
|
||||
elif impl == IMPLEMENTATION_BSD_SOCKETS:
|
||||
cg.add_define("USE_SOCKET_IMPL_BSD_SOCKETS")
|
||||
105
esphome/components/socket/bsd_sockets_impl.cpp
Normal file
105
esphome/components/socket/bsd_sockets_impl.cpp
Normal file
@@ -0,0 +1,105 @@
|
||||
#include "socket.h"
|
||||
#include "esphome/core/defines.h"
|
||||
|
||||
#ifdef USE_SOCKET_IMPL_BSD_SOCKETS
|
||||
|
||||
#include <string.h>
|
||||
|
||||
namespace esphome {
|
||||
namespace socket {
|
||||
|
||||
std::string format_sockaddr(const struct sockaddr_storage &storage) {
|
||||
if (storage.ss_family == AF_INET) {
|
||||
const struct sockaddr_in *addr = reinterpret_cast<const struct sockaddr_in *>(&storage);
|
||||
char buf[INET_ADDRSTRLEN];
|
||||
const char *ret = inet_ntop(AF_INET, &addr->sin_addr, buf, sizeof(buf));
|
||||
if (ret == NULL)
|
||||
return {};
|
||||
return std::string{buf};
|
||||
} else if (storage.ss_family == AF_INET6) {
|
||||
const struct sockaddr_in6 *addr = reinterpret_cast<const struct sockaddr_in6 *>(&storage);
|
||||
char buf[INET6_ADDRSTRLEN];
|
||||
const char *ret = inet_ntop(AF_INET6, &addr->sin6_addr, buf, sizeof(buf));
|
||||
if (ret == NULL)
|
||||
return {};
|
||||
return std::string{buf};
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
class BSDSocketImpl : public Socket {
|
||||
public:
|
||||
BSDSocketImpl(int fd) : Socket(), fd_(fd) {}
|
||||
~BSDSocketImpl() override {
|
||||
if (!closed_) {
|
||||
close();
|
||||
}
|
||||
}
|
||||
std::unique_ptr<Socket> accept(struct sockaddr *addr, socklen_t *addrlen) override {
|
||||
int fd = ::accept(fd_, addr, addrlen);
|
||||
if (fd == -1)
|
||||
return {};
|
||||
return std::unique_ptr<BSDSocketImpl>{new BSDSocketImpl(fd)};
|
||||
}
|
||||
int bind(const struct sockaddr *addr, socklen_t addrlen) override { return ::bind(fd_, addr, addrlen); }
|
||||
int close() override {
|
||||
int ret = ::close(fd_);
|
||||
closed_ = true;
|
||||
return ret;
|
||||
}
|
||||
int shutdown(int how) override { return ::shutdown(fd_, how); }
|
||||
|
||||
int getpeername(struct sockaddr *addr, socklen_t *addrlen) override { return ::getpeername(fd_, addr, addrlen); }
|
||||
std::string getpeername() override {
|
||||
struct sockaddr_storage storage;
|
||||
socklen_t len = sizeof(storage);
|
||||
int err = this->getpeername((struct sockaddr *) &storage, &len);
|
||||
if (err != 0)
|
||||
return {};
|
||||
return format_sockaddr(storage);
|
||||
}
|
||||
int getsockname(struct sockaddr *addr, socklen_t *addrlen) override { return ::getsockname(fd_, addr, addrlen); }
|
||||
std::string getsockname() override {
|
||||
struct sockaddr_storage storage;
|
||||
socklen_t len = sizeof(storage);
|
||||
int err = this->getsockname((struct sockaddr *) &storage, &len);
|
||||
if (err != 0)
|
||||
return {};
|
||||
return format_sockaddr(storage);
|
||||
}
|
||||
int getsockopt(int level, int optname, void *optval, socklen_t *optlen) override {
|
||||
return ::getsockopt(fd_, level, optname, optval, optlen);
|
||||
}
|
||||
int setsockopt(int level, int optname, const void *optval, socklen_t optlen) override {
|
||||
return ::setsockopt(fd_, level, optname, optval, optlen);
|
||||
}
|
||||
int listen(int backlog) override { return ::listen(fd_, backlog); }
|
||||
ssize_t read(void *buf, size_t len) override { return ::read(fd_, buf, len); }
|
||||
ssize_t write(const void *buf, size_t len) override { return ::write(fd_, buf, len); }
|
||||
int setblocking(bool blocking) override {
|
||||
int fl = ::fcntl(fd_, F_GETFL, 0);
|
||||
if (blocking) {
|
||||
fl &= ~O_NONBLOCK;
|
||||
} else {
|
||||
fl |= O_NONBLOCK;
|
||||
}
|
||||
::fcntl(fd_, F_SETFL, fl);
|
||||
return 0;
|
||||
}
|
||||
|
||||
protected:
|
||||
int fd_;
|
||||
bool closed_ = false;
|
||||
};
|
||||
|
||||
std::unique_ptr<Socket> socket(int domain, int type, int protocol) {
|
||||
int ret = ::socket(domain, type, protocol);
|
||||
if (ret == -1)
|
||||
return nullptr;
|
||||
return std::unique_ptr<Socket>{new BSDSocketImpl(ret)};
|
||||
}
|
||||
|
||||
} // namespace socket
|
||||
} // namespace esphome
|
||||
|
||||
#endif // USE_SOCKET_IMPL_BSD_SOCKETS
|
||||
118
esphome/components/socket/headers.h
Normal file
118
esphome/components/socket/headers.h
Normal file
@@ -0,0 +1,118 @@
|
||||
#pragma once
|
||||
#include "esphome/core/defines.h"
|
||||
|
||||
// Helper file to include all socket-related system headers (or use our own
|
||||
// definitions where system ones don't exist)
|
||||
|
||||
|
||||
#ifdef USE_SOCKET_IMPL_LWIP_TCP
|
||||
|
||||
#define LWIP_INTERNAL
|
||||
#include <sys/types.h>
|
||||
#include "lwip/inet.h"
|
||||
#include <stdint.h>
|
||||
#include <errno.h>
|
||||
|
||||
/* Address families. */
|
||||
#define AF_UNSPEC 0
|
||||
#define AF_INET 2
|
||||
#define AF_INET6 10
|
||||
#define PF_INET AF_INET
|
||||
#define PF_INET6 AF_INET6
|
||||
#define PF_UNSPEC AF_UNSPEC
|
||||
#define IPPROTO_IP 0
|
||||
#define IPPROTO_TCP 6
|
||||
#define IPPROTO_IPV6 41
|
||||
#define IPPROTO_ICMPV6 58
|
||||
|
||||
#define TCP_NODELAY 0x01
|
||||
|
||||
#define F_GETFL 3
|
||||
#define F_SETFL 4
|
||||
#define O_NONBLOCK 1
|
||||
|
||||
#define SHUT_RD 0
|
||||
#define SHUT_WR 1
|
||||
#define SHUT_RDWR 2
|
||||
|
||||
/* Socket protocol types (TCP/UDP/RAW) */
|
||||
#define SOCK_STREAM 1
|
||||
#define SOCK_DGRAM 2
|
||||
#define SOCK_RAW 3
|
||||
|
||||
#define SO_REUSEADDR 0x0004 /* Allow local address reuse */
|
||||
#define SO_KEEPALIVE 0x0008 /* keep connections alive */
|
||||
#define SO_BROADCAST 0x0020 /* permit to send and to receive broadcast messages (see IP_SOF_BROADCAST option) */
|
||||
|
||||
#define SOL_SOCKET 0xfff /* options for socket level */
|
||||
|
||||
typedef uint8_t sa_family_t;
|
||||
typedef uint16_t in_port_t;
|
||||
|
||||
struct sockaddr_in {
|
||||
uint8_t sin_len;
|
||||
sa_family_t sin_family;
|
||||
in_port_t sin_port;
|
||||
struct in_addr sin_addr;
|
||||
#define SIN_ZERO_LEN 8
|
||||
char sin_zero[SIN_ZERO_LEN];
|
||||
};
|
||||
|
||||
struct sockaddr_in6 {
|
||||
uint8_t sin6_len; /* length of this structure */
|
||||
sa_family_t sin6_family; /* AF_INET6 */
|
||||
in_port_t sin6_port; /* Transport layer port # */
|
||||
uint32_t sin6_flowinfo; /* IPv6 flow information */
|
||||
struct in6_addr sin6_addr; /* IPv6 address */
|
||||
uint32_t sin6_scope_id; /* Set of interfaces for scope */
|
||||
};
|
||||
|
||||
struct sockaddr {
|
||||
uint8_t sa_len;
|
||||
sa_family_t sa_family;
|
||||
char sa_data[14];
|
||||
};
|
||||
|
||||
struct sockaddr_storage {
|
||||
uint8_t s2_len;
|
||||
sa_family_t ss_family;
|
||||
char s2_data1[2];
|
||||
uint32_t s2_data2[3];
|
||||
uint32_t s2_data3[3];
|
||||
};
|
||||
typedef uint32_t socklen_t;
|
||||
|
||||
#ifdef ARDUINO_ARCH_ESP8266
|
||||
// arduino-esp8266 declares a global vars called INADDR_NONE/ANY which are invalid with the define
|
||||
#ifdef INADDR_ANY
|
||||
#undef INADDR_ANY
|
||||
#endif
|
||||
#ifdef INADDR_NONE
|
||||
#undef INADDR_NONE
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#endif // USE_SOCKET_IMPL_LWIP_TCP
|
||||
|
||||
#ifdef USE_SOCKET_IMPL_BSD_SOCKETS
|
||||
|
||||
#include <sys/types.h>
|
||||
#include <sys/socket.h>
|
||||
#include <sys/ioctl.h>
|
||||
#include <unistd.h>
|
||||
#include <fcntl.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#ifdef ARDUINO_ARCH_ESP32
|
||||
// arduino-esp32 declares a global var called INADDR_NONE which is replaced
|
||||
// by the define
|
||||
#ifdef INADDR_NONE
|
||||
#undef INADDR_NONE
|
||||
#endif
|
||||
// not defined for ESP32
|
||||
typedef uint32_t socklen_t;
|
||||
#endif // ARDUINO_ARCH_ESP32
|
||||
|
||||
#endif // USE_SOCKET_IMPL_BSD_SOCKETS
|
||||
|
||||
|
||||
470
esphome/components/socket/lwip_raw_tcp_impl.cpp
Normal file
470
esphome/components/socket/lwip_raw_tcp_impl.cpp
Normal file
@@ -0,0 +1,470 @@
|
||||
#include "socket.h"
|
||||
#include "esphome/core/defines.h"
|
||||
|
||||
#ifdef USE_SOCKET_IMPL_LWIP_TCP
|
||||
|
||||
#include <queue>
|
||||
#include <string.h>
|
||||
#include "lwip/opt.h"
|
||||
#include "lwip/ip.h"
|
||||
#include "lwip/tcp.h"
|
||||
#include "lwip/netif.h"
|
||||
#include "errno.h"
|
||||
|
||||
#include "esphome/core/log.h"
|
||||
|
||||
namespace esphome {
|
||||
namespace socket {
|
||||
|
||||
static const char *const TAG = "lwip";
|
||||
|
||||
class LWIPRawImpl : public Socket {
|
||||
public:
|
||||
LWIPRawImpl(struct tcp_pcb *pcb) : pcb_(pcb) {}
|
||||
~LWIPRawImpl() override {
|
||||
if (pcb_ != nullptr) {
|
||||
tcp_abort(pcb_);
|
||||
pcb_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void init() {
|
||||
tcp_arg(pcb_, this);
|
||||
tcp_accept(pcb_, LWIPRawImpl::s_accept_fn);
|
||||
tcp_recv(pcb_, LWIPRawImpl::s_recv_fn);
|
||||
tcp_err(pcb_, LWIPRawImpl::s_err_fn);
|
||||
}
|
||||
|
||||
std::unique_ptr<Socket> accept(struct sockaddr *addr, socklen_t *addrlen) override {
|
||||
if (pcb_ == nullptr) {
|
||||
errno = EBADF;
|
||||
return nullptr;
|
||||
}
|
||||
if (accepted_sockets_.empty()) {
|
||||
errno = EWOULDBLOCK;
|
||||
return nullptr;
|
||||
}
|
||||
std::unique_ptr<LWIPRawImpl> sock = std::move(accepted_sockets_.front());
|
||||
accepted_sockets_.pop();
|
||||
if (addr != nullptr) {
|
||||
sock->getpeername(addr, addrlen);
|
||||
}
|
||||
sock->init();
|
||||
return std::unique_ptr<Socket>(std::move(sock));
|
||||
}
|
||||
int bind(const struct sockaddr *name, socklen_t addrlen) override {
|
||||
if (pcb_ == nullptr) {
|
||||
errno = EBADF;
|
||||
return -1;
|
||||
}
|
||||
if (name == nullptr) {
|
||||
errno = EINVAL;
|
||||
return 0;
|
||||
}
|
||||
ip_addr_t ip;
|
||||
in_port_t port;
|
||||
auto family = name->sa_family;
|
||||
#if LWIP_IPV6
|
||||
if (family == AF_INET) {
|
||||
if (addrlen < sizeof(sockaddr_in6)) {
|
||||
errno = EINVAL;
|
||||
return -1;
|
||||
}
|
||||
auto *addr4 = reinterpret_cast<const sockaddr_in *>(name);
|
||||
port = ntohs(addr4->sin_port);
|
||||
ip.type = IPADDR_TYPE_V4;
|
||||
ip.u_addr.ip4.addr = addr4->sin_addr.s_addr;
|
||||
|
||||
} else if (family == AF_INET6) {
|
||||
if (addrlen < sizeof(sockaddr_in)) {
|
||||
errno = EINVAL;
|
||||
return -1;
|
||||
}
|
||||
auto *addr6 = reinterpret_cast<const sockaddr_in6 *>(name);
|
||||
port = ntohs(addr6->sin6_port);
|
||||
ip.type = IPADDR_TYPE_V6;
|
||||
memcpy(&ip.u_addr.ip6.addr, &addr6->sin6_addr.un.u8_addr, 16);
|
||||
} else {
|
||||
errno = EINVAL;
|
||||
return -1;
|
||||
}
|
||||
#else
|
||||
if (family != AF_INET) {
|
||||
errno = EINVAL;
|
||||
return -1;
|
||||
}
|
||||
auto *addr4 = reinterpret_cast<const sockaddr_in *>(name);
|
||||
port = ntohs(addr4->sin_port);
|
||||
ip.addr = addr4->sin_addr.s_addr;
|
||||
#endif
|
||||
err_t err = tcp_bind(pcb_, &ip, port);
|
||||
if (err == ERR_USE) {
|
||||
errno = EADDRINUSE;
|
||||
return -1;
|
||||
}
|
||||
if (err == ERR_VAL) {
|
||||
errno = EINVAL;
|
||||
return -1;
|
||||
}
|
||||
if (err != ERR_OK) {
|
||||
errno = EIO;
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
int close() {
|
||||
if (pcb_ == nullptr) {
|
||||
errno = EBADF;
|
||||
return -1;
|
||||
}
|
||||
err_t err = tcp_close(pcb_);
|
||||
if (err != ERR_OK) {
|
||||
tcp_abort(pcb_);
|
||||
pcb_ = nullptr;
|
||||
errno = err == ERR_MEM ? ENOMEM : EIO;
|
||||
return -1;
|
||||
}
|
||||
pcb_ = nullptr;
|
||||
return 0;
|
||||
}
|
||||
int shutdown(int how) override {
|
||||
if (pcb_ == nullptr) {
|
||||
errno = EBADF;
|
||||
return -1;
|
||||
}
|
||||
bool shut_rx = false, shut_tx = false;
|
||||
if (how == SHUT_RD) {
|
||||
shut_rx = true;
|
||||
} else if (how == SHUT_WR) {
|
||||
shut_tx = true;
|
||||
} else if (how == SHUT_RDWR) {
|
||||
shut_rx = shut_tx = true;
|
||||
} else {
|
||||
errno = EINVAL;
|
||||
return -1;
|
||||
}
|
||||
err_t err = tcp_shutdown(pcb_, shut_rx, shut_tx);
|
||||
if (err != ERR_OK) {
|
||||
errno = err == ERR_MEM ? ENOMEM : EIO;
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int getpeername(struct sockaddr *name, socklen_t *addrlen) override {
|
||||
if (pcb_ == nullptr) {
|
||||
errno = EBADF;
|
||||
return -1;
|
||||
}
|
||||
if (name == nullptr || addrlen == nullptr) {
|
||||
errno = EINVAL;
|
||||
return -1;
|
||||
}
|
||||
if (*addrlen < sizeof(struct sockaddr_in)) {
|
||||
errno = EINVAL;
|
||||
return -1;
|
||||
}
|
||||
struct sockaddr_in *addr = reinterpret_cast<struct sockaddr_in *>(name);
|
||||
addr->sin_family = AF_INET;
|
||||
*addrlen = addr->sin_len = sizeof(struct sockaddr_in);
|
||||
addr->sin_port = pcb_->remote_port;
|
||||
addr->sin_addr.s_addr = pcb_->remote_ip.addr;
|
||||
return 0;
|
||||
}
|
||||
std::string getpeername() override {
|
||||
if (pcb_ == nullptr) {
|
||||
errno = EBADF;
|
||||
return "";
|
||||
}
|
||||
char buffer[24];
|
||||
uint32_t ip4 = pcb_->remote_ip.addr;
|
||||
snprintf(buffer, sizeof(buffer), "%d.%d.%d.%d", (ip4 >> 24) & 0xFF, (ip4 >> 16) & 0xFF, (ip4 >> 8) & 0xFF, (ip4 >> 0) & 0xFF);
|
||||
return std::string(buffer);
|
||||
}
|
||||
int getsockname(struct sockaddr *name, socklen_t *addrlen) override {
|
||||
if (pcb_ == nullptr) {
|
||||
errno = EBADF;
|
||||
return -1;
|
||||
}
|
||||
if (name == nullptr || addrlen == nullptr) {
|
||||
errno = EINVAL;
|
||||
return -1;
|
||||
}
|
||||
if (*addrlen < sizeof(struct sockaddr_in)) {
|
||||
errno = EINVAL;
|
||||
return -1;
|
||||
}
|
||||
struct sockaddr_in *addr = reinterpret_cast<struct sockaddr_in *>(name);
|
||||
addr->sin_family = AF_INET;
|
||||
*addrlen = addr->sin_len = sizeof(struct sockaddr_in);
|
||||
addr->sin_port = pcb_->local_port;
|
||||
addr->sin_addr.s_addr = pcb_->local_ip.addr;
|
||||
return 0;
|
||||
}
|
||||
std::string getsockname() override {
|
||||
if (pcb_ == nullptr) {
|
||||
errno = EBADF;
|
||||
return "";
|
||||
}
|
||||
char buffer[24];
|
||||
uint32_t ip4 = pcb_->local_ip.addr;
|
||||
snprintf(buffer, sizeof(buffer), "%d.%d.%d.%d", (ip4 >> 24) & 0xFF, (ip4 >> 16) & 0xFF, (ip4 >> 8) & 0xFF, (ip4 >> 0) & 0xFF);
|
||||
return std::string(buffer);
|
||||
}
|
||||
int getsockopt(int level, int optname, void *optval, socklen_t *optlen) override {
|
||||
if (pcb_ == nullptr) {
|
||||
errno = EBADF;
|
||||
return -1;
|
||||
}
|
||||
if (level == SOL_SOCKET && optname == SO_REUSEADDR) {
|
||||
if (optlen < 4) {
|
||||
errno = EINVAL;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// lwip doesn't seem to have this feature. Don't send an error
|
||||
// to prevent warnings
|
||||
*reinterpret_cast<int *>(optval) = 1;
|
||||
*optlen = 4;
|
||||
return 0;
|
||||
}
|
||||
if (level == IPPROTO_TCP && optname == TCP_NODELAY) {
|
||||
if (optlen < 4) {
|
||||
errno = EINVAL;
|
||||
return -1;
|
||||
}
|
||||
*reinterpret_cast<int *>(optval) = tcp_nagle_disabled(pcb_);
|
||||
*optlen = 4;
|
||||
return 0;
|
||||
}
|
||||
|
||||
errno = EINVAL;
|
||||
return -1;
|
||||
}
|
||||
int setsockopt(int level, int optname, const void *optval, socklen_t optlen) override {
|
||||
if (pcb_ == nullptr) {
|
||||
errno = EBADF;
|
||||
return -1;
|
||||
}
|
||||
if (level == SOL_SOCKET && optname == SO_REUSEADDR) {
|
||||
if (optlen != 4) {
|
||||
errno = EINVAL;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// lwip doesn't seem to have this feature. Don't send an error
|
||||
// to prevent warnings
|
||||
return 0;
|
||||
}
|
||||
if (level == IPPROTO_TCP && optname == TCP_NODELAY) {
|
||||
if (optlen != 4) {
|
||||
errno = EINVAL;
|
||||
return -1;
|
||||
}
|
||||
int val = *reinterpret_cast<const int *>(optval);
|
||||
if (val != 0) {
|
||||
tcp_nagle_disable(pcb_);
|
||||
} else {
|
||||
tcp_nagle_enable(pcb_);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
errno = EINVAL;
|
||||
return -1;
|
||||
}
|
||||
int listen(int backlog) override {
|
||||
if (pcb_ == nullptr) {
|
||||
errno = EBADF;
|
||||
return -1;
|
||||
}
|
||||
struct tcp_pcb *listen_pcb = tcp_listen_with_backlog(pcb_, backlog);
|
||||
if (listen_pcb == nullptr) {
|
||||
tcp_abort(pcb_);
|
||||
pcb_ = nullptr;
|
||||
errno = EOPNOTSUPP;
|
||||
return -1;
|
||||
}
|
||||
// tcp_listen reallocates the pcb, replace ours
|
||||
pcb_ = listen_pcb;
|
||||
// set callbacks on new pcb
|
||||
tcp_arg(pcb_, this);
|
||||
tcp_accept(pcb_, LWIPRawImpl::s_accept_fn);
|
||||
return 0;
|
||||
}
|
||||
ssize_t read(void *buf, size_t len) override {
|
||||
if (pcb_ == nullptr) {
|
||||
errno = EBADF;
|
||||
return -1;
|
||||
}
|
||||
if (rx_closed_ && rx_buf_ == nullptr) {
|
||||
errno = ECONNRESET;
|
||||
return -1;
|
||||
}
|
||||
if (len == 0) {
|
||||
return 0;
|
||||
}
|
||||
if (rx_buf_ == nullptr) {
|
||||
errno = EWOULDBLOCK;
|
||||
return -1;
|
||||
}
|
||||
|
||||
size_t read = 0;
|
||||
uint8_t *buf8 = reinterpret_cast<uint8_t *>(buf);
|
||||
while (len) {
|
||||
size_t pb_len = rx_buf_->len;
|
||||
size_t pb_left = pb_len - rx_buf_offset_;
|
||||
if (pb_left == 0)
|
||||
break;
|
||||
size_t copysize = std::min(len, pb_left);
|
||||
memcpy(buf8, reinterpret_cast<uint8_t *>(rx_buf_->payload) + rx_buf_offset_, copysize);
|
||||
|
||||
if (pb_left == copysize) {
|
||||
// full pb copied, free it
|
||||
if (rx_buf_->next == nullptr) {
|
||||
// last buffer in chain
|
||||
pbuf_free(rx_buf_);
|
||||
rx_buf_ = nullptr;
|
||||
rx_buf_offset_ = 0;
|
||||
} else {
|
||||
auto *old_buf = rx_buf_;
|
||||
rx_buf_ = rx_buf_->next;
|
||||
pbuf_ref(rx_buf_);
|
||||
pbuf_free(old_buf);
|
||||
rx_buf_offset_ = 0;
|
||||
}
|
||||
} else {
|
||||
rx_buf_offset_ += copysize;
|
||||
}
|
||||
tcp_recved(pcb_, copysize);
|
||||
|
||||
buf8 += copysize;
|
||||
len -= copysize;
|
||||
read += copysize;
|
||||
}
|
||||
|
||||
return read;
|
||||
}
|
||||
ssize_t write(const void *buf, size_t len) {
|
||||
if (pcb_ == nullptr) {
|
||||
errno = EBADF;
|
||||
return -1;
|
||||
}
|
||||
if (len == 0)
|
||||
return 0;
|
||||
if (buf == nullptr) {
|
||||
errno = EINVAL;
|
||||
return 0;
|
||||
}
|
||||
auto space = tcp_sndbuf(pcb_);
|
||||
if (space == 0) {
|
||||
errno = EWOULDBLOCK;
|
||||
return -1;
|
||||
}
|
||||
size_t to_send = std::min((size_t) space, len);
|
||||
err_t err = tcp_write(pcb_, buf, to_send, TCP_WRITE_FLAG_COPY);
|
||||
if (err == ERR_MEM) {
|
||||
errno = EWOULDBLOCK;
|
||||
return -1;
|
||||
}
|
||||
if (err != ERR_OK) {
|
||||
errno = EIO;
|
||||
return -1;
|
||||
}
|
||||
err = tcp_output(pcb_);
|
||||
if (err != ERR_OK) {
|
||||
errno = EIO;
|
||||
return -1;
|
||||
}
|
||||
return to_send;
|
||||
}
|
||||
int setblocking(bool blocking) {
|
||||
if (pcb_ == nullptr) {
|
||||
errno = EBADF;
|
||||
return -1;
|
||||
}
|
||||
if (blocking) {
|
||||
// blocking operation not supported
|
||||
errno = EINVAL;
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
err_t accept_fn(struct tcp_pcb *newpcb, err_t err) {
|
||||
if (err != ERR_OK || newpcb == 0) {
|
||||
// "An error code if there has been an error accepting. Only return ERR_ABRT if you have
|
||||
// called tcp_abort from within the callback function!"
|
||||
// https://www.nongnu.org/lwip/2_1_x/tcp_8h.html#a00517abce6856d6c82f0efebdafb734d
|
||||
// nothing to do here, we just don't push it to the queue
|
||||
return ERR_OK;
|
||||
}
|
||||
accepted_sockets_.emplace(new LWIPRawImpl(newpcb));
|
||||
return ERR_OK;
|
||||
}
|
||||
void err_fn(err_t err) {
|
||||
// "If a connection is aborted because of an error, the application is alerted of this event by
|
||||
// the err callback."
|
||||
// pcb is already freed when this callback is called
|
||||
// ERR_RST: connection was reset by remote host
|
||||
// ERR_ABRT: aborted through tcp_abort or TCP timer
|
||||
pcb_ = nullptr;
|
||||
}
|
||||
err_t recv_fn(struct pbuf *pb, err_t err) {
|
||||
if (err != 0) {
|
||||
// "An error code if there has been an error receiving Only return ERR_ABRT if you have
|
||||
// called tcp_abort from within the callback function!"
|
||||
rx_closed_ = true;
|
||||
return ERR_OK;
|
||||
}
|
||||
if (pb == nullptr) {
|
||||
rx_closed_ = true;
|
||||
return ERR_OK;
|
||||
}
|
||||
if (rx_buf_ == nullptr) {
|
||||
// no need to copy because lwIP gave control of it to us
|
||||
rx_buf_ = pb;
|
||||
rx_buf_offset_ = 0;
|
||||
} else {
|
||||
pbuf_cat(rx_buf_, pb);
|
||||
}
|
||||
return ERR_OK;
|
||||
}
|
||||
|
||||
|
||||
static err_t s_accept_fn(void *arg, struct tcp_pcb *newpcb, err_t err) {
|
||||
LWIPRawImpl *arg_this = reinterpret_cast<LWIPRawImpl *>(arg);
|
||||
return arg_this->accept_fn(newpcb, err);
|
||||
}
|
||||
|
||||
static void s_err_fn(void *arg, err_t err) {
|
||||
LWIPRawImpl *arg_this = reinterpret_cast<LWIPRawImpl *>(arg);
|
||||
return arg_this->err_fn(err);
|
||||
}
|
||||
|
||||
static err_t s_recv_fn(void *arg, struct tcp_pcb *pcb, struct pbuf *pb, err_t err) {
|
||||
LWIPRawImpl *arg_this = reinterpret_cast<LWIPRawImpl *>(arg);
|
||||
return arg_this->recv_fn(pb, err);
|
||||
}
|
||||
|
||||
protected:
|
||||
struct tcp_pcb *pcb_;
|
||||
std::queue<std::unique_ptr<LWIPRawImpl>> accepted_sockets_;
|
||||
bool rx_closed_ = false;
|
||||
pbuf *rx_buf_ = nullptr;
|
||||
size_t rx_buf_offset_ = 0;
|
||||
};
|
||||
|
||||
std::unique_ptr<Socket> socket(int domain, int type, int protocol) {
|
||||
auto *pcb = tcp_new();
|
||||
if (pcb == nullptr)
|
||||
return nullptr;
|
||||
auto *sock = new LWIPRawImpl(pcb);
|
||||
sock->init();
|
||||
return std::unique_ptr<Socket>{sock};
|
||||
}
|
||||
|
||||
} // namespace socket
|
||||
} // namespace esphome
|
||||
|
||||
#endif // USE_SOCKET_IMPL_LWIP_TCP
|
||||
42
esphome/components/socket/socket.h
Normal file
42
esphome/components/socket/socket.h
Normal file
@@ -0,0 +1,42 @@
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include "headers.h"
|
||||
#include "esphome/core/optional.h"
|
||||
|
||||
namespace esphome {
|
||||
namespace socket {
|
||||
|
||||
class Socket {
|
||||
public:
|
||||
Socket() = default;
|
||||
virtual ~Socket() = default;
|
||||
Socket(const Socket &) = delete;
|
||||
Socket &operator=(const Socket &) = delete;
|
||||
|
||||
virtual std::unique_ptr<Socket> accept(struct sockaddr *addr, socklen_t *addrlen) = 0;
|
||||
virtual int bind(const struct sockaddr *addr, socklen_t addrlen) = 0;
|
||||
virtual int close() = 0;
|
||||
// not supported yet:
|
||||
// virtual int connect(const std::string &address) = 0;
|
||||
// virtual int connect(const struct sockaddr *addr, socklen_t addrlen) = 0;
|
||||
virtual int shutdown(int how) = 0;
|
||||
|
||||
virtual int getpeername(struct sockaddr *addr, socklen_t *addrlen) = 0;
|
||||
virtual std::string getpeername() = 0;
|
||||
virtual int getsockname(struct sockaddr *addr, socklen_t *addrlen) = 0;
|
||||
virtual std::string getsockname() = 0;
|
||||
virtual int getsockopt(int level, int optname, void *optval, socklen_t *optlen) = 0;
|
||||
virtual int setsockopt(int level, int optname, const void *optval, socklen_t optlen) = 0;
|
||||
virtual int listen(int backlog) = 0;
|
||||
virtual ssize_t read(void *buf, size_t len) = 0;
|
||||
virtual ssize_t write(const void *buf, size_t len) = 0;
|
||||
virtual int setblocking(bool blocking) = 0;
|
||||
virtual int loop() { return 0; };
|
||||
};
|
||||
|
||||
std::unique_ptr<Socket> socket(int domain, int type, int protocol);
|
||||
|
||||
} // namespace socket
|
||||
} // namespace esphome
|
||||
9
esphome/components/ssl/__init__.py
Normal file
9
esphome/components/ssl/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import esphome.config_validation as cv
|
||||
|
||||
AUTO_LOAD = ["socket"]
|
||||
|
||||
CONFIG_SCHEMA = cv.Schema({})
|
||||
|
||||
|
||||
async def to_code(config):
|
||||
pass
|
||||
272
esphome/components/ssl/mbedtls_impl.cpp
Normal file
272
esphome/components/ssl/mbedtls_impl.cpp
Normal file
@@ -0,0 +1,272 @@
|
||||
#include "ssl_context.h"
|
||||
#include <string.h>
|
||||
#include "mbedtls/platform.h"
|
||||
#include "mbedtls/net_sockets.h"
|
||||
#include "mbedtls/esp_debug.h"
|
||||
#include "mbedtls/ssl.h"
|
||||
#include "mbedtls/entropy.h"
|
||||
#include "mbedtls/ctr_drbg.h"
|
||||
#include "mbedtls/error.h"
|
||||
#include "mbedtls/certs.h"
|
||||
|
||||
#ifdef INADDR_NONE
|
||||
#undef INADDR_NONE
|
||||
#endif
|
||||
#include "esphome/core/log.h"
|
||||
|
||||
namespace esphome {
|
||||
namespace ssl {
|
||||
|
||||
static const char *const TAG = "ssl.mbedtls";
|
||||
|
||||
static int entropy_hw_random_source(void *data, uint8_t *output, size_t len, size_t *olen) {
|
||||
esp_fill_random(output, len);
|
||||
*olen = len;
|
||||
return 0;
|
||||
}
|
||||
|
||||
struct MbedTLSBioCtx {
|
||||
socket::Socket *sock;
|
||||
|
||||
static int send(void *raw, const uint8_t *buf, size_t len) {
|
||||
auto *ctx = reinterpret_cast<MbedTLSBioCtx *>(raw);
|
||||
ssize_t ret = ctx->sock->write(buf, len);
|
||||
if (ret != -1)
|
||||
return ret;
|
||||
if (errno == EWOULDBLOCK || errno == EAGAIN)
|
||||
return MBEDTLS_ERR_SSL_WANT_WRITE;
|
||||
if (errno == EPIPE || errno == ECONNRESET)
|
||||
return MBEDTLS_ERR_NET_CONN_RESET;
|
||||
if (errno == EINTR)
|
||||
return MBEDTLS_ERR_SSL_WANT_WRITE;
|
||||
return MBEDTLS_ERR_NET_SEND_FAILED;
|
||||
}
|
||||
static int recv(void *raw, uint8_t *buf, size_t len) {
|
||||
auto *ctx = reinterpret_cast<MbedTLSBioCtx *>(raw);
|
||||
ssize_t ret = ctx->sock->read(buf, len);
|
||||
if (ret != -1)
|
||||
return ret;
|
||||
if (errno == EWOULDBLOCK || errno == EAGAIN)
|
||||
return MBEDTLS_ERR_SSL_WANT_WRITE;
|
||||
if (errno == EPIPE || errno == ECONNRESET)
|
||||
return MBEDTLS_ERR_NET_CONN_RESET;
|
||||
if (errno == EINTR)
|
||||
return MBEDTLS_ERR_SSL_WANT_WRITE;
|
||||
return MBEDTLS_ERR_NET_SEND_FAILED;
|
||||
}
|
||||
};
|
||||
|
||||
void test();
|
||||
|
||||
class MbedTLSWrappedSocket : public socket::Socket {
|
||||
public:
|
||||
MbedTLSWrappedSocket(std::unique_ptr<socket::Socket> sock) : socket::Socket(), sock_(std::move(sock)) {}
|
||||
~MbedTLSWrappedSocket() override {
|
||||
mbedtls_ssl_free(&ssl_);
|
||||
sock_ = nullptr;
|
||||
}
|
||||
void init(const mbedtls_ssl_config *conf) {
|
||||
// TODO: reuse ssl contexts?
|
||||
mbedtls_ssl_init(&ssl_);
|
||||
int err = mbedtls_ssl_setup(&ssl_, conf);
|
||||
if (err != 0) {
|
||||
ESP_LOGW(TAG, "mbedtls_ssl_setup failed: %d", err);
|
||||
return;
|
||||
}
|
||||
// sock pointer does not fit in void*
|
||||
// instead store it in a heap-allocated var
|
||||
auto *ctx = new MbedTLSBioCtx;
|
||||
// unsafe, but should be fine because we free before sock is reset
|
||||
ctx->sock = sock_.get();
|
||||
mbedtls_ssl_set_bio(&ssl_, ctx, MbedTLSBioCtx::send, MbedTLSBioCtx::recv, nullptr);
|
||||
|
||||
do_handshake_ = true;
|
||||
}
|
||||
|
||||
std::unique_ptr<Socket> accept(struct sockaddr *addr, socklen_t *addrlen) override {
|
||||
// only for server sockets
|
||||
errno = EBADF;
|
||||
return {};
|
||||
}
|
||||
int bind(const struct sockaddr *addr, socklen_t addrlen) override {
|
||||
errno = EBADF;
|
||||
return -1;
|
||||
}
|
||||
int close() override {
|
||||
do_handshake_ = false;
|
||||
return sock_->close();
|
||||
}
|
||||
int connect(const std::string &address) override { return sock_->connect(address); }
|
||||
int connect(const struct sockaddr *addr, socklen_t addrlen) override { return sock_->connect(addr, addrlen); }
|
||||
int shutdown(int how) override {
|
||||
do_handshake_ = false;
|
||||
int ret = mbedtls_ssl_close_notify(&ssl_);
|
||||
if (ret != 0)
|
||||
return this->mbedtls_to_errno_(ret);
|
||||
return this->sock_->shutdown(how);
|
||||
}
|
||||
|
||||
int getpeername(struct sockaddr *addr, socklen_t *addrlen) override { return sock_->getpeername(addr, addrlen); }
|
||||
std::string getpeername() override { return sock_->getpeername(); }
|
||||
int getsockname(struct sockaddr *addr, socklen_t *addrlen) override { return sock_->getsockname(addr, addrlen); }
|
||||
std::string getsockname() override { return sock_->getsockname(); }
|
||||
int getsockopt(int level, int optname, void *optval, socklen_t *optlen) override {
|
||||
return sock_->getsockopt(level, optname, optval, optlen);
|
||||
}
|
||||
int setsockopt(int level, int optname, const void *optval, socklen_t optlen) override {
|
||||
return sock_->setsockopt(level, optname, optval, optlen);
|
||||
}
|
||||
int listen(int backlog) override {
|
||||
errno = EBADF;
|
||||
return -1;
|
||||
}
|
||||
ssize_t read(void *buf, size_t len) override {
|
||||
// mbedtls will automatically perform handshake here if necessary
|
||||
loop();
|
||||
if (do_handshake_) {
|
||||
errno = EWOULDBLOCK;
|
||||
return -1;
|
||||
}
|
||||
int ret = mbedtls_ssl_read(&ssl_, reinterpret_cast<uint8_t *>(buf), len);
|
||||
return this->mbedtls_to_errno_(ret);
|
||||
}
|
||||
ssize_t write(const void *buf, size_t len) override {
|
||||
loop();
|
||||
if (do_handshake_) {
|
||||
errno = EWOULDBLOCK;
|
||||
return -1;
|
||||
}
|
||||
int ret = mbedtls_ssl_write(&ssl_, reinterpret_cast<const uint8_t *>(buf), len);
|
||||
return this->mbedtls_to_errno_(ret);
|
||||
}
|
||||
int setblocking(bool blocking) override {
|
||||
// TODO: handle blocking modes
|
||||
return sock_->setblocking(blocking);
|
||||
}
|
||||
|
||||
int loop() override {
|
||||
if (do_handshake_) {
|
||||
int err = mbedtls_ssl_handshake_step(&ssl_);
|
||||
if (err == 0) {
|
||||
do_handshake_ = false;
|
||||
} else if (err == MBEDTLS_ERR_SSL_WANT_WRITE || err == MBEDTLS_ERR_SSL_WANT_READ) {
|
||||
} else {
|
||||
do_handshake_ = false;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
protected:
|
||||
int mbedtls_to_errno_(int ret) {
|
||||
if (ret >= 0) {
|
||||
return ret;
|
||||
} else if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
|
||||
errno = EWOULDBLOCK;
|
||||
return -1;
|
||||
} else if (ret == MBEDTLS_ERR_NET_CONN_RESET) {
|
||||
errno = ECONNRESET;
|
||||
return -1;
|
||||
} else {
|
||||
if (errno == 0)
|
||||
errno = EIO;
|
||||
ESP_LOGW(TAG, "mbedtls failed with %d", ret);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<socket::Socket> sock_;
|
||||
mbedtls_ssl_context ssl_;
|
||||
bool do_handshake_ = false;
|
||||
};
|
||||
|
||||
void debug_callback(void *arg, int level, const char *filename, int line_number, const char *msg) {
|
||||
ESP_LOGV(TAG, "mbedtls %d [%s:%d]: %s", level, filename, line_number, msg);
|
||||
}
|
||||
|
||||
class MbedTLSContext : public SSLContext {
|
||||
public:
|
||||
MbedTLSContext() = default;
|
||||
~MbedTLSContext() override {
|
||||
mbedtls_pk_free(&pkey_);
|
||||
mbedtls_entropy_free(&entropy_);
|
||||
mbedtls_ctr_drbg_free(&ctr_drbg_);
|
||||
mbedtls_x509_crt_free(&srv_cert_);
|
||||
mbedtls_ssl_config_free(&conf_);
|
||||
}
|
||||
|
||||
void set_server_certificate(const char *cert) override { this->srv_cert_str_ = cert; }
|
||||
void set_private_key(const char *private_key) override { this->privkey_str_ = private_key; }
|
||||
|
||||
int init() override {
|
||||
mbedtls_x509_crt_init(&srv_cert_);
|
||||
mbedtls_ctr_drbg_init(&ctr_drbg_);
|
||||
mbedtls_entropy_init(&entropy_);
|
||||
mbedtls_pk_init(&pkey_);
|
||||
mbedtls_ssl_config_init(&conf_);
|
||||
|
||||
// TODO check what this does
|
||||
int err = mbedtls_entropy_add_source(&entropy_, entropy_hw_random_source, NULL, 134, MBEDTLS_ENTROPY_SOURCE_STRONG);
|
||||
if (err != 0) {
|
||||
ESP_LOGW(TAG, "mbedtls_entropy_add_source failed: %d", err);
|
||||
return 1;
|
||||
}
|
||||
err = mbedtls_ctr_drbg_seed(&ctr_drbg_, mbedtls_entropy_func, &entropy_, NULL, 0);
|
||||
if (err != 0) {
|
||||
ESP_LOGW(TAG, "mbedtls_ctr_drbg_seed failed: %d", err);
|
||||
return 1;
|
||||
}
|
||||
|
||||
err = mbedtls_x509_crt_parse(&srv_cert_, reinterpret_cast<const uint8_t *>(srv_cert_str_),
|
||||
// "including the terminating NULL byte"
|
||||
strlen(srv_cert_str_) + 1);
|
||||
if (err != 0) {
|
||||
ESP_LOGW(TAG, "mbedtls_x509_crt_parse failed: %d", err);
|
||||
return 1;
|
||||
}
|
||||
|
||||
err = mbedtls_pk_parse_key(&pkey_, reinterpret_cast<const uint8_t *>(privkey_str_),
|
||||
// "including the terminating NULL byte"
|
||||
strlen(privkey_str_) + 1, nullptr, 0);
|
||||
if (err != 0) {
|
||||
ESP_LOGW(TAG, "mbedtls_pk_parse_key failed: %d", err);
|
||||
return 1;
|
||||
}
|
||||
|
||||
err = mbedtls_ssl_config_defaults(&conf_, MBEDTLS_SSL_IS_SERVER, MBEDTLS_SSL_TRANSPORT_STREAM,
|
||||
MBEDTLS_SSL_PRESET_DEFAULT);
|
||||
if (err != 0) {
|
||||
ESP_LOGW(TAG, "mbedtls_ssl_config_defaults failed: %d", err);
|
||||
return 1;
|
||||
}
|
||||
mbedtls_ssl_conf_rng(&conf_, mbedtls_ctr_drbg_random, &ctr_drbg_);
|
||||
err = mbedtls_ssl_conf_own_cert(&conf_, &srv_cert_, &pkey_);
|
||||
if (err != 0) {
|
||||
ESP_LOGW(TAG, "mbedtls_ssl_conf_own_cert failed: %d", err);
|
||||
return 1;
|
||||
}
|
||||
mbedtls_ssl_conf_dbg(&conf_, debug_callback, nullptr);
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::unique_ptr<socket::Socket> wrap_socket(std::unique_ptr<socket::Socket> sock) override {
|
||||
auto *wrapped = new MbedTLSWrappedSocket(std::move(sock));
|
||||
wrapped->init(&conf_);
|
||||
return std::unique_ptr<socket::Socket>{wrapped};
|
||||
}
|
||||
|
||||
protected:
|
||||
const char *srv_cert_str_ = nullptr;
|
||||
const char *privkey_str_ = nullptr;
|
||||
mbedtls_entropy_context entropy_;
|
||||
mbedtls_ctr_drbg_context ctr_drbg_;
|
||||
mbedtls_x509_crt srv_cert_;
|
||||
mbedtls_pk_context pkey_;
|
||||
mbedtls_ssl_config conf_;
|
||||
};
|
||||
|
||||
std::unique_ptr<SSLContext> create_context() { return std::unique_ptr<SSLContext>{new MbedTLSContext()}; }
|
||||
|
||||
} // namespace ssl
|
||||
} // namespace esphome
|
||||
24
esphome/components/ssl/ssl_context.h
Normal file
24
esphome/components/ssl/ssl_context.h
Normal file
@@ -0,0 +1,24 @@
|
||||
#pragma once
|
||||
#include <memory>
|
||||
#include "esphome/components/socket/socket.h"
|
||||
|
||||
namespace esphome {
|
||||
namespace ssl {
|
||||
|
||||
class SSLContext {
|
||||
public:
|
||||
SSLContext() = default;
|
||||
virtual ~SSLContext() = default;
|
||||
SSLContext(const SSLContext &) = delete;
|
||||
SSLContext &operator=(const SSLContext &) = delete;
|
||||
|
||||
virtual int init() = 0;
|
||||
virtual void set_server_certificate(const char *cert) = 0;
|
||||
virtual void set_private_key(const char *private_key) = 0;
|
||||
virtual std::unique_ptr<socket::Socket> wrap_socket(std::unique_ptr<socket::Socket> sock) = 0;
|
||||
};
|
||||
|
||||
std::unique_ptr<SSLContext> create_context();
|
||||
|
||||
} // namespace ssl
|
||||
} // namespace esphome
|
||||
@@ -29,3 +29,7 @@
|
||||
#define USE_CAPTIVE_PORTAL
|
||||
#define ESPHOME_BOARD "dummy_board"
|
||||
#define USE_MDNS
|
||||
#define USE_SOCKET_IMPL_LWIP_TCP
|
||||
#define USE_SOCKET_IMPL_BSD_SOCKETS
|
||||
#define USE_API_NOISE
|
||||
#define USE_API_PLAINTEXT
|
||||
|
||||
@@ -55,6 +55,15 @@ double random_double() { return random_uint32() / double(UINT32_MAX); }
|
||||
|
||||
float random_float() { return float(random_double()); }
|
||||
|
||||
void fill_random(uint8_t *data, size_t len) {
|
||||
#ifdef ARDUINO_ARCH_ESP32
|
||||
esp_fill_random(data, len);
|
||||
#else
|
||||
int err = os_get_random(data, len);
|
||||
assert(err == 0);
|
||||
#endif
|
||||
}
|
||||
|
||||
static uint32_t fast_random_seed = 0; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
|
||||
void fast_random_set_seed(uint32_t seed) { fast_random_seed = seed; }
|
||||
|
||||
@@ -109,6 +109,8 @@ double random_double();
|
||||
/// Returns a random float between 0 and 1. Essentially just casts random_double() to a float.
|
||||
float random_float();
|
||||
|
||||
void fill_random(uint8_t *data, size_t len);
|
||||
|
||||
void fast_random_set_seed(uint32_t seed);
|
||||
uint32_t fast_random_32();
|
||||
uint16_t fast_random_16();
|
||||
|
||||
@@ -36,6 +36,7 @@ lib_deps =
|
||||
6306@1.0.3 ; HM3301
|
||||
glmnet/Dsmr@0.3 ; used by dsmr
|
||||
rweather/Crypto@0.2.0 ; used by dsmr
|
||||
esphome/noise-c@0.1.0
|
||||
|
||||
build_flags =
|
||||
-DESPHOME_LOG_LEVEL=ESPHOME_LOG_LEVEL_VERY_VERBOSE
|
||||
|
||||
Reference in New Issue
Block a user