Example #1
0
def main():
    """
    Create the model and start the training
    """

    # Get the CL arguments
    args = get_arguments()
    
    # Initialize the random seed of numpy
    np.random.seed(args.random_seed)

    # Check if the network architecture is valid
    if args.arch not in VALID_ARCHS:
        raise ValueError("Network architecture %s is not supported!"%(args.arch))

    # Check if the method to compute importance is valid
    if args.imp_method not in MODELS:
        raise ValueError("Importance measure %s is undefined!"%(args.imp_method))
    
    # Check if the optimizer is valid
    if args.optim not in VALID_OPTIMS:
        raise ValueError("Optimizer %s is undefined!"%(args.optim))

    # Create log directories to store the results
    if not os.path.exists(args.log_dir):
        print('Log directory %s created!'%(args.log_dir))
        os.makedirs(args.log_dir)

    if args.online_cross_val:
        num_tasks = K_FOR_CROSS_VAL
    else:
        num_tasks = NUM_TASKS - K_FOR_CROSS_VAL

    # Load the split AWA dataset
    data_labs = [np.arange(TOTAL_CLASSES)]
    datasets, AWA_attr = construct_split_awa(data_labs, args.data_dir, AWA_TRAIN_LIST, AWA_VAL_LIST, AWA_TEST_LIST, IMG_HEIGHT, IMG_WIDTH, attr_file=AWA_ATTR_LIST)
    if args.online_cross_val:
        AWA_attr[K_FOR_CROSS_VAL*CLASSES_PER_TASK:] = 0
    else:
        AWA_attr[:K_FOR_CROSS_VAL*CLASSES_PER_TASK] = 0

    print('Attributes: {}'.format(np.sum(AWA_attr, axis=1)))

    if args.cross_validate_mode:
        models_list = MODELS
        learning_rate_list = [0.1, 0.03, 0.01, 0.001, 0.0003]
    else:
        models_list = [args.imp_method]
    for imp_method in models_list:
        if imp_method == 'VAN':
            synap_stgth_list = [0]
            if args.online_cross_val or args.cross_validate_mode:
                pass
            else:
                learning_rate_list = [0.003]
        elif imp_method == 'PI':
            if args.online_cross_val or args.cross_validate_mode:
                synap_stgth_list = [0.1, 1, 10]
            else:
                synap_stgth_list = [10]
                learning_rate_list = [0.003]
        elif imp_method == 'EWC' or imp_method == 'M-EWC':
            if args.online_cross_val or args.cross_validate_mode:
                synap_stgth_list = [0.1, 1, 10, 100]
            else:
                synap_stgth_list = [100]
                learning_rate_list = [0.003]
        elif imp_method == 'MAS':
            if args.online_cross_val or args.cross_validate_mode:
                synap_stgth_list = [0.1, 1, 10, 100]
            else:
                synap_stgth_list = [0.1]
                learning_rate_list = [0.001]
        elif imp_method == 'RWALK':
            if args.online_cross_val or args.cross_validate_mode:
                synap_stgth_list = [0.1, 1, 10, 100]
            else:
                synap_stgth_list = [10]     # Check again!
                learning_rate_list = [0.003]
        elif imp_method == 'S-GEM':
            synap_stgth_list = [0]
            if args.online_cross_val:
                pass
            else:
                learning_rate_list = [args.learning_rate]
        elif imp_method == 'A-GEM':
            synap_stgth_list = [0]
            if args.online_cross_val or args.cross_validate_mode:
                pass
            else:
                learning_rate_list = [0.003]

        for synap_stgth in synap_stgth_list:
            for lr in learning_rate_list:
                # Generate the experiment key and store the meta data in a file
                exper_meta_data = {'ARCH': args.arch,
                    'DATASET': 'SPLIT_AWA',
                    'HYBRID': args.set_hybrid,
                    'NUM_RUNS': args.num_runs,
                    'EVAL_SINGLE_HEAD': args.eval_single_head, 
                    'TRAIN_SINGLE_EPOCH': args.train_single_epoch, 
                    'IMP_METHOD': imp_method, 
                    'SYNAP_STGTH': synap_stgth,
                    'FISHER_EMA_DECAY': args.fisher_ema_decay,
                    'FISHER_UPDATE_AFTER': args.fisher_update_after,
                    'OPTIM': args.optim, 
                    'LR': lr, 
                    'BATCH_SIZE': args.batch_size, 
                    'EPS_MEMORY': args.do_sampling, 
                    'MEM_SIZE': args.mem_size, 
                    'IS_HERDING': args.is_herding}
                experiment_id = "SPLIT_AWA_HERDING_%r_HYB_%r_%s_%r_%r_%s_%s_%s_%r_%s-"%(args.is_herding, args.set_hybrid, args.arch, args.eval_single_head, args.train_single_epoch, imp_method, 
                        str(synap_stgth).replace('.', '_'), 
                        str(args.batch_size), args.do_sampling, str(args.mem_size)) + datetime.datetime.now().strftime("%y-%m-%d-%H-%M")
                snapshot_experiment_meta_data(args.log_dir, experiment_id, exper_meta_data)

                # Reset the default graph
                tf.reset_default_graph()
                graph  = tf.Graph()
                with graph.as_default():

                    # Set the random seed
                    tf.set_random_seed(args.random_seed)

                    # Define Input and Output of the model
                    x = tf.placeholder(tf.float32, shape=[None, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS])
                    y_ = tf.placeholder(tf.float32, shape=[None, num_tasks*TOTAL_CLASSES])
                    attr = tf.placeholder(tf.float32, shape=[num_tasks*TOTAL_CLASSES, ATTR_DIMS])

                    if not args.train_single_epoch:
                        # Define ops for data augmentation
                        x_aug = image_scaling(x)
                        x_aug = random_crop_and_pad_image(x_aug, IMG_HEIGHT, IMG_WIDTH)

                    # Define the optimizer
                    if args.optim == 'ADAM':
                        opt = tf.train.AdamOptimizer(learning_rate=lr)

                    elif args.optim == 'SGD':
                        opt = tf.train.GradientDescentOptimizer(learning_rate=lr)

                    elif args.optim == 'MOMENTUM':
                        base_lr = tf.constant(lr)
                        learning_rate = tf.scalar_mul(base_lr, tf.pow((1 - train_step / training_iters), OPT_POWER))
                        opt = tf.train.MomentumOptimizer(lr, OPT_MOMENTUM)

                    # Create the Model/ contruct the graph
                    if args.train_single_epoch:
                        # When training using a single epoch then there is no need for data augmentation
                        model = Model(x, y_, num_tasks, opt, imp_method, synap_stgth, args.fisher_update_after,
                                args.fisher_ema_decay, network_arch=args.arch, is_ATT_DATASET=True, attr=attr)
                    else:
                        model = Model(x_aug, y_, num_tasks, opt, imp_method, synap_stgth, args.fisher_update_after, 
                                args.fisher_ema_decay, network_arch=args.arch, is_ATT_DATASET=True, x_test=x, attr=attr)

                    # Set up tf session and initialize variables.
                    config = tf.ConfigProto()
                    config.gpu_options.allow_growth = True

                    time_start = time.time()
                    with tf.Session(config=config, graph=graph) as sess:
                        saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=100)
                        runs, task_labels_dataset = train_task_sequence(model, sess, saver, datasets, AWA_attr, CLASSES_PER_TASK, args.cross_validate_mode, 
                                args.train_single_epoch, args.eval_single_head, args.do_sampling, args.is_herding, args.mem_size*CLASSES_PER_TASK*num_tasks, args.train_iters, 
                                args.batch_size, args.num_runs, args.init_checkpoint, args.online_cross_val, args.random_seed)
                        # Close the session
                        sess.close()
                    time_end = time.time()
                    time_spent = time_end - time_start
                    print('Time spent: {}'.format(time_spent))

                # Clean up
                del model

                if args.cross_validate_mode:
                    # If cross-validation flag is enabled, store the stuff in a text file
                    cross_validate_dump_file = args.log_dir + '/' + 'SPLIT_AWA_HYBRID_%s_%s'%(imp_method, args.optim) + '.txt'
                    with open(cross_validate_dump_file, 'a') as f:
                        f.write('HERDING: {} \t ARCH: {} \t LR:{} \t LAMBDA: {} \t ACC: {}\n'.format(args.is_herding, args.arch, lr, synap_stgth, runs))
                else:
                    # Store all the results in one dictionary to process later
                    exper_acc = dict(mean=runs)
                    exper_labels = dict(labels=task_labels_dataset)
                    # Store the experiment output to a file
                    snapshot_experiment_eval(args.log_dir, experiment_id, exper_acc)
                    snapshot_task_labels(args.log_dir, experiment_id, exper_labels)
def main():
    """
    Create the model and start the training
    """

    # Get the CL arguments
    args = get_arguments()

    # Check if the network architecture is valid
    if args.arch not in VALID_ARCHS:
        raise ValueError("Network architecture %s is not supported!" %
                         (args.arch))

    # Check if the method to compute importance is valid
    if args.imp_method not in MODELS:
        raise ValueError("Importance measure %s is undefined!" %
                         (args.imp_method))

    # Check if the optimizer is valid
    if args.optim not in VALID_OPTIMS:
        raise ValueError("Optimizer %s is undefined!" % (args.optim))

    # Create log directories to store the results
    if not os.path.exists(args.log_dir):
        print('Log directory %s created!' % (args.log_dir))
        os.makedirs(args.log_dir)

    # Generate the experiment key and store the meta data in a file
    exper_meta_data = {
        'DATASET': 'PERMUTE_MNIST',
        'NUM_RUNS': args.num_runs,
        'TRAIN_SINGLE_EPOCH': args.train_single_epoch,
        'IMP_METHOD': args.imp_method,
        'SYNAP_STGTH': args.synap_stgth,
        'FISHER_EMA_DECAY': args.fisher_ema_decay,
        'FISHER_UPDATE_AFTER': args.fisher_update_after,
        'OPTIM': args.optim,
        'LR': args.learning_rate,
        'BATCH_SIZE': args.batch_size,
        'MEM_SIZE': args.mem_size
    }
    experiment_id = "PERMUTE_MNIST_META_%s_%s_%r_%s-" % (
        args.imp_method, str(args.synap_stgth).replace('.', '_'),
        str(args.batch_size), str(args.mem_size)) + datetime.datetime.now(
        ).strftime("%y-%m-%d-%H-%M")
    snapshot_experiment_meta_data(args.log_dir, experiment_id, exper_meta_data)

    # Get the subset of data depending on training or cross-validation mode
    if args.online_cross_val:
        num_tasks = K_FOR_CROSS_VAL
    else:
        num_tasks = args.num_tasks - K_FOR_CROSS_VAL

    # Variables to store the accuracies and standard deviations of the experiment
    acc_mean = dict()
    acc_std = dict()

    # Reset the default graph
    tf.reset_default_graph()
    graph = tf.Graph()
    with graph.as_default():

        # Set the random seed
        tf.set_random_seed(args.random_seed)

        # Define Input and Output of the model
        x = tf.placeholder(tf.float32, shape=[None, INPUT_FEATURE_SIZE])
        #x = tf.placeholder(tf.float32, shape=[None, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS])
        learning_rate = tf.placeholder(dtype=tf.float32, shape=())
        if args.imp_method == 'PNN':
            y_ = []
            for i in range(num_tasks):
                y_.append(
                    tf.placeholder(tf.float32, shape=[None, TOTAL_CLASSES]))
        else:
            y_ = tf.placeholder(tf.float32, shape=[None, TOTAL_CLASSES])

        # Define the optimizer
        if args.optim == 'ADAM':
            opt = tf.train.AdamOptimizer(learning_rate=learning_rate)

        elif args.optim == 'SGD':
            opt = tf.train.GradientDescentOptimizer(
                learning_rate=learning_rate)

        elif args.optim == 'MOMENTUM':
            #base_lr = tf.constant(args.learning_rate)
            #learning_rate = tf.scalar_mul(base_lr, tf.pow((1 - model.train_step / training_iters), OPT_POWER))
            opt = tf.train.MomentumOptimizer(learning_rate, OPT_MOMENTUM)

        # Create the Model/ contruct the graph
        model = Model(x,
                      y_,
                      num_tasks,
                      opt,
                      args.imp_method,
                      args.synap_stgth,
                      args.fisher_update_after,
                      args.fisher_ema_decay,
                      learning_rate,
                      network_arch=args.arch)

        # Set up tf session and initialize variables.
        if USE_GPU:
            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
        else:
            config = tf.ConfigProto(device_count={'GPU': 0})

        time_start = time.time()
        with tf.Session(config=config, graph=graph) as sess:
            runs = train_task_sequence(model, sess, args)
            # Close the session
            sess.close()
        time_end = time.time()
        time_spent = time_end - time_start

    # Store all the results in one dictionary to process later
    exper_acc = dict(mean=runs)

    # If cross-validation flag is enabled, store the stuff in a text file
    if args.cross_validate_mode:
        acc_mean, acc_std = average_acc_stats_across_runs(
            runs, model.imp_method)
        fgt_mean, fgt_std = average_fgt_stats_across_runs(
            runs, model.imp_method)
        cross_validate_dump_file = args.log_dir + '/' + 'PERMUTE_MNIST_%s_%s' % (
            args.imp_method, args.optim) + '.txt'
        with open(cross_validate_dump_file, 'a') as f:
            if MULTI_TASK:
                f.write(
                    'GPU:{} \t ARCH: {} \t LR:{} \t LAMBDA: {} \t ACC: {}\n'.
                    format(USE_GPU, args.arch, args.learning_rate,
                           args.synap_stgth, acc_mean[-1, :].mean()))
            else:
                f.write(
                    'ORTHO:{} \t NUM_TASKS: {} \t EXAMPLES_PER_TASK: {} \t MEM_SIZE: {} \t ARCH: {} \t LR:{} \t LAMBDA: {} \t ACC: {} (+-{})\t Fgt: {} (+-{})\t QR: {}\t Time: {}\n'
                    .format(args.maintain_orthogonality, args.num_tasks,
                            args.examples_per_task, args.mem_size, args.arch,
                            args.learning_rate, args.synap_stgth, acc_mean,
                            acc_std, fgt_mean, fgt_std, QR, str(time_spent)))

    # Store the experiment output to a file
    snapshot_experiment_eval(args.log_dir, experiment_id, exper_acc)