Example #1
0
def test(embedding_matrix, entity_num, entity_embedding_dim, rnn_hidden_size,
         vocab_size, start_token, max_sent_num, p1, p1_mask, entity_keys,
         keys_mask, encoder_path, decoder_path, eos_ind):
    encoder = Model.BasicRecurrentEntityEncoder(
        embedding_matrix=embedding_matrix,
        max_entity_num=entity_num,
        entity_embedding_dim=entity_embedding_dim)

    temp_entity_cell, temp_entities = encoder([p1, p1_mask], entity_keys)

    print("temp_entities shape", temp_entities.shape)

    decoder = Model.RNNRecurrentEntitiyDecoder(
        embedding_matrix=embedding_matrix,
        rnn_hidden_size=rnn_hidden_size,
        entity_cell=temp_entity_cell,
        vocab_size=vocab_size,
        max_sent_num=max_sent_num,
        entity_embedding_dim=entity_embedding_dim)

    ' training the model for one step just to initialize all variables '
    decoder_inputs_train = [True, temp_entities, vocab_size, start_token]
    labels = [p2, p2_mask]
    decoder(inputs=decoder_inputs_train,
            keys=entity_keys,
            keys_mask=keys_mask,
            training=True,
            labels=labels)
    max_sent_num = tf.shape(p1)[1]
    max_sent_len = tf.shape(p1)[2]

    ' restoring saved models '
    checkpoint_dir_encoder = encoder_path
    os.makedirs(checkpoint_dir_encoder, exist_ok=True)
    checkpoint_prefix_encoder = os.path.join(checkpoint_dir_encoder, 'ckpt')
    tfe.Saver(encoder.variables).restore(checkpoint_prefix_encoder)

    checkpoint_dir_decoder = decoder_path
    os.makedirs(checkpoint_dir_decoder, exist_ok=True)
    checkpoint_prefix_decoder = os.path.join(checkpoint_dir_decoder, 'ckpt')
    tfe.Saver(decoder.variables).restore(checkpoint_prefix_decoder)

    entity_cell, entity_hiddens = encoder([p1, p1_mask], entity_keys)
    # print("entity_hiddens shape:", entity_hiddens)
    decoder_inputs_test = [
        entity_hiddens, max_sent_num, max_sent_len, eos_ind, start_token
    ]
    generated_prgrph, second_prgrph_entities = decoder(
        inputs=decoder_inputs_test,
        keys=entity_keys,
        keys_mask=keys_mask,
        training=False,
        return_last=False)
    print(generated_prgrph)
    print(second_prgrph_entities.shape)