def test(embedding_matrix, entity_num, entity_embedding_dim, rnn_hidden_size, vocab_size, start_token, max_sent_num, p1, p1_mask, entity_keys, keys_mask, encoder_path, decoder_path, eos_ind): encoder = Model.BasicRecurrentEntityEncoder( embedding_matrix=embedding_matrix, max_entity_num=entity_num, entity_embedding_dim=entity_embedding_dim) temp_entity_cell, temp_entities = encoder([p1, p1_mask], entity_keys) print("temp_entities shape", temp_entities.shape) decoder = Model.RNNRecurrentEntitiyDecoder( embedding_matrix=embedding_matrix, rnn_hidden_size=rnn_hidden_size, entity_cell=temp_entity_cell, vocab_size=vocab_size, max_sent_num=max_sent_num, entity_embedding_dim=entity_embedding_dim) ' training the model for one step just to initialize all variables ' decoder_inputs_train = [True, temp_entities, vocab_size, start_token] labels = [p2, p2_mask] decoder(inputs=decoder_inputs_train, keys=entity_keys, keys_mask=keys_mask, training=True, labels=labels) max_sent_num = tf.shape(p1)[1] max_sent_len = tf.shape(p1)[2] ' restoring saved models ' checkpoint_dir_encoder = encoder_path os.makedirs(checkpoint_dir_encoder, exist_ok=True) checkpoint_prefix_encoder = os.path.join(checkpoint_dir_encoder, 'ckpt') tfe.Saver(encoder.variables).restore(checkpoint_prefix_encoder) checkpoint_dir_decoder = decoder_path os.makedirs(checkpoint_dir_decoder, exist_ok=True) checkpoint_prefix_decoder = os.path.join(checkpoint_dir_decoder, 'ckpt') tfe.Saver(decoder.variables).restore(checkpoint_prefix_decoder) entity_cell, entity_hiddens = encoder([p1, p1_mask], entity_keys) # print("entity_hiddens shape:", entity_hiddens) decoder_inputs_test = [ entity_hiddens, max_sent_num, max_sent_len, eos_ind, start_token ] generated_prgrph, second_prgrph_entities = decoder( inputs=decoder_inputs_test, keys=entity_keys, keys_mask=keys_mask, training=False, return_last=False) print(generated_prgrph) print(second_prgrph_entities.shape)