Compare commits
6 Commits
2022.2.1
...
native-mqt
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3647e32692 | ||
|
|
a81fc6e85d | ||
|
|
c19d893e4e | ||
|
|
f4183778e3 | ||
|
|
11e8bd77e2 | ||
|
|
a39b2c4ac7 |
561
esphome/components/mqtt/mqtt_backend.cpp
Normal file
561
esphome/components/mqtt/mqtt_backend.cpp
Normal file
@@ -0,0 +1,561 @@
|
||||
#include "mqtt_backend.h"
|
||||
|
||||
#ifdef USE_MQTT
|
||||
|
||||
#include "esphome/core/log.h"
|
||||
#include "esphome/core/hal.h"
|
||||
#include "esphome/core/helpers.h"
|
||||
|
||||
namespace esphome {
|
||||
namespace mqtt {
|
||||
|
||||
static const char *TAG = "mqtt.backend";
|
||||
|
||||
ErrorCode util::ConnectionEstablisher::init(const std::string &host, uint16_t port, uint32_t timeout) {
|
||||
if (state_ != State::UNINIT) {
|
||||
enter_error_();
|
||||
ESP_LOGD(TAG, "conn init bad state");
|
||||
return ErrorCode::BAD_STATE;
|
||||
}
|
||||
struct addrinfo hints;
|
||||
hints.ai_family = AF_UNSPEC;
|
||||
hints.ai_socktype = SOCK_STREAM;
|
||||
std::string port_s = to_string(port);
|
||||
getaddrinfo_ = socket::getaddrinfo_async(host.c_str(), port_s.c_str(), &hints);
|
||||
if (!getaddrinfo_) {
|
||||
enter_error_();
|
||||
return ErrorCode::RESOLVE_ERROR;
|
||||
}
|
||||
start_ = millis();
|
||||
timeout_ = timeout;
|
||||
state_ = State::RESOLVING;
|
||||
return ErrorCode::OK;
|
||||
}
|
||||
|
||||
void util::ConnectionEstablisher::enter_error_() {
|
||||
state_ = State::ERROR;
|
||||
getaddrinfo_.reset();
|
||||
socket_.reset();
|
||||
}
|
||||
|
||||
ErrorCode util::ConnectionEstablisher::loop() {
|
||||
ErrorCode ec;
|
||||
uint32_t now = millis();
|
||||
|
||||
switch (state_) {
|
||||
case State::UNINIT: {
|
||||
enter_error_();
|
||||
state_ = State::ERROR;
|
||||
ESP_LOGD(TAG, "conn uninit bad state");
|
||||
return ErrorCode::BAD_STATE;
|
||||
}
|
||||
|
||||
case State::RESOLVING: {
|
||||
if (getaddrinfo_->completed()) {
|
||||
struct addrinfo *res;
|
||||
int r = getaddrinfo_->fetch_result(&res);
|
||||
if (r != 0) {
|
||||
enter_error_();
|
||||
ESP_LOGW(TAG, "Address resolve failed with error %s", gai_strerror(r));
|
||||
return ErrorCode::RESOLVE_ERROR;
|
||||
}
|
||||
if (res == nullptr) {
|
||||
enter_error_();
|
||||
ESP_LOGW(TAG, "Address resolve returned no results");
|
||||
return ErrorCode::RESOLVE_ERROR;
|
||||
}
|
||||
|
||||
ESP_LOGD(TAG, "address resolved!");
|
||||
|
||||
socket_ = socket::socket(res->ai_family, res->ai_socktype, res->ai_protocol);
|
||||
if (!socket_) {
|
||||
freeaddrinfo(res);
|
||||
enter_error_();
|
||||
ESP_LOGW(TAG, "Socket creation failed with error %s", strerror(errno));
|
||||
return ErrorCode::SOCKET_ERROR;
|
||||
}
|
||||
|
||||
r = socket_->setblocking(false);
|
||||
if (r != 0) {
|
||||
enter_error_();
|
||||
ESP_LOGV(TAG, "Setting nonblocking socket failed with error %s", strerror(errno));
|
||||
return ErrorCode::SOCKET_ERROR;
|
||||
}
|
||||
|
||||
r = socket_->connect(res->ai_addr, res->ai_addrlen);
|
||||
freeaddrinfo(res);
|
||||
|
||||
if (r == 0) {
|
||||
// connection established immediately
|
||||
getaddrinfo_.reset();
|
||||
state_ = State::CONNECTED;
|
||||
return ErrorCode::OK;
|
||||
} else if (errno == EINPROGRESS) {
|
||||
getaddrinfo_.reset();
|
||||
state_ = State::CONNECTING;
|
||||
} else {
|
||||
enter_error_();
|
||||
ESP_LOGW(TAG, "Socket connect failed with error %s", strerror(errno));
|
||||
return ErrorCode::SOCKET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
if (now - start_ >= timeout_) {
|
||||
enter_error_();
|
||||
ESP_LOGW(TAG, "Timeout resolving address");
|
||||
return ErrorCode::TIMEOUT;
|
||||
}
|
||||
|
||||
return ErrorCode::IN_PROGRESS;
|
||||
}
|
||||
|
||||
case State::CONNECTING: {
|
||||
int r = socket_->connect_finished();
|
||||
if (r == 0) {
|
||||
// connection established
|
||||
state_ = State::CONNECTED;
|
||||
return ErrorCode::OK;
|
||||
} else if (errno == EINPROGRESS) {
|
||||
// not established yet
|
||||
|
||||
if (now - start_ >= timeout_) {
|
||||
enter_error_();
|
||||
ESP_LOGW(TAG, "Timeout connecting to address");
|
||||
return ErrorCode::TIMEOUT;
|
||||
}
|
||||
|
||||
return ErrorCode::IN_PROGRESS;
|
||||
} else {
|
||||
enter_error_();
|
||||
ESP_LOGW(TAG, "Socket connect failed with error %s", strerror(errno));
|
||||
return ErrorCode::SOCKET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
case State::CONNECTED: {
|
||||
return ErrorCode::OK;
|
||||
}
|
||||
|
||||
case State::FINISHED:
|
||||
case State::ERROR:
|
||||
default: {
|
||||
return ErrorCode::BAD_STATE;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<socket::Socket> util::ConnectionEstablisher::extract_socket() {
|
||||
if (state_ != State::CONNECTED)
|
||||
return nullptr;
|
||||
state_ = State::FINISHED;
|
||||
return std::move(socket_);
|
||||
}
|
||||
|
||||
ErrorCode util::BufferedWriter::write(const std::unique_ptr<socket::Socket> &sock, const uint8_t *data, size_t len, bool do_buffer) {
|
||||
if (len == 0)
|
||||
return ErrorCode::OK;
|
||||
ErrorCode ec;
|
||||
|
||||
if (!tx_buf_.empty()) {
|
||||
// try to empty tx_buf_ first
|
||||
ec = try_drain(sock);
|
||||
if (ec != ErrorCode::OK && ec != ErrorCode::WOULD_BLOCK)
|
||||
return ec;
|
||||
}
|
||||
|
||||
if (!tx_buf_.empty()) {
|
||||
// tx buf not empty, can't write now because then stream would be inconsistent
|
||||
if (!do_buffer)
|
||||
return ErrorCode::WOULD_BLOCK;
|
||||
|
||||
tx_buf_.insert(tx_buf_.end(), data, data + len);
|
||||
return ErrorCode::OK;
|
||||
}
|
||||
|
||||
ssize_t sent = sock->write(data, len);
|
||||
if (sent == 0 || (sent == -1 && (errno == EWOULDBLOCK || errno == EAGAIN))) {
|
||||
// operation would block, add to tx_buf if buffering
|
||||
if (!do_buffer)
|
||||
return ErrorCode::WOULD_BLOCK;
|
||||
tx_buf_.insert(tx_buf_.end(), data, data + len);
|
||||
return ErrorCode::OK;
|
||||
} else if (sent == -1) {
|
||||
// an error occured
|
||||
ESP_LOGV(TAG, "Socket write failed with errno %d", errno);
|
||||
return ErrorCode::SOCKET_ERROR;
|
||||
} else if ((size_t) sent != len) {
|
||||
// partially sent, add end to tx_buf (even if not set to buffering, to prevent
|
||||
// partial packet transmission)
|
||||
tx_buf_.insert(tx_buf_.end(), data + sent, data + len);
|
||||
return ErrorCode::OK;
|
||||
}
|
||||
// fully sent
|
||||
return ErrorCode::OK;
|
||||
}
|
||||
ErrorCode util::BufferedWriter::try_drain(const std::unique_ptr<socket::Socket> &sock) {
|
||||
// try send from tx_buf
|
||||
while (!tx_buf_.empty()) {
|
||||
ssize_t sent = sock->write(tx_buf_.data(), tx_buf_.size());
|
||||
if (sent == 0 || (sent == -1 && (errno = EWOULDBLOCK || errno == EAGAIN))) {
|
||||
break;
|
||||
} else if (sent == -1) {
|
||||
ESP_LOGV(TAG, "Socket write failed with errno %d", errno);
|
||||
return ErrorCode::SOCKET_ERROR;
|
||||
}
|
||||
|
||||
// TODO: inefficient if multiple packets in txbuf
|
||||
// replace with deque of buffers
|
||||
tx_buf_.erase(tx_buf_.begin(), tx_buf_.begin() + sent);
|
||||
}
|
||||
|
||||
return ErrorCode::OK;
|
||||
}
|
||||
|
||||
ErrorCode MQTTConnection::init(ConnectParams *params, MQTTSession *session) {
|
||||
if (state_ != State::UNINIT) {
|
||||
enter_error_();
|
||||
return ErrorCode::BAD_STATE;
|
||||
}
|
||||
params_ = params;
|
||||
session_ = session;
|
||||
connection_establisher_ = make_unique<util::ConnectionEstablisher>();
|
||||
ErrorCode ec = connection_establisher_->init(params_->host, params_->port, 5000);
|
||||
if (ec != ErrorCode::OK) {
|
||||
enter_error_();
|
||||
return ec;
|
||||
}
|
||||
|
||||
state_ = State::CONNECTING;
|
||||
return ErrorCode::OK;
|
||||
}
|
||||
|
||||
ErrorCode MQTTConnection::loop() {
|
||||
ErrorCode ec;
|
||||
|
||||
switch (state_) {
|
||||
case State::CONNECTING: {
|
||||
ec = connection_establisher_->loop();
|
||||
if (ec == ErrorCode::OK) {
|
||||
// connection established
|
||||
socket_ = connection_establisher_->extract_socket();
|
||||
|
||||
ConnectPacket packet{};
|
||||
packet.client_id = params_->client_id;
|
||||
packet.username = params_->username;
|
||||
packet.password = params_->password;
|
||||
packet.will_topic = params_->will_topic;
|
||||
packet.will_message = params_->will_message;
|
||||
packet.will_qos = params_->will_qos;
|
||||
packet.will_retain = params_->will_retain;
|
||||
packet.clean_session = true;
|
||||
packet.keep_alive = params_->keep_alive;
|
||||
|
||||
std::vector<uint8_t> packet_enc;
|
||||
ec = packet.encode(packet_enc);
|
||||
if (ec != ErrorCode::OK) {
|
||||
enter_error_();
|
||||
return ec;
|
||||
}
|
||||
|
||||
ec = writer_.write(socket_, packet_enc.data(), packet_enc.size(), true);
|
||||
if (ec != ErrorCode::OK) {
|
||||
enter_error_();
|
||||
return ec;
|
||||
}
|
||||
|
||||
state_ = State::WAIT_CONNACK;
|
||||
connection_establisher_.reset();
|
||||
} else if (ec != ErrorCode::IN_PROGRESS) {
|
||||
enter_error_();
|
||||
return ec;
|
||||
}
|
||||
return ErrorCode::OK;
|
||||
}
|
||||
|
||||
case State::WAIT_CONNACK:
|
||||
case State::CONNECTED: {
|
||||
ec = writer_.try_drain(socket_);
|
||||
if (ec != ErrorCode::OK) {
|
||||
enter_error_();
|
||||
return ec;
|
||||
}
|
||||
|
||||
ec = read_packet_();
|
||||
if (ec != ErrorCode::OK && ec != ErrorCode::WOULD_BLOCK) {
|
||||
enter_error_();
|
||||
return ec;
|
||||
}
|
||||
|
||||
return ErrorCode::OK;
|
||||
}
|
||||
|
||||
case State::UNINIT:
|
||||
case State::DISCONNECTED:
|
||||
case State::ERROR:
|
||||
default: {
|
||||
enter_error_();
|
||||
ESP_LOGD(TAG, "bad state %d", (int) state_);
|
||||
return ErrorCode::BAD_STATE;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MQTTConnection::enter_error_() {
|
||||
ESP_LOGD(TAG, "enter_error");
|
||||
connection_establisher_.reset();
|
||||
socket_.reset();
|
||||
rx_header_buf_ = {};
|
||||
rx_buf_ = {};
|
||||
writer_.stop();
|
||||
state_ = State::ERROR;
|
||||
}
|
||||
|
||||
ErrorCode MQTTConnection::read_packet_() {
|
||||
if (state_ != State::CONNECTED && state_ != State::WAIT_CONNACK) {
|
||||
enter_error_();
|
||||
ESP_LOGD(TAG, "read_packet_ bad state");
|
||||
return ErrorCode::BAD_STATE;
|
||||
}
|
||||
ErrorCode ec;
|
||||
|
||||
if (!rx_header_parsed_) {
|
||||
while (true) {
|
||||
uint8_t v;
|
||||
ssize_t received = socket_->read(&v, 1);
|
||||
if (received == -1 && (errno == EWOULDBLOCK || errno == EAGAIN)) {
|
||||
// would block
|
||||
return ErrorCode::WOULD_BLOCK;
|
||||
} else if (received == -1) {
|
||||
// error
|
||||
enter_error_();
|
||||
ESP_LOGV(TAG, "Socket read failed with errno %d", errno);
|
||||
return ErrorCode::SOCKET_ERROR;
|
||||
} else if (received == 0) {
|
||||
// EOF
|
||||
enter_error_();
|
||||
ESP_LOGV(TAG, "Socket EOF");
|
||||
return ErrorCode::SOCKET_ERROR;
|
||||
}
|
||||
rx_header_buf_.push_back(v);
|
||||
|
||||
// try parse buf
|
||||
if (rx_header_buf_.size() == 1)
|
||||
continue;
|
||||
|
||||
rx_header_parsed_type_ = (rx_header_buf_[0] >> 4) & 0x0F;
|
||||
rx_header_parsed_flags_ = (rx_header_buf_[0] >> 0) & 0x0F;
|
||||
|
||||
size_t multiplier = 1, value = 0;
|
||||
size_t i = 1;
|
||||
uint8_t enc;
|
||||
bool parsed = true;
|
||||
do {
|
||||
if (i >= rx_header_buf_.size()) {
|
||||
// not enough data yet
|
||||
parsed = false;
|
||||
break;
|
||||
}
|
||||
enc = rx_header_buf_[i];
|
||||
value += (enc & 0x7F) * multiplier;
|
||||
multiplier <<= 7;
|
||||
} while (enc & 0x80);
|
||||
if (!parsed)
|
||||
continue;
|
||||
|
||||
rx_header_parsed_ = true;
|
||||
rx_header_parsed_len_ = value;
|
||||
}
|
||||
}
|
||||
// header reading done
|
||||
|
||||
// reserve space for body
|
||||
if (rx_buf_.size() != rx_header_parsed_len_) {
|
||||
rx_buf_.resize(rx_header_parsed_len_);
|
||||
}
|
||||
|
||||
if (rx_buf_len_ < rx_header_parsed_len_) {
|
||||
// more data to read
|
||||
size_t to_read = rx_header_parsed_len_ - rx_buf_len_;
|
||||
ssize_t received = socket_->read(&rx_buf_[rx_buf_len_], to_read);
|
||||
if (received == -1) {
|
||||
if (errno == EWOULDBLOCK || errno == EAGAIN) {
|
||||
return ErrorCode::WOULD_BLOCK;
|
||||
}
|
||||
enter_error_();
|
||||
ESP_LOGV(TAG, "Socket read failed with errno %d", errno);
|
||||
return ErrorCode::SOCKET_ERROR;
|
||||
} else if (received == 0) {
|
||||
enter_error_();
|
||||
ESP_LOGD(TAG, "Connection closed");
|
||||
return ErrorCode::CONNECTION_CLOSED;
|
||||
}
|
||||
rx_buf_len_ += received;
|
||||
if ((size_t) received != to_read) {
|
||||
// not all read
|
||||
return ErrorCode::WOULD_BLOCK;
|
||||
}
|
||||
}
|
||||
// body reading done
|
||||
|
||||
ec = handle_packet_(rx_header_parsed_type_, rx_header_parsed_flags_, rx_buf_.data(), rx_buf_.size());
|
||||
// prepare for next packet
|
||||
rx_header_parsed_ = false;
|
||||
return ec;
|
||||
}
|
||||
|
||||
ErrorCode MQTTConnection::handle_packet_(uint8_t packet_type, uint8_t flags, const uint8_t *data, size_t len) {
|
||||
util::Parser parser(data, len);
|
||||
ErrorCode ec;
|
||||
|
||||
switch (static_cast<PacketType>(packet_type)) {
|
||||
case PacketType::CONNACK: {
|
||||
if (state_ != State::WAIT_CONNACK) {
|
||||
enter_error_();
|
||||
ESP_LOGV(TAG, "Bad state for connack %d", (int) state_);
|
||||
return ErrorCode::BAD_STATE;
|
||||
}
|
||||
ConnackPacket packet{};
|
||||
ec = packet.decode(flags, parser);
|
||||
if (ec != ErrorCode::OK) {
|
||||
enter_error_();
|
||||
ESP_LOGV(TAG, "Error decoding connack packet %d", (int) ec);
|
||||
return ec;
|
||||
}
|
||||
|
||||
if (packet.connect_return_code != ConnectReturnCode::ACCEPTED) {
|
||||
const char *reason;
|
||||
switch (packet.connect_return_code) {
|
||||
case ConnectReturnCode::UNACCEPTABLE_PROTOCOL_VERSION:
|
||||
reason = "unacceptable protocol version";
|
||||
break;
|
||||
case ConnectReturnCode::IDENTIFIER_REJECTED:
|
||||
reason = "identifier rejected";
|
||||
break;
|
||||
case ConnectReturnCode::SERVER_UNAVAILABLE:
|
||||
reason = "server unavailable";
|
||||
break;
|
||||
case ConnectReturnCode::BAD_USER_NAME_OR_PASSWORD:
|
||||
reason = "bad user name or password";
|
||||
break;
|
||||
case ConnectReturnCode::NOT_AUTHORIZED:
|
||||
reason = "not authorized";
|
||||
break;
|
||||
default:
|
||||
reason = "unknown";
|
||||
break;
|
||||
}
|
||||
enter_error_();
|
||||
ESP_LOGW(TAG, "Connect failed: %s", reason);
|
||||
return ErrorCode::PROTOCOL_ERROR;
|
||||
}
|
||||
|
||||
ESP_LOGD(TAG, "Connected!");
|
||||
state_ = State::CONNECTED;
|
||||
return ErrorCode::OK;
|
||||
}
|
||||
case PacketType::PUBLISH: {
|
||||
if (state_ != State::CONNECTED) {
|
||||
enter_error_();
|
||||
ESP_LOGV(TAG, "Bad state for publish %d", (int) state_);
|
||||
return ErrorCode::BAD_STATE;
|
||||
}
|
||||
|
||||
return ErrorCode::OK;
|
||||
}
|
||||
|
||||
case PacketType::PUBACK:
|
||||
case PacketType::PUBREC:
|
||||
case PacketType::PUBREL:
|
||||
case PacketType::PUBCOMP:
|
||||
case PacketType::SUBACK:
|
||||
case PacketType::UNSUBACK: {
|
||||
// TODO
|
||||
ESP_LOGD(TAG, "Received packet with type %u", packet_type);
|
||||
return ErrorCode::OK;
|
||||
}
|
||||
|
||||
case PacketType::PINGRESP: {
|
||||
// TODO rx timer
|
||||
return ErrorCode::OK;
|
||||
}
|
||||
|
||||
case PacketType::CONNECT:
|
||||
case PacketType::DISCONNECT:
|
||||
case PacketType::SUBSCRIBE:
|
||||
case PacketType::UNSUBSCRIBE:
|
||||
case PacketType::PINGREQ:
|
||||
default: {
|
||||
enter_error_();
|
||||
ESP_LOGW(TAG, "Received unknown packet type %u", packet_type);
|
||||
return ErrorCode::UNEXPECTED;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
ErrorCode MQTTConnection::publish(std::string topic, std::vector<uint8_t> message, bool retain, QOSLevel qos) {
|
||||
PublishPacket packet{};
|
||||
packet.topic = std::move(topic);
|
||||
packet.message = std::move(message);
|
||||
packet.retain = retain;
|
||||
packet.qos = qos;
|
||||
if (packet.qos != QOSLevel::QOS0) {
|
||||
packet.packet_identifier = session_->create_packet_id();
|
||||
}
|
||||
packet.dup = false;
|
||||
|
||||
std::vector<uint8_t> packet_enc;
|
||||
ErrorCode ec = packet.encode(packet_enc);
|
||||
if (ec != ErrorCode::OK)
|
||||
return ec;
|
||||
|
||||
ec = writer_.write(socket_, packet_enc.data(), packet_enc.size(), qos != QOSLevel::QOS0);
|
||||
if (ec != ErrorCode::OK && ec != ErrorCode::WOULD_BLOCK) {
|
||||
enter_error_();
|
||||
ESP_LOGV(TAG, "publish write failed");
|
||||
return ec;
|
||||
}
|
||||
|
||||
return ec;
|
||||
}
|
||||
ErrorCode MQTTConnection::subscribe(std::vector<Subscription> subscriptions) {
|
||||
SubscribePacket packet{};
|
||||
packet.subscriptions = std::move(subscriptions);
|
||||
packet.packet_identifier = session_->create_packet_id();
|
||||
|
||||
std::vector<uint8_t> packet_enc;
|
||||
ErrorCode ec = packet.encode(packet_enc);
|
||||
if (ec != ErrorCode::OK)
|
||||
return ec;
|
||||
|
||||
ec = writer_.write(socket_, packet_enc.data(), packet_enc.size(), true);
|
||||
if (ec != ErrorCode::OK) {
|
||||
enter_error_();
|
||||
ESP_LOGV(TAG, "subscribe write failed");
|
||||
return ec;
|
||||
}
|
||||
return ec;
|
||||
}
|
||||
ErrorCode MQTTConnection::unsubscribe(std::vector<std::string> topic_filters) {
|
||||
UnsubscribePacket packet{};
|
||||
packet.topic_filters = std::move(topic_filters);
|
||||
packet.packet_identifier = session_->create_packet_id();
|
||||
|
||||
std::vector<uint8_t> packet_enc;
|
||||
ErrorCode ec = packet.encode(packet_enc);
|
||||
if (ec != ErrorCode::OK)
|
||||
return ec;
|
||||
|
||||
ec = writer_.write(socket_, packet_enc.data(), packet_enc.size(), true);
|
||||
if (ec != ErrorCode::OK) {
|
||||
enter_error_();
|
||||
ESP_LOGV(TAG, "unsubscribe write failed");
|
||||
return ec;
|
||||
}
|
||||
return ec;
|
||||
}
|
||||
|
||||
} // namespace mqtt
|
||||
} // namespace esphome
|
||||
|
||||
#endif // USE_MQTT
|
||||
143
esphome/components/mqtt/mqtt_backend.h
Normal file
143
esphome/components/mqtt/mqtt_backend.h
Normal file
@@ -0,0 +1,143 @@
|
||||
#pragma once
|
||||
|
||||
#include "esphome/core/defines.h"
|
||||
|
||||
#ifdef USE_MQTT
|
||||
|
||||
#include "packets.h"
|
||||
#include "esphome/components/socket/socket.h"
|
||||
#include "esphome/components/socket/getaddrinfo.h"
|
||||
#include <memory>
|
||||
#include <set>
|
||||
|
||||
namespace esphome {
|
||||
namespace mqtt {
|
||||
|
||||
namespace util {
|
||||
|
||||
class ConnectionEstablisher {
|
||||
public:
|
||||
ErrorCode init(const std::string &host, uint16_t port, uint32_t timeout);
|
||||
ErrorCode loop();
|
||||
// Should only be called when loop() returns OK, is guaranteed to succeed
|
||||
std::unique_ptr<socket::Socket> extract_socket();
|
||||
|
||||
protected:
|
||||
void enter_error_();
|
||||
|
||||
std::unique_ptr<socket::Socket> socket_;
|
||||
std::unique_ptr<socket::GetaddrinfoFuture> getaddrinfo_;
|
||||
uint32_t timeout_;
|
||||
uint32_t start_;
|
||||
|
||||
enum class State {
|
||||
UNINIT = 0,
|
||||
RESOLVING = 1,
|
||||
CONNECTING = 2,
|
||||
CONNECTED = 3,
|
||||
FINISHED = 4,
|
||||
ERROR = 5,
|
||||
} state_ = State::UNINIT;
|
||||
};
|
||||
|
||||
class BufferedWriter {
|
||||
public:
|
||||
ErrorCode write(const std::unique_ptr<socket::Socket> &sock, const uint8_t *data, size_t len,
|
||||
bool do_buffer);
|
||||
ErrorCode try_drain(const std::unique_ptr<socket::Socket> &sock);
|
||||
void stop() {
|
||||
tx_buf_ = {};
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<uint8_t> tx_buf_;
|
||||
};
|
||||
|
||||
} // namespace util
|
||||
|
||||
struct ConnectParams {
|
||||
std::string host;
|
||||
uint16_t port;
|
||||
|
||||
std::string client_id;
|
||||
optional<std::string> username;
|
||||
optional<std::vector<uint8_t>> password;
|
||||
std::string will_topic;
|
||||
std::vector<uint8_t> will_message;
|
||||
QOSLevel will_qos;
|
||||
bool will_retain;
|
||||
uint16_t keep_alive;
|
||||
};
|
||||
|
||||
class MQTTSession {
|
||||
public:
|
||||
bool get_has_session() const { return has_session_; }
|
||||
void set_has_session(bool has_session) { has_session_ = has_session; }
|
||||
void clean_session() {
|
||||
packet_id_counter_ = 0;
|
||||
used_packet_ids_.clear();
|
||||
}
|
||||
uint16_t create_packet_id() {
|
||||
while (true) {
|
||||
packet_id_counter_++;
|
||||
if (packet_id_counter_ == 0 || used_packet_ids_.count(packet_id_counter_) > 0) {
|
||||
continue;
|
||||
}
|
||||
used_packet_ids_.insert(packet_id_counter_);
|
||||
return packet_id_counter_;
|
||||
}
|
||||
}
|
||||
void return_packet_id(uint16_t packet_id) {
|
||||
used_packet_ids_.erase(packet_id);
|
||||
}
|
||||
protected:
|
||||
bool has_session_ = false;
|
||||
uint32_t packet_id_counter_ = 0;
|
||||
std::set<uint16_t> used_packet_ids_;
|
||||
};
|
||||
|
||||
class MQTTConnection {
|
||||
public:
|
||||
ErrorCode init(ConnectParams *params, MQTTSession *session);
|
||||
ErrorCode loop();
|
||||
bool is_connected() { return state_ == State::CONNECTED; }
|
||||
|
||||
ErrorCode publish(std::string topic, std::vector<uint8_t> message, bool retain, QOSLevel qos);
|
||||
ErrorCode subscribe(std::vector<Subscription> subscriptions);
|
||||
ErrorCode unsubscribe(std::vector<std::string> topic_filters);
|
||||
|
||||
protected:
|
||||
void enter_error_();
|
||||
ErrorCode read_packet_();
|
||||
ErrorCode handle_packet_(uint8_t packet_type, uint8_t flags, const uint8_t *data, size_t len);
|
||||
|
||||
ConnectParams *params_;
|
||||
MQTTSession *session_;
|
||||
std::unique_ptr<util::ConnectionEstablisher> connection_establisher_;
|
||||
std::unique_ptr<socket::Socket> socket_;
|
||||
|
||||
std::vector<uint8_t> rx_header_buf_;
|
||||
bool rx_header_parsed_ = false;
|
||||
uint8_t rx_header_parsed_type_ = 0;
|
||||
uint8_t rx_header_parsed_flags_ = 0;
|
||||
size_t rx_header_parsed_len_ = 0;
|
||||
|
||||
std::vector<uint8_t> rx_buf_;
|
||||
size_t rx_buf_len_ = 0;
|
||||
|
||||
util::BufferedWriter writer_;
|
||||
|
||||
enum class State {
|
||||
UNINIT = 0,
|
||||
CONNECTING = 1,
|
||||
WAIT_CONNACK = 2,
|
||||
CONNECTED = 3,
|
||||
DISCONNECTED = 4,
|
||||
ERROR = 5,
|
||||
} state_ = State::UNINIT;
|
||||
};
|
||||
|
||||
} // namespace mqtt
|
||||
} // namespace esphome
|
||||
|
||||
#endif // USE_MQTT
|
||||
@@ -28,7 +28,7 @@ MQTTClientComponent::MQTTClientComponent() {
|
||||
// Connection
|
||||
void MQTTClientComponent::setup() {
|
||||
ESP_LOGCONFIG(TAG, "Setting up MQTT...");
|
||||
this->mqtt_client_.onMessage([this](char const *topic, char *payload, AsyncMqttClientMessageProperties properties,
|
||||
/*this->mqtt_client_.onMessage([this](char const *topic, char *payload, AsyncMqttClientMessageProperties properties,
|
||||
size_t len, size_t index, size_t total) {
|
||||
if (index == 0)
|
||||
this->payload_buffer_.reserve(total);
|
||||
@@ -45,7 +45,7 @@ void MQTTClientComponent::setup() {
|
||||
this->mqtt_client_.onDisconnect([this](AsyncMqttClientDisconnectReason reason) {
|
||||
this->state_ = MQTT_CLIENT_DISCONNECTED;
|
||||
this->disconnect_reason_ = reason;
|
||||
});
|
||||
});*/
|
||||
#ifdef USE_LOGGER
|
||||
if (this->is_log_message_enabled() && logger::global_logger != nullptr) {
|
||||
logger::global_logger->add_on_log_callback([this](int level, const char *tag, const char *message) {
|
||||
@@ -58,12 +58,11 @@ void MQTTClientComponent::setup() {
|
||||
#endif
|
||||
|
||||
this->last_connected_ = millis();
|
||||
this->start_dnslookup_();
|
||||
this->start_connect_();
|
||||
}
|
||||
void MQTTClientComponent::dump_config() {
|
||||
ESP_LOGCONFIG(TAG, "MQTT:");
|
||||
ESP_LOGCONFIG(TAG, " Server Address: %s:%u (%s)", this->credentials_.address.c_str(), this->credentials_.port,
|
||||
this->ip_.str().c_str());
|
||||
ESP_LOGCONFIG(TAG, " Server Address: %s:%u", this->credentials_.address.c_str(), this->credentials_.port);
|
||||
ESP_LOGCONFIG(TAG, " Username: " LOG_SECRET("'%s'"), this->credentials_.username.c_str());
|
||||
ESP_LOGCONFIG(TAG, " Client ID: " LOG_SECRET("'%s'"), this->credentials_.client_id.c_str());
|
||||
if (!this->discovery_info_.prefix.empty()) {
|
||||
@@ -80,131 +79,68 @@ void MQTTClientComponent::dump_config() {
|
||||
}
|
||||
bool MQTTClientComponent::can_proceed() { return this->is_connected(); }
|
||||
|
||||
void MQTTClientComponent::start_dnslookup_() {
|
||||
|
||||
void MQTTClientComponent::start_connect_() {
|
||||
if (!network::is_connected())
|
||||
return;
|
||||
|
||||
for (auto &subscription : this->subscriptions_) {
|
||||
subscription.subscribed = false;
|
||||
subscription.resubscribe_timeout = 0;
|
||||
}
|
||||
|
||||
this->status_set_warning();
|
||||
this->dns_resolve_error_ = false;
|
||||
this->dns_resolved_ = false;
|
||||
ip_addr_t addr;
|
||||
#ifdef USE_ESP32
|
||||
err_t err = dns_gethostbyname_addrtype(this->credentials_.address.c_str(), &addr,
|
||||
MQTTClientComponent::dns_found_callback, this, LWIP_DNS_ADDRTYPE_IPV4);
|
||||
#endif
|
||||
#ifdef USE_ESP8266
|
||||
err_t err = dns_gethostbyname(this->credentials_.address.c_str(), &addr,
|
||||
esphome::mqtt::MQTTClientComponent::dns_found_callback, this);
|
||||
#endif
|
||||
switch (err) {
|
||||
case ERR_OK: {
|
||||
// Got IP immediately
|
||||
this->dns_resolved_ = true;
|
||||
#ifdef USE_ESP32
|
||||
this->ip_ = addr.u_addr.ip4.addr;
|
||||
#endif
|
||||
#ifdef USE_ESP8266
|
||||
this->ip_ = addr.addr;
|
||||
#endif
|
||||
this->start_connect_();
|
||||
return;
|
||||
}
|
||||
case ERR_INPROGRESS: {
|
||||
// wait for callback
|
||||
ESP_LOGD(TAG, "Resolving MQTT broker IP address...");
|
||||
break;
|
||||
}
|
||||
default:
|
||||
case ERR_ARG: {
|
||||
// error
|
||||
#if defined(USE_ESP8266)
|
||||
ESP_LOGW(TAG, "Error resolving MQTT broker IP address: %ld", err);
|
||||
#else
|
||||
ESP_LOGW(TAG, "Error resolving MQTT broker IP address: %d", err);
|
||||
#endif
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
this->state_ = MQTT_CLIENT_RESOLVING_ADDRESS;
|
||||
this->connect_begin_ = millis();
|
||||
}
|
||||
void MQTTClientComponent::check_dnslookup_() {
|
||||
if (!this->dns_resolved_ && millis() - this->connect_begin_ > 20000) {
|
||||
this->dns_resolve_error_ = true;
|
||||
}
|
||||
|
||||
if (this->dns_resolve_error_) {
|
||||
ESP_LOGW(TAG, "Couldn't resolve IP address for '%s'!", this->credentials_.address.c_str());
|
||||
this->state_ = MQTT_CLIENT_DISCONNECTED;
|
||||
return;
|
||||
}
|
||||
|
||||
if (!this->dns_resolved_) {
|
||||
return;
|
||||
}
|
||||
|
||||
ESP_LOGD(TAG, "Resolved broker IP address to %s", this->ip_.str().c_str());
|
||||
this->start_connect_();
|
||||
}
|
||||
#if defined(USE_ESP8266) && LWIP_VERSION_MAJOR == 1
|
||||
void MQTTClientComponent::dns_found_callback(const char *name, ip_addr_t *ipaddr, void *callback_arg) {
|
||||
#else
|
||||
void MQTTClientComponent::dns_found_callback(const char *name, const ip_addr_t *ipaddr, void *callback_arg) {
|
||||
#endif
|
||||
auto *a_this = (MQTTClientComponent *) callback_arg;
|
||||
if (ipaddr == nullptr) {
|
||||
a_this->dns_resolve_error_ = true;
|
||||
} else {
|
||||
#ifdef USE_ESP32
|
||||
a_this->ip_ = ipaddr->u_addr.ip4.addr;
|
||||
#endif
|
||||
#ifdef USE_ESP8266
|
||||
a_this->ip_ = ipaddr->addr;
|
||||
#endif
|
||||
a_this->dns_resolved_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
void MQTTClientComponent::start_connect_() {
|
||||
if (!network::is_connected())
|
||||
return;
|
||||
|
||||
ESP_LOGI(TAG, "Connecting to MQTT...");
|
||||
// Force disconnect first
|
||||
this->mqtt_client_.disconnect(true);
|
||||
|
||||
this->mqtt_client_.setClientId(this->credentials_.client_id.c_str());
|
||||
const char *username = nullptr;
|
||||
if (!this->credentials_.username.empty())
|
||||
username = this->credentials_.username.c_str();
|
||||
const char *password = nullptr;
|
||||
if (!this->credentials_.password.empty())
|
||||
password = this->credentials_.password.c_str();
|
||||
conn_params_.host = credentials_.address;
|
||||
conn_params_.port = credentials_.port;
|
||||
conn_params_.client_id = credentials_.client_id;
|
||||
|
||||
this->mqtt_client_.setCredentials(username, password);
|
||||
|
||||
this->mqtt_client_.setServer((uint32_t) this->ip_, this->credentials_.port);
|
||||
if (!this->last_will_.topic.empty()) {
|
||||
this->mqtt_client_.setWill(this->last_will_.topic.c_str(), this->last_will_.qos, this->last_will_.retain,
|
||||
this->last_will_.payload.c_str(), this->last_will_.payload.length());
|
||||
if (!credentials_.username.empty())
|
||||
conn_params_.username = credentials_.username;
|
||||
else
|
||||
conn_params_.username.reset();
|
||||
if (!credentials_.password.empty()) {
|
||||
std::vector<uint8_t> pwd{credentials_.password.begin(), credentials_.password.end()};
|
||||
conn_params_.password = pwd;
|
||||
} else {
|
||||
conn_params_.password.reset();
|
||||
}
|
||||
|
||||
if (!last_will_.topic.empty()) {
|
||||
conn_params_.will_topic = last_will_.topic;
|
||||
std::vector<uint8_t> msg{last_will_.payload.begin(), last_will_.payload.end()};
|
||||
conn_params_.will_message = msg;
|
||||
conn_params_.will_retain = last_will_.retain;
|
||||
conn_params_.will_qos = static_cast<QOSLevel>(last_will_.qos);
|
||||
} else {
|
||||
conn_params_.will_topic = "";
|
||||
conn_params_.will_message.clear();
|
||||
conn_params_.will_retain = false;
|
||||
conn_params_.will_qos = QOSLevel::QOS0;
|
||||
}
|
||||
|
||||
conn_ = make_unique<MQTTConnection>();
|
||||
ErrorCode ec = conn_->init(&conn_params_, &sess_);
|
||||
if (ec != ErrorCode::OK) {
|
||||
ESP_LOGW(TAG, "connection init failed: %d", (int) ec);
|
||||
return;
|
||||
}
|
||||
|
||||
this->mqtt_client_.connect();
|
||||
this->state_ = MQTT_CLIENT_CONNECTING;
|
||||
this->connect_begin_ = millis();
|
||||
}
|
||||
bool MQTTClientComponent::is_connected() {
|
||||
return this->state_ == MQTT_CLIENT_CONNECTED && this->mqtt_client_.connected();
|
||||
return this->state_ == MQTT_CLIENT_CONNECTED && this->conn_->is_connected();
|
||||
}
|
||||
|
||||
void MQTTClientComponent::check_connected() {
|
||||
if (!this->mqtt_client_.connected()) {
|
||||
if (millis() - this->connect_begin_ > 60000) {
|
||||
this->state_ = MQTT_CLIENT_DISCONNECTED;
|
||||
this->start_dnslookup_();
|
||||
if (conn_ && !conn_->is_connected()) {
|
||||
ErrorCode ec = conn_->loop();
|
||||
if (ec != ErrorCode::OK) {
|
||||
ESP_LOGW(TAG, "check connected loop failed: %d", (int) ec);
|
||||
state_ = MQTT_CLIENT_DISCONNECTED;
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -223,64 +159,25 @@ void MQTTClientComponent::check_connected() {
|
||||
}
|
||||
|
||||
void MQTTClientComponent::loop() {
|
||||
if (this->disconnect_reason_.has_value()) {
|
||||
const LogString *reason_s;
|
||||
switch (*this->disconnect_reason_) {
|
||||
case AsyncMqttClientDisconnectReason::TCP_DISCONNECTED:
|
||||
reason_s = LOG_STR("TCP disconnected");
|
||||
break;
|
||||
case AsyncMqttClientDisconnectReason::MQTT_UNACCEPTABLE_PROTOCOL_VERSION:
|
||||
reason_s = LOG_STR("Unacceptable Protocol Version");
|
||||
break;
|
||||
case AsyncMqttClientDisconnectReason::MQTT_IDENTIFIER_REJECTED:
|
||||
reason_s = LOG_STR("Identifier Rejected");
|
||||
break;
|
||||
case AsyncMqttClientDisconnectReason::MQTT_SERVER_UNAVAILABLE:
|
||||
reason_s = LOG_STR("Server Unavailable");
|
||||
break;
|
||||
case AsyncMqttClientDisconnectReason::MQTT_MALFORMED_CREDENTIALS:
|
||||
reason_s = LOG_STR("Malformed Credentials");
|
||||
break;
|
||||
case AsyncMqttClientDisconnectReason::MQTT_NOT_AUTHORIZED:
|
||||
reason_s = LOG_STR("Not Authorized");
|
||||
break;
|
||||
case AsyncMqttClientDisconnectReason::ESP8266_NOT_ENOUGH_SPACE:
|
||||
reason_s = LOG_STR("Not Enough Space");
|
||||
break;
|
||||
case AsyncMqttClientDisconnectReason::TLS_BAD_FINGERPRINT:
|
||||
reason_s = LOG_STR("TLS Bad Fingerprint");
|
||||
break;
|
||||
default:
|
||||
reason_s = LOG_STR("Unknown");
|
||||
break;
|
||||
}
|
||||
if (!network::is_connected()) {
|
||||
reason_s = LOG_STR("WiFi disconnected");
|
||||
}
|
||||
ESP_LOGW(TAG, "MQTT Disconnected: %s.", LOG_STR_ARG(reason_s));
|
||||
this->disconnect_reason_.reset();
|
||||
}
|
||||
|
||||
const uint32_t now = millis();
|
||||
|
||||
switch (this->state_) {
|
||||
case MQTT_CLIENT_DISCONNECTED:
|
||||
if (now - this->connect_begin_ > 5000) {
|
||||
this->start_dnslookup_();
|
||||
}
|
||||
break;
|
||||
case MQTT_CLIENT_RESOLVING_ADDRESS:
|
||||
this->check_dnslookup_();
|
||||
break;
|
||||
case MQTT_CLIENT_CONNECTING:
|
||||
this->check_connected();
|
||||
break;
|
||||
case MQTT_CLIENT_CONNECTED:
|
||||
if (!this->mqtt_client_.connected()) {
|
||||
if (!this->conn_->is_connected()) {
|
||||
this->state_ = MQTT_CLIENT_DISCONNECTED;
|
||||
ESP_LOGW(TAG, "Lost MQTT Client connection!");
|
||||
this->start_dnslookup_();
|
||||
this->start_connect_();
|
||||
} else {
|
||||
ErrorCode ec = conn_->loop();
|
||||
if (ec != ErrorCode::OK) {
|
||||
ESP_LOGW(TAG, "loop loop failed");
|
||||
}
|
||||
|
||||
if (!this->birth_message_.topic.empty() && !this->sent_birth_message_) {
|
||||
this->sent_birth_message_ = this->publish(this->birth_message_);
|
||||
}
|
||||
@@ -303,17 +200,18 @@ bool MQTTClientComponent::subscribe_(const char *topic, uint8_t qos) {
|
||||
if (!this->is_connected())
|
||||
return false;
|
||||
|
||||
uint16_t ret = this->mqtt_client_.subscribe(topic, qos);
|
||||
yield();
|
||||
Subscription sub{};
|
||||
sub.topic_filter = topic;
|
||||
sub.requested_qos = static_cast<QOSLevel>(qos);
|
||||
ErrorCode ec = this->conn_->subscribe({sub});
|
||||
|
||||
if (ret != 0) {
|
||||
ESP_LOGV(TAG, "subscribe(topic='%s')", topic);
|
||||
} else {
|
||||
delay(5);
|
||||
if (ec != ErrorCode::OK) {
|
||||
ESP_LOGV(TAG, "Subscribe failed for topic='%s'. Will retry later.", topic);
|
||||
this->status_momentary_warning("subscribe", 1000);
|
||||
}
|
||||
return ret != 0;
|
||||
|
||||
ESP_LOGV(TAG, "subscribe(topic='%s')", topic);
|
||||
return ec == ErrorCode::OK;
|
||||
}
|
||||
void MQTTClientComponent::resubscribe_subscription_(MQTTSubscription *sub) {
|
||||
if (sub->subscribed)
|
||||
@@ -361,15 +259,12 @@ void MQTTClientComponent::subscribe_json(const std::string &topic, const mqtt_js
|
||||
}
|
||||
|
||||
void MQTTClientComponent::unsubscribe(const std::string &topic) {
|
||||
uint16_t ret = this->mqtt_client_.unsubscribe(topic.c_str());
|
||||
yield();
|
||||
if (ret != 0) {
|
||||
ESP_LOGV(TAG, "unsubscribe(topic='%s')", topic.c_str());
|
||||
} else {
|
||||
delay(5);
|
||||
ErrorCode ec = this->conn_->unsubscribe({topic});
|
||||
if (ec != ErrorCode::OK) {
|
||||
ESP_LOGV(TAG, "Unsubscribe failed for topic='%s'.", topic.c_str());
|
||||
this->status_momentary_warning("unsubscribe", 1000);
|
||||
}
|
||||
ESP_LOGV(TAG, "unsubscribe(topic='%s')", topic.c_str());
|
||||
|
||||
auto it = subscriptions_.begin();
|
||||
while (it != subscriptions_.end()) {
|
||||
@@ -393,24 +288,22 @@ bool MQTTClientComponent::publish(const std::string &topic, const char *payload,
|
||||
return false;
|
||||
}
|
||||
bool logging_topic = topic == this->log_message_.topic;
|
||||
uint16_t ret = this->mqtt_client_.publish(topic.c_str(), qos, retain, payload, payload_length);
|
||||
delay(0);
|
||||
if (ret == 0 && !logging_topic && this->is_connected()) {
|
||||
delay(0);
|
||||
ret = this->mqtt_client_.publish(topic.c_str(), qos, retain, payload, payload_length);
|
||||
delay(0);
|
||||
std::vector<uint8_t> msg;
|
||||
for (size_t i = 0; i < payload_length; i++) {
|
||||
msg.push_back(static_cast<uint8_t>(payload[i]));
|
||||
}
|
||||
ErrorCode ec = this->conn_->publish(topic, std::move(msg), retain, static_cast<QOSLevel>(qos));
|
||||
|
||||
if (!logging_topic) {
|
||||
if (ret != 0) {
|
||||
ESP_LOGV(TAG, "Publish(topic='%s' payload='%s' retain=%d)", topic.c_str(), payload, retain);
|
||||
} else {
|
||||
if (ec != ErrorCode::OK) {
|
||||
ESP_LOGV(TAG, "Publish failed for topic='%s' (len=%u). will retry later..", topic.c_str(),
|
||||
payload_length); // NOLINT
|
||||
this->status_momentary_warning("publish", 1000);
|
||||
} else {
|
||||
ESP_LOGV(TAG, "Publish(topic='%s' payload='%s' retain=%d)", topic.c_str(), payload, retain);
|
||||
}
|
||||
}
|
||||
return ret != 0;
|
||||
return ec == ErrorCode::OK;
|
||||
}
|
||||
|
||||
bool MQTTClientComponent::publish(const MQTTMessage &message) {
|
||||
@@ -500,7 +393,7 @@ bool MQTTClientComponent::is_log_message_enabled() const { return !this->log_mes
|
||||
void MQTTClientComponent::set_reboot_timeout(uint32_t reboot_timeout) { this->reboot_timeout_ = reboot_timeout; }
|
||||
void MQTTClientComponent::register_mqtt_component(MQTTComponent *component) { this->children_.push_back(component); }
|
||||
void MQTTClientComponent::set_log_level(int level) { this->log_level_ = level; }
|
||||
void MQTTClientComponent::set_keep_alive(uint16_t keep_alive_s) { this->mqtt_client_.setKeepAlive(keep_alive_s); }
|
||||
void MQTTClientComponent::set_keep_alive(uint16_t keep_alive_s) { conn_params_.keep_alive = keep_alive_s; }
|
||||
void MQTTClientComponent::set_log_message_template(MQTTMessage &&message) { this->log_message_ = std::move(message); }
|
||||
const MQTTDiscoveryInfo &MQTTClientComponent::get_discovery_info() const { return this->discovery_info_; }
|
||||
void MQTTClientComponent::set_topic_prefix(std::string topic_prefix) { this->topic_prefix_ = std::move(topic_prefix); }
|
||||
@@ -556,7 +449,7 @@ void MQTTClientComponent::on_shutdown() {
|
||||
this->publish(this->shutdown_message_);
|
||||
yield();
|
||||
}
|
||||
this->mqtt_client_.disconnect(true);
|
||||
// this->mqtt_client_.disconnect(true);
|
||||
}
|
||||
|
||||
#if ASYNC_TCP_SSL_ENABLED
|
||||
|
||||
@@ -9,8 +9,7 @@
|
||||
#include "esphome/core/log.h"
|
||||
#include "esphome/components/json/json_util.h"
|
||||
#include "esphome/components/network/ip_address.h"
|
||||
#include <AsyncMqttClient.h>
|
||||
#include "lwip/ip_addr.h"
|
||||
#include "mqtt_backend.h"
|
||||
|
||||
namespace esphome {
|
||||
namespace mqtt {
|
||||
@@ -74,7 +73,6 @@ struct MQTTDiscoveryInfo {
|
||||
|
||||
enum MQTTClientState {
|
||||
MQTT_CLIENT_DISCONNECTED = 0,
|
||||
MQTT_CLIENT_RESOLVING_ADDRESS,
|
||||
MQTT_CLIENT_CONNECTING,
|
||||
MQTT_CLIENT_CONNECTED,
|
||||
};
|
||||
@@ -116,22 +114,6 @@ class MQTTClientComponent : public Component {
|
||||
void disable_discovery();
|
||||
bool is_discovery_enabled() const;
|
||||
|
||||
#if ASYNC_TCP_SSL_ENABLED
|
||||
/** Add a SSL fingerprint to use for TCP SSL connections to the MQTT broker.
|
||||
*
|
||||
* To use this feature you first have to globally enable the `ASYNC_TCP_SSL_ENABLED` define flag.
|
||||
* This function can be called multiple times and any certificate that matches any of the provided fingerprints
|
||||
* will match. Calling this method will also automatically disable all non-ssl connections.
|
||||
*
|
||||
* @warning This is *not* secure and *not* how SSL is usually done. You'll have to add
|
||||
* a separate fingerprint for every certificate you use. Additionally, the hashing
|
||||
* algorithm used here due to the constraints of the MCU, SHA1, is known to be insecure.
|
||||
*
|
||||
* @param fingerprint The SSL fingerprint as a 20 value long std::array.
|
||||
*/
|
||||
void add_ssl_fingerprint(const std::array<uint8_t, SHA1_SIZE> &fingerprint);
|
||||
#endif
|
||||
|
||||
const Availability &get_availability();
|
||||
|
||||
/** Set the topic prefix that will be prepended to all topics together with "/". This will, in most cases,
|
||||
@@ -237,13 +219,6 @@ class MQTTClientComponent : public Component {
|
||||
protected:
|
||||
/// Reconnect to the MQTT broker if not already connected.
|
||||
void start_connect_();
|
||||
void start_dnslookup_();
|
||||
void check_dnslookup_();
|
||||
#if defined(USE_ESP8266) && LWIP_VERSION_MAJOR == 1
|
||||
static void dns_found_callback(const char *name, ip_addr_t *ipaddr, void *callback_arg);
|
||||
#else
|
||||
static void dns_found_callback(const char *name, const ip_addr_t *ipaddr, void *callback_arg);
|
||||
#endif
|
||||
|
||||
/// Re-calculate the availability property.
|
||||
void recalculate_availability_();
|
||||
@@ -272,20 +247,18 @@ class MQTTClientComponent : public Component {
|
||||
};
|
||||
std::string topic_prefix_{};
|
||||
MQTTMessage log_message_;
|
||||
std::string payload_buffer_;
|
||||
int log_level_{ESPHOME_LOG_LEVEL};
|
||||
|
||||
std::vector<MQTTSubscription> subscriptions_;
|
||||
AsyncMqttClient mqtt_client_;
|
||||
MQTTClientState state_{MQTT_CLIENT_DISCONNECTED};
|
||||
network::IPAddress ip_;
|
||||
bool dns_resolved_{false};
|
||||
bool dns_resolve_error_{false};
|
||||
std::vector<MQTTComponent *> children_;
|
||||
uint32_t reboot_timeout_{300000};
|
||||
uint32_t connect_begin_;
|
||||
uint32_t last_connected_{0};
|
||||
optional<AsyncMqttClientDisconnectReason> disconnect_reason_{};
|
||||
|
||||
std::unique_ptr<MQTTConnection> conn_;
|
||||
ConnectParams conn_params_{};
|
||||
MQTTSession sess_{};
|
||||
};
|
||||
|
||||
extern MQTTClientComponent *global_mqtt_client; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
|
||||
135
esphome/components/mqtt/packets.cpp
Normal file
135
esphome/components/mqtt/packets.cpp
Normal file
@@ -0,0 +1,135 @@
|
||||
#include "packets.h"
|
||||
|
||||
#ifdef USE_MQTT
|
||||
|
||||
namespace esphome {
|
||||
namespace mqtt {
|
||||
|
||||
namespace util {
|
||||
|
||||
ErrorCode encode_uint16(std::vector<uint8_t> &target, uint16_t value) {
|
||||
target.push_back((value >> 8) & 0xFF);
|
||||
target.push_back((value >> 0) & 0xFF);
|
||||
return ErrorCode::OK;
|
||||
}
|
||||
ErrorCode decode_uint16(Parser *parser, uint16_t *value) {
|
||||
if (parser->size_left() < 2)
|
||||
return ErrorCode::MALFORMED_PACKET;
|
||||
*value = 0;
|
||||
*value |= static_cast<uint16_t>(parser->consume()) << 8;
|
||||
*value |= static_cast<uint16_t>(parser->consume()) << 0;
|
||||
return ErrorCode::OK;
|
||||
}
|
||||
ErrorCode encode_bytes(std::vector<uint8_t> &target, const std::vector<uint8_t> &value) {
|
||||
if (value.size() > 65535)
|
||||
return ErrorCode::VALUE_TOO_LONG;
|
||||
encode_uint16(target, value.size());
|
||||
target.insert(target.end(), value.begin(), value.end());
|
||||
return ErrorCode::OK;
|
||||
}
|
||||
ErrorCode decode_bytes(Parser *parser, std::vector<uint8_t> *value) {
|
||||
uint16_t len;
|
||||
ErrorCode ec = decode_uint16(parser, &len);
|
||||
if (ec != ErrorCode::OK)
|
||||
return ec;
|
||||
if (len > parser->size_left())
|
||||
return ErrorCode::MALFORMED_PACKET;
|
||||
value->clear();
|
||||
value->reserve(len);
|
||||
for (size_t i = 0; i < len; i++)
|
||||
value->push_back(parser->consume());
|
||||
return ErrorCode::OK;
|
||||
}
|
||||
ErrorCode encode_utf8(std::vector<uint8_t> &target, const std::string &value) {
|
||||
if (value.size() > 65535)
|
||||
return ErrorCode::VALUE_TOO_LONG;
|
||||
encode_uint16(target, value.size());
|
||||
for (char c : value)
|
||||
target.push_back(static_cast<uint8_t>(c));
|
||||
return ErrorCode::OK;
|
||||
}
|
||||
ErrorCode decode_utf8(Parser *parser, std::string *value) {
|
||||
uint16_t len;
|
||||
ErrorCode ec = decode_uint16(parser, &len);
|
||||
if (ec != ErrorCode::OK)
|
||||
return ec;
|
||||
if (len > parser->size_left())
|
||||
return ErrorCode::MALFORMED_PACKET;
|
||||
value->clear();
|
||||
value->reserve(len);
|
||||
for (size_t i = 0; i < len; i++)
|
||||
value->push_back(static_cast<char>(parser->consume()));
|
||||
return ErrorCode::OK;
|
||||
}
|
||||
ErrorCode encode_varint(std::vector<uint8_t> &target, size_t value) {
|
||||
do {
|
||||
uint8_t encbyte = value % 0x80;
|
||||
value >>= 7;
|
||||
if (value > 0)
|
||||
encbyte |= 0x80;
|
||||
target.push_back(encbyte);
|
||||
} while (value > 0);
|
||||
return ErrorCode::OK;
|
||||
}
|
||||
|
||||
ErrorCode encode_fixed_header(
|
||||
std::vector<uint8_t> &target,
|
||||
PacketType packet_type,
|
||||
uint8_t flags,
|
||||
size_t remaining_length
|
||||
) {
|
||||
uint8_t head = 0;
|
||||
head |= static_cast<uint8_t>(packet_type) << 4;
|
||||
head |= flags;
|
||||
target.push_back(head);
|
||||
return encode_varint(target, remaining_length);
|
||||
}
|
||||
|
||||
ErrorCode encode_packet(
|
||||
std::vector<uint8_t> &target,
|
||||
PacketType packet_type,
|
||||
uint8_t flags
|
||||
) {
|
||||
return encode_fixed_header(target, packet_type, flags, 0);
|
||||
}
|
||||
|
||||
ErrorCode encode_packet(
|
||||
std::vector<uint8_t> &target,
|
||||
PacketType packet_type,
|
||||
uint8_t flags,
|
||||
const std::vector<uint8_t> &variable_header
|
||||
) {
|
||||
ErrorCode ec = encode_fixed_header(
|
||||
target, packet_type, flags,
|
||||
variable_header.size()
|
||||
);
|
||||
if (ec != ErrorCode::OK)
|
||||
return ec;
|
||||
target.insert(target.end(), variable_header.begin(), variable_header.end());
|
||||
return ErrorCode::OK;
|
||||
}
|
||||
|
||||
ErrorCode encode_packet(
|
||||
std::vector<uint8_t> &target,
|
||||
PacketType packet_type,
|
||||
uint8_t flags,
|
||||
const std::vector<uint8_t> &variable_header,
|
||||
const std::vector<uint8_t> &payload
|
||||
) {
|
||||
ErrorCode ec = encode_fixed_header(
|
||||
target, packet_type, flags,
|
||||
variable_header.size() + payload.size()
|
||||
);
|
||||
if (ec != ErrorCode::OK)
|
||||
return ec;
|
||||
target.insert(target.end(), variable_header.begin(), variable_header.end());
|
||||
target.insert(target.end(), payload.begin(), payload.end());
|
||||
return ErrorCode::OK;
|
||||
}
|
||||
|
||||
} // namespace util
|
||||
|
||||
} // namespace mqtt
|
||||
} // namespace esphome
|
||||
|
||||
#endif // USE_MQTT
|
||||
653
esphome/components/mqtt/packets.h
Normal file
653
esphome/components/mqtt/packets.h
Normal file
@@ -0,0 +1,653 @@
|
||||
#pragma once
|
||||
|
||||
#include "esphome/core/defines.h"
|
||||
|
||||
#ifdef USE_MQTT
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <cstdint>
|
||||
#include "esphome/core/optional.h"
|
||||
|
||||
namespace esphome {
|
||||
namespace mqtt {
|
||||
|
||||
enum class ErrorCode {
|
||||
OK = 0,
|
||||
VALUE_TOO_LONG = 1,
|
||||
BAD_FLAGS = 2,
|
||||
MALFORMED_PACKET = 3,
|
||||
BAD_STATE = 4,
|
||||
RESOLVE_ERROR = 5,
|
||||
SOCKET_ERROR = 6,
|
||||
TIMEOUT = 7,
|
||||
IN_PROGRESS = 8,
|
||||
WOULD_BLOCK = 9,
|
||||
CONNECTION_CLOSED = 10,
|
||||
UNEXPECTED = 11,
|
||||
PROTOCOL_ERROR = 12,
|
||||
};
|
||||
|
||||
enum class PacketType : uint8_t {
|
||||
CONNECT = 1,
|
||||
CONNACK = 2,
|
||||
PUBLISH = 3,
|
||||
PUBACK = 4,
|
||||
PUBREC = 5,
|
||||
PUBREL = 6,
|
||||
PUBCOMP = 7,
|
||||
SUBSCRIBE = 8,
|
||||
SUBACK = 9,
|
||||
UNSUBSCRIBE = 10,
|
||||
UNSUBACK = 11,
|
||||
PINGREQ = 12,
|
||||
PINGRESP = 13,
|
||||
DISCONNECT = 14,
|
||||
};
|
||||
|
||||
namespace util {
|
||||
class Parser {
|
||||
public:
|
||||
Parser(const uint8_t *data, size_t len) : data_(data), len_(len) {}
|
||||
|
||||
size_t size_left() const {
|
||||
return len_ - at_;
|
||||
}
|
||||
|
||||
uint8_t consume() {
|
||||
return data_[at_++];
|
||||
}
|
||||
|
||||
void consume(size_t amount) {
|
||||
at_ += amount;
|
||||
}
|
||||
|
||||
private:
|
||||
const uint8_t *data_;
|
||||
size_t len_;
|
||||
size_t at_ = 0;
|
||||
};
|
||||
|
||||
ErrorCode encode_uint16(std::vector<uint8_t> &target, uint16_t value);
|
||||
ErrorCode decode_uint16(Parser *parser, uint16_t *value);
|
||||
ErrorCode encode_bytes(std::vector<uint8_t> &target, const std::vector<uint8_t> &value);
|
||||
ErrorCode decode_bytes(Parser *parser, std::vector<uint8_t> *value);
|
||||
ErrorCode encode_utf8(std::vector<uint8_t> &target, const std::string &value);
|
||||
ErrorCode decode_utf8(Parser *parser, std::string *value);
|
||||
ErrorCode encode_varint(std::vector<uint8_t> &target, size_t value);
|
||||
ErrorCode encode_fixed_header(
|
||||
std::vector<uint8_t> &target,
|
||||
PacketType packet_type,
|
||||
uint8_t flags,
|
||||
size_t remaining_length
|
||||
);
|
||||
ErrorCode encode_packet(
|
||||
std::vector<uint8_t> &target,
|
||||
PacketType packet_type,
|
||||
uint8_t flags
|
||||
);
|
||||
ErrorCode encode_packet(
|
||||
std::vector<uint8_t> &target,
|
||||
PacketType packet_type,
|
||||
uint8_t flags,
|
||||
const std::vector<uint8_t> &variable_header
|
||||
);
|
||||
ErrorCode encode_packet(
|
||||
std::vector<uint8_t> &target,
|
||||
PacketType packet_type,
|
||||
uint8_t flags,
|
||||
const std::vector<uint8_t> &variable_header,
|
||||
const std::vector<uint8_t> &payload
|
||||
);
|
||||
|
||||
} // namespace util
|
||||
|
||||
class MQTTPacket {
|
||||
public:
|
||||
virtual ErrorCode encode(std::vector<uint8_t> &target) const = 0;
|
||||
virtual ErrorCode decode(uint8_t flags, util::Parser parser) = 0;
|
||||
};
|
||||
|
||||
enum class QOSLevel : uint8_t {
|
||||
QOS0 = 0,
|
||||
QOS1 = 1,
|
||||
QOS2 = 2,
|
||||
};
|
||||
|
||||
class ConnectPacket : public MQTTPacket {
|
||||
public:
|
||||
// 3.1
|
||||
std::string client_id;
|
||||
optional<std::string> username;
|
||||
optional<std::vector<uint8_t>> password;
|
||||
std::string will_topic;
|
||||
std::vector<uint8_t> will_message;
|
||||
QOSLevel will_qos = QOSLevel::QOS0;
|
||||
bool will_retain = false;
|
||||
bool clean_session = true;
|
||||
uint8_t protocol_level = 4;
|
||||
uint16_t keep_alive;
|
||||
|
||||
ErrorCode encode(std::vector<uint8_t> &target) const final {
|
||||
uint8_t connect_flags = 0;
|
||||
if (username.has_value())
|
||||
connect_flags |= 0x80;
|
||||
if (password.has_value())
|
||||
connect_flags |= 0x40;
|
||||
if (will_retain)
|
||||
connect_flags |= 0x20;
|
||||
connect_flags |= static_cast<uint8_t>(will_qos) << 3;
|
||||
if (!will_topic.empty())
|
||||
connect_flags |= 0x04;
|
||||
if (clean_session)
|
||||
connect_flags |= 0x02;
|
||||
std::vector<uint8_t> variable_header;
|
||||
variable_header.push_back(0x00);
|
||||
variable_header.push_back(0x04);
|
||||
variable_header.push_back('M');
|
||||
variable_header.push_back('Q');
|
||||
variable_header.push_back('T');
|
||||
variable_header.push_back('T');
|
||||
variable_header.push_back(protocol_level);
|
||||
variable_header.push_back(connect_flags);
|
||||
ErrorCode ec = util::encode_uint16(variable_header, keep_alive);
|
||||
if (ec != ErrorCode::OK)
|
||||
return ec;
|
||||
|
||||
std::vector<uint8_t> payload;
|
||||
ec = util::encode_utf8(payload, client_id);
|
||||
if (ec != ErrorCode::OK)
|
||||
return ec;
|
||||
if (!will_topic.empty()) {
|
||||
ec = util::encode_utf8(payload, will_topic);
|
||||
if (ec != ErrorCode::OK)
|
||||
return ec;
|
||||
ec = util::encode_bytes(payload, will_message);
|
||||
if (ec != ErrorCode::OK)
|
||||
return ec;
|
||||
}
|
||||
if (username.has_value()) {
|
||||
ec = util::encode_utf8(payload, *username);
|
||||
if (ec != ErrorCode::OK)
|
||||
return ec;
|
||||
}
|
||||
if (password.has_value()) {
|
||||
ec = util::encode_bytes(payload, *password);
|
||||
if (ec != ErrorCode::OK)
|
||||
return ec;
|
||||
}
|
||||
|
||||
return util::encode_packet(
|
||||
target, PacketType::CONNECT, 0,
|
||||
variable_header, payload
|
||||
);
|
||||
}
|
||||
|
||||
ErrorCode decode(uint8_t flags, util::Parser parser) override final {
|
||||
if (flags != 0)
|
||||
return ErrorCode::BAD_FLAGS;
|
||||
if (
|
||||
parser.size_left() < 10
|
||||
|| parser.consume() != '\x00'
|
||||
|| parser.consume() != '\x04'
|
||||
|| parser.consume() != 'M'
|
||||
|| parser.consume() != 'Q'
|
||||
|| parser.consume() != 'T'
|
||||
|| parser.consume() != 'T'
|
||||
)
|
||||
return ErrorCode::MALFORMED_PACKET;
|
||||
protocol_level = parser.consume();
|
||||
uint8_t connect_flags = parser.consume();
|
||||
bool username_flag = connect_flags & 0x80;
|
||||
bool password_flag = connect_flags & 0x40;
|
||||
will_retain = connect_flags & 0x20;
|
||||
will_qos = static_cast<QOSLevel>((connect_flags >> 3) & 3);
|
||||
bool will_flag = connect_flags & 0x04;
|
||||
clean_session = connect_flags & 0x02;
|
||||
if ((flags & 1) != 0)
|
||||
return ErrorCode::MALFORMED_PACKET;
|
||||
ErrorCode ec = util::decode_uint16(&parser, &keep_alive);
|
||||
if (ec != ErrorCode::OK)
|
||||
return ec;
|
||||
|
||||
ec = util::decode_utf8(&parser, &client_id);
|
||||
if (ec != ErrorCode::OK)
|
||||
return ec;
|
||||
|
||||
will_topic.clear();
|
||||
will_message.clear();
|
||||
if (will_flag) {
|
||||
ec = util::decode_utf8(&parser, &will_topic);
|
||||
if (ec != ErrorCode::OK)
|
||||
return ec;
|
||||
ec = util::decode_bytes(&parser, &will_message);
|
||||
if (ec != ErrorCode::OK)
|
||||
return ec;
|
||||
}
|
||||
|
||||
username.reset();
|
||||
if (username_flag) {
|
||||
username = {""};
|
||||
ec = util::decode_utf8(&parser, &(*username));
|
||||
if (ec != ErrorCode::OK)
|
||||
return ec;
|
||||
}
|
||||
password.reset();
|
||||
if (password_flag) {
|
||||
password = std::vector<uint8_t>{};
|
||||
ec = util::decode_bytes(&parser, &(*password));
|
||||
if (ec != ErrorCode::OK)
|
||||
return ec;
|
||||
}
|
||||
|
||||
return ErrorCode::OK;
|
||||
}
|
||||
};
|
||||
|
||||
enum class ConnectReturnCode : uint8_t {
|
||||
ACCEPTED = 0,
|
||||
UNACCEPTABLE_PROTOCOL_VERSION = 1,
|
||||
IDENTIFIER_REJECTED = 2,
|
||||
SERVER_UNAVAILABLE = 3,
|
||||
BAD_USER_NAME_OR_PASSWORD = 4,
|
||||
NOT_AUTHORIZED = 5,
|
||||
};
|
||||
|
||||
class ConnackPacket : public MQTTPacket {
|
||||
public:
|
||||
// 3.2
|
||||
bool session_present;
|
||||
ConnectReturnCode connect_return_code;
|
||||
|
||||
ErrorCode encode(std::vector<uint8_t> &target) const final {
|
||||
std::vector<uint8_t> variable_header;
|
||||
variable_header.push_back(static_cast<uint8_t>(session_present));
|
||||
variable_header.push_back(static_cast<uint8_t>(connect_return_code));
|
||||
return util::encode_packet(
|
||||
target, PacketType::CONNACK, 0,
|
||||
variable_header
|
||||
);
|
||||
}
|
||||
|
||||
ErrorCode decode(uint8_t flags, util::Parser parser) override final {
|
||||
if (flags != 0)
|
||||
return ErrorCode::BAD_FLAGS;
|
||||
if (parser.size_left() != 2)
|
||||
return ErrorCode::MALFORMED_PACKET;
|
||||
session_present = parser.consume() & 1;
|
||||
connect_return_code = static_cast<ConnectReturnCode>(parser.consume());
|
||||
return ErrorCode::OK;
|
||||
}
|
||||
};
|
||||
|
||||
class PublishPacket : public MQTTPacket {
|
||||
public:
|
||||
// 3.3
|
||||
std::string topic;
|
||||
std::vector<uint8_t> message;
|
||||
bool dup = false;
|
||||
QOSLevel qos = QOSLevel::QOS0;
|
||||
bool retain = false;
|
||||
optional<uint16_t> packet_identifier;
|
||||
|
||||
ErrorCode encode(std::vector<uint8_t> &target) const final {
|
||||
uint8_t flags = 0;
|
||||
if (dup)
|
||||
flags |= 0x08;
|
||||
flags |= static_cast<uint8_t>(qos) << 1;
|
||||
if (retain)
|
||||
flags |= 0x01;
|
||||
std::vector<uint8_t> variable_header;
|
||||
ErrorCode ec = util::encode_utf8(variable_header, topic);
|
||||
if (ec != ErrorCode::OK)
|
||||
return ec;
|
||||
if (packet_identifier.has_value()) {
|
||||
ec = util::encode_uint16(variable_header, *packet_identifier);
|
||||
if (ec != ErrorCode::OK)
|
||||
return ec;
|
||||
}
|
||||
|
||||
return util::encode_packet(
|
||||
target, PacketType::PUBLISH, flags,
|
||||
variable_header, message
|
||||
);
|
||||
}
|
||||
|
||||
ErrorCode decode(uint8_t flags, util::Parser parser) override final {
|
||||
dup = flags & 0x08;
|
||||
qos = static_cast<QOSLevel>((flags >> 1) & 3);
|
||||
retain = flags & 0x01;
|
||||
ErrorCode ec = util::decode_utf8(&parser, &topic);
|
||||
if (ec != ErrorCode::OK)
|
||||
return ec;
|
||||
if (qos == QOSLevel::QOS1 || qos == QOSLevel::QOS2) {
|
||||
packet_identifier = 0;
|
||||
ec = util::decode_uint16(&parser, &(*packet_identifier));
|
||||
if (ec != ErrorCode::OK)
|
||||
return ec;
|
||||
} else {
|
||||
packet_identifier.reset();
|
||||
}
|
||||
message.clear();
|
||||
message.reserve(parser.size_left());
|
||||
while (parser.size_left())
|
||||
message.push_back(parser.consume());
|
||||
return ErrorCode::OK;
|
||||
}
|
||||
};
|
||||
|
||||
class PubackPacket : public MQTTPacket {
|
||||
public:
|
||||
// 3.4
|
||||
uint16_t packet_identifier;
|
||||
|
||||
ErrorCode encode(std::vector<uint8_t> &target) const final {
|
||||
std::vector<uint8_t> variable_header;
|
||||
ErrorCode ec = util::encode_uint16(variable_header, packet_identifier);
|
||||
if (ec != ErrorCode::OK)
|
||||
return ec;
|
||||
return util::encode_packet(
|
||||
target, PacketType::PUBACK, 0,
|
||||
variable_header
|
||||
);
|
||||
}
|
||||
|
||||
ErrorCode decode(uint8_t flags, util::Parser parser) override final {
|
||||
if (flags != 0)
|
||||
return ErrorCode::BAD_FLAGS;
|
||||
if (parser.size_left() != 2)
|
||||
return ErrorCode::MALFORMED_PACKET;
|
||||
return util::decode_uint16(&parser, &packet_identifier);
|
||||
}
|
||||
};
|
||||
|
||||
class PubrecPacket : public MQTTPacket {
|
||||
public:
|
||||
// 3.5
|
||||
uint16_t packet_identifier;
|
||||
|
||||
ErrorCode encode(std::vector<uint8_t> &target) const final {
|
||||
std::vector<uint8_t> variable_header;
|
||||
ErrorCode ec = util::encode_uint16(variable_header, packet_identifier);
|
||||
if (ec != ErrorCode::OK)
|
||||
return ec;
|
||||
return util::encode_packet(
|
||||
target, PacketType::PUBREC, 0,
|
||||
variable_header
|
||||
);
|
||||
}
|
||||
|
||||
ErrorCode decode(uint8_t flags, util::Parser parser) override final {
|
||||
if (flags != 0)
|
||||
return ErrorCode::BAD_FLAGS;
|
||||
if (parser.size_left() != 2)
|
||||
return ErrorCode::MALFORMED_PACKET;
|
||||
return util::decode_uint16(&parser, &packet_identifier);
|
||||
}
|
||||
};
|
||||
|
||||
class PubrelPacket : public MQTTPacket {
|
||||
public:
|
||||
// 3.6
|
||||
uint16_t packet_identifier;
|
||||
|
||||
ErrorCode encode(std::vector<uint8_t> &target) const final {
|
||||
std::vector<uint8_t> variable_header;
|
||||
ErrorCode ec = util::encode_uint16(variable_header, packet_identifier);
|
||||
if (ec != ErrorCode::OK)
|
||||
return ec;
|
||||
return util::encode_packet(
|
||||
target, PacketType::PUBREL, 2,
|
||||
variable_header
|
||||
);
|
||||
}
|
||||
|
||||
ErrorCode decode(uint8_t flags, util::Parser parser) override final {
|
||||
if (flags != 2)
|
||||
return ErrorCode::BAD_FLAGS;
|
||||
if (parser.size_left() != 2)
|
||||
return ErrorCode::MALFORMED_PACKET;
|
||||
return util::decode_uint16(&parser, &packet_identifier);
|
||||
}
|
||||
};
|
||||
|
||||
class PubcompPacket : public MQTTPacket {
|
||||
public:
|
||||
// 3.7
|
||||
uint16_t packet_identifier;
|
||||
|
||||
ErrorCode encode(std::vector<uint8_t> &target) const final {
|
||||
std::vector<uint8_t> variable_header;
|
||||
ErrorCode ec = util::encode_uint16(variable_header, packet_identifier);
|
||||
if (ec != ErrorCode::OK)
|
||||
return ec;
|
||||
return util::encode_packet(
|
||||
target, PacketType::PUBCOMP, 2,
|
||||
variable_header
|
||||
);
|
||||
}
|
||||
|
||||
ErrorCode decode(uint8_t flags, util::Parser parser) override final {
|
||||
if (flags != 2)
|
||||
return ErrorCode::BAD_FLAGS;
|
||||
if (parser.size_left() != 2)
|
||||
return ErrorCode::MALFORMED_PACKET;
|
||||
return util::decode_uint16(&parser, &packet_identifier);
|
||||
}
|
||||
};
|
||||
|
||||
struct Subscription {
|
||||
std::string topic_filter;
|
||||
QOSLevel requested_qos = QOSLevel::QOS0;
|
||||
};
|
||||
|
||||
class SubscribePacket : public MQTTPacket {
|
||||
public:
|
||||
// 3.8
|
||||
uint16_t packet_identifier;
|
||||
std::vector<Subscription> subscriptions;
|
||||
|
||||
ErrorCode encode(std::vector<uint8_t> &target) const final {
|
||||
std::vector<uint8_t> variable_header;
|
||||
ErrorCode ec = util::encode_uint16(variable_header, packet_identifier);
|
||||
if (ec != ErrorCode::OK)
|
||||
return ec;
|
||||
std::vector<uint8_t> payload;
|
||||
for (const auto &sub : subscriptions) {
|
||||
ec = util::encode_utf8(payload, sub.topic_filter);
|
||||
if (ec != ErrorCode::OK)
|
||||
return ec;
|
||||
payload.push_back(static_cast<uint8_t>(sub.requested_qos));
|
||||
}
|
||||
return util::encode_packet(
|
||||
target, PacketType::SUBSCRIBE, 2,
|
||||
variable_header, payload
|
||||
);
|
||||
}
|
||||
|
||||
ErrorCode decode(uint8_t flags, util::Parser parser) override final {
|
||||
if (flags != 2)
|
||||
return ErrorCode::BAD_FLAGS;
|
||||
|
||||
ErrorCode ec = util::decode_uint16(&parser, &packet_identifier);
|
||||
if (ec != ErrorCode::OK)
|
||||
return ec;
|
||||
subscriptions.clear();
|
||||
while (parser.size_left()) {
|
||||
Subscription sub{};
|
||||
ec = util::decode_utf8(&parser, &sub.topic_filter);
|
||||
if (ec != ErrorCode::OK)
|
||||
return ec;
|
||||
if (parser.size_left() < 1)
|
||||
return ErrorCode::MALFORMED_PACKET;
|
||||
sub.requested_qos = static_cast<QOSLevel>(parser.consume());
|
||||
subscriptions.push_back(sub);
|
||||
}
|
||||
return ErrorCode::OK;
|
||||
}
|
||||
};
|
||||
|
||||
enum class SubackReturnCode : uint8_t {
|
||||
SUCCESS_MAX_QOS0 = 0x00,
|
||||
SUCCESS_MAX_QOS1 = 0x01,
|
||||
SUCCESS_MAX_QOS2 = 0x02,
|
||||
FAILURE = 0x80,
|
||||
};
|
||||
|
||||
class SubackPacket : public MQTTPacket {
|
||||
public:
|
||||
// 3.9
|
||||
uint16_t packet_identifier;
|
||||
std::vector<SubackReturnCode> return_codes;
|
||||
|
||||
ErrorCode encode(std::vector<uint8_t> &target) const final {
|
||||
std::vector<uint8_t> variable_header;
|
||||
ErrorCode ec = util::encode_uint16(variable_header, packet_identifier);
|
||||
std::vector<uint8_t> payload;
|
||||
payload.reserve(return_codes.size());
|
||||
for (SubackReturnCode rc : return_codes) {
|
||||
payload.push_back(static_cast<uint8_t>(rc));
|
||||
}
|
||||
return util::encode_packet(
|
||||
target, PacketType::SUBACK, 0,
|
||||
variable_header, payload
|
||||
);
|
||||
}
|
||||
|
||||
ErrorCode decode(uint8_t flags, util::Parser parser) override final {
|
||||
if (flags != 2)
|
||||
return ErrorCode::BAD_FLAGS;
|
||||
ErrorCode ec = util::decode_uint16(&parser, &packet_identifier);
|
||||
if (ec != ErrorCode::OK)
|
||||
return ec;
|
||||
return_codes.clear();
|
||||
return_codes.reserve(parser.size_left());
|
||||
while (parser.size_left()) {
|
||||
return_codes.push_back(static_cast<SubackReturnCode>(parser.consume()));
|
||||
}
|
||||
return ErrorCode::OK;
|
||||
}
|
||||
};
|
||||
|
||||
class UnsubscribePacket : public MQTTPacket {
|
||||
public:
|
||||
// 3.10
|
||||
uint16_t packet_identifier;
|
||||
std::vector<std::string> topic_filters;
|
||||
|
||||
ErrorCode encode(std::vector<uint8_t> &target) const final {
|
||||
std::vector<uint8_t> variable_header;
|
||||
ErrorCode ec = util::encode_uint16(variable_header, packet_identifier);
|
||||
if (ec != ErrorCode::OK)
|
||||
return ec;
|
||||
std::vector<uint8_t> payload;
|
||||
for (const auto &topic : topic_filters) {
|
||||
ec = util::encode_utf8(payload, topic);
|
||||
if (ec != ErrorCode::OK)
|
||||
return ec;
|
||||
}
|
||||
return util::encode_packet(
|
||||
target, PacketType::UNSUBACK, 2,
|
||||
variable_header, payload
|
||||
);
|
||||
}
|
||||
|
||||
ErrorCode decode(uint8_t flags, util::Parser parser) override final {
|
||||
if (flags != 2)
|
||||
return ErrorCode::BAD_FLAGS;
|
||||
|
||||
ErrorCode ec = util::decode_uint16(&parser, &packet_identifier);
|
||||
if (ec != ErrorCode::OK)
|
||||
return ec;
|
||||
topic_filters.clear();
|
||||
while (parser.size_left()) {
|
||||
std::string topic;
|
||||
ec = util::decode_utf8(&parser, &topic);
|
||||
if (ec != ErrorCode::OK)
|
||||
return ec;
|
||||
topic_filters.push_back(topic);
|
||||
}
|
||||
return ErrorCode::OK;
|
||||
}
|
||||
};
|
||||
|
||||
class UnsubackPacket : public MQTTPacket {
|
||||
public:
|
||||
// 3.11
|
||||
uint16_t packet_identifier;
|
||||
|
||||
ErrorCode encode(std::vector<uint8_t> &target) const final {
|
||||
std::vector<uint8_t> variable_header;
|
||||
ErrorCode ec = util::encode_uint16(variable_header, packet_identifier);
|
||||
if (ec != ErrorCode::OK)
|
||||
return ec;
|
||||
return util::encode_packet(
|
||||
target, PacketType::UNSUBACK, 0,
|
||||
variable_header
|
||||
);
|
||||
}
|
||||
|
||||
ErrorCode decode(uint8_t flags, util::Parser parser) override final {
|
||||
if (flags != 0)
|
||||
return ErrorCode::BAD_FLAGS;
|
||||
if (parser.size_left() != 2)
|
||||
return ErrorCode::MALFORMED_PACKET;
|
||||
return util::decode_uint16(&parser, &packet_identifier);
|
||||
}
|
||||
};
|
||||
|
||||
class PingreqPacket : public MQTTPacket {
|
||||
public:
|
||||
// 3.12
|
||||
|
||||
ErrorCode encode(std::vector<uint8_t> &target) const final {
|
||||
return util::encode_packet(target, PacketType::PINGREQ, 0);
|
||||
}
|
||||
|
||||
ErrorCode decode(uint8_t flags, util::Parser parser) override final {
|
||||
if (flags != 0)
|
||||
return ErrorCode::BAD_FLAGS;
|
||||
if (parser.size_left() != 0)
|
||||
return ErrorCode::MALFORMED_PACKET;
|
||||
return ErrorCode::OK;
|
||||
}
|
||||
};
|
||||
|
||||
class PingrespPacket : public MQTTPacket {
|
||||
public:
|
||||
// 3.13
|
||||
|
||||
ErrorCode encode(std::vector<uint8_t> &target) const final {
|
||||
return util::encode_packet(target, PacketType::PINGRESP, 0);
|
||||
}
|
||||
|
||||
ErrorCode decode(uint8_t flags, util::Parser parser) override final {
|
||||
if (flags != 0)
|
||||
return ErrorCode::BAD_FLAGS;
|
||||
if (parser.size_left() != 0)
|
||||
return ErrorCode::MALFORMED_PACKET;
|
||||
return ErrorCode::OK;
|
||||
}
|
||||
};
|
||||
|
||||
class DisconnectPacket : public MQTTPacket {
|
||||
public:
|
||||
// 3.14
|
||||
|
||||
ErrorCode encode(std::vector<uint8_t> &target) const final {
|
||||
return util::encode_packet(target, PacketType::DISCONNECT, 0);
|
||||
}
|
||||
|
||||
ErrorCode decode(uint8_t flags, util::Parser parser) override final {
|
||||
if (flags != 0)
|
||||
return ErrorCode::BAD_FLAGS;
|
||||
if (parser.size_left() != 0)
|
||||
return ErrorCode::MALFORMED_PACKET;
|
||||
return ErrorCode::OK;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mqtt
|
||||
} // namespace esphome
|
||||
|
||||
#endif // USE_MQTT
|
||||
@@ -1,5 +1,6 @@
|
||||
import esphome.config_validation as cv
|
||||
import esphome.codegen as cg
|
||||
from esphome.core import CORE
|
||||
|
||||
CODEOWNERS = ["@esphome/core"]
|
||||
|
||||
@@ -26,3 +27,6 @@ async def to_code(config):
|
||||
cg.add_define("USE_SOCKET_IMPL_LWIP_TCP")
|
||||
elif impl == IMPLEMENTATION_BSD_SOCKETS:
|
||||
cg.add_define("USE_SOCKET_IMPL_BSD_SOCKETS")
|
||||
|
||||
if CORE.target_platform in ["esp8266", "esp32"]:
|
||||
cg.add_define("USE_SOCKET_HAS_LWIP")
|
||||
|
||||
@@ -48,6 +48,43 @@ class BSDSocketImpl : public Socket {
|
||||
return make_unique<BSDSocketImpl>(fd);
|
||||
}
|
||||
int bind(const struct sockaddr *addr, socklen_t addrlen) override { return ::bind(fd_, addr, addrlen); }
|
||||
int connect(const struct sockaddr *addr, socklen_t addrlen) override { return ::connect(fd_, addr, addrlen); }
|
||||
int connect_finished() override {
|
||||
fd_set wfds;
|
||||
struct timeval tv;
|
||||
FD_ZERO(&wfds);
|
||||
FD_SET(fd_, &wfds);
|
||||
tv.tv_sec = 0;
|
||||
tv.tv_usec = 0;
|
||||
int retval = ::select(fd_ + 1, nullptr, &wfds, nullptr, &tv);
|
||||
if (retval == -1) {
|
||||
// reuse errno
|
||||
return -1;
|
||||
}
|
||||
if (retval == 0) {
|
||||
// timeout, not writable yet
|
||||
errno = EINPROGRESS;
|
||||
return -1;
|
||||
}
|
||||
if (!FD_ISSET(fd_, &wfds)) {
|
||||
errno = ECONNREFUSED;
|
||||
return -1;
|
||||
}
|
||||
|
||||
int so_error;
|
||||
socklen_t len = sizeof(so_error);
|
||||
int ret = this->getsockopt(SOL_SOCKET, SO_ERROR, &so_error, &len);
|
||||
if (ret == -1) {
|
||||
// reuse errno
|
||||
return -1;
|
||||
}
|
||||
if (so_error == 0) {
|
||||
return 0;
|
||||
}
|
||||
errno = ECONNREFUSED;
|
||||
return -1;
|
||||
}
|
||||
|
||||
int close() override {
|
||||
int ret = ::close(fd_);
|
||||
closed_ = true;
|
||||
|
||||
34
esphome/components/socket/getaddrinfo.h
Normal file
34
esphome/components/socket/getaddrinfo.h
Normal file
@@ -0,0 +1,34 @@
|
||||
#pragma once
|
||||
#include <memory>
|
||||
#include "headers.h"
|
||||
|
||||
namespace esphome {
|
||||
namespace socket {
|
||||
|
||||
struct GetaddrinfoFuture {
|
||||
public:
|
||||
virtual ~GetaddrinfoFuture() = default;
|
||||
// returns true when the request has completed (successfully or with an error)
|
||||
virtual bool completed() = 0;
|
||||
/**
|
||||
* @brief Fetch the completed result into res.
|
||||
*
|
||||
* Should only be called after completed() returned true.
|
||||
* Make sure to call freeaddrinfo() to free the addrinfo storage
|
||||
* when it's no longer needed.
|
||||
*
|
||||
* @return See posix getaddrinfo() return values.
|
||||
*/
|
||||
virtual int fetch_result(struct addrinfo **res) = 0;
|
||||
};
|
||||
|
||||
std::unique_ptr<GetaddrinfoFuture> getaddrinfo_async(const char *node, const char *service,
|
||||
const struct addrinfo *hints);
|
||||
|
||||
} // namespace socket
|
||||
} // namespace esphome
|
||||
|
||||
#ifdef USE_ESP8266
|
||||
void freeaddrinfo(struct addrinfo *ai);
|
||||
const char *gai_strerror(int errcode);
|
||||
#endif
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
#define LWIP_INTERNAL
|
||||
#include "lwip/inet.h"
|
||||
#include "lwip/netdb.h"
|
||||
#include <cerrno>
|
||||
#include <cstdint>
|
||||
#include <sys/types.h>
|
||||
@@ -107,6 +108,34 @@ struct iovec {
|
||||
#define ESPHOME_INADDR_NONE INADDR_NONE
|
||||
#endif
|
||||
|
||||
#ifndef EAI_FAIL
|
||||
#define EAI_BADFLAGS (-1)
|
||||
#define EAI_NONAME (-2)
|
||||
#define EAI_AGAIN (-3)
|
||||
#define EAI_FAIL (-4)
|
||||
#define EAI_FAMILY (-6)
|
||||
#define EAI_SOCKTYPE (-7)
|
||||
#define EAI_SERVICE (-8)
|
||||
#define EAI_MEMORY (-10)
|
||||
#define EAI_SYSTEM (-11)
|
||||
#define EAI_OVERFLOW (-12)
|
||||
#endif // !EAI_FAIL
|
||||
|
||||
#ifndef IPPROTO_UDP
|
||||
#define IPPROTO_UDP 17
|
||||
#endif
|
||||
|
||||
struct addrinfo { // NOLINT(readability-identifier-naming)
|
||||
int ai_flags;
|
||||
int ai_family;
|
||||
int ai_socktype;
|
||||
int ai_protocol;
|
||||
socklen_t ai_addrlen;
|
||||
struct sockaddr *ai_addr;
|
||||
char *ai_canonname;
|
||||
struct addrinfo *ai_next;
|
||||
};
|
||||
|
||||
#endif // USE_SOCKET_IMPL_LWIP_TCP
|
||||
|
||||
#ifdef USE_SOCKET_IMPL_BSD_SOCKETS
|
||||
@@ -118,6 +147,7 @@ struct iovec {
|
||||
#include <sys/types.h>
|
||||
#include <sys/uio.h>
|
||||
#include <unistd.h>
|
||||
#include <netdb.h>
|
||||
|
||||
#ifdef USE_ARDUINO
|
||||
// arduino-esp32 declares a global var called INADDR_NONE which is replaced
|
||||
|
||||
197
esphome/components/socket/lwip_getaddrinfo_impl.cpp
Normal file
197
esphome/components/socket/lwip_getaddrinfo_impl.cpp
Normal file
@@ -0,0 +1,197 @@
|
||||
#include "getaddrinfo.h"
|
||||
#include "esphome/core/defines.h"
|
||||
|
||||
#ifdef USE_SOCKET_HAS_LWIP
|
||||
|
||||
#include <utility>
|
||||
#include "lwip/dns.h"
|
||||
#include "lwip/ip_addr.h"
|
||||
#include "lwip/netdb.h"
|
||||
|
||||
#include "esphome/core/helpers.h"
|
||||
#include "esphome/core/log.h"
|
||||
|
||||
namespace esphome {
|
||||
namespace socket {
|
||||
|
||||
static const char *const TAG = "socket.lwipgetaddrinfo";
|
||||
|
||||
struct LwipDNSResult {
|
||||
bool completed;
|
||||
bool error;
|
||||
ip_addr_t ipaddr;
|
||||
};
|
||||
|
||||
struct LwipDNSCallbackArg {
|
||||
std::weak_ptr<LwipDNSResult> res;
|
||||
};
|
||||
|
||||
void lwip_dns_callback(const char *name, const ip_addr_t *ipaddr, void *callback_arg) {
|
||||
LwipDNSCallbackArg *arg = reinterpret_cast<LwipDNSCallbackArg *>(callback_arg);
|
||||
{
|
||||
std::shared_ptr<LwipDNSResult> result = arg->res.lock();
|
||||
if (result) {
|
||||
if (ipaddr == nullptr) {
|
||||
result->error = true;
|
||||
} else {
|
||||
result->error = false;
|
||||
ip_addr_copy(result->ipaddr, *ipaddr);
|
||||
}
|
||||
result->completed = true;
|
||||
}
|
||||
}
|
||||
delete arg; // NOLINT(cppcoreguidelines-owning-memory)
|
||||
}
|
||||
|
||||
class LwipGetaddrinfoFuture : public GetaddrinfoFuture {
|
||||
public:
|
||||
LwipGetaddrinfoFuture(std::shared_ptr<LwipDNSResult> result, int hint_ai_socktype, int hint_ai_protocol,
|
||||
uint16_t portno)
|
||||
: result_(std::move(result)),
|
||||
hint_ai_socktype_(hint_ai_socktype),
|
||||
hint_ai_protocol_(hint_ai_protocol),
|
||||
portno_(portno) {}
|
||||
~LwipGetaddrinfoFuture() override = default;
|
||||
|
||||
bool completed() override { return result_->completed; }
|
||||
int fetch_result(struct addrinfo **res) override {
|
||||
if (res == nullptr)
|
||||
return EAI_FAIL;
|
||||
*res = nullptr;
|
||||
if (!result_->completed)
|
||||
return EAI_FAIL;
|
||||
if (result_->error)
|
||||
return EAI_FAIL;
|
||||
|
||||
size_t alloc_size = sizeof(struct addrinfo) + sizeof(struct sockaddr_storage);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-owning-memory,cppcoreguidelines-no-malloc)
|
||||
void *storage = malloc(alloc_size);
|
||||
memset(storage, 0, alloc_size);
|
||||
struct addrinfo *ai = reinterpret_cast<struct addrinfo *>(storage);
|
||||
struct sockaddr_storage *sa = reinterpret_cast<struct sockaddr_storage *>(ai + 1);
|
||||
|
||||
bool isipv6 = IP_IS_V6(result_->ipaddr);
|
||||
|
||||
bool istcp = true;
|
||||
if ((hint_ai_socktype_ != 0 && hint_ai_socktype_ == SOCK_DGRAM) ||
|
||||
(hint_ai_protocol_ != 0 && hint_ai_protocol_ == IPPROTO_UDP)) {
|
||||
istcp = false;
|
||||
}
|
||||
|
||||
ai->ai_family = isipv6 ? AF_INET6 : AF_INET;
|
||||
ai->ai_socktype = istcp ? SOCK_STREAM : SOCK_DGRAM;
|
||||
ai->ai_protocol = istcp ? IPPROTO_TCP : IPPROTO_UDP;
|
||||
|
||||
if (isipv6) {
|
||||
#if LWIP_IPV6
|
||||
struct sockaddr_in6 *sa6 = reinterpret_cast<struct sockaddr_in6 *>(sa);
|
||||
inet6_addr_from_ip6addr(&sa6->sin6_addr, ip_2_ip6(&result_->ipaddr)) sa6->sin6_family = AF_INET6;
|
||||
sa6->sin6_len = sizeof(struct sockaddr_in6);
|
||||
sa6->sin6_port = htons(portno_);
|
||||
#endif // LWIP_IPV6
|
||||
} else {
|
||||
struct sockaddr_in *sa4 = reinterpret_cast<struct sockaddr_in *>(sa);
|
||||
inet_addr_from_ip4addr(&sa4->sin_addr, ip_2_ip4(&result_->ipaddr));
|
||||
sa4->sin_family = AF_INET;
|
||||
sa4->sin_len = sizeof(struct sockaddr_in);
|
||||
sa4->sin_port = htons(portno_);
|
||||
}
|
||||
|
||||
ai->ai_addrlen = sizeof(struct sockaddr_storage);
|
||||
ai->ai_addr = reinterpret_cast<struct sockaddr *>(sa);
|
||||
*res = ai;
|
||||
return 0;
|
||||
}
|
||||
|
||||
protected:
|
||||
std::shared_ptr<LwipDNSResult> result_;
|
||||
int hint_ai_socktype_;
|
||||
int hint_ai_protocol_;
|
||||
uint16_t portno_;
|
||||
};
|
||||
|
||||
std::unique_ptr<GetaddrinfoFuture> getaddrinfo_async(const char *node, const char *service,
|
||||
const struct addrinfo *hints) {
|
||||
std::shared_ptr<LwipDNSResult> result = std::make_shared<LwipDNSResult>();
|
||||
result->completed = false;
|
||||
|
||||
uint16_t portno = 0;
|
||||
if (service != nullptr) {
|
||||
optional<uint16_t> i = parse_number<uint16_t>(service);
|
||||
if (!i.has_value()) {
|
||||
result->completed = true;
|
||||
result->error = true;
|
||||
return std::unique_ptr<GetaddrinfoFuture>{new LwipGetaddrinfoFuture(result, 0, 0, 0)};
|
||||
}
|
||||
portno = *i;
|
||||
}
|
||||
|
||||
int hint_ai_socktype = 0, hint_ai_protocol = 0;
|
||||
uint8_t dns_addrtype = LWIP_DNS_ADDRTYPE_DEFAULT;
|
||||
if (hints != nullptr) {
|
||||
hint_ai_socktype = hints->ai_socktype;
|
||||
hint_ai_protocol = hints->ai_protocol;
|
||||
if (hints->ai_family == AF_INET) {
|
||||
dns_addrtype = LWIP_DNS_ADDRTYPE_IPV4;
|
||||
} else if (hints->ai_family == AF_INET6) {
|
||||
dns_addrtype = LWIP_DNS_ADDRTYPE_IPV6;
|
||||
}
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-owning-memory)
|
||||
LwipDNSCallbackArg *callback_arg = new LwipDNSCallbackArg;
|
||||
callback_arg->res = result;
|
||||
|
||||
ip_addr_t immediate_result;
|
||||
err_t err = dns_gethostbyname_addrtype(node, &immediate_result, lwip_dns_callback, callback_arg, dns_addrtype);
|
||||
if (err == ERR_OK) {
|
||||
// immediate result
|
||||
result->completed = true;
|
||||
result->error = false;
|
||||
ip_addr_copy(result->ipaddr, immediate_result);
|
||||
|
||||
// callback won't be called
|
||||
delete callback_arg; // NOLINT(cppcoreguidelines-owning-memory)
|
||||
} else if (err == ERR_INPROGRESS) {
|
||||
// result notified via callback
|
||||
} else {
|
||||
// error
|
||||
result->completed = true;
|
||||
result->error = true;
|
||||
|
||||
// callback won't be called
|
||||
delete callback_arg; // NOLINT(cppcoreguidelines-owning-memory)
|
||||
}
|
||||
|
||||
return std::unique_ptr<GetaddrinfoFuture>{
|
||||
new LwipGetaddrinfoFuture(result, hint_ai_socktype, hint_ai_protocol, portno)};
|
||||
}
|
||||
|
||||
} // namespace socket
|
||||
} // namespace esphome
|
||||
|
||||
#ifdef USE_ESP8266
|
||||
void freeaddrinfo(struct addrinfo *ai) {
|
||||
while (ai != nullptr) {
|
||||
struct addrinfo *next = ai->ai_next;
|
||||
delete ai; // NOLINT(cppcoreguidelines-owning-memory)
|
||||
ai = next;
|
||||
}
|
||||
}
|
||||
const char *gai_strerror(int errcode) {
|
||||
switch (errcode) {
|
||||
case EAI_BADFLAGS: return "badflags";
|
||||
case EAI_NONAME: return "noname";
|
||||
case EAI_AGAIN: return "again";
|
||||
case EAI_FAMILY: return "family";
|
||||
case EAI_SOCKTYPE: return "socktype";
|
||||
case EAI_SERVICE: return "service";
|
||||
case EAI_MEMORY: return "memory";
|
||||
case EAI_SYSTEM: return "system";
|
||||
case EAI_OVERFLOW: return "overflow";
|
||||
default: return "unknown";
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // USE_SOCKET_HAS_LWIP
|
||||
@@ -69,7 +69,7 @@ class LWIPRawImpl : public Socket {
|
||||
}
|
||||
if (name == nullptr) {
|
||||
errno = EINVAL;
|
||||
return 0;
|
||||
return -1;
|
||||
}
|
||||
ip_addr_t ip;
|
||||
in_port_t port;
|
||||
@@ -126,6 +126,76 @@ class LWIPRawImpl : public Socket {
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
int connect(const struct sockaddr *addr, socklen_t addrlen) override {
|
||||
if (pcb_ == nullptr) {
|
||||
errno = EBADF;
|
||||
return -1;
|
||||
}
|
||||
if (addr == nullptr) {
|
||||
errno = EINVAL;
|
||||
return -1;
|
||||
}
|
||||
if (connecting_) {
|
||||
errno = EALREADY;
|
||||
return -1;
|
||||
}
|
||||
|
||||
ip_addr_t ipaddr;
|
||||
uint16_t port;
|
||||
|
||||
if (addr->sa_family == AF_INET) {
|
||||
const struct sockaddr_in *sa4 = reinterpret_cast<const struct sockaddr_in *>(addr);
|
||||
inet_addr_to_ip4addr(ip_2_ip4(&ipaddr), &sa4->sin_addr);
|
||||
#if LWIP_IPV4 && LWIP_IPV6
|
||||
ipaddr.type = IPADDR_TYPE_V4;
|
||||
#endif
|
||||
port = ntohs(sa4->sin_port);
|
||||
#if LWIP_IPV6
|
||||
} else if (addr->sa_family == AF_INET6) {
|
||||
const struct sockaddr_in6 *sa6 = reinterpret_cast<const struct sockaddr_in6 *>(addr);
|
||||
inet6_addr_to_ip6addr(ip_2_ip6(&ipaddr), &sa6->sin_addr);
|
||||
ipaddr.type = IPADDR_TYPE_V6;
|
||||
port = ntohs(sa6->sin_port);
|
||||
#endif // LWIP_IPV6
|
||||
} else {
|
||||
errno = EAFNOSUPPORT;
|
||||
return -1;
|
||||
}
|
||||
|
||||
connecting_ = true;
|
||||
connected_ = false;
|
||||
connect_error_ = false;
|
||||
LWIP_LOG("tcp_connect(%u)", port);
|
||||
err_t err = tcp_connect(pcb_, &ipaddr, port, LWIPRawImpl::s_connected_fn);
|
||||
if (err == ERR_VAL) {
|
||||
errno = EINVAL;
|
||||
return -1;
|
||||
}
|
||||
if (err != ERR_OK) {
|
||||
errno = EIO;
|
||||
return -1;
|
||||
}
|
||||
|
||||
errno = EINPROGRESS;
|
||||
return -1;
|
||||
}
|
||||
int connect_finished() override {
|
||||
if (connected_) {
|
||||
return 0;
|
||||
}
|
||||
if (connect_error_) {
|
||||
errno = ECONNREFUSED;
|
||||
return -1;
|
||||
}
|
||||
if (connecting_) {
|
||||
errno = EINPROGRESS;
|
||||
return -1;
|
||||
}
|
||||
// no connect started
|
||||
errno = EALREADY;
|
||||
return -1;
|
||||
}
|
||||
|
||||
int close() override {
|
||||
if (pcb_ == nullptr) {
|
||||
errno = ECONNRESET;
|
||||
@@ -369,9 +439,10 @@ class LWIPRawImpl : public Socket {
|
||||
for (int i = 0; i < iovcnt; i++) {
|
||||
ssize_t err = read(reinterpret_cast<uint8_t *>(iov[i].iov_base), iov[i].iov_len);
|
||||
if (err == -1) {
|
||||
if (ret != 0)
|
||||
if (ret != 0) {
|
||||
// if we already read some don't return an error
|
||||
break;
|
||||
}
|
||||
return err;
|
||||
}
|
||||
ret += err;
|
||||
@@ -433,9 +504,10 @@ class LWIPRawImpl : public Socket {
|
||||
ssize_t written = internal_write(buf, len);
|
||||
if (written == -1)
|
||||
return -1;
|
||||
if (written == 0)
|
||||
if (written == 0) {
|
||||
// no need to output if nothing written
|
||||
return 0;
|
||||
}
|
||||
if (nodelay_) {
|
||||
int err = internal_output();
|
||||
if (err == -1)
|
||||
@@ -448,18 +520,20 @@ class LWIPRawImpl : public Socket {
|
||||
for (int i = 0; i < iovcnt; i++) {
|
||||
ssize_t err = internal_write(reinterpret_cast<uint8_t *>(iov[i].iov_base), iov[i].iov_len);
|
||||
if (err == -1) {
|
||||
if (written != 0)
|
||||
if (written != 0) {
|
||||
// if we already read some don't return an error
|
||||
break;
|
||||
}
|
||||
return err;
|
||||
}
|
||||
written += err;
|
||||
if ((size_t) err != iov[i].iov_len)
|
||||
break;
|
||||
}
|
||||
if (written == 0)
|
||||
if (written == 0) {
|
||||
// no need to output if nothing written
|
||||
return 0;
|
||||
}
|
||||
if (nodelay_) {
|
||||
int err = internal_output();
|
||||
if (err == -1)
|
||||
@@ -524,6 +598,18 @@ class LWIPRawImpl : public Socket {
|
||||
}
|
||||
return ERR_OK;
|
||||
}
|
||||
err_t connected_fn(err_t err) {
|
||||
LWIP_LOG("connected(err=%d)", err);
|
||||
if (err != ERR_OK) {
|
||||
connected_ = false;
|
||||
connect_error_ = false;
|
||||
} else {
|
||||
connected_ = true;
|
||||
connect_error_ = true;
|
||||
}
|
||||
connecting_ = false;
|
||||
return ERR_OK;
|
||||
}
|
||||
|
||||
static err_t s_accept_fn(void *arg, struct tcp_pcb *newpcb, err_t err) {
|
||||
LWIPRawImpl *arg_this = reinterpret_cast<LWIPRawImpl *>(arg);
|
||||
@@ -540,6 +626,11 @@ class LWIPRawImpl : public Socket {
|
||||
return arg_this->recv_fn(pb, err);
|
||||
}
|
||||
|
||||
static err_t s_connected_fn(void *arg, struct tcp_pcb *pcb, err_t err) {
|
||||
LWIPRawImpl *arg_this = reinterpret_cast<LWIPRawImpl *>(arg);
|
||||
return arg_this->connected_fn(err);
|
||||
}
|
||||
|
||||
protected:
|
||||
int ip2sockaddr_(ip_addr_t *ip, uint16_t port, struct sockaddr *name, socklen_t *addrlen) {
|
||||
if (family_ == AF_INET) {
|
||||
@@ -590,6 +681,9 @@ class LWIPRawImpl : public Socket {
|
||||
// instead use it for determining whether to call lwip_output
|
||||
bool nodelay_ = false;
|
||||
sa_family_t family_ = 0;
|
||||
bool connecting_ = false;
|
||||
bool connected_ = false;
|
||||
bool connect_error_ = false;
|
||||
};
|
||||
|
||||
std::unique_ptr<Socket> socket(int domain, int type, int protocol) {
|
||||
|
||||
@@ -7,7 +7,7 @@ namespace esphome {
|
||||
namespace socket {
|
||||
|
||||
std::unique_ptr<Socket> socket_ip(int type, int protocol) {
|
||||
#if LWIP_IPV6
|
||||
#ifdef USE_SOCKET_IPV6
|
||||
return socket(AF_INET6, type, protocol);
|
||||
#else
|
||||
return socket(AF_INET, type, protocol);
|
||||
@@ -15,7 +15,7 @@ std::unique_ptr<Socket> socket_ip(int type, int protocol) {
|
||||
}
|
||||
|
||||
socklen_t set_sockaddr_any(struct sockaddr *addr, socklen_t addrlen, uint16_t port) {
|
||||
#if LWIP_IPV6
|
||||
#if USE_SOCKET_IPV6
|
||||
if (addrlen < sizeof(sockaddr_in6)) {
|
||||
errno = EINVAL;
|
||||
return 0;
|
||||
|
||||
@@ -5,6 +5,12 @@
|
||||
#include "headers.h"
|
||||
#include "esphome/core/optional.h"
|
||||
|
||||
#ifdef USE_SOCKET_IMPL_LWIP_TCP
|
||||
#if LWIP_IPV6
|
||||
#define USE_SOCKET_IPV6
|
||||
#endif
|
||||
#endif
|
||||
|
||||
namespace esphome {
|
||||
namespace socket {
|
||||
|
||||
@@ -17,10 +23,17 @@ class Socket {
|
||||
|
||||
virtual std::unique_ptr<Socket> accept(struct sockaddr *addr, socklen_t *addrlen) = 0;
|
||||
virtual int bind(const struct sockaddr *addr, socklen_t addrlen) = 0;
|
||||
virtual int connect(const struct sockaddr *addr, socklen_t addrlen) = 0;
|
||||
/**
|
||||
* @brief Helper to check if a socket connect() that was EINPROGRESS is now finished.
|
||||
*
|
||||
* If the connect finnished successfully, returns 0.
|
||||
* If it's still in progress, returns -1 and sets errno to EINPROGRESS.
|
||||
* Other errors result in return code -1 and errno like in blocking connect().
|
||||
*/
|
||||
virtual int connect_finished() = 0;
|
||||
|
||||
virtual int close() = 0;
|
||||
// not supported yet:
|
||||
// virtual int connect(const std::string &address) = 0;
|
||||
// virtual int connect(const struct sockaddr *addr, socklen_t addrlen) = 0;
|
||||
virtual int shutdown(int how) = 0;
|
||||
|
||||
virtual int getpeername(struct sockaddr *addr, socklen_t *addrlen) = 0;
|
||||
|
||||
99
esphome/components/socket/thread_getaddrinfo_impl.cpp
Normal file
99
esphome/components/socket/thread_getaddrinfo_impl.cpp
Normal file
@@ -0,0 +1,99 @@
|
||||
#include "getaddrinfo.h"
|
||||
#include "esphome/core/defines.h"
|
||||
|
||||
#ifndef USE_SOCKET_HAS_LWIP
|
||||
|
||||
#include <thread>
|
||||
#include <sys/types.h>
|
||||
#include <sys/socket.h>
|
||||
#include <netdb.h>
|
||||
|
||||
#include "esphome/core/helpers.h"
|
||||
#include "esphome/core/log.h"
|
||||
|
||||
namespace esphome {
|
||||
namespace socket {
|
||||
|
||||
static const char *const TAG = "socket.threadgetaddrinfo";
|
||||
|
||||
struct ThreadGetaddrinfoResult {
|
||||
bool completed;
|
||||
int return_code;
|
||||
struct addrinfo *res;
|
||||
};
|
||||
|
||||
class ThreadGetaddrinfoFuture : public GetaddrinfoFuture {
|
||||
public:
|
||||
ThreadGetaddrinfoFuture(std::shared_ptr<ThreadGetaddrinfoResult> result) : result_(result) {}
|
||||
~ThreadGetaddrinfoFuture() override = default;
|
||||
|
||||
bool completed() override { return result_->completed; }
|
||||
int fetch_result(struct addrinfo **res) {
|
||||
if (res == nullptr)
|
||||
return EAI_FAIL;
|
||||
*res = nullptr;
|
||||
if (!result_->completed)
|
||||
return EAI_FAIL;
|
||||
if (result_->return_code != 0)
|
||||
return result_->return_code;
|
||||
|
||||
*res = result_->res;
|
||||
return 0;
|
||||
}
|
||||
|
||||
protected:
|
||||
std::shared_ptr<ThreadGetaddrinfoResult> result_;
|
||||
};
|
||||
|
||||
void worker(std::shared_ptr<ThreadGetaddrinfoResult> result, const char *node, const char *service,
|
||||
const struct addrinfo *hints) {
|
||||
result->return_code = getaddrinfo(node, service, hints, &result->res);
|
||||
result->completed = true;
|
||||
if (hints != nullptr) {
|
||||
delete hints->ai_addr;
|
||||
delete hints->ai_canonname;
|
||||
delete hints;
|
||||
}
|
||||
delete node;
|
||||
delete service;
|
||||
}
|
||||
|
||||
std::unique_ptr<GetaddrinfoFuture> getaddrinfo_async(const char *node, const char *service,
|
||||
const struct addrinfo *hints) {
|
||||
std::shared_ptr<ThreadGetaddrinfoResult> result = std::make_shared<ThreadGetaddrinfoResult>();
|
||||
result->completed = false;
|
||||
|
||||
struct addrinfo *hints_copy = nullptr;
|
||||
if (hints != nullptr) {
|
||||
hints_copy = new struct addrinfo;
|
||||
hints_copy->ai_flags = hints->ai_flags;
|
||||
hints_copy->ai_family = hints->ai_family;
|
||||
hints_copy->ai_socktype = hints->ai_socktype;
|
||||
hints_copy->ai_protocol = hints->ai_protocol;
|
||||
hints_copy->ai_addrlen = hints->ai_addrlen;
|
||||
if (ai->ai_addr != nullptr) {
|
||||
hints_copy->ai_addr = malloc(hints->ai_addrlen);
|
||||
memcpy(hints_copy->ai_addr, hints->ai_addr, hints->ai_addrlen);
|
||||
}
|
||||
if (ai->ai_canonname != nullptr) {
|
||||
hints_copy->ai_canonname = strdup(hints->ai_canonname);
|
||||
}
|
||||
hints_copy->ai_next = nullptr;
|
||||
}
|
||||
|
||||
const char *node_copy = nullptr, *service_copy = nullptr;
|
||||
if (node != nullptr)
|
||||
node_copy = strdup(node);
|
||||
if (service != nullptr)
|
||||
service_copy = strdup(service);
|
||||
|
||||
std::thread thread(worker, result, node_copy, service_copy, hints_copy);
|
||||
thread.detach();
|
||||
|
||||
return std::unique_ptr<GetaddrinfoFuture>{new ThreadGetaddrinfoFuture(result)};
|
||||
}
|
||||
|
||||
} // namespace socket
|
||||
} // namespace esphome
|
||||
|
||||
#endif // !USE_SOCKET_HAS_LWIP
|
||||
@@ -70,6 +70,7 @@
|
||||
#ifdef USE_ESP_IDF
|
||||
#define USE_ARDUINO_VERSION_CODE VERSION_CODE(4, 3, 0)
|
||||
#endif
|
||||
#define USE_SOCKET_HAS_LWIP
|
||||
#endif
|
||||
|
||||
// ESP8266-specific feature flags
|
||||
@@ -79,6 +80,7 @@
|
||||
#define USE_ESP8266_PREFERENCES_FLASH
|
||||
#define USE_HTTP_REQUEST_ESP8266_HTTPS
|
||||
#define USE_SOCKET_IMPL_LWIP_TCP
|
||||
#define USE_SOCKET_HAS_LWIP
|
||||
#endif
|
||||
|
||||
// Disabled feature flags
|
||||
|
||||
Reference in New Issue
Block a user