def test_end_to_end(self): # Unzip the Grakn distribution containing our data sub.run(['unzip', 'external/animaltrade_dist/file/downloaded', '-d', 'external/animaltrade_dist/file/downloaded-unzipped']) # Start Grakn sub.run(['external/animaltrade_dist/file/downloaded-unzipped/grakn-animaltrade/grakn', 'server', 'start']) modes = (TRAIN, EVAL) client = grakn.Grakn(uri=URI) sessions = server_mgmt.get_sessions(client, KEYSPACES) transactions = server_mgmt.get_transactions(sessions) batch_size = NUM_PER_CLASS * FLAGS.num_classes kgcn = model.KGCN(NEIGHBOUR_SAMPLE_SIZES, FLAGS.features_size, FLAGS.starting_concepts_features_size, FLAGS.aggregated_size, FLAGS.embedding_size, transactions[TRAIN], batch_size, neighbour_sampling_method=random_sampling.random_sample, neighbour_sampling_limit_factor=4) optimizer = tf.train.GradientDescentOptimizer(learning_rate=FLAGS.learning_rate) classifier = classify.SupervisedKGCNClassifier(kgcn, optimizer, FLAGS.num_classes, None, max_training_steps=FLAGS.max_training_steps) feed_dicts = {} sampling_params = { TRAIN: {'sample_size': NUM_PER_CLASS, 'population_size': POPULATION_SIZE_PER_CLASS}, EVAL: {'sample_size': NUM_PER_CLASS, 'population_size': POPULATION_SIZE_PER_CLASS}, PREDICT: {'sample_size': NUM_PER_CLASS, 'population_size': POPULATION_SIZE_PER_CLASS}, } concepts, labels = thing_mgmt.compile_labelled_concepts(EXAMPLES_QUERY, EXAMPLE_CONCEPT_TYPE, LABEL_ATTRIBUTE_TYPE, ATTRIBUTE_VALUES, transactions[TRAIN], transactions[PREDICT], sampling_params) for mode in modes: feed_dicts[mode] = classifier.get_feed_dict(sessions[mode], concepts[mode], labels=labels[mode]) # Note: The ground-truth attribute labels haven't been removed from Grakn, so the results found here are # invalid, and used as an end-to-end test only # Train if TRAIN in modes: print("\n\n********** TRAIN Keyspace **********") classifier.train(feed_dicts[TRAIN]) # Eval if EVAL in modes: print("\n\n********** EVAL Keyspace **********") # Presently, eval keyspace is the same as the TRAIN keyspace classifier.eval(feed_dicts[EVAL]) server_mgmt.close(sessions) server_mgmt.close(transactions)
def main(modes=(TRAIN, EVAL, PREDICT)): client = grakn.Grakn(uri=URI) sessions = grakn_mgmt.get_sessions(client, KEYSPACES) transactions = grakn_mgmt.get_transactions(sessions) batch_size = NUM_PER_CLASS * FLAGS.num_classes kgcn = model.KGCN(NEIGHBOUR_SAMPLE_SIZES, FLAGS.features_size, FLAGS.starting_concepts_features_size, FLAGS.aggregated_size, FLAGS.embedding_size, transactions[TRAIN], batch_size, neighbour_sampling_method=random_sampling.random_sample, neighbour_sampling_limit_factor=4) optimizer = tf.train.GradientDescentOptimizer( learning_rate=FLAGS.learning_rate) classifier = classify.SupervisedKGCNClassifier( kgcn, optimizer, FLAGS.num_classes, FLAGS.log_dir, max_training_steps=FLAGS.max_training_steps) feed_dicts = {} feed_dict_storer = persistence.FeedDictStorer(BASE_PATH + 'input/') # Overwrites any saved data try: for mode in modes: feed_dicts[mode] = feed_dict_storer.retrieve_feed_dict(mode) except FileNotFoundError: # Check if saved concepts and labels exist, and act accordingly # if check_for_saved_labelled_concepts(SAVED_LABELS_PATH, modes): try: concepts, labels = prs.load_saved_labelled_concepts( KEYSPACES, transactions, SAVED_LABELS_PATH) except FileNotFoundError: sampling_params = { TRAIN: { 'sample_size': NUM_PER_CLASS, 'population_size': POPULATION_SIZE_PER_CLASS }, EVAL: { 'sample_size': NUM_PER_CLASS, 'population_size': POPULATION_SIZE_PER_CLASS }, PREDICT: { 'sample_size': NUM_PER_CLASS, 'population_size': POPULATION_SIZE_PER_CLASS }, } concepts, labels = thing_mgmt.compile_labelled_concepts( EXAMPLES_QUERY, EXAMPLE_CONCEPT_TYPE, LABEL_ATTRIBUTE_TYPE, ATTRIBUTE_VALUES, transactions[TRAIN], transactions[PREDICT], sampling_params) prs.save_labelled_concepts(KEYSPACES, concepts, labels, SAVED_LABELS_PATH) thing_mgmt.delete_all_labels_from_keyspaces( transactions, LABEL_ATTRIBUTE_TYPE) # Get new transactions since deleting labels requires committing and therefore closes transactions transactions = grakn_mgmt.get_transactions(sessions) # We need to re-fetch the sample concepts, since we need live transactions where the labels are removed concepts, labels = prs.load_saved_labelled_concepts( KEYSPACES, transactions, SAVED_LABELS_PATH) for mode in modes: feed_dicts[mode] = classifier.get_feed_dict(sessions[mode], concepts[mode], labels=labels[mode]) feed_dict_storer.store_feed_dict(mode, feed_dicts[mode]) # Train if TRAIN in modes: print("\n\n********** TRAIN Keyspace **********") classifier.train(feed_dicts[TRAIN]) # Eval if EVAL in modes: print("\n\n********** EVAL Keyspace **********") # Presently, eval keyspace is the same as the TRAIN keyspace classifier.eval(feed_dicts[EVAL]) # Predict if PREDICT in modes: print("\n\n********** PREDICT Keyspace **********") # We're using unseen data, but since we have labels we can use classifier.eval rather than classifier.predict classifier.eval(feed_dicts[PREDICT]) grakn_mgmt.close(sessions) grakn_mgmt.close(transactions)