Пример #1
0
from __future__ import division
from __future__ import print_function


# internal
from load_data import Data


if __name__ == '__main__':

    dataset = []
    # Load data
    data = Data(dataset='WN18', reverse=True)

    # Prepare train input and targets
    train_data_idxs = data.get_data_idxs(
        data.train_data, data.entity_idxs, data.relation_idxs)

    print('Number of training data points: {}'.format(len(train_data_idxs)))

    dataset.append({'name': 'train', 'data': train_data_idxs})

    valid_data_idxs = data.get_data_idxs(
        data.valid_data, data.entity_idxs, data.relation_idxs)

    print('Number of validation data points: {}'.format(len(valid_data_idxs)))

    dataset.append({'name': 'val', 'data': valid_data_idxs})

    test_data_idxs = data.get_data_idxs(
        data.test_data, data.entity_idxs, data.relation_idxs)
Пример #2
0
def metric_fn(features, labels, logits):

    hits = []
    ranks = []

    for i in range(10):
        hits.append([])

    samples = []

    data = Data(dataset='WN18', reverse=True)

    valid_data_idxs = data.get_data_idxs(data.valid_data, data.entity_idxs,
                                         data.relation_idxs)

    # test_data_idxs = data.get_data_idxs(data)
    er_vocab = data.get_er_vocab(
        data.get_data_idxs(data.valid_data, data.entity_idxs,
                           data.relation_idxs))

    for i in range(0, len(valid_data_idxs), 128):

        data_batch, _ = data.get_batch(er_vocab, valid_data_idxs, i)

        e2_idx = data_batch[:, 2]

    for j in range(data_batch.shape[0]):

        logits_list = []

        filt = er_vocab[(data_batch[j][0], data_batch[j][1])]
        target_value = logits[j][e2_idx[j]]

        logits_unstacked = tf.unstack(logits[j])

        for it in range(len(logits_unstacked)):
            logits_list.append(logits_unstacked[it])

        for k in range(len(filt)):
            result = tf.cond(tf.equal(e2_idx[j], filt[k]),
                             lambda: target_value, lambda: 0.0)
            logits_list[filt[k]] = result

        samples.append(logits_list)

    e2_idx = labels
    logits = tf.stack(samples)

    sort_idxs = tf.argsort(logits, axis=1, direction='DESCENDING')

    for j in range(logits.shape[0]):

        rank = tf.where(tf.equal(sort_idxs[j], e2_idx[j]))
        ranks.append(rank + 1)

        for hits_level in range(10):

            result = tf.cond(
                tf.squeeze(rank) <= hits_level, lambda: 1.0, lambda: 0.0)
            hits[hits_level].append(result)

    accuracy = tf.metrics.mean(hits[9])

    return {"accuracy": accuracy}