| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314 |
- import string
- from functools import cached_property
- from typing import List, Optional, Tuple
- import tokenizers
- class Tokenizer:
- """Simple wrapper around a tokenizers.Tokenizer."""
- def __init__(
- self,
- tokenizer: tokenizers.Tokenizer,
- multilingual: bool,
- task: Optional[str] = None,
- language: Optional[str] = None,
- ):
- self.tokenizer = tokenizer
- if multilingual:
- if task not in _TASKS:
- raise ValueError(
- "'%s' is not a valid task (accepted tasks: %s)"
- % (task, ", ".join(_TASKS))
- )
- if language not in _LANGUAGE_CODES:
- raise ValueError(
- "'%s' is not a valid language code (accepted language codes: %s)"
- % (language, ", ".join(_LANGUAGE_CODES))
- )
- self.task = self.tokenizer.token_to_id("<|%s|>" % task)
- self.language = self.tokenizer.token_to_id("<|%s|>" % language)
- self.language_code = language
- else:
- self.task = None
- self.language = None
- self.language_code = "en"
- @cached_property
- def transcribe(self) -> int:
- return self.tokenizer.token_to_id("<|transcribe|>")
- @cached_property
- def translate(self) -> int:
- return self.tokenizer.token_to_id("<|translate|>")
- @cached_property
- def sot(self) -> int:
- return self.tokenizer.token_to_id("<|startoftranscript|>")
- @cached_property
- def sot_lm(self) -> int:
- return self.tokenizer.token_to_id("<|startoflm|>")
- @cached_property
- def sot_prev(self) -> int:
- return self.tokenizer.token_to_id("<|startofprev|>")
- @cached_property
- def eot(self) -> int:
- return self.tokenizer.token_to_id("<|endoftext|>")
- @cached_property
- def no_timestamps(self) -> int:
- return self.tokenizer.token_to_id("<|notimestamps|>")
- @property
- def timestamp_begin(self) -> int:
- return self.no_timestamps + 1
- @property
- def sot_sequence(self) -> List[int]:
- sequence = [self.sot]
- if self.language is not None:
- sequence.append(self.language)
- if self.task is not None:
- sequence.append(self.task)
- return sequence
- def encode(self, text: str) -> List[int]:
- return self.tokenizer.encode(text, add_special_tokens=False).ids
- def decode(self, tokens: List[int]) -> str:
- text_tokens = [token for token in tokens if token < self.eot]
- return self.tokenizer.decode(text_tokens)
- def decode_with_timestamps(self, tokens: List[int]) -> str:
- outputs = [[]]
- for token in tokens:
- if token >= self.timestamp_begin:
- timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
- outputs.append(timestamp)
- outputs.append([])
- else:
- outputs[-1].append(token)
- return "".join(
- [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
- )
- @cached_property
- def non_speech_tokens(self) -> Tuple[int]:
- """
- Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
- annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
- - ♪♪♪
- - ( SPEAKING FOREIGN LANGUAGE )
- - [DAVID] Hey there,
- keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
- """
- symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
- symbols += (
- "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
- )
- # symbols that may be a single token or multiple tokens depending on the tokenizer.
- # In case they're multiple tokens, suppress the first token, which is safe because:
- # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
- # in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
- miscellaneous = set("♩♪♫♬♭♮♯")
- assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
- # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
- result = {self.encode(" -")[0], self.encode(" '")[0]}
- for symbol in symbols + list(miscellaneous):
- for tokens in [
- self.encode(symbol),
- self.encode(" " + symbol),
- ]:
- if len(tokens) == 1 or symbol in miscellaneous:
- result.add(tokens[0])
- return tuple(sorted(result))
- def split_to_word_tokens(
- self, tokens: List[int]
- ) -> Tuple[List[str], List[List[int]]]:
- if self.language_code in {"zh", "ja", "th", "lo", "my", "yue"}:
- # These languages don't typically use spaces, so it is difficult to split words
- # without morpheme analysis. Here, we instead split words at any
- # position where the tokens are decoded as valid unicode points
- return self.split_tokens_on_unicode(tokens)
- return self.split_tokens_on_spaces(tokens)
- def split_tokens_on_unicode(
- self, tokens: List[int]
- ) -> Tuple[List[str], List[List[int]]]:
- decoded_full = self.decode_with_timestamps(tokens)
- replacement_char = "\ufffd"
- words = []
- word_tokens = []
- current_tokens = []
- unicode_offset = 0
- for token in tokens:
- current_tokens.append(token)
- decoded = self.decode_with_timestamps(current_tokens)
- try:
- replacement_char_index = decoded.index(replacement_char)
- replacement_char_index += unicode_offset
- except ValueError:
- replacement_char_index = None
- if replacement_char_index is None or (
- replacement_char_index < len(decoded_full)
- and decoded_full[replacement_char_index] == replacement_char
- ):
- words.append(decoded)
- word_tokens.append(current_tokens)
- current_tokens = []
- unicode_offset += len(decoded)
- return words, word_tokens
- def split_tokens_on_spaces(
- self, tokens: List[int]
- ) -> Tuple[List[str], List[List[int]]]:
- subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
- words = []
- word_tokens = []
- for subword, subword_tokens in zip(subwords, subword_tokens_list):
- special = subword_tokens[0] >= self.eot
- with_space = subword.startswith(" ")
- punctuation = subword.strip() in string.punctuation
- if special or with_space or punctuation or len(words) == 0:
- words.append(subword)
- word_tokens.append(subword_tokens)
- else:
- words[-1] = words[-1] + subword
- word_tokens[-1].extend(subword_tokens)
- return words, word_tokens
- _TASKS = (
- "transcribe",
- "translate",
- )
- _LANGUAGE_CODES = (
- "af",
- "am",
- "ar",
- "as",
- "az",
- "ba",
- "be",
- "bg",
- "bn",
- "bo",
- "br",
- "bs",
- "ca",
- "cs",
- "cy",
- "da",
- "de",
- "el",
- "en",
- "es",
- "et",
- "eu",
- "fa",
- "fi",
- "fo",
- "fr",
- "gl",
- "gu",
- "ha",
- "haw",
- "he",
- "hi",
- "hr",
- "ht",
- "hu",
- "hy",
- "id",
- "is",
- "it",
- "ja",
- "jw",
- "ka",
- "kk",
- "km",
- "kn",
- "ko",
- "la",
- "lb",
- "ln",
- "lo",
- "lt",
- "lv",
- "mg",
- "mi",
- "mk",
- "ml",
- "mn",
- "mr",
- "ms",
- "mt",
- "my",
- "ne",
- "nl",
- "nn",
- "no",
- "oc",
- "pa",
- "pl",
- "ps",
- "pt",
- "ro",
- "ru",
- "sa",
- "sd",
- "si",
- "sk",
- "sl",
- "sn",
- "so",
- "sq",
- "sr",
- "su",
- "sv",
- "sw",
- "ta",
- "te",
- "tg",
- "th",
- "tk",
- "tl",
- "tr",
- "tt",
- "uk",
- "ur",
- "uz",
- "vi",
- "yi",
- "yo",
- "zh",
- "yue",
- )
|