Beispiel #1
0
    val_loader = None if not args.val else DataLoader(dataset=valDS,batch_size=args.batchsize,shuffle=False, num_workers=args.nworkers, pin_memory=True)

    if args.modelid == 0:
        model = UNet(in_channels=args.nchannel, n_classes=args.nchannel, depth=args.mdepth, wf=round(math.log(args.nfeatures,2)), batch_norm=args.batchnorm, up_mode=args.upmode, dropout=bool(args.dropprob))
    elif (args.modelid == 1) or (args.modelid == 2) or (args.modelid == 3):
        model = SRCNN3D(n_channels=args.nchannel, scale_factor=model_scale_factor, num_features=args.nfeatures)    
    elif (args.modelid == 4) or (args.modelid == 5):
        model = UNetVSeg(in_ch=args.nchannel, out_ch=args.nchannel, n1=args.nfeatures)   
    elif args.modelid == 6:
        model = DenseNet(model_depth=args.mdepth, n_input_channels=args.nchannel, num_classes=args.nchannel, drop_rate=args.dropprob)
    elif (args.modelid == 7) or (args.modelid == 8):
        model = ThisNewNet(in_channels=args.nchannel, n_classes=args.nchannel, depth=args.mdepth, batch_norm=args.batchnorm, up_mode=args.upmode, dropout=bool(args.dropprob), 
                            scale_factor=model_scale_factor, num_features=args.nfeatures, sliceup_first=True if args.modelid==8 else False, 
                            loss_slice_count=args.tnnlslc, loss_inplane=args.tnnlinp)
    elif args.modelID == 9:
        model=ResNet(n_channels=args.nchannel,is3D=True,res_blocks=14,starting_nfeatures=args.nfeatures,updown_blocks=2,is_relu_leaky=True, #TODO: put all params as args
                    do_batchnorm=args.batchnorm, res_drop_prob=0.2,out_act="sigmoid",forwardV=0, upinterp_algo='convtrans', post_interp_convtrans=False)
    elif args.modelID == 10:
        model=ShuffleUNet(in_ch=args.nchannel, num_features=args.nfeatures, out_ch=args.nchannel)
    else:
        sys.exit("Invalid Model ID")

    if args.modelid == 5:
        IsDeepSup = True
    else:
        IsDeepSup = False

    if args.profile:
        dummy = torch.randn(args.batchsize, args.nchannel, *args.inshape)
        with profiler.profile(profile_memory=True, record_shapes=True, use_cuda=True) as prof:
            model(dummy)
            prof.export_chrome_trace(os.path.join(save_path, 'model_trace'))
    logname = os.path.join(args.save_path, 'log_test_'+args.trainID+'.txt')

    logging.basicConfig(filename=logname,
                            filemode='a',
                            format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
                            datefmt='%H:%M:%S',
                            level=logging.DEBUG)

    motion_params = {k.split('motion_')[1]: v for k, v in vars(args).items() if k.startswith('motion')}
    testDS = createTIODS(args.gt_vols_test, args.corrupt_vols_test, is_infer=True, p=args.corrupt_prob, **motion_params)
    
    test_loader = DataLoader(dataset=testDS,batch_size=args.batch_size,shuffle=False, num_workers=args.num_workers)

    if args.modelID == 0:
        model_params = {k.split('model_')[1]: v for k, v in vars(args).items() if k.startswith('model_')}
        model=ResNet(n_channels=args.n_channels,is3D=True,**model_params)
    elif args.modelID == 1:
        model=ShuffleUNet(in_ch=args.n_channels, num_features=args.model_starting_nfeatures, out_ch=args.n_channels)

    if args.do_profile:
        dummy = torch.randn(args.batch_size, args.n_channels, *args.input_shape)
        with profiler.profile(profile_memory=True, record_shapes=True, use_cuda=True) as prof:
            model(dummy)
            prof.export_chrome_trace(os.path.join(args.save_path, 'model_trace'))
    model.to(device)

    chk = torch.load(args.checkpoint, map_location=device)
    model.load_state_dict(chk['state_dict'])
    trained_epoch = chk['epoch'] 
    model.eval()