def __init__(self, use_affine_ex: bool = True, **kwargs):
        super().__init__(**kwargs)
        self.use_affine_ex = use_affine_ex

        # decoder parts

        # Flow for modelling t Gamma
        self.thickness_flow_components = ComposeTransformModule([Spline(1)])
        self.thickness_flow_constraint_transforms = ComposeTransform(
            [self.thickness_flow_lognorm,
             ExpTransform()])
        self.thickness_flow_transforms = ComposeTransform([
            self.thickness_flow_components,
            self.thickness_flow_constraint_transforms
        ])

        # affine flow for s normal
        self.intensity_flow_components = ComposeTransformModule(
            [LearnedAffineTransform(), Spline(1)])
        self.intensity_flow_constraint_transforms = ComposeTransform(
            [SigmoidTransform(), self.intensity_flow_norm])
        self.intensity_flow_transforms = [
            self.intensity_flow_components,
            self.intensity_flow_constraint_transforms
        ]

        # realnvp or so for x
        self._build_image_flow()
    def _get_transformed_x_dist(self, latent, ctx=None):
        x_pred_dist = self.decoder.predict(
            latent,
            ctx)  # returns a normal dist with mean of the predicted image
        if self.laplace_likelihood:
            x_base_dist = Laplace(self.x_base_loc,
                                  self.x_base_scale).to_event(3)
        else:
            x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(
                3)  # 3 dimensions starting from right dep.

        preprocess_transform = self._get_preprocess_transforms()

        if isinstance(x_pred_dist, MultivariateNormal) or isinstance(
                x_pred_dist, LowRankMultivariateNormal):
            chol_transform = LowerCholeskyAffine(x_pred_dist.loc,
                                                 x_pred_dist.scale_tril)
            reshape_transform = ReshapeTransform(self.img_shape,
                                                 (np.prod(self.img_shape), ))
            x_reparam_transform = ComposeTransform(
                [reshape_transform, chol_transform, reshape_transform.inv])
        elif isinstance(x_pred_dist, Independent):
            x_pred_dist = x_pred_dist.base_dist
            x_reparam_transform = AffineTransform(x_pred_dist.loc,
                                                  x_pred_dist.scale, 3)
        else:
            raise ValueError(f'{x_pred_dist} not valid.')

        return TransformedDistribution(
            x_base_dist,
            ComposeTransform([x_reparam_transform, preprocess_transform]))
示例#3
0
    def __init__(self, use_affine_ex: bool = True, **kwargs):
        super().__init__(**kwargs)
        self.use_affine_ex = use_affine_ex

        # decoder parts

        # Flow for modelling t Gamma
        self.thickness_flow_components = ComposeTransformModule([Spline(1)])
        self.thickness_flow_constraint_transforms = ComposeTransform(
            [self.thickness_flow_lognorm,
             ExpTransform()])
        self.thickness_flow_transforms = ComposeTransform([
            self.thickness_flow_components,
            self.thickness_flow_constraint_transforms
        ])

        # affine flow for s normal
        intensity_net = DenseNN(1, [1],
                                param_dims=[1, 1],
                                nonlinearity=torch.nn.Identity())
        self.intensity_flow_components = ConditionalAffineTransform(
            context_nn=intensity_net, event_dim=0)
        self.intensity_flow_constraint_transforms = ComposeTransform(
            [SigmoidTransform(), self.intensity_flow_norm])
        self.intensity_flow_transforms = [
            self.intensity_flow_components,
            self.intensity_flow_constraint_transforms
        ]
        # build flow as s_affine_w * t * e_s + b -> depends on t though

        # realnvp or so for x
        self._build_image_flow()
示例#4
0
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        # Flow for modelling t Gamma
        self.thickness_flow_components = ComposeTransformModule([Spline(1)])
        self.thickness_flow_constraint_transforms = ComposeTransform(
            [self.thickness_flow_lognorm,
             ExpTransform()])
        self.thickness_flow_transforms = ComposeTransform([
            self.thickness_flow_components,
            self.thickness_flow_constraint_transforms
        ])

        # affine flow for s normal
        intensity_net = DenseNN(1, [1],
                                param_dims=[1, 1],
                                nonlinearity=torch.nn.Identity())
        self.intensity_flow_components = ConditionalAffineTransform(
            context_nn=intensity_net, event_dim=0)
        self.intensity_flow_constraint_transforms = ComposeTransform(
            [SigmoidTransform(), self.intensity_flow_norm])
        self.intensity_flow_transforms = [
            self.intensity_flow_components,
            self.intensity_flow_constraint_transforms
        ]
示例#5
0
 def infer_intensity_base(self, thickness, intensity):
     intensity_base_dist = Normal(self.intensity_base_loc,
                                  self.intensity_base_scale)
     cond_intensity_transforms = ComposeTransform(
         ConditionalTransformedDistribution(
             intensity_base_dist, self.intensity_flow_transforms).condition(
                 thickness).transforms)
     return cond_intensity_transforms.inv(intensity)
示例#6
0
    def __init__(self, use_affine_ex=True, **kwargs):
        super.__init__(**kwargs)

        self.num_scales = 2

        self.register_buffer("glasses_base_loc",
                             torch.zeros([
                                 1,
                             ], requires_grad=False))
        self.register_buffer("glasses_base_scale",
                             torch.ones([
                                 1,
                             ], requires_grad=False))

        self.register_buffer("glasses_flow_lognorm_loc",
                             torch.zeros([], requires_grad=False))
        self.register_buffer("glasses_flow_lognorm_scale",
                             torch.ones([], requires_grad=False))

        self.glasses_flow_components = ComposeTransformModule([Spline(1)])
        self.glasses_flow_constraint_transforms = ComposeTransform(
            [self.glasses_flow_lognorm,
             SigmoidTransform()])
        self.glasses_flow_transforms = ComposeTransform([
            self.glasses_flow_components,
            self.glasses_flow_constraint_transforms
        ])

        glasses_base_dist = Normal(self.glasses_base_loc,
                                   self.glasses_base_scale).to_event(1)
        self.glasses_dist = TransformedDistribution(
            glasses_base_dist, self.glasses_flow_transforms)
        glasses_ = pyro.sample("glasses_", self.glasses_dist)
        glasses = pyro.sample("glasses", dist.Bernoulli(glasses_))
        glasses_context = self.glasses_flow_constraint_transforms.inv(glasses_)

        self.x_transforms = self._build_image_flow()
        self.register_buffer("x_base_loc",
                             torch.zeros([1, 64, 64], requires_grad=False))
        self.register_buffer("x_base_scale",
                             torch.ones([1, 64, 64], requires_grad=False))
        x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3)
        cond_x_transforms = ComposeTransform(
            ConditionalTransformedDistribution(
                x_base_dist,
                self.x_transforms).condition(context).transforms).inv
        cond_x_dist = TransformedDistribution(x_base_dist, cond_x_transforms)

        x = pyro.sample("x", cond_x_dist)

        return x, glasses
    def generate(self, x, num_particles):
        z_dist = self.encoder.predict(x)
        z = z_dist.sample()
        x_pred_dist = self.decoder.predict(z)

        x_base_dist = dist.Normal(
            torch.zeros_like(x, requires_grad=False).view(x.shape[0], -1),
            torch.ones_like(x, requires_grad=False).view(x.shape[0], -1),
        ).to_event(1)

        if 'normal' in self.decoder_output or \
            self.decoder_output == 'deepvar' or \
            self.decoder_output == 'deepmean':
            transform = AffineTransform(x_pred_dist.mean, x_pred_dist.stddev,
                                        1)
        elif self.decoder_output == 'low_rank_mvn':
            # print(x_pred_dist.loc.shape)
            # print(x_pred_dist.loc)
            # print(x_pred_dist.scale_tril.shape)
            # print(x_pred_dist.scale_tril)
            transform = LowerCholeskyAffine(x_pred_dist.loc,
                                            x_pred_dist.scale_tril)
        else:
            raise Exception('Unknown decoder output')

        x_dist = dist.TransformedDistribution(x_base_dist,
                                              ComposeTransform([transform]))

        recons = []
        for i in range(num_particles):
            recon = pyro.sample('x', x_dist).view(x.shape[0], self.shape, 3)
            recons.append(recon)
        return torch.stack(recons).mean(0)
示例#8
0
def test_reparam_log_joint(model, kwargs):
    guide = AutoIAFNormal(model)
    guide(**kwargs)
    neutra = NeuTraReparam(guide)
    reparam_model = neutra.reparam(model)
    _, pe_fn, transforms, _ = initialize_model(model, model_kwargs=kwargs)
    init_params, pe_fn_neutra, _, _ = initialize_model(
        reparam_model, model_kwargs=kwargs
    )
    latent_x = list(init_params.values())[0]
    transformed_params = neutra.transform_sample(latent_x)
    pe_transformed = pe_fn_neutra(init_params)
    neutra_transform = ComposeTransform(guide.get_posterior(**kwargs).transforms)
    latent_y = neutra_transform(latent_x)
    log_det_jacobian = neutra_transform.log_abs_det_jacobian(latent_x, latent_y)
    pe = pe_fn({k: transforms[k](v) for k, v in transformed_params.items()})
    assert_close(pe_transformed, pe - log_det_jacobian)
示例#9
0
 def _get_preprocess_transforms(self):
     alpha = 0.05
     num_bits = 8
     if self.preprocessing == 'glow':
         # Map to [-0.5,0.5]
         a1 = AffineTransform(-0.5, (1. / 2 ** num_bits))
         preprocess_transform = ComposeTransform([a1])
     elif self.preprocessing == 'realnvp':
         # Map to [0,1]
         a1 = AffineTransform(0., (1. / 2 ** num_bits))
         # Map into unconstrained space as done in RealNVP
         a2 = AffineTransform(alpha, (1 - alpha))
         s = SigmoidTransform()
         preprocess_transform = ComposeTransform([a1, a2, s.inv])
     else:
         raise ValueError(f'{self.preprocessing} not valid.')
     return preprocess_transform
def model(n_samples=None, scale=2.):
    with pyro.plate('observations', n_samples):
        thickness = pyro.sample('thickness', Gamma(10., 5.))

        loc = (thickness - 2.5) * 2

        transforms = ComposeTransform([SigmoidTransform(), AffineTransform(10, 15)])

        width = pyro.sample('width', TransformedDistribution(Normal(loc, scale), transforms))

    return thickness, width
示例#11
0
    def model(self):
        thickness, intensity = self.pgm_model()

        x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3)
        x_dist = TransformedDistribution(
            x_base_dist,
            ComposeTransform(self.x_transforms).inv)

        x = pyro.sample('x', x_dist)

        return x, thickness, intensity
示例#12
0
    def infer_x_base(self, thickness, intensity, x):
        x_base_dist = Normal(self.x_base_loc, self.x_base_scale)

        thickness_ = self.thickness_flow_constraint_transforms.inv(thickness)
        intensity_ = self.intensity_flow_norm.inv(intensity)

        context = torch.cat([thickness_, intensity_], 1)
        cond_x_transforms = ComposeTransform(
            ConditionalTransformedDistribution(
                x_base_dist, self.x_transforms).condition(context).transforms)
        return cond_x_transforms(x)
示例#13
0
    def _get_transformed_x_dist(self, latent):
        x_pred_dist = self.decoder.predict(latent)
        x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3)

        preprocess_transform = self._get_preprocess_transforms()

        if isinstance(x_pred_dist, MultivariateNormal) or isinstance(
                x_pred_dist, LowRankMultivariateNormal):
            chol_transform = LowerCholeskyAffine(x_pred_dist.loc,
                                                 x_pred_dist.scale_tril)
            reshape_transform = ReshapeTransform(self.img_shape,
                                                 (np.prod(self.img_shape), ))
            x_reparam_transform = ComposeTransform(
                [reshape_transform, chol_transform, reshape_transform.inv])
        elif isinstance(x_pred_dist, Independent):
            x_pred_dist = x_pred_dist.base_dist
            x_reparam_transform = AffineTransform(x_pred_dist.loc,
                                                  x_pred_dist.scale, 3)

        return TransformedDistribution(
            x_base_dist,
            ComposeTransform([x_reparam_transform, preprocess_transform]))
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        # Flow for modelling t Gamma
        self.thickness_flow_components = ComposeTransformModule([Spline(1)])
        self.thickness_flow_constraint_transforms = ComposeTransform(
            [self.thickness_flow_lognorm,
             ExpTransform()])
        self.thickness_flow_transforms = ComposeTransform([
            self.thickness_flow_components,
            self.thickness_flow_constraint_transforms
        ])

        # affine flow for s normal
        self.intensity_flow_components = ComposeTransformModule(
            [LearnedAffineTransform(), Spline(1)])
        self.intensity_flow_constraint_transforms = ComposeTransform(
            [SigmoidTransform(), self.intensity_flow_norm])
        self.intensity_flow_transforms = [
            self.intensity_flow_components,
            self.intensity_flow_constraint_transforms
        ]
示例#15
0
def model(n_samples=None, scale=0.5, invert=False):
    with pyro.plate('observations', n_samples):
        thickness = 0.5 + pyro.sample('thickness', Gamma(10., 5.))

        if invert:
            loc = (thickness - 2) * -2
        else:
            loc = (thickness - 2.5) * 2

        transforms = ComposeTransform(
            [SigmoidTransform(), AffineTransform(64, 191)])

        intensity = pyro.sample(
            'intensity', TransformedDistribution(Normal(loc, scale),
                                                 transforms))

    return thickness, intensity
示例#16
0
    def infer_exogeneous(self, **obs):
        # assuming that we use transformed distributions for everything:
        cond_sample = pyro.condition(self.sample, data=obs)
        cond_trace = pyro.poutine.trace(cond_sample).get_trace(obs['x'].shape[0])

        output = {}
        for name, node in cond_trace.nodes.items():
            if 'fn' not in node.keys():
                continue

            fn = node['fn']
            if isinstance(fn, Independent):
                fn = fn.base_dist
            if isinstance(fn, TransformedDistribution):
                output[name + '_base'] = ComposeTransform(fn.transforms).inv(node['value'])

        return output
示例#17
0
    def model(self):
        thickness, intensity = self.pgm_model()

        thickness_ = self.thickness_flow_constraint_transforms.inv(thickness)
        intensity_ = self.intensity_flow_norm.inv(intensity)

        context = torch.cat([thickness_, intensity_], 1)

        x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3)
        cond_x_transforms = ComposeTransform(
            ConditionalTransformedDistribution(
                x_base_dist,
                self.x_transforms).condition(context).transforms).inv
        cond_x_dist = TransformedDistribution(x_base_dist, cond_x_transforms)

        x = pyro.sample('x', cond_x_dist)

        return x, thickness, intensity
示例#18
0
    def infer_exogenous(self, obs):
        # assuming that we use transformed distributions for everything
        cond_sample = pyro.condition(self.sample, data=obs)
        batch_size = obs['x'].shape[0]
        cond_trace = pyro.poutine.trace(cond_sample).get_trace(batch_size)

        output = {}
        for name, node in cond_trace.nodes.items():
            if 'z' in name or 'fn' not in node.keys():
                continue

            fn = node['fn']
            if isinstance(fn, Independent):
                fn = fn.base_dist
            if isinstance(fn, TransformedDistribution):
                # compute exogenous base distribution at all sites. base dist created with TransformReparam
                output[name + '_base'] = ComposeTransform(fn.transforms).inv(node['value'])

        return output
    def model(self):
        age, sex, ventricle_volume, brain_volume = self.pgm_model()

        ventricle_volume_ = self.ventricle_volume_flow_constraint_transforms.inv(
            ventricle_volume)
        brain_volume_ = self.brain_volume_flow_constraint_transforms.inv(
            brain_volume)
        age_ = self.age_flow_constraint_transforms.inv(age)

        z = pyro.sample('z', Normal(self.z_loc, self.z_scale).to_event(1))

        latent = torch.cat([z, age_, ventricle_volume_, brain_volume_], 1)

        x_loc = self.decoder_mean(self.decoder(latent))
        x_scale = torch.exp(self.decoder_logstd)

        theta = self.decoder_affine_param_net(latent).view(-1, 2, 3)

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")

            grid = nn.functional.affine_grid(theta, x_loc.size())
            x_loc_deformed = nn.functional.grid_sample(x_loc, grid)

        x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3)

        preprocess_transform = self._get_preprocess_transforms()
        x_dist = TransformedDistribution(
            x_base_dist,
            ComposeTransform([
                AffineTransform(x_loc_deformed, x_scale, 3),
                preprocess_transform
            ]))

        x = pyro.sample('x', x_dist)

        return x, z, age, sex, ventricle_volume, brain_volume
示例#20
0
class FlowSCM(pyroModule):
    """ definition of FlowSCM class"""
    def __init__(self, use_affine_ex=True, **kwargs):
        super.__init__(**kwargs)

        self.num_scales = 2

        self.register_buffer("glasses_base_loc",
                             torch.zeros([
                                 1,
                             ], requires_grad=False))
        self.register_buffer("glasses_base_scale",
                             torch.ones([
                                 1,
                             ], requires_grad=False))

        self.register_buffer("glasses_flow_lognorm_loc",
                             torch.zeros([], requires_grad=False))
        self.register_buffer("glasses_flow_lognorm_scale",
                             torch.ones([], requires_grad=False))

        self.glasses_flow_components = ComposeTransformModule([Spline(1)])
        self.glasses_flow_constraint_transforms = ComposeTransform(
            [self.glasses_flow_lognorm,
             SigmoidTransform()])
        self.glasses_flow_transforms = ComposeTransform([
            self.glasses_flow_components,
            self.glasses_flow_constraint_transforms
        ])

        glasses_base_dist = Normal(self.glasses_base_loc,
                                   self.glasses_base_scale).to_event(1)
        self.glasses_dist = TransformedDistribution(
            glasses_base_dist, self.glasses_flow_transforms)
        glasses_ = pyro.sample("glasses_", self.glasses_dist)
        glasses = pyro.sample("glasses", dist.Bernoulli(glasses_))
        glasses_context = self.glasses_flow_constraint_transforms.inv(glasses_)

        self.x_transforms = self._build_image_flow()
        self.register_buffer("x_base_loc",
                             torch.zeros([1, 64, 64], requires_grad=False))
        self.register_buffer("x_base_scale",
                             torch.ones([1, 64, 64], requires_grad=False))
        x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3)
        cond_x_transforms = ComposeTransform(
            ConditionalTransformedDistribution(
                x_base_dist,
                self.x_transforms).condition(context).transforms).inv
        cond_x_dist = TransformedDistribution(x_base_dist, cond_x_transforms)

        x = pyro.sample("x", cond_x_dist)

        return x, glasses

    def _build_image_flow(self):
        self.trans_modules = ComposeTransformModule([])
        self.x_transforms = []
        self.x_transforms += [self._get_preprocess_transforms()]

        c = 1
        for _ in range(self.num_scales):
            self.x_transforms.append(SqueezeTransform())
            c *= 4

            for _ in range(self.flows_per_scale):
                if self.use_actnorm:
                    actnorm = ActNorm(c)
                    self.trans_modules.append(actnorm)
                    self.x_transforms.append(actnorm)

                gcp = GeneralizedChannelPermute(channels=c)
                self.trans_modules.append(gcp)
                self.x_transforms.append(gcp)

                self.x_transforms.append(
                    TransposeTransform(torch.tensor((1, 2, 0))))

                ac = ConditionalAffineCoupling(
                    c // 2,
                    BasicFlowConvNet(c // 2, self.hidden_channels,
                                     (c // 2, c // 2), 2))
                self.trans_modules.append(ac)
                self.x_transforms.append(ac)

                self.x_transforms.append(
                    TransposeTransform(torch.tensor((2, 0, 1))))

            gcp = GeneralizedChannelPermute(channels=c)
            self.trans_modules.append(gcp)
            self.x_transforms.append(gcp)

        self.x_transforms += [
            ReshapeTransform((4**self.num_scales, 32 // 2**self.num_scales,
                              32 // 2**self.num_scales), (1, 32, 32))
        ]

        if self.use_affine_ex:
            affine_net = DenseNN(2, [16, 16], param_dims=[1, 1])
            affine_trans = ConditionalAffineTransform(context_nn=affine_net,
                                                      event_dim=3)

            self.trans_modules.append(affine_trans)
            self.x_transforms.append(affine_trans)
示例#21
0
 def infer_x_base(self, thickness, intensity, x):
     return ComposeTransform(self.x_transforms)(x)
示例#22
0
class IndependentFlowSEM(BaseFlowSEM):
    def __init__(self, use_affine_ex: bool = True, **kwargs):
        super().__init__(**kwargs)
        self.use_affine_ex = use_affine_ex

        # decoder parts

        # Flow for modelling t Gamma
        self.thickness_flow_components = ComposeTransformModule([Spline(1)])
        self.thickness_flow_constraint_transforms = ComposeTransform(
            [self.thickness_flow_lognorm,
             ExpTransform()])
        self.thickness_flow_transforms = ComposeTransform([
            self.thickness_flow_components,
            self.thickness_flow_constraint_transforms
        ])

        # affine flow for s normal
        self.intensity_flow_components = ComposeTransformModule(
            [LearnedAffineTransform(), Spline(1)])
        self.intensity_flow_constraint_transforms = ComposeTransform(
            [SigmoidTransform(), self.intensity_flow_norm])
        self.intensity_flow_transforms = [
            self.intensity_flow_components,
            self.intensity_flow_constraint_transforms
        ]

        # realnvp or so for x
        self._build_image_flow()

    def _build_image_flow(self):

        self.trans_modules = ComposeTransformModule([])

        self.x_transforms = []

        self.x_transforms += [self._get_preprocess_transforms()]

        c = 1
        for _ in range(self.num_scales):
            self.x_transforms.append(SqueezeTransform())
            c *= 4

            for _ in range(self.flows_per_scale):
                if self.use_actnorm:
                    actnorm = ActNorm(c)
                    self.trans_modules.append(actnorm)
                    self.x_transforms.append(actnorm)

                gcp = GeneralizedChannelPermute(channels=c)
                self.trans_modules.append(gcp)
                self.x_transforms.append(gcp)

                self.x_transforms.append(
                    TransposeTransform(torch.tensor((1, 2, 0))))

                ac = AffineCoupling(
                    c // 2,
                    BasicFlowConvNet(c // 2, self.hidden_channels,
                                     (c // 2, c // 2)))
                self.trans_modules.append(ac)
                self.x_transforms.append(ac)

                self.x_transforms.append(
                    TransposeTransform(torch.tensor((2, 0, 1))))

            gcp = GeneralizedChannelPermute(channels=c)
            self.trans_modules.append(gcp)
            self.x_transforms.append(gcp)

        self.x_transforms += [
            ReshapeTransform((4**self.num_scales, 32 // 2**self.num_scales,
                              32 // 2**self.num_scales), (1, 32, 32))
        ]

    @pyro_method
    def pgm_model(self):
        thickness_base_dist = Normal(self.thickness_base_loc,
                                     self.thickness_base_scale).to_event(1)
        thickness_dist = TransformedDistribution(
            thickness_base_dist, self.thickness_flow_transforms)

        thickness = pyro.sample('thickness', thickness_dist)
        # pseudo call to thickness_flow_transforms to register with pyro
        _ = self.thickness_flow_components

        intensity_base_dist = Normal(self.intensity_base_loc,
                                     self.intensity_base_scale).to_event(1)
        intensity_dist = TransformedDistribution(
            intensity_base_dist, self.intensity_flow_transforms)

        intensity = pyro.sample('intensity', intensity_dist)
        # pseudo call to intensity_flow_transforms to register with pyro
        _ = self.intensity_flow_components

        return thickness, intensity

    @pyro_method
    def model(self):
        thickness, intensity = self.pgm_model()

        x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3)
        x_dist = TransformedDistribution(
            x_base_dist,
            ComposeTransform(self.x_transforms).inv)

        x = pyro.sample('x', x_dist)

        return x, thickness, intensity

    @pyro_method
    def infer_thickness_base(self, thickness):
        return self.thickness_flow_transforms.inv(thickness)

    @pyro_method
    def infer_intensity_base(self, intensity):
        return self.intensity_flow_transforms.inv(intensity)

    @pyro_method
    def infer_x_base(self, thickness, intensity, x):
        return ComposeTransform(self.x_transforms)(x)
    def __init__(self,
                 latent_dim: int,
                 prior_components: int = 1,
                 posterior_components: int = 1,
                 logstd_init: float = -5,
                 enc_filters: Tuple[int] = (16, 32, 64, 128),
                 dec_filters: Tuple[int] = (128, 64, 32, 16),
                 num_convolutions: int = 3,
                 use_upconv: bool = False,
                 decoder_type: str = 'fixed_var',
                 decoder_cov_rank: int = 10,
                 img_shape: Tuple[int] = (128, 128),
                 use_nvae=False,
                 use_weight_norm=False,
                 use_spectral_norm=False,
                 laplace_likelihood=False,
                 eps=0.1,
                 n_prior_flows=3,
                 n_posterior_flows=3,
                 use_autoregressive=False,
                 use_swish=False,
                 use_spline=False,
                 use_stable=False,
                 pseudo3d=False,
                 head_filters=(16, 16),
                 **kwargs):
        super().__init__(**kwargs)
        self.encoder_shape = ((3, ) if pseudo3d else (1, )) + tuple(img_shape)
        self.decoder_shape = (head_filters[0], ) + tuple(img_shape)
        self.img_shape = (1, ) + tuple(img_shape)
        self.latent_dim = latent_dim
        self.prior_components = prior_components
        self.posterior_components = posterior_components
        self.logstd_init = logstd_init
        self.enc_filters = enc_filters
        self.dec_filters = dec_filters
        self.head_filters = head_filters
        self.num_convolutions = num_convolutions
        self.use_upconv = use_upconv
        self.decoder_type = decoder_type
        self.decoder_cov_rank = decoder_cov_rank
        self.use_nvae = use_nvae
        self.use_weight_norm = use_weight_norm
        self.use_spectral_norm = use_spectral_norm
        self.laplace_likelihood = laplace_likelihood
        self.eps = eps
        self.n_prior_flows = n_prior_flows
        self.n_posterior_flows = n_posterior_flows
        self.use_autoregressive = use_autoregressive
        self.use_spline = use_spline
        self.use_swish = use_swish
        self.use_stable = use_stable
        self.pseudo3d = pseudo3d
        self.annealing_factor = [
            1.
        ]  # initialize here; will be changed during training
        self.n_levels = 0

        # decoder parts
        if use_nvae:
            decoder = NDecoder(num_convolutions=self.num_convolutions,
                               filters=self.dec_filters,
                               latent_dim=self.latent_dim + self.context_dim,
                               output_size=self.decoder_shape)
        else:
            decoder = Decoder(
                num_convolutions=self.num_convolutions,
                filters=self.dec_filters,
                latent_dim=self.latent_dim + self.context_dim,
                upconv=self.use_upconv,
                output_size=self.decoder_shape,
                use_weight_norm=self.use_weight_norm,
                use_spectral_norm=self.use_spectral_norm,
            )

        self._create_decoder(decoder)

        # encoder parts
        if self.use_nvae:
            self.encoder = NEncoder(num_convolutions=self.num_convolutions,
                                    filters=self.enc_filters,
                                    latent_dim=self.latent_dim,
                                    input_size=self.encoder_shape)
        else:
            self.encoder = Encoder(num_convolutions=self.num_convolutions,
                                   filters=self.enc_filters,
                                   latent_dim=self.latent_dim,
                                   input_size=self.encoder_shape,
                                   use_weight_norm=self.use_weight_norm,
                                   use_spectral_norm=self.use_spectral_norm)

        nonlinearity = Swish() if self.use_swish else torch.nn.LeakyReLU(0.1)
        latent_layers = torch.nn.Sequential(
            torch.nn.Linear(self.latent_dim + self.context_dim,
                            self.latent_dim), nonlinearity)

        if self.posterior_components > 1:
            self.latent_encoder = DeepIndepMixtureNormal(
                latent_layers, self.latent_dim, self.latent_dim,
                self.posterior_components)
        else:
            self.latent_encoder = DeepIndepNormal(latent_layers,
                                                  self.latent_dim,
                                                  self.latent_dim)

        if self.prior_components > 1:
            self.z_loc = torch.nn.Parameter(
                torch.randn([self.prior_components, self.latent_dim]))
            self.z_scale = torch.nn.Parameter(
                torch.randn([self.latent_dim]).clamp(min=-1.,
                                                     max=None))  # log scale
            self.register_buffer(
                'z_components',  # don't be bayesian about the mixture components
                ((1 / self.prior_components) * torch.ones(
                    [self.prior_components], requires_grad=False)).log())
        else:
            self.register_buffer(
                'z_loc', torch.zeros([
                    latent_dim,
                ], requires_grad=False))
            self.register_buffer(
                'z_scale', torch.ones([
                    latent_dim,
                ], requires_grad=False))
            self.z_components = None

        # priors
        self.sex_logits = torch.nn.Parameter(torch.zeros([
            1,
        ]))
        self.register_buffer('slice_number_min',
                             torch.zeros([
                                 1,
                             ], requires_grad=False))
        self.register_buffer(
            'slice_number_max', 241. * torch.ones([
                1,
            ], requires_grad=False) + 1.)

        for k in self.required_data - {'sex', 'x', 'slice_number'}:
            self.register_buffer(f'{k}_base_loc',
                                 torch.zeros([
                                     1,
                                 ], requires_grad=False))
            self.register_buffer(f'{k}_base_scale',
                                 torch.ones([
                                     1,
                                 ], requires_grad=False))

        self.register_buffer('x_base_loc',
                             torch.zeros(self.img_shape, requires_grad=False))
        self.register_buffer('x_base_scale',
                             torch.ones(self.img_shape, requires_grad=False))

        for k in self.required_data - {'sex', 'x', 'slice_number'}:
            self.register_buffer(f'{k}_flow_lognorm_loc',
                                 torch.zeros([], requires_grad=False))
            self.register_buffer(f'{k}_flow_lognorm_scale',
                                 torch.ones([], requires_grad=False))

        perm = lambda: torch.randperm(
            self.latent_dim, dtype=torch.long, requires_grad=False)

        self.use_prior_flow = self.n_prior_flows > 0
        self.use_prior_permutations = self.n_prior_flows > 1
        if self.use_prior_permutations:
            for i in range(self.n_prior_flows):
                self.register_buffer(f'prior_flow_permutation_{i}', perm())

        self.use_posterior_flow = self.n_posterior_flows > 0
        self.use_posterior_permutations = self.n_posterior_flows > 1
        if self.use_posterior_permutations:
            for i in range(self.n_posterior_flows):
                self.register_buffer(f'posterior_flow_permutation_{i}', perm())

        # age flow
        self.age_flow_components = ComposeTransformModule([Spline(1)])
        self.age_flow_lognorm = AffineTransform(
            loc=self.age_flow_lognorm_loc.item(),
            scale=self.age_flow_lognorm_scale.item())
        self.age_flow_constraint_transforms = ComposeTransform(
            [self.age_flow_lognorm, ExpTransform()])
        self.age_flow_transforms = ComposeTransform(
            [self.age_flow_components, self.age_flow_constraint_transforms])

        # other flows shared components
        self.ventricle_volume_flow_lognorm = AffineTransform(
            loc=self.ventricle_volume_flow_lognorm_loc.item(),
            scale=self.ventricle_volume_flow_lognorm_scale.item(
            ))  # noqa: E501
        self.ventricle_volume_flow_constraint_transforms = ComposeTransform(
            [self.ventricle_volume_flow_lognorm,
             ExpTransform()])

        self.brain_volume_flow_lognorm = AffineTransform(
            loc=self.brain_volume_flow_lognorm_loc.item(),
            scale=self.brain_volume_flow_lognorm_scale.item())
        self.brain_volume_flow_constraint_transforms = ComposeTransform(
            [self.brain_volume_flow_lognorm,
             ExpTransform()])

        self.lesion_volume_flow_lognorm = AffineTransform(
            loc=self.lesion_volume_flow_lognorm_loc.item(),
            scale=self.lesion_volume_flow_lognorm_scale.item())
        self.lesion_volume_flow_eps = AffineTransform(loc=-eps, scale=1.)
        self.lesion_volume_flow_constraint_transforms = ComposeTransform([
            self.lesion_volume_flow_lognorm,
            ExpTransform(), self.lesion_volume_flow_eps
        ])

        self.duration_flow_lognorm = AffineTransform(
            loc=self.duration_flow_lognorm_loc.item(),
            scale=self.duration_flow_lognorm_scale.item())
        self.duration_flow_eps = AffineTransform(loc=-eps, scale=1.)
        self.duration_flow_constraint_transforms = ComposeTransform([
            self.duration_flow_lognorm,
            ExpTransform(), self.duration_flow_eps
        ])

        self.edss_flow_lognorm = AffineTransform(
            loc=self.edss_flow_lognorm_loc.item(),
            scale=self.edss_flow_lognorm_scale.item())
        self.edss_flow_eps = AffineTransform(loc=-eps, scale=1.)
        self.edss_flow_constraint_transforms = ComposeTransform(
            [self.edss_flow_lognorm,
             ExpTransform(), self.edss_flow_eps])

        hidden_dims = (3 * self.latent_dim +
                       1, ) if self.use_autoregressive else (2 *
                                                             self.latent_dim,
                                                             2 *
                                                             self.latent_dim)
        flow_kwargs = dict(hidden_dims=hidden_dims, nonlinearity=nonlinearity)
        if self.use_spline:
            flow_ = spline_autoregressive if self.use_autoregressive else spline_coupling
        else:
            flow_ = affine_autoregressive if self.use_autoregressive else affine_coupling
        if self.use_autoregressive:
            flow_kwargs['stable'] = self.use_stable

        if self.use_prior_permutations:
            self.prior_affine = iterated(
                self.n_prior_flows, batchnorm, self.latent_dim,
                momentum=0.05) if self.use_prior_flow else []
            self.prior_permutations = [
                Permute(getattr(self, f'prior_flow_permutation_{i}'))
                for i in range(self.n_prior_flows)
            ]
            self.prior_flow_components = iterated(
                self.n_prior_flows, flow_, self.latent_dim, **
                flow_kwargs) if self.use_prior_flow else []
            self.prior_flow_transforms = [
                x for c in zip(self.prior_permutations, self.prior_affine,
                               self.prior_flow_components) for x in c
            ]
        else:
            self.prior_affine = []
            self.prior_permutations = []
            self.prior_flow_components = flow_(
                self.latent_dim, **flow_kwargs) if self.use_prior_flow else []
            self.prior_flow_transforms = [self.prior_flow_components]

        if self.use_posterior_permutations:
            self.posterior_affine = iterated(self.n_posterior_flows,
                                             batchnorm,
                                             self.latent_dim,
                                             momentum=0.05)
            self.posterior_permutations = [
                Permute(getattr(self, f'posterior_flow_permutation_{i}'))
                for i in range(self.n_posterior_flows)
            ]
            self.posterior_flow_components = iterated(self.n_posterior_flows,
                                                      flow_, self.latent_dim,
                                                      **flow_kwargs)
            self.posterior_flow_transforms = [
                x
                for c in zip(self.posterior_permutations, self.
                             posterior_affine, self.posterior_flow_components)
                for x in c
            ]
        else:
            self.posterior_affine = []
            self.posterior_permutations = []
            self.posterior_flow_components = flow_(
                self.latent_dim, **
                flow_kwargs) if self.use_posterior_flow else []
            self.posterior_flow_transforms = [self.posterior_flow_components]
class ConditionalDecoderVISEM(BaseVISEM):
    context_dim = 2

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        # Flow for modelling t Gamma
        self.thickness_flow_components = ComposeTransformModule([Spline(1)])
        self.thickness_flow_constraint_transforms = ComposeTransform(
            [self.thickness_flow_lognorm,
             ExpTransform()])
        self.thickness_flow_transforms = ComposeTransform([
            self.thickness_flow_components,
            self.thickness_flow_constraint_transforms
        ])

        # affine flow for s normal
        self.intensity_flow_components = ComposeTransformModule(
            [LearnedAffineTransform(), Spline(1)])
        self.intensity_flow_constraint_transforms = ComposeTransform(
            [SigmoidTransform(), self.intensity_flow_norm])
        self.intensity_flow_transforms = [
            self.intensity_flow_components,
            self.intensity_flow_constraint_transforms
        ]

    @pyro_method
    def pgm_model(self):
        thickness_base_dist = Normal(self.thickness_base_loc,
                                     self.thickness_base_scale).to_event(1)
        thickness_dist = TransformedDistribution(
            thickness_base_dist, self.thickness_flow_transforms)

        thickness = pyro.sample('thickness', thickness_dist)
        # pseudo call to thickness_flow_transforms to register with pyro
        _ = self.thickness_flow_components

        intensity_base_dist = Normal(self.intensity_base_loc,
                                     self.intensity_base_scale).to_event(1)
        intensity_dist = TransformedDistribution(
            intensity_base_dist, self.intensity_flow_transforms)

        intensity = pyro.sample('intensity', intensity_dist)
        # pseudo call to intensity_flow_transforms to register with pyro
        _ = self.intensity_flow_components

        return thickness, intensity

    @pyro_method
    def model(self):
        thickness, intensity = self.pgm_model()

        thickness_ = self.thickness_flow_constraint_transforms.inv(thickness)
        intensity_ = self.intensity_flow_norm.inv(intensity)

        z = pyro.sample('z', Normal(self.z_loc, self.z_scale).to_event(1))

        latent = torch.cat([z, thickness_, intensity_], 1)

        x_dist = self._get_transformed_x_dist(latent)

        x = pyro.sample('x', x_dist)

        return x, z, thickness, intensity

    @pyro_method
    def guide(self, x, thickness, intensity):
        with pyro.plate('observations', x.shape[0]):
            hidden = self.encoder(x)

            thickness_ = self.thickness_flow_constraint_transforms.inv(
                thickness)
            intensity_ = self.intensity_flow_norm.inv(intensity)

            hidden = torch.cat([hidden, thickness_, intensity_], 1)
            latent_dist = self.latent_encoder.predict(hidden)

            z = pyro.sample('z', latent_dist)

        return z

    @pyro_method
    def infer_thickness_base(self, thickness):
        return self.thickness_flow_transforms.inv(thickness)

    @pyro_method
    def infer_intensity_base(self, intensity):
        return self.intensity_flow_transforms.inv(intensity)
示例#25
0
    def __init__(self,
                 latent_dim: int,
                 logstd_init: float = -5,
                 enc_filters: str = '16,32,64,128',
                 dec_filters: str = '128,64,32,16',
                 num_convolutions: int = 2,
                 use_upconv: bool = False,
                 decoder_type: str = 'fixed_var',
                 decoder_cov_rank: int = 10,
                 **kwargs):
        super().__init__(**kwargs)

        self.img_shape = (1, 192 // self.downsample, 192 //
                          self.downsample) if self.downsample > 0 else (1, 192,
                                                                        192)

        self.latent_dim = latent_dim
        self.logstd_init = logstd_init

        self.enc_filters = tuple(
            int(f.strip()) for f in enc_filters.split(','))
        self.dec_filters = tuple(
            int(f.strip()) for f in dec_filters.split(','))
        self.num_convolutions = num_convolutions
        self.use_upconv = use_upconv
        self.decoder_type = decoder_type
        self.decoder_cov_rank = decoder_cov_rank

        # decoder parts
        decoder = Decoder(num_convolutions=self.num_convolutions,
                          filters=self.dec_filters,
                          latent_dim=self.latent_dim + self.context_dim,
                          upconv=self.use_upconv,
                          output_size=self.img_shape)

        if self.decoder_type == 'fixed_var':
            self.decoder = Conv2dIndepNormal(decoder, 1, 1)

            torch.nn.init.zeros_(self.decoder.logvar_head.weight)
            self.decoder.logvar_head.weight.requires_grad = False

            torch.nn.init.constant_(self.decoder.logvar_head.bias,
                                    self.logstd_init)
            self.decoder.logvar_head.bias.requires_grad = False
        elif self.decoder_type == 'learned_var':
            self.decoder = Conv2dIndepNormal(decoder, 1, 1)

            torch.nn.init.zeros_(self.decoder.logvar_head.weight)
            self.decoder.logvar_head.weight.requires_grad = False

            torch.nn.init.constant_(self.decoder.logvar_head.bias,
                                    self.logstd_init)
            self.decoder.logvar_head.bias.requires_grad = True
        elif self.decoder_type == 'independent_gaussian':
            self.decoder = Conv2dIndepNormal(decoder, 1, 1)

            torch.nn.init.zeros_(self.decoder.logvar_head.weight)
            self.decoder.logvar_head.weight.requires_grad = True

            torch.nn.init.normal_(self.decoder.logvar_head.bias,
                                  self.logstd_init, 1e-1)
            self.decoder.logvar_head.bias.requires_grad = True
        elif self.decoder_type == 'multivariate_gaussian':
            seq = torch.nn.Sequential(decoder,
                                      Lambda(lambda x: x.view(x.shape[0], -1)))
            self.decoder = DeepMultivariateNormal(seq, np.prod(self.img_shape),
                                                  np.prod(self.img_shape))
        elif self.decoder_type == 'sharedvar_multivariate_gaussian':
            seq = torch.nn.Sequential(decoder,
                                      Lambda(lambda x: x.view(x.shape[0], -1)))
            self.decoder = DeepMultivariateNormal(seq, np.prod(self.img_shape),
                                                  np.prod(self.img_shape))

            torch.nn.init.zeros_(self.decoder.logdiag_head.weight)
            self.decoder.logdiag_head.weight.requires_grad = False

            torch.nn.init.zeros_(self.decoder.lower_head.weight)
            self.decoder.lower_head.weight.requires_grad = False

            torch.nn.init.normal_(self.decoder.logdiag_head.bias,
                                  self.logstd_init, 1e-1)
            self.decoder.logdiag_head.bias.requires_grad = True
        elif self.decoder_type == 'lowrank_multivariate_gaussian':
            seq = torch.nn.Sequential(decoder,
                                      Lambda(lambda x: x.view(x.shape[0], -1)))
            self.decoder = DeepLowRankMultivariateNormal(
                seq, np.prod(self.img_shape), np.prod(self.img_shape),
                decoder_cov_rank)
        elif self.decoder_type == 'sharedvar_lowrank_multivariate_gaussian':
            seq = torch.nn.Sequential(decoder,
                                      Lambda(lambda x: x.view(x.shape[0], -1)))
            self.decoder = DeepLowRankMultivariateNormal(
                seq, np.prod(self.img_shape), np.prod(self.img_shape),
                decoder_cov_rank)

            torch.nn.init.zeros_(self.decoder.logdiag_head.weight)
            self.decoder.logdiag_head.weight.requires_grad = False

            torch.nn.init.zeros_(self.decoder.factor_head.weight)
            self.decoder.factor_head.weight.requires_grad = False

            torch.nn.init.normal_(self.decoder.logdiag_head.bias,
                                  self.logstd_init, 1e-1)
            self.decoder.logdiag_head.bias.requires_grad = True
        else:
            raise ValueError('unknown  ')

        # encoder parts
        self.encoder = Encoder(num_convolutions=self.num_convolutions,
                               filters=self.enc_filters,
                               latent_dim=self.latent_dim,
                               input_size=self.img_shape)

        latent_layers = torch.nn.Sequential(
            torch.nn.Linear(self.latent_dim + self.context_dim,
                            self.latent_dim), torch.nn.ReLU())
        self.latent_encoder = DeepIndepNormal(latent_layers, self.latent_dim,
                                              self.latent_dim)

        # priors
        self.register_buffer('age_base_loc',
                             torch.zeros([
                                 1,
                             ], requires_grad=False))
        self.register_buffer('age_base_scale',
                             torch.ones([
                                 1,
                             ], requires_grad=False))

        self.sex_logits = torch.nn.Parameter(torch.zeros([
            1,
        ]))

        self.register_buffer('ventricle_volume_base_loc',
                             torch.zeros([
                                 1,
                             ], requires_grad=False))
        self.register_buffer('ventricle_volume_base_scale',
                             torch.ones([
                                 1,
                             ], requires_grad=False))

        self.register_buffer('brain_volume_base_loc',
                             torch.zeros([
                                 1,
                             ], requires_grad=False))
        self.register_buffer('brain_volume_base_scale',
                             torch.ones([
                                 1,
                             ], requires_grad=False))

        self.register_buffer('z_loc',
                             torch.zeros([
                                 latent_dim,
                             ], requires_grad=False))
        self.register_buffer('z_scale',
                             torch.ones([
                                 latent_dim,
                             ], requires_grad=False))

        self.register_buffer('x_base_loc',
                             torch.zeros(self.img_shape, requires_grad=False))
        self.register_buffer('x_base_scale',
                             torch.ones(self.img_shape, requires_grad=False))

        self.register_buffer('age_flow_lognorm_loc',
                             torch.zeros([], requires_grad=False))
        self.register_buffer('age_flow_lognorm_scale',
                             torch.ones([], requires_grad=False))

        self.register_buffer('ventricle_volume_flow_lognorm_loc',
                             torch.zeros([], requires_grad=False))
        self.register_buffer('ventricle_volume_flow_lognorm_scale',
                             torch.ones([], requires_grad=False))

        self.register_buffer('brain_volume_flow_lognorm_loc',
                             torch.zeros([], requires_grad=False))
        self.register_buffer('brain_volume_flow_lognorm_scale',
                             torch.ones([], requires_grad=False))

        # age flow
        self.age_flow_components = ComposeTransformModule([Spline(1)])
        self.age_flow_lognorm = AffineTransform(
            loc=self.age_flow_lognorm_loc.item(),
            scale=self.age_flow_lognorm_scale.item())
        self.age_flow_constraint_transforms = ComposeTransform(
            [self.age_flow_lognorm, ExpTransform()])
        self.age_flow_transforms = ComposeTransform(
            [self.age_flow_components, self.age_flow_constraint_transforms])

        # other flows shared components
        self.ventricle_volume_flow_lognorm = AffineTransform(
            loc=self.ventricle_volume_flow_lognorm_loc.item(),
            scale=self.ventricle_volume_flow_lognorm_scale.item(
            ))  # noqa: E501
        self.ventricle_volume_flow_constraint_transforms = ComposeTransform(
            [self.ventricle_volume_flow_lognorm,
             ExpTransform()])

        self.brain_volume_flow_lognorm = AffineTransform(
            loc=self.brain_volume_flow_lognorm_loc.item(),
            scale=self.brain_volume_flow_lognorm_scale.item())
        self.brain_volume_flow_constraint_transforms = ComposeTransform(
            [self.brain_volume_flow_lognorm,
             ExpTransform()])
示例#26
0
class ConditionalFlowSEM(BaseFlowSEM):
    def __init__(self, use_affine_ex: bool = True, **kwargs):
        super().__init__(**kwargs)
        self.use_affine_ex = use_affine_ex

        # decoder parts

        # Flow for modelling t Gamma
        self.thickness_flow_components = ComposeTransformModule([Spline(1)])
        self.thickness_flow_constraint_transforms = ComposeTransform(
            [self.thickness_flow_lognorm,
             ExpTransform()])
        self.thickness_flow_transforms = ComposeTransform([
            self.thickness_flow_components,
            self.thickness_flow_constraint_transforms
        ])

        # affine flow for s normal
        intensity_net = DenseNN(1, [1],
                                param_dims=[1, 1],
                                nonlinearity=torch.nn.Identity())
        self.intensity_flow_components = ConditionalAffineTransform(
            context_nn=intensity_net, event_dim=0)
        self.intensity_flow_constraint_transforms = ComposeTransform(
            [SigmoidTransform(), self.intensity_flow_norm])
        self.intensity_flow_transforms = [
            self.intensity_flow_components,
            self.intensity_flow_constraint_transforms
        ]
        # build flow as s_affine_w * t * e_s + b -> depends on t though

        # realnvp or so for x
        self._build_image_flow()

    def _build_image_flow(self):

        self.trans_modules = ComposeTransformModule([])

        self.x_transforms = []

        self.x_transforms += [self._get_preprocess_transforms()]

        c = 1
        for _ in range(self.num_scales):
            self.x_transforms.append(SqueezeTransform())
            c *= 4

            for _ in range(self.flows_per_scale):
                if self.use_actnorm:
                    actnorm = ActNorm(c)
                    self.trans_modules.append(actnorm)
                    self.x_transforms.append(actnorm)

                gcp = GeneralizedChannelPermute(channels=c)
                self.trans_modules.append(gcp)
                self.x_transforms.append(gcp)

                self.x_transforms.append(
                    TransposeTransform(torch.tensor((1, 2, 0))))

                ac = ConditionalAffineCoupling(
                    c // 2,
                    BasicFlowConvNet(c // 2, self.hidden_channels,
                                     (c // 2, c // 2), 2))
                self.trans_modules.append(ac)
                self.x_transforms.append(ac)

                self.x_transforms.append(
                    TransposeTransform(torch.tensor((2, 0, 1))))

            gcp = GeneralizedChannelPermute(channels=c)
            self.trans_modules.append(gcp)
            self.x_transforms.append(gcp)

        self.x_transforms += [
            ReshapeTransform((4**self.num_scales, 32 // 2**self.num_scales,
                              32 // 2**self.num_scales), (1, 32, 32))
        ]

        if self.use_affine_ex:
            affine_net = DenseNN(2, [16, 16], param_dims=[1, 1])
            affine_trans = ConditionalAffineTransform(context_nn=affine_net,
                                                      event_dim=3)

            self.trans_modules.append(affine_trans)
            self.x_transforms.append(affine_trans)

    @pyro_method
    def pgm_model(self):
        thickness_base_dist = Normal(self.thickness_base_loc,
                                     self.thickness_base_scale).to_event(1)
        thickness_dist = TransformedDistribution(
            thickness_base_dist, self.thickness_flow_transforms)

        thickness = pyro.sample('thickness', thickness_dist)
        thickness_ = self.thickness_flow_constraint_transforms.inv(thickness)
        # pseudo call to thickness_flow_transforms to register with pyro
        _ = self.thickness_flow_components

        intensity_base_dist = Normal(self.intensity_base_loc,
                                     self.intensity_base_scale).to_event(1)
        intensity_dist = ConditionalTransformedDistribution(
            intensity_base_dist,
            self.intensity_flow_transforms).condition(thickness_)

        intensity = pyro.sample('intensity', intensity_dist)
        # pseudo call to w_flow_transforms to register with pyro
        _ = self.intensity_flow_components

        return thickness, intensity

    @pyro_method
    def model(self):
        thickness, intensity = self.pgm_model()

        thickness_ = self.thickness_flow_constraint_transforms.inv(thickness)
        intensity_ = self.intensity_flow_norm.inv(intensity)

        context = torch.cat([thickness_, intensity_], 1)

        x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3)
        cond_x_transforms = ComposeTransform(
            ConditionalTransformedDistribution(
                x_base_dist,
                self.x_transforms).condition(context).transforms).inv
        cond_x_dist = TransformedDistribution(x_base_dist, cond_x_transforms)

        x = pyro.sample('x', cond_x_dist)

        return x, thickness, intensity

    @pyro_method
    def infer_thickness_base(self, thickness):
        return self.thickness_flow_transforms.inv(thickness)

    @pyro_method
    def infer_intensity_base(self, thickness, intensity):
        intensity_base_dist = Normal(self.intensity_base_loc,
                                     self.intensity_base_scale)

        thickness_ = self.thickness_flow_constraint_transforms.inv(thickness)
        cond_intensity_transforms = ComposeTransform(
            ConditionalTransformedDistribution(
                intensity_base_dist, self.intensity_flow_transforms).condition(
                    thickness_).transforms)
        return cond_intensity_transforms.inv(intensity)

    @pyro_method
    def infer_x_base(self, thickness, intensity, x):
        x_base_dist = Normal(self.x_base_loc, self.x_base_scale)

        thickness_ = self.thickness_flow_constraint_transforms.inv(thickness)
        intensity_ = self.intensity_flow_norm.inv(intensity)

        context = torch.cat([thickness_, intensity_], 1)
        cond_x_transforms = ComposeTransform(
            ConditionalTransformedDistribution(
                x_base_dist, self.x_transforms).condition(context).transforms)
        return cond_x_transforms(x)

    @classmethod
    def add_arguments(cls, parser):
        parser = super().add_arguments(parser)

        parser.add_argument(
            '--use_affine_ex',
            default=False,
            action='store_true',
            help=
            "whether to use conditional affine transformation on e_x (default: %(default)s)"
        )

        return parser
class ConditionalSTNVISEM(BaseVISEM):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        # decoder parts
        self.decoder = Decoder(num_convolutions=self.num_convolutions,
                               filters=self.dec_filters,
                               latent_dim=self.latent_dim + 3,
                               upconv=self.use_upconv)

        self.decoder_mean = torch.nn.Conv2d(1, 1, 1)
        self.decoder_logstd = torch.nn.Parameter(
            torch.ones([]) * self.logstd_init)

        self.decoder_affine_param_net = nn.Sequential(
            nn.Linear(self.latent_dim + 3, self.latent_dim), nn.ReLU(),
            nn.Linear(self.latent_dim, self.latent_dim), nn.ReLU(),
            nn.Linear(self.latent_dim, 6))

        self.decoder_affine_param_net[-1].weight.data.zero_()
        self.decoder_affine_param_net[-1].bias.data.copy_(
            torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

        # age flow
        self.age_flow_components = ComposeTransformModule([Spline(1)])
        self.age_flow_lognorm = AffineTransform(loc=0., scale=1.)
        self.age_flow_constraint_transforms = ComposeTransform(
            [self.age_flow_lognorm, ExpTransform()])
        self.age_flow_transforms = ComposeTransform(
            [self.age_flow_components, self.age_flow_constraint_transforms])

        # ventricle_volume flow
        # TODO: decide on how many things to condition on
        ventricle_volume_net = DenseNN(2, [8, 16],
                                       param_dims=[1, 1],
                                       nonlinearity=torch.nn.Identity())
        self.ventricle_volume_flow_components = ConditionalAffineTransform(
            context_nn=ventricle_volume_net, event_dim=0)
        self.ventricle_volume_flow_lognorm = AffineTransform(loc=0., scale=1.)
        self.ventricle_volume_flow_constraint_transforms = ComposeTransform(
            [self.ventricle_volume_flow_lognorm,
             ExpTransform()])
        self.ventricle_volume_flow_transforms = [
            self.ventricle_volume_flow_components,
            self.ventricle_volume_flow_constraint_transforms
        ]

        # brain_volume flow
        # TODO: decide on how many things to condition on
        brain_volume_net = DenseNN(2, [8, 16],
                                   param_dims=[1, 1],
                                   nonlinearity=torch.nn.Identity())
        self.brain_volume_flow_components = ConditionalAffineTransform(
            context_nn=brain_volume_net, event_dim=0)
        self.brain_volume_flow_lognorm = AffineTransform(loc=0., scale=1.)
        self.brain_volume_flow_constraint_transforms = ComposeTransform(
            [self.brain_volume_flow_lognorm,
             ExpTransform()])
        self.brain_volume_flow_transforms = [
            self.brain_volume_flow_components,
            self.brain_volume_flow_constraint_transforms
        ]

        # encoder parts
        self.encoder = Encoder(num_convolutions=self.num_convolutions,
                               filters=self.enc_filters,
                               latent_dim=self.latent_dim)

        # TODO: do we need to replicate the PGM here to be able to run conterfactuals? oO
        latent_layers = torch.nn.Sequential(
            torch.nn.Linear(self.latent_dim + 3, self.latent_dim),
            torch.nn.ReLU())
        self.latent_encoder = DeepIndepNormal(latent_layers, self.latent_dim,
                                              self.latent_dim)

    @pyro_method
    def pgm_model(self):
        sex_dist = Bernoulli(self.sex_logits).to_event(1)

        sex = pyro.sample('sex', sex_dist)

        age_base_dist = Normal(self.age_base_loc,
                               self.age_base_scale).to_event(1)
        age_dist = TransformedDistribution(age_base_dist,
                                           self.age_flow_transforms)

        age = pyro.sample('age', age_dist)
        age_ = self.age_flow_constraint_transforms.inv(age)
        # pseudo call to thickness_flow_transforms to register with pyro
        _ = self.age_flow_transforms

        context = torch.cat([sex, age_], 1)

        ventricle_volume_base_dist = Normal(
            self.ventricle_volume_base_loc,
            self.ventricle_volume_base_scale).to_event(1)
        ventricle_volume_dist = ConditionalTransformedDistribution(
            ventricle_volume_base_dist,
            self.ventricle_volume_flow_transforms).condition(context)

        ventricle_volume = pyro.sample('ventricle_volume',
                                       ventricle_volume_dist)
        # pseudo call to intensity_flow_transforms to register with pyro
        _ = self.ventricle_volume_flow_transforms

        brain_volume_base_dist = Normal(
            self.brain_volume_base_loc,
            self.brain_volume_base_scale).to_event(1)
        brain_volume_dist = ConditionalTransformedDistribution(
            brain_volume_base_dist,
            self.brain_volume_flow_transforms).condition(context)

        brain_volume = pyro.sample('brain_volume', brain_volume_dist)
        # pseudo call to intensity_flow_transforms to register with pyro
        _ = self.brain_volume_flow_transforms

        return age, sex, ventricle_volume, brain_volume

    @pyro_method
    def model(self):
        age, sex, ventricle_volume, brain_volume = self.pgm_model()

        ventricle_volume_ = self.ventricle_volume_flow_constraint_transforms.inv(
            ventricle_volume)
        brain_volume_ = self.brain_volume_flow_constraint_transforms.inv(
            brain_volume)
        age_ = self.age_flow_constraint_transforms.inv(age)

        z = pyro.sample('z', Normal(self.z_loc, self.z_scale).to_event(1))

        latent = torch.cat([z, age_, ventricle_volume_, brain_volume_], 1)

        x_loc = self.decoder_mean(self.decoder(latent))
        x_scale = torch.exp(self.decoder_logstd)

        theta = self.decoder_affine_param_net(latent).view(-1, 2, 3)

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")

            grid = nn.functional.affine_grid(theta, x_loc.size())
            x_loc_deformed = nn.functional.grid_sample(x_loc, grid)

        x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3)

        preprocess_transform = self._get_preprocess_transforms()
        x_dist = TransformedDistribution(
            x_base_dist,
            ComposeTransform([
                AffineTransform(x_loc_deformed, x_scale, 3),
                preprocess_transform
            ]))

        x = pyro.sample('x', x_dist)

        return x, z, age, sex, ventricle_volume, brain_volume

    @pyro_method
    def guide(self, x, age, sex, ventricle_volume, brain_volume):
        with pyro.plate('observations', x.shape[0]):
            hidden = self.encoder(x)

            ventricle_volume_ = self.ventricle_volume_flow_constraint_transforms.inv(
                ventricle_volume)
            brain_volume_ = self.brain_volume_flow_constraint_transforms.inv(
                brain_volume)
            age_ = self.age_flow_constraint_transforms.inv(age)

            hidden = torch.cat(
                [hidden, age_, ventricle_volume_, brain_volume_], 1)

            latent_dist = self.latent_encoder.predict(hidden)

            z = pyro.sample('z', latent_dist)

        return z
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        # decoder parts
        self.decoder = Decoder(num_convolutions=self.num_convolutions,
                               filters=self.dec_filters,
                               latent_dim=self.latent_dim + 3,
                               upconv=self.use_upconv)

        self.decoder_mean = torch.nn.Conv2d(1, 1, 1)
        self.decoder_logstd = torch.nn.Parameter(
            torch.ones([]) * self.logstd_init)

        self.decoder_affine_param_net = nn.Sequential(
            nn.Linear(self.latent_dim + 3, self.latent_dim), nn.ReLU(),
            nn.Linear(self.latent_dim, self.latent_dim), nn.ReLU(),
            nn.Linear(self.latent_dim, 6))

        self.decoder_affine_param_net[-1].weight.data.zero_()
        self.decoder_affine_param_net[-1].bias.data.copy_(
            torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

        # age flow
        self.age_flow_components = ComposeTransformModule([Spline(1)])
        self.age_flow_lognorm = AffineTransform(loc=0., scale=1.)
        self.age_flow_constraint_transforms = ComposeTransform(
            [self.age_flow_lognorm, ExpTransform()])
        self.age_flow_transforms = ComposeTransform(
            [self.age_flow_components, self.age_flow_constraint_transforms])

        # ventricle_volume flow
        # TODO: decide on how many things to condition on
        ventricle_volume_net = DenseNN(2, [8, 16],
                                       param_dims=[1, 1],
                                       nonlinearity=torch.nn.Identity())
        self.ventricle_volume_flow_components = ConditionalAffineTransform(
            context_nn=ventricle_volume_net, event_dim=0)
        self.ventricle_volume_flow_lognorm = AffineTransform(loc=0., scale=1.)
        self.ventricle_volume_flow_constraint_transforms = ComposeTransform(
            [self.ventricle_volume_flow_lognorm,
             ExpTransform()])
        self.ventricle_volume_flow_transforms = [
            self.ventricle_volume_flow_components,
            self.ventricle_volume_flow_constraint_transforms
        ]

        # brain_volume flow
        # TODO: decide on how many things to condition on
        brain_volume_net = DenseNN(2, [8, 16],
                                   param_dims=[1, 1],
                                   nonlinearity=torch.nn.Identity())
        self.brain_volume_flow_components = ConditionalAffineTransform(
            context_nn=brain_volume_net, event_dim=0)
        self.brain_volume_flow_lognorm = AffineTransform(loc=0., scale=1.)
        self.brain_volume_flow_constraint_transforms = ComposeTransform(
            [self.brain_volume_flow_lognorm,
             ExpTransform()])
        self.brain_volume_flow_transforms = [
            self.brain_volume_flow_components,
            self.brain_volume_flow_constraint_transforms
        ]

        # encoder parts
        self.encoder = Encoder(num_convolutions=self.num_convolutions,
                               filters=self.enc_filters,
                               latent_dim=self.latent_dim)

        # TODO: do we need to replicate the PGM here to be able to run conterfactuals? oO
        latent_layers = torch.nn.Sequential(
            torch.nn.Linear(self.latent_dim + 3, self.latent_dim),
            torch.nn.ReLU())
        self.latent_encoder = DeepIndepNormal(latent_layers, self.latent_dim,
                                              self.latent_dim)