Very basic utility, will convert a 16khz wave file to text using wav2vec

This commit is contained in:
David Kruger 2025-04-22 15:27:59 -07:00
parent 936229bd6e
commit 0d6e46d65a
3 changed files with 85 additions and 2 deletions

View File

@ -1,4 +1,5 @@
import argparse
from dnd_transcribe.inference import DEFAULT_MODEL
def build_argument_parser() -> argparse.ArgumentParser:
@ -11,9 +12,16 @@ def build_argument_parser() -> argparse.ArgumentParser:
parser.add_argument(
"-q", "--quiet", action="store_true", help="Only display errors"
)
parser.add_argument(
"-m",
"--model",
type=str,
help="ASR model",
default=DEFAULT_MODEL,
)
parser.add_argument(
"audio_file",
type=argparse.FileType(mode="r"),
type=argparse.FileType(mode="rb"),
help="Audio file to process",
)
return parser

View File

@ -0,0 +1,71 @@
import soundfile
import numpy.typing
import time
import torch
import typing
from transformers import (
Wav2Vec2Model,
Wav2Vec2ForCTC,
Wav2Vec2Processor,
Wav2Vec2ProcessorWithLM,
)
DEFAULT_MODEL = "facebook/wav2vec2-large-960h-lv60-self"
AUDIO_SAMPLE_RATE = 16_000
class InferredTranscript(typing.NamedTuple):
transcript: str
confidence_score: float
processing_time_sec: float
class Inference:
def __init__(self, model_name: str, use_gpu: bool = True) -> None:
self.device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
self.processor = Wav2Vec2Processor.from_pretrained(model_name)
self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
self.model.to(self.device)
def file_to_text(self, audio_file: typing.BinaryIO) -> InferredTranscript:
audio_input, samplerate = soundfile.read(audio_file)
if samplerate != AUDIO_SAMPLE_RATE:
raise Exception(f"Unsupported sample rate {samplerate}")
return self.buffer_to_text(audio_input)
def buffer_to_text(
self, audio_buffer: numpy.typing.ArrayLike
) -> InferredTranscript:
if len(audio_buffer) == 0:
InferredTranscript("", 1, 0)
timer_start = time.perf_counter()
inputs = self.processor(
torch.tensor(audio_buffer),
sampling_rate=AUDIO_SAMPLE_RATE,
return_tensors="pt",
padding=True,
)
with torch.no_grad():
logits = self.model(
inputs.input_values.to(self.device),
attention_mask=inputs.attention_mask.to(self.device),
).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = self.processor.batch_decode(predicted_ids)[0]
timer_end = time.perf_counter()
confidence = self.confidence_score(logits, predicted_ids)
return InferredTranscript(transcription, confidence, timer_end - timer_start)
def confidence_score(
self, logits: torch.Tensor, predicted_ids: torch.Tensor
) -> float:
scores = torch.nn.functional.softmax(logits, dim=-1)
pred_scores = scores.gather(-1, predicted_ids.unsqueeze(-1))[:, :, 0]
mask = torch.logical_and(
predicted_ids.not_equal(self.processor.tokenizer.word_delimiter_token_id),
predicted_ids.not_equal(self.processor.tokenizer.pad_token_id),
)
character_scores = pred_scores.masked_select(mask)
total_average = torch.sum(character_scores) / len(character_scores)
return total_average

View File

@ -1,4 +1,5 @@
import dnd_transcribe.argparse
import dnd_transcribe.inference
import logging
@ -12,4 +13,7 @@ def main():
logging.basicConfig(level=logging.ERROR, format=logging_format)
else:
logging.basicConfig(level=logging.INFO, format=logging_format)
print("WIP")
inference = dnd_transcribe.inference.Inference(args.model)
(transcription, score, duration) = inference.file_to_text(args.audio_file)
print(transcription)
print(f"[Confidence: {score:.1%} in {duration:.2} seconds]")