feature_extractor.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. import numpy as np
  2. class FeatureExtractor:
  3. def __init__(
  4. self,
  5. feature_size=80,
  6. sampling_rate=16000,
  7. hop_length=160,
  8. chunk_length=30,
  9. n_fft=400,
  10. ):
  11. self.n_fft = n_fft
  12. self.hop_length = hop_length
  13. self.chunk_length = chunk_length
  14. self.n_samples = chunk_length * sampling_rate
  15. self.nb_max_frames = self.n_samples // hop_length
  16. self.time_per_frame = hop_length / sampling_rate
  17. self.sampling_rate = sampling_rate
  18. self.mel_filters = self.get_mel_filters(
  19. sampling_rate, n_fft, n_mels=feature_size
  20. ).astype("float32")
  21. @staticmethod
  22. def get_mel_filters(sr, n_fft, n_mels=128):
  23. # Initialize the weights
  24. n_mels = int(n_mels)
  25. # Center freqs of each FFT bin
  26. fftfreqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sr)
  27. # 'Center freqs' of mel bands - uniformly spaced between limits
  28. min_mel = 0.0
  29. max_mel = 45.245640471924965
  30. mels = np.linspace(min_mel, max_mel, n_mels + 2)
  31. # Fill in the linear scale
  32. f_min = 0.0
  33. f_sp = 200.0 / 3
  34. freqs = f_min + f_sp * mels
  35. # And now the nonlinear scale
  36. min_log_hz = 1000.0 # beginning of log region (Hz)
  37. min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
  38. logstep = np.log(6.4) / 27.0 # step size for log region
  39. # If we have vector data, vectorize
  40. log_t = mels >= min_log_mel
  41. freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel))
  42. fdiff = np.diff(freqs)
  43. ramps = freqs.reshape(-1, 1) - fftfreqs.reshape(1, -1)
  44. lower = -ramps[:-2] / np.expand_dims(fdiff[:-1], axis=1)
  45. upper = ramps[2:] / np.expand_dims(fdiff[1:], axis=1)
  46. # Intersect them with each other and zero, vectorized across all i
  47. weights = np.maximum(np.zeros_like(lower), np.minimum(lower, upper))
  48. # Slaney-style mel is scaled to be approx constant energy per channel
  49. enorm = 2.0 / (freqs[2 : n_mels + 2] - freqs[:n_mels])
  50. weights *= np.expand_dims(enorm, axis=1)
  51. return weights
  52. @staticmethod
  53. def stft(
  54. input_array: np.ndarray,
  55. n_fft: int,
  56. hop_length: int = None,
  57. win_length: int = None,
  58. window: np.ndarray = None,
  59. center: bool = True,
  60. mode: str = "reflect",
  61. normalized: bool = False,
  62. onesided: bool = None,
  63. return_complex: bool = None,
  64. ):
  65. # Default initialization for hop_length and win_length
  66. hop_length = hop_length if hop_length is not None else n_fft // 4
  67. win_length = win_length if win_length is not None else n_fft
  68. input_is_complex = np.iscomplexobj(input_array)
  69. # Determine if the output should be complex
  70. return_complex = (
  71. return_complex
  72. if return_complex is not None
  73. else (input_is_complex or (window is not None and np.iscomplexobj(window)))
  74. )
  75. if not return_complex and return_complex is None:
  76. raise ValueError(
  77. "stft requires the return_complex parameter for real inputs."
  78. )
  79. # Input checks
  80. if not np.issubdtype(input_array.dtype, np.floating) and not input_is_complex:
  81. raise ValueError(
  82. "stft: expected an array of floating point or complex values,"
  83. f" got {input_array.dtype}"
  84. )
  85. if input_array.ndim > 2 or input_array.ndim < 1:
  86. raise ValueError(
  87. f"stft: expected a 1D or 2D array, but got {input_array.ndim}D array"
  88. )
  89. # Handle 1D input
  90. if input_array.ndim == 1:
  91. input_array = np.expand_dims(input_array, axis=0)
  92. input_array_1d = True
  93. else:
  94. input_array_1d = False
  95. # Center padding if required
  96. if center:
  97. pad_amount = n_fft // 2
  98. input_array = np.pad(
  99. input_array, ((0, 0), (pad_amount, pad_amount)), mode=mode
  100. )
  101. batch, length = input_array.shape
  102. # Additional input checks
  103. if n_fft <= 0 or n_fft > length:
  104. raise ValueError(
  105. f"stft: expected 0 < n_fft <= {length}, but got n_fft={n_fft}"
  106. )
  107. if hop_length <= 0:
  108. raise ValueError(
  109. f"stft: expected hop_length > 0, but got hop_length={hop_length}"
  110. )
  111. if win_length <= 0 or win_length > n_fft:
  112. raise ValueError(
  113. f"stft: expected 0 < win_length <= n_fft, but got win_length={win_length}"
  114. )
  115. if window is not None:
  116. if window.ndim != 1 or window.shape[0] != win_length:
  117. raise ValueError(
  118. f"stft: expected a 1D window array of size equal to win_length={win_length}, "
  119. f"but got window with size {window.shape}"
  120. )
  121. # Handle padding of the window if necessary
  122. if win_length < n_fft:
  123. left = (n_fft - win_length) // 2
  124. window_ = np.zeros(n_fft, dtype=window.dtype)
  125. window_[left : left + win_length] = window
  126. else:
  127. window_ = window
  128. # Calculate the number of frames
  129. n_frames = 1 + (length - n_fft) // hop_length
  130. # Time to columns
  131. input_array = np.lib.stride_tricks.as_strided(
  132. input_array,
  133. (batch, n_frames, n_fft),
  134. (
  135. input_array.strides[0],
  136. hop_length * input_array.strides[1],
  137. input_array.strides[1],
  138. ),
  139. )
  140. if window_ is not None:
  141. input_array = input_array * window_
  142. # FFT and transpose
  143. complex_fft = input_is_complex
  144. onesided = onesided if onesided is not None else not complex_fft
  145. if normalized:
  146. norm = "ortho"
  147. else:
  148. norm = None
  149. if complex_fft:
  150. if onesided:
  151. raise ValueError(
  152. "Cannot have onesided output if window or input is complex"
  153. )
  154. output = np.fft.fft(input_array, n=n_fft, axis=-1, norm=norm)
  155. else:
  156. output = np.fft.rfft(input_array, n=n_fft, axis=-1, norm=norm)
  157. output = output.transpose((0, 2, 1))
  158. if input_array_1d:
  159. output = output.squeeze(0)
  160. return output if return_complex else np.real(output)
  161. def __call__(self, waveform: np.ndarray, padding=160, chunk_length=None):
  162. """
  163. Compute the log-Mel spectrogram of the provided audio.
  164. """
  165. if chunk_length is not None:
  166. self.n_samples = chunk_length * self.sampling_rate
  167. self.nb_max_frames = self.n_samples // self.hop_length
  168. if waveform.dtype is not np.float32:
  169. waveform = waveform.astype(np.float32)
  170. if padding:
  171. waveform = np.pad(waveform, (0, padding))
  172. window = np.hanning(self.n_fft + 1)[:-1].astype("float32")
  173. stft = self.stft(
  174. waveform,
  175. self.n_fft,
  176. self.hop_length,
  177. window=window,
  178. return_complex=True,
  179. ).astype("complex64")
  180. magnitudes = np.abs(stft[..., :-1]) ** 2
  181. mel_spec = self.mel_filters @ magnitudes
  182. log_spec = np.log10(np.clip(mel_spec, a_min=1e-10, a_max=None))
  183. log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
  184. log_spec = (log_spec + 4.0) / 4.0
  185. return log_spec