def train_cl(model, train_datasets, replay_mode="none", scenario="task", rnt=None, classes_per_task=None,
             iters=2000, batch_size=32, batch_size_replay=None, loss_cbs=list(), eval_cbs=list(), sample_cbs=list(),
             generator=None, gen_iters=0, gen_loss_cbs=list(), feedback=False, reinit=False, args=None, only_last=False,
             sample_method='random', curated_multiplier=4):
    '''Train a model (with a "train_a_batch" method) on multiple tasks, with replay-strategy specified by [replay_mode].

    [model]             <nn.Module> main model to optimize across all tasks
    [train_datasets]    <list> with for each task the training <DataSet>
    [replay_mode]       <str>, choice from "generative", "current", "offline" and "none"
    [scenario]          <str>, choice from "task", "domain", "class" and "all"
    [classes_per_task]  <int>, # classes per task; only 1st task has [classes_per_task]*[first_task_class_boost] classes
    [rnt]               <float>, indicating relative importance of new task (if None, relative to # old tasks)
    [iters]             <int>, # optimization-steps (=batches) per task; 1st task has [first_task_iter_boost] steps more
    [batch_size_replay] <int>, number of samples to replay per batch
    [generator]         None or <nn.Module>, if a seperate generative model should be trained (for [gen_iters] per task)
    [feedback]          <bool>, if True and [replay_mode]="generative", the main model is used for generating replay
    [only_last]         <bool>, only train on final task / episode
    [*_cbs]             <list> of call-back functions to evaluate training-progress
    [sample_method]     <str> indicating the sample method, choices: 'random', 'uniform', 'curated', 'softmax', 'interfered', 'misclassified'
    [curated_multiplier]<int> choose curated samples out of size curated_multiplier * mutiply batch_size_replay

    '''

    # Should convolutional layers be frozen?
    freeze_convE = (utils.checkattr(args, "freeze_convE") and hasattr(args, "depth") and args.depth>0)

    # Use cuda?
    device = model._device()
    cuda = model._is_on_cuda()

    # Set default-values if not specified
    batch_size_replay = batch_size if batch_size_replay is None else batch_size_replay

    # Initiate indicators for replay (no replay for 1st task)
    Generative = Current = Offline_TaskIL = False
    previous_model = None

    # Register starting param-values (needed for "intelligent synapses").
    if isinstance(model, ContinualLearner) and model.si_c>0:
        for n, p in model.named_parameters():
            if p.requires_grad:
                n = n.replace('.', '__')
                model.register_buffer('{}_SI_prev_task'.format(n), p.detach().clone())

    # Loop over all tasks.
    for task, train_dataset in enumerate(train_datasets, 1):

        # If offline replay-setting, create large database of all tasks so far
        if replay_mode=="offline" and (not scenario=="task"):
            train_dataset = ConcatDataset(train_datasets[:task])
        # -but if "offline"+"task": all tasks so far should be visited separately (i.e., separate data-loader per task)
        if replay_mode=="offline" and scenario=="task":
            Offline_TaskIL = True
            data_loader = [None]*task

        # Initialize # iters left on data-loader(s)
        iters_left = 1 if (not Offline_TaskIL) else [1]*task

        # Prepare <dicts> to store running importance estimates and parameter-values before update
        if isinstance(model, ContinualLearner) and model.si_c>0:
            W = {}
            p_old = {}
            for n, p in model.named_parameters():
                if p.requires_grad:
                    n = n.replace('.', '__')
                    W[n] = p.data.clone().zero_()
                    p_old[n] = p.data.clone()

        # Find [active_classes] (=classes in current task)
        active_classes = None  #-> for "domain"- or "all"-scenarios, always all classes are active
        if scenario=="task":
            # -for "task"-scenario, create <list> with for all tasks so far a <list> with the active classes
            active_classes = [list(range(classes_per_task*i, classes_per_task*(i+1))) for i in range(task)]
        elif scenario=="class":
            # -for "class"-scenario, create one <list> with active classes of all tasks so far
            active_classes = list(range(classes_per_task*task))

        # Reinitialize the model's parameters (if requested)
        if reinit:
            from define_models import init_params
            init_params(model, args)
            if generator is not None:
                init_params(generator, args)

        # Define a tqdm progress bar(s)
        iters_main = iters
        progress = tqdm.tqdm(range(1, iters_main+1))
        if generator is not None:
            iters_gen = gen_iters
            progress_gen = tqdm.tqdm(range(1, iters_gen+1))

        # Loop over all iterations
        iters_to_use = (iters_main if (generator is None) else max(iters_main, iters_gen))
        # -if only the final task should be trained on:
        if only_last and not task==len(train_datasets):
            iters_to_use = 0

        # This helps w/ speeding up curated_classVariety
        mask = None
        if (sample_method=="curated_classVariety" and (task-1)>0):
            sampleAmt = batch_size_replay * curated_multiplier
            classNum = classes_per_task*(task-1)
            indexList = [[idx for idx in range(sampleAmt) if (idx%classNum) == (rowIdx%classNum)] for rowIdx in range(sampleAmt)]
            mask = []
            for rowIdxList in indexList:
                curRow = [0] * sampleAmt
                for idx in rowIdxList:
                    curRow[idx] = 1
                mask.append(curRow)
            mask = torch.tensor(mask, dtype=torch.float).to(device)


        for batch_index in range(1, iters_to_use+1):

            # Update # iters left on current data-loader(s) and, if needed, create new one(s)
            if not Offline_TaskIL:
                iters_left -= 1
                if iters_left==0:
                    data_loader = iter(utils.get_data_loader(train_dataset, batch_size, cuda=cuda, drop_last=True))
                    iters_left = len(data_loader)
            else:
                # -with "offline replay" in Task-IL scenario, there is a separate data-loader for each task
                batch_size_to_use = int(np.ceil(batch_size/task))
                for task_id in range(task):
                    iters_left[task_id] -= 1
                    if iters_left[task_id]==0:
                        data_loader[task_id] = iter(utils.get_data_loader(
                            train_datasets[task_id], batch_size_to_use, cuda=cuda, drop_last=True
                        ))
                        iters_left[task_id] = len(data_loader[task_id])



            #-----------------Collect data------------------#

            #####-----CURRENT BATCH-----#####
            if not Offline_TaskIL:
                x, y = next(data_loader)                                    #--> sample training data of current task
                y = y-classes_per_task*(task-1) if scenario=="task" else y  #--> ITL: adjust y-targets to 'active range'
                x, y = x.to(device), y.to(device)                           #--> transfer them to correct device
                #y = y.expand(1) if len(y.size())==1 else y                 #--> hack for if batch-size is 1
            else:
                x = y = task_used = None  #--> all tasks are "treated as replay"
                # -sample training data for all tasks so far, move to correct device and store in lists
                x_, y_ = list(), list()
                for task_id in range(task):
                    x_temp, y_temp = next(data_loader[task_id])
                    x_.append(x_temp.to(device))
                    y_temp = y_temp - (classes_per_task * task_id) #--> adjust y-targets to 'active range'
                    if batch_size_to_use == 1:
                        y_temp = torch.tensor([y_temp])            #--> correct dimensions if batch-size is 1
                    y_.append(y_temp.to(device))


            #####-----REPLAYED BATCH-----#####
            if not Offline_TaskIL and not Generative and not Current:
                x_ = y_ = scores_ = task_used = None   #-> if no replay

            #--------------------------------------------INPUTS----------------------------------------------------#

            ##-->> Current Replay <<--##
            if Current:
                x_ = x[:batch_size_replay]  #--> use current task inputs
                task_used = None


            ##-->> Generative Replay <<--##
            if Generative:
                #---> Only with generative replay, the resulting [x_] will be at the "hidden"-level
                conditional_gen = True if (
                    (previous_generator.per_class and previous_generator.prior=="GMM") or
                    utils.checkattr(previous_generator, 'dg_gates')
                ) else False

                # Sample [x_]
                if conditional_gen and scenario=="task":
                    # -if a conditional generator is used with task-IL scenario, generate data per previous task
                    x_ = list()
                    task_used = list()
                    for task_id in range(task-1):
                        allowed_classes = list(range(classes_per_task*task_id, classes_per_task*(task_id+1)))
                        batch_size_replay_to_use = int(np.ceil(batch_size_replay / (task-1)))
                        x_temp_ = previous_generator.sample(batch_size_replay_to_use, allowed_classes=allowed_classes,
                                                            only_x=False)
                        x_.append(x_temp_[0])
                        task_used.append(x_temp_[2])
                else:
                    # -which classes are allowed to be generated? (relevant if conditional generator / decoder-gates)
                    allowed_classes = None if scenario=="domain" else list(range(classes_per_task*(task-1)))
                    # -which tasks/domains are allowed to be generated? (only relevant if "Domain-IL" with task-gates)
                    allowed_domains = list(range(task-1))
                    # -generate inputs representative of previous tasks

                    # --- SAMPLE METHOD CHOICES: softmax, random, uniform, curated ---
                    # --- Softmax sampling: use previous model to score images from this new task, generate those classes
                    if sample_method == 'softmax':
                        with torch.no_grad():
                            curTaskID = task - 2
                            newScores_og = previous_model.classify(previous_model.input_to_hidden(x),
                                                                   not_hidden=False if Generative else True)
                            newScores = newScores_og[:, :(classes_per_task * (curTaskID + 1))]
                            softmax = torch.nn.Softmax(dim=1)
                            newHardScores = nn.Softmax(dim=1)(newScores)
                            avgError = torch.mean(newHardScores, dim=0)
                            sampleProbs = torch.zeros(newScores_og.shape[1])
                            sampleProbs[:(classes_per_task * (curTaskID + 1))] = avgError[
                                                                                 :(classes_per_task * (curTaskID + 1))]
                            x_, y_used, task_used = previous_generator.sample(
                                batch_size_replay, allowed_classes=allowed_classes, allowed_domains=allowed_domains,
                                only_x=False, class_probs=sampleProbs,uniform_sampling=False)
                        
                    # --- Uniformly random sampling (baseline) ---
                    elif sample_method == 'random':
                        x_, y_used, task_used = previous_generator.sample(
                            batch_size_replay, allowed_classes=allowed_classes, allowed_domains=allowed_domains,
                            only_x=False, class_probs=None, uniform_sampling=False)

                    # --- Uniform sampling: balanced numbers of samples from each class ---
                    elif sample_method == 'uniform':
                        x_, y_used, task_used = previous_generator.sample(
                            batch_size_replay, allowed_classes=allowed_classes, allowed_domains=allowed_domains,
                            only_x=False, class_probs=None, uniform_sampling=True)
                    # --- Uniform sample curation: pick the best samples to show (by some metric), balance uniformly ---
                    else:

                        if (sample_method == "curated_variety"):
                            # Generate x times as many samples as we need to then pick the best of
                            x_, y_used, task_used, varietyVector = previous_generator.sample(
                                batch_size_replay * curated_multiplier, allowed_classes=allowed_classes, allowed_domains=allowed_domains,
                                only_x=False, class_probs=None, uniform_sampling=False, varietyVector=True)

                        # CURATED USING CLASS VARIETY (i.e., generating batch_size_reply*curated_multipler / len(allowed_classes) samples 
                        # per class, where each sample is the "most different" sample based off our variety calculation 
                        elif(sample_method == "curated_classVariety"):
                            x_, y_used, task_used, varietyVector = previous_generator.sample(
                                batch_size_replay * curated_multiplier, allowed_classes=allowed_classes, allowed_domains=allowed_domains,
                                only_x=False, class_probs=None, uniform_sampling=True, varietyVector=True, classVariety=True, classVarietyMask=mask)

                        elif(sample_method == "curated_softmax"):

                            with torch.no_grad():
                                curTaskID = task - 2
                                newScores_og = previous_model.classify(previous_model.input_to_hidden(x),
                                                                       not_hidden=False if Generative else True)
                                newScores = newScores_og[:, :(classes_per_task * (curTaskID + 1))]
                                softmax = torch.nn.Softmax(dim=1)
                                newHardScores = nn.Softmax(dim=1)(newScores)
                                avgError = torch.mean(newHardScores, dim=0)
                                sampleProbs = torch.zeros(newScores_og.shape[1])
                                sampleProbs[:(classes_per_task * (curTaskID + 1))] = avgError[
                                                                                     :(classes_per_task * (curTaskID + 1))]

                            # Generate x times as many samples as we need to then pick the best of
                            x_, y_used, task_used = previous_generator.sample(
                                batch_size_replay * curated_multiplier, allowed_classes=allowed_classes, allowed_domains=allowed_domains,
                                only_x=False, class_probs=sampleProbs, uniform_sampling=False)


                        else: 
                            # Generate x times as many samples as we need to then pick the best of
                            x_, y_used, task_used = previous_generator.sample(
                                batch_size_replay * curated_multiplier, allowed_classes=allowed_classes, allowed_domains=allowed_domains,
                                only_x=False, class_probs=None, uniform_sampling=False)

                        # --- Measure the performance of each of these samples on the current model ---
                        # Use the previous model to score the generated images (code taken from Trevor's softmax above)
                        with torch.no_grad():
                            curTaskID = task - 2
                            newScores_og = model.classify(x_, not_hidden=False if Generative else True)
                            newScores = newScores_og[:, :(classes_per_task * (curTaskID + 1))] # Logits that don't sum to 1
                            newHardScores = nn.Softmax(dim=1)(newScores) # Makes the scores sum to 1 (probabilities)
                            cross_entropy = nn.CrossEntropyLoss(reduction='none')
                            y_used = torch.tensor(y_used, dtype=torch.long).to(device)
                            cross_entropy_loss = cross_entropy(newHardScores, y_used)

                        # --- Copy the model and perform an update on just the new incoming data (no replayed data) ---
                        # This will lead to catastrophic forgetting, as it has no replays to prevent this from happening
                        model_tmp = copy.deepcopy(model)
                        # NOTE: Can train multiple batches if needed, but it would be on the same data, so any changes will just be exacerbated
                        _ = model_tmp.train_a_batch(x, y=y, x_=None, y_=None, scores_=None,
                                                        tasks_=task_used, active_classes=active_classes, task=task, rnt=(
                                                            1. if task==1 else 1./task
                                                        ) if rnt is None else rnt, freeze_convE=freeze_convE,
                                                        replay_not_hidden=False if Generative else True)

                        # --- Measure the performance of each of the generated samples on this updated model ---
                        # This can tell us how much the model 'forgets' each of these samples, we will replay the worst ones
                        with torch.no_grad():
                            curTaskID = task - 2
                            newScores_og = model_tmp.classify(x_, not_hidden=False if Generative else True)
                            newScores = newScores_og[:, :(classes_per_task * (curTaskID + 2))] # Logits that don't sum to 1
                            newHardScores2 = nn.Softmax(dim=1)(newScores) # Makes the scores sum to 1 (probabilities)

                            # --- Measure the difference in cross entropy loss for predictions before and after ---
                            if sample_method == 'curated' or sample_method == "curated_softmax":
                                cross_entropy = nn.CrossEntropyLoss(reduction='none') # Per-example cross entropy (not avg)
                                cross_entropy_loss2 = cross_entropy(newHardScores2, y_used)

                                # Amount that the loss changes between the model updating
                                diff = cross_entropy_loss2 - cross_entropy_loss
                                metric = diff

                            # TREVOR'S NEW METHOD - This tries to take into account the variety of the samples
                            elif sample_method == "curated_variety" or sample_method == "curated_classVariety":
                                cross_entropy = nn.CrossEntropyLoss(reduction='none') # Per-example cross entropy (not avg)
                                cross_entropy_loss2 = cross_entropy(newHardScores2, y_used)

                                # Amount that the loss changes between the model updating
                                diff = cross_entropy_loss2 - cross_entropy_loss
                                
                                # Softmaxing diff and the variety vector
                                varietyWeight = 0.5
                                diff_SM = nn.Softmax(dim=0)(diff)
                                variety_SM = nn.Softmax(dim=0)(varietyVector)
                                metric = ((1-varietyWeight) * diff_SM) + (varietyWeight * variety_SM)

                            # Multiply the misclassification error (cross entropy) by the amount that this changes between the model updating
                            # metric = cross_entropy_loss2 * diff

                            # --- Measure KL Divergence between predictions before and predictions afterwards ---
                            # Maximally Interfered Retrieval uses a linear combination of KL, entropy, and 'variance'
                            # This ensures the samples are not too close together, but we do not currently measure that
                            elif sample_method == 'interfered':
                                KLDiv = nn.KLDivLoss(reduction='none')(newHardScores, newHardScores2)
                                print(KLDiv.shape)
                                print(cross_entropy_loss.shape)

                                # Test code to compute KL divergence for every example individually, above code is (512, 2) rather than (512, 1) for some reason
                                #KLDiv = [ nn.KLDivLoss()(newHardScores[i], newHardScores2[i]) for i in range(len(newHardScores))]

                                # Note from Trevor: When I tried to run this method, there was a size mismatch. 
                                # KLDiv.shape = (1024, 10), whereas cross_entropy_loss.shape = (1024,), and it said
                                # that they needed to be equal on dim 1. Sooo: I tried to just transpose the KLDiv matrix, 
                                # and it worked. Honestly, I'm too tired to try and decipher what I did mathematically lol
                                metric = torch.tensor(KLDiv.T) - 0 * cross_entropy_loss

                            # --- New idea: use the examples which the new model misclassifies the most as one of the new classes
                            # This the opposite approach to softmax, where softmax takes the current model and calculates
                            # Which classes does it confuse the new data for the most, this trains on the new data and then
                            # Tries to find generated examples which it confuses for the new data classes the most
                            elif sample_method == 'misclassified' or sample_method == 'uniform_large' or sample_method == 'random_large':
                                metric = newHardScores2[:, -1] + newHardScores2[:, -1]

                            # --- Sort based on some metric, then divide up by classes (afterwards) ---
                            sorted, indices = torch.sort(metric, descending=True) # Descending order, pick first 100

                            # Shuffle indices around to test choosing from this larger pool of generated samples randomly
                            if sample_method == 'uniform_large' or sample_method == 'random_large':
                                indices2 = indices.cpu().numpy()
                                np.random.shuffle(indices2)
                                indices = torch.from_numpy(indices2).to(device)

                            if sample_method != 'random_large' and sample_method != 'curated_softmax':
                                # --- Calculate how many examples for each class should be generated to divide up uniformly ---
                                # Uniform dist will be [0, 1, 2, 3, 0, 1, 2] for allowed classes=4 and batch_size_replay=7
                                uniform_dist = torch.arange(batch_size_replay) % len(allowed_classes)
                                counts_each_class = torch.unique(uniform_dist, return_counts=True)[1]

                                # --- Optional: Calculate unbalanced indices to replay, results in poor performance ---
                                # If we added a variation term to ensure samples are different from each other, this could
                                # be a simpler way to do things, but variance would be pretty complicated to calculate
                                #indices_to_replay = indices[:batch_size_replay]

                                # --- Select the top k_i indices for each class i, where k_i is the number of examples for that class ---
                                # Top x most affected of the generated samples for each class (ensures it is balanced, slightly more computation than unbalanced)
                                indices_to_replay = torch.cat(( [ indices[y_used[indices]==i][:counts_each_class[i]] for i in range(len(allowed_classes)) ] ))
                                x_ = x_[indices_to_replay]
                            else:
                                # Uniformly randomly choose from the 400 samples generated
                                x_ = x_[indices]

            #--------------------------------------------OUTPUTS----------------------------------------------------#

            if Generative or Current:
                # Get target scores & possibly labels (i.e., [scores_] / [y_]) -- use previous model, with no_grad()
                if scenario in ("domain", "class") and previous_model.mask_dict is None:
                    # -if replay does not need to be evaluated for each task (ie, not Task-IL and no task-specific mask)
                    with torch.no_grad():
                        all_scores_ = previous_model.classify(x_, not_hidden=False if Generative else True)
                    scores_ = all_scores_[:, :(classes_per_task*(task-1))] if (
                            scenario=="class"
                    ) else all_scores_ # -> when scenario=="class", zero probs will be added in [loss_fn_kd]-function
                    # -also get the 'hard target'
                    _, y_ = torch.max(scores_, dim=1)
                else:
                    # -[x_] needs to be evaluated according to each previous task, so make list with entry per task
                    scores_ = list()
                    y_ = list()
                    # -if no task-mask and no conditional generator, all scores can be calculated in one go
                    if previous_model.mask_dict is None and not type(x_)==list:
                        with torch.no_grad():
                            all_scores_ = previous_model.classify(x_, not_hidden=False if Generative else True)
                    for task_id in range(task-1):
                        # -if there is a task-mask (i.e., XdG is used), obtain predicted scores for each task separately
                        if previous_model.mask_dict is not None:
                            previous_model.apply_XdGmask(task=task_id+1)
                        if previous_model.mask_dict is not None or type(x_)==list:
                            with torch.no_grad():
                                all_scores_ = previous_model.classify(x_[task_id] if type(x_)==list else x_,
                                                                      not_hidden=False if Generative else True)
                        if scenario=="domain":
                            # NOTE: if scenario=domain with task-mask, it's of course actually the Task-IL scenario!
                            #       this can be used as trick to run the Task-IL scenario with singlehead output layer
                            temp_scores_ = all_scores_
                        else:
                            temp_scores_ = all_scores_[:, (classes_per_task*task_id):(classes_per_task*(task_id+1))]
                        scores_.append(temp_scores_)
                        # - also get hard target
                        _, temp_y_ = torch.max(temp_scores_, dim=1)
                        y_.append(temp_y_)
            # -only keep predicted y_/scores_ if required (as otherwise unnecessary computations will be done)
            y_ = y_ if (model.replay_targets=="hard") else None
            scores_ = scores_ if (model.replay_targets=="soft") else None



            #-----------------Train model(s)------------------#

            #---> Train MAIN MODEL
            if batch_index <= iters_main:

                # Train the main model with this batch
                loss_dict = model.train_a_batch(x, y=y, x_=x_, y_=y_, scores_=scores_,
                                                tasks_=task_used, active_classes=active_classes, task=task, rnt=(
                                                    1. if task==1 else 1./task
                                                ) if rnt is None else rnt, freeze_convE=freeze_convE,
                                                replay_not_hidden=False if Generative else True)


                # UNIFORM SAMPLE CURATION: loss_dict has a "predL_r" key that contains the individual prediction
                # losses 



                # Update running parameter importance estimates in W
                if isinstance(model, ContinualLearner) and model.si_c>0:
                    for n, p in model.convE.named_parameters():
                        if p.requires_grad:
                            n = "convE."+n
                            n = n.replace('.', '__')
                            if p.grad is not None:
                                W[n].add_(-p.grad*(p.detach()-p_old[n]))
                            p_old[n] = p.detach().clone()
                    for n, p in model.fcE.named_parameters():
                        if p.requires_grad:
                            n = "fcE."+n
                            n = n.replace('.', '__')
                            if p.grad is not None:
                                W[n].add_(-p.grad * (p.detach() - p_old[n]))
                            p_old[n] = p.detach().clone()
                    for n, p in model.classifier.named_parameters():
                        if p.requires_grad:
                            n = "classifier."+n
                            n = n.replace('.', '__')
                            if p.grad is not None:
                                W[n].add_(-p.grad * (p.detach() - p_old[n]))
                            p_old[n] = p.detach().clone()

                # Fire callbacks (for visualization of training-progress / evaluating performance after each task)
                for loss_cb in loss_cbs:
                    if loss_cb is not None:
                        loss_cb(progress, batch_index, loss_dict, task=task)
                for eval_cb in eval_cbs:
                    if eval_cb is not None:
                        eval_cb(model, batch_index, task=task)
                if model.label=="VAE":
                    for sample_cb in sample_cbs:
                        if sample_cb is not None:
                            sample_cb(model, batch_index, task=task, allowed_classes=None if (
                                    scenario=="domain"
                            ) else list(range(classes_per_task*task)))


            #---> Train GENERATOR
            if generator is not None and batch_index <= iters_gen:

                loss_dict = generator.train_a_batch(x, y=y, x_=x_, y_=y_, scores_=scores_,
                                                    tasks_=task_used, active_classes=active_classes, rnt=(
                                                        1. if task==1 else 1./task
                                                    ) if rnt is None else rnt, task=task,
                                                    freeze_convE=freeze_convE,
                                                    replay_not_hidden=False if Generative else True)

                # Fire callbacks on each iteration
                for loss_cb in gen_loss_cbs:
                    if loss_cb is not None:
                        loss_cb(progress_gen, batch_index, loss_dict, task=task)
                for sample_cb in sample_cbs:
                    if sample_cb is not None:
                        sample_cb(generator, batch_index, task=task, allowed_classes=None if (
                                    scenario=="domain"
                            ) else list(range(classes_per_task*task)))


        # Close progres-bar(s)
        progress.close()
        if generator is not None:
            progress_gen.close()


        ##----------> UPON FINISHING EACH TASK...

        # EWC: estimate Fisher Information matrix (FIM) and update term for quadratic penalty
        if isinstance(model, ContinualLearner) and model.ewc_lambda>0:
            # -find allowed classes
            allowed_classes = list(
                range(classes_per_task*(task-1), classes_per_task*task)
            ) if scenario=="task" else (list(range(classes_per_task*task)) if scenario=="class" else None)
            # -if needed, apply correct task-specific mask
            if model.mask_dict is not None:
                model.apply_XdGmask(task=task)
            # -estimate FI-matrix
            model.estimate_fisher(train_dataset, allowed_classes=allowed_classes)

        # SI: calculate and update the normalized path integral
        if isinstance(model, ContinualLearner) and model.si_c>0:
            model.update_omega(W, model.epsilon)

        # REPLAY: update source for replay
        previous_model = copy.deepcopy(model).eval()
        if replay_mode=="generative":
            Generative = True
            previous_generator = previous_model if feedback else copy.deepcopy(generator).eval()
        elif replay_mode=='current':
            Current = True
Exemple #2
0
def run(args, verbose=False):

    # Create plots- and results-directories if needed
    if not os.path.isdir(args.r_dir):
        os.mkdir(args.r_dir)
    if args.pdf and not os.path.isdir(args.p_dir):
        os.mkdir(args.p_dir)

    # If only want param-stamp, get it and exit
    if args.get_stamp:
        from param_stamp import get_param_stamp_from_args
        print(get_param_stamp_from_args(args=args))
        exit()

    # Use cuda?
    cuda = torch.cuda.is_available() and args.cuda
    device = torch.device("cuda" if cuda else "cpu")

    # Report whether cuda is used
    if verbose:
        print("CUDA is {}used".format("" if cuda else "NOT(!!) "))

    # Set random seeds
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if cuda:
        torch.cuda.manual_seed(args.seed)

    #-------------------------------------------------------------------------------------------------#

    #----------------#
    #----- DATA -----#
    #----------------#

    # Prepare data for chosen experiment
    if verbose:
        print("\nPreparing the data...")
    (train_datasets,
     test_datasets), config, classes_per_task = get_multitask_experiment(
         name=args.experiment,
         scenario=args.scenario,
         tasks=args.tasks,
         data_dir=args.d_dir,
         normalize=True if utils.checkattr(args, "normalize") else False,
         augment=True if utils.checkattr(args, "augment") else False,
         verbose=verbose,
         exception=True if args.seed < 10 else False,
         only_test=(not args.train))

    #-------------------------------------------------------------------------------------------------#

    #----------------------#
    #----- MAIN MODEL -----#
    #----------------------#

    # Define main model (i.e., classifier, if requested with feedback connections)
    if verbose and (utils.checkattr(args, "pre_convE") or utils.checkattr(args, "pre_convD")) and \
            (hasattr(args, "depth") and args.depth>0):
        print("\nDefining the model...")
    if utils.checkattr(args, 'feedback'):
        model = define.define_autoencoder(args=args,
                                          config=config,
                                          device=device)
    else:
        model = define.define_classifier(args=args,
                                         config=config,
                                         device=device)

    # Initialize / use pre-trained / freeze model-parameters
    # - initialize (pre-trained) parameters
    model = define.init_params(model, args)
    # - freeze weights of conv-layers?
    if utils.checkattr(args, "freeze_convE"):
        for param in model.convE.parameters():
            param.requires_grad = False
    if utils.checkattr(args, 'feedback') and utils.checkattr(
            args, "freeze_convD"):
        for param in model.convD.parameters():
            param.requires_grad = False

    # Define optimizer (only optimize parameters that "requires_grad")
    model.optim_list = [
        {
            'params': filter(lambda p: p.requires_grad, model.parameters()),
            'lr': args.lr
        },
    ]
    model.optimizer = optim.Adam(model.optim_list, betas=(0.9, 0.999))

    #-------------------------------------------------------------------------------------------------#

    #----------------------------------------------------#
    #----- CL-STRATEGY: REGULARIZATION / ALLOCATION -----#
    #----------------------------------------------------#

    # Elastic Weight Consolidation (EWC)
    if isinstance(model, ContinualLearner) and utils.checkattr(args, 'ewc'):
        model.ewc_lambda = args.ewc_lambda if args.ewc else 0
        model.fisher_n = args.fisher_n
        model.online = utils.checkattr(args, 'online')
        if model.online:
            model.gamma = args.gamma

    # Synpatic Intelligence (SI)
    if isinstance(model, ContinualLearner) and utils.checkattr(args, 'si'):
        model.si_c = args.si_c if args.si else 0
        model.epsilon = args.epsilon

    # XdG: create for every task a "mask" for each hidden fully connected layer
    if isinstance(model, ContinualLearner) and utils.checkattr(
            args, 'xdg') and args.xdg_prop > 0:
        model.define_XdGmask(gating_prop=args.xdg_prop, n_tasks=args.tasks)

    #-------------------------------------------------------------------------------------------------#

    #-------------------------------#
    #----- CL-STRATEGY: REPLAY -----#
    #-------------------------------#

    # Use distillation loss (i.e., soft targets) for replayed data? (and set temperature)
    if isinstance(model, ContinualLearner) and hasattr(
            args, 'replay') and not args.replay == "none":
        model.replay_targets = "soft" if args.distill else "hard"
        model.KD_temp = args.temp

    # If needed, specify separate model for the generator
    train_gen = (hasattr(args, 'replay') and args.replay == "generative"
                 and not utils.checkattr(args, 'feedback'))
    if train_gen:
        # Specify architecture
        generator = define.define_autoencoder(args,
                                              config,
                                              device,
                                              generator=True)

        # Initialize parameters
        generator = define.init_params(generator, args)
        # -freeze weights of conv-layers?
        if utils.checkattr(args, "freeze_convE"):
            for param in generator.convE.parameters():
                param.requires_grad = False
        if utils.checkattr(args, "freeze_convD"):
            for param in generator.convD.parameters():
                param.requires_grad = False

        # Set optimizer(s)
        generator.optim_list = [
            {
                'params': filter(lambda p: p.requires_grad,
                                 generator.parameters()),
                'lr': args.lr_gen if hasattr(args, 'lr_gen') else args.lr
            },
        ]
        generator.optimizer = optim.Adam(generator.optim_list,
                                         betas=(0.9, 0.999))
    else:
        generator = None

    #-------------------------------------------------------------------------------------------------#

    #---------------------#
    #----- REPORTING -----#
    #---------------------#

    # Get parameter-stamp (and print on screen)
    if verbose:
        print("\nParameter-stamp...")
    param_stamp = get_param_stamp(
        args,
        model.name,
        verbose=verbose,
        replay=True if
        (hasattr(args, 'replay') and not args.replay == "none") else False,
        replay_model_name=generator.name if
        (hasattr(args, 'replay') and args.replay in ("generative")
         and not utils.checkattr(args, 'feedback')) else None,
    )

    # Print some model-characteristics on the screen
    if verbose:
        # -main model
        utils.print_model_info(model, title="MAIN MODEL")
        # -generator
        if generator is not None:
            utils.print_model_info(generator, title="GENERATOR")

    # Define [progress_dicts] to keep track of performance during training for storing and for later plotting in pdf
    precision_dict = evaluate.initiate_precision_dict(args.tasks)

    # Prepare for plotting in visdom
    visdom = None
    if args.visdom:
        env_name = "{exp}{tasks}-{scenario}".format(exp=args.experiment,
                                                    tasks=args.tasks,
                                                    scenario=args.scenario)
        replay_statement = "{mode}{fb}{con}{gat}{int}{dis}{b}{u}".format(
            mode=args.replay,
            fb="Rtf" if utils.checkattr(args, "feedback") else "",
            con="Con" if (hasattr(args, "prior") and args.prior == "GMM"
                          and utils.checkattr(args, "per_class")) else "",
            gat="Gat{}".format(args.dg_prop) if
            (utils.checkattr(args, "dg_gates") and hasattr(args, "dg_prop")
             and args.dg_prop > 0) else "",
            int="Int" if utils.checkattr(args, "hidden") else "",
            dis="Dis" if args.replay == "generative" and args.distill else "",
            b="" if
            (args.batch_replay is None or args.batch_replay == args.batch) else
            "-br{}".format(args.batch_replay),
            u="" if args.g_fc_uni == args.fc_units else "-gu{}".format(
                args.g_fc_uni)) if (hasattr(args, "replay")
                                    and not args.replay == "none") else "NR"
        graph_name = "{replay}{syn}{ewc}{xdg}".format(
            replay=replay_statement,
            syn="-si{}".format(args.si_c)
            if utils.checkattr(args, 'si') else "",
            ewc="-ewc{}{}".format(
                args.ewc_lambda, "-O{}".format(args.gamma)
                if utils.checkattr(args, "online") else "") if utils.checkattr(
                    args, 'ewc') else "",
            xdg="" if (not utils.checkattr(args, 'xdg')) or args.xdg_prop == 0
            else "-XdG{}".format(args.xdg_prop),
        )
        visdom = {'env': env_name, 'graph': graph_name}

    #-------------------------------------------------------------------------------------------------#

    #---------------------#
    #----- CALLBACKS -----#
    #---------------------#

    g_iters = args.g_iters if hasattr(args, 'g_iters') else args.iters

    # Callbacks for reporting on and visualizing loss
    generator_loss_cbs = [
        cb._VAE_loss_cb(
            log=args.loss_log,
            visdom=visdom,
            replay=(hasattr(args, "replay") and not args.replay == "none"),
            model=model if utils.checkattr(args, 'feedback') else generator,
            tasks=args.tasks,
            iters_per_task=args.iters
            if utils.checkattr(args, 'feedback') else g_iters)
    ] if (train_gen or utils.checkattr(args, 'feedback')) else [None]
    solver_loss_cbs = [
        cb._solver_loss_cb(log=args.loss_log,
                           visdom=visdom,
                           model=model,
                           iters_per_task=args.iters,
                           tasks=args.tasks,
                           replay=(hasattr(args, "replay")
                                   and not args.replay == "none"))
    ] if (not utils.checkattr(args, 'feedback')) else [None]

    # Callbacks for evaluating and plotting generated / reconstructed samples
    no_samples = (utils.checkattr(args, "no_samples")
                  or (utils.checkattr(args, "hidden")
                      and hasattr(args, 'depth') and args.depth > 0))
    sample_cbs = [
        cb._sample_cb(log=args.sample_log,
                      visdom=visdom,
                      config=config,
                      test_datasets=test_datasets,
                      sample_size=args.sample_n,
                      iters_per_task=g_iters)
    ] if ((train_gen or utils.checkattr(args, 'feedback'))
          and not no_samples) else [None]

    # Callbacks for reporting and visualizing accuracy, and visualizing representation extracted by main model
    # -visdom (i.e., after each [prec_log]
    eval_cb = cb._eval_cb(
        log=args.prec_log,
        test_datasets=test_datasets,
        visdom=visdom,
        precision_dict=None,
        iters_per_task=args.iters,
        test_size=args.prec_n,
        classes_per_task=classes_per_task,
        scenario=args.scenario,
    )
    # -pdf / reporting: summary plots (i.e, only after each task)
    eval_cb_full = cb._eval_cb(
        log=args.iters,
        test_datasets=test_datasets,
        precision_dict=precision_dict,
        iters_per_task=args.iters,
        classes_per_task=classes_per_task,
        scenario=args.scenario,
    )
    # -visualize feature space
    latent_space_cb = cb._latent_space_cb(
        log=args.iters,
        datasets=test_datasets,
        visdom=visdom,
        iters_per_task=args.iters,
        sample_size=400,
    )
    # -collect them in <lists>
    eval_cbs = [eval_cb, eval_cb_full, latent_space_cb]

    #-------------------------------------------------------------------------------------------------#

    #--------------------#
    #----- TRAINING -----#
    #--------------------#

    if args.train:
        if verbose:
            print("\nTraining...")
        # Train model
        train_cl(
            model,
            train_datasets,
            replay_mode=args.replay if hasattr(args, 'replay') else "none",
            scenario=args.scenario,
            classes_per_task=classes_per_task,
            iters=args.iters,
            batch_size=args.batch,
            batch_size_replay=args.batch_replay if hasattr(
                args, 'batch_replay') else None,
            generator=generator,
            gen_iters=g_iters,
            gen_loss_cbs=generator_loss_cbs,
            feedback=utils.checkattr(args, 'feedback'),
            sample_cbs=sample_cbs,
            eval_cbs=eval_cbs,
            loss_cbs=generator_loss_cbs
            if utils.checkattr(args, 'feedback') else solver_loss_cbs,
            args=args,
            reinit=utils.checkattr(args, 'reinit'),
            only_last=utils.checkattr(args, 'only_last'))
        # Save evaluation metrics measured throughout training
        file_name = "{}/dict-{}".format(args.r_dir, param_stamp)
        utils.save_object(precision_dict, file_name)
        # Save trained model(s), if requested
        if args.save:
            save_name = "mM-{}".format(param_stamp) if (
                not hasattr(args, 'full_stag')
                or args.full_stag == "none") else "{}-{}".format(
                    model.name, args.full_stag)
            utils.save_checkpoint(model,
                                  args.m_dir,
                                  name=save_name,
                                  verbose=verbose)
            if generator is not None:
                save_name = "gM-{}".format(param_stamp) if (
                    not hasattr(args, 'full_stag')
                    or args.full_stag == "none") else "{}-{}".format(
                        generator.name, args.full_stag)
                utils.save_checkpoint(generator,
                                      args.m_dir,
                                      name=save_name,
                                      verbose=verbose)

    else:
        # Load previously trained model(s) (if goal is to only evaluate previously trained model)
        if verbose:
            print("\nLoading parameters of the previously trained models...")
        load_name = "mM-{}".format(param_stamp) if (
            not hasattr(args, 'full_ltag')
            or args.full_ltag == "none") else "{}-{}".format(
                model.name, args.full_ltag)
        utils.load_checkpoint(
            model,
            args.m_dir,
            name=load_name,
            verbose=verbose,
            add_si_buffers=(isinstance(model, ContinualLearner)
                            and utils.checkattr(args, 'si')))
        if generator is not None:
            load_name = "gM-{}".format(param_stamp) if (
                not hasattr(args, 'full_ltag')
                or args.full_ltag == "none") else "{}-{}".format(
                    generator.name, args.full_ltag)
            utils.load_checkpoint(generator,
                                  args.m_dir,
                                  name=load_name,
                                  verbose=verbose)

    #-------------------------------------------------------------------------------------------------#

    #-----------------------------------#
    #----- EVALUATION of CLASSIFIER-----#
    #-----------------------------------#

    if verbose:
        print("\n\nEVALUATION RESULTS:")

    # Evaluate precision of final model on full test-set
    precs = [
        evaluate.validate(
            model,
            test_datasets[i],
            verbose=False,
            test_size=None,
            task=i + 1,
            allowed_classes=list(
                range(classes_per_task * i, classes_per_task *
                      (i + 1))) if args.scenario == "task" else None)
        for i in range(args.tasks)
    ]
    average_precs = sum(precs) / args.tasks
    # -print on screen
    if verbose:
        print("\n Accuracy of final model on test-set:")
        for i in range(args.tasks):
            print(" - {} {}: {:.4f}".format(
                "For classes from task"
                if args.scenario == "class" else "Task", i + 1, precs[i]))
        print('=> Average accuracy over all {} {}: {:.4f}\n'.format(
            args.tasks *
            classes_per_task if args.scenario == "class" else args.tasks,
            "classes" if args.scenario == "class" else "tasks", average_precs))
    # -write out to text file
    output_file = open("{}/prec-{}.txt".format(args.r_dir, param_stamp), 'w')
    output_file.write('{}\n'.format(average_precs))
    output_file.close()

    #-------------------------------------------------------------------------------------------------#

    #-----------------------------------#
    #----- EVALUATION of GENERATOR -----#
    #-----------------------------------#

    if (utils.checkattr(args, 'feedback') or train_gen
        ) and args.experiment == "CIFAR100" and args.scenario == "class":

        # Dataset and model to be used
        test_set = ConcatDataset(test_datasets)
        gen_model = model if utils.checkattr(args, 'feedback') else generator
        gen_model.eval()

        # Evaluate log-likelihood of generative model on combined test-set (with S=100 importance samples per datapoint)
        ll_per_datapoint = gen_model.estimate_loglikelihood(
            test_set, S=100, batch_size=args.batch)
        if verbose:
            print('=> Log-likelihood on test set: {:.4f} +/- {:.4f}\n'.format(
                np.mean(ll_per_datapoint), np.sqrt(np.var(ll_per_datapoint))))
        # -write out to text file
        output_file = open("{}/ll-{}.txt".format(args.r_dir, param_stamp), 'w')
        output_file.write('{}\n'.format(np.mean(ll_per_datapoint)))
        output_file.close()

        # Evaluate reconstruction error (averaged over number of input units)
        re_per_datapoint = gen_model.calculate_recon_error(
            test_set, batch_size=args.batch, average=True)
        if verbose:
            print(
                '=> Reconstruction error (per input unit) on test set: {:.4f} +/- {:.4f}\n'
                .format(np.mean(re_per_datapoint),
                        np.sqrt(np.var(re_per_datapoint))))
        # -write out to text file
        output_file = open("{}/re-{}.txt".format(args.r_dir, param_stamp), 'w')
        output_file.write('{}\n'.format(np.mean(re_per_datapoint)))
        output_file.close()

        # Try loading the classifier (our substitute for InceptionNet) for calculating IS, FID and Recall & Precision
        # -define model
        config['classes'] = 100
        pretrained_classifier = define.define_classifier(args=args,
                                                         config=config,
                                                         device=device)
        pretrained_classifier.hidden = False
        # -load pretrained weights
        eval_tag = "" if args.eval_tag == "none" else "-{}".format(
            args.eval_tag)
        try:
            utils.load_checkpoint(pretrained_classifier,
                                  args.m_dir,
                                  verbose=True,
                                  name="{}{}".format(
                                      pretrained_classifier.name, eval_tag))
            FileFound = True
        except FileNotFoundError:
            if verbose:
                print("= Could not find model {}{} in {}".format(
                    pretrained_classifier.name, eval_tag, args.m_dir))
                print("= IS, FID and Precision & Recall not computed!")
            FileFound = False
        pretrained_classifier.eval()

        # Only continue with computing these measures if the requested classifier network (using --eval-tag) was found
        if FileFound:
            # Preparations
            total_n = len(test_set)
            n_repeats = int(np.ceil(total_n / args.batch))
            # -sample data from generator (for IS, FID and Precision & Recall)
            gen_x = gen_model.sample(size=total_n, only_x=True)
            # -generate predictions for generated data (for IS)
            gen_pred = []
            for i in range(n_repeats):
                x = gen_x[(i *
                           args.batch):int(min(((i + 1) *
                                                args.batch), total_n))]
                with torch.no_grad():
                    gen_pred.append(
                        F.softmax(pretrained_classifier.hidden_to_output(x)
                                  if args.hidden else pretrained_classifier(x),
                                  dim=1).cpu().numpy())
            gen_pred = np.concatenate(gen_pred)
            # -generate embeddings for generated data (for FID and Precision & Recall)
            gen_emb = []
            for i in range(n_repeats):
                with torch.no_grad():
                    gen_emb.append(
                        pretrained_classifier.feature_extractor(
                            gen_x[(i * args.batch
                                   ):int(min(((i + 1) *
                                              args.batch), total_n))],
                            from_hidden=args.hidden).cpu().numpy())
            gen_emb = np.concatenate(gen_emb)
            # -generate embeddings for test data (for FID and Precision & Recall)
            data_loader = utils.get_data_loader(test_set,
                                                batch_size=args.batch,
                                                cuda=cuda)
            real_emb = []
            for real_x, _ in data_loader:
                with torch.no_grad():
                    real_emb.append(
                        pretrained_classifier.feature_extractor(
                            real_x.to(device)).cpu().numpy())
            real_emb = np.concatenate(real_emb)

            # Calculate "Inception Score" (IS)
            py = gen_pred.mean(axis=0)
            is_per_datapoint = []
            for i in range(len(gen_pred)):
                pyx = gen_pred[i, :]
                is_per_datapoint.append(entropy(pyx, py))
            IS = np.exp(np.mean(is_per_datapoint))
            if verbose:
                print('=> Inception Score = {:.4f}\n'.format(IS))
            # -write out to text file
            output_file = open(
                "{}/is{}-{}.txt".format(args.r_dir, eval_tag, param_stamp),
                'w')
            output_file.write('{}\n'.format(IS))
            output_file.close()

            ## Calculate "Frechet Inception Distance" (FID)
            FID = fid.calculate_fid_from_embedding(gen_emb, real_emb)
            if verbose:
                print('=> Frechet Inception Distance = {:.4f}\n'.format(FID))
            # -write out to text file
            output_file = open(
                "{}/fid{}-{}.txt".format(args.r_dir, eval_tag, param_stamp),
                'w')
            output_file.write('{}\n'.format(FID))
            output_file.close()

            # Calculate "Precision & Recall"-curves
            precision, recall = pr.compute_prd_from_embedding(
                gen_emb, real_emb)
            # -write out to text files
            file_name = "{}/precision{}-{}.txt".format(args.r_dir, eval_tag,
                                                       param_stamp)
            with open(file_name, 'w') as f:
                for item in precision:
                    f.write("%s\n" % item)
            file_name = "{}/recall{}-{}.txt".format(args.r_dir, eval_tag,
                                                    param_stamp)
            with open(file_name, 'w') as f:
                for item in recall:
                    f.write("%s\n" % item)

    #-------------------------------------------------------------------------------------------------#

    #--------------------#
    #----- PLOTTING -----#
    #--------------------#

    # If requested, generate pdf
    if args.pdf:
        # -open pdf
        plot_name = "{}/{}.pdf".format(args.p_dir, param_stamp)
        pp = evaluate.visual.plt.open_pdf(plot_name)

        # -show metrics reflecting progression during training
        if args.train and (not utils.checkattr(args, 'only_last')):
            # -create list to store all figures to be plotted.
            figure_list = []
            # -generate figures (and store them in [figure_list])
            figure = evaluate.visual.plt.plot_lines(
                precision_dict["all_tasks"],
                x_axes=[
                    i * classes_per_task for i in precision_dict["x_task"]
                ] if args.scenario == "class" else precision_dict["x_task"],
                line_names=[
                    '{} {}'.format(
                        "episode / task"
                        if args.scenario == "class" else "task", i + 1)
                    for i in range(args.tasks)
                ],
                xlabel="# of {}s so far".format("classe" if args.scenario ==
                                                "class" else "task"),
                ylabel="Test accuracy")
            figure_list.append(figure)
            figure = evaluate.visual.plt.plot_lines(
                [precision_dict["average"]],
                x_axes=[
                    i * classes_per_task for i in precision_dict["x_task"]
                ] if args.scenario == "class" else precision_dict["x_task"],
                line_names=[
                    'Average based on all {}s so far'.format((
                        "digit" if args.experiment == "splitMNIST" else
                        "classe") if args.scenario else "task")
                ],
                xlabel="# of {}s so far".format("classe" if args.scenario ==
                                                "class" else "task"),
                ylabel="Test accuracy")
            figure_list.append(figure)
            # -add figures to pdf
            for figure in figure_list:
                pp.savefig(figure)

        gen_eval = (utils.checkattr(args, 'feedback') or train_gen)
        # -show samples (from main model or separate generator)
        if gen_eval and not no_samples:
            evaluate.show_samples(
                model if utils.checkattr(args, 'feedback') else generator,
                config,
                size=args.sample_n,
                pdf=pp,
                title="Generated samples (by final model)")

        # -plot "Precision & Recall"-curve
        if gen_eval and args.experiment == "CIFAR100" and args.scenario == "class" and FileFound:
            figure = evaluate.visual.plt.plot_pr_curves([[precision]],
                                                        [[recall]])
            pp.savefig(figure)

        # -close pdf
        pp.close()

        # -print name of generated plot on screen
        if verbose:
            print("\nGenerated plot: {}\n".format(plot_name))
def run(args, model_name, shift, slot, verbose=False):

    # Create plots- and results-directories if needed
    if not os.path.isdir(args.r_dir):
        os.mkdir(args.r_dir)
    if args.pdf and not os.path.isdir(args.p_dir):
        os.mkdir(args.p_dir)

    # If only want param-stamp, get it and exit
    if args.get_stamp:
        from param_stamp import get_param_stamp_from_args
        print(get_param_stamp_from_args(args=args))
        exit()

    # Use cuda?
    cuda = torch.cuda.is_available() and args.cuda
    device = torch.device("cuda" if cuda else "cpu")

    # Report whether cuda is used
    if verbose:
        print("CUDA is {}used".format("" if cuda else "NOT(!!) "))

    # Set random seeds
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if cuda:
        torch.cuda.manual_seed(args.seed)

    #-------------------------------------------------------------------------------------------------#

    #----------------#
    #----- DATA -----#
    #----------------#

    # Prepare data for chosen experiment
    if verbose:
        print("\nPreparing the data...")
    (train_datasets,
     test_datasets), config, classes_per_task = get_multitask_experiment(
         name=args.experiment,
         tasks=args.tasks,
         slot=args.slot,
         shift=args.shift,
         data_dir=args.d_dir,
         normalize=True if utils.checkattr(args, "normalize") else False,
         augment=True if utils.checkattr(args, "augment") else False,
         verbose=verbose,
         exception=True if args.seed < 10 else False,
         only_test=(not args.train),
         max_samples=args.max_samples)

    #-------------------------------------------------------------------------------------------------#

    #----------------------#
    #----- MAIN MODEL -----#
    #----------------------#

    # Define main model (i.e., classifier, if requested with feedback connections)
    if verbose and utils.checkattr(
            args, "pre_convE") and (hasattr(args, "depth") and args.depth > 0):
        print("\nDefining the model...")
    model = define.define_classifier(args=args, config=config, device=device)

    # Initialize / use pre-trained / freeze model-parameters
    # - initialize (pre-trained) parameters
    model = define.init_params(model, args)
    # - freeze weights of conv-layers?
    if utils.checkattr(args, "freeze_convE"):
        for param in model.convE.parameters():
            param.requires_grad = False

    # Define optimizer (only optimize parameters that "requires_grad")
    model.optim_list = [
        {
            'params': filter(lambda p: p.requires_grad, model.parameters()),
            'lr': args.lr
        },
    ]
    model.optimizer = optim.Adam(model.optim_list, betas=(0.9, 0.999))

    #-------------------------------------------------------------------------------------------------#

    #----------------------------------#
    #----- CL-STRATEGY: EXEMPLARS -----#
    #----------------------------------#

    # Store in model whether, how many and in what way to store exemplars
    if isinstance(model, ExemplarHandler) and (args.use_exemplars
                                               or args.replay == "exemplars"):
        model.memory_budget = args.budget
        model.herding = args.herding
        model.norm_exemplars = args.herding

    #-------------------------------------------------------------------------------------------------#

    #----------------------------------------------------#
    #----- CL-STRATEGY: REGULARIZATION / ALLOCATION -----#
    #----------------------------------------------------#

    # Elastic Weight Consolidation (EWC)
    if isinstance(model, ContinualLearner) and utils.checkattr(args, 'ewc'):
        model.ewc_lambda = args.ewc_lambda if args.ewc else 0
        model.fisher_n = args.fisher_n
        model.online = utils.checkattr(args, 'online')
        if model.online:
            model.gamma = args.gamma

    # Synpatic Intelligence (SI)
    if isinstance(model, ContinualLearner) and utils.checkattr(args, 'si'):
        model.si_c = args.si_c if args.si else 0
        model.epsilon = args.epsilon

    # XdG: create for every task a "mask" for each hidden fully connected layer
    if isinstance(model, ContinualLearner) and utils.checkattr(
            args, 'xdg') and args.xdg_prop > 0:
        model.define_XdGmask(gating_prop=args.xdg_prop, n_tasks=args.tasks)

    #-------------------------------------------------------------------------------------------------#

    #-------------------------------#
    #----- CL-STRATEGY: REPLAY -----#
    #-------------------------------#

    # Use distillation loss (i.e., soft targets) for replayed data? (and set temperature)
    if isinstance(model, ContinualLearner) and hasattr(
            args, 'replay') and not args.replay == "none":
        model.replay_targets = "soft" if args.distill else "hard"
        model.KD_temp = args.temp

    #-------------------------------------------------------------------------------------------------#

    #---------------------#
    #----- REPORTING -----#
    #---------------------#

    # Get parameter-stamp (and print on screen)
    if verbose:
        print("\nParameter-stamp...")
    param_stamp, reinit_param_stamp = get_param_stamp(
        args,
        model.name,
        verbose=verbose,
        replay=True if
        (hasattr(args, 'replay') and not args.replay == "none") else False,
    )

    # Print some model-characteristics on the screen
    if verbose:
        # -main model
        utils.print_model_info(model, title="MAIN MODEL")

    # Prepare for keeping track of statistics required for metrics (also used for plotting in pdf)
    if args.pdf or args.metrics:
        # -define [metrics_dict] to keep track of performance during training for storing & for later plotting in pdf
        metrics_dict = evaluate.initiate_metrics_dict(n_tasks=args.tasks)
        # -evaluate randomly initiated model on all tasks & store accuracies in [metrics_dict] (for calculating metrics)
        if not args.use_exemplars:
            metrics_dict = evaluate.intial_accuracy(
                model,
                test_datasets,
                metrics_dict,
                no_task_mask=False,
                classes_per_task=classes_per_task,
                test_size=None)
    else:
        metrics_dict = None

    # Prepare for plotting in visdom
    visdom = None
    if args.visdom:
        env_name = "{exp}-{tasks}".format(exp=args.experiment,
                                          tasks=args.tasks)
        replay_statement = "{mode}{b}".format(
            mode=args.replay,
            b="" if
            (args.batch_replay is None or args.batch_replay == args.batch) else
            "-br{}".format(args.batch_replay),
        ) if (hasattr(args, "replay") and not args.replay == "none") else "NR"
        graph_name = "{replay}{syn}{ewc}{xdg}".format(
            replay=replay_statement,
            syn="-si{}".format(args.si_c)
            if utils.checkattr(args, 'si') else "",
            ewc="-ewc{}{}".format(
                args.ewc_lambda, "-O{}".format(args.gamma)
                if utils.checkattr(args, "online") else "") if utils.checkattr(
                    args, 'ewc') else "",
            xdg="" if (not utils.checkattr(args, 'xdg')) or args.xdg_prop == 0
            else "-XdG{}".format(args.xdg_prop),
        )
        visdom = {'env': env_name, 'graph': graph_name}

    #-------------------------------------------------------------------------------------------------#

    #---------------------#
    #----- CALLBACKS -----#
    #---------------------#

    # Callbacks for reporting on and visualizing loss
    solver_loss_cbs = [
        cb._solver_loss_cb(log=args.loss_log,
                           visdom=visdom,
                           model=model,
                           iters_per_task=args.iters,
                           tasks=args.tasks,
                           replay=(hasattr(args, "replay")
                                   and not args.replay == "none"))
    ]

    # Callbacks for reporting and visualizing accuracy
    # -visdom (i.e., after each [prec_log]
    eval_cbs = [
        cb._eval_cb(log=args.prec_log,
                    test_datasets=test_datasets,
                    visdom=visdom,
                    iters_per_task=args.iters,
                    test_size=args.prec_n,
                    classes_per_task=classes_per_task,
                    with_exemplars=False)
    ] if (not args.use_exemplars) else [None]
    #--> during training on a task, evaluation cannot be with exemplars as those are only selected after training
    #    (instead, evaluation for visdom is only done after each task, by including callback-function into [metric_cbs])

    # Callbacks for calculating statists required for metrics
    # -pdf / reporting: summary plots (i.e, only after each task) (when using exemplars, also for visdom)
    metric_cbs = [
        cb._metric_cb(log=args.iters,
                      test_datasets=test_datasets,
                      classes_per_task=classes_per_task,
                      metrics_dict=metrics_dict,
                      iters_per_task=args.iters,
                      with_exemplars=args.use_exemplars),
        cb._eval_cb(log=args.iters,
                    test_datasets=test_datasets,
                    visdom=visdom,
                    iters_per_task=args.iters,
                    test_size=args.prec_n,
                    classes_per_task=classes_per_task,
                    with_exemplars=True) if args.use_exemplars else None
    ]

    #-------------------------------------------------------------------------------------------------#

    #--------------------#
    #----- TRAINING -----#
    #--------------------#

    if args.train:
        if verbose:
            print("\nTraining...")
        # Train model
        train_cl(
            model,
            train_datasets,
            model_name=model_name,
            shift=shift,
            slot=slot,
            replay_mode=args.replay if hasattr(args, 'replay') else "none",
            classes_per_task=classes_per_task,
            iters=args.iters,
            args=args,
            batch_size=args.batch,
            batch_size_replay=args.batch_replay if hasattr(
                args, 'batch_replay') else None,
            eval_cbs=eval_cbs,
            loss_cbs=solver_loss_cbs,
            reinit=utils.checkattr(args, 'reinit'),
            only_last=utils.checkattr(args, 'only_last'),
            metric_cbs=metric_cbs,
            use_exemplars=args.use_exemplars,
        )
        # Save trained model(s), if requested
        if args.save:
            save_name = "mM-{}".format(param_stamp) if (
                not hasattr(args, 'full_stag')
                or args.full_stag == "none") else "{}-{}".format(
                    model.name, args.full_stag)
            utils.save_checkpoint(model,
                                  args.m_dir,
                                  name=save_name,
                                  verbose=verbose)
    else:
        # Load previously trained model(s) (if goal is to only evaluate previously trained model)
        if verbose:
            print("\nLoading parameters of the previously trained models...")
        load_name = "mM-{}".format(param_stamp) if (
            not hasattr(args, 'full_ltag')
            or args.full_ltag == "none") else "{}-{}".format(
                model.name, args.full_ltag)
        utils.load_checkpoint(
            model,
            args.m_dir,
            name=load_name,
            verbose=verbose,
            add_si_buffers=(isinstance(model, ContinualLearner)
                            and utils.checkattr(args, 'si')))
        # Load previously created metrics-dict
        file_name = "{}/dict-{}".format(args.r_dir, param_stamp)
        metrics_dict = utils.load_object(file_name)

    #-------------------------------------------------------------------------------------------------#

    #-----------------------------------#
    #----- EVALUATION of CLASSIFIER-----#
    #-----------------------------------#

    if verbose:
        print("\n\nEVALUATION RESULTS:")

    # Evaluate precision of final model on full test-set
    precs = [
        evaluate.validate(model,
                          test_datasets[i],
                          verbose=False,
                          test_size=None,
                          task=i + 1,
                          with_exemplars=False,
                          allowed_classes=list(
                              range(classes_per_task * i,
                                    classes_per_task * (i + 1))))
        for i in range(args.tasks)
    ]
    average_precs = sum(precs) / args.tasks
    # -print on screen
    if verbose:
        print("\n Precision on test-set{}:".format(
            " (softmax classification)" if args.use_exemplars else ""))
        for i in range(args.tasks):
            print(" - Task {}: {:.4f}".format(i + 1, precs[i]))
        print('=> Average precision over all {} tasks: {:.4f}\n'.format(
            args.tasks, average_precs))

    # -with exemplars
    if args.use_exemplars:
        precs = [
            evaluate.validate(model,
                              test_datasets[i],
                              verbose=False,
                              test_size=None,
                              task=i + 1,
                              with_exemplars=True,
                              allowed_classes=list(
                                  range(classes_per_task * i,
                                        classes_per_task * (i + 1))))
            for i in range(args.tasks)
        ]
        average_precs_ex = sum(precs) / args.tasks
        # -print on screen
        if verbose:
            print(" Precision on test-set (classification using exemplars):")
            for i in range(args.tasks):
                print(" - Task {}: {:.4f}".format(i + 1, precs[i]))
            print('=> Average precision over all {} tasks: {:.4f}\n'.format(
                args.tasks, average_precs_ex))

    # If requested, compute metrics
    '''if args.metrics:
Exemple #4
0
def train_cl(model,
             train_datasets,
             replay_mode="none",
             rnt=None,
             classes_per_task=None,
             iters=2000,
             batch_size=32,
             batch_size_replay=None,
             loss_cbs=list(),
             eval_cbs=list(),
             reinit=False,
             args=None,
             only_last=False,
             use_exemplars=False,
             metric_cbs=list()):
    '''Train a model (with a "train_a_batch" method) on multiple tasks, with replay-strategy specified by [replay_mode].

    [model]             <nn.Module> main model to optimize across all tasks
    [train_datasets]    <list> with for each task the training <DataSet>
    [replay_mode]       <str>, choice from "current", "offline" and "none"
    [classes_per_task]  <int>, # classes per task; only 1st task has [classes_per_task]*[first_task_class_boost] classes
    [rnt]               <float>, indicating relative importance of new task (if None, relative to # old tasks)
    [iters]             <int>, # optimization-steps (=batches) per task; 1st task has [first_task_iter_boost] steps more
    [batch_size_replay] <int>, number of samples to replay per batch
    [only_last]         <bool>, only train on final task / episode
    [*_cbs]             <list> of call-back functions to evaluate training-progress'''

    # Should convolutional layers be frozen?
    freeze_convE = (utils.checkattr(args, "freeze_convE")
                    and hasattr(args, "depth") and args.depth > 0)

    # Use cuda?
    device = model._device()
    cuda = model._is_on_cuda()

    # Set default-values if not specified
    batch_size_replay = batch_size if batch_size_replay is None else batch_size_replay

    # Initiate indicators for replay (no replay for 1st task)
    Exact = Current = Offline_TaskIL = False
    previous_model = None

    # Register starting param-values (needed for "intelligent synapses").
    if isinstance(model, ContinualLearner) and model.si_c > 0:
        for n, p in model.named_parameters():
            if p.requires_grad:
                n = n.replace('.', '__')
                model.register_buffer('{}_SI_prev_task'.format(n),
                                      p.detach().clone())

    # Loop over all tasks.
    for task, train_dataset in enumerate(train_datasets, 1):

        # In offline replay-setting, all tasks so far should be visited separately (i.e., separate data-loader per task)
        if replay_mode == "offline":
            Offline_TaskIL = True
            data_loader = [None] * task

        train_dataset = train_dataset

        # Initialize # iters left on data-loader(s)
        iters_left = 1 if (not Offline_TaskIL) else [1] * task
        if Exact:
            iters_left_previous = [1] * (task - 1)
            data_loader_previous = [None] * (task - 1)

        # Prepare <dicts> to store running importance estimates and parameter-values before update
        if isinstance(model, ContinualLearner) and model.si_c > 0:
            W = {}
            p_old = {}
            for n, p in model.named_parameters():
                if p.requires_grad:
                    n = n.replace('.', '__')
                    W[n] = p.data.clone().zero_()
                    p_old[n] = p.data.clone()

        # Find [active_classes] (=classes in current task)
        active_classes = [
            list(range(classes_per_task * i, classes_per_task * (i + 1)))
            for i in range(task)
        ]

        # Reinitialize the model's parameters and the optimizer (if requested)
        if reinit:
            from define_models import init_params
            init_params(model, args)
            model.optimizer = optim.Adam(model.optim_list, betas=(0.9, 0.999))

        # Define a tqdm progress bar(s)
        progress = tqdm.tqdm(range(1, iters + 1))

        # Loop over all iterations
        iters_to_use = iters
        # -if only the final task should be trained on:
        if only_last and not task == len(train_datasets):
            iters_to_use = 0
        for batch_index in range(1, iters_to_use + 1):

            # Update # iters left on current data-loader(s) and, if needed, create new one(s)
            if not Offline_TaskIL:
                iters_left -= 1
                if iters_left == 0:
                    data_loader = iter(
                        utils.get_data_loader(train_dataset,
                                              batch_size,
                                              cuda=cuda,
                                              drop_last=True))
                    # NOTE:  [train_dataset]  is training-set of current task
                    #      [train_dataset] is training-set of current task with stored exemplars added (if requested)
                    iters_left = len(data_loader)
            else:
                # -with "offline replay", there is a separate data-loader for each task
                batch_size_to_use = batch_size
                for task_id in range(task):
                    iters_left[task_id] -= 1
                    if iters_left[task_id] == 0:
                        data_loader[task_id] = iter(
                            utils.get_data_loader(train_datasets[task_id],
                                                  batch_size_to_use,
                                                  cuda=cuda,
                                                  drop_last=True))
                        iters_left[task_id] = len(data_loader[task_id])

            # Update # iters left on data-loader(s) of the previous task(s) and, if needed, create new one(s)
            if Exact:
                up_to_task = task - 1
                batch_size_replay_pt = int(
                    np.floor(
                        batch_size_replay /
                        up_to_task)) if (up_to_task > 1) else batch_size_replay
                # -need separate replay for each task
                for task_id in range(up_to_task):
                    batch_size_to_use = min(batch_size_replay_pt,
                                            len(previous_datasets[task_id]))
                    iters_left_previous[task_id] -= 1
                    if iters_left_previous[task_id] == 0:
                        data_loader_previous[task_id] = iter(
                            utils.get_data_loader(previous_datasets[task_id],
                                                  batch_size_to_use,
                                                  cuda=cuda,
                                                  drop_last=True))
                        iters_left_previous[task_id] = len(
                            data_loader_previous[task_id])

            #-----------------Collect data------------------#

            #####-----CURRENT BATCH-----#####
            if not Offline_TaskIL:
                x, y = next(
                    data_loader)  #--> sample training data of current task
                y = y - classes_per_task * (
                    task - 1)  #--> ITL: adjust y-targets to 'active range'
                x, y = x.to(device), y.to(
                    device)  #--> transfer them to correct device
                #y = y.expand(1) if len(y.size())==1 else y     #--> hack for if batch-size is 1
            else:
                x = y = task_used = None  #--> all tasks are "treated as replay"
                # -sample training data for all tasks so far, move to correct device and store in lists
                x_, y_ = list(), list()
                for task_id in range(task):
                    x_temp, y_temp = next(data_loader[task_id])
                    x_.append(x_temp.to(device))
                    y_temp = y_temp - (
                        classes_per_task * task_id
                    )  #--> adjust y-targets to 'active range'
                    if batch_size_to_use == 1:
                        y_temp = torch.tensor([
                            y_temp
                        ])  #--> correct dimensions if batch-size is 1
                    y_.append(y_temp.to(device))

            #####-----REPLAYED BATCH-----#####
            if not Offline_TaskIL and not Exact and not Current:
                x_ = y_ = scores_ = task_used = None  #-> if no replay

            #--------------------------------------------INPUTS----------------------------------------------------#

            ##-->> Exact Replay <<--##
            if Exact:
                # Sample replayed training data, move to correct device and store in lists
                x_ = list()
                y_ = list()
                up_to_task = task - 1
                for task_id in range(up_to_task):
                    x_temp, y_temp = next(data_loader_previous[task_id])
                    x_.append(x_temp.to(device))
                    # -only keep [y_] if required (as otherwise unnecessary computations will be done)
                    if model.replay_targets == "hard":
                        y_temp = y_temp - (
                            classes_per_task * task_id
                        )  #-> adjust y-targets to 'active range'
                        y_.append(y_temp.to(device))
                    else:
                        y_.append(None)
                # If required, get target scores (i.e, [scores_])        -- using previous model, with no_grad()
                if (model.replay_targets == "soft") and (previous_model
                                                         is not None):
                    scores_ = list()
                    for task_id in range(up_to_task):
                        with torch.no_grad():
                            scores_temp = previous_model(x_[task_id])
                        scores_temp = scores_temp[:,
                                                  (classes_per_task *
                                                   task_id):(classes_per_task *
                                                             (task_id + 1))]
                        scores_.append(scores_temp)
                else:
                    scores_ = None

            ##-->> Current Replay <<--##
            if Current:
                x_ = x[:batch_size_replay]  #--> use current task inputs
                task_used = None

            #--------------------------------------------OUTPUTS----------------------------------------------------#

            if Current:
                # Get target scores & possibly labels (i.e., [scores_] / [y_]) -- use previous model, with no_grad()
                # -[x_] needs to be evaluated according to each previous task, so make list with entry per task
                scores_ = list()
                y_ = list()
                # -if no task-mask and no conditional generator, all scores can be calculated in one go
                if previous_model.mask_dict is None and not type(x_) == list:
                    with torch.no_grad():
                        all_scores_ = previous_model.classify(x_)
                for task_id in range(task - 1):
                    # -if there is a task-mask (i.e., XdG is used), obtain predicted scores for each task separately
                    if previous_model.mask_dict is not None:
                        previous_model.apply_XdGmask(task=task_id + 1)
                    if previous_model.mask_dict is not None or type(
                            x_) == list:
                        with torch.no_grad():
                            all_scores_ = previous_model.classify(
                                x_[task_id] if type(x_) == list else x_)
                    temp_scores_ = all_scores_[:, (classes_per_task *
                                                   task_id):(classes_per_task *
                                                             (task_id + 1))]
                    scores_.append(temp_scores_)
                    # - also get hard target
                    _, temp_y_ = torch.max(temp_scores_, dim=1)
                    y_.append(temp_y_)
            # -only keep predicted y_/scores_ if required (as otherwise unnecessary computations will be done)
            y_ = y_ if (model.replay_targets == "hard") else None
            scores_ = scores_ if (model.replay_targets == "soft") else None

            #-----------------Train model------------------#

            # Train the main model with this batch
            loss_dict = model.train_a_batch(x,
                                            y=y,
                                            x_=x_,
                                            y_=y_,
                                            scores_=scores_,
                                            tasks_=task_used,
                                            active_classes=active_classes,
                                            task=task,
                                            rnt=(1. if task == 1 else 1. /
                                                 task) if rnt is None else rnt,
                                            freeze_convE=freeze_convE)

            # Update running parameter importance estimates in W
            if isinstance(model, ContinualLearner) and model.si_c > 0:
                for n, p in model.named_parameters():
                    if p.requires_grad:
                        n = n.replace('.', '__')
                        if p.grad is not None:
                            W[n].add_(-p.grad * (p.detach() - p_old[n]))
                        p_old[n] = p.detach().clone()

            # Fire callbacks (for visualization of training-progress / evaluating performance after each task)
            for loss_cb in loss_cbs:
                if loss_cb is not None:
                    loss_cb(progress, batch_index, loss_dict, task=task)
            for eval_cb in eval_cbs:
                if eval_cb is not None:
                    eval_cb(model, batch_index, task=task)

        # Close progres-bar
        progress.close()

        ##----------> UPON FINISHING EACH TASK...

        # EWC: estimate Fisher Information matrix (FIM) and update term for quadratic penalty
        if isinstance(model, ContinualLearner) and model.ewc_lambda > 0:
            # -find allowed classes
            allowed_classes = list(
                range(classes_per_task * (task - 1), classes_per_task * task))
            # -if needed, apply correct task-specific mask
            if model.mask_dict is not None:
                model.apply_XdGmask(task=task)
            # -estimate FI-matrix
            model.estimate_fisher(train_dataset,
                                  allowed_classes=allowed_classes)

        # SI: calculate and update the normalized path integral
        if isinstance(model, ContinualLearner) and model.si_c > 0:
            model.update_omega(W, model.epsilon)

        # EXEMPLARS: update exemplar sets
        if use_exemplars or replay_mode == "exemplars":
            exemplars_per_class = int(
                np.floor(model.memory_budget / (classes_per_task * task)))
            # reduce examplar-sets
            model.reduce_exemplar_sets(exemplars_per_class)
            # for each new class trained on, construct examplar-set
            new_classes = list(
                range(classes_per_task * (task - 1), classes_per_task * task))
            for class_id in new_classes:
                # create new dataset containing only all examples of this class
                class_dataset = SubDataset(original_dataset=train_dataset,
                                           sub_labels=[class_id])
                # based on this dataset, construct new exemplar-set for this class
                model.construct_exemplar_set(dataset=class_dataset,
                                             n=exemplars_per_class)
            model.compute_means = True

        # Calculate statistics required for metrics
        for metric_cb in metric_cbs:
            if metric_cb is not None:
                metric_cb(model, iters, task=task)

        # REPLAY: update source for replay
        previous_model = copy.deepcopy(model).eval()
        if replay_mode == 'current':
            Current = True
        elif replay_mode in ('exemplars', 'exact'):
            Exact = True
            if replay_mode == "exact":
                previous_datasets = train_datasets[:task]
            else:
                previous_datasets = []
                for task_id in range(task):
                    previous_datasets.append(
                        ExemplarDataset(
                            model.exemplar_sets[(classes_per_task *
                                                 task_id):(classes_per_task *
                                                           (task_id + 1))],
                            target_transform=lambda y, x=classes_per_task *
                            task_id: y + x))
def train_cl(model,
             train_datasets,
             replay_mode="none",
             scenario="task",
             rnt=None,
             classes_per_task=None,
             iters=2000,
             batch_size=32,
             batch_size_replay=None,
             loss_cbs=list(),
             eval_cbs=list(),
             sample_cbs=list(),
             generator=None,
             gen_iters=0,
             gen_loss_cbs=list(),
             feedback=False,
             reinit=False,
             args=None,
             only_last=False):
    '''Train a model (with a "train_a_batch" method) on multiple tasks, with replay-strategy specified by [replay_mode].

    [model]             <nn.Module> main model to optimize across all tasks
    [train_datasets]    <list> with for each task the training <DataSet>
    [replay_mode]       <str>, choice from "generative", "current", "offline" and "none"
    [scenario]          <str>, choice from "task", "domain", "class" and "all"
    [classes_per_task]  <int>, # classes per task; only 1st task has [classes_per_task]*[first_task_class_boost] classes
    [rnt]               <float>, indicating relative importance of new task (if None, relative to # old tasks)
    [iters]             <int>, # optimization-steps (=batches) per task; 1st task has [first_task_iter_boost] steps more
    [batch_size_replay] <int>, number of samples to replay per batch
    [generator]         None or <nn.Module>, if a seperate generative model should be trained (for [gen_iters] per task)
    [feedback]          <bool>, if True and [replay_mode]="generative", the main model is used for generating replay
    [only_last]         <bool>, only train on final task / episode
    [*_cbs]             <list> of call-back functions to evaluate training-progress'''

    # Should convolutional layers be frozen?
    freeze_convE = (utils.checkattr(args, "freeze_convE")
                    and hasattr(args, "depth") and args.depth > 0)

    # Use cuda?
    device = model._device()
    cuda = model._is_on_cuda()

    # Set default-values if not specified
    batch_size_replay = batch_size if batch_size_replay is None else batch_size_replay

    # Initiate indicators for replay (no replay for 1st task)
    Generative = Current = Offline_TaskIL = False
    previous_model = None

    # Register starting param-values (needed for "intelligent synapses").
    if isinstance(model, ContinualLearner) and model.si_c > 0:
        for n, p in model.named_parameters():
            if p.requires_grad:
                n = n.replace('.', '__')
                model.register_buffer('{}_SI_prev_task'.format(n),
                                      p.detach().clone())

    # Loop over all tasks.
    for task, train_dataset in enumerate(train_datasets, 1):

        # If offline replay-setting, create large database of all tasks so far
        if replay_mode == "offline" and (not scenario == "task"):
            train_dataset = ConcatDataset(train_datasets[:task])
        # -but if "offline"+"task": all tasks so far should be visited separately (i.e., separate data-loader per task)
        if replay_mode == "offline" and scenario == "task":
            Offline_TaskIL = True
            data_loader = [None] * task

        # Initialize # iters left on data-loader(s)
        iters_left = 1 if (not Offline_TaskIL) else [1] * task

        # Prepare <dicts> to store running importance estimates and parameter-values before update
        if isinstance(model, ContinualLearner) and model.si_c > 0:
            W = {}
            p_old = {}
            for n, p in model.named_parameters():
                if p.requires_grad:
                    n = n.replace('.', '__')
                    W[n] = p.data.clone().zero_()
                    p_old[n] = p.data.clone()

        # Find [active_classes] (=classes in current task)
        active_classes = None  #-> for "domain"- or "all"-scenarios, always all classes are active
        if scenario == "task":
            # -for "task"-scenario, create <list> with for all tasks so far a <list> with the active classes
            active_classes = [
                list(range(classes_per_task * i, classes_per_task * (i + 1)))
                for i in range(task)
            ]
        elif scenario == "class":
            # -for "class"-scenario, create one <list> with active classes of all tasks so far
            active_classes = list(range(classes_per_task * task))

        # Reinitialize the model's parameters (if requested)
        if reinit:
            from define_models import init_params
            init_params(model, args)
            if generator is not None:
                init_params(generator, args)

        # Define a tqdm progress bar(s)
        iters_main = iters
        progress = tqdm.tqdm(range(1, iters_main + 1))
        if generator is not None:
            iters_gen = gen_iters
            progress_gen = tqdm.tqdm(range(1, iters_gen + 1))

        # Loop over all iterations
        iters_to_use = (iters_main if
                        (generator is None) else max(iters_main, iters_gen))
        # -if only the final task should be trained on:
        if only_last and not task == len(train_datasets):
            iters_to_use = 0
        for batch_index in range(1, iters_to_use + 1):

            # Update # iters left on current data-loader(s) and, if needed, create new one(s)
            if not Offline_TaskIL:
                iters_left -= 1
                if iters_left == 0:
                    data_loader = iter(
                        utils.get_data_loader(train_dataset,
                                              batch_size,
                                              cuda=cuda,
                                              drop_last=True))
                    iters_left = len(data_loader)
            else:
                # -with "offline replay" in Task-IL scenario, there is a separate data-loader for each task
                batch_size_to_use = int(np.ceil(batch_size / task))
                for task_id in range(task):
                    iters_left[task_id] -= 1
                    if iters_left[task_id] == 0:
                        data_loader[task_id] = iter(
                            utils.get_data_loader(train_datasets[task_id],
                                                  batch_size_to_use,
                                                  cuda=cuda,
                                                  drop_last=True))
                        iters_left[task_id] = len(data_loader[task_id])

            #-----------------Collect data------------------#

            #####-----CURRENT BATCH-----#####
            if not Offline_TaskIL:
                x, y = next(
                    data_loader)  #--> sample training data of current task
                y = y - classes_per_task * (
                    task - 1
                ) if scenario == "task" else y  #--> ITL: adjust y-targets to 'active range'
                x, y = x.to(device), y.to(
                    device)  #--> transfer them to correct device
                #y = y.expand(1) if len(y.size())==1 else y                 #--> hack for if batch-size is 1
            else:
                x = y = task_used = None  #--> all tasks are "treated as replay"
                # -sample training data for all tasks so far, move to correct device and store in lists
                x_, y_ = list(), list()
                for task_id in range(task):
                    x_temp, y_temp = next(data_loader[task_id])
                    x_.append(x_temp.to(device))
                    y_temp = y_temp - (
                        classes_per_task * task_id
                    )  #--> adjust y-targets to 'active range'
                    if batch_size_to_use == 1:
                        y_temp = torch.tensor([
                            y_temp
                        ])  #--> correct dimensions if batch-size is 1
                    y_.append(y_temp.to(device))

            #####-----REPLAYED BATCH-----#####
            if not Offline_TaskIL and not Generative and not Current:
                x_ = y_ = scores_ = task_used = None  #-> if no replay

            #--------------------------------------------INPUTS----------------------------------------------------#

            ##-->> Current Replay <<--##
            if Current:
                x_ = x[:batch_size_replay]  #--> use current task inputs
                task_used = None

            ##-->> Generative Replay <<--##
            if Generative:
                #---> Only with generative replay, the resulting [x_] will be at the "hidden"-level
                conditional_gen = True if (
                    (previous_generator.per_class
                     and previous_generator.prior == "GMM") or utils.checkattr(
                         previous_generator, 'dg_gates')) else False

                # Sample [x_]
                if conditional_gen and scenario == "task":
                    # -if a conditional generator is used with task-IL scenario, generate data per previous task
                    x_ = list()
                    task_used = list()
                    for task_id in range(task - 1):
                        allowed_classes = list(
                            range(classes_per_task * task_id,
                                  classes_per_task * (task_id + 1)))
                        batch_size_replay_to_use = int(
                            np.ceil(batch_size_replay / (task - 1)))
                        x_temp_ = previous_generator.sample(
                            batch_size_replay_to_use,
                            allowed_classes=allowed_classes,
                            only_x=False)
                        x_.append(x_temp_[0])
                        task_used.append(x_temp_[2])
                else:
                    # -which classes are allowed to be generated? (relevant if conditional generator / decoder-gates)
                    allowed_classes = None if scenario == "domain" else list(
                        range(classes_per_task * (task - 1)))
                    # -which tasks/domains are allowed to be generated? (only relevant if "Domain-IL" with task-gates)
                    allowed_domains = list(range(task - 1))
                    # -generate inputs representative of previous tasks
                    x_temp_ = previous_generator.sample(
                        batch_size_replay,
                        allowed_classes=allowed_classes,
                        allowed_domains=allowed_domains,
                        only_x=False,
                    )
                    x_ = x_temp_[0]
                    task_used = x_temp_[2]

            #--------------------------------------------OUTPUTS----------------------------------------------------#

            if Generative or Current:
                # Get target scores & possibly labels (i.e., [scores_] / [y_]) -- use previous model, with no_grad()
                if scenario in ("domain",
                                "class") and previous_model.mask_dict is None:
                    # -if replay does not need to be evaluated for each task (ie, not Task-IL and no task-specific mask)
                    with torch.no_grad():
                        all_scores_ = previous_model.classify(
                            x_, not_hidden=False if Generative else True)
                    scores_ = all_scores_[:, :(
                        classes_per_task * (task - 1)
                    )] if (
                        scenario == "class"
                    ) else all_scores_  # -> when scenario=="class", zero probs will be added in [loss_fn_kd]-function
                    # -also get the 'hard target'
                    _, y_ = torch.max(scores_, dim=1)
                else:
                    # -[x_] needs to be evaluated according to each previous task, so make list with entry per task
                    scores_ = list()
                    y_ = list()
                    # -if no task-mask and no conditional generator, all scores can be calculated in one go
                    if previous_model.mask_dict is None and not type(
                            x_) == list:
                        with torch.no_grad():
                            all_scores_ = previous_model.classify(
                                x_, not_hidden=False if Generative else True)
                    for task_id in range(task - 1):
                        # -if there is a task-mask (i.e., XdG is used), obtain predicted scores for each task separately
                        if previous_model.mask_dict is not None:
                            previous_model.apply_XdGmask(task=task_id + 1)
                        if previous_model.mask_dict is not None or type(
                                x_) == list:
                            with torch.no_grad():
                                all_scores_ = previous_model.classify(
                                    x_[task_id] if type(x_) == list else x_,
                                    not_hidden=False if Generative else True)
                        if scenario == "domain":
                            # NOTE: if scenario=domain with task-mask, it's of course actually the Task-IL scenario!
                            #       this can be used as trick to run the Task-IL scenario with singlehead output layer
                            temp_scores_ = all_scores_
                        else:
                            temp_scores_ = all_scores_[:, (
                                classes_per_task * task_id):(classes_per_task *
                                                             (task_id + 1))]
                        scores_.append(temp_scores_)
                        # - also get hard target
                        _, temp_y_ = torch.max(temp_scores_, dim=1)
                        y_.append(temp_y_)
            # -only keep predicted y_/scores_ if required (as otherwise unnecessary computations will be done)
            y_ = y_ if (model.replay_targets == "hard") else None
            scores_ = scores_ if (model.replay_targets == "soft") else None

            #-----------------Train model(s)------------------#

            #---> Train MAIN MODEL
            if batch_index <= iters_main:

                # Train the main model with this batch
                loss_dict = model.train_a_batch(
                    x,
                    y=y,
                    x_=x_,
                    y_=y_,
                    scores_=scores_,
                    tasks_=task_used,
                    active_classes=active_classes,
                    task=task,
                    rnt=(1. if task == 1 else 1. /
                         task) if rnt is None else rnt,
                    freeze_convE=freeze_convE,
                    replay_not_hidden=False if Generative else True)

                # Update running parameter importance estimates in W
                if isinstance(model, ContinualLearner) and model.si_c > 0:
                    for n, p in model.convE.named_parameters():
                        if p.requires_grad:
                            n = "convE." + n
                            n = n.replace('.', '__')
                            if p.grad is not None:
                                W[n].add_(-p.grad * (p.detach() - p_old[n]))
                            p_old[n] = p.detach().clone()
                    for n, p in model.fcE.named_parameters():
                        if p.requires_grad:
                            n = "fcE." + n
                            n = n.replace('.', '__')
                            if p.grad is not None:
                                W[n].add_(-p.grad * (p.detach() - p_old[n]))
                            p_old[n] = p.detach().clone()
                    for n, p in model.classifier.named_parameters():
                        if p.requires_grad:
                            n = "classifier." + n
                            n = n.replace('.', '__')
                            if p.grad is not None:
                                W[n].add_(-p.grad * (p.detach() - p_old[n]))
                            p_old[n] = p.detach().clone()

                # Fire callbacks (for visualization of training-progress / evaluating performance after each task)
                for loss_cb in loss_cbs:
                    if loss_cb is not None:
                        loss_cb(progress, batch_index, loss_dict, task=task)
                for eval_cb in eval_cbs:
                    if eval_cb is not None:
                        eval_cb(model, batch_index, task=task)
                if model.label == "VAE":
                    for sample_cb in sample_cbs:
                        if sample_cb is not None:
                            sample_cb(model,
                                      batch_index,
                                      task=task,
                                      allowed_classes=None if
                                      (scenario == "domain") else list(
                                          range(classes_per_task * task)))

            #---> Train GENERATOR
            if generator is not None and batch_index <= iters_gen:

                loss_dict = generator.train_a_batch(
                    x,
                    y=y,
                    x_=x_,
                    y_=y_,
                    scores_=scores_,
                    tasks_=task_used,
                    active_classes=active_classes,
                    rnt=(1. if task == 1 else 1. /
                         task) if rnt is None else rnt,
                    task=task,
                    freeze_convE=freeze_convE,
                    replay_not_hidden=False if Generative else True)

                # Fire callbacks on each iteration
                for loss_cb in gen_loss_cbs:
                    if loss_cb is not None:
                        loss_cb(progress_gen,
                                batch_index,
                                loss_dict,
                                task=task)
                for sample_cb in sample_cbs:
                    if sample_cb is not None:
                        sample_cb(generator,
                                  batch_index,
                                  task=task,
                                  allowed_classes=None if
                                  (scenario == "domain") else list(
                                      range(classes_per_task * task)))

        # Close progres-bar(s)
        progress.close()
        if generator is not None:
            progress_gen.close()

        ##----------> UPON FINISHING EACH TASK...

        # EWC: estimate Fisher Information matrix (FIM) and update term for quadratic penalty
        if isinstance(model, ContinualLearner) and model.ewc_lambda > 0:
            # -find allowed classes
            allowed_classes = list(
                range(classes_per_task * (task - 1), classes_per_task *
                      task)) if scenario == "task" else (
                          list(range(classes_per_task *
                                     task)) if scenario == "class" else None)
            # -if needed, apply correct task-specific mask
            if model.mask_dict is not None:
                model.apply_XdGmask(task=task)
            # -estimate FI-matrix
            model.estimate_fisher(train_dataset,
                                  allowed_classes=allowed_classes)

        # SI: calculate and update the normalized path integral
        if isinstance(model, ContinualLearner) and model.si_c > 0:
            model.update_omega(W, model.epsilon)

        # REPLAY: update source for replay
        previous_model = copy.deepcopy(model).eval()
        if replay_mode == "generative":
            Generative = True
            previous_generator = previous_model if feedback else copy.deepcopy(
                generator).eval()
        elif replay_mode == 'current':
            Current = True
def run(args):

    # Use cuda?
    cuda = torch.cuda.is_available() and args.cuda
    device = torch.device("cuda" if cuda else "cpu")

    # Set random seeds
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if cuda:
        torch.cuda.manual_seed(args.seed)

    # Report whether cuda is used
    print("CUDA is {}used".format("" if cuda else "NOT(!!) "))

    # Create plots-directory if needed
    if args.pdf and not os.path.isdir(args.p_dir):
        os.mkdir(args.p_dir)

    #-------------------------------------------------------------------------------------------------#

    #----------------#
    #----- DATA -----#
    #----------------#

    # Prepare data for chosen experiment
    print("\nPreparing the data...")
    (trainset, testset), config = get_singletask_experiment(
        name=args.experiment, data_dir=args.d_dir, verbose=True,
        normalize = True if utils.checkattr(args, "normalize") else False,
        augment = True if utils.checkattr(args, "augment") else False,
    )

    # Specify "data-loader" (among others for easy random shuffling and 'batchifying')
    train_loader = utils.get_data_loader(trainset, batch_size=args.batch, cuda=cuda, drop_last=True)

    # Determine number of iterations / epochs:
    iters = args.iters if args.iters else args.epochs*len(train_loader)
    epochs = ((args.iters-1) // len(train_loader)) + 1 if args.iters else args.epochs


    #-------------------------------------------------------------------------------------------------#

    #-----------------#
    #----- MODEL -----#
    #-----------------#

    # Specify model
    if (utils.checkattr(args, "pre_convE") or utils.checkattr(args, "pre_convD")) and \
            (hasattr(args, "depth") and args.depth>0):
        print("\nDefining the model...")
    cnn = define.define_classifier(args=args, config=config, device=device)

    # Initialize (pre-trained) parameters
    cnn = define.init_params(cnn, args)
    # - freeze weights of conv-layers?
    if utils.checkattr(args, "freeze_convE"):
        for param in cnn.convE.parameters():
            param.requires_grad = False
        cnn.convE.eval()  #--> needed to ensure batchnorm-layers also do not change
    # - freeze weights of representation-learning layers?
    if utils.checkattr(args, "freeze_full"):
        for param in cnn.parameters():
            param.requires_grad = False
        for param in cnn.classifier.parameters():
            param.requires_grad = True

    # Set optimizer
    optim_list = [{'params': filter(lambda p: p.requires_grad, cnn.parameters()), 'lr': args.lr}]
    cnn.optimizer = torch.optim.Adam(optim_list, betas=(0.9, 0.999))


    #-------------------------------------------------------------------------------------------------#

    #---------------------#
    #----- REPORTING -----#
    #---------------------#

    # Get parameter-stamp
    print("\nParameter-stamp...")
    param_stamp = get_param_stamp(args, cnn.name, verbose=True)

    # Print some model-characteristics on the screen
    utils.print_model_info(cnn, title="CLASSIFIER")

    # Define [progress_dicts] to keep track of performance during training for storing and for later plotting in pdf
    precision_dict = evaluate.initiate_precision_dict(n_tasks=1)

    # Prepare for plotting in visdom
    graph_name = cnn.name
    visdom = None if (not args.visdom) else {'env': args.experiment, 'graph': graph_name}

    #-------------------------------------------------------------------------------------------------#

    #---------------------#
    #----- CALLBACKS -----#
    #---------------------#

    # Determine after how many iterations to evaluate the model
    eval_log = args.prec_log if (args.prec_log is not None) else len(train_loader)

    # Define callback-functions to evaluate during training
    # -loss
    loss_cbs = [cb._solver_loss_cb(log=args.loss_log, visdom=visdom, epochs=epochs)]
    # -precision
    eval_cb = cb._eval_cb(log=eval_log, test_datasets=[testset], visdom=visdom, precision_dict=precision_dict)
    # -visualize extracted representation
    latent_space_cb = cb._latent_space_cb(log=min(5*eval_log, iters), datasets=[testset], visdom=visdom,
                                          sample_size=400)


    #-------------------------------------------------------------------------------------------------#

    #--------------------------#
    #----- (PRE-)TRAINING -----#
    #--------------------------#

    # (Pre)train model
    print("\nTraining...")
    train.train(cnn, train_loader, iters, loss_cbs=loss_cbs, eval_cbs=[eval_cb, latent_space_cb],
                save_every=1000 if args.save else None, m_dir=args.m_dir, args=args)

    # Save (pre)trained model
    if args.save:
        # -conv-layers
        save_name = cnn.convE.name if (
            not hasattr(args, 'convE_stag') or args.convE_stag=="none"
        ) else "{}-{}".format(cnn.convE.name, args.convE_stag)
        utils.save_checkpoint(cnn.convE, args.m_dir, name=save_name)
        # -full model
        save_name = cnn.name if (
            not hasattr(args, 'full_stag') or args.full_stag=="none"
        ) else "{}-{}".format(cnn.name, args.full_stag)
        utils.save_checkpoint(cnn, args.m_dir, name=save_name)


    #-------------------------------------------------------------------------------------------------#

    #--------------------#
    #----- PLOTTING -----#
    #--------------------#

    # if requested, generate pdf.
    if args.pdf:
        # -open pdf
        plot_name = "{}/{}.pdf".format(args.p_dir, param_stamp)
        pp = plt.open_pdf(plot_name)
        # -Fig1: show some images
        images, _ = next(iter(train_loader))            #--> get a mini-batch of random training images
        plt.plot_images_from_tensor(images, pp, title="example input images", config=config)
        # -Fig2: precision
        figure = plt.plot_lines(precision_dict["all_tasks"], x_axes=precision_dict["x_iteration"],
                                line_names=['ave precision'], xlabel="Iterations", ylabel="Test accuracy")
        pp.savefig(figure)
        # -close pdf
        pp.close()
        # -print name of generated plot on screen
        print("\nGenerated plot: {}\n".format(plot_name))
Exemple #7
0
def run(args, verbose=False):

    # Create plots- and results-directories if needed
    if not os.path.isdir(args.r_dir):
        os.mkdir(args.r_dir)
    if args.pdf and not os.path.isdir(args.p_dir):
        os.mkdir(args.p_dir)

    # If only want param-stamp, get it and exit
    if args.get_stamp:
        from param_stamp import get_param_stamp_from_args
        print(get_param_stamp_from_args(args=args))
        exit()

    # Use cuda?
    cuda = torch.cuda.is_available() and args.cuda
    device = torch.device("cuda" if cuda else "cpu")

    # Report whether cuda is used
    if verbose:
        print("CUDA is {}used".format("" if cuda else "NOT(!!) "))

    # Set random seeds
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if cuda:
        torch.cuda.manual_seed(args.seed)

    #-------------------------------------------------------------------------------------------------#

    #----------------#
    #----- DATA -----#
    #----------------#

    # Prepare data for chosen experiment
    if verbose:
        print("\nPreparing the data...")
    (train_datasets,
     test_datasets), config, classes_per_task = get_multitask_experiment(
         name=args.experiment,
         tasks=args.tasks,
         data_dir=args.d_dir,
         normalize=True if utils.checkattr(args, "normalize") else False,
         augment=True if utils.checkattr(args, "augment") else False,
         verbose=verbose,
         exception=True if args.seed < 10 else False,
         only_test=(not args.train),
         max_samples=args.max_samples)

    #-------------------------------------------------------------------------------------------------#

    #----------------------#
    #----- MAIN MODEL -----#
    #----------------------#

    # Define main model (i.e., classifier, if requested with feedback connections)
    if verbose and utils.checkattr(
            args, "pre_convE") and (hasattr(args, "depth") and args.depth > 0):
        print("\nDefining the model...")
    model = define.define_classifier(args=args, config=config, device=device)

    # Initialize / use pre-trained / freeze model-parameters
    # - initialize (pre-trained) parameters
    model = define.init_params(model, args)
    # - freeze weights of conv-layers?
    if utils.checkattr(args, "freeze_convE"):
        for param in model.convE.parameters():
            param.requires_grad = False

    # Define optimizer (only optimize parameters that "requires_grad")
    model.optim_list = [
        {
            'params': filter(lambda p: p.requires_grad, model.parameters()),
            'lr': args.lr
        },
    ]
    model.optimizer = optim.Adam(model.optim_list, betas=(0.9, 0.999))

    #-------------------------------------------------------------------------------------------------#

    #----------------------------------#
    #----- CL-STRATEGY: EXEMPLARS -----#
    #----------------------------------#

    # Store in model whether, how many and in what way to store exemplars
    if isinstance(model, ExemplarHandler) and (args.use_exemplars
                                               or args.replay == "exemplars"):
        model.memory_budget = args.budget
        model.herding = args.herding
        model.norm_exemplars = args.herding

    #-------------------------------------------------------------------------------------------------#

    #----------------------------------------------------#
    #----- CL-STRATEGY: REGULARIZATION / ALLOCATION -----#
    #----------------------------------------------------#

    # Elastic Weight Consolidation (EWC)
    if isinstance(model, ContinualLearner) and utils.checkattr(args, 'ewc'):
        model.ewc_lambda = args.ewc_lambda if args.ewc else 0
        model.fisher_n = args.fisher_n
        model.online = utils.checkattr(args, 'online')
        if model.online:
            model.gamma = args.gamma

    # Synpatic Intelligence (SI)
    if isinstance(model, ContinualLearner) and utils.checkattr(args, 'si'):
        model.si_c = args.si_c if args.si else 0
        model.epsilon = args.epsilon

    # XdG: create for every task a "mask" for each hidden fully connected layer
    if isinstance(model, ContinualLearner) and utils.checkattr(
            args, 'xdg') and args.xdg_prop > 0:
        model.define_XdGmask(gating_prop=args.xdg_prop, n_tasks=args.tasks)

    #-------------------------------------------------------------------------------------------------#

    #-------------------------------#
    #----- CL-STRATEGY: REPLAY -----#
    #-------------------------------#

    # Use distillation loss (i.e., soft targets) for replayed data? (and set temperature)
    if isinstance(model, ContinualLearner) and hasattr(
            args, 'replay') and not args.replay == "none":
        model.replay_targets = "soft" if args.distill else "hard"
        model.KD_temp = args.temp

    #-------------------------------------------------------------------------------------------------#

    #---------------------#
    #----- REPORTING -----#
    #---------------------#

    # Get parameter-stamp (and print on screen)
    if verbose:
        print("\nParameter-stamp...")
    param_stamp, reinit_param_stamp = get_param_stamp(
        args,
        model.name,
        verbose=verbose,
        replay=True if
        (hasattr(args, 'replay') and not args.replay == "none") else False,
    )

    # Print some model-characteristics on the screen
    if verbose:
        # -main model
        utils.print_model_info(model, title="MAIN MODEL")

    # Prepare for keeping track of statistics required for metrics (also used for plotting in pdf)
    if args.pdf or args.metrics:
        # -define [metrics_dict] to keep track of performance during training for storing & for later plotting in pdf
        metrics_dict = evaluate.initiate_metrics_dict(n_tasks=args.tasks)
        # -evaluate randomly initiated model on all tasks & store accuracies in [metrics_dict] (for calculating metrics)
        if not args.use_exemplars:
            metrics_dict = evaluate.intial_accuracy(
                model,
                test_datasets,
                metrics_dict,
                no_task_mask=False,
                classes_per_task=classes_per_task,
                test_size=None)
    else:
        metrics_dict = None

    # Prepare for plotting in visdom
    visdom = None
    if args.visdom:
        env_name = "{exp}-{tasks}".format(exp=args.experiment,
                                          tasks=args.tasks)
        replay_statement = "{mode}{b}".format(
            mode=args.replay,
            b="" if
            (args.batch_replay is None or args.batch_replay == args.batch) else
            "-br{}".format(args.batch_replay),
        ) if (hasattr(args, "replay") and not args.replay == "none") else "NR"
        graph_name = "{replay}{syn}{ewc}{xdg}".format(
            replay=replay_statement,
            syn="-si{}".format(args.si_c)
            if utils.checkattr(args, 'si') else "",
            ewc="-ewc{}{}".format(
                args.ewc_lambda, "-O{}".format(args.gamma)
                if utils.checkattr(args, "online") else "") if utils.checkattr(
                    args, 'ewc') else "",
            xdg="" if (not utils.checkattr(args, 'xdg')) or args.xdg_prop == 0
            else "-XdG{}".format(args.xdg_prop),
        )
        visdom = {'env': env_name, 'graph': graph_name}

    #-------------------------------------------------------------------------------------------------#

    #---------------------#
    #----- CALLBACKS -----#
    #---------------------#

    # Callbacks for reporting on and visualizing loss
    solver_loss_cbs = [
        cb._solver_loss_cb(log=args.loss_log,
                           visdom=visdom,
                           model=model,
                           iters_per_task=args.iters,
                           tasks=args.tasks,
                           replay=(hasattr(args, "replay")
                                   and not args.replay == "none"))
    ]

    # Callbacks for reporting and visualizing accuracy
    # -visdom (i.e., after each [prec_log]
    eval_cbs = [
        cb._eval_cb(log=args.prec_log,
                    test_datasets=test_datasets,
                    visdom=visdom,
                    iters_per_task=args.iters,
                    test_size=args.prec_n,
                    classes_per_task=classes_per_task,
                    with_exemplars=False)
    ] if (not args.use_exemplars) else [None]
    #--> during training on a task, evaluation cannot be with exemplars as those are only selected after training
    #    (instead, evaluation for visdom is only done after each task, by including callback-function into [metric_cbs])

    # Callbacks for calculating statists required for metrics
    # -pdf / reporting: summary plots (i.e, only after each task) (when using exemplars, also for visdom)
    metric_cbs = [
        cb._metric_cb(log=args.iters,
                      test_datasets=test_datasets,
                      classes_per_task=classes_per_task,
                      metrics_dict=metrics_dict,
                      iters_per_task=args.iters,
                      with_exemplars=args.use_exemplars),
        cb._eval_cb(log=args.iters,
                    test_datasets=test_datasets,
                    visdom=visdom,
                    iters_per_task=args.iters,
                    test_size=args.prec_n,
                    classes_per_task=classes_per_task,
                    with_exemplars=True) if args.use_exemplars else None
    ]

    #-------------------------------------------------------------------------------------------------#

    #--------------------#
    #----- TRAINING -----#
    #--------------------#

    if args.train:
        if verbose:
            print("\nTraining...")
        # Train model
        train_cl(
            model,
            train_datasets,
            replay_mode=args.replay if hasattr(args, 'replay') else "none",
            classes_per_task=classes_per_task,
            iters=args.iters,
            args=args,
            batch_size=args.batch,
            batch_size_replay=args.batch_replay if hasattr(
                args, 'batch_replay') else None,
            eval_cbs=eval_cbs,
            loss_cbs=solver_loss_cbs,
            reinit=utils.checkattr(args, 'reinit'),
            only_last=utils.checkattr(args, 'only_last'),
            metric_cbs=metric_cbs,
            use_exemplars=args.use_exemplars,
        )
        # Save trained model(s), if requested
        if args.save:
            save_name = "mM-{}".format(param_stamp) if (
                not hasattr(args, 'full_stag')
                or args.full_stag == "none") else "{}-{}".format(
                    model.name, args.full_stag)
            utils.save_checkpoint(model,
                                  args.m_dir,
                                  name=save_name,
                                  verbose=verbose)
    else:
        # Load previously trained model(s) (if goal is to only evaluate previously trained model)
        if verbose:
            print("\nLoading parameters of the previously trained models...")
        load_name = "mM-{}".format(param_stamp) if (
            not hasattr(args, 'full_ltag')
            or args.full_ltag == "none") else "{}-{}".format(
                model.name, args.full_ltag)
        utils.load_checkpoint(
            model,
            args.m_dir,
            name=load_name,
            verbose=verbose,
            add_si_buffers=(isinstance(model, ContinualLearner)
                            and utils.checkattr(args, 'si')))
        # Load previously created metrics-dict
        file_name = "{}/dict-{}".format(args.r_dir, param_stamp)
        metrics_dict = utils.load_object(file_name)

    #-------------------------------------------------------------------------------------------------#

    #-----------------------------------#
    #----- EVALUATION of CLASSIFIER-----#
    #-----------------------------------#

    if verbose:
        print("\n\nEVALUATION RESULTS:")

    # Evaluate precision of final model on full test-set
    precs = [
        evaluate.validate(model,
                          test_datasets[i],
                          verbose=False,
                          test_size=None,
                          task=i + 1,
                          with_exemplars=False,
                          allowed_classes=list(
                              range(classes_per_task * i,
                                    classes_per_task * (i + 1))))
        for i in range(args.tasks)
    ]
    average_precs = sum(precs) / args.tasks
    # -print on screen
    if verbose:
        print("\n Precision on test-set{}:".format(
            " (softmax classification)" if args.use_exemplars else ""))
        for i in range(args.tasks):
            print(" - Task {}: {:.4f}".format(i + 1, precs[i]))
        print('=> Average precision over all {} tasks: {:.4f}\n'.format(
            args.tasks, average_precs))

    # -with exemplars
    if args.use_exemplars:
        precs = [
            evaluate.validate(model,
                              test_datasets[i],
                              verbose=False,
                              test_size=None,
                              task=i + 1,
                              with_exemplars=True,
                              allowed_classes=list(
                                  range(classes_per_task * i,
                                        classes_per_task * (i + 1))))
            for i in range(args.tasks)
        ]
        average_precs_ex = sum(precs) / args.tasks
        # -print on screen
        if verbose:
            print(" Precision on test-set (classification using exemplars):")
            for i in range(args.tasks):
                print(" - Task {}: {:.4f}".format(i + 1, precs[i]))
            print('=> Average precision over all {} tasks: {:.4f}\n'.format(
                args.tasks, average_precs_ex))

    # If requested, compute metrics
    if args.metrics:
        # Load accuracy matrix of "reinit"-experiment (i.e., each task's accuracy when only trained on that task)
        if not utils.checkattr(args, 'reinit'):
            file_name = "{}/dict-{}".format(args.r_dir, reinit_param_stamp)
            if not os.path.isfile("{}.pkl".format(file_name)):
                raise FileNotFoundError(
                    "Need to run the correct 'reinit' experiment (with --metrics) first!!"
                )
            reinit_metrics_dict = utils.load_object(file_name)
        # Accuracy matrix
        R = pd.DataFrame(
            data=metrics_dict['acc per task'],
            index=['after task {}'.format(i + 1) for i in range(args.tasks)])
        R = R[["task {}".format(task_id + 1) for task_id in range(args.tasks)]]
        R.loc['at start'] = metrics_dict['initial acc per task'] if (
            not args.use_exemplars) else ['NA' for _ in range(args.tasks)]
        if not utils.checkattr(args, 'reinit'):
            R.loc['only trained on itself'] = [
                reinit_metrics_dict['acc per task']['task {}'.format(
                    task_id + 1)][task_id] for task_id in range(args.tasks)
            ]
        R = R.reindex(
            ['at start'] +
            ['after task {}'.format(i + 1)
             for i in range(args.tasks)] + ['only trained on itself'])
        BWTs = [(R.loc['after task {}'.format(args.tasks), 'task {}'.format(i + 1)] - \
                 R.loc['after task {}'.format(i + 1), 'task {}'.format(i + 1)]) for i in range(args.tasks - 1)]
        FWTs = [
            0. if args.use_exemplars else
            (R.loc['after task {}'.format(i + 1), 'task {}'.format(i + 2)] -
             R.loc['at start', 'task {}'.format(i + 2)])
            for i in range(args.tasks - 1)
        ]
        forgetting = []
        for i in range(args.tasks - 1):
            forgetting.append(
                max(R.iloc[1:args.tasks, i]) - R.iloc[args.tasks, i])
        R.loc['FWT (per task)'] = ['NA'] + FWTs
        R.loc['BWT (per task)'] = BWTs + ['NA']
        R.loc['F (per task)'] = forgetting + ['NA']
        BWT = sum(BWTs) / (args.tasks - 1)
        F = sum(forgetting) / (args.tasks - 1)
        FWT = sum(FWTs) / (args.tasks - 1)
        metrics_dict['BWT'] = BWT
        metrics_dict['F'] = F
        metrics_dict['FWT'] = FWT
        # -Vogelstein et al's measures of transfer efficiency
        if not utils.checkattr(args, 'reinit'):
            TEs = [((1 - R.loc['only trained on itself',
                               'task {}'.format(task_id + 1)]) /
                    (1 - R.loc['after task {}'.format(args.tasks),
                               'task {}'.format(task_id + 1)]))
                   for task_id in range(args.tasks)]
            BTEs = [((1 - R.loc['after task {}'.format(task_id + 1),
                                'task {}'.format(task_id + 1)]) /
                     (1 - R.loc['after task {}'.format(args.tasks),
                                'task {}'.format(task_id + 1)]))
                    for task_id in range(args.tasks)]
            FTEs = [((1 - R.loc['only trained on itself',
                                'task {}'.format(task_id + 1)]) /
                     (1 - R.loc['after task {}'.format(task_id + 1),
                                'task {}'.format(task_id + 1)]))
                    for task_id in range(args.tasks)]
            # -TEs and BTEs after each task
            TEs_all = []
            BTEs_all = []
            for after_task_id in range(args.tasks):
                TEs_all.append([
                    ((1 - R.loc['only trained on itself',
                                'task {}'.format(task_id + 1)]) /
                     (1 - R.loc['after task {}'.format(after_task_id + 1),
                                'task {}'.format(task_id + 1)]))
                    for task_id in range(after_task_id + 1)
                ])
                BTEs_all.append([
                    ((1 - R.loc['after task {}'.format(task_id + 1),
                                'task {}'.format(task_id + 1)]) /
                     (1 - R.loc['after task {}'.format(after_task_id + 1),
                                'task {}'.format(task_id + 1)]))
                    for task_id in range(after_task_id + 1)
                ])
            R.loc['TEs (per task, after all 10 tasks)'] = TEs
            for after_task_id in range(args.tasks):
                R.loc['TEs (per task, after {} tasks)'.format(
                    after_task_id +
                    1)] = TEs_all[after_task_id] + ['NA'] * (args.tasks -
                                                             after_task_id - 1)
            R.loc['BTEs (per task, after all 10 tasks)'] = BTEs
            for after_task_id in range(args.tasks):
                R.loc['BTEs (per task, after {} tasks)'.format(
                    after_task_id + 1)] = BTEs_all[after_task_id] + ['NA'] * (
                        args.tasks - after_task_id - 1)
            R.loc['FTEs (per task)'] = FTEs
            metrics_dict['R'] = R
        # -print on screen
        if verbose:
            print("Accuracy matrix")
            print(R)
            print("\nFWT = {:.4f}".format(FWT))
            print("BWT = {:.4f}".format(BWT))
            print("  F = {:.4f}\n\n".format(F))

    #-------------------------------------------------------------------------------------------------#

    #------------------#
    #----- OUTPUT -----#
    #------------------#

    # Average precision on full test set
    output_file = open("{}/prec-{}.txt".format(args.r_dir, param_stamp), 'w')
    output_file.write('{}\n'.format(
        average_precs_ex if args.use_exemplars else average_precs))
    output_file.close()
    # -metrics-dict
    if args.metrics:
        file_name = "{}/dict-{}".format(args.r_dir, param_stamp)
        utils.save_object(metrics_dict, file_name)

    #-------------------------------------------------------------------------------------------------#

    #--------------------#
    #----- PLOTTING -----#
    #--------------------#

    # If requested, generate pdf
    if args.pdf:
        # -open pdf
        plot_name = "{}/{}.pdf".format(args.p_dir, param_stamp)
        pp = evaluate.visual.plt.open_pdf(plot_name)

        # -plot TEs
        if not utils.checkattr(args, 'reinit'):
            BTEs = []
            for task_id in range(args.tasks):
                BTEs.append([
                    R.loc['BTEs (per task, after {} tasks)'.
                          format(after_task_id + 1),
                          'task {}'.format(task_id + 1)]
                    for after_task_id in range(task_id, args.tasks)
                ])
            figure = visual_plt.plot_TEs([FTEs], [BTEs], [TEs], ["test"])
            pp.savefig(figure)

        # -show metrics reflecting progression during training
        if args.train and (not utils.checkattr(args, 'only_last')):
            # -create list to store all figures to be plotted.
            figure_list = []
            # -generate all figures (and store them in [figure_list])
            key = "acc per task"
            plot_list = []
            for i in range(args.tasks):
                plot_list.append(metrics_dict[key]["task {}".format(i + 1)])
            figure = visual_plt.plot_lines(plot_list,
                                           x_axes=metrics_dict["x_task"],
                                           line_names=[
                                               'task {}'.format(i + 1)
                                               for i in range(args.tasks)
                                           ])
            figure_list.append(figure)
            figure = visual_plt.plot_lines(
                [metrics_dict["average"]],
                x_axes=metrics_dict["x_task"],
                line_names=['average all tasks so far'])
            figure_list.append(figure)
            # -add figures to pdf
            for figure in figure_list:
                pp.savefig(figure)

        # -close pdf
        pp.close()

        # -print name of generated plot on screen
        if verbose:
            print("\nGenerated plot: {}\n".format(plot_name))