SimilarSentenceSplitter.py 1.1 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. from typing import List
  2. from .Splitter import Splitter
  3. class SimilarSentenceSplitter(Splitter):
  4. def __init__(self, similarity_model, sentence_splitter: Splitter):
  5. self.model = similarity_model
  6. self.sentence_splitter = sentence_splitter
  7. def split(self, text: str, group_max_sentences=5) -> List[str]:
  8. '''
  9. group_max_sentences: The maximum number of sentences in a group.
  10. '''
  11. sentences = self.sentence_splitter.split(text)
  12. if len(sentences) == 0:
  13. return []
  14. similarities = self.model.similarities(sentences)
  15. # The first sentence is always in the first group.
  16. groups = [[sentences[0]]]
  17. # Using the group min/max sentences contraints,
  18. # group together the rest of the sentences.
  19. for i in range(1, len(sentences)):
  20. if len(groups[-1]) >= group_max_sentences:
  21. groups.append([sentences[i]])
  22. elif similarities[i-1] >= self.model.similarity_threshold:
  23. groups[-1].append(sentences[i])
  24. else:
  25. groups.append([sentences[i]])
  26. return groups