def _compute_distillation_loss(self, images, labels, new_outputs):
        if self.known_classes == 0:
            return self.bce_loss(
                new_outputs, get_one_hot(labels, self.num_classes,
                                         self.device))

        sigmoid = nn.Sigmoid()
        n_old_classes = self.known_classes
        old_outputs = self.old_net(images)

        targets = get_one_hot(labels, self.num_classes, self.device)
        targets[:, :n_old_classes] = sigmoid(old_outputs)[:, :n_old_classes]
        tot_loss = self.bce_loss(new_outputs, targets)

        return tot_loss
    def _bias_training(self, eval_dataloader):
        args = utils.get_arguments()
        """bias_optimizer = torch.optim.Adam(self.bias_layer.parameters(), lr=0.01)
        scheduler_bias = torch.optim.lr_scheduler.MultiStepLR(bias_optimizer, milestones=[13], gamma=args['GAMMA'], last_epoch=-1)"""
        bias_optimizer, scheduler_bias = utils.get_otpmizer_scheduler(
            self.bias_layer.parameters(), args['LR'], args['MOMENTUM'],
            args['WEIGHT_DECAY'], [13], args['GAMMA'])
        criterion = self.criterion_bias

        if self.known_classes > 0:
            self.net.eval()
            current_step = 0
            epochs = 20

            self.bias_layer.train()
            for epoch in range(epochs):
                print(
                    f"\tSTARTING Bias Training EPOCH {epoch + 1} - LR={scheduler_bias.get_last_lr()}..."
                )

                # Iterate over the dataset
                for i, (images, labels) in enumerate(eval_dataloader):
                    # Bring data over the device of choice
                    images = images.to(self.device)
                    labels = labels.to(self.device)

                    bias_optimizer.zero_grad()  # Zero-ing the gradients

                    # Forward pass to the network and to the bias layer
                    with torch.no_grad():
                        outputs = self.net(images)
                    outputs = self.bias_forward(outputs, self.known_classes)

                    # One hot encoding labels for binary cross-entropy loss
                    labels_one_hot = utils.get_one_hot(labels,
                                                       self.num_classes,
                                                       self.device)

                    loss = criterion(outputs, labels_one_hot)

                    if i != 0 and i % 20 == 0:
                        print(
                            f"\t\tBias Training Epoch {epoch + 1}: Train_loss = {loss.item()}"
                        )

                    loss.backward()  # backward pass: computes gradients
                    bias_optimizer.step(
                    )  # update weights based on accumulated gradients
                    current_step += 1
                scheduler_bias.step()
    def compute_loss(self, images, labels, new_outputs):
        if self.loss_computer is None:
            return self._compute_distillation_loss(images, labels, new_outputs)
        else:
            class_ratio = self.class_batch_size / (self.class_batch_size +
                                                   self.known_classes)
            labels = utils.get_one_hot(labels, self.num_classes, self.device)
            class_inputs = new_outputs
            class_targets = labels
            if self.known_classes == 0:
                dist_inputs = None
                dist_targets = None
            else:
                dist_inputs = new_outputs[:, :self.known_classes]
                dist_targets = self.old_net(images)[:, :self.known_classes]

            return self.loss_computer(class_inputs, class_targets, dist_inputs,
                                      dist_targets, class_ratio)
 def compute_loss(self, labels, new_outputs):
     return self.bce_loss(new_outputs, get_one_hot(labels, self.num_classes, self.device))