Esempio n. 1
0
        lines.append(ave_ICARL)
        line_names.append("iCaRL")
        colors.append("brown")
        if args.n_seeds > 1:
            errors.append(sem_ICARL)

    # -plot
    figure = visual_plt.plot_lines(
        lines,
        x_axes=budget_list,
        ylabel="average precision (after all tasks)",
        title=title,
        x_log=True,
        ylim=y_lim,
        line_names=line_names,
        xlabel="Total memory budget",
        with_dots=True,
        colors=colors,
        list_with_errors=errors,
        h_lines=h_lines,
        h_errors=h_errors,
        h_labels=h_labels,
        h_colors=h_colors,
    )
    figure_list.append(figure)

    # add figures to pdf
    for figure in figure_list:
        pp.savefig(figure)

    # close the pdf
Esempio n. 2
0
def run(args):

    # Set default arguments
    args.g_fc_lay = args.fc_lay if args.g_fc_lay is None else args.g_fc_lay
    args.g_fc_uni = args.fc_units if args.g_fc_uni is None else args.g_fc_uni
    args.g_iters = args.iters if args.g_iters is None else args.g_iters
    # -if [log_per_task], reset all logs
    if args.log_per_task:
        args.prec_log = args.iters
        args.loss_log = args.iters
        args.sample_log = args.iters
    # -if XdG is selected but not the incremental task learning scenario, give error
    if (not args.scenario == "task") and args.gating_prop > 0:
        raise ValueError(
            "'XdG' only works for the incremental task learning scenario.")
    # -if EWC, SI or XdG is selected together with 'feedback', give error
    if args.feedback and (args.ewc or args.si or args.gating_prop > 0):
        raise NotImplementedError(
            "EWC, SI and XdG are not supported with feedback connections.")
    # -if XdG is selected together with replay of any kind, give error
    if args.gating_prop > 0 and (not args.replay == "none"):
        raise NotImplementedError(
            "XdG is not supported with '{}' replay.".format(args.replay))
    # -create plots- and results-directories if needed
    if not os.path.isdir(args.r_dir):
        os.mkdir(args.r_dir)
    if args.pdf and not os.path.isdir(args.p_dir):
        os.mkdir(args.p_dir)

    # Use cuda?
    cuda = torch.cuda.is_available() and args.cuda
    device = torch.device("cuda" if cuda else "cpu")

    # Set random seeds
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if cuda:
        torch.cuda.manual_seed(args.seed)

    #-------------------------------------------------------------------------------------------------#

    #----------------#
    #----- DATA -----#
    #----------------#

    # Prepare data for chosen experiment
    (train_datasets,
     test_datasets), config, classes_per_task = get_multitask_experiment(
         name=args.experiment,
         scenario=args.scenario,
         tasks=args.tasks,
         data_dir=args.d_dir,
         verbose=True,
         exception=True if args.seed == 0 else False,
     )

    #-------------------------------------------------------------------------------------------------#

    #------------------------------#
    #----- MODEL (CLASSIFIER) -----#
    #------------------------------#

    # Define main model (i.e., classifier, if requested with feedback connections)
    if args.feedback:
        model = AutoEncoder(
            image_size=config['size'],
            image_channels=config['channels'],
            classes=config['classes'],
            fc_layers=args.fc_lay,
            fc_units=args.fc_units,
            z_dim=args.z_dim,
            fc_drop=args.fc_drop,
            fc_bn=True if args.fc_bn == "yes" else False,
            fc_nl=args.fc_nl,
        ).to(device)
        model.lamda_pl = 1.  #--> to make that this VAE is also trained to classify
    else:
        model = Classifier(
            image_size=config['size'],
            image_channels=config['channels'],
            classes=config['classes'],
            fc_layers=args.fc_lay,
            fc_units=args.fc_units,
            fc_drop=args.fc_drop,
            fc_nl=args.fc_nl,
            fc_bn=True if args.fc_bn == "yes" else False,
            excit_buffer=True if args.gating_prop > 0 else False,
        ).to(device)

    # Define optimizer (only include parameters that "requires_grad")
    model.optim_list = [{
        'params':
        filter(lambda p: p.requires_grad, model.parameters()),
        'lr':
        args.lr
    }]
    model.optim_type = args.optimizer
    if model.optim_type in ("adam", "adam_reset"):
        model.optimizer = optim.Adam(model.optim_list, betas=(0.9, 0.999))
    elif model.optim_type == "sgd":
        model.optimizer = optim.SGD(model.optim_list)
    else:
        raise ValueError(
            "Unrecognized optimizer, '{}' is not currently a valid option".
            format(args.optimizer))

    # Set loss-function for reconstruction
    if args.feedback:
        model.recon_criterion = nn.BCELoss(size_average=True)

    #-------------------------------------------------------------------------------------------------#

    #-----------------------------------#
    #----- CL-STRATEGY: ALLOCATION -----#
    #-----------------------------------#

    # Elastic Weight Consolidation (EWC)
    if isinstance(model, ContinualLearner):
        model.ewc_lambda = args.ewc_lambda if args.ewc else 0
        model.fisher_n = args.fisher_n
        model.gamma = args.gamma
        model.online = args.online
        model.emp_FI = args.emp_fi

    # Synpatic Intelligence (SI)
    if isinstance(model, ContinualLearner):
        model.si_c = args.si_c if args.si else 0
        model.epsilon = args.epsilon

    # XdG: create for every task a "mask" for each hidden fully connected layer
    if isinstance(model, ContinualLearner) and args.gating_prop > 0:
        mask_dict = {}
        excit_buffer_list = []
        for task_id in range(args.tasks):
            mask_dict[task_id + 1] = {}
            for i in range(model.fcE.layers):
                layer = getattr(model.fcE, "fcLayer{}".format(i + 1)).linear
                if task_id == 0:
                    excit_buffer_list.append(layer.excit_buffer)
                n_units = len(layer.excit_buffer)
                gated_units = np.random.choice(n_units,
                                               size=int(args.gating_prop *
                                                        n_units),
                                               replace=False)
                mask_dict[task_id + 1][i] = gated_units
        model.mask_dict = mask_dict
        model.excit_buffer_list = excit_buffer_list

    #-------------------------------------------------------------------------------------------------#

    #-------------------------------#
    #----- CL-STRATEGY: REPLAY -----#
    #-------------------------------#

    # Use distillation loss (i.e., soft targets) for replayed data? (and set temperature)
    model.replay_targets = "soft" if args.distill else "hard"
    model.KD_temp = args.temp

    # If needed, specify separate model for the generator
    train_gen = True if (args.replay == "generative"
                         and not args.feedback) else False
    if train_gen:
        # -specify architecture
        generator = AutoEncoder(
            image_size=config['size'],
            image_channels=config['channels'],
            fc_layers=args.g_fc_lay,
            fc_units=args.g_fc_uni,
            z_dim=args.z_dim,
            classes=config['classes'],
            fc_drop=args.fc_drop,
            fc_bn=True if args.fc_bn == "yes" else False,
            fc_nl=args.fc_nl,
        ).to(device)
        # -set optimizer(s)
        generator.optim_list = [{
            'params':
            filter(lambda p: p.requires_grad, generator.parameters()),
            'lr':
            args.lr
        }]
        generator.optim_type = args.optimizer
        if generator.optim_type in ("adam", "adam_reset"):
            generator.optimizer = optim.Adam(generator.optim_list,
                                             betas=(0.9, 0.999))
        elif generator.optim_type == "sgd":
            generator.optimizer = optim.SGD(generator.optim_list)
        # -set reconstruction criterion
        generator.recon_criterion = nn.BCELoss(size_average=True)
    else:
        generator = None

    #-------------------------------------------------------------------------------------------------#

    #---------------------#
    #----- REPORTING -----#
    #---------------------#

    # Get parameter-stamp (and print on screen)
    param_stamp = utils.get_param_stamp(
        args,
        model.name,
        verbose=True,
        replay=True if (not args.replay == "none") else False,
        replay_model_name=generator.name if
        (args.replay == "generative" and not args.feedback) else None,
    )

    # Print some model-characteristics on the screen
    # -main model
    print("\n")
    utils.print_model_info(model, title="MAIN MODEL")
    # -generator
    if generator is not None:
        utils.print_model_info(generator, title="GENERATOR")

    # Prepare for plotting
    # -open pdf
    pp = visual_plt.open_pdf("{}/{}.pdf".format(
        args.p_dir, param_stamp)) if args.pdf else None
    # -define [precision_dict] to keep track of performance during training for later plotting
    precision_dict = evaluate.initiate_precision_dict(args.tasks)
    # -visdom-settings
    if args.visdom:
        env_name = "{exp}{tasks}-{scenario}".format(exp=args.experiment,
                                                    tasks=args.tasks,
                                                    scenario=args.scenario)
        graph_name = "{fb}{mode}{syn}{ewc}{XdG}".format(
            fb="1M-" if args.feedback else "",
            mode=args.replay,
            syn="-si{}".format(args.si_c) if args.si else "",
            ewc="-ewc{}{}".format(
                args.ewc_lambda, "-O{}".format(args.gamma)
                if args.online else "") if args.ewc else "",
            XdG=""
            if args.gating_prop == 0 else "-XdG{}".format(args.gating_prop))
        visdom = {'env': env_name, 'graph': graph_name}
    else:
        visdom = None

    #-------------------------------------------------------------------------------------------------#

    #---------------------#
    #----- CALLBACKS -----#
    #---------------------#

    # Callbacks for reporting on and visualizing loss
    generator_loss_cbs = [
        cb._VAE_loss_cb(log=args.loss_log,
                        visdom=visdom,
                        model=model if args.feedback else generator,
                        tasks=args.tasks,
                        iters_per_task=args.g_iters,
                        replay=False if args.replay == "none" else True)
    ] if (train_gen or args.feedback) else [None]
    solver_loss_cbs = [
        cb._solver_loss_cb(log=args.loss_log,
                           visdom=visdom,
                           model=model,
                           tasks=args.tasks,
                           iters_per_task=args.iters,
                           replay=False if args.replay == "none" else True)
    ] if (not args.feedback) else [None]

    # Callbacks for evaluating and plotting generated / reconstructed samples
    sample_cbs = [
        cb._sample_cb(log=args.sample_log,
                      visdom=visdom,
                      config=config,
                      test_datasets=test_datasets,
                      sample_size=args.sample_n,
                      iters_per_task=args.g_iters)
    ] if (train_gen or args.feedback) else [None]

    # Callbacks for reporting and visualizing accuracy
    # -visdom (i.e., after each [prec_log])
    eval_cb = cb._eval_cb(
        log=args.prec_log,
        test_datasets=test_datasets,
        visdom=visdom,
        iters_per_task=args.iters,
        scenario=args.scenario,
        collate_fn=utils.label_squeezing_collate_fn,
        test_size=args.prec_n,
        classes_per_task=classes_per_task,
        task_mask=True if isinstance(model, ContinualLearner) and
        (args.gating_prop > 0) else False)
    # -pdf: for summary plots (i.e, only after each task)
    eval_cb_full = cb._eval_cb(
        log=args.iters,
        test_datasets=test_datasets,
        precision_dict=precision_dict,
        scenario=args.scenario,
        collate_fn=utils.label_squeezing_collate_fn,
        iters_per_task=args.iters,
        classes_per_task=classes_per_task,
        task_mask=True if isinstance(model, ContinualLearner) and
        (args.gating_prop > 0) else False)
    # -collect them in <lists>
    eval_cbs = [eval_cb, eval_cb_full]

    #-------------------------------------------------------------------------------------------------#

    #--------------------#
    #----- TRAINING -----#
    #--------------------#

    print("--> Training:")
    # Keep track of training-time
    start = time.time()
    # Train model
    train_cl(
        model,
        train_datasets,
        replay_mode=args.replay,
        scenario=args.scenario,
        classes_per_task=classes_per_task,
        iters=args.iters,
        batch_size=args.batch,
        collate_fn=utils.label_squeezing_collate_fn,
        visualize=True if args.visdom else False,
        generator=generator,
        gen_iters=args.g_iters,
        gen_loss_cbs=generator_loss_cbs,
        sample_cbs=sample_cbs,
        eval_cbs=eval_cbs,
        loss_cbs=generator_loss_cbs if args.feedback else solver_loss_cbs,
    )
    # Get total training-time in seconds, and write to file
    training_time = time.time() - start
    time_file = open("{}/time-{}.txt".format(args.r_dir, param_stamp), 'w')
    time_file.write('{}\n'.format(training_time))
    time_file.close()

    #-------------------------------------------------------------------------------------------------#

    #----------------------#
    #----- EVALUATION -----#
    #----------------------#

    print('\n\n--> Evaluation ("incremental {} learning scenario"):'.format(
        args.scenario))

    # Generation (plot in pdf)
    if (pp is not None) and train_gen:
        evaluate.show_samples(generator, config, size=args.sample_n, pdf=pp)
    if (pp is not None) and args.feedback:
        evaluate.show_samples(model, config, size=args.sample_n, pdf=pp)

    # Reconstruction (plot in pdf)
    if (pp is not None) and (train_gen or args.feedback):
        for i in range(args.tasks):
            if args.feedback:
                evaluate.show_reconstruction(model,
                                             test_datasets[i],
                                             config,
                                             pdf=pp,
                                             task=i + 1)
            else:
                evaluate.show_reconstruction(generator,
                                             test_datasets[i],
                                             config,
                                             pdf=pp,
                                             task=i + 1)

    # Classifier (print on screen & write to file)
    if args.scenario == "task":
        precs = [
            evaluate.validate(
                model,
                test_datasets[i],
                verbose=False,
                test_size=None,
                task_mask=True if isinstance(model, ContinualLearner)
                and args.gating_prop > 0 else False,
                task=i + 1,
                allowed_classes=list(
                    range(classes_per_task * i, classes_per_task * (i + 1))))
            for i in range(args.tasks)
        ]
    else:
        precs = [
            evaluate.validate(model,
                              test_datasets[i],
                              verbose=False,
                              test_size=None,
                              task=i + 1) for i in range(args.tasks)
        ]
    print("\n Precision on test-set:")
    for i in range(args.tasks):
        print(" - Task {}: {:.4f}".format(i + 1, precs[i]))
    average_precs = sum(precs) / args.tasks
    print('=> average precision over all {} tasks: {:.4f}\n'.format(
        args.tasks, average_precs))

    #-------------------------------------------------------------------------------------------------#

    #------------------#
    #----- OUTPUT -----#
    #------------------#

    # Average precision on full test set (no restrictions on which nodes can be predicted: "incremental" / "singlehead")
    output_file = open("{}/prec-{}.txt".format(args.r_dir, param_stamp), 'w')
    output_file.write('{}\n'.format(average_precs))
    output_file.close()

    # Precision-dictionary
    file_name = "{}/dict-{}".format(args.r_dir, param_stamp)
    utils.save_object(precision_dict, file_name)

    #-------------------------------------------------------------------------------------------------#

    #--------------------#
    #----- PLOTTING -----#
    #--------------------#

    # If requested, generate pdf
    if pp is not None:
        # -create list to store all figures to be plotted.
        figure_list = []
        # -generate all figures (and store them in [figure_list])
        figure = visual_plt.plot_lines(
            precision_dict["all_tasks"],
            x_axes=precision_dict["x_task"],
            line_names=['task {}'.format(i + 1) for i in range(args.tasks)])
        figure_list.append(figure)
        figure = visual_plt.plot_lines([precision_dict["average"]],
                                       x_axes=precision_dict["x_task"],
                                       line_names=['average all tasks so far'])
        figure_list.append(figure)
        # -add figures to pdf (and close this pdf).
        for figure in figure_list:
            pp.savefig(figure)

    # Close pdf
    if pp is not None:
        pp.close()
Esempio n. 3
0
        new_ave_line = []
        new_sem_line = []
        for line_id in range(len(prec[args.seed][id])):
            all_entries = [prec[seed][id][line_id] for seed in seed_list]
            new_ave_line.append(np.mean(all_entries))
            if len(seed_list) > 1:
                new_sem_line.append(
                    1.96 *
                    np.sqrt(np.var(all_entries) / (len(all_entries) - 1)))
        ave_lines.append(new_ave_line)
        sem_lines.append(new_sem_line)
    figure = visual_plt.plot_lines(
        ave_lines,
        x_axes=x_axes,
        line_names=names,
        colors=colors,
        title=title,
        xlabel="tasks",
        ylabel="average precision (on tasks seen so far)",
        list_with_errors=sem_lines if len(seed_list) > 1 else None)
    figure_list.append(figure)

    # scatter-plot (accuracy vs training-time)
    accuracies = []
    times = []
    for id in ids[:-1]:
        accuracies.append([ave_prec[seed][id] for seed in seed_list])
        times.append([train_time[seed][id] / 60 for seed in seed_list])
    xmax = np.max(times)
    ylim = (0, 1.025)
    figure = visual_plt.plot_scatter_groups(
Esempio n. 4
0
def run(args):

    # Set default arguments & check for incompatible options
    args.lr_gen = args.lr if args.lr_gen is None else args.lr_gen
    args.g_iters = args.iters if args.g_iters is None else args.g_iters
    args.g_fc_lay = args.fc_lay if args.g_fc_lay is None else args.g_fc_lay
    args.g_fc_uni = args.fc_units if args.g_fc_uni is None else args.g_fc_uni
    # -if [log_per_task], reset all logs
    if args.log_per_task:
        args.prec_log = args.iters
        args.loss_log = args.iters
        args.sample_log = args.iters
    # -if [iCaRL] is selected, select all accompanying options
    if hasattr(args, "icarl") and args.icarl:
        args.use_exemplars = True
        args.add_exemplars = True
        args.bce = True
        args.bce_distill = True
    # -if XdG is selected but not the Task-IL scenario, give error
    if (not args.scenario == "task") and args.xdg:
        raise ValueError("'XdG' is only compatible with the Task-IL scenario.")
    # -if EWC, SI or XdG is selected together with 'feedback', give error
    if args.feedback and (args.ewc or args.si or args.xdg or args.icarl):
        raise NotImplementedError(
            "EWC, SI, XdG and iCaRL are not supported with feedback connections."
        )
    # -if binary classification loss is selected together with 'feedback', give error
    if args.feedback and args.bce:
        raise NotImplementedError(
            "Binary classification loss not supported with feedback connections."
        )
    # -if XdG is selected together with both replay and EWC, give error (either one of them alone with XdG is fine)
    if args.xdg and (not args.replay == "none") and (args.ewc or args.si):
        raise NotImplementedError(
            "XdG is not supported with both '{}' replay and EWC / SI.".format(
                args.replay))
        #--> problem is that applying different task-masks interferes with gradient calculation
        #    (should be possible to overcome by calculating backward step on EWC/SI-loss also for each mask separately)
    # -if 'BCEdistill' is selected for other than scenario=="class", give error
    if args.bce_distill and not args.scenario == "class":
        raise ValueError(
            "BCE-distill can only be used for class-incremental learning.")
    # -create plots- and results-directories if needed
    if not os.path.isdir(args.r_dir):
        os.mkdir(args.r_dir)
    if args.pdf and not os.path.isdir(args.p_dir):
        os.mkdir(args.p_dir)

    scenario = args.scenario
    # If Task-IL scenario is chosen with single-headed output layer, set args.scenario to "domain"
    # (but note that when XdG is used, task-identity information is being used so the actual scenario is still Task-IL)
    if args.singlehead and args.scenario == "task":
        scenario = "domain"

    # If only want param-stamp, get it printed to screen and exit
    if hasattr(args, "get_stamp") and args.get_stamp:
        _ = get_param_stamp_from_args(args=args)
        exit()

    # Use cuda?
    cuda = torch.cuda.is_available() and args.cuda
    device = torch.device("cuda" if cuda else "cpu")

    # Set random seeds
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if cuda:
        torch.cuda.manual_seed(args.seed)

    #-------------------------------------------------------------------------------------------------#

    #----------------#
    #----- DATA -----#
    #----------------#

    # Prepare data for chosen experiment
    (train_datasets,
     test_datasets), config, classes_per_task = get_multitask_experiment(
         name=args.experiment,
         scenario=scenario,
         tasks=args.tasks,
         data_dir=args.d_dir,
         verbose=True,
         exception=True if args.seed == 0 else False,
     )

    #print(train_datasets, test_datasets)
    #a = input()
    #-------------------------------------------------------------------------------------------------#

    #------------------------------#
    #----- MODEL (CLASSIFIER) -----#
    #------------------------------#

    # Define main model (i.e., classifier, if requested with feedback connections)
    if args.feedback:
        model = AutoEncoder(
            image_size=config['size'],
            image_channels=config['channels'],
            classes=config['classes'],
            fc_layers=args.fc_lay,
            fc_units=args.fc_units,
            z_dim=args.z_dim,
            fc_drop=args.fc_drop,
            fc_bn=True if args.fc_bn == "yes" else False,
            fc_nl=args.fc_nl,
        ).to(device)
        model.lamda_pl = 1.  #--> to make that this VAE is also trained to classify
    else:
        model = Classifier(
            image_size=config['size'],
            image_channels=config['channels'],
            classes=config['classes'],
            fc_layers=args.fc_lay,
            fc_units=args.fc_units,
            fc_drop=args.fc_drop,
            fc_nl=args.fc_nl,
            fc_bn=True if args.fc_bn == "yes" else False,
            excit_buffer=True if args.xdg and args.gating_prop > 0 else False,
            binaryCE=args.bce,
            binaryCE_distill=args.bce_distill,
        ).to(device)

    # Define optimizer (only include parameters that "requires_grad")
    model.optim_list = [{
        'params':
        filter(lambda p: p.requires_grad, model.parameters()),
        'lr':
        args.lr
    }]
    model.optim_type = args.optimizer
    if model.optim_type in ("adam", "adam_reset"):
        model.optimizer = optim.Adam(model.optim_list, betas=(0.9, 0.999))
    elif model.optim_type == "sgd":
        model.optimizer = optim.SGD(model.optim_list)
    else:
        raise ValueError(
            "Unrecognized optimizer, '{}' is not currently a valid option".
            format(args.optimizer))

    #-------------------------------------------------------------------------------------------------#

    #----------------------------------#
    #----- CL-STRATEGY: EXEMPLARS -----#
    #----------------------------------#

    # Store in model whether, how many and in what way to store exemplars
    if isinstance(model, ExemplarHandler) and (args.use_exemplars
                                               or args.add_exemplars
                                               or args.replay == "exemplars"):
        model.memory_budget = args.budget
        model.norm_exemplars = args.norm_exemplars
        model.herding = args.herding

    #-------------------------------------------------------------------------------------------------#

    #-----------------------------------#
    #----- CL-STRATEGY: ALLOCATION -----#
    #-----------------------------------#

    # Elastic Weight Consolidation (EWC)
    if isinstance(model, ContinualLearner):
        model.ewc_lambda = args.ewc_lambda if args.ewc else 0
        if args.ewc:
            model.fisher_n = args.fisher_n
            model.gamma = args.gamma
            model.online = args.online
            model.emp_FI = args.emp_fi

    # Synpatic Intelligence (SI)
    if isinstance(model, ContinualLearner):
        model.si_c = args.si_c if args.si else 0
        if args.si:
            model.epsilon = args.epsilon

    # XdG: create for every task a "mask" for each hidden fully connected layer
    if isinstance(model, ContinualLearner) and (args.xdg
                                                and args.gating_prop > 0):
        mask_dict = {}
        excit_buffer_list = []
        for task_id in range(args.tasks):
            mask_dict[task_id + 1] = {}
            for i in range(model.fcE.layers):
                layer = getattr(model.fcE, "fcLayer{}".format(i + 1)).linear
                if task_id == 0:
                    excit_buffer_list.append(layer.excit_buffer)
                n_units = len(layer.excit_buffer)
                gated_units = np.random.choice(n_units,
                                               size=int(args.gating_prop *
                                                        n_units),
                                               replace=False)
                mask_dict[task_id + 1][i] = gated_units
        model.mask_dict = mask_dict
        model.excit_buffer_list = excit_buffer_list

    #-------------------------------------------------------------------------------------------------#

    #-------------------------------#
    #----- CL-STRATEGY: REPLAY -----#
    #-------------------------------#

    # Use distillation loss (i.e., soft targets) for replayed data? (and set temperature)
    if isinstance(model, Replayer):
        model.replay_targets = "soft" if args.distill else "hard"
        model.KD_temp = args.temp

    # If needed, specify separate model for the generator
    train_gen = True if (args.replay == "generative"
                         and not args.feedback) else False
    if train_gen:
        # -specify architecture
        generator = AutoEncoder(
            image_size=config['size'],
            image_channels=config['channels'],
            fc_layers=args.g_fc_lay,
            fc_units=args.g_fc_uni,
            z_dim=args.g_z_dim,
            classes=config['classes'],
            fc_drop=args.fc_drop,
            fc_bn=True if args.fc_bn == "yes" else False,
            fc_nl=args.fc_nl,
        ).to(device)
        # -set optimizer(s)
        generator.optim_list = [{
            'params':
            filter(lambda p: p.requires_grad, generator.parameters()),
            'lr':
            args.lr_gen
        }]
        generator.optim_type = args.optimizer
        if generator.optim_type in ("adam", "adam_reset"):
            generator.optimizer = optim.Adam(generator.optim_list,
                                             betas=(0.9, 0.999))
        elif generator.optim_type == "sgd":
            generator.optimizer = optim.SGD(generator.optim_list)
    else:
        generator = None

    #-------------------------------------------------------------------------------------------------#

    #---------------------#
    #----- REPORTING -----#
    #---------------------#

    # Get parameter-stamp (and print on screen)
    param_stamp = get_param_stamp(
        args,
        model.name,
        verbose=True,
        replay=True if (not args.replay == "none") else False,
        replay_model_name=generator.name if
        (args.replay == "generative" and not args.feedback) else None,
    )

    # Print some model-characteristics on the screen
    # -main model
    print("\n")
    utils.print_model_info(model, title="MAIN MODEL")
    # -generator
    if generator is not None:
        utils.print_model_info(generator, title="GENERATOR")

    # Prepare for plotting in visdom
    # -define [precision_dict] to keep track of performance during training for storing and for later plotting in pdf
    precision_dict = evaluate.initiate_precision_dict(args.tasks)
    precision_dict_exemplars = evaluate.initiate_precision_dict(
        args.tasks) if args.use_exemplars else None
    # -visdom-settings
    if args.visdom:
        env_name = "{exp}{tasks}-{scenario}".format(exp=args.experiment,
                                                    tasks=args.tasks,
                                                    scenario=args.scenario)
        graph_name = "{fb}{replay}{syn}{ewc}{xdg}{icarl}{bud}".format(
            fb="1M-" if args.feedback else "",
            replay="{}{}".format(args.replay, "D" if args.distill else ""),
            syn="-si{}".format(args.si_c) if args.si else "",
            ewc="-ewc{}{}".format(
                args.ewc_lambda, "-O{}".format(args.gamma)
                if args.online else "") if args.ewc else "",
            xdg="" if (not args.xdg) or args.gating_prop == 0 else
            "-XdG{}".format(args.gating_prop),
            icarl="-iCaRL" if (args.use_exemplars and args.add_exemplars
                               and args.bce and args.bce_distill) else "",
            bud="-bud{}".format(args.budget) if
            (args.use_exemplars or args.add_exemplars
             or args.replay == "exemplars") else "",
        )
        visdom = {'env': env_name, 'graph': graph_name}
        if args.use_exemplars:
            visdom_exemplars = {
                'env': env_name,
                'graph': "{}-EX".format(graph_name)
            }
    else:
        visdom = visdom_exemplars = None

    #-------------------------------------------------------------------------------------------------#

    #---------------------#
    #----- CALLBACKS -----#
    #---------------------#

    # Callbacks for reporting on and visualizing loss
    generator_loss_cbs = [
        cb._VAE_loss_cb(
            log=args.loss_log,
            visdom=visdom,
            model=model if args.feedback else generator,
            tasks=args.tasks,
            iters_per_task=args.iters if args.feedback else args.g_iters,
            replay=False if args.replay == "none" else True)
    ] if (train_gen or args.feedback) else [None]
    solver_loss_cbs = [
        cb._solver_loss_cb(log=args.loss_log,
                           visdom=visdom,
                           model=model,
                           tasks=args.tasks,
                           iters_per_task=args.iters,
                           replay=False if args.replay == "none" else True)
    ] if (not args.feedback) else [None]

    # Callbacks for evaluating and plotting generated / reconstructed samples
    sample_cbs = [
        cb._sample_cb(
            log=args.sample_log,
            visdom=visdom,
            config=config,
            test_datasets=test_datasets,
            sample_size=args.sample_n,
            iters_per_task=args.iters if args.feedback else args.g_iters)
    ] if (train_gen or args.feedback) else [None]

    # Callbacks for reporting and visualizing accuracy
    # -visdom (i.e., after each [prec_log]
    eval_cb = cb._eval_cb(
        log=args.prec_log,
        test_datasets=test_datasets,
        visdom=visdom,
        precision_dict=None,
        iters_per_task=args.iters,
        test_size=args.prec_n,
        classes_per_task=classes_per_task,
        scenario=scenario,
    )
    # -pdf / reporting: summary plots (i.e, only after each task)
    eval_cb_full = cb._eval_cb(
        log=args.iters,
        test_datasets=test_datasets,
        precision_dict=precision_dict,
        iters_per_task=args.iters,
        classes_per_task=classes_per_task,
        scenario=scenario,
    )
    # -with exemplars (both for visdom & reporting / pdf)
    eval_cb_exemplars = cb._eval_cb(
        log=args.iters,
        test_datasets=test_datasets,
        visdom=visdom_exemplars,
        classes_per_task=classes_per_task,
        precision_dict=precision_dict_exemplars,
        scenario=scenario,
        iters_per_task=args.iters,
        with_exemplars=True,
    ) if args.use_exemplars else None
    # -collect them in <lists>
    eval_cbs = [eval_cb, eval_cb_full]
    eval_cbs_exemplars = [eval_cb_exemplars]

    #-------------------------------------------------------------------------------------------------#

    #--------------------#
    #----- TRAINING -----#
    #--------------------#

    print("--> Training:" + args.name)
    print("Total tasks:" + str(args.tasks_to_complete))
    # Keep track of training-time
    start = time.time()
    # Train model
    train_cl(
        args.tasks_to_complete,
        args.name,
        model,
        train_datasets,
        test_datasets,
        replay_mode=args.replay,
        scenario=scenario,
        classes_per_task=classes_per_task,
        iters=args.iters,
        batch_size=args.batch,
        generator=generator,
        gen_iters=args.g_iters,
        gen_loss_cbs=generator_loss_cbs,
        sample_cbs=sample_cbs,
        eval_cbs=eval_cbs,
        loss_cbs=generator_loss_cbs if args.feedback else solver_loss_cbs,
        eval_cbs_exemplars=eval_cbs_exemplars,
        use_exemplars=args.use_exemplars,
        add_exemplars=args.add_exemplars,
    )
    # Get total training-time in seconds, and write to file
    training_time = time.time() - start
    time_file = open("{}/time-{}.txt".format(args.r_dir, param_stamp), 'w')
    time_file.write('{}\n'.format(training_time))
    time_file.close()

    #-------------------------------------------------------------------------------------------------#

    #----------------------#
    #----- EVALUATION -----#
    #----------------------#

    print("\n\n--> Evaluation ({}-incremental learning scenario):".format(
        args.scenario))

    # Evaluate precision of final model on full test-set
    precs = [
        evaluate.validate(
            model,
            test_datasets[i],
            verbose=False,
            test_size=None,
            task=i + 1,
            with_exemplars=False,
            allowed_classes=list(
                range(classes_per_task * i, classes_per_task *
                      (i + 1))) if scenario == "task" else None)
        for i in range(args.tasks)
    ]
    print("\n Precision on test-set (softmax classification):")
    for i in range(args.tasks):
        print(" - Task {}: {:.4f}".format(i + 1, precs[i]))
    average_precs = sum(precs) / args.tasks
    print('=> average precision over all {} tasks: {:.4f}'.format(
        args.tasks, average_precs))

    # -with exemplars
    if args.use_exemplars:
        precs = [
            evaluate.validate(
                model,
                test_datasets[i],
                verbose=False,
                test_size=None,
                task=i + 1,
                with_exemplars=True,
                allowed_classes=list(
                    range(classes_per_task * i, classes_per_task *
                          (i + 1))) if scenario == "task" else None)
            for i in range(args.tasks)
        ]
        print("\n Precision on test-set (classification using exemplars):")
        for i in range(args.tasks):
            print(" - Task {}: {:.4f}".format(i + 1, precs[i]))
        average_precs_ex = sum(precs) / args.tasks
        print('=> average precision over all {} tasks: {:.4f}'.format(
            args.tasks, average_precs_ex))
    print("\n")

    #-------------------------------------------------------------------------------------------------#

    #------------------#
    #----- OUTPUT -----#
    #------------------#

    # Average precision on full test set
    output_file = open("{}/prec-{}.txt".format(args.r_dir, param_stamp), 'w')
    output_file.write('{}\n'.format(
        average_precs_ex if args.use_exemplars else average_precs))
    output_file.close()
    # -precision-dict
    file_name = "{}/dict-{}".format(args.r_dir, param_stamp)
    utils.save_object(
        precision_dict_exemplars if args.use_exemplars else precision_dict,
        file_name)

    # Average precision on full test set not evaluated using exemplars (i.e., using softmax on final layer)
    if args.use_exemplars:
        output_file = open(
            "{}/prec_noex-{}.txt".format(args.r_dir, param_stamp), 'w')
        output_file.write('{}\n'.format(average_precs))
        output_file.close()
        # -precision-dict:
        file_name = "{}/dict_noex-{}".format(args.r_dir, param_stamp)
        utils.save_object(precision_dict, file_name)

    #-------------------------------------------------------------------------------------------------#

    #--------------------#
    #----- PLOTTING -----#
    #--------------------#

    # If requested, generate pdf
    if args.pdf:
        # -open pdf
        pp = visual_plt.open_pdf("{}/{}.pdf".format(args.p_dir, param_stamp))

        # -show samples and reconstructions (either from main model or from separate generator)
        if args.feedback or args.replay == "generative":
            evaluate.show_samples(model if args.feedback else generator,
                                  config,
                                  size=args.sample_n,
                                  pdf=pp)
            for i in range(args.tasks):
                evaluate.show_reconstruction(
                    model if args.feedback else generator,
                    test_datasets[i],
                    config,
                    pdf=pp,
                    task=i + 1)

        # -show metrics reflecting progression during training
        figure_list = []  #-> create list to store all figures to be plotted
        # -generate all figures (and store them in [figure_list])
        figure = visual_plt.plot_lines(
            precision_dict["all_tasks"],
            x_axes=precision_dict["x_task"],
            line_names=['task {}'.format(i + 1) for i in range(args.tasks)])
        figure_list.append(figure)
        figure = visual_plt.plot_lines([precision_dict["average"]],
                                       x_axes=precision_dict["x_task"],
                                       line_names=['average all tasks so far'])
        figure_list.append(figure)
        if args.use_exemplars:
            figure = visual_plt.plot_lines(
                precision_dict_exemplars["all_tasks"],
                x_axes=precision_dict_exemplars["x_task"],
                line_names=[
                    'task {}'.format(i + 1) for i in range(args.tasks)
                ])
            figure_list.append(figure)
        # -add figures to pdf (and close this pdf).
        for figure in figure_list:
            pp.savefig(figure)

        # -close pdf
        pp.close()
Esempio n. 5
0
def run(args):

    if not args.single_test:
        import pidfile
        resfile = pidfile.exclusive_dirfn(
            os.path.join(args.r_dir, args.save_dir))

    if args.log_per_task:
        args.prec_log = args.iters
        args.loss_log = args.iters

    # -create plots- and results-directories if needed
    if not os.path.isdir(args.r_dir):
        os.mkdir(args.r_dir)
    if args.pdf and not os.path.isdir(args.p_dir):
        os.mkdir(args.p_dir)

    # set cuda
    cuda = torch.cuda.is_available() and args.cuda
    device = torch.device("cuda" if cuda else "cpu")

    # set random seeds
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if cuda:
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)

    scenario = args.scenario

    #-------------------------------------------------------------------------------------------------
    # DATA
    #-------------------------------------------------------------------------------------------------
    (train_datasets, test_datasets), config = get_multitask_experiment(
        args,
        name=args.experiment,
        scenario=scenario,
        tasks=args.tasks,
        data_dir=args.d_dir,
        verbose=True,
        exception=True if args.seed == 0 else False,
    )
    args.tasks = len(config['labels_per_task'])
    args.labels_per_task = config['labels_per_task']
    if not args.task_boundary:
        args.iterations_per_virtual_epc = config['iterations_per_virtual_epc']
        args.task_dict = config['task_dict']

    #-------------------------------------------------------------------------------------------------
    # MODEL
    #-------------------------------------------------------------------------------------------------
    if args.ebm:
        model = EBM(args,
                    image_size=config['size'],
                    image_channels=config['channels'],
                    classes=config['num_classes'],
                    fc_units=args.fc_units).to(device)
    else:
        model = Classifier(args,
                           image_size=config['size'],
                           image_channels=config['channels'],
                           classes=config['num_classes'],
                           fc_units=args.fc_units).to(device)

    if args.experiment == 'cifar100':
        model = utils.init_params(model, args)
        for param in model.convE.parameters():
            param.requires_grad = False

    if args.pretrain:
        checkpoint = torch.load(args.pretrain)
        best_acc = checkpoint['best_acc']
        checkpoint_state = checkpoint['state_dict']

        print(
            '-----------------------------------------------------------------------------'
        )
        print('load pretrained model %s' % args.pretrain)
        print('best_acc', best_acc)
        print(
            '-----------------------------------------------------------------------------'
        )

        model_dict = model.fcE.state_dict()
        checkpoint_state = {
            k[7:]: v
            for k, v in checkpoint_state.items() if k[7:] in model_dict
        }  ## remove module.
        del checkpoint_state['classifier.weight']
        del checkpoint_state['classifier.bias']
        if 'y_ebm.weight' in checkpoint_state:
            del checkpoint_state['y_ebm.weight']
        model_dict.update(checkpoint_state)
        model.fcE.load_state_dict(model_dict)

        for param in model.fcE.model.parameters():
            param.requires_grad = False

    model.optim_list = [{
        'params':
        filter(lambda p: p.requires_grad, model.parameters()),
        'lr':
        args.lr
    }]
    model.optim_type = args.optimizer

    if model.optim_type in ("adam", "adam_reset"):
        model.optimizer = optim.Adam(model.optim_list, betas=(0.9, 0.999))
    elif model.optim_type == "sgd":
        model.optimizer = optim.SGD(model.optim_list)
    else:
        raise ValueError(
            "Unrecognized optimizer, '{}' is not currently a valid option".
            format(args.optimizer))

    #-------------------------------------------------------------------------------------------------
    # CL-STRATEGY: ALLOCATION
    #-------------------------------------------------------------------------------------------------

    # Elastic Weight Consolidation (EWC)
    if isinstance(model, ContinualLearner):
        model.ewc_lambda = args.ewc_lambda if args.ewc else 0
        if args.ewc:
            model.fisher_n = args.fisher_n
            model.gamma = args.gamma
            model.online = args.online
            model.emp_FI = args.emp_fi

    # Synpatic Intelligence (SI)
    if isinstance(model, ContinualLearner):
        model.si_c = args.si_c if args.si else 0
        if args.si:
            model.epsilon = args.epsilon

    #-------------------------------------------------------------------------------------------------
    # Get parameter-stamp (and print on screen)
    #-------------------------------------------------------------------------------------------------
    param_stamp = get_param_stamp(args, model.name, verbose=True)
    param_stamp = param_stamp + '--' + args.model_name

    # -define [precision_dict] to keep track of performance during training for storing and for later plotting in pdf
    precision_dict = evaluate.initiate_precision_dict(args.tasks)

    #-------------------------------------------------------------------------------------------------#

    #---------------------#
    #----- CALLBACKS -----#
    #---------------------#
    solver_loss_cbs = [
        cb._solver_loss_cb(log=args.loss_log,
                           model=model,
                           tasks=args.tasks,
                           iters_per_task=args.iters)
    ]

    eval_cb = cb._eval_cb(log=args.prec_log,
                          test_datasets=test_datasets,
                          visdom=args.visdom,
                          precision_dict=None,
                          iters_per_task=args.iters,
                          test_size=args.prec_n,
                          labels_per_task=config['labels_per_task'],
                          scenario=scenario)
    eval_cb_full = cb._eval_cb(log=args.iters,
                               test_datasets=test_datasets,
                               precision_dict=precision_dict,
                               iters_per_task=args.iters,
                               labels_per_task=config['labels_per_task'],
                               scenario=scenario)
    eval_cbs = [eval_cb, eval_cb_full]

    #-------------------------------------------------------------------------------------------------
    # TRAINING
    #-------------------------------------------------------------------------------------------------
    print("--> Training:")
    start = time.time()

    if args.task_boundary:
        train_cl(args,
                 model,
                 train_datasets,
                 scenario=scenario,
                 labels_per_task=config['labels_per_task'],
                 iters=args.iters,
                 batch_size=args.batch,
                 eval_cbs=eval_cbs,
                 loss_cbs=solver_loss_cbs)
    else:
        train_cl_noboundary(args,
                            model,
                            train_datasets,
                            scenario=scenario,
                            labels_per_task=config['labels_per_task'],
                            iters=args.iters,
                            batch_size=args.batch,
                            eval_cbs=eval_cbs,
                            loss_cbs=solver_loss_cbs)

    training_time = time.time() - start

    #-------------------------------------------------------------------------------------------------
    # EVALUATION
    #-------------------------------------------------------------------------------------------------
    print("\n\n--> Evaluation ({}-incremental learning scenario):".format(
        args.scenario))
    if args.ebm:
        precs = [
            evaluate.validate_ebm(args,
                                  model,
                                  test_datasets[i],
                                  verbose=False,
                                  test_size=None,
                                  task=i + 1,
                                  with_exemplars=False,
                                  current_task=args.tasks)
            for i in range(args.tasks)
        ]
    else:
        precs = [
            evaluate.validate(args,
                              model,
                              test_datasets[i],
                              verbose=False,
                              test_size=None,
                              task=i + 1,
                              with_exemplars=False,
                              current_task=args.tasks)
            for i in range(args.tasks)
        ]

    print("\n Precision on test-set (softmax classification):")
    for i in range(args.tasks):
        print(" - Task {}: {:.4f}".format(i + 1, precs[i]))
    average_precs = sum(precs) / args.tasks
    print('average precision over all {} tasks: {:.4f}'.format(
        args.tasks, average_precs))

    #-------------------------------------------------------------------------------------------------
    # OUTPUT
    #-------------------------------------------------------------------------------------------------
    if not os.path.exists(os.path.join(args.r_dir, args.save_dir)):
        os.makedirs(os.path.join(args.r_dir, args.save_dir))

    output_file = open(
        "{}/{}/{}.txt".format(args.r_dir, args.save_dir, param_stamp), 'w')
    output_file.write("Training time {} \n".format(training_time))
    for i in range(args.tasks):
        output_file.write(" - Task {}: {:.4f}".format(i + 1, precs[i]))
        output_file.write("\n")
    output_file.write(' - Average {}\n'.format(average_precs))
    output_file.close()
    file_name = "{}/{}/{}".format(args.r_dir, args.save_dir, param_stamp)
    utils.save_object(precision_dict, file_name)

    if args.pdf:
        pp = visual_plt.open_pdf("{}/{}/{}.pdf".format(args.r_dir,
                                                       args.save_dir,
                                                       param_stamp))
        # -show metrics reflecting progression during training
        figure_list = []  #-> create list to store all figures to be plotted
        # -generate all figures (and store them in [figure_list])
        figure = visual_plt.plot_lines(
            precision_dict["all_tasks"],
            x_axes=precision_dict["x_task"],
            line_names=['task {}'.format(i + 1) for i in range(args.tasks)])
        figure_list.append(figure)
        figure = visual_plt.plot_lines([precision_dict["average"]],
                                       x_axes=precision_dict["x_task"],
                                       line_names=['average all tasks so far'])
        figure_list.append(figure)
        # -add figures to pdf (and close this pdf).
        for figure in figure_list:
            pp.savefig(figure)

        pp.close()

    if not args.single_test:
        resfile.done()
Esempio n. 6
0
def run(args, verbose=False):
    # Set default arguments & check for incompatible options
    args.lr_gen = args.lr if args.lr_gen is None else args.lr_gen
    args.g_iters = args.iters if args.g_iters is None else args.g_iters
    args.g_fc_lay = args.fc_lay if args.g_fc_lay is None else args.g_fc_lay
    args.g_fc_uni = args.fc_units if args.g_fc_uni is None else args.g_fc_uni
    # -if [log_per_task], reset all logs
    if args.log_per_task:
        args.prec_log = args.iters
        args.loss_log = args.iters
        args.sample_log = args.iters
    # -if [iCaRL] is selected, select all accompanying options
    if hasattr(args, "icarl") and args.icarl:
        args.use_exemplars = True
        args.add_exemplars = True
        args.bce = True
        args.bce_distill = True
    # -if XdG is selected but not the Task-IL scenario, give error
    if (not args.scenario == "task") and args.xdg:
        raise ValueError("'XdG' is only compatible with the Task-IL scenario.")
    # -if EWC, SI, XdG, A-GEM or iCaRL is selected together with 'feedback', give error
    if args.feedback and (args.ewc or args.si or args.xdg or args.icarl
                          or args.agem):
        raise NotImplementedError(
            "EWC, SI, XdG, A-GEM and iCaRL are not supported with feedback connections."
        )
    # -if A-GEM is selected without any replay, give warning
    if args.agem and args.replay == "none":
        raise Warning(
            "The '--agem' flag is selected, but without any type of replay. "
            "For the original A-GEM method, also select --replay='exemplars'.")
    # -if EWC, SI, XdG, A-GEM or iCaRL is selected together with offline-replay, give error
    if args.replay == "offline" and (args.ewc or args.si or args.xdg
                                     or args.icarl or args.agem):
        raise NotImplementedError(
            "Offline replay cannot be combined with EWC, SI, XdG, A-GEM or iCaRL."
        )
    # -if binary classification loss is selected together with 'feedback', give error
    if args.feedback and args.bce:
        raise NotImplementedError(
            "Binary classification loss not supported with feedback connections."
        )
    # -if XdG is selected together with both replay and EWC, give error (either one of them alone with XdG is fine)
    if (args.xdg and args.gating_prop > 0) and (
            not args.replay == "none") and (args.ewc or args.si):
        raise NotImplementedError(
            "XdG is not supported with both '{}' replay and EWC / SI.".format(
                args.replay))
        # --> problem is that applying different task-masks interferes with gradient calculation
        #    (should be possible to overcome by calculating backward step on EWC/SI-loss also for each mask separately)
    # -if 'BCEdistill' is selected for other than scenario=="class", give error
    if args.bce_distill and not args.scenario == "class":
        raise ValueError(
            "BCE-distill can only be used for class-incremental learning.")
    # -create plots- and results-directories if needed
    if not os.path.isdir(args.r_dir):
        os.mkdir(args.r_dir)
    if args.pdf and not os.path.isdir(args.p_dir):
        os.mkdir(args.p_dir)

    scenario = args.scenario
    # If Task-IL scenario is chosen with single-headed output layer, set args.scenario to "domain"
    # (but note that when XdG is used, task-identity information is being used so the actual scenario is still Task-IL)
    if args.singlehead and args.scenario == "task":
        scenario = "domain"

    # If only want param-stamp, get it printed to screen and exit
    if hasattr(args, "get_stamp") and args.get_stamp:
        print(get_param_stamp_from_args(args=args))
        exit()

    # Use cuda?
    cuda = torch.cuda.is_available() and args.cuda
    device = torch.device("cuda" if cuda else "cpu")
    if verbose:
        print("CUDA is {}used".format("" if cuda else "NOT(!!) "))

    # Set random seeds
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if cuda:
        torch.cuda.manual_seed(args.seed)

    # -------------------------------------------------------------------------------------------------#

    # ----------------#
    # ----- DATA -----#
    # ----------------#

    # Prepare data for chosen experiment
    if verbose:
        print("\nPreparing the data...")
    (train_datasets,
     test_datasets), config, classes_per_task = get_multitask_experiment(
         name=args.experiment,
         scenario=scenario,
         tasks=args.tasks,
         data_dir=args.d_dir,
         verbose=verbose,
         exception=True if args.seed == 0 else False,
     )

    # -------------------------------------------------------------------------------------------------#

    # ------------------------------#
    # ----- MODEL (CLASSIFIER) -----#
    # ------------------------------#

    # Define main model (i.e., classifier, if requested with feedback connections)
    if args.feedback:
        model = AutoEncoder(
            image_size=config['size'],
            image_channels=config['channels'],
            classes=config['classes'],
            fc_layers=args.fc_lay,
            fc_units=args.fc_units,
            z_dim=args.z_dim,
            fc_drop=args.fc_drop,
            fc_bn=True if args.fc_bn == "yes" else False,
            fc_nl=args.fc_nl,
        ).to(device)
        model.lamda_pl = 1.  # --> to make that this VAE is also trained to classify
    else:
        model = Classifier(
            image_size=config['size'],
            image_channels=config['channels'],
            classes=config['classes'],
            fc_layers=args.fc_lay,
            fc_units=args.fc_units,
            fc_drop=args.fc_drop,
            fc_nl=args.fc_nl,
            fc_bn=True if args.fc_bn == "yes" else False,
            excit_buffer=True if args.xdg and args.gating_prop > 0 else False,
            binaryCE=args.bce,
            binaryCE_distill=args.bce_distill,
            AGEM=args.agem,
        ).to(device)

    # Define optimizer (only include parameters that "requires_grad")
    model.optim_list = [{
        'params':
        filter(lambda p: p.requires_grad, model.parameters()),
        'lr':
        args.lr
    }]
    model.optim_type = args.optimizer
    if model.optim_type in ("adam", "adam_reset"):
        model.optimizer = optim.Adam(model.optim_list, betas=(0.9, 0.999))
    elif model.optim_type == "sgd":
        model.optimizer = optim.SGD(model.optim_list)
    else:
        raise ValueError(
            "Unrecognized optimizer, '{}' is not currently a valid option".
            format(args.optimizer))

    # -------------------------------------------------------------------------------------------------#

    # ----------------------------------#
    # ----- CL-STRATEGY: EXEMPLARS -----#
    # ----------------------------------#

    # Store in model whether, how many and in what way to store exemplars
    if isinstance(model, ExemplarHandler) and (args.use_exemplars
                                               or args.add_exemplars
                                               or args.replay == "exemplars"):
        model.memory_budget = args.budget
        model.norm_exemplars = args.norm_exemplars
        model.herding = args.herding

    # -------------------------------------------------------------------------------------------------#

    # -----------------------------------#
    # ----- CL-STRATEGY: ALLOCATION -----#
    # -----------------------------------#

    # Elastic Weight Consolidation (EWC)
    if isinstance(model, ContinualLearner):
        model.ewc_lambda = args.ewc_lambda if args.ewc else 0
        if args.ewc:
            model.fisher_n = args.fisher_n
            model.gamma = args.gamma
            model.online = args.online
            model.emp_FI = args.emp_fi

    # Synpatic Intelligence (SI)
    if isinstance(model, ContinualLearner):
        model.si_c = args.si_c if args.si else 0
        if args.si:
            model.epsilon = args.epsilon

    # XdG: create for every task a "mask" for each hidden fully connected layer
    if isinstance(model, ContinualLearner) and (args.xdg
                                                and args.gating_prop > 0):
        mask_dict = {}
        excit_buffer_list = []
        for task_id in range(args.tasks):
            mask_dict[task_id + 1] = {}
            for i in range(model.fcE.layers):
                layer = getattr(model.fcE, "fcLayer{}".format(i + 1)).linear
                if task_id == 0:
                    excit_buffer_list.append(layer.excit_buffer)
                n_units = len(layer.excit_buffer)
                gated_units = np.random.choice(n_units,
                                               size=int(args.gating_prop *
                                                        n_units),
                                               replace=False)
                mask_dict[task_id + 1][i] = gated_units
        model.mask_dict = mask_dict
        model.excit_buffer_list = excit_buffer_list

    # -------------------------------------------------------------------------------------------------#

    # -------------------------------#
    # ----- CL-STRATEGY: REPLAY -----#
    # -------------------------------#

    # Use distillation loss (i.e., soft targets) for replayed data? (and set temperature)
    if isinstance(model, Replayer):
        model.replay_targets = "soft" if args.distill else "hard"
        model.KD_temp = args.temp

    # If needed, specify separate model for the generator
    train_gen = True if (args.replay == "generative"
                         and not args.feedback) else False
    if train_gen:
        # -specify architecture
        generator = AutoEncoder(
            image_size=config['size'],
            image_channels=config['channels'],
            fc_layers=args.g_fc_lay,
            fc_units=args.g_fc_uni,
            z_dim=args.g_z_dim,
            classes=config['classes'],
            fc_drop=args.fc_drop,
            fc_bn=True if args.fc_bn == "yes" else False,
            fc_nl=args.fc_nl,
        ).to(device)
        # -set optimizer(s)
        generator.optim_list = [{
            'params':
            filter(lambda p: p.requires_grad, generator.parameters()),
            'lr':
            args.lr_gen
        }]
        generator.optim_type = args.optimizer
        if generator.optim_type in ("adam", "adam_reset"):
            generator.optimizer = optim.Adam(generator.optim_list,
                                             betas=(0.9, 0.999))
        elif generator.optim_type == "sgd":
            generator.optimizer = optim.SGD(generator.optim_list)
    else:
        generator = None

    # -------------------------------------------------------------------------------------------------#

    # ---------------------#
    # ----- REPORTING -----#
    # ---------------------#

    # Get parameter-stamp (and print on screen)
    if verbose:
        print("\nParameter-stamp...")
    param_stamp = get_param_stamp(
        args,
        model.name,
        verbose=verbose,
        replay=True if (not args.replay == "none") else False,
        replay_model_name=generator.name if
        (args.replay == "generative" and not args.feedback) else None,
    )

    # Print some model-characteristics on the screen
    if verbose:
        # -main model
        utils.print_model_info(model, title="MAIN MODEL")
        # -generator
        if generator is not None:
            utils.print_model_info(generator, title="GENERATOR")

    # Prepare for keeping track of statistics required for metrics (also used for plotting in pdf)
    if args.pdf or args.metrics:
        # -define [metrics_dict] to keep track of performance during training for storing & for later plotting in pdf
        metrics_dict = evaluate.initiate_metrics_dict(n_tasks=args.tasks,
                                                      scenario=args.scenario)
        # -evaluate randomly initiated model on all tasks & store accuracies in [metrics_dict] (for calculating metrics)
        if not args.use_exemplars:
            metrics_dict = evaluate.intial_accuracy(
                model,
                test_datasets,
                metrics_dict,
                classes_per_task=classes_per_task,
                scenario=scenario,
                test_size=None,
                no_task_mask=False)
    else:
        metrics_dict = None

    # Prepare for plotting in visdom
    # -visdom-settings
    if args.visdom:
        env_name = "{exp}{tasks}-{scenario}".format(exp=args.experiment,
                                                    tasks=args.tasks,
                                                    scenario=args.scenario)
        graph_name = "{fb}{replay}{syn}{ewc}{xdg}{icarl}{bud}".format(
            fb="1M-" if args.feedback else "",
            replay="{}{}{}".format(args.replay, "D" if args.distill else "",
                                   "-aGEM" if args.agem else ""),
            syn="-si{}".format(args.si_c) if args.si else "",
            ewc="-ewc{}{}".format(
                args.ewc_lambda, "-O{}".format(args.gamma)
                if args.online else "") if args.ewc else "",
            xdg="" if (not args.xdg) or args.gating_prop == 0 else
            "-XdG{}".format(args.gating_prop),
            icarl="-iCaRL" if (args.use_exemplars and args.add_exemplars
                               and args.bce and args.bce_distill) else "",
            bud="-bud{}".format(args.budget) if
            (args.use_exemplars or args.add_exemplars
             or args.replay == "exemplars") else "",
        )
        visdom = {'env': env_name, 'graph': graph_name}
    else:
        visdom = None

    # -------------------------------------------------------------------------------------------------#

    # ---------------------#
    # ----- CALLBACKS -----#
    # ---------------------#

    # Callbacks for reporting on and visualizing loss
    generator_loss_cbs = [
        cb._VAE_loss_cb(
            log=args.loss_log,
            visdom=visdom,
            model=model if args.feedback else generator,
            tasks=args.tasks,
            iters_per_task=args.iters if args.feedback else args.g_iters,
            replay=False if args.replay == "none" else True)
    ] if (train_gen or args.feedback) else [None]
    solver_loss_cbs = [
        cb._solver_loss_cb(log=args.loss_log,
                           visdom=visdom,
                           model=model,
                           tasks=args.tasks,
                           iters_per_task=args.iters,
                           replay=False if args.replay == "none" else True)
    ] if (not args.feedback) else [None]

    # Callbacks for evaluating and plotting generated / reconstructed samples
    sample_cbs = [
        cb._sample_cb(
            log=args.sample_log,
            visdom=visdom,
            config=config,
            test_datasets=test_datasets,
            sample_size=args.sample_n,
            iters_per_task=args.iters if args.feedback else args.g_iters)
    ] if (train_gen or args.feedback) else [None]

    # Callbacks for reporting and visualizing accuracy
    # -visdom (i.e., after each [prec_log]
    eval_cbs = [
        cb._eval_cb(log=args.prec_log,
                    test_datasets=test_datasets,
                    visdom=visdom,
                    iters_per_task=args.iters,
                    test_size=args.prec_n,
                    classes_per_task=classes_per_task,
                    scenario=scenario,
                    with_exemplars=False)
    ] if (not args.use_exemplars) else [None]
    # --> during training on a task, evaluation cannot be with exemplars as those are only selected after training
    #    (instead, evaluation for visdom is only done after each task, by including callback-function into [metric_cbs])

    # Callbacks for calculating statists required for metrics
    # -pdf / reporting: summary plots (i.e, only after each task) (when using exemplars, also for visdom)
    metric_cbs = [
        cb._metric_cb(log=args.iters,
                      test_datasets=test_datasets,
                      classes_per_task=classes_per_task,
                      metrics_dict=metrics_dict,
                      scenario=scenario,
                      iters_per_task=args.iters,
                      with_exemplars=args.use_exemplars),
        cb._eval_cb(log=args.iters,
                    test_datasets=test_datasets,
                    visdom=visdom,
                    iters_per_task=args.iters,
                    test_size=args.prec_n,
                    classes_per_task=classes_per_task,
                    scenario=scenario,
                    with_exemplars=True) if args.use_exemplars else None
    ]

    # -------------------------------------------------------------------------------------------------#

    # --------------------#
    # ----- TRAINING -----#
    # --------------------#

    if verbose:
        print("\nTraining...")
    # Keep track of training-time
    start = time.time()
    # Train model
    train_cl(
        model,
        train_datasets,
        replay_mode=args.replay,
        scenario=scenario,
        classes_per_task=classes_per_task,
        iters=args.iters,
        batch_size=args.batch,
        generator=generator,
        gen_iters=args.g_iters,
        gen_loss_cbs=generator_loss_cbs,
        sample_cbs=sample_cbs,
        eval_cbs=eval_cbs,
        loss_cbs=generator_loss_cbs if args.feedback else solver_loss_cbs,
        metric_cbs=metric_cbs,
        use_exemplars=args.use_exemplars,
        add_exemplars=args.add_exemplars,
        param_stamp=param_stamp,
    )
    # Get total training-time in seconds, and write to file
    if args.time:
        training_time = time.time() - start
        time_file = open("{}/time-{}.txt".format(args.r_dir, param_stamp), 'w')
        time_file.write('{}\n'.format(training_time))
        time_file.close()

    # -------------------------------------------------------------------------------------------------#

    # ----------------------#
    # ----- EVALUATION -----#
    # ----------------------#

    if verbose:
        print("\n\nEVALUATION RESULTS:")

    # Evaluate precision of final model on full test-set
    precs = [
        evaluate.validate(
            model,
            test_datasets[i],
            verbose=False,
            test_size=None,
            task=i + 1,
            with_exemplars=False,
            allowed_classes=list(
                range(classes_per_task * i, classes_per_task *
                      (i + 1))) if scenario == "task" else None)
        for i in range(args.tasks)
    ]
    average_precs = sum(precs) / args.tasks
    # -print on screen
    if verbose:
        print("\n Precision on test-set{}:".format(
            " (softmax classification)" if args.use_exemplars else ""))
        for i in range(args.tasks):
            print(" - Task {} [{}-{}]: {:.4f}".format(
                i + 1, classes_per_task * i,
                classes_per_task * (i + 1) - 1, precs[i]))
        print('=> Average precision over all {} tasks: {:.4f}\n'.format(
            args.tasks, average_precs))

    # -with exemplars
    if args.use_exemplars:
        precs = [
            evaluate.validate(
                model,
                test_datasets[i],
                verbose=False,
                test_size=None,
                task=i + 1,
                with_exemplars=True,
                allowed_classes=list(
                    range(classes_per_task * i, classes_per_task *
                          (i + 1))) if scenario == "task" else None)
            for i in range(args.tasks)
        ]
        average_precs_ex = sum(precs) / args.tasks
        # -print on screen
        if verbose:
            print(" Precision on test-set (classification using exemplars):")
            for i in range(args.tasks):
                print(" - Task {}: {:.4f}".format(i + 1, precs[i]))
            print('=> Average precision over all {} tasks: {:.4f}\n'.format(
                args.tasks, average_precs_ex))

    if args.metrics:
        # Accuracy matrix
        if args.scenario in ('task', 'domain'):
            R = pd.DataFrame(data=metrics_dict['acc per task'],
                             index=[
                                 'after task {}'.format(i + 1)
                                 for i in range(args.tasks)
                             ])
            R.loc['at start'] = metrics_dict['initial acc per task'] if (
                not args.use_exemplars) else ['NA' for _ in range(args.tasks)]
            R = R.reindex(
                ['at start'] +
                ['after task {}'.format(i + 1) for i in range(args.tasks)])
            BWTs = [(R.loc['after task {}'.format(args.tasks), 'task {}'.format(i + 1)] - \
                     R.loc['after task {}'.format(i + 1), 'task {}'.format(i + 1)]) for i in range(args.tasks - 1)]
            FWTs = [
                0. if args.use_exemplars else
                (R.loc['after task {}'.format(i + 1),
                       'task {}'.format(i + 2)] -
                 R.loc['at start', 'task {}'.format(i + 2)])
                for i in range(args.tasks - 1)
            ]
            forgetting = []
            for i in range(args.tasks - 1):
                forgetting.append(
                    max(R.iloc[1:args.tasks, i]) - R.iloc[args.tasks, i])
            R.loc['FWT (per task)'] = ['NA'] + FWTs
            R.loc['BWT (per task)'] = BWTs + ['NA']
            R.loc['F (per task)'] = forgetting + ['NA']
            BWT = sum(BWTs) / (args.tasks - 1)
            F = sum(forgetting) / (args.tasks - 1)
            FWT = sum(FWTs) / (args.tasks - 1)
            metrics_dict['BWT'] = BWT
            metrics_dict['F'] = F
            metrics_dict['FWT'] = FWT
            # -print on screen
            if verbose:
                print("Accuracy matrix")
                print(R)
                print("\nFWT = {:.4f}".format(FWT))
                print("BWT = {:.4f}".format(BWT))
                print("  F = {:.4f}\n\n".format(F))
        else:
            if verbose:
                # Accuracy matrix based only on classes in that task (i.e., evaluation as if Task-IL scenario)
                R = pd.DataFrame(
                    data=metrics_dict['acc per task (only classes in task)'],
                    index=[
                        'after task {}'.format(i + 1)
                        for i in range(args.tasks)
                    ])
                R.loc['at start'] = metrics_dict[
                    'initial acc per task (only classes in task)'] if not args.use_exemplars else [
                        'NA' for _ in range(args.tasks)
                    ]
                R = R.reindex(
                    ['at start'] +
                    ['after task {}'.format(i + 1) for i in range(args.tasks)])
                print(
                    "Accuracy matrix, based on only classes in that task ('as if Task-IL scenario')"
                )
                print(R)

                # Accuracy matrix, always based on all classes
                R = pd.DataFrame(
                    data=metrics_dict['acc per task (all classes)'],
                    index=[
                        'after task {}'.format(i + 1)
                        for i in range(args.tasks)
                    ])
                R.loc['at start'] = metrics_dict[
                    'initial acc per task (only classes in task)'] if not args.use_exemplars else [
                        'NA' for _ in range(args.tasks)
                    ]
                R = R.reindex(
                    ['at start'] +
                    ['after task {}'.format(i + 1) for i in range(args.tasks)])
                print("\nAccuracy matrix, always based on all classes")
                print(R)

                # Accuracy matrix, based on all classes thus far
                R = pd.DataFrame(data=metrics_dict[
                    'acc per task (all classes up to trained task)'],
                                 index=[
                                     'after task {}'.format(i + 1)
                                     for i in range(args.tasks)
                                 ])
                print(
                    "\nAccuracy matrix, based on all classes up to the trained task"
                )
                print(R)

            # Accuracy matrix, based on all classes up to the task being evaluated
            # (this is the accuracy-matrix used for calculating the metrics in the Class-IL scenario)
            R = pd.DataFrame(data=metrics_dict[
                'acc per task (all classes up to evaluated task)'],
                             index=[
                                 'after task {}'.format(i + 1)
                                 for i in range(args.tasks)
                             ])
            R.loc['at start'] = metrics_dict[
                'initial acc per task (only classes in task)'] if not args.use_exemplars else [
                    'NA' for _ in range(args.tasks)
                ]
            R = R.reindex(
                ['at start'] +
                ['after task {}'.format(i + 1) for i in range(args.tasks)])
            BWTs = [(R.loc['after task {}'.format(args.tasks), 'task {}'.format(i + 1)] - \
                     R.loc['after task {}'.format(i + 1), 'task {}'.format(i + 1)]) for i in range(args.tasks - 1)]
            FWTs = [
                0. if args.use_exemplars else
                (R.loc['after task {}'.format(i + 1),
                       'task {}'.format(i + 2)] -
                 R.loc['at start', 'task {}'.format(i + 2)])
                for i in range(args.tasks - 1)
            ]
            forgetting = []
            for i in range(args.tasks - 1):
                forgetting.append(
                    max(R.iloc[1:args.tasks, i]) - R.iloc[args.tasks, i])
            R.loc['FWT (per task)'] = ['NA'] + FWTs
            R.loc['BWT (per task)'] = BWTs + ['NA']
            R.loc['F (per task)'] = forgetting + ['NA']
            BWT = sum(BWTs) / (args.tasks - 1)
            F = sum(forgetting) / (args.tasks - 1)
            FWT = sum(FWTs) / (args.tasks - 1)
            metrics_dict['BWT'] = BWT
            metrics_dict['F'] = F
            metrics_dict['FWT'] = FWT
            # -print on screen
            if verbose:
                print(
                    "\nAccuracy matrix, based on all classes up to the evaluated task"
                )
                print(R)
                print("\n=> FWT = {:.4f}".format(FWT))
                print("=> BWT = {:.4f}".format(BWT))
                print("=>  F = {:.4f}\n".format(F))

    if verbose and args.time:
        print(
            "=> Total training time = {:.1f} seconds\n".format(training_time))

    # -------------------------------------------------------------------------------------------------#

    # ------------------#
    # ----- OUTPUT -----#
    # ------------------#

    # Average precision on full test set
    output_file = open("{}/prec-{}.txt".format(args.r_dir, param_stamp), 'w')
    output_file.write('{}\n'.format(
        average_precs_ex if args.use_exemplars else average_precs))
    output_file.close()
    # -metrics-dict
    if args.metrics:
        file_name = "{}/dict-{}".format(args.r_dir, param_stamp)
        utils.save_object(metrics_dict, file_name)

    # -------------------------------------------------------------------------------------------------#

    # --------------------#
    # ----- PLOTTING -----#
    # --------------------#

    # If requested, generate pdf
    if args.pdf:
        # -open pdf
        plot_name = "{}/{}.pdf".format(args.p_dir, param_stamp)
        pp = visual_plt.open_pdf(plot_name)

        # -show samples and reconstructions (either from main model or from separate generator)
        if args.feedback or args.replay == "generative":
            evaluate.show_samples(model if args.feedback else generator,
                                  config,
                                  size=args.sample_n,
                                  pdf=pp)
            for i in range(args.tasks):
                evaluate.show_reconstruction(
                    model if args.feedback else generator,
                    test_datasets[i],
                    config,
                    pdf=pp,
                    task=i + 1)

        # -show metrics reflecting progression during training
        figure_list = []  # -> create list to store all figures to be plotted

        # -generate all figures (and store them in [figure_list])
        key = "acc per task ({} task)".format(
            "all classes up to trained" if scenario ==
            'class' else "only classes in")
        plot_list = []
        for i in range(args.tasks):
            plot_list.append(metrics_dict[key]["task {}".format(i + 1)])
        figure = visual_plt.plot_lines(
            plot_list,
            x_axes=metrics_dict["x_task"],
            line_names=['task {}'.format(i + 1) for i in range(args.tasks)])
        figure_list.append(figure)
        figure = visual_plt.plot_lines([metrics_dict["average"]],
                                       x_axes=metrics_dict["x_task"],
                                       line_names=['average all tasks so far'])
        figure_list.append(figure)

        # -add figures to pdf (and close this pdf).
        for figure in figure_list:
            pp.savefig(figure)

        # -close pdf
        pp.close()

        # -print name of generated plot on screen
        if verbose:
            print("\nGenerated plot: {}\n".format(plot_name))
Esempio n. 7
0
def train(model, train_datasets, test_datasets, epochs_per_task=10,
          batch_size=64, test_size=1024, consolidate=True,
          fisher_estimation_sample_size=1024,
          lr=1e-3, weight_decay=1e-5, lamda=3,
          loss_log_interval=30,
          eval_log_interval=50,
          cuda=False,
          plot="pdf",
          pdf_file_name=None,
          epsilon=1e-3,
          c=1,
          intelligent_synapses=False):

    # number of tasks
    n_tasks = len(train_datasets)

    # prepare the loss criterion and the optimizer.
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr,
                           weight_decay=weight_decay)

    # register starting param-values (needed for "intelligent synapses").
    if intelligent_synapses:
        for n, p in model.named_parameters():
            n = n.replace('.', '__')
            model.register_buffer('{}_prev_task'.format(n), p.data.clone())

    # if plotting, prepare task names and plot-titles
    if not plot=="none":
        names = ['task {}'.format(i + 1) for i in range(n_tasks)]
        title_precision = 'precision (consolidated)' if consolidate else 'precision'
        title_loss = 'loss (consolidated)' if consolidate else 'loss'

    # if plotting in pdf, initiate lists for storing data
    if plot=="pdf":
        all_task_lists = [[] for _ in range(n_tasks)]
        x_list = []
        average_list = []
        all_loss_lists = [[] for _ in range(3)]
        x_loss_list = []

    # training, ..looping over all tasks
    for task, train_dataset in enumerate(train_datasets, 1):

        # if requested, prepare dictionaries to store running importance
        #  estimates and parameter-values before update
        if intelligent_synapses:
            W = {}
            p_old = {}
            for n, p in model.named_parameters():
                n = n.replace('.', '__')
                W[n] = p.data.clone().zero_()
                p_old[n] = p.data.clone()

        # ..looping over all epochs
        for epoch in range(1, epochs_per_task+1):

            # prepare data-loader, and wrap in "tqdm"-object.
            data_loader = utils.get_data_loader(
                train_dataset, batch_size=batch_size, cuda=cuda
            )
            data_stream = tqdm(enumerate(data_loader, 1))

            # ..looping over all batches
            for batch_index, (x, y) in data_stream:

                # where are we?
                data_size = len(x)
                dataset_size = len(data_loader.dataset)
                dataset_batches = len(data_loader)
                previous_task_iteration = sum([
                    epochs_per_task * len(d) // batch_size for d in
                    train_datasets[:task-1]
                ])
                current_task_iteration = (epoch-1)*dataset_batches + batch_index
                iteration = previous_task_iteration + current_task_iteration

                # prepare the data.
                x = x.view(data_size, -1)
                x = Variable(x).cuda() if cuda else Variable(x)
                y = Variable(y).cuda() if cuda else Variable(y)

                # run model, backpropagate errors, update parameters.
                model.train()
                optimizer.zero_grad()
                scores = model(x)
                ce_loss = criterion(scores, y)
                ewc_loss = model.ewc_loss(lamda, cuda=cuda)
                surrogate_loss = model.surrogate_loss(c, cuda=cuda)
                loss = ce_loss + ewc_loss + surrogate_loss
                loss.backward()
                optimizer.step()

                # if requested, update importance estimates
                if intelligent_synapses:
                    for n, p in model.named_parameters():
                        n = n.replace('.', '__')
                        W[n].add_(-p.grad.data*(p.data-p_old[n]))
                        p_old[n] = p.data.clone()

                # calculate the training precision.
                _, predicted = scores.max(1)
                precision = (predicted == y).sum().data[0] / len(x)

                # print progress to the screen using "tqdm"
                data_stream.set_description((
                    'task: {task}/{tasks} | '
                    'epoch: {epoch}/{epochs} | '
                    'progress: [{trained}/{total}] ({progress:.0f}%) | '
                    'prec: {prec:.4} | '
                    'loss => '
                    'ce: {ce_loss:.4} / '
                    'ewc: {ewc_loss:.4} / '
                    'total: {loss:.4}'
                ).format(
                    task=task,
                    tasks=n_tasks,
                    epoch=epoch,
                    epochs=epochs_per_task,
                    trained=batch_index*batch_size,
                    total=dataset_size,
                    progress=(100.*batch_index/dataset_batches),
                    prec=precision,
                    ce_loss=ce_loss.data[0],
                    ewc_loss=ewc_loss.data[0],
                    loss=loss.data[0],
                ))

                # Send test precision to the visdom server,
                #  or store for later plotting to pdf.
                if not plot=="none":
                    if iteration % eval_log_interval == 0:
                        precs = [
                            utils.validate(
                                model, test_datasets[i], test_size=test_size,
                                cuda=cuda, verbose=False,
                            ) if i+1 <= task else 0 for i in range(n_tasks)
                        ]
                        if plot=="visdom":
                            visual_visdom.visualize_scalars(
                                precs, names, title_precision,
                                iteration, env=model.name,
                            )
                            visual_visdom.visualize_scalars(
                                [sum([precs[task_id] for task_id in range(task)]) / task],
                                ["average precision"], title_precision+" (ave)",
                                iteration, env=model.name,
                            )
                        elif plot=="pdf":
                            for task_id, _ in enumerate(names):
                                all_task_lists[task_id].append(precs[task_id])
                            average_list.append(sum([precs[task_id] for task_id in range(task)])/task)
                            x_list.append(iteration)

                # Send losses to the visdom server,
                #  or store for later plotting to pdf.
                if not plot=="none":
                    if iteration % loss_log_interval == 0:
                        if plot=="visdom":
                            visual_visdom.visualize_scalars(
                                [loss.data, ce_loss.data, ewc_loss.data, surrogate_loss.data],
                                ['total', 'cross entropy', 'ewc', 'surrogate loss'],
                                title_loss, iteration, env=model.name
                            )
                        elif plot=="pdf":
                            all_loss_lists[0].append(loss.data.cpu().numpy()[0])
                            all_loss_lists[1].append(ce_loss.data.cpu().numpy()[0])
                            all_loss_lists[2].append(ewc_loss.data.cpu().numpy()[0])
                            all_loss_lists[3].append(surrogate_loss.data.cpu().numpy()[0])
                            x_loss_list.append(iteration)

        if consolidate:
            # take [fisher_estimation_sample_size] random samples from the last task learned
            sample_ids = random.sample(range(len(train_dataset)), fisher_estimation_sample_size)
            selected_samples = [train_dataset[id] for id in sample_ids]
            # estimate the Fisher Information matrix and consolidate it in the network
            model.estimate_fisher(selected_samples)

        if intelligent_synapses:
            # update & consolidate normalized path integral in the network
            model.update_omega(W, epsilon)

    # if requested, generate pdf.
    if plot=="pdf":
        # create list to store all figures to be plotted.
        figure_list = []

        # Fig1: precision
        figure = visual_plt.plot_lines(
            all_task_lists, x_axes=x_list, line_names=names
        )
        figure_list.append(figure)

        # Fig2: loss
        figure = visual_plt.plot_lines(
            all_loss_lists, x_axes=x_loss_list,
            line_names=['total', 'cross entropy', 'ewc', 'surrogate loss']
        )
        figure_list.append(figure)

        # create pdf containing all figures.
        pdf = PdfPages(pdf_file_name)
        for figure in figure_list:
            pdf.savefig(figure)
        pdf.close()