def bce_loss(self, outputs, labels, encode=False, row_start=None, row_end=None, col_start=None, col_end=None): criterion = nn.BCEWithLogitsLoss(reduction = 'mean') if encode: labels = utils._one_hot_encode(labels, self.n_classes, self.reverse_index, device=self.DEVICE) labels = labels.type_as(outputs) return criterion(outputs[row_start:row_end, col_start:col_end], labels[row_start:row_end, col_start:col_end])
def l2_loss(self, outputs, labels, encode=False, row_start=None, row_end=None, col_start=None, col_end=None): criterion = nn.MSELoss(reduction = 'mean') if encode: labels = utils._one_hot_encode(labels, self.n_classes, self.reverse_index, device=self.DEVICE) labels = labels.type_as(outputs) loss_val = criterion(outputs[row_start:row_end, col_start:col_end], labels[row_start:row_end, col_start:col_end]) return self.limit_loss(loss_val)
def update_representation(self, dataset, train_dataset_big, new_classes): # 1 - retrieve the classes from the dataset (which is the current train_subset) # 2 - retrieve the new classes # 1,2 are done in the main_icarl #gc.collect() # 3 - increment classes # (add output nodes) # (update n_classes) # 5 store network outputs with pre-update parameters self.increment_classes(len(new_classes)) # 4 - combine current train_subset (dataset) with exemplars # to form a new augmented train dataset # join the datasets exemplars_dataset = self.build_exemplars_dataset(train_dataset_big) # if len(exemplars_dataset) > 0: augmented_dataset = ConcatDataset(dataset, exemplars_dataset) #augmented_dataset = utils.joinSubsets(train_dataset_big, [dataset, exemplars_dataset]) else: augmented_dataset = dataset # first iteration # 6 - run network training, with loss function net = optimizer = optim.SGD(net.parameters(), lr=self.LR, weight_decay=self.WEIGHT_DECAY, momentum=self.MOMENTUM) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.MILESTONES, gamma=self.GAMMA, last_epoch=-1) criterion = utils.getLossCriterion() cudnn.benchmark # Calling this optimizes runtime net = # define the loader for the augmented_dataset loader = DataLoader(augmented_dataset, batch_size=self.BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True) if len(self.exemplar_sets) > 0: old_net = copy.deepcopy(net) for epoch in range(self.NUM_EPOCHS): print("NUM_EPOCHS: ", epoch, "/", self.NUM_EPOCHS) for _, images, labels in loader: # Bring data over the device of choice images = labels = net.train() # PyTorch, by default, accumulates gradients after each backward pass # We need to manually set the gradients to zero before starting a new iteration optimizer.zero_grad() # Zero-ing the gradients # Forward pass to the network outputs = net(images) labels_one_hot = utils._one_hot_encode(labels, self.n_classes, self.reverse_index, device=self.DEVICE) labels_one_hot = labels_one_hot.type_as(outputs) # Loss = only classification on new classes if len(self.exemplar_sets) == 0: loss = criterion(outputs, labels_one_hot) # Distilation loss for old classes, class loss on new classes if len(self.exemplar_sets) > 0: # print('outputs', outputs.size()) # print('labels_one_hot', labels_one_hot.size()) labels_one_hot = labels_one_hot.type_as(outputs) out_old = torch.sigmoid( old_net(images) ) # Variable(torch.sigmoid(old_net(images)),requires_grad = False) #[outputold, onehot_new] target =[:, :self.n_known], labels_one_hot[:, self.n_known:]), dim=1) loss = criterion(outputs, target) print('original loss', loss.item()) loss1 = criterion(outputs[:, self.n_known:], labels_one_hot[:, self.n_known:]) loss2 = criterion(outputs[:, :self.n_known], out_old[:, :self.n_known]) alpha = self.n_known / self.n_classes splittedloss = loss1 + loss2 splittedloss2 = (1 - alpha) * loss1 + alpha * loss2 print('summed loss1', splittedloss.item(), loss1.item(), loss2.item()) print('summed loss2', splittedloss2.item(), (1 - alpha) * loss1.item(), alpha * loss2.item()) donDistLoss = sum( criterion(outputs[:, y], out_old[:, y]) for y in range(self.n_known)) print('donlee dist loss', donDistLoss.item()) CE = nn.CrossEntropyLoss() donClassLoss = CE(outputs, labels) print('donlee class loss', donClassLoss.item()) print('donlee loss (CE + BCE)', (donClassLoss + donDistLoss).item()) print('donlee loss (CE + BCE) rebalanced', ((1 - alpha) * donClassLoss + alpha * donDistLoss).item()) l2_loss1 = self.l2_class_loss(outputs, labels) l2_loss2 = self.l2_dist_loss(outputs, out_old) l2_loss = (1 - alpha) * l2_loss1 + alpha * l2_loss2 print('L2 loss', (l2_loss1 + l2_loss2).item(), l2_loss1.item(), l2_loss2.item()) print('L2 loss rebalanced', l2_loss.item(), (1 - alpha) * l2_loss1.item(), alpha * l2_loss2.item()) self_loss1 = self.class_loss(outputs, labels, col_start=self.n_known) self_loss2 = self.dist_loss(outputs, out_old, col_end=self.n_known) print('Self losses', (self_loss1 + self_loss2).item(), self_loss1.item(), self_loss2.item()) print() loss.backward() optimizer.step() scheduler.step() print("LOSS: ", loss.item()) = copy.deepcopy(net) self.feature_extractor = copy.deepcopy(net) self.feature_extractor.fc = nn.Sequential() #cleaning del net torch.cuda.empty_cache()