Exemplo n.º 1
0
    def __update_state(self, in_tokens):
        """
        Update model state
        """
        data = tokenizer.parse(in_tokens)
        data = np.array(tokenizer.encode(data))
        self.processor.gather(data)
        data = self.processor.transform(data)

        x_source = np.concatenate(([self._last_token], data[:-1]), axis=0)
        self._last_token = data[-1]

        assert len(x_source) > 0

        x_source = np.array([x_source], dtype=np.int32)
        x_source = np.transpose(x_source, (1, 0))

        batch = self.builder.build_infer_batch(x_source)

        x = torch.tensor(batch.x, device=device, dtype=torch.int64)
        x_type = torch.tensor(batch.x_type, device=device, dtype=torch.int64)
        tokens = torch.tensor(batch.tokens, device=device, dtype=torch.int64)
        ids = torch.tensor(batch.ids, device=device, dtype=torch.int64)
        nums = torch.tensor(batch.nums, device=device, dtype=torch.int64)

        out: train_id.ModelOutput = self.__model(x, None, x_type, None, tokens,
                                                 ids, nums, self.h0, self.c0)

        self.h0 = out.hn.detach()
        self.c0 = out.cn.detach()
Exemplo n.º 2
0
def _read_file(path: Path) -> List[int]:
    """
    Read and encode a file
    """
    try:
        with open(str(path)) as f:
            content = f.read()

        parsed = parse_string(content)
        parsed = _remove_comments(parsed)
        parsed = _remove_empty_lines(parsed)
        parsed = _fix_indentation(parsed)
        serialized = encode(parsed)

        # deserialized = tokenizer.deserialize(serialized)
        # for i in range(len(serialized)):
        #     assert deserialized[i] == parsed[i]
        #
        # res = to_text(deserialized)
        # print(res)

        return serialized
    except:
        logger.log()
        return None