Allow using more models like whisper
This commit is contained in:
parent
0d6e46d65a
commit
5f71e3da8a
56
README.md
56
README.md
@ -6,6 +6,58 @@ DND games and transcribe them.
|
||||
Our initial approach is rather naive, using wav2vec 2.0 pre-trained models to
|
||||
perform automated speach recognition
|
||||
|
||||
## Usage
|
||||
## Installation Instructions
|
||||
|
||||
TODO
|
||||
### Optional: Install CUDA
|
||||
|
||||
If you would like to make use of an Nvidia GPU through CUDA make sure to install
|
||||
CUDA first
|
||||
|
||||
### Install with `pip`
|
||||
|
||||
```
|
||||
virtualenv .env
|
||||
.env/bin/pip install .
|
||||
```
|
||||
|
||||
## Usage Examples
|
||||
|
||||
The following examples are using a WAV file with 16kHz sample rate
|
||||
|
||||
### Run against the large Facebook wav2vec2 model using just the CPU
|
||||
|
||||
This is the most versatile and is the most reliable in terms of transcribing
|
||||
human speech in multiple languages. However it is large and relatively slow.
|
||||
|
||||
note: this is the default model when `--model` is not provided
|
||||
|
||||
```
|
||||
./.env/bin/dnd_transcribe \
|
||||
--audio-file example.wav \
|
||||
--model "facebook/wav2vec2-large-960h-lv60-self" \
|
||||
--no-gpu
|
||||
```
|
||||
|
||||
### Run against the base wav2vec2 model
|
||||
|
||||
This is the base of the Facebook wav2vec2 model, which is smaller and less
|
||||
precise, however is small enough to run on a GPU with limited memory.
|
||||
|
||||
```
|
||||
./.env/bin/dnd_transcribe \
|
||||
--audio-file example.wav \
|
||||
--model "facebook/wav2vec2-base-960h"
|
||||
```
|
||||
|
||||
### Run against the small OpenAI whisper model
|
||||
|
||||
See more about the whisper models: https://huggingface.co/openai/whisper-large-v3
|
||||
|
||||
This is the model provided by OpenAI which is smaller and only supports English,
|
||||
but being smaller is able to run on a GPU with limited memory.
|
||||
|
||||
```
|
||||
./.env/bin/dnd_transcribe \
|
||||
--audio-file example.wav \
|
||||
--model "openai/whisper-small.en"
|
||||
```
|
||||
|
@ -12,6 +12,12 @@ def build_argument_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"-q", "--quiet", action="store_true", help="Only display errors"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-gpu",
|
||||
dest="use_gpu",
|
||||
action="store_false",
|
||||
help="Disable using the GPU with CUDA",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--model",
|
||||
@ -20,7 +26,8 @@ def build_argument_parser() -> argparse.ArgumentParser:
|
||||
default=DEFAULT_MODEL,
|
||||
)
|
||||
parser.add_argument(
|
||||
"audio_file",
|
||||
"-f",
|
||||
"--audio-file",
|
||||
type=argparse.FileType(mode="rb"),
|
||||
help="Audio file to process",
|
||||
)
|
||||
|
@ -4,6 +4,9 @@ import time
|
||||
import torch
|
||||
import typing
|
||||
from transformers import (
|
||||
AutoModelForSpeechSeq2Seq,
|
||||
AutoProcessor,
|
||||
pipeline,
|
||||
Wav2Vec2Model,
|
||||
Wav2Vec2ForCTC,
|
||||
Wav2Vec2Processor,
|
||||
@ -17,17 +20,32 @@ AUDIO_SAMPLE_RATE = 16_000
|
||||
|
||||
class InferredTranscript(typing.NamedTuple):
|
||||
transcript: str
|
||||
confidence_score: float
|
||||
confidence_score: typing.Optional[float]
|
||||
processing_time_sec: float
|
||||
|
||||
|
||||
class Inference:
|
||||
def __init__(self, model_name: str, use_gpu: bool = True) -> None:
|
||||
self.device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
|
||||
self.model_name = model_name
|
||||
cuda_available = use_gpu and torch.cuda.is_available()
|
||||
self.device = "cuda" if cuda_available else "cpu"
|
||||
self.torch_dtype = torch.float16 if cuda_available else torch.float32
|
||||
if model_name.startswith("facebook/wav2vec2"):
|
||||
self.processor = Wav2Vec2Processor.from_pretrained(model_name)
|
||||
self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
|
||||
else:
|
||||
self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=self.torch_dtype,
|
||||
low_cpu_mem_usage=True,
|
||||
use_safetensors=True,
|
||||
)
|
||||
self.processor = AutoProcessor.from_pretrained(model_name)
|
||||
self.model.to(self.device)
|
||||
|
||||
def is_wav2vec2(self) -> bool:
|
||||
return self.model_name.startswith("facebook/wav2vec2")
|
||||
|
||||
def file_to_text(self, audio_file: typing.BinaryIO) -> InferredTranscript:
|
||||
audio_input, samplerate = soundfile.read(audio_file)
|
||||
if samplerate != AUDIO_SAMPLE_RATE:
|
||||
@ -40,6 +58,17 @@ class Inference:
|
||||
if len(audio_buffer) == 0:
|
||||
InferredTranscript("", 1, 0)
|
||||
timer_start = time.perf_counter()
|
||||
if self.is_wav2vec2():
|
||||
transcription, confidence = self._wav2vec_buffer_to_text(audio_buffer)
|
||||
else:
|
||||
transcription = self._pipeline_buffer_to_text(audio_buffer)
|
||||
confidence = None
|
||||
timer_end = time.perf_counter()
|
||||
return InferredTranscript(transcription, confidence, timer_end - timer_start)
|
||||
|
||||
def _wav2vec_buffer_to_text(
|
||||
self, audio_buffer: numpy.typing.ArrayLike
|
||||
) -> typing.Tuple[str, float]:
|
||||
inputs = self.processor(
|
||||
torch.tensor(audio_buffer),
|
||||
sampling_rate=AUDIO_SAMPLE_RATE,
|
||||
@ -47,15 +76,31 @@ class Inference:
|
||||
padding=True,
|
||||
)
|
||||
with torch.no_grad():
|
||||
if hasattr(inputs, "attention_mask"):
|
||||
attention_mask = inputs.attention_mask.to(self.device)
|
||||
else:
|
||||
attention_mask = None
|
||||
logits = self.model(
|
||||
inputs.input_values.to(self.device),
|
||||
attention_mask=inputs.attention_mask.to(self.device),
|
||||
attention_mask=attention_mask,
|
||||
).logits
|
||||
predicted_ids = torch.argmax(logits, dim=-1)
|
||||
transcription = self.processor.batch_decode(predicted_ids)[0]
|
||||
timer_end = time.perf_counter()
|
||||
confidence = self.confidence_score(logits, predicted_ids)
|
||||
return InferredTranscript(transcription, confidence, timer_end - timer_start)
|
||||
return (self.processor.batch_decode(predicted_ids)[0], confidence)
|
||||
|
||||
def _pipeline_buffer_to_text(self, audio_buffer: numpy.typing.ArrayLike) -> str:
|
||||
pipe = pipeline(
|
||||
"automatic-speech-recognition",
|
||||
model=self.model,
|
||||
tokenizer=self.processor.tokenizer,
|
||||
feature_extractor=self.processor.feature_extractor,
|
||||
chunk_length_s=30,
|
||||
batch_size=16, # batch size for inference - set based on your device
|
||||
torch_dtype=self.torch_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
result = pipe(audio_buffer)
|
||||
return result["text"]
|
||||
|
||||
def confidence_score(
|
||||
self, logits: torch.Tensor, predicted_ids: torch.Tensor
|
||||
|
@ -13,7 +13,13 @@ def main():
|
||||
logging.basicConfig(level=logging.ERROR, format=logging_format)
|
||||
else:
|
||||
logging.basicConfig(level=logging.INFO, format=logging_format)
|
||||
inference = dnd_transcribe.inference.Inference(args.model)
|
||||
inference = dnd_transcribe.inference.Inference(args.model, use_gpu=args.use_gpu)
|
||||
if args.audio_file is not None:
|
||||
(transcription, score, duration) = inference.file_to_text(args.audio_file)
|
||||
print(transcription)
|
||||
print(f"[Confidence: {score:.1%} in {duration:.2} seconds]")
|
||||
if score is not None:
|
||||
print(f"[Confidence: {score:.1%} in {duration} seconds]")
|
||||
else:
|
||||
print(f"[Confidence -unknown- in {duration} seconds]")
|
||||
else:
|
||||
print("Live transcription is a WIP")
|
||||
|
Loading…
x
Reference in New Issue
Block a user