Allow using more models like whisper

This commit is contained in:
David Kruger 2025-04-25 22:57:48 -07:00
parent 0d6e46d65a
commit 5f71e3da8a
4 changed files with 125 additions and 15 deletions

View File

@ -6,6 +6,58 @@ DND games and transcribe them.
Our initial approach is rather naive, using wav2vec 2.0 pre-trained models to
perform automated speach recognition
## Usage
## Installation Instructions
TODO
### Optional: Install CUDA
If you would like to make use of an Nvidia GPU through CUDA make sure to install
CUDA first
### Install with `pip`
```
virtualenv .env
.env/bin/pip install .
```
## Usage Examples
The following examples are using a WAV file with 16kHz sample rate
### Run against the large Facebook wav2vec2 model using just the CPU
This is the most versatile and is the most reliable in terms of transcribing
human speech in multiple languages. However it is large and relatively slow.
note: this is the default model when `--model` is not provided
```
./.env/bin/dnd_transcribe \
--audio-file example.wav \
--model "facebook/wav2vec2-large-960h-lv60-self" \
--no-gpu
```
### Run against the base wav2vec2 model
This is the base of the Facebook wav2vec2 model, which is smaller and less
precise, however is small enough to run on a GPU with limited memory.
```
./.env/bin/dnd_transcribe \
--audio-file example.wav \
--model "facebook/wav2vec2-base-960h"
```
### Run against the small OpenAI whisper model
See more about the whisper models: https://huggingface.co/openai/whisper-large-v3
This is the model provided by OpenAI which is smaller and only supports English,
but being smaller is able to run on a GPU with limited memory.
```
./.env/bin/dnd_transcribe \
--audio-file example.wav \
--model "openai/whisper-small.en"
```

View File

@ -12,6 +12,12 @@ def build_argument_parser() -> argparse.ArgumentParser:
parser.add_argument(
"-q", "--quiet", action="store_true", help="Only display errors"
)
parser.add_argument(
"--no-gpu",
dest="use_gpu",
action="store_false",
help="Disable using the GPU with CUDA",
)
parser.add_argument(
"-m",
"--model",
@ -20,7 +26,8 @@ def build_argument_parser() -> argparse.ArgumentParser:
default=DEFAULT_MODEL,
)
parser.add_argument(
"audio_file",
"-f",
"--audio-file",
type=argparse.FileType(mode="rb"),
help="Audio file to process",
)

View File

@ -4,6 +4,9 @@ import time
import torch
import typing
from transformers import (
AutoModelForSpeechSeq2Seq,
AutoProcessor,
pipeline,
Wav2Vec2Model,
Wav2Vec2ForCTC,
Wav2Vec2Processor,
@ -17,17 +20,32 @@ AUDIO_SAMPLE_RATE = 16_000
class InferredTranscript(typing.NamedTuple):
transcript: str
confidence_score: float
confidence_score: typing.Optional[float]
processing_time_sec: float
class Inference:
def __init__(self, model_name: str, use_gpu: bool = True) -> None:
self.device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
self.model_name = model_name
cuda_available = use_gpu and torch.cuda.is_available()
self.device = "cuda" if cuda_available else "cpu"
self.torch_dtype = torch.float16 if cuda_available else torch.float32
if model_name.startswith("facebook/wav2vec2"):
self.processor = Wav2Vec2Processor.from_pretrained(model_name)
self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
else:
self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_name,
torch_dtype=self.torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
)
self.processor = AutoProcessor.from_pretrained(model_name)
self.model.to(self.device)
def is_wav2vec2(self) -> bool:
return self.model_name.startswith("facebook/wav2vec2")
def file_to_text(self, audio_file: typing.BinaryIO) -> InferredTranscript:
audio_input, samplerate = soundfile.read(audio_file)
if samplerate != AUDIO_SAMPLE_RATE:
@ -40,6 +58,17 @@ class Inference:
if len(audio_buffer) == 0:
InferredTranscript("", 1, 0)
timer_start = time.perf_counter()
if self.is_wav2vec2():
transcription, confidence = self._wav2vec_buffer_to_text(audio_buffer)
else:
transcription = self._pipeline_buffer_to_text(audio_buffer)
confidence = None
timer_end = time.perf_counter()
return InferredTranscript(transcription, confidence, timer_end - timer_start)
def _wav2vec_buffer_to_text(
self, audio_buffer: numpy.typing.ArrayLike
) -> typing.Tuple[str, float]:
inputs = self.processor(
torch.tensor(audio_buffer),
sampling_rate=AUDIO_SAMPLE_RATE,
@ -47,15 +76,31 @@ class Inference:
padding=True,
)
with torch.no_grad():
if hasattr(inputs, "attention_mask"):
attention_mask = inputs.attention_mask.to(self.device)
else:
attention_mask = None
logits = self.model(
inputs.input_values.to(self.device),
attention_mask=inputs.attention_mask.to(self.device),
attention_mask=attention_mask,
).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = self.processor.batch_decode(predicted_ids)[0]
timer_end = time.perf_counter()
confidence = self.confidence_score(logits, predicted_ids)
return InferredTranscript(transcription, confidence, timer_end - timer_start)
return (self.processor.batch_decode(predicted_ids)[0], confidence)
def _pipeline_buffer_to_text(self, audio_buffer: numpy.typing.ArrayLike) -> str:
pipe = pipeline(
"automatic-speech-recognition",
model=self.model,
tokenizer=self.processor.tokenizer,
feature_extractor=self.processor.feature_extractor,
chunk_length_s=30,
batch_size=16, # batch size for inference - set based on your device
torch_dtype=self.torch_dtype,
device=self.device,
)
result = pipe(audio_buffer)
return result["text"]
def confidence_score(
self, logits: torch.Tensor, predicted_ids: torch.Tensor

View File

@ -13,7 +13,13 @@ def main():
logging.basicConfig(level=logging.ERROR, format=logging_format)
else:
logging.basicConfig(level=logging.INFO, format=logging_format)
inference = dnd_transcribe.inference.Inference(args.model)
inference = dnd_transcribe.inference.Inference(args.model, use_gpu=args.use_gpu)
if args.audio_file is not None:
(transcription, score, duration) = inference.file_to_text(args.audio_file)
print(transcription)
print(f"[Confidence: {score:.1%} in {duration:.2} seconds]")
if score is not None:
print(f"[Confidence: {score:.1%} in {duration} seconds]")
else:
print(f"[Confidence -unknown- in {duration} seconds]")
else:
print("Live transcription is a WIP")