コード例 #1
0
ファイル: end_to_end_test.py プロジェクト: LivingMatrix/kglib
    def test_end_to_end(self):
        # Unzip the Grakn distribution containing our data
        sub.run(['unzip', 'external/animaltrade_dist/file/downloaded', '-d',
                          'external/animaltrade_dist/file/downloaded-unzipped'])

        # Start Grakn
        sub.run(['external/animaltrade_dist/file/downloaded-unzipped/grakn-animaltrade/grakn', 'server', 'start'])

        modes = (TRAIN, EVAL)

        client = grakn.Grakn(uri=URI)
        sessions = server_mgmt.get_sessions(client, KEYSPACES)
        transactions = server_mgmt.get_transactions(sessions)

        batch_size = NUM_PER_CLASS * FLAGS.num_classes
        kgcn = model.KGCN(NEIGHBOUR_SAMPLE_SIZES,
                          FLAGS.features_size,
                          FLAGS.starting_concepts_features_size,
                          FLAGS.aggregated_size,
                          FLAGS.embedding_size,
                          transactions[TRAIN],
                          batch_size,
                          neighbour_sampling_method=random_sampling.random_sample,
                          neighbour_sampling_limit_factor=4)

        optimizer = tf.train.GradientDescentOptimizer(learning_rate=FLAGS.learning_rate)
        classifier = classify.SupervisedKGCNClassifier(kgcn, optimizer, FLAGS.num_classes, None,
                                                       max_training_steps=FLAGS.max_training_steps)

        feed_dicts = {}

        sampling_params = {
            TRAIN: {'sample_size': NUM_PER_CLASS, 'population_size': POPULATION_SIZE_PER_CLASS},
            EVAL: {'sample_size': NUM_PER_CLASS, 'population_size': POPULATION_SIZE_PER_CLASS},
            PREDICT: {'sample_size': NUM_PER_CLASS, 'population_size': POPULATION_SIZE_PER_CLASS},
        }
        concepts, labels = thing_mgmt.compile_labelled_concepts(EXAMPLES_QUERY, EXAMPLE_CONCEPT_TYPE,
                                                                LABEL_ATTRIBUTE_TYPE, ATTRIBUTE_VALUES,
                                                                transactions[TRAIN], transactions[PREDICT],
                                                                sampling_params)

        for mode in modes:
            feed_dicts[mode] = classifier.get_feed_dict(sessions[mode], concepts[mode], labels=labels[mode])

        # Note: The ground-truth attribute labels haven't been removed from Grakn, so the results found here are
        # invalid, and used as an end-to-end test only

        # Train
        if TRAIN in modes:
            print("\n\n********** TRAIN Keyspace **********")
            classifier.train(feed_dicts[TRAIN])

        # Eval
        if EVAL in modes:
            print("\n\n********** EVAL Keyspace **********")
            # Presently, eval keyspace is the same as the TRAIN keyspace
            classifier.eval(feed_dicts[EVAL])

        server_mgmt.close(sessions)
        server_mgmt.close(transactions)
コード例 #2
0
def main(modes=(TRAIN, EVAL, PREDICT)):

    client = grakn.Grakn(uri=URI)
    sessions = grakn_mgmt.get_sessions(client, KEYSPACES)
    transactions = grakn_mgmt.get_transactions(sessions)

    batch_size = NUM_PER_CLASS * FLAGS.num_classes
    kgcn = model.KGCN(NEIGHBOUR_SAMPLE_SIZES,
                      FLAGS.features_size,
                      FLAGS.starting_concepts_features_size,
                      FLAGS.aggregated_size,
                      FLAGS.embedding_size,
                      transactions[TRAIN],
                      batch_size,
                      neighbour_sampling_method=random_sampling.random_sample,
                      neighbour_sampling_limit_factor=4)

    optimizer = tf.train.GradientDescentOptimizer(
        learning_rate=FLAGS.learning_rate)
    classifier = classify.SupervisedKGCNClassifier(
        kgcn,
        optimizer,
        FLAGS.num_classes,
        FLAGS.log_dir,
        max_training_steps=FLAGS.max_training_steps)

    feed_dicts = {}
    feed_dict_storer = persistence.FeedDictStorer(BASE_PATH + 'input/')

    # Overwrites any saved data
    try:
        for mode in modes:
            feed_dicts[mode] = feed_dict_storer.retrieve_feed_dict(mode)

    except FileNotFoundError:

        # Check if saved concepts and labels exist, and act accordingly
        # if check_for_saved_labelled_concepts(SAVED_LABELS_PATH, modes):
        try:
            concepts, labels = prs.load_saved_labelled_concepts(
                KEYSPACES, transactions, SAVED_LABELS_PATH)
        except FileNotFoundError:
            sampling_params = {
                TRAIN: {
                    'sample_size': NUM_PER_CLASS,
                    'population_size': POPULATION_SIZE_PER_CLASS
                },
                EVAL: {
                    'sample_size': NUM_PER_CLASS,
                    'population_size': POPULATION_SIZE_PER_CLASS
                },
                PREDICT: {
                    'sample_size': NUM_PER_CLASS,
                    'population_size': POPULATION_SIZE_PER_CLASS
                },
            }
            concepts, labels = thing_mgmt.compile_labelled_concepts(
                EXAMPLES_QUERY, EXAMPLE_CONCEPT_TYPE, LABEL_ATTRIBUTE_TYPE,
                ATTRIBUTE_VALUES, transactions[TRAIN], transactions[PREDICT],
                sampling_params)
            prs.save_labelled_concepts(KEYSPACES, concepts, labels,
                                       SAVED_LABELS_PATH)

            thing_mgmt.delete_all_labels_from_keyspaces(
                transactions, LABEL_ATTRIBUTE_TYPE)

            # Get new transactions since deleting labels requires committing and therefore closes transactions
            transactions = grakn_mgmt.get_transactions(sessions)
            # We need to re-fetch the sample concepts, since we need live transactions where the labels are removed
            concepts, labels = prs.load_saved_labelled_concepts(
                KEYSPACES, transactions, SAVED_LABELS_PATH)

        for mode in modes:
            feed_dicts[mode] = classifier.get_feed_dict(sessions[mode],
                                                        concepts[mode],
                                                        labels=labels[mode])
            feed_dict_storer.store_feed_dict(mode, feed_dicts[mode])

    # Train
    if TRAIN in modes:
        print("\n\n********** TRAIN Keyspace **********")
        classifier.train(feed_dicts[TRAIN])

    # Eval
    if EVAL in modes:
        print("\n\n********** EVAL Keyspace **********")
        # Presently, eval keyspace is the same as the TRAIN keyspace
        classifier.eval(feed_dicts[EVAL])

    # Predict
    if PREDICT in modes:
        print("\n\n********** PREDICT Keyspace **********")
        # We're using unseen data, but since we have labels we can use classifier.eval rather than classifier.predict
        classifier.eval(feed_dicts[PREDICT])

    grakn_mgmt.close(sessions)
    grakn_mgmt.close(transactions)