Pipeline updates (#15926)

* Allow using TTS

* Allow streaming audio from frontens to STT

* Improve stop recording

* Even better stop
This commit is contained in:
Paulus Schoutsen 2023-03-26 22:42:08 -04:00 committed by GitHub
parent 395358b192
commit 520f489830
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 226 additions and 39 deletions

View File

@ -12,6 +12,9 @@ interface PipelineRunStartEvent extends PipelineEventBase {
data: {
pipeline: string;
language: string;
runner_data: {
stt_binary_handler_id: number | null;
};
};
}
interface PipelineRunEndEvent extends PipelineEventBase {

View File

@ -1,6 +1,7 @@
import { css, html, LitElement, TemplateResult } from "lit";
import { customElement, property, query, state } from "lit/decorators";
import { css, html, LitElement, PropertyValues, TemplateResult } from "lit";
import { customElement, property, state } from "lit/decorators";
import "../../../../../../components/ha-card";
import "../../../../../../components/ha-alert";
import "../../../../../../components/ha-button";
import "../../../../../../components/ha-circular-progress";
import "../../../../../../components/ha-expansion-panel";
@ -14,17 +15,13 @@ import { SubscribeMixin } from "../../../../../../mixins/subscribe-mixin";
import { haStyle } from "../../../../../../resources/styles";
import type { HomeAssistant } from "../../../../../../types";
import { formatNumber } from "../../../../../../common/number/format_number";
import { showPromptDialog } from "../../../../../../dialogs/generic/show-dialog-box";
const RUN_DATA = {
pipeline: "Pipeline",
language: "Language",
};
const ERROR_DATA = {
code: "Code",
message: "Message",
};
const STT_DATA = {
engine: "Engine",
};
@ -52,6 +49,20 @@ const hasStage = (run: PipelineRun, stage: PipelineRun["stage"]) =>
STAGES[run.init_options.start_stage] <= STAGES[stage] &&
STAGES[stage] <= STAGES[run.init_options.end_stage];
const maybeRenderError = (
run: PipelineRun,
stage: string,
lastRunStage: string
) => {
if (run.stage !== "error" || lastRunStage !== stage) {
return "";
}
return html`<ha-alert alert-type="error">
${run.error!.message} (${run.error!.code})
</ha-alert>`;
};
const renderProgress = (
hass: HomeAssistant,
pipelineRun: PipelineRun,
@ -68,6 +79,10 @@ const renderProgress = (
return "";
}
if (pipelineRun.stage === "error") {
return html``;
}
if (!finishEvent) {
return html`<ha-circular-progress
size="tiny"
@ -117,13 +132,19 @@ export class AssistPipelineDebug extends SubscribeMixin(LitElement) {
@property({ type: Boolean }) public narrow!: boolean;
@query("#run-input", true)
private _newRunInput!: HTMLInputElement;
@state() private _pipelineRun?: PipelineRun;
@state()
private _pipelineRun?: PipelineRun;
@state() private _stopRecording?: () => void;
private _audioBuffer?: Int16Array[];
protected render(): TemplateResult {
const lastRunStage: string = this._pipelineRun
? ["tts", "intent", "stt"].find(
(stage) => this._pipelineRun![stage] !== undefined
) || "ready"
: "ready";
return html`
<hass-subpage
.narrow=${this.narrow}
@ -131,24 +152,25 @@ export class AssistPipelineDebug extends SubscribeMixin(LitElement) {
header="Assist Pipeline"
>
<div class="content">
<ha-card header="Run pipeline" class="run-pipeline-card">
<div class="card-content">
<ha-textfield
id="run-input"
label="Input"
value="Are the lights on?"
></ha-textfield>
</div>
<div class="card-actions">
<ha-button
@click=${this._runPipeline}
.disabled=${this._pipelineRun &&
!["error", "done"].includes(this._pipelineRun.stage)}
>
Run
</ha-button>
</div>
</ha-card>
<div class="start-row">
<ha-button
raised
@click=${this._runTextPipeline}
.disabled=${this._pipelineRun &&
!["error", "done"].includes(this._pipelineRun.stage)}
>
Run Text Pipeline
</ha-button>
<ha-button
raised
@click=${this._runAudioPipeline}
.disabled=${this._pipelineRun &&
!["error", "done"].includes(this._pipelineRun.stage)}
>
Run Audio Pipeline
</ha-button>
</div>
${this._pipelineRun
? html`
<ha-card>
@ -159,12 +181,10 @@ export class AssistPipelineDebug extends SubscribeMixin(LitElement) {
</div>
${renderData(this._pipelineRun.run, RUN_DATA)}
${this._pipelineRun.error
? renderData(this._pipelineRun.error, ERROR_DATA)
: ""}
</div>
</ha-card>
${maybeRenderError(this._pipelineRun, "ready", lastRunStage)}
${hasStage(this._pipelineRun, "stt")
? html`
<ha-card>
@ -189,9 +209,20 @@ export class AssistPipelineDebug extends SubscribeMixin(LitElement) {
`
: ""}
</div>
${this._pipelineRun.stage === "stt" &&
this._stopRecording
? html`
<div class="card-actions">
<ha-button @click=${this._stopRecording}>
Stop Recording
</ha-button>
</div>
`
: ""}
</ha-card>
`
: ""}
${maybeRenderError(this._pipelineRun, "stt", lastRunStage)}
${hasStage(this._pipelineRun, "intent")
? html`
<ha-card>
@ -222,6 +253,7 @@ export class AssistPipelineDebug extends SubscribeMixin(LitElement) {
</ha-card>
`
: ""}
${maybeRenderError(this._pipelineRun, "intent", lastRunStage)}
${hasStage(this._pipelineRun, "tts")
? html`
<ha-card>
@ -238,17 +270,23 @@ export class AssistPipelineDebug extends SubscribeMixin(LitElement) {
? html`
<div class="card-content">
${renderData(this._pipelineRun.tts, TTS_DATA)}
${dataMinusKeysRender(
this._pipelineRun.tts,
TTS_DATA
)}
</div>
`
: ""}
</div>
${this._pipelineRun?.tts?.tts_output
? html`
<div class="card-actions">
<ha-button @click=${this._playTTS}>
Play Audio
</ha-button>
</div>
`
: ""}
</ha-card>
`
: ""}
${maybeRenderError(this._pipelineRun, "tts", lastRunStage)}
<ha-card>
<ha-expansion-panel>
<span slot="header">Raw</span>
@ -262,7 +300,40 @@ export class AssistPipelineDebug extends SubscribeMixin(LitElement) {
`;
}
private _runPipeline(): void {
protected updated(changedProperties: PropertyValues): void {
super.updated(changedProperties);
if (
!changedProperties.has("_pipelineRun") ||
!this._pipelineRun ||
this._pipelineRun.init_options.start_stage !== "stt"
) {
return;
}
if (this._pipelineRun.stage === "stt" && this._audioBuffer) {
// Send the buffer over the WS to the STT engine.
for (const buffer of this._audioBuffer) {
this._sendAudioChunk(buffer);
}
this._audioBuffer = undefined;
}
if (this._pipelineRun.stage !== "stt" && this._stopRecording) {
this._stopRecording();
}
}
private async _runTextPipeline() {
const text = await showPromptDialog(this, {
title: "Input text",
confirmText: "Run",
});
if (!text) {
return;
}
this._pipelineRun = undefined;
runPipelineFromText(
this.hass,
@ -272,11 +343,72 @@ export class AssistPipelineDebug extends SubscribeMixin(LitElement) {
{
start_stage: "intent",
end_stage: "intent",
input: { text: this._newRunInput.value },
input: { text },
}
);
}
private async _runAudioPipeline() {
// @ts-ignore-next-line
const context = new (window.AudioContext || window.webkitAudioContext)();
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
await context.audioWorklet.addModule(
new URL("./recorder.worklet.js", import.meta.url)
);
const source = context.createMediaStreamSource(stream);
const recorder = new AudioWorkletNode(context, "recorder.worklet");
this.hass.connection.socket!.binaryType = "arraybuffer";
this._stopRecording = () => {
stream.getTracks()[0].stop();
context.close();
this._stopRecording = undefined;
this._audioBuffer = undefined;
// Send empty message to indicate we're done streaming.
this._sendAudioChunk(new Int16Array());
};
this._audioBuffer = [];
source.connect(recorder).connect(context.destination);
recorder.port.onmessage = (e) => {
if (this._audioBuffer) {
this._audioBuffer.push(e.data);
return;
}
if (this._pipelineRun?.stage !== "stt") {
return;
}
this._sendAudioChunk(e.data);
};
this._pipelineRun = undefined;
runPipelineFromText(
this.hass,
(run) => {
this._pipelineRun = run;
},
{
start_stage: "stt",
end_stage: "tts",
}
);
}
private _sendAudioChunk(chunk: Int16Array) {
// Turn into 8 bit so we can prefix our handler ID.
const data = new Uint8Array(1 + chunk.length * 2);
data[0] = this._pipelineRun!.run.runner_data.stt_binary_handler_id!;
data.set(new Uint8Array(chunk.buffer), 1);
this.hass.connection.socket!.send(data);
}
private _playTTS(): void {
const url = this._pipelineRun!.tts!.tts_output!.url;
const audio = new Audio(url);
audio.play();
}
static styles = [
haStyle,
css`
@ -286,7 +418,15 @@ export class AssistPipelineDebug extends SubscribeMixin(LitElement) {
margin: 0 auto;
direction: ltr;
}
ha-card {
.start-row {
text-align: center;
}
.start-row ha-button {
margin: 16px;
}
ha-card,
ha-alert {
display: block;
margin-bottom: 16px;
}
.run-pipeline-card ha-textfield {

View File

@ -0,0 +1,21 @@
class RecorderProcessor extends AudioWorkletProcessor {
process(inputList, _outputList, _parameters) {
if (inputList[0].length < 1) {
return true;
}
const float32Data = inputList[0][0];
const int16Data = new Int16Array(float32Data.length);
for (let i = 0; i < float32Data.length; i++) {
const s = Math.max(-1, Math.min(1, float32Data[i]));
int16Data[i] = s < 0 ? s * 0x8000 : s * 0x7fff;
}
this.port.postMessage(int16Data);
return true;
}
}
registerProcessor("recorder.worklet", RecorderProcessor);

23
src/types/audio.d.ts vendored Normal file
View File

@ -0,0 +1,23 @@
interface AudioWorkletProcessor {
readonly port: MessagePort;
process(
inputs: Float32Array[][],
outputs: Float32Array[][],
parameters: Record<string, Float32Array>
): boolean;
}
/* eslint-disable */
declare var AudioWorkletProcessor: {
prototype: AudioWorkletProcessor;
new (options?: AudioWorkletNodeOptions): AudioWorkletProcessor;
};
declare function registerProcessor(
name: string,
processorCtor: (new (
options?: AudioWorkletNodeOptions
) => AudioWorkletProcessor) & {
parameterDescriptors?: AudioParamDescriptor[];
}
);