def main(args):

    if (Ray_tune):
        for key in list(args.keys()):
            if str(args[key]).lower() == "true":
                args[key] = True
            elif str(args[key]).lower() == "false":
                args[key] = False
        print(args)
        args = Bunch(args)

    seed = args.seed
    random.seed(seed)
    np.random.seed(seed)
    tf.set_random_seed(seed)
    cir = args.cir_inner_loop
    K = args.K
    if (cir == 0.5):
        train_occ = False
        num_training_samples_per_class = int(K / 2)
    elif (cir == 1.0):
        train_occ = True
        num_training_samples_per_class = K
    else:
        print('cir between 0.5 and 1.0 not implemented')
        assert (0)
    test_occ = True

    if (Ray_tune):
        args.summary_dir = args.summary_dir + '_K_' + str(K) + '_lr_' + str(
            args.lr) + '_updts_' + str(args.num_updates) + '_q_' + str(
                args.n_queries)
        args.summary_dir = args.summary_dir.replace(".", "_")
        if (args.bn):
            args.summary_dir += '_bn'
        args.summary_dir = args.dataset + '_' + args.summary_dir

    base_path = "/home/USER/Documents/"
    if not (os.path.exists(base_path)):
        base_path = "/home/ubuntu/Projects/"
    if not (os.path.exists(base_path)):
        base_path = "/home/USER/Projects/"
    basefolder = base_path + "MAML/raw_data/"

    if (args.dataset == 'MIN'):
        metatrain_task_distribution, metaval_task_distribution, metatest_task_distribution = create_miniimagenet_task_distribution(
            basefolder + "miniImageNet_data/miniimagenet.pkl",
            train_occ=train_occ,
            test_occ=test_occ,
            num_training_samples_per_class=num_training_samples_per_class,
            num_test_samples_per_class=int(args.n_queries / 2),
            num_training_classes=2,
            meta_batch_size=8,
            seq_length=0)
    elif (args.dataset == 'OMN'):
        metatrain_task_distribution, metaval_task_distribution, metatest_task_distribution = create_omniglot_allcharacters_task_distribution(
            basefolder + "omniglot/omniglot.pkl",
            train_occ=train_occ,
            test_occ=test_occ,
            num_training_samples_per_class=num_training_samples_per_class,
            num_test_samples_per_class=int(args.n_queries / 2),
            num_training_classes=2,
            meta_batch_size=8,
            seq_length=0)
    elif (args.dataset == 'CIFAR_FS'):
        metatrain_task_distribution, metaval_task_distribution, metatest_task_distribution = create_cifarfs_task_distribution(
            base_path + "MAML/cifar_fc100/data/CIFAR_FS/CIFAR_FS_train.pickle",
            base_path + "MAML/cifar_fc100/data/CIFAR_FS/CIFAR_FS_val.pickle",
            base_path + "MAML/cifar_fc100/data/CIFAR_FS/CIFAR_FS_test.pickle",
            train_occ=train_occ,
            test_occ=test_occ,
            num_training_samples_per_class=num_training_samples_per_class,
            num_test_samples_per_class=int(args.n_queries / 2),
            num_training_classes=2,
            meta_batch_size=8,
            seq_length=0)
    elif (args.dataset == 'FC100'):
        metatrain_task_distribution, metaval_task_distribution, metatest_task_distribution = create_fc100_task_distribution(
            base_path + "MAML/cifar_fc100/data/FC100/FC100_train.pickle",
            base_path + "MAML/cifar_fc100/data/FC100/FC100_val.pickle",
            base_path + "MAML/cifar_fc100/data/FC100/FC100_test.pickle",
            train_occ=train_occ,
            test_occ=test_occ,
            num_training_samples_per_class=num_training_samples_per_class,
            num_test_samples_per_class=int(args.n_queries / 2),
            num_training_classes=2,
            meta_batch_size=8,
            seq_length=0)

    sess = tf.InteractiveSession()
    meta_batch = metatrain_task_distribution.sample_batch()
    input_shape = meta_batch[0].get_train_set()[0][0].shape
    if ('MAML' in args.summary_dir or 'ANIL' in args.summary_dir):
        if (args.stop_grad):
            from fomaml_class import FOMAML
            model = FOMAML(sess, args, seed, 64, input_shape)

        else:
            from maml_class import MAML
            model = MAML(sess, args, seed, 64, input_shape)

    else:
        print('model is unknown')
        assert (0)

    summary = False
    if (args.summary_dir):
        summary = True

    if (summary):
        loddir_path = './summaries_CIMAML'
        if (not (os.path.exists(loddir_path))):
            os.mkdir(loddir_path)
        if (not (os.path.exists(os.path.join(loddir_path,
                                             model.summary_dir)))):
            os.mkdir(os.path.join(loddir_path, model.summary_dir))
        train_writer = tf.summary.FileWriter(
            os.path.join(loddir_path, model.summary_dir) + '/train')
        val_writer = tf.summary.FileWriter(
            os.path.join(loddir_path, model.summary_dir) + '/val')
        val_tags = [
            'val_loss_avg', 'val_acc_avg', 'val_precision_avg',
            'val_recall_avg', 'val_specificity_avg', 'val_f1_score_avg',
            'val_auc_pr_avg'
        ]

    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())

    # intialization
    n_val_tasks_sampled = 2
    val_test_loss = 0
    min_val_epoch = -1
    min_val_test_loss = 10000
    min_metatrain_epoch = -1
    min_metatrain_loss = 10000
    val_interval = 100

    max_val_acc = -1
    max_val_f1 = -1

    if (Ray_tune):
        track.init()
    for epoch in range(args.meta_epochs):
        if ((epoch % val_interval == 0) or (epoch == args.meta_epochs - 1)):
            val_metrics_list = []
            analysis_list = []
            for _ in range(n_val_tasks_sampled):
                val_meta_batch = metaval_task_distribution.sample_batch()

                X_val_a, Y_val_a, X_val_b, Y_val_b = [], [], [], []
                for task in val_meta_batch:
                    X_val_a.append(task.get_train_set()[0])
                    Y_val_a.append(np.expand_dims(task.get_train_set()[1], -1))
                    X_val_b.append(task.get_test_set()[0])
                    Y_val_b.append(np.expand_dims(task.get_test_set()[1], -1))

                for K_val_X, K_val_Y, test_val_X, test_val_Y in zip(
                        X_val_a, Y_val_a, X_val_b, Y_val_b):
                    val_summaries, val_test_loss, acc, precision, recall, specificity, f1_score, auc_pr = model.val_op(
                        K_val_X, K_val_Y, test_val_X, test_val_Y)
                    val_metrics_list.append([
                        val_test_loss, acc, precision, recall, specificity,
                        f1_score, auc_pr
                    ])

            avg_val_metrics = np.mean(val_metrics_list, axis=0)

            if (avg_val_metrics[0] < min_val_test_loss):
                model.saver.save(
                    model.sess, model.checkpoint_path + model.summary_dir +
                    "_restore_val_test_loss/model.ckpt")
                min_val_test_loss = avg_val_metrics[0]
                min_val_epoch = epoch
                print('model saved', ' epoch: ', epoch, ' val_test_loss: ',
                      avg_val_metrics[0], ' acc : ', avg_val_metrics[1],
                      ' prec : ', avg_val_metrics[2], ' recall : ',
                      avg_val_metrics[3], ' spec : ', avg_val_metrics[4],
                      ' F1 : ', avg_val_metrics[5], ' auc_pr : ',
                      avg_val_metrics[6])

            if (avg_val_metrics[1] > max_val_acc):
                max_val_acc = avg_val_metrics[1]
            if (avg_val_metrics[5] > max_val_f1):
                max_val_f1 = avg_val_metrics[5]
            if (Ray_tune):
                track.log(mean_loss=min_val_test_loss,
                          mean_accuracy=max_val_acc,
                          f1_score=max_val_f1,
                          training_iteration=epoch)

            if (summary):
                val_summaries = []
                for i in range(len(avg_val_metrics)):
                    val_summaries.append(
                        tf.Summary(value=[
                            tf.Summary.Value(tag=val_tags[i],
                                             simple_value=avg_val_metrics[i]),
                        ]))

                for smr in val_summaries:
                    val_writer.add_summary(smr, epoch)

                val_writer.flush()

        meta_batch = metatrain_task_distribution.sample_batch()
        X_train_a, Y_train_a, X_train_b, Y_train_b = [], [], [], []
        for task in meta_batch:
            X_train_a.append(task.get_train_set()[0])
            Y_train_a.append(np.expand_dims(task.get_train_set()[1], -1))
            X_train_b.append(task.get_test_set()[0])
            Y_train_b.append(np.expand_dims(task.get_test_set()[1], -1))

        X_train_a = np.array(X_train_a)
        Y_train_a = np.array(Y_train_a)
        X_train_b = np.array(X_train_b)
        Y_train_b = np.array(Y_train_b)

        metatrain_loss, train_summaries = model.metatrain_op(
            epoch, X_train_a, Y_train_a, X_train_b, Y_train_b)

        if (min_metatrain_loss > metatrain_loss):
            min_metatrain_loss = metatrain_loss
        if (summary and (epoch % model.summary_interval == 0)):
            train_writer.add_summary(train_summaries, epoch)
            train_writer.flush()

    if (summary):
        train_writer.close()
        val_writer.close()

    if (not (Ray_tune)):
        # this test on n_test_ tasks * 8 test tasks
        n_test_tasks = 100
        test_metrics_list = []
        model.saver.restore(
            model.sess, model.checkpoint_path + model.summary_dir +
            "_restore_val_test_loss/model.ckpt")
        print('training ended, restored best model')
        for _ in range(n_test_tasks):
            test_meta_batch = metatest_task_distribution.sample_batch()

            X_test_a, Y_test_a, X_test_b, Y_test_b = [], [], [], []
            for task in test_meta_batch:
                X_test_a.append(task.get_train_set()[0])
                Y_test_a.append(np.expand_dims(task.get_train_set()[1], -1))
                X_test_b.append(task.get_test_set()[0])
                Y_test_b.append(np.expand_dims(task.get_test_set()[1], -1))

            for K_test_X, K_test_Y, test_test_X, test_test_Y in zip(
                    X_test_a, Y_test_a, X_test_b, Y_test_b):
                test_summaries, test_test_loss, acc, precision, recall, specificity, f1_score, auc_pr = model.val_op(
                    K_test_X, K_test_Y, test_test_X, test_test_Y)
                test_metrics_list.append([
                    test_test_loss, acc, precision, recall, specificity,
                    f1_score, auc_pr
                ])

        avg_test_metrics = np.mean(test_metrics_list, axis=0)

        print('+++ Test metrics - loss: ', avg_test_metrics[0], ' acc : ',
              avg_test_metrics[1], ' prec : ', avg_test_metrics[2],
              ' recall : ', avg_test_metrics[3], ' spec : ',
              avg_test_metrics[4], ' F1 : ', avg_test_metrics[5], ' auc_pr : ',
              avg_test_metrics[6])

    sess.close()
Example #2
0
def main(args):

    # set the random seed
    seed = args.seed
    np.random.seed(seed)
    tf.set_random_seed(seed)

    # hyperparameters
    n_episodes = 100
    n_way = 2  # because OCC is a binary classification problem
    h_dim = 64
    n_shot = args.n_shot
    n_query = args.n_query

    # load data
    base_path = "/home/USER/Documents/"
    if not (os.path.exists(base_path)):
        base_path = "/home/ubuntu/Projects/"
    if not (os.path.exists(base_path)):
        base_path = "/home/USER/Projects/"
    basefolder = base_path + "raw_data/"

    if (args.dataset == 'OMN'):
        n_epochs = 2000
        im_width, im_height, channels = 28, 28, 1
        metatrain_task_distribution, metaval_task_distribution, metatest_task_distribution = create_omniglot_allcharacters_task_distribution(
            basefolder + "omniglot/omniglot.pkl",
            train_occ=True,
            test_occ=True,
            num_training_samples_per_class=n_shot,
            num_test_samples_per_class=n_query,
            num_training_classes=2,
            meta_batch_size=1,
            seq_length=0)

    elif (args.dataset == 'CIFAR_FS'):
        n_epochs = 4000
        im_width, im_height, channels = 32, 32, 3

        metatrain_task_distribution, metaval_task_distribution, metatest_task_distribution = create_cifarfs_task_distribution(
            base_path + "data/CIFAR_FS/CIFAR_FS_train.pickle",
            base_path + "data/CIFAR_FS/CIFAR_FS_val.pickle",
            base_path + "data/CIFAR_FS/CIFAR_FS_test.pickle",
            train_occ=True,
            test_occ=True,
            num_training_samples_per_class=n_shot,
            num_test_samples_per_class=n_query,
            num_training_classes=2,
            meta_batch_size=1,
            seq_length=0)
    elif (args.dataset == 'FC100'):
        n_epochs = 4000
        im_width, im_height, channels = 32, 32, 3
        metatrain_task_distribution, metaval_task_distribution, metatest_task_distribution = create_fc100_task_distribution(
            base_path + "data/FC100/FC100_train.pickle",
            base_path + "data/FC100/FC100_val.pickle",
            base_path + "data/FC100/FC100_test.pickle",
            train_occ=True,
            test_occ=True,
            num_training_samples_per_class=n_shot,
            num_test_samples_per_class=n_query,
            num_training_classes=2,
            meta_batch_size=1,
            seq_length=0)

    else:
        n_epochs = 4000
        im_width, im_height, channels = 84, 84, 3

        metatrain_task_distribution, metaval_task_distribution, metatest_task_distribution = create_miniimagenet_task_distribution(
            basefolder + "miniImageNet_data/miniimagenet.pkl",
            train_occ=True,
            test_occ=True,
            num_training_samples_per_class=n_shot,
            num_test_samples_per_class=n_query,
            num_training_classes=2,
            meta_batch_size=1,
            seq_length=0)

    # batchNorm behavior
    support_training = True
    query_training = False

    # whether to reorder the layers such that batchNorm come last, as
    # mentioned in the original paper
    if (args.reorder == 'True'):
        reorder_layers = True
    else:
        reorder_layers = False

    # create placeholders for support and query inputs
    x = tf.placeholder(tf.float32, [None, None, im_height, im_width, channels])
    q = tf.placeholder(tf.float32, [None, None, im_height, im_width, channels])
    x_shape = tf.shape(x)
    q_shape = tf.shape(q)
    num_classes, num_support = x_shape[0], x_shape[1]
    num_classes_q, num_queries = q_shape[0], q_shape[1]
    y = tf.placeholder(tf.int64, [None, None])
    y_one_hot = tf.one_hot(y, depth=num_classes_q)

    # create the encoder network and feed inputs forward
    emb_x = encoder(tf.reshape(
        x, [num_classes * num_support, im_height, im_width, channels]),
                    h_dim,
                    h_dim,
                    reorder_layers=reorder_layers,
                    training=support_training)
    emb_dim = tf.shape(emb_x)[-1]

    # compute prototype for the normal class
    emb_x = tf.reduce_mean(tf.reshape(emb_x,
                                      [num_classes, num_support, emb_dim]),
                           axis=1)

    # encode the query set
    emb_q = encoder(tf.reshape(
        q, [num_classes_q * num_queries, im_height, im_width, channels]),
                    h_dim,
                    h_dim,
                    reuse=True,
                    reorder_layers=reorder_layers,
                    training=query_training)

    # compute euclidean distances between query embeddings and normal class
    # prototype and center (0)
    dists = euclidean_distance(emb_q, emb_x)

    # compute loss and accuracy
    log_p_y = tf.reshape(tf.nn.log_softmax(-dists),
                         [num_classes_q, num_queries, -1])
    ce_loss = - \
        tf.reduce_mean(tf.reshape(tf.reduce_sum(tf.multiply(y_one_hot, log_p_y), axis=-1), [-1]))
    acc = tf.reduce_mean(tf.to_float(tf.equal(tf.argmax(log_p_y, axis=-1), y)))

    # collect operations of batchNorm updates of moving mean and moving
    # variance
    bn_updates = []
    for i in [0, 1, 2, 3]:
        bn_updates += bn[i].updates

    # execute them before updating the model parameters
    with tf.control_dependencies(bn_updates):
        train_op = tf.train.AdamOptimizer().minimize(ce_loss)

    # summaries
    summary_dir = args.dataset + '_' + str(n_shot) + '_seed_' + str(seed)
    if (reorder_layers):
        summary_dir += '_R'
    summary_dir = summary_dir + '_Q_' + str(n_query)

    # model checkpoints
    saver = tf.train.Saver()
    checkpoint_path = base_path + 'MAML/OW_ProtoNets_checkpoints/'
    if (not (os.path.exists(checkpoint_path))):
        os.mkdir(checkpoint_path)
    if (not (os.path.exists(os.path.join(checkpoint_path, summary_dir)))):
        os.mkdir(os.path.join(checkpoint_path, summary_dir))

    # create session and initialize variables
    sess = tf.InteractiveSession()
    init_op = tf.global_variables_initializer()
    sess.run(init_op)

    # meta-training
    losses = []
    accs = []
    min_val_loss = 10000
    n_val_tasks = 50
    min_val_loss_epoch = -1
    early_stopping = False

    for ep in range(n_epochs):
        for epi in range(n_episodes):
            task = metatrain_task_distribution.sample_batch()[0]
            support = np.reshape(task.get_train_set()[0],
                                 (1, n_shot, im_height, im_width, channels))
            query = np.reshape(task.get_test_set()[0],
                               (n_way, n_query, im_height, im_width, channels))
            labels = np.tile(np.arange(n_way)[:, np.newaxis],
                             (1, n_query)).astype(np.uint8)
            _, ls, ac = sess.run([train_op, ce_loss, acc],
                                 feed_dict={
                                     x: support,
                                     q: query,
                                     y: labels
                                 })
            losses.append(ls)
            accs.append(ac)
            if (epi + 1) % 100 == 0:
                print(
                    'TR: [epoch {}/{}, episode {}/{}] => loss: {:.5f}, acc: {:.5f}'
                    .format(ep + 1, n_epochs, epi + 1, n_episodes,
                            np.mean(losses), np.mean(accs)))
                losses, accs = [], []
                # val episode
                val_losses, val_accs = [], []
                for i in range(n_val_tasks):
                    task = metaval_task_distribution.sample_batch()[0]
                    support = np.reshape(
                        task.get_train_set()[0],
                        (1, n_shot, im_height, im_width, channels))
                    query = np.reshape(
                        task.get_test_set()[0],
                        (n_way, n_query, im_height, im_width, channels))
                    labels = np.tile(
                        np.arange(n_way)[:, np.newaxis],
                        (1, n_query)).astype(np.uint8)
                    ls, ac = sess.run([ce_loss, acc],
                                      feed_dict={
                                          x: support,
                                          q: query,
                                          y: labels
                                      })
                    val_losses.append(ls)
                    val_accs.append(ac)
                mean_loss, mean_acc = np.mean(val_losses), np.mean(val_accs)
                if (mean_loss < min_val_loss):
                    min_val_loss = mean_loss
                    min_val_loss_epoch = ep
                    print('### model saved ###')
                    print(
                        'VAL: [epoch {}/{}, episode {}/{}] => loss: {:.5f}, acc: {:.5f}'
                        .format(ep + 1, n_epochs, epi + 1, n_episodes,
                                mean_loss, mean_acc))
                    saver.save(
                        sess, checkpoint_path + summary_dir +
                        "_restore_val_test_loss/model.ckpt")
                if (ep - min_val_loss_epoch > 300):
                    early_stopping = True
        if (early_stopping):
            print(
                '##### EARLY STOPPING - NO IMPROVEMENT IN THE LAST 300 EPOCHS #####'
            )
            break

    # restore best performing model on meta-validation set
    saver.restore(
        sess,
        checkpoint_path + summary_dir + "_restore_val_test_loss/model.ckpt")
    print('### restored best model ###')

    # meta-testing
    n_test_episodes = 20000
    n_test_way = 2
    n_test_shot = n_shot
    n_test_query = n_query

    print('Testing...')
    avg_acc = 0.
    for epi in range(n_test_episodes):
        task = metatest_task_distribution.sample_batch()[0]
        support = np.reshape(task.get_train_set()[0],
                             (1, n_shot, im_height, im_width, channels))
        query = np.reshape(task.get_test_set()[0],
                           (n_way, n_query, im_height, im_width, channels))
        labels = np.tile(
            np.arange(n_test_way)[:, np.newaxis],
            (1, n_test_query)).astype(np.uint8)
        ls, ac = sess.run([ce_loss, acc],
                          feed_dict={
                              x: support,
                              q: query,
                              y: labels
                          })
        avg_acc += ac
        if (epi + 1) % 50 == 0:
            print('[test episode {}/{}] => loss: {:.5f}, acc: {:.5f}'.format(
                epi + 1, n_test_episodes, ls, ac))
    avg_acc /= n_test_episodes
    print('Average Test Accuracy: {:.5f}'.format(avg_acc))