예제 #1
0
def create_datasets(args):
    train_mask = create_mask_for_mask_type(args.mask_type, args.center_fractions, args.accelerations)
    dev_mask = create_mask_for_mask_type(args.mask_type, args.center_fractions, args.accelerations)

    train_data = SliceData(
        root=args.data_path / f'{args.challenge}_train',
        transform=DataTransform(train_mask, args.resolution, args.challenge),
        sample_rate=args.sample_rate,
        challenge=args.challenge
    )
    dev_data = SliceData(
        root=args.data_path / f'{args.challenge}_val',
        transform=DataTransform(dev_mask, args.resolution, args.challenge, use_seed=True),
        sample_rate=args.sample_rate,
        challenge=args.challenge,
    )
    return dev_data, train_data
예제 #2
0
 def train_data_transform(self):
     mask = create_mask_for_mask_type(self.hparams.mask_type,
                                      self.hparams.center_fractions,
                                      self.hparams.accelerations)
     return DataTransform(self.hparams.resolution,
                          self.hparams.challenge,
                          mask,
                          use_seed=False)
예제 #3
0
def create_data_loader(args):
    dev_mask = create_mask_for_mask_type(args.mask_type, args.center_fractions, args.accelerations)
    data = SliceData(
        root=args.data_path / f'{args.challenge}_val',
        transform=DataTransform(dev_mask),
        challenge=args.challenge,
        sample_rate=args.sample_rate
    )
    return data
예제 #4
0
 def __init__(self, reg_network=None, hparams=None):
     super(NeumannNetwork, self).__init__()
     self.hparams = hparams
     self.device = "cuda"
     if hparams.gpus == 0:
         self.device = "cpu"
     self.mask_func = create_mask_for_mask_type(
         self.hparams.mask_type, self.hparams.center_fractions,
         self.hparams.accelerations)
     self.reg_network = reg_network
     self.n_blocks = hparams.n_blocks
     self.eta = nn.Parameter(torch.Tensor([0.1]), requires_grad=True)
     self.preconditioned = False
예제 #5
0
def create_data_loaders(args):
    mask_func = None
    if args.mask_kspace:
        mask_func = create_mask_for_mask_type(args.mask_type, args.center_fractions, args.accelerations)
    data = SliceData(
        root=args.data_path / f'{args.challenge}_{args.data_split}',
        transform=DataTransform(args.resolution, args.challenge, mask_func),
        sample_rate=1.,
        challenge=args.challenge
    )
    data_loader = DataLoader(
        dataset=data,
        batch_size=args.batch_size,
        num_workers=4,
        pin_memory=True,
    )
    return data_loader
예제 #6
0
 def validation_step(self, batch, batch_idx):
     # print(f"Validation step, batch_idx:{batch_idx}")
     image, target, kspace, mean, std, fname, slice = batch
     output = self.forward(kspace)
     mask = create_mask_for_mask_type(self.hparams.mask_type, self.hparams.center_fractions,
                                      self.hparams.accelerations)
     _, undersampled_img = forward_adjoint_helper(self.device, self.hparams, mask, kspace, None)
     mean = mean.unsqueeze(1).unsqueeze(2)
     std = std.unsqueeze(1).unsqueeze(2)
     return {
         "fname": fname,
         "slice": slice,
         "output": (output * std + mean).cpu().numpy(),
         "target": (target * std + mean).cpu().numpy(),
         "undersampled_img": (undersampled_img * std + mean).cpu().numpy(),
         "val_loss": F.l1_loss(output, target),
     }
예제 #7
0
 def val_data_transform(self):
     mask = create_mask_for_mask_type(self.hparams.mask_type,
                                      self.hparams.center_fractions,
                                      self.hparams.accelerations)
     return DataTransform(mask, self.hparams.resolution)
예제 #8
0
def train_epoch(args, epoch, model, data_loader, optimizer, writer):
    model.train()
    avg_loss = 0.
    start_epoch = start_iter = time.perf_counter()
    global_step = epoch * len(data_loader)
    
    if args.fn_train:
        #loss_f = torch.nn.MSELoss(reduction='none')
        loss_f = torch.nn.L1Loss(reduction='none')
        mask_f = create_mask_for_mask_type('random',[0.08, 0.04],  [4, 8])
        
        n_pixel_range = (10, 100)
        n_pixel_range = (args.min_n_pixel, args.max_n_pixel)

    for iter, data in enumerate(data_loader):
        if args.bbox_root:
            (input, target, mean, std, norm), seg = data
        else:
            if not args.fn_train:
                input, target, mean, std, norm = data
            else:
                input, target, mean, std, norm, fn_image = data
        input = input.unsqueeze(1).to(args.device)
        target = target.to(args.device)


        output = model(input).squeeze(1)

        loss = F.l1_loss(output, target)

        if args.bbox_root:
            writer.add_scalar('L1_Loss', loss.item(), global_step + iter)

            
            bbox_loss = []
            for j in range(11):
                seg_mask = seg[:, :, :, j]
                if seg_mask.sum() > 0:
                    seg_mask = seg_mask.to(args.device)
                    bbox_output = output * seg_mask
                    bbox_target = target * seg_mask
                    bbox_loss.append(nmse(bbox_target, bbox_output))

            if len(bbox_loss)>0:
                bbox_loss = 2 * torch.stack(bbox_loss).mean()
                #print(loss.item(), bbox_loss.item())
                writer.add_scalar('BBOX_Loss', bbox_loss.item(), global_step + iter)
                loss += bbox_loss



        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        #information_loss_list, xs_list, ys_list = [], [], []
        
        if args.fn_train:
            #fn_attack_train(model, target, optimizer)
            #run_finite_diff(model, target, iterations=2, train=True, optimizer=optimizer)
            get_attack_loss_new(model, fn_image,
                           loss_f=loss_f, 
                                xs=np.random.randint(low=from_boarder, high=320-from_boarder, size=(fn_image.size(0),)), 
                               ys=np.random.randint(low=from_boarder, high=320-from_boarder, size=(fn_image.size(0),)), 
                                shape=(320, 320), n_pixel_range=n_pixel_range, train=True, optimizer=optimizer)

        avg_loss = 0.99 * avg_loss + 0.01 * loss.item() if iter > 0 else loss.item()
        writer.add_scalar('TrainLoss', loss.item(), global_step + iter)

        if iter % args.report_interval == 0:
            logging.info(
                f'Epoch = [{epoch:3d}/{args.num_epochs:3d}] '
                f'Iter = [{iter:4d}/{len(data_loader):4d}] '
                f'Loss = {loss.item():.4g} Avg Loss = {avg_loss:.4g} '
                f'Time = {time.perf_counter() - start_iter:.4f}s',
            )
        start_iter = time.perf_counter()
    return avg_loss, time.perf_counter() - start_epoch
예제 #9
0
    parser.add_argument('--bbox_root', type=str, default=None)

    parser.add_argument('--fnaf_eval', type=pathlib.Path, default=None,
                        help='Path where fnaf eval results should be saved')

    parser.add_argument('--fnaf_eval_control', action='store_true')

    
    return parser


if __name__ == '__main__':

    
    args = create_arg_parser().parse_args()
                        
    import os
    os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
    # The GPU id to use, usually either "0" or "1"
    os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu 
                        
    mask_f = create_mask_for_mask_type(args.mask_type, args.center_fractions, args.accelerations)
                        
                      


    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    main(args)