示例#1
0
class Classifier(ContinualLearner, ExemplarHandler):
    '''Model for encoding (i.e., feature extraction) and classifying images, enriched as "ContinualLearner"--object.'''
    def __init__(
            self,
            image_size,
            image_channels,
            classes,
            # -conv-layers
            conv_type="standard",
            depth=0,
            start_channels=64,
            reducing_layers=3,
            conv_bn=True,
            conv_nl="relu",
            num_blocks=2,
            global_pooling=False,
            no_fnl=True,
            conv_gated=False,
            # -fc-layers
            fc_layers=3,
            fc_units=1000,
            h_dim=400,
            fc_drop=0,
            fc_bn=True,
            fc_nl="relu",
            fc_gated=False,
            bias=True,
            excitability=False,
            excit_buffer=False,
            # -training-related parameters
            AGEM=False):

        # model configurations
        super().__init__()
        self.classes = classes
        self.label = "Classifier"
        self.depth = depth
        self.fc_layers = fc_layers
        self.fc_drop = fc_drop

        # settings for training
        self.AGEM = AGEM  #-> use gradient of replayed data as inequality constraint for (instead of adding it to)
        #   the gradient of the current data (as in A-GEM, see Chaudry et al., 2019; ICLR)

        # optimizer (needs to be set before training starts))
        self.optimizer = None
        self.optim_list = []

        # check whether there is at least 1 fc-layer
        if fc_layers < 1:
            raise ValueError(
                "The classifier needs to have at least 1 fully-connected layer."
            )

        ######------SPECIFY MODEL------######
        #--> convolutional layers
        self.convE = ConvLayers(
            conv_type=conv_type,
            block_type="basic",
            num_blocks=num_blocks,
            image_channels=image_channels,
            depth=depth,
            start_channels=start_channels,
            reducing_layers=reducing_layers,
            batch_norm=conv_bn,
            nl=conv_nl,
            global_pooling=global_pooling,
            gated=conv_gated,
            output="none" if no_fnl else "normal",
        )
        self.flatten = modules.Flatten()
        #------------------------------calculate input/output-sizes--------------------------------#
        self.conv_out_units = self.convE.out_units(image_size)
        self.conv_out_size = self.convE.out_size(image_size)
        self.conv_out_channels = self.convE.out_channels
        if fc_layers < 2:
            self.fc_layer_sizes = [
                self.conv_out_units
            ]  #--> this results in self.fcE = modules.Identity()
        elif fc_layers == 2:
            self.fc_layer_sizes = [self.conv_out_units, h_dim]
        else:
            self.fc_layer_sizes = [self.conv_out_units] + [
                int(x) for x in np.linspace(fc_units, h_dim, num=fc_layers - 1)
            ]
        self.units_before_classifier = h_dim if fc_layers > 1 else self.conv_out_units
        #------------------------------------------------------------------------------------------#
        #--> fully connected layers
        self.fcE = MLP(
            size_per_layer=self.fc_layer_sizes,
            drop=fc_drop,
            batch_norm=fc_bn,
            nl=fc_nl,
            bias=bias,
            excitability=excitability,
            excit_buffer=excit_buffer,
            gated=fc_gated)  #, output="none") ## NOTE: temporary change!!!
        #--> classifier
        self.classifier = fc_layer(self.units_before_classifier,
                                   classes,
                                   excit_buffer=True,
                                   nl='none',
                                   drop=fc_drop)

    def list_init_layers(self):
        '''Return list of modules whose parameters could be initialized differently (i.e., conv- or fc-layers).'''
        list = self.convE.list_init_layers()
        list += self.fcE.list_init_layers()
        list += self.classifier.list_init_layers()
        return list

    @property
    def name(self):
        if self.depth > 0 and self.fc_layers > 1:
            return "{}_{}_c{}".format(self.convE.name, self.fcE.name,
                                      self.classes)
        elif self.depth > 0:
            return "{}_{}c{}".format(
                self.convE.name,
                "drop{}-".format(self.fc_drop) if self.fc_drop > 0 else "",
                self.classes)
        elif self.fc_layers > 1:
            return "{}_c{}".format(self.fcE.name, self.classes)
        else:
            return "i{}_{}c{}".format(
                self.fc_layer_sizes[0],
                "drop{}-".format(self.fc_drop) if self.fc_drop > 0 else "",
                self.classes)

    def forward(self, x):
        hidden_rep = self.convE(x)
        final_features = self.fcE(self.flatten(hidden_rep))
        return self.classifier(final_features)

    def feature_extractor(self, images):
        return self.fcE(self.flatten(self.convE(images)))

    def classify(self, x):
        '''For input [x] (image or extracted "intermediate" image features), return all predicted "scores"/"logits".'''
        image_features = self.flatten(self.convE(x))
        hE = self.fcE(image_features)
        return self.classifier(hE)

    def train_a_batch(self,
                      x,
                      y=None,
                      x_=None,
                      y_=None,
                      scores_=None,
                      rnt=0.5,
                      active_classes=None,
                      task=1,
                      freeze_convE=False,
                      **kwargs):
        '''Train model for one batch ([x],[y]), possibly supplemented with replayed data ([x_],[y_]).

        [x]                 <tensor> batch of inputs (could be None, in which case only 'replayed' data is used)
        [y]                 <tensor> batch of corresponding labels
        [x_]                None or (<list> of) <tensor> batch of replayed inputs
        [y_]                None or (<list> of) <tensor> batch of corresponding "replayed" labels
        [scores_]           None or (<list> of) <tensor> 2Dtensor:[batch]x[classes] predicted "scores"/"logits" for [x_]
        [rnt]               <number> in [0,1], relative importance of new task
        [active_classes]    None or (<list> of) <list> with "active" classes
        [task]              <int>, for setting task-specific mask'''

        # Set model to training-mode
        self.train()
        if freeze_convE:
            # - if conv-layers are frozen, they shoud be set to eval() to prevent batch-norm layers from changing
            self.convE.eval()

        # Reset optimizer
        self.optimizer.zero_grad()

        # Should gradient be computed separately for each task? (needed when a task-mask is combined with replay)
        gradient_per_task = True if ((self.mask_dict is not None) and
                                     (x_ is not None)) else False

        ##--(1)-- REPLAYED DATA --##

        if x_ is not None:
            # In the Task-IL scenario, [y_] or [scores_] is a list and [x_] needs to be evaluated on each of them
            # (in case of 'exact' or 'exemplar' replay, [x_] is also a list!
            TaskIL = (type(y_) == list) if (y_ is not None) else (type(scores_)
                                                                  == list)
            if not TaskIL:
                y_ = [y_]
                scores_ = [scores_]
                active_classes = [active_classes] if (active_classes
                                                      is not None) else None
            n_replays = len(y_) if (y_ is not None) else len(scores_)

            # Prepare lists to store losses for each replay
            loss_replay = [None] * n_replays
            predL_r = [None] * n_replays
            distilL_r = [None] * n_replays

            # Run model (if [x_] is not a list with separate replay per task and there is no task-specific mask)
            if (not type(x_) == list) and (self.mask_dict is None):
                y_hat_all = self(x_)

            # Loop to evalute predictions on replay according to each previous task
            for replay_id in range(n_replays):

                # -if [x_] is a list with separate replay per task, evaluate model on this task's replay
                if (type(x_) == list) or (self.mask_dict is not None):
                    x_temp_ = x_[replay_id] if type(x_) == list else x_
                    if self.mask_dict is not None:
                        self.apply_XdGmask(task=replay_id + 1)
                    y_hat_all = self(x_temp_)

                # -if needed, remove predictions for classes not in replayed task
                y_hat = y_hat_all if (
                    active_classes is None
                ) else y_hat_all[:, active_classes[replay_id]]

                # Calculate losses
                if (y_ is not None) and (y_[replay_id] is not None):
                    predL_r[replay_id] = F.cross_entropy(y_hat,
                                                         y_[replay_id],
                                                         reduction='mean')
                if (scores_ is not None) and (scores_[replay_id] is not None):
                    # n_classes_to_consider = scores.size(1) #--> with this version, no zeroes are added to [scores]!
                    n_classes_to_consider = y_hat.size(
                        1
                    )  #--> zeros will be added to [scores] to make it this size!
                    kd_fn = lf.loss_fn_kd
                    distilL_r[replay_id] = kd_fn(
                        scores=y_hat[:, :n_classes_to_consider],
                        target_scores=scores_[replay_id],
                        T=self.KD_temp)
                # Weigh losses
                if self.replay_targets == "hard":
                    loss_replay[replay_id] = predL_r[replay_id]
                elif self.replay_targets == "soft":
                    loss_replay[replay_id] = distilL_r[replay_id]

                # If needed, perform backward pass before next task-mask (gradients of all tasks will be accumulated)
                if gradient_per_task:
                    weight = 1 if self.AGEM else (1 - rnt)
                    weighted_replay_loss_this_task = weight * loss_replay[
                        replay_id] / n_replays
                    weighted_replay_loss_this_task.backward()

        # Calculate total replay loss
        loss_replay = None if (x_ is None) else sum(loss_replay) / n_replays

        # If using A-GEM, calculate and store averaged gradient of replayed data
        if self.AGEM and x_ is not None:
            # Perform backward pass to calculate gradient of replayed batch (if not yet done)
            if not gradient_per_task:
                loss_replay.backward()
            # Reorganize the gradient of the replayed batch as a single vector
            grad_rep = []
            for p in self.parameters():
                if p.requires_grad:
                    grad_rep.append(p.grad.view(-1))
            grad_rep = torch.cat(grad_rep)
            # Reset gradients (with GEM, gradients of replayed batch should only be used as inequality constraint)
            self.optimizer.zero_grad()

        ##--(2)-- CURRENT DATA --##

        if x is not None:
            # If requested, apply correct task-specific mask
            if self.mask_dict is not None:
                self.apply_XdGmask(task=task)

            # Run model
            y_hat = self(x)
            # -if needed, remove predictions for classes not in current task
            if active_classes is not None:
                class_entries = active_classes[-1] if type(
                    active_classes[0]) == list else active_classes
                y_hat = y_hat[:, class_entries]

            # Calculate prediction loss
            predL = None if y is None else F.cross_entropy(
                input=y_hat, target=y, reduction='mean')

            # Weigh losses
            loss_cur = predL

            # Calculate training-precision
            precision = None if y is None else (
                y == y_hat.max(1)[1]).sum().item() / x.size(0)

            # If backward passes are performed per task (i.e., when XdG combined with replay), perform backward pass
            if gradient_per_task:
                weighted_current_loss = rnt * loss_cur
                weighted_current_loss.backward()
        else:
            precision = predL = None
            # -> it's possible there is only "replay" [i.e., for offline with incremental task learning]

        # Combine loss from current and replayed batch
        if x_ is None or self.AGEM:
            loss_total = loss_cur
        else:
            loss_total = loss_replay if (
                x is None) else rnt * loss_cur + (1 - rnt) * loss_replay

        ##--(3)-- ALLOCATION LOSSES --##

        # Add SI-loss (Zenke et al., 2017)
        surrogate_loss = self.surrogate_loss()
        if self.si_c > 0:
            loss_total += self.si_c * surrogate_loss

        # Add EWC-loss
        ewc_loss = self.ewc_loss()
        if self.ewc_lambda > 0:
            loss_total += self.ewc_lambda * ewc_loss

        # Backpropagate errors (if not yet done)
        if not gradient_per_task:
            loss_total.backward()

        # If using A-GEM:
        if self.AGEM and x_ is not None:
            # -reorganize gradient (of current batch) as single vector
            grad_cur = []
            for p in self.parameters():
                if p.requires_grad:
                    grad_cur.append(p.grad.view(-1))
            grad_cur = torch.cat(grad_cur)
            # -check inequality constrain
            angle = (grad_cur * grad_rep).sum()
            if angle < 0:
                # -if violated, project the gradient of the current batch onto the gradient of the replayed batch ...
                length_rep = (grad_rep * grad_rep).sum()
                grad_proj = grad_cur - (angle / length_rep) * grad_rep
                # -...and replace all the gradients within the model with this projected gradient
                index = 0
                for p in self.parameters():
                    if p.requires_grad:
                        n_param = p.numel()  # number of parameters in [p]
                        p.grad.copy_(grad_proj[index:index +
                                               n_param].view_as(p))
                        index += n_param

        # Take optimization-step
        self.optimizer.step()

        # Return the dictionary with different training-loss split in categories
        return {
            'loss_total':
            loss_total.item(),
            'loss_current':
            loss_cur.item() if x is not None else 0,
            'loss_replay':
            loss_replay.item() if
            (loss_replay is not None) and (x is not None) else 0,
            'pred':
            predL.item() if predL is not None else 0,
            'pred_r':
            sum(predL_r).item() / n_replays if
            (x_ is not None and predL_r[0] is not None) else 0,
            'distil_r':
            sum(distilL_r).item() / n_replays if
            (x_ is not None and distilL_r[0] is not None) else 0,
            'ewc':
            ewc_loss.item(),
            'si_loss':
            surrogate_loss.item(),
            'precision':
            precision if precision is not None else 0.,
        }
示例#2
0
    def __init__(
            self,
            image_size,
            image_channels,
            classes,
            # -conv-layers
            conv_type="standard",
            depth=0,
            start_channels=64,
            reducing_layers=3,
            conv_bn=True,
            conv_nl="relu",
            num_blocks=2,
            global_pooling=False,
            no_fnl=True,
            conv_gated=False,
            # -fc-layers
            fc_layers=3,
            fc_units=1000,
            h_dim=400,
            fc_drop=0,
            fc_bn=True,
            fc_nl="relu",
            fc_gated=False,
            bias=True,
            excitability=False,
            excit_buffer=False,
            # -training-related parameters
            AGEM=False):

        # model configurations
        super().__init__()
        self.classes = classes
        self.label = "Classifier"
        self.depth = depth
        self.fc_layers = fc_layers
        self.fc_drop = fc_drop

        # settings for training
        self.AGEM = AGEM  #-> use gradient of replayed data as inequality constraint for (instead of adding it to)
        #   the gradient of the current data (as in A-GEM, see Chaudry et al., 2019; ICLR)

        # optimizer (needs to be set before training starts))
        self.optimizer = None
        self.optim_list = []

        # check whether there is at least 1 fc-layer
        if fc_layers < 1:
            raise ValueError(
                "The classifier needs to have at least 1 fully-connected layer."
            )

        ######------SPECIFY MODEL------######
        #--> convolutional layers
        self.convE = ConvLayers(
            conv_type=conv_type,
            block_type="basic",
            num_blocks=num_blocks,
            image_channels=image_channels,
            depth=depth,
            start_channels=start_channels,
            reducing_layers=reducing_layers,
            batch_norm=conv_bn,
            nl=conv_nl,
            global_pooling=global_pooling,
            gated=conv_gated,
            output="none" if no_fnl else "normal",
        )
        self.flatten = modules.Flatten()
        #------------------------------calculate input/output-sizes--------------------------------#
        self.conv_out_units = self.convE.out_units(image_size)
        self.conv_out_size = self.convE.out_size(image_size)
        self.conv_out_channels = self.convE.out_channels
        if fc_layers < 2:
            self.fc_layer_sizes = [
                self.conv_out_units
            ]  #--> this results in self.fcE = modules.Identity()
        elif fc_layers == 2:
            self.fc_layer_sizes = [self.conv_out_units, h_dim]
        else:
            self.fc_layer_sizes = [self.conv_out_units] + [
                int(x) for x in np.linspace(fc_units, h_dim, num=fc_layers - 1)
            ]
        self.units_before_classifier = h_dim if fc_layers > 1 else self.conv_out_units
        #------------------------------------------------------------------------------------------#
        #--> fully connected layers
        self.fcE = MLP(
            size_per_layer=self.fc_layer_sizes,
            drop=fc_drop,
            batch_norm=fc_bn,
            nl=fc_nl,
            bias=bias,
            excitability=excitability,
            excit_buffer=excit_buffer,
            gated=fc_gated)  #, output="none") ## NOTE: temporary change!!!
        #--> classifier
        self.classifier = fc_layer(self.units_before_classifier,
                                   classes,
                                   excit_buffer=True,
                                   nl='none',
                                   drop=fc_drop)
示例#3
0
    def __init__(
            self,
            image_size,
            image_channels,
            classes,
            target_name,
            only_active=False,
            # -conv-layers
            conv_type="standard",
            depth=5,
            start_channels=16,
            reducing_layers=4,
            conv_bn=True,
            conv_nl="relu",
            num_blocks=2,
            global_pooling=False,
            no_fnl=True,
            conv_gated=False,
            # -fc-layers
            fc_layers=3,
            fc_units=2000,
            h_dim=2000,
            fc_drop=0,
            fc_bn=False,
            fc_nl="relu",
            excit_buffer=False,
            fc_gated=False,
            # -prior
            prior="GMM",
            z_dim=100,
            per_class=True,
            n_modes=1,
            # -decoder
            recon_loss='MSEnorm',
            dg_gates=True,
            dg_prop=0.5,
            device='cpu',
            # -training-specific settings (can be changed after setting up model)
            lamda_pl=1.,
            lamda_rcl=1.,
            lamda_vl=1.,
            **kwargs):

        # Set configurations for setting up the model
        super().__init__()
        self.target_name = target_name
        self.label = "BIR"
        self.image_size = image_size
        self.image_channels = image_channels
        self.classes = classes
        self.fc_layers = fc_layers
        self.z_dim = z_dim
        self.h_dim = h_dim
        self.fc_units = fc_units
        self.fc_drop = fc_drop
        self.depth = depth

        # whether always all classes can be predicted or only those seen so far
        self.only_active = only_active
        self.active_classes = []

        # -type of loss to be used for reconstruction
        self.recon_loss = recon_loss  # options: BCE|MSE|MSEnorm
        self.network_output = "sigmoid" if self.recon_loss in (
            "MSE", "BCE") else "none"
        # -settings for class-specific gates in fully-connected hidden layers of decoder
        self.dg_prop = dg_prop
        self.dg_gates = dg_gates if dg_prop > 0. else False
        self.gate_size = classes if self.dg_gates else 0

        # Prior-related parameters
        self.prior = prior
        self.per_class = per_class
        self.n_modes = n_modes * classes if self.per_class else n_modes
        self.modes_per_class = n_modes if self.per_class else None

        # Components deciding how to train / run the model (i.e., these can be changed after setting up the model)
        # -options for prediction loss
        self.lamda_pl = lamda_pl  # weight of classification-loss
        # -how to compute the loss function?
        self.lamda_rcl = lamda_rcl  # weight of reconstruction-loss
        self.lamda_vl = lamda_vl  # weight of variational loss

        # Check whether there is at least 1 fc-layer
        if fc_layers < 1:
            raise ValueError("VAE cannot have 0 fully-connected layers!")

        ######------SPECIFY MODEL------######

        ##>----Encoder (= q[z|x])----<##
        self.convE = ConvLayers(conv_type=conv_type,
                                block_type="basic",
                                num_blocks=num_blocks,
                                image_channels=image_channels,
                                depth=self.depth,
                                start_channels=start_channels,
                                reducing_layers=reducing_layers,
                                batch_norm=conv_bn,
                                nl=conv_nl,
                                output="none" if no_fnl else "normal",
                                global_pooling=global_pooling,
                                gated=conv_gated)
        self.flatten = modules.Flatten()
        #------------------------------calculate input/output-sizes--------------------------------#
        self.conv_out_units = self.convE.out_units(image_size)
        self.conv_out_size = self.convE.out_size(image_size)
        self.conv_out_channels = self.convE.out_channels
        if fc_layers < 2:
            self.fc_layer_sizes = [
                self.conv_out_units
            ]  #--> this results in self.fcE = modules.Identity()
        elif fc_layers == 2:
            self.fc_layer_sizes = [self.conv_out_units, h_dim]
        else:
            self.fc_layer_sizes = [self.conv_out_units] + [
                int(x) for x in np.linspace(fc_units, h_dim, num=fc_layers - 1)
            ]
        real_h_dim = h_dim if fc_layers > 1 else self.conv_out_units
        #------------------------------------------------------------------------------------------#
        self.fcE = MLP(size_per_layer=self.fc_layer_sizes,
                       drop=fc_drop,
                       batch_norm=fc_bn,
                       nl=fc_nl,
                       excit_buffer=excit_buffer,
                       gated=fc_gated)
        # to z
        self.toZ = fc_layer_split(real_h_dim,
                                  z_dim,
                                  nl_mean='none',
                                  nl_logvar='none')  #, drop=fc_drop)

        ##>----Classifier----<##
        self.units_before_classifier = real_h_dim
        self.classifier = fc_layer(self.units_before_classifier,
                                   classes,
                                   excit_buffer=True,
                                   nl='none')

        ##>----Decoder (= p[x|z])----<##
        out_nl = True if fc_layers > 1 else (True if (
            self.depth > 0 and not no_fnl) else False)
        real_h_dim_down = h_dim if fc_layers > 1 else self.convE.out_units(
            image_size, ignore_gp=True)
        if self.dg_gates:
            self.fromZ = fc_layer_fixed_gates(
                z_dim,
                real_h_dim_down,
                batch_norm=(out_nl and fc_bn),
                nl=fc_nl if out_nl else "none",
                gate_size=self.gate_size,
                gating_prop=dg_prop,
            )
        else:
            self.fromZ = fc_layer(z_dim,
                                  real_h_dim_down,
                                  batch_norm=(out_nl and fc_bn),
                                  nl=fc_nl if out_nl else "none")
        fc_layer_sizes_down = self.fc_layer_sizes
        fc_layer_sizes_down[0] = self.convE.out_units(image_size,
                                                      ignore_gp=True)
        # -> if 'gp' is used in forward pass, size of first/final hidden layer differs between forward and backward pass
        if self.dg_gates:
            self.fcD = MLP_gates(
                size_per_layer=[x for x in reversed(fc_layer_sizes_down)],
                drop=fc_drop,
                batch_norm=fc_bn,
                nl=fc_nl,
                gate_size=self.gate_size,
                gating_prop=dg_prop,
                device=device,
                output=self.network_output,
            )
        else:
            self.fcD = MLP(
                size_per_layer=[x for x in reversed(fc_layer_sizes_down)],
                drop=fc_drop,
                batch_norm=fc_bn,
                nl=fc_nl,
                gated=fc_gated,
                output=self.network_output,
            )
        # to image-shape
        self.to_image = modules.Reshape(image_channels=self.convE.out_channels
                                        if self.depth > 0 else image_channels)
        # through deconv-layers
        self.convD = modules.Identity()

        ##>----Prior----<##
        # -if using the GMM-prior, add its parameters
        if self.prior == "GMM":
            # -create
            self.z_class_means = nn.Parameter(
                torch.Tensor(self.n_modes, self.z_dim))
            self.z_class_logvars = nn.Parameter(
                torch.Tensor(self.n_modes, self.z_dim))
            # -initialize
            self.z_class_means.data.normal_()
            self.z_class_logvars.data.normal_()
示例#4
0
class IntegratedReplayModel(ContinualLearner):
    """Class for brain-inspired replay (BI-R) models."""
    def __init__(
            self,
            image_size,
            image_channels,
            classes,
            target_name,
            only_active=False,
            # -conv-layers
            conv_type="standard",
            depth=5,
            start_channels=16,
            reducing_layers=4,
            conv_bn=True,
            conv_nl="relu",
            num_blocks=2,
            global_pooling=False,
            no_fnl=True,
            conv_gated=False,
            # -fc-layers
            fc_layers=3,
            fc_units=2000,
            h_dim=2000,
            fc_drop=0,
            fc_bn=False,
            fc_nl="relu",
            excit_buffer=False,
            fc_gated=False,
            # -prior
            prior="GMM",
            z_dim=100,
            per_class=True,
            n_modes=1,
            # -decoder
            recon_loss='MSEnorm',
            dg_gates=True,
            dg_prop=0.5,
            device='cpu',
            # -training-specific settings (can be changed after setting up model)
            lamda_pl=1.,
            lamda_rcl=1.,
            lamda_vl=1.,
            **kwargs):

        # Set configurations for setting up the model
        super().__init__()
        self.target_name = target_name
        self.label = "BIR"
        self.image_size = image_size
        self.image_channels = image_channels
        self.classes = classes
        self.fc_layers = fc_layers
        self.z_dim = z_dim
        self.h_dim = h_dim
        self.fc_units = fc_units
        self.fc_drop = fc_drop
        self.depth = depth

        # whether always all classes can be predicted or only those seen so far
        self.only_active = only_active
        self.active_classes = []

        # -type of loss to be used for reconstruction
        self.recon_loss = recon_loss  # options: BCE|MSE|MSEnorm
        self.network_output = "sigmoid" if self.recon_loss in (
            "MSE", "BCE") else "none"
        # -settings for class-specific gates in fully-connected hidden layers of decoder
        self.dg_prop = dg_prop
        self.dg_gates = dg_gates if dg_prop > 0. else False
        self.gate_size = classes if self.dg_gates else 0

        # Prior-related parameters
        self.prior = prior
        self.per_class = per_class
        self.n_modes = n_modes * classes if self.per_class else n_modes
        self.modes_per_class = n_modes if self.per_class else None

        # Components deciding how to train / run the model (i.e., these can be changed after setting up the model)
        # -options for prediction loss
        self.lamda_pl = lamda_pl  # weight of classification-loss
        # -how to compute the loss function?
        self.lamda_rcl = lamda_rcl  # weight of reconstruction-loss
        self.lamda_vl = lamda_vl  # weight of variational loss

        # Check whether there is at least 1 fc-layer
        if fc_layers < 1:
            raise ValueError("VAE cannot have 0 fully-connected layers!")

        ######------SPECIFY MODEL------######

        ##>----Encoder (= q[z|x])----<##
        self.convE = ConvLayers(conv_type=conv_type,
                                block_type="basic",
                                num_blocks=num_blocks,
                                image_channels=image_channels,
                                depth=self.depth,
                                start_channels=start_channels,
                                reducing_layers=reducing_layers,
                                batch_norm=conv_bn,
                                nl=conv_nl,
                                output="none" if no_fnl else "normal",
                                global_pooling=global_pooling,
                                gated=conv_gated)
        self.flatten = modules.Flatten()
        #------------------------------calculate input/output-sizes--------------------------------#
        self.conv_out_units = self.convE.out_units(image_size)
        self.conv_out_size = self.convE.out_size(image_size)
        self.conv_out_channels = self.convE.out_channels
        if fc_layers < 2:
            self.fc_layer_sizes = [
                self.conv_out_units
            ]  #--> this results in self.fcE = modules.Identity()
        elif fc_layers == 2:
            self.fc_layer_sizes = [self.conv_out_units, h_dim]
        else:
            self.fc_layer_sizes = [self.conv_out_units] + [
                int(x) for x in np.linspace(fc_units, h_dim, num=fc_layers - 1)
            ]
        real_h_dim = h_dim if fc_layers > 1 else self.conv_out_units
        #------------------------------------------------------------------------------------------#
        self.fcE = MLP(size_per_layer=self.fc_layer_sizes,
                       drop=fc_drop,
                       batch_norm=fc_bn,
                       nl=fc_nl,
                       excit_buffer=excit_buffer,
                       gated=fc_gated)
        # to z
        self.toZ = fc_layer_split(real_h_dim,
                                  z_dim,
                                  nl_mean='none',
                                  nl_logvar='none')  #, drop=fc_drop)

        ##>----Classifier----<##
        self.units_before_classifier = real_h_dim
        self.classifier = fc_layer(self.units_before_classifier,
                                   classes,
                                   excit_buffer=True,
                                   nl='none')

        ##>----Decoder (= p[x|z])----<##
        out_nl = True if fc_layers > 1 else (True if (
            self.depth > 0 and not no_fnl) else False)
        real_h_dim_down = h_dim if fc_layers > 1 else self.convE.out_units(
            image_size, ignore_gp=True)
        if self.dg_gates:
            self.fromZ = fc_layer_fixed_gates(
                z_dim,
                real_h_dim_down,
                batch_norm=(out_nl and fc_bn),
                nl=fc_nl if out_nl else "none",
                gate_size=self.gate_size,
                gating_prop=dg_prop,
            )
        else:
            self.fromZ = fc_layer(z_dim,
                                  real_h_dim_down,
                                  batch_norm=(out_nl and fc_bn),
                                  nl=fc_nl if out_nl else "none")
        fc_layer_sizes_down = self.fc_layer_sizes
        fc_layer_sizes_down[0] = self.convE.out_units(image_size,
                                                      ignore_gp=True)
        # -> if 'gp' is used in forward pass, size of first/final hidden layer differs between forward and backward pass
        if self.dg_gates:
            self.fcD = MLP_gates(
                size_per_layer=[x for x in reversed(fc_layer_sizes_down)],
                drop=fc_drop,
                batch_norm=fc_bn,
                nl=fc_nl,
                gate_size=self.gate_size,
                gating_prop=dg_prop,
                device=device,
                output=self.network_output,
            )
        else:
            self.fcD = MLP(
                size_per_layer=[x for x in reversed(fc_layer_sizes_down)],
                drop=fc_drop,
                batch_norm=fc_bn,
                nl=fc_nl,
                gated=fc_gated,
                output=self.network_output,
            )
        # to image-shape
        self.to_image = modules.Reshape(image_channels=self.convE.out_channels
                                        if self.depth > 0 else image_channels)
        # through deconv-layers
        self.convD = modules.Identity()

        ##>----Prior----<##
        # -if using the GMM-prior, add its parameters
        if self.prior == "GMM":
            # -create
            self.z_class_means = nn.Parameter(
                torch.Tensor(self.n_modes, self.z_dim))
            self.z_class_logvars = nn.Parameter(
                torch.Tensor(self.n_modes, self.z_dim))
            # -initialize
            self.z_class_means.data.normal_()
            self.z_class_logvars.data.normal_()

    ##------ NAMES --------##

    def get_name(self):
        convE_label = "{}{}_".format(self.convE.name,
                                     "H") if self.depth > 0 else ""
        fcE_label = "{}_".format(
            self.fcE.name) if self.fc_layers > 1 else "{}{}_".format(
                "h" if self.depth > 0 else "i", self.conv_out_units)
        z_label = "z{}{}".format(
            self.z_dim, "" if self.prior == "standard" else "-{}{}{}".format(
                self.prior, self.n_modes, "pc" if self.per_class else ""))
        class_label = "_c{}".format(self.classes)
        decoder_label = "_{}{}".format("cg",
                                       self.dg_prop) if self.dg_gates else ""
        return "{}={}{}{}{}{}".format(self.label, convE_label, fcE_label,
                                      z_label, class_label, decoder_label)

    @property
    def name(self):
        return self.get_name()

    ##------ LAYERS --------##

    def list_init_layers(self):
        '''Return list of modules whose parameters could be initialized differently (i.e., conv- or fc-layers).'''
        list = []
        list += self.convE.list_init_layers()
        list += self.fcE.list_init_layers()
        list += self.classifier.list_init_layers()
        list += self.toZ.list_init_layers()
        list += self.fromZ.list_init_layers()
        list += self.fcD.list_init_layers()
        return list

    ##------ FORWARD FUNCTIONS --------##

    def update_active_classes(self, y):
        '''Given labels of newly observed batch, update list of all classes seen so far.'''
        for i in y.cpu().numpy():
            if not i in self.active_classes:
                self.active_classes.append(i)

    def get_logits(self, x, hidden=False):
        '''Perform feedforward pass solely to get predicted logits.'''
        hidden_x = self.convE(x) if not hidden else x
        logits = self.classifier(self.fcE(self.flatten(hidden_x)))
        return logits

    def forward(self, x, hidden=False):
        '''Return tensors required for updating weights and a dict (with key self.target_name) of predicted labels.'''
        # Perform forward pass
        hidden_x = self.convE(x) if not hidden else x
        hE = self.fcE(self.flatten(hidden_x))
        logits = self.classifier(hE)
        # Get predictions
        if self.only_active and len(self.active_classes) > 0:
            # -restrict predictions to those classes listed in [self.active_classes]
            logits_for_prediction = logits[:, self.active_classes]
            predictions_shifted = logits_for_prediction.cpu().data.numpy(
            ).argmax(1)
            predictions = {
                self.target_name:
                np.array([self.active_classes[i] for i in predictions_shifted])
            }
        else:
            # -all classes can be predicted (even those not yet observed)
            predictions = {
                self.target_name: logits.cpu().data.numpy().argmax(1)
            }
        # Create tuple of tensors required for updating weights
        tensors_for_weight_update = (hidden_x, hE, logits)
        return tensors_for_weight_update, predictions

    def reparameterize(self, mu, logvar):
        '''Perform "reparametrization trick" to make these stochastic variables differentiable.'''
        std = logvar.mul(0.5).exp_()
        eps = std.new(std.size()).normal_()
        return eps.mul(std).add_(mu)

    def decode(self, z, gate_input=None):
        '''Decode latent variable activations.

        INPUT:  - [z]            <2D-tensor>; latent variables to be decoded
                - [gate_input]   <1D-tensor> or <np.ndarray>; for each batch-element in [x] its class-/taskID  ---OR---
                                 <2D-tensor>; for each batch-element in [x] a probability for every class-/task-ID

        OUTPUT: - [image_recon]  <4D-tensor>'''

        # -if needed, convert [gate_input] to one-hot vector
        if self.dg_gates and (gate_input
                              is not None) and (type(gate_input) == np.ndarray
                                                or gate_input.dim() < 2):
            gate_input = lf.to_one_hot(gate_input,
                                       classes=self.gate_size,
                                       device=self._device())

        # -put inputs through decoder
        hD = self.fromZ(
            z, gate_input=gate_input) if self.dg_gates else self.fromZ(z)
        image_features = self.fcD(
            hD, gate_input=gate_input) if self.dg_gates else self.fcD(hD)
        image_recon = self.convD(self.to_image(image_features))
        return image_recon

    def continued(self, hE, gate_input=None, reparameterize=True, **kwargs):
        '''Forward function to propagate [hE] furhter through the encoder, reparametrization and decoder.

        Input: - [x]          <4D-tensor> of shape [batch_size]x[out_channels]x[out_size]x[outsize]
               - [gate_input] <1D-tensor> or <np.ndarray>; for each batch-element in [x] its class-ID (eg, [y]) ---OR---
                              <2D-tensor>; for each batch-element in [x] a probability for each class-ID (eg, [y_hat])

        Output should be a <tuple> consisting of:
        - [x_recon]     <4D-tensor> reconstructed image (features) in same shape as [x]
        - [z_mean]      <2D-tensor> with either [z] or the estimated mean of [z]
        - [z_logvar]    <2D-tensor> estimated log(SD^2) of [z]
        - [z]           <2D-tensor> reparameterized [z] used for reconstruction'''

        # Get parameters for reparametrization
        (z_mean, z_logvar) = self.toZ(hE)
        # -reparameterize
        z = self.reparameterize(z_mean, z_logvar) if reparameterize else z_mean
        # -decode
        gate_input = gate_input if self.dg_gates else None
        x_recon = self.decode(z, gate_input=gate_input)
        # -return
        return (x_recon, z_mean, z_logvar, z)

    ##------ SAMPLE FUNCTIONS --------##

    def sample(self,
               size,
               allowed_classes=None,
               class_probs=None,
               sample_mode=None,
               **kwargs):
        '''Generate [size] samples from the model. Outputs are tensors (not "requiring grad"), on same device as <self>.

        INPUT:  - [allowed_classes]     <list> of [class_ids] from which to sample
                - [class_probs]         <list> with for each class the probability it is sampled from it
                - [sample_mode]         <int> to sample from specific mode of [z]-distr'n, overwrites [allowed_classes]

        OUTPUT: - [X]  <4D-tensor> generated image-features of shape [batch_size]x[out_channels]x[out_size]x[outsize]'''

        # set model to eval()-mode
        mode = self.training
        self.eval()

        # sample for each sample the prior-mode to be used
        if self.prior == "GMM":
            if sample_mode is None:
                if (allowed_classes is None
                        and class_probs is None) or (not self.per_class):
                    # -randomly sample modes from all possible modes (and find their corresponding class, if applicable)
                    sampled_modes = np.random.randint(0, self.n_modes, size)
                    y_used = np.array([
                        int(mode / self.modes_per_class)
                        for mode in sampled_modes
                    ]) if self.per_class else None
                else:
                    if allowed_classes is None:
                        allowed_classes = [i for i in range(len(class_probs))]
                    # -sample from modes belonging to [allowed_classes], possibly weighted according to [class_probs]
                    allowed_modes = []  # -collect all allowed modes
                    unweighted_probs = [
                    ]  # -collect unweighted sample-probabilities of those modes
                    for index, class_id in enumerate(allowed_classes):
                        allowed_modes += list(
                            range(class_id * self.modes_per_class,
                                  (class_id + 1) * self.modes_per_class))
                        if class_probs is not None:
                            for i in range(self.modes_per_class):
                                unweighted_probs.append(
                                    class_probs[index].item())
                    mode_probs = None if class_probs is None else [
                        p / sum(unweighted_probs) for p in unweighted_probs
                    ]
                    sampled_modes = np.random.choice(allowed_modes,
                                                     size,
                                                     p=mode_probs,
                                                     replace=True)
                    y_used = np.array([
                        int(mode / self.modes_per_class)
                        for mode in sampled_modes
                    ])
            else:
                # -always sample from the provided mode
                sampled_modes = np.repeat(sample_mode, size)
                y_used = np.repeat(int(sample_mode / self.modes_per_class),
                                   size) if self.per_class else None
        else:
            y_used = None

        # sample z
        if self.prior == "GMM":
            prior_means = self.z_class_means
            prior_logvars = self.z_class_logvars
            # -for each sample to be generated, select the previously sampled mode
            z_means = prior_means[sampled_modes, :]
            z_logvars = prior_logvars[sampled_modes, :]
            with torch.no_grad():
                z = self.reparameterize(z_means, z_logvars)
        else:
            z = torch.randn(size, self.z_dim).to(self._device())

        # if no classes are selected yet, but they are needed for the "decoder-gates", select classes to be sampled
        if (y_used is None) and (self.dg_gates):
            if allowed_classes is None and class_probs is None:
                y_used = np.random.randint(0, self.classes, size)
            else:
                if allowed_classes is None:
                    allowed_classes = [i for i in range(len(class_probs))]
                y_used = np.random.choice(allowed_classes,
                                          size,
                                          p=class_probs,
                                          replace=True)

        # decode z into image X
        with torch.no_grad():
            X = self.decode(z, gate_input=y_used if self.dg_gates else None)

        # set model back to its initial mode
        self.train(mode=mode)

        # return samples as [batch_size]x[out_channels]x[out_size]x[outsize] tensor
        return X

    ##------ LOSS FUNCTIONS --------##

    def calculate_recon_loss(self, x, x_recon, average=False):
        '''Calculate reconstruction loss for each element in the batch.

        INPUT:  - [x]           <tensor> with original input (1st dimension (ie, dim=0) is "batch-dimension")
                - [x_recon]     <tensor> with reconstructed input in same shape as [x]
                - [average]     <bool>, if True, loss is average over all pixels; otherwise it is summed

        OUTPUT: - [reconL]      <1D-tensor> of length [batch_size]'''

        batch_size = x.size(0)
        if self.recon_loss in ("MSE", "MSEnorm"):
            reconL = -lf.log_Normal_standard(
                x=x, mean=x_recon, average=average, dim=-1)
        elif self.recon_loss == "BCE":
            reconL = F.binary_cross_entropy(input=x_recon.view(batch_size, -1),
                                            target=x.view(batch_size, -1),
                                            reduction='none')
            reconL = torch.mean(reconL, dim=1) if average else torch.sum(
                reconL, dim=1)
        else:
            raise NotImplementedError(
                "Wrong choice for type of reconstruction-loss!")
        # --> if [average]=True, reconstruction loss is averaged over all pixels/elements (otherwise it is summed)
        #       (averaging over all elements in the batch will be done later)
        return reconL

    def calculate_variat_loss(self,
                              z,
                              mu,
                              logvar,
                              y=None,
                              y_prob=None,
                              allowed_classes=None):
        '''Calculate reconstruction loss for each element in the batch.

        INPUT:  - [z]        <2D-tensor> with sampled latent variables (1st dimension (ie, dim=0) is "batch-dimension")
                - [mu]       <2D-tensor> by encoder predicted mean for [z]
                - [logvar]   <2D-tensor> by encoder predicted logvar for [z]

        OPTIONS THAT ARE RELEVANT ONLY IF self.per_class IS TRUE:
            - [y]               None or <1D-tensor> with target-classes (as integers, corresponding to actual class-IDs)
            - [y_prob]          None or <2D-tensor> with probabilities for each class (in [allowed_classes])
            - [allowed_classes] None or <list> with class-IDs to use for selecting prior-mode(s)

        OUTPUT: - [variatL]   <1D-tensor> of length [batch_size]'''

        if self.prior == "standard":
            # --> calculate analytically
            # ---- see Appendix B from: Kingma & Welling (2014) Auto-Encoding Variational Bayes, ICLR ----#
            variatL = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(),
                                       dim=1)

        elif self.prior == "GMM":
            # --> calculate "by estimation"

            ## Get [means] and [logvars] of all (possible) modes
            allowed_modes = list(range(self.n_modes))
            # -if we don't use the specific modes of a target, we could select modes based on list of classes
            if (y is None) and (allowed_classes
                                is not None) and self.per_class:
                allowed_modes = []
                for class_id in allowed_classes:
                    allowed_modes += list(
                        range(class_id * self.modes_per_class,
                              (class_id + 1) * self.modes_per_class))
            # -calculate/retireve the means and logvars for the selected modes
            prior_means = self.z_class_means[allowed_modes, :]
            prior_logvars = self.z_class_logvars[allowed_modes, :]
            # -rearrange / select for each batch prior-modes to be used
            z_expand = z.unsqueeze(1)  # [batch_size] x 1 x [z_dim]
            means = prior_means.unsqueeze(0)  # 1 x [n_modes] x [z_dim]
            logvars = prior_logvars.unsqueeze(0)  # 1 x [n_modes] x [z_dim]

            ## Calculate "log_p_z" (log-likelihood of "reparameterized" [z] based on selected priors)
            n_modes = self.modes_per_class if (
                ((y is not None) or (y_prob is not None))
                and self.per_class) else len(allowed_modes)
            a = lf.log_Normal_diag(
                z_expand, mean=means, log_var=logvars, average=False,
                dim=2) - math.log(n_modes)
            # --> for each element in batch, calculate log-likelihood for all pseudoinputs: [batch_size] x [n_modes]
            if (y is not None) and self.per_class:
                modes_list = list()
                for i in range(len(y)):
                    target = y[i].item()
                    modes_list.append(
                        list(
                            range(target * self.modes_per_class,
                                  (target + 1) * self.modes_per_class)))
                modes_tensor = torch.LongTensor(modes_list).to(self._device())
                a = a.gather(dim=1, index=modes_tensor)
                # --> reduce [a] to size [batch_size]x[modes_per_class] (ie, per batch only keep modes of [y])
                #     but within the batch, elements can have different [y], so this reduction couldn't be done before
            a_max, _ = torch.max(a, dim=1)  # [batch_size]
            # --> for each element in batch, take highest log-likelihood over all pseudoinputs
            #     this is calculated and used to avoid underflow in the below computation
            a_exp = torch.exp(a -
                              a_max.unsqueeze(1))  # [batch_size] x [n_modes]
            if (y is None) and (y_prob is not None) and self.per_class:
                batch_size = y_prob.size(0)
                y_prob = y_prob.view(-1, 1).repeat(1,
                                                   self.modes_per_class).view(
                                                       batch_size, -1)
                # ----> extend probabilities per class to probabilities per mode; y_prob: [batch_size] x [n_modes]
                a_logsum = torch.log(
                    torch.clamp(torch.sum(y_prob * a_exp, dim=1), min=1e-40))
            else:
                a_logsum = torch.log(
                    torch.clamp(torch.sum(a_exp, dim=1),
                                min=1e-40))  # -> sum over modes: [batch_size]
            log_p_z = a_logsum + a_max  # [batch_size]

            ## Calculate "log_q_z" (entropy of "reparameterized" [z] given [x])
            log_q_z = lf.log_Normal_diag(z,
                                         mean=mu,
                                         log_var=logvar,
                                         average=False,
                                         dim=1)
            # ----->  mu: [batch_size] x [z_dim]; logvar: [batch_size] x [z_dim]; z: [batch_size] x [z_dim]
            # ----->  log_q_z: [batch_size]

            ## Combine
            variatL = -(log_p_z - log_q_z)

        return variatL

    def loss_function(self,
                      x,
                      y,
                      x_recon,
                      y_hat,
                      scores,
                      mu,
                      z,
                      logvar,
                      allowed_classes=None,
                      batch_weights=None):
        '''Calculate and return various losses that could be used for training and/or evaluating the model.

        INPUT:  - [x]           <4D-tensor> original image
                - [y]           <1D-tensor> with target-classes (as integers, corresponding to [allowed_classes])
                - [x_recon]     <4D-tensor> reconstructed image in same shape as [x]
                - [y_hat]       <2D-tensor> with predicted "logits" for each class (corresponding to [allowed_classes])
                - [scores]         <2D-tensor> with target "logits" for each class (corresponding to [allowed_classes])
                                     (if len(scores)<len(y_hat), 0 probs are added during distillation step at the end)
                - [mu]             <2D-tensor> with either [z] or the estimated mean of [z]
                - [z]              <2D-tensor> with reparameterized [z]
                - [logvar]         <2D-tensor> with estimated log(SD^2) of [z]
                - [batch_weights]  <1D-tensor> with a weight for each batch-element (if None, normal average over batch)
                - [allowed_classes]None or <list> with class-IDs to use for selecting prior-mode(s)

        OUTPUT: - [reconL]       reconstruction loss indicating how well [x] and [x_recon] match
                - [variatL]      variational (KL-divergence) loss "indicating how close distribion [z] is to prior"
                - [predL]        prediction loss indicating how well targets [y] are predicted
                - [distilL]      knowledge distillation (KD) loss indicating how well the predicted "logits" ([y_hat])
                                     match the target "logits" ([scores])'''

        ###-----Reconstruction loss-----###
        batch_size = x.size(0)
        reconL = self.calculate_recon_loss(x=x.view(batch_size, -1),
                                           average=True,
                                           x_recon=x_recon.view(
                                               batch_size,
                                               -1))  # -> average over pixels
        reconL = lf.weighted_average(reconL, weights=batch_weights,
                                     dim=0)  # -> average over batch

        ###-----Variational loss-----###
        actual_y = torch.tensor([allowed_classes[i.item()] for i in y]).to(
            self._device()) if ((allowed_classes is not None) and
                                (y is not None)) else y
        if (y is None and scores is not None):
            y_prob = F.softmax(scores / self.distill_temp, dim=1)
            if allowed_classes is not None and len(
                    allowed_classes) > y_prob.size(1):
                n_batch = y_prob.size(0)
                zeros_to_add = torch.zeros(
                    n_batch,
                    len(allowed_classes) - y_prob.size(1))
                zeros_to_add = zeros_to_add.to(self._device())
                y_prob = torch.cat([y_prob, zeros_to_add], dim=1)
        else:
            y_prob = None
        # ---> if [y] is not provided but [scores] is, calculate variational loss using weighted sum of prior-modes
        variatL = self.calculate_variat_loss(z=z,
                                             mu=mu,
                                             logvar=logvar,
                                             y=actual_y,
                                             y_prob=y_prob,
                                             allowed_classes=allowed_classes)
        variatL = lf.weighted_average(variatL, weights=batch_weights,
                                      dim=0)  # -> average over batch
        variatL /= (self.image_channels * self.image_size**2
                    )  # -> divide by # of input-pixels

        ###-----Prediction loss-----###
        if y is not None and y_hat is not None:
            predL = F.cross_entropy(input=y_hat, target=y, reduction='none')
            #--> no reduction needed, summing over classes is "implicit"
            predL = lf.weighted_average(predL, weights=batch_weights,
                                        dim=0)  # -> average over batch
        else:
            predL = torch.tensor(0., device=self._device())

        ###-----Distilliation loss-----###
        if scores is not None and y_hat is not None:
            # n_classes_to_consider = scores.size(1) #--> with this version, no zeroes would be added to [scores]!
            n_classes_to_consider = y_hat.size(
                1)  #--> zeros will be added to [scores] to make it this size!
            distilL = lf.loss_fn_kd(
                scores=y_hat[:, :n_classes_to_consider],
                target_scores=scores,
                T=self.distill_temp,
                weights=batch_weights
            )  #--> summing over classes & averaging over batch in function
        else:
            distilL = torch.tensor(0., device=self._device())

        # Return a tuple of the calculated losses
        return reconL, variatL, predL, distilL

    ##------ TRAINING FUNCTIONS --------##

    def update_weights(self,
                       tensors_for_weight_update,
                       y,
                       rnt=None,
                       update=True,
                       **kwargs):
        '''Train model for one batch ([x],[y]), with [x] transformed to [tensor_for_weight_update] by forward-pass.

        [tensors_for_weight_update]  <tuple> containing (hidden_x, hE, logits)
        [y]                          <tensor> batch of corresponding ground-truth labels'''

        # Set model to training-mode
        if update:
            self.train()

        # Reset optimizer
        self.optimizer.zero_grad()

        # Unpack [tensor_for_weight_update]
        hidden_x = tensors_for_weight_update[0]
        hE = tensors_for_weight_update[1]
        logits = tensors_for_weight_update[2]

        ##--(1)-- CURRENT DATA --##

        # Hack to make it work if batch-size is 1
        y = y.expand(1) if len(y.size()) == 0 else y

        # Continue running the model
        recon_batch, mu, logvar, z = self.continued(
            hE, gate_input=y if self.dg_gates else None)

        # If requested, restrict predictions to those classes seen so far
        if self.only_active:
            # -update "active classes" (i.e., list of all classes seen so far)
            self.update_active_classes(y)
            # -remove predictions for classes not yet seen
            logits = logits[:, self.active_classes]
            # -update indeces of ground-truth labels to match those in the "active classes"-list
            y = torch.tensor([self.active_classes.index(i)
                              for i in y]).to(self._device())

        # Calculate all losses
        reconL, variatL, predL, _ = self.loss_function(
            x=hidden_x,
            y=y,
            x_recon=recon_batch,
            y_hat=logits,
            scores=None,
            mu=mu,
            z=z,
            logvar=logvar,
            allowed_classes=self.active_classes if self.only_active else None)

        # Weigh losses as requested
        loss_cur = self.lamda_rcl * reconL + self.lamda_vl * variatL + self.lamda_pl * predL

        # Calculate training-precision
        precision = (y == logits.max(1)[1]).sum().item() / logits.size(0)

        ##--(2)-- REPLAYED DATA --##

        # If model from previous episode is stored, generate replay (at the hidden/latent level!)
        if self.previous_model is not None:
            batch_size_replay = z.size(0)

            # -generate the hidden representations to be replayed
            x_ = self.previous_model.sample(
                batch_size_replay,
                allowed_classes=self.previous_model.active_classes
                if self.only_active else None,
            )

            # -generate labels for this replay
            with torch.no_grad():
                target_logits = self.previous_model.get_logits(x_, hidden=True)
            if self.only_active:
                # -remove targets for classes not yet seen (they will be set to zero prob later)
                target_logits = target_logits[:, self.previous_model.
                                              active_classes]

            # -run current model on the replayed data
            (_, hE_, logits_), _ = self.forward(x_, hidden=True)
            target_probs = F.softmax(target_logits / self.distill_temp, dim=1)
            if self.only_active:
                # for those classes not in [self.previous_model.active_classes], set target_prob to zero
                new_target_probs = None
                for i in range(self.classes):
                    if i in self.previous_model.active_classes:
                        tensor_to_add = target_probs[:,
                                                     self.previous_model.
                                                     active_classes.
                                                     index(i)].unsqueeze(1)
                    else:
                        tensor_to_add = target_probs[:, 0].zero_().unsqueeze(1)

                    if new_target_probs is None:
                        new_target_probs = tensor_to_add
                    else:
                        new_target_probs = torch.cat(
                            [new_target_probs, tensor_to_add], dim=1)
                target_probs = new_target_probs
            recon_x_, mu_, logvar_, z_ = self.continued(
                hE_, gate_input=target_probs if self.dg_gates else None)

            # -if requested, restrict predictions to classes seen so far
            if self.only_active:
                # -remove predictions for classes not yet seen, in both predictions and targets
                logits_ = logits_[:, self.active_classes]

            # -evaluate replayed data
            reconL_r, variatL_r, _, distilL_r = self.loss_function(
                x=x_,
                y=None,
                x_recon=recon_x_,
                y_hat=logits_,
                scores=target_logits,
                mu=mu_,
                z=z_,
                logvar=logvar_,
                allowed_classes=self.active_classes
                if self.only_active else None)
            # -weigh losses as requested
            loss_replay = self.lamda_rcl * reconL_r + self.lamda_vl * variatL_r + self.lamda_pl * distilL_r
        else:
            loss_replay = None

        # Calculate total loss
        loss_total = loss_cur if (
            self.previous_model is None
        ) else rnt * loss_cur + (1 - rnt) * loss_replay

        # Add SI-loss (Zenke et al., 2017)
        si_loss = self.si_loss()
        if self.si_c > 0:
            loss_total += self.si_c * si_loss

        if update:
            # Backpropagate errors
            loss_total.backward()

            # Take optimization-step
            self.optimizer.step()

        # Return the dictionary with different training-loss split in categories
        return {
            'loss_total': loss_total.item(),
            'loss_current': loss_cur.item(),
            'loss_replay': loss_replay.item() if
            (loss_replay is not None) else 0.,
            'si_loss': si_loss.item(),
            'precision': precision,
        }
    def __init__(
            self,
            image_size,
            image_channels,
            classes,
            target_name,
            only_active=False,
            # -conv-layers
            conv_type="standard",
            depth=0,
            start_channels=16,
            reducing_layers=4,
            conv_bn=True,
            conv_nl="relu",
            num_blocks=2,
            global_pooling=False,
            no_fnl=True,
            conv_gated=False,
            # -fc-layers
            fc_layers=3,
            fc_units=2000,
            h_dim=2000,
            fc_drop=0,
            fc_bn=False,
            fc_nl="relu",
            fc_gated=False,
            bias=True,
            excitability=False,
            excit_buffer=False):

        # model configurations
        super().__init__()
        self.classes = classes
        self.target_name = target_name
        self.label = "Classifier"
        self.depth = depth
        self.fc_layers = fc_layers
        self.fc_drop = fc_drop

        # whether always all classes can be predicted or only those seen so far
        self.only_active = only_active
        self.active_classes = []

        # check whether there is at least 1 fc-layer
        if fc_layers < 1:
            raise ValueError(
                "The classifier needs to have at least 1 fully-connected layer."
            )

        ######------SPECIFY MODEL------######
        #--> convolutional layers
        self.convE = ConvLayers(
            conv_type=conv_type,
            block_type="basic",
            num_blocks=num_blocks,
            image_channels=image_channels,
            depth=depth,
            start_channels=start_channels,
            reducing_layers=reducing_layers,
            batch_norm=conv_bn,
            nl=conv_nl,
            global_pooling=global_pooling,
            gated=conv_gated,
            output="none" if no_fnl else "normal",
        )
        self.flatten = modules.Flatten()
        #------------------------------calculate input/output-sizes--------------------------------#
        self.conv_out_units = self.convE.out_units(image_size)
        self.conv_out_size = self.convE.out_size(image_size)
        self.conv_out_channels = self.convE.out_channels
        if fc_layers < 2:
            self.fc_layer_sizes = [
                self.conv_out_units
            ]  #--> this results in self.fcE = modules.Identity()
        elif fc_layers == 2:
            self.fc_layer_sizes = [self.conv_out_units, h_dim]
        else:
            self.fc_layer_sizes = [self.conv_out_units] + [
                int(x) for x in np.linspace(fc_units, h_dim, num=fc_layers - 1)
            ]
        self.units_before_classifier = h_dim if fc_layers > 1 else self.conv_out_units
        #------------------------------------------------------------------------------------------#
        #--> fully connected layers
        self.fcE = MLP(size_per_layer=self.fc_layer_sizes,
                       drop=fc_drop,
                       batch_norm=fc_bn,
                       nl=fc_nl,
                       bias=bias,
                       excitability=excitability,
                       excit_buffer=excit_buffer,
                       gated=fc_gated)
        #--> classifier
        self.classifier = fc_layer(self.units_before_classifier,
                                   classes,
                                   excit_buffer=True,
                                   nl='none',
                                   drop=fc_drop)
class Classifier(ContinualLearner):
    '''Model for encoding (i.e., feature extraction) and classifying images, enriched as "ContinualLearner"--object.'''
    def __init__(
            self,
            image_size,
            image_channels,
            classes,
            target_name,
            only_active=False,
            # -conv-layers
            conv_type="standard",
            depth=0,
            start_channels=16,
            reducing_layers=4,
            conv_bn=True,
            conv_nl="relu",
            num_blocks=2,
            global_pooling=False,
            no_fnl=True,
            conv_gated=False,
            # -fc-layers
            fc_layers=3,
            fc_units=2000,
            h_dim=2000,
            fc_drop=0,
            fc_bn=False,
            fc_nl="relu",
            fc_gated=False,
            bias=True,
            excitability=False,
            excit_buffer=False):

        # model configurations
        super().__init__()
        self.classes = classes
        self.target_name = target_name
        self.label = "Classifier"
        self.depth = depth
        self.fc_layers = fc_layers
        self.fc_drop = fc_drop

        # whether always all classes can be predicted or only those seen so far
        self.only_active = only_active
        self.active_classes = []

        # check whether there is at least 1 fc-layer
        if fc_layers < 1:
            raise ValueError(
                "The classifier needs to have at least 1 fully-connected layer."
            )

        ######------SPECIFY MODEL------######
        #--> convolutional layers
        self.convE = ConvLayers(
            conv_type=conv_type,
            block_type="basic",
            num_blocks=num_blocks,
            image_channels=image_channels,
            depth=depth,
            start_channels=start_channels,
            reducing_layers=reducing_layers,
            batch_norm=conv_bn,
            nl=conv_nl,
            global_pooling=global_pooling,
            gated=conv_gated,
            output="none" if no_fnl else "normal",
        )
        self.flatten = modules.Flatten()
        #------------------------------calculate input/output-sizes--------------------------------#
        self.conv_out_units = self.convE.out_units(image_size)
        self.conv_out_size = self.convE.out_size(image_size)
        self.conv_out_channels = self.convE.out_channels
        if fc_layers < 2:
            self.fc_layer_sizes = [
                self.conv_out_units
            ]  #--> this results in self.fcE = modules.Identity()
        elif fc_layers == 2:
            self.fc_layer_sizes = [self.conv_out_units, h_dim]
        else:
            self.fc_layer_sizes = [self.conv_out_units] + [
                int(x) for x in np.linspace(fc_units, h_dim, num=fc_layers - 1)
            ]
        self.units_before_classifier = h_dim if fc_layers > 1 else self.conv_out_units
        #------------------------------------------------------------------------------------------#
        #--> fully connected layers
        self.fcE = MLP(size_per_layer=self.fc_layer_sizes,
                       drop=fc_drop,
                       batch_norm=fc_bn,
                       nl=fc_nl,
                       bias=bias,
                       excitability=excitability,
                       excit_buffer=excit_buffer,
                       gated=fc_gated)
        #--> classifier
        self.classifier = fc_layer(self.units_before_classifier,
                                   classes,
                                   excit_buffer=True,
                                   nl='none',
                                   drop=fc_drop)

    def list_init_layers(self):
        '''Return list of modules whose parameters could be initialized differently (i.e., conv- or fc-layers).'''
        list = self.convE.list_init_layers()
        list += self.fcE.list_init_layers()
        list += self.classifier.list_init_layers()
        return list

    @property
    def name(self):
        if self.depth > 0 and self.fc_layers > 1:
            return "{}_{}_c{}".format(self.convE.name, self.fcE.name,
                                      self.classes)
        elif self.depth > 0:
            return "{}_{}c{}".format(
                self.convE.name,
                "drop{}-".format(self.fc_drop) if self.fc_drop > 0 else "",
                self.classes)
        elif self.fc_layers > 1:
            return "{}_c{}".format(self.fcE.name, self.classes)
        else:
            return "i{}_{}c{}".format(
                self.fc_layer_sizes[0],
                "drop{}-".format(self.fc_drop) if self.fc_drop > 0 else "",
                self.classes)

    def update_active_classes(self, y):
        '''Given labels of newly observed batch, update list of all classes seen so far.'''
        for i in y.cpu().numpy():
            if not i in self.active_classes:
                self.active_classes.append(i)

    def forward(self, x):
        '''Return tensors required for updating weights and a dict (with key self.target_name) of predicted labels.'''
        # Perform forward pass
        logits = self.classifier(self.fcE(self.flatten(self.convE(x))))
        # Get predictions
        if self.only_active and len(self.active_classes) > 0:
            # -restrict predictions to those classes listed in [self.active_classes]
            logits_for_prediction = logits[:, self.active_classes]
            predictions_shifted = logits_for_prediction.cpu().data.numpy(
            ).argmax(1)
            predictions = {
                self.target_name:
                np.array([self.active_classes[i] for i in predictions_shifted])
            }
        else:
            # -all classes can be predicted (even those not yet observed)
            predictions = {
                self.target_name: logits.cpu().data.numpy().argmax(1)
            }
        return logits, predictions

    def update_weights(self,
                       logits,
                       y,
                       x=None,
                       rnt=None,
                       update=True,
                       **kwargs):
        '''Train model for one batch ([x],[y]).

        [logits]  <tensor> batch of logits returned by model for inputs [x]
        [y]       <tensor> batch of corresponding ground-truth labels'''

        # Set model to training-mode
        if update:
            self.train()

        # Reset optimizer
        self.optimizer.zero_grad()

        # Hack to make it work if batch-size is 1
        y = y.expand(1) if len(y.size()) == 0 else y

        # If requested, restrict predictions to those classes seen so far
        if self.only_active:
            # -update "active classes" (i.e., list of all classes seen so far)
            self.update_active_classes(y)
            # -remove predictions for classes not yet seen
            logits = logits[:, self.active_classes]
            # -update indeces of ground-truth labels to match those in the "active classes"-list
            y = torch.tensor([self.active_classes.index(i)
                              for i in y]).to(self._device())

        # print(self.active_classes)
        # print(y)
        # print(logits.shape)

        # Calculate multiclass prediction loss
        predL = F.cross_entropy(
            input=logits, target=y,
            reduction='none')  # -> summing over classes is "implicit"
        predL = lf.weighted_average(predL, weights=None,
                                    dim=0)  # -> average over batch
        loss_cur = predL

        # Calculate training-precision
        precision = (y == logits.max(1)[1]).sum().item() / logits.size(0)

        # If doing LwF, add 'replayed' data & calculate loss on it
        if self.previous_model is not None:
            # -generate the labels to 'replay' the current inputs with
            with torch.no_grad():
                target_logits, _ = self.previous_model.forward(x)
            if self.only_active:
                target_logits = target_logits[:, self.previous_model.
                                              active_classes]
            # -evaluate replayed data
            loss_replay = lf.loss_fn_kd(scores=logits,
                                        target_scores=target_logits,
                                        T=self.distill_temp)
        else:
            loss_replay = None

        # Calculate total loss
        loss_total = loss_cur if (
            self.previous_model is None
        ) else rnt * loss_cur + (1 - rnt) * loss_replay

        # Add SI-loss (Zenke et al., 2017)
        si_loss = self.si_loss()
        if self.si_c > 0:
            loss_total += self.si_c * si_loss

        if update:
            # Backpropagate errors
            loss_total.backward()

            # Take optimization-step
            self.optimizer.step()

        # Return the dictionary with different training-loss split in categories
        return {
            'loss_total': loss_total.item(),
            'loss_current': loss_cur.item(),
            'loss_replay': loss_replay.item() if
            (loss_replay is not None) else 0.,
            'si_loss': si_loss.item(),
            'precision': precision,
        }
class Classifier(ContinualLearner):
    '''Model for encoding (i.e., feature extraction) and classifying images, enriched as "ContinualLearner"--object.'''
    def __init__(
        self,
        image_size,
        image_channels,
        classes,
        # -conv-layers
        conv_type="standard",
        depth=0,
        start_channels=64,
        reducing_layers=3,
        conv_bn=True,
        conv_nl="relu",
        num_blocks=2,
        global_pooling=False,
        no_fnl=True,
        conv_gated=False,
        # -fc-layers
        fc_layers=3,
        fc_units=1000,
        h_dim=400,
        fc_drop=0,
        fc_bn=True,
        fc_nl="relu",
        fc_gated=False,
        bias=True,
        excitability=False,
        excit_buffer=False,
        # -training-specific settings (can be changed after setting up model)
        hidden=False):

        # model configurations
        super().__init__()
        self.classes = classes
        self.label = "Classifier"
        self.depth = depth
        self.fc_layers = fc_layers
        self.fc_drop = fc_drop

        # settings for training
        self.hidden = hidden
        #--> if True, [self.classify] & replayed data of [self.train_a_batch] expected to be "hidden data"

        # optimizer (needs to be set before training starts))
        self.optimizer = None
        self.optim_list = []

        # check whether there is at least 1 fc-layer
        if fc_layers < 1:
            raise ValueError(
                "The classifier needs to have at least 1 fully-connected layer."
            )

        ######------SPECIFY MODEL------######
        #--> convolutional layers
        self.convE = ConvLayers(
            conv_type=conv_type,
            block_type="basic",
            num_blocks=num_blocks,
            image_channels=image_channels,
            depth=depth,
            start_channels=start_channels,
            reducing_layers=reducing_layers,
            batch_norm=conv_bn,
            nl=conv_nl,
            global_pooling=global_pooling,
            gated=conv_gated,
            output="none" if no_fnl else "normal",
        )
        self.flatten = modules.Flatten()
        #------------------------------calculate input/output-sizes--------------------------------#
        self.conv_out_units = self.convE.out_units(image_size)
        self.conv_out_size = self.convE.out_size(image_size)
        self.conv_out_channels = self.convE.out_channels
        if fc_layers < 2:
            self.fc_layer_sizes = [
                self.conv_out_units
            ]  #--> this results in self.fcE = modules.Identity()
        elif fc_layers == 2:
            self.fc_layer_sizes = [self.conv_out_units, h_dim]
        else:
            self.fc_layer_sizes = [self.conv_out_units] + [
                int(x) for x in np.linspace(fc_units, h_dim, num=fc_layers - 1)
            ]
        self.units_before_classifier = h_dim if fc_layers > 1 else self.conv_out_units
        #------------------------------------------------------------------------------------------#
        #--> fully connected layers
        self.fcE = MLP(
            size_per_layer=self.fc_layer_sizes,
            drop=fc_drop,
            batch_norm=fc_bn,
            nl=fc_nl,
            bias=bias,
            excitability=excitability,
            excit_buffer=excit_buffer,
            gated=fc_gated)  #, output="none") ## NOTE: temporary change!!!
        #--> classifier
        self.classifier = fc_layer(self.units_before_classifier,
                                   classes,
                                   excit_buffer=True,
                                   nl='none',
                                   drop=fc_drop)

    def list_init_layers(self):
        '''Return list of modules whose parameters could be initialized differently (i.e., conv- or fc-layers).'''
        list = self.convE.list_init_layers()
        list += self.fcE.list_init_layers()
        list += self.classifier.list_init_layers()
        return list

    @property
    def name(self):
        if self.depth > 0 and self.fc_layers > 1:
            return "{}_{}_c{}".format(self.convE.name, self.fcE.name,
                                      self.classes)
        elif self.depth > 0:
            return "{}_{}c{}".format(
                self.convE.name,
                "drop{}-".format(self.fc_drop) if self.fc_drop > 0 else "",
                self.classes)
        elif self.fc_layers > 1:
            return "{}_c{}".format(self.fcE.name, self.classes)
        else:
            return "i{}_{}c{}".format(
                self.fc_layer_sizes[0],
                "drop{}-".format(self.fc_drop) if self.fc_drop > 0 else "",
                self.classes)

    def forward(self, x):
        hidden_rep = self.convE(x)
        final_features = self.fcE(self.flatten(hidden_rep))
        return self.classifier(final_features)

    def input_to_hidden(self, x):
        '''Get [hidden_rep]s (inputs to final fully-connected layers) for images [x].'''
        return self.convE(x)

    def hidden_to_output(self, hidden_rep):
        '''Map [hidden_rep]s to outputs (i.e., logits for all possible output-classes).'''
        return self.classifier(self.fcE(self.flatten(hidden_rep)))

    def feature_extractor(self, images, from_hidden=False):
        return self.fcE(
            self.flatten(images if from_hidden else self.convE(images)))
        #return self.classifier(self.fcE(self.flatten(images if from_hidden else self.convE(images))))

    def classify(self, x, not_hidden=False):
        '''For input [x] (image or extracted "intermediate" image features), return all predicted "scores"/"logits".'''
        image_features = self.flatten(x) if (
            self.hidden and not not_hidden) else self.flatten(self.convE(x))
        hE = self.fcE(image_features)
        return self.classifier(hE)

    def train_a_batch(self,
                      x,
                      y=None,
                      x_=None,
                      y_=None,
                      scores_=None,
                      rnt=0.5,
                      active_classes=None,
                      task=1,
                      replay_not_hidden=False,
                      freeze_convE=False,
                      **kwargs):
        '''Train model for one batch ([x],[y]), possibly supplemented with replayed data ([x_],[y_]).

        [x]                 <tensor> batch of inputs (could be None, in which case only 'replayed' data is used)
        [y]                 <tensor> batch of corresponding labels
        [x_]                None or (<list> of) <tensor> batch of replayed inputs
                              NOTE: expected to be as [self.hidden] or [replay_up_to], unless [replay_not_hidden]==True
        [y_]                None or (<list> of) <tensor> batch of corresponding "replayed" labels
        [scores_]           None or (<list> of) <tensor> 2Dtensor:[batch]x[classes] predicted "scores"/"logits" for [x_]
        [rnt]               <number> in [0,1], relative importance of new task
        [active_classes]    None or (<list> of) <list> with "active" classes
        [task]              <int>, for setting task-specific mask
        [replay_not_hidden] <bool> provided [x_] are original images, even though other level might be expected'''

        # Set model to training-mode
        self.train()
        if freeze_convE:
            # - if conv-layers are frozen, they shoud be set to eval() to prevent batch-norm layers from changing
            self.convE.eval()

        # Reset optimizer
        self.optimizer.zero_grad()

        ##--(1)-- CURRENT DATA --##

        if x is not None:
            # If requested, apply correct task-specific mask
            if self.mask_dict is not None:
                self.apply_XdGmask(task=task)

            # Run model
            y_hat = self(x)
            # -if needed (e.g., "class" or "task" scenario), remove predictions for classes not in current task
            if active_classes is not None:
                class_entries = active_classes[-1] if type(
                    active_classes[0]) == list else active_classes
                y_hat = y_hat[:, class_entries]

            # Calculate multiclass prediction loss
            if y is not None and len(y.size()) == 0:
                y = y.expand(1)  #--> hack to make it work if batch-size is 1
            predL = None if y is None else F.cross_entropy(
                input=y_hat, target=y, reduction='none')
            # --> no reduction needed, summing over classes is "implicit"
            predL = None if y is None else lf.weighted_average(
                predL, weights=None, dim=0)  # -> average over batch

            # Weigh losses
            loss_cur = predL

            # Calculate training-precision
            precision = None if y is None else (
                y == y_hat.max(1)[1]).sum().item() / x.size(0)

            # If XdG is combined with replay, backward-pass needs to be done before new task-mask is applied
            if (self.mask_dict is not None) and (x_ is not None):
                weighted_current_loss = rnt * loss_cur
                weighted_current_loss.backward()
        else:
            precision = predL = None
            # -> it's possible there is only "replay" [i.e., for offline with incremental task learning scenario]

        ##--(2)-- REPLAYED DATA --##

        if x_ is not None:
            # In the Task-IL scenario, [y_] or [scores_] is a list and [x_] needs to be evaluated on each of them
            TaskIL = (type(y_) == list) if (y_ is not None) else (type(scores_)
                                                                  == list)
            if not TaskIL:
                y_ = [y_]
                scores_ = [scores_]
                active_classes = [active_classes] if (active_classes
                                                      is not None) else None
            n_replays = len(y_) if (y_ is not None) else len(scores_)

            # Prepare lists to store losses for each replay
            loss_replay = [torch.tensor(0., device=self._device())] * n_replays
            predL_r = [torch.tensor(0., device=self._device())] * n_replays
            distilL_r = [torch.tensor(0., device=self._device())] * n_replays

            # Run model (if [x_] is not a list with separate replay per task and there is no task-specific mask)
            if (not type(x_) == list) and (self.mask_dict is None):
                y_hat_all = self.classify(x_, not_hidden=replay_not_hidden)

            # Loop to perform each replay
            for replay_id in range(n_replays):

                # -if [x_] is a list with separate replay per task, evaluate model on this task's replay
                if (type(x_) == list) or (self.mask_dict is not None):
                    x_temp_ = x_[replay_id] if type(x_) == list else x_
                    if self.mask_dict is not None:
                        self.apply_XdGmask(task=replay_id + 1)
                    y_hat_all = self.classify(x_temp_,
                                              not_hidden=replay_not_hidden)

                # -if needed (e.g., "class" or "task" scenario), remove predictions for classes not in replayed task
                y_hat = y_hat_all if (
                    active_classes is None
                ) else y_hat_all[:, active_classes[replay_id]]

                # Calculate losses
                if (y_ is not None) and (y_[replay_id] is not None):
                    predL_r[replay_id] = F.cross_entropy(y_hat,
                                                         y_[replay_id],
                                                         reduction='none')
                    predL_r[replay_id] = lf.weighted_average(
                        predL_r[replay_id], dim=0)
                    #-> average over batch
                if (scores_ is not None) and (scores_[replay_id] is not None):
                    # n_classes_to_consider = scores.size(1) #--> with this version, no zeroes are added to [scores]!
                    n_classes_to_consider = y_hat.size(
                        1
                    )  #--> zeros will be added to [scores] to make it this size!
                    distilL_r[replay_id] = lf.loss_fn_kd(
                        scores=y_hat[:, :n_classes_to_consider],
                        target_scores=scores_[replay_id],
                        T=self.KD_temp,
                    )  # --> summing over classes & averaging over batch within this function
                # Weigh losses
                if self.replay_targets == "hard":
                    loss_replay[replay_id] = predL_r[replay_id]
                elif self.replay_targets == "soft":
                    loss_replay[replay_id] = distilL_r[replay_id]

                # If task-specific mask, backward pass needs to be performed before next task-mask is applied
                if self.mask_dict is not None:
                    weighted_replay_loss_this_task = (
                        1 - rnt) * loss_replay[replay_id] / n_replays
                    weighted_replay_loss_this_task.backward()

        # Calculate total loss
        loss_replay = None if (x_ is None) else sum(loss_replay) / n_replays
        loss_total = loss_replay if (
            x is None) else (loss_cur if x_ is None else rnt * loss_cur +
                             (1 - rnt) * loss_replay)

        ##--(3)-- ALLOCATION LOSSES --##

        # Add SI-loss (Zenke et al., 2017)
        surrogate_loss = self.surrogate_loss()
        if self.si_c > 0:
            loss_total += self.si_c * surrogate_loss

        # Add EWC-loss
        ewc_loss = self.ewc_loss()
        if self.ewc_lambda > 0:
            loss_total += self.ewc_lambda * ewc_loss

        # Backpropagate errors (if not yet done)
        if (self.mask_dict is None) or (x_ is None):
            loss_total.backward()
        # Take optimization-step
        self.optimizer.step()

        # Return the dictionary with different training-loss split in categories
        return {
            'loss_total':
            loss_total.item(),
            'loss_current':
            loss_cur.item() if x is not None else 0,
            'loss_replay':
            loss_replay.item() if
            (loss_replay is not None) and (x is not None) else 0,
            'pred':
            predL.item() if predL is not None else 0,
            'pred_r':
            sum(predL_r).item() / n_replays if
            (x_ is not None and predL_r[0] is not None) else 0,
            'distil_r':
            sum(distilL_r).item() / n_replays if
            (x_ is not None and distilL_r[0] is not None) else 0,
            'ewc':
            ewc_loss.item(),
            'si_loss':
            surrogate_loss.item(),
            'precision':
            precision if precision is not None else 0.,
        }