def test_get_center_delta(self): result = get_center_delta(self.features, self.centers, self.targets, self.alpha) # size should match self.assertTrue(result.size() == self.centers.size()) # for class 1 class1_result = -((self.features[0] + self.features[2]) - 2 * self.centers[1]) / 3 * self.alpha self.assertEqual(3, torch.sum(result[1] == class1_result).item()) # for class 3 class3_result = -(self.features[1] - self.centers[3]) / 2 * self.alpha self.assertEqual(3, torch.sum(result[3] == class3_result).item()) # others should all be zero sum_others = torch.sum(result[(0, 2), :]).item() self.assertEqual(0, sum_others)
def run_epoch(self, mode): if mode == 'train': dataloader = self.training_dataloader loss_recorder = self.training_losses self.model.train() else: dataloader = self.validation_dataloader loss_recorder = self.validation_losses self.model.eval() total_cross_entropy_loss = 0 total_center_loss = 0 total_loss = 0 total_top1_matches = 0 total_top3_matches = 0 batch = 0 with torch.set_grad_enabled(mode == 'train'): for images, targets, names in dataloader: batch += 1 targets = torch.tensor(targets).to(device) images = images.to(device) centers = self.model.centers logits, features = self.model(images) cross_entropy_loss = torch.nn.functional.cross_entropy( logits, targets) center_loss = compute_center_loss(features, centers, targets) loss = self.lamda * center_loss + cross_entropy_loss print("[{}:{}] cross entropy loss: {:.8f} - center loss: " "{:.8f} - total weighted loss: {:.8f}".format( mode, self.current_epoch, cross_entropy_loss.item(), center_loss.item(), loss.item())) total_cross_entropy_loss += cross_entropy_loss total_center_loss += center_loss total_loss += loss if mode == 'train': self.optimizer.zero_grad() loss.backward() self.optimizer.step() # make features untrack by autograd, or there will be # a memory leak when updating the centers center_deltas = get_center_delta(features.data, centers, targets, self.alpha) self.model.centers = centers - center_deltas # compute acc here total_top1_matches += self._get_matches(targets, logits, 1) total_top3_matches += self._get_matches(targets, logits, 3) center_loss = total_center_loss / batch cross_entropy_loss = total_cross_entropy_loss / batch loss = center_loss + cross_entropy_loss top1_acc = total_top1_matches / len(dataloader.dataset) top3_acc = total_top3_matches / len(dataloader.dataset) loss_recorder['center'].append(total_center_loss / batch) loss_recorder['cross_entropy'].append(cross_entropy_loss) loss_recorder['together'].append(total_loss / batch) loss_recorder['top1acc'].append(top1_acc) loss_recorder['top3acc'].append(top3_acc) print("[{}:{}] finished. cross entropy loss: {:.8f} - " "center loss: {:.8f} - together: {:.8f} - " "top1 acc: {:.4f} % - top3 acc: {:.4f} %".format( mode, self.current_epoch, cross_entropy_loss.item(), center_loss.item(), loss.item(), top1_acc * 100, top3_acc * 100))
for param_group in optimizer.param_groups: param_group['lr'] = LR * 0.1 cnn.train() for step, (b_x, b_y) in enumerate(train_loader): # mnist_vis(b_x, b_y) b_x = b_x.cuda() b_y = b_y.cuda() predictions, features = cnn(b_x) cel = loss_func(predictions, b_y) centers = cnn.centers # print(loss) # if epoch > 160: # cel = ohkpm(cel, 24) center_loss = compute_center_loss(features, centers, b_y) loss = LAMBDA * center_loss + cel center_deltas = get_center_delta(features.data, centers, b_y, ALPHA) cnn.centers = centers - center_deltas # else: # loss = ohkpm(loss, 24) optimizer.zero_grad() loss.backward() optimizer.step() # loss = loss.cpu() # loss_avg.val = loss.data.numpy() if step % 500 == 0: print('Step: ', step, '| class loss: %.8f' % cel, step, '| center loss: %.8f' % center_loss) cnn.eval() for step, (b_x, b_y) in enumerate(test_loader): b_x = b_x.cuda() b_y = b_y.cuda()