def transfer_learn(): test_actions = [ 'CleanAndJerk', 'MoppingFloor', 'FrontCrawl', # 'Surfing', 'Bowling', 'SoccerPenalty', 'SumoWrestling', 'Shotput', 'PlayingSitar', 'FloorGymnastics', # 'Typing', 'JumpingJack', 'ShavingBeard', 'FrisbeeCatch', 'WritingOnBoard', 'JavelinThrow', 'Fencing', # 'FieldHockeyPenalty', # 'BaseballPitch', 'CuttingInKitchen', # 'Kayaking', ] train_dataset, test_dataset = get_traditional_dataset( base_address=BASE_ADDRESS, class_sample_size=CLASS_SAMPLE_SIZE, test_actions=test_actions) with tf.variable_scope('train_data'): input_data_ph = tf.placeholder(dtype=tf.float32, shape=[None, 16, 112, 112, 3]) input_labels_ph = tf.placeholder(dtype=tf.float32, shape=[None, 85]) tf.summary.image('train', input_data_ph[:, 0, :, :, :], max_outputs=25) with tf.variable_scope('validation_data'): val_data_ph = tf.placeholder(dtype=tf.float32, shape=[None, 16, 112, 112, 3]) val_labels_ph = tf.placeholder(dtype=tf.float32, shape=[None, 85]) tf.summary.image('validation', val_data_ph[:, 0, :, :, :], max_outputs=25) gpu_devices = ['/gpu:{}'.format(gpu_id) for gpu_id in range(NUM_GPUS)] maml = ModelAgnosticMetaLearning(C3DNetwork, input_data_ph, input_labels_ph, val_data_ph, val_labels_ph, log_dir=LOG_DIR, gpu_devices=gpu_devices, meta_learn_rate=0.00001, learning_rate=0.001, train=TRAIN, log_device_placement=False) maml.load_model(path='MAML/sports1m_pretrained.model', load_last_layer=False) for it in range(TRANSFER_LEARNING_ITERATIONS): print(it) data = train_dataset.next_batch(num_classes=85, real_labels=True) batch_test_data, batch_test_labels = data['train'] batch_test_val_data, batch_test_val_labels = data['validation'] batch_split_size = int(NUM_CLASSES / BATCH_SPLIT_NUM) for batch_split_index in range(BATCH_SPLIT_NUM): start = batch_split_index * batch_split_size end = batch_split_index * batch_split_size + batch_split_size test_data = batch_test_data[start:end, :, :, :, :] test_labels = batch_test_labels[start:end, :] test_val_data = batch_test_val_data[start:end, :, :, :, :] test_val_labels = batch_test_val_labels[start:end, :] if it % 50 == 0: if batch_split_index == 0: val_accs = [] if it % 100 == 0: maml.save_model( 'saved_models/transfer_learning_85/model', step=it) run_options = tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() merged_summary = maml.sess.run(maml.merged, feed_dict={ input_data_ph: test_data, input_labels_ph: test_labels, val_data_ph: test_val_data, val_labels_ph: test_val_labels, }, options=run_options, run_metadata=run_metadata) maml.file_writer.add_summary(merged_summary, global_step=it + batch_split_index) outputs = maml.sess.run(maml.inner_model_out, feed_dict={ maml.input_data: test_val_data, maml.input_labels: test_val_labels, }) val_accs.append(print_accuracy(outputs, test_val_labels))\ if batch_split_index == BATCH_SPLIT_NUM - 1: print('iteration: {}'.format(it)) print('Validation accuracy on all batches: ') print(np.mean(val_accs)) maml.sess.run(maml.inner_train_ops, feed_dict={ input_data_ph: test_data, input_labels_ph: test_labels, }) maml.save_model('saved_models/transfer_learning_85/model', step=it)
def train_maml(): with tf.variable_scope('train_data'): input_data_ph = tf.placeholder(dtype=tf.float32, shape=[None, 16, 112, 112, 3]) input_labels_ph = tf.placeholder(dtype=tf.float32, shape=[None, NUM_CLASSES]) tf.summary.image('train', input_data_ph[:, 0, :, :, :], max_outputs=NUM_CLASSES) with tf.variable_scope('validation_data'): val_data_ph = tf.placeholder(dtype=tf.float32, shape=[None, 16, 112, 112, 3]) val_labels_ph = tf.placeholder(dtype=tf.float32, shape=[None, NUM_CLASSES]) tf.summary.image('validation', val_data_ph[:, 0, :, :, :], max_outputs=NUM_CLASSES) gpu_devices = ['/gpu:{}'.format(gpu_id) for gpu_id in range(NUM_GPUS)] maml = ModelAgnosticMetaLearning( C3DNetwork, input_data_ph, input_labels_ph, val_data_ph, val_labels_ph, log_dir=LOG_DIR, gpu_devices=gpu_devices, meta_learn_rate=0.0001, learning_rate=0.001, train=TRAIN, log_device_placement=False, num_classes=101 ) if TRAIN: train_dataset, test_dataset = get_traditional_dataset( num_train_actions=600, base_address=BASE_ADDRESS, class_sample_size=CLASS_SAMPLE_SIZE, ) maml.load_model(path='MAML/sports1m_pretrained.model', load_last_layer=False) print('start meta training.') it = 0 for it in range(10001): train_dataset.sample_k_samples() data = train_dataset.next_batch(num_classes=NUM_CLASSES) tr_data, tr_labels = data['train'] val_data, val_labels = data['validation'] if it % 50 == 0: merged_summary = maml.sess.run(maml.merged, feed_dict={ input_data_ph: tr_data, input_labels_ph: tr_labels, val_data_ph: val_data, val_labels_ph: val_labels, }) maml.file_writer.add_summary(merged_summary, global_step=it) print(it) maml.sess.run(maml.train_op, feed_dict={ input_data_ph: tr_data, input_labels_ph: tr_labels, val_data_ph: val_data, val_labels_ph: val_labels, }) if it % 200 == 0: maml.save_model(path='saved_models/kinetics400/model', step=it) if it != 0: maml.save_model(path='saved_models/kinetics400/model', step=it) else: # test_actions = [ # 'CleanAndJerk', # 'MoppingFloor', # 'FrontCrawl', # 'Surfing', # 'Bowling', # 'SoccerPenalty', # 'SumoWrestling', # 'Shotput', # 'PlayingSitar', # 'FloorGymnastics', # 'Typing', # 'JumpingJack', # 'ShavingBeard', # 'FrisbeeCatch', # 'WritingOnBoard', # 'JavelinThrow', # 'Fencing', # 'FieldHockeyPenalty', # 'BaseballPitch', # 'CuttingInKitchen', # 'Kayaking', # ] test_actions = [ 'ApplyEyeMakeup', 'Archery', 'BabyCrawling', 'BandMarching', 'Bowling', 'Basketball', 'Biking', 'Billiards', 'BlowDryHair', 'FloorGymnastics', 'Typing', 'CliffDiving', 'ShavingBeard', 'FrisbeeCatch', 'WritingOnBoard', 'JavelinThrow', 'Fencing', 'FieldHockeyPenalty', 'PlayingPiano', 'CuttingInKitchen', ] train_dataset, test_dataset = get_traditional_dataset( base_address='/home/siavash/UCF-101/', class_sample_size=CLASS_SAMPLE_SIZE, test_actions=test_actions ) maml.load_model(path='saved_models/kinetics400/model-8000') print('Start testing the network') data = test_dataset.next_batch(num_classes=NUM_CLASSES) test_data, test_labels = data['train'] test_val_data, test_val_labels = data['validation'] print(test_dataset.actions) for it in range(5): maml.sess.run(maml.inner_train_ops, feed_dict={ input_data_ph: test_data, input_labels_ph: test_labels, }) if it % 1 == 0: merged_summary = maml.sess.run(maml.merged, feed_dict={ input_data_ph: test_data, input_labels_ph: test_labels, val_data_ph: test_val_data, val_labels_ph: test_val_labels, }) maml.file_writer.add_summary(merged_summary, global_step=it) print('gradient step: ') print(it) outputs = maml.sess.run(maml.inner_model_out, feed_dict={ maml.input_data: test_val_data, maml.input_labels: test_val_labels, }) print_accuracy(outputs, test_val_labels) maml.save_model('saved_models/ucf101-fit/model-kinetics-trained', step=it)
def initialize(): if RANDOM_SEED != -1: random.seed(RANDOM_SEED) tf.set_random_seed(RANDOM_SEED) model_dir = os.path.join( DATASET, 'meta-train', '{}-way-classifier'.format(N), '{}-shot'.format(K), 'batch-size-{}'.format(BATCH_SIZE), 'num-gpus-{}'.format(NUM_GPUS), 'random-seed-{}'.format(RANDOM_SEED), 'num-iterations-{}'.format(NUM_ITERATIONS), 'meta-learning-rate-{}'.format(META_LEARNING_RATE), 'learning-rate-{}'.format(LEARNING_RATE), ) if META_TRAIN: log_dir = os.path.join(settings.BASE_LOG_ADDRESS, model_dir) saving_path = os.path.join(settings.SAVED_MODELS_ADDRESS, model_dir) else: log_dir = os.path.join(settings.BASE_LOG_ADDRESS, 'meta-test') saving_path = os.path.join(settings.SAVED_MODELS_ADDRESS, 'meta-test', 'model') if DATASET == 'ucf-101': base_address = settings.UCF101_TF_RECORDS_ADDRESS # '/home/siavash/programming/FewShotLearning/ucf101_tfrecords/' elif DATASET == 'diva': base_address = settings.DIVA_TRAIN_TF_RECORDS_ADDRESS else: base_address = settings.KINETICS_TF_RECORDS_ADDRESS if META_TRAIN: input_data_ph, input_labels_ph, val_data_ph, val_labels_ph, iterator = create_data_feed_for_train( base_address=base_address, test_actions=test_actions, batch_size=BATCH_SIZE * NUM_GPUS, k=K, n=N, random_labels=False) else: if DATASET == 'ucf-101' or DATASET == 'kinetics': print(test_actions[:BATCH_SIZE * NUM_GPUS]) input_data_ph, input_labels_ph, iterator, table = \ create_ucf101_data_feed_for_k_sample_per_action_iterative_dataset( dataset_address=base_address, k=K, batch_size=BATCH_SIZE * NUM_GPUS, actions_include=test_actions[:BATCH_SIZE * NUM_GPUS], ) val_data_ph = input_data_ph val_labels_ph = input_labels_ph else: # input_data_ph, input_labels_ph, iterator = create_diva_data_feed_for_k_sample_per_action_iterative_dataset( # dataset_address=base_address, # k=K, # batch_size=BATCH_SIZE * NUM_GPUS, # ) input_data_ph, input_labels_ph, iterator, table = \ create_diva_data_feed_for_k_sample_per_action_iterative_dataset_unique_class_each_batch( dataset_address=base_address, actions_include=None ) val_data_ph = input_data_ph val_labels_ph = input_labels_ph maml = ModelAgnosticMetaLearning(C3DNetwork, input_data_ph, input_labels_ph, val_data_ph, val_labels_ph, log_dir=log_dir, saving_path=saving_path, num_gpu_devices=NUM_GPUS, meta_learn_rate=META_LEARNING_RATE, learning_rate=LEARNING_RATE, log_device_placement=False, num_classes=N) maml.sess.run(tf.tables_initializer()) maml.sess.run(iterator.initializer) if not META_TRAIN: print(maml.sess.run(table.export())) return maml
with tf.variable_scope('validation_data'): val_data_ph = tf.placeholder(dtype=tf.float32, shape=[None, 16, 112, 112, 3]) val_labels_ph = tf.placeholder(dtype=tf.float32, shape=[None, len(action_labels)]) tf.summary.image('validation', val_data_ph[:, 0, :, :, :], max_outputs=len(action_labels)) maml = ModelAgnosticMetaLearning(C3DNetwork, input_data_ph, input_labels_ph, val_data_ph, val_labels_ph, log_dir=settings.BASE_LOG_ADDRESS + '/logs/diva/', saving_path=None, num_gpu_devices=1, meta_learn_rate=0.00001, learning_rate=0.001, log_device_placement=False, num_classes=len(action_labels)) maml.load_model(path=settings.SAVED_MODELS_ADDRESS + '/meta-test/model/-120') class_labels_counters = [] hierarchy_confusion_matrix = np.zeros((5, 5)) for action in sorted(action_labels.keys()): correct = 0 total = 0
def evaluate(): with tf.variable_scope('train_data'): input_data_ph = tf.placeholder(dtype=tf.float32, shape=[None, 16, 112, 112, 3]) input_labels_ph = tf.placeholder(dtype=tf.float32, shape=[None, 20]) tf.summary.image('train', input_data_ph[:, 0, :, :, :], max_outputs=25) with tf.variable_scope('validation_data'): val_data_ph = tf.placeholder(dtype=tf.float32, shape=[None, 16, 112, 112, 3]) val_labels_ph = tf.placeholder(dtype=tf.float32, shape=[None, 20]) tf.summary.image('validation', val_data_ph[:, 0, :, :, :], max_outputs=25) maml = ModelAgnosticMetaLearning( C3DNetwork, input_data_ph, input_labels_ph, val_data_ph, val_labels_ph, num_gpu_devices=1, log_dir=LOG_DIR, learning_rate=0.001, log_device_placement=False, saving_path=None, num_classes=len(TEST_ACTIONS), ) maml.load_model(path=SAVED_MODEL_ADDRESS) correct = 0 count = 0 class_labels_couners = [] for action in sorted(TEST_ACTIONS.keys()): class_label_counter = [0] * len(TEST_ACTIONS) print(action) for file_address in os.listdir(BASE_ADDRESS + action): video_address = BASE_ADDRESS + action + '/' + file_address if len(os.listdir(video_address)) < 16: continue video, _ = TraditionalDataset.get_data_and_labels( None, [[video_address]], num_classes=len(TEST_ACTIONS)) outputs = maml.sess.run(maml.inner_model_out, feed_dict={ maml.input_data: video, }) label = np.argmax(outputs, 2) if label == TEST_ACTIONS[action]: correct += 1 count += 1 class_label_counter[label[0][0]] += 1 print(class_label_counter) print(np.argmax(class_label_counter)) class_labels_couners.append(class_label_counter) print('Accuracy: ') print(float(correct) / count) print(count) print(correct) confusion_matrix = np.array(class_labels_couners, dtype=np.float32).transpose() print('\n\n') print('confusion matrix') print(confusion_matrix) print('\n\n') columns_sum = np.sum(confusion_matrix, axis=0) rows_sum = np.sum(confusion_matrix, axis=1) counter = 0 for action in sorted(TEST_ACTIONS.keys()): print(action) recall = confusion_matrix[counter][counter] / rows_sum[counter] precision = confusion_matrix[counter][counter] / columns_sum[counter] f1_score = 2 * precision * recall / (precision + recall) print('F1 Score: ') print(f1_score) counter += 1
def transfer_learn(): test_actions = [ 'CleanAndJerk', 'MoppingFloor', 'FrontCrawl', 'Surfing', # 0 'Bowling', 'SoccerPenalty', 'SumoWrestling', 'Shotput', 'PlayingSitar', 'FloorGymnastics', 'Typing', # 1 'JumpingJack', 'ShavingBeard', 'FrisbeeCatch', 'WritingOnBoard', 'JavelinThrow', 'Fencing', 'FieldHockeyPenalty', # 3 'BaseballPitch', # 4 'CuttingInKitchen', 'Kayaking', # 2 ] train_dataset, test_dataset = get_traditional_dataset( base_address=BASE_ADDRESS, class_sample_size=CLASS_SAMPLE_SIZE, test_actions=test_actions) with tf.variable_scope('train_data'): input_data_ph = tf.placeholder(dtype=tf.float32, shape=[None, 16, 112, 112, 3]) input_labels_ph = tf.placeholder(dtype=tf.float32, shape=[None, 20]) tf.summary.image('train', input_data_ph[:, 0, :, :, :], max_outputs=25) with tf.variable_scope('validation_data'): val_data_ph = tf.placeholder(dtype=tf.float32, shape=[None, 16, 112, 112, 3]) val_labels_ph = tf.placeholder(dtype=tf.float32, shape=[None, 20]) tf.summary.image('validation', val_data_ph[:, 0, :, :, :], max_outputs=25) gpu_devices = ['/gpu:{}'.format(gpu_id) for gpu_id in range(NUM_GPUS)] maml = ModelAgnosticMetaLearning(C3DNetwork, input_data_ph, input_labels_ph, val_data_ph, val_labels_ph, log_dir=LOG_DIR, gpu_devices=gpu_devices, meta_learn_rate=0.00001, learning_rate=0.001, train=TRAIN, log_device_placement=False) maml.load_model(path='saved_models/transfer_learning/model-400', load_last_layer=False) data = test_dataset.next_batch(num_classes=20) test_data, test_labels = data['train'] test_val_data, test_val_labels = data['validation'] print(test_dataset.actions) for it in range(TRANSFER_LEARNING_ITERATIONS): print(it) if it % 50 == 0: if it % 100 == 0: maml.save_model( 'saved_models/transfer_learning_to_20_classes/model', step=it) merged_summary = maml.sess.run(maml.merged, feed_dict={ input_data_ph: test_data, input_labels_ph: test_labels, val_data_ph: test_val_data, val_labels_ph: test_val_labels, }) maml.file_writer.add_summary(merged_summary, global_step=it) outputs = maml.sess.run(maml.inner_model_out, feed_dict={ maml.input_data: test_val_data, maml.input_labels: test_val_labels, }) val_acc = print_accuracy(outputs, test_val_labels) print('iteration: {}'.format(it)) print('Validation accuracy on all batches: ') print(val_acc) maml.sess.run(maml.inner_train_ops, feed_dict={ input_data_ph: test_data, input_labels_ph: test_labels, })
def train_maml(): data_generator = DataGenerator(UPDATE_BATCH_SIZE * 2, META_BATCH_SIZE) with tf.variable_scope('data_reader'): image_tensor, label_tensor = data_generator.make_data_tensor( train=TRAIN, binary_classification=BINARY_CLASSIFICATION) with tf.variable_scope('train_data'): input_data_ph = tf.slice(image_tensor, [0, 0, 0], [-1, NUM_CLASSES * UPDATE_BATCH_SIZE, -1], name='train') input_labels_ph = tf.slice(label_tensor, [0, 0, 0], [-1, NUM_CLASSES * UPDATE_BATCH_SIZE, -1], name='labels') input_data_ph = tf.reshape(input_data_ph, (-1, 28, 28, 1)) input_labels_ph = tf.reshape(input_labels_ph, (-1, NUM_CLASSES)) tf.summary.image('train', input_data_ph, max_outputs=25) with tf.variable_scope('validation_data'): val_data_ph = tf.slice(image_tensor, [0, NUM_CLASSES * UPDATE_BATCH_SIZE, 0], [-1, -1, -1], name='validation') val_labels_ph = tf.slice(label_tensor, [0, NUM_CLASSES * UPDATE_BATCH_SIZE, 0], [-1, -1, -1], name='val_labels') val_data_ph = tf.reshape(val_data_ph, (-1, 28, 28, 1)) val_labels_ph = tf.reshape(val_labels_ph, (-1, NUM_CLASSES)) tf.summary.image('validation', val_data_ph, max_outputs=25) maml = ModelAgnosticMetaLearning( NeuralNetwork, input_data_ph, input_labels_ph, val_data_ph, val_labels_ph, log_dir=LOG_DIR, learning_rate=0.001, meta_learn_rate=0.0001, log_device_placement=False, saving_path=SAVING_PATH, ) tf.train.start_queue_runners(maml.sess) if TRAIN: print('Start meta training.') it = 0 for it in range(MAML_TRAIN_ITERATIONS): maml.sess.run(maml.train_op) if it % 20 == 0: merged_summary = maml.sess.run(maml.merged) maml.file_writer.add_summary(merged_summary, global_step=it) print(it) if it % 200 == 0: maml.save_model(path=SAVING_PATH, step=it) if it != 0: maml.save_model(path=SAVING_PATH, step=it) else: maml.load_model('saved_models/omniglot/model-1000') print('Start testing the network') test_batch, test_batch_labels, test_val_batch, test_val_batch_labels = maml.sess.run( [ maml.input_data, maml.input_labels, maml.input_validation, maml.input_validation_labels ]) for it in range(MAML_ADAPTATION_ITERATIONS): maml.sess.run(maml.inner_train_ops, feed_dict={ maml.input_data: test_batch, maml.input_labels: test_batch_labels, }) if it % 1 == 0: print(it) summary = maml.sess.run(maml.merged, feed_dict={ maml.input_data: test_batch, maml.input_labels: test_batch_labels, maml.input_validation: test_val_batch, maml.input_validation_labels: test_val_batch_labels, }) maml.file_writer.add_summary(summary, global_step=it) outputs = maml.sess.run( [maml.inner_model_out, maml], feed_dict={ maml.input_data: test_val_batch, maml.input_labels: test_val_batch_labels, }) print('model output:') outputs_np = np.argmax(outputs, axis=1) print(outputs_np) print('labels output:') labels_np = np.argmax(test_val_batch_labels.reshape(-1, 5), axis=1) print(labels_np) print('accuracy:') acc_num = np.sum(outputs_np == labels_np) acc = acc_num / 25. print(acc_num) print(acc) print('done')
def train_maml(): train_actions = sorted(os.listdir(BASE_ADDRESS))[:80] train_example, val_example = get_fast_dataset(train_actions) with tf.variable_scope('train_data'): input_data_ph = train_example['video'] input_labels_ph = train_example['task'] tf.summary.image('train', input_data_ph[:, 0, :, :, :], max_outputs=25) with tf.variable_scope('validation_data'): val_data_ph = val_example['video'] val_labels_ph = val_example['task'] tf.summary.image('validation', val_data_ph[:, 0, :, :, :], max_outputs=25) maml = ModelAgnosticMetaLearning( C3DNetwork, input_data_ph, input_labels_ph, val_data_ph, val_labels_ph, log_dir=LOG_DIR, meta_learn_rate=0.00001, learning_rate=0.001, train=TRAIN ) if TRAIN: # maml.load_model(path='MAML/sports1m_pretrained.model', load_last_layer=False) print('start meta training.') it = 0 for it in range(1001): maml.sess.run(maml.train_op) if it % 20 == 0 and it != 0: merged_summary = maml.sess.run(maml.merged) maml.file_writer.add_summary(merged_summary, global_step=it) print(it) if it != 0: maml.save_model(path='saved_models/ucf101/model', step=it) else: test_actions = sorted(os.listdir(BASE_ADDRESS))[80:] test_example, test_val_example = get_fast_dataset(test_actions) maml.load_model(path='saved_models/backups/ucf101/model-1000') print('Start testing the network') test_data, test_labels, test_val_data, test_val_labels = maml.sess.run( (test_example['video'], test_example['task'], test_val_example['video'], test_val_example['task']) ) for it in range(5): maml.sess.run(maml.inner_train_ops, feed_dict={ input_data_ph: test_data, input_labels_ph: test_labels, }) if it % 1 == 0: merged_summary = maml.sess.run(maml.merged, feed_dict={ input_data_ph: test_data, input_labels_ph: test_labels, val_data_ph: test_val_data, val_labels_ph: test_val_labels, }) maml.file_writer.add_summary(merged_summary, global_step=it) print(it) outputs, loss = maml.sess.run([maml.model_out_train, maml.train_loss], feed_dict={ maml.input_data: test_val_data, maml.input_labels: test_val_labels, }) print_accuracy(outputs, test_val_labels)
def transfer_learn(): train_dataset, test_dataset = get_traditional_dataset( base_address=BASE_ADDRESS, class_sample_size=CLASS_SAMPLE_SIZE, num_train_actions=400) with tf.variable_scope('train_data'): input_data_ph = tf.placeholder(dtype=tf.float32, shape=[None, 16, 112, 112, 3]) input_labels_ph = tf.placeholder(dtype=tf.float32, shape=[None, 400]) tf.summary.image('train', input_data_ph[:, 0, :, :, :], max_outputs=25) with tf.variable_scope('validation_data'): val_data_ph = tf.placeholder(dtype=tf.float32, shape=[None, 16, 112, 112, 3]) val_labels_ph = tf.placeholder(dtype=tf.float32, shape=[None, 400]) tf.summary.image('validation', val_data_ph[:, 0, :, :, :], max_outputs=25) gpu_devices = ['/gpu:{}'.format(gpu_id) for gpu_id in range(NUM_GPUS)] maml = ModelAgnosticMetaLearning(C3DNetwork, input_data_ph, input_labels_ph, val_data_ph, val_labels_ph, log_dir=LOG_DIR, gpu_devices=gpu_devices, meta_learn_rate=0.00001, learning_rate=0.001, train=TRAIN, log_device_placement=False) maml.load_model(path='MAML/sports1m_pretrained.model', load_last_layer=False) for it in range(TRANSFER_LEARNING_ITERATIONS): print(it) data, labels = train_dataset.next_simple_batch(batch_size=BATCH_SIZE) if it % 100 == 0: maml.save_model('saved_models/transfer_learning/model', step=it) if it % 20 == 0: merged_summary = maml.sess.run(maml.merged, feed_dict={ input_data_ph: data, input_labels_ph: labels, val_data_ph: data, val_labels_ph: labels, }) maml.file_writer.add_summary(merged_summary, global_step=it) outputs = maml.sess.run(maml.inner_model_out, feed_dict={ maml.input_data: data, maml.input_labels: labels, }) print_accuracy(outputs, labels) maml.sess.run(maml.inner_train_ops, feed_dict={ input_data_ph: data, input_labels_ph: labels, })
def train_maml(): data_generator = DataGenerator(UPDATE_BATCH_SIZE * 2, META_BATCH_SIZE) with tf.variable_scope('data_reader'): image_tensor, label_tensor = data_generator.make_data_tensor( train=TRAIN) with tf.variable_scope('train_data'): input_data_ph = tf.slice(image_tensor, [0, 0, 0], [-1, NUM_CLASSES * UPDATE_BATCH_SIZE, -1], name='train') input_labels_ph = tf.slice(label_tensor, [0, 0, 0], [-1, NUM_CLASSES * UPDATE_BATCH_SIZE, -1], name='labels') input_data_ph = tf.reshape(input_data_ph, (-1, 28, 28, 1)) input_labels_ph = tf.reshape(input_labels_ph, (-1, 5)) tf.summary.image('train', input_data_ph, max_outputs=25) with tf.variable_scope('validation_data'): val_data_ph = tf.slice(image_tensor, [0, NUM_CLASSES * UPDATE_BATCH_SIZE, 0], [-1, -1, -1], name='validation') val_labels_ph = tf.slice(label_tensor, [0, NUM_CLASSES * UPDATE_BATCH_SIZE, 0], [-1, -1, -1], name='val_labels') val_data_ph = tf.reshape(val_data_ph, (-1, 28, 28, 1)) val_labels_ph = tf.reshape(val_labels_ph, (-1, 5)) tf.summary.image('validation', val_data_ph, max_outputs=25) maml = ModelAgnosticMetaLearning( NeuralNetwork, input_data_ph, input_labels_ph, val_data_ph, val_labels_ph, log_dir=LOG_DIR, learning_rate=0.001, neural_loss_learning_rate=0.001, meta_learn_rate=0.0001, learn_the_loss_function=True, train=TRAIN, ) tf.train.start_queue_runners(maml.sess) if TRAIN: print('Start meta training.') it = 0 for it in range(MAML_TRAIN_ITERATIONS): for k in range(10): maml.sess.run(maml.train_op) for k in range(10): maml.sess.run(maml.loss_func_op) if it % 20 == 0: merged_summary, _ = maml.sess.run((maml.merged, maml.train_op)) maml.file_writer.add_summary(merged_summary, global_step=it) print(it) if it != 0: maml.save_model(path='../saved_models/omniglot_neural_loss/model', step=it) else: maml.load_model('../saved_models/omniglot_neural_loss/model-5000') print('Start testing the network') test_batch, test_batch_labels, test_val_batch, test_val_batch_labels = maml.sess.run( (maml.input_data, maml.input_labels, maml.input_validation, maml.input_validation_labels)) for it in range(MAML_ADAPTATION_ITERATIONS): maml.sess.run(maml.inner_train_ops, feed_dict={ maml.input_data: test_batch, maml.input_labels: test_batch_labels, }) if it % 1 == 0: print(it) summary = maml.sess.run(maml.merged, feed_dict={ maml.input_data: test_batch, maml.input_labels: test_batch_labels, maml.input_validation: test_val_batch, maml.input_validation_labels: test_val_batch_labels, }) maml.file_writer.add_summary(summary, global_step=it) outputs = maml.sess.run(maml.inner_model_out, feed_dict={ maml.input_data: test_val_batch, }) print_accuracy(outputs, test_val_batch_labels) print('done')