tokenizer.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. import string
  2. from functools import cached_property
  3. from typing import List, Optional, Tuple
  4. import tokenizers
  5. class Tokenizer:
  6. """Simple wrapper around a tokenizers.Tokenizer."""
  7. def __init__(
  8. self,
  9. tokenizer: tokenizers.Tokenizer,
  10. multilingual: bool,
  11. task: Optional[str] = None,
  12. language: Optional[str] = None,
  13. ):
  14. self.tokenizer = tokenizer
  15. if multilingual:
  16. if task not in _TASKS:
  17. raise ValueError(
  18. "'%s' is not a valid task (accepted tasks: %s)"
  19. % (task, ", ".join(_TASKS))
  20. )
  21. if language not in _LANGUAGE_CODES:
  22. raise ValueError(
  23. "'%s' is not a valid language code (accepted language codes: %s)"
  24. % (language, ", ".join(_LANGUAGE_CODES))
  25. )
  26. self.task = self.tokenizer.token_to_id("<|%s|>" % task)
  27. self.language = self.tokenizer.token_to_id("<|%s|>" % language)
  28. self.language_code = language
  29. else:
  30. self.task = None
  31. self.language = None
  32. self.language_code = "en"
  33. @cached_property
  34. def transcribe(self) -> int:
  35. return self.tokenizer.token_to_id("<|transcribe|>")
  36. @cached_property
  37. def translate(self) -> int:
  38. return self.tokenizer.token_to_id("<|translate|>")
  39. @cached_property
  40. def sot(self) -> int:
  41. return self.tokenizer.token_to_id("<|startoftranscript|>")
  42. @cached_property
  43. def sot_lm(self) -> int:
  44. return self.tokenizer.token_to_id("<|startoflm|>")
  45. @cached_property
  46. def sot_prev(self) -> int:
  47. return self.tokenizer.token_to_id("<|startofprev|>")
  48. @cached_property
  49. def eot(self) -> int:
  50. return self.tokenizer.token_to_id("<|endoftext|>")
  51. @cached_property
  52. def no_timestamps(self) -> int:
  53. return self.tokenizer.token_to_id("<|notimestamps|>")
  54. @property
  55. def timestamp_begin(self) -> int:
  56. return self.no_timestamps + 1
  57. @property
  58. def sot_sequence(self) -> List[int]:
  59. sequence = [self.sot]
  60. if self.language is not None:
  61. sequence.append(self.language)
  62. if self.task is not None:
  63. sequence.append(self.task)
  64. return sequence
  65. def encode(self, text: str) -> List[int]:
  66. return self.tokenizer.encode(text, add_special_tokens=False).ids
  67. def decode(self, tokens: List[int]) -> str:
  68. text_tokens = [token for token in tokens if token < self.eot]
  69. return self.tokenizer.decode(text_tokens)
  70. def decode_with_timestamps(self, tokens: List[int]) -> str:
  71. outputs = [[]]
  72. for token in tokens:
  73. if token >= self.timestamp_begin:
  74. timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
  75. outputs.append(timestamp)
  76. outputs.append([])
  77. else:
  78. outputs[-1].append(token)
  79. return "".join(
  80. [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
  81. )
  82. @cached_property
  83. def non_speech_tokens(self) -> Tuple[int]:
  84. """
  85. Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
  86. annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
  87. - ♪♪♪
  88. - ( SPEAKING FOREIGN LANGUAGE )
  89. - [DAVID] Hey there,
  90. keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
  91. """
  92. symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
  93. symbols += (
  94. "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
  95. )
  96. # symbols that may be a single token or multiple tokens depending on the tokenizer.
  97. # In case they're multiple tokens, suppress the first token, which is safe because:
  98. # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
  99. # in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
  100. miscellaneous = set("♩♪♫♬♭♮♯")
  101. assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
  102. # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
  103. result = {self.encode(" -")[0], self.encode(" '")[0]}
  104. for symbol in symbols + list(miscellaneous):
  105. for tokens in [
  106. self.encode(symbol),
  107. self.encode(" " + symbol),
  108. ]:
  109. if len(tokens) == 1 or symbol in miscellaneous:
  110. result.add(tokens[0])
  111. return tuple(sorted(result))
  112. def split_to_word_tokens(
  113. self, tokens: List[int]
  114. ) -> Tuple[List[str], List[List[int]]]:
  115. if self.language_code in {"zh", "ja", "th", "lo", "my", "yue"}:
  116. # These languages don't typically use spaces, so it is difficult to split words
  117. # without morpheme analysis. Here, we instead split words at any
  118. # position where the tokens are decoded as valid unicode points
  119. return self.split_tokens_on_unicode(tokens)
  120. return self.split_tokens_on_spaces(tokens)
  121. def split_tokens_on_unicode(
  122. self, tokens: List[int]
  123. ) -> Tuple[List[str], List[List[int]]]:
  124. decoded_full = self.decode_with_timestamps(tokens)
  125. replacement_char = "\ufffd"
  126. words = []
  127. word_tokens = []
  128. current_tokens = []
  129. unicode_offset = 0
  130. for token in tokens:
  131. current_tokens.append(token)
  132. decoded = self.decode_with_timestamps(current_tokens)
  133. try:
  134. replacement_char_index = decoded.index(replacement_char)
  135. replacement_char_index += unicode_offset
  136. except ValueError:
  137. replacement_char_index = None
  138. if replacement_char_index is None or (
  139. replacement_char_index < len(decoded_full)
  140. and decoded_full[replacement_char_index] == replacement_char
  141. ):
  142. words.append(decoded)
  143. word_tokens.append(current_tokens)
  144. current_tokens = []
  145. unicode_offset += len(decoded)
  146. return words, word_tokens
  147. def split_tokens_on_spaces(
  148. self, tokens: List[int]
  149. ) -> Tuple[List[str], List[List[int]]]:
  150. subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
  151. words = []
  152. word_tokens = []
  153. for subword, subword_tokens in zip(subwords, subword_tokens_list):
  154. special = subword_tokens[0] >= self.eot
  155. with_space = subword.startswith(" ")
  156. punctuation = subword.strip() in string.punctuation
  157. if special or with_space or punctuation or len(words) == 0:
  158. words.append(subword)
  159. word_tokens.append(subword_tokens)
  160. else:
  161. words[-1] = words[-1] + subword
  162. word_tokens[-1].extend(subword_tokens)
  163. return words, word_tokens
  164. _TASKS = (
  165. "transcribe",
  166. "translate",
  167. )
  168. _LANGUAGE_CODES = (
  169. "af",
  170. "am",
  171. "ar",
  172. "as",
  173. "az",
  174. "ba",
  175. "be",
  176. "bg",
  177. "bn",
  178. "bo",
  179. "br",
  180. "bs",
  181. "ca",
  182. "cs",
  183. "cy",
  184. "da",
  185. "de",
  186. "el",
  187. "en",
  188. "es",
  189. "et",
  190. "eu",
  191. "fa",
  192. "fi",
  193. "fo",
  194. "fr",
  195. "gl",
  196. "gu",
  197. "ha",
  198. "haw",
  199. "he",
  200. "hi",
  201. "hr",
  202. "ht",
  203. "hu",
  204. "hy",
  205. "id",
  206. "is",
  207. "it",
  208. "ja",
  209. "jw",
  210. "ka",
  211. "kk",
  212. "km",
  213. "kn",
  214. "ko",
  215. "la",
  216. "lb",
  217. "ln",
  218. "lo",
  219. "lt",
  220. "lv",
  221. "mg",
  222. "mi",
  223. "mk",
  224. "ml",
  225. "mn",
  226. "mr",
  227. "ms",
  228. "mt",
  229. "my",
  230. "ne",
  231. "nl",
  232. "nn",
  233. "no",
  234. "oc",
  235. "pa",
  236. "pl",
  237. "ps",
  238. "pt",
  239. "ro",
  240. "ru",
  241. "sa",
  242. "sd",
  243. "si",
  244. "sk",
  245. "sl",
  246. "sn",
  247. "so",
  248. "sq",
  249. "sr",
  250. "su",
  251. "sv",
  252. "sw",
  253. "ta",
  254. "te",
  255. "tg",
  256. "th",
  257. "tk",
  258. "tl",
  259. "tr",
  260. "tt",
  261. "uk",
  262. "ur",
  263. "uz",
  264. "vi",
  265. "yi",
  266. "yo",
  267. "zh",
  268. "yue",
  269. )