eval.add_batch(y_true, y_scores) confusion, accuracy, specificity, sensitivity, precision = eval.confusion_matrix( ) log = OrderedDict([('val_auc_roc', eval.auc_roc()), ('val_f1', eval.f1_score()), ('val_acc', accuracy), ('SE', sensitivity), ('SP', specificity)]) return dict_round(log, 6) if __name__ == '__main__': args = parse_args() save_path = join(args.outf, args.save) sys.stdout = Print_Logger(os.path.join(save_path, 'test_log.txt')) device = torch.device( "cuda" if torch.cuda.is_available() and args.cuda else "cpu") # net = models.UNetFamily.Dense_Unet(1,2).to(device) net = models.LadderNet(inplanes=1, num_classes=2, layers=3, filters=16).to(device) cudnn.benchmark = True # Load checkpoint print('==> Loading checkpoint...') checkpoint = torch.load(join(save_path, 'best_model.pth')) net.load_state_dict(checkpoint['net']) eval = Test(args) eval.inference(net) print(eval.evaluate()) eval.save_segmentation_result()
def main(): setpu_seed(2021) args = parse_args() save_path = join(args.outf, args.save) save_args(args, save_path) device = torch.device( "cuda" if torch.cuda.is_available() and args.cuda else "cpu") cudnn.benchmark = True log = Logger(save_path) sys.stdout = Print_Logger(os.path.join(save_path, 'train_log.txt')) print('The computing device used is: ', 'GPU' if device.type == 'cuda' else 'CPU') # net = models.UNetFamily.U_Net(1,2).to(device) net = models.LadderNet(inplanes=1, num_classes=2, layers=3, filters=16).to(device) print("Total number of parameters: " + str(count_parameters(net))) log.save_graph( net, torch.randn((1, 1, 48, 48)).to(device).to( device=device)) # Save the model structure to the tensorboard file # torch.nn.init.kaiming_normal(net, mode='fan_out') # Modify default initialization method # net.apply(weight_init) # The training speed of this task is fast, so pre training is not recommended if args.pre_trained is not None: # Load checkpoint. print('==> Resuming from checkpoint..') checkpoint = torch.load(args.outf + '%s/latest_model.pth' % args.pre_trained) net.load_state_dict(checkpoint['net']) optimizer.load_state_dict(checkpoint['optimizer']) args.start_epoch = checkpoint['epoch'] + 1 # criterion = LossMulti(jaccard_weight=0,class_weights=np.array([0.5,0.5])) criterion = CrossEntropyLoss2d() # Initialize loss function # create a list of learning rate with epochs # lr_epoch = np.array([50, args.N_epochs]) # lr_value = np.array([0.001, 0.0001]) # lr_schedule = make_lr_schedule(lr_epoch,lr_value) # lr_scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=10,gamma=0.5) # optimizer = optim.SGD(net.parameters(),lr=lr_schedule[0], momentum=0.9, weight_decay=5e-4, nesterov=True) optimizer = optim.Adam(net.parameters(), lr=args.lr) lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.N_epochs, eta_min=0) train_loader, val_loader = get_dataloader(args) # create dataloader if args.val_on_test: print( '\033[0;32m===============Validation on Testset!!!===============\033[0m' ) val_tool = Test(args) best = { 'epoch': 0, 'AUC_roc': 0.5 } # Initialize the best epoch and performance(AUC of ROC) trigger = 0 # Early stop Counter for epoch in range(args.start_epoch, args.N_epochs + 1): print('\nEPOCH: %d/%d --(learn_rate:%.6f) | Time: %s' % \ (epoch, args.N_epochs,optimizer.state_dict()['param_groups'][0]['lr'], time.asctime())) # train stage train_log = train(train_loader, net, criterion, optimizer, device) # val stage if not args.val_on_test: val_log = val(val_loader, net, criterion, device) else: val_tool.inference(net) val_log = val_tool.val() log.update(epoch, train_log, val_log) # Add log information lr_scheduler.step() # Save checkpoint of latest and best model. state = { 'net': net.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch } torch.save(state, join(save_path, 'latest_model.pth')) trigger += 1 if val_log['val_auc_roc'] > best['AUC_roc']: print('\033[0;33mSaving best model!\033[0m') torch.save(state, join(save_path, 'best_model.pth')) best['epoch'] = epoch best['AUC_roc'] = val_log['val_auc_roc'] trigger = 0 print('Best performance at Epoch: {} | AUC_roc: {}'.format( best['epoch'], best['AUC_roc'])) # early stopping if not args.early_stop is None: if trigger >= args.early_stop: print("=> early stopping") break torch.cuda.empty_cache()