Very basic utility, will convert a 16khz wave file to text using wav2vec
This commit is contained in:
parent
936229bd6e
commit
0d6e46d65a
@ -1,4 +1,5 @@
|
||||
import argparse
|
||||
from dnd_transcribe.inference import DEFAULT_MODEL
|
||||
|
||||
|
||||
def build_argument_parser() -> argparse.ArgumentParser:
|
||||
@ -11,9 +12,16 @@ def build_argument_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"-q", "--quiet", action="store_true", help="Only display errors"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--model",
|
||||
type=str,
|
||||
help="ASR model",
|
||||
default=DEFAULT_MODEL,
|
||||
)
|
||||
parser.add_argument(
|
||||
"audio_file",
|
||||
type=argparse.FileType(mode="r"),
|
||||
type=argparse.FileType(mode="rb"),
|
||||
help="Audio file to process",
|
||||
)
|
||||
return parser
|
||||
|
71
dnd_transcribe/inference.py
Normal file
71
dnd_transcribe/inference.py
Normal file
@ -0,0 +1,71 @@
|
||||
import soundfile
|
||||
import numpy.typing
|
||||
import time
|
||||
import torch
|
||||
import typing
|
||||
from transformers import (
|
||||
Wav2Vec2Model,
|
||||
Wav2Vec2ForCTC,
|
||||
Wav2Vec2Processor,
|
||||
Wav2Vec2ProcessorWithLM,
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_MODEL = "facebook/wav2vec2-large-960h-lv60-self"
|
||||
AUDIO_SAMPLE_RATE = 16_000
|
||||
|
||||
|
||||
class InferredTranscript(typing.NamedTuple):
|
||||
transcript: str
|
||||
confidence_score: float
|
||||
processing_time_sec: float
|
||||
|
||||
|
||||
class Inference:
|
||||
def __init__(self, model_name: str, use_gpu: bool = True) -> None:
|
||||
self.device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
|
||||
self.processor = Wav2Vec2Processor.from_pretrained(model_name)
|
||||
self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
|
||||
self.model.to(self.device)
|
||||
|
||||
def file_to_text(self, audio_file: typing.BinaryIO) -> InferredTranscript:
|
||||
audio_input, samplerate = soundfile.read(audio_file)
|
||||
if samplerate != AUDIO_SAMPLE_RATE:
|
||||
raise Exception(f"Unsupported sample rate {samplerate}")
|
||||
return self.buffer_to_text(audio_input)
|
||||
|
||||
def buffer_to_text(
|
||||
self, audio_buffer: numpy.typing.ArrayLike
|
||||
) -> InferredTranscript:
|
||||
if len(audio_buffer) == 0:
|
||||
InferredTranscript("", 1, 0)
|
||||
timer_start = time.perf_counter()
|
||||
inputs = self.processor(
|
||||
torch.tensor(audio_buffer),
|
||||
sampling_rate=AUDIO_SAMPLE_RATE,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
)
|
||||
with torch.no_grad():
|
||||
logits = self.model(
|
||||
inputs.input_values.to(self.device),
|
||||
attention_mask=inputs.attention_mask.to(self.device),
|
||||
).logits
|
||||
predicted_ids = torch.argmax(logits, dim=-1)
|
||||
transcription = self.processor.batch_decode(predicted_ids)[0]
|
||||
timer_end = time.perf_counter()
|
||||
confidence = self.confidence_score(logits, predicted_ids)
|
||||
return InferredTranscript(transcription, confidence, timer_end - timer_start)
|
||||
|
||||
def confidence_score(
|
||||
self, logits: torch.Tensor, predicted_ids: torch.Tensor
|
||||
) -> float:
|
||||
scores = torch.nn.functional.softmax(logits, dim=-1)
|
||||
pred_scores = scores.gather(-1, predicted_ids.unsqueeze(-1))[:, :, 0]
|
||||
mask = torch.logical_and(
|
||||
predicted_ids.not_equal(self.processor.tokenizer.word_delimiter_token_id),
|
||||
predicted_ids.not_equal(self.processor.tokenizer.pad_token_id),
|
||||
)
|
||||
character_scores = pred_scores.masked_select(mask)
|
||||
total_average = torch.sum(character_scores) / len(character_scores)
|
||||
return total_average
|
@ -1,4 +1,5 @@
|
||||
import dnd_transcribe.argparse
|
||||
import dnd_transcribe.inference
|
||||
import logging
|
||||
|
||||
|
||||
@ -12,4 +13,7 @@ def main():
|
||||
logging.basicConfig(level=logging.ERROR, format=logging_format)
|
||||
else:
|
||||
logging.basicConfig(level=logging.INFO, format=logging_format)
|
||||
print("WIP")
|
||||
inference = dnd_transcribe.inference.Inference(args.model)
|
||||
(transcription, score, duration) = inference.file_to_text(args.audio_file)
|
||||
print(transcription)
|
||||
print(f"[Confidence: {score:.1%} in {duration:.2} seconds]")
|
||||
|
Loading…
x
Reference in New Issue
Block a user