Esempio n. 1
0
    def init_model(self):
        if self.hyper_params['init_fbp']:
            fbp = fbp_op(
                self.non_normed_ray_trafo,
                filter_type=self.hyper_params['init_filter_type'],
                frequency_scaling=self.hyper_params['init_frequency_scaling'])
            if self.normalize_by_opnorm:
                fbp = OperatorRightScalarMult(fbp, self.opnorm)
            self.init_mod = OperatorModule(fbp)
        else:
            self.init_mod = None
        self.model = PrimalDualNet(
            n_iter=self.niter,
            op=self.ray_trafo_mod,
            op_adj=self.ray_trafo_adj_mod,
            op_init=self.init_mod,
            n_primal=self.hyper_params['nprimal'],
            n_dual=self.hyper_params['ndual'],
            use_sigmoid=self.hyper_params['use_sigmoid'],
            internal_ch=self.hyper_params['internal_ch'],
            kernel_size=self.hyper_params['kernel_size'],
            batch_norm=self.hyper_params['batch_norm'],
            prelu=self.hyper_params['prelu'],
            lrelu_coeff=self.hyper_params['lrelu_coeff'])

        def weights_init(m):
            if isinstance(m, torch.nn.Conv2d):
                m.bias.data.fill_(0.0)
                torch.nn.init.xavier_uniform_(m.weight)

        self.model.apply(weights_init)

        if self.use_cuda:
            # WARNING: using data-parallel here doesn't work because of astra-gpu
            self.model = self.model.to(self.device)
Esempio n. 2
0
    def init_model(self):
        self.op_mod = OperatorModule(self.op)
        self.op_adj_mod = OperatorModule(self.op.adjoint)
        partial0 = odl.PartialDerivative(self.op.domain, axis=0)
        partial1 = odl.PartialDerivative(self.op.domain, axis=1)
        self.reg_mod = OperatorModule(partial0.adjoint * partial0 +
                                      partial1.adjoint * partial1)
        if self.hyper_params['init_fbp']:
            fbp = fbp_op(
                self.non_normed_op,
                filter_type=self.hyper_params['init_filter_type'],
                frequency_scaling=self.hyper_params['init_frequency_scaling'])
            if self.normalize_by_opnorm:
                fbp = OperatorRightScalarMult(fbp, self.opnorm)
            self.init_mod = OperatorModule(fbp)
        else:
            self.init_mod = None
        self.model = IterativeNet(n_iter=self.niter,
                                  n_memory=5,
                                  op=self.op_mod,
                                  op_adj=self.op_adj_mod,
                                  op_init=self.init_mod,
                                  op_reg=self.reg_mod,
                                  use_sigmoid=self.hyper_params['use_sigmoid'],
                                  n_layer=self.hyper_params['nlayer'],
                                  internal_ch=self.hyper_params['internal_ch'],
                                  kernel_size=self.hyper_params['kernel_size'],
                                  batch_norm=self.hyper_params['batch_norm'],
                                  prelu=self.hyper_params['prelu'],
                                  lrelu_coeff=self.hyper_params['lrelu_coeff'])

        def weights_init(m):
            if isinstance(m, torch.nn.Conv2d):
                m.bias.data.fill_(0.0)
                if self.hyper_params['init_weight_xavier_normal']:
                    torch.nn.init.xavier_normal_(
                        m.weight, gain=self.hyper_params['init_weight_gain'])

        self.model.apply(weights_init)

        if self.use_cuda:
            # WARNING: using data-parallel here doesn't work, probably
            # astra_cuda is not thread-safe
            self.model = self.model.to(self.device)