audio.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. """We use the PyAV library to decode the audio: https://github.com/PyAV-Org/PyAV
  2. The advantage of PyAV is that it bundles the FFmpeg libraries so there is no additional
  3. system dependencies. FFmpeg does not need to be installed on the system.
  4. However, the API is quite low-level so we need to manipulate audio frames directly.
  5. """
  6. import gc
  7. import io
  8. import itertools
  9. from typing import BinaryIO, Union
  10. import av
  11. import numpy as np
  12. def decode_audio(
  13. input_file: Union[str, BinaryIO],
  14. sampling_rate: int = 16000,
  15. split_stereo: bool = False,
  16. ):
  17. """Decodes the audio.
  18. Args:
  19. input_file: Path to the input file or a file-like object.
  20. sampling_rate: Resample the audio to this sample rate.
  21. split_stereo: Return separate left and right channels.
  22. Returns:
  23. A float32 Numpy array.
  24. If `split_stereo` is enabled, the function returns a 2-tuple with the
  25. separated left and right channels.
  26. """
  27. resampler = av.audio.resampler.AudioResampler(
  28. format="s16",
  29. layout="mono" if not split_stereo else "stereo",
  30. rate=sampling_rate,
  31. )
  32. raw_buffer = io.BytesIO()
  33. dtype = None
  34. with av.open(input_file, mode="r", metadata_errors="ignore") as container:
  35. frames = container.decode(audio=0)
  36. frames = _ignore_invalid_frames(frames)
  37. frames = _group_frames(frames, 500000)
  38. frames = _resample_frames(frames, resampler)
  39. for frame in frames:
  40. array = frame.to_ndarray()
  41. dtype = array.dtype
  42. raw_buffer.write(array)
  43. # It appears that some objects related to the resampler are not freed
  44. # unless the garbage collector is manually run.
  45. # https://github.com/SYSTRAN/faster-whisper/issues/390
  46. # note that this slows down loading the audio a little bit
  47. # if that is a concern, please use ffmpeg directly as in here:
  48. # https://github.com/openai/whisper/blob/25639fc/whisper/audio.py#L25-L62
  49. del resampler
  50. gc.collect()
  51. audio = np.frombuffer(raw_buffer.getbuffer(), dtype=dtype)
  52. # Convert s16 back to f32.
  53. audio = audio.astype(np.float32) / 32768.0
  54. if split_stereo:
  55. left_channel = audio[0::2]
  56. right_channel = audio[1::2]
  57. return left_channel, right_channel
  58. return audio
  59. def _ignore_invalid_frames(frames):
  60. iterator = iter(frames)
  61. while True:
  62. try:
  63. yield next(iterator)
  64. except StopIteration:
  65. break
  66. except av.error.InvalidDataError:
  67. continue
  68. def _group_frames(frames, num_samples=None):
  69. fifo = av.audio.fifo.AudioFifo()
  70. for frame in frames:
  71. frame.pts = None # Ignore timestamp check.
  72. fifo.write(frame)
  73. if num_samples is not None and fifo.samples >= num_samples:
  74. yield fifo.read()
  75. if fifo.samples > 0:
  76. yield fifo.read()
  77. def _resample_frames(frames, resampler):
  78. # Add None to flush the resampler.
  79. for frame in itertools.chain(frames, [None]):
  80. yield from resampler.resample(frame)
  81. def pad_or_trim(array, length: int = 3000, *, axis: int = -1):
  82. """
  83. Pad or trim the Mel features array to 3000, as expected by the encoder.
  84. """
  85. if array.shape[axis] > length:
  86. array = array.take(indices=range(length), axis=axis)
  87. if array.shape[axis] < length:
  88. pad_widths = [(0, 0)] * array.ndim
  89. pad_widths[axis] = (0, length - array.shape[axis])
  90. array = np.pad(array, pad_widths)
  91. return array