def validation(self, epoch): self.net.eval() acc = 0. # Accuracy SE = 0. # Sensit ivity (Recall) SP = 0. # Specificity PC = 0. # Precision F1 = 0. # F1 Score JS = 0. # Jaccard Similarity DC = 0. # Dice Coefficient length = 0 for i, (imgs, probs, gts) in enumerate(self.val_loader): imgs = imgs.to(self.device) probs = probs.to(self.device) gts = gts.round().long().to(self.device) outputs = self.net(imgs,probs) # weight = np.array( # [0., 100., 100., 100., 50., 50., 80., 80., 50., 80., 80., 80., 50., 50., 70., 70., 70., 70., # 60., 60., 100., 100., 100., ]) # weight = torch.tensor( # [1., 100., 100., 50., 80., 50., 80., 80., 50., 70., 70., # 60., 100., 100., ]).to(self.device) ious = MulticlassJaccardLoss(classes=list(range(23)))(outputs, gts.reshape(-1, 256, 256)) # ious = jaccard_similarity_score(gts.detach().cpu().squeeze().numpy().reshape(-1) # , outputs.detach().cpu().squeeze().argmax(dim=1).numpy().reshape(-1))*imgs.size(0) DC += ious length += imgs.size(0) DC = DC / length score = DC print('[Validation] DC: %.4f' % ( DC)) # save the best net model if score < self.best_score: # 算的其实是loss 保小的 self.best_score = score self.best_epoch = epoch print('Best %s model score: %.4f'%(self.model_type, self.best_score)) torch.save(self.net.state_dict(), self.net_path)
def test(self): del self.net self.build_model() self.net.load_state_dict(torch.load(self.net_path)) self.net.eval() DC = 0. # Dice Coefficient length = 0 for i, (imgs, probs, gts) in enumerate(self.test_loader): imgs = imgs.to(self.device) probs = probs.to(self.device) gts = gts.round().long().to(self.device) outputs = self.net(imgs, probs) ious = MulticlassJaccardLoss(classes=list(range(23)))(outputs, gts.reshape( -1, 256, 256)) # ious = jaccard_similarity_score(gts.detach().cpu().squeeze().numpy().reshape(-1) # , outputs.detach().cpu().squeeze().argmax(dim=1).numpy().reshape(-1))*imgs.size(0) DC += ious length += imgs.size(0) # weight = np.array( # [0., 100., 100., 100., 50., 50., 80., 80., 50., 80., 80., 80., 50., 50., 70., 70., 70., 70., # 60., 60., 100., 100., 100., ]) # ious = IoU(gts.detach().cpu().squeeze().numpy().reshape(-1), # outputs.detach().cpu().squeeze().argmax(dim=0).numpy().reshape(-1), num_classes=14) * imgs.size( # 0) # DC += np.array(ious[1:]).mean() length += imgs.size(0) DC = DC / length score = DC f = open(os.path.join(self.csv_path, 'result.csv'), 'a', encoding='utf8', newline='') wr = csv.writer(f) wr.writerow([ self.model_type, DC, self.lr, self.best_epoch, self.num_epochs, self.num_epochs_decay, self.augmentation_prob, self.batch_size, self.comment ]) f.close()
def train(self, epoch): self.net.train(True) # Decay learning rate if (epoch + 1) > (self.num_epochs - self.num_epochs_decay): self.decayed_lr -= (self.lr / float(self.num_epochs_decay)) for param_group in self.optimizer.param_groups: param_group['lr'] = self.decayed_lr print('epoch{}: Decay learning rate to lr: {}.'.format(epoch, self.decayed_lr)) epoch_loss = 0 acc = 0. # Accuracy SE = 0. # Sensitivity (Recall) SP = 0. # Specificity PC = 0. # Precision F1 = 0. # F1 Score JS = 0. # Jaccard Similarity DC = 0. # Dice Coefficient length = 0 for i, (img, prob , gt) in enumerate(tqdm(self.train_loader)): img = img.to(self.device) prob = prob.to(self.device) gt = gt.round().long().to(self.device) self.optimizer.zero_grad() outputs = self.net(img, prob) # make sure shapes are the same by flattening them # weight = torch.tensor([1.,100.,100.,100.,50.,50.,80.,80.,50.,80.,80.,80.,50.,50.,70.,70.,70.,70., # 60.,60.,100.,100.,100.,]).to(self.device) # # weight = torch.tensor( # [1., 100., 130., 1000., 700., 900., 30., 1000., 60., 200., 100., 300., 100., 55.]).to(self.device) weight = torch.tensor([ 1., 100., 130., 130., 1000., 1000., 700.,700.,900., 30.,30.,1000.,60.,60.,200.,200., 100.,100.,300.,300.,100.,55.,55. ]).to(self.device) # weight = torch.tensor( # [1., 100., 100., 50., 80., 50., 80., 80., 50., 70., 70., # 60., 100., 100., ]).to(self.device) ce_loss = nn.CrossEntropyLoss(weight=weight,reduction='mean')(outputs, gt.reshape(-1,256,256)) #dice_loss = DiceLoss(sigmoid_normalization=False)(outputs, expand_as_one_hot(gt.reshape(-1,128,128),14)) dice_loss = MulticlassJaccardLoss(classes=list(range(14)))(outputs, gt.reshape(-1,256,256)) # bce_loss = torch.nn.BCEWithLogitsLoss()(outputs, gts) # focal_loss = FocalLoss(alpha=0.8,gamma=0.5)(outputs, gts) # focal_loss = FocalLoss2d(gamma=0.5)(outputs, gt.reshape(-1,256,256)) loss = ce_loss +dice_loss #loss = focal_loss + dice_loss epoch_loss += loss.item() * img.size(0) # because reduction = 'mean' loss.backward() self.optimizer.step() # DC += iou(outputs.detach().cpu().squeeze().argmax(dim=1),gts.detach().cpu(),n_classes=14)*imgs.size(0) length += img.size(0) # DC = DC / length # epoch_loss = epoch_loss/length # # Print the log info # print( # 'Epoch [%d/%d], Loss: %.4f, \n[Training] DC: %.4f' % ( # epoch + 1, self.num_epochs, # epoch_loss, # DC)) print('EPOCH{}, Loss{}'.format(epoch,epoch_loss/length))