evaluate_yt_commons.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import argparse
  2. import json
  3. import os
  4. from io import BytesIO
  5. from datasets import load_dataset
  6. from jiwer import wer
  7. from pytubefix import YouTube
  8. from pytubefix.exceptions import VideoUnavailable
  9. from tqdm import tqdm
  10. from transformers.models.whisper.english_normalizer import EnglishTextNormalizer
  11. from faster_whisper import BatchedInferencePipeline, WhisperModel, decode_audio
  12. def url_to_audio(row):
  13. buffer = BytesIO()
  14. yt = YouTube(row["link"])
  15. try:
  16. video = (
  17. yt.streams.filter(only_audio=True, mime_type="audio/mp4")
  18. .order_by("bitrate")
  19. .desc()
  20. .last()
  21. )
  22. video.stream_to_buffer(buffer)
  23. buffer.seek(0)
  24. row["audio"] = decode_audio(buffer)
  25. except VideoUnavailable:
  26. print(f'Failed to download: {row["link"]}')
  27. row["audio"] = []
  28. return row
  29. parser = argparse.ArgumentParser(description="WER benchmark")
  30. parser.add_argument(
  31. "--audio_numb",
  32. type=int,
  33. default=None,
  34. help="Specify the number of validation audio files in the dataset."
  35. " Set to None to retrieve all audio files.",
  36. )
  37. args = parser.parse_args()
  38. with open(os.path.join(os.path.dirname(__file__), "normalizer.json"), "r") as f:
  39. normalizer = EnglishTextNormalizer(json.load(f))
  40. dataset = load_dataset("mobiuslabsgmbh/youtube-commons-asr-eval", streaming=True).map(
  41. url_to_audio
  42. )
  43. model = WhisperModel("large-v3", device="cuda")
  44. pipeline = BatchedInferencePipeline(model, device="cuda")
  45. all_transcriptions = []
  46. all_references = []
  47. # iterate over the dataset and run inference
  48. for i, row in tqdm(enumerate(dataset["test"]), desc="Evaluating..."):
  49. if not row["audio"]:
  50. continue
  51. result, info = pipeline.transcribe(
  52. row["audio"][0],
  53. batch_size=8,
  54. word_timestamps=False,
  55. without_timestamps=True,
  56. )
  57. all_transcriptions.append("".join(segment.text for segment in result))
  58. all_references.append(row["text"][0])
  59. if args.audio_numb and i == (args.audio_numb - 1):
  60. break
  61. # normalize predictions and references
  62. all_transcriptions = [normalizer(transcription) for transcription in all_transcriptions]
  63. all_references = [normalizer(reference) for reference in all_references]
  64. # compute the WER metric
  65. word_error_rate = 100 * wer(hypothesis=all_transcriptions, reference=all_references)
  66. print("WER: %.3f" % word_error_rate)