Example #1
0
    random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

# TODO: just pass config and figure it out
#brats_data = BraTSDataset(config.data_dir, config.labels, modes=config.modes, debug=True)
#train, test = torch.utils.data.random_split(brats_data, [8, 2])
#trainp = 228
#testp = 57
trainp = 285
testp = 0

brats_data = BraTSDataset(config.data_dir,
                          modes=config.modes,
                          debug=config.debug,
                          dims=config.dims,
                          downsample=4)
#train_split, test_split = torch.utils.data.random_split(brats_data, [trainp, testp])
#
#trainloader = DataLoader(train_split, batch_size=1, shuffle=True, num_workers=0)
#testloader = DataLoader(test_split, batch_size=1, shuffle=True, num_workers=0)
trainloader = DataLoader(brats_data, batch_size=1, shuffle=True, num_workers=0)
testloader = DataLoader(
    BraTSDataset(config.test_dir,
                 modes=config.modes,
                 debug=config.debug,
                 dims=config.dims,
                 downsample=4))
# specify in config?
input_channels = len(config.modes)
Example #2
0
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

# TODO: just pass config and figure it out
#brats_data = BraTSDataset(config.data_dir, config.labels, modes=config.modes, debug=True)
#train, test = torch.utils.data.random_split(brats_data, [8, 2])
#trainp = 228
#testp = 57
trainp = 285
testp = 0

brats_data = BraTSDataset(config.data_dir,
                          modes=config.modes,
                          debug=config.debug,
                          dims=config.dims)
#train_split, test_split = torch.utils.data.random_split(brats_data, [trainp, testp])
#
#trainloader = DataLoader(train_split, batch_size=1, shuffle=True, num_workers=0)
#testloader = DataLoader(test_split, batch_size=1, shuffle=True, num_workers=0)
trainloader = DataLoader(brats_data, batch_size=1, shuffle=True, num_workers=0)
testloader = None

# TODO: Replace with builder.
if config.model_type == 'baseline':
    model = vaereg.UNet()
    model = model.to(device)
if config.model_type == 'reconreg':
    model = vaereg.ReconReg()
    model = model.to(device)
Example #3
0
import torch
from model import vaereg

from torch.utils.data import DataLoader
from datasets.data_loader import BraTSDataset

from utils import _validate
from losses import losses
device = torch.device('cuda')
model = vaereg.UNet()
checkpoint = torch.load('./checkpoints/downsampled4/downsampled4',
                        map_location='cuda:0')
# checkpoint = torch.load('./checkpoints/baseline/baseline', map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
model.eval()
model = model.to(device)

brats_data = BraTSDataset('./data/brats2018downsampled/test',
                          dims=[32, 32, 32],
                          modes=["t1", "t1ce", "t2", "flair"],
                          downsample=4)
dataloader = DataLoader(brats_data, batch_size=1, shuffle=True, num_workers=0)

_validate(model, losses.DiceLoss(), dataloader, device, False)
Example #4
0
from torch.utils.data import DataLoader
from datasets.data_loader import BraTSDataset
import nibabel as nib

parser = argparse.ArgumentParser(
    description='Perform inference using trained MRI segmentation model.')
parser.add_argument('--config')
args = parser.parse_args()

config = MRISegConfigParser(args.config)

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

testloader = DataLoader(BraTSDataset(config.data_dir, modes=config.modes),
                        batch_size=1)
#checkpoint = torch.load('checkpoints/zero.pt')
#checkpoint = torch.load('checkpoints/best_overfit.pt')
checkpoint = torch.load('checkpoints/best_' + config.model_name)
model = BraTSSegmentation(input_channels=len(config.modes))
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
model.eval()

for i, data in enumerate(testloader):
    src, target = data

    src_npy = src.squeeze().numpy()[1, :, :, :]
    img = nib.Nifti1Image(src_npy, np.eye(4))
    nib.save(img, os.path.join('scratch', 'test.nii.gz'))
Example #5
0
from utils import dice_score
from torch.utils.data import DataLoader
from losses.dice import DiceLoss
from model.btseg import BraTSSegmentation
from datasets.data_loader import BraTSDataset

config = SafeConfigParser()
config.read("config/test.cfg")

train_split = config.getfloat('train_params', 'train_split')
data_dir = config.get('data', 'data_dir')

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

brats_data = BraTSDataset(data_dir)
num_examples = len(brats_data)
data_indices = np.arange(num_examples)
deterministic_test = True
# Fix stochasticity in data sampling
if deterministic_test:
    np.random.seed(0)

# TODO: Doesn't really seem to belong here. Make a new
# class for handling this or push it to the dataloader?
np.random.shuffle(data_indices)
split_idx = int(num_examples * train_split)
test_sampler = sampler.SubsetRandomSampler(data_indices[split_idx:])
testloader = DataLoader(brats_data, batch_size=1, sampler=test_sampler)

model = BraTSSegmentation(input_channels=2)
Example #6
0
import nibabel as nib

device = torch.device('cuda')
model = vaereg.UNet()
#checkpoint = \
#    torch.load('checkpoints/vaereg-fulltrain-smallcrop-eloss/vaereg-fulltrain-smallcrop-eloss', 
#map_location='cuda:0')
checkpoint = \
    torch.load('checkpoints/vaereg-fulltrain-smallcrop-eloss/vaereg-fulltrain-smallcrop-eloss', 
        map_location='cuda:0')
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
model = model.to(device)

brats_data = \
    BraTSDataset('/data/cddunca2/brats2018/validation/', dims=[128, 128, 128])
dataloader = DataLoader(brats_data, batch_size=1, num_workers=0)
dims=[128, 128, 128]
with torch.no_grad():
  model.eval()
  for src, tgt in tqdm(dataloader):
    ID = tgt[0].split("/")[5]
    src = src.to(device, dtype=torch.float)
    
    output = model(src)
    x_off = int((240 - dims[0]) / 4)*2
    y_off = int((240 - dims[1]) / 4)*2
    m = nn.ConstantPad3d((13, 14, x_off, x_off, y_off, y_off), 0)
    
    ncr_net = m(output[0, 0, :, :, :])
    ed = m(output[0, 1, :, :, :])