117 lines
4.3 KiB
Python
117 lines
4.3 KiB
Python
import soundfile
|
|
import numpy.typing
|
|
import time
|
|
import torch
|
|
import typing
|
|
from transformers import (
|
|
AutoModelForSpeechSeq2Seq,
|
|
AutoProcessor,
|
|
pipeline,
|
|
Wav2Vec2Model,
|
|
Wav2Vec2ForCTC,
|
|
Wav2Vec2Processor,
|
|
Wav2Vec2ProcessorWithLM,
|
|
)
|
|
|
|
|
|
DEFAULT_MODEL = "facebook/wav2vec2-large-960h-lv60-self"
|
|
AUDIO_SAMPLE_RATE = 16_000
|
|
|
|
|
|
class InferredTranscript(typing.NamedTuple):
|
|
transcript: str
|
|
confidence_score: typing.Optional[float]
|
|
processing_time_sec: float
|
|
|
|
|
|
class Inference:
|
|
def __init__(self, model_name: str, use_gpu: bool = True) -> None:
|
|
self.model_name = model_name
|
|
cuda_available = use_gpu and torch.cuda.is_available()
|
|
self.device = "cuda" if cuda_available else "cpu"
|
|
self.torch_dtype = torch.float16 if cuda_available else torch.float32
|
|
if model_name.startswith("facebook/wav2vec2"):
|
|
self.processor = Wav2Vec2Processor.from_pretrained(model_name)
|
|
self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
|
|
else:
|
|
self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
|
model_name,
|
|
torch_dtype=self.torch_dtype,
|
|
low_cpu_mem_usage=True,
|
|
use_safetensors=True,
|
|
)
|
|
self.processor = AutoProcessor.from_pretrained(model_name)
|
|
self.model.to(self.device)
|
|
|
|
def is_wav2vec2(self) -> bool:
|
|
return self.model_name.startswith("facebook/wav2vec2")
|
|
|
|
def file_to_text(self, audio_file: typing.BinaryIO) -> InferredTranscript:
|
|
audio_input, samplerate = soundfile.read(audio_file)
|
|
if samplerate != AUDIO_SAMPLE_RATE:
|
|
raise Exception(f"Unsupported sample rate {samplerate}")
|
|
return self.buffer_to_text(audio_input)
|
|
|
|
def buffer_to_text(
|
|
self, audio_buffer: numpy.typing.ArrayLike
|
|
) -> InferredTranscript:
|
|
if len(audio_buffer) == 0:
|
|
InferredTranscript("", 1, 0)
|
|
timer_start = time.perf_counter()
|
|
if self.is_wav2vec2():
|
|
transcription, confidence = self._wav2vec_buffer_to_text(audio_buffer)
|
|
else:
|
|
transcription = self._pipeline_buffer_to_text(audio_buffer)
|
|
confidence = None
|
|
timer_end = time.perf_counter()
|
|
return InferredTranscript(transcription, confidence, timer_end - timer_start)
|
|
|
|
def _wav2vec_buffer_to_text(
|
|
self, audio_buffer: numpy.typing.ArrayLike
|
|
) -> typing.Tuple[str, float]:
|
|
inputs = self.processor(
|
|
torch.tensor(audio_buffer),
|
|
sampling_rate=AUDIO_SAMPLE_RATE,
|
|
return_tensors="pt",
|
|
padding=True,
|
|
)
|
|
with torch.no_grad():
|
|
if hasattr(inputs, "attention_mask"):
|
|
attention_mask = inputs.attention_mask.to(self.device)
|
|
else:
|
|
attention_mask = None
|
|
logits = self.model(
|
|
inputs.input_values.to(self.device),
|
|
attention_mask=attention_mask,
|
|
).logits
|
|
predicted_ids = torch.argmax(logits, dim=-1)
|
|
confidence = self.confidence_score(logits, predicted_ids)
|
|
return (self.processor.batch_decode(predicted_ids)[0], confidence)
|
|
|
|
def _pipeline_buffer_to_text(self, audio_buffer: numpy.typing.ArrayLike) -> str:
|
|
pipe = pipeline(
|
|
"automatic-speech-recognition",
|
|
model=self.model,
|
|
tokenizer=self.processor.tokenizer,
|
|
feature_extractor=self.processor.feature_extractor,
|
|
chunk_length_s=30,
|
|
batch_size=16, # batch size for inference - set based on your device
|
|
torch_dtype=self.torch_dtype,
|
|
device=self.device,
|
|
)
|
|
result = pipe(audio_buffer)
|
|
return result["text"]
|
|
|
|
def confidence_score(
|
|
self, logits: torch.Tensor, predicted_ids: torch.Tensor
|
|
) -> float:
|
|
scores = torch.nn.functional.softmax(logits, dim=-1)
|
|
pred_scores = scores.gather(-1, predicted_ids.unsqueeze(-1))[:, :, 0]
|
|
mask = torch.logical_and(
|
|
predicted_ids.not_equal(self.processor.tokenizer.word_delimiter_token_id),
|
|
predicted_ids.not_equal(self.processor.tokenizer.pad_token_id),
|
|
)
|
|
character_scores = pred_scores.masked_select(mask)
|
|
total_average = torch.sum(character_scores) / len(character_scores)
|
|
return total_average
|