Merge branch 'stream'
This commit is contained in:
commit
69a9278536
@ -1,4 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
from dnd_transcribe.inference import DEFAULT_MODEL
|
from dnd_transcribe.inference import DEFAULT_MODEL
|
||||||
|
|
||||||
|
|
||||||
@ -18,6 +19,12 @@ def build_argument_parser() -> argparse.ArgumentParser:
|
|||||||
action="store_false",
|
action="store_false",
|
||||||
help="Disable using the GPU with CUDA",
|
help="Disable using the GPU with CUDA",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--block-len",
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help="Block length in seconds of audio sent when streaming or to whisper",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-m",
|
"-m",
|
||||||
"--model",
|
"--model",
|
||||||
@ -29,6 +36,19 @@ def build_argument_parser() -> argparse.ArgumentParser:
|
|||||||
"-f",
|
"-f",
|
||||||
"--audio-file",
|
"--audio-file",
|
||||||
type=argparse.FileType(mode="rb"),
|
type=argparse.FileType(mode="rb"),
|
||||||
help="Audio file to process",
|
help="Audio file to process, for long audo see --stream-audio-file",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-s",
|
||||||
|
"--stream-audio-file",
|
||||||
|
type=valid_file_path,
|
||||||
|
help="Audio file to process by streaming",
|
||||||
)
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def valid_file_path(path: str) -> str:
|
||||||
|
path = os.path.realpath(path)
|
||||||
|
if os.path.isfile(path):
|
||||||
|
return path
|
||||||
|
raise argparse.ArgumentTypeError("{} is not a valid file".format(path))
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import soundfile
|
import librosa
|
||||||
import numpy.typing
|
import numpy.typing
|
||||||
import time
|
import time
|
||||||
import torch
|
import torch
|
||||||
@ -25,8 +25,11 @@ class InferredTranscript(typing.NamedTuple):
|
|||||||
|
|
||||||
|
|
||||||
class Inference:
|
class Inference:
|
||||||
def __init__(self, model_name: str, use_gpu: bool = True) -> None:
|
def __init__(
|
||||||
|
self, model_name: str, block_len: int = 20, use_gpu: bool = True
|
||||||
|
) -> None:
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
self.block_len = block_len
|
||||||
cuda_available = use_gpu and torch.cuda.is_available()
|
cuda_available = use_gpu and torch.cuda.is_available()
|
||||||
self.device = "cuda" if cuda_available else "cpu"
|
self.device = "cuda" if cuda_available else "cpu"
|
||||||
self.torch_dtype = torch.float16 if cuda_available else torch.float32
|
self.torch_dtype = torch.float16 if cuda_available else torch.float32
|
||||||
@ -46,10 +49,42 @@ class Inference:
|
|||||||
def is_wav2vec2(self) -> bool:
|
def is_wav2vec2(self) -> bool:
|
||||||
return self.model_name.startswith("facebook/wav2vec2")
|
return self.model_name.startswith("facebook/wav2vec2")
|
||||||
|
|
||||||
def file_to_text(self, audio_file: typing.BinaryIO) -> InferredTranscript:
|
def stream_file_to_text(self, audio_file_path: str) -> InferredTranscript:
|
||||||
audio_input, samplerate = soundfile.read(audio_file)
|
samplerate = librosa.get_samplerate(audio_file_path)
|
||||||
if samplerate != AUDIO_SAMPLE_RATE:
|
if samplerate != AUDIO_SAMPLE_RATE:
|
||||||
raise Exception(f"Unsupported sample rate {samplerate}")
|
raise Exception(f"Unsupported sample rate {samplerate}")
|
||||||
|
stream = librosa.stream(
|
||||||
|
audio_file_path,
|
||||||
|
block_length=self.block_len,
|
||||||
|
frame_length=AUDIO_SAMPLE_RATE,
|
||||||
|
hop_length=AUDIO_SAMPLE_RATE,
|
||||||
|
)
|
||||||
|
transcript = ""
|
||||||
|
confidence = None
|
||||||
|
processing_time = 0.0
|
||||||
|
for block in stream:
|
||||||
|
if len(block.shape) > 1:
|
||||||
|
block = speech[:, 0] + speech[:, 1]
|
||||||
|
try:
|
||||||
|
block_inference = self.buffer_to_text(block)
|
||||||
|
transcript += block_inference.transcript + " "
|
||||||
|
processing_time += block_inference.processing_time_sec
|
||||||
|
if block_inference.confidence_score is not None:
|
||||||
|
if confidence is None:
|
||||||
|
confidence = block_inference.confidence_score
|
||||||
|
else:
|
||||||
|
confidence *= block_inference.confidence_score
|
||||||
|
except torch.OutOfMemoryError as e:
|
||||||
|
print(e)
|
||||||
|
break
|
||||||
|
return InferredTranscript(transcript.strip(), confidence, processing_time)
|
||||||
|
|
||||||
|
def file_to_text(self, audio_file: typing.BinaryIO) -> InferredTranscript:
|
||||||
|
audio_input, samplerate = librosa.load(audio_file)
|
||||||
|
if samplerate != AUDIO_SAMPLE_RATE:
|
||||||
|
audio_input = librosa.resample(
|
||||||
|
audio_input, orig_sr=samplerate, target_sr=AUDIO_SAMPLE_RATE
|
||||||
|
)
|
||||||
return self.buffer_to_text(audio_input)
|
return self.buffer_to_text(audio_input)
|
||||||
|
|
||||||
def buffer_to_text(
|
def buffer_to_text(
|
||||||
@ -94,7 +129,7 @@ class Inference:
|
|||||||
model=self.model,
|
model=self.model,
|
||||||
tokenizer=self.processor.tokenizer,
|
tokenizer=self.processor.tokenizer,
|
||||||
feature_extractor=self.processor.feature_extractor,
|
feature_extractor=self.processor.feature_extractor,
|
||||||
chunk_length_s=30,
|
chunk_length_s=self.block_len,
|
||||||
batch_size=16, # batch size for inference - set based on your device
|
batch_size=16, # batch size for inference - set based on your device
|
||||||
torch_dtype=self.torch_dtype,
|
torch_dtype=self.torch_dtype,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
@ -13,13 +13,27 @@ def main():
|
|||||||
logging.basicConfig(level=logging.ERROR, format=logging_format)
|
logging.basicConfig(level=logging.ERROR, format=logging_format)
|
||||||
else:
|
else:
|
||||||
logging.basicConfig(level=logging.INFO, format=logging_format)
|
logging.basicConfig(level=logging.INFO, format=logging_format)
|
||||||
inference = dnd_transcribe.inference.Inference(args.model, use_gpu=args.use_gpu)
|
inference = dnd_transcribe.inference.Inference(
|
||||||
|
args.model,
|
||||||
|
block_len=args.block_len,
|
||||||
|
use_gpu=args.use_gpu,
|
||||||
|
)
|
||||||
if args.audio_file is not None:
|
if args.audio_file is not None:
|
||||||
(transcription, score, duration) = inference.file_to_text(args.audio_file)
|
_print_inferred_transcript(inference.file_to_text(args.audio_file))
|
||||||
print(transcription)
|
elif args.stream_audio_file is not None:
|
||||||
if score is not None:
|
_print_inferred_transcript(
|
||||||
print(f"[Confidence: {score:.1%} in {duration} seconds]")
|
inference.stream_file_to_text(args.stream_audio_file)
|
||||||
else:
|
)
|
||||||
print(f"[Confidence -unknown- in {duration} seconds]")
|
|
||||||
else:
|
else:
|
||||||
print("Live transcription is a WIP")
|
print("Live transcription is a WIP")
|
||||||
|
|
||||||
|
|
||||||
|
def _print_inferred_transcript(
|
||||||
|
transcript: dnd_transcribe.inference.InferredTranscript,
|
||||||
|
) -> None:
|
||||||
|
(transcription, score, duration) = transcript
|
||||||
|
print(transcription)
|
||||||
|
if score is not None:
|
||||||
|
print(f"[Confidence: {score:.1%} in {duration} seconds]")
|
||||||
|
else:
|
||||||
|
print(f"[Confidence -unknown- in {duration} seconds]")
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
soundfile>=0.13.1
|
librosa>=0.11.0
|
||||||
|
numpy>=2.2.5
|
||||||
torch>=2.6.0
|
torch>=2.6.0
|
||||||
transformers>=4.51.3
|
transformers>=4.51.3
|
||||||
|
Loading…
x
Reference in New Issue
Block a user