val_loader = None if not args.val else DataLoader(dataset=valDS,batch_size=args.batchsize,shuffle=False, num_workers=args.nworkers, pin_memory=True) if args.modelid == 0: model = UNet(in_channels=args.nchannel, n_classes=args.nchannel, depth=args.mdepth, wf=round(math.log(args.nfeatures,2)), batch_norm=args.batchnorm, up_mode=args.upmode, dropout=bool(args.dropprob)) elif (args.modelid == 1) or (args.modelid == 2) or (args.modelid == 3): model = SRCNN3D(n_channels=args.nchannel, scale_factor=model_scale_factor, num_features=args.nfeatures) elif (args.modelid == 4) or (args.modelid == 5): model = UNetVSeg(in_ch=args.nchannel, out_ch=args.nchannel, n1=args.nfeatures) elif args.modelid == 6: model = DenseNet(model_depth=args.mdepth, n_input_channels=args.nchannel, num_classes=args.nchannel, drop_rate=args.dropprob) elif (args.modelid == 7) or (args.modelid == 8): model = ThisNewNet(in_channels=args.nchannel, n_classes=args.nchannel, depth=args.mdepth, batch_norm=args.batchnorm, up_mode=args.upmode, dropout=bool(args.dropprob), scale_factor=model_scale_factor, num_features=args.nfeatures, sliceup_first=True if args.modelid==8 else False, loss_slice_count=args.tnnlslc, loss_inplane=args.tnnlinp) elif args.modelID == 9: model=ResNet(n_channels=args.nchannel,is3D=True,res_blocks=14,starting_nfeatures=args.nfeatures,updown_blocks=2,is_relu_leaky=True, #TODO: put all params as args do_batchnorm=args.batchnorm, res_drop_prob=0.2,out_act="sigmoid",forwardV=0, upinterp_algo='convtrans', post_interp_convtrans=False) elif args.modelID == 10: model=ShuffleUNet(in_ch=args.nchannel, num_features=args.nfeatures, out_ch=args.nchannel) else: sys.exit("Invalid Model ID") if args.modelid == 5: IsDeepSup = True else: IsDeepSup = False if args.profile: dummy = torch.randn(args.batchsize, args.nchannel, *args.inshape) with profiler.profile(profile_memory=True, record_shapes=True, use_cuda=True) as prof: model(dummy) prof.export_chrome_trace(os.path.join(save_path, 'model_trace'))
logname = os.path.join(args.save_path, 'log_test_'+args.trainID+'.txt') logging.basicConfig(filename=logname, filemode='a', format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s', datefmt='%H:%M:%S', level=logging.DEBUG) motion_params = {k.split('motion_')[1]: v for k, v in vars(args).items() if k.startswith('motion')} testDS = createTIODS(args.gt_vols_test, args.corrupt_vols_test, is_infer=True, p=args.corrupt_prob, **motion_params) test_loader = DataLoader(dataset=testDS,batch_size=args.batch_size,shuffle=False, num_workers=args.num_workers) if args.modelID == 0: model_params = {k.split('model_')[1]: v for k, v in vars(args).items() if k.startswith('model_')} model=ResNet(n_channels=args.n_channels,is3D=True,**model_params) elif args.modelID == 1: model=ShuffleUNet(in_ch=args.n_channels, num_features=args.model_starting_nfeatures, out_ch=args.n_channels) if args.do_profile: dummy = torch.randn(args.batch_size, args.n_channels, *args.input_shape) with profiler.profile(profile_memory=True, record_shapes=True, use_cuda=True) as prof: model(dummy) prof.export_chrome_trace(os.path.join(args.save_path, 'model_trace')) model.to(device) chk = torch.load(args.checkpoint, map_location=device) model.load_state_dict(chk['state_dict']) trained_epoch = chk['epoch'] model.eval()