def calc_loss(pred, target, metrics, dataset, phase='train', bce_weight=0.3): if dataset == "roof": bce = F.binary_cross_entropy_with_logits(pred, target) pred = torch.sigmoid(pred) else: bce = F.binary_cross_entropy_with_logits(pred, target) pred = torch.exp(pred) # convering tensor to numpy to remove from the computationl graph if phase == 'test': pred = (pred > 0.50).float() # with 0.55 is a little better dice = dice_loss(pred, target) jaccard_loss = metric_jaccard(pred, target) loss = bce * bce_weight + dice * (1 - bce_weight) metrics['bce'] = bce.data.cpu().numpy() * target.size(0) metrics['loss'] = loss.data.cpu().numpy() * target.size(0) metrics['dice'] = 1 - dice.data.cpu().numpy() * target.size(0) metrics[ 'jaccard'] = 1 - jaccard_loss.data.cpu().numpy() * target.size(0) else: dice = dice_loss(pred, target) jaccard_loss = metric_jaccard(pred, target) loss = bce * bce_weight + dice * (1 - bce_weight) metrics['bce'] = bce.data.cpu().numpy() * target.size(0) metrics['loss'] += loss.data.cpu().numpy() * target.size(0) metrics['dice_loss'] += dice.data.cpu().numpy() * target.size(0) metrics['jaccard_loss'] += jaccard_loss.data.cpu().numpy( ) * target.size(0) return loss
def test_step(self, batch, batch_idx): inputs, targets = self.prepare_batch(batch) # print(f"training input range: {torch.min(inputs)} - {torch.max(inputs)}") logits = self(inputs) logits = F.interpolate(logits, size=logits.size()[2:]) probs = torch.sigmoid(logits) dice, iou, _, _ = get_score(probs, targets) if batch_idx != 0 and batch_idx % 50 == 0: # save total about 10 picture input = inputs.chunk(inputs.size()[0], 0)[0] # split into 1 in the dimension 0 target = targets.chunk(targets.size()[0], 0)[0] # split into 1 in the dimension 0 logit = probs.chunk(logits.size()[0], 0)[0] # split into 1 in the dimension 0 log_all_info(self, input, target, logit, batch_idx, "testing") # loss = F.binary_cross_entropy_with_logits(logits, targets) loss = dice_loss(probs, targets) dice, iou, sensitivity, specificity = get_score(probs, targets) return { 'test_step_loss': loss, 'test_step_dice': dice, 'test_step_IoU': iou, 'test_step_sensitivity': sensitivity, 'test_step_specificity': specificity }
def training_step(self, batch, batch_idx): inputs, targets = self.prepare_batch(batch) # print(f"training input range: {torch.min(inputs)} - {torch.max(inputs)}") logits = self(inputs) probs = torch.sigmoid(logits) dice, iou, _, _ = get_score(probs, targets) if batch_idx != 0 and ((self.current_epoch >= 1 and dice.item() < 0.5) or batch_idx % 100 == 0): input = inputs.chunk(inputs.size()[0], 0)[0] # split into 1 in the dimension 0 target = targets.chunk(targets.size()[0], 0)[0] # split into 1 in the dimension 0 prob = probs.chunk(logits.size()[0], 0)[0] # split into 1 in the dimension 0 dice_score, _, _, _ = get_score(torch.unsqueeze(prob, 0), torch.unsqueeze(target, 0)) log_all_info(self, input, target, prob, batch_idx, "training", dice_score.item()) # loss = F.binary_cross_entropy_with_logits(logits, targets) loss = dice_loss(probs, targets) tensorboard_logs = { "train_loss": loss, "train_IoU": iou, "train_dice": dice } return {'loss': loss, "log": tensorboard_logs}
def calc_loss(pred, target, metrics, bce_weight=0.5): bce = F.binary_cross_entropy_with_logits(pred, target) pred = F.sigmoid(pred) dice = dice_loss(pred, target) loss = bce * bce_weight + dice * (1 - bce_weight) metrics['bce'] += bce.data.cpu().numpy() * target.size(0) metrics['dice'] += dice.data.cpu().numpy() * target.size(0) metrics['loss'] += loss.data.cpu().numpy() * target.size(0) return loss
def validation_step(self, batch, batch_id): inputs, targets = self.prepare_batch(batch) # print(f"input shape: {inputs.shape}, targets shape: {targets.shape}") # print(f"validation input range: {torch.min(inputs)} - {torch.max(inputs)}") logits = self(inputs) probs = torch.sigmoid(logits) # compare the position # loss = F.binary_cross_entropy_with_logits(logits, targets) loss = dice_loss(probs, targets) dice, iou, sensitivity, specificity = get_score(probs, targets) return { 'val_step_loss': loss, 'val_step_dice': dice, 'val_step_IoU': iou, "val_step_sensitivity": sensitivity, "val_step_specificity": specificity }
def epoch_dice_loss(self, **kw): score_acc = ScoreAccumulator() running_loss = 0.0 for i, data in enumerate(kw['data_loader'], 1): inputs, labels = data['inputs'].to( self.device).float(), data['labels'].to(self.device).long() # weights = data['weights'].to(self.device) self.optimizer.zero_grad() outputs = self.model(inputs) _, predicted = torch.max(outputs, 1) # Balancing imbalanced class as per computed weights from the dataset # w = torch.FloatTensor(2).random_(1, 100).to(self.device) # wd = torch.FloatTensor(*labels.shape).uniform_(0.1, 2).to(self.device) loss = dice_loss(outputs[:, 1, :, :], labels, beta=rd.choice(np.arange(1, 2, 0.1).tolist())) loss.backward() self.optimizer.step() current_loss = loss.item() running_loss += current_loss p, r, f1, a = score_acc.reset().add_tensor(predicted, labels).get_prfa() if i % self.log_frequency == 0: print( 'Epochs[%d/%d] Batch[%d/%d] loss:%.5f pre:%.3f rec:%.3f f1:%.3f acc:%.3f' % (kw['epoch'], self.epochs, i, kw['data_loader'].__len__(), running_loss / self.log_frequency, p, r, f1, a)) running_loss = 0.0 self.flush( self.train_logger, ','.join( str(x) for x in [0, kw['epoch'], i, p, r, f1, a, current_loss]))
def epoch_dice_loss(self, **kw): score_acc = ScoreAccumulator() if self.model.training else kw.get('score_acc') assert isinstance(score_acc, ScoreAccumulator) running_loss = 0.0 for i, data in enumerate(kw['data_loader'], 1): inputs, labels = data['inputs'].to(self.device).float(), data['labels'].to(self.device).long() if self.model.training: self.optimizer.zero_grad() outputs = F.softmax(self.model(inputs), 1) _, predicted = torch.max(outputs, 1) loss = dice_loss(outputs[:, 1, :, :], labels, beta=rd.choice(np.arange(1, 2, 0.1).tolist())) if self.model.training: loss.backward() self.optimizer.step() current_loss = loss.item() running_loss += current_loss if self.model.training: score_acc.reset() p, r, f1, a = score_acc.add_tensor(predicted, labels).get_prfa() if i % self.log_frequency == 0: print('Epochs[%d/%d] Batch[%d/%d] loss:%.5f pre:%.3f rec:%.3f f1:%.3f acc:%.3f' % ( kw['epoch'], self.epochs, i, kw['data_loader'].__len__(), running_loss / self.log_frequency, p, r, f1, a)) running_loss = 0.0 self.flush(kw['logger'], ','.join(str(x) for x in [0, kw['epoch'], i, p, r, f1, a, current_loss]))
def train(epoch, model, criterion, optimizer, writer, opt, data_loader): model['densesharp'].train() model['sat'].train() total_cross_loss = 0 total_dice_loss = 0 eval_loss = 0 correct = 0 positive_target = np.zeros(len(data_loader.dataset)) positive_score = np.zeros(len(data_loader.dataset)) optimizer.zero_grad() for batch_idx, (data, target, segment_target) in enumerate(tqdm(data_loader)): if torch.cuda.is_available(): data, target, segment_target = data.cuda(), target.cuda( ), segment_target.cuda() # forward the models output, features, segment_output = model['densesharp'](data) n_segment = segment2n_segment(segment_target, opt.n_sat) batch_features = model['feature2SAT'](features, n_segment) batch_features = batch_features + model['p_encoder'](n_segment) output_SAT = model['sat'](batch_features) # get the loss indiv_cross_loss = criterion(output_SAT, target) indiv_dice_loss = dice_loss(segment_output, segment_target) loss = indiv_cross_loss + 0.2 * indiv_dice_loss loss.backward() total_cross_loss += indiv_cross_loss.item() total_dice_loss += indiv_dice_loss.item() eval_loss += loss.item() pred = output_SAT.data.max(1, keepdim=True)[1] correct += pred.eq(target.data.view_as(pred)).cpu().sum() possi = F.softmax(output_SAT, dim=1) for i, t in enumerate(target): pos = opt.batch_size * batch_idx + i positive_target[pos] = target.data[i] positive_score[pos] = possi.cpu().data[i][0] if (batch_idx + 1) % divide_batch == 0: optimizer.step() optimizer.zero_grad() if batch_idx % 10 == 0: log_tmp = 'Train Epoch: {} [{}/{} ({:.0f}%)]\tCrossEntropyLoss: {:.6f}\tDiceLoss: {:.6f}\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(data_loader.dataset), 100. * batch_idx / len(data_loader), indiv_cross_loss.item() / opt.batch_size, indiv_dice_loss.item(), loss.item()) print(log_tmp) if opt.write_log: with open("./log/{}.txt".format(opt.file_name), "a") as log: log.write('{}\n'.format(log_tmp)) eval_loss /= len(data_loader.dataset) total_cross_loss /= len(data_loader.dataset) total_dice_loss = total_dice_loss / (len(data_loader.dataset) / opt.batch_size) log_tmp = 'Eval Epoch:{} CrossEntropyLoss:{:.6f} DiceLoss:{:.6f} Average loss: {:.6f}, Accuracy: {}/{} ({:.6f}%)'.format( epoch, total_cross_loss, total_dice_loss, eval_loss, correct, len(data_loader.dataset), 100. * float(correct) / len(data_loader.dataset)) print(log_tmp) # draw the ROC curve fpr, tpr, thresholds = roc_curve(positive_target, positive_score, pos_label=0) roc_auc = auc(fpr, tpr) print('Train_AUC = %.8f' % roc_auc) if opt.write_log: with open( "./data/{}/epoch{}_train_fpr.json".format( opt.file_name, epoch), "w") as f: json.dump(fpr.tolist(), f) with open( "./data/{}/epoch{}_train_tpr.json".format( opt.file_name, epoch), "w") as f: json.dump(tpr.tolist(), f) with open("./log/{}.txt".format(opt.file_name), "a") as log: log.write('{}\n'.format(log_tmp)) log.write('Train_AUC = %.8f\n' % roc_auc) writer.add_scalar('Train_AUC', roc_auc, epoch) writer.add_scalar('Train_CrossEntropyLoss', total_cross_loss, epoch) writer.add_scalar('Train_DiceLoss', total_dice_loss, epoch) writer.add_scalar('Eval_loss', eval_loss, epoch) writer.add_scalar('Eval_accuracy', 100. * float(correct) / len(data_loader.dataset), epoch) torch.save( model['densesharp'].state_dict(), './model_saved/{}/densesharp_{}_epoch_{}_dict.pkl'.format( opt.file_name, opt.file_name, epoch)) torch.save( model['sat'].state_dict(), './model_saved/{}/sat_{}_epoch_{}_dict.pkl'.format( opt.file_name, opt.file_name, epoch))
def test(epoch, model, criterion, writer, opt, data_loader, dichotomy=False): model['densesharp'].eval() model['sat'].eval() total_cross_loss = 0 total_dice_loss = 0 test_loss = 0 correct = 0 positive_target = np.zeros(len(data_loader.dataset)) positive_score = np.zeros(len(data_loader.dataset)) for index, (data, target, segment_target) in enumerate(tqdm(data_loader)): if torch.cuda.is_available(): data, target, segment_target = data.cuda(), target.cuda( ), segment_target.cuda() # forward the models output, features, segment_output = model['densesharp'](data) n_segment = segment2n_segment(segment_target, opt.n_sat) batch_features = model['feature2SAT'](features, n_segment) batch_features = batch_features + model['p_encoder'](n_segment) output_SAT = model['sat'](batch_features) # sum up batch loss indiv_cross_loss = criterion(output_SAT, target) indiv_dice_loss = dice_loss(segment_output, segment_target) loss = indiv_cross_loss + 0.2 * indiv_dice_loss total_cross_loss += indiv_cross_loss.item() total_dice_loss += indiv_dice_loss.item() test_loss += loss.item() # get the index of the max log-probability pred = output_SAT.data.max(1, keepdim=True)[1] correct += pred.eq(target.data.view_as(pred)).cpu().sum() if dichotomy: possi = F.softmax(output_SAT, dim=1) for i, t in enumerate(target): pos = opt.batch_size * index + i positive_target[pos] = target.data[i] positive_score[pos] = possi.cpu().data[i][0] total_cross_loss /= len(data_loader.dataset) total_dice_loss /= (len(data_loader.dataset) / opt.batch_size) test_loss /= len(data_loader.dataset) log_tmp = 'Test epoch:{} CrossEntropyLoss:{:.6f} DiceLoss:{:.6f} Average loss:{:.6f}, Accuracy: {}/{} ({:.6f}%)'.format( epoch, total_cross_loss, total_dice_loss, test_loss, correct, len(data_loader.dataset), 100. * float(correct) / len(data_loader.dataset)) print(log_tmp) # draw the ROC curve fpr, tpr, thresholds = roc_curve(positive_target, positive_score, pos_label=0) roc_auc = auc(fpr, tpr) print('AUC = %.8f' % roc_auc) if opt.write_log: with open("./log/{}.txt".format(opt.file_name), "a") as log: log.write('{}\n'.format(log_tmp)) writer.add_scalar('Test_AUC', roc_auc, epoch) writer.add_scalar('Test_loss', test_loss, epoch) writer.add_scalar('Test_CrossEntropyLoss', total_cross_loss, epoch) writer.add_scalar('Test_DiceLoss', total_dice_loss, epoch) writer.add_scalar('Test_accuracy', 100. * float(correct) / len(data_loader.dataset), epoch) if dichotomy and opt.write_log: with open("./log/{}.txt".format(opt.file_name), "a") as log: log.write('Test_AUC = %.8f\n' % roc_auc) with open( "./data/{}/epoch{}_test_fpr.json".format(opt.file_name, epoch), "w") as f: json.dump(fpr.tolist(), f) with open( "./data/{}/epoch{}_test_tpr.json".format(opt.file_name, epoch), "w") as f: json.dump(tpr.tolist(), f)