| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903 |
- import itertools
- import json
- import logging
- import os
- import zlib
- from dataclasses import asdict, dataclass
- from inspect import signature
- from math import ceil
- from typing import BinaryIO, Iterable, List, Optional, Tuple, Union
- from warnings import warn
- import ctranslate2
- import numpy as np
- import tokenizers
- from tqdm import tqdm
- from faster_whisper.audio import decode_audio, pad_or_trim
- from faster_whisper.feature_extractor import FeatureExtractor
- from faster_whisper.tokenizer import _LANGUAGE_CODES, Tokenizer
- from faster_whisper.utils import download_model, format_timestamp, get_end, get_logger
- from faster_whisper.vad import (
- SpeechTimestampsMap,
- VadOptions,
- collect_chunks,
- get_speech_timestamps,
- merge_segments,
- )
- @dataclass
- class Word:
- start: float
- end: float
- word: str
- probability: float
- def _asdict(self):
- warn(
- "Word._asdict() method is deprecated, use dataclasses.asdict(Word) instead",
- DeprecationWarning,
- 2,
- )
- return asdict(self)
- @dataclass
- class Segment:
- id: int
- seek: int
- start: float
- end: float
- text: str
- tokens: List[int]
- avg_logprob: float
- compression_ratio: float
- no_speech_prob: float
- words: Optional[List[Word]]
- temperature: Optional[float]
- def _asdict(self):
- warn(
- "Segment._asdict() method is deprecated, use dataclasses.asdict(Segment) instead",
- DeprecationWarning,
- 2,
- )
- return asdict(self)
- @dataclass
- class TranscriptionOptions:
- beam_size: int
- best_of: int
- patience: float
- length_penalty: float
- repetition_penalty: float
- no_repeat_ngram_size: int
- log_prob_threshold: Optional[float]
- no_speech_threshold: Optional[float]
- compression_ratio_threshold: Optional[float]
- condition_on_previous_text: bool
- prompt_reset_on_temperature: float
- temperatures: List[float]
- initial_prompt: Optional[Union[str, Iterable[int]]]
- prefix: Optional[str]
- suppress_blank: bool
- suppress_tokens: Optional[List[int]]
- without_timestamps: bool
- max_initial_timestamp: float
- word_timestamps: bool
- prepend_punctuations: str
- append_punctuations: str
- multilingual: bool
- max_new_tokens: Optional[int]
- clip_timestamps: Union[str, List[float]]
- hallucination_silence_threshold: Optional[float]
- hotwords: Optional[str]
- @dataclass
- class TranscriptionInfo:
- language: str
- language_probability: float
- duration: float
- duration_after_vad: float
- all_language_probs: Optional[List[Tuple[str, float]]]
- transcription_options: TranscriptionOptions
- vad_options: VadOptions
- class BatchedInferencePipeline:
- def __init__(
- self,
- model,
- ):
- self.model: WhisperModel = model
- self.last_speech_timestamp = 0.0
- def forward(self, features, tokenizer, chunks_metadata, options):
- encoder_output, outputs = self.generate_segment_batched(
- features, tokenizer, options
- )
- segmented_outputs = []
- segment_sizes = []
- for chunk_metadata, output in zip(chunks_metadata, outputs):
- duration = chunk_metadata["end_time"] - chunk_metadata["start_time"]
- segment_size = int(ceil(duration) * self.model.frames_per_second)
- segment_sizes.append(segment_size)
- (
- subsegments,
- seek,
- single_timestamp_ending,
- ) = self.model._split_segments_by_timestamps(
- tokenizer=tokenizer,
- tokens=output["tokens"],
- time_offset=chunk_metadata["start_time"],
- segment_size=segment_size,
- segment_duration=duration,
- seek=0,
- )
- segmented_outputs.append(
- [
- dict(
- text=tokenizer.decode(subsegment["tokens"]),
- avg_logprob=output["avg_logprob"],
- no_speech_prob=output["no_speech_prob"],
- tokens=subsegment["tokens"],
- start=subsegment["start"],
- end=subsegment["end"],
- compression_ratio=get_compression_ratio(
- tokenizer.decode(subsegment["tokens"])
- ),
- seek=int(
- chunk_metadata["start_time"] * self.model.frames_per_second
- ),
- )
- for subsegment in subsegments
- ]
- )
- if options.word_timestamps:
- self.last_speech_timestamp = self.model.add_word_timestamps(
- segmented_outputs,
- tokenizer,
- encoder_output,
- segment_sizes,
- options.prepend_punctuations,
- options.append_punctuations,
- self.last_speech_timestamp,
- )
- return segmented_outputs
- def generate_segment_batched(
- self,
- features: np.ndarray,
- tokenizer: Tokenizer,
- options: TranscriptionOptions,
- ):
- batch_size = features.shape[0]
- prompt = self.model.get_prompt(
- tokenizer,
- previous_tokens=(
- tokenizer.encode(options.initial_prompt)
- if options.initial_prompt is not None
- else []
- ),
- without_timestamps=options.without_timestamps,
- hotwords=options.hotwords,
- )
- if options.max_new_tokens is not None:
- max_length = len(prompt) + options.max_new_tokens
- else:
- max_length = self.model.max_length
- if max_length > self.model.max_length:
- raise ValueError(
- f"The length of the prompt is {len(prompt)}, and the `max_new_tokens` "
- f"{max_length - len(prompt)}. Thus, the combined length of the prompt "
- f"and `max_new_tokens` is: {max_length}. This exceeds the "
- f"`max_length` of the Whisper model: {self.model.max_length}. "
- "You should either reduce the length of your prompt, or "
- "reduce the value of `max_new_tokens`, "
- f"so that their combined length is less that {self.model.max_length}."
- )
- encoder_output = self.model.encode(features)
- prompts = [prompt.copy() for _ in range(batch_size)]
- if options.multilingual:
- language_tokens = [
- tokenizer.tokenizer.token_to_id(segment_langs[0][0])
- for segment_langs in self.model.model.detect_language(encoder_output)
- ]
- language_token_index = prompt.index(tokenizer.language)
- for i, language_token in enumerate(language_tokens):
- prompts[i][language_token_index] = language_token
- results = self.model.model.generate(
- encoder_output,
- prompts,
- beam_size=options.beam_size,
- patience=options.patience,
- length_penalty=options.length_penalty,
- max_length=max_length,
- suppress_blank=options.suppress_blank,
- suppress_tokens=options.suppress_tokens,
- return_scores=True,
- return_no_speech_prob=True,
- sampling_temperature=options.temperatures[0],
- repetition_penalty=options.repetition_penalty,
- no_repeat_ngram_size=options.no_repeat_ngram_size,
- )
- output = []
- for result in results:
- # return scores
- seq_len = len(result.sequences_ids[0])
- cum_logprob = result.scores[0] * (seq_len**options.length_penalty)
- output.append(
- dict(
- avg_logprob=cum_logprob / (seq_len + 1),
- no_speech_prob=result.no_speech_prob,
- tokens=result.sequences_ids[0],
- )
- )
- return encoder_output, output
- def transcribe(
- self,
- audio: Union[str, BinaryIO, np.ndarray],
- language: Optional[str] = None,
- task: str = "transcribe",
- log_progress: bool = False,
- beam_size: int = 5,
- best_of: int = 5,
- patience: float = 1,
- length_penalty: float = 1,
- repetition_penalty: float = 1,
- no_repeat_ngram_size: int = 0,
- temperature: Union[float, List[float], Tuple[float, ...]] = [
- 0.0,
- 0.2,
- 0.4,
- 0.6,
- 0.8,
- 1.0,
- ],
- compression_ratio_threshold: Optional[float] = 2.4,
- log_prob_threshold: Optional[float] = -1.0,
- no_speech_threshold: Optional[float] = 0.6,
- condition_on_previous_text: bool = True,
- prompt_reset_on_temperature: float = 0.5,
- initial_prompt: Optional[Union[str, Iterable[int]]] = None,
- prefix: Optional[str] = None,
- suppress_blank: bool = True,
- suppress_tokens: Optional[List[int]] = [-1],
- without_timestamps: bool = True,
- max_initial_timestamp: float = 1.0,
- word_timestamps: bool = False,
- prepend_punctuations: str = "\"'“¿([{-",
- append_punctuations: str = "\"'.。,,!!??::”)]}、",
- multilingual: bool = False,
- vad_filter: bool = True,
- vad_parameters: Optional[Union[dict, VadOptions]] = None,
- max_new_tokens: Optional[int] = None,
- chunk_length: Optional[int] = None,
- clip_timestamps: Optional[List[dict]] = None,
- hallucination_silence_threshold: Optional[float] = None,
- batch_size: int = 8,
- hotwords: Optional[str] = None,
- language_detection_threshold: Optional[float] = 0.5,
- language_detection_segments: int = 1,
- ) -> Tuple[Iterable[Segment], TranscriptionInfo]:
- """transcribe audio in chunks in batched fashion and return with language info.
- Arguments:
- audio: Path to the input file (or a file-like object), or the audio waveform.
- language: The language spoken in the audio. It should be a language code such
- as "en" or "fr". If not set, the language will be detected in the first 30 seconds
- of audio.
- task: Task to execute (transcribe or translate).
- log_progress: whether to show progress bar or not.
- beam_size: Beam size to use for decoding.
- best_of: Number of candidates when sampling with non-zero temperature.
- patience: Beam search patience factor.
- length_penalty: Exponential length penalty constant.
- repetition_penalty: Penalty applied to the score of previously generated tokens
- (set > 1 to penalize).
- no_repeat_ngram_size: Prevent repetitions of ngrams with this size (set 0 to disable).
- temperature: Temperature for sampling. If a list or tuple is passed,
- only the first value is used.
- initial_prompt: Optional text string or iterable of token ids to provide as a
- prompt for the each window.
- suppress_blank: Suppress blank outputs at the beginning of the sampling.
- suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
- of symbols as defined in `tokenizer.non_speech_tokens()`.
- without_timestamps: Only sample text tokens.
- word_timestamps: Extract word-level timestamps using the cross-attention pattern
- and dynamic time warping, and include the timestamps for each word in each segment.
- Set as False.
- prepend_punctuations: If word_timestamps is True, merge these punctuation symbols
- with the next word
- append_punctuations: If word_timestamps is True, merge these punctuation symbols
- with the previous word
- multilingual: Perform language detection on every segment.
- vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
- without speech. This step is using the Silero VAD model
- https://github.com/snakers4/silero-vad.
- vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
- parameters and default values in the class `VadOptions`).
- max_new_tokens: Maximum number of new tokens to generate per-chunk. If not set,
- the maximum will be set by the default max_length.
- chunk_length: The length of audio segments. If it is not None, it will overwrite the
- default chunk_length of the FeatureExtractor.
- clip_timestamps: Optionally provide list of dictionaries each containing "start" and
- "end" keys that specify the start and end of the voiced region within
- `chunk_length` boundary. vad_filter will be ignored if clip_timestamps is used.
- batch_size: the maximum number of parallel requests to model for decoding.
- hotwords:
- Hotwords/hint phrases to the model. Has no effect if prefix is not None.
- language_detection_threshold: If the maximum probability of the language tokens is
- higher than this value, the language is detected.
- language_detection_segments: Number of segments to consider for the language detection.
- Unused Arguments
- compression_ratio_threshold: If the gzip compression ratio is above this value,
- treat as failed.
- log_prob_threshold: If the average log probability over sampled tokens is
- below this value, treat as failed.
- no_speech_threshold: If the no_speech probability is higher than this value AND
- the average log probability over sampled tokens is below `log_prob_threshold`,
- consider the segment as silent.
- condition_on_previous_text: If True, the previous output of the model is provided
- as a prompt for the next window; disabling may make the text inconsistent across
- windows, but the model becomes less prone to getting stuck in a failure loop,
- such as repetition looping or timestamps going out of sync. Set as False
- prompt_reset_on_temperature: Resets prompt if temperature is above this value.
- Arg has effect only if condition_on_previous_text is True. Set at 0.5
- prefix: Optional text to provide as a prefix at the beginning of each window.
- max_initial_timestamp: The initial timestamp cannot be later than this, set at 0.0.
- hallucination_silence_threshold: Optional[float]
- When word_timestamps is True, skip silent periods longer than this threshold
- (in seconds) when a possible hallucination is detected. set as None.
- Returns:
- A tuple with:
- - a generator over transcribed segments
- - an instance of TranscriptionInfo
- """
- sampling_rate = self.model.feature_extractor.sampling_rate
- if multilingual and not self.model.model.is_multilingual:
- self.model.logger.warning(
- "The current model is English-only but the multilingual parameter is set to"
- "True; setting to False instead."
- )
- multilingual = False
- if not isinstance(audio, np.ndarray):
- audio = decode_audio(audio, sampling_rate=sampling_rate)
- duration = audio.shape[0] / sampling_rate
- self.model.logger.info(
- "Processing audio with duration %s", format_timestamp(duration)
- )
- chunk_length = chunk_length or self.model.feature_extractor.chunk_length
- # if no segment split is provided, use vad_model and generate segments
- if not clip_timestamps:
- if vad_filter:
- if vad_parameters is None:
- vad_parameters = VadOptions(
- max_speech_duration_s=chunk_length,
- min_silence_duration_ms=160,
- )
- elif isinstance(vad_parameters, dict):
- if "max_speech_duration_s" in vad_parameters.keys():
- vad_parameters.pop("max_speech_duration_s")
- vad_parameters = VadOptions(
- **vad_parameters, max_speech_duration_s=chunk_length
- )
- active_segments = get_speech_timestamps(audio, vad_parameters)
- clip_timestamps = merge_segments(active_segments, vad_parameters)
- # run the audio if it is less than 30 sec even without clip_timestamps
- elif duration < chunk_length:
- clip_timestamps = [{"start": 0, "end": audio.shape[0]}]
- else:
- raise RuntimeError(
- "No clip timestamps found. "
- "Set 'vad_filter' to True or provide 'clip_timestamps'."
- )
- duration_after_vad = (
- sum((segment["end"] - segment["start"]) for segment in clip_timestamps)
- / sampling_rate
- )
- self.model.logger.info(
- "VAD filter removed %s of audio",
- format_timestamp(duration - duration_after_vad),
- )
- audio_chunks, chunks_metadata = collect_chunks(audio, clip_timestamps)
- features = (
- [self.model.feature_extractor(chunk)[..., :-1] for chunk in audio_chunks]
- if duration_after_vad
- else []
- )
- all_language_probs = None
- # detecting the language if not provided
- if language is None:
- if not self.model.model.is_multilingual:
- language = "en"
- language_probability = 1
- else:
- (
- language,
- language_probability,
- all_language_probs,
- ) = self.model.detect_language(
- features=np.concatenate(
- features
- + [
- np.full((self.model.model.n_mels, 1), -1.5, dtype="float32")
- ],
- axis=1,
- ), # add a dummy feature to account for empty audio
- language_detection_segments=language_detection_segments,
- language_detection_threshold=language_detection_threshold,
- )
- self.model.logger.info(
- "Detected language '%s' with probability %.2f",
- language,
- language_probability,
- )
- else:
- if not self.model.model.is_multilingual and language != "en":
- self.model.logger.warning(
- "The current model is English-only but the language parameter is set to '%s'; "
- "using 'en' instead." % language
- )
- language = "en"
- language_probability = 1
- tokenizer = Tokenizer(
- self.model.hf_tokenizer,
- self.model.model.is_multilingual,
- task=task,
- language=language,
- )
- features = (
- np.stack([pad_or_trim(feature) for feature in features]) if features else []
- )
- options = TranscriptionOptions(
- beam_size=beam_size,
- best_of=best_of,
- patience=patience,
- length_penalty=length_penalty,
- repetition_penalty=repetition_penalty,
- no_repeat_ngram_size=no_repeat_ngram_size,
- log_prob_threshold=log_prob_threshold,
- no_speech_threshold=no_speech_threshold,
- compression_ratio_threshold=compression_ratio_threshold,
- temperatures=(
- temperature[:1]
- if isinstance(temperature, (list, tuple))
- else [temperature]
- ),
- initial_prompt=initial_prompt,
- prefix=prefix,
- suppress_blank=suppress_blank,
- suppress_tokens=(
- get_suppressed_tokens(tokenizer, suppress_tokens)
- if suppress_tokens
- else suppress_tokens
- ),
- prepend_punctuations=prepend_punctuations,
- append_punctuations=append_punctuations,
- max_new_tokens=max_new_tokens,
- hotwords=hotwords,
- word_timestamps=word_timestamps,
- hallucination_silence_threshold=None,
- condition_on_previous_text=False,
- clip_timestamps=clip_timestamps,
- prompt_reset_on_temperature=0.5,
- multilingual=multilingual,
- without_timestamps=without_timestamps,
- max_initial_timestamp=0.0,
- )
- info = TranscriptionInfo(
- language=language,
- language_probability=language_probability,
- duration=duration,
- duration_after_vad=duration_after_vad,
- transcription_options=options,
- vad_options=vad_parameters,
- all_language_probs=all_language_probs,
- )
- segments = self._batched_segments_generator(
- features,
- tokenizer,
- chunks_metadata,
- batch_size,
- options,
- log_progress,
- )
- return segments, info
- def _batched_segments_generator(
- self, features, tokenizer, chunks_metadata, batch_size, options, log_progress
- ):
- pbar = tqdm(total=len(features), disable=not log_progress, position=0)
- seg_idx = 0
- for i in range(0, len(features), batch_size):
- results = self.forward(
- features[i : i + batch_size],
- tokenizer,
- chunks_metadata[i : i + batch_size],
- options,
- )
- for result in results:
- for segment in result:
- seg_idx += 1
- yield Segment(
- seek=segment["seek"],
- id=seg_idx,
- text=segment["text"],
- start=round(segment["start"], 3),
- end=round(segment["end"], 3),
- words=(
- None
- if not options.word_timestamps
- else [Word(**word) for word in segment["words"]]
- ),
- tokens=segment["tokens"],
- avg_logprob=segment["avg_logprob"],
- no_speech_prob=segment["no_speech_prob"],
- compression_ratio=segment["compression_ratio"],
- temperature=options.temperatures[0],
- )
- pbar.update(1)
- pbar.close()
- self.last_speech_timestamp = 0.0
- class WhisperModel:
- def __init__(
- self,
- model_size_or_path: str,
- device: str = "auto",
- device_index: Union[int, List[int]] = 0,
- compute_type: str = "default",
- cpu_threads: int = 0,
- num_workers: int = 1,
- download_root: Optional[str] = None,
- local_files_only: bool = False,
- files: dict = None,
- revision: Optional[str] = None,
- **model_kwargs,
- ):
- """Initializes the Whisper model.
- Args:
- model_size_or_path: Size of the model to use (tiny, tiny.en, base, base.en,
- small, small.en, distil-small.en, medium, medium.en, distil-medium.en, large-v1,
- large-v2, large-v3, large, distil-large-v2, distil-large-v3, large-v3-turbo, or turbo),
- a path to a converted model directory, or a CTranslate2-converted Whisper model ID from
- the HF Hub. When a size or a model ID is configured, the converted model is downloaded
- from the Hugging Face Hub.
- device: Device to use for computation ("cpu", "cuda", "auto").
- device_index: Device ID to use.
- The model can also be loaded on multiple GPUs by passing a list of IDs
- (e.g. [0, 1, 2, 3]). In that case, multiple transcriptions can run in parallel
- when transcribe() is called from multiple Python threads (see also num_workers).
- compute_type: Type to use for computation.
- See https://opennmt.net/CTranslate2/quantization.html.
- cpu_threads: Number of threads to use when running on CPU (4 by default).
- A non zero value overrides the OMP_NUM_THREADS environment variable.
- num_workers: When transcribe() is called from multiple Python threads,
- having multiple workers enables true parallelism when running the model
- (concurrent calls to self.model.generate() will run in parallel).
- This can improve the global throughput at the cost of increased memory usage.
- download_root: Directory where the models should be saved. If not set, the models
- are saved in the standard Hugging Face cache directory.
- local_files_only: If True, avoid downloading the file and return the path to the
- local cached file if it exists.
- files: Load model files from the memory. This argument is a dictionary mapping file names
- to file contents as file-like or bytes objects. If this is set, model_path acts as an
- identifier for this model.
- revision:
- An optional Git revision id which can be a branch name, a tag, or a
- commit hash.
- """
- self.logger = get_logger()
- tokenizer_bytes, preprocessor_bytes = None, None
- if files:
- model_path = model_size_or_path
- tokenizer_bytes = files.pop("tokenizer.json", None)
- preprocessor_bytes = files.pop("preprocessor_config.json", None)
- elif os.path.isdir(model_size_or_path):
- model_path = model_size_or_path
- else:
- model_path = download_model(
- model_size_or_path,
- local_files_only=local_files_only,
- cache_dir=download_root,
- revision=revision,
- )
- self.model = ctranslate2.models.Whisper(
- model_path,
- device=device,
- device_index=device_index,
- compute_type=compute_type,
- intra_threads=cpu_threads,
- inter_threads=num_workers,
- files=files,
- **model_kwargs,
- )
- tokenizer_file = os.path.join(model_path, "tokenizer.json")
- if tokenizer_bytes:
- self.hf_tokenizer = tokenizers.Tokenizer.from_buffer(tokenizer_bytes)
- elif os.path.isfile(tokenizer_file):
- self.hf_tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file)
- else:
- self.hf_tokenizer = tokenizers.Tokenizer.from_pretrained(
- "openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en")
- )
- self.feat_kwargs = self._get_feature_kwargs(model_path, preprocessor_bytes)
- self.feature_extractor = FeatureExtractor(**self.feat_kwargs)
- self.input_stride = 2
- self.num_samples_per_token = (
- self.feature_extractor.hop_length * self.input_stride
- )
- self.frames_per_second = (
- self.feature_extractor.sampling_rate // self.feature_extractor.hop_length
- )
- self.tokens_per_second = (
- self.feature_extractor.sampling_rate // self.num_samples_per_token
- )
- self.time_precision = 0.02
- self.max_length = 448
- @property
- def supported_languages(self) -> List[str]:
- """The languages supported by the model."""
- return list(_LANGUAGE_CODES) if self.model.is_multilingual else ["en"]
- def _get_feature_kwargs(self, model_path, preprocessor_bytes=None) -> dict:
- config = {}
- try:
- config_path = os.path.join(model_path, "preprocessor_config.json")
- if preprocessor_bytes:
- config = json.loads(preprocessor_bytes)
- elif os.path.isfile(config_path):
- with open(config_path, "r", encoding="utf-8") as file:
- config = json.load(file)
- else:
- return config
- valid_keys = signature(FeatureExtractor.__init__).parameters.keys()
- return {k: v for k, v in config.items() if k in valid_keys}
- except json.JSONDecodeError as e:
- self.logger.warning("Could not load preprocessor config: %s", e)
- return config
- def transcribe(
- self,
- audio: Union[str, BinaryIO, np.ndarray],
- language: Optional[str] = None,
- task: str = "transcribe",
- log_progress: bool = False,
- beam_size: int = 5,
- best_of: int = 5,
- patience: float = 1,
- length_penalty: float = 1,
- repetition_penalty: float = 1,
- no_repeat_ngram_size: int = 0,
- temperature: Union[float, List[float], Tuple[float, ...]] = [
- 0.0,
- 0.2,
- 0.4,
- 0.6,
- 0.8,
- 1.0,
- ],
- compression_ratio_threshold: Optional[float] = 2.4,
- log_prob_threshold: Optional[float] = -1.0,
- no_speech_threshold: Optional[float] = 0.6,
- condition_on_previous_text: bool = True,
- prompt_reset_on_temperature: float = 0.5,
- initial_prompt: Optional[Union[str, Iterable[int]]] = None,
- prefix: Optional[str] = None,
- suppress_blank: bool = True,
- suppress_tokens: Optional[List[int]] = [-1],
- without_timestamps: bool = False,
- max_initial_timestamp: float = 1.0,
- word_timestamps: bool = False,
- prepend_punctuations: str = "\"'“¿([{-",
- append_punctuations: str = "\"'.。,,!!??::”)]}、",
- multilingual: bool = False,
- vad_filter: bool = False,
- vad_parameters: Optional[Union[dict, VadOptions]] = None,
- max_new_tokens: Optional[int] = None,
- chunk_length: Optional[int] = None,
- clip_timestamps: Union[str, List[float]] = "0",
- hallucination_silence_threshold: Optional[float] = None,
- hotwords: Optional[str] = None,
- language_detection_threshold: Optional[float] = 0.5,
- language_detection_segments: int = 1,
- ) -> Tuple[Iterable[Segment], TranscriptionInfo]:
- """Transcribes an input file.
- Arguments:
- audio: Path to the input file (or a file-like object), or the audio waveform.
- language: The language spoken in the audio. It should be a language code such
- as "en" or "fr". If not set, the language will be detected in the first 30 seconds
- of audio.
- task: Task to execute (transcribe or translate).
- log_progress: whether to show progress bar or not.
- beam_size: Beam size to use for decoding.
- best_of: Number of candidates when sampling with non-zero temperature.
- patience: Beam search patience factor.
- length_penalty: Exponential length penalty constant.
- repetition_penalty: Penalty applied to the score of previously generated tokens
- (set > 1 to penalize).
- no_repeat_ngram_size: Prevent repetitions of ngrams with this size (set 0 to disable).
- temperature: Temperature for sampling. It can be a tuple of temperatures,
- which will be successively used upon failures according to either
- `compression_ratio_threshold` or `log_prob_threshold`.
- compression_ratio_threshold: If the gzip compression ratio is above this value,
- treat as failed.
- log_prob_threshold: If the average log probability over sampled tokens is
- below this value, treat as failed.
- no_speech_threshold: If the no_speech probability is higher than this value AND
- the average log probability over sampled tokens is below `log_prob_threshold`,
- consider the segment as silent.
- condition_on_previous_text: If True, the previous output of the model is provided
- as a prompt for the next window; disabling may make the text inconsistent across
- windows, but the model becomes less prone to getting stuck in a failure loop,
- such as repetition looping or timestamps going out of sync.
- prompt_reset_on_temperature: Resets prompt if temperature is above this value.
- Arg has effect only if condition_on_previous_text is True.
- initial_prompt: Optional text string or iterable of token ids to provide as a
- prompt for the first window.
- prefix: Optional text to provide as a prefix for the first window.
- suppress_blank: Suppress blank outputs at the beginning of the sampling.
- suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
- of symbols as defined in `tokenizer.non_speech_tokens()`.
- without_timestamps: Only sample text tokens.
- max_initial_timestamp: The initial timestamp cannot be later than this.
- word_timestamps: Extract word-level timestamps using the cross-attention pattern
- and dynamic time warping, and include the timestamps for each word in each segment.
- prepend_punctuations: If word_timestamps is True, merge these punctuation symbols
- with the next word
- append_punctuations: If word_timestamps is True, merge these punctuation symbols
- with the previous word
- multilingual: Perform language detection on every segment.
- vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
- without speech. This step is using the Silero VAD model
- https://github.com/snakers4/silero-vad.
- vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
- parameters and default values in the class `VadOptions`).
- max_new_tokens: Maximum number of new tokens to generate per-chunk. If not set,
- the maximum will be set by the default max_length.
- chunk_length: The length of audio segments. If it is not None, it will overwrite the
- default chunk_length of the FeatureExtractor.
- clip_timestamps:
- Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to
- process. The last end timestamp defaults to the end of the file.
- vad_filter will be ignored if clip_timestamps is used.
- hallucination_silence_threshold:
- When word_timestamps is True, skip silent periods longer than this threshold
- (in seconds) when a possible hallucination is detected
- hotwords:
- Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.
- language_detection_threshold: If the maximum probability of the language tokens is higher
- than this value, the language is detected.
- language_detection_segments: Number of segments to consider for the language detection.
- Returns:
- A tuple with:
- - a generator over transcribed segments
- - an instance of TranscriptionInfo
- """
- sampling_rate = self.feature_extractor.sampling_rate
- if multilingual and not self.model.is_multilingual:
- self.logger.warning(
- "The current model is English-only but the multilingual parameter is set to"
- "True; setting to False instead."
- )
- multilingual = False
- if not isinstance(audio, np.ndarray):
- audio = decode_audio(audio, sampling_rate=sampling_rate)
- duration = audio.shape[0] / sampling_rate
- duration_after_vad = duration
- self.logger.info(
- "Processing audio with duration %s", format_timestamp(duration)
- )
- if vad_filter and clip_timestamps == "0":
- if vad_parameters is None:
- vad_parameters = VadOptions()
- elif isinstance(vad_parameters, dict):
- vad_parameters = VadOptions(**vad_parameters)
- speech_chunks = get_speech_timestamps(audio, vad_parameters)
- audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks)
- audio = np.concatenate(audio_chunks, axis=0)
- duration_after_vad = audio.shape[0] / sampling_rate
- self.logger.info(
- "VAD filter removed %s of audio",
- format_timestamp(duration - duration_after_vad),
- )
- if self.logger.isEnabledFor(logging.DEBUG):
- self.logger.debug(
- "VAD filter kept the following audio segments: %s",
- ", ".join(
- "[%s -> %s]"
- % (
- format_timestamp(chunk["start"] / sampling_rate),
- format_timestamp(chunk["end"] / sampling_rate),
- )
- for chunk in speech_chunks
- ),
- )
- else:
- speech_chunks = None
- features = self.feature_extractor(audio, chunk_length=chunk_length)
- encoder_output = None
- all_language_probs = None
- # detecting the language if not provided
- if language is None:
- if not self.model.is_multilingual:
- language = "en"
- language_probability = 1
- else:
- start_timestamp = (
- float(clip_timestamps.split(",")[0])
- if isinstance(clip_timestamps, str)
- else clip_timestamps[0]
- )
- content_frames = features.shape[-1] - 1
- seek = (
- int(start_timestamp * self.frames_per_second)
- if start_timestamp * self.frames_per_second < content_frames
- else 0
- )
- (
- language,
- language_probability,
- all_language_probs,
- ) = self.detect_language(
- features=features[..., seek:],
- language_detection_segments=language_detection_segments,
- language_detection_threshold=language_detection_threshold,
- )
- self.logger.info(
- "Detected language '%s' with probability %.2f",
- language,
- language_probability,
- )
- else:
- if not self.model.is_multilingual and language != "en":
- self.logger.warning(
- "The current model is English-only but the language parameter is set to '%s'; "
- "using 'en' instead." % language
- )
- language = "en"
- language_probability = 1
- tokenizer = Tokenizer(
- self.hf_tokenizer,
- self.model.is_multilingual,
- task=task,
- language=language,
- )
- options = TranscriptionOptions(
- beam_size=beam_size,
- best_of=best_of,
- patience=patience,
- length_penalty=length_penalty,
- repetition_penalty=repetition_penalty,
- no_repeat_ngram_size=no_repeat_ngram_size,
- log_prob_threshold=log_prob_threshold,
- no_speech_threshold=no_speech_threshold,
- compression_ratio_threshold=compression_ratio_threshold,
- condition_on_previous_text=condition_on_previous_text,
- prompt_reset_on_temperature=prompt_reset_on_temperature,
- temperatures=(
- temperature if isinstance(temperature, (list, tuple)) else [temperature]
- ),
- initial_prompt=initial_prompt,
- prefix=prefix,
- suppress_blank=suppress_blank,
- suppress_tokens=(
- get_suppressed_tokens(tokenizer, suppress_tokens)
- if suppress_tokens
- else suppress_tokens
- ),
- without_timestamps=without_timestamps,
- max_initial_timestamp=max_initial_timestamp,
- word_timestamps=word_timestamps,
- prepend_punctuations=prepend_punctuations,
- append_punctuations=append_punctuations,
- multilingual=multilingual,
- max_new_tokens=max_new_tokens,
- clip_timestamps=clip_timestamps,
- hallucination_silence_threshold=hallucination_silence_threshold,
- hotwords=hotwords,
- )
- segments = self.generate_segments(
- features, tokenizer, options, log_progress, encoder_output
- )
- if speech_chunks:
- segments = restore_speech_timestamps(segments, speech_chunks, sampling_rate)
- info = TranscriptionInfo(
- language=language,
- language_probability=language_probability,
- duration=duration,
- duration_after_vad=duration_after_vad,
- transcription_options=options,
- vad_options=vad_parameters,
- all_language_probs=all_language_probs,
- )
- return segments, info
- def _split_segments_by_timestamps(
- self,
- tokenizer: Tokenizer,
- tokens: List[int],
- time_offset: float,
- segment_size: int,
- segment_duration: float,
- seek: int,
- ) -> List[List[int]]:
- current_segments = []
- single_timestamp_ending = (
- len(tokens) >= 2 and tokens[-2] < tokenizer.timestamp_begin <= tokens[-1]
- )
- consecutive_timestamps = [
- i
- for i in range(len(tokens))
- if i > 0
- and tokens[i] >= tokenizer.timestamp_begin
- and tokens[i - 1] >= tokenizer.timestamp_begin
- ]
- if len(consecutive_timestamps) > 0:
- slices = list(consecutive_timestamps)
- if single_timestamp_ending:
- slices.append(len(tokens))
- last_slice = 0
- for current_slice in slices:
- sliced_tokens = tokens[last_slice:current_slice]
- start_timestamp_position = sliced_tokens[0] - tokenizer.timestamp_begin
- end_timestamp_position = sliced_tokens[-1] - tokenizer.timestamp_begin
- start_time = (
- time_offset + start_timestamp_position * self.time_precision
- )
- end_time = time_offset + end_timestamp_position * self.time_precision
- current_segments.append(
- dict(
- seek=seek,
- start=start_time,
- end=end_time,
- tokens=sliced_tokens,
- )
- )
- last_slice = current_slice
- if single_timestamp_ending:
- # single timestamp at the end means no speech after the last timestamp.
- seek += segment_size
- else:
- # otherwise, ignore the unfinished segment and seek to the last timestamp
- last_timestamp_position = (
- tokens[last_slice - 1] - tokenizer.timestamp_begin
- )
- seek += last_timestamp_position * self.input_stride
- else:
- duration = segment_duration
- timestamps = [
- token for token in tokens if token >= tokenizer.timestamp_begin
- ]
- if len(timestamps) > 0 and timestamps[-1] != tokenizer.timestamp_begin:
- last_timestamp_position = timestamps[-1] - tokenizer.timestamp_begin
- duration = last_timestamp_position * self.time_precision
- current_segments.append(
- dict(
- seek=seek,
- start=time_offset,
- end=time_offset + duration,
- tokens=tokens,
- )
- )
- seek += segment_size
- return current_segments, seek, single_timestamp_ending
- def generate_segments(
- self,
- features: np.ndarray,
- tokenizer: Tokenizer,
- options: TranscriptionOptions,
- log_progress,
- encoder_output: Optional[ctranslate2.StorageView] = None,
- ) -> Iterable[Segment]:
- content_frames = features.shape[-1] - 1
- content_duration = float(content_frames * self.feature_extractor.time_per_frame)
- if isinstance(options.clip_timestamps, str):
- options.clip_timestamps = [
- float(ts)
- for ts in (
- options.clip_timestamps.split(",")
- if options.clip_timestamps
- else []
- )
- ]
- seek_points: List[int] = [
- round(ts * self.frames_per_second) for ts in options.clip_timestamps
- ]
- if len(seek_points) == 0:
- seek_points.append(0)
- if len(seek_points) % 2 == 1:
- seek_points.append(content_frames)
- seek_clips: List[Tuple[int, int]] = list(
- zip(seek_points[::2], seek_points[1::2])
- )
- punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"
- idx = 0
- clip_idx = 0
- seek = seek_clips[clip_idx][0]
- all_tokens = []
- prompt_reset_since = 0
- if options.initial_prompt is not None:
- if isinstance(options.initial_prompt, str):
- initial_prompt = " " + options.initial_prompt.strip()
- initial_prompt_tokens = tokenizer.encode(initial_prompt)
- all_tokens.extend(initial_prompt_tokens)
- else:
- all_tokens.extend(options.initial_prompt)
- pbar = tqdm(total=content_duration, unit="seconds", disable=not log_progress)
- last_speech_timestamp = 0.0
- # NOTE: This loop is obscurely flattened to make the diff readable.
- # A later commit should turn this into a simpler nested loop.
- # for seek_clip_start, seek_clip_end in seek_clips:
- # while seek < seek_clip_end
- while clip_idx < len(seek_clips):
- seek_clip_start, seek_clip_end = seek_clips[clip_idx]
- if seek_clip_end > content_frames:
- seek_clip_end = content_frames
- if seek < seek_clip_start:
- seek = seek_clip_start
- if seek >= seek_clip_end:
- clip_idx += 1
- if clip_idx < len(seek_clips):
- seek = seek_clips[clip_idx][0]
- continue
- time_offset = seek * self.feature_extractor.time_per_frame
- window_end_time = float(
- (seek + self.feature_extractor.nb_max_frames)
- * self.feature_extractor.time_per_frame
- )
- segment_size = min(
- self.feature_extractor.nb_max_frames,
- content_frames - seek,
- seek_clip_end - seek,
- )
- segment = features[:, seek : seek + segment_size]
- segment_duration = segment_size * self.feature_extractor.time_per_frame
- segment = pad_or_trim(segment)
- if self.logger.isEnabledFor(logging.DEBUG):
- self.logger.debug(
- "Processing segment at %s", format_timestamp(time_offset)
- )
- previous_tokens = all_tokens[prompt_reset_since:]
- if seek > 0 or encoder_output is None:
- encoder_output = self.encode(segment)
- if options.multilingual:
- results = self.model.detect_language(encoder_output)
- language_token, language_probability = results[0][0]
- language = language_token[2:-2]
- tokenizer.language = tokenizer.tokenizer.token_to_id(language_token)
- tokenizer.language_code = language
- prompt = self.get_prompt(
- tokenizer,
- previous_tokens,
- without_timestamps=options.without_timestamps,
- prefix=options.prefix if seek == 0 else None,
- hotwords=options.hotwords,
- )
- (
- result,
- avg_logprob,
- temperature,
- compression_ratio,
- ) = self.generate_with_fallback(encoder_output, prompt, tokenizer, options)
- if options.no_speech_threshold is not None:
- # no voice activity check
- should_skip = result.no_speech_prob > options.no_speech_threshold
- if (
- options.log_prob_threshold is not None
- and avg_logprob > options.log_prob_threshold
- ):
- # don't skip if the logprob is high enough, despite the no_speech_prob
- should_skip = False
- if should_skip:
- self.logger.debug(
- "No speech threshold is met (%f > %f)",
- result.no_speech_prob,
- options.no_speech_threshold,
- )
- # fast-forward to the next segment boundary
- seek += segment_size
- continue
- tokens = result.sequences_ids[0]
- previous_seek = seek
- # anomalous words are very long/short/improbable
- def word_anomaly_score(word: dict) -> float:
- probability = word.get("probability", 0.0)
- duration = word["end"] - word["start"]
- score = 0.0
- if probability < 0.15:
- score += 1.0
- if duration < 0.133:
- score += (0.133 - duration) * 15
- if duration > 2.0:
- score += duration - 2.0
- return score
- def is_segment_anomaly(segment: Optional[dict]) -> bool:
- if segment is None or not segment["words"]:
- return False
- words = [w for w in segment["words"] if w["word"] not in punctuation]
- words = words[:8]
- score = sum(word_anomaly_score(w) for w in words)
- return score >= 3 or score + 0.01 >= len(words)
- def next_words_segment(segments: List[dict]) -> Optional[dict]:
- return next((s for s in segments if s["words"]), None)
- (
- current_segments,
- seek,
- single_timestamp_ending,
- ) = self._split_segments_by_timestamps(
- tokenizer=tokenizer,
- tokens=tokens,
- time_offset=time_offset,
- segment_size=segment_size,
- segment_duration=segment_duration,
- seek=seek,
- )
- if options.word_timestamps:
- self.add_word_timestamps(
- [current_segments],
- tokenizer,
- encoder_output,
- segment_size,
- options.prepend_punctuations,
- options.append_punctuations,
- last_speech_timestamp=last_speech_timestamp,
- )
- if not single_timestamp_ending:
- last_word_end = get_end(current_segments)
- if last_word_end is not None and last_word_end > time_offset:
- seek = round(last_word_end * self.frames_per_second)
- # skip silence before possible hallucinations
- if options.hallucination_silence_threshold is not None:
- threshold = options.hallucination_silence_threshold
- # if first segment might be a hallucination, skip leading silence
- first_segment = next_words_segment(current_segments)
- if first_segment is not None and is_segment_anomaly(first_segment):
- gap = first_segment["start"] - time_offset
- if gap > threshold:
- seek = previous_seek + round(gap * self.frames_per_second)
- continue
- # skip silence before any possible hallucination that is surrounded
- # by silence or more hallucinations
- hal_last_end = last_speech_timestamp
- for si in range(len(current_segments)):
- segment = current_segments[si]
- if not segment["words"]:
- continue
- if is_segment_anomaly(segment):
- next_segment = next_words_segment(
- current_segments[si + 1 :]
- )
- if next_segment is not None:
- hal_next_start = next_segment["words"][0]["start"]
- else:
- hal_next_start = time_offset + segment_duration
- silence_before = (
- segment["start"] - hal_last_end > threshold
- or segment["start"] < threshold
- or segment["start"] - time_offset < 2.0
- )
- silence_after = (
- hal_next_start - segment["end"] > threshold
- or is_segment_anomaly(next_segment)
- or window_end_time - segment["end"] < 2.0
- )
- if silence_before and silence_after:
- seek = round(
- max(time_offset + 1, segment["start"])
- * self.frames_per_second
- )
- if content_duration - segment["end"] < threshold:
- seek = content_frames
- current_segments[si:] = []
- break
- hal_last_end = segment["end"]
- last_word_end = get_end(current_segments)
- if last_word_end is not None:
- last_speech_timestamp = last_word_end
- for segment in current_segments:
- tokens = segment["tokens"]
- text = tokenizer.decode(tokens)
- if segment["start"] == segment["end"] or not text.strip():
- continue
- all_tokens.extend(tokens)
- idx += 1
- yield Segment(
- id=idx,
- seek=previous_seek,
- start=segment["start"],
- end=segment["end"],
- text=text,
- tokens=tokens,
- temperature=temperature,
- avg_logprob=avg_logprob,
- compression_ratio=compression_ratio,
- no_speech_prob=result.no_speech_prob,
- words=(
- [Word(**word) for word in segment["words"]]
- if options.word_timestamps
- else None
- ),
- )
- if (
- not options.condition_on_previous_text
- or temperature > options.prompt_reset_on_temperature
- ):
- if options.condition_on_previous_text:
- self.logger.debug(
- "Reset prompt. prompt_reset_on_temperature threshold is met %f > %f",
- temperature,
- options.prompt_reset_on_temperature,
- )
- prompt_reset_since = len(all_tokens)
- pbar.update(
- (min(content_frames, seek) - previous_seek)
- * self.feature_extractor.time_per_frame,
- )
- pbar.close()
- def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
- # When the model is running on multiple GPUs, the encoder output should be moved
- # to the CPU since we don't know which GPU will handle the next job.
- to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
- if features.ndim == 2:
- features = np.expand_dims(features, 0)
- features = get_ctranslate2_storage(features)
- return self.model.encode(features, to_cpu=to_cpu)
- def generate_with_fallback(
- self,
- encoder_output: ctranslate2.StorageView,
- prompt: List[int],
- tokenizer: Tokenizer,
- options: TranscriptionOptions,
- ) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float, float]:
- decode_result = None
- all_results = []
- below_cr_threshold_results = []
- max_initial_timestamp_index = int(
- round(options.max_initial_timestamp / self.time_precision)
- )
- if options.max_new_tokens is not None:
- max_length = len(prompt) + options.max_new_tokens
- else:
- max_length = self.max_length
- if max_length > self.max_length:
- raise ValueError(
- f"The length of the prompt is {len(prompt)}, and the `max_new_tokens` "
- f"{max_length - len(prompt)}. Thus, the combined length of the prompt "
- f"and `max_new_tokens` is: {max_length}. This exceeds the "
- f"`max_length` of the Whisper model: {self.max_length}. "
- "You should either reduce the length of your prompt, or "
- "reduce the value of `max_new_tokens`, "
- f"so that their combined length is less that {self.max_length}."
- )
- for temperature in options.temperatures:
- if temperature > 0:
- kwargs = {
- "beam_size": 1,
- "num_hypotheses": options.best_of,
- "sampling_topk": 0,
- "sampling_temperature": temperature,
- }
- else:
- kwargs = {
- "beam_size": options.beam_size,
- "patience": options.patience,
- }
- result = self.model.generate(
- encoder_output,
- [prompt],
- length_penalty=options.length_penalty,
- repetition_penalty=options.repetition_penalty,
- no_repeat_ngram_size=options.no_repeat_ngram_size,
- max_length=max_length,
- return_scores=True,
- return_no_speech_prob=True,
- suppress_blank=options.suppress_blank,
- suppress_tokens=options.suppress_tokens,
- max_initial_timestamp_index=max_initial_timestamp_index,
- **kwargs,
- )[0]
- tokens = result.sequences_ids[0]
- # Recover the average log prob from the returned score.
- seq_len = len(tokens)
- cum_logprob = result.scores[0] * (seq_len**options.length_penalty)
- avg_logprob = cum_logprob / (seq_len + 1)
- text = tokenizer.decode(tokens).strip()
- compression_ratio = get_compression_ratio(text)
- decode_result = (
- result,
- avg_logprob,
- temperature,
- compression_ratio,
- )
- all_results.append(decode_result)
- needs_fallback = False
- if options.compression_ratio_threshold is not None:
- if compression_ratio > options.compression_ratio_threshold:
- needs_fallback = True # too repetitive
- self.logger.debug(
- "Compression ratio threshold is not met with temperature %.1f (%f > %f)",
- temperature,
- compression_ratio,
- options.compression_ratio_threshold,
- )
- else:
- below_cr_threshold_results.append(decode_result)
- if (
- options.log_prob_threshold is not None
- and avg_logprob < options.log_prob_threshold
- ):
- needs_fallback = True # average log probability is too low
- self.logger.debug(
- "Log probability threshold is not met with temperature %.1f (%f < %f)",
- temperature,
- avg_logprob,
- options.log_prob_threshold,
- )
- if (
- options.no_speech_threshold is not None
- and result.no_speech_prob > options.no_speech_threshold
- and options.log_prob_threshold is not None
- and avg_logprob < options.log_prob_threshold
- ):
- needs_fallback = False # silence
- if not needs_fallback:
- break
- else:
- # all failed, select the result with the highest average log probability
- decode_result = max(
- below_cr_threshold_results or all_results, key=lambda x: x[1]
- )
- # to pass final temperature for prompt_reset_on_temperature
- decode_result = (
- decode_result[0],
- decode_result[1],
- temperature,
- decode_result[3],
- )
- return decode_result
- def get_prompt(
- self,
- tokenizer: Tokenizer,
- previous_tokens: List[int],
- without_timestamps: bool = False,
- prefix: Optional[str] = None,
- hotwords: Optional[str] = None,
- ) -> List[int]:
- prompt = []
- if previous_tokens or (hotwords and not prefix):
- prompt.append(tokenizer.sot_prev)
- if hotwords and not prefix:
- hotwords_tokens = tokenizer.encode(" " + hotwords.strip())
- if len(hotwords_tokens) >= self.max_length // 2:
- hotwords_tokens = hotwords_tokens[: self.max_length // 2 - 1]
- prompt.extend(hotwords_tokens)
- if previous_tokens:
- prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])
- prompt.extend(tokenizer.sot_sequence)
- if without_timestamps:
- prompt.append(tokenizer.no_timestamps)
- if prefix:
- prefix_tokens = tokenizer.encode(" " + prefix.strip())
- if len(prefix_tokens) >= self.max_length // 2:
- prefix_tokens = prefix_tokens[: self.max_length // 2 - 1]
- if not without_timestamps:
- prompt.append(tokenizer.timestamp_begin)
- prompt.extend(prefix_tokens)
- return prompt
- def add_word_timestamps(
- self,
- segments: List[dict],
- tokenizer: Tokenizer,
- encoder_output: ctranslate2.StorageView,
- num_frames: int,
- prepend_punctuations: str,
- append_punctuations: str,
- last_speech_timestamp: float,
- ) -> float:
- if len(segments) == 0:
- return
- text_tokens = []
- text_tokens_per_segment = []
- for segment in segments:
- segment_tokens = [
- [token for token in subsegment["tokens"] if token < tokenizer.eot]
- for subsegment in segment
- ]
- text_tokens.append(list(itertools.chain.from_iterable(segment_tokens)))
- text_tokens_per_segment.append(segment_tokens)
- alignments = self.find_alignment(
- tokenizer, text_tokens, encoder_output, num_frames
- )
- median_max_durations = []
- for alignment in alignments:
- word_durations = np.array(
- [word["end"] - word["start"] for word in alignment]
- )
- word_durations = word_durations[word_durations.nonzero()]
- median_duration = (
- np.median(word_durations) if len(word_durations) > 0 else 0.0
- )
- median_duration = min(0.7, float(median_duration))
- max_duration = median_duration * 2
- # hack: truncate long words at sentence boundaries.
- # a better segmentation algorithm based on VAD should be able to replace this.
- if len(word_durations) > 0:
- sentence_end_marks = ".。!!??"
- # ensure words at sentence boundaries
- # are not longer than twice the median word duration.
- for i in range(1, len(alignment)):
- if alignment[i]["end"] - alignment[i]["start"] > max_duration:
- if alignment[i]["word"] in sentence_end_marks:
- alignment[i]["end"] = alignment[i]["start"] + max_duration
- elif alignment[i - 1]["word"] in sentence_end_marks:
- alignment[i]["start"] = alignment[i]["end"] - max_duration
- merge_punctuations(alignment, prepend_punctuations, append_punctuations)
- median_max_durations.append((median_duration, max_duration))
- for segment_idx, segment in enumerate(segments):
- word_index = 0
- time_offset = segment[0]["seek"] / self.frames_per_second
- median_duration, max_duration = median_max_durations[segment_idx]
- for subsegment_idx, subsegment in enumerate(segment):
- saved_tokens = 0
- words = []
- while word_index < len(alignments[segment_idx]) and saved_tokens < len(
- text_tokens_per_segment[segment_idx][subsegment_idx]
- ):
- timing = alignments[segment_idx][word_index]
- if timing["word"]:
- words.append(
- dict(
- word=timing["word"],
- start=round(time_offset + timing["start"], 2),
- end=round(time_offset + timing["end"], 2),
- probability=timing["probability"],
- )
- )
- saved_tokens += len(timing["tokens"])
- word_index += 1
- # hack: truncate long words at segment boundaries.
- # a better segmentation algorithm based on VAD should be able to replace this.
- if len(words) > 0:
- # ensure the first and second word after a pause is not longer than
- # twice the median word duration.
- if words[0][
- "end"
- ] - last_speech_timestamp > median_duration * 4 and (
- words[0]["end"] - words[0]["start"] > max_duration
- or (
- len(words) > 1
- and words[1]["end"] - words[0]["start"] > max_duration * 2
- )
- ):
- if (
- len(words) > 1
- and words[1]["end"] - words[1]["start"] > max_duration
- ):
- boundary = max(
- words[1]["end"] / 2, words[1]["end"] - max_duration
- )
- words[0]["end"] = words[1]["start"] = boundary
- words[0]["start"] = max(0, words[0]["end"] - max_duration)
- # prefer the segment-level start timestamp if the first word is too long.
- if (
- subsegment["start"] < words[0]["end"]
- and subsegment["start"] - 0.5 > words[0]["start"]
- ):
- words[0]["start"] = max(
- 0,
- min(words[0]["end"] - median_duration, subsegment["start"]),
- )
- else:
- subsegment["start"] = words[0]["start"]
- # prefer the segment-level end timestamp if the last word is too long.
- if (
- subsegment["end"] > words[-1]["start"]
- and subsegment["end"] + 0.5 < words[-1]["end"]
- ):
- words[-1]["end"] = max(
- words[-1]["start"] + median_duration, subsegment["end"]
- )
- else:
- subsegment["end"] = words[-1]["end"]
- last_speech_timestamp = subsegment["end"]
- segments[segment_idx][subsegment_idx]["words"] = words
- return last_speech_timestamp
- def find_alignment(
- self,
- tokenizer: Tokenizer,
- text_tokens: List[int],
- encoder_output: ctranslate2.StorageView,
- num_frames: int,
- median_filter_width: int = 7,
- ) -> List[dict]:
- if len(text_tokens) == 0:
- return []
- results = self.model.align(
- encoder_output,
- tokenizer.sot_sequence,
- text_tokens,
- num_frames,
- median_filter_width=median_filter_width,
- )
- return_list = []
- for result, text_token in zip(results, text_tokens):
- text_token_probs = result.text_token_probs
- alignments = result.alignments
- text_indices = np.array([pair[0] for pair in alignments])
- time_indices = np.array([pair[1] for pair in alignments])
- words, word_tokens = tokenizer.split_to_word_tokens(
- text_token + [tokenizer.eot]
- )
- if len(word_tokens) <= 1:
- # return on eot only
- # >>> np.pad([], (1, 0))
- # array([0.])
- # This results in crashes when we lookup jump_times with float, like
- # IndexError: arrays used as indices must be of integer (or boolean) type
- return_list.append([])
- continue
- word_boundaries = np.pad(
- np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0)
- )
- if len(word_boundaries) <= 1:
- return_list.append([])
- continue
- jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(
- bool
- )
- jump_times = time_indices[jumps] / self.tokens_per_second
- start_times = jump_times[word_boundaries[:-1]]
- end_times = jump_times[word_boundaries[1:]]
- word_probabilities = [
- np.mean(text_token_probs[i:j])
- for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
- ]
- return_list.append(
- [
- dict(
- word=word,
- tokens=tokens,
- start=start,
- end=end,
- probability=probability,
- )
- for word, tokens, start, end, probability in zip(
- words, word_tokens, start_times, end_times, word_probabilities
- )
- ]
- )
- return return_list
- def detect_language(
- self,
- audio: Optional[np.ndarray] = None,
- features: Optional[np.ndarray] = None,
- vad_filter: bool = False,
- vad_parameters: Union[dict, VadOptions] = None,
- language_detection_segments: int = 1,
- language_detection_threshold: float = 0.5,
- ) -> Tuple[str, float, List[Tuple[str, float]]]:
- """
- Use Whisper to detect the language of the input audio or features.
- Arguments:
- audio: Input audio signal, must be a 1D float array sampled at 16khz.
- features: Input Mel spectrogram features, must be a float array with
- shape (n_mels, n_frames), if `audio` is provided, the features will be ignored.
- Either `audio` or `features` must be provided.
- vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
- without speech. This step is using the Silero VAD model.
- vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
- parameters and default values in the class `VadOptions`).
- language_detection_threshold: If the maximum probability of the language tokens is
- higher than this value, the language is detected.
- language_detection_segments: Number of segments to consider for the language detection.
- Returns:
- language: Detected language.
- languege_probability: Probability of the detected language.
- all_language_probs: List of tuples with all language names and probabilities.
- """
- assert (
- audio is not None or features is not None
- ), "Either `audio` or `features` must be provided."
- if audio is not None:
- if vad_filter:
- speech_chunks = get_speech_timestamps(audio, vad_parameters)
- audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks)
- audio = np.concatenate(audio_chunks, axis=0)
- audio = audio[
- : language_detection_segments * self.feature_extractor.n_samples
- ]
- features = self.feature_extractor(audio)
- features = features[
- ..., : language_detection_segments * self.feature_extractor.nb_max_frames
- ]
- detected_language_info = {}
- for i in range(0, features.shape[-1], self.feature_extractor.nb_max_frames):
- encoder_output = self.encode(
- pad_or_trim(features[..., i : i + self.feature_extractor.nb_max_frames])
- )
- # results is a list of tuple[str, float] with language names and probabilities.
- results = self.model.detect_language(encoder_output)[0]
- # Parse language names to strip out markers
- all_language_probs = [(token[2:-2], prob) for (token, prob) in results]
- # Get top language token and probability
- language, language_probability = all_language_probs[0]
- if language_probability > language_detection_threshold:
- break
- detected_language_info.setdefault(language, []).append(language_probability)
- else:
- # If no language detected for all segments, the majority vote of the highest
- # projected languages for all segments is used to determine the language.
- language = max(
- detected_language_info,
- key=lambda lang: len(detected_language_info[lang]),
- )
- language_probability = max(detected_language_info[language])
- return language, language_probability, all_language_probs
- def restore_speech_timestamps(
- segments: Iterable[Segment],
- speech_chunks: List[dict],
- sampling_rate: int,
- ) -> Iterable[Segment]:
- ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate)
- for segment in segments:
- if segment.words:
- words = []
- for word in segment.words:
- # Ensure the word start and end times are resolved to the same chunk.
- middle = (word.start + word.end) / 2
- chunk_index = ts_map.get_chunk_index(middle)
- word.start = ts_map.get_original_time(word.start, chunk_index)
- word.end = ts_map.get_original_time(word.end, chunk_index)
- words.append(word)
- segment.start = words[0].start
- segment.end = words[-1].end
- segment.words = words
- else:
- segment.start = ts_map.get_original_time(segment.start)
- segment.end = ts_map.get_original_time(segment.end)
- yield segment
- def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView:
- segment = np.ascontiguousarray(segment)
- segment = ctranslate2.StorageView.from_array(segment)
- return segment
- def get_compression_ratio(text: str) -> float:
- text_bytes = text.encode("utf-8")
- return len(text_bytes) / len(zlib.compress(text_bytes))
- def get_suppressed_tokens(
- tokenizer: Tokenizer,
- suppress_tokens: Tuple[int],
- ) -> Optional[List[int]]:
- if -1 in suppress_tokens:
- suppress_tokens = [t for t in suppress_tokens if t >= 0]
- suppress_tokens.extend(tokenizer.non_speech_tokens)
- elif suppress_tokens is None or len(suppress_tokens) == 0:
- suppress_tokens = [] # interpret empty string as an empty list
- else:
- assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
- suppress_tokens.extend(
- [
- tokenizer.transcribe,
- tokenizer.translate,
- tokenizer.sot,
- tokenizer.sot_prev,
- tokenizer.sot_lm,
- ]
- )
- return tuple(sorted(set(suppress_tokens)))
- def merge_punctuations(alignment: List[dict], prepended: str, appended: str) -> None:
- # merge prepended punctuations
- i = len(alignment) - 2
- j = len(alignment) - 1
- while i >= 0:
- previous = alignment[i]
- following = alignment[j]
- if previous["word"].startswith(" ") and previous["word"].strip() in prepended:
- # prepend it to the following word
- following["word"] = previous["word"] + following["word"]
- following["tokens"] = previous["tokens"] + following["tokens"]
- previous["word"] = ""
- previous["tokens"] = []
- else:
- j = i
- i -= 1
- # merge appended punctuations
- i = 0
- j = 1
- while j < len(alignment):
- previous = alignment[i]
- following = alignment[j]
- if not previous["word"].endswith(" ") and following["word"] in appended:
- # append it to the previous word
- previous["word"] = previous["word"] + following["word"]
- previous["tokens"] = previous["tokens"] + following["tokens"]
- following["word"] = ""
- following["tokens"] = []
- else:
- i = j
- j += 1
|