diff --git a/README.md b/README.md index fc4327d..2f30082 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,58 @@ DND games and transcribe them. Our initial approach is rather naive, using wav2vec 2.0 pre-trained models to perform automated speach recognition -## Usage +## Installation Instructions -TODO +### Optional: Install CUDA + +If you would like to make use of an Nvidia GPU through CUDA make sure to install +CUDA first + +### Install with `pip` + +``` +virtualenv .env +.env/bin/pip install . +``` + +## Usage Examples + +The following examples are using a WAV file with 16kHz sample rate + +### Run against the large Facebook wav2vec2 model using just the CPU + +This is the most versatile and is the most reliable in terms of transcribing +human speech in multiple languages. However it is large and relatively slow. + +note: this is the default model when `--model` is not provided + +``` +./.env/bin/dnd_transcribe \ + --audio-file example.wav \ + --model "facebook/wav2vec2-large-960h-lv60-self" \ + --no-gpu +``` + +### Run against the base wav2vec2 model + +This is the base of the Facebook wav2vec2 model, which is smaller and less +precise, however is small enough to run on a GPU with limited memory. + +``` +./.env/bin/dnd_transcribe \ + --audio-file example.wav \ + --model "facebook/wav2vec2-base-960h" +``` + +### Run against the small OpenAI whisper model + +See more about the whisper models: https://huggingface.co/openai/whisper-large-v3 + +This is the model provided by OpenAI which is smaller and only supports English, +but being smaller is able to run on a GPU with limited memory. + +``` +./.env/bin/dnd_transcribe \ + --audio-file example.wav \ + --model "openai/whisper-small.en" +``` diff --git a/dnd_transcribe/argparse.py b/dnd_transcribe/argparse.py index 4f14af0..b86be73 100644 --- a/dnd_transcribe/argparse.py +++ b/dnd_transcribe/argparse.py @@ -12,6 +12,12 @@ def build_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "-q", "--quiet", action="store_true", help="Only display errors" ) + parser.add_argument( + "--no-gpu", + dest="use_gpu", + action="store_false", + help="Disable using the GPU with CUDA", + ) parser.add_argument( "-m", "--model", @@ -20,7 +26,8 @@ def build_argument_parser() -> argparse.ArgumentParser: default=DEFAULT_MODEL, ) parser.add_argument( - "audio_file", + "-f", + "--audio-file", type=argparse.FileType(mode="rb"), help="Audio file to process", ) diff --git a/dnd_transcribe/inference.py b/dnd_transcribe/inference.py index f83a696..15bebfb 100644 --- a/dnd_transcribe/inference.py +++ b/dnd_transcribe/inference.py @@ -4,6 +4,9 @@ import time import torch import typing from transformers import ( + AutoModelForSpeechSeq2Seq, + AutoProcessor, + pipeline, Wav2Vec2Model, Wav2Vec2ForCTC, Wav2Vec2Processor, @@ -17,17 +20,32 @@ AUDIO_SAMPLE_RATE = 16_000 class InferredTranscript(typing.NamedTuple): transcript: str - confidence_score: float + confidence_score: typing.Optional[float] processing_time_sec: float class Inference: def __init__(self, model_name: str, use_gpu: bool = True) -> None: - self.device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu" - self.processor = Wav2Vec2Processor.from_pretrained(model_name) - self.model = Wav2Vec2ForCTC.from_pretrained(model_name) + self.model_name = model_name + cuda_available = use_gpu and torch.cuda.is_available() + self.device = "cuda" if cuda_available else "cpu" + self.torch_dtype = torch.float16 if cuda_available else torch.float32 + if model_name.startswith("facebook/wav2vec2"): + self.processor = Wav2Vec2Processor.from_pretrained(model_name) + self.model = Wav2Vec2ForCTC.from_pretrained(model_name) + else: + self.model = AutoModelForSpeechSeq2Seq.from_pretrained( + model_name, + torch_dtype=self.torch_dtype, + low_cpu_mem_usage=True, + use_safetensors=True, + ) + self.processor = AutoProcessor.from_pretrained(model_name) self.model.to(self.device) + def is_wav2vec2(self) -> bool: + return self.model_name.startswith("facebook/wav2vec2") + def file_to_text(self, audio_file: typing.BinaryIO) -> InferredTranscript: audio_input, samplerate = soundfile.read(audio_file) if samplerate != AUDIO_SAMPLE_RATE: @@ -40,6 +58,17 @@ class Inference: if len(audio_buffer) == 0: InferredTranscript("", 1, 0) timer_start = time.perf_counter() + if self.is_wav2vec2(): + transcription, confidence = self._wav2vec_buffer_to_text(audio_buffer) + else: + transcription = self._pipeline_buffer_to_text(audio_buffer) + confidence = None + timer_end = time.perf_counter() + return InferredTranscript(transcription, confidence, timer_end - timer_start) + + def _wav2vec_buffer_to_text( + self, audio_buffer: numpy.typing.ArrayLike + ) -> typing.Tuple[str, float]: inputs = self.processor( torch.tensor(audio_buffer), sampling_rate=AUDIO_SAMPLE_RATE, @@ -47,15 +76,31 @@ class Inference: padding=True, ) with torch.no_grad(): + if hasattr(inputs, "attention_mask"): + attention_mask = inputs.attention_mask.to(self.device) + else: + attention_mask = None logits = self.model( inputs.input_values.to(self.device), - attention_mask=inputs.attention_mask.to(self.device), + attention_mask=attention_mask, ).logits predicted_ids = torch.argmax(logits, dim=-1) - transcription = self.processor.batch_decode(predicted_ids)[0] - timer_end = time.perf_counter() confidence = self.confidence_score(logits, predicted_ids) - return InferredTranscript(transcription, confidence, timer_end - timer_start) + return (self.processor.batch_decode(predicted_ids)[0], confidence) + + def _pipeline_buffer_to_text(self, audio_buffer: numpy.typing.ArrayLike) -> str: + pipe = pipeline( + "automatic-speech-recognition", + model=self.model, + tokenizer=self.processor.tokenizer, + feature_extractor=self.processor.feature_extractor, + chunk_length_s=30, + batch_size=16, # batch size for inference - set based on your device + torch_dtype=self.torch_dtype, + device=self.device, + ) + result = pipe(audio_buffer) + return result["text"] def confidence_score( self, logits: torch.Tensor, predicted_ids: torch.Tensor diff --git a/dnd_transcribe/main.py b/dnd_transcribe/main.py index 6370ecf..ec06ab8 100644 --- a/dnd_transcribe/main.py +++ b/dnd_transcribe/main.py @@ -13,7 +13,13 @@ def main(): logging.basicConfig(level=logging.ERROR, format=logging_format) else: logging.basicConfig(level=logging.INFO, format=logging_format) - inference = dnd_transcribe.inference.Inference(args.model) - (transcription, score, duration) = inference.file_to_text(args.audio_file) - print(transcription) - print(f"[Confidence: {score:.1%} in {duration:.2} seconds]") + inference = dnd_transcribe.inference.Inference(args.model, use_gpu=args.use_gpu) + if args.audio_file is not None: + (transcription, score, duration) = inference.file_to_text(args.audio_file) + print(transcription) + if score is not None: + print(f"[Confidence: {score:.1%} in {duration} seconds]") + else: + print(f"[Confidence -unknown- in {duration} seconds]") + else: + print("Live transcription is a WIP")