utils.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. import logging
  2. import os
  3. import re
  4. from typing import List, Optional
  5. import huggingface_hub
  6. import requests
  7. from tqdm.auto import tqdm
  8. _MODELS = {
  9. "tiny.en": "Systran/faster-whisper-tiny.en",
  10. "tiny": "Systran/faster-whisper-tiny",
  11. "base.en": "Systran/faster-whisper-base.en",
  12. "base": "Systran/faster-whisper-base",
  13. "small.en": "Systran/faster-whisper-small.en",
  14. "small": "Systran/faster-whisper-small",
  15. "medium.en": "Systran/faster-whisper-medium.en",
  16. "medium": "Systran/faster-whisper-medium",
  17. "large-v1": "Systran/faster-whisper-large-v1",
  18. "large-v2": "Systran/faster-whisper-large-v2",
  19. "large-v3": "Systran/faster-whisper-large-v3",
  20. "large": "Systran/faster-whisper-large-v3",
  21. "distil-large-v2": "Systran/faster-distil-whisper-large-v2",
  22. "distil-medium.en": "Systran/faster-distil-whisper-medium.en",
  23. "distil-small.en": "Systran/faster-distil-whisper-small.en",
  24. "distil-large-v3": "Systran/faster-distil-whisper-large-v3",
  25. "large-v3-turbo": "mobiuslabsgmbh/faster-whisper-large-v3-turbo",
  26. "turbo": "mobiuslabsgmbh/faster-whisper-large-v3-turbo",
  27. }
  28. def available_models() -> List[str]:
  29. """Returns the names of available models."""
  30. return list(_MODELS.keys())
  31. def get_assets_path():
  32. """Returns the path to the assets directory."""
  33. return os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
  34. def get_logger():
  35. """Returns the module logger."""
  36. return logging.getLogger("faster_whisper")
  37. def download_model(
  38. size_or_id: str,
  39. output_dir: Optional[str] = None,
  40. local_files_only: bool = False,
  41. cache_dir: Optional[str] = None,
  42. revision: Optional[str] = None,
  43. ):
  44. """Downloads a CTranslate2 Whisper model from the Hugging Face Hub.
  45. Args:
  46. size_or_id: Size of the model to download from https://huggingface.co/Systran
  47. (tiny, tiny.en, base, base.en, small, small.en, distil-small.en, medium, medium.en,
  48. distil-medium.en, large-v1, large-v2, large-v3, large, distil-large-v2,
  49. distil-large-v3), or a CTranslate2-converted model ID from the Hugging Face Hub
  50. (e.g. Systran/faster-whisper-large-v3).
  51. output_dir: Directory where the model should be saved. If not set, the model is saved in
  52. the cache directory.
  53. local_files_only: If True, avoid downloading the file and return the path to the local
  54. cached file if it exists.
  55. cache_dir: Path to the folder where cached files are stored.
  56. revision: An optional Git revision id which can be a branch name, a tag, or a
  57. commit hash.
  58. Returns:
  59. The path to the downloaded model.
  60. Raises:
  61. ValueError: if the model size is invalid.
  62. """
  63. if re.match(r".*/.*", size_or_id):
  64. repo_id = size_or_id
  65. else:
  66. repo_id = _MODELS.get(size_or_id)
  67. if repo_id is None:
  68. raise ValueError(
  69. "Invalid model size '%s', expected one of: %s"
  70. % (size_or_id, ", ".join(_MODELS.keys()))
  71. )
  72. allow_patterns = [
  73. "config.json",
  74. "preprocessor_config.json",
  75. "model.bin",
  76. "tokenizer.json",
  77. "vocabulary.*",
  78. ]
  79. kwargs = {
  80. "local_files_only": local_files_only,
  81. "allow_patterns": allow_patterns,
  82. "tqdm_class": disabled_tqdm,
  83. "revision": revision,
  84. }
  85. if output_dir is not None:
  86. kwargs["local_dir"] = output_dir
  87. kwargs["local_dir_use_symlinks"] = False
  88. if cache_dir is not None:
  89. kwargs["cache_dir"] = cache_dir
  90. try:
  91. return huggingface_hub.snapshot_download(repo_id, **kwargs)
  92. except (
  93. huggingface_hub.utils.HfHubHTTPError,
  94. requests.exceptions.ConnectionError,
  95. ) as exception:
  96. logger = get_logger()
  97. logger.warning(
  98. "An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s",
  99. repo_id,
  100. exception,
  101. )
  102. logger.warning(
  103. "Trying to load the model directly from the local cache, if it exists."
  104. )
  105. kwargs["local_files_only"] = True
  106. return huggingface_hub.snapshot_download(repo_id, **kwargs)
  107. def format_timestamp(
  108. seconds: float,
  109. always_include_hours: bool = False,
  110. decimal_marker: str = ".",
  111. ) -> str:
  112. assert seconds >= 0, "non-negative timestamp expected"
  113. milliseconds = round(seconds * 1000.0)
  114. hours = milliseconds // 3_600_000
  115. milliseconds -= hours * 3_600_000
  116. minutes = milliseconds // 60_000
  117. milliseconds -= minutes * 60_000
  118. seconds = milliseconds // 1_000
  119. milliseconds -= seconds * 1_000
  120. hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
  121. return (
  122. f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
  123. )
  124. class disabled_tqdm(tqdm):
  125. def __init__(self, *args, **kwargs):
  126. kwargs["disable"] = True
  127. super().__init__(*args, **kwargs)
  128. def get_end(segments: List[dict]) -> Optional[float]:
  129. return next(
  130. (w["end"] for s in reversed(segments) for w in reversed(s["words"])),
  131. segments[-1]["end"] if segments else None,
  132. )