Ejemplo n.º 1
0
    def test_insert_accurate_when_buffer_under_full(self):
        seq = RotatingSequence(3)
        seq.insert("a")
        seq.insert("b")

        self.assertSequenceEqual(["a", "b", None], seq._buffer)
        self.assertEqual(2, seq._num_items)
        self.assertEqual(2, seq._i_next)
Ejemplo n.º 2
0
    def test_retrieve_accurate_when_buffer_circled_around_and_random_num_insertions(
            self):
        chars = [chr(c) for c in range(ord('a'), ord('z') + 1)]
        for i in range(20):  # 20 random test runs
            seq = RotatingSequence(3)
            num_insertions = random.randint(4, 26)
            for n in range(num_insertions):
                seq.insert(chars[n])

            self.assertSequenceEqual(chars[num_insertions - 3:num_insertions],
                                     seq.retrieve())
    def __init__(self, tokenizer, model, device=None, mem_length=512):
        self._tokenizer = tokenizer
        self._lm = model
        self._device = device
        self._lm_output_states = [None] * (self._NUM_APPENDED_MASKS_MAX - self._NUM_APPENDED_MASKS_MIN + 1)

        # mem_length - _NUM_APPENDED_MASKS_MAX because we will always append up to _NUM_APPENDED_MASKS_MAX [MASK]
        # tokens in each forward feed.
        self._token_ids_memory = RotatingSequence(mem_length - self._NUM_APPENDED_MASKS_MAX)
        self._token_ids_memory.insert(tokenizer.sep_token_id)

        self._re_sep_required = re.compile(r"([\.\?;])(\s|$)")
        self._re_sep_repl_str = r"\1 {} ".format(tokenizer.sep_token)
Ejemplo n.º 4
0
    def test_retrieve_accurate_when_buffer_full(self):
        seq = RotatingSequence(3)
        seq.insert("a")
        seq.insert("b")
        seq.insert("c")

        self.assertSequenceEqual(["a", "b", "c"], seq.retrieve())
Ejemplo n.º 5
0
    def test_retrieve_accurate_when_buffer_circled_around(self):
        seq = RotatingSequence(3)
        seq.insert("a")
        seq.insert("b")
        seq.insert("c")
        seq.insert("d")

        self.assertSequenceEqual(["b", "c", "d"], seq.retrieve())
    def __init__(self,
                 tokenizer,
                 model,
                 device=None,
                 use_past=False,
                 mem_length=256):
        """
        use_past=True will speed up inference at the cost of higher memory consumption.  Set to False to avoid out of
        memory errors occurring after repeated forward feeds to the model.
        """
        self._tokenizer = tokenizer
        self._lm = model
        self._device = device
        self._past = None
        self._model_output = None

        self._use_past = use_past
        if not use_past:
            self._tokens_history = RotatingSequence(mem_length)
Ejemplo n.º 7
0
    def test_insert_accurate_when_buffer_reaches_full(self):
        seq = RotatingSequence(3)
        seq.insert("a")
        seq.insert("b")
        seq.insert("c")

        self.assertSequenceEqual(["a", "b", "c"], seq._buffer)
        self.assertEqual(3, seq._num_items)
        self.assertEqual(0, seq._i_next)
class Gpt2WordPredictor(IWordPredictor):
    def __init__(self,
                 tokenizer,
                 model,
                 device=None,
                 use_past=False,
                 mem_length=256):
        """
        use_past=True will speed up inference at the cost of higher memory consumption.  Set to False to avoid out of
        memory errors occurring after repeated forward feeds to the model.
        """
        self._tokenizer = tokenizer
        self._lm = model
        self._device = device
        self._past = None
        self._model_output = None

        self._use_past = use_past
        if not use_past:
            self._tokens_history = RotatingSequence(mem_length)

    @property
    def prepends_spaces(self):
        return True

    def feed(self, text, **kwargs):
        if not text:
            return

        # todo: how to handle new lines?
        prefix_space = text.startswith(" ")

        tokens = self._tokenizer.encode(text, add_prefix_space=prefix_space)

        if self._use_past:
            tokens_tensor = torch.tensor([tokens]).to(self._device)
            self._model_output, self._past = self._lm(tokens_tensor,
                                                      past=self._past)
        else:
            for t in tokens:
                self._tokens_history.insert(t)
            with torch.no_grad():
                tokens_tensor = torch.tensor([self._tokens_history.retrieve()
                                              ]).to(self._device)
                self._model_output = self._lm(tokens_tensor)

    def top_n_next(self, n):
        i_layer = 0
        i_batch = 0
        i_final_token = -1

        if self._use_past:
            if len(self._model_output.size()) == 2:
                p = self._model_output[i_layer]
            else:
                p = self._model_output[i_layer][i_final_token]
        else:
            p = self._model_output[i_layer][i_batch][-1]

        top_n = torch.topk(torch.softmax(p, 0, torch.float32), n)

        scores = [v.item() for v in top_n.values]
        tokens = [self._tokenizer.decode(t) for t in top_n.indices.tolist()]

        return zip(scores, tokens)
class BertWordPredictor(IWordPredictor):
    _INDEX_FINAL_LAYER = 0
    _INDEX_BATCH = 0
    _INDEX_FINAL_TOKEN = -1

    _NUM_APPENDED_MASKS_MIN = 5
    _NUM_APPENDED_MASKS_MAX = 5

    def __init__(self, tokenizer, model, device=None, mem_length=512):
        self._tokenizer = tokenizer
        self._lm = model
        self._device = device
        self._lm_output_states = [None] * (self._NUM_APPENDED_MASKS_MAX - self._NUM_APPENDED_MASKS_MIN + 1)

        # mem_length - _NUM_APPENDED_MASKS_MAX because we will always append up to _NUM_APPENDED_MASKS_MAX [MASK]
        # tokens in each forward feed.
        self._token_ids_memory = RotatingSequence(mem_length - self._NUM_APPENDED_MASKS_MAX)
        self._token_ids_memory.insert(tokenizer.sep_token_id)

        self._re_sep_required = re.compile(r"([\.\?;])(\s|$)")
        self._re_sep_repl_str = r"\1 {} ".format(tokenizer.sep_token)

    @property
    def prepends_spaces(self):
        return False

    def feed(self, text, **kwargs):
        try:
            segment_id = kwargs["segment_id"]
        except KeyError:
            segment_id = 0

        # Tokenize text and insert SEP token after each/any sentence terminating character.
        tokenized_text = self._tokenizer.tokenize(self._re_sep_required.sub(self._re_sep_repl_str, text))

        # Convert tokens to their ids and append to memory
        text_token_ids = self._tokenizer.convert_tokens_to_ids(tokenized_text)
        for token_id in text_token_ids:  # todo: performance: better to create and use an insert_iterable method
            self._token_ids_memory.insert(token_id)

        # The complete tokenized text fed to the model each time is the concatenation of:
        # previously fed tokens + tokenized text + m x MASK tokens; total number of tokens fed is limited to mem_length
        # We perform n iterations of feeding to the model, changing m at each iteration.
        # We average the word predictions across the n iterations.
        for n in range(self._NUM_APPENDED_MASKS_MAX - self._NUM_APPENDED_MASKS_MIN + 1):
            token_ids_to_feed = self._token_ids_memory.retrieve() + [self._tokenizer.mask_token_id] * (self._NUM_APPENDED_MASKS_MIN + n)

            # Convert to tensor and push to GPU
            tokens_tensor = torch.tensor([token_ids_to_feed]).to(self._device)

            # BERT uses A/B segments.  We have to assign the tokens to the A or B segment.
            segment_ids = [segment_id] * len(token_ids_to_feed)
            segments_tensor = torch.tensor([segment_ids]).to(self._device)

            # Feed forward all the tokens
            with torch.no_grad():   # todo: performance: repeated context creation/destruction
                self._lm_output_states[n] = self._lm(tokens_tensor, token_type_ids=segments_tensor)[self._INDEX_FINAL_LAYER][self._INDEX_BATCH][-self._NUM_APPENDED_MASKS_MIN - n]

    def top_n_next(self, n):
        # debugging
        #print(self._tokenizer.convert_ids_to_tokens(self._token_ids_memory.retrieve()))

        top_n = torch.topk(
            torch.softmax(
                torch.sum(torch.stack(self._lm_output_states), dim=0),
                0,
                torch.float32),
            n)

        scores = [v.item() for v in top_n.values]
        tokens = self._tokenizer.convert_ids_to_tokens([i.item() for i in top_n.indices])

        return zip(scores, tokens)