wer_benchmark.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. import argparse
  2. import json
  3. import os
  4. from datasets import load_dataset
  5. from jiwer import wer
  6. from tqdm import tqdm
  7. from transformers.models.whisper.english_normalizer import EnglishTextNormalizer
  8. from faster_whisper import WhisperModel
  9. parser = argparse.ArgumentParser(description="WER benchmark")
  10. parser.add_argument(
  11. "--audio_numb",
  12. type=int,
  13. default=None,
  14. help="Specify the number of validation audio files in the dataset."
  15. " Set to None to retrieve all audio files.",
  16. )
  17. args = parser.parse_args()
  18. model_path = "large-v3"
  19. model = WhisperModel(model_path, device="cuda")
  20. # load the dataset with streaming mode
  21. dataset = load_dataset("librispeech_asr", "clean", split="validation", streaming=True)
  22. with open(os.path.join(os.path.dirname(__file__), "normalizer.json"), "r") as f:
  23. normalizer = EnglishTextNormalizer(json.load(f))
  24. def inference(batch):
  25. batch["transcription"] = []
  26. for sample in batch["audio"]:
  27. segments, info = model.transcribe(sample["array"], language="en")
  28. batch["transcription"].append("".join([segment.text for segment in segments]))
  29. batch["reference"] = batch["text"]
  30. return batch
  31. dataset = dataset.map(function=inference, batched=True, batch_size=16)
  32. all_transcriptions = []
  33. all_references = []
  34. # iterate over the dataset and run inference
  35. for i, result in tqdm(enumerate(dataset), desc="Evaluating..."):
  36. all_transcriptions.append(result["transcription"])
  37. all_references.append(result["reference"])
  38. if args.audio_numb and i == (args.audio_numb - 1):
  39. break
  40. # normalize predictions and references
  41. all_transcriptions = [normalizer(transcription) for transcription in all_transcriptions]
  42. all_references = [normalizer(reference) for reference in all_references]
  43. # compute the WER metric
  44. word_error_rate = 100 * wer(hypothesis=all_transcriptions, reference=all_references)
  45. print("WER: %.3f" % word_error_rate)