Пример #1
0
def load_model(checkpoint_file):
    checkpoint = torch.load(checkpoint_file)
    args = checkpoint['args']
    model = UnetModel(1, 1, args.num_chans, args.num_pools, args.drop_prob).to(args.device)
    if args.data_parallel:
        model = torch.nn.DataParallel(model)
    model.load_state_dict(checkpoint['model'])
    return model
Пример #2
0
def train():
    # Args and output stuff
    args = Args().parse_args()
    writer = SummaryWriter(log_dir=args.output_dir)
    pathlib.Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    with open(args.output_dir + '/args.txt', "w") as text_file:
        for arg in vars(args):
            print(str(arg) + ': ' + str(getattr(args, arg)), file=text_file)
    # Define models:
    sampler = Sampler(args.resolution, args.decimation_rate,
                      args.decision_levels, args.sampler_convolution_channels,
                      args.sampler_convolution_layers,
                      args.sampler_linear_layers)
    adversarial = Adversarial(args.resolution,
                              args.adversarial_convolution_channels,
                              args.adversarial_convolution_layers,
                              args.adversarial_linear_layers)
    reconstructor = UnetModel(
        in_chans=2,
        out_chans=1,
        chans=args.reconstruction_unet_chans,
        num_pool_layers=args.reconstruction_unet_num_pool_layers,
        drop_prob=args.reconstruction_unet_drop_prob)
    # Define optimizer:
    adversarial_optimizer = torch.optim.Adam(
        adversarial.parameters(),
        lr=args.adversarial_lr,
    )
    sampler_optimizer = torch.optim.Adam(sampler.parameters(),
                                         lr=args.sampler_lr)
    reconstructor_optimizer = torch.optim.Adam(reconstructor.parameters(),
                                               lr=args.reconstructor_lr)

    # TODO: check this
    # this will be used to reset the gradients of the entire model
    over_all_optimizer = torch.optim.Adam(
        list(adversarial.parameters()) + list(sampler.parameters()) +
        list(reconstructor.parameters()))

    # TODO: remove this line, each NN needs it's own data loader with it's own sample rate
    args.sample_rate = 0.2
    train_data_loader, val_data_loader, display_data_loader = load_data(args)

    if args.loss_fn == "MSE":
        loss_function = nn.MSELoss()
    else:
        loss_function = nn.L1Loss()
    print('~~~Starting Training~~~')
    print('We will run now ' + str(args.num_epochs) + ' epochs')
    for epoch_number in range(args.num_epochs):
        train_epoch(sampler, adversarial, reconstructor, train_data_loader,
                    display_data_loader, loss_function,
                    args.reconstructor_sub_epochs, adversarial_optimizer,
                    sampler_optimizer, reconstructor_optimizer,
                    over_all_optimizer, epoch_number + 1, writer)
    print('~~~Finished Training~~~')
    writer.close()
Пример #3
0
def load_model(checkpoint_file):
    checkpoint = torch.load(checkpoint_file)
    args = checkpoint['args']
    model = UnetModel(in_chans=1,
                      out_chans=1,
                      chans=args.num_chans,
                      num_pool_layers=args.num_pools,
                      drop_prob=args.drop_prob,
                      acceleration=args.accelerations,
                      center_fraction=args.center_fractions,
                      res=args.resolution).to(args.device)
    if args.data_parallel:
        model = torch.nn.DataParallel(model)
    model.load_state_dict(checkpoint['model'])
    return model
 def __init__(self, number_of_conv_layers, unet_chans, unet_num_pool_layers,
              unet_drop_prob):
     super().__init__()
     conv_layers = []
     channels = 2
     # the amount of input cahnnels is 2 becuase of the ifft result
     for _ in range(number_of_conv_layers // 2):
         conv_layers.append(
             nn.Conv2d(in_channels=channels,
                       out_channels=channels * 2,
                       kernel_size=3,
                       bias=False))
         conv_layers.append(nn.ReLU())
         channels *= 2
     for _ in range(number_of_conv_layers // 2, number_of_conv_layers):
         conv_layers.append(
             nn.Conv2d(in_channels=channels,
                       out_channels=channels // 2,
                       kernel_size=3,
                       bias=False))
         conv_layers.append(nn.ReLU())
         channels //= 2
     self.K_space_reconstruction = nn.Sequential(*conv_layers)
     self.Unet_model = UnetModel(in_chans=2,
                                 out_chans=1,
                                 chans=unet_chans,
                                 num_pool_layers=unet_num_pool_layers,
                                 drop_prob=unet_drop_prob)
Пример #5
0
 def __init__(self, chans, num_pools):
     super().__init__()
     self.unet = UnetModel(in_chans=2,
                           out_chans=2,
                           chans=chans,
                           num_pool_layers=num_pools,
                           drop_prob=0)
def build_model(args):
    model = UnetModel(in_chans=1,
                      out_chans=1,
                      chans=args.num_chans,
                      num_pool_layers=args.num_pools,
                      drop_prob=args.drop_prob).to(args.device)
    return model
Пример #7
0
 def __init__(self, hparams):
     super().__init__(hparams)
     self.unet = UnetModel(in_chans=1,
                           out_chans=1,
                           chans=hparams.num_chans,
                           num_pool_layers=hparams.num_pools,
                           drop_prob=hparams.drop_prob)
Пример #8
0
 def __init__(self, resolution, unet_chans, unet_num_pool_layers,
              unet_drop_prob):
     super().__init__()
     self.resolution = resolution
     self.sampler = UnetModel(in_chans=2,
                              out_chans=1,
                              chans=unet_chans,
                              num_pool_layers=unet_num_pool_layers,
                              drop_prob=unet_drop_prob)
Пример #9
0
def build_model(args):
    model = UnetModel(in_chans=1,
                      out_chans=1,
                      chans=args.num_chans,
                      num_pool_layers=args.num_pools,
                      drop_prob=args.drop_prob).to(args.device)

    model_disc = DiscModel().to(
        args.device)  # add model params here depending on unet_model.py
    return model, model_disc
Пример #10
0
def build_model(args):
    model = UnetModel(
        in_chans=1,
        out_chans=1,
        chans=args.num_chans,
        num_pool_layers=args.num_pools,
        drop_prob=args.drop_prob,
        acceleration=args.accelerations,
        center_fraction=args.center_fractions,
        res=args.resolution
    ).to(args.device)
    return model
Пример #11
0
 def __init__(self, number_of_samples, unet_chans, unet_num_pool_layers,
              unet_drop_prob):
     super().__init__()
     self.number_of_samples = number_of_samples
     self.NN = UnetModel(in_chans=2,
                         out_chans=1,
                         chans=unet_chans,
                         num_pool_layers=unet_num_pool_layers,
                         drop_prob=unet_drop_prob)
     self.max = nn.Softmax()
     # can try sigmoid
     self.sqwish = nn.ReLU()
Пример #12
0
 def __init__(self, resolution, unet_chans, unet_num_pool_layers,
              unet_drop_prob):
     super().__init__()
     self.resolution = resolution
     self.conv_layer = nn.Sequential(
         nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, padding=1),
         nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, padding=1))
     self.sampeling_unet = UnetModel(in_chans=3,
                                     out_chans=1,
                                     chans=unet_chans,
                                     num_pool_layers=unet_num_pool_layers,
                                     drop_prob=unet_drop_prob)
Пример #13
0
def build_model(args):
    model = UnetModel(in_chans=1,
                      out_chans=1,
                      chans=args.num_chans,
                      num_pool_layers=args.num_pools,
                      drop_prob=args.drop_prob).to(args.device)
    model_dis = define_Dis(input_nc=1,
                           ndf=64,
                           netD='n_layers',
                           n_layers_D=5,
                           norm='instance').to(args.device)

    return model, model_dis
Пример #14
0
 def __init__(self, decimation_rate, resolution, trajectory_learning,
              subsampling_trajectory, spiral_density, unet_chans,
              unet_num_pool_layers, unet_drop_prob):
     super().__init__()
     self.sub_sampling_layer = SubSamplingLayer(decimation_rate, resolution,
                                                trajectory_learning,
                                                subsampling_trajectory,
                                                spiral_density)
     self.reconstruction_model = UnetModel(
         in_chans=2,
         out_chans=1,
         chans=unet_chans,
         num_pool_layers=unet_num_pool_layers,
         drop_prob=unet_drop_prob)
Пример #15
0
 def __init__(self, hparams):
     super().__init__()
     # reg_model = REDNet20(num_features= self.hparams.resolution)
     self.hparams = hparams
     self.device = "cuda"
     if hparams.gpus == 0:
         self.device = "cpu"
     print(f"Batch size:{hparams.batch_size}, number of blocks:{hparams.n_blocks}")
     reg_model = UnetModel(
         in_chans=1,
         out_chans=1,
         chans=hparams.num_chans,
         num_pool_layers=hparams.num_pools,
         drop_prob=hparams.drop_prob
     )
     self.neumann = NeumannNetwork(reg_network=reg_model, hparams=hparams)