class AutoEncoderLatent(Replayer): """Class for variational auto-encoder (VAE) models.""" def __init__(self, latent_size, classes, fc_layers=3, fc_units=1000, fc_drop=0, fc_bn=True, fc_nl="relu", gated=False, z_dim=20): '''Class for variational auto-encoder (VAE) models.''' # Set configurations super().__init__() self.latent_size = latent_size self.label = "VAE" self.classes = classes self.fc_layers = fc_layers self.z_dim = z_dim self.fc_units = fc_units # Weigths of different components of the loss function self.lamda_rcl = 1. self.lamda_vl = 1. self.lamda_pl = 0. #--> when used as "classifier with feedback-connections", this should be set to 1. self.average = True #--> makes that [reconL] and [variatL] are both divided by number of input-pixels # Check whether there is at least 1 fc-layer if fc_layers < 1: raise ValueError("VAE cannot have 0 fully-connected layers!") ######------SPECIFY MODEL------###### ##>----Encoder (= q[z|x])----<## # -fully connected hidden layers self.fcE = MLP(input_size=latent_size, output_size=fc_units, layers=fc_layers - 1, hid_size=fc_units, drop=fc_drop, batch_norm=fc_bn, nl=fc_nl, gated=gated) mlp_output_size = fc_units if fc_layers > 1 else latent_size # -to z self.toZ = fc_layer_split(mlp_output_size, z_dim, nl_mean='none', nl_logvar='none') ##>----Classifier----<## self.classifier = fc_layer(mlp_output_size, classes, excit_buffer=True, nl='none') ##>----Decoder (= p[x|z])----<## # -from z out_nl = True if fc_layers > 1 else False self.fromZ = fc_layer(z_dim, mlp_output_size, batch_norm=(out_nl and fc_bn), nl=fc_nl if out_nl else "none") # -fully connected hidden layers self.fcD = MLP(input_size=fc_units, output_size=latent_size, layers=fc_layers - 1, hid_size=fc_units, drop=fc_drop, batch_norm=fc_bn, nl=fc_nl, gated=gated, output='BCE') @property def name(self): fc_label = "{}--".format(self.fcE.name) if self.fc_layers > 1 else "" hid_label = "{}{}-".format( "i", self.latent_size) if self.fc_layers == 1 else "" z_label = "z{}".format(self.z_dim) return "{}({}{}{}-c{})".format(self.label, fc_label, hid_label, z_label, self.classes) def list_init_layers(self): '''Return list of modules whose parameters could be initialized differently (i.e., conv- or fc-layers).''' list = [] list += self.fcE.list_init_layers() list += self.toZ.list_init_layers() list += self.classifier.list_init_layers() list += self.fromZ.list_init_layers() list += self.fcD.list_init_layers() return list ##------ FORWARD FUNCTIONS --------## def encode(self, x): '''Pass input through feed-forward connections, to get [hE], [z_mean] and [z_logvar].''' # extract final hidden features (forward-pass) hE = self.fcE(x) # get parameters for reparametrization (z_mean, z_logvar) = self.toZ(hE) return z_mean, z_logvar, hE def classify(self, x): '''For input [x], return all predicted "scores"/"logits".''' hE = self.fcE(x) y_hat = self.classifier(hE) return y_hat def reparameterize(self, mu, logvar): '''Perform "reparametrization trick" to make these stochastic variables differentiable.''' std = logvar.mul(0.5).exp_() eps = std.new(std.size()).normal_() return eps.mul(std).add_(mu) def decode(self, z): hD = self.fromZ(z) features = self.fcD(hD) return features def forward(self, x, full=False, reparameterize=True): '''Forward function to propagate [x] through the encoder, reparametrization and decoder. Input: - [x] <4D-tensor> of shape [batch_size]x[channels]x[image_size]x[image_size] If [full] is True, output should be a <tuple> consisting of: - [x_recon] <4D-tensor> reconstructed image (features) in same shape as [x] - [y_hat] <2D-tensor> with predicted logits for each class - [mu] <2D-tensor> with either [z] or the estimated mean of [z] - [logvar] None or <2D-tensor> estimated log(SD^2) of [z] - [z] <2D-tensor> reparameterized [z] used for reconstruction If [full] is False, output is simply the predicted logits (i.e., [y_hat]).''' if full: # encode (forward), reparameterize and decode (backward) mu, logvar, hE = self.encode(x) z = self.reparameterize(mu, logvar) if reparameterize else mu x_recon = self.decode(z) # classify y_hat = self.classifier(hE) # return return (x_recon, y_hat, mu, logvar, z) else: return self.classify( x) # -> if [full]=False, only forward pass for prediction ##------ SAMPLE FUNCTIONS --------## def sample(self, size): '''Generate [size] samples from the model. Output is tensor (not "requiring grad"), on same device as <self>.''' # set model to eval()-mode mode = self.training self.eval() # sample z z = torch.randn(size, self.z_dim).to(self._device()) # decode z into image X with torch.no_grad(): X = self.decode(z) # set model back to its initial mode self.train(mode=mode) # return samples as [batch_size]x[channels]x[image_size]x[image_size] tensor return X ##------ LOSS FUNCTIONS --------## def calculate_recon_loss(self, x, x_recon, average=False): '''Calculate reconstruction loss for each element in the batch. INPUT: - [x] <tensor> with original input (1st dimension (ie, dim=0) is "batch-dimension") - [x_recon] (tuple of 2x) <tensor> with reconstructed input in same shape as [x] - [average] <bool>, if True, loss is average over all pixels; otherwise it is summed OUTPUT: - [reconL] <1D-tensor> of length [batch_size]''' batch_size = x.size(0) reconL = F.binary_cross_entropy(input=x_recon.view(batch_size, -1), target=x.view(batch_size, -1), reduction='none') reconL = torch.mean(reconL, dim=1) if average else torch.sum(reconL, dim=1) return reconL def calculate_variat_loss(self, mu, logvar): '''Calculate reconstruction loss for each element in the batch. INPUT: - [mu] <2D-tensor> by encoder predicted mean for [z] - [logvar] <2D-tensor> by encoder predicted logvar for [z] OUTPUT: - [variatL] <1D-tensor> of length [batch_size]''' # --> calculate analytically # ---- see Appendix B from: Kingma & Welling (2014) Auto-Encoding Variational Bayes, ICLR ----# variatL = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1) return variatL def loss_function(self, recon_x, x, y_hat=None, y_target=None, scores=None, mu=None, logvar=None): '''Calculate and return various losses that could be used for training and/or evaluating the model. INPUT: - [recon_x] <4D-tensor> reconstructed image in same shape as [x] - [x] <4D-tensor> original image - [y_hat] <2D-tensor> with predicted "logits" for each class - [y_target] <1D-tensor> with target-classes (as integers) - [scores] <2D-tensor> with target "logits" for each class - [mu] <2D-tensor> with either [z] or the estimated mean of [z] - [logvar] None or <2D-tensor> with estimated log(SD^2) of [z] SETTING:- [self.average] <bool>, if True, both [reconL] and [variatL] are divided by number of input elements OUTPUT: - [reconL] reconstruction loss indicating how well [x] and [x_recon] match - [variatL] variational (KL-divergence) loss "indicating how normally distributed [z] is" - [predL] prediction loss indicating how well targets [y] are predicted - [distilL] knowledge distillation (KD) loss indicating how well the predicted "logits" ([y_hat]) match the target "logits" ([scores])''' ###-----Reconstruction loss-----### reconL = self.calculate_recon_loss( x=x, x_recon=recon_x, average=self.average) #-> possibly average over pixels reconL = torch.mean(reconL) #-> average over batch ###-----Variational loss-----### if logvar is not None: variatL = self.calculate_variat_loss(mu=mu, logvar=logvar) variatL = torch.mean(variatL) #-> average over batch if self.average: variatL /= self.latent_size #-> divide by # of input-pixels, if [self.average] else: variatL = torch.tensor(0., device=self._device()) ###-----Prediction loss-----### if y_target is not None: predL = F.cross_entropy(y_hat, y_target, reduction='mean') #-> average over batch else: predL = torch.tensor(0., device=self._device()) ###-----Distilliation loss-----### if scores is not None: n_classes_to_consider = y_hat.size( 1 ) #--> zeroes will be added to [scores] to make its size match [y_hat] distilL = utils.loss_fn_kd(scores=y_hat[:, :n_classes_to_consider], target_scores=scores, T=self.KD_temp) else: distilL = torch.tensor(0., device=self._device()) # Return a tuple of the calculated losses return reconL, variatL, predL, distilL ##------ TRAINING FUNCTIONS --------## def train_a_batch(self, x, y, x_=None, y_=None, scores_=None, rnt=0.5, active_classes=None, task=1, **kwargs): '''Train model for one batch ([x],[y]), possibly supplemented with replayed data ([x_],[y_]). [x] <tensor> batch of inputs (could be None, in which case only 'replayed' data is used) [y] <tensor> batch of corresponding labels [x_] None or (<list> of) <tensor> batch of replayed inputs [y_] None or (<list> of) <tensor> batch of corresponding "replayed" labels [scores_] None or (<list> of) <tensor> 2Dtensor:[batch]x[classes] predicted "scores"/"logits" for [x_] [rnt] <number> in [0,1], relative importance of new task [active_classes] None or (<list> of) <list> with "active" classes''' # Set model to training-mode self.train() ##--(1)-- CURRENT DATA --## precision = 0. if x is not None: # Run the model recon_batch, y_hat, mu, logvar, z = self(x, full=True) # If needed (e.g., Task-IL or Class-IL scenario), remove predictions for classes not in current task if active_classes is not None: y_hat = y_hat[:, active_classes[-1]] if type( active_classes[0]) == list else y_hat[:, active_classes] # Calculate all losses reconL, variatL, predL, _ = self.loss_function(recon_x=recon_batch, x=x, y_hat=y_hat, y_target=y, mu=mu, logvar=logvar) # Weigh losses as requested loss_cur = self.lamda_rcl * reconL + self.lamda_vl * variatL + self.lamda_pl * predL # Calculate training-precision if y is not None: _, predicted = y_hat.max(1) precision = (y == predicted).sum().item() / x.size(0) ##--(2)-- REPLAYED DATA --## if x_ is not None: # In the Task-IL scenario, [y_] or [scores_] is a list and [x_] needs to be evaluated on each of them # (in case of 'exact' or 'exemplar' replay, [x_] is also a list! TaskIL = (type(y_) == list) if (y_ is not None) else (type(scores_) == list) if not TaskIL: y_ = [y_] scores_ = [scores_] active_classes = [active_classes] if (active_classes is not None) else None n_replays = len(x_) if (type(x_) == list) else 1 else: n_replays = len(y_) if (y_ is not None) else (len(scores_) if ( scores_ is not None) else 1) # Prepare lists to store losses for each replay loss_replay = [None] * n_replays reconL_r = [None] * n_replays variatL_r = [None] * n_replays predL_r = [None] * n_replays distilL_r = [None] * n_replays # Run model (if [x_] is not a list with separate replay per task) if (not type(x_) == list): x_temp_ = x_ recon_batch, y_hat_all, mu, logvar, z = self(x_temp_, full=True) # Loop to perform each replay for replay_id in range(n_replays): # -if [x_] is a list with separate replay per task, evaluate model on this task's replay if (type(x_) == list): x_temp_ = x_[replay_id] recon_batch, y_hat_all, mu, logvar, z = self(x_temp_, full=True) # If needed (e.g., Task-IL or Class-IL scenario), remove predictions for classes not in replayed task if active_classes is not None: y_hat = y_hat_all[:, active_classes[replay_id]] else: y_hat = y_hat_all # Calculate all losses reconL_r[replay_id], variatL_r[replay_id], predL_r[ replay_id], distilL_r[replay_id] = self.loss_function( recon_x=recon_batch, x=x_temp_, y_hat=y_hat, y_target=y_[replay_id] if (y_ is not None) else None, scores=scores_[replay_id] if (scores_ is not None) else None, mu=mu, logvar=logvar, ) # Weigh losses as requested loss_replay[replay_id] = self.lamda_rcl * reconL_r[ replay_id] + self.lamda_vl * variatL_r[replay_id] if self.replay_targets == "hard": loss_replay[ replay_id] += self.lamda_pl * predL_r[replay_id] elif self.replay_targets == "soft": loss_replay[ replay_id] += self.lamda_pl * distilL_r[replay_id] # Calculate total loss loss_replay = None if (x_ is None) else sum(loss_replay) / n_replays loss_total = loss_replay if ( x is None) else (loss_cur if x_ is None else rnt * loss_cur + (1 - rnt) * loss_replay) # Reset optimizer self.optimizer.zero_grad() # Backpropagate errors loss_total.backward() # Take optimization-step self.optimizer.step() # Return the dictionary with different training-loss split in categories return { 'loss_total': loss_total.item(), 'precision': precision, 'recon': reconL.item() if x is not None else 0, 'variat': variatL.item() if x is not None else 0, 'pred': predL.item() if x is not None else 0, 'recon_r': sum(reconL_r).item() / n_replays if x_ is not None else 0, 'variat_r': sum(variatL_r).item() / n_replays if x_ is not None else 0, 'pred_r': sum(predL_r).item() / n_replays if (x_ is not None and predL_r[0] is not None) else 0, 'distil_r': sum(distilL_r).item() / n_replays if (x_ is not None and distilL_r[0] is not None) else 0, }
class RootClassifier(ContinualLearner, Replayer, ExemplarHandler): '''Model for classifying images, "enriched" as "ContinualLearner"-, Replayer- and ExemplarHandler-object.''' # TODO: Do I need the `classes` argument? def __init__(self, image_size, image_channels, classes, fc_layers=3, fc_units=1000, fc_drop=0, fc_bn=False, fc_nl="relu", gated=False, bias=True, excitability=False, excit_buffer=False, binaryCE=False, binaryCE_distill=False, AGEM=False, dataset="mnist"): # configurations super().__init__() self.classes = classes self.label = "Classifier" self.fc_layers = fc_layers # settings for training self.binaryCE = binaryCE #-> use binary (instead of multiclass) prediction error self.binaryCE_distill = binaryCE_distill #-> for classes from previous tasks, use the by the previous model # predicted probs as binary targets (only in Class-IL with binaryCE) self.AGEM = AGEM #-> use gradient of replayed data as inequality constraint for (instead of adding it to) # the gradient of the current data (as in A-GEM, see Chaudry et al., 2019; ICLR) # check whether there is at least 1 fc-layer if fc_layers < 1: raise ValueError( "The classifier needs to have at least 1 fully-connected layer." ) ######------SPECIFY MODEL------###### # flatten image to 2D-tensor self.flatten = utils.Flatten() # fully connected hidden layers if dataset == "ckplus" or dataset == "affectnet": self.input_size = image_size[0] * image_size[1] * image_channels else: self.input_size = image_channels * image_size**2 self.fcE = MLP(input_size=self.input_size, output_size=fc_units, layers=fc_layers - 1, hid_size=fc_units, drop=fc_drop, batch_norm=fc_bn, nl=fc_nl, bias=bias, excitability=excitability, excit_buffer=excit_buffer, gated=gated, latent_space=128) def list_init_layers(self): '''Return list of modules whose parameters could be initialized differently (i.e., conv- or fc-layers).''' list = [] list += self.fcE.list_init_layers() # list += self.classifier.list_init_layers() return list @property def name(self): return "{}_c{}".format(self.fcE.name, self.classes) def forward(self, x): final_features = self.fcE(self.flatten(x)) return final_features def feature_extractor(self, images): return self.fcE(self.flatten(images)) def train_a_batch(self, x, y, scores=None, x_=None, y_=None, scores_=None, rnt=0.5, active_classes=None, task=1): '''Train model for one batch ([x],[y]), possibly supplemented with replayed data ([x_],[y_/scores_]). [x] <tensor> batch of inputs (could be None, in which case only 'replayed' data is used) [y] <tensor> batch of corresponding labels [scores] None or <tensor> 2Dtensor:[batch]x[classes] predicted "scores"/"logits" for [x] NOTE: only to be used for "BCE with distill" (only when scenario=="class") [x_] None or (<list> of) <tensor> batch of replayed inputs [y_] None or (<list> of) <tensor> batch of corresponding "replayed" labels [scores_] None or (<list> of) <tensor> 2Dtensor:[batch]x[classes] predicted "scores"/"logits" for [x_] [rnt] <number> in [0,1], relative importance of new task [active_classes] None or (<list> of) <list> with "active" classes [task] <int>, for setting task-specific mask''' # Set model to training-mode self.train() # Reset optimizer self.optimizer.zero_grad() # Should gradient be computed separately for each task? (needed when a task-mask is combined with replay) gradient_per_task = True if ((self.mask_dict is not None) and (x_ is not None)) else False ##--(1)-- REPLAYED DATA --## if x_ is not None: # In the Task-IL scenario, [y_] or [scores_] is a list and [x_] needs to be evaluated on each of them # (in case of 'exact' or 'exemplar' replay, [x_] is also a list! TaskIL = (type(y_) == list) if (y_ is not None) else (type(scores_) == list) if not TaskIL: y_ = [y_] scores_ = [scores_] active_classes = [active_classes] if (active_classes is not None) else None n_replays = len(y_) if (y_ is not None) else len(scores_) # Prepare lists to store losses for each replay loss_replay = [None] * n_replays predL_r = [None] * n_replays distilL_r = [None] * n_replays # Run model (if [x_] is not a list with separate replay per task and there is no task-specific mask) if (not type(x_) == list) and (self.mask_dict is None): y_hat_all = self(x_) # Loop to evalute predictions on replay according to each previous task for replay_id in range(n_replays): # -if [x_] is a list with separate replay per task, evaluate model on this task's replay if (type(x_) == list) or (self.mask_dict is not None): x_temp_ = x_[replay_id] if type(x_) == list else x_ if self.mask_dict is not None: self.apply_XdGmask(task=replay_id + 1) y_hat_all = self(x_temp_) # -if needed (e.g., Task-IL or Class-IL scenario), remove predictions for classes not in replayed task y_hat = y_hat_all if ( active_classes is None ) else y_hat_all[:, active_classes[replay_id]] # Calculate losses if (y_ is not None) and (y_[replay_id] is not None): if self.binaryCE: binary_targets_ = utils.to_one_hot( y_[replay_id].cpu(), y_hat.size(1)).to(y_[replay_id].device) predL_r[replay_id] = F.binary_cross_entropy_with_logits( input=y_hat, target=binary_targets_, reduction='none').sum(dim=1).mean( ) #--> sum over classes, then average over batch else: predL_r[replay_id] = F.cross_entropy(y_hat, y_[replay_id], reduction='mean') if (scores_ is not None) and (scores_[replay_id] is not None): # n_classes_to_consider = scores.size(1) #--> with this version, no zeroes are added to [scores]! n_classes_to_consider = y_hat.size( 1 ) #--> zeros will be added to [scores] to make it this size! kd_fn = utils.loss_fn_kd_binary if self.binaryCE else utils.loss_fn_kd distilL_r[replay_id] = kd_fn( scores=y_hat[:, :n_classes_to_consider], target_scores=scores_[replay_id], T=self.KD_temp) # Weigh losses if self.replay_targets == "hard": loss_replay[replay_id] = predL_r[replay_id] elif self.replay_targets == "soft": loss_replay[replay_id] = distilL_r[replay_id] # If needed, perform backward pass before next task-mask (gradients of all tasks will be accumulated) if gradient_per_task: weight = 1 if self.AGEM else (1 - rnt) weighted_replay_loss_this_task = weight * loss_replay[ replay_id] / n_replays weighted_replay_loss_this_task.backward() # Calculate total replay loss loss_replay = None if (x_ is None) else sum(loss_replay) / n_replays # If using A-GEM, calculate and store averaged gradient of replayed data if self.AGEM and x_ is not None: # Perform backward pass to calculate gradient of replayed batch (if not yet done) if not gradient_per_task: loss_replay.backward() # Reorganize the gradient of the replayed batch as a single vector grad_rep = [] for p in self.parameters(): if p.requires_grad: grad_rep.append(p.grad.view(-1)) grad_rep = torch.cat(grad_rep) # Reset gradients (with A-GEM, gradients of replayed batch should only be used as inequality constraint) self.optimizer.zero_grad() ##--(2)-- CURRENT DATA --## if x is not None: # If requested, apply correct task-specific mask if self.mask_dict is not None: self.apply_XdGmask(task=task) # Run model y_hat = self(x) # -if needed, remove predictions for classes not in current task if active_classes is not None: class_entries = active_classes[-1] if type( active_classes[0]) == list else active_classes y_hat = y_hat[:, class_entries] # Calculate prediction loss if self.binaryCE: # -binary prediction loss binary_targets = utils.to_one_hot(y.cpu(), y_hat.size(1)).to(y.device) if self.binaryCE_distill and (scores is not None): classes_per_task = int(y_hat.size(1) / task) binary_targets = binary_targets[:, -(classes_per_task):] binary_targets = torch.cat( [torch.sigmoid(scores / self.KD_temp), binary_targets], dim=1) predL = None if y is None else F.binary_cross_entropy_with_logits( input=y_hat, target=binary_targets, reduction='none').sum( dim=1).mean( ) #--> sum over classes, then average over batch else: # -multiclass prediction loss predL = None if y is None else F.cross_entropy( input=y_hat, target=y, reduction='mean') # Weigh losses loss_cur = predL # Calculate training-precision precision = None if y is None else ( y == y_hat.max(1)[1]).sum().item() / x.size(0) # If backward passes are performed per task (e.g., XdG combined with replay), perform backward pass if gradient_per_task: weighted_current_loss = rnt * loss_cur weighted_current_loss.backward() else: precision = predL = None # -> it's possible there is only "replay" [e.g., for offline with task-incremental learning] # Combine loss from current and replayed batch if x_ is None or self.AGEM: loss_total = loss_cur else: loss_total = loss_replay if ( x is None) else rnt * loss_cur + (1 - rnt) * loss_replay ##--(3)-- ALLOCATION LOSSES --## # Add SI-loss (Zenke et al., 2017) surrogate_loss = self.surrogate_loss() if self.si_c > 0: loss_total += self.si_c * surrogate_loss # Add EWC-loss ewc_loss = self.ewc_loss() if self.ewc_lambda > 0: loss_total += self.ewc_lambda * ewc_loss # Backpropagate errors (if not yet done) if not gradient_per_task: loss_total.backward() # If using A-GEM, potentially change gradient: if self.AGEM and x_ is not None: # -reorganize gradient (of current batch) as single vector grad_cur = [] for p in self.parameters(): if p.requires_grad: grad_cur.append(p.grad.view(-1)) grad_cur = torch.cat(grad_cur) # -check inequality constrain angle = (grad_cur * grad_rep).sum() if angle < 0: # -if violated, project the gradient of the current batch onto the gradient of the replayed batch ... length_rep = (grad_rep * grad_rep).sum() grad_proj = grad_cur - (angle / length_rep) * grad_rep # -...and replace all the gradients within the model with this projected gradient index = 0 for p in self.parameters(): if p.requires_grad: n_param = p.numel() # number of parameters in [p] p.grad.copy_(grad_proj[index:index + n_param].view_as(p)) index += n_param # Take optimization-step self.optimizer.step() # Return the dictionary with different training-loss split in categories return { 'loss_total': loss_total.item(), 'loss_current': loss_cur.item() if x is not None else 0, 'loss_replay': loss_replay.item() if (loss_replay is not None) and (x is not None) else 0, 'pred': predL.item() if predL is not None else 0, 'pred_r': sum(predL_r).item() / n_replays if (x_ is not None and predL_r[0] is not None) else 0, 'distil_r': sum(distilL_r).item() / n_replays if (x_ is not None and distilL_r[0] is not None) else 0, 'ewc': ewc_loss.item(), 'si_loss': surrogate_loss.item(), 'precision': precision if precision is not None else 0., }
class Classifier(ContinualLearner, Replayer, ExemplarHandler): '''Model for classifying images, "enriched" as "ContinualLearner"-, Replayer- and ExemplarHandler-object.''' def __init__(self, num_features, num_seq, classes, fc_layers=3, fc_units=1000, fc_drop=0, fc_bn=True, fc_nl="relu", gated=False, bias=True, excitability=None, excit_buffer=False, binaryCE=False, binaryCE_distill=False, experiment='splitMNIST', cls_type='mlp', args=None): # configurations super().__init__() self.num_features = num_features self.num_seq = num_seq self.classes = classes self.label = "Classifier" self.fc_layers = fc_layers self.hidden_dim = fc_units self.layer_dim = fc_layers - 1 self.cuda = None if args is None else args.cuda self.device = args.device self.weights_per_class = None if args is None else torch.FloatTensor( args.weights_per_class).to(args.device) # store precision_dict into model so that we can fetch # self.precision_dict_list = [[] for i in range(len(args.num_classes_per_task_l))] # self.precision_dict = {} # settings for training self.binaryCE = binaryCE self.binaryCE_distill = binaryCE_distill # check whether there is at least 1 fc-layer if fc_layers < 1: raise ValueError( "The classifier needs to have at least 1 fully-connected layer." ) ######------SPECIFY MODEL------###### self.cls_type = cls_type self.experiment = experiment # flatten image to 2D-tensor self.flatten = utils.Flatten() # fully connected hidden layers if experiment == 'sensor': if self.cls_type == 'mlp': self.fcE = MLP(input_size=num_seq * num_features, output_size=fc_units, layers=fc_layers - 1, hid_size=fc_units, drop=fc_drop, batch_norm=fc_bn, nl=fc_nl, bias=bias, excitability=excitability, excit_buffer=excit_buffer, gated=gated) elif self.cls_type == 'lstm': self.lstm_input_dropout = nn.Dropout(args.input_drop) self.lstm = nn.LSTM(input_size=num_features, hidden_size=fc_units, num_layers=fc_layers - 1, dropout=0.0 if (fc_layers - 1) == 1 else fc_drop, batch_first=True) # self.name = "LSTM([{} X {} X {}])".format(num_features, num_seq, classes) if self.fc_layers > 0 else "" else: self.fcE = MLP(input_size=num_seq * num_features**2, output_size=fc_units, layers=fc_layers - 1, hid_size=fc_units, drop=fc_drop, batch_norm=fc_bn, nl=fc_nl, bias=bias, excitability=excitability, excit_buffer=excit_buffer, gated=gated) # classifier if self.cls_type == 'mlp': mlp_output_size = fc_units if fc_layers > 1 else num_seq * num_features**2 self.classifier = fc_layer(mlp_output_size, classes, excit_buffer=True, nl='none', drop=fc_drop) elif self.cls_type == 'lstm': self.lstm_fc = nn.Linear(fc_units, classes) ################# # +++++ GEM +++++ ##### if args.gem: print('this is test for GEM ') self.margin = args.memory_strength self.ce = nn.CrossEntropyLoss() self.n_outputs = classes self.n_memories = args.n_memories self.gpu = args.cuda n_tasks = len(args.num_classes_per_task_l) # allocate episodic memory self.memory_data = torch.FloatTensor(n_tasks, self.n_memories, self.num_seq, self.num_features) self.memory_labs = torch.LongTensor(n_tasks, self.n_memories) if args.cuda: # self.memory_data = self.memory_data.cuda() self.memory_data = self.memory_data.to(self.device) # self.memory_labs = self.memory_labs.cuda() self.memory_labs = self.memory_labs.to(self.device) # allocate temporary synaptic memory self.grad_dims = [] for param in self.parameters(): self.grad_dims.append(param.data.numel()) self.grads = torch.Tensor(sum(self.grad_dims), n_tasks) if args.cuda: # self.grads = self.grads.cuda() self.grads = self.grads.to(self.device) # allocate counters self.observed_tasks = [] self.old_task = -1 self.mem_cnt = 0 def list_init_layers(self): '''Return list of modules whose parameters could be initialized differently (i.e., conv- or fc-layers).''' list = [] list += self.fcE.list_init_layers() list += self.classifier.list_init_layers() return list @property def name(self): if self.cls_type == 'mlp': return "{}_c{}".format(self.fcE.name, self.classes) elif self.cls_type == 'lstm': return "LSTM([{} X {}]_c{})".format(self.num_seq, self.num_features, self.classes) def forward(self, x): if self.cls_type == 'mlp': final_features = self.fcE(self.flatten(x)) return self.classifier(final_features) elif self.cls_type == 'lstm': x = self.lstm_input_dropout(x) h0, c0 = self.init_hidden(x) out, (hn, cn) = self.lstm(x, (h0, c0)) return self.lstm_fc(out[:, -1, :]) # lstm_out, hidden = self.lstm(x) # print(lstm_out.size()) # print(x.size()) # print(lstm_out[-1].size()) # return self.lstm_fc(lstm_out[-1].view(x.size(0), -1)) def feature_extractor(self, x): if self.cls_type == 'mlp': return self.fcE(self.flatten(x)) elif self.cls_type == 'lstm': x = self.lstm_input_dropout(x) h0, c0 = self.init_hidden(x) out, (hn, cn) = self.lstm(x, (h0, c0)) return out[:, -1, :] # lstm_out, hidden = self.lstm(images) # return lstm_out[-1] def forward_from_hidden_layer(self, x): if self.cls_type == 'mlp': return self.classifier(x) elif self.cls_type == 'lstm': return self.lstm_fc(x) def init_hidden(self, x): h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim) c0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim) return [t.to(self.device) for t in (h0, c0)] if self.cuda else (h0, c0) def train_a_batch(self, x, y, x_=None, y_=None, x_ex=None, y_ex=None, scores=None, scores_=None, rnt=0.5, active_classes=None, num_classes_per_task_l=None, task=1, args=None): '''Train model for one batch ([x],[y]), possibly supplemented with replayed data ([x_],[y_/scores_]). [x] <tensor> batch of inputs (could be None, in which case only 'replayed' data is used) [y] <tensor> batch of corresponding labels [scores] None or <tensor> 2Dtensor:[batch]x[classes] predicted "scores"/"logits" for [x] NOTE: only to be used for "BCE with distill" (only when scenario=="class") [x_] None or (<list> of) <tensor> batch of replayed inputs [y_] None or (<list> of) <tensor> batch of corresponding "replayed" labels [x_ex] None or (<list> of) <tensor> batch of exemplars inputs [y_ex] None or (<list> of) <tensor> batch of exemplars inputs' labels [scores_] None or (<list> of) <tensor> 2Dtensor:[batch]x[classes] predicted "scores"/"logits" for [x_] [rnt] <number> in [0,1], relative importance of new task [active_classes] None or (<list> of) <list> with "active" classes [task] <int>, for setting task-specific mask''' y.long() if y_ is not None: y_.long() if y_ex is not None: y_ex.long() # Set model to training-mode self.train() if args.gem: start_gem = time.time() t = task - 1 # update memory if t != self.old_task: self.observed_tasks.append(t) self.old_task = t # Update ring buffer storing examples from current task bsz = y.data.size(0) endcnt = min(self.mem_cnt + bsz, self.n_memories) effbsz = endcnt - self.mem_cnt self.memory_data[t, self.mem_cnt:endcnt].copy_(x.data[:effbsz]) if bsz == 1: self.memory_labs[t, self.mem_cnt] = y.data[0] else: self.memory_labs[t, self.mem_cnt:endcnt].copy_(y.data[:effbsz]) self.mem_cnt += effbsz if self.mem_cnt == self.n_memories: self.mem_cnt = 0 args.train_time_gem_update_memory += time.time() - start_gem start_gem = time.time() # compute gradient on previous tasks if len(self.observed_tasks) > 1: # print('self.observed_tasks: ', self.observed_tasks) for tt in range(len(self.observed_tasks) - 1): self.zero_grad() # fwd/bwd on the examples in the memory past_task = self.observed_tasks[tt] offset1, offset2 = my_compute_offsets( task=past_task, num_classes_per_task_l=num_classes_per_task_l) # ptloss = self.ce( # input=self.forward(self.memory_data[past_task])[:, offset1: offset2], # target=self.memory_labs[past_task] - offset1) ptloss = F.cross_entropy( input=self.forward( self.memory_data[past_task])[:, offset1:offset2], target=self.memory_labs[past_task] - offset1, weight=self.weights_per_class[offset1:offset2]) ptloss.backward() store_grad(self.parameters, self.grads, self.grad_dims, past_task) args.train_time_gem_compute_gradient += time.time() - start_gem # now compute the grad on the current minibatch self.zero_grad() # print(t, num_classes_per_task_l) offset1, offset2 = my_compute_offsets( task=t, num_classes_per_task_l=num_classes_per_task_l) # print(self.forward(x)[:, offset1: offset2].size()) # print(y.size()) # loss = self.ce( # input=self.forward(x)[:, offset1: offset2], # target=y - offset1) loss = F.cross_entropy( input=self.forward(x)[:, offset1:offset2], target=y - offset1, weight=self.weights_per_class[offset1:offset2]) loss.backward() # check if gradient violates constraints start_gem = time.time() if len(self.observed_tasks) > 1: # copy gradient store_grad(self.parameters, self.grads, self.grad_dims, t) # indx = torch.cuda.LongTensor(self.observed_tasks[:-1]) if self.gpu \ # else torch.LongTensor(self.observed_tasks[:-1]) indx = torch.LongTensor(self.observed_tasks[:-1]).to(self.device) if self.gpu \ else torch.LongTensor(self.observed_tasks[:-1]) # print(indx) dotp = torch.mm(self.grads[:, t].unsqueeze(0), self.grads.index_select(1, indx)) if (dotp < 0).sum() != 0: project2cone2(self.grads[:, t].unsqueeze(1), self.grads.index_select(1, indx), self.margin) # copy gradients back overwrite_grad(self.parameters, self.grads[:, t], self.grad_dims) args.train_time_gem_violation_check += time.time() - start_gem self.optimizer.step() return { 'loss_total': loss.item(), 'loss_current': 0, 'loss_replay': 0, 'pred': 0, 'pred_r': 0, 'distil_r': 0, 'ewc': 0., 'si_loss': 0., 'precision': 0., } else: # Reset optimizer self.optimizer.zero_grad() ##--(1)-- CURRENT DATA --## if x is not None: # If requested, apply correct task-specific mask if self.mask_dict is not None: self.apply_XdGmask(task=task) # Run model if args.augment < 0: y_hat = self(x) else: features = self.feature_extractor(x) features_ex = self.feature_extractor(x_ex) # y_hat = self.forward_from_hidden_layer(features) # y_hat_ex = self.forward_from_hidden_layer(features_ex) ##### ##### perform augmentation on feature space ##### ##### if args.augment == 0: # random augmentation features = torch.cat([features, features_ex]) y = torch.cat([y, y_ex]) # now let's add some noise! for i in range(args.scaling - 1): features = torch.cat([ features, features_ex + torch.randn(features_ex.shape).to(self.device) * args.sd ]) y = torch.cat([y, y_ex]) if features.shape[0] > args.batch: features = features[:args.batch, :] y = y[:args.batch] break y_hat = self.forward_from_hidden_layer(features) elif args.augment == 1: # feature augmentation based on standard deviation pass elif args.augment == 2: # feature augmentation based on SMOTE method pass # -if needed, remove predictions for classes not in current task if active_classes is not None: class_entries = active_classes[-1] if type( active_classes[0]) == list else active_classes y_hat = y_hat[:, class_entries] # if args.augment >= 0: # y_hat_ex = y_hat_ex[:, class_entries] # Calculate prediction loss if self.binaryCE: # ICARL goes into this. binaryCE is True. # -binary prediction loss # print('x.shape: ', x.shape) # print('y.shape: ', y.shape) # print('y_hat.shape: ', y_hat.shape) # print('y: ', y) start_icarl = time.time() binary_targets = utils.to_one_hot( y.cpu(), y_hat.size(1)).to(y.device) # binary_targets = [0, 0, 1, ... , 0] <=> class = 2 # print(binary_targets.size()) # [128 x 17] if self.binaryCE_distill and ( scores is not None ): # ICARL does not go into this cuz scores is None if args.experiment == 'sensor': binary_targets = binary_targets[:, sum(num_classes_per_task_l[:( task - 1)]):] else: classes_per_task = int(y_hat.size(1) / task) binary_targets = binary_targets[:, -( classes_per_task):] # print(classes_per_task) # 8 # print(binary_targets.size()) # [128 x 1] binary_targets = torch.cat([ torch.sigmoid(scores / self.KD_temp), binary_targets ], dim=1) # print(binary_targets.size()) # [128 x 17] => this supposed to be [128 x 17 (16 + 1)] # print(scores.size()) # [128 x 16] # print(self.KD_temp) # 2 predL = None if y is None else F.binary_cross_entropy_with_logits( input=y_hat, target=binary_targets, reduction='none').sum(dim=1).mean( ) #--> sum over classes, then average over batch args.train_time_icarl_loss += time.time() - start_icarl else: # -multiclass prediction loss # print("x", x.shape, x) # print("y_hat", y_hat.shape, y_hat) # print("y", y.shape, y) predL = None if y is None else F.cross_entropy( input=y_hat, target=y, weight=self.weights_per_class[class_entries], reduction='elementwise_mean') # Weigh losses loss_cur = predL # Calculate training-precision precision = None if y is None else ( y == y_hat.max(1)[1]).sum().item() / x.size(0) # If XdG is combined with replay, backward-pass needs to be done before new task-mask is applied if (self.mask_dict is not None) and (x_ is not None): weighted_current_loss = rnt * loss_cur weighted_current_loss.backward() else: precision = predL = None # -> it's possible there is only "replay" [i.e., for offline with incremental task learning] ##--(2)-- REPLAYED DATA --## if x_ is not None: # In the Task-IL scenario, [y_] or [scores_] is a list and [x_] needs to be evaluated on each of them # (in case of 'exact' or 'exemplar' replay, [x_] is also a list! start_lwf = time.time() TaskIL = (type(y_) == list) if (y_ is not None) else ( type(scores_) == list) if not TaskIL: y_ = [y_] scores_ = [scores_] active_classes = [active_classes] if ( active_classes is not None) else None n_replays = len(y_) if (y_ is not None) else len(scores_) # Prepare lists to store losses for each replay loss_replay = [None] * n_replays predL_r = [None] * n_replays distilL_r = [None] * n_replays # Run model (if [x_] is not a list with separate replay per task and there is no task-specific mask) if (not type(x_) == list) and (self.mask_dict is None): y_hat_all = self(x_) # Loop to evalute predictions on replay according to each previous task for replay_id in range(n_replays): # -if [x_] is a list with separate replay per task, evaluate model on this task's replay if (type(x_) == list) or (self.mask_dict is not None): x_temp_ = x_[replay_id] if type(x_) == list else x_ if self.mask_dict is not None: self.apply_XdGmask(task=replay_id + 1) y_hat_all = self(x_temp_) # -if needed (e.g., Task-IL or Class-IL scenario), remove predictions for classes not in replayed task y_hat = y_hat_all if ( active_classes is None ) else y_hat_all[:, active_classes[replay_id]] # Calculate losses if (y_ is not None) and (y_[replay_id] is not None): if self.binaryCE: binary_targets_ = utils.to_one_hot( y_[replay_id].cpu(), y_hat.size(1)).to(y_[replay_id].device) predL_r[ replay_id] = F.binary_cross_entropy_with_logits( input=y_hat, target=binary_targets_, reduction='none' ).sum(dim=1).mean( ) #--> sum over classes, then average over batch else: predL_r[replay_id] = F.cross_entropy( input=y_hat, target=y_[replay_id], weight=self.weights_per_class[ active_classes[replay_id]], reduction='elementwise_mean') if (scores_ is not None) and (scores_[replay_id] is not None): # n_classes_to_consider = scores.size(1) #--> with this version, no zeroes are added to [scores]! n_classes_to_consider = y_hat.size( 1 ) #--> zeros will be added to [scores] to make it this size! kd_fn = utils.loss_fn_kd_binary if self.binaryCE else utils.loss_fn_kd distilL_r[replay_id] = kd_fn( scores=y_hat[:, :n_classes_to_consider], target_scores=scores_[replay_id], T=self.KD_temp) # Weigh losses if self.replay_targets == "hard": loss_replay[replay_id] = predL_r[replay_id] elif self.replay_targets == "soft": loss_replay[replay_id] = distilL_r[replay_id] # If task-specific mask, backward pass needs to be performed before next task-mask is applied if self.mask_dict is not None: weighted_replay_loss_this_task = ( 1 - rnt) * loss_replay[replay_id] / n_replays weighted_replay_loss_this_task.backward() args.train_time_lwf_loss += time.time() - start_lwf # Calculate total loss with replay loss if it exists. if x_ is None: loss_replay = None else: start_lwf = time.time() loss_replay = sum(loss_replay) / n_replays args.train_time_lwf_loss += time.time() - start_lwf if x is None: start_lwf = time.time() loss_total = loss_replay args.train_time_lwf_loss += time.time() - start_lwf else: if x_ is None: loss_total = loss_cur else: start_lwf = time.time() loss_total = rnt * loss_cur + (1 - rnt) * loss_replay args.train_time_lwf_loss += time.time() - start_lwf # loss_replay = None if (x_ is None) else sum(loss_replay)/n_replays # loss_total = loss_replay if (x is None) else (loss_cur if x_ is None else rnt*loss_cur+(1-rnt)*loss_replay) ##--(3)-- ALLOCATION LOSSES --## # Add SI-loss (Zenke et al., 2017) if self.si_c > 0: start_si = time.time() surrogate_loss = self.surrogate_loss() loss_total += self.si_c * surrogate_loss args.train_time_si_loss += time.time() - start_si # Add EWC-loss if self.ewc_lambda > 0: start_ewc = time.time() ewc_loss = self.ewc_loss() loss_total += self.ewc_lambda * ewc_loss args.train_time_ewc_loss += time.time() - start_ewc # Backpropagate errors (if not yet done) if (self.mask_dict is None) or (x_ is None): loss_total.backward() # Take optimization-step self.optimizer.step() # Return the dictionary with different training-loss split in categories return { 'loss_total': loss_total.item(), 'loss_current': loss_cur.item() if x is not None else 0, 'loss_replay': loss_replay.item() if (loss_replay is not None) and (x is not None) else 0, 'pred': predL.item() if predL is not None else 0, 'pred_r': sum(predL_r).item() / n_replays if (x_ is not None and predL_r[0] is not None) else 0, 'distil_r': sum(distilL_r).item() / n_replays if (x_ is not None and distilL_r[0] is not None) else 0, 'ewc': ewc_loss.item() if self.ewc_lambda > 0 else 0.0, 'si_loss': surrogate_loss.item() if self.si_c > 0 else 0.0, 'precision': precision if precision is not None else 0., }
class Classifier(ContinualLearner, Replayer, ExemplarHandler): '''Model for classifying images, "enriched" as "ContinualLearner"-, Replayer- and ExemplarHandler-object.''' def __init__(self, image_size, image_channels, classes, fc_layers=3, fc_units=1000, fc_drop=0, fc_bn=True, fc_nl="relu", gated=False, bias=True, excitability=False, excit_buffer=False, binaryCE=False, binaryCE_distill=False): # configurations super().__init__() self.classes = classes self.label = "Classifier" self.fc_layers = fc_layers # settings for training self.binaryCE = binaryCE self.binaryCE_distill = binaryCE_distill # check whether there is at least 1 fc-layer if fc_layers < 1: raise ValueError( "The classifier needs to have at least 1 fully-connected layer." ) ######------SPECIFY MODEL------###### # flatten image to 2D-tensor self.flatten = utils.Flatten() # fully connected hidden layers self.fcE = MLP(input_size=image_channels * image_size**2, output_size=fc_units, layers=fc_layers - 1, hid_size=fc_units, drop=fc_drop, batch_norm=fc_bn, nl=fc_nl, bias=bias, excitability=excitability, excit_buffer=excit_buffer, gated=gated) mlp_output_size = fc_units if fc_layers > 1 else image_channels * image_size**2 print('*************num of classes in encoder: ' + str(classes)) self.vgg = vgg16(classes) # classifier #self.classifier = fc_layer(mlp_output_size, classes, excit_buffer=True, nl='none', drop=fc_drop) def list_init_layers(self): '''Return list of modules whose parameters could be initialized differently (i.e., conv- or fc-layers).''' list = [] list += self.fcE.list_init_layers() list += self.classifier.list_init_layers() return list @property def name(self): #return "{}_c{}".format(self.fcE.name, self.classes) return "vgg" def forward(self, x): return self.vgg(x) def feature_extractor(self, images): return self.fcE(self.flatten(images)) def train_a_batch(self, x, y, scores=None, x_=None, y_=None, scores_=None, rnt=0.5, active_classes=None, task=1): '''Train model for one batch ([x],[y]), possibly supplemented with replayed data ([x_],[y_/scores_]). [x] <tensor> batch of inputs (could be None, in which case only 'replayed' data is used) [y] <tensor> batch of corresponding labels [scores] None or <tensor> 2Dtensor:[batch]x[classes] predicted "scores"/"logits" for [x] [x_] None or (<list> of) <tensor> batch of replayed inputs [y_] None or (<list> of) <tensor> batch of corresponding "replayed" labels [scores_] None or (<list> of) <tensor> 2Dtensor:[batch]x[classes] predicted "scores"/"logits" for [x_] [rnt] <number> in [0,1], relative importance of new task [active_classes] None or (<list> of) <list> with "active" classes [task] <int>, for setting task-specific mask''' # Set model to training-mode self.train() # Reset optimizer self.optimizer.zero_grad() ##--(1)-- CURRENT DATA --## if x is not None: # If requested, apply correct task-specific mask if self.mask_dict is not None: self.apply_XdGmask(task=task) # Run model y_hat = self(x) # -if needed, remove predictions for classes not in current task if active_classes is not None: class_entries = active_classes[-1] if type( active_classes[0]) == list else active_classes y_hat = y_hat[:, class_entries] # Calculate prediction loss if self.binaryCE: # -binary prediction loss binary_targets = utils.to_one_hot(y.cpu(), y_hat.size(1)).to(y.device) if self.binaryCE_distill and (scores is not None): classes_per_task = int(y_hat.size(1) / task) binary_targets = binary_targets[:, -(classes_per_task):] binary_targets = torch.cat( [torch.sigmoid(scores / self.KD_temp), binary_targets], dim=1) predL = None if y is None else F.binary_cross_entropy_with_logits( input=y_hat, target=binary_targets, reduction='none').sum( dim=1).mean( ) #--> sum over classes, then average over batch else: # -multiclass prediction loss predL = None if y is None else F.cross_entropy( input=y_hat, target=y, reduction='elementwise_mean') # Weigh losses loss_cur = predL # Calculate training-precision precision = None if y is None else ( y == y_hat.max(1)[1]).sum().item() / x.size(0) # If XdG is combined with replay, backward-pass needs to be done before new task-mask is applied if (self.mask_dict is not None) and (x_ is not None): weighted_current_loss = rnt * loss_cur weighted_current_loss.backward() else: precision = predL = None # -> it's possible there is only "replay" [i.e., for offline with incremental task learning] ##--(2)-- REPLAYED DATA --## if x_ is not None: # In the Task-IL scenario, [y_] or [scores_] is a list and [x_] needs to be evaluated on each of them # (in case of 'exact' or 'exemplar' replay, [x_] is also a list! TaskIL = (type(y_) == list) if (y_ is not None) else (type(scores_) == list) if not TaskIL: y_ = [y_] scores_ = [scores_] active_classes = [active_classes] if (active_classes is not None) else None n_replays = len(y_) if (y_ is not None) else len(scores_) # Prepare lists to store losses for each replay loss_replay = [None] * n_replays predL_r = [None] * n_replays distilL_r = [None] * n_replays # Run model (if [x_] is not a list with separate replay per task and there is no task-specific mask) if (not type(x_) == list) and (self.mask_dict is None): y_hat_all = self(x_) # Loop to evalute predictions on replay according to each previous task for replay_id in range(n_replays): # -if [x_] is a list with separate replay per task, evaluate model on this task's replay if (type(x_) == list) or (self.mask_dict is not None): x_temp_ = x_[replay_id] if type(x_) == list else x_ if self.mask_dict is not None: self.apply_XdGmask(task=replay_id + 1) y_hat_all = self(x_temp_) # -if needed (e.g., Task-IL or Class-IL scenario), remove predictions for classes not in replayed task y_hat = y_hat_all if ( active_classes is None ) else y_hat_all[:, active_classes[replay_id]] # Calculate losses if (y_ is not None) and (y_[replay_id] is not None): if self.binaryCE: binary_targets_ = utils.to_one_hot( y_[replay_id].cpu(), y_hat.size(1)).to(y_[replay_id].device) predL_r[replay_id] = F.binary_cross_entropy_with_logits( input=y_hat, target=binary_targets_, reduction='none').sum(dim=1).mean( ) #--> sum over classes, then average over batch else: predL_r[replay_id] = F.cross_entropy( y_hat, y_[replay_id], reduction='elementwise_mean') if (scores_ is not None) and (scores_[replay_id] is not None): # n_classes_to_consider = scores.size(1) #--> with this version, no zeroes are added to [scores]! n_classes_to_consider = y_hat.size( 1 ) #--> zeros will be added to [scores] to make it this size! kd_fn = utils.loss_fn_kd_binary if self.binaryCE else utils.loss_fn_kd distilL_r[replay_id] = kd_fn( scores=y_hat[:, :n_classes_to_consider], target_scores=scores_[replay_id], T=self.KD_temp) # Weigh losses if self.replay_targets == "hard": loss_replay[replay_id] = predL_r[replay_id] elif self.replay_targets == "soft": loss_replay[replay_id] = distilL_r[replay_id] # If task-specific mask, backward pass needs to be performed before next task-mask is applied if self.mask_dict is not None: weighted_replay_loss_this_task = ( 1 - rnt) * loss_replay[replay_id] / n_replays weighted_replay_loss_this_task.backward() # Calculate total loss loss_replay = None if (x_ is None) else sum(loss_replay) / n_replays loss_total = loss_replay if ( x is None) else (loss_cur if x_ is None else rnt * loss_cur + (1 - rnt) * loss_replay) ##--(3)-- ALLOCATION LOSSES --## # Add SI-loss (Zenke et al., 2017) surrogate_loss = self.surrogate_loss() if self.si_c > 0: loss_total += self.si_c * surrogate_loss # Add EWC-loss ewc_loss = self.ewc_loss() if self.ewc_lambda > 0: loss_total += self.ewc_lambda * ewc_loss # Backpropagate errors (if not yet done) if (self.mask_dict is None) or (x_ is None): loss_total.backward() # Take optimization-step self.optimizer.step() # Return the dictionary with different training-loss split in categories return { 'loss_total': loss_total.item(), 'loss_current': loss_cur.item() if x is not None else 0, 'loss_replay': loss_replay.item() if (loss_replay is not None) and (x is not None) else 0, 'pred': predL.item() if predL is not None else 0, 'pred_r': sum(predL_r).item() / n_replays if (x_ is not None and predL_r[0] is not None) else 0, 'distil_r': sum(distilL_r).item() / n_replays if (x_ is not None and distilL_r[0] is not None) else 0, 'ewc': ewc_loss.item(), 'si_loss': surrogate_loss.item(), 'precision': precision if precision is not None else 0., }
class Classifier(ContinualLearner, Replayer, ExemplarHandler): '''Model for classifying images, "enriched" as "ContinualLearner"-, Replayer- and ExemplarHandler-object.''' def __init__(self, image_size, image_channels, classes, fc_layers=3, fc_units=1000, fc_drop=0, fc_bn=False, fc_nl="relu", gated=False, bias=True, excitability=False, excit_buffer=False, binaryCE=False, binaryCE_distill=False, AGEM=False, experiment='splitMNIST'): # configurations super().__init__() self.classes = classes self.label = "Classifier" self.fc_layers = fc_layers # settings for training self.binaryCE = binaryCE # -> use binary (instead of multiclass) prediction error self.binaryCE_distill = binaryCE_distill # -> for classes from previous tasks, use the by the previous model # predicted probs as binary targets (only in Class-IL with binaryCE) self.AGEM = AGEM # -> use gradient of replayed data as inequality constraint for (instead of adding it to) # the gradient of the current data (as in A-GEM, see Chaudry et al., 2019; ICLR) # Online mem distillation self.is_offline_training = False self.is_ready_distill = False self.alpha_t = 0.5 # check whether there is at least 1 fc-layer if fc_layers < 1: raise ValueError("The classifier needs to have at least 1 fully-connected layer.") ######------SPECIFY MODEL------###### self.experiment = experiment if self.experiment in ['CIFAR10', 'CIFAR100', 'CUB2011']: self.fcE = rn.resnet32(classes, pretrained=False) self.fcE.linear = nn.Identity() self.classifier = fc_layer(64, classes, excit_buffer=True, nl='none', drop=fc_drop) elif self.experiment == 'ImageNet': ResNet.name = 'ResNet-18' self.fcE = resnet18(pretrained=True) self.fcE.fc = nn.Identity() self.classifier = fc_layer(512, classes, excit_buffer=True, nl='none', drop=fc_drop) else: # flatten image to 2D-tensor self.flatten = utils.Flatten() # fully connected hidden layers self.fcE = MLP(input_size=image_channels * image_size ** 2, output_size=fc_units, layers=fc_layers - 1, hid_size=fc_units, drop=fc_drop, batch_norm=fc_bn, nl=fc_nl, bias=bias, excitability=excitability, excit_buffer=excit_buffer, gated=gated) mlp_output_size = fc_units if fc_layers > 1 else image_channels * image_size ** 2 # classifier self.classifier = fc_layer(mlp_output_size, classes, excit_buffer=True, nl='none', drop=fc_drop) def list_init_layers(self): '''Return list of modules whose parameters could be initialized differently (i.e., conv- or fc-layers).''' list = [] list += self.fcE.list_init_layers() list += self.classifier.list_init_layers() return list @property def name(self): return "{}_c{}".format(self.fcE.name, self.classes) def forward(self, x): final_features = self.feature_extractor(x) return self.classifier(final_features) def feature_extractor(self, images): if self.experiment not in ['splitMNIST', 'permMNIST', 'rotMNIST']: return self.fcE(images) else: return self.fcE(self.flatten(images)) def select_triplets(self, embeds, y_score, x, y, triplet_selection, task, scenario, use_embeddings, multi_negative): uq = torch.unique(y).cpu().numpy() selection_strategies = triplet_selection.split('-') # Select instances in the batch for replay later for m in uq: neg_y = np.delete(uq, np.where(uq == m)) mask = y == m mask_neg = y != m ce_m = y_score[mask] if ce_m.size(0) != 0: # Select anchor and hard positive instances for class m positive_batch = x[mask] positive_embed_batch = embeds[mask] anchor_idx = torch.argmin(ce_m) anchor_x = positive_batch[anchor_idx].unsqueeze(dim=0) anchor_embed = positive_embed_batch[anchor_idx].unsqueeze(dim=0) # anchor should not equal positive positive_batch = torch.cat( (positive_batch[:anchor_idx], positive_batch[anchor_idx + 1:]), dim=0) positive_embed_batch = torch.cat( (positive_embed_batch[:anchor_idx], positive_embed_batch[anchor_idx + 1:]), dim=0) if positive_batch.size(0) != 0: if use_embeddings: anchor_batch = anchor_embed.expand(positive_embed_batch.size()) positive_dist = F.pairwise_distance(anchor_batch.view(anchor_batch.size(0), -1), positive_embed_batch.view(positive_embed_batch.size(0), -1)) else: anchor_batch = anchor_x.expand(positive_batch.size()) positive_dist = F.pairwise_distance(anchor_batch.view(anchor_batch.size(0), -1), positive_batch.view(positive_batch.size(0), -1)) if selection_strategies[0] == 'HP': # Hard positive _, positive_idx = torch.topk(positive_dist, 1) else: # Easy positive _, positive_idx = torch.topk(positive_dist, 1, largest=False) positive_x = positive_batch[positive_idx] x_m = torch.cat((anchor_x, positive_x), dim=0) y_m = torch.tensor([m, m]) else: x_m = anchor_x y_m = torch.tensor([m]) if scenario in ['task', 'domain']: self.add_instances_to_online_exemplar_sets(x_m, y_m, (y_m + len(uq) * (task - 1)).detach().cpu().numpy()) else: self.add_instances_to_online_exemplar_sets(x_m, y_m, y_m.detach().cpu().numpy()) negative_batch = x[mask_neg] negative_batch_y = y[mask_neg] negative_embed_batch = embeds[mask_neg] if negative_batch.size(0) != 0: if use_embeddings: anchor_batch = anchor_embed.expand(negative_embed_batch.size()) negative_dist = F.pairwise_distance(anchor_batch.view(anchor_batch.size(0), -1), negative_embed_batch.view(negative_embed_batch.size(0), -1)) else: anchor_batch = anchor_x.expand(negative_batch.size()) negative_dist = F.pairwise_distance(anchor_batch.view(anchor_batch.size(0), -1), negative_batch.view(negative_batch.size(0), -1)) # Select instances for each negative class if multi_negative: for n in neg_y: mask_neg_n = negative_batch_y == n negative_dist_n = negative_dist[mask_neg_n] negative_batch_n = negative_batch[mask_neg_n] negative_batch_y_n = negative_batch_y[mask_neg_n] if selection_strategies[1] == 'HN': # Hard negative _, negative_idx = torch.topk(negative_dist_n, int(selection_strategies[2]), largest=False) negative_x = negative_batch_n[negative_idx] negative_y = negative_batch_y_n[negative_idx] elif selection_strategies[1] == 'SHN': # Semi-hard negative if use_embeddings: positive_embed = positive_embed_batch[positive_idx].unsqueeze(dim=0) dap = F.pairwise_distance(anchor_embed.view(anchor_x.size(0), -1), positive_embed.view(positive_x.size(0), -1)) else: dap = F.pairwise_distance(anchor_x.view(anchor_x.size(0), -1), positive_x.view(positive_x.size(0), -1)) valid_shn_idx = negative_dist_n > dap if valid_shn_idx.any(): shn_batch = negative_batch_n[valid_shn_idx] shn_y = negative_batch_y_n[valid_shn_idx] # negative_idx = torch.argmin(negative_dist[valid_shn_idx]) _, negative_idx = torch.topk(negative_dist_n, int(selection_strategies[2]), largest=False) negative_x = shn_batch[negative_idx] negative_y = shn_y[negative_idx] else: # There is no semi-hard negative sample, ignore negative sample negative_x = None negative_y = None else: # Easy negative _, negative_idx = torch.topk(negative_dist_n, int(selection_strategies[2])) negative_x = negative_batch_n[negative_idx] negative_y = negative_batch_y_n[negative_idx] if negative_x is not None and negative_y is not None: if scenario in ['task', 'domain']: self.add_instances_to_online_exemplar_sets(negative_x, negative_y, (negative_y + len(uq) * ( task - 1)).detach().cpu().numpy()) else: self.add_instances_to_online_exemplar_sets(negative_x, negative_y, negative_y.detach().cpu().numpy()) else: if selection_strategies[1] == 'HN': # Hard negative _, negative_idx = torch.topk(negative_dist, int(selection_strategies[2]), largest=False) negative_x = negative_batch[negative_idx] negative_y = negative_batch_y[negative_idx] elif selection_strategies[1] == 'SHN': # Semi-hard negative if use_embeddings: positive_embed = positive_embed_batch[positive_idx].unsqueeze(dim=0) dap = F.pairwise_distance(anchor_embed.view(anchor_x.size(0), -1), positive_embed.view(positive_x.size(0), -1)) else: dap = F.pairwise_distance(anchor_x.view(anchor_x.size(0), -1), positive_x.view(positive_x.size(0), -1)) valid_shn_idx = negative_dist > dap if valid_shn_idx.any(): shn_batch = negative_batch[valid_shn_idx] shn_y = negative_batch_y[valid_shn_idx] # negative_idx = torch.argmin(negative_dist[valid_shn_idx]) _, negative_idx = torch.topk(negative_dist[valid_shn_idx], int(selection_strategies[2]), largest=False) negative_x = shn_batch[negative_idx] negative_y = shn_y[negative_idx] else: # There is no semi-hard negative sample, ignore negative sample negative_x = None negative_y = None else: # Easy negative _, negative_idx = torch.topk(negative_dist, int(selection_strategies[2])) negative_x = negative_batch[negative_idx] negative_y = negative_batch_y[negative_idx] if negative_x is not None and negative_y is not None: if scenario in ['task', 'domain']: self.add_instances_to_online_exemplar_sets(negative_x, negative_y, (negative_y + len(uq) * ( task - 1)).detach().cpu().numpy()) else: self.add_instances_to_online_exemplar_sets(negative_x, negative_y, negative_y.detach().cpu().numpy()) def select_instances(self, embeds, x, y, scenario, task): uq, _ = torch.sort(torch.unique(y)) uq = uq.cpu().numpy() exemplars_per_class = int(np.floor(self.memory_budget / (len(uq) * task))) exemplar_set = [] if self.herding: # Accumulate class means for m in uq: mask = y == m xm = x[mask] embedsm = embeds[mask] if self.norm_exemplars: features = F.normalize(embedsm, p=2, dim=1) # calculate mean of all features class_mean = torch.mean(features, dim=0, keepdim=True) # if self.norm_exemplars: # class_mean = F.normalize(class_mean, p=2, dim=1) # one by one, select exemplar that makes mean of all exemplars as close to [class_mean] as possible exemplar_features = torch.zeros_like(features[:min(exemplars_per_class, embedsm.size(0))]) list_of_selected = [] for k in range(min(exemplars_per_class, embedsm.size(0))): if k > 0: exemplar_sum = torch.sum(exemplar_features[:k], dim=0).unsqueeze(0) features_means = (features + exemplar_sum) / (k + 1) features_dists = features_means - class_mean else: features_dists = features - class_mean index_selected = np.argmin(torch.norm(features_dists, p=2, dim=1).detach().cpu().numpy()) if index_selected in list_of_selected: raise ValueError("Exemplars should not be repeated!!!!") list_of_selected.append(index_selected) exemplar_set.append(xm[index_selected].detach().cpu().numpy()) exemplar_features[k] = features[index_selected].clone() # make sure this example won't be selected again features[index_selected] = features[index_selected] + 10000 if scenario in ['task', 'domain']: if len(self.exemplar_sets) == ((task - 1) * len(uq) + m % len(uq)): self.exemplar_means.append(class_mean) self.exemplar_sets.append(np.array(exemplar_set)) elif len(self.exemplar_sets) < ((task - 1) * len(uq) + m % len(uq)): self.exemplar_means[m + len(uq) * (task - 1)] = (self.exemplar_means[m + len(uq) * (task - 1)]+ class_mean)/2 self.exemplar_sets[m] = np.concatenate( (self.exemplar_sets[m + len(uq) * (task - 1)], exemplar_set), axis=0) else: if len(self.exemplar_sets) == ((task - 1) * len(uq) + m % len(uq)): self.exemplar_means.append(class_mean) self.exemplar_sets.append(np.array(exemplar_set)) elif len(self.exemplar_sets) < ((task - 1) * len(uq) + m % len(uq)): self.exemplar_means[m] = (self.exemplar_means[m] + class_mean) / 2 self.exemplar_sets[m] = np.concatenate( (self.exemplar_sets[m], exemplar_set), axis=0) else: for m in uq: mask = y == m xm = x[mask] indeces_selected = np.random.choice(xm.size(0), size=min(xm.size(0),exemplars_per_class), replace=False) if scenario in ['task', 'domain']: if len(self.exemplar_sets) < task * len(uq): self.exemplar_sets.append(xm[indeces_selected].detach().cpu().numpy()) else: # Concate to exsisting self.exemplar_sets[m + len(uq) * (task - 1)] = np.concatenate( (self.exemplar_sets[m + len(uq) * (task - 1)], xm[indeces_selected].detach().cpu().numpy()), axis=0) else: if len(self.exemplar_sets) < task * len(uq): self.exemplar_sets.append(xm[indeces_selected].detach().cpu().numpy()) else: # Concate to exsisting self.exemplar_sets[m] = np.concatenate( (self.exemplar_sets[m], xm[indeces_selected].detach().cpu().numpy()), axis=0) self.reduce_exemplar_sets(exemplars_per_class) # for i in range(len(self.exemplar_sets)): # print("Task %d Class %d" % (task, i), self.exemplar_sets[i].shape) def train_a_batch(self, x, y, scores=None, x_=None, y_=None, scores_=None, rnt=0.5, active_classes=None, task=1, scenario='class', teacher=None, params_dict=None, epoch=0): '''Train model for one batch ([x],[y]), possibly supplemented with replayed data ([x_],[y_/scores_]). [x] <tensor> batch of inputs (could be None, in which case only 'replayed' data is used) [y] <tensor> batch of corresponding labels [scores] None or <tensor> 2Dtensor:[batch]x[classes] predicted "scores"/"logits" for [x] NOTE: only to be used for "BCE with distill" (only when scenario=="class") [x_] None or (<list> of) <tensor> batch of replayed inputs [y_] None or (<list> of) <tensor> batch of corresponding "replayed" labels [scores_] None or (<list> of) <tensor> 2Dtensor:[batch]x[classes] predicted "scores"/"logits" for [x_] [rnt] <number> in [0,1], relative importance of new task [active_classes] None or (<list> of) <list> with "active" classes [task] <int>, for setting task-specific mask''' # Set model to training-mode self.train() # Reset optimizer self.optimizer.zero_grad() # Should gradient be computed separately for each task? (needed when a task-mask is combined with replay) gradient_per_task = True if ((self.mask_dict is not None) and (x_ is not None)) else False ##--(1)-- REPLAYED DATA --## if x_ is not None: # print(y_, task) # In the Task-IL scenario, [y_] or [scores_] is a list and [x_] needs to be evaluated on each of them # (in case of 'exact' or 'exemplar' replay, [x_] is also a list! TaskIL = (type(y_) == list) if (y_ is not None) else (type(scores_) == list) if not TaskIL: y_ = [y_] scores_ = [scores_] active_classes = [active_classes] if (active_classes is not None) else None n_replays = len(y_) if (y_ is not None) else len(scores_) # Prepare lists to store losses for each replay loss_KD = [None] * n_replays loss_replay = [None] * n_replays predL_r = [None] * n_replays distilL_r = [None] * n_replays # Run model (if [x_] is not a list with separate replay per task and there is no task-specific mask) if (not type(x_) == list) and (self.mask_dict is None): y_hat_all = self(x_) if teacher is not None and task > 1: if teacher.is_ready_distill: teacher.eval() with torch.no_grad(): embeds_teacher = teacher.feature_extractor(x_) y_hat_teacher = teacher.classifier(embeds_teacher) else: y_hat_teacher = None else: y_hat_teacher = None # Loop to evalute predictions on replay according to each previous task for replay_id in range(n_replays): # -if [x_] is a list with separate replay per task, evaluate model on this task's replay if (type(x_) == list) or (self.mask_dict is not None): x_temp_ = x_[replay_id] if type(x_) == list else x_ if self.mask_dict is not None: self.apply_XdGmask(task=replay_id + 1) y_hat_all = self(x_temp_) if teacher is not None and task > 1: if teacher.is_ready_distill: teacher.eval() with torch.no_grad(): embeds_teacher = teacher.feature_extractor(x_temp_) y_hat_teacher = teacher.classifier(embeds_teacher) else: y_hat_teacher = None else: y_hat_teacher = None # -if needed (e.g., Task-IL or Class-IL scenario), remove predictions for classes not in replayed task y_hat = y_hat_all if (active_classes is None) else y_hat_all[:, active_classes[replay_id]] if y_hat_teacher is not None: y_hat_teacher = y_hat_teacher if (active_classes is None) else y_hat_teacher[:, active_classes[replay_id]] # Calculate losses if (y_ is not None) and (y_[replay_id] is not None): if self.binaryCE: binary_targets_ = utils.to_one_hot(y_[replay_id].cpu(), y_hat.size(1)).to(y_[replay_id].device) predL_r[replay_id] = F.binary_cross_entropy_with_logits( input=y_hat, target=binary_targets_, reduction='none' ).sum(dim=1).mean() # --> sum over classes, then average over batch else: predL_r[replay_id] = F.cross_entropy(y_hat, y_[replay_id], reduction='mean') # Compute distillation loss from teacher outputs if y_hat_teacher is not None: if params_dict['distill_type'] in ['E', 'ET', 'ES', 'ETS']: with torch.no_grad(): y_hat_ensemble = 0.5 * (y_hat.clone() + y_hat_teacher.clone()) if params_dict['distill_type'] in ['ET', 'ETS']: loss_KD[replay_id] = 0.5 * (F.kl_div(F.log_softmax(y_hat / self.KD_temp, dim=1), F.softmax(y_hat_ensemble / self.KD_temp, dim=1)) * (self.KD_temp * self.KD_temp) + F.kl_div(F.log_softmax(y_hat / self.KD_temp, dim=1), F.softmax(y_hat_teacher / self.KD_temp, dim=1)) * (self.KD_temp * self.KD_temp)) else: # distill: E, ES loss_KD[replay_id] = F.kl_div(F.log_softmax(y_hat / self.KD_temp, dim=1), F.softmax(y_hat_ensemble / self.KD_temp, dim=1)) \ * (self.KD_temp * self.KD_temp) else: # distill: T, TS loss_KD[replay_id] = F.kl_div(F.log_softmax(y_hat / self.KD_temp, dim=1), F.softmax(y_hat_teacher / self.KD_temp, dim=1)) \ * (self.KD_temp * self.KD_temp) # loss_KD = self.alpha_t * loss_KD + F.cross_entropy(y_hat, y) * (1. - self.alpha_t) if (scores_ is not None) and (scores_[replay_id] is not None): # n_classes_to_consider = scores.size(1) #--> with this version, no zeroes are added to [scores]! n_classes_to_consider = y_hat.size(1) # --> zeros will be added to [scores] to make it this size! kd_fn = utils.loss_fn_kd_binary if self.binaryCE else utils.loss_fn_kd distilL_r[replay_id] = kd_fn(scores=y_hat[:, :n_classes_to_consider], target_scores=scores_[replay_id], T=self.KD_temp) # Weigh losses if self.replay_targets == "hard": loss_replay[replay_id] = predL_r[replay_id] elif self.replay_targets == "soft": loss_replay[replay_id] = distilL_r[replay_id] # If needed, perform backward pass before next task-mask (gradients of all tasks will be accumulated) if gradient_per_task: weight = 1 if self.AGEM else (1 - rnt) weighted_replay_loss_this_task = weight * loss_replay[replay_id] / n_replays weighted_replay_loss_this_task.backward() # Calculate total replay loss loss_replay = None if (x_ is None) else sum(loss_replay) / n_replays # Calculate total kd loss loss_KD = None if any(lkd is None for lkd in loss_KD) else sum(loss_KD) / n_replays else: loss_KD = None # If using A-GEM, calculate and store averaged gradient of replayed data if self.AGEM and x_ is not None: # Perform backward pass to calculate gradient of replayed batch (if not yet done) if not gradient_per_task: loss_replay = loss_replay.clamp(min=1e-6) loss_replay.backward() # Reorganize the gradient of the replayed batch as a single vector grad_rep = [] for p in self.parameters(): if p.requires_grad: grad_rep.append(p.grad.view(-1)) grad_rep = torch.cat(grad_rep) # Reset gradients (with A-GEM, gradients of replayed batch should only be used as inequality constraint) self.optimizer.zero_grad() ##--(2)-- CURRENT DATA --## if x is not None: # If requested, apply correct task-specific mask if self.mask_dict is not None: self.apply_XdGmask(task=task) # Run model embeds = self.feature_extractor(x) y_hat = self.classifier(embeds) # -if needed, remove predictions for classes not in current task if active_classes is not None: class_entries = active_classes[-1] if type(active_classes[0]) == list else active_classes y_hat = y_hat[:, class_entries] # Calculate prediction loss if self.binaryCE: # -binary prediction loss binary_targets = utils.to_one_hot(y.cpu(), y_hat.size(1)).to(y.device) if self.binaryCE_distill and (scores is not None): classes_per_task = int(y_hat.size(1) / task) binary_targets = binary_targets[:, -(classes_per_task):] binary_targets = torch.cat([torch.sigmoid(scores / self.KD_temp), binary_targets], dim=1) y_score = F.binary_cross_entropy_with_logits( input=y_hat, target=binary_targets, reduction='none' ).sum(dim=1) # --> sum over classes, predL = None if y is None else y_score.mean() # average over batch if params_dict['mem_online'] and epoch == 0: self.select_instances(embeds, x, y, scenario, task) else: if params_dict['use_otr'] and epoch == 0: self.select_triplets(embeds, y_score, x, y, params_dict['triplet_selection'], task, scenario, params_dict['use_embeddings'], params_dict['multi_negative']) else: # -multiclass prediction loss y_score = F.cross_entropy(input=y_hat, target=y, reduction='none') predL = None if y is None else y_score.mean() if params_dict['mem_online'] and epoch == 0: self.select_instances(embeds, x, y, scenario, task) else: if params_dict['use_otr'] and epoch == 0: self.select_triplets(embeds, y_score, x, y, params_dict['triplet_selection'], task, scenario, params_dict['use_embeddings'], params_dict['multi_negative']) loss_cur = predL # Calculate training-precision precision = None if y is None else (y == y_hat.max(1)[1]).sum().item() / x.size(0) # If backward passes are performed per task (e.g., XdG combined with replay), perform backward pass if gradient_per_task: weighted_current_loss = rnt * loss_cur weighted_current_loss.backward() else: precision = predL = None # -> it's possible there is only "replay" [e.g., for offline with task-incremental learning] # Combine loss from current and replayed batch if x_ is None or self.AGEM: loss_total = loss_cur else: loss_total = loss_replay if (x is None) else rnt * loss_cur + (1 - rnt) * loss_replay if loss_KD is not None: loss_total = loss_total + loss_KD ##--(3)-- ALLOCATION LOSSES --## # Add SI-loss (Zenke et al., 2017) surrogate_loss = self.surrogate_loss() if self.si_c > 0: loss_total += self.si_c * surrogate_loss # Add EWC-loss ewc_loss = self.ewc_loss() if self.ewc_lambda > 0: loss_total += self.ewc_lambda * ewc_loss # Backpropagate errors (if not yet done) if not gradient_per_task: loss_total.backward() # If using A-GEM, potentially change gradient: if self.AGEM and x_ is not None: # -reorganize gradient (of current batch) as single vector grad_cur = [] for p in self.parameters(): if p.requires_grad: grad_cur.append(p.grad.view(-1)) grad_cur = torch.cat(grad_cur) # -check inequality constrain angle = (grad_cur * grad_rep).sum() if angle < 0: # -if violated, project the gradient of the current batch onto the gradient of the replayed batch ... length_rep = (grad_rep * grad_rep).sum() grad_proj = grad_cur - (angle / length_rep) * grad_rep # -...and replace all the gradients within the model with this projected gradient index = 0 for p in self.parameters(): if p.requires_grad: n_param = p.numel() # number of parameters in [p] p.grad.copy_(grad_proj[index:index + n_param].view_as(p)) index += n_param # Take optimization-step self.optimizer.step() # Return the dictionary with different training-loss split in categories return { 'loss_total': loss_total.item(), 'loss_current': loss_cur.item() if x is not None else 0, 'loss_replay': loss_replay.item() if (x_ is not None and loss_replay is not None) else 0, 'pred': predL.item() if predL is not None else 0, 'pred_r': sum(predL_r).item() / n_replays if (x_ is not None and predL_r[0] is not None) else 0, 'distil_r': sum(distilL_r).item() / n_replays if (x_ is not None and distilL_r[0] is not None) else 0, 'ewc': ewc_loss.item(), 'si_loss': surrogate_loss.item(), 'precision': precision if precision is not None else 0., } def train_epoch(self, train_loader, criterion, optimizer, active_classes, params_dict, writer=None): # class_entries = active_classes[-1] if type(active_classes[0]) == list else active_classes self.train() tlosses = [] for batch_idx, batch in enumerate(train_loader): x, y = batch x, y = x.to(self._device()), y.to(self._device()) optimizer.zero_grad() y_hat = self(x) # y_hat = y_hat[:, class_entries] if params_dict['teacher_loss'] == 'BCE': y = utils.to_one_hot(y.cpu(), y_hat.size(1)).to(y.device) loss = criterion(y_hat, y) loss.backward() tlosses.append(loss.item()) # writer.add_scalar('Training loss', loss.item(), params_dict['epoch'] * len(train_loader) + batch_idx) optimizer.step() return tlosses def valid_epoch(self, val_loader, criterion, active_classes, params_dict, writer=None): # class_entries = active_classes[-1] if type(active_classes[0]) == list else active_classes valid_losses = [] self.eval() with torch.no_grad(): for batch_idx, batch in enumerate(val_loader, 0): x, y = batch x, y = x.to(self._device()), y.to(self._device()) y_hat = self(x) # y_hat = y_hat[:, class_entries] if params_dict['teacher_loss'] == 'BCE': y = utils.to_one_hot(y.cpu(), y_hat.size(1)).to(y.device) valid_loss = criterion(y_hat, y) valid_losses.append(valid_loss.item()) # writer.add_scalar('Validation loss', valid_loss.item(), params_dict['epoch'] * len(val_loader) + batch_idx) self.train() return valid_losses def train_via_KD(self, model, x, distill_type, active_classes): if distill_type == 'T': return model.eval() with torch.no_grad(): y_hat = model(x) model.train() self.train() self.optimizer.zero_grad() y_hat_teacher = self(x) if active_classes is not None: class_entries = active_classes[-1] if type(active_classes[0]) == list else active_classes y_hat = y_hat[:, class_entries] y_hat_teacher = y_hat_teacher[:, class_entries] if distill_type in ['E', 'ET', 'ES', 'ETS']: with torch.no_grad(): y_hat_ensemble = 0.5 * (y_hat_teacher.clone() + y_hat) if distill_type in ['ES', 'ETS']: # distill from ensemble and student to teacher loss = 0.5 * (F.kl_div(F.log_softmax(y_hat_teacher / self.KD_temp, dim=1), F.softmax(y_hat_ensemble / self.KD_temp, dim=1)) * (self.KD_temp * self.KD_temp) + F.kl_div(F.log_softmax(y_hat_teacher / self.KD_temp, dim=1), F.softmax(y_hat / self.KD_temp, dim=1)) * (self.KD_temp * self.KD_temp)) else: # distill from ensemble to teacher loss = F.kl_div(F.log_softmax(y_hat_teacher / self.KD_temp, dim=1), F.softmax(y_hat_ensemble / self.KD_temp, dim=1)) \ * (self.KD_temp * self.KD_temp) else: loss = F.kl_div(F.log_softmax(y_hat_teacher / self.KD_temp, dim=1), F.softmax(y_hat / self.KD_temp, dim=1)) \ * (self.KD_temp * self.KD_temp) loss.backward() self.optimizer.step()