vad.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. import bisect
  2. import functools
  3. import os
  4. from dataclasses import dataclass
  5. from typing import Dict, List, Optional, Tuple
  6. import numpy as np
  7. from faster_whisper.utils import get_assets_path
  8. # The code below is adapted from https://github.com/snakers4/silero-vad.
  9. @dataclass
  10. class VadOptions:
  11. """VAD options.
  12. Attributes:
  13. threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
  14. probabilities ABOVE this value are considered as SPEECH. It is better to tune this
  15. parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
  16. neg_threshold: Silence threshold for determining the end of speech. If a probability is lower
  17. than neg_threshold, it is always considered silence. Values higher than neg_threshold
  18. are only considered speech if the previous sample was classified as speech; otherwise,
  19. they are treated as silence. This parameter helps refine the detection of speech
  20. transitions, ensuring smoother segment boundaries.
  21. min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out.
  22. max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer
  23. than max_speech_duration_s will be split at the timestamp of the last silence that
  24. lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be
  25. split aggressively just before max_speech_duration_s.
  26. min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms
  27. before separating it
  28. speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side
  29. """
  30. threshold: float = 0.5
  31. neg_threshold: float = None
  32. min_speech_duration_ms: int = 0
  33. max_speech_duration_s: float = float("inf")
  34. min_silence_duration_ms: int = 2000
  35. speech_pad_ms: int = 400
  36. def get_speech_timestamps(
  37. audio: np.ndarray,
  38. vad_options: Optional[VadOptions] = None,
  39. sampling_rate: int = 16000,
  40. **kwargs,
  41. ) -> List[dict]:
  42. """This method is used for splitting long audios into speech chunks using silero VAD.
  43. Args:
  44. audio: One dimensional float array.
  45. vad_options: Options for VAD processing.
  46. sampling rate: Sampling rate of the audio.
  47. kwargs: VAD options passed as keyword arguments for backward compatibility.
  48. Returns:
  49. List of dicts containing begin and end samples of each speech chunk.
  50. """
  51. if vad_options is None:
  52. vad_options = VadOptions(**kwargs)
  53. threshold = vad_options.threshold
  54. neg_threshold = vad_options.neg_threshold
  55. min_speech_duration_ms = vad_options.min_speech_duration_ms
  56. max_speech_duration_s = vad_options.max_speech_duration_s
  57. min_silence_duration_ms = vad_options.min_silence_duration_ms
  58. window_size_samples = 512
  59. speech_pad_ms = vad_options.speech_pad_ms
  60. min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
  61. speech_pad_samples = sampling_rate * speech_pad_ms / 1000
  62. max_speech_samples = (
  63. sampling_rate * max_speech_duration_s
  64. - window_size_samples
  65. - 2 * speech_pad_samples
  66. )
  67. min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
  68. min_silence_samples_at_max_speech = sampling_rate * 98 / 1000
  69. audio_length_samples = len(audio)
  70. model = get_vad_model()
  71. padded_audio = np.pad(
  72. audio, (0, window_size_samples - audio.shape[0] % window_size_samples)
  73. )
  74. speech_probs = model(padded_audio.reshape(1, -1)).squeeze(0)
  75. triggered = False
  76. speeches = []
  77. current_speech = {}
  78. if neg_threshold is None:
  79. neg_threshold = max(threshold - 0.15, 0.01)
  80. # to save potential segment end (and tolerate some silence)
  81. temp_end = 0
  82. # to save potential segment limits in case of maximum segment size reached
  83. prev_end = next_start = 0
  84. for i, speech_prob in enumerate(speech_probs):
  85. if (speech_prob >= threshold) and temp_end:
  86. temp_end = 0
  87. if next_start < prev_end:
  88. next_start = window_size_samples * i
  89. if (speech_prob >= threshold) and not triggered:
  90. triggered = True
  91. current_speech["start"] = window_size_samples * i
  92. continue
  93. if (
  94. triggered
  95. and (window_size_samples * i) - current_speech["start"] > max_speech_samples
  96. ):
  97. if prev_end:
  98. current_speech["end"] = prev_end
  99. speeches.append(current_speech)
  100. current_speech = {}
  101. # previously reached silence (< neg_thres) and is still not speech (< thres)
  102. if next_start < prev_end:
  103. triggered = False
  104. else:
  105. current_speech["start"] = next_start
  106. prev_end = next_start = temp_end = 0
  107. else:
  108. current_speech["end"] = window_size_samples * i
  109. speeches.append(current_speech)
  110. current_speech = {}
  111. prev_end = next_start = temp_end = 0
  112. triggered = False
  113. continue
  114. if (speech_prob < neg_threshold) and triggered:
  115. if not temp_end:
  116. temp_end = window_size_samples * i
  117. # condition to avoid cutting in very short silence
  118. if (window_size_samples * i) - temp_end > min_silence_samples_at_max_speech:
  119. prev_end = temp_end
  120. if (window_size_samples * i) - temp_end < min_silence_samples:
  121. continue
  122. else:
  123. current_speech["end"] = temp_end
  124. if (
  125. current_speech["end"] - current_speech["start"]
  126. ) > min_speech_samples:
  127. speeches.append(current_speech)
  128. current_speech = {}
  129. prev_end = next_start = temp_end = 0
  130. triggered = False
  131. continue
  132. if (
  133. current_speech
  134. and (audio_length_samples - current_speech["start"]) > min_speech_samples
  135. ):
  136. current_speech["end"] = audio_length_samples
  137. speeches.append(current_speech)
  138. for i, speech in enumerate(speeches):
  139. if i == 0:
  140. speech["start"] = int(max(0, speech["start"] - speech_pad_samples))
  141. if i != len(speeches) - 1:
  142. silence_duration = speeches[i + 1]["start"] - speech["end"]
  143. if silence_duration < 2 * speech_pad_samples:
  144. speech["end"] += int(silence_duration // 2)
  145. speeches[i + 1]["start"] = int(
  146. max(0, speeches[i + 1]["start"] - silence_duration // 2)
  147. )
  148. else:
  149. speech["end"] = int(
  150. min(audio_length_samples, speech["end"] + speech_pad_samples)
  151. )
  152. speeches[i + 1]["start"] = int(
  153. max(0, speeches[i + 1]["start"] - speech_pad_samples)
  154. )
  155. else:
  156. speech["end"] = int(
  157. min(audio_length_samples, speech["end"] + speech_pad_samples)
  158. )
  159. return speeches
  160. def collect_chunks(
  161. audio: np.ndarray, chunks: List[dict], sampling_rate: int = 16000
  162. ) -> Tuple[List[np.ndarray], List[Dict[str, int]]]:
  163. """Collects audio chunks."""
  164. if not chunks:
  165. chunk_metadata = {
  166. "start_time": 0,
  167. "end_time": 0,
  168. }
  169. return [np.array([], dtype=np.float32)], [chunk_metadata]
  170. audio_chunks = []
  171. chunks_metadata = []
  172. for chunk in chunks:
  173. chunk_metadata = {
  174. "start_time": chunk["start"] / sampling_rate,
  175. "end_time": chunk["end"] / sampling_rate,
  176. }
  177. audio_chunks.append(audio[chunk["start"] : chunk["end"]])
  178. chunks_metadata.append(chunk_metadata)
  179. return audio_chunks, chunks_metadata
  180. class SpeechTimestampsMap:
  181. """Helper class to restore original speech timestamps."""
  182. def __init__(self, chunks: List[dict], sampling_rate: int, time_precision: int = 2):
  183. self.sampling_rate = sampling_rate
  184. self.time_precision = time_precision
  185. self.chunk_end_sample = []
  186. self.total_silence_before = []
  187. previous_end = 0
  188. silent_samples = 0
  189. for chunk in chunks:
  190. silent_samples += chunk["start"] - previous_end
  191. previous_end = chunk["end"]
  192. self.chunk_end_sample.append(chunk["end"] - silent_samples)
  193. self.total_silence_before.append(silent_samples / sampling_rate)
  194. def get_original_time(
  195. self,
  196. time: float,
  197. chunk_index: Optional[int] = None,
  198. ) -> float:
  199. if chunk_index is None:
  200. chunk_index = self.get_chunk_index(time)
  201. total_silence_before = self.total_silence_before[chunk_index]
  202. return round(total_silence_before + time, self.time_precision)
  203. def get_chunk_index(self, time: float) -> int:
  204. sample = int(time * self.sampling_rate)
  205. return min(
  206. bisect.bisect(self.chunk_end_sample, sample),
  207. len(self.chunk_end_sample) - 1,
  208. )
  209. @functools.lru_cache
  210. def get_vad_model():
  211. """Returns the VAD model instance."""
  212. encoder_path = os.path.join(get_assets_path(), "silero_encoder_v5.onnx")
  213. decoder_path = os.path.join(get_assets_path(), "silero_decoder_v5.onnx")
  214. return SileroVADModel(encoder_path, decoder_path)
  215. class SileroVADModel:
  216. def __init__(self, encoder_path, decoder_path):
  217. try:
  218. import onnxruntime
  219. except ImportError as e:
  220. raise RuntimeError(
  221. "Applying the VAD filter requires the onnxruntime package"
  222. ) from e
  223. opts = onnxruntime.SessionOptions()
  224. opts.inter_op_num_threads = 1
  225. opts.intra_op_num_threads = 1
  226. opts.enable_cpu_mem_arena = False
  227. opts.log_severity_level = 4
  228. self.encoder_session = onnxruntime.InferenceSession(
  229. encoder_path,
  230. providers=["CPUExecutionProvider"],
  231. sess_options=opts,
  232. )
  233. self.decoder_session = onnxruntime.InferenceSession(
  234. decoder_path,
  235. providers=["CPUExecutionProvider"],
  236. sess_options=opts,
  237. )
  238. def __call__(
  239. self, audio: np.ndarray, num_samples: int = 512, context_size_samples: int = 64
  240. ):
  241. assert (
  242. audio.ndim == 2
  243. ), "Input should be a 2D array with size (batch_size, num_samples)"
  244. assert (
  245. audio.shape[1] % num_samples == 0
  246. ), "Input size should be a multiple of num_samples"
  247. batch_size = audio.shape[0]
  248. state = np.zeros((2, batch_size, 128), dtype="float32")
  249. context = np.zeros(
  250. (batch_size, context_size_samples),
  251. dtype="float32",
  252. )
  253. batched_audio = audio.reshape(batch_size, -1, num_samples)
  254. context = batched_audio[..., -context_size_samples:]
  255. context[:, -1] = 0
  256. context = np.roll(context, 1, 1)
  257. batched_audio = np.concatenate([context, batched_audio], 2)
  258. batched_audio = batched_audio.reshape(-1, num_samples + context_size_samples)
  259. encoder_batch_size = 10000
  260. num_segments = batched_audio.shape[0]
  261. encoder_outputs = []
  262. for i in range(0, num_segments, encoder_batch_size):
  263. encoder_output = self.encoder_session.run(
  264. None, {"input": batched_audio[i : i + encoder_batch_size]}
  265. )[0]
  266. encoder_outputs.append(encoder_output)
  267. encoder_output = np.concatenate(encoder_outputs, axis=0)
  268. encoder_output = encoder_output.reshape(batch_size, -1, 128)
  269. decoder_outputs = []
  270. for window in np.split(encoder_output, encoder_output.shape[1], axis=1):
  271. out, state = self.decoder_session.run(
  272. None, {"input": window.squeeze(1), "state": state}
  273. )
  274. decoder_outputs.append(out)
  275. out = np.stack(decoder_outputs, axis=1).squeeze(-1)
  276. return out
  277. def merge_segments(segments_list, vad_options: VadOptions, sampling_rate: int = 16000):
  278. if not segments_list:
  279. return []
  280. curr_end = 0
  281. seg_idxs = []
  282. merged_segments = []
  283. edge_padding = vad_options.speech_pad_ms * sampling_rate // 1000
  284. chunk_length = vad_options.max_speech_duration_s * sampling_rate
  285. curr_start = segments_list[0]["start"]
  286. for idx, seg in enumerate(segments_list):
  287. # if any segment start timing is less than previous segment end timing,
  288. # reset the edge padding. Similarly for end timing.
  289. if idx > 0:
  290. if seg["start"] < segments_list[idx - 1]["end"]:
  291. seg["start"] += edge_padding
  292. if idx < len(segments_list) - 1:
  293. if seg["end"] > segments_list[idx + 1]["start"]:
  294. seg["end"] -= edge_padding
  295. if seg["end"] - curr_start > chunk_length and curr_end - curr_start > 0:
  296. merged_segments.append(
  297. {
  298. "start": curr_start,
  299. "end": curr_end,
  300. "segments": seg_idxs,
  301. }
  302. )
  303. curr_start = seg["start"]
  304. seg_idxs = []
  305. curr_end = seg["end"]
  306. seg_idxs.append((seg["start"], seg["end"]))
  307. # add final
  308. merged_segments.append(
  309. {
  310. "start": curr_start,
  311. "end": curr_end,
  312. "segments": seg_idxs,
  313. }
  314. )
  315. return merged_segments