Compare commits

..

No commits in common. "460ec426371b4ae61cc0addbd313f2614012915e" and "db2b6d956b63713d2f7f6698fcf3b5697fa0d2ed" have entirely different histories.

3 changed files with 12 additions and 29 deletions

View File

@ -19,12 +19,6 @@ def build_argument_parser() -> argparse.ArgumentParser:
action="store_false", action="store_false",
help="Disable using the GPU with CUDA", help="Disable using the GPU with CUDA",
) )
parser.add_argument(
"--block-len",
type=int,
default=30,
help="Block length in seconds of audio sent when streaming or to whisper",
)
parser.add_argument( parser.add_argument(
"-m", "-m",
"--model", "--model",

View File

@ -25,11 +25,8 @@ class InferredTranscript(typing.NamedTuple):
class Inference: class Inference:
def __init__( def __init__(self, model_name: str, use_gpu: bool = True) -> None:
self, model_name: str, block_len: int = 20, use_gpu: bool = True
) -> None:
self.model_name = model_name self.model_name = model_name
self.block_len = block_len
cuda_available = use_gpu and torch.cuda.is_available() cuda_available = use_gpu and torch.cuda.is_available()
self.device = "cuda" if cuda_available else "cpu" self.device = "cuda" if cuda_available else "cpu"
self.torch_dtype = torch.float16 if cuda_available else torch.float32 self.torch_dtype = torch.float16 if cuda_available else torch.float32
@ -55,7 +52,7 @@ class Inference:
raise Exception(f"Unsupported sample rate {samplerate}") raise Exception(f"Unsupported sample rate {samplerate}")
stream = librosa.stream( stream = librosa.stream(
audio_file_path, audio_file_path,
block_length=self.block_len, block_length=20,
frame_length=AUDIO_SAMPLE_RATE, frame_length=AUDIO_SAMPLE_RATE,
hop_length=AUDIO_SAMPLE_RATE, hop_length=AUDIO_SAMPLE_RATE,
) )
@ -65,7 +62,6 @@ class Inference:
for block in stream: for block in stream:
if len(block.shape) > 1: if len(block.shape) > 1:
block = speech[:, 0] + speech[:, 1] block = speech[:, 0] + speech[:, 1]
try:
block_inference = self.buffer_to_text(block) block_inference = self.buffer_to_text(block)
transcript += block_inference.transcript + " " transcript += block_inference.transcript + " "
processing_time += block_inference.processing_time_sec processing_time += block_inference.processing_time_sec
@ -74,9 +70,6 @@ class Inference:
confidence = block_inference.confidence_score confidence = block_inference.confidence_score
else: else:
confidence *= block_inference.confidence_score confidence *= block_inference.confidence_score
except torch.OutOfMemoryError as e:
print(e)
break
return InferredTranscript(transcript.strip(), confidence, processing_time) return InferredTranscript(transcript.strip(), confidence, processing_time)
def file_to_text(self, audio_file: typing.BinaryIO) -> InferredTranscript: def file_to_text(self, audio_file: typing.BinaryIO) -> InferredTranscript:
@ -129,7 +122,7 @@ class Inference:
model=self.model, model=self.model,
tokenizer=self.processor.tokenizer, tokenizer=self.processor.tokenizer,
feature_extractor=self.processor.feature_extractor, feature_extractor=self.processor.feature_extractor,
chunk_length_s=self.block_len, chunk_length_s=30,
batch_size=16, # batch size for inference - set based on your device batch_size=16, # batch size for inference - set based on your device
torch_dtype=self.torch_dtype, torch_dtype=self.torch_dtype,
device=self.device, device=self.device,

View File

@ -13,11 +13,7 @@ def main():
logging.basicConfig(level=logging.ERROR, format=logging_format) logging.basicConfig(level=logging.ERROR, format=logging_format)
else: else:
logging.basicConfig(level=logging.INFO, format=logging_format) logging.basicConfig(level=logging.INFO, format=logging_format)
inference = dnd_transcribe.inference.Inference( inference = dnd_transcribe.inference.Inference(args.model, use_gpu=args.use_gpu)
args.model,
block_len=args.block_len,
use_gpu=args.use_gpu,
)
if args.audio_file is not None: if args.audio_file is not None:
_print_inferred_transcript(inference.file_to_text(args.audio_file)) _print_inferred_transcript(inference.file_to_text(args.audio_file))
elif args.stream_audio_file is not None: elif args.stream_audio_file is not None: