def val_seg(args, model): info = json.load(open(osp.join(args.list_dir, 'info.json'), 'r')) normalize = dt.Normalize(mean=info['mean'], std=info['std']) t = [] if args.resize: t.append(dt.Resize(args.resize)) if args.crop_size: t.append(dt.RandomCrop(args.crop_size)) t.extend([dt.Label_Transform(), dt.ToTensor(), normalize]) dataset = SegList(args.data_dir, 'val', dt.Compose(t), list_dir=args.list_dir) val_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=False) cudnn.benchmark = True dice_avg, dice_1, dice_2, dice_3, dice_list, auc, auc_1, auc_2, auc_3 = val( args, val_loader, model) return dice_avg, dice_1, dice_2, dice_3, dice_list, auc, auc_1, auc_2, auc_3
def test_seg(args, result_path): print('Loading test model ...') if args.fusion: # 1 net_1 = net_builder('unet_nested') net_1 = nn.DataParallel(net_1).cuda() checkpoint_1 = torch.load( 'result/ori_3D/train/unet_nested_nopre_mix_33_NEW_multi_2_another/checkpoint/model_best.pth.tar' ) net_1.load_state_dict(checkpoint_1['state_dict']) # 2 net_2 = net_builder('unet') net_2 = nn.DataParallel(net_2).cuda() checkpoint_2 = torch.load( 'result/ori_3D/train/unet_nopre_mix_3_NEW_multi_2/checkpoint/model_best.pth.tar' ) net_2.load_state_dict(checkpoint_2['state_dict']) net = [net_1, net_2] else: net = net_builder(args.seg_name) net = nn.DataParallel(net).cuda() checkpoint = torch.load(args.seg_path) net.load_state_dict(checkpoint['state_dict']) #print('model loaded!') info = json.load(open(osp.join(args.list_dir, 'info_test.json'), 'r')) normalize = st.Normalize(mean=info['mean'], std=info['std']) t = [] if args.resize: t.append(st.Resize(args.resize)) t.extend([st.ToTensor(), normalize]) dataset = SegList(args.data_dir, 'test', st.Compose(t), list_dir=args.list_dir) test_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=False) cudnn.benchmark = True if args.fusion: test_fusion(args, test_loader, net, result_path) else: test(args, test_loader, net, result_path)
def train_seg(args,result_path,logger): for k, v in args.__dict__.items(): print(k, ':', v) # load the net net = net_builder(args.name,args.model_path,args.pretrained) model = torch.nn.DataParallel(net).cuda() param = count_param(model) print('###################################') print('Model #%s# parameters: %.2f M' % (args.name,param/1e6)) # set the loss criterion criterion = loss_builder(args.loss) # Data loading code info = json.load(open(osp.join(args.list_dir, 'info.json'), 'r')) normalize = dt.Normalize(mean=info['mean'],std=info['std']) # data transforms t = [] if args.resize: t.append(dt.Resize(args.resize)) if args.random_rotate > 0: t.append(dt.RandomRotate(args.random_rotate)) if args.random_scale > 0: t.append(dt.RandomScale(args.random_scale)) if args.crop_size: t.append(dt.RandomCrop(args.crop_size)) t.extend([dt.Label_Transform(), dt.RandomHorizontalFlip(), dt.ToTensor(), normalize]) train_loader = torch.utils.data.DataLoader( SegList(args.data_dir, 'train', dt.Compose(t),list_dir=args.list_dir),batch_size=args.batch_size, shuffle=True, num_workers=args.workers,pin_memory=True, drop_last=True) # define loss function (criterion) and pptimizer if args.optimizer == 'SGD': #SGD optimizer optimizer = torch.optim.SGD(net.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) elif args.optimizer == 'Adam': #Adam optimizer optimizer = torch.optim.Adam(net.parameters(), args.lr, betas=(0.9, 0.99), weight_decay=args.weight_decay) cudnn.benchmark = True best_dice = 0 start_epoch = 0 # load the pretrained model if args.model_path: print("=> loading pretrained model '{}'".format(args.model_path)) checkpoint = torch.load(args.model_path) model.load_state_dict(checkpoint['state_dict']) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] best_dice = checkpoint['best_dice'] dice_epoch = checkpoint['dice_epoch'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) # main training for epoch in range(start_epoch, args.epochs): lr = adjust_learning_rate(args,optimizer, epoch) logger_vis.info('Epoch: [{0}]\tlr {1:.06f}'.format(epoch, lr)) # train for one epoch loss,dice_train,dice_1,dice_2,dice_3,dice_4,dice_5,dice_6,dice_7,dice_8,dice_9 = train(args,train_loader, model, criterion, optimizer, epoch) # evaluate on validation set dice_val,dice_11,dice_22,dice_33,dice_44,dice_55,dice_66,dice_77,dice_88,dice_99,dice_list = val_seg(args,model) # save best checkpoints is_best = dice_val > best_dice best_dice = max(dice_val, best_dice) checkpoint_dir = osp.join(result_path,'checkpoint') if not exists(checkpoint_dir): os.makedirs(checkpoint_dir) checkpoint_latest = checkpoint_dir+'/checkpoint_latest.pth.tar' save_checkpoint({ 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'dice_epoch':dice_val, 'best_dice': best_dice, }, is_best, checkpoint_dir,filename=checkpoint_latest) if args.save_every_checkpoint: if (epoch + 1) % 1 == 0: history_path = checkpoint_dir+'/checkpoint_{:03d}.pth.tar'.format(epoch + 1) shutil.copyfile(checkpoint_latest, history_path) logger.append([epoch,dice_train,dice_val,dice_1,dice_11,dice_2,dice_22,dice_3,dice_33,dice_4,dice_44,dice_5,dice_55,dice_6,dice_66,dice_7,dice_77,dice_8,dice_88,dice_9,dice_99])