def evaluation(self, dataloader):
        self.network.eval()

        test_loss = 0.0
        test_correct = 0.0

        for iter_idx, (raw_inputs, labels) in enumerate(dataloader):
            # #############################
            # data ingestion
            network_inputs = raw_inputs
            if self.device == "gpu":
                network_inputs = network_inputs.cuda()
                labels = labels.cuda()
            # #############################

            # #############################
            # forward
            logits = self.network(network_inputs)
            pred_probs = logits2probs_softmax(logits=logits)
            # #############################

            # #############################
            # loss
            if self.loss_name != "proselflc":
                loss = self.loss_criterion(
                    pred_probs=pred_probs,
                    target_probs=labels,
                )
            else:
                loss = self.loss_criterion(
                    pred_probs=pred_probs,
                    target_probs=labels,
                    cur_time=self.cur_time,
                )
            # #############################

            test_loss += loss.item()
            _, preds = pred_probs.max(1)
            _, annotations = labels.max(1)
            test_correct += preds.eq(annotations).sum()

        test_loss = test_loss / len(dataloader.dataset)
        test_accuracy = test_correct.item() / len(dataloader.dataset)

        return test_loss, test_accuracy
Exemplo n.º 2
0
    def test_epsilon0_0(self):
        # FashionMNIST Datasets for training/test
        trn_ds = datasets.FashionMNIST(
            "./datasets",
            download=True,
            transform=transforms.Compose([transforms.ToTensor()]),
        )
        # Dataloader for training/test
        batch_size = 4
        trn_dl = DataLoader(trn_ds, batch_size=batch_size, shuffle=True)

        # paramaters initialization
        input_dim = 784  # 28x28 FashionMNIST data
        output_dim = 10
        w_init = np.random.normal(scale=0.05, size=(input_dim, output_dim))
        w_init = torch.tensor(w_init, requires_grad=True).float()
        b = torch.zeros(output_dim)

        for iteration, (X, y) in enumerate(trn_dl):
            logits = my_model(X, input_dim, w_init, b)
            # I would like to use probs to calculate losses
            # for ProSelfLC, CCE, LS, LS, LC, etc.
            probs = logits2probs_softmax(logits)
            y_vector = np.zeros((len(y), output_dim), dtype=np.float32)
            for i in range(len(y)):
                y_vector[i][y[i]] = 1
            y_vector = torch.tensor(y_vector)

            loss1 = self.loss_criterion1(probs, y_vector)
            loss2 = self.loss_criterion2(probs, y_vector)
            loss3 = self.loss_criterion3(probs, y_vector)

            print(f"loss1: {loss1.item()} " + f" versus loss2: {loss2.item()}")
            print("logits shape: " + str(logits.shape))
            print("label shape: " + str(y.shape))
            print("distribution shape: " + str(y_vector.shape))
            print("probs shape: " + str(probs.shape))

            print("gradient check")
            # logits.retain_grad()  # for intermediate variables

            # register_hook for logits
            logits.register_hook(set_grad(logits))

            loss1.backward(retain_graph=True)
            loss1_logit_grad = logits.grad

            # clear out the gradients of Variables
            # (i.e. W, b)
            # W.grad.data.zero_()
            # b.grad.data.zero_()
            # zero out so not to acculumate
            # zero out so not to acculumate
            # logits.grad.data.zero_()

            loss2.backward(retain_graph=True)
            loss2_logit_grad = logits.grad

            self.assertTrue(
                torch.all(
                    torch.lt(
                        torch.abs(torch.add(loss1_logit_grad, -loss2_logit_grad)), 1e-4
                    )
                )
            )

            loss3.backward(retain_graph=True)
            loss3_logit_grad = logits.grad
            with torch.no_grad():
                H_pred_probs = torch.sum(-(probs + 1e-12) * torch.log(probs + 1e-12), 1)
                H_pred_probs = torch.reshape(H_pred_probs, (4, 1))
                H_pred_probs = H_pred_probs.repeat(1, output_dim)
            logit_grad_derived = (
                (1 - self.params["epsilon"]) * (probs - y_vector)
                - self.params["epsilon"]
                * probs
                * (torch.log(probs + 1e-12) + H_pred_probs)
            ) / batch_size

            self.assertTrue(
                torch.all(
                    torch.lt(
                        torch.abs(torch.add(loss3_logit_grad, -logit_grad_derived)),
                        1e-4,
                    )
                )
            )
    def train_one_epoch(self, epoch: int, dataloader) -> None:
        self.network.train()  # self.network.train(mode=True)

        for batch_index, (raw_inputs, labels) in enumerate(dataloader):
            # #############################
            # track time for proselflc
            if self.loss_name == "proselflc":
                if self.counter == "epoch":
                    self.cur_time = epoch
                else:
                    # epoch counter to iteration counter
                    self.cur_time = (epoch -
                                     1) * len(dataloader) + batch_index + 1
            # #############################

            # #############################
            # data ingestion
            network_inputs = raw_inputs
            if self.device == "gpu":
                network_inputs = network_inputs.cuda()
                labels = labels.cuda()
            # #############################

            # #############################
            # forward
            logits = self.network(network_inputs)
            pred_probs = logits2probs_softmax(logits=logits)
            # #############################

            # #############################
            # loss
            if self.loss_name != "proselflc":
                loss = self.loss_criterion(
                    pred_probs=pred_probs,
                    target_probs=labels,
                )
            else:
                loss = self.loss_criterion(
                    pred_probs=pred_probs,
                    target_probs=labels,
                    cur_time=self.cur_time,
                )
            # #############################

            # backward
            self.optim.optimizer.zero_grad()
            if self.params["loss_name"] == "dm_exp_pi":
                # ########################################################
                # Implementation for derivative manipulation + Improved MAE
                # Novelty: From Loss Design to Derivative Design
                # Our work inspired: ICML-2020 (Normalised Loss Functions)
                # and ICML-2021 (Asymmetric Loss Functions)
                # ########################################################
                # remove orignal weights
                p_i = pred_probs[labels.nonzero(as_tuple=True)][:, None]
                logit_grad_derived = (pred_probs -
                                      labels) / (2.0 * (1.0 - p_i) + 1e-8)
                # add new weight: derivative manipulation or IMAE
                logit_grad_derived *= torch.exp(
                    self.params["dm_beta"] * (1.0 - p_i) *
                    torch.pow(p_i + 1e-8, self.params["dm_lambda"]))
                # derivative normalisation,
                # which inspired the ICML-2020 paper-Normalised Loss Functions
                sum_weight = sum(
                    torch.exp(self.params["dm_beta"] * (1.0 - p_i) *
                              torch.pow(p_i + 1e-8, self.params["dm_lambda"])))
                logit_grad_derived /= sum_weight
                logits.backward(logit_grad_derived)
            else:
                loss.backward()

            # update params
            self.optim.optimizer.step()
            # #############################

            # warmup iteration-wise lr scheduler
            if epoch <= self.warmup_epochs:
                self.optim.warmup_scheduler.step()