예제 #1
0
def load_model(checkpoint_file):
    checkpoint = torch.load(checkpoint_file)
    args = checkpoint['args']
    model = Subsampling_Model(out_directions=args.out_directions,
                              dir_decimation_rate=args.dir_decimation_rate,
                              direction_learning=args.direction_learning,
                              initialization=args.initialization,
                              chans=args.num_chans,
                              num_pool_layers=args.num_pools,
                              drop_prob=args.drop_prob).to(args.device)
    if args.data_parallel:
        model = torch.nn.DataParallel(model)
    model.load_state_dict(checkpoint['model'])
    return model
예제 #2
0
def load_model(checkpoint):
    checkpoint = torch.load(checkpoint)
    args = checkpoint['args']
    model = Subsampling_Model(in_chans=15,
                              out_chans=1,
                              chans=args.num_chans,
                              num_pool_layers=args.num_pools,
                              drop_prob=args.drop_prob,
                              decimation_rate=args.decimation_rate,
                              res=args.resolution,
                              trajectory_learning=args.trajectory_learning,
                              initialization=args.initialization,
                              SNR=args.SNR).to(args.device)
    if args.data_parallel:
        model = torch.nn.DataParallel(model)
    model.load_state_dict(checkpoint['model'])
    return model
예제 #3
0
def build_model(args):
    model = Subsampling_Model(out_directions=args.out_directions,
                              dir_decimation_rate=args.dir_decimation_rate,
                              direction_learning=args.direction_learning,
                              initialization=args.initialization,
                              chans=args.num_chans,
                              num_pool_layers=args.num_pools,
                              drop_prob=args.drop_prob).to(args.device)
    return model
예제 #4
0
def build_model(args):
    model = Subsampling_Model(
        in_chans=1,
        out_chans=1,
        chans=args.num_chans,
        num_pool_layers=args.num_pools,
        drop_prob=args.drop_prob,
        decimation_rate=args.decimation_rate,
        res=args.resolution,
        trajectory_learning=args.trajectory_learning,
        initialization=args.initialization,
        SNR=args.SNR,
        n_shots=args.n_shots,
        interp_gap=args.interp_gap
    ).to(args.device)
    return model