def _prepare_buffer(self, file_idx: int) -> None: assert file_idx < len(self._buffered_files_paths) self._cur_buffered_path_context = BufferedPathContext.load( self._buffered_files_paths[file_idx]) self._order = numpy.arange(len(self._cur_buffered_path_context)) if self.shuffle: self._order = numpy.random.permutation(self._order) self._cur_sample_idx = 0
def test_forward(self): config = EncoderConfig(self._hidden_size, self._hidden_size, True, 0.5, 1, 0.5) buffered_path_contexts = BufferedPathContext.load(self._test_data_path) batch = PathContextBatch([buffered_path_contexts[i] for i in range(self._batch_size)]) token_vocab_size = max(batch.context[FROM_TOKEN].max().item(), batch.context[TO_TOKEN].max().item()) type_vocab_size = batch.context[PATH_TYPES].max().item() model = PathEncoder(config, self._hidden_size, token_vocab_size + 1, 0, type_vocab_size + 1, 0) out = model(batch.context) number_of_paths = sum(batch.contexts_per_label) self.assertTupleEqual((number_of_paths, self._hidden_size), out.shape)
def test_forward(self): config = DecoderConfig(self._hidden_size, self._hidden_size, 1, 0.5, 1) model = PathDecoder(config, self._out_size, 0, 0) buffered_path_contexts = BufferedPathContext.load(self._test_data_path) batch = PathContextBatch([ buffered_path_contexts[i] for i in range(len(buffered_path_contexts)) ]) number_of_paths = sum(batch.contexts_per_label) fake_encoder_input = torch.rand(number_of_paths, self._hidden_size) output = model(fake_encoder_input, batch.contexts_per_label, self._target_length) self.assertTupleEqual( (self._target_length, len( batch.contexts_per_label), self._out_size), output.shape)