test_transcribe.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. import inspect
  2. import os
  3. import numpy as np
  4. from faster_whisper import BatchedInferencePipeline, WhisperModel, decode_audio
  5. def test_supported_languages():
  6. model = WhisperModel("tiny.en")
  7. assert model.supported_languages == ["en"]
  8. def test_transcribe(jfk_path):
  9. model = WhisperModel("tiny")
  10. segments, info = model.transcribe(jfk_path, word_timestamps=True)
  11. assert info.all_language_probs is not None
  12. assert info.language == "en"
  13. assert info.language_probability > 0.9
  14. assert info.duration == 11
  15. # Get top language info from all results, which should match the
  16. # already existing metadata
  17. top_lang, top_lang_score = info.all_language_probs[0]
  18. assert info.language == top_lang
  19. assert abs(info.language_probability - top_lang_score) < 1e-16
  20. segments = list(segments)
  21. assert len(segments) == 1
  22. segment = segments[0]
  23. assert segment.text == (
  24. " And so my fellow Americans, ask not what your country can do for you, "
  25. "ask what you can do for your country."
  26. )
  27. assert segment.text == "".join(word.word for word in segment.words)
  28. assert segment.start == segment.words[0].start
  29. assert segment.end == segment.words[-1].end
  30. batched_model = BatchedInferencePipeline(model=model)
  31. result, info = batched_model.transcribe(
  32. jfk_path, word_timestamps=True, vad_filter=False
  33. )
  34. assert info.language == "en"
  35. assert info.language_probability > 0.7
  36. segments = []
  37. for segment in result:
  38. segments.append(
  39. {"start": segment.start, "end": segment.end, "text": segment.text}
  40. )
  41. assert len(segments) == 1
  42. assert segment.text == (
  43. " And so my fellow Americans ask not what your country can do for you, "
  44. "ask what you can do for your country."
  45. )
  46. def test_batched_transcribe(physcisworks_path):
  47. model = WhisperModel("tiny")
  48. batched_model = BatchedInferencePipeline(model=model)
  49. result, info = batched_model.transcribe(physcisworks_path, batch_size=16)
  50. assert info.language == "en"
  51. assert info.language_probability > 0.7
  52. segments = []
  53. for segment in result:
  54. segments.append(
  55. {"start": segment.start, "end": segment.end, "text": segment.text}
  56. )
  57. # number of near 30 sec segments
  58. assert len(segments) == 7
  59. result, info = batched_model.transcribe(
  60. physcisworks_path,
  61. batch_size=16,
  62. without_timestamps=False,
  63. word_timestamps=True,
  64. )
  65. segments = []
  66. for segment in result:
  67. assert segment.words is not None
  68. segments.append(
  69. {"start": segment.start, "end": segment.end, "text": segment.text}
  70. )
  71. assert len(segments) > 7
  72. def test_empty_audio():
  73. audio = np.asarray([], dtype="float32")
  74. model = WhisperModel("tiny")
  75. pipeline = BatchedInferencePipeline(model=model)
  76. assert list(model.transcribe(audio)[0]) == []
  77. assert list(pipeline.transcribe(audio)[0]) == []
  78. model.detect_language(audio)
  79. def test_prefix_with_timestamps(jfk_path):
  80. model = WhisperModel("tiny")
  81. segments, _ = model.transcribe(jfk_path, prefix="And so my fellow Americans")
  82. segments = list(segments)
  83. assert len(segments) == 1
  84. segment = segments[0]
  85. assert segment.text == (
  86. " And so my fellow Americans, ask not what your country can do for you, "
  87. "ask what you can do for your country."
  88. )
  89. assert segment.start == 0
  90. assert 10 < segment.end <= 11
  91. def test_vad(jfk_path):
  92. model = WhisperModel("tiny")
  93. segments, info = model.transcribe(
  94. jfk_path,
  95. vad_filter=True,
  96. vad_parameters=dict(min_silence_duration_ms=500, speech_pad_ms=200),
  97. )
  98. segments = list(segments)
  99. assert len(segments) == 1
  100. segment = segments[0]
  101. assert segment.text == (
  102. " And so my fellow Americans ask not what your country can do for you, "
  103. "ask what you can do for your country."
  104. )
  105. assert 0 < segment.start < 1
  106. assert 10 < segment.end < 11
  107. assert info.vad_options.min_silence_duration_ms == 500
  108. assert info.vad_options.speech_pad_ms == 200
  109. def test_stereo_diarization(data_dir):
  110. model = WhisperModel("tiny")
  111. audio_path = os.path.join(data_dir, "stereo_diarization.wav")
  112. left, right = decode_audio(audio_path, split_stereo=True)
  113. segments, _ = model.transcribe(left)
  114. transcription = "".join(segment.text for segment in segments).strip()
  115. assert transcription == (
  116. "He began a confused complaint against the wizard, "
  117. "who had vanished behind the curtain on the left."
  118. )
  119. segments, _ = model.transcribe(right)
  120. transcription = "".join(segment.text for segment in segments).strip()
  121. assert transcription == "The horizon seems extremely distant."
  122. def test_multilingual_transcription(data_dir):
  123. model = WhisperModel("tiny")
  124. pipeline = BatchedInferencePipeline(model)
  125. audio_path = os.path.join(data_dir, "multilingual.mp3")
  126. audio = decode_audio(audio_path)
  127. segments, info = model.transcribe(
  128. audio,
  129. multilingual=True,
  130. without_timestamps=True,
  131. condition_on_previous_text=False,
  132. )
  133. segments = list(segments)
  134. assert (
  135. segments[0].text
  136. == " Permission is hereby granted, free of charge, to any person obtaining a copy of the"
  137. " software and associated documentation files to deal in the software without restriction,"
  138. " including without limitation the rights to use, copy, modify, merge, publish, distribute"
  139. ", sublicence, and or cell copies of the software, and to permit persons to whom the "
  140. "software is furnished to do so, subject to the following conditions. The above copyright"
  141. " notice and this permission notice, shall be included in all copies or substantial "
  142. "portions of the software."
  143. )
  144. assert (
  145. segments[1].text
  146. == " Jedem, der dieses Software und die dazu gehöregen Dokumentationsdatein erhält, wird "
  147. "hiermit unengeltlich die Genehmigung erteilt, wird der Software und eingeschränkt zu "
  148. "verfahren. Dies umfasst insbesondere das Recht, die Software zu verwenden, zu "
  149. "vervielfältigen, zu modifizieren, zu Samenzofügen, zu veröffentlichen, zu verteilen, "
  150. "unterzulizenzieren und oder kopieren der Software zu verkaufen und diese Rechte "
  151. "unterfolgen den Bedingungen anderen zu übertragen."
  152. )
  153. segments, info = pipeline.transcribe(audio, multilingual=True)
  154. segments = list(segments)
  155. assert (
  156. segments[0].text
  157. == " Permission is hereby granted, free of charge, to any person obtaining a copy of the"
  158. " software and associated documentation files to deal in the software without restriction,"
  159. " including without limitation the rights to use, copy, modify, merge, publish, distribute"
  160. ", sublicence, and or cell copies of the software, and to permit persons to whom the "
  161. "software is furnished to do so, subject to the following conditions. The above copyright"
  162. " notice and this permission notice, shall be included in all copies or substantial "
  163. "portions of the software."
  164. )
  165. assert (
  166. "Dokumentationsdatein erhält, wird hiermit unengeltlich die Genehmigung erteilt,"
  167. " wird der Software und eingeschränkt zu verfahren. Dies umfasst insbesondere das Recht,"
  168. " die Software zu verwenden, zu vervielfältigen, zu modifizieren"
  169. in segments[1].text
  170. )
  171. def test_hotwords(data_dir):
  172. model = WhisperModel("tiny")
  173. pipeline = BatchedInferencePipeline(model)
  174. audio_path = os.path.join(data_dir, "hotwords.mp3")
  175. audio = decode_audio(audio_path)
  176. segments, info = model.transcribe(audio, hotwords="ComfyUI")
  177. segments = list(segments)
  178. assert "ComfyUI" in segments[0].text
  179. assert info.transcription_options.hotwords == "ComfyUI"
  180. segments, info = pipeline.transcribe(audio, hotwords="ComfyUI")
  181. segments = list(segments)
  182. assert "ComfyUI" in segments[0].text
  183. assert info.transcription_options.hotwords == "ComfyUI"
  184. def test_transcribe_signature():
  185. model_transcribe_args = set(inspect.getargs(WhisperModel.transcribe.__code__).args)
  186. pipeline_transcribe_args = set(
  187. inspect.getargs(BatchedInferencePipeline.transcribe.__code__).args
  188. )
  189. pipeline_transcribe_args.remove("batch_size")
  190. assert model_transcribe_args == pipeline_transcribe_args
  191. def test_monotonic_timestamps(physcisworks_path):
  192. model = WhisperModel("tiny")
  193. pipeline = BatchedInferencePipeline(model=model)
  194. segments, info = model.transcribe(physcisworks_path, word_timestamps=True)
  195. segments = list(segments)
  196. for i in range(len(segments) - 1):
  197. assert segments[i].start <= segments[i].end
  198. assert segments[i].end <= segments[i + 1].start
  199. for word in segments[i].words:
  200. assert word.start <= word.end
  201. assert word.end <= segments[i].end
  202. assert segments[-1].end <= info.duration
  203. segments, info = pipeline.transcribe(physcisworks_path, word_timestamps=True)
  204. segments = list(segments)
  205. for i in range(len(segments) - 1):
  206. assert segments[i].start <= segments[i].end
  207. assert segments[i].end <= segments[i + 1].start
  208. for word in segments[i].words:
  209. assert word.start <= word.end
  210. assert word.end <= segments[i].end
  211. assert segments[-1].end <= info.duration