transcribe.py 77 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903
  1. import itertools
  2. import json
  3. import logging
  4. import os
  5. import zlib
  6. from dataclasses import asdict, dataclass
  7. from inspect import signature
  8. from math import ceil
  9. from typing import BinaryIO, Iterable, List, Optional, Tuple, Union
  10. from warnings import warn
  11. import ctranslate2
  12. import numpy as np
  13. import tokenizers
  14. from tqdm import tqdm
  15. from faster_whisper.audio import decode_audio, pad_or_trim
  16. from faster_whisper.feature_extractor import FeatureExtractor
  17. from faster_whisper.tokenizer import _LANGUAGE_CODES, Tokenizer
  18. from faster_whisper.utils import download_model, format_timestamp, get_end, get_logger
  19. from faster_whisper.vad import (
  20. SpeechTimestampsMap,
  21. VadOptions,
  22. collect_chunks,
  23. get_speech_timestamps,
  24. merge_segments,
  25. )
  26. @dataclass
  27. class Word:
  28. start: float
  29. end: float
  30. word: str
  31. probability: float
  32. def _asdict(self):
  33. warn(
  34. "Word._asdict() method is deprecated, use dataclasses.asdict(Word) instead",
  35. DeprecationWarning,
  36. 2,
  37. )
  38. return asdict(self)
  39. @dataclass
  40. class Segment:
  41. id: int
  42. seek: int
  43. start: float
  44. end: float
  45. text: str
  46. tokens: List[int]
  47. avg_logprob: float
  48. compression_ratio: float
  49. no_speech_prob: float
  50. words: Optional[List[Word]]
  51. temperature: Optional[float]
  52. def _asdict(self):
  53. warn(
  54. "Segment._asdict() method is deprecated, use dataclasses.asdict(Segment) instead",
  55. DeprecationWarning,
  56. 2,
  57. )
  58. return asdict(self)
  59. @dataclass
  60. class TranscriptionOptions:
  61. beam_size: int
  62. best_of: int
  63. patience: float
  64. length_penalty: float
  65. repetition_penalty: float
  66. no_repeat_ngram_size: int
  67. log_prob_threshold: Optional[float]
  68. no_speech_threshold: Optional[float]
  69. compression_ratio_threshold: Optional[float]
  70. condition_on_previous_text: bool
  71. prompt_reset_on_temperature: float
  72. temperatures: List[float]
  73. initial_prompt: Optional[Union[str, Iterable[int]]]
  74. prefix: Optional[str]
  75. suppress_blank: bool
  76. suppress_tokens: Optional[List[int]]
  77. without_timestamps: bool
  78. max_initial_timestamp: float
  79. word_timestamps: bool
  80. prepend_punctuations: str
  81. append_punctuations: str
  82. multilingual: bool
  83. max_new_tokens: Optional[int]
  84. clip_timestamps: Union[str, List[float]]
  85. hallucination_silence_threshold: Optional[float]
  86. hotwords: Optional[str]
  87. @dataclass
  88. class TranscriptionInfo:
  89. language: str
  90. language_probability: float
  91. duration: float
  92. duration_after_vad: float
  93. all_language_probs: Optional[List[Tuple[str, float]]]
  94. transcription_options: TranscriptionOptions
  95. vad_options: VadOptions
  96. class BatchedInferencePipeline:
  97. def __init__(
  98. self,
  99. model,
  100. ):
  101. self.model: WhisperModel = model
  102. self.last_speech_timestamp = 0.0
  103. def forward(self, features, tokenizer, chunks_metadata, options):
  104. encoder_output, outputs = self.generate_segment_batched(
  105. features, tokenizer, options
  106. )
  107. segmented_outputs = []
  108. segment_sizes = []
  109. for chunk_metadata, output in zip(chunks_metadata, outputs):
  110. duration = chunk_metadata["end_time"] - chunk_metadata["start_time"]
  111. segment_size = int(ceil(duration) * self.model.frames_per_second)
  112. segment_sizes.append(segment_size)
  113. (
  114. subsegments,
  115. seek,
  116. single_timestamp_ending,
  117. ) = self.model._split_segments_by_timestamps(
  118. tokenizer=tokenizer,
  119. tokens=output["tokens"],
  120. time_offset=chunk_metadata["start_time"],
  121. segment_size=segment_size,
  122. segment_duration=duration,
  123. seek=0,
  124. )
  125. segmented_outputs.append(
  126. [
  127. dict(
  128. text=tokenizer.decode(subsegment["tokens"]),
  129. avg_logprob=output["avg_logprob"],
  130. no_speech_prob=output["no_speech_prob"],
  131. tokens=subsegment["tokens"],
  132. start=subsegment["start"],
  133. end=subsegment["end"],
  134. compression_ratio=get_compression_ratio(
  135. tokenizer.decode(subsegment["tokens"])
  136. ),
  137. seek=int(
  138. chunk_metadata["start_time"] * self.model.frames_per_second
  139. ),
  140. )
  141. for subsegment in subsegments
  142. ]
  143. )
  144. if options.word_timestamps:
  145. self.last_speech_timestamp = self.model.add_word_timestamps(
  146. segmented_outputs,
  147. tokenizer,
  148. encoder_output,
  149. segment_sizes,
  150. options.prepend_punctuations,
  151. options.append_punctuations,
  152. self.last_speech_timestamp,
  153. )
  154. return segmented_outputs
  155. def generate_segment_batched(
  156. self,
  157. features: np.ndarray,
  158. tokenizer: Tokenizer,
  159. options: TranscriptionOptions,
  160. ):
  161. batch_size = features.shape[0]
  162. prompt = self.model.get_prompt(
  163. tokenizer,
  164. previous_tokens=(
  165. tokenizer.encode(options.initial_prompt)
  166. if options.initial_prompt is not None
  167. else []
  168. ),
  169. without_timestamps=options.without_timestamps,
  170. hotwords=options.hotwords,
  171. )
  172. if options.max_new_tokens is not None:
  173. max_length = len(prompt) + options.max_new_tokens
  174. else:
  175. max_length = self.model.max_length
  176. if max_length > self.model.max_length:
  177. raise ValueError(
  178. f"The length of the prompt is {len(prompt)}, and the `max_new_tokens` "
  179. f"{max_length - len(prompt)}. Thus, the combined length of the prompt "
  180. f"and `max_new_tokens` is: {max_length}. This exceeds the "
  181. f"`max_length` of the Whisper model: {self.model.max_length}. "
  182. "You should either reduce the length of your prompt, or "
  183. "reduce the value of `max_new_tokens`, "
  184. f"so that their combined length is less that {self.model.max_length}."
  185. )
  186. encoder_output = self.model.encode(features)
  187. prompts = [prompt.copy() for _ in range(batch_size)]
  188. if options.multilingual:
  189. language_tokens = [
  190. tokenizer.tokenizer.token_to_id(segment_langs[0][0])
  191. for segment_langs in self.model.model.detect_language(encoder_output)
  192. ]
  193. language_token_index = prompt.index(tokenizer.language)
  194. for i, language_token in enumerate(language_tokens):
  195. prompts[i][language_token_index] = language_token
  196. results = self.model.model.generate(
  197. encoder_output,
  198. prompts,
  199. beam_size=options.beam_size,
  200. patience=options.patience,
  201. length_penalty=options.length_penalty,
  202. max_length=max_length,
  203. suppress_blank=options.suppress_blank,
  204. suppress_tokens=options.suppress_tokens,
  205. return_scores=True,
  206. return_no_speech_prob=True,
  207. sampling_temperature=options.temperatures[0],
  208. repetition_penalty=options.repetition_penalty,
  209. no_repeat_ngram_size=options.no_repeat_ngram_size,
  210. )
  211. output = []
  212. for result in results:
  213. # return scores
  214. seq_len = len(result.sequences_ids[0])
  215. cum_logprob = result.scores[0] * (seq_len**options.length_penalty)
  216. output.append(
  217. dict(
  218. avg_logprob=cum_logprob / (seq_len + 1),
  219. no_speech_prob=result.no_speech_prob,
  220. tokens=result.sequences_ids[0],
  221. )
  222. )
  223. return encoder_output, output
  224. def transcribe(
  225. self,
  226. audio: Union[str, BinaryIO, np.ndarray],
  227. language: Optional[str] = None,
  228. task: str = "transcribe",
  229. log_progress: bool = False,
  230. beam_size: int = 5,
  231. best_of: int = 5,
  232. patience: float = 1,
  233. length_penalty: float = 1,
  234. repetition_penalty: float = 1,
  235. no_repeat_ngram_size: int = 0,
  236. temperature: Union[float, List[float], Tuple[float, ...]] = [
  237. 0.0,
  238. 0.2,
  239. 0.4,
  240. 0.6,
  241. 0.8,
  242. 1.0,
  243. ],
  244. compression_ratio_threshold: Optional[float] = 2.4,
  245. log_prob_threshold: Optional[float] = -1.0,
  246. no_speech_threshold: Optional[float] = 0.6,
  247. condition_on_previous_text: bool = True,
  248. prompt_reset_on_temperature: float = 0.5,
  249. initial_prompt: Optional[Union[str, Iterable[int]]] = None,
  250. prefix: Optional[str] = None,
  251. suppress_blank: bool = True,
  252. suppress_tokens: Optional[List[int]] = [-1],
  253. without_timestamps: bool = True,
  254. max_initial_timestamp: float = 1.0,
  255. word_timestamps: bool = False,
  256. prepend_punctuations: str = "\"'“¿([{-",
  257. append_punctuations: str = "\"'.。,,!!??::”)]}、",
  258. multilingual: bool = False,
  259. vad_filter: bool = True,
  260. vad_parameters: Optional[Union[dict, VadOptions]] = None,
  261. max_new_tokens: Optional[int] = None,
  262. chunk_length: Optional[int] = None,
  263. clip_timestamps: Optional[List[dict]] = None,
  264. hallucination_silence_threshold: Optional[float] = None,
  265. batch_size: int = 8,
  266. hotwords: Optional[str] = None,
  267. language_detection_threshold: Optional[float] = 0.5,
  268. language_detection_segments: int = 1,
  269. ) -> Tuple[Iterable[Segment], TranscriptionInfo]:
  270. """transcribe audio in chunks in batched fashion and return with language info.
  271. Arguments:
  272. audio: Path to the input file (or a file-like object), or the audio waveform.
  273. language: The language spoken in the audio. It should be a language code such
  274. as "en" or "fr". If not set, the language will be detected in the first 30 seconds
  275. of audio.
  276. task: Task to execute (transcribe or translate).
  277. log_progress: whether to show progress bar or not.
  278. beam_size: Beam size to use for decoding.
  279. best_of: Number of candidates when sampling with non-zero temperature.
  280. patience: Beam search patience factor.
  281. length_penalty: Exponential length penalty constant.
  282. repetition_penalty: Penalty applied to the score of previously generated tokens
  283. (set > 1 to penalize).
  284. no_repeat_ngram_size: Prevent repetitions of ngrams with this size (set 0 to disable).
  285. temperature: Temperature for sampling. If a list or tuple is passed,
  286. only the first value is used.
  287. initial_prompt: Optional text string or iterable of token ids to provide as a
  288. prompt for the each window.
  289. suppress_blank: Suppress blank outputs at the beginning of the sampling.
  290. suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
  291. of symbols as defined in `tokenizer.non_speech_tokens()`.
  292. without_timestamps: Only sample text tokens.
  293. word_timestamps: Extract word-level timestamps using the cross-attention pattern
  294. and dynamic time warping, and include the timestamps for each word in each segment.
  295. Set as False.
  296. prepend_punctuations: If word_timestamps is True, merge these punctuation symbols
  297. with the next word
  298. append_punctuations: If word_timestamps is True, merge these punctuation symbols
  299. with the previous word
  300. multilingual: Perform language detection on every segment.
  301. vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
  302. without speech. This step is using the Silero VAD model
  303. https://github.com/snakers4/silero-vad.
  304. vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
  305. parameters and default values in the class `VadOptions`).
  306. max_new_tokens: Maximum number of new tokens to generate per-chunk. If not set,
  307. the maximum will be set by the default max_length.
  308. chunk_length: The length of audio segments. If it is not None, it will overwrite the
  309. default chunk_length of the FeatureExtractor.
  310. clip_timestamps: Optionally provide list of dictionaries each containing "start" and
  311. "end" keys that specify the start and end of the voiced region within
  312. `chunk_length` boundary. vad_filter will be ignored if clip_timestamps is used.
  313. batch_size: the maximum number of parallel requests to model for decoding.
  314. hotwords:
  315. Hotwords/hint phrases to the model. Has no effect if prefix is not None.
  316. language_detection_threshold: If the maximum probability of the language tokens is
  317. higher than this value, the language is detected.
  318. language_detection_segments: Number of segments to consider for the language detection.
  319. Unused Arguments
  320. compression_ratio_threshold: If the gzip compression ratio is above this value,
  321. treat as failed.
  322. log_prob_threshold: If the average log probability over sampled tokens is
  323. below this value, treat as failed.
  324. no_speech_threshold: If the no_speech probability is higher than this value AND
  325. the average log probability over sampled tokens is below `log_prob_threshold`,
  326. consider the segment as silent.
  327. condition_on_previous_text: If True, the previous output of the model is provided
  328. as a prompt for the next window; disabling may make the text inconsistent across
  329. windows, but the model becomes less prone to getting stuck in a failure loop,
  330. such as repetition looping or timestamps going out of sync. Set as False
  331. prompt_reset_on_temperature: Resets prompt if temperature is above this value.
  332. Arg has effect only if condition_on_previous_text is True. Set at 0.5
  333. prefix: Optional text to provide as a prefix at the beginning of each window.
  334. max_initial_timestamp: The initial timestamp cannot be later than this, set at 0.0.
  335. hallucination_silence_threshold: Optional[float]
  336. When word_timestamps is True, skip silent periods longer than this threshold
  337. (in seconds) when a possible hallucination is detected. set as None.
  338. Returns:
  339. A tuple with:
  340. - a generator over transcribed segments
  341. - an instance of TranscriptionInfo
  342. """
  343. sampling_rate = self.model.feature_extractor.sampling_rate
  344. if multilingual and not self.model.model.is_multilingual:
  345. self.model.logger.warning(
  346. "The current model is English-only but the multilingual parameter is set to"
  347. "True; setting to False instead."
  348. )
  349. multilingual = False
  350. if not isinstance(audio, np.ndarray):
  351. audio = decode_audio(audio, sampling_rate=sampling_rate)
  352. duration = audio.shape[0] / sampling_rate
  353. self.model.logger.info(
  354. "Processing audio with duration %s", format_timestamp(duration)
  355. )
  356. chunk_length = chunk_length or self.model.feature_extractor.chunk_length
  357. # if no segment split is provided, use vad_model and generate segments
  358. if not clip_timestamps:
  359. if vad_filter:
  360. if vad_parameters is None:
  361. vad_parameters = VadOptions(
  362. max_speech_duration_s=chunk_length,
  363. min_silence_duration_ms=160,
  364. )
  365. elif isinstance(vad_parameters, dict):
  366. if "max_speech_duration_s" in vad_parameters.keys():
  367. vad_parameters.pop("max_speech_duration_s")
  368. vad_parameters = VadOptions(
  369. **vad_parameters, max_speech_duration_s=chunk_length
  370. )
  371. active_segments = get_speech_timestamps(audio, vad_parameters)
  372. clip_timestamps = merge_segments(active_segments, vad_parameters)
  373. # run the audio if it is less than 30 sec even without clip_timestamps
  374. elif duration < chunk_length:
  375. clip_timestamps = [{"start": 0, "end": audio.shape[0]}]
  376. else:
  377. raise RuntimeError(
  378. "No clip timestamps found. "
  379. "Set 'vad_filter' to True or provide 'clip_timestamps'."
  380. )
  381. duration_after_vad = (
  382. sum((segment["end"] - segment["start"]) for segment in clip_timestamps)
  383. / sampling_rate
  384. )
  385. self.model.logger.info(
  386. "VAD filter removed %s of audio",
  387. format_timestamp(duration - duration_after_vad),
  388. )
  389. audio_chunks, chunks_metadata = collect_chunks(audio, clip_timestamps)
  390. features = (
  391. [self.model.feature_extractor(chunk)[..., :-1] for chunk in audio_chunks]
  392. if duration_after_vad
  393. else []
  394. )
  395. all_language_probs = None
  396. # detecting the language if not provided
  397. if language is None:
  398. if not self.model.model.is_multilingual:
  399. language = "en"
  400. language_probability = 1
  401. else:
  402. (
  403. language,
  404. language_probability,
  405. all_language_probs,
  406. ) = self.model.detect_language(
  407. features=np.concatenate(
  408. features
  409. + [
  410. np.full((self.model.model.n_mels, 1), -1.5, dtype="float32")
  411. ],
  412. axis=1,
  413. ), # add a dummy feature to account for empty audio
  414. language_detection_segments=language_detection_segments,
  415. language_detection_threshold=language_detection_threshold,
  416. )
  417. self.model.logger.info(
  418. "Detected language '%s' with probability %.2f",
  419. language,
  420. language_probability,
  421. )
  422. else:
  423. if not self.model.model.is_multilingual and language != "en":
  424. self.model.logger.warning(
  425. "The current model is English-only but the language parameter is set to '%s'; "
  426. "using 'en' instead." % language
  427. )
  428. language = "en"
  429. language_probability = 1
  430. tokenizer = Tokenizer(
  431. self.model.hf_tokenizer,
  432. self.model.model.is_multilingual,
  433. task=task,
  434. language=language,
  435. )
  436. features = (
  437. np.stack([pad_or_trim(feature) for feature in features]) if features else []
  438. )
  439. options = TranscriptionOptions(
  440. beam_size=beam_size,
  441. best_of=best_of,
  442. patience=patience,
  443. length_penalty=length_penalty,
  444. repetition_penalty=repetition_penalty,
  445. no_repeat_ngram_size=no_repeat_ngram_size,
  446. log_prob_threshold=log_prob_threshold,
  447. no_speech_threshold=no_speech_threshold,
  448. compression_ratio_threshold=compression_ratio_threshold,
  449. temperatures=(
  450. temperature[:1]
  451. if isinstance(temperature, (list, tuple))
  452. else [temperature]
  453. ),
  454. initial_prompt=initial_prompt,
  455. prefix=prefix,
  456. suppress_blank=suppress_blank,
  457. suppress_tokens=(
  458. get_suppressed_tokens(tokenizer, suppress_tokens)
  459. if suppress_tokens
  460. else suppress_tokens
  461. ),
  462. prepend_punctuations=prepend_punctuations,
  463. append_punctuations=append_punctuations,
  464. max_new_tokens=max_new_tokens,
  465. hotwords=hotwords,
  466. word_timestamps=word_timestamps,
  467. hallucination_silence_threshold=None,
  468. condition_on_previous_text=False,
  469. clip_timestamps=clip_timestamps,
  470. prompt_reset_on_temperature=0.5,
  471. multilingual=multilingual,
  472. without_timestamps=without_timestamps,
  473. max_initial_timestamp=0.0,
  474. )
  475. info = TranscriptionInfo(
  476. language=language,
  477. language_probability=language_probability,
  478. duration=duration,
  479. duration_after_vad=duration_after_vad,
  480. transcription_options=options,
  481. vad_options=vad_parameters,
  482. all_language_probs=all_language_probs,
  483. )
  484. segments = self._batched_segments_generator(
  485. features,
  486. tokenizer,
  487. chunks_metadata,
  488. batch_size,
  489. options,
  490. log_progress,
  491. )
  492. return segments, info
  493. def _batched_segments_generator(
  494. self, features, tokenizer, chunks_metadata, batch_size, options, log_progress
  495. ):
  496. pbar = tqdm(total=len(features), disable=not log_progress, position=0)
  497. seg_idx = 0
  498. for i in range(0, len(features), batch_size):
  499. results = self.forward(
  500. features[i : i + batch_size],
  501. tokenizer,
  502. chunks_metadata[i : i + batch_size],
  503. options,
  504. )
  505. for result in results:
  506. for segment in result:
  507. seg_idx += 1
  508. yield Segment(
  509. seek=segment["seek"],
  510. id=seg_idx,
  511. text=segment["text"],
  512. start=round(segment["start"], 3),
  513. end=round(segment["end"], 3),
  514. words=(
  515. None
  516. if not options.word_timestamps
  517. else [Word(**word) for word in segment["words"]]
  518. ),
  519. tokens=segment["tokens"],
  520. avg_logprob=segment["avg_logprob"],
  521. no_speech_prob=segment["no_speech_prob"],
  522. compression_ratio=segment["compression_ratio"],
  523. temperature=options.temperatures[0],
  524. )
  525. pbar.update(1)
  526. pbar.close()
  527. self.last_speech_timestamp = 0.0
  528. class WhisperModel:
  529. def __init__(
  530. self,
  531. model_size_or_path: str,
  532. device: str = "auto",
  533. device_index: Union[int, List[int]] = 0,
  534. compute_type: str = "default",
  535. cpu_threads: int = 0,
  536. num_workers: int = 1,
  537. download_root: Optional[str] = None,
  538. local_files_only: bool = False,
  539. files: dict = None,
  540. revision: Optional[str] = None,
  541. **model_kwargs,
  542. ):
  543. """Initializes the Whisper model.
  544. Args:
  545. model_size_or_path: Size of the model to use (tiny, tiny.en, base, base.en,
  546. small, small.en, distil-small.en, medium, medium.en, distil-medium.en, large-v1,
  547. large-v2, large-v3, large, distil-large-v2, distil-large-v3, large-v3-turbo, or turbo),
  548. a path to a converted model directory, or a CTranslate2-converted Whisper model ID from
  549. the HF Hub. When a size or a model ID is configured, the converted model is downloaded
  550. from the Hugging Face Hub.
  551. device: Device to use for computation ("cpu", "cuda", "auto").
  552. device_index: Device ID to use.
  553. The model can also be loaded on multiple GPUs by passing a list of IDs
  554. (e.g. [0, 1, 2, 3]). In that case, multiple transcriptions can run in parallel
  555. when transcribe() is called from multiple Python threads (see also num_workers).
  556. compute_type: Type to use for computation.
  557. See https://opennmt.net/CTranslate2/quantization.html.
  558. cpu_threads: Number of threads to use when running on CPU (4 by default).
  559. A non zero value overrides the OMP_NUM_THREADS environment variable.
  560. num_workers: When transcribe() is called from multiple Python threads,
  561. having multiple workers enables true parallelism when running the model
  562. (concurrent calls to self.model.generate() will run in parallel).
  563. This can improve the global throughput at the cost of increased memory usage.
  564. download_root: Directory where the models should be saved. If not set, the models
  565. are saved in the standard Hugging Face cache directory.
  566. local_files_only: If True, avoid downloading the file and return the path to the
  567. local cached file if it exists.
  568. files: Load model files from the memory. This argument is a dictionary mapping file names
  569. to file contents as file-like or bytes objects. If this is set, model_path acts as an
  570. identifier for this model.
  571. revision:
  572. An optional Git revision id which can be a branch name, a tag, or a
  573. commit hash.
  574. """
  575. self.logger = get_logger()
  576. tokenizer_bytes, preprocessor_bytes = None, None
  577. if files:
  578. model_path = model_size_or_path
  579. tokenizer_bytes = files.pop("tokenizer.json", None)
  580. preprocessor_bytes = files.pop("preprocessor_config.json", None)
  581. elif os.path.isdir(model_size_or_path):
  582. model_path = model_size_or_path
  583. else:
  584. model_path = download_model(
  585. model_size_or_path,
  586. local_files_only=local_files_only,
  587. cache_dir=download_root,
  588. revision=revision,
  589. )
  590. self.model = ctranslate2.models.Whisper(
  591. model_path,
  592. device=device,
  593. device_index=device_index,
  594. compute_type=compute_type,
  595. intra_threads=cpu_threads,
  596. inter_threads=num_workers,
  597. files=files,
  598. **model_kwargs,
  599. )
  600. tokenizer_file = os.path.join(model_path, "tokenizer.json")
  601. if tokenizer_bytes:
  602. self.hf_tokenizer = tokenizers.Tokenizer.from_buffer(tokenizer_bytes)
  603. elif os.path.isfile(tokenizer_file):
  604. self.hf_tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file)
  605. else:
  606. self.hf_tokenizer = tokenizers.Tokenizer.from_pretrained(
  607. "openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en")
  608. )
  609. self.feat_kwargs = self._get_feature_kwargs(model_path, preprocessor_bytes)
  610. self.feature_extractor = FeatureExtractor(**self.feat_kwargs)
  611. self.input_stride = 2
  612. self.num_samples_per_token = (
  613. self.feature_extractor.hop_length * self.input_stride
  614. )
  615. self.frames_per_second = (
  616. self.feature_extractor.sampling_rate // self.feature_extractor.hop_length
  617. )
  618. self.tokens_per_second = (
  619. self.feature_extractor.sampling_rate // self.num_samples_per_token
  620. )
  621. self.time_precision = 0.02
  622. self.max_length = 448
  623. @property
  624. def supported_languages(self) -> List[str]:
  625. """The languages supported by the model."""
  626. return list(_LANGUAGE_CODES) if self.model.is_multilingual else ["en"]
  627. def _get_feature_kwargs(self, model_path, preprocessor_bytes=None) -> dict:
  628. config = {}
  629. try:
  630. config_path = os.path.join(model_path, "preprocessor_config.json")
  631. if preprocessor_bytes:
  632. config = json.loads(preprocessor_bytes)
  633. elif os.path.isfile(config_path):
  634. with open(config_path, "r", encoding="utf-8") as file:
  635. config = json.load(file)
  636. else:
  637. return config
  638. valid_keys = signature(FeatureExtractor.__init__).parameters.keys()
  639. return {k: v for k, v in config.items() if k in valid_keys}
  640. except json.JSONDecodeError as e:
  641. self.logger.warning("Could not load preprocessor config: %s", e)
  642. return config
  643. def transcribe(
  644. self,
  645. audio: Union[str, BinaryIO, np.ndarray],
  646. language: Optional[str] = None,
  647. task: str = "transcribe",
  648. log_progress: bool = False,
  649. beam_size: int = 5,
  650. best_of: int = 5,
  651. patience: float = 1,
  652. length_penalty: float = 1,
  653. repetition_penalty: float = 1,
  654. no_repeat_ngram_size: int = 0,
  655. temperature: Union[float, List[float], Tuple[float, ...]] = [
  656. 0.0,
  657. 0.2,
  658. 0.4,
  659. 0.6,
  660. 0.8,
  661. 1.0,
  662. ],
  663. compression_ratio_threshold: Optional[float] = 2.4,
  664. log_prob_threshold: Optional[float] = -1.0,
  665. no_speech_threshold: Optional[float] = 0.6,
  666. condition_on_previous_text: bool = True,
  667. prompt_reset_on_temperature: float = 0.5,
  668. initial_prompt: Optional[Union[str, Iterable[int]]] = None,
  669. prefix: Optional[str] = None,
  670. suppress_blank: bool = True,
  671. suppress_tokens: Optional[List[int]] = [-1],
  672. without_timestamps: bool = False,
  673. max_initial_timestamp: float = 1.0,
  674. word_timestamps: bool = False,
  675. prepend_punctuations: str = "\"'“¿([{-",
  676. append_punctuations: str = "\"'.。,,!!??::”)]}、",
  677. multilingual: bool = False,
  678. vad_filter: bool = False,
  679. vad_parameters: Optional[Union[dict, VadOptions]] = None,
  680. max_new_tokens: Optional[int] = None,
  681. chunk_length: Optional[int] = None,
  682. clip_timestamps: Union[str, List[float]] = "0",
  683. hallucination_silence_threshold: Optional[float] = None,
  684. hotwords: Optional[str] = None,
  685. language_detection_threshold: Optional[float] = 0.5,
  686. language_detection_segments: int = 1,
  687. ) -> Tuple[Iterable[Segment], TranscriptionInfo]:
  688. """Transcribes an input file.
  689. Arguments:
  690. audio: Path to the input file (or a file-like object), or the audio waveform.
  691. language: The language spoken in the audio. It should be a language code such
  692. as "en" or "fr". If not set, the language will be detected in the first 30 seconds
  693. of audio.
  694. task: Task to execute (transcribe or translate).
  695. log_progress: whether to show progress bar or not.
  696. beam_size: Beam size to use for decoding.
  697. best_of: Number of candidates when sampling with non-zero temperature.
  698. patience: Beam search patience factor.
  699. length_penalty: Exponential length penalty constant.
  700. repetition_penalty: Penalty applied to the score of previously generated tokens
  701. (set > 1 to penalize).
  702. no_repeat_ngram_size: Prevent repetitions of ngrams with this size (set 0 to disable).
  703. temperature: Temperature for sampling. It can be a tuple of temperatures,
  704. which will be successively used upon failures according to either
  705. `compression_ratio_threshold` or `log_prob_threshold`.
  706. compression_ratio_threshold: If the gzip compression ratio is above this value,
  707. treat as failed.
  708. log_prob_threshold: If the average log probability over sampled tokens is
  709. below this value, treat as failed.
  710. no_speech_threshold: If the no_speech probability is higher than this value AND
  711. the average log probability over sampled tokens is below `log_prob_threshold`,
  712. consider the segment as silent.
  713. condition_on_previous_text: If True, the previous output of the model is provided
  714. as a prompt for the next window; disabling may make the text inconsistent across
  715. windows, but the model becomes less prone to getting stuck in a failure loop,
  716. such as repetition looping or timestamps going out of sync.
  717. prompt_reset_on_temperature: Resets prompt if temperature is above this value.
  718. Arg has effect only if condition_on_previous_text is True.
  719. initial_prompt: Optional text string or iterable of token ids to provide as a
  720. prompt for the first window.
  721. prefix: Optional text to provide as a prefix for the first window.
  722. suppress_blank: Suppress blank outputs at the beginning of the sampling.
  723. suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
  724. of symbols as defined in `tokenizer.non_speech_tokens()`.
  725. without_timestamps: Only sample text tokens.
  726. max_initial_timestamp: The initial timestamp cannot be later than this.
  727. word_timestamps: Extract word-level timestamps using the cross-attention pattern
  728. and dynamic time warping, and include the timestamps for each word in each segment.
  729. prepend_punctuations: If word_timestamps is True, merge these punctuation symbols
  730. with the next word
  731. append_punctuations: If word_timestamps is True, merge these punctuation symbols
  732. with the previous word
  733. multilingual: Perform language detection on every segment.
  734. vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
  735. without speech. This step is using the Silero VAD model
  736. https://github.com/snakers4/silero-vad.
  737. vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
  738. parameters and default values in the class `VadOptions`).
  739. max_new_tokens: Maximum number of new tokens to generate per-chunk. If not set,
  740. the maximum will be set by the default max_length.
  741. chunk_length: The length of audio segments. If it is not None, it will overwrite the
  742. default chunk_length of the FeatureExtractor.
  743. clip_timestamps:
  744. Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to
  745. process. The last end timestamp defaults to the end of the file.
  746. vad_filter will be ignored if clip_timestamps is used.
  747. hallucination_silence_threshold:
  748. When word_timestamps is True, skip silent periods longer than this threshold
  749. (in seconds) when a possible hallucination is detected
  750. hotwords:
  751. Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.
  752. language_detection_threshold: If the maximum probability of the language tokens is higher
  753. than this value, the language is detected.
  754. language_detection_segments: Number of segments to consider for the language detection.
  755. Returns:
  756. A tuple with:
  757. - a generator over transcribed segments
  758. - an instance of TranscriptionInfo
  759. """
  760. sampling_rate = self.feature_extractor.sampling_rate
  761. if multilingual and not self.model.is_multilingual:
  762. self.logger.warning(
  763. "The current model is English-only but the multilingual parameter is set to"
  764. "True; setting to False instead."
  765. )
  766. multilingual = False
  767. if not isinstance(audio, np.ndarray):
  768. audio = decode_audio(audio, sampling_rate=sampling_rate)
  769. duration = audio.shape[0] / sampling_rate
  770. duration_after_vad = duration
  771. self.logger.info(
  772. "Processing audio with duration %s", format_timestamp(duration)
  773. )
  774. if vad_filter and clip_timestamps == "0":
  775. if vad_parameters is None:
  776. vad_parameters = VadOptions()
  777. elif isinstance(vad_parameters, dict):
  778. vad_parameters = VadOptions(**vad_parameters)
  779. speech_chunks = get_speech_timestamps(audio, vad_parameters)
  780. audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks)
  781. audio = np.concatenate(audio_chunks, axis=0)
  782. duration_after_vad = audio.shape[0] / sampling_rate
  783. self.logger.info(
  784. "VAD filter removed %s of audio",
  785. format_timestamp(duration - duration_after_vad),
  786. )
  787. if self.logger.isEnabledFor(logging.DEBUG):
  788. self.logger.debug(
  789. "VAD filter kept the following audio segments: %s",
  790. ", ".join(
  791. "[%s -> %s]"
  792. % (
  793. format_timestamp(chunk["start"] / sampling_rate),
  794. format_timestamp(chunk["end"] / sampling_rate),
  795. )
  796. for chunk in speech_chunks
  797. ),
  798. )
  799. else:
  800. speech_chunks = None
  801. features = self.feature_extractor(audio, chunk_length=chunk_length)
  802. encoder_output = None
  803. all_language_probs = None
  804. # detecting the language if not provided
  805. if language is None:
  806. if not self.model.is_multilingual:
  807. language = "en"
  808. language_probability = 1
  809. else:
  810. start_timestamp = (
  811. float(clip_timestamps.split(",")[0])
  812. if isinstance(clip_timestamps, str)
  813. else clip_timestamps[0]
  814. )
  815. content_frames = features.shape[-1] - 1
  816. seek = (
  817. int(start_timestamp * self.frames_per_second)
  818. if start_timestamp * self.frames_per_second < content_frames
  819. else 0
  820. )
  821. (
  822. language,
  823. language_probability,
  824. all_language_probs,
  825. ) = self.detect_language(
  826. features=features[..., seek:],
  827. language_detection_segments=language_detection_segments,
  828. language_detection_threshold=language_detection_threshold,
  829. )
  830. self.logger.info(
  831. "Detected language '%s' with probability %.2f",
  832. language,
  833. language_probability,
  834. )
  835. else:
  836. if not self.model.is_multilingual and language != "en":
  837. self.logger.warning(
  838. "The current model is English-only but the language parameter is set to '%s'; "
  839. "using 'en' instead." % language
  840. )
  841. language = "en"
  842. language_probability = 1
  843. tokenizer = Tokenizer(
  844. self.hf_tokenizer,
  845. self.model.is_multilingual,
  846. task=task,
  847. language=language,
  848. )
  849. options = TranscriptionOptions(
  850. beam_size=beam_size,
  851. best_of=best_of,
  852. patience=patience,
  853. length_penalty=length_penalty,
  854. repetition_penalty=repetition_penalty,
  855. no_repeat_ngram_size=no_repeat_ngram_size,
  856. log_prob_threshold=log_prob_threshold,
  857. no_speech_threshold=no_speech_threshold,
  858. compression_ratio_threshold=compression_ratio_threshold,
  859. condition_on_previous_text=condition_on_previous_text,
  860. prompt_reset_on_temperature=prompt_reset_on_temperature,
  861. temperatures=(
  862. temperature if isinstance(temperature, (list, tuple)) else [temperature]
  863. ),
  864. initial_prompt=initial_prompt,
  865. prefix=prefix,
  866. suppress_blank=suppress_blank,
  867. suppress_tokens=(
  868. get_suppressed_tokens(tokenizer, suppress_tokens)
  869. if suppress_tokens
  870. else suppress_tokens
  871. ),
  872. without_timestamps=without_timestamps,
  873. max_initial_timestamp=max_initial_timestamp,
  874. word_timestamps=word_timestamps,
  875. prepend_punctuations=prepend_punctuations,
  876. append_punctuations=append_punctuations,
  877. multilingual=multilingual,
  878. max_new_tokens=max_new_tokens,
  879. clip_timestamps=clip_timestamps,
  880. hallucination_silence_threshold=hallucination_silence_threshold,
  881. hotwords=hotwords,
  882. )
  883. segments = self.generate_segments(
  884. features, tokenizer, options, log_progress, encoder_output
  885. )
  886. if speech_chunks:
  887. segments = restore_speech_timestamps(segments, speech_chunks, sampling_rate)
  888. info = TranscriptionInfo(
  889. language=language,
  890. language_probability=language_probability,
  891. duration=duration,
  892. duration_after_vad=duration_after_vad,
  893. transcription_options=options,
  894. vad_options=vad_parameters,
  895. all_language_probs=all_language_probs,
  896. )
  897. return segments, info
  898. def _split_segments_by_timestamps(
  899. self,
  900. tokenizer: Tokenizer,
  901. tokens: List[int],
  902. time_offset: float,
  903. segment_size: int,
  904. segment_duration: float,
  905. seek: int,
  906. ) -> List[List[int]]:
  907. current_segments = []
  908. single_timestamp_ending = (
  909. len(tokens) >= 2 and tokens[-2] < tokenizer.timestamp_begin <= tokens[-1]
  910. )
  911. consecutive_timestamps = [
  912. i
  913. for i in range(len(tokens))
  914. if i > 0
  915. and tokens[i] >= tokenizer.timestamp_begin
  916. and tokens[i - 1] >= tokenizer.timestamp_begin
  917. ]
  918. if len(consecutive_timestamps) > 0:
  919. slices = list(consecutive_timestamps)
  920. if single_timestamp_ending:
  921. slices.append(len(tokens))
  922. last_slice = 0
  923. for current_slice in slices:
  924. sliced_tokens = tokens[last_slice:current_slice]
  925. start_timestamp_position = sliced_tokens[0] - tokenizer.timestamp_begin
  926. end_timestamp_position = sliced_tokens[-1] - tokenizer.timestamp_begin
  927. start_time = (
  928. time_offset + start_timestamp_position * self.time_precision
  929. )
  930. end_time = time_offset + end_timestamp_position * self.time_precision
  931. current_segments.append(
  932. dict(
  933. seek=seek,
  934. start=start_time,
  935. end=end_time,
  936. tokens=sliced_tokens,
  937. )
  938. )
  939. last_slice = current_slice
  940. if single_timestamp_ending:
  941. # single timestamp at the end means no speech after the last timestamp.
  942. seek += segment_size
  943. else:
  944. # otherwise, ignore the unfinished segment and seek to the last timestamp
  945. last_timestamp_position = (
  946. tokens[last_slice - 1] - tokenizer.timestamp_begin
  947. )
  948. seek += last_timestamp_position * self.input_stride
  949. else:
  950. duration = segment_duration
  951. timestamps = [
  952. token for token in tokens if token >= tokenizer.timestamp_begin
  953. ]
  954. if len(timestamps) > 0 and timestamps[-1] != tokenizer.timestamp_begin:
  955. last_timestamp_position = timestamps[-1] - tokenizer.timestamp_begin
  956. duration = last_timestamp_position * self.time_precision
  957. current_segments.append(
  958. dict(
  959. seek=seek,
  960. start=time_offset,
  961. end=time_offset + duration,
  962. tokens=tokens,
  963. )
  964. )
  965. seek += segment_size
  966. return current_segments, seek, single_timestamp_ending
  967. def generate_segments(
  968. self,
  969. features: np.ndarray,
  970. tokenizer: Tokenizer,
  971. options: TranscriptionOptions,
  972. log_progress,
  973. encoder_output: Optional[ctranslate2.StorageView] = None,
  974. ) -> Iterable[Segment]:
  975. content_frames = features.shape[-1] - 1
  976. content_duration = float(content_frames * self.feature_extractor.time_per_frame)
  977. if isinstance(options.clip_timestamps, str):
  978. options.clip_timestamps = [
  979. float(ts)
  980. for ts in (
  981. options.clip_timestamps.split(",")
  982. if options.clip_timestamps
  983. else []
  984. )
  985. ]
  986. seek_points: List[int] = [
  987. round(ts * self.frames_per_second) for ts in options.clip_timestamps
  988. ]
  989. if len(seek_points) == 0:
  990. seek_points.append(0)
  991. if len(seek_points) % 2 == 1:
  992. seek_points.append(content_frames)
  993. seek_clips: List[Tuple[int, int]] = list(
  994. zip(seek_points[::2], seek_points[1::2])
  995. )
  996. punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"
  997. idx = 0
  998. clip_idx = 0
  999. seek = seek_clips[clip_idx][0]
  1000. all_tokens = []
  1001. prompt_reset_since = 0
  1002. if options.initial_prompt is not None:
  1003. if isinstance(options.initial_prompt, str):
  1004. initial_prompt = " " + options.initial_prompt.strip()
  1005. initial_prompt_tokens = tokenizer.encode(initial_prompt)
  1006. all_tokens.extend(initial_prompt_tokens)
  1007. else:
  1008. all_tokens.extend(options.initial_prompt)
  1009. pbar = tqdm(total=content_duration, unit="seconds", disable=not log_progress)
  1010. last_speech_timestamp = 0.0
  1011. # NOTE: This loop is obscurely flattened to make the diff readable.
  1012. # A later commit should turn this into a simpler nested loop.
  1013. # for seek_clip_start, seek_clip_end in seek_clips:
  1014. # while seek < seek_clip_end
  1015. while clip_idx < len(seek_clips):
  1016. seek_clip_start, seek_clip_end = seek_clips[clip_idx]
  1017. if seek_clip_end > content_frames:
  1018. seek_clip_end = content_frames
  1019. if seek < seek_clip_start:
  1020. seek = seek_clip_start
  1021. if seek >= seek_clip_end:
  1022. clip_idx += 1
  1023. if clip_idx < len(seek_clips):
  1024. seek = seek_clips[clip_idx][0]
  1025. continue
  1026. time_offset = seek * self.feature_extractor.time_per_frame
  1027. window_end_time = float(
  1028. (seek + self.feature_extractor.nb_max_frames)
  1029. * self.feature_extractor.time_per_frame
  1030. )
  1031. segment_size = min(
  1032. self.feature_extractor.nb_max_frames,
  1033. content_frames - seek,
  1034. seek_clip_end - seek,
  1035. )
  1036. segment = features[:, seek : seek + segment_size]
  1037. segment_duration = segment_size * self.feature_extractor.time_per_frame
  1038. segment = pad_or_trim(segment)
  1039. if self.logger.isEnabledFor(logging.DEBUG):
  1040. self.logger.debug(
  1041. "Processing segment at %s", format_timestamp(time_offset)
  1042. )
  1043. previous_tokens = all_tokens[prompt_reset_since:]
  1044. if seek > 0 or encoder_output is None:
  1045. encoder_output = self.encode(segment)
  1046. if options.multilingual:
  1047. results = self.model.detect_language(encoder_output)
  1048. language_token, language_probability = results[0][0]
  1049. language = language_token[2:-2]
  1050. tokenizer.language = tokenizer.tokenizer.token_to_id(language_token)
  1051. tokenizer.language_code = language
  1052. prompt = self.get_prompt(
  1053. tokenizer,
  1054. previous_tokens,
  1055. without_timestamps=options.without_timestamps,
  1056. prefix=options.prefix if seek == 0 else None,
  1057. hotwords=options.hotwords,
  1058. )
  1059. (
  1060. result,
  1061. avg_logprob,
  1062. temperature,
  1063. compression_ratio,
  1064. ) = self.generate_with_fallback(encoder_output, prompt, tokenizer, options)
  1065. if options.no_speech_threshold is not None:
  1066. # no voice activity check
  1067. should_skip = result.no_speech_prob > options.no_speech_threshold
  1068. if (
  1069. options.log_prob_threshold is not None
  1070. and avg_logprob > options.log_prob_threshold
  1071. ):
  1072. # don't skip if the logprob is high enough, despite the no_speech_prob
  1073. should_skip = False
  1074. if should_skip:
  1075. self.logger.debug(
  1076. "No speech threshold is met (%f > %f)",
  1077. result.no_speech_prob,
  1078. options.no_speech_threshold,
  1079. )
  1080. # fast-forward to the next segment boundary
  1081. seek += segment_size
  1082. continue
  1083. tokens = result.sequences_ids[0]
  1084. previous_seek = seek
  1085. # anomalous words are very long/short/improbable
  1086. def word_anomaly_score(word: dict) -> float:
  1087. probability = word.get("probability", 0.0)
  1088. duration = word["end"] - word["start"]
  1089. score = 0.0
  1090. if probability < 0.15:
  1091. score += 1.0
  1092. if duration < 0.133:
  1093. score += (0.133 - duration) * 15
  1094. if duration > 2.0:
  1095. score += duration - 2.0
  1096. return score
  1097. def is_segment_anomaly(segment: Optional[dict]) -> bool:
  1098. if segment is None or not segment["words"]:
  1099. return False
  1100. words = [w for w in segment["words"] if w["word"] not in punctuation]
  1101. words = words[:8]
  1102. score = sum(word_anomaly_score(w) for w in words)
  1103. return score >= 3 or score + 0.01 >= len(words)
  1104. def next_words_segment(segments: List[dict]) -> Optional[dict]:
  1105. return next((s for s in segments if s["words"]), None)
  1106. (
  1107. current_segments,
  1108. seek,
  1109. single_timestamp_ending,
  1110. ) = self._split_segments_by_timestamps(
  1111. tokenizer=tokenizer,
  1112. tokens=tokens,
  1113. time_offset=time_offset,
  1114. segment_size=segment_size,
  1115. segment_duration=segment_duration,
  1116. seek=seek,
  1117. )
  1118. if options.word_timestamps:
  1119. self.add_word_timestamps(
  1120. [current_segments],
  1121. tokenizer,
  1122. encoder_output,
  1123. segment_size,
  1124. options.prepend_punctuations,
  1125. options.append_punctuations,
  1126. last_speech_timestamp=last_speech_timestamp,
  1127. )
  1128. if not single_timestamp_ending:
  1129. last_word_end = get_end(current_segments)
  1130. if last_word_end is not None and last_word_end > time_offset:
  1131. seek = round(last_word_end * self.frames_per_second)
  1132. # skip silence before possible hallucinations
  1133. if options.hallucination_silence_threshold is not None:
  1134. threshold = options.hallucination_silence_threshold
  1135. # if first segment might be a hallucination, skip leading silence
  1136. first_segment = next_words_segment(current_segments)
  1137. if first_segment is not None and is_segment_anomaly(first_segment):
  1138. gap = first_segment["start"] - time_offset
  1139. if gap > threshold:
  1140. seek = previous_seek + round(gap * self.frames_per_second)
  1141. continue
  1142. # skip silence before any possible hallucination that is surrounded
  1143. # by silence or more hallucinations
  1144. hal_last_end = last_speech_timestamp
  1145. for si in range(len(current_segments)):
  1146. segment = current_segments[si]
  1147. if not segment["words"]:
  1148. continue
  1149. if is_segment_anomaly(segment):
  1150. next_segment = next_words_segment(
  1151. current_segments[si + 1 :]
  1152. )
  1153. if next_segment is not None:
  1154. hal_next_start = next_segment["words"][0]["start"]
  1155. else:
  1156. hal_next_start = time_offset + segment_duration
  1157. silence_before = (
  1158. segment["start"] - hal_last_end > threshold
  1159. or segment["start"] < threshold
  1160. or segment["start"] - time_offset < 2.0
  1161. )
  1162. silence_after = (
  1163. hal_next_start - segment["end"] > threshold
  1164. or is_segment_anomaly(next_segment)
  1165. or window_end_time - segment["end"] < 2.0
  1166. )
  1167. if silence_before and silence_after:
  1168. seek = round(
  1169. max(time_offset + 1, segment["start"])
  1170. * self.frames_per_second
  1171. )
  1172. if content_duration - segment["end"] < threshold:
  1173. seek = content_frames
  1174. current_segments[si:] = []
  1175. break
  1176. hal_last_end = segment["end"]
  1177. last_word_end = get_end(current_segments)
  1178. if last_word_end is not None:
  1179. last_speech_timestamp = last_word_end
  1180. for segment in current_segments:
  1181. tokens = segment["tokens"]
  1182. text = tokenizer.decode(tokens)
  1183. if segment["start"] == segment["end"] or not text.strip():
  1184. continue
  1185. all_tokens.extend(tokens)
  1186. idx += 1
  1187. yield Segment(
  1188. id=idx,
  1189. seek=previous_seek,
  1190. start=segment["start"],
  1191. end=segment["end"],
  1192. text=text,
  1193. tokens=tokens,
  1194. temperature=temperature,
  1195. avg_logprob=avg_logprob,
  1196. compression_ratio=compression_ratio,
  1197. no_speech_prob=result.no_speech_prob,
  1198. words=(
  1199. [Word(**word) for word in segment["words"]]
  1200. if options.word_timestamps
  1201. else None
  1202. ),
  1203. )
  1204. if (
  1205. not options.condition_on_previous_text
  1206. or temperature > options.prompt_reset_on_temperature
  1207. ):
  1208. if options.condition_on_previous_text:
  1209. self.logger.debug(
  1210. "Reset prompt. prompt_reset_on_temperature threshold is met %f > %f",
  1211. temperature,
  1212. options.prompt_reset_on_temperature,
  1213. )
  1214. prompt_reset_since = len(all_tokens)
  1215. pbar.update(
  1216. (min(content_frames, seek) - previous_seek)
  1217. * self.feature_extractor.time_per_frame,
  1218. )
  1219. pbar.close()
  1220. def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
  1221. # When the model is running on multiple GPUs, the encoder output should be moved
  1222. # to the CPU since we don't know which GPU will handle the next job.
  1223. to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
  1224. if features.ndim == 2:
  1225. features = np.expand_dims(features, 0)
  1226. features = get_ctranslate2_storage(features)
  1227. return self.model.encode(features, to_cpu=to_cpu)
  1228. def generate_with_fallback(
  1229. self,
  1230. encoder_output: ctranslate2.StorageView,
  1231. prompt: List[int],
  1232. tokenizer: Tokenizer,
  1233. options: TranscriptionOptions,
  1234. ) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float, float]:
  1235. decode_result = None
  1236. all_results = []
  1237. below_cr_threshold_results = []
  1238. max_initial_timestamp_index = int(
  1239. round(options.max_initial_timestamp / self.time_precision)
  1240. )
  1241. if options.max_new_tokens is not None:
  1242. max_length = len(prompt) + options.max_new_tokens
  1243. else:
  1244. max_length = self.max_length
  1245. if max_length > self.max_length:
  1246. raise ValueError(
  1247. f"The length of the prompt is {len(prompt)}, and the `max_new_tokens` "
  1248. f"{max_length - len(prompt)}. Thus, the combined length of the prompt "
  1249. f"and `max_new_tokens` is: {max_length}. This exceeds the "
  1250. f"`max_length` of the Whisper model: {self.max_length}. "
  1251. "You should either reduce the length of your prompt, or "
  1252. "reduce the value of `max_new_tokens`, "
  1253. f"so that their combined length is less that {self.max_length}."
  1254. )
  1255. for temperature in options.temperatures:
  1256. if temperature > 0:
  1257. kwargs = {
  1258. "beam_size": 1,
  1259. "num_hypotheses": options.best_of,
  1260. "sampling_topk": 0,
  1261. "sampling_temperature": temperature,
  1262. }
  1263. else:
  1264. kwargs = {
  1265. "beam_size": options.beam_size,
  1266. "patience": options.patience,
  1267. }
  1268. result = self.model.generate(
  1269. encoder_output,
  1270. [prompt],
  1271. length_penalty=options.length_penalty,
  1272. repetition_penalty=options.repetition_penalty,
  1273. no_repeat_ngram_size=options.no_repeat_ngram_size,
  1274. max_length=max_length,
  1275. return_scores=True,
  1276. return_no_speech_prob=True,
  1277. suppress_blank=options.suppress_blank,
  1278. suppress_tokens=options.suppress_tokens,
  1279. max_initial_timestamp_index=max_initial_timestamp_index,
  1280. **kwargs,
  1281. )[0]
  1282. tokens = result.sequences_ids[0]
  1283. # Recover the average log prob from the returned score.
  1284. seq_len = len(tokens)
  1285. cum_logprob = result.scores[0] * (seq_len**options.length_penalty)
  1286. avg_logprob = cum_logprob / (seq_len + 1)
  1287. text = tokenizer.decode(tokens).strip()
  1288. compression_ratio = get_compression_ratio(text)
  1289. decode_result = (
  1290. result,
  1291. avg_logprob,
  1292. temperature,
  1293. compression_ratio,
  1294. )
  1295. all_results.append(decode_result)
  1296. needs_fallback = False
  1297. if options.compression_ratio_threshold is not None:
  1298. if compression_ratio > options.compression_ratio_threshold:
  1299. needs_fallback = True # too repetitive
  1300. self.logger.debug(
  1301. "Compression ratio threshold is not met with temperature %.1f (%f > %f)",
  1302. temperature,
  1303. compression_ratio,
  1304. options.compression_ratio_threshold,
  1305. )
  1306. else:
  1307. below_cr_threshold_results.append(decode_result)
  1308. if (
  1309. options.log_prob_threshold is not None
  1310. and avg_logprob < options.log_prob_threshold
  1311. ):
  1312. needs_fallback = True # average log probability is too low
  1313. self.logger.debug(
  1314. "Log probability threshold is not met with temperature %.1f (%f < %f)",
  1315. temperature,
  1316. avg_logprob,
  1317. options.log_prob_threshold,
  1318. )
  1319. if (
  1320. options.no_speech_threshold is not None
  1321. and result.no_speech_prob > options.no_speech_threshold
  1322. and options.log_prob_threshold is not None
  1323. and avg_logprob < options.log_prob_threshold
  1324. ):
  1325. needs_fallback = False # silence
  1326. if not needs_fallback:
  1327. break
  1328. else:
  1329. # all failed, select the result with the highest average log probability
  1330. decode_result = max(
  1331. below_cr_threshold_results or all_results, key=lambda x: x[1]
  1332. )
  1333. # to pass final temperature for prompt_reset_on_temperature
  1334. decode_result = (
  1335. decode_result[0],
  1336. decode_result[1],
  1337. temperature,
  1338. decode_result[3],
  1339. )
  1340. return decode_result
  1341. def get_prompt(
  1342. self,
  1343. tokenizer: Tokenizer,
  1344. previous_tokens: List[int],
  1345. without_timestamps: bool = False,
  1346. prefix: Optional[str] = None,
  1347. hotwords: Optional[str] = None,
  1348. ) -> List[int]:
  1349. prompt = []
  1350. if previous_tokens or (hotwords and not prefix):
  1351. prompt.append(tokenizer.sot_prev)
  1352. if hotwords and not prefix:
  1353. hotwords_tokens = tokenizer.encode(" " + hotwords.strip())
  1354. if len(hotwords_tokens) >= self.max_length // 2:
  1355. hotwords_tokens = hotwords_tokens[: self.max_length // 2 - 1]
  1356. prompt.extend(hotwords_tokens)
  1357. if previous_tokens:
  1358. prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])
  1359. prompt.extend(tokenizer.sot_sequence)
  1360. if without_timestamps:
  1361. prompt.append(tokenizer.no_timestamps)
  1362. if prefix:
  1363. prefix_tokens = tokenizer.encode(" " + prefix.strip())
  1364. if len(prefix_tokens) >= self.max_length // 2:
  1365. prefix_tokens = prefix_tokens[: self.max_length // 2 - 1]
  1366. if not without_timestamps:
  1367. prompt.append(tokenizer.timestamp_begin)
  1368. prompt.extend(prefix_tokens)
  1369. return prompt
  1370. def add_word_timestamps(
  1371. self,
  1372. segments: List[dict],
  1373. tokenizer: Tokenizer,
  1374. encoder_output: ctranslate2.StorageView,
  1375. num_frames: int,
  1376. prepend_punctuations: str,
  1377. append_punctuations: str,
  1378. last_speech_timestamp: float,
  1379. ) -> float:
  1380. if len(segments) == 0:
  1381. return
  1382. text_tokens = []
  1383. text_tokens_per_segment = []
  1384. for segment in segments:
  1385. segment_tokens = [
  1386. [token for token in subsegment["tokens"] if token < tokenizer.eot]
  1387. for subsegment in segment
  1388. ]
  1389. text_tokens.append(list(itertools.chain.from_iterable(segment_tokens)))
  1390. text_tokens_per_segment.append(segment_tokens)
  1391. alignments = self.find_alignment(
  1392. tokenizer, text_tokens, encoder_output, num_frames
  1393. )
  1394. median_max_durations = []
  1395. for alignment in alignments:
  1396. word_durations = np.array(
  1397. [word["end"] - word["start"] for word in alignment]
  1398. )
  1399. word_durations = word_durations[word_durations.nonzero()]
  1400. median_duration = (
  1401. np.median(word_durations) if len(word_durations) > 0 else 0.0
  1402. )
  1403. median_duration = min(0.7, float(median_duration))
  1404. max_duration = median_duration * 2
  1405. # hack: truncate long words at sentence boundaries.
  1406. # a better segmentation algorithm based on VAD should be able to replace this.
  1407. if len(word_durations) > 0:
  1408. sentence_end_marks = ".。!!??"
  1409. # ensure words at sentence boundaries
  1410. # are not longer than twice the median word duration.
  1411. for i in range(1, len(alignment)):
  1412. if alignment[i]["end"] - alignment[i]["start"] > max_duration:
  1413. if alignment[i]["word"] in sentence_end_marks:
  1414. alignment[i]["end"] = alignment[i]["start"] + max_duration
  1415. elif alignment[i - 1]["word"] in sentence_end_marks:
  1416. alignment[i]["start"] = alignment[i]["end"] - max_duration
  1417. merge_punctuations(alignment, prepend_punctuations, append_punctuations)
  1418. median_max_durations.append((median_duration, max_duration))
  1419. for segment_idx, segment in enumerate(segments):
  1420. word_index = 0
  1421. time_offset = segment[0]["seek"] / self.frames_per_second
  1422. median_duration, max_duration = median_max_durations[segment_idx]
  1423. for subsegment_idx, subsegment in enumerate(segment):
  1424. saved_tokens = 0
  1425. words = []
  1426. while word_index < len(alignments[segment_idx]) and saved_tokens < len(
  1427. text_tokens_per_segment[segment_idx][subsegment_idx]
  1428. ):
  1429. timing = alignments[segment_idx][word_index]
  1430. if timing["word"]:
  1431. words.append(
  1432. dict(
  1433. word=timing["word"],
  1434. start=round(time_offset + timing["start"], 2),
  1435. end=round(time_offset + timing["end"], 2),
  1436. probability=timing["probability"],
  1437. )
  1438. )
  1439. saved_tokens += len(timing["tokens"])
  1440. word_index += 1
  1441. # hack: truncate long words at segment boundaries.
  1442. # a better segmentation algorithm based on VAD should be able to replace this.
  1443. if len(words) > 0:
  1444. # ensure the first and second word after a pause is not longer than
  1445. # twice the median word duration.
  1446. if words[0][
  1447. "end"
  1448. ] - last_speech_timestamp > median_duration * 4 and (
  1449. words[0]["end"] - words[0]["start"] > max_duration
  1450. or (
  1451. len(words) > 1
  1452. and words[1]["end"] - words[0]["start"] > max_duration * 2
  1453. )
  1454. ):
  1455. if (
  1456. len(words) > 1
  1457. and words[1]["end"] - words[1]["start"] > max_duration
  1458. ):
  1459. boundary = max(
  1460. words[1]["end"] / 2, words[1]["end"] - max_duration
  1461. )
  1462. words[0]["end"] = words[1]["start"] = boundary
  1463. words[0]["start"] = max(0, words[0]["end"] - max_duration)
  1464. # prefer the segment-level start timestamp if the first word is too long.
  1465. if (
  1466. subsegment["start"] < words[0]["end"]
  1467. and subsegment["start"] - 0.5 > words[0]["start"]
  1468. ):
  1469. words[0]["start"] = max(
  1470. 0,
  1471. min(words[0]["end"] - median_duration, subsegment["start"]),
  1472. )
  1473. else:
  1474. subsegment["start"] = words[0]["start"]
  1475. # prefer the segment-level end timestamp if the last word is too long.
  1476. if (
  1477. subsegment["end"] > words[-1]["start"]
  1478. and subsegment["end"] + 0.5 < words[-1]["end"]
  1479. ):
  1480. words[-1]["end"] = max(
  1481. words[-1]["start"] + median_duration, subsegment["end"]
  1482. )
  1483. else:
  1484. subsegment["end"] = words[-1]["end"]
  1485. last_speech_timestamp = subsegment["end"]
  1486. segments[segment_idx][subsegment_idx]["words"] = words
  1487. return last_speech_timestamp
  1488. def find_alignment(
  1489. self,
  1490. tokenizer: Tokenizer,
  1491. text_tokens: List[int],
  1492. encoder_output: ctranslate2.StorageView,
  1493. num_frames: int,
  1494. median_filter_width: int = 7,
  1495. ) -> List[dict]:
  1496. if len(text_tokens) == 0:
  1497. return []
  1498. results = self.model.align(
  1499. encoder_output,
  1500. tokenizer.sot_sequence,
  1501. text_tokens,
  1502. num_frames,
  1503. median_filter_width=median_filter_width,
  1504. )
  1505. return_list = []
  1506. for result, text_token in zip(results, text_tokens):
  1507. text_token_probs = result.text_token_probs
  1508. alignments = result.alignments
  1509. text_indices = np.array([pair[0] for pair in alignments])
  1510. time_indices = np.array([pair[1] for pair in alignments])
  1511. words, word_tokens = tokenizer.split_to_word_tokens(
  1512. text_token + [tokenizer.eot]
  1513. )
  1514. if len(word_tokens) <= 1:
  1515. # return on eot only
  1516. # >>> np.pad([], (1, 0))
  1517. # array([0.])
  1518. # This results in crashes when we lookup jump_times with float, like
  1519. # IndexError: arrays used as indices must be of integer (or boolean) type
  1520. return_list.append([])
  1521. continue
  1522. word_boundaries = np.pad(
  1523. np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0)
  1524. )
  1525. if len(word_boundaries) <= 1:
  1526. return_list.append([])
  1527. continue
  1528. jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(
  1529. bool
  1530. )
  1531. jump_times = time_indices[jumps] / self.tokens_per_second
  1532. start_times = jump_times[word_boundaries[:-1]]
  1533. end_times = jump_times[word_boundaries[1:]]
  1534. word_probabilities = [
  1535. np.mean(text_token_probs[i:j])
  1536. for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
  1537. ]
  1538. return_list.append(
  1539. [
  1540. dict(
  1541. word=word,
  1542. tokens=tokens,
  1543. start=start,
  1544. end=end,
  1545. probability=probability,
  1546. )
  1547. for word, tokens, start, end, probability in zip(
  1548. words, word_tokens, start_times, end_times, word_probabilities
  1549. )
  1550. ]
  1551. )
  1552. return return_list
  1553. def detect_language(
  1554. self,
  1555. audio: Optional[np.ndarray] = None,
  1556. features: Optional[np.ndarray] = None,
  1557. vad_filter: bool = False,
  1558. vad_parameters: Union[dict, VadOptions] = None,
  1559. language_detection_segments: int = 1,
  1560. language_detection_threshold: float = 0.5,
  1561. ) -> Tuple[str, float, List[Tuple[str, float]]]:
  1562. """
  1563. Use Whisper to detect the language of the input audio or features.
  1564. Arguments:
  1565. audio: Input audio signal, must be a 1D float array sampled at 16khz.
  1566. features: Input Mel spectrogram features, must be a float array with
  1567. shape (n_mels, n_frames), if `audio` is provided, the features will be ignored.
  1568. Either `audio` or `features` must be provided.
  1569. vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
  1570. without speech. This step is using the Silero VAD model.
  1571. vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
  1572. parameters and default values in the class `VadOptions`).
  1573. language_detection_threshold: If the maximum probability of the language tokens is
  1574. higher than this value, the language is detected.
  1575. language_detection_segments: Number of segments to consider for the language detection.
  1576. Returns:
  1577. language: Detected language.
  1578. languege_probability: Probability of the detected language.
  1579. all_language_probs: List of tuples with all language names and probabilities.
  1580. """
  1581. assert (
  1582. audio is not None or features is not None
  1583. ), "Either `audio` or `features` must be provided."
  1584. if audio is not None:
  1585. if vad_filter:
  1586. speech_chunks = get_speech_timestamps(audio, vad_parameters)
  1587. audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks)
  1588. audio = np.concatenate(audio_chunks, axis=0)
  1589. audio = audio[
  1590. : language_detection_segments * self.feature_extractor.n_samples
  1591. ]
  1592. features = self.feature_extractor(audio)
  1593. features = features[
  1594. ..., : language_detection_segments * self.feature_extractor.nb_max_frames
  1595. ]
  1596. detected_language_info = {}
  1597. for i in range(0, features.shape[-1], self.feature_extractor.nb_max_frames):
  1598. encoder_output = self.encode(
  1599. pad_or_trim(features[..., i : i + self.feature_extractor.nb_max_frames])
  1600. )
  1601. # results is a list of tuple[str, float] with language names and probabilities.
  1602. results = self.model.detect_language(encoder_output)[0]
  1603. # Parse language names to strip out markers
  1604. all_language_probs = [(token[2:-2], prob) for (token, prob) in results]
  1605. # Get top language token and probability
  1606. language, language_probability = all_language_probs[0]
  1607. if language_probability > language_detection_threshold:
  1608. break
  1609. detected_language_info.setdefault(language, []).append(language_probability)
  1610. else:
  1611. # If no language detected for all segments, the majority vote of the highest
  1612. # projected languages for all segments is used to determine the language.
  1613. language = max(
  1614. detected_language_info,
  1615. key=lambda lang: len(detected_language_info[lang]),
  1616. )
  1617. language_probability = max(detected_language_info[language])
  1618. return language, language_probability, all_language_probs
  1619. def restore_speech_timestamps(
  1620. segments: Iterable[Segment],
  1621. speech_chunks: List[dict],
  1622. sampling_rate: int,
  1623. ) -> Iterable[Segment]:
  1624. ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate)
  1625. for segment in segments:
  1626. if segment.words:
  1627. words = []
  1628. for word in segment.words:
  1629. # Ensure the word start and end times are resolved to the same chunk.
  1630. middle = (word.start + word.end) / 2
  1631. chunk_index = ts_map.get_chunk_index(middle)
  1632. word.start = ts_map.get_original_time(word.start, chunk_index)
  1633. word.end = ts_map.get_original_time(word.end, chunk_index)
  1634. words.append(word)
  1635. segment.start = words[0].start
  1636. segment.end = words[-1].end
  1637. segment.words = words
  1638. else:
  1639. segment.start = ts_map.get_original_time(segment.start)
  1640. segment.end = ts_map.get_original_time(segment.end)
  1641. yield segment
  1642. def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView:
  1643. segment = np.ascontiguousarray(segment)
  1644. segment = ctranslate2.StorageView.from_array(segment)
  1645. return segment
  1646. def get_compression_ratio(text: str) -> float:
  1647. text_bytes = text.encode("utf-8")
  1648. return len(text_bytes) / len(zlib.compress(text_bytes))
  1649. def get_suppressed_tokens(
  1650. tokenizer: Tokenizer,
  1651. suppress_tokens: Tuple[int],
  1652. ) -> Optional[List[int]]:
  1653. if -1 in suppress_tokens:
  1654. suppress_tokens = [t for t in suppress_tokens if t >= 0]
  1655. suppress_tokens.extend(tokenizer.non_speech_tokens)
  1656. elif suppress_tokens is None or len(suppress_tokens) == 0:
  1657. suppress_tokens = [] # interpret empty string as an empty list
  1658. else:
  1659. assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
  1660. suppress_tokens.extend(
  1661. [
  1662. tokenizer.transcribe,
  1663. tokenizer.translate,
  1664. tokenizer.sot,
  1665. tokenizer.sot_prev,
  1666. tokenizer.sot_lm,
  1667. ]
  1668. )
  1669. return tuple(sorted(set(suppress_tokens)))
  1670. def merge_punctuations(alignment: List[dict], prepended: str, appended: str) -> None:
  1671. # merge prepended punctuations
  1672. i = len(alignment) - 2
  1673. j = len(alignment) - 1
  1674. while i >= 0:
  1675. previous = alignment[i]
  1676. following = alignment[j]
  1677. if previous["word"].startswith(" ") and previous["word"].strip() in prepended:
  1678. # prepend it to the following word
  1679. following["word"] = previous["word"] + following["word"]
  1680. following["tokens"] = previous["tokens"] + following["tokens"]
  1681. previous["word"] = ""
  1682. previous["tokens"] = []
  1683. else:
  1684. j = i
  1685. i -= 1
  1686. # merge appended punctuations
  1687. i = 0
  1688. j = 1
  1689. while j < len(alignment):
  1690. previous = alignment[i]
  1691. following = alignment[j]
  1692. if not previous["word"].endswith(" ") and following["word"] in appended:
  1693. # append it to the previous word
  1694. previous["word"] = previous["word"] + following["word"]
  1695. previous["tokens"] = previous["tokens"] + following["tokens"]
  1696. following["word"] = ""
  1697. following["tokens"] = []
  1698. else:
  1699. i = j
  1700. j += 1