diff --git a/dnd_transcribe/argparse.py b/dnd_transcribe/argparse.py index babe7cd..81133b1 100644 --- a/dnd_transcribe/argparse.py +++ b/dnd_transcribe/argparse.py @@ -19,6 +19,12 @@ def build_argument_parser() -> argparse.ArgumentParser: action="store_false", help="Disable using the GPU with CUDA", ) + parser.add_argument( + "--block-len", + type=int, + default=30, + help="Block length in seconds of audio sent when streaming or to whisper", + ) parser.add_argument( "-m", "--model", diff --git a/dnd_transcribe/inference.py b/dnd_transcribe/inference.py index 4ffe5cf..6691e3d 100644 --- a/dnd_transcribe/inference.py +++ b/dnd_transcribe/inference.py @@ -25,8 +25,11 @@ class InferredTranscript(typing.NamedTuple): class Inference: - def __init__(self, model_name: str, use_gpu: bool = True) -> None: + def __init__( + self, model_name: str, block_len: int = 20, use_gpu: bool = True + ) -> None: self.model_name = model_name + self.block_len = block_len cuda_available = use_gpu and torch.cuda.is_available() self.device = "cuda" if cuda_available else "cpu" self.torch_dtype = torch.float16 if cuda_available else torch.float32 @@ -52,7 +55,7 @@ class Inference: raise Exception(f"Unsupported sample rate {samplerate}") stream = librosa.stream( audio_file_path, - block_length=20, + block_length=self.block_len, frame_length=AUDIO_SAMPLE_RATE, hop_length=AUDIO_SAMPLE_RATE, ) @@ -126,7 +129,7 @@ class Inference: model=self.model, tokenizer=self.processor.tokenizer, feature_extractor=self.processor.feature_extractor, - chunk_length_s=30, + chunk_length_s=self.block_len, batch_size=16, # batch size for inference - set based on your device torch_dtype=self.torch_dtype, device=self.device, diff --git a/dnd_transcribe/main.py b/dnd_transcribe/main.py index 3d1f471..56d2081 100644 --- a/dnd_transcribe/main.py +++ b/dnd_transcribe/main.py @@ -13,7 +13,11 @@ def main(): logging.basicConfig(level=logging.ERROR, format=logging_format) else: logging.basicConfig(level=logging.INFO, format=logging_format) - inference = dnd_transcribe.inference.Inference(args.model, use_gpu=args.use_gpu) + inference = dnd_transcribe.inference.Inference( + args.model, + block_len=args.block_len, + use_gpu=args.use_gpu, + ) if args.audio_file is not None: _print_inferred_transcript(inference.file_to_text(args.audio_file)) elif args.stream_audio_file is not None: