Allow using more models like whisper
This commit is contained in:
parent
0d6e46d65a
commit
5f71e3da8a
56
README.md
56
README.md
@ -6,6 +6,58 @@ DND games and transcribe them.
|
|||||||
Our initial approach is rather naive, using wav2vec 2.0 pre-trained models to
|
Our initial approach is rather naive, using wav2vec 2.0 pre-trained models to
|
||||||
perform automated speach recognition
|
perform automated speach recognition
|
||||||
|
|
||||||
## Usage
|
## Installation Instructions
|
||||||
|
|
||||||
TODO
|
### Optional: Install CUDA
|
||||||
|
|
||||||
|
If you would like to make use of an Nvidia GPU through CUDA make sure to install
|
||||||
|
CUDA first
|
||||||
|
|
||||||
|
### Install with `pip`
|
||||||
|
|
||||||
|
```
|
||||||
|
virtualenv .env
|
||||||
|
.env/bin/pip install .
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage Examples
|
||||||
|
|
||||||
|
The following examples are using a WAV file with 16kHz sample rate
|
||||||
|
|
||||||
|
### Run against the large Facebook wav2vec2 model using just the CPU
|
||||||
|
|
||||||
|
This is the most versatile and is the most reliable in terms of transcribing
|
||||||
|
human speech in multiple languages. However it is large and relatively slow.
|
||||||
|
|
||||||
|
note: this is the default model when `--model` is not provided
|
||||||
|
|
||||||
|
```
|
||||||
|
./.env/bin/dnd_transcribe \
|
||||||
|
--audio-file example.wav \
|
||||||
|
--model "facebook/wav2vec2-large-960h-lv60-self" \
|
||||||
|
--no-gpu
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run against the base wav2vec2 model
|
||||||
|
|
||||||
|
This is the base of the Facebook wav2vec2 model, which is smaller and less
|
||||||
|
precise, however is small enough to run on a GPU with limited memory.
|
||||||
|
|
||||||
|
```
|
||||||
|
./.env/bin/dnd_transcribe \
|
||||||
|
--audio-file example.wav \
|
||||||
|
--model "facebook/wav2vec2-base-960h"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run against the small OpenAI whisper model
|
||||||
|
|
||||||
|
See more about the whisper models: https://huggingface.co/openai/whisper-large-v3
|
||||||
|
|
||||||
|
This is the model provided by OpenAI which is smaller and only supports English,
|
||||||
|
but being smaller is able to run on a GPU with limited memory.
|
||||||
|
|
||||||
|
```
|
||||||
|
./.env/bin/dnd_transcribe \
|
||||||
|
--audio-file example.wav \
|
||||||
|
--model "openai/whisper-small.en"
|
||||||
|
```
|
||||||
|
@ -12,6 +12,12 @@ def build_argument_parser() -> argparse.ArgumentParser:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-q", "--quiet", action="store_true", help="Only display errors"
|
"-q", "--quiet", action="store_true", help="Only display errors"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-gpu",
|
||||||
|
dest="use_gpu",
|
||||||
|
action="store_false",
|
||||||
|
help="Disable using the GPU with CUDA",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-m",
|
"-m",
|
||||||
"--model",
|
"--model",
|
||||||
@ -20,7 +26,8 @@ def build_argument_parser() -> argparse.ArgumentParser:
|
|||||||
default=DEFAULT_MODEL,
|
default=DEFAULT_MODEL,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"audio_file",
|
"-f",
|
||||||
|
"--audio-file",
|
||||||
type=argparse.FileType(mode="rb"),
|
type=argparse.FileType(mode="rb"),
|
||||||
help="Audio file to process",
|
help="Audio file to process",
|
||||||
)
|
)
|
||||||
|
@ -4,6 +4,9 @@ import time
|
|||||||
import torch
|
import torch
|
||||||
import typing
|
import typing
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
AutoModelForSpeechSeq2Seq,
|
||||||
|
AutoProcessor,
|
||||||
|
pipeline,
|
||||||
Wav2Vec2Model,
|
Wav2Vec2Model,
|
||||||
Wav2Vec2ForCTC,
|
Wav2Vec2ForCTC,
|
||||||
Wav2Vec2Processor,
|
Wav2Vec2Processor,
|
||||||
@ -17,17 +20,32 @@ AUDIO_SAMPLE_RATE = 16_000
|
|||||||
|
|
||||||
class InferredTranscript(typing.NamedTuple):
|
class InferredTranscript(typing.NamedTuple):
|
||||||
transcript: str
|
transcript: str
|
||||||
confidence_score: float
|
confidence_score: typing.Optional[float]
|
||||||
processing_time_sec: float
|
processing_time_sec: float
|
||||||
|
|
||||||
|
|
||||||
class Inference:
|
class Inference:
|
||||||
def __init__(self, model_name: str, use_gpu: bool = True) -> None:
|
def __init__(self, model_name: str, use_gpu: bool = True) -> None:
|
||||||
self.device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
|
self.model_name = model_name
|
||||||
self.processor = Wav2Vec2Processor.from_pretrained(model_name)
|
cuda_available = use_gpu and torch.cuda.is_available()
|
||||||
self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
|
self.device = "cuda" if cuda_available else "cpu"
|
||||||
|
self.torch_dtype = torch.float16 if cuda_available else torch.float32
|
||||||
|
if model_name.startswith("facebook/wav2vec2"):
|
||||||
|
self.processor = Wav2Vec2Processor.from_pretrained(model_name)
|
||||||
|
self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
|
||||||
|
else:
|
||||||
|
self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=self.torch_dtype,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
use_safetensors=True,
|
||||||
|
)
|
||||||
|
self.processor = AutoProcessor.from_pretrained(model_name)
|
||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
|
|
||||||
|
def is_wav2vec2(self) -> bool:
|
||||||
|
return self.model_name.startswith("facebook/wav2vec2")
|
||||||
|
|
||||||
def file_to_text(self, audio_file: typing.BinaryIO) -> InferredTranscript:
|
def file_to_text(self, audio_file: typing.BinaryIO) -> InferredTranscript:
|
||||||
audio_input, samplerate = soundfile.read(audio_file)
|
audio_input, samplerate = soundfile.read(audio_file)
|
||||||
if samplerate != AUDIO_SAMPLE_RATE:
|
if samplerate != AUDIO_SAMPLE_RATE:
|
||||||
@ -40,6 +58,17 @@ class Inference:
|
|||||||
if len(audio_buffer) == 0:
|
if len(audio_buffer) == 0:
|
||||||
InferredTranscript("", 1, 0)
|
InferredTranscript("", 1, 0)
|
||||||
timer_start = time.perf_counter()
|
timer_start = time.perf_counter()
|
||||||
|
if self.is_wav2vec2():
|
||||||
|
transcription, confidence = self._wav2vec_buffer_to_text(audio_buffer)
|
||||||
|
else:
|
||||||
|
transcription = self._pipeline_buffer_to_text(audio_buffer)
|
||||||
|
confidence = None
|
||||||
|
timer_end = time.perf_counter()
|
||||||
|
return InferredTranscript(transcription, confidence, timer_end - timer_start)
|
||||||
|
|
||||||
|
def _wav2vec_buffer_to_text(
|
||||||
|
self, audio_buffer: numpy.typing.ArrayLike
|
||||||
|
) -> typing.Tuple[str, float]:
|
||||||
inputs = self.processor(
|
inputs = self.processor(
|
||||||
torch.tensor(audio_buffer),
|
torch.tensor(audio_buffer),
|
||||||
sampling_rate=AUDIO_SAMPLE_RATE,
|
sampling_rate=AUDIO_SAMPLE_RATE,
|
||||||
@ -47,15 +76,31 @@ class Inference:
|
|||||||
padding=True,
|
padding=True,
|
||||||
)
|
)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
if hasattr(inputs, "attention_mask"):
|
||||||
|
attention_mask = inputs.attention_mask.to(self.device)
|
||||||
|
else:
|
||||||
|
attention_mask = None
|
||||||
logits = self.model(
|
logits = self.model(
|
||||||
inputs.input_values.to(self.device),
|
inputs.input_values.to(self.device),
|
||||||
attention_mask=inputs.attention_mask.to(self.device),
|
attention_mask=attention_mask,
|
||||||
).logits
|
).logits
|
||||||
predicted_ids = torch.argmax(logits, dim=-1)
|
predicted_ids = torch.argmax(logits, dim=-1)
|
||||||
transcription = self.processor.batch_decode(predicted_ids)[0]
|
|
||||||
timer_end = time.perf_counter()
|
|
||||||
confidence = self.confidence_score(logits, predicted_ids)
|
confidence = self.confidence_score(logits, predicted_ids)
|
||||||
return InferredTranscript(transcription, confidence, timer_end - timer_start)
|
return (self.processor.batch_decode(predicted_ids)[0], confidence)
|
||||||
|
|
||||||
|
def _pipeline_buffer_to_text(self, audio_buffer: numpy.typing.ArrayLike) -> str:
|
||||||
|
pipe = pipeline(
|
||||||
|
"automatic-speech-recognition",
|
||||||
|
model=self.model,
|
||||||
|
tokenizer=self.processor.tokenizer,
|
||||||
|
feature_extractor=self.processor.feature_extractor,
|
||||||
|
chunk_length_s=30,
|
||||||
|
batch_size=16, # batch size for inference - set based on your device
|
||||||
|
torch_dtype=self.torch_dtype,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
result = pipe(audio_buffer)
|
||||||
|
return result["text"]
|
||||||
|
|
||||||
def confidence_score(
|
def confidence_score(
|
||||||
self, logits: torch.Tensor, predicted_ids: torch.Tensor
|
self, logits: torch.Tensor, predicted_ids: torch.Tensor
|
||||||
|
@ -13,7 +13,13 @@ def main():
|
|||||||
logging.basicConfig(level=logging.ERROR, format=logging_format)
|
logging.basicConfig(level=logging.ERROR, format=logging_format)
|
||||||
else:
|
else:
|
||||||
logging.basicConfig(level=logging.INFO, format=logging_format)
|
logging.basicConfig(level=logging.INFO, format=logging_format)
|
||||||
inference = dnd_transcribe.inference.Inference(args.model)
|
inference = dnd_transcribe.inference.Inference(args.model, use_gpu=args.use_gpu)
|
||||||
(transcription, score, duration) = inference.file_to_text(args.audio_file)
|
if args.audio_file is not None:
|
||||||
print(transcription)
|
(transcription, score, duration) = inference.file_to_text(args.audio_file)
|
||||||
print(f"[Confidence: {score:.1%} in {duration:.2} seconds]")
|
print(transcription)
|
||||||
|
if score is not None:
|
||||||
|
print(f"[Confidence: {score:.1%} in {duration} seconds]")
|
||||||
|
else:
|
||||||
|
print(f"[Confidence -unknown- in {duration} seconds]")
|
||||||
|
else:
|
||||||
|
print("Live transcription is a WIP")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user