Esempio n. 1
0
    def test_compute_center_loss(self):

        loss = torch.mean(
            (self.features[(0, 2, 1), :] - self.centers[(1, 1, 3), :])**2)

        self.assertEqual(
            loss, compute_center_loss(self.features, self.centers,
                                      self.targets))
Esempio n. 2
0
    def run_epoch(self, mode):
        if mode == 'train':
            dataloader = self.training_dataloader
            loss_recorder = self.training_losses
            self.model.train()
        else:
            dataloader = self.validation_dataloader
            loss_recorder = self.validation_losses
            self.model.eval()

        total_cross_entropy_loss = 0
        total_center_loss = 0
        total_loss = 0
        total_top1_matches = 0
        total_top3_matches = 0
        batch = 0

        with torch.set_grad_enabled(mode == 'train'):
            for images, targets, names in dataloader:
                batch += 1
                targets = torch.tensor(targets).to(device)
                images = images.to(device)
                centers = self.model.centers

                logits, features = self.model(images)

                cross_entropy_loss = torch.nn.functional.cross_entropy(
                    logits, targets)
                center_loss = compute_center_loss(features, centers, targets)
                loss = self.lamda * center_loss + cross_entropy_loss

                print("[{}:{}] cross entropy loss: {:.8f} - center loss: "
                      "{:.8f} - total weighted loss: {:.8f}".format(
                          mode, self.current_epoch, cross_entropy_loss.item(),
                          center_loss.item(), loss.item()))

                total_cross_entropy_loss += cross_entropy_loss
                total_center_loss += center_loss
                total_loss += loss

                if mode == 'train':
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()

                    # make features untrack by autograd, or there will be
                    # a memory leak when updating the centers
                    center_deltas = get_center_delta(features.data, centers,
                                                     targets, self.alpha)
                    self.model.centers = centers - center_deltas

                # compute acc here
                total_top1_matches += self._get_matches(targets, logits, 1)
                total_top3_matches += self._get_matches(targets, logits, 3)

            center_loss = total_center_loss / batch
            cross_entropy_loss = total_cross_entropy_loss / batch
            loss = center_loss + cross_entropy_loss
            top1_acc = total_top1_matches / len(dataloader.dataset)
            top3_acc = total_top3_matches / len(dataloader.dataset)

            loss_recorder['center'].append(total_center_loss / batch)
            loss_recorder['cross_entropy'].append(cross_entropy_loss)
            loss_recorder['together'].append(total_loss / batch)
            loss_recorder['top1acc'].append(top1_acc)
            loss_recorder['top3acc'].append(top3_acc)

            print("[{}:{}] finished. cross entropy loss: {:.8f} - "
                  "center loss: {:.8f} - together: {:.8f} - "
                  "top1 acc: {:.4f} % - top3 acc: {:.4f} %".format(
                      mode, self.current_epoch, cross_entropy_loss.item(),
                      center_loss.item(), loss.item(), top1_acc * 100,
                      top3_acc * 100))
 while epoch < EPOCH:
     if epoch in [35, 45]:
         for param_group in optimizer.param_groups:
             param_group['lr'] = LR * 0.1
     cnn.train()
     for step, (b_x, b_y) in enumerate(train_loader):
         # mnist_vis(b_x, b_y)
         b_x = b_x.cuda()
         b_y = b_y.cuda()
         predictions, features = cnn(b_x)
         cel = loss_func(predictions, b_y)
         centers = cnn.centers
         # print(loss)
         # if epoch > 160:
         # cel = ohkpm(cel, 24)
         center_loss = compute_center_loss(features, centers, b_y)
         loss = LAMBDA * center_loss + cel
         center_deltas = get_center_delta(features.data, centers, b_y,
                                          ALPHA)
         cnn.centers = centers - center_deltas
         # else:
         # loss = ohkpm(loss, 24)
         optimizer.zero_grad()
         loss.backward()
         optimizer.step()
         # loss = loss.cpu()
         # loss_avg.val = loss.data.numpy()
         if step % 500 == 0:
             print('Step: ', step, '| class loss: %.8f' % cel, step,
                   '| center loss: %.8f' % center_loss)
     cnn.eval()