def __init__(self, forward_op, primal_architecture_factory=primal_net_factory, dual_architecture_factory=dual_net_factory, n_iter=10, n_primal=5, n_dual=5): super(LearnedPrimalDual, self).__init__() self.forward_op = forward_op self.primal_architecture_factory = primal_architecture_factory self.dual_architecture_factory = dual_architecture_factory self.n_iter = n_iter self.n_primal = n_primal self.n_dual = n_dual self.primal_shape = (n_primal, ) + forward_op.domain.shape self.dual_shape = (n_dual, ) + forward_op.range.shape self.primal_op_layer = odl_torch.OperatorModule(forward_op) self.dual_op_layer = odl_torch.OperatorModule(forward_op.adjoint) self.primal_nets = nn.ModuleList() self.dual_nets = nn.ModuleList() self.concatenate_layer = ConcatenateLayer() self.primal_split_layer = SplitLayer([n_primal, n_dual, 1]) self.dual_split_layer = SplitLayer([n_primal, n_dual]) for i in range(n_iter): self.primal_nets.append(primal_architecture_factory(n_primal)) self.dual_nets.append(dual_architecture_factory(n_dual))
def test_module_forward(shape, device): """Test forward evaluation with operators as modules.""" # Define ODL operator and wrap as module ndim = len(shape) space = odl.uniform_discr([0] * ndim, shape, shape, dtype='float32') odl_op = odl.ScalingOperator(space, 2) op_mod = odl_torch.OperatorModule(odl_op) # Input data x_arr = np.ones(shape, dtype='float32') # Test with 1 extra dim (minimum) x = torch.from_numpy(x_arr).to(device)[None, ...] x.requires_grad_(True) res = op_mod(x) res_arr = res.detach().cpu().numpy() assert res_arr.shape == (1, ) + odl_op.range.shape assert all_almost_equal(res_arr, np.asarray(odl_op(x_arr))[None, ...]) assert x.device.type == res.device.type == device # Test with 2 extra dims x = torch.from_numpy(x_arr).to(device)[None, None, ...] x.requires_grad_(True) res = op_mod(x) res_arr = res.detach().cpu().numpy() assert res_arr.shape == (1, 1) + odl_op.range.shape assert all_almost_equal(res_arr, np.asarray(odl_op(x_arr))[None, None, ...]) assert x.device.type == res.device.type == device
def test_module_forward_diff_shapes(device): """Test operator module with different shapes of input and output.""" # Define ODL operator and wrap as module matrix = np.random.rand(2, 3).astype('float32') odl_op = odl.MatrixOperator(matrix) op_mod = odl_torch.OperatorModule(odl_op) # Input data x_arr = np.ones(3, dtype='float32') # Test with 1 extra dim (minimum) x = torch.from_numpy(x_arr).to(device)[None, ...] x.requires_grad_(True) res = op_mod(x) res_arr = res.detach().cpu().numpy() assert res_arr.shape == (1, ) + odl_op.range.shape assert all_almost_equal(res_arr, np.asarray(odl_op(x_arr))[None, ...]) assert x.device.type == res.device.type == device # Test with 2 extra dims x = torch.from_numpy(x_arr).to(device)[None, None, ...] x.requires_grad_(True) res = op_mod(x) res_arr = res.detach().cpu().numpy() assert res_arr.shape == (1, 1) + odl_op.range.shape assert all_almost_equal(res_arr, np.asarray(odl_op(x_arr))[None, None, ...]) assert x.device.type == res.device.type == device
def test_module_backward(device): """Test backpropagation with operators as modules.""" # Define ODL operator and wrap as module matrix = np.random.rand(2, 3).astype('float32') odl_op = odl.MatrixOperator(matrix) op_mod = odl_torch.OperatorModule(odl_op) loss_fn = nn.MSELoss() # Test with linear layers (1 extra dim) layer_before = nn.Linear(3, 3) layer_after = nn.Linear(2, 2) model = nn.Sequential(layer_before, op_mod, layer_after).to(device) x = torch.from_numpy(np.ones(3, dtype='float32'))[None, ...].to(device) x.requires_grad_(True) target = torch.from_numpy(np.zeros(2, dtype='float32'))[None, ...].to(device) loss = loss_fn(model(x), target) loss.backward() assert all(p is not None for p in model.parameters()) assert x.grad.detach().cpu().abs().sum() != 0 assert x.device.type == loss.device.type == device # Test with conv layers (2 extra dims) layer_before = nn.Conv1d(1, 2, 2) # 1->2 channels layer_after = nn.Conv1d(2, 1, 2) # 2->1 channels model = nn.Sequential(layer_before, op_mod, layer_after).to(device) # Input size 4 since initial convolution reduces by 1 x = torch.from_numpy(np.ones(4, dtype='float32'))[None, None, ...].to(device) x.requires_grad_(True) # Output size 1 since final convolution reduces by 1 target = torch.from_numpy(np.zeros(1, dtype='float32'))[None, None, ...].to(device) loss = loss_fn(model(x), target) loss.backward() assert all(p is not None for p in model.parameters()) assert x.grad.detach().cpu().abs().sum() != 0 assert x.device.type == loss.device.type == device
def main(): parser = argparse.ArgumentParser() # general & dataset & training settings parser.add_argument('--k_max', type=int, default=5, help='Max reconstruction iterations') parser.add_argument('--save_figs', type = lambda x:bool(strtobool(x)), default=True, help='save pics in reconstruction') parser.add_argument('--img_mode', type=str, default='SimpleCT', help=' image-modality reconstruction: SimpleCT') parser.add_argument('--train_size', type=int, default=4000, help='dataset size') parser.add_argument('--dataset_type', type=str, default='GenEllipsesSamples', help='GenEllipsesSamples or GenFoamSamples') parser.add_argument('--pseudo_inverse_init', type = lambda x:bool(strtobool(x)), default=True, help='initialise with pseudoinverse') parser.add_argument('--epochs', type=int, default=150, help='number of epochs to train') parser.add_argument('--batch_size', type=int, default=128, help='input batch size for training') parser.add_argument('--initial_lr', type=float, default=1e-3, help='initial_lr') parser.add_argument('--val_batch_size', type=int, default=128, help='input batch size for valing') parser.add_argument('--arch_args', type=json.loads, default=dict(), help='load architecture dictionary') parser.add_argument('--block_type', type=str, default='bayesian_homo', help='deterministic, bayesian_homo, bayesian_hetero') parser.add_argument('--save', type= lambda x:bool(strtobool(x)), default=True, help='save model') parser.add_argument('--load', type= lambda x:bool(strtobool(x)), default=False, help='save model') # forward models setting parser.add_argument('--size', type=int, default=128, help='image size') parser.add_argument('--beam_num_angle', type=int, default=30, help='number of angles / projections') parser.add_argument('--limited_view', type = lambda x:bool(strtobool(x)), default=False, help='limited view geometry instead of sparse view geometry') # options parser.add_argument('--no_cuda', type = lambda x:bool(strtobool(x)), default=False, help='disables CUDA training') parser.add_argument('--seed', type=int, default=222, help='random seed') parser.add_argument('--config', default='configs/bayesian_arch_config.json', help='config file path') args = parser.parse_args() if args.config is not None: with open(args.config) as handle: config = json.load(handle) vars(args).update(config) block_utils.set_gpu_mode(True) torch.manual_seed(args.seed) np.random.seed(args.seed) use_cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device('cuda' if use_cuda else 'cpu') if args.img_mode == SimpleCT.__name__: img_mode = SimpleCT() half_size = args.size / 2 space = odl.uniform_discr([-half_size, -half_size], [half_size, half_size], [args.size, args.size], dtype='float32') img_mode.space = space if not args.limited_view: geometry = odl.tomo.parallel_beam_geometry(space, num_angles=args.beam_num_angle) elif args.limited_view: geometry = limited_view_parallel_beam_geometry(space, beam_num_angle=args.beam_num_angle) else: raise NotImplementedError img_mode.geometry = geometry operator = odl.tomo.RayTransform(space, geometry) opnorm = odl.power_method_opnorm(operator) img_mode.operator = odl_torch.OperatorModule((1 / opnorm) * operator) img_mode.adjoint = odl_torch.OperatorModule((1 / opnorm) * operator.adjoint) pseudoinverse = odl.tomo.fbp_op(operator) pseudoinverse = odl_torch.OperatorModule(pseudoinverse * opnorm) img_mode.pseudoinverse = pseudoinverse geometry_specs = 'full_view_sparse_' + str(args.beam_num_angle) if not args.limited_view else 'limited_view_' + str(args.beam_num_angle) dataset_name = 'dataset' + '_' + args.img_mode + '_' + str(args.size) \ + '_' + str(args.train_size) + '_' + geometry_specs + '_' + args.dataset_type data_constructor = DatasetConstructor(img_mode, train_size=args.train_size, dataset_name=dataset_name) data = data_constructor.data() else: raise NotImplementedError dataset = DataSet(data, img_mode, args.pseudo_inverse_init) optim_parms = {'epochs':args.epochs, 'initial_lr': args.initial_lr, 'batch_size': args.batch_size} if args.block_type == 'deterministic': from blocks import DeterministicBlock as Block elif args.block_type == 'bayesian_homo': from blocks import BlockHomo as Block elif args.block_type == 'bayesian_hetero': from blocks import BlockHetero as Block else: raise NotImplementedError # results directory path = os.path.dirname(__file__) dir_path = os.path.join(path, 'results', args.img_mode, args.block_type, args.dataset_type, str(args.train_size), geometry_specs, str(args.size), str(args.seed)) if not os.path.isdir(dir_path): os.makedirs(dir_path) # all config print('===========================\n', flush=True) for key, val in vars(args).items(): print('{}: {}'.format(key, val), flush=True) print('===========================\n', flush=True) blocks_history = {'block': [], 'optimizer': []} # savings training procedures filename = 'train_phase' filepath = os.path.join(dir_path, filename) vis = TrainVisualiser(filepath) start_time = time.time() # looping through architecture-blocs for idx in range(1, args.k_max + 1): print('============== training block number: {} ============= \n'.format(idx), flush=True) train_tensor = dataset.construct(flag='train') val_tensor = dataset.construct(flag='validation') train_loader = DataLoader(train_tensor, batch_size=args.batch_size, shuffle=True) val_loader = DataLoader(val_tensor, batch_size=args.val_batch_size, shuffle=True) block = Block(args.arch_args) block = block.to(device) path_block = os.path.join(dir_path, str(idx) + '.pt') if args.load and \ os.path.exists(path_block): block.load_state_dict( torch.load(path_block) ) loaded = True print('============= loaded idx: {} ============='.format(idx), flush=True) else: block.optimise(train_loader, **optim_parms) loaded = False start = time.time() info = next_step_update(dataset, train_tensor, block, device, flag='train') end = time.time() print('============= {} {:.4f} ============= \n'.format('training reconstruction', end-start), flush=True) for key in info.keys(): print('{}: {} \n'.format(key, info[key]), flush=True) start = time.time() info = next_step_update(dataset, val_tensor, block, device, flag='validation') end = time.time() print('============= {} {:.4f} ============= \n'.format('validation reconstruction', end-start), flush=True) for key in info.keys(): print('{}: {} \n'.format(key, info[key]), flush=True) vis.update(dataset, flag='validation') blocks_history['block'].append(block) # reconstruction resonstruction_dir_path = os.path.join(dir_path, str(idx)) if not loaded: if not os.path.isdir(resonstruction_dir_path): os.makedirs(resonstruction_dir_path) get_stats(dataset, blocks_history, device, resonstruction_dir_path) if args.save and not loaded: torch.save(block.state_dict(), os.path.join(dir_path, str(idx) + '.pt')) print('--- training time: %s seconds ---' % (time.time() - start_time), flush=True) vis.generate()
def main(): parser = argparse.ArgumentParser() # general & dataset & training settings parser.add_argument('--k_max', type=int, default=5, help='Max reconstruction iterations') parser.add_argument('--save_figs', type = lambda x:bool(strtobool(x)), default=True, help='save pics in reconstruction') parser.add_argument('--img_mode', type=str, default='SimpleCT', help=' image-modality reconstruction: SimpleCT') parser.add_argument('--train_size', type=int, default=4000, help='dataset size') parser.add_argument('--pseudo_inverse_init', type = lambda x:bool(strtobool(x)), default=True, help='initialise with pseudoinverse') parser.add_argument('--brain', type = lambda x:bool(strtobool(x)), default=False, help='test set of brain images') parser.add_argument('--epochs', type=int, default=150, help='number of epochs to train') parser.add_argument('--batch_size', type=int, default=128, help='input batch size for training') parser.add_argument('--initial_lr', type=float, default=1e-3, help='initial_lr') parser.add_argument('--val_batch_size', type=int, default=128, help='input batch size for valing') # forward models setting parser.add_argument('--size', type=int, default=128, help='image size') parser.add_argument('--beam_num_angle', type=int, default=30, help='number of angles / projections') # options parser.add_argument('--no_cuda', type = lambda x:bool(strtobool(x)), default=False, help='disables CUDA training') parser.add_argument('--seed', type=int, default=222, help='random seed') args = parser.parse_args() layer_utils.set_gpu_mode(True) torch.manual_seed(args.seed) np.random.seed(args.seed) use_cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device('cuda' if use_cuda else 'cpu') if args.img_mode is not None: forward_model = ForwardModel() half_size = args.size / 2 space = odl.uniform_discr([-half_size, -half_size], [half_size, half_size], [args.size, args.size], dtype='float32') forward_model.space = space geometry = odl.tomo.parallel_beam_geometry(space, num_angles=args.beam_num_angle) forward_model.geometry = geometry operator = odl.tomo.RayTransform(space, geometry) opnorm = odl.power_method_opnorm(operator) forward_model.operator = odl_torch.OperatorModule( (1 / opnorm) * operator ) forward_model.adjoint = odl_torch.OperatorModule(operator.adjoint) pseudoinverse = odl.tomo.fbp_op(operator) pseudoinverse = odl_torch.OperatorModule( pseudoinverse * opnorm ) forward_model.pseudoinverse = pseudoinverse geometry_specs = 'full_view_sparse_' + str(args.beam_num_angle) dataset_name = 'dataset' + '_' + args.img_mode + '_' + str(args.size) \ + '_' + str(args.train_size) + '_' + geometry_specs + '_' \ + 'brain' + '_' + str(args.brain) if args.img_mode == SimpleCT.__name__: img_mode = SimpleCT(forward_model) data_constructor = DatasetConstructor(img_mode, train_size=args.train_size, brain=args.brain, dataset_name=dataset_name) data = data_constructor.data() else: raise NotImplementedError dataset = DataSet(data, img_mode, args.pseudo_inverse_init) optim_parms = {'epochs':args.epochs, 'initial_lr': args.initial_lr, 'batch_size': args.batch_size} from hybrid_model import HybridModel as NeuralLearner # results directory path = os.path.dirname(__file__) dir_path = os.path.join(path, 'results', args.img_mode, 'MFVI', str(args.train_size), geometry_specs, str(args.seed)) if not os.path.isdir(dir_path): os.makedirs(dir_path) # all config print('===========================\n', flush=True) for key, val in vars(args).items(): print('{}: {}'.format(key, val), flush=True) print('===========================\n', flush=True) blocks_history = {'model': [], 'optimizer': []} arch_args = {'arch': {'up': [ [1, 16, 3, 1, 1], [16, 32, 3, 1, 1]], 'low': [ [1, 16, 3, 1, 1], [16, 32, 3, 1, 1]], 'cm': [ [64, 32, 3, 1, 1], [32, 16, 3, 1, 1]] }} # savings training procedures filename = 'train_phase' filepath = os.path.join(dir_path, filename) vis = TrainVisualiser(filepath) start_time = time.time() # looping through architecture-blocs for idx in range(1, args.k_max + 1): print('============== training block number: {} ============= \n'.format(idx), flush=True) train_tensor = dataset.construct(flag='train') val_tensor = dataset.construct(flag='validation') train_loader = DataLoader(train_tensor, batch_size=args.batch_size, shuffle=True) val_loader = DataLoader(val_tensor, batch_size=args.val_batch_size, shuffle=True) model = NeuralLearner(arch_args) model = model.to(device) model_path = os.path.join(dir_path, str(idx) + '.pt') if os.path.exists(model_path): model_loaded = True model.load_state_dict(torch.load(model_path)) print('idx: {} model loaded!\npath to model:\n{}'.format(idx, model_path), flush=True) else: model_loaded = False model.optimise(train_loader, **optim_parms) save_net(model, os.path.join(dir_path, str(idx) + '.pt')) print('idx: {} optimisation finished!'.format(idx), flush=True) start = time.time() info = next_step_update(dataset, train_tensor, model, device, flag='train') end = time.time() print('============= {} {:.4f} ============= \n'.format('training reconstruction', end-start), flush=True) for key in info.keys(): print('{}: {} \n'.format(key, info[key]), flush=True) start = time.time() info = next_step_update(dataset, val_tensor, model, device, flag='validation') end = time.time() print('============= {} {:.4f} ============= \n'.format('validation reconstruction', end-start), flush=True) for key in info.keys(): print('{}: {} \n'.format(key, info[key]), flush=True) vis.update(dataset, flag='validation') blocks_history['model'].append(model) # reconstruction resonstruction_dir_path = os.path.join(dir_path, str(idx)) if model_loaded: resonstruction_dir_path = os.path.join(dir_path, str(idx), 're-loaded') if not os.path.isdir(resonstruction_dir_path): os.makedirs(resonstruction_dir_path) get_stats(dataset, blocks_history, device, resonstruction_dir_path) print('--- training time: %s seconds ---' % (time.time() - start_time), flush=True) vis.generate()
def main(): parser = argparse.ArgumentParser() # general & dataset & training settings parser.add_argument('--k_max', type=int, default=5, help='Max reconstruction iterations') parser.add_argument('--save_figs', type = lambda x:bool(strtobool(x)), default=True, help='save pics in reconstruction') parser.add_argument('--img_mode', type=str, default='SimpleCT', help=' image-modality reconstruction: SimpleCT') parser.add_argument('--train_size', type=int, default=4000, help='dataset size') parser.add_argument('--dataset_type', type=str, default='GenEllipsesSamples', help='GenEllipsesSamples or GenFoamSamples') parser.add_argument('--pseudo_inverse_init', type = lambda x:bool(strtobool(x)), default=True, help='initialise with pseudoinverse') parser.add_argument('--epochs', type=int, default=150, help='number of epochs to train') parser.add_argument('--batch_size', type=int, default=128, help='input batch size for training') parser.add_argument('--initial_lr', type=float, default=1e-3, help='initial_lr') parser.add_argument('--val_batch_size', type=int, default=128, help='input batch size for valing') parser.add_argument('--arch_args', type=json.loads, default=dict(), help='load architecture dictionary') parser.add_argument('--block_type', type=str, default='bayesian_homo', help='deterministic, bayesian_homo, bayesian_hetero') # forward models setting parser.add_argument('--size', type=int, default=128, help='image size') parser.add_argument('--beam_num_angle', type=int, default=30, help='number of angles / projections') parser.add_argument('--limited_view', type = lambda x:bool(strtobool(x)), default=False, help='limited view geometry instead of sparse view geometry') # options parser.add_argument('--no_cuda', type = lambda x:bool(strtobool(x)), default=False, help='disables CUDA training') parser.add_argument('--seed', type=int, default=222, help='random seed') parser.add_argument('--config', default='configs/bayesian_arch_config.json', help='config file path') args = parser.parse_args() if args.config is not None: with open(args.config) as handle: config = json.load(handle) vars(args).update(config) block_utils.set_gpu_mode(True) torch.manual_seed(args.seed) np.random.seed(args.seed) use_cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device('cuda' if use_cuda else 'cpu') if args.img_mode == SimpleCT.__name__: img_mode = SimpleCT() half_size = args.size / 2 space = odl.uniform_discr([-half_size, -half_size], [half_size, half_size], [args.size, args.size], dtype='float32') img_mode.space = space if not args.limited_view: geometry = odl.tomo.parallel_beam_geometry(space, num_angles=args.beam_num_angle) elif args.limited_view: geometry = limited_view_parallel_beam_geometry(space, beam_num_angle=args.beam_num_angle) else: raise NotImplementedError img_mode.geometry = geometry operator = odl.tomo.RayTransform(space, geometry) opnorm = odl.power_method_opnorm(operator) img_mode.operator = odl_torch.OperatorModule((1 / opnorm) * operator) img_mode.adjoint = odl_torch.OperatorModule((1 / opnorm) * operator.adjoint) pseudoinverse = odl.tomo.fbp_op(operator) pseudoinverse = odl_torch.OperatorModule(pseudoinverse * opnorm) img_mode.pseudoinverse = pseudoinverse geometry_specs = 'full_view_sparse_' + str(args.beam_num_angle) if not args.limited_view else 'limited_view_' + str(args.beam_num_angle) dataset_name = 'dataset' + '_' + args.img_mode + '_' + str(args.size) \ + '_' + str(args.train_size) + '_' + geometry_specs + '_' + args.dataset_type data_constructor = DatasetConstructor(img_mode, train_size=args.train_size, dataset_name=dataset_name) data = data_constructor.data() else: raise NotImplementedError dataset = DataSet(data, img_mode, args.pseudo_inverse_init) optim_parms = {'epochs':args.epochs, 'initial_lr': args.initial_lr, 'batch_size': args.batch_size} if args.block_type == 'deterministic': from blocks import DeterministicBlock as Block elif args.block_type == 'bayesian_homo': from blocks import BlockHomo as Block elif args.block_type == 'bayesian_hetero': from blocks import BlockHetero as Block else: raise NotImplementedError # results directory path = os.path.dirname(__file__) dir_path = os.path.join(path, 'results', args.img_mode, args.block_type, args.dataset_type, str(args.train_size), geometry_specs, str(args.size), str(args.seed)) if not os.path.isdir(dir_path): os.makedirs(dir_path) # all config print('===========================\n', flush=True) for key, val in vars(args).items(): print('{}: {}'.format(key, val), flush=True) print('===========================\n', flush=True) blocks_history = {'block': []} start_time = time.time() # looping through architecture-blocs for idx in range(1, args.k_max + 1): print('============== training block number: {} ============= \n'.format(idx), flush=True) block = Block(args.arch_args) block = block.to(device) path_block = os.path.join(dir_path, str(idx) + '.pt') if os.path.exists(path_block): block.load_state_dict( torch.load(path_block) ) else: raise NotImplementedError blocks_history['block'].append(block) start = time.time() with torch.no_grad(): mc_samples_mean, mc_samples_var = [], [] for _ in range(100): for block in blocks_history['block']: block.eval() test_tensor = dataset.construct(flag='test', display=False) dataloader = DataLoader(deepcopy(test_tensor), batch_size=64, shuffle=False, drop_last=False) X_, Var_ = [], [] for batch_idx, (data, grad, target) in enumerate(dataloader): data, grad, target = data.to(device), grad.to(device), target.to(device) output, var = block.forward(data, grad) X_.append(output); Var_.append(var) dataset.update(torch.cat(X_).cpu(), flag='test') mc_samples_mean.append(dataset.X_['test']) mc_samples_var.append(torch.cat(Var_).cpu()) dataset.reset('test') print('time: {}'.format(time.time() - start)) mean = torch.mean(torch.stack(mc_samples_mean), dim=0) if hasattr(block, 'bayes_CNN_log_std'): epistemic = torch.std(torch.stack(mc_samples_mean), dim=0)**2 aleatoric = torch.mean(torch.stack(mc_samples_var), dim=0) std = torch.sqrt( torch.std(torch.stack(mc_samples_mean), dim=0)**2 + torch.mean(torch.stack(mc_samples_var), dim=0) ) else: raise NotImplementedError dir_path = os.path.join(dir_path, 'uncertainty analysis') if not os.path.isdir(dir_path): os.makedirs(dir_path) filename = 'data' + '_' + str(args.k_max) + '.p' filepath = os.path.join(dir_path, filename) with open(filepath, 'wb') as handle: pickle.dump({'mean': mean, 'aleatoric': aleatoric, 'epistemic': epistemic, 'std': std}, handle, protocol=pickle.HIGHEST_PROTOCOL) handle.close() # reconstruction resonstruction_dir_path = os.path.join(dir_path, str(idx)) if not os.path.isdir(resonstruction_dir_path): os.makedirs(resonstruction_dir_path) get_stats(dataset, blocks_history, device, resonstruction_dir_path) print('--- training time: %s seconds ---' % (time.time() - start_time), flush=True)
import numpy as np import torch from torch import nn import odl from odl.contrib import torch as odl_torch # --- Forward --- # # Define ODL operator matrix = np.array([[1, 0, 0], [0, 1, 1]], dtype='float32') odl_op = odl.MatrixOperator(matrix) # Wrap ODL operator as `Module` op_layer = odl_torch.OperatorModule(odl_op) # Test with some inputs. We need to add at least one batch axis inp = torch.ones((1, 3)) print('Operator layer evaluated on a 1x3 tensor:') print(op_layer(inp)) print() inp = torch.ones((1, 1, 3)) print('Operator layer evaluated on a 1x1x3 tensor:') print(op_layer(inp)) print() # We combine the module with some builtin pytorch layers of matching shapes layer_before = nn.Linear(3, 3) layer_after = nn.Linear(2, 2)