def __init__(self,
                 forward_op,
                 primal_architecture_factory=primal_net_factory,
                 dual_architecture_factory=dual_net_factory,
                 n_iter=10,
                 n_primal=5,
                 n_dual=5):

        super(LearnedPrimalDual, self).__init__()

        self.forward_op = forward_op
        self.primal_architecture_factory = primal_architecture_factory
        self.dual_architecture_factory = dual_architecture_factory
        self.n_iter = n_iter
        self.n_primal = n_primal
        self.n_dual = n_dual

        self.primal_shape = (n_primal, ) + forward_op.domain.shape
        self.dual_shape = (n_dual, ) + forward_op.range.shape

        self.primal_op_layer = odl_torch.OperatorModule(forward_op)
        self.dual_op_layer = odl_torch.OperatorModule(forward_op.adjoint)

        self.primal_nets = nn.ModuleList()
        self.dual_nets = nn.ModuleList()

        self.concatenate_layer = ConcatenateLayer()
        self.primal_split_layer = SplitLayer([n_primal, n_dual, 1])
        self.dual_split_layer = SplitLayer([n_primal, n_dual])

        for i in range(n_iter):
            self.primal_nets.append(primal_architecture_factory(n_primal))
            self.dual_nets.append(dual_architecture_factory(n_dual))
Beispiel #2
0
def test_module_forward(shape, device):
    """Test forward evaluation with operators as modules."""
    # Define ODL operator and wrap as module
    ndim = len(shape)
    space = odl.uniform_discr([0] * ndim, shape, shape, dtype='float32')
    odl_op = odl.ScalingOperator(space, 2)
    op_mod = odl_torch.OperatorModule(odl_op)

    # Input data
    x_arr = np.ones(shape, dtype='float32')

    # Test with 1 extra dim (minimum)
    x = torch.from_numpy(x_arr).to(device)[None, ...]
    x.requires_grad_(True)
    res = op_mod(x)
    res_arr = res.detach().cpu().numpy()
    assert res_arr.shape == (1, ) + odl_op.range.shape
    assert all_almost_equal(res_arr, np.asarray(odl_op(x_arr))[None, ...])
    assert x.device.type == res.device.type == device

    # Test with 2 extra dims
    x = torch.from_numpy(x_arr).to(device)[None, None, ...]
    x.requires_grad_(True)
    res = op_mod(x)
    res_arr = res.detach().cpu().numpy()
    assert res_arr.shape == (1, 1) + odl_op.range.shape
    assert all_almost_equal(res_arr,
                            np.asarray(odl_op(x_arr))[None, None, ...])
    assert x.device.type == res.device.type == device
Beispiel #3
0
def test_module_forward_diff_shapes(device):
    """Test operator module with different shapes of input and output."""
    # Define ODL operator and wrap as module
    matrix = np.random.rand(2, 3).astype('float32')
    odl_op = odl.MatrixOperator(matrix)
    op_mod = odl_torch.OperatorModule(odl_op)

    # Input data
    x_arr = np.ones(3, dtype='float32')

    # Test with 1 extra dim (minimum)
    x = torch.from_numpy(x_arr).to(device)[None, ...]
    x.requires_grad_(True)
    res = op_mod(x)
    res_arr = res.detach().cpu().numpy()
    assert res_arr.shape == (1, ) + odl_op.range.shape
    assert all_almost_equal(res_arr, np.asarray(odl_op(x_arr))[None, ...])
    assert x.device.type == res.device.type == device

    # Test with 2 extra dims
    x = torch.from_numpy(x_arr).to(device)[None, None, ...]
    x.requires_grad_(True)
    res = op_mod(x)
    res_arr = res.detach().cpu().numpy()
    assert res_arr.shape == (1, 1) + odl_op.range.shape
    assert all_almost_equal(res_arr,
                            np.asarray(odl_op(x_arr))[None, None, ...])
    assert x.device.type == res.device.type == device
Beispiel #4
0
def test_module_backward(device):
    """Test backpropagation with operators as modules."""
    # Define ODL operator and wrap as module
    matrix = np.random.rand(2, 3).astype('float32')
    odl_op = odl.MatrixOperator(matrix)
    op_mod = odl_torch.OperatorModule(odl_op)
    loss_fn = nn.MSELoss()

    # Test with linear layers (1 extra dim)
    layer_before = nn.Linear(3, 3)
    layer_after = nn.Linear(2, 2)
    model = nn.Sequential(layer_before, op_mod, layer_after).to(device)
    x = torch.from_numpy(np.ones(3, dtype='float32'))[None, ...].to(device)
    x.requires_grad_(True)
    target = torch.from_numpy(np.zeros(2, dtype='float32'))[None,
                                                            ...].to(device)
    loss = loss_fn(model(x), target)
    loss.backward()
    assert all(p is not None for p in model.parameters())
    assert x.grad.detach().cpu().abs().sum() != 0
    assert x.device.type == loss.device.type == device

    # Test with conv layers (2 extra dims)
    layer_before = nn.Conv1d(1, 2, 2)  # 1->2 channels
    layer_after = nn.Conv1d(2, 1, 2)  # 2->1 channels
    model = nn.Sequential(layer_before, op_mod, layer_after).to(device)
    # Input size 4 since initial convolution reduces by 1
    x = torch.from_numpy(np.ones(4, dtype='float32'))[None, None,
                                                      ...].to(device)
    x.requires_grad_(True)
    # Output size 1 since final convolution reduces by 1
    target = torch.from_numpy(np.zeros(1, dtype='float32'))[None, None,
                                                            ...].to(device)

    loss = loss_fn(model(x), target)
    loss.backward()
    assert all(p is not None for p in model.parameters())
    assert x.grad.detach().cpu().abs().sum() != 0
    assert x.device.type == loss.device.type == device
Beispiel #5
0
def main():

    parser = argparse.ArgumentParser()
    # general & dataset & training settings
    parser.add_argument('--k_max', type=int, default=5,
                        help='Max reconstruction iterations')
    parser.add_argument('--save_figs', type = lambda x:bool(strtobool(x)), default=True,
                        help='save pics in reconstruction')
    parser.add_argument('--img_mode', type=str, default='SimpleCT',
                        help=' image-modality reconstruction: SimpleCT')
    parser.add_argument('--train_size', type=int, default=4000,
                        help='dataset size')
    parser.add_argument('--dataset_type', type=str, default='GenEllipsesSamples',
                        help='GenEllipsesSamples or GenFoamSamples')
    parser.add_argument('--pseudo_inverse_init', type = lambda x:bool(strtobool(x)), default=True,
                        help='initialise with pseudoinverse')
    parser.add_argument('--epochs', type=int, default=150,
                        help='number of epochs to train')
    parser.add_argument('--batch_size', type=int, default=128,
                        help='input batch size for training')
    parser.add_argument('--initial_lr', type=float, default=1e-3,
                        help='initial_lr')
    parser.add_argument('--val_batch_size', type=int, default=128,
                        help='input batch size for valing')
    parser.add_argument('--arch_args', type=json.loads, default=dict(),
                        help='load architecture dictionary')
    parser.add_argument('--block_type', type=str, default='bayesian_homo',
                        help='deterministic, bayesian_homo, bayesian_hetero')
    parser.add_argument('--save', type= lambda x:bool(strtobool(x)), default=True,
                        help='save model')
    parser.add_argument('--load', type= lambda x:bool(strtobool(x)), default=False,
                        help='save model')

    # forward models setting
    parser.add_argument('--size', type=int, default=128,
                        help='image size')
    parser.add_argument('--beam_num_angle', type=int, default=30,
                        help='number of angles / projections')
    parser.add_argument('--limited_view', type = lambda x:bool(strtobool(x)), default=False,
                        help='limited view geometry instead of sparse view geometry')
    # options
    parser.add_argument('--no_cuda', type = lambda x:bool(strtobool(x)), default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=222,
                        help='random seed')
    parser.add_argument('--config', default='configs/bayesian_arch_config.json',
                        help='config file path')

    args = parser.parse_args()
    if args.config is not None:
        with open(args.config) as handle:
            config = json.load(handle)
        vars(args).update(config)

    block_utils.set_gpu_mode(True)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device('cuda' if use_cuda else 'cpu')

    if args.img_mode == SimpleCT.__name__:
        img_mode = SimpleCT()
        half_size = args.size / 2
        space =  odl.uniform_discr([-half_size, -half_size],
                                   [half_size, half_size],
                                   [args.size, args.size], dtype='float32')
        img_mode.space = space
        if not args.limited_view:
            geometry = odl.tomo.parallel_beam_geometry(space, num_angles=args.beam_num_angle)
        elif args.limited_view:
            geometry = limited_view_parallel_beam_geometry(space, beam_num_angle=args.beam_num_angle)
        else:
            raise NotImplementedError
        img_mode.geometry = geometry
        operator = odl.tomo.RayTransform(space, geometry)
        opnorm = odl.power_method_opnorm(operator)
        img_mode.operator = odl_torch.OperatorModule((1 / opnorm) * operator)
        img_mode.adjoint = odl_torch.OperatorModule((1 / opnorm) * operator.adjoint)
        pseudoinverse = odl.tomo.fbp_op(operator)
        pseudoinverse = odl_torch.OperatorModule(pseudoinverse * opnorm)
        img_mode.pseudoinverse = pseudoinverse

        geometry_specs = 'full_view_sparse_' + str(args.beam_num_angle) if not args.limited_view else 'limited_view_' + str(args.beam_num_angle)
        dataset_name = 'dataset' + '_' + args.img_mode + '_' + str(args.size) \
        + '_' + str(args.train_size) + '_' + geometry_specs + '_' + args.dataset_type

        data_constructor = DatasetConstructor(img_mode, train_size=args.train_size, dataset_name=dataset_name)
        data = data_constructor.data()
    else:
        raise NotImplementedError
    dataset = DataSet(data, img_mode, args.pseudo_inverse_init)

    optim_parms = {'epochs':args.epochs, 'initial_lr':  args.initial_lr, 'batch_size': args.batch_size}

    if args.block_type == 'deterministic':
        from blocks import DeterministicBlock as Block
    elif args.block_type == 'bayesian_homo':
        from blocks import BlockHomo as Block
    elif args.block_type == 'bayesian_hetero':
        from blocks import BlockHetero as Block
    else:
        raise NotImplementedError

    # results directory
    path = os.path.dirname(__file__)
    dir_path = os.path.join(path, 'results', args.img_mode, args.block_type, args.dataset_type, str(args.train_size), geometry_specs, str(args.size), str(args.seed))
    if not os.path.isdir(dir_path):
        os.makedirs(dir_path)

    # all config
    print('===========================\n', flush=True)
    for key, val in vars(args).items():
        print('{}: {}'.format(key, val), flush=True)
    print('===========================\n', flush=True)

    blocks_history = {'block': [], 'optimizer': []}
    # savings training procedures
    filename = 'train_phase'
    filepath = os.path.join(dir_path, filename)
    vis = TrainVisualiser(filepath)

    start_time = time.time()
    # looping through architecture-blocs
    for idx in range(1, args.k_max + 1):

        print('============== training block number: {} ============= \n'.format(idx), flush=True)

        train_tensor =  dataset.construct(flag='train')
        val_tensor = dataset.construct(flag='validation')

        train_loader = DataLoader(train_tensor, batch_size=args.batch_size, shuffle=True)
        val_loader = DataLoader(val_tensor, batch_size=args.val_batch_size, shuffle=True)

        block = Block(args.arch_args)
        block = block.to(device)

        path_block = os.path.join(dir_path, str(idx) + '.pt')
        if args.load and \
            os.path.exists(path_block):
            block.load_state_dict( torch.load(path_block) )
            loaded = True
            print('============= loaded idx: {} ============='.format(idx), flush=True)

        else:
            block.optimise(train_loader, **optim_parms)
            loaded = False

        start = time.time()
        info = next_step_update(dataset, train_tensor, block, device, flag='train')
        end = time.time()
        print('============= {} {:.4f} ============= \n'.format('training reconstruction', end-start), flush=True)
        for key in info.keys():
            print('{}: {} \n'.format(key, info[key]), flush=True)

        start = time.time()
        info = next_step_update(dataset, val_tensor, block, device, flag='validation')
        end = time.time()
        print('============= {} {:.4f} ============= \n'.format('validation reconstruction', end-start), flush=True)
        for key in info.keys():
            print('{}: {} \n'.format(key, info[key]), flush=True)

        vis.update(dataset, flag='validation')
        blocks_history['block'].append(block)

        # reconstruction
        resonstruction_dir_path = os.path.join(dir_path, str(idx))
        if not loaded:
            if not os.path.isdir(resonstruction_dir_path):
                os.makedirs(resonstruction_dir_path)
            get_stats(dataset, blocks_history, device, resonstruction_dir_path)

        if args.save and not loaded:
            torch.save(block.state_dict(), os.path.join(dir_path, str(idx) + '.pt'))

    print('--- training time: %s seconds ---' % (time.time() - start_time), flush=True)
    vis.generate()
def main():

    parser = argparse.ArgumentParser()
    # general & dataset & training settings
    parser.add_argument('--k_max', type=int, default=5,
                        help='Max reconstruction iterations')
    parser.add_argument('--save_figs', type = lambda x:bool(strtobool(x)), default=True,
                        help='save pics in reconstruction')
    parser.add_argument('--img_mode', type=str, default='SimpleCT',
                        help=' image-modality reconstruction: SimpleCT')
    parser.add_argument('--train_size', type=int, default=4000,
                        help='dataset size')
    parser.add_argument('--pseudo_inverse_init', type = lambda x:bool(strtobool(x)), default=True,
                        help='initialise with pseudoinverse')
    parser.add_argument('--brain', type = lambda x:bool(strtobool(x)), default=False,
                        help='test set of brain images')
    parser.add_argument('--epochs', type=int, default=150,
                        help='number of epochs to train')
    parser.add_argument('--batch_size', type=int, default=128,
                        help='input batch size for training')
    parser.add_argument('--initial_lr', type=float, default=1e-3,
                        help='initial_lr')
    parser.add_argument('--val_batch_size', type=int, default=128,
                        help='input batch size for valing')

    # forward models setting
    parser.add_argument('--size', type=int, default=128,
                        help='image size')
    parser.add_argument('--beam_num_angle', type=int, default=30,
                        help='number of angles / projections')
    # options
    parser.add_argument('--no_cuda', type = lambda x:bool(strtobool(x)), default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=222,
                        help='random seed')

    args = parser.parse_args()
    layer_utils.set_gpu_mode(True)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device('cuda' if use_cuda else 'cpu')

    if args.img_mode is not None:
        forward_model = ForwardModel()
        half_size = args.size / 2
        space =  odl.uniform_discr([-half_size, -half_size],
                                   [half_size, half_size],
                                   [args.size, args.size], dtype='float32')
        forward_model.space = space
        geometry = odl.tomo.parallel_beam_geometry(space, num_angles=args.beam_num_angle)
        forward_model.geometry = geometry
        operator = odl.tomo.RayTransform(space, geometry)
        opnorm = odl.power_method_opnorm(operator)
        forward_model.operator = odl_torch.OperatorModule( (1 / opnorm) * operator )
        forward_model.adjoint = odl_torch.OperatorModule(operator.adjoint)
        pseudoinverse = odl.tomo.fbp_op(operator)
        pseudoinverse = odl_torch.OperatorModule( pseudoinverse * opnorm )
        forward_model.pseudoinverse = pseudoinverse

        geometry_specs = 'full_view_sparse_' + str(args.beam_num_angle)
        dataset_name = 'dataset' + '_' + args.img_mode + '_' + str(args.size) \
        + '_' + str(args.train_size) + '_' + geometry_specs + '_' \
        + 'brain' + '_' + str(args.brain)


    if args.img_mode == SimpleCT.__name__:
        img_mode = SimpleCT(forward_model)
        data_constructor = DatasetConstructor(img_mode, train_size=args.train_size, brain=args.brain, dataset_name=dataset_name)
        data = data_constructor.data()
    else:
        raise NotImplementedError
    dataset = DataSet(data, img_mode, args.pseudo_inverse_init)

    optim_parms = {'epochs':args.epochs, 'initial_lr':  args.initial_lr, 'batch_size': args.batch_size}
    from hybrid_model import HybridModel as NeuralLearner

    # results directory
    path = os.path.dirname(__file__)
    dir_path = os.path.join(path, 'results', args.img_mode, 'MFVI', str(args.train_size), geometry_specs, str(args.seed))
    if not os.path.isdir(dir_path):
        os.makedirs(dir_path)

    # all config
    print('===========================\n', flush=True)
    for key, val in vars(args).items():
        print('{}: {}'.format(key, val), flush=True)
    print('===========================\n', flush=True)

    blocks_history = {'model': [], 'optimizer': []}
    arch_args = {'arch': {'up':  [ [1, 16, 3, 1, 1],  [16, 32, 3, 1, 1]],
                          'low': [ [1, 16, 3, 1, 1],  [16, 32, 3, 1, 1]],
                          'cm':  [ [64, 32, 3, 1, 1], [32, 16, 3, 1, 1]] }}

    # savings training procedures
    filename = 'train_phase'
    filepath = os.path.join(dir_path, filename)
    vis = TrainVisualiser(filepath)

    start_time = time.time()
    # looping through architecture-blocs
    for idx in range(1, args.k_max + 1):

        print('============== training block number: {} ============= \n'.format(idx), flush=True)

        train_tensor =  dataset.construct(flag='train')
        val_tensor = dataset.construct(flag='validation')

        train_loader = DataLoader(train_tensor, batch_size=args.batch_size, shuffle=True)
        val_loader = DataLoader(val_tensor, batch_size=args.val_batch_size, shuffle=True)

        model = NeuralLearner(arch_args)
        model = model.to(device)
        model_path = os.path.join(dir_path, str(idx) + '.pt')
        if os.path.exists(model_path):
            model_loaded = True
            model.load_state_dict(torch.load(model_path))
            print('idx: {} model loaded!\npath to model:\n{}'.format(idx, model_path), flush=True)
        else:
            model_loaded = False
            model.optimise(train_loader, **optim_parms)
            save_net(model, os.path.join(dir_path, str(idx) + '.pt'))
            print('idx: {} optimisation finished!'.format(idx), flush=True)

        start = time.time()
        info = next_step_update(dataset, train_tensor, model, device, flag='train')
        end = time.time()
        print('============= {} {:.4f} ============= \n'.format('training reconstruction', end-start), flush=True)
        for key in info.keys():
            print('{}: {} \n'.format(key, info[key]), flush=True)

        start = time.time()
        info = next_step_update(dataset, val_tensor, model, device, flag='validation')
        end = time.time()
        print('============= {} {:.4f} ============= \n'.format('validation reconstruction', end-start), flush=True)
        for key in info.keys():
            print('{}: {} \n'.format(key, info[key]), flush=True)

        vis.update(dataset, flag='validation')
        blocks_history['model'].append(model)

        # reconstruction
        resonstruction_dir_path = os.path.join(dir_path, str(idx))
        if model_loaded:
            resonstruction_dir_path = os.path.join(dir_path, str(idx), 're-loaded')

        if not os.path.isdir(resonstruction_dir_path):
            os.makedirs(resonstruction_dir_path)
        get_stats(dataset, blocks_history, device, resonstruction_dir_path)

    print('--- training time: %s seconds ---' % (time.time() - start_time), flush=True)
    vis.generate()
def main():

    parser = argparse.ArgumentParser()
    # general & dataset & training settings
    parser.add_argument('--k_max', type=int, default=5,
                        help='Max reconstruction iterations')
    parser.add_argument('--save_figs', type = lambda x:bool(strtobool(x)), default=True,
                        help='save pics in reconstruction')
    parser.add_argument('--img_mode', type=str, default='SimpleCT',
                        help=' image-modality reconstruction: SimpleCT')
    parser.add_argument('--train_size', type=int, default=4000,
                        help='dataset size')
    parser.add_argument('--dataset_type', type=str, default='GenEllipsesSamples',
                        help='GenEllipsesSamples or GenFoamSamples')
    parser.add_argument('--pseudo_inverse_init', type = lambda x:bool(strtobool(x)), default=True,
                        help='initialise with pseudoinverse')
    parser.add_argument('--epochs', type=int, default=150,
                        help='number of epochs to train')
    parser.add_argument('--batch_size', type=int, default=128,
                        help='input batch size for training')
    parser.add_argument('--initial_lr', type=float, default=1e-3,
                        help='initial_lr')
    parser.add_argument('--val_batch_size', type=int, default=128,
                        help='input batch size for valing')
    parser.add_argument('--arch_args', type=json.loads, default=dict(),
                        help='load architecture dictionary')
    parser.add_argument('--block_type', type=str, default='bayesian_homo',
                        help='deterministic, bayesian_homo, bayesian_hetero')

    # forward models setting
    parser.add_argument('--size', type=int, default=128,
                        help='image size')
    parser.add_argument('--beam_num_angle', type=int, default=30,
                        help='number of angles / projections')
    parser.add_argument('--limited_view', type = lambda x:bool(strtobool(x)), default=False,
                        help='limited view geometry instead of sparse view geometry')
    # options
    parser.add_argument('--no_cuda', type = lambda x:bool(strtobool(x)), default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=222,
                        help='random seed')
    parser.add_argument('--config', default='configs/bayesian_arch_config.json',
                        help='config file path')

    args = parser.parse_args()
    if args.config is not None:
        with open(args.config) as handle:
            config = json.load(handle)
        vars(args).update(config)

    block_utils.set_gpu_mode(True)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device('cuda' if use_cuda else 'cpu')

    if args.img_mode == SimpleCT.__name__:
        img_mode = SimpleCT()
        half_size = args.size / 2
        space =  odl.uniform_discr([-half_size, -half_size],
                                   [half_size, half_size],
                                   [args.size, args.size], dtype='float32')
        img_mode.space = space
        if not args.limited_view:
            geometry = odl.tomo.parallel_beam_geometry(space, num_angles=args.beam_num_angle)
        elif args.limited_view:
            geometry = limited_view_parallel_beam_geometry(space, beam_num_angle=args.beam_num_angle)
        else:
            raise NotImplementedError
        img_mode.geometry = geometry
        operator = odl.tomo.RayTransform(space, geometry)
        opnorm = odl.power_method_opnorm(operator)
        img_mode.operator = odl_torch.OperatorModule((1 / opnorm) * operator)
        img_mode.adjoint = odl_torch.OperatorModule((1 / opnorm) * operator.adjoint)
        pseudoinverse = odl.tomo.fbp_op(operator)
        pseudoinverse = odl_torch.OperatorModule(pseudoinverse * opnorm)
        img_mode.pseudoinverse = pseudoinverse

        geometry_specs = 'full_view_sparse_' + str(args.beam_num_angle) if not args.limited_view else 'limited_view_' + str(args.beam_num_angle)
        dataset_name = 'dataset' + '_' + args.img_mode + '_' + str(args.size) \
        + '_' + str(args.train_size) + '_' + geometry_specs + '_' + args.dataset_type

        data_constructor = DatasetConstructor(img_mode, train_size=args.train_size, dataset_name=dataset_name)
        data = data_constructor.data()
    else:
        raise NotImplementedError
    dataset = DataSet(data, img_mode, args.pseudo_inverse_init)

    optim_parms = {'epochs':args.epochs, 'initial_lr':  args.initial_lr, 'batch_size': args.batch_size}

    if args.block_type == 'deterministic':
        from blocks import DeterministicBlock as Block
    elif args.block_type == 'bayesian_homo':
        from blocks import BlockHomo as Block
    elif args.block_type == 'bayesian_hetero':
        from blocks import BlockHetero as Block
    else:
        raise NotImplementedError

    # results directory
    path = os.path.dirname(__file__)
    dir_path = os.path.join(path, 'results', args.img_mode, args.block_type, args.dataset_type, str(args.train_size), geometry_specs, str(args.size), str(args.seed))
    if not os.path.isdir(dir_path):
        os.makedirs(dir_path)

    # all config
    print('===========================\n', flush=True)
    for key, val in vars(args).items():
        print('{}: {}'.format(key, val), flush=True)
    print('===========================\n', flush=True)

    blocks_history = {'block': []}
    start_time = time.time()
    # looping through architecture-blocs
    for idx in range(1, args.k_max + 1):

        print('============== training block number: {} ============= \n'.format(idx), flush=True)

        block = Block(args.arch_args)
        block = block.to(device)

        path_block = os.path.join(dir_path, str(idx) + '.pt')
        if os.path.exists(path_block):
            block.load_state_dict( torch.load(path_block) )
        else:
            raise NotImplementedError

        blocks_history['block'].append(block)

    start = time.time()
    with torch.no_grad():
        mc_samples_mean, mc_samples_var = [], []
        for _ in range(100):
            for block in blocks_history['block']:
                block.eval()
                test_tensor = dataset.construct(flag='test', display=False)
                dataloader = DataLoader(deepcopy(test_tensor), batch_size=64, shuffle=False, drop_last=False)
                X_, Var_ = [], []
                for batch_idx, (data, grad, target) in enumerate(dataloader):
                    data, grad, target = data.to(device), grad.to(device), target.to(device)
                    output, var = block.forward(data, grad)
                    X_.append(output); Var_.append(var)
                dataset.update(torch.cat(X_).cpu(), flag='test')
            mc_samples_mean.append(dataset.X_['test'])
            mc_samples_var.append(torch.cat(Var_).cpu())
            dataset.reset('test')

        print('time: {}'.format(time.time() - start))
        mean = torch.mean(torch.stack(mc_samples_mean), dim=0)
        if hasattr(block, 'bayes_CNN_log_std'):
            epistemic = torch.std(torch.stack(mc_samples_mean), dim=0)**2
            aleatoric = torch.mean(torch.stack(mc_samples_var), dim=0)
            std = torch.sqrt( torch.std(torch.stack(mc_samples_mean), dim=0)**2 + torch.mean(torch.stack(mc_samples_var), dim=0) )
        else:
            raise NotImplementedError

        dir_path = os.path.join(dir_path, 'uncertainty analysis')
        if not os.path.isdir(dir_path):
            os.makedirs(dir_path)
        filename = 'data' + '_' + str(args.k_max) + '.p'
        filepath = os.path.join(dir_path, filename)
        with open(filepath, 'wb') as handle:
            pickle.dump({'mean': mean, 'aleatoric': aleatoric, 'epistemic': epistemic, 'std': std}, handle, protocol=pickle.HIGHEST_PROTOCOL)
        handle.close()

    # reconstruction
    resonstruction_dir_path = os.path.join(dir_path, str(idx))
    if not os.path.isdir(resonstruction_dir_path):
        os.makedirs(resonstruction_dir_path)
    get_stats(dataset, blocks_history, device, resonstruction_dir_path)

    print('--- training time: %s seconds ---' % (time.time() - start_time), flush=True)
Beispiel #8
0
import numpy as np
import torch
from torch import nn

import odl
from odl.contrib import torch as odl_torch

# --- Forward --- #

# Define ODL operator
matrix = np.array([[1, 0, 0], [0, 1, 1]], dtype='float32')
odl_op = odl.MatrixOperator(matrix)

# Wrap ODL operator as `Module`
op_layer = odl_torch.OperatorModule(odl_op)

# Test with some inputs. We need to add at least one batch axis
inp = torch.ones((1, 3))
print('Operator layer evaluated on a 1x3 tensor:')
print(op_layer(inp))
print()

inp = torch.ones((1, 1, 3))
print('Operator layer evaluated on a 1x1x3 tensor:')
print(op_layer(inp))
print()

# We combine the module with some builtin pytorch layers of matching shapes
layer_before = nn.Linear(3, 3)
layer_after = nn.Linear(2, 2)