def _forward_loss(self, training_network, inputs, targets, memory_flags, metrics, gradcam_grad=None, gradcam_act=None, **kwargs): inputs, targets = inputs.to(self._device), targets.to(self._device) onehot_targets = utils.to_onehot(targets, self._n_classes).to(self._device) outputs = training_network(inputs) if gradcam_act is not None: outputs["gradcam_gradients"] = gradcam_grad outputs["gradcam_activations"] = gradcam_act loss = self._compute_loss(inputs, outputs, targets, onehot_targets, memory_flags, metrics) if not utils.check_loss(loss): raise ValueError("Loss became invalid ({}).".format(loss)) metrics["loss"] += loss.item() return loss
def _forward_loss(self, training_network, inputs, targets, memory_flags, metrics): inputs, targets = inputs.to(self._device), targets.to(self._device) onehot_targets = utils.to_onehot(targets, self._n_classes).to(self._device) outputs = training_network(inputs) loss = self._compute_loss(inputs, outputs, targets, onehot_targets, memory_flags, metrics) if not utils.check_loss(loss): raise ValueError("Loss became invalid ({}).".format(loss)) metrics["loss"] += loss.item() return loss
def _compute_loss(self, inputs, outputs, targets, onehot_targets, memory_flags, metrics): logits = outputs["logits"] if self._old_model is None: # Classification loss loss = F.cross_entropy(logits, targets) metrics["clf"] += loss.item() else: self._old_model.zero_grad() old_outputs = self._old_model(inputs) old_logits = old_outputs["logits"] # Classification loss loss = F.cross_entropy( logits[..., -self._task_size:], (targets - self._n_classes + self._task_size)) metrics["clf"] += loss.item() # Distillation on probabilities distill_loss = self._distillation_config[ "factor"] * F.binary_cross_entropy_with_logits( logits[..., :-self._task_size], torch.sigmoid(old_logits.detach())) metrics["dis"] += distill_loss.item() loss += distill_loss # Distillation on gradcam-generated attentions if self._attention_config: top_logits_indexes = logits[..., :-self._task_size].argmax( dim=1) onehot_top_logits = utils.to_onehot( top_logits_indexes, self._n_classes - self._task_size).to(self._device) logits[..., :-self._task_size].backward( gradient=onehot_top_logits, retain_graph=True) old_logits.backward(gradient=onehot_top_logits) if len(outputs["gradcam_gradients"]) > 1: gradcam_gradients = torch.cat([ g.to(self._device) for g in outputs["gradcam_gradients"] ]) gradcam_activations = torch.cat([ a.to(self._device) for a in outputs["gradcam_activations"] ]) else: gradcam_gradients = outputs["gradcam_gradients"][0] gradcam_activations = outputs["gradcam_activations"][0] attention_loss = losses.gradcam_distillation( gradcam_gradients, old_outputs["gradcam_gradients"][0].detach(), gradcam_activations, old_outputs["gradcam_activations"][0].detach(), **self._attention_config) metrics["ad"] += attention_loss.item() loss += attention_loss self._old_model.zero_grad() self._network.zero_grad() return loss
def _forward_loss(self, inputs, targets): inputs, targets = inputs.to(self._device), targets.to(self._device) targets = utils.to_onehot(targets, self._n_classes).to(self._device) logits = self._network(inputs) return self._compute_loss(inputs, logits, targets)
def _compute_loss(self, inputs, outputs, targets, onehot_targets, memory_flags): if self._args['use_sim_clr']: half_batch = inputs.shape[0] // 2 inputs = inputs[:half_batch] memory_flags = memory_flags[:half_batch] onehot_targets = onehot_targets[:half_batch] targets = targets[:half_batch] outputs['raw_features'] = outputs['raw_features'][:half_batch] outputs['features'] = outputs['raw_features'][:half_batch] outputs['logits'] = outputs['logits'][:half_batch] outputs['raw_logits'] = outputs['raw_logits'][:half_batch] for i in range(len(outputs['attention'])): outputs['attention'][i] = outputs['attention'][i][:half_batch] features, logits, atts = outputs["raw_features"], outputs[ "logits"], outputs["attention"] if self._post_processing_type is None: scaled_logits = self._network.post_process(logits) else: scaled_logits = logits * self._post_processing_type if self._old_model is not None: with torch.no_grad(): old_outputs = self._old_model(inputs) old_features = old_outputs["raw_features"] old_atts = old_outputs["attention"] if self._nca_config: nca_config = copy.deepcopy(self._nca_config) if self._network.post_processor: nca_config["scale"] = self._network.post_processor.factor loss = losses.nca(logits, targets, memory_flags=memory_flags, class_weights=self._class_weights, **nca_config) self._metrics["nca"] += loss.item() elif self._softmax_ce: loss = F.cross_entropy(scaled_logits, targets) self._metrics["cce"] += loss.item() # -------------------- # Distillation losses: # -------------------- if self._old_model is not None: if self._pod_flat_config: if self._pod_flat_config["scheduled_factor"]: factor = self._pod_flat_config[ "scheduled_factor"] * math.sqrt( self._n_classes / self._task_size) else: factor = self._pod_flat_config.get("factor", 1.) pod_flat_loss = factor * losses.embeddings_similarity( old_features, features) loss += pod_flat_loss self._metrics["flat"] += pod_flat_loss.item() if self._pod_spatial_config: if self._pod_spatial_config.get("scheduled_factor", False): factor = self._pod_spatial_config[ "scheduled_factor"] * math.sqrt( self._n_classes / self._task_size) else: factor = self._pod_spatial_config.get("factor", 1.) pod_spatial_loss = factor * losses.pod( old_atts, atts, memory_flags=memory_flags.bool(), task_percent=(self._task + 1) / self._n_tasks, **self._pod_spatial_config) loss += pod_spatial_loss self._metrics["pod"] += pod_spatial_loss.item() if self._perceptual_features: percep_feat = losses.perceptual_features_reconstruction( old_atts, atts, **self._perceptual_features) loss += percep_feat self._metrics["p_feat"] += percep_feat.item() if self._perceptual_style: percep_style = losses.perceptual_style_reconstruction( old_atts, atts, **self._perceptual_style) loss += percep_style self._metrics["p_sty"] += percep_style.item() if self._gradcam_distil: top_logits_indexes = logits[..., :-self._task_size].argmax( dim=1) try: onehot_top_logits = utils.to_onehot( top_logits_indexes, self._n_classes - self._task_size).to(self._device) except: import pdb pdb.set_trace() old_logits = old_outputs["logits"] logits[..., :-self._task_size].backward( gradient=onehot_top_logits, retain_graph=True) old_logits.backward(gradient=onehot_top_logits) if len(outputs["gradcam_gradients"]) > 1: gradcam_gradients = torch.cat([ g.to(self._device) for g in outputs["gradcam_gradients"] if g is not None ]) gradcam_activations = torch.cat([ a.to(self._device) for a in outputs["gradcam_activations"] if a is not None ]) else: gradcam_gradients = outputs["gradcam_gradients"][0] gradcam_activations = outputs["gradcam_activations"][0] if self._gradcam_distil.get("scheduled_factor", False): factor = self._gradcam_distil[ "scheduled_factor"] * math.sqrt( self._n_classes / self._task_size) else: factor = self._gradcam_distil.get("factor", 1.) try: attention_loss = factor * losses.gradcam_distillation( gradcam_gradients, old_outputs["gradcam_gradients"][0].detach(), gradcam_activations, old_outputs["gradcam_activations"][0].detach()) except: import pdb pdb.set_trace() self._metrics["grad"] += attention_loss.item() loss += attention_loss self._old_model.zero_grad() self._network.zero_grad() return loss
def _forward_loss(self, training_network, inputs, targets, memory_flags, gradcam_grad=None, gradcam_act=None, **kwargs): inputs, targets = inputs.to(self._device), targets.to(self._device) onehot_targets = utils.to_onehot(targets, self._n_classes).to(self._device) outputs = training_network(inputs) if gradcam_act is not None: outputs["gradcam_gradients"] = gradcam_grad outputs["gradcam_activations"] = gradcam_act loss = 0 if self._args['use_sim_clr']: similarity_loss = self.nt_xent_loss(outputs['features']) loss += self._args['sim_clr_alpha'] * similarity_loss self._metrics['xt_xent_loss'] += self._args[ 'sim_clr_alpha'] * similarity_loss.item() loss = self._compute_loss(inputs, outputs, targets, onehot_targets, memory_flags) if self._args['sv_regularization']: # sv regularization goes here loss += self.sv_loss(training_network, self._metrics) else: # add sv values to the metrics with torch.no_grad(): linear_layers_names = list( filter(lambda x: "classifier" in x, training_network.state_dict().keys())) linear_tensors = [] for linear_layer in training_network.classifier.parameters(): linear_tensors.append(linear_layer) linear_matrix = torch.cat(linear_tensors) mean_linear_tensors = [] num_proxy_per_class = self._args['classifier_config'][ 'proxy_per_class'] num_classes = int(linear_matrix.shape[0] / num_proxy_per_class) for i in range(num_classes): from_ = i * num_proxy_per_class to = (i + 1) * num_proxy_per_class mean_linear_tensors.append( torch.mean(linear_matrix[from_:to], axis=0, keepdim=True)) linear_matrix = torch.cat(mean_linear_tensors) u, s, v = torch.svd( torch.matmul(linear_matrix, linear_matrix.T)) sv_ratio = s[0] / (s[-1] + 0.00001) sv_entropy_positive = torch.sum( F.softmax(torch.sqrt(s), dim=0) * F.log_softmax(torch.sqrt(s), dim=0)) sv_entropy_negative = -torch.sum( F.softmax(torch.sqrt(s), dim=0) * F.log_softmax(torch.sqrt(s), dim=0)) norm = torch.mean(torch.norm(linear_matrix, dim=1)) self._metrics['norm'] += norm.item() self._metrics['sv_ratio'] += sv_ratio.item() self._metrics[ 'sv_entropy_positive'] += sv_entropy_positive.item() self._metrics[ 'sv_entropy_negative'] += sv_entropy_negative.item() #if not utils.check_loss(loss): # raise ValueError("A loss is NaN: {}".format(self._metrics)) self._metrics["loss"] += loss.item() return loss