diff --git a/dnd_transcribe/argparse.py b/dnd_transcribe/argparse.py index b86be73..81133b1 100644 --- a/dnd_transcribe/argparse.py +++ b/dnd_transcribe/argparse.py @@ -1,4 +1,5 @@ import argparse +import os from dnd_transcribe.inference import DEFAULT_MODEL @@ -18,6 +19,12 @@ def build_argument_parser() -> argparse.ArgumentParser: action="store_false", help="Disable using the GPU with CUDA", ) + parser.add_argument( + "--block-len", + type=int, + default=30, + help="Block length in seconds of audio sent when streaming or to whisper", + ) parser.add_argument( "-m", "--model", @@ -29,6 +36,19 @@ def build_argument_parser() -> argparse.ArgumentParser: "-f", "--audio-file", type=argparse.FileType(mode="rb"), - help="Audio file to process", + help="Audio file to process, for long audo see --stream-audio-file", + ) + parser.add_argument( + "-s", + "--stream-audio-file", + type=valid_file_path, + help="Audio file to process by streaming", ) return parser + + +def valid_file_path(path: str) -> str: + path = os.path.realpath(path) + if os.path.isfile(path): + return path + raise argparse.ArgumentTypeError("{} is not a valid file".format(path)) diff --git a/dnd_transcribe/inference.py b/dnd_transcribe/inference.py index 15bebfb..6691e3d 100644 --- a/dnd_transcribe/inference.py +++ b/dnd_transcribe/inference.py @@ -1,4 +1,4 @@ -import soundfile +import librosa import numpy.typing import time import torch @@ -25,8 +25,11 @@ class InferredTranscript(typing.NamedTuple): class Inference: - def __init__(self, model_name: str, use_gpu: bool = True) -> None: + def __init__( + self, model_name: str, block_len: int = 20, use_gpu: bool = True + ) -> None: self.model_name = model_name + self.block_len = block_len cuda_available = use_gpu and torch.cuda.is_available() self.device = "cuda" if cuda_available else "cpu" self.torch_dtype = torch.float16 if cuda_available else torch.float32 @@ -46,10 +49,42 @@ class Inference: def is_wav2vec2(self) -> bool: return self.model_name.startswith("facebook/wav2vec2") - def file_to_text(self, audio_file: typing.BinaryIO) -> InferredTranscript: - audio_input, samplerate = soundfile.read(audio_file) + def stream_file_to_text(self, audio_file_path: str) -> InferredTranscript: + samplerate = librosa.get_samplerate(audio_file_path) if samplerate != AUDIO_SAMPLE_RATE: raise Exception(f"Unsupported sample rate {samplerate}") + stream = librosa.stream( + audio_file_path, + block_length=self.block_len, + frame_length=AUDIO_SAMPLE_RATE, + hop_length=AUDIO_SAMPLE_RATE, + ) + transcript = "" + confidence = None + processing_time = 0.0 + for block in stream: + if len(block.shape) > 1: + block = speech[:, 0] + speech[:, 1] + try: + block_inference = self.buffer_to_text(block) + transcript += block_inference.transcript + " " + processing_time += block_inference.processing_time_sec + if block_inference.confidence_score is not None: + if confidence is None: + confidence = block_inference.confidence_score + else: + confidence *= block_inference.confidence_score + except torch.OutOfMemoryError as e: + print(e) + break + return InferredTranscript(transcript.strip(), confidence, processing_time) + + def file_to_text(self, audio_file: typing.BinaryIO) -> InferredTranscript: + audio_input, samplerate = librosa.load(audio_file) + if samplerate != AUDIO_SAMPLE_RATE: + audio_input = librosa.resample( + audio_input, orig_sr=samplerate, target_sr=AUDIO_SAMPLE_RATE + ) return self.buffer_to_text(audio_input) def buffer_to_text( @@ -94,7 +129,7 @@ class Inference: model=self.model, tokenizer=self.processor.tokenizer, feature_extractor=self.processor.feature_extractor, - chunk_length_s=30, + chunk_length_s=self.block_len, batch_size=16, # batch size for inference - set based on your device torch_dtype=self.torch_dtype, device=self.device, diff --git a/dnd_transcribe/main.py b/dnd_transcribe/main.py index ec06ab8..56d2081 100644 --- a/dnd_transcribe/main.py +++ b/dnd_transcribe/main.py @@ -13,13 +13,27 @@ def main(): logging.basicConfig(level=logging.ERROR, format=logging_format) else: logging.basicConfig(level=logging.INFO, format=logging_format) - inference = dnd_transcribe.inference.Inference(args.model, use_gpu=args.use_gpu) + inference = dnd_transcribe.inference.Inference( + args.model, + block_len=args.block_len, + use_gpu=args.use_gpu, + ) if args.audio_file is not None: - (transcription, score, duration) = inference.file_to_text(args.audio_file) - print(transcription) - if score is not None: - print(f"[Confidence: {score:.1%} in {duration} seconds]") - else: - print(f"[Confidence -unknown- in {duration} seconds]") + _print_inferred_transcript(inference.file_to_text(args.audio_file)) + elif args.stream_audio_file is not None: + _print_inferred_transcript( + inference.stream_file_to_text(args.stream_audio_file) + ) else: print("Live transcription is a WIP") + + +def _print_inferred_transcript( + transcript: dnd_transcribe.inference.InferredTranscript, +) -> None: + (transcription, score, duration) = transcript + print(transcription) + if score is not None: + print(f"[Confidence: {score:.1%} in {duration} seconds]") + else: + print(f"[Confidence -unknown- in {duration} seconds]") diff --git a/requirements.txt b/requirements.txt index 55d6b56..225425e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ -soundfile>=0.13.1 +librosa>=0.11.0 +numpy>=2.2.5 torch>=2.6.0 transformers>=4.51.3