def train_maml():
    with tf.variable_scope('train_data'):
        input_data_ph = tf.placeholder(dtype=tf.float32, shape=[None, 16, 112, 112, 3])
        input_labels_ph = tf.placeholder(dtype=tf.float32, shape=[None, NUM_CLASSES])
        tf.summary.image('train', input_data_ph[:, 0, :, :, :], max_outputs=NUM_CLASSES)

    with tf.variable_scope('validation_data'):
        val_data_ph = tf.placeholder(dtype=tf.float32, shape=[None, 16, 112, 112, 3])
        val_labels_ph = tf.placeholder(dtype=tf.float32, shape=[None, NUM_CLASSES])
        tf.summary.image('validation', val_data_ph[:, 0, :, :, :], max_outputs=NUM_CLASSES)

    gpu_devices = ['/gpu:{}'.format(gpu_id) for gpu_id in range(NUM_GPUS)]

    maml = ModelAgnosticMetaLearning(
        C3DNetwork,
        input_data_ph,
        input_labels_ph,
        val_data_ph,
        val_labels_ph,
        log_dir=LOG_DIR,
        gpu_devices=gpu_devices,
        meta_learn_rate=0.0001,
        learning_rate=0.001,
        train=TRAIN,
        log_device_placement=False,
        num_classes=101
    )

    if TRAIN:
        train_dataset, test_dataset = get_traditional_dataset(
            num_train_actions=600,
            base_address=BASE_ADDRESS,
            class_sample_size=CLASS_SAMPLE_SIZE,
        )

        maml.load_model(path='MAML/sports1m_pretrained.model', load_last_layer=False)
        print('start meta training.')

        it = 0
        for it in range(10001):
            train_dataset.sample_k_samples()
            data = train_dataset.next_batch(num_classes=NUM_CLASSES)
            tr_data, tr_labels = data['train']
            val_data, val_labels = data['validation']

            if it % 50 == 0:
                merged_summary = maml.sess.run(maml.merged, feed_dict={
                    input_data_ph: tr_data,
                    input_labels_ph: tr_labels,
                    val_data_ph: val_data,
                    val_labels_ph: val_labels,
                })
                maml.file_writer.add_summary(merged_summary, global_step=it)
                print(it)

            maml.sess.run(maml.train_op, feed_dict={
                input_data_ph: tr_data,
                input_labels_ph: tr_labels,
                val_data_ph: val_data,
                val_labels_ph: val_labels,
            })

            if it % 200 == 0:
                maml.save_model(path='saved_models/kinetics400/model', step=it)

        if it != 0:
            maml.save_model(path='saved_models/kinetics400/model', step=it)

    else:
        # test_actions = [
        #     'CleanAndJerk',
        #     'MoppingFloor',
        #     'FrontCrawl',
        #     'Surfing',
        #     'Bowling',
        #     'SoccerPenalty',
        #     'SumoWrestling',
        #     'Shotput',
        #     'PlayingSitar',
        #     'FloorGymnastics',
        #     'Typing',
        #     'JumpingJack',
        #     'ShavingBeard',
        #     'FrisbeeCatch',
        #     'WritingOnBoard',
        #     'JavelinThrow',
        #     'Fencing',
        #     'FieldHockeyPenalty',
        #     'BaseballPitch',
        #     'CuttingInKitchen',
        #     'Kayaking',
        # ]
        test_actions = [
            'ApplyEyeMakeup',
            'Archery',
            'BabyCrawling',
            'BandMarching',
            'Bowling',
            'Basketball',
            'Biking',
            'Billiards',
            'BlowDryHair',
            'FloorGymnastics',
            'Typing',
            'CliffDiving',
            'ShavingBeard',
            'FrisbeeCatch',
            'WritingOnBoard',
            'JavelinThrow',
            'Fencing',
            'FieldHockeyPenalty',
            'PlayingPiano',
            'CuttingInKitchen',
        ]
        train_dataset, test_dataset = get_traditional_dataset(
            base_address='/home/siavash/UCF-101/',
            class_sample_size=CLASS_SAMPLE_SIZE,
            test_actions=test_actions
        )

        maml.load_model(path='saved_models/kinetics400/model-8000')
        print('Start testing the network')
        data = test_dataset.next_batch(num_classes=NUM_CLASSES)
        test_data, test_labels = data['train']
        test_val_data, test_val_labels = data['validation']

        print(test_dataset.actions)

        for it in range(5):
            maml.sess.run(maml.inner_train_ops, feed_dict={
                input_data_ph: test_data,
                input_labels_ph: test_labels,
            })

            if it % 1 == 0:
                merged_summary = maml.sess.run(maml.merged, feed_dict={
                    input_data_ph: test_data,
                    input_labels_ph: test_labels,
                    val_data_ph: test_val_data,
                    val_labels_ph: test_val_labels,
                })
                maml.file_writer.add_summary(merged_summary, global_step=it)
                print('gradient step: ')
                print(it)

                outputs = maml.sess.run(maml.inner_model_out, feed_dict={
                    maml.input_data: test_val_data,
                    maml.input_labels: test_val_labels,
                })

                print_accuracy(outputs, test_val_labels)

            maml.save_model('saved_models/ucf101-fit/model-kinetics-trained', step=it)
示例#2
0
def transfer_learn():
    test_actions = [
        'CleanAndJerk',
        'MoppingFloor',
        'FrontCrawl',
        # 'Surfing',
        'Bowling',
        'SoccerPenalty',
        'SumoWrestling',
        'Shotput',
        'PlayingSitar',
        'FloorGymnastics',
        # 'Typing',
        'JumpingJack',
        'ShavingBeard',
        'FrisbeeCatch',
        'WritingOnBoard',
        'JavelinThrow',
        'Fencing',
        # 'FieldHockeyPenalty',
        # 'BaseballPitch',
        'CuttingInKitchen',
        # 'Kayaking',
    ]

    train_dataset, test_dataset = get_traditional_dataset(
        base_address=BASE_ADDRESS,
        class_sample_size=CLASS_SAMPLE_SIZE,
        test_actions=test_actions)

    with tf.variable_scope('train_data'):
        input_data_ph = tf.placeholder(dtype=tf.float32,
                                       shape=[None, 16, 112, 112, 3])
        input_labels_ph = tf.placeholder(dtype=tf.float32, shape=[None, 85])
        tf.summary.image('train', input_data_ph[:, 0, :, :, :], max_outputs=25)

    with tf.variable_scope('validation_data'):
        val_data_ph = tf.placeholder(dtype=tf.float32,
                                     shape=[None, 16, 112, 112, 3])
        val_labels_ph = tf.placeholder(dtype=tf.float32, shape=[None, 85])
        tf.summary.image('validation',
                         val_data_ph[:, 0, :, :, :],
                         max_outputs=25)

    gpu_devices = ['/gpu:{}'.format(gpu_id) for gpu_id in range(NUM_GPUS)]

    maml = ModelAgnosticMetaLearning(C3DNetwork,
                                     input_data_ph,
                                     input_labels_ph,
                                     val_data_ph,
                                     val_labels_ph,
                                     log_dir=LOG_DIR,
                                     gpu_devices=gpu_devices,
                                     meta_learn_rate=0.00001,
                                     learning_rate=0.001,
                                     train=TRAIN,
                                     log_device_placement=False)

    maml.load_model(path='MAML/sports1m_pretrained.model',
                    load_last_layer=False)

    for it in range(TRANSFER_LEARNING_ITERATIONS):
        print(it)

        data = train_dataset.next_batch(num_classes=85, real_labels=True)
        batch_test_data, batch_test_labels = data['train']
        batch_test_val_data, batch_test_val_labels = data['validation']
        batch_split_size = int(NUM_CLASSES / BATCH_SPLIT_NUM)

        for batch_split_index in range(BATCH_SPLIT_NUM):
            start = batch_split_index * batch_split_size
            end = batch_split_index * batch_split_size + batch_split_size
            test_data = batch_test_data[start:end, :, :, :, :]
            test_labels = batch_test_labels[start:end, :]
            test_val_data = batch_test_val_data[start:end, :, :, :, :]
            test_val_labels = batch_test_val_labels[start:end, :]

            if it % 50 == 0:
                if batch_split_index == 0:
                    val_accs = []

                    if it % 100 == 0:
                        maml.save_model(
                            'saved_models/transfer_learning_85/model', step=it)

                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                run_metadata = tf.RunMetadata()
                merged_summary = maml.sess.run(maml.merged,
                                               feed_dict={
                                                   input_data_ph: test_data,
                                                   input_labels_ph:
                                                   test_labels,
                                                   val_data_ph: test_val_data,
                                                   val_labels_ph:
                                                   test_val_labels,
                                               },
                                               options=run_options,
                                               run_metadata=run_metadata)
                maml.file_writer.add_summary(merged_summary,
                                             global_step=it +
                                             batch_split_index)

                outputs = maml.sess.run(maml.inner_model_out,
                                        feed_dict={
                                            maml.input_data: test_val_data,
                                            maml.input_labels: test_val_labels,
                                        })

                val_accs.append(print_accuracy(outputs, test_val_labels))\

                if batch_split_index == BATCH_SPLIT_NUM - 1:
                    print('iteration: {}'.format(it))
                    print('Validation accuracy on all batches: ')
                    print(np.mean(val_accs))

            maml.sess.run(maml.inner_train_ops,
                          feed_dict={
                              input_data_ph: test_data,
                              input_labels_ph: test_labels,
                          })

    maml.save_model('saved_models/transfer_learning_85/model', step=it)
maml = ModelAgnosticMetaLearning(C3DNetwork,
                                 input_data_ph,
                                 input_labels_ph,
                                 val_data_ph,
                                 val_labels_ph,
                                 log_dir=settings.BASE_LOG_ADDRESS +
                                 '/logs/diva/',
                                 saving_path=None,
                                 num_gpu_devices=1,
                                 meta_learn_rate=0.00001,
                                 learning_rate=0.001,
                                 log_device_placement=False,
                                 num_classes=len(action_labels))

maml.load_model(path=settings.SAVED_MODELS_ADDRESS + '/meta-test/model/-120')

class_labels_counters = []
hierarchy_confusion_matrix = np.zeros((5, 5))

for action in sorted(action_labels.keys()):
    correct = 0
    total = 0
    guess_table = [0] * len(action_labels)
    print(action)
    action_labels_sum = np.zeros((1, 20))
    sys.stdout.flush()
    for file_address in os.listdir(os.path.join(base_address, action))[:50]:
        tf_record_address = os.path.join(base_address, action, file_address)
        dataset = tf.data.TFRecordDataset([tf_record_address])
        dataset = dataset.map(extract_video)
示例#4
0
def evaluate():
    with tf.variable_scope('train_data'):
        input_data_ph = tf.placeholder(dtype=tf.float32,
                                       shape=[None, 16, 112, 112, 3])
        input_labels_ph = tf.placeholder(dtype=tf.float32, shape=[None, 20])
        tf.summary.image('train', input_data_ph[:, 0, :, :, :], max_outputs=25)

    with tf.variable_scope('validation_data'):
        val_data_ph = tf.placeholder(dtype=tf.float32,
                                     shape=[None, 16, 112, 112, 3])
        val_labels_ph = tf.placeholder(dtype=tf.float32, shape=[None, 20])
        tf.summary.image('validation',
                         val_data_ph[:, 0, :, :, :],
                         max_outputs=25)

    maml = ModelAgnosticMetaLearning(
        C3DNetwork,
        input_data_ph,
        input_labels_ph,
        val_data_ph,
        val_labels_ph,
        num_gpu_devices=1,
        log_dir=LOG_DIR,
        learning_rate=0.001,
        log_device_placement=False,
        saving_path=None,
        num_classes=len(TEST_ACTIONS),
    )

    maml.load_model(path=SAVED_MODEL_ADDRESS)

    correct = 0
    count = 0

    class_labels_couners = []

    for action in sorted(TEST_ACTIONS.keys()):
        class_label_counter = [0] * len(TEST_ACTIONS)
        print(action)
        for file_address in os.listdir(BASE_ADDRESS + action):
            video_address = BASE_ADDRESS + action + '/' + file_address
            if len(os.listdir(video_address)) < 16:
                continue

            video, _ = TraditionalDataset.get_data_and_labels(
                None, [[video_address]], num_classes=len(TEST_ACTIONS))

            outputs = maml.sess.run(maml.inner_model_out,
                                    feed_dict={
                                        maml.input_data: video,
                                    })

            label = np.argmax(outputs, 2)

            if label == TEST_ACTIONS[action]:
                correct += 1

            count += 1
            class_label_counter[label[0][0]] += 1

        print(class_label_counter)
        print(np.argmax(class_label_counter))
        class_labels_couners.append(class_label_counter)

    print('Accuracy: ')
    print(float(correct) / count)
    print(count)
    print(correct)

    confusion_matrix = np.array(class_labels_couners,
                                dtype=np.float32).transpose()
    print('\n\n')
    print('confusion matrix')
    print(confusion_matrix)
    print('\n\n')
    columns_sum = np.sum(confusion_matrix, axis=0)
    rows_sum = np.sum(confusion_matrix, axis=1)

    counter = 0
    for action in sorted(TEST_ACTIONS.keys()):
        print(action)
        recall = confusion_matrix[counter][counter] / rows_sum[counter]
        precision = confusion_matrix[counter][counter] / columns_sum[counter]
        f1_score = 2 * precision * recall / (precision + recall)
        print('F1 Score: ')
        print(f1_score)
        counter += 1
示例#5
0
def transfer_learn():
    test_actions = [
        'CleanAndJerk',
        'MoppingFloor',
        'FrontCrawl',
        'Surfing',  # 0
        'Bowling',
        'SoccerPenalty',
        'SumoWrestling',
        'Shotput',
        'PlayingSitar',
        'FloorGymnastics',
        'Typing',  # 1
        'JumpingJack',
        'ShavingBeard',
        'FrisbeeCatch',
        'WritingOnBoard',
        'JavelinThrow',
        'Fencing',
        'FieldHockeyPenalty',  # 3
        'BaseballPitch',  # 4
        'CuttingInKitchen',
        'Kayaking',  # 2
    ]

    train_dataset, test_dataset = get_traditional_dataset(
        base_address=BASE_ADDRESS,
        class_sample_size=CLASS_SAMPLE_SIZE,
        test_actions=test_actions)

    with tf.variable_scope('train_data'):
        input_data_ph = tf.placeholder(dtype=tf.float32,
                                       shape=[None, 16, 112, 112, 3])
        input_labels_ph = tf.placeholder(dtype=tf.float32, shape=[None, 20])
        tf.summary.image('train', input_data_ph[:, 0, :, :, :], max_outputs=25)

    with tf.variable_scope('validation_data'):
        val_data_ph = tf.placeholder(dtype=tf.float32,
                                     shape=[None, 16, 112, 112, 3])
        val_labels_ph = tf.placeholder(dtype=tf.float32, shape=[None, 20])
        tf.summary.image('validation',
                         val_data_ph[:, 0, :, :, :],
                         max_outputs=25)

    gpu_devices = ['/gpu:{}'.format(gpu_id) for gpu_id in range(NUM_GPUS)]

    maml = ModelAgnosticMetaLearning(C3DNetwork,
                                     input_data_ph,
                                     input_labels_ph,
                                     val_data_ph,
                                     val_labels_ph,
                                     log_dir=LOG_DIR,
                                     gpu_devices=gpu_devices,
                                     meta_learn_rate=0.00001,
                                     learning_rate=0.001,
                                     train=TRAIN,
                                     log_device_placement=False)

    maml.load_model(path='saved_models/transfer_learning/model-400',
                    load_last_layer=False)

    data = test_dataset.next_batch(num_classes=20)
    test_data, test_labels = data['train']
    test_val_data, test_val_labels = data['validation']
    print(test_dataset.actions)

    for it in range(TRANSFER_LEARNING_ITERATIONS):
        print(it)
        if it % 50 == 0:
            if it % 100 == 0:
                maml.save_model(
                    'saved_models/transfer_learning_to_20_classes/model',
                    step=it)

            merged_summary = maml.sess.run(maml.merged,
                                           feed_dict={
                                               input_data_ph: test_data,
                                               input_labels_ph: test_labels,
                                               val_data_ph: test_val_data,
                                               val_labels_ph: test_val_labels,
                                           })
            maml.file_writer.add_summary(merged_summary, global_step=it)

            outputs = maml.sess.run(maml.inner_model_out,
                                    feed_dict={
                                        maml.input_data: test_val_data,
                                        maml.input_labels: test_val_labels,
                                    })

            val_acc = print_accuracy(outputs, test_val_labels)

            print('iteration: {}'.format(it))
            print('Validation accuracy on all batches: ')
            print(val_acc)

        maml.sess.run(maml.inner_train_ops,
                      feed_dict={
                          input_data_ph: test_data,
                          input_labels_ph: test_labels,
                      })
def train_maml():
    data_generator = DataGenerator(UPDATE_BATCH_SIZE * 2, META_BATCH_SIZE)
    with tf.variable_scope('data_reader'):
        image_tensor, label_tensor = data_generator.make_data_tensor(
            train=TRAIN, binary_classification=BINARY_CLASSIFICATION)

    with tf.variable_scope('train_data'):
        input_data_ph = tf.slice(image_tensor, [0, 0, 0],
                                 [-1, NUM_CLASSES * UPDATE_BATCH_SIZE, -1],
                                 name='train')
        input_labels_ph = tf.slice(label_tensor, [0, 0, 0],
                                   [-1, NUM_CLASSES * UPDATE_BATCH_SIZE, -1],
                                   name='labels')
        input_data_ph = tf.reshape(input_data_ph, (-1, 28, 28, 1))
        input_labels_ph = tf.reshape(input_labels_ph, (-1, NUM_CLASSES))
        tf.summary.image('train', input_data_ph, max_outputs=25)

    with tf.variable_scope('validation_data'):
        val_data_ph = tf.slice(image_tensor,
                               [0, NUM_CLASSES * UPDATE_BATCH_SIZE, 0],
                               [-1, -1, -1],
                               name='validation')
        val_labels_ph = tf.slice(label_tensor,
                                 [0, NUM_CLASSES * UPDATE_BATCH_SIZE, 0],
                                 [-1, -1, -1],
                                 name='val_labels')
        val_data_ph = tf.reshape(val_data_ph, (-1, 28, 28, 1))
        val_labels_ph = tf.reshape(val_labels_ph, (-1, NUM_CLASSES))
        tf.summary.image('validation', val_data_ph, max_outputs=25)

    maml = ModelAgnosticMetaLearning(
        NeuralNetwork,
        input_data_ph,
        input_labels_ph,
        val_data_ph,
        val_labels_ph,
        log_dir=LOG_DIR,
        learning_rate=0.001,
        meta_learn_rate=0.0001,
        log_device_placement=False,
        saving_path=SAVING_PATH,
    )
    tf.train.start_queue_runners(maml.sess)

    if TRAIN:
        print('Start meta training.')

        it = 0
        for it in range(MAML_TRAIN_ITERATIONS):
            maml.sess.run(maml.train_op)

            if it % 20 == 0:
                merged_summary = maml.sess.run(maml.merged)
                maml.file_writer.add_summary(merged_summary, global_step=it)
                print(it)
            if it % 200 == 0:
                maml.save_model(path=SAVING_PATH, step=it)

        if it != 0:
            maml.save_model(path=SAVING_PATH, step=it)

    else:
        maml.load_model('saved_models/omniglot/model-1000')
        print('Start testing the network')
        test_batch, test_batch_labels, test_val_batch, test_val_batch_labels = maml.sess.run(
            [
                maml.input_data, maml.input_labels, maml.input_validation,
                maml.input_validation_labels
            ])

        for it in range(MAML_ADAPTATION_ITERATIONS):
            maml.sess.run(maml.inner_train_ops,
                          feed_dict={
                              maml.input_data: test_batch,
                              maml.input_labels: test_batch_labels,
                          })

            if it % 1 == 0:
                print(it)
                summary = maml.sess.run(maml.merged,
                                        feed_dict={
                                            maml.input_data:
                                            test_batch,
                                            maml.input_labels:
                                            test_batch_labels,
                                            maml.input_validation:
                                            test_val_batch,
                                            maml.input_validation_labels:
                                            test_val_batch_labels,
                                        })
                maml.file_writer.add_summary(summary, global_step=it)

        outputs = maml.sess.run(
            [maml.inner_model_out, maml],
            feed_dict={
                maml.input_data: test_val_batch,
                maml.input_labels: test_val_batch_labels,
            })

        print('model output:')
        outputs_np = np.argmax(outputs, axis=1)
        print(outputs_np)

        print('labels output:')
        labels_np = np.argmax(test_val_batch_labels.reshape(-1, 5), axis=1)
        print(labels_np)

        print('accuracy:')
        acc_num = np.sum(outputs_np == labels_np)
        acc = acc_num / 25.
        print(acc_num)
        print(acc)

    print('done')
示例#7
0
def train_maml():
    train_actions = sorted(os.listdir(BASE_ADDRESS))[:80]
    train_example, val_example = get_fast_dataset(train_actions)

    with tf.variable_scope('train_data'):
        input_data_ph = train_example['video']
        input_labels_ph = train_example['task']
        tf.summary.image('train', input_data_ph[:, 0, :, :, :], max_outputs=25)

    with tf.variable_scope('validation_data'):
        val_data_ph = val_example['video']
        val_labels_ph = val_example['task']
        tf.summary.image('validation', val_data_ph[:, 0, :, :, :], max_outputs=25)

    maml = ModelAgnosticMetaLearning(
        C3DNetwork,
        input_data_ph,
        input_labels_ph,
        val_data_ph,
        val_labels_ph,
        log_dir=LOG_DIR,
        meta_learn_rate=0.00001,
        learning_rate=0.001,
        train=TRAIN
    )

    if TRAIN:
        # maml.load_model(path='MAML/sports1m_pretrained.model', load_last_layer=False)
        print('start meta training.')

        it = 0
        for it in range(1001):
            maml.sess.run(maml.train_op)

            if it % 20 == 0 and it != 0:
                merged_summary = maml.sess.run(maml.merged)
                maml.file_writer.add_summary(merged_summary, global_step=it)
                print(it)

        if it != 0:
            maml.save_model(path='saved_models/ucf101/model', step=it)

    else:
        test_actions = sorted(os.listdir(BASE_ADDRESS))[80:]
        test_example, test_val_example = get_fast_dataset(test_actions)
        maml.load_model(path='saved_models/backups/ucf101/model-1000')
        print('Start testing the network')

        test_data, test_labels, test_val_data, test_val_labels = maml.sess.run(
            (test_example['video'], test_example['task'], test_val_example['video'], test_val_example['task'])
        )

        for it in range(5):
            maml.sess.run(maml.inner_train_ops, feed_dict={
                input_data_ph: test_data,
                input_labels_ph: test_labels,
            })

            if it % 1 == 0:
                merged_summary = maml.sess.run(maml.merged, feed_dict={
                    input_data_ph: test_data,
                    input_labels_ph: test_labels,
                    val_data_ph: test_val_data,
                    val_labels_ph: test_val_labels,
                })
                maml.file_writer.add_summary(merged_summary, global_step=it)
                print(it)

                outputs, loss = maml.sess.run([maml.model_out_train, maml.train_loss], feed_dict={
                    maml.input_data: test_val_data,
                    maml.input_labels: test_val_labels,
                })

                print_accuracy(outputs, test_val_labels)
def transfer_learn():
    train_dataset, test_dataset = get_traditional_dataset(
        base_address=BASE_ADDRESS,
        class_sample_size=CLASS_SAMPLE_SIZE,
        num_train_actions=400)

    with tf.variable_scope('train_data'):
        input_data_ph = tf.placeholder(dtype=tf.float32,
                                       shape=[None, 16, 112, 112, 3])
        input_labels_ph = tf.placeholder(dtype=tf.float32, shape=[None, 400])
        tf.summary.image('train', input_data_ph[:, 0, :, :, :], max_outputs=25)

    with tf.variable_scope('validation_data'):
        val_data_ph = tf.placeholder(dtype=tf.float32,
                                     shape=[None, 16, 112, 112, 3])
        val_labels_ph = tf.placeholder(dtype=tf.float32, shape=[None, 400])
        tf.summary.image('validation',
                         val_data_ph[:, 0, :, :, :],
                         max_outputs=25)

    gpu_devices = ['/gpu:{}'.format(gpu_id) for gpu_id in range(NUM_GPUS)]

    maml = ModelAgnosticMetaLearning(C3DNetwork,
                                     input_data_ph,
                                     input_labels_ph,
                                     val_data_ph,
                                     val_labels_ph,
                                     log_dir=LOG_DIR,
                                     gpu_devices=gpu_devices,
                                     meta_learn_rate=0.00001,
                                     learning_rate=0.001,
                                     train=TRAIN,
                                     log_device_placement=False)

    maml.load_model(path='MAML/sports1m_pretrained.model',
                    load_last_layer=False)

    for it in range(TRANSFER_LEARNING_ITERATIONS):
        print(it)

        data, labels = train_dataset.next_simple_batch(batch_size=BATCH_SIZE)

        if it % 100 == 0:
            maml.save_model('saved_models/transfer_learning/model', step=it)

        if it % 20 == 0:
            merged_summary = maml.sess.run(maml.merged,
                                           feed_dict={
                                               input_data_ph: data,
                                               input_labels_ph: labels,
                                               val_data_ph: data,
                                               val_labels_ph: labels,
                                           })
            maml.file_writer.add_summary(merged_summary, global_step=it)

            outputs = maml.sess.run(maml.inner_model_out,
                                    feed_dict={
                                        maml.input_data: data,
                                        maml.input_labels: labels,
                                    })

            print_accuracy(outputs, labels)

        maml.sess.run(maml.inner_train_ops,
                      feed_dict={
                          input_data_ph: data,
                          input_labels_ph: labels,
                      })
def train_maml():
    data_generator = DataGenerator(UPDATE_BATCH_SIZE * 2, META_BATCH_SIZE)
    with tf.variable_scope('data_reader'):
        image_tensor, label_tensor = data_generator.make_data_tensor(
            train=TRAIN)

    with tf.variable_scope('train_data'):
        input_data_ph = tf.slice(image_tensor, [0, 0, 0],
                                 [-1, NUM_CLASSES * UPDATE_BATCH_SIZE, -1],
                                 name='train')
        input_labels_ph = tf.slice(label_tensor, [0, 0, 0],
                                   [-1, NUM_CLASSES * UPDATE_BATCH_SIZE, -1],
                                   name='labels')
        input_data_ph = tf.reshape(input_data_ph, (-1, 28, 28, 1))
        input_labels_ph = tf.reshape(input_labels_ph, (-1, 5))
        tf.summary.image('train', input_data_ph, max_outputs=25)

    with tf.variable_scope('validation_data'):
        val_data_ph = tf.slice(image_tensor,
                               [0, NUM_CLASSES * UPDATE_BATCH_SIZE, 0],
                               [-1, -1, -1],
                               name='validation')
        val_labels_ph = tf.slice(label_tensor,
                                 [0, NUM_CLASSES * UPDATE_BATCH_SIZE, 0],
                                 [-1, -1, -1],
                                 name='val_labels')
        val_data_ph = tf.reshape(val_data_ph, (-1, 28, 28, 1))
        val_labels_ph = tf.reshape(val_labels_ph, (-1, 5))
        tf.summary.image('validation', val_data_ph, max_outputs=25)

    maml = ModelAgnosticMetaLearning(
        NeuralNetwork,
        input_data_ph,
        input_labels_ph,
        val_data_ph,
        val_labels_ph,
        log_dir=LOG_DIR,
        learning_rate=0.001,
        neural_loss_learning_rate=0.001,
        meta_learn_rate=0.0001,
        learn_the_loss_function=True,
        train=TRAIN,
    )
    tf.train.start_queue_runners(maml.sess)

    if TRAIN:
        print('Start meta training.')

        it = 0
        for it in range(MAML_TRAIN_ITERATIONS):
            for k in range(10):
                maml.sess.run(maml.train_op)
            for k in range(10):
                maml.sess.run(maml.loss_func_op)

            if it % 20 == 0:
                merged_summary, _ = maml.sess.run((maml.merged, maml.train_op))
                maml.file_writer.add_summary(merged_summary, global_step=it)
                print(it)

        if it != 0:
            maml.save_model(path='../saved_models/omniglot_neural_loss/model',
                            step=it)

    else:
        maml.load_model('../saved_models/omniglot_neural_loss/model-5000')
        print('Start testing the network')
        test_batch, test_batch_labels, test_val_batch, test_val_batch_labels = maml.sess.run(
            (maml.input_data, maml.input_labels, maml.input_validation,
             maml.input_validation_labels))

        for it in range(MAML_ADAPTATION_ITERATIONS):
            maml.sess.run(maml.inner_train_ops,
                          feed_dict={
                              maml.input_data: test_batch,
                              maml.input_labels: test_batch_labels,
                          })

            if it % 1 == 0:
                print(it)
                summary = maml.sess.run(maml.merged,
                                        feed_dict={
                                            maml.input_data:
                                            test_batch,
                                            maml.input_labels:
                                            test_batch_labels,
                                            maml.input_validation:
                                            test_val_batch,
                                            maml.input_validation_labels:
                                            test_val_batch_labels,
                                        })
                maml.file_writer.add_summary(summary, global_step=it)

            outputs = maml.sess.run(maml.inner_model_out,
                                    feed_dict={
                                        maml.input_data: test_val_batch,
                                    })

            print_accuracy(outputs, test_val_batch_labels)

    print('done')