From db2b6d956b63713d2f7f6698fcf3b5697fa0d2ed Mon Sep 17 00:00:00 2001 From: David Kruger Date: Fri, 25 Apr 2025 23:56:14 -0700 Subject: [PATCH] Use librosa to stream for long files --- dnd_transcribe/argparse.py | 16 +++++++++++++++- dnd_transcribe/inference.py | 34 +++++++++++++++++++++++++++++++--- dnd_transcribe/main.py | 22 ++++++++++++++++------ requirements.txt | 3 ++- 4 files changed, 64 insertions(+), 11 deletions(-) diff --git a/dnd_transcribe/argparse.py b/dnd_transcribe/argparse.py index b86be73..babe7cd 100644 --- a/dnd_transcribe/argparse.py +++ b/dnd_transcribe/argparse.py @@ -1,4 +1,5 @@ import argparse +import os from dnd_transcribe.inference import DEFAULT_MODEL @@ -29,6 +30,19 @@ def build_argument_parser() -> argparse.ArgumentParser: "-f", "--audio-file", type=argparse.FileType(mode="rb"), - help="Audio file to process", + help="Audio file to process, for long audo see --stream-audio-file", + ) + parser.add_argument( + "-s", + "--stream-audio-file", + type=valid_file_path, + help="Audio file to process by streaming", ) return parser + + +def valid_file_path(path: str) -> str: + path = os.path.realpath(path) + if os.path.isfile(path): + return path + raise argparse.ArgumentTypeError("{} is not a valid file".format(path)) diff --git a/dnd_transcribe/inference.py b/dnd_transcribe/inference.py index 15bebfb..3665507 100644 --- a/dnd_transcribe/inference.py +++ b/dnd_transcribe/inference.py @@ -1,4 +1,4 @@ -import soundfile +import librosa import numpy.typing import time import torch @@ -46,10 +46,38 @@ class Inference: def is_wav2vec2(self) -> bool: return self.model_name.startswith("facebook/wav2vec2") - def file_to_text(self, audio_file: typing.BinaryIO) -> InferredTranscript: - audio_input, samplerate = soundfile.read(audio_file) + def stream_file_to_text(self, audio_file_path: str) -> InferredTranscript: + samplerate = librosa.get_samplerate(audio_file_path) if samplerate != AUDIO_SAMPLE_RATE: raise Exception(f"Unsupported sample rate {samplerate}") + stream = librosa.stream( + audio_file_path, + block_length=20, + frame_length=AUDIO_SAMPLE_RATE, + hop_length=AUDIO_SAMPLE_RATE, + ) + transcript = "" + confidence = None + processing_time = 0.0 + for block in stream: + if len(block.shape) > 1: + block = speech[:, 0] + speech[:, 1] + block_inference = self.buffer_to_text(block) + transcript += block_inference.transcript + " " + processing_time += block_inference.processing_time_sec + if block_inference.confidence_score is not None: + if confidence is None: + confidence = block_inference.confidence_score + else: + confidence *= block_inference.confidence_score + return InferredTranscript(transcript.strip(), confidence, processing_time) + + def file_to_text(self, audio_file: typing.BinaryIO) -> InferredTranscript: + audio_input, samplerate = librosa.load(audio_file) + if samplerate != AUDIO_SAMPLE_RATE: + audio_input = librosa.resample( + audio_input, orig_sr=samplerate, target_sr=AUDIO_SAMPLE_RATE + ) return self.buffer_to_text(audio_input) def buffer_to_text( diff --git a/dnd_transcribe/main.py b/dnd_transcribe/main.py index ec06ab8..3d1f471 100644 --- a/dnd_transcribe/main.py +++ b/dnd_transcribe/main.py @@ -15,11 +15,21 @@ def main(): logging.basicConfig(level=logging.INFO, format=logging_format) inference = dnd_transcribe.inference.Inference(args.model, use_gpu=args.use_gpu) if args.audio_file is not None: - (transcription, score, duration) = inference.file_to_text(args.audio_file) - print(transcription) - if score is not None: - print(f"[Confidence: {score:.1%} in {duration} seconds]") - else: - print(f"[Confidence -unknown- in {duration} seconds]") + _print_inferred_transcript(inference.file_to_text(args.audio_file)) + elif args.stream_audio_file is not None: + _print_inferred_transcript( + inference.stream_file_to_text(args.stream_audio_file) + ) else: print("Live transcription is a WIP") + + +def _print_inferred_transcript( + transcript: dnd_transcribe.inference.InferredTranscript, +) -> None: + (transcription, score, duration) = transcript + print(transcription) + if score is not None: + print(f"[Confidence: {score:.1%} in {duration} seconds]") + else: + print(f"[Confidence -unknown- in {duration} seconds]") diff --git a/requirements.txt b/requirements.txt index 55d6b56..225425e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ -soundfile>=0.13.1 +librosa>=0.11.0 +numpy>=2.2.5 torch>=2.6.0 transformers>=4.51.3