Ejemplo n.º 1
0
def analyze_experiments(experiment_data_entries, hyperparams_selection=None, plot_seq_acc=True,
                        plot_seq_forgetting=False, save_img_parent_dir=None, img_extention='png', legend_location='top',
                        all_diff_color_force=False, ylim=None, taskcount=10):
    """ Pipeline data collection and plotting/summary."""

    # Collect data
    experiment_data_entries, hyperparams_counts, max_task_count = collect_dataframe(experiment_data_entries,
                                                                                    hyperparams_selection, taskcount)
    # Pad entries
    pad_dataframe(experiment_data_entries, hyperparams_counts)

    # Plot
    if save_img_parent_dir is not None:
        filename_template = save_img_parent_dir + "_TASK{}." + img_extention
        filename_template = filename_template.replace(" ", "")
        save_img_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'imgs',
                                     save_img_parent_dir, img_extention)
        utils.create_dir(save_img_path)
        save_img_path = os.path.join(save_img_path, filename_template)

        plot_multigraphs(experiment_data_entries, save_img_path, max_task_count,
                         plot_seq_forgetting=plot_seq_forgetting,
                         plot_seq_acc=plot_seq_acc,
                         legend_location=legend_location,
                         all_diff_color_force=all_diff_color_force,
                         ylim=ylim,
                         taskcount=taskcount)

    # Collect some statistics
    print_exp_statistics(experiment_data_entries)
Ejemplo n.º 2
0
def saved_for_eval(eval_loader, result, output_path, epoch=None, vs=False):
    """
        save a results file in the format accepted by the submission server.
        input result : [ans_idxs, acc, q_ids, att_weights, spatials, mask_ans_idxs]
    """
    label2ans = eval_loader.dataset.label2ans
    results, mask_results = [], []
    for a, q_id in zip(result[0], result[2]):
        results.append({'question_id': int(q_id), 'answer': label2ans[a]})
        # mask_results.append({'question_id': int(q_id), 'answer': label2ans[ma]})
    if config.mode != 'trainval':
        name = 'eval_results.json'
    else:
        name = 'eval_results_epoch{}.json'.format(epoch)
        # results_path = os.path.join(output_path, name)
        # weights_path = os.path.join(output_path, 'att_weights.h5')
        # mask_results_path = os.path.join(output_path, 'mask_{}_eval_results.json'.format(config.masks))
        # # mask_results_path = os.path.join(output, 'att_random_eval_results.json')
        # with open(mask_results_path, 'w') as fd:
        #     json.dump(mask_results, fd)
    if vs:
        path = os.path.join(output_path, "vs/")
        utils.create_dir(path)
        with open(os.path.join(path, name), 'w') as fd:
            json.dump(results, fd)
        with h5py.File(os.path.join(path, 'att_weights.h5'), 'w') as f:
            f.create_dataset('weights', data=result[3])
            f.create_dataset('spatials', data=result[4])
            f.create_dataset('hints', data=result[6])
    else:
        with open(os.path.join(output_path, name), 'w') as fd:
            json.dump(results, fd)
Ejemplo n.º 3
0
 def save_hyperparams(self, output_dir, hyperparams):
     """ Add extra stats (memory,...) and save to output_dir.
     :param output_dir: Dir to export the dictionary to.
     :param hyperparams: Dictionary with hyperparams to save
     """
     utils.create_dir(output_dir)
     hyperparams_outpath = os.path.join(
         output_dir, utils.get_hyperparams_output_filename())
     torch.save(hyperparams, hyperparams_outpath)
     print("Saved hyperparams to: ", hyperparams_outpath)
Ejemplo n.º 4
0
def download_dset(path, location="eu"):
    """
    Europe links are used, replace if Asia or North America.
    Location: eu (Europe), asia,
    """
    assert location in ["eu", "asia", "us"]
    utils.create_dir(path)

    # TRAIN/VAL IMAGES
    train_link = "https://storage.googleapis.com/inat_data_2018_{}/train_val2018.tar.gz".format(
        location)
    train_tarname = train_link.split('/')[-1]  # train_val2018.tar.gz
    train_dirname = train_tarname.split('.')[0]  # train_val2018

    if not os.path.exists(os.path.join(path, train_tarname)):
        download(path, train_link)
        print("Succesfully downloaded train+val dataset iNaturalist.")
    if not os.path.exists(os.path.join(path, train_dirname)):
        extract(path, os.path.join(path, train_tarname))
        print("Succesfully extracted train+val dataset iNaturalist.")

    # TRAIN JSON
    trainjson_link = "https://storage.googleapis.com/inat_data_2018_{}/train2018.json.tar.gz".format(
        location)
    trainjson_tarname = trainjson_link.split('/')[-1]  # train2018.json.tar.gz
    trainjson_filename = trainjson_link.split('.')[0]  # train2018

    if not os.path.exists(os.path.join(path, trainjson_tarname)):
        download(path, train_link)
        print("Succesfully downloaded train json iNaturalist.")
    if not os.path.exists(os.path.join(path, trainjson_filename)):
        extract(path, os.path.join(path, os.path.join(path,
                                                      trainjson_tarname)))
        print("Succesfully extracted train json iNaturalist.")

    # VAL JSON
    trainjson_link = "https://storage.googleapis.com/inat_data_2018_{}/train2018.json.tar.gz".format(
        location)
    trainjson_tarname = trainjson_link.split('/')[-1]  # train2018.json.tar.gz
    trainjson_filename = trainjson_link.split('.')[0]  # train2018

    if not os.path.exists(os.path.join(path, trainjson_tarname)):
        download(path, train_link)
        print("Succesfully downloaded train json iNaturalist.")
    if not os.path.exists(os.path.join(path, trainjson_filename)):
        extract(path, os.path.join(path, os.path.join(path,
                                                      trainjson_tarname)))
        print("Succesfully extracted train json iNaturalist.")
Ejemplo n.º 5
0
    def __init__(self, crop=False, create=False):
        super().__init__(crop=crop, create=create)
        self.original_dataset_root = self.dataset_root
        self.dataset_root = os.path.join(self.original_dataset_root, self.suffix)
        utils.create_dir(self.dataset_root)
        print(self.dataset_root)

        # Create symbolic links if non-existing
        for task in range(1, self.task_count + 1):
            src_taskdir = os.path.join(self.original_dataset_root, str(self.task_ordering[task - 1]))
            dst_tasklink = os.path.join(self.dataset_root, str(task))
            if not os.path.exists(dst_tasklink):
                os.symlink(src_taskdir, dst_tasklink)
                print("CREATE LINK: {} -> {}".format(dst_tasklink, src_taskdir))
            else:
                print("EXISTING LINK: {} -> {}".format(dst_tasklink, src_taskdir))
Ejemplo n.º 6
0
    def load_chkpt(self, manager):
        """ Load checkpoint hyperparams from convergence. """
        utils.create_dir(manager.heuristic_exp_dir)
        hyperparams_path = os.path.join(
            manager.heuristic_exp_dir, utils.get_hyperparams_output_filename())
        try:
            print("Initiating framework chkpt:{}".format(hyperparams_path))
            chkpt = torch.load(hyperparams_path)
        except:
            print(
                "CHECKPOINT LOAD FAILED: No state to restore, starting from scratch."
            )
            return False

        self._restore_state(chkpt['state'])
        print("SUCCESSFUL loading framework chkpt:{}".format(hyperparams_path))
        return True
Ejemplo n.º 7
0
def download_dset(path):
    utils.create_dir(path)

    if not os.path.exists(os.path.join(path, 'tiny-imagenet-200.zip')):
        subprocess.call(
            "wget -P {} http://cs231n.stanford.edu/tiny-imagenet-200.zip".format(path),
            shell=True)
        print("Succesfully downloaded TinyImgnet dataset.")
    else:
        print("Already downloaded TinyImgnet dataset in {}".format(path))

    if not os.path.exists(os.path.join(path, 'tiny-imagenet-200')):
        subprocess.call(
            "unzip {} -d {}".format(os.path.join(path, 'tiny-imagenet-200.zip'), path),
            shell=True)
        print("Succesfully extracted TinyImgnet dataset.")
    else:
        print("Already extracted TinyImgnet dataset in {}".format(os.path.join(path, 'tiny-imagenet-200')))
Ejemplo n.º 8
0
def main():
    args = parse_args()
    seed_torch(args.seed)
    print("epochs", args.epochs)
    print("lr", args.lr)
    print("optimizer", args.optimizer)
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    train_loader = data.get_loader('train', args.batch_size)
    val_loader = data.get_loader('test', args.batch_size)

    output_path = 'saved_models/{}/{}'.format(config.dataset, args.output)
    utils.create_dir(output_path)
    torch.backends.cudnn.benchmark = True

    embeddings = np.load(os.path.join(config.dataroot, 'glove6b_init_300d.npy'))
    constructor = 'build_baseline_with_dl'
    model = getattr(base_model, constructor)(embeddings).cuda()
    model = nn.DataParallel(model).cuda()

    if args.optimizer == 'adadelta':
        optimizer = optim.Adadelta(model.parameters(), rho=0.95, eps=1e-6, weight_decay=args.weight_decay)
    elif args.optimizer == 'rmsprop':
        optimizer = optim.RMSprop(model.parameters(), lr=0.01, alpha=0.99, eps=1e-08, weight_decay=args.weight_decay, momentum=0,
                                    centered=False)
    elif args.optimizer == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=args.weight_decay)
    elif args.optimizer == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)
    else:
        optimizer = optim.Adamax(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)


    start_epoch = 0
    f1_val_best = 0
    tracker = utils.Tracker()
    model_path = os.path.join(output_path, 'model.pth')
    for epoch in range(start_epoch, args.epochs):
        run(model, train_loader, optimizer, tracker, train=True, prefix='train', epoch=epoch)
        r = run(model, val_loader, optimizer, tracker, train=False, prefix='test', epoch=epoch)
        if r[4].mean() > f1_val_best:
            f1_val_best = r[4].mean()
            best_results = r
    utils.print_results(best_results[2], best_results[3])
Ejemplo n.º 9
0
def compute_target(answers_dset, ans2label, name):
    """Augment answers_dset with soft score as label

    ***answers_dset should be preprocessed***

    Write result into a cache file
    """
    target = []
    for ans_entry in answers_dset:
        answers = ans_entry['answers']
        answer_count = {}
        for answer in answers:
            answer_ = preprocess_answer(answer['answer'])
            answer_count[answer_] = answer_count.get(answer_, 0) + 1

        labels = []
        scores = []
        for answer in answer_count:
            if answer not in ans2label:
                continue
            labels.append(ans2label[answer])
            score = get_score(answer_count[answer])
            scores.append(score)

        label_counts = {}
        for k, v in answer_count.items():
            if k in ans2label:
                label_counts[ans2label[k]] = v
        # when all the answers are removed in training mode, skip this entrance
        target.append({
            'question_id': ans_entry['question_id'],
            'question_type': ans_entry['question_type'],
            'image_id': ans_entry['image_id'],
            'labels': labels,
            'label_counts': label_counts,
            'scores': scores
        })

    utils.create_dir(config.cache_root)
    cache_file = os.path.join(config.cache_root, name + '_target.json')
    json.dump(target, open(cache_file, 'w'))
    return target
def create_ans2label(occurence, name):
    """Note that this will also create label2ans.pkl at the same time
    occurence: dict {answer -> whatever}
    name: prefix of the output file
    cache_root: str
    """
    ans2label = {}
    label2ans = []
    label = 0
    for answer in occurence:
        label2ans.append(answer)
        ans2label[answer] = label
        label += 1

    utils.create_dir(config.cache_root)

    cache_file = os.path.join(config.cache_root, name + '_ans2label.json')
    json.dump(ans2label, open(cache_file, 'w'))
    cache_file = os.path.join(config.cache_root, name + '_label2ans.json')
    json.dump(label2ans, open(cache_file, 'w'))
    return ans2label
Ejemplo n.º 11
0
def create_train_val_test_imagefolder_dict_joint(dataset_root, img_paths, outfile, no_crop=True):
    """
    For JOINT training: All 10 tasks in 1 data folder.
    Makes specific wrapper dictionary with the 3 ImageFolder objects we will use for training, validation and evaluation.
    """
    # Data loading code
    if no_crop:
        out_dir = os.path.join(dataset_root, "no_crop")
    else:
        out_dir = dataset_root

    # Tiny Imgnet total values from pytorch
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    dsets = create_train_test_val_imagefolders(img_paths[1], dataset_root, normalize, True, no_crop=no_crop)

    ################ SAVE ##################
    utils.create_dir(out_dir)
    torch.save(dsets, os.path.join(out_dir, outfile))
    print("JOINT SIZES: train={}, val={}, test={}".format(len(dsets['train']), len(dsets['val']),
                                                          len(dsets['test'])))
    print("JOINT: Saved dictionary format of train/val/test dataset Imagefolders.")
Ejemplo n.º 12
0
def create_train_val_test_imagefolder_dict(dataset_root, img_paths, task_count, outfile, no_crop=True, transform=False):
    """
    Makes specific wrapper dictionary with the 3 ImageFolder objects we will use for training, validation and evaluation.
    """
    # Data loading code
    if no_crop:
        out_dir = os.path.join(dataset_root, "no_crop", "{}tasks".format(task_count))
    else:
        out_dir = os.path.join(dataset_root, "{}tasks".format(task_count))

    for task in range(1, task_count + 1):
        print("\nTASK ", task)

        # Tiny Imgnet total values from pytorch
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        dsets = create_train_test_val_imagefolders(img_paths[task], dataset_root, normalize, transform, no_crop)
        utils.create_dir(os.path.join(out_dir, str(task)))
        torch.save(dsets, os.path.join(out_dir, str(task), outfile))
        print("SIZES: train={}, val={}, test={}".format(len(dsets['train']), len(dsets['val']),
                                                        len(dsets['test'])))
        print("Saved dictionary format of train/val/test dataset Imagefolders.")
Ejemplo n.º 13
0
def lr_grid_single_task(args, manager, save_models_mode='keep_none'):
    """
    Finetunes from starting model, acquire best lr and acc. LR gridsearch, with #finetune_iterations per LR.
    Makes symbolic link to the overall best iteration, corresponding with the obtained best lr.
    """

    # Init
    manager.store_policy = StoragePolicy(save_models_mode)
    args.task_name = manager.dataset.get_taskname(args.task_counter)
    manager.ft_parent_exp_dir = os.path.join(manager.parent_exp_dir,
                                             'task_' + str(args.task_counter),
                                             'FT_LR_GRIDSEARCH')
    utils.create_dir(manager.ft_parent_exp_dir)
    print("FINETUNE LR GRIDSEARCH: Task ", args.task_name)

    # Logfile
    logfile_parent_dir = os.path.join(manager.ft_parent_exp_dir, 'log')
    utils.create_dir(logfile_parent_dir)
    logfile = os.path.join(logfile_parent_dir,
                           utils.get_now() + '_finetune_grid.log')
    utils.append_to_file(logfile, "FINETUNE GRIDSEARCH LOG: Processed LRs")

    # Load Checkpoint
    processed_lrs = {}
    grid_checkpoint_file = os.path.join(manager.ft_parent_exp_dir,
                                        'grid_checkpoint.pth')
    if os.path.exists(grid_checkpoint_file):
        checkpoint = torch.load(grid_checkpoint_file)
        processed_lrs = checkpoint['processed_lrs']

        print("STARTING FROM CHECKPOINT: ", checkpoint)
        utils.append_to_file(logfile, "STARTING FROM CHECKPOINT")

    ########################################################
    # PRESTEPS
    args.presteps_elapsed_time = 0
    if hasattr(manager.method, 'grid_prestep'):
        manager.method.grid_prestep(args, manager)

    ########################################################
    # LR GRIDSEARCH
    best_acc = 0
    best_lr = None
    manager.best_exp_grid_node_dirname = None
    best_iteration_batch_dirs = []
    for lr in args.lrs:
        print("\n", "<" * 20, "LR ", lr, ">" * 20)
        accum_acc = 0
        best_iteration_dir = None
        best_iteration_acc = 0
        iteration_batch_dirs = []
        if lr not in processed_lrs:
            processed_lrs[lr] = {'acc': []}

        for finetune_iteration in range(0, args.finetune_iterations):
            print("\n", "-" * 20, "FT ITERATION ", finetune_iteration,
                  "-" * 20)
            start_time = time.time()

            # Paths
            exp_grid_node_dirname = "lr=" + str(
                utils.float_to_scientific_str(lr))
            if args.finetune_iterations > 1:
                exp_grid_node_dirname += "_it" + str(finetune_iteration)
            manager.gridsearch_exp_dir = os.path.join(
                manager.ft_parent_exp_dir, exp_grid_node_dirname)
            iteration_batch_dirs.append(manager.gridsearch_exp_dir)

            if finetune_iteration < len(processed_lrs[lr]['acc']):
                acc = processed_lrs[lr]['acc'][finetune_iteration]
                utils.set_random(finetune_iteration)
                print("RESTORING FROM CHECKPOINT: ITERATION = ",
                      finetune_iteration, "ACC = ", acc)
            else:
                # Set new seed for reproducability
                utils.set_random(finetune_iteration)

                # Only actually saved when in save_model mode
                utils.create_dir(manager.gridsearch_exp_dir)

                # TRAIN
                model, acc = manager.method.grid_train(args, manager, lr)

                # Append results
                processed_lrs[lr]['acc'].append(acc)
                msg = "LR = {}, FT Iteration {}/{}, Acc = {}".format(
                    lr, finetune_iteration + 1, args.finetune_iterations, acc)
                print(msg)
                utils.append_to_file(logfile, msg)

            # New best
            if acc > best_iteration_acc:
                if args.finetune_iterations > 1:
                    msg = "=> NEW BEST FT ITERATION {}/{}: (Attempt '{}': Acc '{}' > best attempt Acc '{}')" \
                        .format(finetune_iteration + 1,
                                args.finetune_iterations,
                                finetune_iteration,
                                acc,
                                best_iteration_acc)
                    print(msg)
                    utils.append_to_file(logfile, msg)

                best_iteration_acc = acc
                best_iteration_dir = manager.gridsearch_exp_dir

            accum_acc = accum_acc + acc

            # update logfile/checkpoint
            torch.save({'processed_lrs': processed_lrs}, grid_checkpoint_file)

            # Save iteration hyperparams
            if hasattr(manager.method,
                       "grid_chkpt") and manager.method.grid_chkpt:
                it_elapsed_time = time.time() - start_time
                hyperparams = {
                    'val_acc': acc,
                    'lr': lr,
                    'iteration_elapsed_time': it_elapsed_time,
                    'args': vars(args),
                    'manager': vars(manager)
                }
                utils.print_timing(it_elapsed_time, 'TRAIN')
                manager.save_hyperparams(manager.gridsearch_exp_dir,
                                         hyperparams)
        avg_acc = accum_acc / args.finetune_iterations
        print("Done FT iterations\n")
        print("LR AVG ACC = ", avg_acc, ", BEST OF LRs ACC = ", best_acc)

        # New it-avg best
        if avg_acc > best_acc:
            best_lr = lr
            best_acc = avg_acc
            manager.best_exp_grid_node_dirname = best_iteration_dir  # Keep ref to best in all attempts
            print("UPDATE best lr = {}".format(best_lr))
            print("UPDATE best lr acc= {}".format(best_acc))

            utils.append_to_file(logfile,
                                 "UPDATE best lr = {}".format(best_lr))
            utils.append_to_file(logfile,
                                 "UPDATE best lr acc= {}\n".format(best_acc))

            # Clean all from previous best
            if manager.store_policy.only_keep_best:
                for out_dir in best_iteration_batch_dirs:
                    if os.path.exists(out_dir):
                        shutil.rmtree(out_dir, ignore_errors=True)
                        print("[CLEANUP] removing {}".format(out_dir))
            best_iteration_batch_dirs = iteration_batch_dirs
        else:
            if manager.store_policy.only_keep_best:
                for out_dir in iteration_batch_dirs:
                    if os.path.exists(out_dir):
                        shutil.rmtree(out_dir, ignore_errors=True)
                        print("[CLEANUP] removing {}".format(out_dir))
        if manager.store_policy.keep_none:
            for out_dir in iteration_batch_dirs:
                if os.path.exists(out_dir):
                    shutil.rmtree(out_dir, ignore_errors=True)
                    print("[CLEANUP] removing {}".format(out_dir))
    print("FINETUNE DONE: best_lr={}, best_acc={}".format(best_lr, best_acc))

    ########################################################
    # POSTPROCESS
    if hasattr(manager.method, 'grid_poststep'):
        manager.method.grid_poststep(args, manager)

    return best_lr, best_acc
Ejemplo n.º 14
0
        name = 'eval_results.json'
    else:
        name = 'eval_results_epoch{}.json'.format(epoch)
    results_path = os.path.join(output_path, name)
    with open(results_path, 'w') as fd:
        json.dump(results, fd)


if __name__ == '__main__':
    args = parse_args()
    seed_torch(args.seed)
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    if args.test_only:
        args.resume = True
    output_path = 'saved_models/{}/{}'.format(config.type + config.version, args.output)
    utils.create_dir(output_path)
    torch.backends.cudnn.benchmark = True

    ######################################### DATASET PREPARATION #######################################
    if config.mode == 'train':
        train_loader = data.get_loader('train')
        val_loader = data.get_loader('val')
    elif args.test_only:
        train_loader = None
        val_loader = data.get_loader('test')
    else:
        train_loader = data.get_loader('trainval')
        val_loader = data.get_loader('test')

    ######################################### MODEL PREPARATION #######################################
    embeddings = np.load(os.path.join(config.cache_root, 'glove6b_init_300d.npy'))
Ejemplo n.º 15
0
def main(overwrite_args, nc_per_task):
    """
    For quick implementation: args are overwritten by overwrite_args (if specified).
    Additional params are passed in the main function.

    Do this task and return acc (task contains multiple classes = class-incremental in this setup)
    :param overwrite_args:
    :param nc_per_task: array with the amount of classes for each task
    :return:
    """
    parser = argparse.ArgumentParser(description='Continuum learning')

    parser.add_argument('--task_name', type=str, help='name of the task')
    parser.add_argument('--task_count',
                        type=int,
                        help='count of the task, STARTING FROM 1')
    parser.add_argument('--prev_model_path',
                        type=str,
                        help='path to prev model where to start from')
    parser.add_argument('--save_path',
                        type=str,
                        default='results/',
                        help='save models during and at the end of training')
    parser.add_argument('--n_outputs',
                        type=int,
                        default=200,
                        help='total number of outputs for ALL tasks')
    parser.add_argument('--method',
                        choices=[
                            'gem', 'baseline_rehearsal_full_mem', 'icarl',
                            'baseline_rehearsal_partial_mem'
                        ],
                        type=str,
                        default='gem',
                        help='method to use for train')
    parser.add_argument(
        '--postprocess',
        action="store_true",
        help='Do datamanagement (e.g. update buffers) after task is learned')
    # implemented in separate step to only perform once on THE best model in pipeline
    parser.add_argument('--debug', action="store_true", help='Debug mode')
    # model parameters
    parser.add_argument('--n_hiddens',
                        type=int,
                        default=100,
                        help='number of hidden neurons at each layer')
    parser.add_argument('--n_layers',
                        type=int,
                        default=2,
                        help='number of hidden layers')
    parser.add_argument('--n_inputs',
                        type=int,
                        default=-1,
                        help='number of hidden layers')
    parser.add_argument('--weight_decay',
                        type=float,
                        default=0,
                        help='weight_decay')
    parser.add_argument('--is_scratch_model',
                        action='store_true',
                        help='is this the first task you train for?')

    # memory parameters
    parser.add_argument('--n_memories',
                        type=int,
                        default=0,
                        help='number of memories per task')
    parser.add_argument('--memory_strength',
                        default=0,
                        type=float,
                        help='memory strength (meaning depends on memory)')
    parser.add_argument('--finetune',
                        action="store_true",
                        help='whether to initialize nets in indep. nets')

    # optimizer parameters
    parser.add_argument('--n_epochs',
                        type=int,
                        default=1,
                        help='Number of epochs per task')
    parser.add_argument('--batch_size',
                        type=int,
                        default=70,
                        help='batch size')
    parser.add_argument('--lr',
                        type=float,
                        default=1e-3,
                        help='SGD learning rate')

    # experiment parameters
    parser.add_argument('--cuda', action="store_true", help='Use GPU?')
    parser.add_argument('--log_every',
                        type=int,
                        default=100,
                        help='frequency of logs, in minibatches')

    # data parameters
    parser.add_argument('dataset_path',
                        type=str,
                        help='path to imgfolders of datasets')
    parser.add_argument('--shuffle_tasks',
                        type=str,
                        default='no',
                        help='present tasks in order')
    parser.add_argument('--n_tasks',
                        type=int,
                        default=10,
                        help='number of hidden layers')

    #####################################
    # ARGS PARSING
    #####################################
    args = parser.parse_known_args()[0]

    # Overwrite with specified method args
    args.nc_per_task = nc_per_task  # Array with nr outputs per class
    for key_arg, val_arg in overwrite_args.items():
        setattr(args, key_arg, val_arg)

    # Index starting from 0 (for array)
    args.task_idx = args.task_count - 1
    args.n_exemplars_to_append_per_batch = 0

    if 'baseline_rehearsal' in args.method:
        args.finetune = True
        if args.method == 'baseline_rehearsal_full_mem':
            args.full_mem_mode = True
        elif args.method == 'baseline_rehearsal_partial_mem':
            args.full_mem_mode = False
        else:
            raise Exception("UNKNOWN BASELINE METHOD:", args.method)

    # Input checks
    assert args.n_outputs == sum(args.nc_per_task)
    assert args.n_tasks == len(nc_per_task)

    # Baselines from scratch, others from SI model
    if args.task_count == 1 and 'baseline' not in args.method:
        assert 'SI' in args.prev_model_path, "FIRST TASK NOT STARTING FROM SCRATCH, BUT FROM SI: " \
                                             "ONLY STORING WRAPPER WITH EXEMPLARS, path = {}" \
            .format(args.prev_model_path)
        assert args.postprocess, "FIRST TASK WE DO ONLY POSTPROCESSING"

    assert os.path.isfile(
        args.prev_model_path
    ), "Must specify existing prev_model_path, got: " + args.prev_model_path

    print("RUNNING WITH ARGS: ", overwrite_args)
    #####################################
    # DATASET
    #####################################
    # load data: We consider 1 task, not class-incremental
    dsets = torch.load(args.dataset_path)
    args.task_imgfolders = dsets
    args.dset_loaders = {
        x:
        torch.utils.data.DataLoader(ImageFolder_Subset_PathRetriever(dsets[x]),
                                    batch_size=args.batch_size,
                                    shuffle=True,
                                    num_workers=8,
                                    pin_memory=True)
        for x in ['train', 'val']
    }
    dset_sizes = {x: len(dsets[x]) for x in ['train', 'val']}

    # Assign random part of batch to exemplars
    if 'baseline' in args.method or args.method == 'icarl':
        if args.method == 'baseline_rehearsal_partial_mem':  # Only trains on exemplar sets, no validation
            # Based on lengths of tr datasets: random sampling, but guaranteed at each batch from all exemplar sets
            n_mem_samples = args.n_memories * args.task_idx
        elif args.method == 'baseline_rehearsal_full_mem' or args.method == 'icarl':
            n_mem_samples = args.n_memories * args.n_tasks

        n_total_samples = float(dset_sizes['train']) + n_mem_samples
        ratio = float(n_mem_samples) / n_total_samples
        if not args.debug:
            args.n_exemplars_to_append_per_batch = int(
                np.ceil(args.batch_size * ratio))  # Ceil: at least 1 per task
        else:
            args.n_exemplars_to_append_per_batch = int(args.batch_size / 2 +
                                                       17)
        args.total_batch_size = args.batch_size
        args.batch_size = args.batch_size - args.n_exemplars_to_append_per_batch

        print("BATCH CONSISTS OF: {} new samples, {} exemplars".format(
            args.batch_size, args.n_exemplars_to_append_per_batch))
        print("mem length = {}, tr_dset_size={}, ratio={}".format(
            float(n_mem_samples), float(dset_sizes['train']), ratio))

    #####################################
    # LOADING MODEL
    #####################################
    if args.is_scratch_model:  # make model with self.net
        assert args.task_idx == 0  # Has to start from 0
        Model = importlib.import_module('model.' + args.method)
        model = Model.Net(args.n_inputs, args.n_outputs, args.n_tasks, args)
    else:
        # load prev saved model
        print("Loading prev model from path: ", args.prev_model_path)
        model = torch.load(args.prev_model_path)
    print("MODEL LOADED")
    model.init_setup(args)

    # Checks
    assert model.n_tasks == args.n_tasks, "model tasks={}, args tasks={}".format(
        model.n_tasks, args.n_tasks)
    assert model.n_outputs == args.n_outputs
    print("MODEL SETUP FOR CURRENT TASK")

    if args.postprocess:
        #####################################
        # POSTPROCESS
        #####################################
        """
        iCARL each task, 
        GEM only first task model SI (during training no exemplars where gathered)
        """
        print("POSTPROCESSING")
        model.manage_memory(args.task_idx, args)
        # if args.debug:
        #     model.check_exemplars(args.task_idx)

        utils.create_dir(os.path.dirname(args.save_path),
                         "Postprocessed model dir")
        torch.save(model, args.save_path)
        print("SAVED POSTPROCESSED MODEL TO: {}".format(args.save_path))
        return None, None
    else:
        #####################################
        # TRAIN
        #####################################
        # if there is a checkpoint to be resumed, in case where the training has stopped before on a given task
        resume = os.path.join(args.save_path, 'epoch.pth.tar')
        model, best_val_acc = methods.rehearsal.train_rehearsal.train_model(
            model, args, dset_sizes, resume=resume)
        print("FINISHED TRAINING")

        return model, best_val_acc