コード例 #1
0
 def _prepare_buffer(self, file_idx: int) -> None:
     assert file_idx < len(self._buffered_files_paths)
     self._cur_buffered_path_context = BufferedPathContext.load(
         self._buffered_files_paths[file_idx])
     self._order = numpy.arange(len(self._cur_buffered_path_context))
     if self.shuffle:
         self._order = numpy.random.permutation(self._order)
     self._cur_sample_idx = 0
コード例 #2
0
    def test_forward(self):
        config = EncoderConfig(self._hidden_size, self._hidden_size, True, 0.5, 1, 0.5)

        buffered_path_contexts = BufferedPathContext.load(self._test_data_path)
        batch = PathContextBatch([buffered_path_contexts[i] for i in range(self._batch_size)])
        token_vocab_size = max(batch.context[FROM_TOKEN].max().item(), batch.context[TO_TOKEN].max().item())
        type_vocab_size = batch.context[PATH_TYPES].max().item()

        model = PathEncoder(config, self._hidden_size, token_vocab_size + 1, 0, type_vocab_size + 1, 0)

        out = model(batch.context)
        number_of_paths = sum(batch.contexts_per_label)
        self.assertTupleEqual((number_of_paths, self._hidden_size), out.shape)
コード例 #3
0
    def test_forward(self):
        config = DecoderConfig(self._hidden_size, self._hidden_size, 1, 0.5, 1)

        model = PathDecoder(config, self._out_size, 0, 0)

        buffered_path_contexts = BufferedPathContext.load(self._test_data_path)

        batch = PathContextBatch([
            buffered_path_contexts[i]
            for i in range(len(buffered_path_contexts))
        ])
        number_of_paths = sum(batch.contexts_per_label)
        fake_encoder_input = torch.rand(number_of_paths, self._hidden_size)

        output = model(fake_encoder_input, batch.contexts_per_label,
                       self._target_length)

        self.assertTupleEqual(
            (self._target_length, len(
                batch.contexts_per_label), self._out_size), output.shape)