示例#1
0
def main(_):

    pp.pprint(flags.FLAGS.__flags)

    order = []
    with open('imagenet_64x64_dogs_%s.txt' % FLAGS.order_file) as file_in:
        for line in file_in.readlines():
            order.append(int(line))
    order = np.array(order)

    assert FLAGS.mode == 'wgan-gp'

    NUM_CLASSES = 120
    NUM_TEST_SAMPLES_PER_CLASS = 50
    NUM_TRAIN_SAMPLES_PER_CLASS = 1300  # around 1300

    if not FLAGS.only_gen_no_cls:

        def build_cnn(inputs, is_training):
            train_or_test = {True: 'train', False: 'test'}
            if FLAGS.network_arch == 'resnet':
                logits, end_points = utils_resnet_64x64.ResNet(
                    inputs,
                    train_or_test[is_training],
                    num_outputs=NUM_CLASSES,
                    alpha=0.0,
                    scope=('ResNet-' + train_or_test[is_training]))
            else:
                raise Exception()
            return logits, end_points

        # Save all intermediate result in the result_folder
        method_name = '_'.join(
            os.path.basename(__file__).split('.')[0].split('_')[4:])
        method_name += '_gen_%d_and_select' % FLAGS.gen_how_many if FLAGS.gen_more_and_select else ''

        cls_func = '' if FLAGS.use_softmax else '_sigmoid'
        result_folder = os.path.join(
            FLAGS.result_dir, FLAGS.dataset + ('_flip' if FLAGS.flip else '') +
            '_' + FLAGS.order_file, 'nb_cl_' + str(FLAGS.nb_cl),
            'non_truncated' if FLAGS.no_truncate else 'truncated',
            FLAGS.network_arch + cls_func + '_init_' + FLAGS.init_strategy,
            'weight_decay_' + str(FLAGS.weight_decay),
            'base_lr_' + str(FLAGS.base_lr), 'adam_lr_' + str(FLAGS.adam_lr),
            method_name)

        # Add a "_run-i" suffix to the folder name if the folder exists
        if os.path.exists(result_folder):
            temp_i = 2
            while True:
                result_folder_mod = result_folder + '_run-' + str(temp_i)
                if not os.path.exists(result_folder_mod):
                    result_folder = result_folder_mod
                    break
                temp_i += 1
        os.makedirs(result_folder)
        print('Result folder: %s' % result_folder)

        graph_cls = tf.Graph()
        with graph_cls.as_default():
            '''
            Define variables
            '''
            batch_images = tf.placeholder(tf.float32, shape=[None, 64, 64, 3])
            batch = tf.Variable(0, trainable=False)
            learning_rate = tf.placeholder(tf.float32, shape=[])
            '''
            Network output mask
            '''
            mask_output = tf.placeholder(tf.bool, shape=[NUM_CLASSES])
            '''
            Old and new ground truth
            '''
            one_hot_labels_truncated = tf.placeholder(tf.float32,
                                                      shape=[None, None])
            '''
            Define the training network
            '''
            train_logits, _ = build_cnn(batch_images, True)
            train_masked_logits = tf.gather(train_logits,
                                            tf.squeeze(tf.where(mask_output)),
                                            axis=1)  # masking operation
            train_masked_logits = tf.cond(
                tf.equal(tf.rank(train_masked_logits),
                         1), lambda: tf.expand_dims(train_masked_logits, 1),
                lambda: train_masked_logits
            )  # convert to (N, 1) if the shape is (N,), otherwise softmax would output wrong values
            # Train accuracy(since there is only one class excluding the old recorded responses, this accuracy is not very meaningful)
            train_pred = tf.argmax(train_masked_logits, 1)
            train_ground_truth = tf.argmax(one_hot_labels_truncated, 1)
            correct_prediction = tf.equal(train_pred, train_ground_truth)
            train_accuracy = tf.reduce_mean(
                tf.cast(correct_prediction, tf.float32))
            train_batch_weights = tf.placeholder(tf.float32, shape=[None])

            reg_weights = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
            regularization_loss = FLAGS.weight_decay * tf.add_n(reg_weights)
            '''
            More Settings
            '''
            if FLAGS.use_softmax:
                empirical_loss = tf.losses.softmax_cross_entropy(
                    onehot_labels=one_hot_labels_truncated,
                    logits=train_masked_logits,
                    weights=train_batch_weights)
            else:
                empirical_loss = tf.losses.sigmoid_cross_entropy(
                    multi_class_labels=one_hot_labels_truncated,
                    logits=train_masked_logits,
                    weights=train_batch_weights)

            loss = empirical_loss + regularization_loss
            if FLAGS.use_momentum:
                opt = tf.train.MomentumOptimizer(
                    learning_rate, FLAGS.momentum).minimize(loss,
                                                            global_step=batch)
            else:
                opt = tf.train.GradientDescentOptimizer(
                    learning_rate).minimize(loss, global_step=batch)
            '''
            Define the testing network
            '''
            test_logits, _ = build_cnn(batch_images, False)
            test_masked_logits = tf.gather(test_logits,
                                           tf.squeeze(tf.where(mask_output)),
                                           axis=1)
            test_masked_logits = tf.cond(
                tf.equal(tf.rank(test_masked_logits),
                         1), lambda: tf.expand_dims(test_masked_logits, 1),
                lambda: test_masked_logits)
            test_masked_prob = tf.nn.softmax(test_masked_logits)
            test_pred = tf.argmax(test_masked_logits, 1)
            test_accuracy = tf.placeholder(tf.float32)
            '''
            Copy network (define the copying op)
            '''
            if FLAGS.network_arch == 'resnet':
                all_variables = tf.get_collection(tf.GraphKeys.WEIGHTS)
            else:
                raise Exception('Invalid network architecture')
            copy_ops = [
                all_variables[ix + len(all_variables) // 2].assign(var.value())
                for ix, var in enumerate(all_variables[0:len(all_variables) //
                                                       2])
            ]
            '''
            Init certain layers when new classes added
            '''
            init_ops = tf.no_op()
            if FLAGS.init_strategy == 'all':
                init_ops = tf.global_variables_initializer()
            elif FLAGS.init_strategy == 'last':
                if FLAGS.network_arch == 'resnet':
                    init_vars = [
                        var for var in tf.global_variables()
                        if 'fc' in var.name and 'train' in var.name
                    ]
                init_ops = tf.initialize_variables(init_vars)
            '''
            Create session
            '''
            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            sess = tf.Session(config=config, graph=graph_cls)
            sess.run(tf.global_variables_initializer())

            saver = tf.train.Saver()
        '''
        Summary
        '''
        train_loss_summary = tf.summary.scalar('train_loss', loss)
        train_acc_summary = tf.summary.scalar('train_accuracy', train_accuracy)
        test_acc_summary = tf.summary.scalar('test_accuracy', test_accuracy)

        summary_dir = os.path.join(result_folder, 'summary')
        if not os.path.exists(summary_dir):
            os.makedirs(summary_dir)
        train_summary_writer = tf.summary.FileWriter(
            os.path.join(summary_dir, 'train'), sess.graph)
        test_summary_writer = tf.summary.FileWriter(
            os.path.join(summary_dir, 'test'))

        iteration = 0
        '''
        Declaration of other vars
        '''
        # Average accuracy on seen classes
        aver_acc_over_time = dict()
        aver_acc_per_class_over_time = dict()
        conf_mat_over_time = dict()

        # Network mask
        mask_output_val = np.zeros([NUM_CLASSES], dtype=bool)
        mask_output_val_prev = np.zeros([NUM_CLASSES], dtype=bool)
        mask_output_test = np.zeros([NUM_CLASSES], dtype=bool)

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=1.0)
    run_config = tf.ConfigProto(gpu_options=gpu_options,
                                allow_soft_placement=True)
    run_config.gpu_options.allow_growth = True
    '''
    Train generative model(DC-GAN)
    '''
    graph_gen = tf.Graph()
    sess_wgan = tf.Session(config=run_config, graph=graph_gen)

    acwgan_obj = WGAN64x64(sess_wgan,
                           graph_gen,
                           dataset_name=(FLAGS.dataset + '_' +
                                         FLAGS.order_file),
                           mode=FLAGS.mode,
                           batch_size=FLAGS.batch_size,
                           dim=FLAGS.dim,
                           output_dim=FLAGS.output_dim,
                           lambda_param=FLAGS.lambda_param,
                           critic_iters=FLAGS.critic_iters,
                           iters=FLAGS.iters,
                           result_dir=FLAGS.result_dir_cwgan,
                           checkpoint_interval=FLAGS.gan_save_interval,
                           adam_lr=FLAGS.adam_lr,
                           use_decay=FLAGS.use_decay,
                           conditional=FLAGS.conditional,
                           acgan=FLAGS.acgan,
                           acgan_scale=FLAGS.acgan_scale,
                           acgan_scale_g=FLAGS.acgan_scale_g,
                           normalization_g=FLAGS.normalization_g,
                           normalization_d=FLAGS.normalization_d,
                           gen_bs_multiple=FLAGS.gen_bs_multiple,
                           nb_cl=FLAGS.nb_cl,
                           n_gpus=FLAGS.n_gpus)

    test_images, test_labels, test_one_hot_labels, raw_images_test = imagenet_64x64.load_test_data(
    )
    '''
    Class Incremental Learning
    '''
    print('Starting from category ' + str(FLAGS.from_class_idx + 1) + ' to ' +
          str(FLAGS.to_class_idx + 1))
    print('Adding %d categories every time' % FLAGS.nb_cl)
    assert (FLAGS.from_class_idx % FLAGS.nb_cl == 0)
    for category_idx in range(FLAGS.from_class_idx, FLAGS.to_class_idx + 1,
                              FLAGS.nb_cl):

        to_category_idx = category_idx + FLAGS.nb_cl - 1
        if FLAGS.nb_cl == 1:
            print('Adding Category ' + str(category_idx + 1))
        else:
            print('Adding Category %d-%d' %
                  (category_idx + 1, to_category_idx + 1))

        # sess_idx starts from 0
        sess_idx = category_idx / FLAGS.nb_cl

        train_x_gan = np.zeros([0, FLAGS.output_dim], dtype=np.uint8)
        train_y_gan = np.zeros([0], dtype=float)
        test_x_gan = np.zeros([0, FLAGS.output_dim], dtype=np.uint8)
        test_y_gan = np.zeros([0], dtype=float)

        if not FLAGS.only_gen_no_cls:
            # train and test data of seen classes
            train_y_one_hot = np.zeros([0, NUM_CLASSES], dtype=np.float32)
            test_x = np.zeros([0, 64, 64, 3], dtype=np.float32)
            test_y = np.zeros([0], dtype=np.float32)

        for category_idx_in_group in range(category_idx, to_category_idx + 1):
            real_category_idx = order[category_idx_in_group]
            real_images_train_cur_cls, raw_images_train_cur_cls = imagenet_64x64.load_train_data(
                real_category_idx, flip=FLAGS.flip)

            # GAN
            train_x_gan = np.concatenate(
                (train_x_gan, raw_images_train_cur_cls))
            train_y_gan_cur_cls = np.ones([len(raw_images_train_cur_cls)]) * (
                category_idx_in_group % FLAGS.nb_cl)
            train_y_gan = np.concatenate((train_y_gan, train_y_gan_cur_cls))

            if not FLAGS.only_gen_no_cls:
                train_y_one_hot_cur_cls = np.zeros(
                    [len(raw_images_train_cur_cls), NUM_CLASSES])
                train_y_one_hot_cur_cls[:, category_idx_in_group] = np.ones(
                    len(raw_images_train_cur_cls))
                train_y_one_hot = np.concatenate(
                    (train_y_one_hot, train_y_one_hot_cur_cls))

        for category_idx_seen in range(to_category_idx + 1):

            real_category_idx = order[category_idx_seen]
            test_indices_cur_cls = [
                idx for idx in range(len(test_labels))
                if test_labels[idx] == real_category_idx
            ]
            test_x_gan_cur_cls = raw_images_test[test_indices_cur_cls, :]
            test_y_gan_cur_cls = np.ones([len(test_indices_cur_cls)]) * (
                category_idx_seen % FLAGS.nb_cl)
            test_x_gan = np.concatenate((test_x_gan, test_x_gan_cur_cls))
            test_y_gan = np.concatenate((test_y_gan, test_y_gan_cur_cls))

            # Classification network
            if not FLAGS.only_gen_no_cls:
                test_indices_cur_cls = [
                    idx for idx in range(len(test_labels))
                    if test_labels[idx] == real_category_idx
                ]
                test_x_cur_cls = test_images[test_indices_cur_cls, :]
                test_y_cur_cls = np.ones([len(test_indices_cur_cls)
                                          ]) * category_idx_seen

                test_x = np.concatenate((test_x, test_x_cur_cls))
                test_y = np.concatenate((test_y, test_y_cur_cls))
        '''
        Train classification model
        '''
        # No need to train the classifier if there is only one class
        if (to_category_idx > 0
                and not FLAGS.only_gen_no_cls) or not FLAGS.use_softmax:

            # init certain layers
            sess.run(init_ops)

            if FLAGS.no_truncate:
                mask_output_val[:] = True
            else:
                mask_output_val[:to_category_idx + 1] = True

            # Test on all seen classes
            mask_output_test[:to_category_idx + 1] = True
            '''
            Generate samples of old classes
            '''
            train_x = np.copy(train_x_gan)
            if FLAGS.no_truncate:
                train_y_truncated = train_y_one_hot[:, :]
            else:
                train_y_truncated = train_y_one_hot[:, :to_category_idx + 1]

            # Load old class model
            if sess_idx > 0:
                if not acwgan_obj.load(category_idx - 1)[0]:
                    raise Exception(
                        "[!] Train a model first, then run test mode")
                gen_samples_x = np.zeros((0, FLAGS.output_dim), dtype=int)
                for _ in range(category_idx):
                    gen_samples_x_frac, _, _ = acwgan_obj.test(
                        NUM_TRAIN_SAMPLES_PER_CLASS, label=None)
                    gen_samples_x = np.concatenate(
                        (gen_samples_x, gen_samples_x_frac))

                # import wgan.tflib.save_images
                # wgan.tflib.save_images.save_images(gen_samples_x[:128].reshape((128, 3, 32, 32)),
                #                                    'test.jpg')

                # get the output y
                gen_samples_y = np.zeros(
                    (len(gen_samples_x), to_category_idx + 1))
                if category_idx == 1:
                    gen_samples_y[:, 0] = np.ones((len(gen_samples_x)))
                else:
                    test_pred_val = []
                    mask_output_val_prev[:category_idx] = True
                    for i in range(0, len(gen_samples_x),
                                   FLAGS.test_batch_size):
                        gen_samples_x_batch = gen_samples_x[i:i + FLAGS.
                                                            test_batch_size]
                        test_pred_val_batch = sess.run(
                            test_pred,
                            feed_dict={
                                batch_images:
                                imagenet_64x64.convert_images(
                                    gen_samples_x_batch),
                                mask_output:
                                mask_output_val_prev
                            })
                        test_pred_val.extend(test_pred_val_batch)
                    for i in range(len(gen_samples_x)):
                        gen_samples_y[i, test_pred_val[i]] = 1

                train_weights_val = np.concatenate(
                    (np.ones(len(train_x)) * FLAGS.ratio,
                     np.ones(len(gen_samples_x)) * (1 - FLAGS.ratio)))
                train_x = np.concatenate((train_x, gen_samples_x))
                train_y_truncated = np.concatenate(
                    (train_y_truncated, gen_samples_y))
            else:
                train_weights_val = np.ones(len(train_x)) * FLAGS.ratio

            # # DEBUG:
            # train_indices = [idx for idx in range(NUM_SAMPLES_TOTAL) if train_labels[idx] <= category_idx]
            # train_x = raw_images_train[train_indices, :]
            # # Record the response of the new data using the old model(category_idx is consistent with the number of True in mask_output_val_prev)
            # train_y_truncated = train_one_hot_labels[train_indices, :category_idx + 1]

            # Training set
            # Convert the raw images from the data-files to floating-points.
            train_x = imagenet_64x64.convert_images(train_x)

            # Shuffle the indices and create mini-batch
            batch_indices_perm = []

            epoch_idx = 0
            lr = FLAGS.base_lr
            '''
            Training with mixed data
            '''
            while True:
                # Generate mini-batch
                if len(batch_indices_perm) == 0:
                    if epoch_idx >= FLAGS.epochs_per_category:
                        break
                    if epoch_idx in lr_strat:
                        lr /= FLAGS.lr_factor
                        print("NEW LEARNING RATE: %f" % lr)
                    epoch_idx = epoch_idx + 1

                    shuffled_indices = range(train_x.shape[0])
                    np.random.shuffle(shuffled_indices)
                    for i in range(0, len(shuffled_indices),
                                   FLAGS.train_batch_size):
                        batch_indices_perm.append(
                            shuffled_indices[i:i + FLAGS.train_batch_size])
                    batch_indices_perm.reverse()

                popped_batch_idx = batch_indices_perm.pop()

                # Use the random index to select random images and labels.
                train_x_batch = train_x[popped_batch_idx, :, :, :]
                train_y_batch = [
                    train_y_truncated[k] for k in popped_batch_idx
                ]
                train_weights_batch_val = train_weights_val[popped_batch_idx]

                # Train
                train_loss_summary_str, train_acc_summary_str, train_accuracy_val, \
                train_loss_val, train_empirical_loss_val, train_reg_loss_val, _ = sess.run(
                    [train_loss_summary, train_acc_summary, train_accuracy, loss, empirical_loss,
                     regularization_loss, opt], feed_dict={batch_images: train_x_batch,
                                                           one_hot_labels_truncated: train_y_batch,
                                                           mask_output: mask_output_val,
                                                           learning_rate: lr,
                                                           train_batch_weights: train_weights_batch_val})

                # Test
                if iteration % FLAGS.test_interval == 0:
                    sess.run(copy_ops)

                    # Divide and conquer: to avoid allocating too much GPU memory
                    test_pred_val = []
                    for i in range(0, len(test_x), FLAGS.test_batch_size):
                        test_x_batch = test_x[i:i + FLAGS.test_batch_size]
                        test_pred_val_batch = sess.run(test_pred,
                                                       feed_dict={
                                                           batch_images:
                                                           test_x_batch,
                                                           mask_output:
                                                           mask_output_test
                                                       })
                        test_pred_val.extend(test_pred_val_batch)

                    test_accuracy_val = 1. * np.sum(
                        np.equal(test_pred_val, test_y)) / (len(test_pred_val))
                    test_per_class_accuracy_val = np.diag(
                        confusion_matrix(test_y, test_pred_val)) * 2
                    # I simply multiply the correct predictions by 2 to calculate the accuracy since there are 50 samples per class in the test set

                    test_acc_summary_str = sess.run(
                        test_acc_summary,
                        feed_dict={test_accuracy: test_accuracy_val})

                    test_summary_writer.add_summary(test_acc_summary_str,
                                                    iteration)

                    print("TEST: step %d, lr %.4f, accuracy %g" %
                          (iteration, lr, test_accuracy_val))
                    print("PER CLASS ACCURACY: " + " | ".join(
                        str(o) + '%' for o in test_per_class_accuracy_val))

                # Print the training logs
                if iteration % FLAGS.display_interval == 0:
                    train_summary_writer.add_summary(train_loss_summary_str,
                                                     iteration)
                    train_summary_writer.add_summary(train_acc_summary_str,
                                                     iteration)
                    print(
                        "TRAIN: epoch %d, step %d, lr %.4f, accuracy %g, loss %g, empirical %g, reg %g"
                        % (epoch_idx, iteration, lr, train_accuracy_val,
                           train_loss_val, train_empirical_loss_val,
                           train_reg_loss_val))

                iteration = iteration + 1
            '''
            Final test(before the next class is added)
            '''
            sess.run(copy_ops)
            # Divide and conquer: to avoid allocating too much GPU memory
            test_pred_val = []
            for i in range(0, len(test_x), FLAGS.test_batch_size):
                test_x_batch = test_x[i:i + FLAGS.test_batch_size]
                test_pred_val_batch = sess.run(test_pred,
                                               feed_dict={
                                                   batch_images: test_x_batch,
                                                   mask_output:
                                                   mask_output_test
                                               })
                test_pred_val.extend(test_pred_val_batch)

            test_accuracy_val = 1. * np.sum(np.equal(
                test_pred_val, test_y)) / (len(test_pred_val))
            conf_mat = confusion_matrix(test_y, test_pred_val)
            test_per_class_accuracy_val = np.diag(conf_mat)

            # Record and save the cumulative accuracy
            aver_acc_over_time[to_category_idx] = test_accuracy_val
            aver_acc_per_class_over_time[
                to_category_idx] = test_per_class_accuracy_val
            conf_mat_over_time[to_category_idx] = conf_mat

            dump_obj = dict()
            dump_obj['flags'] = flags.FLAGS.__flags
            dump_obj['aver_acc_over_time'] = aver_acc_over_time
            dump_obj[
                'aver_acc_per_class_over_time'] = aver_acc_per_class_over_time
            dump_obj['conf_mat_over_time'] = conf_mat_over_time

            np_file_result = os.path.join(result_folder, 'acc_over_time.pkl')
            with open(np_file_result, 'wb') as file:
                pickle.dump(dump_obj, file)

            visualize_result.vis(np_file_result, 'ImageNetDogs')
        '''
        Train generative model(W-GAN)
        '''
        if acwgan_obj.check_model(to_category_idx):
            print(
                " [*] Model of Class %d-%d exists. Skip the training process" %
                (category_idx + 1, to_category_idx + 1))
        else:
            print(
                " [*] Model of Class %d-%d does not exist. Start the training process"
                % (category_idx + 1, to_category_idx + 1))
            acwgan_obj.load(to_category_idx - FLAGS.nb_cl)
            for _ in range(category_idx):
                gen_samples_x, _, _ = acwgan_obj.test(
                    NUM_TRAIN_SAMPLES_PER_CLASS, label=None)
                gen_samples_x = np.uint8(gen_samples_x)
                train_x_gan = np.concatenate((train_x_gan, gen_samples_x))
                train_y_gan = np.concatenate(
                    (train_y_gan, np.zeros(len(gen_samples_x))))
            acwgan_obj.train(train_x_gan, train_y_gan, test_x_gan, test_y_gan,
                             to_category_idx)

    # Save the final model
    if not FLAGS.only_gen_no_cls:
        checkpoint_dir = os.path.join(result_folder, 'checkpoints')
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        saver.save(sess, os.path.join(checkpoint_dir, 'model.ckpt'))
        sess.close()
def main(_):

    assert FLAGS.balanced

    pp.pprint(flags.FLAGS.__flags)

    # Load the class order
    order = []
    with open('cifar-100_%s.txt' % FLAGS.order_file) as file_in:
        for line in file_in.readlines():
            order.append(int(line))
    order = np.array(order)

    assert FLAGS.mode == 'wgan-gp'

    import cifar100
    NUM_CLASSES = 100  # number of classes
    NUM_TRAIN_SAMPLES_PER_CLASS = 500  # number of training samples per class
    NUM_TEST_SAMPLES_PER_CLASS = 100  # number of test samples per class
    train_images, train_labels, train_one_hot_labels, \
        test_images, test_labels, test_one_hot_labels, \
        raw_images_train, raw_images_test, pixel_mean = cifar100.load_data(order, mean_subtraction=True)

    # Number of all training samples
    NUM_TRAIN_SAMPLES_TOTAL = NUM_CLASSES * NUM_TRAIN_SAMPLES_PER_CLASS
    NUM_TEST_SAMPLES_TOTAL = NUM_CLASSES * NUM_TEST_SAMPLES_PER_CLASS

    def build_cnn(inputs, is_training):
        train_or_test = {True: 'train', False: 'test'}
        if FLAGS.network_arch == 'lenet':
            logits, end_points = utils_lenet.lenet(
                inputs,
                num_classes=NUM_CLASSES,
                is_training=is_training,
                use_dropout=FLAGS.use_dropout,
                scope=('LeNet-' + train_or_test[is_training]))
        elif FLAGS.network_arch == 'resnet':
            logits, end_points = utils_resnet.ResNet(
                inputs,
                train_or_test[is_training],
                num_outputs=NUM_CLASSES,
                alpha=0.0,
                n=FLAGS.num_resblocks,
                scope=('ResNet-' + train_or_test[is_training]))
        elif FLAGS.network_arch == 'nin':
            logits, end_points = utils_nin.nin(
                inputs,
                is_training=is_training,
                num_classes=NUM_CLASSES,
                scope=('NIN-' + train_or_test[is_training]))
        else:
            raise Exception('Invalid network architecture')
        return logits, end_points

    '''
    Define variables
    '''
    if not FLAGS.only_gen_no_cls:

        # Save all intermediate result in the result_folder
        method_name = '_'.join(
            os.path.basename(__file__).split('.')[0].split('_')[2:])
        method_name += '_gen_%d_and_select' % FLAGS.gen_how_many if FLAGS.gen_more_and_select else ''
        method_name += '_auto-%.1f-%.1f' % (FLAGS.auto_param1, FLAGS.auto_param2) \
            if FLAGS.auto_choose_num_exemplars else (
        '_%d' % FLAGS.num_exemplars_per_class if not FLAGS.memory_constrained else '')
        method_name += '_%s' % FLAGS.exemplar_select_criterion
        method_name += '_%.1f-%.1f' % (FLAGS.proto_weight, FLAGS.gen_weight)
        method_name += '_cache_%d' % FLAGS.cache_size_per_class if FLAGS.use_cache_for_gen_samples else ''
        method_name += '_icarl_%d' % FLAGS.memory_upperbound if FLAGS.memory_constrained else ''
        method_name += '_reorder' if FLAGS.reorder_exemplars else ''
        method_name += '' if FLAGS.label_smoothing == 1. else '_smoothing_%.1f' % FLAGS.label_smoothing

        cls_func = '' if FLAGS.use_softmax else '_sigmoid'
        result_folder = os.path.join(
            FLAGS.result_dir, 'cifar-100_' + FLAGS.order_file,
            'nb_cl_' + str(FLAGS.nb_cl),
            'non_truncated' if FLAGS.no_truncate else 'truncated',
            FLAGS.network_arch +
            ('_%d' %
             FLAGS.num_resblocks if FLAGS.network_arch == 'resnet' else '') +
            cls_func + '_init_' + FLAGS.init_strategy,
            'weight_decay_' + str(FLAGS.weight_decay),
            'base_lr_' + str(FLAGS.base_lr), 'adam_lr_' + str(FLAGS.adam_lr))
        if FLAGS.gan_finetune and 'gan' in method_name:
            result_folder = os.path.join(
                result_folder, method_name + '_finetune_' +
                FLAGS.pretrained_model_sub_dir.replace('/', '_'))
        else:
            result_folder = os.path.join(result_folder, method_name)

        # Add a "_run-i" suffix to the folder name if the folder exists
        if os.path.exists(result_folder):
            temp_i = 2
            while True:
                result_folder_mod = result_folder + '_run-' + str(temp_i)
                if not os.path.exists(result_folder_mod):
                    result_folder = result_folder_mod
                    break
                temp_i += 1
        os.makedirs(result_folder)
        print('Result folder: %s' % result_folder)

        graph_cls = tf.Graph()
        with graph_cls.as_default():
            '''
            Define variables
            '''
            batch_images = tf.placeholder(tf.float32, shape=[None, 32, 32, 3])
            batch = tf.Variable(0,
                                trainable=False,
                                name='LeNet-train/iteration')
            learning_rate = tf.placeholder(tf.float32, shape=[])
            '''
            Network output mask
            '''
            mask_output = tf.placeholder(tf.bool, shape=[NUM_CLASSES])
            '''
            Old and new ground truth
            '''
            one_hot_labels_truncated = tf.placeholder(tf.float32,
                                                      shape=[None, None])
            '''
            Define the training network
            '''
            train_logits, _ = build_cnn(batch_images, True)
            train_masked_logits = tf.gather(train_logits,
                                            tf.squeeze(tf.where(mask_output)),
                                            axis=1)
            train_masked_logits = tf.cond(
                tf.equal(tf.rank(train_masked_logits),
                         1), lambda: tf.expand_dims(train_masked_logits, 1),
                lambda: train_masked_logits)
            train_pred = tf.argmax(train_masked_logits, 1)
            train_ground_truth = tf.argmax(one_hot_labels_truncated, 1)
            correct_prediction = tf.equal(train_pred, train_ground_truth)
            train_accuracy = tf.reduce_mean(
                tf.cast(correct_prediction, tf.float32))
            train_batch_weights = tf.placeholder(tf.float32, shape=[None])
            reg_weights = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
            regularization_loss = FLAGS.weight_decay * tf.add_n(reg_weights)
            '''
            More Settings
            '''
            if FLAGS.use_softmax:
                empirical_loss = tf.losses.softmax_cross_entropy(
                    onehot_labels=one_hot_labels_truncated,
                    logits=train_masked_logits,
                    weights=train_batch_weights)
            else:
                empirical_loss = tf.losses.sigmoid_cross_entropy(
                    multi_class_labels=one_hot_labels_truncated,
                    logits=train_masked_logits,
                    weights=train_batch_weights)

            loss = empirical_loss + regularization_loss
            if FLAGS.use_momentum:
                opt = tf.train.MomentumOptimizer(
                    learning_rate, FLAGS.momentum).minimize(loss,
                                                            global_step=batch)
            else:
                opt = tf.train.GradientDescentOptimizer(
                    learning_rate).minimize(loss, global_step=batch)
            '''
            Define the testing network
            '''
            test_logits, _ = build_cnn(batch_images, False)
            test_masked_logits = tf.gather(test_logits,
                                           tf.squeeze(tf.where(mask_output)),
                                           axis=1)
            test_masked_logits = tf.cond(
                tf.equal(tf.rank(test_masked_logits),
                         1), lambda: tf.expand_dims(test_masked_logits, 1),
                lambda: test_masked_logits)
            test_masked_prob = tf.nn.softmax(test_masked_logits)
            test_pred = tf.argmax(test_masked_logits, 1)
            test_accuracy = tf.placeholder(tf.float32)
            '''
            Copy network (define the copying op)
            '''
            if FLAGS.network_arch == 'resnet':
                all_variables = tf.get_collection(tf.GraphKeys.WEIGHTS)
            else:
                all_variables = tf.trainable_variables()
            copy_ops = [
                all_variables[ix + len(all_variables) // 2].assign(var.value())
                for ix, var in enumerate(all_variables[0:len(all_variables) //
                                                       2])
            ]
            '''
            Init certain layers when new classes added
            '''
            init_ops = tf.no_op()
            if FLAGS.init_strategy == 'all':
                init_ops = tf.global_variables_initializer()
            elif FLAGS.init_strategy == 'last':
                if FLAGS.network_arch == 'lenet':
                    init_vars = [
                        var for var in tf.global_variables()
                        if 'fc4' in var.name and 'train' in var.name
                    ]
                elif FLAGS.network_arch == 'resnet':
                    init_vars = [
                        var for var in tf.global_variables()
                        if 'fc' in var.name and 'train' in var.name
                    ]
                elif FLAGS.network_arch == 'nin':
                    init_vars = [
                        var for var in tf.global_variables()
                        if 'ccp6' in var.name and 'train' in var.name
                    ]
                init_ops = tf.initialize_variables(init_vars)
            '''
            Create session
            '''
            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            sess = tf.Session(config=config, graph=graph_cls)
            sess.run(tf.global_variables_initializer())

            saver = tf.train.Saver()
        '''
        Summary
        '''
        train_loss_summary = tf.summary.scalar('train_loss', loss)
        train_acc_summary = tf.summary.scalar('train_accuracy', train_accuracy)
        test_acc_summary = tf.summary.scalar('test_accuracy', test_accuracy)

        summary_dir = os.path.join(result_folder, 'summary')
        if not os.path.exists(summary_dir):
            os.makedirs(summary_dir)
        train_summary_writer = tf.summary.FileWriter(
            os.path.join(summary_dir, 'train'), sess.graph)
        test_summary_writer = tf.summary.FileWriter(
            os.path.join(summary_dir, 'test'))

        iteration = 0
        '''
        Declaration of other vars
        '''
        # Average accuracy on seen classes
        aver_acc_over_time = dict()
        aver_acc_per_class_over_time = dict()
        conf_mat_over_time = dict()

        # Network mask
        mask_output_val = np.zeros([NUM_CLASSES], dtype=bool)
        '''
        Cache(accelerate)
        '''
        cache_dir = os.path.join(result_folder, 'cache')
        if not os.path.exists(cache_dir):
            os.makedirs(cache_dir)
        '''
        Exemplars(for ablation study and other purposes)
        '''
        exemplars_dir = os.path.join(result_folder, 'exemplars')
        if not os.path.exists(exemplars_dir):
            os.makedirs(exemplars_dir)
    '''
    Train generative model(DC-GAN)
    '''
    run_config = tf.ConfigProto()
    run_config.gpu_options.allow_growth = True
    graph_gen = tf.Graph()
    sess_wgan = tf.Session(config=run_config, graph=graph_gen)

    wgan_obj = GAN(sess_wgan,
                   graph_gen,
                   dataset_name='cifar-100',
                   mode=FLAGS.mode,
                   batch_size=FLAGS.batch_size,
                   dim=FLAGS.dim,
                   output_dim=FLAGS.output_dim,
                   lambda_param=FLAGS.lambda_param,
                   critic_iters=FLAGS.critic_iters,
                   iters=FLAGS.iters,
                   result_dir=FLAGS.result_dir_wgan,
                   checkpoint_interval=FLAGS.gan_save_interval,
                   adam_lr=FLAGS.adam_lr,
                   adam_beta1=FLAGS.adam_beta1,
                   adam_beta2=FLAGS.adam_beta2,
                   finetune=FLAGS.gan_finetune,
                   finetune_from=FLAGS.gan_finetune_from,
                   pretrained_model_base_dir=FLAGS.pretrained_model_base_dir,
                   pretrained_model_sub_dir=FLAGS.pretrained_model_sub_dir)

    exemplars = []
    '''
    Class Incremental Learning
    '''
    print('Starting from category ' + str(FLAGS.from_class_idx + 1) + ' to ' +
          str(FLAGS.to_class_idx + 1))
    print('Adding %d categories every time' % FLAGS.nb_cl)
    assert (FLAGS.from_class_idx % FLAGS.nb_cl == 0)
    for category_idx in range(FLAGS.from_class_idx, FLAGS.to_class_idx + 1,
                              FLAGS.nb_cl):

        to_category_idx = category_idx + FLAGS.nb_cl - 1
        if FLAGS.nb_cl == 1:
            print('Adding Category ' + str(category_idx + 1))
        else:
            print('Adding Category %d-%d' %
                  (category_idx + 1, to_category_idx + 1))

        for category_idx_in_group in range(category_idx, to_category_idx + 1):
            # Training set(current category)
            train_indices_gan = [
                idx for idx in range(NUM_TRAIN_SAMPLES_TOTAL)
                if train_labels[idx] == category_idx_in_group
            ]
            test_indices_cur_cls_gan = [
                idx for idx in range(NUM_TEST_SAMPLES_TOTAL)
                if test_labels[idx] == category_idx_in_group
            ]

            train_x_gan = raw_images_train[train_indices_gan, :]
            test_x_cur_cls_gan = raw_images_test[test_indices_cur_cls_gan, :]
            '''
            Train generative model(W-GAN)
            '''
            real_class_idx = order[category_idx_in_group]
            if wgan_obj.check_model(real_class_idx):
                print(
                    " [*] Model of Class %d exists. Skip the training process"
                    % (real_class_idx + 1))
            else:
                print(
                    " [*] Model of Class %d does not exist. Start the training process"
                    % (real_class_idx + 1))
                wgan_obj.train(train_x_gan, test_x_cur_cls_gan, real_class_idx)
        '''
        Train classification model
        '''
        # No need to train the classifier if there is only one class
        if not FLAGS.only_gen_no_cls:

            if FLAGS.no_truncate:
                mask_output_val[:] = True
            else:
                mask_output_val[:to_category_idx + 1] = True

            if to_category_idx > 0:

                # init certain layers
                sess.run(init_ops)
                '''
                Generate samples of new classes
                '''
                train_indices_new = [
                    idx for idx in range(NUM_TRAIN_SAMPLES_TOTAL)
                    if category_idx <= train_labels[idx] <= to_category_idx
                ]
                train_x_new = raw_images_train[train_indices_new]
                if FLAGS.no_truncate:
                    train_y_truncated_new = train_one_hot_labels[
                        train_indices_new, :]
                else:
                    train_y_truncated_new = train_one_hot_labels[
                        train_indices_new, :to_category_idx + 1]
                train_weights_val_new = np.ones(len(train_x_new))

                train_x = raw_images_train[[], :]
                if FLAGS.no_truncate:
                    train_y_truncated = train_one_hot_labels[[], :]
                else:
                    train_y_truncated = train_one_hot_labels[
                        [], :to_category_idx + 1]
                train_weights_val = np.zeros([0])

                for new_category_idx in range(category_idx,
                                              to_category_idx + 1):
                    if len(exemplars) == 0:
                        num_gen_samples_x_needed = 0
                    else:
                        num_gen_samples_x_needed = NUM_TRAIN_SAMPLES_PER_CLASS * (
                            NUM_TRAIN_SAMPLES_PER_CLASS -
                            len(exemplars[0])) / len(exemplars[0])

                    if num_gen_samples_x_needed > 0:
                        gen_samples_x = []
                        packs, last_pack = divmod(num_gen_samples_x_needed,
                                                  500)

                        batch_size_gens = []
                        for _ in range(packs):
                            batch_size_gens.append(500)
                        if last_pack > 0:
                            batch_size_gens.append(last_pack)

                        wgan_obj.load(new_category_idx)
                        for pack_num in batch_size_gens:
                            gen_samples_x_batch, _, _ = wgan_obj.test(pack_num)
                            gen_samples_x.extend(gen_samples_x_batch)

                        train_x_new = np.concatenate(
                            (train_x_new, gen_samples_x))
                        train_weights_val_new = np.concatenate(
                            (train_weights_val_new,
                             np.ones(len(gen_samples_x)) * FLAGS.proto_weight))
                        if FLAGS.no_truncate:
                            gen_samples_y = np.ones(
                                (num_gen_samples_x_needed, NUM_CLASSES)) * (
                                    (1 - FLAGS.label_smoothing) /
                                    (NUM_CLASSES - 1))
                        else:
                            gen_samples_y = np.ones(
                                (num_gen_samples_x_needed, to_category_idx +
                                 1)) * ((1 - FLAGS.label_smoothing) /
                                        to_category_idx)
                        gen_samples_y[:, new_category_idx] = np.ones(
                            (num_gen_samples_x_needed)) * FLAGS.label_smoothing

                        train_y_truncated_new = np.concatenate(
                            (train_y_truncated_new, gen_samples_y))
                '''
                Generate samples of old classes
                '''

                for old_category_idx in range(0, category_idx):

                    # Load old class model
                    num_gen_samples_x_needed = NUM_TRAIN_SAMPLES_PER_CLASS - len(
                        exemplars[old_category_idx])
                    if num_gen_samples_x_needed > 0:

                        # if FLAGS.use_cache_for_gen_samples:
                        #     cache_file = os.path.join(cache_dir, 'class_%d.npy' % (old_category_idx + 1))
                        #     if os.path.exists(cache_file):
                        #         gen_samples_x = np.load(cache_file)
                        #     else:
                        #         if not wgan_obj.load(old_category_idx)[0]:
                        #             raise Exception("[!] Train a model first, then run test mode")
                        #         gen_samples_x, _, _ = wgan_obj.test(FLAGS.cache_size_per_class)
                        #         np.save(cache_file, gen_samples_x)
                        #
                        #     gen_samples_x_idx = np.random.choice(len(gen_samples_x),
                        #                                          num_gen_samples_x_needed,
                        #                                          replace=False)
                        #     gen_samples_x = gen_samples_x[gen_samples_x_idx]
                        # else:
                        #     if not wgan_obj.load(old_category_idx)[0]:
                        #         raise Exception("[!] Train a model first, then run test mode")
                        #     gen_samples_x, _, _ = wgan_obj.test(num_gen_samples_x_needed)

                        real_class_idx = order[old_category_idx]
                        if not wgan_obj.load(real_class_idx)[0]:
                            raise Exception(
                                "[!] Train a model first, then run test mode")
                        if FLAGS.gen_more_and_select:
                            gen_samples_x_more, _, _ = wgan_obj.test(
                                FLAGS.gen_how_many)
                            gen_samples_x_more_real = cifar100.convert_images(
                                gen_samples_x_more, pixel_mean=pixel_mean)
                            gen_samples_prob = sess.run(
                                test_masked_prob,
                                feed_dict={
                                    batch_images: gen_samples_x_more_real,
                                    mask_output: mask_output_val
                                })
                            gen_samples_scores_cur_cls = gen_samples_prob[:,
                                                                          old_category_idx]
                            top_k_indices = np.argsort(
                                -gen_samples_scores_cur_cls
                            )[:num_gen_samples_x_needed]
                            gen_samples_x = gen_samples_x_more[top_k_indices]
                        else:
                            gen_samples_x, _, _ = wgan_obj.test(
                                num_gen_samples_x_needed)

                        # import wgan.tflib.save_images
                        # wgan.tflib.save_images.save_images(gen_samples_x[:128].reshape((128, 3, 32, 32)),
                        #                                    'test.jpg')
                        train_x = np.concatenate((train_x, gen_samples_x,
                                                  exemplars[old_category_idx]))
                        train_weights_val = np.concatenate(
                            (train_weights_val,
                             np.ones(len(gen_samples_x)) * FLAGS.gen_weight,
                             np.ones(len(exemplars[old_category_idx])) *
                             FLAGS.proto_weight))
                    elif num_gen_samples_x_needed == 0:
                        train_x = np.concatenate(
                            (train_x, exemplars[old_category_idx]))
                        train_weights_val = np.concatenate(
                            (train_weights_val,
                             np.ones(len(exemplars[old_category_idx])) *
                             FLAGS.proto_weight))
                    # if FLAGS.no_truncate:
                    #     gen_samples_y = np.zeros((NUM_TRAIN_SAMPLES_PER_CLASS, NUM_CLASSES))
                    # else:
                    #     gen_samples_y = np.zeros((NUM_TRAIN_SAMPLES_PER_CLASS, to_category_idx+1))
                    # gen_samples_y[:, old_category_idx] = np.ones((NUM_TRAIN_SAMPLES_PER_CLASS))

                    if FLAGS.no_truncate:
                        gen_samples_y = np.ones(
                            (NUM_TRAIN_SAMPLES_PER_CLASS, NUM_CLASSES)) * (
                                (1 - FLAGS.label_smoothing) /
                                (NUM_CLASSES - 1))
                    else:
                        gen_samples_y = np.ones(
                            (NUM_TRAIN_SAMPLES_PER_CLASS,
                             to_category_idx + 1)) * (
                                 (1 - FLAGS.label_smoothing) / to_category_idx)
                    gen_samples_y[:, old_category_idx] = np.ones(
                        (NUM_TRAIN_SAMPLES_PER_CLASS)) * FLAGS.label_smoothing

                    train_y_truncated = np.concatenate(
                        (train_y_truncated, gen_samples_y))

                # Training set
                # Convert the raw images from the data-files to floating-points.
                train_x = cifar100.convert_images(train_x,
                                                  pixel_mean=pixel_mean)
                train_x_new = cifar100.convert_images(train_x_new,
                                                      pixel_mean=pixel_mean)

                # Testing set
                test_indices = [
                    idx for idx in range(len(test_labels))
                    if test_labels[idx] <= to_category_idx
                ]
                test_x = test_images[test_indices]
                test_y = test_labels[test_indices]

                # Shuffle the indices and create mini-batch
                batch_indices_perm = []

                epoch_idx = 0
                lr = FLAGS.base_lr
                '''
                Training with mixed data
                '''
                old_ratio = float(category_idx) / (to_category_idx + 1)
                old_batch_size = int(FLAGS.train_batch_size * old_ratio)
                new_batch_size = FLAGS.train_batch_size - old_batch_size

                while True:
                    # Generate mini-batch
                    if len(batch_indices_perm) == 0:
                        if epoch_idx >= FLAGS.epochs_per_category:
                            break
                        if epoch_idx in lr_strat:
                            lr /= FLAGS.lr_factor
                            print("NEW LEARNING RATE: %f" % lr)
                        epoch_idx = epoch_idx + 1

                        # print('Epoch %d' % epoch_idx)

                        if len(train_x) > 0:
                            shuffled_indices = range(train_x.shape[0])
                            np.random.shuffle(shuffled_indices)
                            for i in range(0, len(shuffled_indices),
                                           old_batch_size):
                                batch_indices_perm.append(
                                    shuffled_indices[i:i + old_batch_size])
                            batch_indices_perm.reverse()
                        elif len(train_x) == 0:
                            for i in range(0, len(train_x_new),
                                           new_batch_size):
                                batch_indices_perm.append([])

                    popped_batch_idx = batch_indices_perm.pop()

                    # Use the random index to select random images and labels.
                    train_weights_batch_val_old = train_weights_val[
                        popped_batch_idx]
                    train_x_batch_old = train_x[popped_batch_idx, :, :, :]
                    train_y_batch_old = np.array(
                        [train_y_truncated[k] for k in popped_batch_idx])

                    popped_batch_idx_new = np.random.choice(range(
                        len(train_x_new)),
                                                            new_batch_size,
                                                            replace=False)

                    train_weights_batch_val_new = train_weights_val_new[
                        popped_batch_idx_new]
                    train_x_batch_new = train_x_new[
                        popped_batch_idx_new, :, :, :]
                    train_y_batch_new = np.array([
                        train_y_truncated_new[k] for k in popped_batch_idx_new
                    ])

                    if len(train_y_batch_old) == 0:
                        train_y_batch_old.shape = (0,
                                                   train_y_batch_new.shape[1])

                    train_x_batch = np.concatenate(
                        (train_x_batch_old, train_x_batch_new))
                    train_y_batch = np.concatenate(
                        (train_y_batch_old, train_y_batch_new))
                    train_weights_batch_val = np.concatenate(
                        (train_weights_batch_val_old,
                         train_weights_batch_val_new))

                    # Train
                    train_loss_summary_str, train_acc_summary_str, train_accuracy_val, \
                        train_loss_val, train_empirical_loss_val, train_reg_loss_val, _ = sess.run(
                            [train_loss_summary, train_acc_summary, train_accuracy, loss, empirical_loss,
                             regularization_loss, opt], feed_dict={batch_images: train_x_batch,
                                                                   one_hot_labels_truncated: train_y_batch,
                                                                   mask_output: mask_output_val,
                                                                   learning_rate: lr,
                                                                   train_batch_weights: train_weights_batch_val})

                    # Test
                    if iteration % FLAGS.test_interval == 0:
                        sess.run(copy_ops)

                        # Divide and conquer: to avoid allocating too much GPU memory
                        test_pred_val = []
                        for i in range(0, len(test_x), FLAGS.test_batch_size):
                            test_x_batch = test_x[i:i + FLAGS.test_batch_size]
                            test_pred_val_batch = sess.run(test_pred,
                                                           feed_dict={
                                                               batch_images:
                                                               test_x_batch,
                                                               mask_output:
                                                               mask_output_val
                                                           })
                            test_pred_val.extend(test_pred_val_batch)

                        test_accuracy_val = 1. * np.sum(
                            np.equal(test_pred_val,
                                     test_y)) / (len(test_pred_val))
                        test_per_class_accuracy_val = np.diag(
                            confusion_matrix(test_y, test_pred_val))

                        test_acc_summary_str = sess.run(
                            test_acc_summary,
                            feed_dict={test_accuracy: test_accuracy_val})

                        test_summary_writer.add_summary(
                            test_acc_summary_str, iteration)

                        print("TEST: step %d, lr %.4f, accuracy %g" %
                              (iteration, lr, test_accuracy_val))
                        print("PER CLASS ACCURACY: " + " | ".join(
                            str(o) + '%' for o in test_per_class_accuracy_val))

                    # Print the training logs
                    if iteration % FLAGS.display_interval == 0:
                        train_summary_writer.add_summary(
                            train_loss_summary_str, iteration)
                        train_summary_writer.add_summary(
                            train_acc_summary_str, iteration)
                        print(
                            "TRAIN: epoch %d, step %d, lr %.4f, accuracy %g, loss %g, empirical %g, reg %g"
                            % (epoch_idx, iteration, lr, train_accuracy_val,
                               train_loss_val, train_empirical_loss_val,
                               train_reg_loss_val))

                    iteration = iteration + 1
                '''
                Final test(before the next class is added)
                '''
                sess.run(copy_ops)
                # Divide and conquer: to avoid allocating too much GPU memory
                test_pred_val = []
                for i in range(0, len(test_x), FLAGS.test_batch_size):
                    test_x_batch = test_x[i:i + FLAGS.test_batch_size]
                    test_pred_val_batch = sess.run(test_pred,
                                                   feed_dict={
                                                       batch_images:
                                                       test_x_batch,
                                                       mask_output:
                                                       mask_output_val
                                                   })
                    test_pred_val.extend(test_pred_val_batch)

                test_accuracy_val = 1. * np.sum(np.equal(
                    test_pred_val, test_y)) / (len(test_pred_val))
                conf_mat = confusion_matrix(test_y, test_pred_val)
                test_per_class_accuracy_val = np.diag(conf_mat)

                # Record and save the cumulative accuracy
                aver_acc_over_time[to_category_idx] = test_accuracy_val
                aver_acc_per_class_over_time[
                    to_category_idx] = test_per_class_accuracy_val
                conf_mat_over_time[to_category_idx] = conf_mat

                dump_obj = dict()
                dump_obj['flags'] = flags.FLAGS.__flags
                dump_obj['aver_acc_over_time'] = aver_acc_over_time
                dump_obj[
                    'aver_acc_per_class_over_time'] = aver_acc_per_class_over_time
                dump_obj['conf_mat_over_time'] = conf_mat_over_time

                np_file_result = os.path.join(result_folder,
                                              'acc_over_time.pkl')
                with open(np_file_result, 'wb') as file:
                    pickle.dump(dump_obj, file)

                visualize_result.vis(np_file_result)

            # reorder the exemplars
            if FLAGS.reorder_exemplars:
                for old_category_idx in range(category_idx):

                    sess.run(copy_ops)
                    # divide and conquer: to avoid allocating too much GPU memory
                    train_prob_cur_cls_exemplars_val = sess.run(
                        test_masked_prob,
                        feed_dict={
                            batch_images:
                            cifar100.convert_images(
                                exemplars[old_category_idx]),
                            mask_output:
                            mask_output_val
                        })
                    train_prob_cur_cls_exemplars_val = train_prob_cur_cls_exemplars_val[:,
                                                                                        old_category_idx]
                    reorder_indices = np.argsort(
                        -train_prob_cur_cls_exemplars_val)
                    exemplars[old_category_idx] = exemplars[old_category_idx][
                        reorder_indices]

            # select the exemplars
            for category_idx_in_group in range(category_idx,
                                               to_category_idx + 1):
                train_indices_cur_cls = [
                    idx for idx in range(NUM_TRAIN_SAMPLES_TOTAL)
                    if train_labels[idx] == category_idx_in_group
                ]
                train_x_cur_cls = raw_images_train[train_indices_cur_cls]
                train_x_cur_cls_normalized = cifar100.convert_images(
                    train_x_cur_cls, pixel_mean=pixel_mean)
                sess.run(copy_ops)
                # Divide and conquer: to avoid allocating too much GPU memory
                train_prob_cur_cls_val = sess.run(
                    test_masked_prob,
                    feed_dict={
                        batch_images: train_x_cur_cls_normalized,
                        mask_output: mask_output_val
                    })
                train_prob_cur_cls_val = train_prob_cur_cls_val[:,
                                                                category_idx_in_group]

                # use iCaRL-like memory mechanism to save exemplars or not
                if FLAGS.memory_constrained:

                    if FLAGS.auto_choose_num_exemplars:  # auto or fixed number of exemplars
                        # check if we can save all new samples as exemplars
                        if NUM_TRAIN_SAMPLES_PER_CLASS > FLAGS.memory_upperbound - sum(
                            [len(exemplars[i])
                             for i in range(len(exemplars))]):
                            # load inception scores of all classes
                            save_exemplars_ratios = []
                            for i in range(category_idx_in_group + 1):
                                real_class_idx = order[i]
                                inception_score = wgan_obj.load_inception_score(
                                    real_class_idx)
                                save_exemplars_ratio = FLAGS.auto_param1 - FLAGS.auto_param2 * inception_score
                                save_exemplars_ratios.append(
                                    save_exemplars_ratio)

                            save_exemplars_ratios = np.array(
                                save_exemplars_ratios)
                            keep_exemplars_num = np.floor(
                                save_exemplars_ratios *
                                FLAGS.memory_upperbound /
                                sum(save_exemplars_ratios)).astype(int)
                            for old_category_idx in range(
                                    category_idx_in_group):
                                exemplars[old_category_idx] = exemplars[
                                    old_category_idx][:keep_exemplars_num[
                                        old_category_idx]]
                            num_exemplars_cur_cls = keep_exemplars_num[-1]
                        else:
                            num_exemplars_cur_cls = NUM_TRAIN_SAMPLES_PER_CLASS

                    else:
                        num_exemplars_per_cls = int(
                            FLAGS.memory_upperbound //
                            (category_idx_in_group + 1))
                        num_exemplars_per_cls = min(
                            num_exemplars_per_cls, NUM_TRAIN_SAMPLES_PER_CLASS)
                        # remove redundant elements in the memory for previous classes
                        if category_idx_in_group > 0 and len(
                                exemplars[0]) > num_exemplars_per_cls:
                            for old_category_idx in range(
                                    category_idx_in_group):
                                exemplars[old_category_idx] = exemplars[
                                    old_category_idx][:num_exemplars_per_cls]

                        # add how many new elements in the memory for the current class
                        num_exemplars_cur_cls = num_exemplars_per_cls
                        print(' [*] Store %d exemplars for each class' %
                              num_exemplars_cur_cls)

                else:
                    if FLAGS.auto_choose_num_exemplars:  # auto or fixed number of exemplars
                        real_class_idx = order[category_idx_in_group]
                        inception_score = wgan_obj.load_inception_score(
                            real_class_idx)
                        num_exemplars_cur_cls = int(
                            np.floor(FLAGS.auto_param1 -
                                     FLAGS.auto_param2 * inception_score))
                        print(' [*] Inception score %f, store %d exemplars' %
                              (inception_score, num_exemplars_cur_cls))
                    else:
                        num_exemplars_cur_cls = FLAGS.num_exemplars_per_class

                selected_indices = np.array(range(len(train_prob_cur_cls_val)))
                if FLAGS.exemplar_select_criterion == 'high':
                    selected_indices = train_prob_cur_cls_val.argsort()[:-(
                        num_exemplars_cur_cls + 1):-1]  # select the last 20
                elif FLAGS.exemplar_select_criterion == 'low':
                    selected_indices = train_prob_cur_cls_val.argsort(
                    )[:num_exemplars_cur_cls]  # select the last 20
                elif FLAGS.exemplar_select_criterion == 'random':
                    random_idx = range(len(train_prob_cur_cls_val))
                    np.random.shuffle(random_idx)
                    selected_indices = random_idx[:num_exemplars_cur_cls]

                exemplars.append(train_x_cur_cls[selected_indices])

                np_file_exemplars = os.path.join(
                    exemplars_dir,
                    'exemplars_%d' % (category_idx_in_group + 1))
                np.save(np_file_exemplars, exemplars)

    # Save the final model
    if not FLAGS.only_gen_no_cls:
        checkpoint_dir = os.path.join(result_folder, 'checkpoints')
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        saver.save(sess, os.path.join(checkpoint_dir, 'model.ckpt'))
        sess.close()
示例#3
0
def main(_):

    pp.pprint(flags.FLAGS.__flags)

    order = []
    with open('imagenet_64x64_dogs_%s.txt' % FLAGS.order_file) as file_in:
        for line in file_in.readlines():
            order.append(int(line))
    order = np.array(order)

    NUM_CLASSES = 120
    NUM_TEST_SAMPLES_PER_CLASS = 50

    def build_cnn(inputs, is_training):
        train_or_test = {True: 'train', False: 'test'}
        if FLAGS.network_arch == 'resnet':
            logits, end_points = utils_resnet_64x64.ResNet(
                inputs,
                train_or_test[is_training],
                num_outputs=NUM_CLASSES,
                alpha=0.0,
                scope=('ResNet-' + train_or_test[is_training]))
        else:
            raise Exception()
        return logits, end_points

    # Save all intermediate result in the result_folder
    method_name = '_'.join(
        os.path.basename(__file__).split('.')[0].split('_')[4:])
    cls_func = '' if FLAGS.use_softmax else '_sigmoid'
    result_folder = os.path.join(
        FLAGS.result_dir, FLAGS.dataset + ('_flip' if FLAGS.flip else '') +
        '_' + FLAGS.order_file, 'nb_cl_' + str(FLAGS.nb_cl),
        'non_truncated' if FLAGS.no_truncate else 'truncated',
        FLAGS.network_arch + cls_func + '_init_' + FLAGS.init_strategy,
        'weight_decay_' + str(FLAGS.weight_decay),
        'base_lr_' + str(FLAGS.base_lr), method_name)

    # Add a "_run-i" suffix to the folder name if the folder exists
    if os.path.exists(result_folder):
        temp_i = 2
        while True:
            result_folder_mod = result_folder + '_run-' + str(temp_i)
            if not os.path.exists(result_folder_mod):
                result_folder = result_folder_mod
                break
            temp_i += 1
    os.makedirs(result_folder)
    print('Result folder: %s' % result_folder)
    '''
    Define variables
    '''
    batch_images = tf.placeholder(tf.float32, shape=[None, 64, 64, 3])
    batch = tf.Variable(0, trainable=False)
    learning_rate = tf.placeholder(tf.float32, shape=[])
    '''
    Network output mask
    '''
    mask_output = tf.placeholder(tf.bool, shape=[NUM_CLASSES])
    '''
    Old and new ground truth
    '''
    one_hot_labels_truncated = tf.placeholder(tf.float32, shape=[None, None])
    '''
    Define the training network
    '''
    train_logits, _ = build_cnn(batch_images, True)
    train_masked_logits = tf.gather(train_logits,
                                    tf.squeeze(tf.where(mask_output)),
                                    axis=1)
    train_masked_logits = tf.cond(
        tf.equal(tf.rank(train_masked_logits),
                 1), lambda: tf.expand_dims(train_masked_logits, 1),
        lambda: train_masked_logits)
    train_pred = tf.argmax(train_masked_logits, 1)
    train_ground_truth = tf.argmax(one_hot_labels_truncated, 1)
    correct_prediction = tf.equal(train_pred, train_ground_truth)
    train_accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    reg_weights = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    regularization_loss = FLAGS.weight_decay * tf.add_n(reg_weights)
    '''
    More Settings
    '''
    if FLAGS.use_softmax:
        empirical_loss = tf.losses.softmax_cross_entropy(
            onehot_labels=one_hot_labels_truncated, logits=train_masked_logits)
    else:
        empirical_loss = tf.losses.sigmoid_cross_entropy(
            multi_class_labels=one_hot_labels_truncated,
            logits=train_masked_logits)

    loss = empirical_loss + regularization_loss
    if FLAGS.use_momentum:
        opt = tf.train.MomentumOptimizer(
            learning_rate, FLAGS.momentum).minimize(loss, global_step=batch)
    else:
        opt = tf.train.GradientDescentOptimizer(learning_rate).minimize(
            loss, global_step=batch)
    '''
    Define the testing network
    '''
    test_logits, _ = build_cnn(batch_images, False)
    test_masked_logits = tf.gather(test_logits,
                                   tf.squeeze(tf.where(mask_output)),
                                   axis=1)
    test_masked_logits = tf.cond(tf.equal(tf.rank(test_masked_logits), 1),
                                 lambda: tf.expand_dims(test_masked_logits, 1),
                                 lambda: test_masked_logits)
    test_pred = tf.argmax(test_masked_logits, 1)
    test_accuracy = tf.placeholder(tf.float32)
    '''
    Copy network (define the copying op)
    '''
    if FLAGS.network_arch == 'resnet':
        all_variables = tf.get_collection(tf.GraphKeys.WEIGHTS)
    else:
        raise Exception('Invalid network architecture')
    copy_ops = [
        all_variables[ix + len(all_variables) // 2].assign(var.value())
        for ix, var in enumerate(all_variables[0:len(all_variables) // 2])
    ]
    '''
    Init certain layers when new classes added
    '''
    init_ops = tf.no_op()
    if FLAGS.init_strategy == 'all':
        init_ops = tf.global_variables_initializer()
    elif FLAGS.init_strategy == 'last':
        if FLAGS.network_arch == 'resnet':
            init_vars = [
                var for var in tf.global_variables()
                if 'fc' in var.name and 'train' in var.name
            ]
        init_ops = tf.initialize_variables(init_vars)
    '''
    Create session
    '''
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    saver = tf.train.Saver()
    '''
    Summary
    '''
    train_loss_summary = tf.summary.scalar('train_loss', loss)
    train_acc_summary = tf.summary.scalar('train_accuracy', train_accuracy)
    test_acc_summary = tf.summary.scalar('test_accuracy', test_accuracy)

    summary_dir = os.path.join(result_folder, 'summary')
    if not os.path.exists(summary_dir):
        os.makedirs(summary_dir)
    train_summary_writer = tf.summary.FileWriter(
        os.path.join(summary_dir, 'train'), sess.graph)
    test_summary_writer = tf.summary.FileWriter(
        os.path.join(summary_dir, 'test'))

    iteration = 0
    '''
    Declaration of other vars
    '''
    # Average accuracy on seen classes
    aver_acc_over_time = dict()
    aver_acc_per_class_over_time = dict()
    conf_mat_over_time = dict()

    # Network mask
    mask_output_val = np.zeros([NUM_CLASSES], dtype=bool)
    mask_output_test = np.zeros([NUM_CLASSES], dtype=bool)

    # train and test data of seen classes
    train_x = np.zeros([0, 64, 64, 3], dtype=np.float32)
    train_y = np.zeros([0, NUM_CLASSES], dtype=np.float32)
    test_x = np.zeros([0, 64, 64, 3], dtype=np.float32)
    test_y = np.zeros([0], dtype=np.float32)

    test_images, test_labels, test_one_hot_labels, _ = imagenet_64x64.load_test_data(
    )
    '''
    Class Incremental Learning
    '''
    print('Starting from category ' + str(FLAGS.from_class_idx + 1) + ' to ' +
          str(FLAGS.to_class_idx + 1))
    print('Adding %d categories every time' % FLAGS.nb_cl)
    assert (FLAGS.from_class_idx % FLAGS.nb_cl == 0)
    for category_idx in range(FLAGS.from_class_idx, FLAGS.to_class_idx + 1,
                              FLAGS.nb_cl):

        to_category_idx = category_idx + FLAGS.nb_cl - 1
        if FLAGS.nb_cl == 1:
            print('Adding Category ' + str(category_idx + 1))
        else:
            print('Adding Category %d-%d' %
                  (category_idx + 1, to_category_idx + 1))

        if FLAGS.no_truncate:
            mask_output_val[:] = True
        else:
            mask_output_val[:to_category_idx + 1] = True

        # Test on all seen classes
        mask_output_test[:to_category_idx + 1] = True

        for category_idx_in_group in range(category_idx, to_category_idx + 1):
            real_category_idx = order[category_idx_in_group]
            real_images_train_cur_cls, _ = imagenet_64x64.load_train_data(
                real_category_idx, flip=FLAGS.flip)
            train_y_cur_cls = np.zeros(
                [len(real_images_train_cur_cls), NUM_CLASSES])
            train_y_cur_cls[:, category_idx_in_group] = np.ones(
                [len(real_images_train_cur_cls)])

            train_x = np.concatenate((train_x, real_images_train_cur_cls))
            train_y = np.concatenate((train_y, train_y_cur_cls))

            test_indices_cur_cls = [
                idx for idx in range(len(test_labels))
                if test_labels[idx] == real_category_idx
            ]
            test_x_cur_cls = test_images[test_indices_cur_cls, :]
            test_y_cur_cls = np.ones([len(test_indices_cur_cls)
                                      ]) * category_idx_in_group

            test_x = np.concatenate((test_x, test_x_cur_cls))
            test_y = np.concatenate((test_y, test_y_cur_cls))

        if FLAGS.no_truncate:
            train_y_truncated = train_y[:, :]
        else:
            train_y_truncated = train_y[:, :to_category_idx + 1]

        # No need to train the classifier if there is only one class
        if to_category_idx > 0 or not FLAGS.use_softmax:

            # init certain layers
            sess.run(init_ops)

            # Shuffle the indices and create mini-batch
            batch_indices_perm = []

            epoch_idx = 0
            lr = FLAGS.base_lr

            while True:
                # Generate mini-batch
                if len(batch_indices_perm) == 0:
                    if epoch_idx >= FLAGS.epochs_per_category:
                        break
                    if epoch_idx in lr_strat:
                        lr /= FLAGS.lr_factor
                        print("NEW LEARNING RATE: %f" % lr)
                    epoch_idx = epoch_idx + 1

                    shuffled_indices = range(len(train_x))
                    np.random.shuffle(shuffled_indices)
                    for i in range(0, len(shuffled_indices),
                                   FLAGS.train_batch_size):
                        batch_indices_perm.append(
                            shuffled_indices[i:i + FLAGS.train_batch_size])
                    batch_indices_perm.reverse()

                popped_batch_idx = batch_indices_perm.pop()

                # Use the random index to select random images and labels.
                train_x_batch = train_x[popped_batch_idx, :, :, :]
                train_y_batch = [
                    train_y_truncated[k] for k in popped_batch_idx
                ]

                # Train
                train_loss_summary_str, train_acc_summary_str, train_accuracy_val, \
                train_loss_val, train_empirical_loss_val, train_reg_loss_val, _ = sess.run(
                    [train_loss_summary, train_acc_summary, train_accuracy, loss, empirical_loss,
                     regularization_loss, opt], feed_dict={batch_images: train_x_batch,
                                                           one_hot_labels_truncated: train_y_batch,
                                                           mask_output: mask_output_val,
                                                           learning_rate: lr})

                # Test
                if iteration % FLAGS.test_interval == 0:
                    sess.run(copy_ops)

                    # Divide and conquer: to avoid allocating too much GPU memory
                    test_pred_val = []
                    for i in range(0, len(test_x), FLAGS.test_batch_size):
                        test_x_batch = test_x[i:i + FLAGS.test_batch_size]
                        test_pred_val_batch = sess.run(test_pred,
                                                       feed_dict={
                                                           batch_images:
                                                           test_x_batch,
                                                           mask_output:
                                                           mask_output_test
                                                       })
                        test_pred_val.extend(test_pred_val_batch)

                    test_accuracy_val = 1. * np.sum(
                        np.equal(test_pred_val, test_y)) / (len(test_pred_val))
                    test_per_class_accuracy_val = np.diag(
                        confusion_matrix(test_y, test_pred_val)) * 2
                    # I simply multiply the correct predictions by 2 to calculate the accuracy since there are 50 samples per class in the test set

                    test_acc_summary_str = sess.run(
                        test_acc_summary,
                        feed_dict={test_accuracy: test_accuracy_val})

                    test_summary_writer.add_summary(test_acc_summary_str,
                                                    iteration)

                    print("TEST: step %d, lr %.4f, accuracy %g" %
                          (iteration, lr, test_accuracy_val))
                    print("PER CLASS ACCURACY: " + " | ".join(
                        str(o) + '%' for o in test_per_class_accuracy_val))

                # Print the training logs
                if iteration % FLAGS.display_interval == 0:
                    train_summary_writer.add_summary(train_loss_summary_str,
                                                     iteration)
                    train_summary_writer.add_summary(train_acc_summary_str,
                                                     iteration)
                    print(
                        "TRAIN: epoch %d, step %d, lr %.4f, accuracy %g, loss %g, empirical %g, reg %g"
                        % (epoch_idx, iteration, lr, train_accuracy_val,
                           train_loss_val, train_empirical_loss_val,
                           train_reg_loss_val))

                iteration = iteration + 1
            '''
            Final test(before the next class is added)
            '''
            sess.run(copy_ops)
            # Divide and conquer: to avoid allocating too much GPU memory
            test_pred_val = []
            for i in range(0, len(test_x), FLAGS.test_batch_size):
                test_x_batch = test_x[i:i + FLAGS.test_batch_size]
                test_pred_val_batch = sess.run(test_pred,
                                               feed_dict={
                                                   batch_images: test_x_batch,
                                                   mask_output:
                                                   mask_output_test
                                               })
                test_pred_val.extend(test_pred_val_batch)

            test_accuracy_val = 1. * np.sum(np.equal(
                test_pred_val, test_y)) / (len(test_pred_val))
            conf_mat = confusion_matrix(test_y, test_pred_val)
            test_per_class_accuracy_val = np.diag(conf_mat)

            # Record and save the cumulative accuracy
            aver_acc_over_time[to_category_idx] = test_accuracy_val
            aver_acc_per_class_over_time[
                to_category_idx] = test_per_class_accuracy_val
            conf_mat_over_time[to_category_idx] = conf_mat

            dump_obj = dict()
            dump_obj['flags'] = flags.FLAGS.__flags
            dump_obj['aver_acc_over_time'] = aver_acc_over_time
            dump_obj[
                'aver_acc_per_class_over_time'] = aver_acc_per_class_over_time
            dump_obj['conf_mat_over_time'] = conf_mat_over_time

            np_file_result = os.path.join(result_folder, 'acc_over_time.pkl')
            with open(np_file_result, 'wb') as file:
                pickle.dump(dump_obj, file)

            visualize_result.vis(np_file_result, 'ImageNetDogs')

    # Save the final model
    checkpoint_dir = os.path.join(result_folder, 'checkpoints')
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    saver.save(sess, os.path.join(checkpoint_dir, 'model.ckpt'))
    sess.close()
示例#4
0
def main(_):

    pp.pprint(flags.FLAGS.__flags)

    # Load the class order
    order = []
    with open('cifar-100_%s.txt' % FLAGS.order_file) as file_in:
        for line in file_in.readlines():
            order.append(int(line))
    order = np.array(order)

    import cifar100
    NUM_CLASSES = 100  # number of classes
    NUM_TRAIN_SAMPLES_PER_CLASS = 500  # number of training samples per class
    NUM_TEST_SAMPLES_PER_CLASS = 100  # number of test samples per class
    train_images, train_labels, train_one_hot_labels, \
        test_images, test_labels, test_one_hot_labels, \
        raw_images_train, raw_images_test, pixel_mean = cifar100.load_data(order, mean_subtraction=True)

    # Number of all training samples
    NUM_TRAIN_SAMPLES_TOTAL = NUM_CLASSES * NUM_TRAIN_SAMPLES_PER_CLASS
    NUM_TEST_SAMPLES_TOTAL = NUM_CLASSES * NUM_TEST_SAMPLES_PER_CLASS

    def build_cnn(inputs, is_training):
        train_or_test = {True: 'train', False: 'test'}
        if FLAGS.network_arch == 'lenet':
            logits, end_points = utils_lenet.lenet(
                inputs,
                num_classes=NUM_CLASSES,
                is_training=is_training,
                use_dropout=FLAGS.use_dropout,
                scope=('LeNet-' + train_or_test[is_training]))
        elif FLAGS.network_arch == 'resnet':
            logits, end_points = utils_resnet.ResNet(
                inputs,
                train_or_test[is_training],
                num_outputs=NUM_CLASSES,
                alpha=0.0,
                n=FLAGS.num_resblocks,
                scope=('ResNet-' + train_or_test[is_training]))
        elif FLAGS.network_arch == 'nin':
            logits, end_points = utils_nin.nin(
                inputs,
                is_training=is_training,
                num_classes=NUM_CLASSES,
                scope=('NIN-' + train_or_test[is_training]))
        else:
            raise Exception('Invalid network architecture')
        return logits, end_points

    '''
    Define variables
    '''
    # Save all intermediate result in the result_folder
    method_name = '_'.join(
        os.path.basename(__file__).split('.')[0].split('_')[2:])

    cls_func = '' if FLAGS.use_softmax else '_sigmoid'
    result_base_folder = os.path.join(
        FLAGS.result_dir, 'cifar-100_' + FLAGS.order_file,
        'nb_cl_' + str(FLAGS.nb_cl),
        'non_truncated' if FLAGS.no_truncate else 'truncated',
        FLAGS.network_arch +
        ('_%d' % FLAGS.num_resblocks if FLAGS.network_arch == 'resnet' else '')
        + cls_func + '_init_' + FLAGS.init_strategy,
        'weight_decay_' + str(FLAGS.weight_decay),
        'base_lr_' + str(FLAGS.base_lr), 'adam_lr_' + str(FLAGS.adam_lr))

    result_folder = os.path.join(
        result_base_folder,
        FLAGS.exemplars_base_folder + '_ablation_epoch_based')

    exemplars_folder = os.path.join(result_base_folder,
                                    FLAGS.exemplars_base_folder, 'exemplars')

    if not os.path.exists(exemplars_folder):
        raise Exception()

    # Add a "_run-i" suffix to the folder name if the folder exists
    if os.path.exists(result_folder):
        temp_i = 2
        while True:
            result_folder_mod = result_folder + '_run-' + str(temp_i)
            if not os.path.exists(result_folder_mod):
                result_folder = result_folder_mod
                break
            temp_i += 1
    os.makedirs(result_folder)
    print('Result folder: %s' % result_folder)

    graph_cls = tf.Graph()
    with graph_cls.as_default():
        batch_images = tf.placeholder(tf.float32, shape=[None, 32, 32, 3])
        batch = tf.Variable(0, trainable=False, name='LeNet-train/iteration')
        learning_rate = tf.placeholder(tf.float32, shape=[])
        '''
        Network output mask
        '''
        mask_output = tf.placeholder(tf.bool, shape=[NUM_CLASSES])
        '''
        Old and new ground truth
        '''
        one_hot_labels_truncated = tf.placeholder(tf.float32,
                                                  shape=[None, None])
        '''
        Define the training network
        '''
        train_logits, _ = build_cnn(batch_images, True)
        train_masked_logits = tf.gather(train_logits,
                                        tf.squeeze(tf.where(mask_output)),
                                        axis=1)
        train_masked_logits = tf.cond(
            tf.equal(tf.rank(train_masked_logits),
                     1), lambda: tf.expand_dims(train_masked_logits, 1),
            lambda: train_masked_logits)
        train_pred = tf.argmax(train_masked_logits, 1)
        train_ground_truth = tf.argmax(one_hot_labels_truncated, 1)
        correct_prediction = tf.equal(train_pred, train_ground_truth)
        train_accuracy = tf.reduce_mean(tf.cast(correct_prediction,
                                                tf.float32))
        train_batch_weights = tf.placeholder(tf.float32, shape=[None])

        reg_weights = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
        regularization_loss = FLAGS.weight_decay * tf.add_n(reg_weights)
        '''
        More Settings
        '''
        if FLAGS.use_softmax:
            empirical_loss = tf.losses.softmax_cross_entropy(
                onehot_labels=one_hot_labels_truncated,
                logits=train_masked_logits,
                weights=train_batch_weights)
        else:
            empirical_loss = tf.losses.sigmoid_cross_entropy(
                multi_class_labels=one_hot_labels_truncated,
                logits=train_masked_logits,
                weights=train_batch_weights)

        loss = empirical_loss + regularization_loss
        if FLAGS.use_momentum:
            opt = tf.train.MomentumOptimizer(
                learning_rate, FLAGS.momentum).minimize(loss,
                                                        global_step=batch)
        else:
            opt = tf.train.GradientDescentOptimizer(learning_rate).minimize(
                loss, global_step=batch)
        '''
        Define the testing network
        '''
        test_logits, _ = build_cnn(batch_images, False)
        test_masked_logits = tf.gather(test_logits,
                                       tf.squeeze(tf.where(mask_output)),
                                       axis=1)
        test_masked_logits = tf.cond(
            tf.equal(tf.rank(test_masked_logits),
                     1), lambda: tf.expand_dims(test_masked_logits, 1),
            lambda: test_masked_logits)
        test_masked_prob = tf.nn.softmax(test_masked_logits)
        test_pred = tf.argmax(test_masked_logits, 1)
        test_accuracy = tf.placeholder(tf.float32)
        '''
        Copy network (define the copying op)
        '''
        if FLAGS.network_arch == 'resnet':
            all_variables = tf.get_collection(tf.GraphKeys.WEIGHTS)
        else:
            all_variables = tf.trainable_variables()
        copy_ops = [
            all_variables[ix + len(all_variables) // 2].assign(var.value())
            for ix, var in enumerate(all_variables[0:len(all_variables) // 2])
        ]
        '''
        Init certain layers when new classes added
        '''
        init_ops = tf.no_op()
        if FLAGS.init_strategy == 'all':
            init_ops = tf.global_variables_initializer()
        elif FLAGS.init_strategy == 'last':
            if FLAGS.network_arch == 'lenet':
                init_vars = [
                    var for var in tf.global_variables()
                    if 'fc4' in var.name and 'train' in var.name
                ]
            elif FLAGS.network_arch == 'resnet':
                init_vars = [
                    var for var in tf.global_variables()
                    if 'fc' in var.name and 'train' in var.name
                ]
            elif FLAGS.network_arch == 'nin':
                init_vars = [
                    var for var in tf.global_variables()
                    if 'ccp6' in var.name and 'train' in var.name
                ]
            init_ops = tf.initialize_variables(init_vars)
        '''
        Create session
        '''
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config, graph=graph_cls)
        sess.run(tf.global_variables_initializer())

        saver = tf.train.Saver()
    '''
    Summary
    '''
    train_loss_summary = tf.summary.scalar('train_loss', loss)
    train_acc_summary = tf.summary.scalar('train_accuracy', train_accuracy)
    test_acc_summary = tf.summary.scalar('test_accuracy', test_accuracy)

    summary_dir = os.path.join(result_folder, 'summary')
    if not os.path.exists(summary_dir):
        os.makedirs(summary_dir)
    train_summary_writer = tf.summary.FileWriter(
        os.path.join(summary_dir, 'train'), sess.graph)
    test_summary_writer = tf.summary.FileWriter(
        os.path.join(summary_dir, 'test'))

    iteration = 0
    '''
    Declaration of other vars
    '''
    # Average accuracy on seen classes
    aver_acc_over_time = dict()
    aver_acc_per_class_over_time = dict()
    conf_mat_over_time = dict()

    # Network mask
    mask_output_val = np.zeros([NUM_CLASSES], dtype=bool)
    '''
    Class Incremental Learning
    '''
    print('Starting from category ' + str(FLAGS.from_class_idx + 1) + ' to ' +
          str(FLAGS.to_class_idx + 1))
    print('Adding %d categories every time' % FLAGS.nb_cl)
    assert (FLAGS.from_class_idx % FLAGS.nb_cl == 0)
    for category_idx in range(FLAGS.from_class_idx, FLAGS.to_class_idx + 1,
                              FLAGS.nb_cl):

        to_category_idx = category_idx + FLAGS.nb_cl - 1
        if FLAGS.nb_cl == 1:
            print('Adding Category ' + str(category_idx + 1))
        else:
            print('Adding Category %d-%d' %
                  (category_idx + 1, to_category_idx + 1))
        '''
        Train classification model
        '''
        if FLAGS.no_truncate:
            mask_output_val[:] = True
        else:
            mask_output_val[:to_category_idx + 1] = True

        if to_category_idx > 0:

            # init certain layers
            sess.run(init_ops)

            # Training set (current category)
            train_indices = [
                idx for idx in range(NUM_TRAIN_SAMPLES_TOTAL)
                if category_idx <= train_labels[idx] <= to_category_idx
            ]
            train_x = raw_images_train[train_indices]

            if FLAGS.no_truncate:
                train_y_truncated = train_one_hot_labels[train_indices, :]
            else:
                train_y_truncated = train_one_hot_labels[
                    train_indices, :to_category_idx + 1]
            train_weights_val = np.ones(len(train_x))
            '''
            Generate samples of old classes
            '''
            if category_idx > 0:
                exemplars = np.load(
                    os.path.join(exemplars_folder,
                                 'exemplars_%d.npy' % category_idx))
                for old_category_idx in range(0, category_idx):

                    # Load old class model
                    exemplars_idx_cur_cls = np.random.choice(
                        len(exemplars[old_category_idx]),
                        NUM_TRAIN_SAMPLES_PER_CLASS,
                        replace=True)
                    exemplars_cur_cls = exemplars[old_category_idx][
                        exemplars_idx_cur_cls]

                    train_x = np.concatenate((train_x, exemplars_cur_cls))
                    train_weights_val = np.concatenate(
                        (train_weights_val,
                         np.ones(NUM_TRAIN_SAMPLES_PER_CLASS)))

                    train_y_old_cls = np.zeros(
                        (NUM_TRAIN_SAMPLES_PER_CLASS, to_category_idx + 1))
                    train_y_old_cls[:, old_category_idx] = np.ones(
                        (NUM_TRAIN_SAMPLES_PER_CLASS))
                    train_y_truncated = np.concatenate(
                        (train_y_truncated, train_y_old_cls))

            # # DEBUG:
            # train_indices = [idx for idx in range(NUM_SAMPLES_TOTAL) if train_labels[idx] <= category_idx]
            # train_x = raw_images_train[train_indices, :]
            # # Record the response of the new data using the old model(category_idx is consistent with the number of True in mask_output_val_prev)
            # train_y_truncated = train_one_hot_labels[train_indices, :category_idx + 1]

            # Training set
            # Convert the raw images from the data-files to floating-points.
            train_x = cifar100.convert_images(train_x, pixel_mean=pixel_mean)

            # Testing set
            test_indices = [
                idx for idx in range(len(test_labels))
                if test_labels[idx] <= to_category_idx
            ]
            test_x = test_images[test_indices]
            test_y = test_labels[test_indices]

            # Shuffle the indices and create mini-batch
            batch_indices_perm = []

            epoch_idx = 0
            lr = FLAGS.base_lr
            '''
            Training with mixed data
            '''
            while True:
                # Generate mini-batch
                if len(batch_indices_perm) == 0:
                    if epoch_idx >= FLAGS.epochs_per_category:
                        break
                    if epoch_idx in lr_strat:
                        lr /= FLAGS.lr_factor
                        print("NEW LEARNING RATE: %f" % lr)
                    epoch_idx = epoch_idx + 1

                    shuffled_indices = range(train_x.shape[0])
                    np.random.shuffle(shuffled_indices)
                    for i in range(0, len(shuffled_indices),
                                   FLAGS.train_batch_size):
                        batch_indices_perm.append(
                            shuffled_indices[i:i + FLAGS.train_batch_size])
                    batch_indices_perm.reverse()

                popped_batch_idx = batch_indices_perm.pop()

                # Use the random index to select random images and labels.
                train_weights_batch_val = train_weights_val[popped_batch_idx]
                train_y_batch = [
                    train_y_truncated[k] for k in popped_batch_idx
                ]
                train_x_batch = train_x[popped_batch_idx, :, :, :]

                # Train
                train_loss_summary_str, train_acc_summary_str, train_accuracy_val, \
                    train_loss_val, train_empirical_loss_val, train_reg_loss_val, _ = sess.run(
                        [train_loss_summary, train_acc_summary, train_accuracy, loss, empirical_loss,
                            regularization_loss, opt], feed_dict={batch_images: train_x_batch,
                                                                  one_hot_labels_truncated: train_y_batch,
                                                                  mask_output: mask_output_val,
                                                                  learning_rate: lr,
                                                                  train_batch_weights: train_weights_batch_val})

                # Test
                if iteration % FLAGS.test_interval == 0:
                    sess.run(copy_ops)

                    # Divide and conquer: to avoid allocating too much GPU memory
                    test_pred_val = []
                    for i in range(0, len(test_x), FLAGS.test_batch_size):
                        test_x_batch = test_x[i:i + FLAGS.test_batch_size]
                        test_pred_val_batch = sess.run(test_pred,
                                                       feed_dict={
                                                           batch_images:
                                                           test_x_batch,
                                                           mask_output:
                                                           mask_output_val
                                                       })
                        test_pred_val.extend(test_pred_val_batch)

                    test_accuracy_val = 1. * np.sum(
                        np.equal(test_pred_val, test_y)) / (len(test_pred_val))
                    test_per_class_accuracy_val = np.diag(
                        confusion_matrix(test_y, test_pred_val))

                    test_acc_summary_str = sess.run(
                        test_acc_summary,
                        feed_dict={test_accuracy: test_accuracy_val})

                    test_summary_writer.add_summary(test_acc_summary_str,
                                                    iteration)

                    print("TEST: step %d, lr %.4f, accuracy %g" %
                          (iteration, lr, test_accuracy_val))
                    print("PER CLASS ACCURACY: " + " | ".join(
                        str(o) + '%' for o in test_per_class_accuracy_val))

                # Print the training logs
                if iteration % FLAGS.display_interval == 0:
                    train_summary_writer.add_summary(train_loss_summary_str,
                                                     iteration)
                    train_summary_writer.add_summary(train_acc_summary_str,
                                                     iteration)
                    print(
                        "TRAIN: epoch %d, step %d, lr %.4f, accuracy %g, loss %g, empirical %g, reg %g"
                        % (epoch_idx, iteration, lr, train_accuracy_val,
                           train_loss_val, train_empirical_loss_val,
                           train_reg_loss_val))

                iteration = iteration + 1
            '''
            Final test(before the next class is added)
            '''
            sess.run(copy_ops)
            # Divide and conquer: to avoid allocating too much GPU memory
            test_pred_val = []
            for i in range(0, len(test_x), FLAGS.test_batch_size):
                test_x_batch = test_x[i:i + FLAGS.test_batch_size]
                test_pred_val_batch = sess.run(test_pred,
                                               feed_dict={
                                                   batch_images: test_x_batch,
                                                   mask_output: mask_output_val
                                               })
                test_pred_val.extend(test_pred_val_batch)

            test_accuracy_val = 1. * np.sum(np.equal(
                test_pred_val, test_y)) / (len(test_pred_val))
            conf_mat = confusion_matrix(test_y, test_pred_val)
            test_per_class_accuracy_val = np.diag(conf_mat)

            # Record and save the cumulative accuracy
            aver_acc_over_time[to_category_idx] = test_accuracy_val
            aver_acc_per_class_over_time[
                to_category_idx] = test_per_class_accuracy_val
            conf_mat_over_time[to_category_idx] = conf_mat

            dump_obj = dict()
            dump_obj['flags'] = flags.FLAGS.__flags
            dump_obj['aver_acc_over_time'] = aver_acc_over_time
            dump_obj[
                'aver_acc_per_class_over_time'] = aver_acc_per_class_over_time
            dump_obj['conf_mat_over_time'] = conf_mat_over_time

            np_file_result = os.path.join(result_folder, 'acc_over_time.pkl')
            with open(np_file_result, 'wb') as file:
                pickle.dump(dump_obj, file)

            visualize_result.vis(np_file_result)

    # Save the final model
    checkpoint_dir = os.path.join(result_folder, 'checkpoints')
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    saver.save(sess, os.path.join(checkpoint_dir, 'model.ckpt'))
    sess.close()