示例#1
0
    def __init__(self,
                 image_size,
                 image_channels,
                 classes,
                 fc_layers=3,
                 fc_units=1000,
                 fc_drop=0,
                 fc_bn=False,
                 fc_nl="relu",
                 gated=False,
                 bias=True,
                 excitability=False,
                 excit_buffer=False,
                 binaryCE=False,
                 binaryCE_distill=False,
                 AGEM=False,
                 dataset="mnist"):

        # configurations
        super().__init__()
        self.classes = classes
        self.label = "Classifier"
        self.fc_layers = fc_layers

        # settings for training
        self.binaryCE = binaryCE  #-> use binary (instead of multiclass) prediction error
        self.binaryCE_distill = binaryCE_distill  #-> for classes from previous tasks, use the by the previous model
        #   predicted probs as binary targets (only in Class-IL with binaryCE)
        self.AGEM = AGEM  #-> use gradient of replayed data as inequality constraint for (instead of adding it to)
        #   the gradient of the current data (as in A-GEM, see Chaudry et al., 2019; ICLR)

        # check whether there is at least 1 fc-layer
        if fc_layers < 1:
            raise ValueError(
                "The classifier needs to have at least 1 fully-connected layer."
            )

        ######------SPECIFY MODEL------######

        # flatten image to 2D-tensor
        self.flatten = utils.Flatten()

        # fully connected hidden layers
        if dataset == "ckplus" or dataset == "affectnet":
            self.input_size = image_size[0] * image_size[1] * image_channels
        else:
            self.input_size = image_channels * image_size**2

        self.fcE = MLP(input_size=self.input_size,
                       output_size=fc_units,
                       layers=fc_layers - 1,
                       hid_size=fc_units,
                       drop=fc_drop,
                       batch_norm=fc_bn,
                       nl=fc_nl,
                       bias=bias,
                       excitability=excitability,
                       excit_buffer=excit_buffer,
                       gated=gated,
                       latent_space=128)
示例#2
0
    def __init__(self, image_size, image_channels, classes,
                 fc_layers=3, fc_units=1000, fc_drop=0, fc_bn=True, fc_nl="relu", gated=False, z_dim=20, 
                 dataset="mnist"):
        '''Class for variational auto-encoder (VAE) models.'''

        # Set configurations
        super().__init__()
        self.label = "VAE"
        self.image_size = image_size
        self.image_channels = image_channels
        self.classes = classes
        self.fc_layers = fc_layers
        self.z_dim = z_dim
        self.fc_units = fc_units

        # Weigths of different components of the loss function
        self.lamda_rcl = 1.
        self.lamda_vl = 1.
        self.lamda_pl = 0.  #--> when used as "classifier with feedback-connections", this should be set to 1.

        self.average = True #--> makes that [reconL] and [variatL] are both divided by number of input-pixels

        # Check whether there is at least 1 fc-layer
        if fc_layers<1:
            raise ValueError("VAE cannot have 0 fully-connected layers!")


        ######------SPECIFY MODEL------######

        ##>----Encoder (= q[z|x])----<##
        # -flatten image to 2D-tensor
        self.flatten = utils.Flatten()

        if dataset == "ckplus" or dataset == "affectnet": 
            self.input_size = image_size[0] * image_size[1] * image_channels
        else: 
            self.input_size = image_channels*image_size**2

        # -fully connected hidden layers
        self.fcE = MLP(input_size=self.input_size, output_size=fc_units, layers=fc_layers-1,
                       hid_size=fc_units, drop=fc_drop, batch_norm=fc_bn, nl=fc_nl, gated=gated)
        mlp_output_size = fc_units if fc_layers > 1 else self.input_size
        # -to z
        self.toZ = fc_layer_split(mlp_output_size, z_dim, nl_mean='none', nl_logvar='none')

        ##>----Classifier----<##
        self.classifier = fc_layer(mlp_output_size, classes, excit_buffer=True, nl='none')

        ##>----Decoder (= p[x|z])----<##
        # -from z
        out_nl = True if fc_layers > 1 else False
        self.fromZ = fc_layer(z_dim, mlp_output_size, batch_norm=(out_nl and fc_bn), nl=fc_nl if out_nl else "none")
        # -fully connected hidden layers
        self.fcD = MLP(input_size=fc_units, output_size=self.input_size, layers=fc_layers-1,
                       hid_size=fc_units, drop=fc_drop, batch_norm=fc_bn, nl=fc_nl, gated=gated, output='BCE')
        # -to image-shape
        self.to_image = utils.Reshape(image_channels=image_channels)
示例#3
0
    def __init__(self,
                 image_size,
                 image_channels,
                 classes,
                 fc_layers=3,
                 fc_units=1000,
                 fc_drop=0,
                 fc_bn=True,
                 fc_nl="relu",
                 gated=False,
                 bias=True,
                 excitability=False,
                 excit_buffer=False,
                 binaryCE=False,
                 binaryCE_distill=False):

        # configurations
        super().__init__()
        self.classes = classes
        self.label = "Classifier"
        self.fc_layers = fc_layers

        # settings for training
        self.binaryCE = binaryCE
        self.binaryCE_distill = binaryCE_distill

        # check whether there is at least 1 fc-layer
        if fc_layers < 1:
            raise ValueError(
                "The classifier needs to have at least 1 fully-connected layer."
            )

        ######------SPECIFY MODEL------######

        # flatten image to 2D-tensor
        self.flatten = utils.Flatten()

        # fully connected hidden layers
        self.fcE = MLP(input_size=image_channels * image_size**2,
                       output_size=fc_units,
                       layers=fc_layers - 1,
                       hid_size=fc_units,
                       drop=fc_drop,
                       batch_norm=fc_bn,
                       nl=fc_nl,
                       bias=bias,
                       excitability=excitability,
                       excit_buffer=excit_buffer,
                       gated=gated)
        mlp_output_size = fc_units if fc_layers > 1 else image_channels * image_size**2

        # classifier
        self.classifier = fc_layer(mlp_output_size,
                                   classes,
                                   excit_buffer=True,
                                   nl='none',
                                   drop=fc_drop)
示例#4
0
    def __init__(self, image_size, image_channels, classes,
                 fc_layers=3, fc_units=1000, fc_drop=0, fc_bn=False, fc_nl="relu", gated=False,
                 bias=True, excitability=False, excit_buffer=False, binaryCE=False, binaryCE_distill=False, AGEM=False,
                 experiment='splitMNIST'):

        # configurations
        super().__init__()
        self.classes = classes
        self.label = "Classifier"
        self.fc_layers = fc_layers

        # settings for training
        self.binaryCE = binaryCE  # -> use binary (instead of multiclass) prediction error
        self.binaryCE_distill = binaryCE_distill  # -> for classes from previous tasks, use the by the previous model
        #   predicted probs as binary targets (only in Class-IL with binaryCE)
        self.AGEM = AGEM  # -> use gradient of replayed data as inequality constraint for (instead of adding it to)
        #   the gradient of the current data (as in A-GEM, see Chaudry et al., 2019; ICLR)

        # Online mem distillation
        self.is_offline_training = False
        self.is_ready_distill = False
        self.alpha_t = 0.5
        # check whether there is at least 1 fc-layer
        if fc_layers < 1:
            raise ValueError("The classifier needs to have at least 1 fully-connected layer.")

        ######------SPECIFY MODEL------######
        self.experiment = experiment
        if self.experiment in ['CIFAR10', 'CIFAR100', 'CUB2011']:
            self.fcE = rn.resnet32(classes, pretrained=False)
            self.fcE.linear = nn.Identity()

            self.classifier = fc_layer(64, classes, excit_buffer=True, nl='none', drop=fc_drop)
        elif self.experiment == 'ImageNet':
            ResNet.name = 'ResNet-18'
            self.fcE = resnet18(pretrained=True)
            self.fcE.fc = nn.Identity()

            self.classifier = fc_layer(512, classes, excit_buffer=True, nl='none', drop=fc_drop)
        else:
            # flatten image to 2D-tensor
            self.flatten = utils.Flatten()

            # fully connected hidden layers
            self.fcE = MLP(input_size=image_channels * image_size ** 2, output_size=fc_units, layers=fc_layers - 1,
                           hid_size=fc_units, drop=fc_drop, batch_norm=fc_bn, nl=fc_nl, bias=bias,
                           excitability=excitability, excit_buffer=excit_buffer, gated=gated)
            mlp_output_size = fc_units if fc_layers > 1 else image_channels * image_size ** 2

            # classifier
            self.classifier = fc_layer(mlp_output_size, classes, excit_buffer=True, nl='none', drop=fc_drop)
示例#5
0
    def __init__(self,
                 image_size,
                 image_channels,
                 classes,
                 fc_layers=3,
                 fc_units=1000,
                 fc_drop=0,
                 fc_bn=True,
                 fc_nl="relu",
                 bias=True,
                 excitability=False,
                 excit_buffer=False):

        # configurations
        super().__init__()
        self.classes = classes
        self.label = "Classifier"

        # check whether there is at least 1 fc-layer
        if fc_layers < 1:
            raise ValueError(
                "The classifier needs to have at least 1 fully-connected layer."
            )

        ######------SPECIFY MODEL------######

        # flatten image to 2D-tensor
        self.flatten = utils.Flatten()

        # fully connected hidden layers
        self.fcE = MLP(input_size=image_channels * image_size**2,
                       output_size=fc_units,
                       layers=fc_layers - 1,
                       hid_size=fc_units,
                       drop=fc_drop,
                       batch_norm=fc_bn,
                       nl=fc_nl,
                       final_nl=True,
                       bias=bias,
                       excitability=excitability,
                       excit_buffer=excit_buffer)
        mlp_output_size = fc_units if fc_layers > 1 else image_channels * image_size**2

        # classifier
        self.classifier = nn.Sequential(
            nn.Dropout(fc_drop),
            eM.LinearExcitability(mlp_output_size, classes, excit_buffer=True))
class AutoEncoderLatent(Replayer):
    """Class for variational auto-encoder (VAE) models."""
    def __init__(self,
                 latent_size,
                 classes,
                 fc_layers=3,
                 fc_units=1000,
                 fc_drop=0,
                 fc_bn=True,
                 fc_nl="relu",
                 gated=False,
                 z_dim=20):
        '''Class for variational auto-encoder (VAE) models.'''

        # Set configurations
        super().__init__()
        self.latent_size = latent_size
        self.label = "VAE"
        self.classes = classes
        self.fc_layers = fc_layers
        self.z_dim = z_dim
        self.fc_units = fc_units

        # Weigths of different components of the loss function
        self.lamda_rcl = 1.
        self.lamda_vl = 1.
        self.lamda_pl = 0.  #--> when used as "classifier with feedback-connections", this should be set to 1.

        self.average = True  #--> makes that [reconL] and [variatL] are both divided by number of input-pixels

        # Check whether there is at least 1 fc-layer
        if fc_layers < 1:
            raise ValueError("VAE cannot have 0 fully-connected layers!")

        ######------SPECIFY MODEL------######

        ##>----Encoder (= q[z|x])----<##
        # -fully connected hidden layers
        self.fcE = MLP(input_size=latent_size,
                       output_size=fc_units,
                       layers=fc_layers - 1,
                       hid_size=fc_units,
                       drop=fc_drop,
                       batch_norm=fc_bn,
                       nl=fc_nl,
                       gated=gated)
        mlp_output_size = fc_units if fc_layers > 1 else latent_size
        # -to z
        self.toZ = fc_layer_split(mlp_output_size,
                                  z_dim,
                                  nl_mean='none',
                                  nl_logvar='none')

        ##>----Classifier----<##
        self.classifier = fc_layer(mlp_output_size,
                                   classes,
                                   excit_buffer=True,
                                   nl='none')

        ##>----Decoder (= p[x|z])----<##
        # -from z
        out_nl = True if fc_layers > 1 else False
        self.fromZ = fc_layer(z_dim,
                              mlp_output_size,
                              batch_norm=(out_nl and fc_bn),
                              nl=fc_nl if out_nl else "none")
        # -fully connected hidden layers
        self.fcD = MLP(input_size=fc_units,
                       output_size=latent_size,
                       layers=fc_layers - 1,
                       hid_size=fc_units,
                       drop=fc_drop,
                       batch_norm=fc_bn,
                       nl=fc_nl,
                       gated=gated,
                       output='BCE')

    @property
    def name(self):
        fc_label = "{}--".format(self.fcE.name) if self.fc_layers > 1 else ""
        hid_label = "{}{}-".format(
            "i", self.latent_size) if self.fc_layers == 1 else ""
        z_label = "z{}".format(self.z_dim)
        return "{}({}{}{}-c{})".format(self.label, fc_label, hid_label,
                                       z_label, self.classes)

    def list_init_layers(self):
        '''Return list of modules whose parameters could be initialized differently (i.e., conv- or fc-layers).'''
        list = []
        list += self.fcE.list_init_layers()
        list += self.toZ.list_init_layers()
        list += self.classifier.list_init_layers()
        list += self.fromZ.list_init_layers()
        list += self.fcD.list_init_layers()
        return list

    ##------ FORWARD FUNCTIONS --------##

    def encode(self, x):
        '''Pass input through feed-forward connections, to get [hE], [z_mean] and [z_logvar].'''
        # extract final hidden features (forward-pass)
        hE = self.fcE(x)
        # get parameters for reparametrization
        (z_mean, z_logvar) = self.toZ(hE)
        return z_mean, z_logvar, hE

    def classify(self, x):
        '''For input [x], return all predicted "scores"/"logits".'''
        hE = self.fcE(x)
        y_hat = self.classifier(hE)
        return y_hat

    def reparameterize(self, mu, logvar):
        '''Perform "reparametrization trick" to make these stochastic variables differentiable.'''
        std = logvar.mul(0.5).exp_()
        eps = std.new(std.size()).normal_()
        return eps.mul(std).add_(mu)

    def decode(self, z):
        hD = self.fromZ(z)
        features = self.fcD(hD)
        return features

    def forward(self, x, full=False, reparameterize=True):
        '''Forward function to propagate [x] through the encoder, reparametrization and decoder.

        Input:  - [x]   <4D-tensor> of shape [batch_size]x[channels]x[image_size]x[image_size]

        If [full] is True, output should be a <tuple> consisting of:
        - [x_recon]     <4D-tensor> reconstructed image (features) in same shape as [x]
        - [y_hat]       <2D-tensor> with predicted logits for each class
        - [mu]          <2D-tensor> with either [z] or the estimated mean of [z]
        - [logvar]      None or <2D-tensor> estimated log(SD^2) of [z]
        - [z]           <2D-tensor> reparameterized [z] used for reconstruction
        If [full] is False, output is simply the predicted logits (i.e., [y_hat]).'''
        if full:
            # encode (forward), reparameterize and decode (backward)
            mu, logvar, hE = self.encode(x)
            z = self.reparameterize(mu, logvar) if reparameterize else mu
            x_recon = self.decode(z)
            # classify
            y_hat = self.classifier(hE)
            # return
            return (x_recon, y_hat, mu, logvar, z)
        else:
            return self.classify(
                x)  # -> if [full]=False, only forward pass for prediction

    ##------ SAMPLE FUNCTIONS --------##

    def sample(self, size):
        '''Generate [size] samples from the model. Output is tensor (not "requiring grad"), on same device as <self>.'''

        # set model to eval()-mode
        mode = self.training
        self.eval()

        # sample z
        z = torch.randn(size, self.z_dim).to(self._device())

        # decode z into image X
        with torch.no_grad():
            X = self.decode(z)

        # set model back to its initial mode
        self.train(mode=mode)

        # return samples as [batch_size]x[channels]x[image_size]x[image_size] tensor
        return X

    ##------ LOSS FUNCTIONS --------##

    def calculate_recon_loss(self, x, x_recon, average=False):
        '''Calculate reconstruction loss for each element in the batch.

        INPUT:  - [x]           <tensor> with original input (1st dimension (ie, dim=0) is "batch-dimension")
                - [x_recon]     (tuple of 2x) <tensor> with reconstructed input in same shape as [x]
                - [average]     <bool>, if True, loss is average over all pixels; otherwise it is summed

        OUTPUT: - [reconL]      <1D-tensor> of length [batch_size]'''

        batch_size = x.size(0)
        reconL = F.binary_cross_entropy(input=x_recon.view(batch_size, -1),
                                        target=x.view(batch_size, -1),
                                        reduction='none')
        reconL = torch.mean(reconL, dim=1) if average else torch.sum(reconL,
                                                                     dim=1)

        return reconL

    def calculate_variat_loss(self, mu, logvar):
        '''Calculate reconstruction loss for each element in the batch.

        INPUT:  - [mu]        <2D-tensor> by encoder predicted mean for [z]
                - [logvar]    <2D-tensor> by encoder predicted logvar for [z]

        OUTPUT: - [variatL]   <1D-tensor> of length [batch_size]'''

        # --> calculate analytically
        # ---- see Appendix B from: Kingma & Welling (2014) Auto-Encoding Variational Bayes, ICLR ----#
        variatL = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(),
                                   dim=1)

        return variatL

    def loss_function(self,
                      recon_x,
                      x,
                      y_hat=None,
                      y_target=None,
                      scores=None,
                      mu=None,
                      logvar=None):
        '''Calculate and return various losses that could be used for training and/or evaluating the model.

        INPUT:  - [recon_x]         <4D-tensor> reconstructed image in same shape as [x]
                - [x]               <4D-tensor> original image
                - [y_hat]           <2D-tensor> with predicted "logits" for each class
                - [y_target]        <1D-tensor> with target-classes (as integers)
                - [scores]          <2D-tensor> with target "logits" for each class
                - [mu]              <2D-tensor> with either [z] or the estimated mean of [z]
                - [logvar]          None or <2D-tensor> with estimated log(SD^2) of [z]

        SETTING:- [self.average]    <bool>, if True, both [reconL] and [variatL] are divided by number of input elements

        OUTPUT: - [reconL]       reconstruction loss indicating how well [x] and [x_recon] match
                - [variatL]      variational (KL-divergence) loss "indicating how normally distributed [z] is"
                - [predL]        prediction loss indicating how well targets [y] are predicted
                - [distilL]      knowledge distillation (KD) loss indicating how well the predicted "logits" ([y_hat])
                                     match the target "logits" ([scores])'''

        ###-----Reconstruction loss-----###
        reconL = self.calculate_recon_loss(
            x=x, x_recon=recon_x,
            average=self.average)  #-> possibly average over pixels
        reconL = torch.mean(reconL)  #-> average over batch

        ###-----Variational loss-----###
        if logvar is not None:
            variatL = self.calculate_variat_loss(mu=mu, logvar=logvar)
            variatL = torch.mean(variatL)  #-> average over batch
            if self.average:
                variatL /= self.latent_size  #-> divide by # of input-pixels, if [self.average]
        else:
            variatL = torch.tensor(0., device=self._device())

        ###-----Prediction loss-----###
        if y_target is not None:
            predL = F.cross_entropy(y_hat, y_target,
                                    reduction='mean')  #-> average over batch
        else:
            predL = torch.tensor(0., device=self._device())

        ###-----Distilliation loss-----###
        if scores is not None:
            n_classes_to_consider = y_hat.size(
                1
            )  #--> zeroes will be added to [scores] to make its size match [y_hat]
            distilL = utils.loss_fn_kd(scores=y_hat[:, :n_classes_to_consider],
                                       target_scores=scores,
                                       T=self.KD_temp)
        else:
            distilL = torch.tensor(0., device=self._device())

        # Return a tuple of the calculated losses
        return reconL, variatL, predL, distilL

    ##------ TRAINING FUNCTIONS --------##

    def train_a_batch(self,
                      x,
                      y,
                      x_=None,
                      y_=None,
                      scores_=None,
                      rnt=0.5,
                      active_classes=None,
                      task=1,
                      **kwargs):
        '''Train model for one batch ([x],[y]), possibly supplemented with replayed data ([x_],[y_]).

        [x]               <tensor> batch of inputs (could be None, in which case only 'replayed' data is used)
        [y]               <tensor> batch of corresponding labels
        [x_]              None or (<list> of) <tensor> batch of replayed inputs
        [y_]              None or (<list> of) <tensor> batch of corresponding "replayed" labels
        [scores_]         None or (<list> of) <tensor> 2Dtensor:[batch]x[classes] predicted "scores"/"logits" for [x_]
        [rnt]             <number> in [0,1], relative importance of new task
        [active_classes]  None or (<list> of) <list> with "active" classes'''

        # Set model to training-mode
        self.train()

        ##--(1)-- CURRENT DATA --##
        precision = 0.
        if x is not None:
            # Run the model
            recon_batch, y_hat, mu, logvar, z = self(x, full=True)
            # If needed (e.g., Task-IL or Class-IL scenario), remove predictions for classes not in current task
            if active_classes is not None:
                y_hat = y_hat[:, active_classes[-1]] if type(
                    active_classes[0]) == list else y_hat[:, active_classes]
            # Calculate all losses
            reconL, variatL, predL, _ = self.loss_function(recon_x=recon_batch,
                                                           x=x,
                                                           y_hat=y_hat,
                                                           y_target=y,
                                                           mu=mu,
                                                           logvar=logvar)
            # Weigh losses as requested
            loss_cur = self.lamda_rcl * reconL + self.lamda_vl * variatL + self.lamda_pl * predL

            # Calculate training-precision
            if y is not None:
                _, predicted = y_hat.max(1)
                precision = (y == predicted).sum().item() / x.size(0)

        ##--(2)-- REPLAYED DATA --##
        if x_ is not None:
            # In the Task-IL scenario, [y_] or [scores_] is a list and [x_] needs to be evaluated on each of them
            # (in case of 'exact' or 'exemplar' replay, [x_] is also a list!
            TaskIL = (type(y_) == list) if (y_ is not None) else (type(scores_)
                                                                  == list)
            if not TaskIL:
                y_ = [y_]
                scores_ = [scores_]
                active_classes = [active_classes] if (active_classes
                                                      is not None) else None
                n_replays = len(x_) if (type(x_) == list) else 1
            else:
                n_replays = len(y_) if (y_ is not None) else (len(scores_) if (
                    scores_ is not None) else 1)

            # Prepare lists to store losses for each replay
            loss_replay = [None] * n_replays
            reconL_r = [None] * n_replays
            variatL_r = [None] * n_replays
            predL_r = [None] * n_replays
            distilL_r = [None] * n_replays

            # Run model (if [x_] is not a list with separate replay per task)
            if (not type(x_) == list):
                x_temp_ = x_
                recon_batch, y_hat_all, mu, logvar, z = self(x_temp_,
                                                             full=True)

            # Loop to perform each replay
            for replay_id in range(n_replays):

                # -if [x_] is a list with separate replay per task, evaluate model on this task's replay
                if (type(x_) == list):
                    x_temp_ = x_[replay_id]
                    recon_batch, y_hat_all, mu, logvar, z = self(x_temp_,
                                                                 full=True)

                # If needed (e.g., Task-IL or Class-IL scenario), remove predictions for classes not in replayed task
                if active_classes is not None:
                    y_hat = y_hat_all[:, active_classes[replay_id]]
                else:
                    y_hat = y_hat_all

                # Calculate all losses
                reconL_r[replay_id], variatL_r[replay_id], predL_r[
                    replay_id], distilL_r[replay_id] = self.loss_function(
                        recon_x=recon_batch,
                        x=x_temp_,
                        y_hat=y_hat,
                        y_target=y_[replay_id] if (y_ is not None) else None,
                        scores=scores_[replay_id] if
                        (scores_ is not None) else None,
                        mu=mu,
                        logvar=logvar,
                    )

                # Weigh losses as requested
                loss_replay[replay_id] = self.lamda_rcl * reconL_r[
                    replay_id] + self.lamda_vl * variatL_r[replay_id]
                if self.replay_targets == "hard":
                    loss_replay[
                        replay_id] += self.lamda_pl * predL_r[replay_id]
                elif self.replay_targets == "soft":
                    loss_replay[
                        replay_id] += self.lamda_pl * distilL_r[replay_id]

        # Calculate total loss
        loss_replay = None if (x_ is None) else sum(loss_replay) / n_replays
        loss_total = loss_replay if (
            x is None) else (loss_cur if x_ is None else rnt * loss_cur +
                             (1 - rnt) * loss_replay)

        # Reset optimizer
        self.optimizer.zero_grad()
        # Backpropagate errors
        loss_total.backward()
        # Take optimization-step
        self.optimizer.step()

        # Return the dictionary with different training-loss split in categories
        return {
            'loss_total':
            loss_total.item(),
            'precision':
            precision,
            'recon':
            reconL.item() if x is not None else 0,
            'variat':
            variatL.item() if x is not None else 0,
            'pred':
            predL.item() if x is not None else 0,
            'recon_r':
            sum(reconL_r).item() / n_replays if x_ is not None else 0,
            'variat_r':
            sum(variatL_r).item() / n_replays if x_ is not None else 0,
            'pred_r':
            sum(predL_r).item() / n_replays if
            (x_ is not None and predL_r[0] is not None) else 0,
            'distil_r':
            sum(distilL_r).item() / n_replays if
            (x_ is not None and distilL_r[0] is not None) else 0,
        }
示例#7
0
class RootClassifier(ContinualLearner, Replayer, ExemplarHandler):
    '''Model for classifying images, "enriched" as "ContinualLearner"-, Replayer- and ExemplarHandler-object.'''

    # TODO: Do I need the `classes` argument?
    def __init__(self,
                 image_size,
                 image_channels,
                 classes,
                 fc_layers=3,
                 fc_units=1000,
                 fc_drop=0,
                 fc_bn=False,
                 fc_nl="relu",
                 gated=False,
                 bias=True,
                 excitability=False,
                 excit_buffer=False,
                 binaryCE=False,
                 binaryCE_distill=False,
                 AGEM=False,
                 dataset="mnist"):

        # configurations
        super().__init__()
        self.classes = classes
        self.label = "Classifier"
        self.fc_layers = fc_layers

        # settings for training
        self.binaryCE = binaryCE  #-> use binary (instead of multiclass) prediction error
        self.binaryCE_distill = binaryCE_distill  #-> for classes from previous tasks, use the by the previous model
        #   predicted probs as binary targets (only in Class-IL with binaryCE)
        self.AGEM = AGEM  #-> use gradient of replayed data as inequality constraint for (instead of adding it to)
        #   the gradient of the current data (as in A-GEM, see Chaudry et al., 2019; ICLR)

        # check whether there is at least 1 fc-layer
        if fc_layers < 1:
            raise ValueError(
                "The classifier needs to have at least 1 fully-connected layer."
            )

        ######------SPECIFY MODEL------######

        # flatten image to 2D-tensor
        self.flatten = utils.Flatten()

        # fully connected hidden layers
        if dataset == "ckplus" or dataset == "affectnet":
            self.input_size = image_size[0] * image_size[1] * image_channels
        else:
            self.input_size = image_channels * image_size**2

        self.fcE = MLP(input_size=self.input_size,
                       output_size=fc_units,
                       layers=fc_layers - 1,
                       hid_size=fc_units,
                       drop=fc_drop,
                       batch_norm=fc_bn,
                       nl=fc_nl,
                       bias=bias,
                       excitability=excitability,
                       excit_buffer=excit_buffer,
                       gated=gated,
                       latent_space=128)

    def list_init_layers(self):
        '''Return list of modules whose parameters could be initialized differently (i.e., conv- or fc-layers).'''
        list = []
        list += self.fcE.list_init_layers()
        # list += self.classifier.list_init_layers()
        return list

    @property
    def name(self):
        return "{}_c{}".format(self.fcE.name, self.classes)

    def forward(self, x):
        final_features = self.fcE(self.flatten(x))
        return final_features

    def feature_extractor(self, images):
        return self.fcE(self.flatten(images))

    def train_a_batch(self,
                      x,
                      y,
                      scores=None,
                      x_=None,
                      y_=None,
                      scores_=None,
                      rnt=0.5,
                      active_classes=None,
                      task=1):
        '''Train model for one batch ([x],[y]), possibly supplemented with replayed data ([x_],[y_/scores_]).

        [x]               <tensor> batch of inputs (could be None, in which case only 'replayed' data is used)
        [y]               <tensor> batch of corresponding labels
        [scores]          None or <tensor> 2Dtensor:[batch]x[classes] predicted "scores"/"logits" for [x]
                            NOTE: only to be used for "BCE with distill" (only when scenario=="class")
        [x_]              None or (<list> of) <tensor> batch of replayed inputs
        [y_]              None or (<list> of) <tensor> batch of corresponding "replayed" labels
        [scores_]         None or (<list> of) <tensor> 2Dtensor:[batch]x[classes] predicted "scores"/"logits" for [x_]
        [rnt]             <number> in [0,1], relative importance of new task
        [active_classes]  None or (<list> of) <list> with "active" classes
        [task]            <int>, for setting task-specific mask'''

        # Set model to training-mode
        self.train()

        # Reset optimizer
        self.optimizer.zero_grad()

        # Should gradient be computed separately for each task? (needed when a task-mask is combined with replay)
        gradient_per_task = True if ((self.mask_dict is not None) and
                                     (x_ is not None)) else False

        ##--(1)-- REPLAYED DATA --##

        if x_ is not None:
            # In the Task-IL scenario, [y_] or [scores_] is a list and [x_] needs to be evaluated on each of them
            # (in case of 'exact' or 'exemplar' replay, [x_] is also a list!
            TaskIL = (type(y_) == list) if (y_ is not None) else (type(scores_)
                                                                  == list)
            if not TaskIL:
                y_ = [y_]
                scores_ = [scores_]
                active_classes = [active_classes] if (active_classes
                                                      is not None) else None
            n_replays = len(y_) if (y_ is not None) else len(scores_)

            # Prepare lists to store losses for each replay
            loss_replay = [None] * n_replays
            predL_r = [None] * n_replays
            distilL_r = [None] * n_replays

            # Run model (if [x_] is not a list with separate replay per task and there is no task-specific mask)
            if (not type(x_) == list) and (self.mask_dict is None):
                y_hat_all = self(x_)

            # Loop to evalute predictions on replay according to each previous task
            for replay_id in range(n_replays):

                # -if [x_] is a list with separate replay per task, evaluate model on this task's replay
                if (type(x_) == list) or (self.mask_dict is not None):
                    x_temp_ = x_[replay_id] if type(x_) == list else x_
                    if self.mask_dict is not None:
                        self.apply_XdGmask(task=replay_id + 1)
                    y_hat_all = self(x_temp_)

                # -if needed (e.g., Task-IL or Class-IL scenario), remove predictions for classes not in replayed task
                y_hat = y_hat_all if (
                    active_classes is None
                ) else y_hat_all[:, active_classes[replay_id]]

                # Calculate losses
                if (y_ is not None) and (y_[replay_id] is not None):
                    if self.binaryCE:
                        binary_targets_ = utils.to_one_hot(
                            y_[replay_id].cpu(),
                            y_hat.size(1)).to(y_[replay_id].device)
                        predL_r[replay_id] = F.binary_cross_entropy_with_logits(
                            input=y_hat,
                            target=binary_targets_,
                            reduction='none').sum(dim=1).mean(
                            )  #--> sum over classes, then average over batch
                    else:
                        predL_r[replay_id] = F.cross_entropy(y_hat,
                                                             y_[replay_id],
                                                             reduction='mean')
                if (scores_ is not None) and (scores_[replay_id] is not None):
                    # n_classes_to_consider = scores.size(1) #--> with this version, no zeroes are added to [scores]!
                    n_classes_to_consider = y_hat.size(
                        1
                    )  #--> zeros will be added to [scores] to make it this size!
                    kd_fn = utils.loss_fn_kd_binary if self.binaryCE else utils.loss_fn_kd
                    distilL_r[replay_id] = kd_fn(
                        scores=y_hat[:, :n_classes_to_consider],
                        target_scores=scores_[replay_id],
                        T=self.KD_temp)
                # Weigh losses
                if self.replay_targets == "hard":
                    loss_replay[replay_id] = predL_r[replay_id]
                elif self.replay_targets == "soft":
                    loss_replay[replay_id] = distilL_r[replay_id]

                # If needed, perform backward pass before next task-mask (gradients of all tasks will be accumulated)
                if gradient_per_task:
                    weight = 1 if self.AGEM else (1 - rnt)
                    weighted_replay_loss_this_task = weight * loss_replay[
                        replay_id] / n_replays
                    weighted_replay_loss_this_task.backward()

        # Calculate total replay loss
        loss_replay = None if (x_ is None) else sum(loss_replay) / n_replays

        # If using A-GEM, calculate and store averaged gradient of replayed data
        if self.AGEM and x_ is not None:
            # Perform backward pass to calculate gradient of replayed batch (if not yet done)
            if not gradient_per_task:
                loss_replay.backward()
            # Reorganize the gradient of the replayed batch as a single vector
            grad_rep = []
            for p in self.parameters():
                if p.requires_grad:
                    grad_rep.append(p.grad.view(-1))
            grad_rep = torch.cat(grad_rep)
            # Reset gradients (with A-GEM, gradients of replayed batch should only be used as inequality constraint)
            self.optimizer.zero_grad()

        ##--(2)-- CURRENT DATA --##

        if x is not None:
            # If requested, apply correct task-specific mask
            if self.mask_dict is not None:
                self.apply_XdGmask(task=task)

            # Run model
            y_hat = self(x)
            # -if needed, remove predictions for classes not in current task
            if active_classes is not None:
                class_entries = active_classes[-1] if type(
                    active_classes[0]) == list else active_classes
                y_hat = y_hat[:, class_entries]

            # Calculate prediction loss
            if self.binaryCE:
                # -binary prediction loss
                binary_targets = utils.to_one_hot(y.cpu(),
                                                  y_hat.size(1)).to(y.device)
                if self.binaryCE_distill and (scores is not None):
                    classes_per_task = int(y_hat.size(1) / task)
                    binary_targets = binary_targets[:, -(classes_per_task):]
                    binary_targets = torch.cat(
                        [torch.sigmoid(scores / self.KD_temp), binary_targets],
                        dim=1)
                predL = None if y is None else F.binary_cross_entropy_with_logits(
                    input=y_hat, target=binary_targets, reduction='none').sum(
                        dim=1).mean(
                        )  #--> sum over classes, then average over batch
            else:
                # -multiclass prediction loss
                predL = None if y is None else F.cross_entropy(
                    input=y_hat, target=y, reduction='mean')

            # Weigh losses
            loss_cur = predL

            # Calculate training-precision
            precision = None if y is None else (
                y == y_hat.max(1)[1]).sum().item() / x.size(0)

            # If backward passes are performed per task (e.g., XdG combined with replay), perform backward pass
            if gradient_per_task:
                weighted_current_loss = rnt * loss_cur
                weighted_current_loss.backward()
        else:
            precision = predL = None
            # -> it's possible there is only "replay" [e.g., for offline with task-incremental learning]

        # Combine loss from current and replayed batch
        if x_ is None or self.AGEM:
            loss_total = loss_cur
        else:
            loss_total = loss_replay if (
                x is None) else rnt * loss_cur + (1 - rnt) * loss_replay

        ##--(3)-- ALLOCATION LOSSES --##

        # Add SI-loss (Zenke et al., 2017)
        surrogate_loss = self.surrogate_loss()
        if self.si_c > 0:
            loss_total += self.si_c * surrogate_loss

        # Add EWC-loss
        ewc_loss = self.ewc_loss()
        if self.ewc_lambda > 0:
            loss_total += self.ewc_lambda * ewc_loss

        # Backpropagate errors (if not yet done)
        if not gradient_per_task:
            loss_total.backward()

        # If using A-GEM, potentially change gradient:
        if self.AGEM and x_ is not None:
            # -reorganize gradient (of current batch) as single vector
            grad_cur = []
            for p in self.parameters():
                if p.requires_grad:
                    grad_cur.append(p.grad.view(-1))
            grad_cur = torch.cat(grad_cur)
            # -check inequality constrain
            angle = (grad_cur * grad_rep).sum()
            if angle < 0:
                # -if violated, project the gradient of the current batch onto the gradient of the replayed batch ...
                length_rep = (grad_rep * grad_rep).sum()
                grad_proj = grad_cur - (angle / length_rep) * grad_rep
                # -...and replace all the gradients within the model with this projected gradient
                index = 0
                for p in self.parameters():
                    if p.requires_grad:
                        n_param = p.numel()  # number of parameters in [p]
                        p.grad.copy_(grad_proj[index:index +
                                               n_param].view_as(p))
                        index += n_param

        # Take optimization-step
        self.optimizer.step()

        # Return the dictionary with different training-loss split in categories
        return {
            'loss_total':
            loss_total.item(),
            'loss_current':
            loss_cur.item() if x is not None else 0,
            'loss_replay':
            loss_replay.item() if
            (loss_replay is not None) and (x is not None) else 0,
            'pred':
            predL.item() if predL is not None else 0,
            'pred_r':
            sum(predL_r).item() / n_replays if
            (x_ is not None and predL_r[0] is not None) else 0,
            'distil_r':
            sum(distilL_r).item() / n_replays if
            (x_ is not None and distilL_r[0] is not None) else 0,
            'ewc':
            ewc_loss.item(),
            'si_loss':
            surrogate_loss.item(),
            'precision':
            precision if precision is not None else 0.,
        }
示例#8
0
class Classifier(ContinualLearner, Replayer, ExemplarHandler):
    '''Model for classifying images, "enriched" as "ContinualLearner"-, Replayer- and ExemplarHandler-object.'''
    def __init__(self,
                 image_size,
                 image_channels,
                 classes,
                 fc_layers=3,
                 fc_units=1000,
                 fc_drop=0,
                 fc_bn=True,
                 fc_nl="relu",
                 gated=False,
                 bias=True,
                 excitability=False,
                 excit_buffer=False,
                 binaryCE=False,
                 binaryCE_distill=False):

        # configurations
        super().__init__()
        self.classes = classes
        self.label = "Classifier"
        self.fc_layers = fc_layers

        # settings for training
        self.binaryCE = binaryCE
        self.binaryCE_distill = binaryCE_distill

        # check whether there is at least 1 fc-layer
        if fc_layers < 1:
            raise ValueError(
                "The classifier needs to have at least 1 fully-connected layer."
            )

        ######------SPECIFY MODEL------######

        # flatten image to 2D-tensor
        self.flatten = utils.Flatten()

        # fully connected hidden layers
        self.fcE = MLP(input_size=image_channels * image_size**2,
                       output_size=fc_units,
                       layers=fc_layers - 1,
                       hid_size=fc_units,
                       drop=fc_drop,
                       batch_norm=fc_bn,
                       nl=fc_nl,
                       bias=bias,
                       excitability=excitability,
                       excit_buffer=excit_buffer,
                       gated=gated)
        mlp_output_size = fc_units if fc_layers > 1 else image_channels * image_size**2
        print('*************num of classes in encoder: ' + str(classes))
        self.vgg = vgg16(classes)

        # classifier
        #self.classifier = fc_layer(mlp_output_size, classes, excit_buffer=True, nl='none', drop=fc_drop)

    def list_init_layers(self):
        '''Return list of modules whose parameters could be initialized differently (i.e., conv- or fc-layers).'''
        list = []
        list += self.fcE.list_init_layers()
        list += self.classifier.list_init_layers()
        return list

    @property
    def name(self):
        #return "{}_c{}".format(self.fcE.name, self.classes)
        return "vgg"

    def forward(self, x):
        return self.vgg(x)

    def feature_extractor(self, images):
        return self.fcE(self.flatten(images))

    def train_a_batch(self,
                      x,
                      y,
                      scores=None,
                      x_=None,
                      y_=None,
                      scores_=None,
                      rnt=0.5,
                      active_classes=None,
                      task=1):
        '''Train model for one batch ([x],[y]), possibly supplemented with replayed data ([x_],[y_/scores_]).

        [x]               <tensor> batch of inputs (could be None, in which case only 'replayed' data is used)
        [y]               <tensor> batch of corresponding labels
        [scores]          None or <tensor> 2Dtensor:[batch]x[classes] predicted "scores"/"logits" for [x]

        [x_]              None or (<list> of) <tensor> batch of replayed inputs
        [y_]              None or (<list> of) <tensor> batch of corresponding "replayed" labels
        [scores_]         None or (<list> of) <tensor> 2Dtensor:[batch]x[classes] predicted "scores"/"logits" for [x_]
        [rnt]             <number> in [0,1], relative importance of new task
        [active_classes]  None or (<list> of) <list> with "active" classes
        [task]            <int>, for setting task-specific mask'''

        # Set model to training-mode
        self.train()

        # Reset optimizer
        self.optimizer.zero_grad()

        ##--(1)-- CURRENT DATA --##

        if x is not None:
            # If requested, apply correct task-specific mask
            if self.mask_dict is not None:
                self.apply_XdGmask(task=task)

            # Run model
            y_hat = self(x)
            # -if needed, remove predictions for classes not in current task
            if active_classes is not None:
                class_entries = active_classes[-1] if type(
                    active_classes[0]) == list else active_classes
                y_hat = y_hat[:, class_entries]

            # Calculate prediction loss
            if self.binaryCE:
                # -binary prediction loss
                binary_targets = utils.to_one_hot(y.cpu(),
                                                  y_hat.size(1)).to(y.device)
                if self.binaryCE_distill and (scores is not None):
                    classes_per_task = int(y_hat.size(1) / task)
                    binary_targets = binary_targets[:, -(classes_per_task):]
                    binary_targets = torch.cat(
                        [torch.sigmoid(scores / self.KD_temp), binary_targets],
                        dim=1)
                predL = None if y is None else F.binary_cross_entropy_with_logits(
                    input=y_hat, target=binary_targets, reduction='none').sum(
                        dim=1).mean(
                        )  #--> sum over classes, then average over batch
            else:
                # -multiclass prediction loss
                predL = None if y is None else F.cross_entropy(
                    input=y_hat, target=y, reduction='elementwise_mean')

            # Weigh losses
            loss_cur = predL

            # Calculate training-precision
            precision = None if y is None else (
                y == y_hat.max(1)[1]).sum().item() / x.size(0)

            # If XdG is combined with replay, backward-pass needs to be done before new task-mask is applied
            if (self.mask_dict is not None) and (x_ is not None):
                weighted_current_loss = rnt * loss_cur
                weighted_current_loss.backward()
        else:
            precision = predL = None
            # -> it's possible there is only "replay" [i.e., for offline with incremental task learning]

        ##--(2)-- REPLAYED DATA --##

        if x_ is not None:
            # In the Task-IL scenario, [y_] or [scores_] is a list and [x_] needs to be evaluated on each of them
            # (in case of 'exact' or 'exemplar' replay, [x_] is also a list!
            TaskIL = (type(y_) == list) if (y_ is not None) else (type(scores_)
                                                                  == list)
            if not TaskIL:
                y_ = [y_]
                scores_ = [scores_]
                active_classes = [active_classes] if (active_classes
                                                      is not None) else None
            n_replays = len(y_) if (y_ is not None) else len(scores_)

            # Prepare lists to store losses for each replay
            loss_replay = [None] * n_replays
            predL_r = [None] * n_replays
            distilL_r = [None] * n_replays

            # Run model (if [x_] is not a list with separate replay per task and there is no task-specific mask)
            if (not type(x_) == list) and (self.mask_dict is None):
                y_hat_all = self(x_)

            # Loop to evalute predictions on replay according to each previous task
            for replay_id in range(n_replays):

                # -if [x_] is a list with separate replay per task, evaluate model on this task's replay
                if (type(x_) == list) or (self.mask_dict is not None):
                    x_temp_ = x_[replay_id] if type(x_) == list else x_
                    if self.mask_dict is not None:
                        self.apply_XdGmask(task=replay_id + 1)
                    y_hat_all = self(x_temp_)

                # -if needed (e.g., Task-IL or Class-IL scenario), remove predictions for classes not in replayed task
                y_hat = y_hat_all if (
                    active_classes is None
                ) else y_hat_all[:, active_classes[replay_id]]

                # Calculate losses
                if (y_ is not None) and (y_[replay_id] is not None):
                    if self.binaryCE:
                        binary_targets_ = utils.to_one_hot(
                            y_[replay_id].cpu(),
                            y_hat.size(1)).to(y_[replay_id].device)
                        predL_r[replay_id] = F.binary_cross_entropy_with_logits(
                            input=y_hat,
                            target=binary_targets_,
                            reduction='none').sum(dim=1).mean(
                            )  #--> sum over classes, then average over batch
                    else:
                        predL_r[replay_id] = F.cross_entropy(
                            y_hat, y_[replay_id], reduction='elementwise_mean')
                if (scores_ is not None) and (scores_[replay_id] is not None):
                    # n_classes_to_consider = scores.size(1) #--> with this version, no zeroes are added to [scores]!
                    n_classes_to_consider = y_hat.size(
                        1
                    )  #--> zeros will be added to [scores] to make it this size!
                    kd_fn = utils.loss_fn_kd_binary if self.binaryCE else utils.loss_fn_kd
                    distilL_r[replay_id] = kd_fn(
                        scores=y_hat[:, :n_classes_to_consider],
                        target_scores=scores_[replay_id],
                        T=self.KD_temp)
                # Weigh losses
                if self.replay_targets == "hard":
                    loss_replay[replay_id] = predL_r[replay_id]
                elif self.replay_targets == "soft":
                    loss_replay[replay_id] = distilL_r[replay_id]

                # If task-specific mask, backward pass needs to be performed before next task-mask is applied
                if self.mask_dict is not None:
                    weighted_replay_loss_this_task = (
                        1 - rnt) * loss_replay[replay_id] / n_replays
                    weighted_replay_loss_this_task.backward()

        # Calculate total loss
        loss_replay = None if (x_ is None) else sum(loss_replay) / n_replays
        loss_total = loss_replay if (
            x is None) else (loss_cur if x_ is None else rnt * loss_cur +
                             (1 - rnt) * loss_replay)

        ##--(3)-- ALLOCATION LOSSES --##

        # Add SI-loss (Zenke et al., 2017)
        surrogate_loss = self.surrogate_loss()
        if self.si_c > 0:
            loss_total += self.si_c * surrogate_loss

        # Add EWC-loss
        ewc_loss = self.ewc_loss()
        if self.ewc_lambda > 0:
            loss_total += self.ewc_lambda * ewc_loss

        # Backpropagate errors (if not yet done)
        if (self.mask_dict is None) or (x_ is None):
            loss_total.backward()
        # Take optimization-step
        self.optimizer.step()

        # Return the dictionary with different training-loss split in categories
        return {
            'loss_total':
            loss_total.item(),
            'loss_current':
            loss_cur.item() if x is not None else 0,
            'loss_replay':
            loss_replay.item() if
            (loss_replay is not None) and (x is not None) else 0,
            'pred':
            predL.item() if predL is not None else 0,
            'pred_r':
            sum(predL_r).item() / n_replays if
            (x_ is not None and predL_r[0] is not None) else 0,
            'distil_r':
            sum(distilL_r).item() / n_replays if
            (x_ is not None and distilL_r[0] is not None) else 0,
            'ewc':
            ewc_loss.item(),
            'si_loss':
            surrogate_loss.item(),
            'precision':
            precision if precision is not None else 0.,
        }
    def __init__(self,
                 num_features,
                 num_seq,
                 classes,
                 fc_layers=3,
                 fc_units=1000,
                 fc_drop=0,
                 fc_bn=True,
                 fc_nl="relu",
                 gated=False,
                 bias=True,
                 excitability=None,
                 excit_buffer=False,
                 binaryCE=False,
                 binaryCE_distill=False,
                 experiment='splitMNIST',
                 cls_type='mlp',
                 args=None):

        # configurations
        super().__init__()
        self.num_features = num_features
        self.num_seq = num_seq
        self.classes = classes
        self.label = "Classifier"
        self.fc_layers = fc_layers
        self.hidden_dim = fc_units
        self.layer_dim = fc_layers - 1
        self.cuda = None if args is None else args.cuda
        self.device = args.device
        self.weights_per_class = None if args is None else torch.FloatTensor(
            args.weights_per_class).to(args.device)

        # store precision_dict into model so that we can fetch
        # self.precision_dict_list = [[] for i in range(len(args.num_classes_per_task_l))]
        # self.precision_dict = {}

        # settings for training
        self.binaryCE = binaryCE
        self.binaryCE_distill = binaryCE_distill

        # check whether there is at least 1 fc-layer
        if fc_layers < 1:
            raise ValueError(
                "The classifier needs to have at least 1 fully-connected layer."
            )

        ######------SPECIFY MODEL------######
        self.cls_type = cls_type
        self.experiment = experiment
        # flatten image to 2D-tensor
        self.flatten = utils.Flatten()

        # fully connected hidden layers
        if experiment == 'sensor':
            if self.cls_type == 'mlp':
                self.fcE = MLP(input_size=num_seq * num_features,
                               output_size=fc_units,
                               layers=fc_layers - 1,
                               hid_size=fc_units,
                               drop=fc_drop,
                               batch_norm=fc_bn,
                               nl=fc_nl,
                               bias=bias,
                               excitability=excitability,
                               excit_buffer=excit_buffer,
                               gated=gated)
            elif self.cls_type == 'lstm':
                self.lstm_input_dropout = nn.Dropout(args.input_drop)
                self.lstm = nn.LSTM(input_size=num_features,
                                    hidden_size=fc_units,
                                    num_layers=fc_layers - 1,
                                    dropout=0.0 if
                                    (fc_layers - 1) == 1 else fc_drop,
                                    batch_first=True)
                # self.name = "LSTM([{} X {} X {}])".format(num_features, num_seq, classes) if self.fc_layers > 0 else ""
        else:
            self.fcE = MLP(input_size=num_seq * num_features**2,
                           output_size=fc_units,
                           layers=fc_layers - 1,
                           hid_size=fc_units,
                           drop=fc_drop,
                           batch_norm=fc_bn,
                           nl=fc_nl,
                           bias=bias,
                           excitability=excitability,
                           excit_buffer=excit_buffer,
                           gated=gated)
        # classifier
        if self.cls_type == 'mlp':
            mlp_output_size = fc_units if fc_layers > 1 else num_seq * num_features**2
            self.classifier = fc_layer(mlp_output_size,
                                       classes,
                                       excit_buffer=True,
                                       nl='none',
                                       drop=fc_drop)
        elif self.cls_type == 'lstm':
            self.lstm_fc = nn.Linear(fc_units, classes)

        #################
        # +++++ GEM +++++
        #####
        if args.gem:
            print('this is test for GEM ')
            self.margin = args.memory_strength
            self.ce = nn.CrossEntropyLoss()
            self.n_outputs = classes
            self.n_memories = args.n_memories
            self.gpu = args.cuda
            n_tasks = len(args.num_classes_per_task_l)
            # allocate episodic memory
            self.memory_data = torch.FloatTensor(n_tasks, self.n_memories,
                                                 self.num_seq,
                                                 self.num_features)
            self.memory_labs = torch.LongTensor(n_tasks, self.n_memories)
            if args.cuda:
                # self.memory_data = self.memory_data.cuda()
                self.memory_data = self.memory_data.to(self.device)
                # self.memory_labs = self.memory_labs.cuda()
                self.memory_labs = self.memory_labs.to(self.device)

            # allocate temporary synaptic memory
            self.grad_dims = []
            for param in self.parameters():
                self.grad_dims.append(param.data.numel())
            self.grads = torch.Tensor(sum(self.grad_dims), n_tasks)
            if args.cuda:
                # self.grads = self.grads.cuda()
                self.grads = self.grads.to(self.device)

            # allocate counters
            self.observed_tasks = []
            self.old_task = -1
            self.mem_cnt = 0
class Classifier(ContinualLearner, Replayer, ExemplarHandler):
    '''Model for classifying images, "enriched" as "ContinualLearner"-, Replayer- and ExemplarHandler-object.'''
    def __init__(self,
                 num_features,
                 num_seq,
                 classes,
                 fc_layers=3,
                 fc_units=1000,
                 fc_drop=0,
                 fc_bn=True,
                 fc_nl="relu",
                 gated=False,
                 bias=True,
                 excitability=None,
                 excit_buffer=False,
                 binaryCE=False,
                 binaryCE_distill=False,
                 experiment='splitMNIST',
                 cls_type='mlp',
                 args=None):

        # configurations
        super().__init__()
        self.num_features = num_features
        self.num_seq = num_seq
        self.classes = classes
        self.label = "Classifier"
        self.fc_layers = fc_layers
        self.hidden_dim = fc_units
        self.layer_dim = fc_layers - 1
        self.cuda = None if args is None else args.cuda
        self.device = args.device
        self.weights_per_class = None if args is None else torch.FloatTensor(
            args.weights_per_class).to(args.device)

        # store precision_dict into model so that we can fetch
        # self.precision_dict_list = [[] for i in range(len(args.num_classes_per_task_l))]
        # self.precision_dict = {}

        # settings for training
        self.binaryCE = binaryCE
        self.binaryCE_distill = binaryCE_distill

        # check whether there is at least 1 fc-layer
        if fc_layers < 1:
            raise ValueError(
                "The classifier needs to have at least 1 fully-connected layer."
            )

        ######------SPECIFY MODEL------######
        self.cls_type = cls_type
        self.experiment = experiment
        # flatten image to 2D-tensor
        self.flatten = utils.Flatten()

        # fully connected hidden layers
        if experiment == 'sensor':
            if self.cls_type == 'mlp':
                self.fcE = MLP(input_size=num_seq * num_features,
                               output_size=fc_units,
                               layers=fc_layers - 1,
                               hid_size=fc_units,
                               drop=fc_drop,
                               batch_norm=fc_bn,
                               nl=fc_nl,
                               bias=bias,
                               excitability=excitability,
                               excit_buffer=excit_buffer,
                               gated=gated)
            elif self.cls_type == 'lstm':
                self.lstm_input_dropout = nn.Dropout(args.input_drop)
                self.lstm = nn.LSTM(input_size=num_features,
                                    hidden_size=fc_units,
                                    num_layers=fc_layers - 1,
                                    dropout=0.0 if
                                    (fc_layers - 1) == 1 else fc_drop,
                                    batch_first=True)
                # self.name = "LSTM([{} X {} X {}])".format(num_features, num_seq, classes) if self.fc_layers > 0 else ""
        else:
            self.fcE = MLP(input_size=num_seq * num_features**2,
                           output_size=fc_units,
                           layers=fc_layers - 1,
                           hid_size=fc_units,
                           drop=fc_drop,
                           batch_norm=fc_bn,
                           nl=fc_nl,
                           bias=bias,
                           excitability=excitability,
                           excit_buffer=excit_buffer,
                           gated=gated)
        # classifier
        if self.cls_type == 'mlp':
            mlp_output_size = fc_units if fc_layers > 1 else num_seq * num_features**2
            self.classifier = fc_layer(mlp_output_size,
                                       classes,
                                       excit_buffer=True,
                                       nl='none',
                                       drop=fc_drop)
        elif self.cls_type == 'lstm':
            self.lstm_fc = nn.Linear(fc_units, classes)

        #################
        # +++++ GEM +++++
        #####
        if args.gem:
            print('this is test for GEM ')
            self.margin = args.memory_strength
            self.ce = nn.CrossEntropyLoss()
            self.n_outputs = classes
            self.n_memories = args.n_memories
            self.gpu = args.cuda
            n_tasks = len(args.num_classes_per_task_l)
            # allocate episodic memory
            self.memory_data = torch.FloatTensor(n_tasks, self.n_memories,
                                                 self.num_seq,
                                                 self.num_features)
            self.memory_labs = torch.LongTensor(n_tasks, self.n_memories)
            if args.cuda:
                # self.memory_data = self.memory_data.cuda()
                self.memory_data = self.memory_data.to(self.device)
                # self.memory_labs = self.memory_labs.cuda()
                self.memory_labs = self.memory_labs.to(self.device)

            # allocate temporary synaptic memory
            self.grad_dims = []
            for param in self.parameters():
                self.grad_dims.append(param.data.numel())
            self.grads = torch.Tensor(sum(self.grad_dims), n_tasks)
            if args.cuda:
                # self.grads = self.grads.cuda()
                self.grads = self.grads.to(self.device)

            # allocate counters
            self.observed_tasks = []
            self.old_task = -1
            self.mem_cnt = 0

    def list_init_layers(self):
        '''Return list of modules whose parameters could be initialized differently (i.e., conv- or fc-layers).'''
        list = []
        list += self.fcE.list_init_layers()
        list += self.classifier.list_init_layers()
        return list

    @property
    def name(self):
        if self.cls_type == 'mlp':
            return "{}_c{}".format(self.fcE.name, self.classes)
        elif self.cls_type == 'lstm':
            return "LSTM([{} X {}]_c{})".format(self.num_seq,
                                                self.num_features,
                                                self.classes)

    def forward(self, x):
        if self.cls_type == 'mlp':
            final_features = self.fcE(self.flatten(x))
            return self.classifier(final_features)
        elif self.cls_type == 'lstm':
            x = self.lstm_input_dropout(x)
            h0, c0 = self.init_hidden(x)
            out, (hn, cn) = self.lstm(x, (h0, c0))
            return self.lstm_fc(out[:, -1, :])
            # lstm_out, hidden = self.lstm(x)
            # print(lstm_out.size())
            # print(x.size())
            # print(lstm_out[-1].size())
            # return self.lstm_fc(lstm_out[-1].view(x.size(0), -1))

    def feature_extractor(self, x):
        if self.cls_type == 'mlp':
            return self.fcE(self.flatten(x))
        elif self.cls_type == 'lstm':
            x = self.lstm_input_dropout(x)
            h0, c0 = self.init_hidden(x)
            out, (hn, cn) = self.lstm(x, (h0, c0))
            return out[:, -1, :]
            # lstm_out, hidden = self.lstm(images)
            # return lstm_out[-1]

    def forward_from_hidden_layer(self, x):
        if self.cls_type == 'mlp': return self.classifier(x)
        elif self.cls_type == 'lstm': return self.lstm_fc(x)

    def init_hidden(self, x):
        h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim)
        c0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim)
        return [t.to(self.device) for t in (h0, c0)] if self.cuda else (h0, c0)

    def train_a_batch(self,
                      x,
                      y,
                      x_=None,
                      y_=None,
                      x_ex=None,
                      y_ex=None,
                      scores=None,
                      scores_=None,
                      rnt=0.5,
                      active_classes=None,
                      num_classes_per_task_l=None,
                      task=1,
                      args=None):
        '''Train model for one batch ([x],[y]), possibly supplemented with replayed data ([x_],[y_/scores_]).

        [x]               <tensor> batch of inputs (could be None, in which case only 'replayed' data is used)
        [y]               <tensor> batch of corresponding labels
        [scores]          None or <tensor> 2Dtensor:[batch]x[classes] predicted "scores"/"logits" for [x]
                            NOTE: only to be used for "BCE with distill" (only when scenario=="class")
        [x_]              None or (<list> of) <tensor> batch of replayed inputs
        [y_]              None or (<list> of) <tensor> batch of corresponding "replayed" labels
        [x_ex]              None or (<list> of) <tensor> batch of exemplars inputs
        [y_ex]              None or (<list> of) <tensor> batch of exemplars inputs' labels
        [scores_]         None or (<list> of) <tensor> 2Dtensor:[batch]x[classes] predicted "scores"/"logits" for [x_]
        [rnt]             <number> in [0,1], relative importance of new task
        [active_classes]  None or (<list> of) <list> with "active" classes
        [task]            <int>, for setting task-specific mask'''

        y.long()
        if y_ is not None:
            y_.long()
        if y_ex is not None:
            y_ex.long()

        # Set model to training-mode
        self.train()

        if args.gem:
            start_gem = time.time()
            t = task - 1
            # update memory
            if t != self.old_task:
                self.observed_tasks.append(t)
                self.old_task = t

            # Update ring buffer storing examples from current task
            bsz = y.data.size(0)
            endcnt = min(self.mem_cnt + bsz, self.n_memories)
            effbsz = endcnt - self.mem_cnt
            self.memory_data[t, self.mem_cnt:endcnt].copy_(x.data[:effbsz])
            if bsz == 1:
                self.memory_labs[t, self.mem_cnt] = y.data[0]
            else:
                self.memory_labs[t, self.mem_cnt:endcnt].copy_(y.data[:effbsz])
            self.mem_cnt += effbsz
            if self.mem_cnt == self.n_memories:
                self.mem_cnt = 0
            args.train_time_gem_update_memory += time.time() - start_gem

            start_gem = time.time()
            # compute gradient on previous tasks
            if len(self.observed_tasks) > 1:
                # print('self.observed_tasks: ', self.observed_tasks)
                for tt in range(len(self.observed_tasks) - 1):
                    self.zero_grad()
                    # fwd/bwd on the examples in the memory
                    past_task = self.observed_tasks[tt]

                    offset1, offset2 = my_compute_offsets(
                        task=past_task,
                        num_classes_per_task_l=num_classes_per_task_l)
                    # ptloss = self.ce(
                    #     input=self.forward(self.memory_data[past_task])[:, offset1: offset2],
                    #     target=self.memory_labs[past_task] - offset1)
                    ptloss = F.cross_entropy(
                        input=self.forward(
                            self.memory_data[past_task])[:, offset1:offset2],
                        target=self.memory_labs[past_task] - offset1,
                        weight=self.weights_per_class[offset1:offset2])
                    ptloss.backward()
                    store_grad(self.parameters, self.grads, self.grad_dims,
                               past_task)

            args.train_time_gem_compute_gradient += time.time() - start_gem
            # now compute the grad on the current minibatch
            self.zero_grad()

            # print(t, num_classes_per_task_l)
            offset1, offset2 = my_compute_offsets(
                task=t, num_classes_per_task_l=num_classes_per_task_l)
            # print(self.forward(x)[:, offset1: offset2].size())
            # print(y.size())
            # loss = self.ce(
            #     input=self.forward(x)[:, offset1: offset2],
            #     target=y - offset1)
            loss = F.cross_entropy(
                input=self.forward(x)[:, offset1:offset2],
                target=y - offset1,
                weight=self.weights_per_class[offset1:offset2])
            loss.backward()

            # check if gradient violates constraints
            start_gem = time.time()
            if len(self.observed_tasks) > 1:
                # copy gradient
                store_grad(self.parameters, self.grads, self.grad_dims, t)
                # indx = torch.cuda.LongTensor(self.observed_tasks[:-1]) if self.gpu \
                #     else torch.LongTensor(self.observed_tasks[:-1])
                indx = torch.LongTensor(self.observed_tasks[:-1]).to(self.device) if self.gpu \
                    else torch.LongTensor(self.observed_tasks[:-1])
                # print(indx)
                dotp = torch.mm(self.grads[:, t].unsqueeze(0),
                                self.grads.index_select(1, indx))
                if (dotp < 0).sum() != 0:
                    project2cone2(self.grads[:, t].unsqueeze(1),
                                  self.grads.index_select(1, indx),
                                  self.margin)
                    # copy gradients back
                    overwrite_grad(self.parameters, self.grads[:, t],
                                   self.grad_dims)
            args.train_time_gem_violation_check += time.time() - start_gem
            self.optimizer.step()

            return {
                'loss_total': loss.item(),
                'loss_current': 0,
                'loss_replay': 0,
                'pred': 0,
                'pred_r': 0,
                'distil_r': 0,
                'ewc': 0.,
                'si_loss': 0.,
                'precision': 0.,
            }
        else:
            # Reset optimizer
            self.optimizer.zero_grad()
            ##--(1)-- CURRENT DATA --##

            if x is not None:
                # If requested, apply correct task-specific mask
                if self.mask_dict is not None:
                    self.apply_XdGmask(task=task)

                # Run model
                if args.augment < 0:
                    y_hat = self(x)
                else:
                    features = self.feature_extractor(x)
                    features_ex = self.feature_extractor(x_ex)
                    # y_hat = self.forward_from_hidden_layer(features)
                    # y_hat_ex = self.forward_from_hidden_layer(features_ex)

                #####
                ##### perform augmentation on feature space #####
                #####
                if args.augment == 0:  # random augmentation

                    features = torch.cat([features, features_ex])
                    y = torch.cat([y, y_ex])
                    # now let's add some noise!
                    for i in range(args.scaling - 1):
                        features = torch.cat([
                            features, features_ex +
                            torch.randn(features_ex.shape).to(self.device) *
                            args.sd
                        ])
                        y = torch.cat([y, y_ex])
                        if features.shape[0] > args.batch:
                            features = features[:args.batch, :]
                            y = y[:args.batch]
                            break

                    y_hat = self.forward_from_hidden_layer(features)
                elif args.augment == 1:  # feature augmentation based on standard deviation
                    pass
                elif args.augment == 2:  # feature augmentation based on SMOTE method
                    pass

                # -if needed, remove predictions for classes not in current task
                if active_classes is not None:
                    class_entries = active_classes[-1] if type(
                        active_classes[0]) == list else active_classes
                    y_hat = y_hat[:, class_entries]
                    # if args.augment >= 0:
                    #     y_hat_ex = y_hat_ex[:, class_entries]

                # Calculate prediction loss
                if self.binaryCE:  # ICARL goes into this. binaryCE is True.
                    # -binary prediction loss
                    # print('x.shape: ', x.shape)
                    # print('y.shape: ', y.shape)
                    # print('y_hat.shape: ', y_hat.shape)
                    # print('y: ', y)
                    start_icarl = time.time()
                    binary_targets = utils.to_one_hot(
                        y.cpu(), y_hat.size(1)).to(y.device)
                    # binary_targets = [0, 0, 1, ... , 0] <=> class = 2
                    # print(binary_targets.size()) # [128 x 17]
                    if self.binaryCE_distill and (
                            scores is not None
                    ):  # ICARL does not go into this cuz scores is None
                        if args.experiment == 'sensor':
                            binary_targets = binary_targets[:,
                                                            sum(num_classes_per_task_l[:(
                                                                task - 1)]):]
                        else:
                            classes_per_task = int(y_hat.size(1) / task)
                            binary_targets = binary_targets[:, -(
                                classes_per_task):]
                        #     print(classes_per_task) # 8
                        # print(binary_targets.size()) # [128 x 1]
                        binary_targets = torch.cat([
                            torch.sigmoid(scores / self.KD_temp),
                            binary_targets
                        ],
                                                   dim=1)
                        # print(binary_targets.size()) # [128 x 17] => this supposed to be [128 x 17 (16 + 1)]
                        # print(scores.size()) # [128 x 16]
                        # print(self.KD_temp) # 2
                    predL = None if y is None else F.binary_cross_entropy_with_logits(
                        input=y_hat, target=binary_targets,
                        reduction='none').sum(dim=1).mean(
                        )  #--> sum over classes, then average over batch
                    args.train_time_icarl_loss += time.time() - start_icarl
                else:
                    # -multiclass prediction loss
                    # print("x", x.shape, x)
                    # print("y_hat", y_hat.shape, y_hat)
                    # print("y", y.shape, y)
                    predL = None if y is None else F.cross_entropy(
                        input=y_hat,
                        target=y,
                        weight=self.weights_per_class[class_entries],
                        reduction='elementwise_mean')

                # Weigh losses
                loss_cur = predL

                # Calculate training-precision
                precision = None if y is None else (
                    y == y_hat.max(1)[1]).sum().item() / x.size(0)

                # If XdG is combined with replay, backward-pass needs to be done before new task-mask is applied
                if (self.mask_dict is not None) and (x_ is not None):
                    weighted_current_loss = rnt * loss_cur
                    weighted_current_loss.backward()
            else:
                precision = predL = None
                # -> it's possible there is only "replay" [i.e., for offline with incremental task learning]

            ##--(2)-- REPLAYED DATA --##

            if x_ is not None:
                # In the Task-IL scenario, [y_] or [scores_] is a list and [x_] needs to be evaluated on each of them
                # (in case of 'exact' or 'exemplar' replay, [x_] is also a list!
                start_lwf = time.time()

                TaskIL = (type(y_) == list) if (y_ is not None) else (
                    type(scores_) == list)
                if not TaskIL:
                    y_ = [y_]
                    scores_ = [scores_]
                    active_classes = [active_classes] if (
                        active_classes is not None) else None
                n_replays = len(y_) if (y_ is not None) else len(scores_)

                # Prepare lists to store losses for each replay
                loss_replay = [None] * n_replays
                predL_r = [None] * n_replays
                distilL_r = [None] * n_replays

                # Run model (if [x_] is not a list with separate replay per task and there is no task-specific mask)
                if (not type(x_) == list) and (self.mask_dict is None):
                    y_hat_all = self(x_)

                # Loop to evalute predictions on replay according to each previous task
                for replay_id in range(n_replays):

                    # -if [x_] is a list with separate replay per task, evaluate model on this task's replay
                    if (type(x_) == list) or (self.mask_dict is not None):
                        x_temp_ = x_[replay_id] if type(x_) == list else x_
                        if self.mask_dict is not None:
                            self.apply_XdGmask(task=replay_id + 1)
                        y_hat_all = self(x_temp_)

                    # -if needed (e.g., Task-IL or Class-IL scenario), remove predictions for classes not in replayed task
                    y_hat = y_hat_all if (
                        active_classes is None
                    ) else y_hat_all[:, active_classes[replay_id]]

                    # Calculate losses
                    if (y_ is not None) and (y_[replay_id] is not None):
                        if self.binaryCE:
                            binary_targets_ = utils.to_one_hot(
                                y_[replay_id].cpu(),
                                y_hat.size(1)).to(y_[replay_id].device)
                            predL_r[
                                replay_id] = F.binary_cross_entropy_with_logits(
                                    input=y_hat,
                                    target=binary_targets_,
                                    reduction='none'
                                ).sum(dim=1).mean(
                                )  #--> sum over classes, then average over batch
                        else:
                            predL_r[replay_id] = F.cross_entropy(
                                input=y_hat,
                                target=y_[replay_id],
                                weight=self.weights_per_class[
                                    active_classes[replay_id]],
                                reduction='elementwise_mean')
                    if (scores_ is not None) and (scores_[replay_id]
                                                  is not None):
                        # n_classes_to_consider = scores.size(1) #--> with this version, no zeroes are added to [scores]!
                        n_classes_to_consider = y_hat.size(
                            1
                        )  #--> zeros will be added to [scores] to make it this size!
                        kd_fn = utils.loss_fn_kd_binary if self.binaryCE else utils.loss_fn_kd
                        distilL_r[replay_id] = kd_fn(
                            scores=y_hat[:, :n_classes_to_consider],
                            target_scores=scores_[replay_id],
                            T=self.KD_temp)
                    # Weigh losses
                    if self.replay_targets == "hard":
                        loss_replay[replay_id] = predL_r[replay_id]
                    elif self.replay_targets == "soft":
                        loss_replay[replay_id] = distilL_r[replay_id]

                    # If task-specific mask, backward pass needs to be performed before next task-mask is applied
                    if self.mask_dict is not None:
                        weighted_replay_loss_this_task = (
                            1 - rnt) * loss_replay[replay_id] / n_replays
                        weighted_replay_loss_this_task.backward()
                args.train_time_lwf_loss += time.time() - start_lwf

            # Calculate total loss with replay loss if it exists.
            if x_ is None:
                loss_replay = None
            else:
                start_lwf = time.time()
                loss_replay = sum(loss_replay) / n_replays
                args.train_time_lwf_loss += time.time() - start_lwf
            if x is None:
                start_lwf = time.time()
                loss_total = loss_replay
                args.train_time_lwf_loss += time.time() - start_lwf
            else:
                if x_ is None:
                    loss_total = loss_cur
                else:
                    start_lwf = time.time()
                    loss_total = rnt * loss_cur + (1 - rnt) * loss_replay
                    args.train_time_lwf_loss += time.time() - start_lwf

            # loss_replay = None if (x_ is None) else sum(loss_replay)/n_replays
            # loss_total = loss_replay if (x is None) else (loss_cur if x_ is None else rnt*loss_cur+(1-rnt)*loss_replay)

            ##--(3)-- ALLOCATION LOSSES --##

            # Add SI-loss (Zenke et al., 2017)
            if self.si_c > 0:
                start_si = time.time()
                surrogate_loss = self.surrogate_loss()
                loss_total += self.si_c * surrogate_loss
                args.train_time_si_loss += time.time() - start_si

            # Add EWC-loss
            if self.ewc_lambda > 0:
                start_ewc = time.time()
                ewc_loss = self.ewc_loss()
                loss_total += self.ewc_lambda * ewc_loss
                args.train_time_ewc_loss += time.time() - start_ewc

            # Backpropagate errors (if not yet done)
            if (self.mask_dict is None) or (x_ is None):
                loss_total.backward()
            # Take optimization-step
            self.optimizer.step()

            # Return the dictionary with different training-loss split in categories
            return {
                'loss_total':
                loss_total.item(),
                'loss_current':
                loss_cur.item() if x is not None else 0,
                'loss_replay':
                loss_replay.item() if
                (loss_replay is not None) and (x is not None) else 0,
                'pred':
                predL.item() if predL is not None else 0,
                'pred_r':
                sum(predL_r).item() / n_replays if
                (x_ is not None and predL_r[0] is not None) else 0,
                'distil_r':
                sum(distilL_r).item() / n_replays if
                (x_ is not None and distilL_r[0] is not None) else 0,
                'ewc':
                ewc_loss.item() if self.ewc_lambda > 0 else 0.0,
                'si_loss':
                surrogate_loss.item() if self.si_c > 0 else 0.0,
                'precision':
                precision if precision is not None else 0.,
            }
示例#11
0
class Classifier(ContinualLearner, Replayer, ExemplarHandler):
    '''Model for classifying images, "enriched" as "ContinualLearner"-, Replayer- and ExemplarHandler-object.'''

    def __init__(self, image_size, image_channels, classes,
                 fc_layers=3, fc_units=1000, fc_drop=0, fc_bn=False, fc_nl="relu", gated=False,
                 bias=True, excitability=False, excit_buffer=False, binaryCE=False, binaryCE_distill=False, AGEM=False,
                 experiment='splitMNIST'):

        # configurations
        super().__init__()
        self.classes = classes
        self.label = "Classifier"
        self.fc_layers = fc_layers

        # settings for training
        self.binaryCE = binaryCE  # -> use binary (instead of multiclass) prediction error
        self.binaryCE_distill = binaryCE_distill  # -> for classes from previous tasks, use the by the previous model
        #   predicted probs as binary targets (only in Class-IL with binaryCE)
        self.AGEM = AGEM  # -> use gradient of replayed data as inequality constraint for (instead of adding it to)
        #   the gradient of the current data (as in A-GEM, see Chaudry et al., 2019; ICLR)

        # Online mem distillation
        self.is_offline_training = False
        self.is_ready_distill = False
        self.alpha_t = 0.5
        # check whether there is at least 1 fc-layer
        if fc_layers < 1:
            raise ValueError("The classifier needs to have at least 1 fully-connected layer.")

        ######------SPECIFY MODEL------######
        self.experiment = experiment
        if self.experiment in ['CIFAR10', 'CIFAR100', 'CUB2011']:
            self.fcE = rn.resnet32(classes, pretrained=False)
            self.fcE.linear = nn.Identity()

            self.classifier = fc_layer(64, classes, excit_buffer=True, nl='none', drop=fc_drop)
        elif self.experiment == 'ImageNet':
            ResNet.name = 'ResNet-18'
            self.fcE = resnet18(pretrained=True)
            self.fcE.fc = nn.Identity()

            self.classifier = fc_layer(512, classes, excit_buffer=True, nl='none', drop=fc_drop)
        else:
            # flatten image to 2D-tensor
            self.flatten = utils.Flatten()

            # fully connected hidden layers
            self.fcE = MLP(input_size=image_channels * image_size ** 2, output_size=fc_units, layers=fc_layers - 1,
                           hid_size=fc_units, drop=fc_drop, batch_norm=fc_bn, nl=fc_nl, bias=bias,
                           excitability=excitability, excit_buffer=excit_buffer, gated=gated)
            mlp_output_size = fc_units if fc_layers > 1 else image_channels * image_size ** 2

            # classifier
            self.classifier = fc_layer(mlp_output_size, classes, excit_buffer=True, nl='none', drop=fc_drop)

    def list_init_layers(self):
        '''Return list of modules whose parameters could be initialized differently (i.e., conv- or fc-layers).'''
        list = []
        list += self.fcE.list_init_layers()
        list += self.classifier.list_init_layers()
        return list

    @property
    def name(self):
        return "{}_c{}".format(self.fcE.name, self.classes)

    def forward(self, x):
        final_features = self.feature_extractor(x)
        return self.classifier(final_features)

    def feature_extractor(self, images):
        if self.experiment not in ['splitMNIST', 'permMNIST', 'rotMNIST']:
            return self.fcE(images)
        else:
            return self.fcE(self.flatten(images))

    def select_triplets(self, embeds, y_score, x, y, triplet_selection, task, scenario, use_embeddings, multi_negative):
        uq = torch.unique(y).cpu().numpy()
        selection_strategies = triplet_selection.split('-')
        # Select instances in the batch for replay later
        for m in uq:
            neg_y = np.delete(uq, np.where(uq == m))
            mask = y == m
            mask_neg = y != m
            ce_m = y_score[mask]
            if ce_m.size(0) != 0:
                # Select anchor and hard positive instances for class m
                positive_batch = x[mask]
                positive_embed_batch = embeds[mask]
                anchor_idx = torch.argmin(ce_m)
                anchor_x = positive_batch[anchor_idx].unsqueeze(dim=0)
                anchor_embed = positive_embed_batch[anchor_idx].unsqueeze(dim=0)
                # anchor should not equal positive
                positive_batch = torch.cat(
                    (positive_batch[:anchor_idx], positive_batch[anchor_idx + 1:]), dim=0)
                positive_embed_batch = torch.cat(
                    (positive_embed_batch[:anchor_idx], positive_embed_batch[anchor_idx + 1:]), dim=0)
                if positive_batch.size(0) != 0:
                    if use_embeddings:
                        anchor_batch = anchor_embed.expand(positive_embed_batch.size())
                        positive_dist = F.pairwise_distance(anchor_batch.view(anchor_batch.size(0), -1),
                                                            positive_embed_batch.view(positive_embed_batch.size(0), -1))
                    else:
                        anchor_batch = anchor_x.expand(positive_batch.size())
                        positive_dist = F.pairwise_distance(anchor_batch.view(anchor_batch.size(0), -1),
                                                            positive_batch.view(positive_batch.size(0), -1))

                    if selection_strategies[0] == 'HP':
                        # Hard positive
                        _, positive_idx = torch.topk(positive_dist, 1)
                    else:
                        # Easy positive
                        _, positive_idx = torch.topk(positive_dist, 1, largest=False)

                    positive_x = positive_batch[positive_idx]
                    x_m = torch.cat((anchor_x, positive_x), dim=0)
                    y_m = torch.tensor([m, m])
                else:
                    x_m = anchor_x
                    y_m = torch.tensor([m])

                if scenario in ['task', 'domain']:
                    self.add_instances_to_online_exemplar_sets(x_m, y_m,
                                                               (y_m + len(uq) * (task - 1)).detach().cpu().numpy())
                else:
                    self.add_instances_to_online_exemplar_sets(x_m, y_m, y_m.detach().cpu().numpy())

                negative_batch = x[mask_neg]
                negative_batch_y = y[mask_neg]
                negative_embed_batch = embeds[mask_neg]

                if negative_batch.size(0) != 0:
                    if use_embeddings:
                        anchor_batch = anchor_embed.expand(negative_embed_batch.size())
                        negative_dist = F.pairwise_distance(anchor_batch.view(anchor_batch.size(0), -1),
                                                            negative_embed_batch.view(negative_embed_batch.size(0), -1))
                    else:
                        anchor_batch = anchor_x.expand(negative_batch.size())
                        negative_dist = F.pairwise_distance(anchor_batch.view(anchor_batch.size(0), -1),
                                                            negative_batch.view(negative_batch.size(0), -1))

                # Select instances for each negative class
                if multi_negative:
                    for n in neg_y:
                        mask_neg_n = negative_batch_y == n
                        negative_dist_n = negative_dist[mask_neg_n]
                        negative_batch_n = negative_batch[mask_neg_n]
                        negative_batch_y_n = negative_batch_y[mask_neg_n]

                        if selection_strategies[1] == 'HN':
                            # Hard negative
                            _, negative_idx = torch.topk(negative_dist_n, int(selection_strategies[2]), largest=False)
                            negative_x = negative_batch_n[negative_idx]
                            negative_y = negative_batch_y_n[negative_idx]
                        elif selection_strategies[1] == 'SHN':
                            # Semi-hard negative
                            if use_embeddings:
                                positive_embed = positive_embed_batch[positive_idx].unsqueeze(dim=0)
                                dap = F.pairwise_distance(anchor_embed.view(anchor_x.size(0), -1),
                                                          positive_embed.view(positive_x.size(0), -1))
                            else:
                                dap = F.pairwise_distance(anchor_x.view(anchor_x.size(0), -1),
                                                          positive_x.view(positive_x.size(0), -1))
                            valid_shn_idx = negative_dist_n > dap
                            if valid_shn_idx.any():
                                shn_batch = negative_batch_n[valid_shn_idx]
                                shn_y = negative_batch_y_n[valid_shn_idx]
                                # negative_idx = torch.argmin(negative_dist[valid_shn_idx])
                                _, negative_idx = torch.topk(negative_dist_n, int(selection_strategies[2]),
                                                             largest=False)
                                negative_x = shn_batch[negative_idx]
                                negative_y = shn_y[negative_idx]
                            else:
                                # There is no semi-hard negative sample, ignore negative sample
                                negative_x = None
                                negative_y = None
                        else:
                            # Easy negative
                            _, negative_idx = torch.topk(negative_dist_n, int(selection_strategies[2]))
                            negative_x = negative_batch_n[negative_idx]
                            negative_y = negative_batch_y_n[negative_idx]

                        if negative_x is not None and negative_y is not None:
                            if scenario in ['task', 'domain']:
                                self.add_instances_to_online_exemplar_sets(negative_x, negative_y,
                                                                           (negative_y + len(uq) * (
                                                                                   task - 1)).detach().cpu().numpy())
                            else:
                                self.add_instances_to_online_exemplar_sets(negative_x, negative_y,
                                                                           negative_y.detach().cpu().numpy())
                else:
                    if selection_strategies[1] == 'HN':
                        # Hard negative
                        _, negative_idx = torch.topk(negative_dist, int(selection_strategies[2]), largest=False)
                        negative_x = negative_batch[negative_idx]
                        negative_y = negative_batch_y[negative_idx]
                    elif selection_strategies[1] == 'SHN':
                        # Semi-hard negative
                        if use_embeddings:
                            positive_embed = positive_embed_batch[positive_idx].unsqueeze(dim=0)
                            dap = F.pairwise_distance(anchor_embed.view(anchor_x.size(0), -1),
                                                      positive_embed.view(positive_x.size(0), -1))
                        else:
                            dap = F.pairwise_distance(anchor_x.view(anchor_x.size(0), -1),
                                                      positive_x.view(positive_x.size(0), -1))
                        valid_shn_idx = negative_dist > dap
                        if valid_shn_idx.any():
                            shn_batch = negative_batch[valid_shn_idx]
                            shn_y = negative_batch_y[valid_shn_idx]
                            # negative_idx = torch.argmin(negative_dist[valid_shn_idx])
                            _, negative_idx = torch.topk(negative_dist[valid_shn_idx], int(selection_strategies[2]), largest=False)
                            negative_x = shn_batch[negative_idx]
                            negative_y = shn_y[negative_idx]
                        else:
                            # There is no semi-hard negative sample, ignore negative sample
                            negative_x = None
                            negative_y = None
                    else:
                        # Easy negative
                        _, negative_idx = torch.topk(negative_dist, int(selection_strategies[2]))
                        negative_x = negative_batch[negative_idx]
                        negative_y = negative_batch_y[negative_idx]

                if negative_x is not None and negative_y is not None:
                    if scenario in ['task', 'domain']:
                        self.add_instances_to_online_exemplar_sets(negative_x, negative_y,
                                                                   (negative_y + len(uq) * (
                                                                           task - 1)).detach().cpu().numpy())
                    else:
                        self.add_instances_to_online_exemplar_sets(negative_x, negative_y,
                                                                   negative_y.detach().cpu().numpy())

    def select_instances(self, embeds, x, y, scenario, task):
        uq, _ = torch.sort(torch.unique(y))
        uq = uq.cpu().numpy()
        exemplars_per_class = int(np.floor(self.memory_budget / (len(uq) * task)))
        exemplar_set = []
        if self.herding:
            # Accumulate class means
            for m in uq:
                mask = y == m
                xm = x[mask]
                embedsm = embeds[mask]

                if self.norm_exemplars:
                    features = F.normalize(embedsm, p=2, dim=1)

                # calculate mean of all features
                class_mean = torch.mean(features, dim=0, keepdim=True)
                # if self.norm_exemplars:
                #     class_mean = F.normalize(class_mean, p=2, dim=1)

                # one by one, select exemplar that makes mean of all exemplars as close to [class_mean] as possible
                exemplar_features = torch.zeros_like(features[:min(exemplars_per_class, embedsm.size(0))])
                list_of_selected = []
                for k in range(min(exemplars_per_class, embedsm.size(0))):
                    if k > 0:
                        exemplar_sum = torch.sum(exemplar_features[:k], dim=0).unsqueeze(0)
                        features_means = (features + exemplar_sum) / (k + 1)
                        features_dists = features_means - class_mean
                    else:
                        features_dists = features - class_mean
                        index_selected = np.argmin(torch.norm(features_dists, p=2, dim=1).detach().cpu().numpy())
                        if index_selected in list_of_selected:
                            raise ValueError("Exemplars should not be repeated!!!!")
                        list_of_selected.append(index_selected)

                        exemplar_set.append(xm[index_selected].detach().cpu().numpy())
                        exemplar_features[k] = features[index_selected].clone()

                        # make sure this example won't be selected again
                        features[index_selected] = features[index_selected] + 10000

                if scenario in ['task', 'domain']:
                    if len(self.exemplar_sets) == ((task - 1) * len(uq) + m % len(uq)):
                        self.exemplar_means.append(class_mean)
                        self.exemplar_sets.append(np.array(exemplar_set))
                    elif len(self.exemplar_sets) < ((task - 1) * len(uq) + m % len(uq)):
                        self.exemplar_means[m + len(uq) * (task - 1)] = (self.exemplar_means[m + len(uq) * (task - 1)]+ class_mean)/2
                        self.exemplar_sets[m] = np.concatenate(
                            (self.exemplar_sets[m + len(uq) * (task - 1)], exemplar_set), axis=0)
                else:
                    if len(self.exemplar_sets) == ((task - 1) * len(uq) + m % len(uq)):
                        self.exemplar_means.append(class_mean)
                        self.exemplar_sets.append(np.array(exemplar_set))
                    elif len(self.exemplar_sets) < ((task - 1) * len(uq) + m % len(uq)):
                        self.exemplar_means[m] = (self.exemplar_means[m] + class_mean) / 2
                        self.exemplar_sets[m] = np.concatenate(
                            (self.exemplar_sets[m], exemplar_set), axis=0)
        else:
            for m in uq:
                mask = y == m
                xm = x[mask]
                indeces_selected = np.random.choice(xm.size(0), size=min(xm.size(0),exemplars_per_class), replace=False)
                if scenario in ['task', 'domain']:
                    if len(self.exemplar_sets) < task * len(uq):
                        self.exemplar_sets.append(xm[indeces_selected].detach().cpu().numpy())
                    else:
                        # Concate to exsisting
                        self.exemplar_sets[m + len(uq) * (task - 1)] = np.concatenate(
                            (self.exemplar_sets[m + len(uq) * (task - 1)], xm[indeces_selected].detach().cpu().numpy()), axis=0)
                else:
                    if len(self.exemplar_sets) < task * len(uq):
                        self.exemplar_sets.append(xm[indeces_selected].detach().cpu().numpy())
                    else:
                        # Concate to exsisting
                        self.exemplar_sets[m] = np.concatenate(
                            (self.exemplar_sets[m], xm[indeces_selected].detach().cpu().numpy()), axis=0)

        self.reduce_exemplar_sets(exemplars_per_class)
        # for i in range(len(self.exemplar_sets)):
        #     print("Task %d Class %d" % (task, i), self.exemplar_sets[i].shape)

    def train_a_batch(self, x, y, scores=None, x_=None, y_=None, scores_=None, rnt=0.5,
                      active_classes=None, task=1, scenario='class', teacher=None,
                      params_dict=None, epoch=0):
        '''Train model for one batch ([x],[y]), possibly supplemented with replayed data ([x_],[y_/scores_]).

        [x]               <tensor> batch of inputs (could be None, in which case only 'replayed' data is used)
        [y]               <tensor> batch of corresponding labels
        [scores]          None or <tensor> 2Dtensor:[batch]x[classes] predicted "scores"/"logits" for [x]
                            NOTE: only to be used for "BCE with distill" (only when scenario=="class")
        [x_]              None or (<list> of) <tensor> batch of replayed inputs
        [y_]              None or (<list> of) <tensor> batch of corresponding "replayed" labels
        [scores_]         None or (<list> of) <tensor> 2Dtensor:[batch]x[classes] predicted "scores"/"logits" for [x_]
        [rnt]             <number> in [0,1], relative importance of new task
        [active_classes]  None or (<list> of) <list> with "active" classes
        [task]            <int>, for setting task-specific mask'''

        # Set model to training-mode
        self.train()

        # Reset optimizer
        self.optimizer.zero_grad()

        # Should gradient be computed separately for each task? (needed when a task-mask is combined with replay)
        gradient_per_task = True if ((self.mask_dict is not None) and (x_ is not None)) else False
        ##--(1)-- REPLAYED DATA --##

        if x_ is not None:
            # print(y_, task)
            # In the Task-IL scenario, [y_] or [scores_] is a list and [x_] needs to be evaluated on each of them
            # (in case of 'exact' or 'exemplar' replay, [x_] is also a list!
            TaskIL = (type(y_) == list) if (y_ is not None) else (type(scores_) == list)
            if not TaskIL:
                y_ = [y_]
                scores_ = [scores_]
                active_classes = [active_classes] if (active_classes is not None) else None
            n_replays = len(y_) if (y_ is not None) else len(scores_)

            # Prepare lists to store losses for each replay
            loss_KD = [None] * n_replays
            loss_replay = [None] * n_replays
            predL_r = [None] * n_replays
            distilL_r = [None] * n_replays

            # Run model (if [x_] is not a list with separate replay per task and there is no task-specific mask)
            if (not type(x_) == list) and (self.mask_dict is None):
                y_hat_all = self(x_)
                if teacher is not None and task > 1:
                    if teacher.is_ready_distill:
                        teacher.eval()
                        with torch.no_grad():
                            embeds_teacher = teacher.feature_extractor(x_)
                            y_hat_teacher = teacher.classifier(embeds_teacher)
                    else:
                        y_hat_teacher = None
                else:
                    y_hat_teacher = None

            # Loop to evalute predictions on replay according to each previous task
            for replay_id in range(n_replays):

                # -if [x_] is a list with separate replay per task, evaluate model on this task's replay
                if (type(x_) == list) or (self.mask_dict is not None):
                    x_temp_ = x_[replay_id] if type(x_) == list else x_
                    if self.mask_dict is not None:
                        self.apply_XdGmask(task=replay_id + 1)
                    y_hat_all = self(x_temp_)

                    if teacher is not None and task > 1:
                        if teacher.is_ready_distill:
                            teacher.eval()
                            with torch.no_grad():
                                embeds_teacher = teacher.feature_extractor(x_temp_)
                                y_hat_teacher = teacher.classifier(embeds_teacher)
                        else:
                            y_hat_teacher = None
                    else:
                        y_hat_teacher = None

                # -if needed (e.g., Task-IL or Class-IL scenario), remove predictions for classes not in replayed task
                y_hat = y_hat_all if (active_classes is None) else y_hat_all[:, active_classes[replay_id]]
                if y_hat_teacher is not None:
                    y_hat_teacher = y_hat_teacher if (active_classes is None) else y_hat_teacher[:, active_classes[replay_id]]
                # Calculate losses
                if (y_ is not None) and (y_[replay_id] is not None):
                    if self.binaryCE:
                        binary_targets_ = utils.to_one_hot(y_[replay_id].cpu(), y_hat.size(1)).to(y_[replay_id].device)
                        predL_r[replay_id] = F.binary_cross_entropy_with_logits(
                            input=y_hat, target=binary_targets_, reduction='none'
                        ).sum(dim=1).mean()  # --> sum over classes, then average over batch
                    else:
                        predL_r[replay_id] = F.cross_entropy(y_hat, y_[replay_id], reduction='mean')

                # Compute distillation loss from teacher outputs
                if y_hat_teacher is not None:
                    if params_dict['distill_type'] in ['E', 'ET', 'ES', 'ETS']:
                        with torch.no_grad():
                            y_hat_ensemble = 0.5 * (y_hat.clone() + y_hat_teacher.clone())

                        if params_dict['distill_type'] in ['ET', 'ETS']:
                            loss_KD[replay_id] = 0.5 * (F.kl_div(F.log_softmax(y_hat / self.KD_temp, dim=1),
                                                      F.softmax(y_hat_ensemble / self.KD_temp, dim=1))
                                             * (self.KD_temp * self.KD_temp) +
                                             F.kl_div(F.log_softmax(y_hat / self.KD_temp, dim=1),
                                                      F.softmax(y_hat_teacher / self.KD_temp, dim=1))
                                             * (self.KD_temp * self.KD_temp))

                        else:  # distill: E, ES
                            loss_KD[replay_id] = F.kl_div(F.log_softmax(y_hat / self.KD_temp, dim=1),
                                               F.softmax(y_hat_ensemble / self.KD_temp, dim=1)) \
                                      * (self.KD_temp * self.KD_temp)

                    else:  # distill: T, TS
                        loss_KD[replay_id] = F.kl_div(F.log_softmax(y_hat / self.KD_temp, dim=1),
                                           F.softmax(y_hat_teacher / self.KD_temp, dim=1)) \
                                  * (self.KD_temp * self.KD_temp)
                        # loss_KD = self.alpha_t * loss_KD + F.cross_entropy(y_hat, y) * (1. - self.alpha_t)

                if (scores_ is not None) and (scores_[replay_id] is not None):
                    # n_classes_to_consider = scores.size(1) #--> with this version, no zeroes are added to [scores]!
                    n_classes_to_consider = y_hat.size(1)  # --> zeros will be added to [scores] to make it this size!
                    kd_fn = utils.loss_fn_kd_binary if self.binaryCE else utils.loss_fn_kd
                    distilL_r[replay_id] = kd_fn(scores=y_hat[:, :n_classes_to_consider],
                                                 target_scores=scores_[replay_id], T=self.KD_temp)
                # Weigh losses
                if self.replay_targets == "hard":
                    loss_replay[replay_id] = predL_r[replay_id]
                elif self.replay_targets == "soft":
                    loss_replay[replay_id] = distilL_r[replay_id]

                # If needed, perform backward pass before next task-mask (gradients of all tasks will be accumulated)
                if gradient_per_task:
                    weight = 1 if self.AGEM else (1 - rnt)
                    weighted_replay_loss_this_task = weight * loss_replay[replay_id] / n_replays
                    weighted_replay_loss_this_task.backward()

            # Calculate total replay loss
            loss_replay = None if (x_ is None) else sum(loss_replay) / n_replays

            # Calculate total kd loss
            loss_KD = None if any(lkd is None for lkd in loss_KD) else sum(loss_KD) / n_replays
        else:
            loss_KD = None

        # If using A-GEM, calculate and store averaged gradient of replayed data
        if self.AGEM and x_ is not None:
            # Perform backward pass to calculate gradient of replayed batch (if not yet done)
            if not gradient_per_task:
                loss_replay = loss_replay.clamp(min=1e-6)
                loss_replay.backward()
            # Reorganize the gradient of the replayed batch as a single vector
            grad_rep = []
            for p in self.parameters():
                if p.requires_grad:
                    grad_rep.append(p.grad.view(-1))
            grad_rep = torch.cat(grad_rep)
            # Reset gradients (with A-GEM, gradients of replayed batch should only be used as inequality constraint)
            self.optimizer.zero_grad()

        ##--(2)-- CURRENT DATA --##
        if x is not None:
            # If requested, apply correct task-specific mask
            if self.mask_dict is not None:
                self.apply_XdGmask(task=task)

            # Run model
            embeds = self.feature_extractor(x)
            y_hat = self.classifier(embeds)

            # -if needed, remove predictions for classes not in current task
            if active_classes is not None:
                class_entries = active_classes[-1] if type(active_classes[0]) == list else active_classes
                y_hat = y_hat[:, class_entries]

            # Calculate prediction loss
            if self.binaryCE:
                # -binary prediction loss
                binary_targets = utils.to_one_hot(y.cpu(), y_hat.size(1)).to(y.device)
                if self.binaryCE_distill and (scores is not None):
                    classes_per_task = int(y_hat.size(1) / task)
                    binary_targets = binary_targets[:, -(classes_per_task):]
                    binary_targets = torch.cat([torch.sigmoid(scores / self.KD_temp), binary_targets], dim=1)
                y_score = F.binary_cross_entropy_with_logits(
                    input=y_hat, target=binary_targets, reduction='none'
                ).sum(dim=1)  # --> sum over classes,
                predL = None if y is None else y_score.mean()  # average over batch
                if params_dict['mem_online'] and epoch == 0:
                    self.select_instances(embeds, x, y, scenario, task)
                else:
                    if params_dict['use_otr'] and epoch == 0:
                        self.select_triplets(embeds, y_score, x, y,
                                             params_dict['triplet_selection'], task, scenario,
                                             params_dict['use_embeddings'], params_dict['multi_negative'])
            else:
                # -multiclass prediction loss
                y_score = F.cross_entropy(input=y_hat, target=y, reduction='none')
                predL = None if y is None else y_score.mean()

                if params_dict['mem_online'] and epoch == 0:
                    self.select_instances(embeds, x, y, scenario, task)
                else:
                    if params_dict['use_otr'] and epoch == 0:
                        self.select_triplets(embeds, y_score, x, y,
                                             params_dict['triplet_selection'], task, scenario,
                                             params_dict['use_embeddings'], params_dict['multi_negative'])

            loss_cur = predL
            # Calculate training-precision
            precision = None if y is None else (y == y_hat.max(1)[1]).sum().item() / x.size(0)

            # If backward passes are performed per task (e.g., XdG combined with replay), perform backward pass
            if gradient_per_task:
                weighted_current_loss = rnt * loss_cur
                weighted_current_loss.backward()
        else:
            precision = predL = None
            # -> it's possible there is only "replay" [e.g., for offline with task-incremental learning]

        # Combine loss from current and replayed batch
        if x_ is None or self.AGEM:
            loss_total = loss_cur
        else:
            loss_total = loss_replay if (x is None) else rnt * loss_cur + (1 - rnt) * loss_replay
        if loss_KD is not None:
            loss_total = loss_total + loss_KD
        ##--(3)-- ALLOCATION LOSSES --##

        # Add SI-loss (Zenke et al., 2017)
        surrogate_loss = self.surrogate_loss()
        if self.si_c > 0:
            loss_total += self.si_c * surrogate_loss

        # Add EWC-loss
        ewc_loss = self.ewc_loss()
        if self.ewc_lambda > 0:
            loss_total += self.ewc_lambda * ewc_loss

        # Backpropagate errors (if not yet done)
        if not gradient_per_task:
            loss_total.backward()

        # If using A-GEM, potentially change gradient:
        if self.AGEM and x_ is not None:
            # -reorganize gradient (of current batch) as single vector
            grad_cur = []
            for p in self.parameters():
                if p.requires_grad:
                    grad_cur.append(p.grad.view(-1))
            grad_cur = torch.cat(grad_cur)
            # -check inequality constrain
            angle = (grad_cur * grad_rep).sum()
            if angle < 0:
                # -if violated, project the gradient of the current batch onto the gradient of the replayed batch ...
                length_rep = (grad_rep * grad_rep).sum()
                grad_proj = grad_cur - (angle / length_rep) * grad_rep
                # -...and replace all the gradients within the model with this projected gradient
                index = 0
                for p in self.parameters():
                    if p.requires_grad:
                        n_param = p.numel()  # number of parameters in [p]
                        p.grad.copy_(grad_proj[index:index + n_param].view_as(p))
                        index += n_param

        # Take optimization-step
        self.optimizer.step()

        # Return the dictionary with different training-loss split in categories
        return {
            'loss_total': loss_total.item(),
            'loss_current': loss_cur.item() if x is not None else 0,
            'loss_replay': loss_replay.item() if (x_ is not None and loss_replay is not None) else 0,
            'pred': predL.item() if predL is not None else 0,
            'pred_r': sum(predL_r).item() / n_replays if (x_ is not None and predL_r[0] is not None) else 0,
            'distil_r': sum(distilL_r).item() / n_replays if (x_ is not None and distilL_r[0] is not None) else 0,
            'ewc': ewc_loss.item(), 'si_loss': surrogate_loss.item(),
            'precision': precision if precision is not None else 0.,
        }

    def train_epoch(self, train_loader, criterion, optimizer, active_classes, params_dict, writer=None):
        # class_entries = active_classes[-1] if type(active_classes[0]) == list else active_classes
        self.train()
        tlosses = []
        for batch_idx, batch in enumerate(train_loader):
            x, y = batch
            x, y = x.to(self._device()), y.to(self._device())
            optimizer.zero_grad()
            y_hat = self(x)
            # y_hat = y_hat[:, class_entries]

            if params_dict['teacher_loss'] == 'BCE':
                y = utils.to_one_hot(y.cpu(), y_hat.size(1)).to(y.device)

            loss = criterion(y_hat, y)
            loss.backward()
            tlosses.append(loss.item())
            # writer.add_scalar('Training loss', loss.item(), params_dict['epoch'] * len(train_loader) + batch_idx)
            optimizer.step()
        return tlosses

    def valid_epoch(self, val_loader, criterion, active_classes, params_dict, writer=None):
        # class_entries = active_classes[-1] if type(active_classes[0]) == list else active_classes
        valid_losses = []
        self.eval()
        with torch.no_grad():
            for batch_idx, batch in enumerate(val_loader, 0):
                x, y = batch
                x, y = x.to(self._device()), y.to(self._device())
                y_hat = self(x)
                # y_hat = y_hat[:, class_entries]

                if params_dict['teacher_loss'] == 'BCE':
                    y = utils.to_one_hot(y.cpu(), y_hat.size(1)).to(y.device)

                valid_loss = criterion(y_hat, y)
                valid_losses.append(valid_loss.item())
                # writer.add_scalar('Validation loss', valid_loss.item(), params_dict['epoch'] * len(val_loader) + batch_idx)
        self.train()
        return valid_losses

    def train_via_KD(self, model, x, distill_type, active_classes):
        if distill_type == 'T':
            return

        model.eval()
        with torch.no_grad():
            y_hat = model(x)
        model.train()

        self.train()
        self.optimizer.zero_grad()
        y_hat_teacher = self(x)
        if active_classes is not None:
            class_entries = active_classes[-1] if type(active_classes[0]) == list else active_classes
            y_hat = y_hat[:, class_entries]
            y_hat_teacher = y_hat_teacher[:, class_entries]

        if distill_type in ['E', 'ET', 'ES', 'ETS']:
            with torch.no_grad():
                y_hat_ensemble = 0.5 * (y_hat_teacher.clone() + y_hat)
            if distill_type in ['ES', 'ETS']:  # distill from ensemble and student to teacher
                loss = 0.5 * (F.kl_div(F.log_softmax(y_hat_teacher / self.KD_temp, dim=1),
                                       F.softmax(y_hat_ensemble / self.KD_temp, dim=1))
                              * (self.KD_temp * self.KD_temp) +
                              F.kl_div(F.log_softmax(y_hat_teacher / self.KD_temp, dim=1),
                                       F.softmax(y_hat / self.KD_temp, dim=1))
                              * (self.KD_temp * self.KD_temp))
            else:  # distill from ensemble to teacher
                loss = F.kl_div(F.log_softmax(y_hat_teacher / self.KD_temp, dim=1),
                                F.softmax(y_hat_ensemble / self.KD_temp, dim=1)) \
                       * (self.KD_temp * self.KD_temp)
        else:
            loss = F.kl_div(F.log_softmax(y_hat_teacher / self.KD_temp, dim=1),
                            F.softmax(y_hat / self.KD_temp, dim=1)) \
                   * (self.KD_temp * self.KD_temp)
        loss.backward()
        self.optimizer.step()