Beispiel #1
0
    def _forward_loss(self,
                      training_network,
                      inputs,
                      targets,
                      memory_flags,
                      metrics,
                      gradcam_grad=None,
                      gradcam_act=None,
                      **kwargs):
        inputs, targets = inputs.to(self._device), targets.to(self._device)
        onehot_targets = utils.to_onehot(targets,
                                         self._n_classes).to(self._device)

        outputs = training_network(inputs)
        if gradcam_act is not None:
            outputs["gradcam_gradients"] = gradcam_grad
            outputs["gradcam_activations"] = gradcam_act

        loss = self._compute_loss(inputs, outputs, targets, onehot_targets,
                                  memory_flags, metrics)

        if not utils.check_loss(loss):
            raise ValueError("Loss became invalid ({}).".format(loss))

        metrics["loss"] += loss.item()

        return loss
Beispiel #2
0
    def _forward_loss(self, training_network, inputs, targets, memory_flags,
                      metrics):
        inputs, targets = inputs.to(self._device), targets.to(self._device)
        onehot_targets = utils.to_onehot(targets,
                                         self._n_classes).to(self._device)

        outputs = training_network(inputs)

        loss = self._compute_loss(inputs, outputs, targets, onehot_targets,
                                  memory_flags, metrics)

        if not utils.check_loss(loss):
            raise ValueError("Loss became invalid ({}).".format(loss))

        metrics["loss"] += loss.item()

        return loss
Beispiel #3
0
    def _compute_loss(self, inputs, outputs, targets, onehot_targets,
                      memory_flags, metrics):
        logits = outputs["logits"]

        if self._old_model is None:
            # Classification loss
            loss = F.cross_entropy(logits, targets)
            metrics["clf"] += loss.item()
        else:
            self._old_model.zero_grad()
            old_outputs = self._old_model(inputs)
            old_logits = old_outputs["logits"]

            # Classification loss
            loss = F.cross_entropy(
                logits[..., -self._task_size:],
                (targets - self._n_classes + self._task_size))
            metrics["clf"] += loss.item()

            # Distillation on probabilities
            distill_loss = self._distillation_config[
                "factor"] * F.binary_cross_entropy_with_logits(
                    logits[..., :-self._task_size],
                    torch.sigmoid(old_logits.detach()))
            metrics["dis"] += distill_loss.item()
            loss += distill_loss

            # Distillation on gradcam-generated attentions
            if self._attention_config:
                top_logits_indexes = logits[..., :-self._task_size].argmax(
                    dim=1)
                onehot_top_logits = utils.to_onehot(
                    top_logits_indexes,
                    self._n_classes - self._task_size).to(self._device)

                logits[..., :-self._task_size].backward(
                    gradient=onehot_top_logits, retain_graph=True)
                old_logits.backward(gradient=onehot_top_logits)

                if len(outputs["gradcam_gradients"]) > 1:
                    gradcam_gradients = torch.cat([
                        g.to(self._device)
                        for g in outputs["gradcam_gradients"]
                    ])
                    gradcam_activations = torch.cat([
                        a.to(self._device)
                        for a in outputs["gradcam_activations"]
                    ])
                else:
                    gradcam_gradients = outputs["gradcam_gradients"][0]
                    gradcam_activations = outputs["gradcam_activations"][0]

                attention_loss = losses.gradcam_distillation(
                    gradcam_gradients,
                    old_outputs["gradcam_gradients"][0].detach(),
                    gradcam_activations,
                    old_outputs["gradcam_activations"][0].detach(),
                    **self._attention_config)
                metrics["ad"] += attention_loss.item()
                loss += attention_loss

                self._old_model.zero_grad()
                self._network.zero_grad()

        return loss
Beispiel #4
0
    def _forward_loss(self, inputs, targets):
        inputs, targets = inputs.to(self._device), targets.to(self._device)
        targets = utils.to_onehot(targets, self._n_classes).to(self._device)
        logits = self._network(inputs)

        return self._compute_loss(inputs, logits, targets)
    def _compute_loss(self, inputs, outputs, targets, onehot_targets,
                      memory_flags):

        if self._args['use_sim_clr']:
            half_batch = inputs.shape[0] // 2
            inputs = inputs[:half_batch]
            memory_flags = memory_flags[:half_batch]
            onehot_targets = onehot_targets[:half_batch]
            targets = targets[:half_batch]
            outputs['raw_features'] = outputs['raw_features'][:half_batch]
            outputs['features'] = outputs['raw_features'][:half_batch]
            outputs['logits'] = outputs['logits'][:half_batch]
            outputs['raw_logits'] = outputs['raw_logits'][:half_batch]
            for i in range(len(outputs['attention'])):
                outputs['attention'][i] = outputs['attention'][i][:half_batch]

        features, logits, atts = outputs["raw_features"], outputs[
            "logits"], outputs["attention"]

        if self._post_processing_type is None:
            scaled_logits = self._network.post_process(logits)
        else:
            scaled_logits = logits * self._post_processing_type

        if self._old_model is not None:
            with torch.no_grad():
                old_outputs = self._old_model(inputs)
                old_features = old_outputs["raw_features"]
                old_atts = old_outputs["attention"]

        if self._nca_config:
            nca_config = copy.deepcopy(self._nca_config)
            if self._network.post_processor:
                nca_config["scale"] = self._network.post_processor.factor

            loss = losses.nca(logits,
                              targets,
                              memory_flags=memory_flags,
                              class_weights=self._class_weights,
                              **nca_config)
            self._metrics["nca"] += loss.item()
        elif self._softmax_ce:
            loss = F.cross_entropy(scaled_logits, targets)
            self._metrics["cce"] += loss.item()

        # --------------------
        # Distillation losses:
        # --------------------

        if self._old_model is not None:
            if self._pod_flat_config:
                if self._pod_flat_config["scheduled_factor"]:
                    factor = self._pod_flat_config[
                        "scheduled_factor"] * math.sqrt(
                            self._n_classes / self._task_size)
                else:
                    factor = self._pod_flat_config.get("factor", 1.)

                pod_flat_loss = factor * losses.embeddings_similarity(
                    old_features, features)
                loss += pod_flat_loss
                self._metrics["flat"] += pod_flat_loss.item()

            if self._pod_spatial_config:
                if self._pod_spatial_config.get("scheduled_factor", False):
                    factor = self._pod_spatial_config[
                        "scheduled_factor"] * math.sqrt(
                            self._n_classes / self._task_size)
                else:
                    factor = self._pod_spatial_config.get("factor", 1.)

                pod_spatial_loss = factor * losses.pod(
                    old_atts,
                    atts,
                    memory_flags=memory_flags.bool(),
                    task_percent=(self._task + 1) / self._n_tasks,
                    **self._pod_spatial_config)
                loss += pod_spatial_loss
                self._metrics["pod"] += pod_spatial_loss.item()

            if self._perceptual_features:
                percep_feat = losses.perceptual_features_reconstruction(
                    old_atts, atts, **self._perceptual_features)
                loss += percep_feat
                self._metrics["p_feat"] += percep_feat.item()

            if self._perceptual_style:
                percep_style = losses.perceptual_style_reconstruction(
                    old_atts, atts, **self._perceptual_style)
                loss += percep_style
                self._metrics["p_sty"] += percep_style.item()

            if self._gradcam_distil:
                top_logits_indexes = logits[..., :-self._task_size].argmax(
                    dim=1)
                try:
                    onehot_top_logits = utils.to_onehot(
                        top_logits_indexes,
                        self._n_classes - self._task_size).to(self._device)
                except:
                    import pdb
                    pdb.set_trace()

                old_logits = old_outputs["logits"]

                logits[..., :-self._task_size].backward(
                    gradient=onehot_top_logits, retain_graph=True)
                old_logits.backward(gradient=onehot_top_logits)

                if len(outputs["gradcam_gradients"]) > 1:
                    gradcam_gradients = torch.cat([
                        g.to(self._device)
                        for g in outputs["gradcam_gradients"] if g is not None
                    ])
                    gradcam_activations = torch.cat([
                        a.to(self._device)
                        for a in outputs["gradcam_activations"]
                        if a is not None
                    ])
                else:
                    gradcam_gradients = outputs["gradcam_gradients"][0]
                    gradcam_activations = outputs["gradcam_activations"][0]

                if self._gradcam_distil.get("scheduled_factor", False):
                    factor = self._gradcam_distil[
                        "scheduled_factor"] * math.sqrt(
                            self._n_classes / self._task_size)
                else:
                    factor = self._gradcam_distil.get("factor", 1.)

                try:
                    attention_loss = factor * losses.gradcam_distillation(
                        gradcam_gradients,
                        old_outputs["gradcam_gradients"][0].detach(),
                        gradcam_activations,
                        old_outputs["gradcam_activations"][0].detach())
                except:
                    import pdb
                    pdb.set_trace()

                self._metrics["grad"] += attention_loss.item()
                loss += attention_loss

                self._old_model.zero_grad()
                self._network.zero_grad()

        return loss
    def _forward_loss(self,
                      training_network,
                      inputs,
                      targets,
                      memory_flags,
                      gradcam_grad=None,
                      gradcam_act=None,
                      **kwargs):
        inputs, targets = inputs.to(self._device), targets.to(self._device)
        onehot_targets = utils.to_onehot(targets,
                                         self._n_classes).to(self._device)

        outputs = training_network(inputs)
        if gradcam_act is not None:
            outputs["gradcam_gradients"] = gradcam_grad
            outputs["gradcam_activations"] = gradcam_act
        loss = 0
        if self._args['use_sim_clr']:
            similarity_loss = self.nt_xent_loss(outputs['features'])
            loss += self._args['sim_clr_alpha'] * similarity_loss
            self._metrics['xt_xent_loss'] += self._args[
                'sim_clr_alpha'] * similarity_loss.item()

        loss = self._compute_loss(inputs, outputs, targets, onehot_targets,
                                  memory_flags)

        if self._args['sv_regularization']:
            # sv regularization goes here
            loss += self.sv_loss(training_network, self._metrics)
        else:
            # add sv values to the metrics
            with torch.no_grad():
                linear_layers_names = list(
                    filter(lambda x: "classifier" in x,
                           training_network.state_dict().keys()))
                linear_tensors = []
                for linear_layer in training_network.classifier.parameters():
                    linear_tensors.append(linear_layer)
                linear_matrix = torch.cat(linear_tensors)
                mean_linear_tensors = []
                num_proxy_per_class = self._args['classifier_config'][
                    'proxy_per_class']
                num_classes = int(linear_matrix.shape[0] / num_proxy_per_class)
                for i in range(num_classes):
                    from_ = i * num_proxy_per_class
                    to = (i + 1) * num_proxy_per_class
                    mean_linear_tensors.append(
                        torch.mean(linear_matrix[from_:to],
                                   axis=0,
                                   keepdim=True))
                linear_matrix = torch.cat(mean_linear_tensors)
                u, s, v = torch.svd(
                    torch.matmul(linear_matrix, linear_matrix.T))
                sv_ratio = s[0] / (s[-1] + 0.00001)
                sv_entropy_positive = torch.sum(
                    F.softmax(torch.sqrt(s), dim=0) *
                    F.log_softmax(torch.sqrt(s), dim=0))
                sv_entropy_negative = -torch.sum(
                    F.softmax(torch.sqrt(s), dim=0) *
                    F.log_softmax(torch.sqrt(s), dim=0))
                norm = torch.mean(torch.norm(linear_matrix, dim=1))

                self._metrics['norm'] += norm.item()
                self._metrics['sv_ratio'] += sv_ratio.item()
                self._metrics[
                    'sv_entropy_positive'] += sv_entropy_positive.item()
                self._metrics[
                    'sv_entropy_negative'] += sv_entropy_negative.item()

        #if not utils.check_loss(loss):
        #    raise ValueError("A loss is NaN: {}".format(self._metrics))

        self._metrics["loss"] += loss.item()

        return loss