Beispiel #1
0
    def init_model(self):
        if self.hyper_params['init_fbp']:
            fbp = fbp_op(
                self.non_normed_ray_trafo,
                filter_type=self.hyper_params['init_filter_type'],
                frequency_scaling=self.hyper_params['init_frequency_scaling'])
            if self.normalize_by_opnorm:
                fbp = OperatorRightScalarMult(fbp, self.opnorm)
            self.init_mod = OperatorModule(fbp)
        else:
            self.init_mod = None
        self.model = PrimalDualNet(
            n_iter=self.niter,
            op=self.ray_trafo_mod,
            op_adj=self.ray_trafo_adj_mod,
            op_init=self.init_mod,
            n_primal=self.hyper_params['nprimal'],
            n_dual=self.hyper_params['ndual'],
            use_sigmoid=self.hyper_params['use_sigmoid'],
            internal_ch=self.hyper_params['internal_ch'],
            kernel_size=self.hyper_params['kernel_size'],
            batch_norm=self.hyper_params['batch_norm'],
            prelu=self.hyper_params['prelu'],
            lrelu_coeff=self.hyper_params['lrelu_coeff'])

        def weights_init(m):
            if isinstance(m, torch.nn.Conv2d):
                m.bias.data.fill_(0.0)
                torch.nn.init.xavier_uniform_(m.weight)

        self.model.apply(weights_init)

        if self.use_cuda:
            # WARNING: using data-parallel here doesn't work because of astra-gpu
            self.model = self.model.to(self.device)
Beispiel #2
0
    def __init__(self, func, scalar):
        """Initialize a new instance.

        Parameters
        ----------
        func : `Functional`
            The functional which will have its argument scaled.
        scalar : float, nonzero
            The scaling parameter with which the argument is scaled.
        """

        if not isinstance(func, Functional):
            raise TypeError('`func` {!r} is not a `Functional` instance'
                            ''.format(func))

        scalar = func.domain.field.element(scalar)

        Functional.__init__(self, space=func.domain, linear=func.is_linear,
                            grad_lipschitz=(
                                np.abs(scalar) * func.grad_lipschitz))

        OperatorRightScalarMult.__init__(self, operator=func, scalar=scalar)
Beispiel #3
0
    def __init__(self, func, scalar):
        """Initialize a new instance.

        Parameters
        ----------
        func : `Functional`
            The functional which will have its argument scaled.
        scalar : float, nonzero
            The scaling parameter with which the argument is scaled.
        """

        if not isinstance(func, Functional):
            raise TypeError('`func` {!r} is not a `Functional` instance'
                            ''.format(func))

        scalar = func.domain.field.element(scalar)

        Functional.__init__(self, space=func.domain, linear=func.is_linear,
                            grad_lipschitz=(
                                np.abs(scalar) * func.grad_lipschitz))

        OperatorRightScalarMult.__init__(self, operator=func, scalar=scalar)
Beispiel #4
0
    def init_model(self):
        self.op_mod = OperatorModule(self.op)
        self.op_adj_mod = OperatorModule(self.op.adjoint)
        partial0 = odl.PartialDerivative(self.op.domain, axis=0)
        partial1 = odl.PartialDerivative(self.op.domain, axis=1)
        self.reg_mod = OperatorModule(partial0.adjoint * partial0 +
                                      partial1.adjoint * partial1)
        if self.hyper_params['init_fbp']:
            fbp = fbp_op(
                self.non_normed_op,
                filter_type=self.hyper_params['init_filter_type'],
                frequency_scaling=self.hyper_params['init_frequency_scaling'])
            if self.normalize_by_opnorm:
                fbp = OperatorRightScalarMult(fbp, self.opnorm)
            self.init_mod = OperatorModule(fbp)
        else:
            self.init_mod = None
        self.model = IterativeNet(n_iter=self.niter,
                                  n_memory=5,
                                  op=self.op_mod,
                                  op_adj=self.op_adj_mod,
                                  op_init=self.init_mod,
                                  op_reg=self.reg_mod,
                                  use_sigmoid=self.hyper_params['use_sigmoid'],
                                  n_layer=self.hyper_params['nlayer'],
                                  internal_ch=self.hyper_params['internal_ch'],
                                  kernel_size=self.hyper_params['kernel_size'],
                                  batch_norm=self.hyper_params['batch_norm'],
                                  prelu=self.hyper_params['prelu'],
                                  lrelu_coeff=self.hyper_params['lrelu_coeff'])

        def weights_init(m):
            if isinstance(m, torch.nn.Conv2d):
                m.bias.data.fill_(0.0)
                if self.hyper_params['init_weight_xavier_normal']:
                    torch.nn.init.xavier_normal_(
                        m.weight, gain=self.hyper_params['init_weight_gain'])

        self.model.apply(weights_init)

        if self.use_cuda:
            # WARNING: using data-parallel here doesn't work, probably
            # astra_cuda is not thread-safe
            self.model = self.model.to(self.device)