117 lines
4.3 KiB
Python

import soundfile
import numpy.typing
import time
import torch
import typing
from transformers import (
AutoModelForSpeechSeq2Seq,
AutoProcessor,
pipeline,
Wav2Vec2Model,
Wav2Vec2ForCTC,
Wav2Vec2Processor,
Wav2Vec2ProcessorWithLM,
)
DEFAULT_MODEL = "facebook/wav2vec2-large-960h-lv60-self"
AUDIO_SAMPLE_RATE = 16_000
class InferredTranscript(typing.NamedTuple):
transcript: str
confidence_score: typing.Optional[float]
processing_time_sec: float
class Inference:
def __init__(self, model_name: str, use_gpu: bool = True) -> None:
self.model_name = model_name
cuda_available = use_gpu and torch.cuda.is_available()
self.device = "cuda" if cuda_available else "cpu"
self.torch_dtype = torch.float16 if cuda_available else torch.float32
if model_name.startswith("facebook/wav2vec2"):
self.processor = Wav2Vec2Processor.from_pretrained(model_name)
self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
else:
self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_name,
torch_dtype=self.torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
)
self.processor = AutoProcessor.from_pretrained(model_name)
self.model.to(self.device)
def is_wav2vec2(self) -> bool:
return self.model_name.startswith("facebook/wav2vec2")
def file_to_text(self, audio_file: typing.BinaryIO) -> InferredTranscript:
audio_input, samplerate = soundfile.read(audio_file)
if samplerate != AUDIO_SAMPLE_RATE:
raise Exception(f"Unsupported sample rate {samplerate}")
return self.buffer_to_text(audio_input)
def buffer_to_text(
self, audio_buffer: numpy.typing.ArrayLike
) -> InferredTranscript:
if len(audio_buffer) == 0:
InferredTranscript("", 1, 0)
timer_start = time.perf_counter()
if self.is_wav2vec2():
transcription, confidence = self._wav2vec_buffer_to_text(audio_buffer)
else:
transcription = self._pipeline_buffer_to_text(audio_buffer)
confidence = None
timer_end = time.perf_counter()
return InferredTranscript(transcription, confidence, timer_end - timer_start)
def _wav2vec_buffer_to_text(
self, audio_buffer: numpy.typing.ArrayLike
) -> typing.Tuple[str, float]:
inputs = self.processor(
torch.tensor(audio_buffer),
sampling_rate=AUDIO_SAMPLE_RATE,
return_tensors="pt",
padding=True,
)
with torch.no_grad():
if hasattr(inputs, "attention_mask"):
attention_mask = inputs.attention_mask.to(self.device)
else:
attention_mask = None
logits = self.model(
inputs.input_values.to(self.device),
attention_mask=attention_mask,
).logits
predicted_ids = torch.argmax(logits, dim=-1)
confidence = self.confidence_score(logits, predicted_ids)
return (self.processor.batch_decode(predicted_ids)[0], confidence)
def _pipeline_buffer_to_text(self, audio_buffer: numpy.typing.ArrayLike) -> str:
pipe = pipeline(
"automatic-speech-recognition",
model=self.model,
tokenizer=self.processor.tokenizer,
feature_extractor=self.processor.feature_extractor,
chunk_length_s=30,
batch_size=16, # batch size for inference - set based on your device
torch_dtype=self.torch_dtype,
device=self.device,
)
result = pipe(audio_buffer)
return result["text"]
def confidence_score(
self, logits: torch.Tensor, predicted_ids: torch.Tensor
) -> float:
scores = torch.nn.functional.softmax(logits, dim=-1)
pred_scores = scores.gather(-1, predicted_ids.unsqueeze(-1))[:, :, 0]
mask = torch.logical_and(
predicted_ids.not_equal(self.processor.tokenizer.word_delimiter_token_id),
predicted_ids.not_equal(self.processor.tokenizer.pad_token_id),
)
character_scores = pred_scores.masked_select(mask)
total_average = torch.sum(character_scores) / len(character_scores)
return total_average