Beispiel #1
0
def create_datasets(args):
    train_mask = MaskFunc(args.center_fractions, args.accelerations)
    dev_mask = MaskFunc(args.center_fractions, args.accelerations)

    train_data = SliceData(
        root=args.data_path / f'{args.challenge}_train',
        transform=DataTransform(train_mask, args.resolution, args.challenge),
        sample_rate=args.sample_rate,
        challenge=args.challenge
    )
    if not args.overfit:
        dev_data = SliceData(
            root=args.data_path / f'{args.challenge}_val',
            transform=DataTransform(dev_mask, args.resolution, args.challenge, use_seed=True),
            sample_rate=args.sample_rate,
            challenge=args.challenge,
        )
    else:
        dev_data = SliceData(
            root=args.data_path / f'{args.challenge}_train',
            transform=DataTransform(dev_mask, args.resolution, args.challenge, use_seed=True),
            sample_rate=args.sample_rate,
            challenge=args.challenge,
        )
    if args.use_dicom:
        dicom_data = SliceDICOM(root=args.data_path, 
                                transform=DICOMTransform(args.resolution),
                                sample_rate=args.sample_rate,
        )
        return dev_data, train_data, dicom_data
    return dev_data, train_data
Beispiel #2
0
def create_datasets(args):
    train_data = SliceData(
        root=args.data_path / 'multicoil_train',
        transform=DataTransform(args.resolution),
        sample_rate=args.sample_rate
    )
    dev_data = SliceData(
        root=args.data_path / 'multicoil_val',
        transform=DataTransform(args.resolution),
        sample_rate=args.sample_rate
    )
    return dev_data, train_data
Beispiel #3
0
def create_data_loader(args):
    dev_mask = MaskFunc(args.center_fractions, args.accelerations)
    data = SliceData(root=args.data_path / f'{args.challenge}_val',
                     transform=DataTransform(dev_mask),
                     challenge=args.challenge,
                     sample_rate=args.sample_rate)
    return data
Beispiel #4
0
    def _create_data_loader(self,
                            data_transform,
                            data_partition,
                            sample_rate=None):
        sample_rate = sample_rate or self.hparams.sample_rate
        dataset = SliceData(root=self.hparams.data_path /
                            f'{self.hparams.challenge}_{data_partition}',
                            transform=data_transform,
                            sample_rate=sample_rate,
                            challenge=self.hparams.challenge)

        is_train = (data_partition == 'train')
        if is_train:
            sampler = DistributedSampler(dataset)
        else:
            sampler = VolumeSampler(dataset)

        return DataLoader(
            dataset=dataset,
            batch_size=self.hparams.batch_size,
            num_workers=4,
            pin_memory=False,
            drop_last=is_train,
            sampler=sampler,
        )
def create_datasets(args):
    train_mask = MaskFunc(args.center_fractions, args.accelerations)
    dev_mask = MaskFunc(args.center_fractions, args.accelerations)

    train_data = SliceData(
        root=args.data_path / f'{args.challenge}_train',
        transform=DataTransform(train_mask, args.resolution, args.challenge,use_aug=args.aug),
        sample_rate=args.sample_rate,
        challenge=args.challenge
    )
    dev_data = SliceData(
        root=args.data_path / f'{args.challenge}_val',
        transform=DataTransform(dev_mask, args.resolution, args.challenge, use_seed=True, use_aug=False),
        sample_rate=args.sample_rate,
        challenge=args.challenge
    )
    return dev_data, train_data
Beispiel #6
0
def create_datasets_multi(args):
    train_mask = arc_masking_func
    dev_mask = arc_masking_func

    train_data = SliceData(
        root=args.data_path / f'{args.challenge}_train',
        transform=SquareDataTransformC3_multi(train_mask, args.resolution, args.challenge),
        sample_rate=args.sample_rate,
        challenge=args.challenge
    )
    dev_data = SliceData(
        root=args.data_path / f'{args.challenge}_val',
        transform=SquareDataTransformC3_multi(dev_mask, args.resolution, args.challenge, use_seed=True),
        sample_rate=args.sample_rate,
        challenge=args.challenge,
    )
    return dev_data, train_data
def create_datasets(args):
    train_mask = create_mask_for_mask_type(args.mask_type, args.center_fractions, args.accelerations)
    dev_mask = create_mask_for_mask_type(args.mask_type, args.center_fractions, args.accelerations)

    train_data = SliceData(
        root=args.data_path / f'{args.challenge}_train',
        transform=DataTransform(train_mask, args.resolution, args.challenge),
        sample_rate=args.sample_rate,
        challenge=args.challenge,
        bbox_root=args.bbox_root,
    )
    dev_data = SliceData(
        root=args.data_path / f'{args.challenge}_val',
        transform=DataTransform(dev_mask, args.resolution, args.challenge, use_seed=True),
        sample_rate=args.sample_rate,
        challenge=args.challenge,
    )
    return dev_data, train_data
def create_datasets(args):
    train_mask = subsample.MaskFunc(args.center_fractions, args.accelerations)
    dev_mask = subsample.MaskFunc(args.center_fractions, args.accelerations)

    train_data = SliceData(
        root=args.data_path + '{}_train'.format(args.challenge),
        transform=DataTransform(train_mask, args.resolution, args.challenge),
        sample_rate=args.sample_rate,
        challenge=args.challenge)
    dev_data = SliceData(
        root=args.data_path + '{}_val'.format(args.challenge),
        transform=DataTransform(dev_mask,
                                args.resolution,
                                args.challenge,
                                use_seed=True),
        sample_rate=args.sample_rate,
        challenge=args.challenge,
    )
    return dev_data, train_data
Beispiel #9
0
def load_data(args):
    print("Staring load_data()")
    train_dataset = SliceData(
        root=args.data_path + '/singlecoil_train',
        transform=DataTransform(resolution=args.resolution),
        challenge=args.challenge,
        sample_rate=args.sample_rate)
    train_data_loader = DataLoader(
        dataset=train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True,
    )
    val_dataset = SliceData(
        root=args.data_path + '/singlecoil_val',
        transform=DataTransform(resolution=args.resolution),
        challenge=args.challenge,
        sample_rate=args.sample_rate)
    val_data_loader = DataLoader(
        dataset=val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True,
    )
    display_dataset = [
        val_dataset[i] for i in range(0, len(val_dataset),
                                      len(val_dataset) // args.display_images)
    ]
    display_data_loader = DataLoader(
        dataset=display_dataset,
        batch_size=args.display_images,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True,
    )
    print("Ended load_data()")
    return train_data_loader, val_data_loader, display_data_loader
Beispiel #10
0
def create_data_loaders(args):
    data = SliceData(root=args.data_path / f'multicoil_{args.data_split}',
                     transform=DataTransform(args.resolution),
                     sample_rate=args.sample_rate)
    data_loader = DataLoader(
        dataset=data,
        batch_size=args.batch_size,
        num_workers=4,
        pin_memory=True,
    )
    return data_loader
Beispiel #11
0
def create_data_loaders(args):
    data = SliceData(root=args.origin_file,
                     transform=DataTransform(args.resolution, args.challenge),
                     sample_rate=1.,
                     challenge=args.challenge)
    data_loader = DataLoader(
        dataset=data,
        batch_size=args.batch_size,
        num_workers=4,
        pin_memory=True,
    )
    return data_loader
def create_data_loaders(args):
    mask_func = None
    if args.mask_kspace:
        mask_func = MaskFunc(args.center_fractions, args.accelerations)
    data = SliceData(
        root=args.data_path / f'{args.challenge}_{args.data_split}',
        transform=DataTransform(args.resolution, args.challenge, mask_func),
        sample_rate=1.,
        challenge=args.challenge)
    data_loader = DataLoader(
        dataset=data,
        batch_size=args.batch_size,
        num_workers=4,
        pin_memory=True,
    )
    return data_loader
Beispiel #13
0
def create_dataset(args, transform, split):
    if args.model == 'unet_volumes':
        return MultiSliceData(root=args.data_path /
                              f'{args.challenge}_{split}',
                              transform=transform,
                              sample_rate=args.sample_rate,
                              overfit=args.overfit,
                              challenge=args.challenge,
                              num_volumes=args.num_volumes)
    return SliceData(
        root=args.data_path / f'{args.challenge}_{split}',
        transform=transform,
        sample_rate=args.sample_rate,
        overfit=args.overfit,
        challenge=args.challenge,
    )
Beispiel #14
0
 def _create_data_loader(self,
                         data_transform,
                         data_partition,
                         sample_rate=None):
     sample_rate = sample_rate or self.hparams.sample_rate
     dataset = SliceData(root=self.hparams.data_path /
                         f'{self.hparams.challenge}_{data_partition}',
                         transform=data_transform,
                         sample_rate=sample_rate,
                         challenge=self.hparams.challenge)
     sampler = DistributedSampler(dataset)
     return DataLoader(
         dataset=dataset,
         batch_size=self.hparams.batch_size,
         num_workers=8,
         pin_memory=True,
         sampler=sampler,
     )
Beispiel #15
0
def create_data_loader(args):
    data = SliceData(root=args.data_path / f'{args.challenge}_test',
                     transform=data_transform,
                     challenge=args.challenge,
                     sample_rate=args.sample_rate)
    return data
Beispiel #16
0
def main():
    logger.info("Logger is set - training start")

    # set default gpu device id
    torch.cuda.set_device(config.gpus[0])

    # set seed
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)

    torch.backends.cudnn.benchmark = True

    # get data with meta info
    # input_size, input_channels, n_classes, train_data = utils.get_data(
    #     config.dataset, config.data_path, cutout_length=0, validation=False)
    input_size = 320
    input_channels = 2
    n_classes = 2
    train_mask = MaskFunc([0.08, 0.04], [4, 8])

    train_data = SliceData(
        root=config.dataset + 'train',
        transform=DataTransform(train_mask, input_size, 'singlecoil'),
        challenge='singlecoil'
    )

    net_crit = nn.L1Loss().to(device)
    model = SearchCNNController(input_channels, config.init_channels, n_classes, config.layers,
                                net_crit, device_ids=config.gpus)
    model = model.to(device)
    # weights optimizer
    w_optim = torch.optim.SGD(model.weights(), config.w_lr, momentum=config.w_momentum,
                              weight_decay=config.w_weight_decay)
    # alphas optimizer
    alpha_optim = torch.optim.Adam(model.alphas(), config.alpha_lr, betas=(0.5, 0.999),
                                   weight_decay=config.alpha_weight_decay)

    # split data to train/validation
    n_train = len(train_data)
    split = n_train // 2
    indices = list(range(n_train))
    train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split])
    valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:])
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=config.batch_size,
                                               sampler=train_sampler,
                                               num_workers=config.workers,
                                               pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=config.batch_size,
                                               sampler=valid_sampler,
                                               num_workers=config.workers,
                                               pin_memory=True)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        w_optim, config.epochs, eta_min=config.w_lr_min)
    architect = Architect(model, config.w_momentum, config.w_weight_decay)
    # training loop
    best_top1 = 0.
    for epoch in range(config.epochs):
        lr_scheduler.step()
        lr = lr_scheduler.get_lr()[0]

        model.print_alphas(logger)

        # training
        train(train_loader, valid_loader, model, architect, w_optim, alpha_optim, lr, epoch)

        # validation
        cur_step = (epoch+1) * len(train_loader)
        top1 = validate(valid_loader, model, epoch, cur_step)

        # log
        # genotype
        genotype = model.genotype()
        logger.info("genotype = {}".format(genotype))

        # genotype as a image
        plot_path = os.path.join(config.plot_path, "EP{:02d}".format(epoch+1))
        caption = "Epoch {}".format(epoch+1)
        plot(genotype.normal, plot_path + "-normal", caption)
        plot(genotype.reduce, plot_path + "-reduce", caption)

        # save
        if best_top1 < top1:
            best_top1 = top1
            best_genotype = genotype
            is_best = True
        else:
            is_best = False
        utils.save_checkpoint(model, config.path, is_best)
        print("")

    logger.info("Final best PSNR = {:.4%}".format(best_top1))
    logger.info("Best Genotype = {}".format(best_genotype))
Beispiel #17
0
def create_datasets(args):
    train_data = SliceData(args.train_path, args.acceleration)
    dev_data = SliceDataDev(args.val_path, args.acceleration)
    return dev_data, train_data