Ejemplo n.º 1
0
def main(_):

    pp.pprint(flags.FLAGS.__flags)

    order = []
    with open('imagenet_64x64_dogs_%s.txt' % FLAGS.order_file) as file_in:
        for line in file_in.readlines():
            order.append(int(line))
    order = np.array(order)

    assert FLAGS.mode == 'wgan-gp'

    NUM_CLASSES = 120
    NUM_TEST_SAMPLES_PER_CLASS = 50
    NUM_TRAIN_SAMPLES_PER_CLASS = 1300  # around 1300

    if not FLAGS.only_gen_no_cls:

        def build_cnn(inputs, is_training):
            train_or_test = {True: 'train', False: 'test'}
            if FLAGS.network_arch == 'resnet':
                logits, end_points = utils_resnet_64x64.ResNet(
                    inputs,
                    train_or_test[is_training],
                    num_outputs=NUM_CLASSES,
                    alpha=0.0,
                    scope=('ResNet-' + train_or_test[is_training]))
            else:
                raise Exception()
            return logits, end_points

        # Save all intermediate result in the result_folder
        method_name = '_'.join(
            os.path.basename(__file__).split('.')[0].split('_')[4:])
        method_name += '_gen_%d_and_select' % FLAGS.gen_how_many if FLAGS.gen_more_and_select else ''

        cls_func = '' if FLAGS.use_softmax else '_sigmoid'
        result_folder = os.path.join(
            FLAGS.result_dir, FLAGS.dataset + ('_flip' if FLAGS.flip else '') +
            '_' + FLAGS.order_file, 'nb_cl_' + str(FLAGS.nb_cl),
            'non_truncated' if FLAGS.no_truncate else 'truncated',
            FLAGS.network_arch + cls_func + '_init_' + FLAGS.init_strategy,
            'weight_decay_' + str(FLAGS.weight_decay),
            'base_lr_' + str(FLAGS.base_lr), 'adam_lr_' + str(FLAGS.adam_lr),
            method_name)

        # Add a "_run-i" suffix to the folder name if the folder exists
        if os.path.exists(result_folder):
            temp_i = 2
            while True:
                result_folder_mod = result_folder + '_run-' + str(temp_i)
                if not os.path.exists(result_folder_mod):
                    result_folder = result_folder_mod
                    break
                temp_i += 1
        os.makedirs(result_folder)
        print('Result folder: %s' % result_folder)

        graph_cls = tf.Graph()
        with graph_cls.as_default():
            '''
            Define variables
            '''
            batch_images = tf.placeholder(tf.float32, shape=[None, 64, 64, 3])
            batch = tf.Variable(0, trainable=False)
            learning_rate = tf.placeholder(tf.float32, shape=[])
            '''
            Network output mask
            '''
            mask_output = tf.placeholder(tf.bool, shape=[NUM_CLASSES])
            '''
            Old and new ground truth
            '''
            one_hot_labels_truncated = tf.placeholder(tf.float32,
                                                      shape=[None, None])
            '''
            Define the training network
            '''
            train_logits, _ = build_cnn(batch_images, True)
            train_masked_logits = tf.gather(train_logits,
                                            tf.squeeze(tf.where(mask_output)),
                                            axis=1)  # masking operation
            train_masked_logits = tf.cond(
                tf.equal(tf.rank(train_masked_logits),
                         1), lambda: tf.expand_dims(train_masked_logits, 1),
                lambda: train_masked_logits
            )  # convert to (N, 1) if the shape is (N,), otherwise softmax would output wrong values
            # Train accuracy(since there is only one class excluding the old recorded responses, this accuracy is not very meaningful)
            train_pred = tf.argmax(train_masked_logits, 1)
            train_ground_truth = tf.argmax(one_hot_labels_truncated, 1)
            correct_prediction = tf.equal(train_pred, train_ground_truth)
            train_accuracy = tf.reduce_mean(
                tf.cast(correct_prediction, tf.float32))
            train_batch_weights = tf.placeholder(tf.float32, shape=[None])

            reg_weights = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
            regularization_loss = FLAGS.weight_decay * tf.add_n(reg_weights)
            '''
            More Settings
            '''
            if FLAGS.use_softmax:
                empirical_loss = tf.losses.softmax_cross_entropy(
                    onehot_labels=one_hot_labels_truncated,
                    logits=train_masked_logits,
                    weights=train_batch_weights)
            else:
                empirical_loss = tf.losses.sigmoid_cross_entropy(
                    multi_class_labels=one_hot_labels_truncated,
                    logits=train_masked_logits,
                    weights=train_batch_weights)

            loss = empirical_loss + regularization_loss
            if FLAGS.use_momentum:
                opt = tf.train.MomentumOptimizer(
                    learning_rate, FLAGS.momentum).minimize(loss,
                                                            global_step=batch)
            else:
                opt = tf.train.GradientDescentOptimizer(
                    learning_rate).minimize(loss, global_step=batch)
            '''
            Define the testing network
            '''
            test_logits, _ = build_cnn(batch_images, False)
            test_masked_logits = tf.gather(test_logits,
                                           tf.squeeze(tf.where(mask_output)),
                                           axis=1)
            test_masked_logits = tf.cond(
                tf.equal(tf.rank(test_masked_logits),
                         1), lambda: tf.expand_dims(test_masked_logits, 1),
                lambda: test_masked_logits)
            test_masked_prob = tf.nn.softmax(test_masked_logits)
            test_pred = tf.argmax(test_masked_logits, 1)
            test_accuracy = tf.placeholder(tf.float32)
            '''
            Copy network (define the copying op)
            '''
            if FLAGS.network_arch == 'resnet':
                all_variables = tf.get_collection(tf.GraphKeys.WEIGHTS)
            else:
                raise Exception('Invalid network architecture')
            copy_ops = [
                all_variables[ix + len(all_variables) // 2].assign(var.value())
                for ix, var in enumerate(all_variables[0:len(all_variables) //
                                                       2])
            ]
            '''
            Init certain layers when new classes added
            '''
            init_ops = tf.no_op()
            if FLAGS.init_strategy == 'all':
                init_ops = tf.global_variables_initializer()
            elif FLAGS.init_strategy == 'last':
                if FLAGS.network_arch == 'resnet':
                    init_vars = [
                        var for var in tf.global_variables()
                        if 'fc' in var.name and 'train' in var.name
                    ]
                init_ops = tf.initialize_variables(init_vars)
            '''
            Create session
            '''
            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            sess = tf.Session(config=config, graph=graph_cls)
            sess.run(tf.global_variables_initializer())

            saver = tf.train.Saver()
        '''
        Summary
        '''
        train_loss_summary = tf.summary.scalar('train_loss', loss)
        train_acc_summary = tf.summary.scalar('train_accuracy', train_accuracy)
        test_acc_summary = tf.summary.scalar('test_accuracy', test_accuracy)

        summary_dir = os.path.join(result_folder, 'summary')
        if not os.path.exists(summary_dir):
            os.makedirs(summary_dir)
        train_summary_writer = tf.summary.FileWriter(
            os.path.join(summary_dir, 'train'), sess.graph)
        test_summary_writer = tf.summary.FileWriter(
            os.path.join(summary_dir, 'test'))

        iteration = 0
        '''
        Declaration of other vars
        '''
        # Average accuracy on seen classes
        aver_acc_over_time = dict()
        aver_acc_per_class_over_time = dict()
        conf_mat_over_time = dict()

        # Network mask
        mask_output_val = np.zeros([NUM_CLASSES], dtype=bool)
        mask_output_val_prev = np.zeros([NUM_CLASSES], dtype=bool)
        mask_output_test = np.zeros([NUM_CLASSES], dtype=bool)

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=1.0)
    run_config = tf.ConfigProto(gpu_options=gpu_options,
                                allow_soft_placement=True)
    run_config.gpu_options.allow_growth = True
    '''
    Train generative model(DC-GAN)
    '''
    graph_gen = tf.Graph()
    sess_wgan = tf.Session(config=run_config, graph=graph_gen)

    acwgan_obj = WGAN64x64(sess_wgan,
                           graph_gen,
                           dataset_name=(FLAGS.dataset + '_' +
                                         FLAGS.order_file),
                           mode=FLAGS.mode,
                           batch_size=FLAGS.batch_size,
                           dim=FLAGS.dim,
                           output_dim=FLAGS.output_dim,
                           lambda_param=FLAGS.lambda_param,
                           critic_iters=FLAGS.critic_iters,
                           iters=FLAGS.iters,
                           result_dir=FLAGS.result_dir_cwgan,
                           checkpoint_interval=FLAGS.gan_save_interval,
                           adam_lr=FLAGS.adam_lr,
                           use_decay=FLAGS.use_decay,
                           conditional=FLAGS.conditional,
                           acgan=FLAGS.acgan,
                           acgan_scale=FLAGS.acgan_scale,
                           acgan_scale_g=FLAGS.acgan_scale_g,
                           normalization_g=FLAGS.normalization_g,
                           normalization_d=FLAGS.normalization_d,
                           gen_bs_multiple=FLAGS.gen_bs_multiple,
                           nb_cl=FLAGS.nb_cl,
                           n_gpus=FLAGS.n_gpus)

    test_images, test_labels, test_one_hot_labels, raw_images_test = imagenet_64x64.load_test_data(
    )
    '''
    Class Incremental Learning
    '''
    print('Starting from category ' + str(FLAGS.from_class_idx + 1) + ' to ' +
          str(FLAGS.to_class_idx + 1))
    print('Adding %d categories every time' % FLAGS.nb_cl)
    assert (FLAGS.from_class_idx % FLAGS.nb_cl == 0)
    for category_idx in range(FLAGS.from_class_idx, FLAGS.to_class_idx + 1,
                              FLAGS.nb_cl):

        to_category_idx = category_idx + FLAGS.nb_cl - 1
        if FLAGS.nb_cl == 1:
            print('Adding Category ' + str(category_idx + 1))
        else:
            print('Adding Category %d-%d' %
                  (category_idx + 1, to_category_idx + 1))

        # sess_idx starts from 0
        sess_idx = category_idx / FLAGS.nb_cl

        train_x_gan = np.zeros([0, FLAGS.output_dim], dtype=np.uint8)
        train_y_gan = np.zeros([0], dtype=float)
        test_x_gan = np.zeros([0, FLAGS.output_dim], dtype=np.uint8)
        test_y_gan = np.zeros([0], dtype=float)

        if not FLAGS.only_gen_no_cls:
            # train and test data of seen classes
            train_y_one_hot = np.zeros([0, NUM_CLASSES], dtype=np.float32)
            test_x = np.zeros([0, 64, 64, 3], dtype=np.float32)
            test_y = np.zeros([0], dtype=np.float32)

        for category_idx_in_group in range(category_idx, to_category_idx + 1):
            real_category_idx = order[category_idx_in_group]
            real_images_train_cur_cls, raw_images_train_cur_cls = imagenet_64x64.load_train_data(
                real_category_idx, flip=FLAGS.flip)

            # GAN
            train_x_gan = np.concatenate(
                (train_x_gan, raw_images_train_cur_cls))
            train_y_gan_cur_cls = np.ones([len(raw_images_train_cur_cls)]) * (
                category_idx_in_group % FLAGS.nb_cl)
            train_y_gan = np.concatenate((train_y_gan, train_y_gan_cur_cls))

            if not FLAGS.only_gen_no_cls:
                train_y_one_hot_cur_cls = np.zeros(
                    [len(raw_images_train_cur_cls), NUM_CLASSES])
                train_y_one_hot_cur_cls[:, category_idx_in_group] = np.ones(
                    len(raw_images_train_cur_cls))
                train_y_one_hot = np.concatenate(
                    (train_y_one_hot, train_y_one_hot_cur_cls))

        for category_idx_seen in range(to_category_idx + 1):

            real_category_idx = order[category_idx_seen]
            test_indices_cur_cls = [
                idx for idx in range(len(test_labels))
                if test_labels[idx] == real_category_idx
            ]
            test_x_gan_cur_cls = raw_images_test[test_indices_cur_cls, :]
            test_y_gan_cur_cls = np.ones([len(test_indices_cur_cls)]) * (
                category_idx_seen % FLAGS.nb_cl)
            test_x_gan = np.concatenate((test_x_gan, test_x_gan_cur_cls))
            test_y_gan = np.concatenate((test_y_gan, test_y_gan_cur_cls))

            # Classification network
            if not FLAGS.only_gen_no_cls:
                test_indices_cur_cls = [
                    idx for idx in range(len(test_labels))
                    if test_labels[idx] == real_category_idx
                ]
                test_x_cur_cls = test_images[test_indices_cur_cls, :]
                test_y_cur_cls = np.ones([len(test_indices_cur_cls)
                                          ]) * category_idx_seen

                test_x = np.concatenate((test_x, test_x_cur_cls))
                test_y = np.concatenate((test_y, test_y_cur_cls))
        '''
        Train classification model
        '''
        # No need to train the classifier if there is only one class
        if (to_category_idx > 0
                and not FLAGS.only_gen_no_cls) or not FLAGS.use_softmax:

            # init certain layers
            sess.run(init_ops)

            if FLAGS.no_truncate:
                mask_output_val[:] = True
            else:
                mask_output_val[:to_category_idx + 1] = True

            # Test on all seen classes
            mask_output_test[:to_category_idx + 1] = True
            '''
            Generate samples of old classes
            '''
            train_x = np.copy(train_x_gan)
            if FLAGS.no_truncate:
                train_y_truncated = train_y_one_hot[:, :]
            else:
                train_y_truncated = train_y_one_hot[:, :to_category_idx + 1]

            # Load old class model
            if sess_idx > 0:
                if not acwgan_obj.load(category_idx - 1)[0]:
                    raise Exception(
                        "[!] Train a model first, then run test mode")
                gen_samples_x = np.zeros((0, FLAGS.output_dim), dtype=int)
                for _ in range(category_idx):
                    gen_samples_x_frac, _, _ = acwgan_obj.test(
                        NUM_TRAIN_SAMPLES_PER_CLASS, label=None)
                    gen_samples_x = np.concatenate(
                        (gen_samples_x, gen_samples_x_frac))

                # import wgan.tflib.save_images
                # wgan.tflib.save_images.save_images(gen_samples_x[:128].reshape((128, 3, 32, 32)),
                #                                    'test.jpg')

                # get the output y
                gen_samples_y = np.zeros(
                    (len(gen_samples_x), to_category_idx + 1))
                if category_idx == 1:
                    gen_samples_y[:, 0] = np.ones((len(gen_samples_x)))
                else:
                    test_pred_val = []
                    mask_output_val_prev[:category_idx] = True
                    for i in range(0, len(gen_samples_x),
                                   FLAGS.test_batch_size):
                        gen_samples_x_batch = gen_samples_x[i:i + FLAGS.
                                                            test_batch_size]
                        test_pred_val_batch = sess.run(
                            test_pred,
                            feed_dict={
                                batch_images:
                                imagenet_64x64.convert_images(
                                    gen_samples_x_batch),
                                mask_output:
                                mask_output_val_prev
                            })
                        test_pred_val.extend(test_pred_val_batch)
                    for i in range(len(gen_samples_x)):
                        gen_samples_y[i, test_pred_val[i]] = 1

                train_weights_val = np.concatenate(
                    (np.ones(len(train_x)) * FLAGS.ratio,
                     np.ones(len(gen_samples_x)) * (1 - FLAGS.ratio)))
                train_x = np.concatenate((train_x, gen_samples_x))
                train_y_truncated = np.concatenate(
                    (train_y_truncated, gen_samples_y))
            else:
                train_weights_val = np.ones(len(train_x)) * FLAGS.ratio

            # # DEBUG:
            # train_indices = [idx for idx in range(NUM_SAMPLES_TOTAL) if train_labels[idx] <= category_idx]
            # train_x = raw_images_train[train_indices, :]
            # # Record the response of the new data using the old model(category_idx is consistent with the number of True in mask_output_val_prev)
            # train_y_truncated = train_one_hot_labels[train_indices, :category_idx + 1]

            # Training set
            # Convert the raw images from the data-files to floating-points.
            train_x = imagenet_64x64.convert_images(train_x)

            # Shuffle the indices and create mini-batch
            batch_indices_perm = []

            epoch_idx = 0
            lr = FLAGS.base_lr
            '''
            Training with mixed data
            '''
            while True:
                # Generate mini-batch
                if len(batch_indices_perm) == 0:
                    if epoch_idx >= FLAGS.epochs_per_category:
                        break
                    if epoch_idx in lr_strat:
                        lr /= FLAGS.lr_factor
                        print("NEW LEARNING RATE: %f" % lr)
                    epoch_idx = epoch_idx + 1

                    shuffled_indices = range(train_x.shape[0])
                    np.random.shuffle(shuffled_indices)
                    for i in range(0, len(shuffled_indices),
                                   FLAGS.train_batch_size):
                        batch_indices_perm.append(
                            shuffled_indices[i:i + FLAGS.train_batch_size])
                    batch_indices_perm.reverse()

                popped_batch_idx = batch_indices_perm.pop()

                # Use the random index to select random images and labels.
                train_x_batch = train_x[popped_batch_idx, :, :, :]
                train_y_batch = [
                    train_y_truncated[k] for k in popped_batch_idx
                ]
                train_weights_batch_val = train_weights_val[popped_batch_idx]

                # Train
                train_loss_summary_str, train_acc_summary_str, train_accuracy_val, \
                train_loss_val, train_empirical_loss_val, train_reg_loss_val, _ = sess.run(
                    [train_loss_summary, train_acc_summary, train_accuracy, loss, empirical_loss,
                     regularization_loss, opt], feed_dict={batch_images: train_x_batch,
                                                           one_hot_labels_truncated: train_y_batch,
                                                           mask_output: mask_output_val,
                                                           learning_rate: lr,
                                                           train_batch_weights: train_weights_batch_val})

                # Test
                if iteration % FLAGS.test_interval == 0:
                    sess.run(copy_ops)

                    # Divide and conquer: to avoid allocating too much GPU memory
                    test_pred_val = []
                    for i in range(0, len(test_x), FLAGS.test_batch_size):
                        test_x_batch = test_x[i:i + FLAGS.test_batch_size]
                        test_pred_val_batch = sess.run(test_pred,
                                                       feed_dict={
                                                           batch_images:
                                                           test_x_batch,
                                                           mask_output:
                                                           mask_output_test
                                                       })
                        test_pred_val.extend(test_pred_val_batch)

                    test_accuracy_val = 1. * np.sum(
                        np.equal(test_pred_val, test_y)) / (len(test_pred_val))
                    test_per_class_accuracy_val = np.diag(
                        confusion_matrix(test_y, test_pred_val)) * 2
                    # I simply multiply the correct predictions by 2 to calculate the accuracy since there are 50 samples per class in the test set

                    test_acc_summary_str = sess.run(
                        test_acc_summary,
                        feed_dict={test_accuracy: test_accuracy_val})

                    test_summary_writer.add_summary(test_acc_summary_str,
                                                    iteration)

                    print("TEST: step %d, lr %.4f, accuracy %g" %
                          (iteration, lr, test_accuracy_val))
                    print("PER CLASS ACCURACY: " + " | ".join(
                        str(o) + '%' for o in test_per_class_accuracy_val))

                # Print the training logs
                if iteration % FLAGS.display_interval == 0:
                    train_summary_writer.add_summary(train_loss_summary_str,
                                                     iteration)
                    train_summary_writer.add_summary(train_acc_summary_str,
                                                     iteration)
                    print(
                        "TRAIN: epoch %d, step %d, lr %.4f, accuracy %g, loss %g, empirical %g, reg %g"
                        % (epoch_idx, iteration, lr, train_accuracy_val,
                           train_loss_val, train_empirical_loss_val,
                           train_reg_loss_val))

                iteration = iteration + 1
            '''
            Final test(before the next class is added)
            '''
            sess.run(copy_ops)
            # Divide and conquer: to avoid allocating too much GPU memory
            test_pred_val = []
            for i in range(0, len(test_x), FLAGS.test_batch_size):
                test_x_batch = test_x[i:i + FLAGS.test_batch_size]
                test_pred_val_batch = sess.run(test_pred,
                                               feed_dict={
                                                   batch_images: test_x_batch,
                                                   mask_output:
                                                   mask_output_test
                                               })
                test_pred_val.extend(test_pred_val_batch)

            test_accuracy_val = 1. * np.sum(np.equal(
                test_pred_val, test_y)) / (len(test_pred_val))
            conf_mat = confusion_matrix(test_y, test_pred_val)
            test_per_class_accuracy_val = np.diag(conf_mat)

            # Record and save the cumulative accuracy
            aver_acc_over_time[to_category_idx] = test_accuracy_val
            aver_acc_per_class_over_time[
                to_category_idx] = test_per_class_accuracy_val
            conf_mat_over_time[to_category_idx] = conf_mat

            dump_obj = dict()
            dump_obj['flags'] = flags.FLAGS.__flags
            dump_obj['aver_acc_over_time'] = aver_acc_over_time
            dump_obj[
                'aver_acc_per_class_over_time'] = aver_acc_per_class_over_time
            dump_obj['conf_mat_over_time'] = conf_mat_over_time

            np_file_result = os.path.join(result_folder, 'acc_over_time.pkl')
            with open(np_file_result, 'wb') as file:
                pickle.dump(dump_obj, file)

            visualize_result.vis(np_file_result, 'ImageNetDogs')
        '''
        Train generative model(W-GAN)
        '''
        if acwgan_obj.check_model(to_category_idx):
            print(
                " [*] Model of Class %d-%d exists. Skip the training process" %
                (category_idx + 1, to_category_idx + 1))
        else:
            print(
                " [*] Model of Class %d-%d does not exist. Start the training process"
                % (category_idx + 1, to_category_idx + 1))
            acwgan_obj.load(to_category_idx - FLAGS.nb_cl)
            for _ in range(category_idx):
                gen_samples_x, _, _ = acwgan_obj.test(
                    NUM_TRAIN_SAMPLES_PER_CLASS, label=None)
                gen_samples_x = np.uint8(gen_samples_x)
                train_x_gan = np.concatenate((train_x_gan, gen_samples_x))
                train_y_gan = np.concatenate(
                    (train_y_gan, np.zeros(len(gen_samples_x))))
            acwgan_obj.train(train_x_gan, train_y_gan, test_x_gan, test_y_gan,
                             to_category_idx)

    # Save the final model
    if not FLAGS.only_gen_no_cls:
        checkpoint_dir = os.path.join(result_folder, 'checkpoints')
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        saver.save(sess, os.path.join(checkpoint_dir, 'model.ckpt'))
        sess.close()
Ejemplo n.º 2
0
def main(_):

    pp.pprint(flags.FLAGS.__flags)

    order = []
    with open('imagenet_64x64_dogs_%s.txt' % FLAGS.order_file) as file_in:
        for line in file_in.readlines():
            order.append(int(line))
    order = np.array(order)

    NUM_CLASSES = 120
    NUM_TEST_SAMPLES_PER_CLASS = 50

    def build_cnn(inputs, is_training):
        train_or_test = {True: 'train', False: 'test'}
        if FLAGS.network_arch == 'resnet':
            logits, end_points = utils_resnet_64x64.ResNet(
                inputs,
                train_or_test[is_training],
                num_outputs=NUM_CLASSES,
                alpha=0.0,
                scope=('ResNet-' + train_or_test[is_training]))
        else:
            raise Exception()
        return logits, end_points

    # Save all intermediate result in the result_folder
    method_name = '_'.join(
        os.path.basename(__file__).split('.')[0].split('_')[4:])
    cls_func = '' if FLAGS.use_softmax else '_sigmoid'
    result_folder = os.path.join(
        FLAGS.result_dir, FLAGS.dataset + ('_flip' if FLAGS.flip else '') +
        '_' + FLAGS.order_file, 'nb_cl_' + str(FLAGS.nb_cl),
        'non_truncated' if FLAGS.no_truncate else 'truncated',
        FLAGS.network_arch + cls_func + '_init_' + FLAGS.init_strategy,
        'weight_decay_' + str(FLAGS.weight_decay),
        'base_lr_' + str(FLAGS.base_lr), method_name)

    # Add a "_run-i" suffix to the folder name if the folder exists
    if os.path.exists(result_folder):
        temp_i = 2
        while True:
            result_folder_mod = result_folder + '_run-' + str(temp_i)
            if not os.path.exists(result_folder_mod):
                result_folder = result_folder_mod
                break
            temp_i += 1
    os.makedirs(result_folder)
    print('Result folder: %s' % result_folder)
    '''
    Define variables
    '''
    batch_images = tf.placeholder(tf.float32, shape=[None, 64, 64, 3])
    batch = tf.Variable(0, trainable=False)
    learning_rate = tf.placeholder(tf.float32, shape=[])
    '''
    Network output mask
    '''
    mask_output = tf.placeholder(tf.bool, shape=[NUM_CLASSES])
    '''
    Old and new ground truth
    '''
    one_hot_labels_truncated = tf.placeholder(tf.float32, shape=[None, None])
    '''
    Define the training network
    '''
    train_logits, _ = build_cnn(batch_images, True)
    train_masked_logits = tf.gather(train_logits,
                                    tf.squeeze(tf.where(mask_output)),
                                    axis=1)
    train_masked_logits = tf.cond(
        tf.equal(tf.rank(train_masked_logits),
                 1), lambda: tf.expand_dims(train_masked_logits, 1),
        lambda: train_masked_logits)
    train_pred = tf.argmax(train_masked_logits, 1)
    train_ground_truth = tf.argmax(one_hot_labels_truncated, 1)
    correct_prediction = tf.equal(train_pred, train_ground_truth)
    train_accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    reg_weights = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    regularization_loss = FLAGS.weight_decay * tf.add_n(reg_weights)
    '''
    More Settings
    '''
    if FLAGS.use_softmax:
        empirical_loss = tf.losses.softmax_cross_entropy(
            onehot_labels=one_hot_labels_truncated, logits=train_masked_logits)
    else:
        empirical_loss = tf.losses.sigmoid_cross_entropy(
            multi_class_labels=one_hot_labels_truncated,
            logits=train_masked_logits)

    loss = empirical_loss + regularization_loss
    if FLAGS.use_momentum:
        opt = tf.train.MomentumOptimizer(
            learning_rate, FLAGS.momentum).minimize(loss, global_step=batch)
    else:
        opt = tf.train.GradientDescentOptimizer(learning_rate).minimize(
            loss, global_step=batch)
    '''
    Define the testing network
    '''
    test_logits, _ = build_cnn(batch_images, False)
    test_masked_logits = tf.gather(test_logits,
                                   tf.squeeze(tf.where(mask_output)),
                                   axis=1)
    test_masked_logits = tf.cond(tf.equal(tf.rank(test_masked_logits), 1),
                                 lambda: tf.expand_dims(test_masked_logits, 1),
                                 lambda: test_masked_logits)
    test_pred = tf.argmax(test_masked_logits, 1)
    test_accuracy = tf.placeholder(tf.float32)
    '''
    Copy network (define the copying op)
    '''
    if FLAGS.network_arch == 'resnet':
        all_variables = tf.get_collection(tf.GraphKeys.WEIGHTS)
    else:
        raise Exception('Invalid network architecture')
    copy_ops = [
        all_variables[ix + len(all_variables) // 2].assign(var.value())
        for ix, var in enumerate(all_variables[0:len(all_variables) // 2])
    ]
    '''
    Init certain layers when new classes added
    '''
    init_ops = tf.no_op()
    if FLAGS.init_strategy == 'all':
        init_ops = tf.global_variables_initializer()
    elif FLAGS.init_strategy == 'last':
        if FLAGS.network_arch == 'resnet':
            init_vars = [
                var for var in tf.global_variables()
                if 'fc' in var.name and 'train' in var.name
            ]
        init_ops = tf.initialize_variables(init_vars)
    '''
    Create session
    '''
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    saver = tf.train.Saver()
    '''
    Summary
    '''
    train_loss_summary = tf.summary.scalar('train_loss', loss)
    train_acc_summary = tf.summary.scalar('train_accuracy', train_accuracy)
    test_acc_summary = tf.summary.scalar('test_accuracy', test_accuracy)

    summary_dir = os.path.join(result_folder, 'summary')
    if not os.path.exists(summary_dir):
        os.makedirs(summary_dir)
    train_summary_writer = tf.summary.FileWriter(
        os.path.join(summary_dir, 'train'), sess.graph)
    test_summary_writer = tf.summary.FileWriter(
        os.path.join(summary_dir, 'test'))

    iteration = 0
    '''
    Declaration of other vars
    '''
    # Average accuracy on seen classes
    aver_acc_over_time = dict()
    aver_acc_per_class_over_time = dict()
    conf_mat_over_time = dict()

    # Network mask
    mask_output_val = np.zeros([NUM_CLASSES], dtype=bool)
    mask_output_test = np.zeros([NUM_CLASSES], dtype=bool)

    # train and test data of seen classes
    train_x = np.zeros([0, 64, 64, 3], dtype=np.float32)
    train_y = np.zeros([0, NUM_CLASSES], dtype=np.float32)
    test_x = np.zeros([0, 64, 64, 3], dtype=np.float32)
    test_y = np.zeros([0], dtype=np.float32)

    test_images, test_labels, test_one_hot_labels, _ = imagenet_64x64.load_test_data(
    )
    '''
    Class Incremental Learning
    '''
    print('Starting from category ' + str(FLAGS.from_class_idx + 1) + ' to ' +
          str(FLAGS.to_class_idx + 1))
    print('Adding %d categories every time' % FLAGS.nb_cl)
    assert (FLAGS.from_class_idx % FLAGS.nb_cl == 0)
    for category_idx in range(FLAGS.from_class_idx, FLAGS.to_class_idx + 1,
                              FLAGS.nb_cl):

        to_category_idx = category_idx + FLAGS.nb_cl - 1
        if FLAGS.nb_cl == 1:
            print('Adding Category ' + str(category_idx + 1))
        else:
            print('Adding Category %d-%d' %
                  (category_idx + 1, to_category_idx + 1))

        if FLAGS.no_truncate:
            mask_output_val[:] = True
        else:
            mask_output_val[:to_category_idx + 1] = True

        # Test on all seen classes
        mask_output_test[:to_category_idx + 1] = True

        for category_idx_in_group in range(category_idx, to_category_idx + 1):
            real_category_idx = order[category_idx_in_group]
            real_images_train_cur_cls, _ = imagenet_64x64.load_train_data(
                real_category_idx, flip=FLAGS.flip)
            train_y_cur_cls = np.zeros(
                [len(real_images_train_cur_cls), NUM_CLASSES])
            train_y_cur_cls[:, category_idx_in_group] = np.ones(
                [len(real_images_train_cur_cls)])

            train_x = np.concatenate((train_x, real_images_train_cur_cls))
            train_y = np.concatenate((train_y, train_y_cur_cls))

            test_indices_cur_cls = [
                idx for idx in range(len(test_labels))
                if test_labels[idx] == real_category_idx
            ]
            test_x_cur_cls = test_images[test_indices_cur_cls, :]
            test_y_cur_cls = np.ones([len(test_indices_cur_cls)
                                      ]) * category_idx_in_group

            test_x = np.concatenate((test_x, test_x_cur_cls))
            test_y = np.concatenate((test_y, test_y_cur_cls))

        if FLAGS.no_truncate:
            train_y_truncated = train_y[:, :]
        else:
            train_y_truncated = train_y[:, :to_category_idx + 1]

        # No need to train the classifier if there is only one class
        if to_category_idx > 0 or not FLAGS.use_softmax:

            # init certain layers
            sess.run(init_ops)

            # Shuffle the indices and create mini-batch
            batch_indices_perm = []

            epoch_idx = 0
            lr = FLAGS.base_lr

            while True:
                # Generate mini-batch
                if len(batch_indices_perm) == 0:
                    if epoch_idx >= FLAGS.epochs_per_category:
                        break
                    if epoch_idx in lr_strat:
                        lr /= FLAGS.lr_factor
                        print("NEW LEARNING RATE: %f" % lr)
                    epoch_idx = epoch_idx + 1

                    shuffled_indices = range(len(train_x))
                    np.random.shuffle(shuffled_indices)
                    for i in range(0, len(shuffled_indices),
                                   FLAGS.train_batch_size):
                        batch_indices_perm.append(
                            shuffled_indices[i:i + FLAGS.train_batch_size])
                    batch_indices_perm.reverse()

                popped_batch_idx = batch_indices_perm.pop()

                # Use the random index to select random images and labels.
                train_x_batch = train_x[popped_batch_idx, :, :, :]
                train_y_batch = [
                    train_y_truncated[k] for k in popped_batch_idx
                ]

                # Train
                train_loss_summary_str, train_acc_summary_str, train_accuracy_val, \
                train_loss_val, train_empirical_loss_val, train_reg_loss_val, _ = sess.run(
                    [train_loss_summary, train_acc_summary, train_accuracy, loss, empirical_loss,
                     regularization_loss, opt], feed_dict={batch_images: train_x_batch,
                                                           one_hot_labels_truncated: train_y_batch,
                                                           mask_output: mask_output_val,
                                                           learning_rate: lr})

                # Test
                if iteration % FLAGS.test_interval == 0:
                    sess.run(copy_ops)

                    # Divide and conquer: to avoid allocating too much GPU memory
                    test_pred_val = []
                    for i in range(0, len(test_x), FLAGS.test_batch_size):
                        test_x_batch = test_x[i:i + FLAGS.test_batch_size]
                        test_pred_val_batch = sess.run(test_pred,
                                                       feed_dict={
                                                           batch_images:
                                                           test_x_batch,
                                                           mask_output:
                                                           mask_output_test
                                                       })
                        test_pred_val.extend(test_pred_val_batch)

                    test_accuracy_val = 1. * np.sum(
                        np.equal(test_pred_val, test_y)) / (len(test_pred_val))
                    test_per_class_accuracy_val = np.diag(
                        confusion_matrix(test_y, test_pred_val)) * 2
                    # I simply multiply the correct predictions by 2 to calculate the accuracy since there are 50 samples per class in the test set

                    test_acc_summary_str = sess.run(
                        test_acc_summary,
                        feed_dict={test_accuracy: test_accuracy_val})

                    test_summary_writer.add_summary(test_acc_summary_str,
                                                    iteration)

                    print("TEST: step %d, lr %.4f, accuracy %g" %
                          (iteration, lr, test_accuracy_val))
                    print("PER CLASS ACCURACY: " + " | ".join(
                        str(o) + '%' for o in test_per_class_accuracy_val))

                # Print the training logs
                if iteration % FLAGS.display_interval == 0:
                    train_summary_writer.add_summary(train_loss_summary_str,
                                                     iteration)
                    train_summary_writer.add_summary(train_acc_summary_str,
                                                     iteration)
                    print(
                        "TRAIN: epoch %d, step %d, lr %.4f, accuracy %g, loss %g, empirical %g, reg %g"
                        % (epoch_idx, iteration, lr, train_accuracy_val,
                           train_loss_val, train_empirical_loss_val,
                           train_reg_loss_val))

                iteration = iteration + 1
            '''
            Final test(before the next class is added)
            '''
            sess.run(copy_ops)
            # Divide and conquer: to avoid allocating too much GPU memory
            test_pred_val = []
            for i in range(0, len(test_x), FLAGS.test_batch_size):
                test_x_batch = test_x[i:i + FLAGS.test_batch_size]
                test_pred_val_batch = sess.run(test_pred,
                                               feed_dict={
                                                   batch_images: test_x_batch,
                                                   mask_output:
                                                   mask_output_test
                                               })
                test_pred_val.extend(test_pred_val_batch)

            test_accuracy_val = 1. * np.sum(np.equal(
                test_pred_val, test_y)) / (len(test_pred_val))
            conf_mat = confusion_matrix(test_y, test_pred_val)
            test_per_class_accuracy_val = np.diag(conf_mat)

            # Record and save the cumulative accuracy
            aver_acc_over_time[to_category_idx] = test_accuracy_val
            aver_acc_per_class_over_time[
                to_category_idx] = test_per_class_accuracy_val
            conf_mat_over_time[to_category_idx] = conf_mat

            dump_obj = dict()
            dump_obj['flags'] = flags.FLAGS.__flags
            dump_obj['aver_acc_over_time'] = aver_acc_over_time
            dump_obj[
                'aver_acc_per_class_over_time'] = aver_acc_per_class_over_time
            dump_obj['conf_mat_over_time'] = conf_mat_over_time

            np_file_result = os.path.join(result_folder, 'acc_over_time.pkl')
            with open(np_file_result, 'wb') as file:
                pickle.dump(dump_obj, file)

            visualize_result.vis(np_file_result, 'ImageNetDogs')

    # Save the final model
    checkpoint_dir = os.path.join(result_folder, 'checkpoints')
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    saver.save(sess, os.path.join(checkpoint_dir, 'model.ckpt'))
    sess.close()