Allow specifying the block length used when streaming
This commit is contained in:
parent
72347d5d47
commit
460ec42637
@ -19,6 +19,12 @@ def build_argument_parser() -> argparse.ArgumentParser:
|
||||
action="store_false",
|
||||
help="Disable using the GPU with CUDA",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--block-len",
|
||||
type=int,
|
||||
default=30,
|
||||
help="Block length in seconds of audio sent when streaming or to whisper",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--model",
|
||||
|
@ -25,8 +25,11 @@ class InferredTranscript(typing.NamedTuple):
|
||||
|
||||
|
||||
class Inference:
|
||||
def __init__(self, model_name: str, use_gpu: bool = True) -> None:
|
||||
def __init__(
|
||||
self, model_name: str, block_len: int = 20, use_gpu: bool = True
|
||||
) -> None:
|
||||
self.model_name = model_name
|
||||
self.block_len = block_len
|
||||
cuda_available = use_gpu and torch.cuda.is_available()
|
||||
self.device = "cuda" if cuda_available else "cpu"
|
||||
self.torch_dtype = torch.float16 if cuda_available else torch.float32
|
||||
@ -52,7 +55,7 @@ class Inference:
|
||||
raise Exception(f"Unsupported sample rate {samplerate}")
|
||||
stream = librosa.stream(
|
||||
audio_file_path,
|
||||
block_length=20,
|
||||
block_length=self.block_len,
|
||||
frame_length=AUDIO_SAMPLE_RATE,
|
||||
hop_length=AUDIO_SAMPLE_RATE,
|
||||
)
|
||||
@ -126,7 +129,7 @@ class Inference:
|
||||
model=self.model,
|
||||
tokenizer=self.processor.tokenizer,
|
||||
feature_extractor=self.processor.feature_extractor,
|
||||
chunk_length_s=30,
|
||||
chunk_length_s=self.block_len,
|
||||
batch_size=16, # batch size for inference - set based on your device
|
||||
torch_dtype=self.torch_dtype,
|
||||
device=self.device,
|
||||
|
@ -13,7 +13,11 @@ def main():
|
||||
logging.basicConfig(level=logging.ERROR, format=logging_format)
|
||||
else:
|
||||
logging.basicConfig(level=logging.INFO, format=logging_format)
|
||||
inference = dnd_transcribe.inference.Inference(args.model, use_gpu=args.use_gpu)
|
||||
inference = dnd_transcribe.inference.Inference(
|
||||
args.model,
|
||||
block_len=args.block_len,
|
||||
use_gpu=args.use_gpu,
|
||||
)
|
||||
if args.audio_file is not None:
|
||||
_print_inferred_transcript(inference.file_to_text(args.audio_file))
|
||||
elif args.stream_audio_file is not None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user