Ejemplo n.º 1
0
def create_datasets(args):
    train_mask = MaskFunc(args.center_fractions, args.accelerations)
    dev_mask = MaskFunc(args.center_fractions, args.accelerations)

    train_data = SliceData(
        root=args.data_path / f'{args.challenge}_train',
        transform=DataTransform(train_mask, args.resolution, args.challenge),
        sample_rate=args.sample_rate,
        challenge=args.challenge
    )
    if not args.overfit:
        dev_data = SliceData(
            root=args.data_path / f'{args.challenge}_val',
            transform=DataTransform(dev_mask, args.resolution, args.challenge, use_seed=True),
            sample_rate=args.sample_rate,
            challenge=args.challenge,
        )
    else:
        dev_data = SliceData(
            root=args.data_path / f'{args.challenge}_train',
            transform=DataTransform(dev_mask, args.resolution, args.challenge, use_seed=True),
            sample_rate=args.sample_rate,
            challenge=args.challenge,
        )
    if args.use_dicom:
        dicom_data = SliceDICOM(root=args.data_path, 
                                transform=DICOMTransform(args.resolution),
                                sample_rate=args.sample_rate,
        )
        return dev_data, train_data, dicom_data
    return dev_data, train_data
Ejemplo n.º 2
0
def get_transforms(args):
    train_mask = MaskFunc(args.center_fractions, args.accelerations)
    dev_mask = MaskFunc(args.center_fractions, args.accelerations)
    train_transform = DataTransform(train_mask, args.resolution, args.challenge) 
    val_transform = DataTransform(dev_mask, args.resolution, args.challenge, use_seed=True) 
    test_transform = DataTransform(None, args.resolution, args.challenge, use_seed=True, use_mask=False) 
    return train_transform, val_transform, test_transform
Ejemplo n.º 3
0
def create_data_loader(args):
    dev_mask = MaskFunc(args.center_fractions, args.accelerations)
    data = SliceData(root=args.data_path / f'{args.challenge}_val',
                     transform=DataTransform(dev_mask),
                     challenge=args.challenge,
                     sample_rate=args.sample_rate)
    return data
Ejemplo n.º 4
0
def create_datasets(args):
    train_mask = MaskFunc(args.center_fractions, args.accelerations)
    dev_mask = MaskFunc(args.center_fractions, args.accelerations)

    train_data = SliceData(root=args.data_path + 'singlecoil_train',
                           transform=DataTransform(train_mask,
                                                   args.resolution),
                           sample_rate=args.sample_rate,
                           challenge=args.challenge)
    dev_data = SliceData(
        root=args.data_path + 'singlecoil_val',
        transform=DataTransform(dev_mask, args.resolution, use_seed=True),
        sample_rate=args.sample_rate,
        challenge=args.challenge,
    )
    return dev_data, train_data
def create_datasets(args):
    train_mask = MaskFunc(args.center_fractions, args.accelerations)
    dev_mask = MaskFunc(args.center_fractions, args.accelerations)

    train_data = SliceData(
        root=args.data_path / f'{args.challenge}_train',
        transform=DataTransform(train_mask, args.resolution, args.challenge,use_aug=args.aug),
        sample_rate=args.sample_rate,
        challenge=args.challenge
    )
    dev_data = SliceData(
        root=args.data_path / f'{args.challenge}_val',
        transform=DataTransform(dev_mask, args.resolution, args.challenge, use_seed=True, use_aug=False),
        sample_rate=args.sample_rate,
        challenge=args.challenge
    )
    return dev_data, train_data
Ejemplo n.º 6
0
def test_mask_reuse(center_fracs, accelerations, batch_size, dim):
    mask_func = MaskFunc(center_fracs, accelerations)
    shape = (batch_size, dim, dim, 2)
    mask1 = mask_func(shape, seed=123)
    mask2 = mask_func(shape, seed=123)
    mask3 = mask_func(shape, seed=123)
    assert torch.all(mask1 == mask2)
    assert torch.all(mask2 == mask3)
Ejemplo n.º 7
0
def test_apply_mask(shape, center_fractions, accelerations):
    mask_func = MaskFunc(center_fractions, accelerations)
    expected_mask = mask_func(shape, seed=123)
    input = create_input(shape)
    output, mask = transforms.apply_mask(input, mask_func, seed=123)
    assert output.shape == input.shape
    assert mask.shape == expected_mask.shape
    assert np.all(expected_mask.numpy() == mask.numpy())
    assert np.all(np.where(mask.numpy() == 0, 0, output.numpy()) == output.numpy())
Ejemplo n.º 8
0
def create_datasets(args, limit=-1):
    train_mask = MaskFunc(args.center_fractions, args.accelerations)
    dev_mask = MaskFunc(args.center_fractions, args.accelerations)

    train_data = SliceData(root=args.data_path + '/singlecoil_train',
                           transform=DataTransform(train_mask, args.resolution,
                                                   args.reduce, args.polar),
                           sample_rate=args.sample_rate,
                           challenge=args.challenge,
                           non_zero_ratio=args.non_zero_ratio,
                           limit=limit)
    dev_data = SliceData(root=args.data_path + '/singlecoil_val',
                         transform=DataTransform(dev_mask,
                                                 args.resolution,
                                                 args.reduce,
                                                 args.polar,
                                                 use_seed=True),
                         sample_rate=args.sample_rate,
                         challenge=args.challenge,
                         non_zero_ratio=args.non_zero_ratio,
                         limit=limit)
    return dev_data, train_data
Ejemplo n.º 9
0
def test_mask_low_freqs(center_fracs, accelerations, batch_size, dim):
    mask_func = MaskFunc(center_fracs, accelerations)
    shape = (batch_size, dim, dim, 2)
    mask = mask_func(shape, seed=123)
    mask_shape = [1 for _ in shape]
    mask_shape[-2] = dim
    assert list(mask.shape) == mask_shape

    num_low_freqs_matched = False
    for center_frac in center_fracs:
        num_low_freqs = int(round(dim * center_frac))
        pad = (dim - num_low_freqs + 1) // 2
        if np.all(mask[pad:pad + num_low_freqs].numpy() == 1):
            num_low_freqs_matched = True
    assert num_low_freqs_matched
Ejemplo n.º 10
0
def create_data_loaders(args):
    mask_func = None
    if args.mask_kspace:
        mask_func = MaskFunc(args.center_fractions, args.accelerations)
    data = SliceData(
        root=args.data_path / f'{args.challenge}_{args.data_split}',
        transform=DataTransform(args.resolution, args.challenge, mask_func),
        sample_rate=1.,
        challenge=args.challenge)
    data_loader = DataLoader(
        dataset=data,
        batch_size=args.batch_size,
        num_workers=4,
        pin_memory=True,
    )
    return data_loader
Ejemplo n.º 11
0
def get_epoch_batch(subject_id, acc, center_fract):
    ''' get training data '''

    rawdata_name, coil_name = subject_id

    rawdata = np.complex64(loadmat(rawdata_name)['rawdata']).transpose(2, 0, 1)

    #    coil_sensitivities = load_file(coil_name)
    #    coil_sensitivities = data2complex(coil_sensitivities['sensitivities']).transpose(2,1,0)
    #    coil_sensitivities = np.complex64(coil_sensitivities)

    sensitivity = np.complex64(loadmat(coil_name)['sensitivities'])

    mask_func = MaskFunc(center_fractions=[center_fract], accelerations=[acc])

    rawdata2 = T.to_tensor(rawdata)

    sensitivity2 = T.to_tensor(sensitivity.transpose(2, 0, 1))

    return data_for_training(rawdata2, sensitivity2, mask_func)
Ejemplo n.º 12
0
# As we can see, each slice in a multi-coil MRI scan focusses on a different region of the image. These slices can be combined into the full image using the Root-Sum-of-Squares (RSS) transform.

# In[11]:

slice_image_rss = T.root_sum_of_squares(slice_image_abs, dim=0)

# In[12]:

plt.imshow(np.abs(slice_image_rss.numpy()), cmap='gray')

# So far, we have been looking at fully-sampled data. We can simulate under-sampled data by creating a mask and applying it to k-space.

# In[13]:

from common.subsample import MaskFunc
mask_func = MaskFunc(center_fractions=[0.04],
                     accelerations=[8])  # Create the mask function object

# In[14]:

masked_kspace, mask = T.apply_mask(slice_kspace2,
                                   mask_func)  # Apply the mask to k-space

# Let's see what the subsampled image looks like:

# In[15]:

sampled_image = T.ifft2(
    masked_kspace)  # Apply Inverse Fourier Transform to get the complex image
sampled_image_abs = T.complex_abs(
    sampled_image)  # Compute absolute value to get a real image
sampled_image_rss = T.root_sum_of_squares(sampled_image_abs, dim=0)
Ejemplo n.º 13
0
                    raw = '{0}/rawdata{1}.mat'.format(subject_id, i)
                    sen = '{0}/espirit{1}.mat'.format(subject_id, i)
                    mask = '{0}/{1}'.format(which_view, dataset['mask'])

                    rawdata = sio.loadmat(raw)['rawdata']
                    if dataset['name'] == 'axial_t2':
                        coil_sensitivities = load_file(sen)
                        coil_sensitivities = data2complex(
                            coil_sensitivities['sensitivities']).transpose(
                                2, 1, 0)
                    else:
                        coil_sensitivities = np.complex64(
                            sio.loadmat(sen)['sensitivities'])

                    mask_func = MaskFunc(center_fractions=[center_fract],
                                         accelerations=[acc])
                    img_und, img_gt, rawdata_und, masks, sensitivity = data_for_training(
                        rawdata, coil_sensitivities, mask_func)

                    # add batch dimension
                    batch_img_und = img_und.unsqueeze(0).to(device)
                    batch_rawdata_und = rawdata_und.unsqueeze(0).to(device)
                    batch_masks = masks.unsqueeze(0).to(device)
                    batch_sensitivities = sensitivity.unsqueeze(0).to(device)

                    # deploy the model
                    rec = rec_net(batch_img_und, batch_rawdata_und,
                                  batch_masks, batch_sensitivities)

                    # convert to complex
                    batch_recon = tensor_to_complex_np(rec.to('cpu'))
Ejemplo n.º 14
0
def main():
    logger.info("Logger is set - training start")

    # set default gpu device id
    torch.cuda.set_device(config.gpus[0])

    # set seed
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)

    torch.backends.cudnn.benchmark = True

    # get data with meta info
    # input_size, input_channels, n_classes, train_data = utils.get_data(
    #     config.dataset, config.data_path, cutout_length=0, validation=False)
    input_size = 320
    input_channels = 2
    n_classes = 2
    train_mask = MaskFunc([0.08, 0.04], [4, 8])

    train_data = SliceData(
        root=config.dataset + 'train',
        transform=DataTransform(train_mask, input_size, 'singlecoil'),
        challenge='singlecoil'
    )

    net_crit = nn.L1Loss().to(device)
    model = SearchCNNController(input_channels, config.init_channels, n_classes, config.layers,
                                net_crit, device_ids=config.gpus)
    model = model.to(device)
    # weights optimizer
    w_optim = torch.optim.SGD(model.weights(), config.w_lr, momentum=config.w_momentum,
                              weight_decay=config.w_weight_decay)
    # alphas optimizer
    alpha_optim = torch.optim.Adam(model.alphas(), config.alpha_lr, betas=(0.5, 0.999),
                                   weight_decay=config.alpha_weight_decay)

    # split data to train/validation
    n_train = len(train_data)
    split = n_train // 2
    indices = list(range(n_train))
    train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split])
    valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:])
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=config.batch_size,
                                               sampler=train_sampler,
                                               num_workers=config.workers,
                                               pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=config.batch_size,
                                               sampler=valid_sampler,
                                               num_workers=config.workers,
                                               pin_memory=True)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        w_optim, config.epochs, eta_min=config.w_lr_min)
    architect = Architect(model, config.w_momentum, config.w_weight_decay)
    # training loop
    best_top1 = 0.
    for epoch in range(config.epochs):
        lr_scheduler.step()
        lr = lr_scheduler.get_lr()[0]

        model.print_alphas(logger)

        # training
        train(train_loader, valid_loader, model, architect, w_optim, alpha_optim, lr, epoch)

        # validation
        cur_step = (epoch+1) * len(train_loader)
        top1 = validate(valid_loader, model, epoch, cur_step)

        # log
        # genotype
        genotype = model.genotype()
        logger.info("genotype = {}".format(genotype))

        # genotype as a image
        plot_path = os.path.join(config.plot_path, "EP{:02d}".format(epoch+1))
        caption = "Epoch {}".format(epoch+1)
        plot(genotype.normal, plot_path + "-normal", caption)
        plot(genotype.reduce, plot_path + "-reduce", caption)

        # save
        if best_top1 < top1:
            best_top1 = top1
            best_genotype = genotype
            is_best = True
        else:
            is_best = False
        utils.save_checkpoint(model, config.path, is_best)
        print("")

    logger.info("Final best PSNR = {:.4%}".format(best_top1))
    logger.info("Best Genotype = {}".format(best_genotype))