Exemplo n.º 1
0
    # ds_val = AbdomenDS("/raid/scratch/schatter/Dataset/dhanun/MRI/MRITTemp","train",(1,0.5,0.5))

    checkname = None  #'/raid/scratch/schatter/Dataset/dhanun/checkpoints/checkpoint_9_14_13_52_epoch_15' #Set None if no need to load

    result_upsample = True
    #upunterp_fact = None
    upinterp_fact = (1, 1, 1)
    #model_downscale = False
    dataloader = DataLoader(ds_test, batch_size=config.BatchSize, shuffle=True)
    model = WNet(is_cuda)
    #model = WNet()
    if is_cuda:
        device = torch.device("cuda:1")
    else:
        device = torch.device("cpu")
    model.to(device)
    #model.cuda()
    model.eval()
    #model_downscale = False
    mode = 'test'
    optimizer = torch.optim.Adam(model.parameters(), lr=config.init_lr)
    #optimizer
    with open(config.model_tested, 'rb') as f:
        para = torch.load(f, "cpu")
        #para = torch.load(f,"cuda:0")
        model.load_state_dict(para['state_dict'])
    for step, [x] in enumerate(dataloader):
        print('Step' + str(step + 1))
        #print(x.shape)
        x = x.to(device)
        pred, pad_pred = model(x, mode, config.ModelDownscale)
Exemplo n.º 2
0
    dataloader = DataLoader(ds_train,
                            batch_size=config.BatchSize,
                            shuffle=True)
    dataloader1 = DataLoader(ds_val, batch_size=config.BatchSize, shuffle=True)
    #eval_set = DataLoader("MRI/new_test","train")

    #eval_loader = eval_set.torch_loader()
    model = WNet(is_cuda)
    #model = torch.nn.DataParallel(WNet())
    if is_cuda:
        device1 = torch.device("cuda:1")
        device2 = torch.device("cuda:0")
    else:
        device1 = torch.device("cpu")
        device2 = torch.device("cpu")
    model.to(device1)
    #model_eval = torch.nn.DataParallel(WNet())
    #model.cuda()
    #model_eval.cuda()
    #model_eval.eval()
    optimizer = torch.optim.Adam(model.parameters(), lr=config.init_lr)
    #reconstr = torch.nn.MSELoss().cuda(config.cuda_dev)
    if config.useSSIMLoss:
        import pytorch_ssim
        reconstr = pytorch_ssim.SSIM()
    else:
        reconstr = torch.nn.MSELoss()
    Ncuts = NCutsLoss()
    mode = 'train'
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=config.lr_decay_iter,