Ejemplo n.º 1
0
    def pgm_model(self):
        sex_dist = Bernoulli(logits=self.sex_logits).to_event(1)
        # pseudo call to register with pyro
        _ = self.sex_logits
        sex = pyro.sample('sex', sex_dist, infer=dict(baseline={'use_decaying_avg_baseline': True}))

        slice_number_dist = Uniform(self.slice_number_min, self.slice_number_max).to_event(1)
        slice_number = pyro.sample('slice_number', slice_number_dist)

        age_base_dist = Normal(self.age_base_loc, self.age_base_scale).to_event(1)
        age_dist = TransformedDistribution(age_base_dist, self.age_flow_transforms)
        _ = self.age_flow_components
        age = pyro.sample('age', age_dist)
        age_ = self.age_flow_constraint_transforms.inv(age)

        duration_context = torch.cat([sex, age_], 1)
        duration_base_dist = Normal(self.duration_base_loc, self.duration_base_scale).to_event(1)
        duration_dist = ConditionalTransformedDistribution(duration_base_dist, self.duration_flow_transforms).condition(duration_context)  # noqa: E501
        duration = pyro.sample('duration', duration_dist)
        _ = self.duration_flow_components
        duration_ = self.duration_flow_constraint_transforms.inv(duration)

        edss_context = torch.cat([sex, duration_], 1)
        edss_base_dist = Normal(self.edss_base_loc, self.edss_base_scale).to_event(1)
        edss_dist = ConditionalTransformedDistribution(edss_base_dist, self.edss_flow_transforms).condition(edss_context)  # noqa: E501
        edss = pyro.sample('edss', edss_dist)
        _ = self.edss_flow_components
        edss_ = self.edss_flow_constraint_transforms.inv(edss)

        brain_context = torch.cat([sex, age_], 1)
        brain_volume_base_dist = Normal(self.brain_volume_base_loc, self.brain_volume_base_scale).to_event(1)
        brain_volume_dist = ConditionalTransformedDistribution(brain_volume_base_dist, self.brain_volume_flow_transforms).condition(brain_context)
        _ = self.brain_volume_flow_components
        brain_volume = pyro.sample('brain_volume', brain_volume_dist)
        brain_volume_ = self.brain_volume_flow_constraint_transforms.inv(brain_volume)

        ventricle_context = torch.cat([age_, brain_volume_, duration_], 1)
        ventricle_volume_base_dist = Normal(self.ventricle_volume_base_loc, self.ventricle_volume_base_scale).to_event(1)
        ventricle_volume_dist = ConditionalTransformedDistribution(ventricle_volume_base_dist, self.ventricle_volume_flow_transforms).condition(ventricle_context)  # noqa: E501
        ventricle_volume = pyro.sample('ventricle_volume', ventricle_volume_dist)
        _ = self.ventricle_volume_flow_components
        ventricle_volume_ = self.ventricle_volume_flow_constraint_transforms.inv(ventricle_volume)

        lesion_context = torch.cat([brain_volume_, ventricle_volume_, duration_, edss_], 1)
        lesion_volume_base_dist = Normal(self.lesion_volume_base_loc, self.lesion_volume_base_scale).to_event(1)
        lesion_volume_dist = ConditionalTransformedDistribution(lesion_volume_base_dist, self.lesion_volume_flow_transforms).condition(lesion_context)
        lesion_volume = pyro.sample('lesion_volume', lesion_volume_dist)
        _ = self.lesion_volume_flow_components

        return dict(age=age, sex=sex, ventricle_volume=ventricle_volume, brain_volume=brain_volume,
                    lesion_volume=lesion_volume, duration=duration, edss=edss, slice_number=slice_number)
Ejemplo n.º 2
0
    def pgm_model(self):
        sex_dist = Bernoulli(logits=self.sex_logits).to_event(1)

        _ = self.sex_logits

        sex = pyro.sample('sex', sex_dist)

        age_base_dist = Normal(self.age_base_loc,
                               self.age_base_scale).to_event(1)
        age_dist = TransformedDistribution(age_base_dist,
                                           self.age_flow_transforms)

        age = pyro.sample('age', age_dist)
        age_ = self.age_flow_constraint_transforms.inv(age)
        # pseudo call to thickness_flow_transforms to register with pyro
        _ = self.age_flow_components

        brain_context = torch.cat([sex, age_], 1)

        brain_volume_base_dist = Normal(
            self.brain_volume_base_loc,
            self.brain_volume_base_scale).to_event(1)
        brain_volume_dist = ConditionalTransformedDistribution(
            brain_volume_base_dist,
            self.brain_volume_flow_transforms).condition(brain_context)

        brain_volume = pyro.sample('brain_volume', brain_volume_dist)
        # pseudo call to intensity_flow_transforms to register with pyro
        _ = self.brain_volume_flow_components

        brain_volume_ = self.brain_volume_flow_constraint_transforms.inv(
            brain_volume)

        ventricle_context = torch.cat([age_, brain_volume_], 1)

        ventricle_volume_base_dist = Normal(
            self.ventricle_volume_base_loc,
            self.ventricle_volume_base_scale).to_event(1)
        ventricle_volume_dist = ConditionalTransformedDistribution(
            ventricle_volume_base_dist,
            self.ventricle_volume_flow_transforms).condition(
                ventricle_context)  # noqa: E501

        ventricle_volume = pyro.sample('ventricle_volume',
                                       ventricle_volume_dist)
        # pseudo call to intensity_flow_transforms to register with pyro
        _ = self.ventricle_volume_flow_components

        return age, sex, ventricle_volume, brain_volume
Ejemplo n.º 3
0
 def infer_intensity_base(self, thickness, intensity):
     intensity_base_dist = Normal(self.intensity_base_loc,
                                  self.intensity_base_scale)
     cond_intensity_transforms = ComposeTransform(
         ConditionalTransformedDistribution(
             intensity_base_dist, self.intensity_flow_transforms).condition(
                 thickness).transforms)
     return cond_intensity_transforms.inv(intensity)
Ejemplo n.º 4
0
    def infer_x_base(self, thickness, intensity, x):
        x_base_dist = Normal(self.x_base_loc, self.x_base_scale)

        thickness_ = self.thickness_flow_constraint_transforms.inv(thickness)
        intensity_ = self.intensity_flow_norm.inv(intensity)

        context = torch.cat([thickness_, intensity_], 1)
        cond_x_transforms = ComposeTransform(
            ConditionalTransformedDistribution(
                x_base_dist, self.x_transforms).condition(context).transforms)
        return cond_x_transforms(x)
Ejemplo n.º 5
0
    def __init__(self, use_affine_ex=True, **kwargs):
        super.__init__(**kwargs)

        self.num_scales = 2

        self.register_buffer("glasses_base_loc",
                             torch.zeros([
                                 1,
                             ], requires_grad=False))
        self.register_buffer("glasses_base_scale",
                             torch.ones([
                                 1,
                             ], requires_grad=False))

        self.register_buffer("glasses_flow_lognorm_loc",
                             torch.zeros([], requires_grad=False))
        self.register_buffer("glasses_flow_lognorm_scale",
                             torch.ones([], requires_grad=False))

        self.glasses_flow_components = ComposeTransformModule([Spline(1)])
        self.glasses_flow_constraint_transforms = ComposeTransform(
            [self.glasses_flow_lognorm,
             SigmoidTransform()])
        self.glasses_flow_transforms = ComposeTransform([
            self.glasses_flow_components,
            self.glasses_flow_constraint_transforms
        ])

        glasses_base_dist = Normal(self.glasses_base_loc,
                                   self.glasses_base_scale).to_event(1)
        self.glasses_dist = TransformedDistribution(
            glasses_base_dist, self.glasses_flow_transforms)
        glasses_ = pyro.sample("glasses_", self.glasses_dist)
        glasses = pyro.sample("glasses", dist.Bernoulli(glasses_))
        glasses_context = self.glasses_flow_constraint_transforms.inv(glasses_)

        self.x_transforms = self._build_image_flow()
        self.register_buffer("x_base_loc",
                             torch.zeros([1, 64, 64], requires_grad=False))
        self.register_buffer("x_base_scale",
                             torch.ones([1, 64, 64], requires_grad=False))
        x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3)
        cond_x_transforms = ComposeTransform(
            ConditionalTransformedDistribution(
                x_base_dist,
                self.x_transforms).condition(context).transforms).inv
        cond_x_dist = TransformedDistribution(x_base_dist, cond_x_transforms)

        x = pyro.sample("x", cond_x_dist)

        return x, glasses
Ejemplo n.º 6
0
    def model(self):
        thickness, intensity = self.pgm_model()

        thickness_ = self.thickness_flow_constraint_transforms.inv(thickness)
        intensity_ = self.intensity_flow_norm.inv(intensity)

        context = torch.cat([thickness_, intensity_], 1)

        x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3)
        cond_x_transforms = ComposeTransform(
            ConditionalTransformedDistribution(
                x_base_dist,
                self.x_transforms).condition(context).transforms).inv
        cond_x_dist = TransformedDistribution(x_base_dist, cond_x_transforms)

        x = pyro.sample('x', cond_x_dist)

        return x, thickness, intensity
Ejemplo n.º 7
0
    def pgm_model(self):
        thickness_base_dist = Normal(self.thickness_base_loc,
                                     self.thickness_base_scale).to_event(1)
        thickness_dist = TransformedDistribution(
            thickness_base_dist, self.thickness_flow_transforms)

        thickness = pyro.sample('thickness', thickness_dist)
        thickness_ = self.thickness_flow_constraint_transforms.inv(thickness)
        # pseudo call to thickness_flow_transforms to register with pyro
        _ = self.thickness_flow_components

        intensity_base_dist = Normal(self.intensity_base_loc,
                                     self.intensity_base_scale).to_event(1)
        intensity_dist = ConditionalTransformedDistribution(
            intensity_base_dist,
            self.intensity_flow_transforms).condition(thickness_)

        intensity = pyro.sample('intensity', intensity_dist)
        # pseudo call to w_flow_transforms to register with pyro
        _ = self.intensity_flow_components

        return thickness, intensity