def infer_intensity_base(self, thickness, intensity):
     intensity_base_dist = Normal(self.intensity_base_loc,
                                  self.intensity_base_scale)
     cond_intensity_transforms = ComposeTransform(
         ConditionalTransformedDistribution(
             intensity_base_dist, self.intensity_flow_transforms).condition(
                 thickness).transforms)
     return cond_intensity_transforms.inv(intensity)
class ConditionalDecoderVISEM(BaseVISEM):
    context_dim = 2

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        # Flow for modelling t Gamma
        self.thickness_flow_components = ComposeTransformModule([Spline(1)])
        self.thickness_flow_constraint_transforms = ComposeTransform(
            [self.thickness_flow_lognorm,
             ExpTransform()])
        self.thickness_flow_transforms = ComposeTransform([
            self.thickness_flow_components,
            self.thickness_flow_constraint_transforms
        ])

        # affine flow for s normal
        self.intensity_flow_components = ComposeTransformModule(
            [LearnedAffineTransform(), Spline(1)])
        self.intensity_flow_constraint_transforms = ComposeTransform(
            [SigmoidTransform(), self.intensity_flow_norm])
        self.intensity_flow_transforms = [
            self.intensity_flow_components,
            self.intensity_flow_constraint_transforms
        ]

    @pyro_method
    def pgm_model(self):
        thickness_base_dist = Normal(self.thickness_base_loc,
                                     self.thickness_base_scale).to_event(1)
        thickness_dist = TransformedDistribution(
            thickness_base_dist, self.thickness_flow_transforms)

        thickness = pyro.sample('thickness', thickness_dist)
        # pseudo call to thickness_flow_transforms to register with pyro
        _ = self.thickness_flow_components

        intensity_base_dist = Normal(self.intensity_base_loc,
                                     self.intensity_base_scale).to_event(1)
        intensity_dist = TransformedDistribution(
            intensity_base_dist, self.intensity_flow_transforms)

        intensity = pyro.sample('intensity', intensity_dist)
        # pseudo call to intensity_flow_transforms to register with pyro
        _ = self.intensity_flow_components

        return thickness, intensity

    @pyro_method
    def model(self):
        thickness, intensity = self.pgm_model()

        thickness_ = self.thickness_flow_constraint_transforms.inv(thickness)
        intensity_ = self.intensity_flow_norm.inv(intensity)

        z = pyro.sample('z', Normal(self.z_loc, self.z_scale).to_event(1))

        latent = torch.cat([z, thickness_, intensity_], 1)

        x_dist = self._get_transformed_x_dist(latent)

        x = pyro.sample('x', x_dist)

        return x, z, thickness, intensity

    @pyro_method
    def guide(self, x, thickness, intensity):
        with pyro.plate('observations', x.shape[0]):
            hidden = self.encoder(x)

            thickness_ = self.thickness_flow_constraint_transforms.inv(
                thickness)
            intensity_ = self.intensity_flow_norm.inv(intensity)

            hidden = torch.cat([hidden, thickness_, intensity_], 1)
            latent_dist = self.latent_encoder.predict(hidden)

            z = pyro.sample('z', latent_dist)

        return z

    @pyro_method
    def infer_thickness_base(self, thickness):
        return self.thickness_flow_transforms.inv(thickness)

    @pyro_method
    def infer_intensity_base(self, intensity):
        return self.intensity_flow_transforms.inv(intensity)
Exemple #3
0
class ConditionalFlowSEM(BaseFlowSEM):
    def __init__(self, use_affine_ex: bool = True, **kwargs):
        super().__init__(**kwargs)
        self.use_affine_ex = use_affine_ex

        # decoder parts

        # Flow for modelling t Gamma
        self.thickness_flow_components = ComposeTransformModule([Spline(1)])
        self.thickness_flow_constraint_transforms = ComposeTransform(
            [self.thickness_flow_lognorm,
             ExpTransform()])
        self.thickness_flow_transforms = ComposeTransform([
            self.thickness_flow_components,
            self.thickness_flow_constraint_transforms
        ])

        # affine flow for s normal
        intensity_net = DenseNN(1, [1],
                                param_dims=[1, 1],
                                nonlinearity=torch.nn.Identity())
        self.intensity_flow_components = ConditionalAffineTransform(
            context_nn=intensity_net, event_dim=0)
        self.intensity_flow_constraint_transforms = ComposeTransform(
            [SigmoidTransform(), self.intensity_flow_norm])
        self.intensity_flow_transforms = [
            self.intensity_flow_components,
            self.intensity_flow_constraint_transforms
        ]
        # build flow as s_affine_w * t * e_s + b -> depends on t though

        # realnvp or so for x
        self._build_image_flow()

    def _build_image_flow(self):

        self.trans_modules = ComposeTransformModule([])

        self.x_transforms = []

        self.x_transforms += [self._get_preprocess_transforms()]

        c = 1
        for _ in range(self.num_scales):
            self.x_transforms.append(SqueezeTransform())
            c *= 4

            for _ in range(self.flows_per_scale):
                if self.use_actnorm:
                    actnorm = ActNorm(c)
                    self.trans_modules.append(actnorm)
                    self.x_transforms.append(actnorm)

                gcp = GeneralizedChannelPermute(channels=c)
                self.trans_modules.append(gcp)
                self.x_transforms.append(gcp)

                self.x_transforms.append(
                    TransposeTransform(torch.tensor((1, 2, 0))))

                ac = ConditionalAffineCoupling(
                    c // 2,
                    BasicFlowConvNet(c // 2, self.hidden_channels,
                                     (c // 2, c // 2), 2))
                self.trans_modules.append(ac)
                self.x_transforms.append(ac)

                self.x_transforms.append(
                    TransposeTransform(torch.tensor((2, 0, 1))))

            gcp = GeneralizedChannelPermute(channels=c)
            self.trans_modules.append(gcp)
            self.x_transforms.append(gcp)

        self.x_transforms += [
            ReshapeTransform((4**self.num_scales, 32 // 2**self.num_scales,
                              32 // 2**self.num_scales), (1, 32, 32))
        ]

        if self.use_affine_ex:
            affine_net = DenseNN(2, [16, 16], param_dims=[1, 1])
            affine_trans = ConditionalAffineTransform(context_nn=affine_net,
                                                      event_dim=3)

            self.trans_modules.append(affine_trans)
            self.x_transforms.append(affine_trans)

    @pyro_method
    def pgm_model(self):
        thickness_base_dist = Normal(self.thickness_base_loc,
                                     self.thickness_base_scale).to_event(1)
        thickness_dist = TransformedDistribution(
            thickness_base_dist, self.thickness_flow_transforms)

        thickness = pyro.sample('thickness', thickness_dist)
        thickness_ = self.thickness_flow_constraint_transforms.inv(thickness)
        # pseudo call to thickness_flow_transforms to register with pyro
        _ = self.thickness_flow_components

        intensity_base_dist = Normal(self.intensity_base_loc,
                                     self.intensity_base_scale).to_event(1)
        intensity_dist = ConditionalTransformedDistribution(
            intensity_base_dist,
            self.intensity_flow_transforms).condition(thickness_)

        intensity = pyro.sample('intensity', intensity_dist)
        # pseudo call to w_flow_transforms to register with pyro
        _ = self.intensity_flow_components

        return thickness, intensity

    @pyro_method
    def model(self):
        thickness, intensity = self.pgm_model()

        thickness_ = self.thickness_flow_constraint_transforms.inv(thickness)
        intensity_ = self.intensity_flow_norm.inv(intensity)

        context = torch.cat([thickness_, intensity_], 1)

        x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3)
        cond_x_transforms = ComposeTransform(
            ConditionalTransformedDistribution(
                x_base_dist,
                self.x_transforms).condition(context).transforms).inv
        cond_x_dist = TransformedDistribution(x_base_dist, cond_x_transforms)

        x = pyro.sample('x', cond_x_dist)

        return x, thickness, intensity

    @pyro_method
    def infer_thickness_base(self, thickness):
        return self.thickness_flow_transforms.inv(thickness)

    @pyro_method
    def infer_intensity_base(self, thickness, intensity):
        intensity_base_dist = Normal(self.intensity_base_loc,
                                     self.intensity_base_scale)

        thickness_ = self.thickness_flow_constraint_transforms.inv(thickness)
        cond_intensity_transforms = ComposeTransform(
            ConditionalTransformedDistribution(
                intensity_base_dist, self.intensity_flow_transforms).condition(
                    thickness_).transforms)
        return cond_intensity_transforms.inv(intensity)

    @pyro_method
    def infer_x_base(self, thickness, intensity, x):
        x_base_dist = Normal(self.x_base_loc, self.x_base_scale)

        thickness_ = self.thickness_flow_constraint_transforms.inv(thickness)
        intensity_ = self.intensity_flow_norm.inv(intensity)

        context = torch.cat([thickness_, intensity_], 1)
        cond_x_transforms = ComposeTransform(
            ConditionalTransformedDistribution(
                x_base_dist, self.x_transforms).condition(context).transforms)
        return cond_x_transforms(x)

    @classmethod
    def add_arguments(cls, parser):
        parser = super().add_arguments(parser)

        parser.add_argument(
            '--use_affine_ex',
            default=False,
            action='store_true',
            help=
            "whether to use conditional affine transformation on e_x (default: %(default)s)"
        )

        return parser
class ConditionalSTNVISEM(BaseVISEM):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        # decoder parts
        self.decoder = Decoder(num_convolutions=self.num_convolutions,
                               filters=self.dec_filters,
                               latent_dim=self.latent_dim + 3,
                               upconv=self.use_upconv)

        self.decoder_mean = torch.nn.Conv2d(1, 1, 1)
        self.decoder_logstd = torch.nn.Parameter(
            torch.ones([]) * self.logstd_init)

        self.decoder_affine_param_net = nn.Sequential(
            nn.Linear(self.latent_dim + 3, self.latent_dim), nn.ReLU(),
            nn.Linear(self.latent_dim, self.latent_dim), nn.ReLU(),
            nn.Linear(self.latent_dim, 6))

        self.decoder_affine_param_net[-1].weight.data.zero_()
        self.decoder_affine_param_net[-1].bias.data.copy_(
            torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

        # age flow
        self.age_flow_components = ComposeTransformModule([Spline(1)])
        self.age_flow_lognorm = AffineTransform(loc=0., scale=1.)
        self.age_flow_constraint_transforms = ComposeTransform(
            [self.age_flow_lognorm, ExpTransform()])
        self.age_flow_transforms = ComposeTransform(
            [self.age_flow_components, self.age_flow_constraint_transforms])

        # ventricle_volume flow
        # TODO: decide on how many things to condition on
        ventricle_volume_net = DenseNN(2, [8, 16],
                                       param_dims=[1, 1],
                                       nonlinearity=torch.nn.Identity())
        self.ventricle_volume_flow_components = ConditionalAffineTransform(
            context_nn=ventricle_volume_net, event_dim=0)
        self.ventricle_volume_flow_lognorm = AffineTransform(loc=0., scale=1.)
        self.ventricle_volume_flow_constraint_transforms = ComposeTransform(
            [self.ventricle_volume_flow_lognorm,
             ExpTransform()])
        self.ventricle_volume_flow_transforms = [
            self.ventricle_volume_flow_components,
            self.ventricle_volume_flow_constraint_transforms
        ]

        # brain_volume flow
        # TODO: decide on how many things to condition on
        brain_volume_net = DenseNN(2, [8, 16],
                                   param_dims=[1, 1],
                                   nonlinearity=torch.nn.Identity())
        self.brain_volume_flow_components = ConditionalAffineTransform(
            context_nn=brain_volume_net, event_dim=0)
        self.brain_volume_flow_lognorm = AffineTransform(loc=0., scale=1.)
        self.brain_volume_flow_constraint_transforms = ComposeTransform(
            [self.brain_volume_flow_lognorm,
             ExpTransform()])
        self.brain_volume_flow_transforms = [
            self.brain_volume_flow_components,
            self.brain_volume_flow_constraint_transforms
        ]

        # encoder parts
        self.encoder = Encoder(num_convolutions=self.num_convolutions,
                               filters=self.enc_filters,
                               latent_dim=self.latent_dim)

        # TODO: do we need to replicate the PGM here to be able to run conterfactuals? oO
        latent_layers = torch.nn.Sequential(
            torch.nn.Linear(self.latent_dim + 3, self.latent_dim),
            torch.nn.ReLU())
        self.latent_encoder = DeepIndepNormal(latent_layers, self.latent_dim,
                                              self.latent_dim)

    @pyro_method
    def pgm_model(self):
        sex_dist = Bernoulli(self.sex_logits).to_event(1)

        sex = pyro.sample('sex', sex_dist)

        age_base_dist = Normal(self.age_base_loc,
                               self.age_base_scale).to_event(1)
        age_dist = TransformedDistribution(age_base_dist,
                                           self.age_flow_transforms)

        age = pyro.sample('age', age_dist)
        age_ = self.age_flow_constraint_transforms.inv(age)
        # pseudo call to thickness_flow_transforms to register with pyro
        _ = self.age_flow_transforms

        context = torch.cat([sex, age_], 1)

        ventricle_volume_base_dist = Normal(
            self.ventricle_volume_base_loc,
            self.ventricle_volume_base_scale).to_event(1)
        ventricle_volume_dist = ConditionalTransformedDistribution(
            ventricle_volume_base_dist,
            self.ventricle_volume_flow_transforms).condition(context)

        ventricle_volume = pyro.sample('ventricle_volume',
                                       ventricle_volume_dist)
        # pseudo call to intensity_flow_transforms to register with pyro
        _ = self.ventricle_volume_flow_transforms

        brain_volume_base_dist = Normal(
            self.brain_volume_base_loc,
            self.brain_volume_base_scale).to_event(1)
        brain_volume_dist = ConditionalTransformedDistribution(
            brain_volume_base_dist,
            self.brain_volume_flow_transforms).condition(context)

        brain_volume = pyro.sample('brain_volume', brain_volume_dist)
        # pseudo call to intensity_flow_transforms to register with pyro
        _ = self.brain_volume_flow_transforms

        return age, sex, ventricle_volume, brain_volume

    @pyro_method
    def model(self):
        age, sex, ventricle_volume, brain_volume = self.pgm_model()

        ventricle_volume_ = self.ventricle_volume_flow_constraint_transforms.inv(
            ventricle_volume)
        brain_volume_ = self.brain_volume_flow_constraint_transforms.inv(
            brain_volume)
        age_ = self.age_flow_constraint_transforms.inv(age)

        z = pyro.sample('z', Normal(self.z_loc, self.z_scale).to_event(1))

        latent = torch.cat([z, age_, ventricle_volume_, brain_volume_], 1)

        x_loc = self.decoder_mean(self.decoder(latent))
        x_scale = torch.exp(self.decoder_logstd)

        theta = self.decoder_affine_param_net(latent).view(-1, 2, 3)

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")

            grid = nn.functional.affine_grid(theta, x_loc.size())
            x_loc_deformed = nn.functional.grid_sample(x_loc, grid)

        x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3)

        preprocess_transform = self._get_preprocess_transforms()
        x_dist = TransformedDistribution(
            x_base_dist,
            ComposeTransform([
                AffineTransform(x_loc_deformed, x_scale, 3),
                preprocess_transform
            ]))

        x = pyro.sample('x', x_dist)

        return x, z, age, sex, ventricle_volume, brain_volume

    @pyro_method
    def guide(self, x, age, sex, ventricle_volume, brain_volume):
        with pyro.plate('observations', x.shape[0]):
            hidden = self.encoder(x)

            ventricle_volume_ = self.ventricle_volume_flow_constraint_transforms.inv(
                ventricle_volume)
            brain_volume_ = self.brain_volume_flow_constraint_transforms.inv(
                brain_volume)
            age_ = self.age_flow_constraint_transforms.inv(age)

            hidden = torch.cat(
                [hidden, age_, ventricle_volume_, brain_volume_], 1)

            latent_dist = self.latent_encoder.predict(hidden)

            z = pyro.sample('z', latent_dist)

        return z
Exemple #5
0
class IndependentFlowSEM(BaseFlowSEM):
    def __init__(self, use_affine_ex: bool = True, **kwargs):
        super().__init__(**kwargs)
        self.use_affine_ex = use_affine_ex

        # decoder parts

        # Flow for modelling t Gamma
        self.thickness_flow_components = ComposeTransformModule([Spline(1)])
        self.thickness_flow_constraint_transforms = ComposeTransform(
            [self.thickness_flow_lognorm,
             ExpTransform()])
        self.thickness_flow_transforms = ComposeTransform([
            self.thickness_flow_components,
            self.thickness_flow_constraint_transforms
        ])

        # affine flow for s normal
        self.intensity_flow_components = ComposeTransformModule(
            [LearnedAffineTransform(), Spline(1)])
        self.intensity_flow_constraint_transforms = ComposeTransform(
            [SigmoidTransform(), self.intensity_flow_norm])
        self.intensity_flow_transforms = [
            self.intensity_flow_components,
            self.intensity_flow_constraint_transforms
        ]

        # realnvp or so for x
        self._build_image_flow()

    def _build_image_flow(self):

        self.trans_modules = ComposeTransformModule([])

        self.x_transforms = []

        self.x_transforms += [self._get_preprocess_transforms()]

        c = 1
        for _ in range(self.num_scales):
            self.x_transforms.append(SqueezeTransform())
            c *= 4

            for _ in range(self.flows_per_scale):
                if self.use_actnorm:
                    actnorm = ActNorm(c)
                    self.trans_modules.append(actnorm)
                    self.x_transforms.append(actnorm)

                gcp = GeneralizedChannelPermute(channels=c)
                self.trans_modules.append(gcp)
                self.x_transforms.append(gcp)

                self.x_transforms.append(
                    TransposeTransform(torch.tensor((1, 2, 0))))

                ac = AffineCoupling(
                    c // 2,
                    BasicFlowConvNet(c // 2, self.hidden_channels,
                                     (c // 2, c // 2)))
                self.trans_modules.append(ac)
                self.x_transforms.append(ac)

                self.x_transforms.append(
                    TransposeTransform(torch.tensor((2, 0, 1))))

            gcp = GeneralizedChannelPermute(channels=c)
            self.trans_modules.append(gcp)
            self.x_transforms.append(gcp)

        self.x_transforms += [
            ReshapeTransform((4**self.num_scales, 32 // 2**self.num_scales,
                              32 // 2**self.num_scales), (1, 32, 32))
        ]

    @pyro_method
    def pgm_model(self):
        thickness_base_dist = Normal(self.thickness_base_loc,
                                     self.thickness_base_scale).to_event(1)
        thickness_dist = TransformedDistribution(
            thickness_base_dist, self.thickness_flow_transforms)

        thickness = pyro.sample('thickness', thickness_dist)
        # pseudo call to thickness_flow_transforms to register with pyro
        _ = self.thickness_flow_components

        intensity_base_dist = Normal(self.intensity_base_loc,
                                     self.intensity_base_scale).to_event(1)
        intensity_dist = TransformedDistribution(
            intensity_base_dist, self.intensity_flow_transforms)

        intensity = pyro.sample('intensity', intensity_dist)
        # pseudo call to intensity_flow_transforms to register with pyro
        _ = self.intensity_flow_components

        return thickness, intensity

    @pyro_method
    def model(self):
        thickness, intensity = self.pgm_model()

        x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3)
        x_dist = TransformedDistribution(
            x_base_dist,
            ComposeTransform(self.x_transforms).inv)

        x = pyro.sample('x', x_dist)

        return x, thickness, intensity

    @pyro_method
    def infer_thickness_base(self, thickness):
        return self.thickness_flow_transforms.inv(thickness)

    @pyro_method
    def infer_intensity_base(self, intensity):
        return self.intensity_flow_transforms.inv(intensity)

    @pyro_method
    def infer_x_base(self, thickness, intensity, x):
        return ComposeTransform(self.x_transforms)(x)
Exemple #6
0
class FlowSCM(pyroModule):
    """ definition of FlowSCM class"""
    def __init__(self, use_affine_ex=True, **kwargs):
        super.__init__(**kwargs)

        self.num_scales = 2

        self.register_buffer("glasses_base_loc",
                             torch.zeros([
                                 1,
                             ], requires_grad=False))
        self.register_buffer("glasses_base_scale",
                             torch.ones([
                                 1,
                             ], requires_grad=False))

        self.register_buffer("glasses_flow_lognorm_loc",
                             torch.zeros([], requires_grad=False))
        self.register_buffer("glasses_flow_lognorm_scale",
                             torch.ones([], requires_grad=False))

        self.glasses_flow_components = ComposeTransformModule([Spline(1)])
        self.glasses_flow_constraint_transforms = ComposeTransform(
            [self.glasses_flow_lognorm,
             SigmoidTransform()])
        self.glasses_flow_transforms = ComposeTransform([
            self.glasses_flow_components,
            self.glasses_flow_constraint_transforms
        ])

        glasses_base_dist = Normal(self.glasses_base_loc,
                                   self.glasses_base_scale).to_event(1)
        self.glasses_dist = TransformedDistribution(
            glasses_base_dist, self.glasses_flow_transforms)
        glasses_ = pyro.sample("glasses_", self.glasses_dist)
        glasses = pyro.sample("glasses", dist.Bernoulli(glasses_))
        glasses_context = self.glasses_flow_constraint_transforms.inv(glasses_)

        self.x_transforms = self._build_image_flow()
        self.register_buffer("x_base_loc",
                             torch.zeros([1, 64, 64], requires_grad=False))
        self.register_buffer("x_base_scale",
                             torch.ones([1, 64, 64], requires_grad=False))
        x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3)
        cond_x_transforms = ComposeTransform(
            ConditionalTransformedDistribution(
                x_base_dist,
                self.x_transforms).condition(context).transforms).inv
        cond_x_dist = TransformedDistribution(x_base_dist, cond_x_transforms)

        x = pyro.sample("x", cond_x_dist)

        return x, glasses

    def _build_image_flow(self):
        self.trans_modules = ComposeTransformModule([])
        self.x_transforms = []
        self.x_transforms += [self._get_preprocess_transforms()]

        c = 1
        for _ in range(self.num_scales):
            self.x_transforms.append(SqueezeTransform())
            c *= 4

            for _ in range(self.flows_per_scale):
                if self.use_actnorm:
                    actnorm = ActNorm(c)
                    self.trans_modules.append(actnorm)
                    self.x_transforms.append(actnorm)

                gcp = GeneralizedChannelPermute(channels=c)
                self.trans_modules.append(gcp)
                self.x_transforms.append(gcp)

                self.x_transforms.append(
                    TransposeTransform(torch.tensor((1, 2, 0))))

                ac = ConditionalAffineCoupling(
                    c // 2,
                    BasicFlowConvNet(c // 2, self.hidden_channels,
                                     (c // 2, c // 2), 2))
                self.trans_modules.append(ac)
                self.x_transforms.append(ac)

                self.x_transforms.append(
                    TransposeTransform(torch.tensor((2, 0, 1))))

            gcp = GeneralizedChannelPermute(channels=c)
            self.trans_modules.append(gcp)
            self.x_transforms.append(gcp)

        self.x_transforms += [
            ReshapeTransform((4**self.num_scales, 32 // 2**self.num_scales,
                              32 // 2**self.num_scales), (1, 32, 32))
        ]

        if self.use_affine_ex:
            affine_net = DenseNN(2, [16, 16], param_dims=[1, 1])
            affine_trans = ConditionalAffineTransform(context_nn=affine_net,
                                                      event_dim=3)

            self.trans_modules.append(affine_trans)
            self.x_transforms.append(affine_trans)