diff --git a/dnd_transcribe/argparse.py b/dnd_transcribe/argparse.py index 571d1d1..4f14af0 100644 --- a/dnd_transcribe/argparse.py +++ b/dnd_transcribe/argparse.py @@ -1,4 +1,5 @@ import argparse +from dnd_transcribe.inference import DEFAULT_MODEL def build_argument_parser() -> argparse.ArgumentParser: @@ -11,9 +12,16 @@ def build_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "-q", "--quiet", action="store_true", help="Only display errors" ) + parser.add_argument( + "-m", + "--model", + type=str, + help="ASR model", + default=DEFAULT_MODEL, + ) parser.add_argument( "audio_file", - type=argparse.FileType(mode="r"), + type=argparse.FileType(mode="rb"), help="Audio file to process", ) return parser diff --git a/dnd_transcribe/inference.py b/dnd_transcribe/inference.py new file mode 100644 index 0000000..f83a696 --- /dev/null +++ b/dnd_transcribe/inference.py @@ -0,0 +1,71 @@ +import soundfile +import numpy.typing +import time +import torch +import typing +from transformers import ( + Wav2Vec2Model, + Wav2Vec2ForCTC, + Wav2Vec2Processor, + Wav2Vec2ProcessorWithLM, +) + + +DEFAULT_MODEL = "facebook/wav2vec2-large-960h-lv60-self" +AUDIO_SAMPLE_RATE = 16_000 + + +class InferredTranscript(typing.NamedTuple): + transcript: str + confidence_score: float + processing_time_sec: float + + +class Inference: + def __init__(self, model_name: str, use_gpu: bool = True) -> None: + self.device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu" + self.processor = Wav2Vec2Processor.from_pretrained(model_name) + self.model = Wav2Vec2ForCTC.from_pretrained(model_name) + self.model.to(self.device) + + def file_to_text(self, audio_file: typing.BinaryIO) -> InferredTranscript: + audio_input, samplerate = soundfile.read(audio_file) + if samplerate != AUDIO_SAMPLE_RATE: + raise Exception(f"Unsupported sample rate {samplerate}") + return self.buffer_to_text(audio_input) + + def buffer_to_text( + self, audio_buffer: numpy.typing.ArrayLike + ) -> InferredTranscript: + if len(audio_buffer) == 0: + InferredTranscript("", 1, 0) + timer_start = time.perf_counter() + inputs = self.processor( + torch.tensor(audio_buffer), + sampling_rate=AUDIO_SAMPLE_RATE, + return_tensors="pt", + padding=True, + ) + with torch.no_grad(): + logits = self.model( + inputs.input_values.to(self.device), + attention_mask=inputs.attention_mask.to(self.device), + ).logits + predicted_ids = torch.argmax(logits, dim=-1) + transcription = self.processor.batch_decode(predicted_ids)[0] + timer_end = time.perf_counter() + confidence = self.confidence_score(logits, predicted_ids) + return InferredTranscript(transcription, confidence, timer_end - timer_start) + + def confidence_score( + self, logits: torch.Tensor, predicted_ids: torch.Tensor + ) -> float: + scores = torch.nn.functional.softmax(logits, dim=-1) + pred_scores = scores.gather(-1, predicted_ids.unsqueeze(-1))[:, :, 0] + mask = torch.logical_and( + predicted_ids.not_equal(self.processor.tokenizer.word_delimiter_token_id), + predicted_ids.not_equal(self.processor.tokenizer.pad_token_id), + ) + character_scores = pred_scores.masked_select(mask) + total_average = torch.sum(character_scores) / len(character_scores) + return total_average diff --git a/dnd_transcribe/main.py b/dnd_transcribe/main.py index 36d2a56..6370ecf 100644 --- a/dnd_transcribe/main.py +++ b/dnd_transcribe/main.py @@ -1,4 +1,5 @@ import dnd_transcribe.argparse +import dnd_transcribe.inference import logging @@ -12,4 +13,7 @@ def main(): logging.basicConfig(level=logging.ERROR, format=logging_format) else: logging.basicConfig(level=logging.INFO, format=logging_format) - print("WIP") + inference = dnd_transcribe.inference.Inference(args.model) + (transcription, score, duration) = inference.file_to_text(args.audio_file) + print(transcription) + print(f"[Confidence: {score:.1%} in {duration:.2} seconds]")