Ejemplo n.º 1
0
    def guide(self, mini_batch, mini_batch_reversed, mini_batch_mask,
              mini_batch_seq_lengths, annealing_factor=1.0):

        # this is the number of time steps we need to process in the mini-batch
        T_max = mini_batch.size(1)
        # register all PyTorch (sub)modules with pyro
        pyro.module("dmm", self)

        # if on gpu we need the fully broadcast view of the rnn initial state
        # to be in contiguous gpu memory
        h_0_contig = self.h_0.expand(1, mini_batch.size(0), self.rnn.hidden_size).contiguous()
        # push the observed x's through the rnn;
        # rnn_output contains the hidden state at each time step
        rnn_output, _ = self.rnn(mini_batch_reversed, h_0_contig)
        # reverse the time-ordering in the hidden state and un-pack it
        rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths)
        # set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...)
        z_prev = self.z_q_0.expand(mini_batch.size(0), self.z_q_0.size(0))

        # we enclose all the sample statements in the guide in a plate.
        # this marks that each datapoint is conditionally independent of the others.
        with pyro.plate("z_minibatch", len(mini_batch)):
            # sample the latents z one time step at a time
            # we wrap this loop in pyro.markov so that TraceEnum_ELBO can use multiple samples from the guide at each z
            for t in pyro.markov(range(1, T_max + 1)):
                # the next two lines assemble the distribution q(z_t | z_{t-1}, x_{t:T})
                z_loc, z_scale = self.combiner(z_prev, rnn_output[:, t - 1, :])

                # if we are using normalizing flows, we apply the sequence of transformations
                # parameterized by self.iafs to the base distribution defined in the previous line
                # to yield a transformed distribution that we use for q(z_t|...)
                if len(self.iafs) > 0:
                    z_dist = TransformedDistribution(dist.Normal(z_loc, z_scale), self.iafs)
                    assert z_dist.event_shape == (self.z_q_0.size(0),)
                    assert z_dist.batch_shape[-1:] == (len(mini_batch),)
                else:
                    z_dist = dist.Normal(z_loc, z_scale)
                    assert z_dist.event_shape == ()
                    assert z_dist.batch_shape[-2:] == (len(mini_batch), self.z_q_0.size(0))

                # sample z_t from the distribution z_dist
                with pyro.poutine.scale(scale=annealing_factor):
                    if len(self.iafs) > 0:
                        # in output of normalizing flow, all dimensions are correlated (event shape is not empty)
                        z_t = pyro.sample("z_%d" % t,
                                          z_dist.mask(mini_batch_mask[:, t - 1]))
                    else:
                        # when no normalizing flow used, ".to_event(1)" indicates latent dimensions are independent
                        z_t = pyro.sample("z_%d" % t,
                                          z_dist.mask(mini_batch_mask[:, t - 1:t])
                                          .to_event(1))
                # the latent sampled at this time step will be conditioned upon in the next time step
                # so keep track of it
                z_prev = z_t
Ejemplo n.º 2
0
    def __init__(self, use_affine_ex=True, **kwargs):
        super.__init__(**kwargs)

        self.num_scales = 2

        self.register_buffer("glasses_base_loc",
                             torch.zeros([
                                 1,
                             ], requires_grad=False))
        self.register_buffer("glasses_base_scale",
                             torch.ones([
                                 1,
                             ], requires_grad=False))

        self.register_buffer("glasses_flow_lognorm_loc",
                             torch.zeros([], requires_grad=False))
        self.register_buffer("glasses_flow_lognorm_scale",
                             torch.ones([], requires_grad=False))

        self.glasses_flow_components = ComposeTransformModule([Spline(1)])
        self.glasses_flow_constraint_transforms = ComposeTransform(
            [self.glasses_flow_lognorm,
             SigmoidTransform()])
        self.glasses_flow_transforms = ComposeTransform([
            self.glasses_flow_components,
            self.glasses_flow_constraint_transforms
        ])

        glasses_base_dist = Normal(self.glasses_base_loc,
                                   self.glasses_base_scale).to_event(1)
        self.glasses_dist = TransformedDistribution(
            glasses_base_dist, self.glasses_flow_transforms)
        glasses_ = pyro.sample("glasses_", self.glasses_dist)
        glasses = pyro.sample("glasses", dist.Bernoulli(glasses_))
        glasses_context = self.glasses_flow_constraint_transforms.inv(glasses_)

        self.x_transforms = self._build_image_flow()
        self.register_buffer("x_base_loc",
                             torch.zeros([1, 64, 64], requires_grad=False))
        self.register_buffer("x_base_scale",
                             torch.ones([1, 64, 64], requires_grad=False))
        x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3)
        cond_x_transforms = ComposeTransform(
            ConditionalTransformedDistribution(
                x_base_dist,
                self.x_transforms).condition(context).transforms).inv
        cond_x_dist = TransformedDistribution(x_base_dist, cond_x_transforms)

        x = pyro.sample("x", cond_x_dist)

        return x, glasses
    def _get_transformed_x_dist(self, latent, ctx=None):
        x_pred_dist = self.decoder.predict(
            latent,
            ctx)  # returns a normal dist with mean of the predicted image
        if self.laplace_likelihood:
            x_base_dist = Laplace(self.x_base_loc,
                                  self.x_base_scale).to_event(3)
        else:
            x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(
                3)  # 3 dimensions starting from right dep.

        preprocess_transform = self._get_preprocess_transforms()

        if isinstance(x_pred_dist, MultivariateNormal) or isinstance(
                x_pred_dist, LowRankMultivariateNormal):
            chol_transform = LowerCholeskyAffine(x_pred_dist.loc,
                                                 x_pred_dist.scale_tril)
            reshape_transform = ReshapeTransform(self.img_shape,
                                                 (np.prod(self.img_shape), ))
            x_reparam_transform = ComposeTransform(
                [reshape_transform, chol_transform, reshape_transform.inv])
        elif isinstance(x_pred_dist, Independent):
            x_pred_dist = x_pred_dist.base_dist
            x_reparam_transform = AffineTransform(x_pred_dist.loc,
                                                  x_pred_dist.scale, 3)
        else:
            raise ValueError(f'{x_pred_dist} not valid.')

        return TransformedDistribution(
            x_base_dist,
            ComposeTransform([x_reparam_transform, preprocess_transform]))
Ejemplo n.º 4
0
    def guide(self, obs):
        batch_size = obs['x'].shape[0]
        with pyro.plate('observations', batch_size):
            hidden = self.encoder(obs['x'])

            ventricle_volume_ = self.ventricle_volume_flow_constraint_transforms.inv(
                obs['ventricle_volume'])
            brain_volume_ = self.brain_volume_flow_constraint_transforms.inv(
                obs['brain_volume'])
            lesion_volume_ = self.lesion_volume_flow_constraint_transforms.inv(
                obs['lesion_volume'])
            slice_number = obs['slice_number']
            ctx = torch.cat([
                ventricle_volume_, brain_volume_, lesion_volume_, slice_number
            ], 1)
            hidden = torch.cat([hidden, ctx], 1)

            z_base_dist = self.latent_encoder.predict(hidden)
            z_dist = TransformedDistribution(
                z_base_dist, self.posterior_flow_transforms
            ) if self.use_posterior_flow else z_base_dist
            _ = self.posterior_affine
            _ = self.posterior_flow_components
            with poutine.scale(scale=self.annealing_factor[-1]):
                z = pyro.sample('z', z_dist)

        return z
Ejemplo n.º 5
0
    def model(self):
        obs = self.pgm_model()

        ventricle_volume_ = self.ventricle_volume_flow_constraint_transforms.inv(
            obs['ventricle_volume'])
        brain_volume_ = self.brain_volume_flow_constraint_transforms.inv(
            obs['brain_volume'])
        lesion_volume_ = self.lesion_volume_flow_constraint_transforms.inv(
            obs['lesion_volume'])
        slice_number = obs['slice_number']
        ctx = torch.cat(
            [ventricle_volume_, brain_volume_, lesion_volume_, slice_number],
            1)

        if self.prior_components > 1:
            z_scale = (
                0.5 * self.z_scale).exp() + 1e-5  # z_scale parameter is logvar
            z_base_dist = MixtureOfDiagNormalsSharedCovariance(
                self.z_loc, z_scale, self.z_components).to_event(0)
        else:
            z_base_dist = Normal(self.z_loc, self.z_scale).to_event(1)
        z_dist = TransformedDistribution(
            z_base_dist,
            self.prior_flow_transforms) if self.use_prior_flow else z_base_dist
        _ = self.prior_affine
        _ = self.prior_flow_components
        with poutine.scale(scale=self.annealing_factor[-1]):
            z = pyro.sample('z', z_dist)
        latent = torch.cat([z, ctx], 1)

        x_dist = self._get_transformed_x_dist(latent)  # run decoder
        x = pyro.sample('x', x_dist)

        obs.update(dict(x=x, z=z))
        return obs
    def guide(self, response, mask, annealing_factor=1):
        pyro.module("item_encoder", self.item_encoder)
        pyro.module("ability_encoder", self.ability_encoder)
        device = response.device

        item_domain = torch.arange(self.num_item).unsqueeze(1).to(device)
        item_feat_mu, item_feat_logvar = self.item_encoder(item_domain)
        item_feat_scale = torch.exp(0.5 * item_feat_logvar)

        with poutine.scale(scale=annealing_factor):
            item_feat = pyro.sample(
                "item_feat",
                dist.Normal(item_feat_mu, item_feat_scale),
            )

        if self.conditional_posterior:
            ability_mu, ability_logvar = self.ability_encoder(
                response, mask, item_feat)
        else:
            ability_mu, ability_logvar = self.ability_encoder(response, mask)

        ability_scale = torch.exp(0.5 * ability_logvar)
        ability_dist = dist.Normal(ability_mu, ability_scale)

        if self.num_iafs > 0:
            ability_dist = TransformedDistribution(ability_dist, self.iafs)

        with poutine.scale(scale=annealing_factor):
            ability = pyro.sample("ability", ability_dist)

        return ability_mu, ability_logvar, item_feat_mu, item_feat_logvar
Ejemplo n.º 7
0
Archivo: vi_utils.py Proyecto: bkmi/sbi
def adapt_variational_distribution(
    q: TransformedDistribution,
    prior: Distribution,
    link_transform: Callable,
    parameters: Iterable = [],
    modules: Iterable = [],
) -> Distribution:
    """This will adapt a distribution to be compatible with DivergenceOptimizers.
    Especially it will make sure that the distribution has parameters and that it
    satisfies obvious contraints which a posterior must satisfy i.e. the support must be
    equal to that of the prior.

    Args:
        q: Variational distribution.
        prior: Prior distribution
        theta_transform: Theta transformation.
        parameters: List of parameters.
        modules: List of modules.

    Returns:
        TransformedDistribution: Compatible variational distribution.

    """

    # Extract user define parameters
    def parameters_fn():
        """Returns the parameters of the distribution."""
        return parameters

    def modules_fn():
        """Returns the modules of the distribution."""
        return modules

    if isinstance(q, TransformedDistribution):
        if parameters == [] or modules_fn == []:
            add_parameter_attributes_to_transformed_distribution(q)
        else:
            add_parameters_module_attributes(q, parameters_fn, modules_fn)
        if hasattr(prior, "support") and q.support != prior.support:
            q = TransformedDistribution(q.base_dist,
                                        q.transforms + [link_transform])
    else:
        if hasattr(prior, "support") and q.support != prior.support:
            q = TransformedDistribution(q, [link_transform])
        add_parameters_module_attributes(q, parameters_fn, modules_fn)

    return q
Ejemplo n.º 8
0
Archivo: dmm.py Proyecto: lewisKit/pyro
    def guide(self, mini_batch, mini_batch_reversed, mini_batch_mask,
              mini_batch_seq_lengths, annealing_factor=1.0):

        # this is the number of time steps we need to process in the mini-batch
        T_max = mini_batch.size(1)
        # register all PyTorch (sub)modules with pyro
        pyro.module("dmm", self)

        # if on gpu we need the fully broadcast view of the rnn initial state
        # to be in contiguous gpu memory
        h_0_contig = self.h_0.expand(1, mini_batch.size(0), self.rnn.hidden_size).contiguous()
        # push the observed x's through the rnn;
        # rnn_output contains the hidden state at each time step
        rnn_output, _ = self.rnn(mini_batch_reversed, h_0_contig)
        # reverse the time-ordering in the hidden state and un-pack it
        rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths)
        # set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...)
        z_prev = self.z_q_0.expand(mini_batch.size(0), self.z_q_0.size(0))

        # we enclose all the sample statements in the guide in a iarange.
        # this marks that each datapoint is conditionally independent of the others.
        with pyro.iarange("z_minibatch", len(mini_batch)):
            # sample the latents z one time step at a time
            for t in range(1, T_max + 1):
                # the next two lines assemble the distribution q(z_t | z_{t-1}, x_{t:T})
                z_loc, z_scale = self.combiner(z_prev, rnn_output[:, t - 1, :])

                # if we are using normalizing flows, we apply the sequence of transformations
                # parameterized by self.iafs to the base distribution defined in the previous line
                # to yield a transformed distribution that we use for q(z_t|...)
                if len(self.iafs) > 0:
                    z_dist = TransformedDistribution(dist.Normal(z_loc, z_scale), self.iafs)
                else:
                    z_dist = dist.Normal(z_loc, z_scale)
                assert z_dist.event_shape == ()
                assert z_dist.batch_shape == (len(mini_batch), self.z_q_0.size(0))

                # sample z_t from the distribution z_dist
                with pyro.poutine.scale(scale=annealing_factor):
                    z_t = pyro.sample("z_%d" % t,
                                      z_dist.mask(mini_batch_mask[:, t - 1:t])
                                            .independent(1))
                # the latent sampled at this time step will be conditioned upon in the next time step
                # so keep track of it
                z_prev = z_t
Ejemplo n.º 9
0
 def sample_latent_space(self, annealing_factor, batch_size, t, z_loc_q,
                         z_scale_q):
     if len(self.iafs) > 0:
         z_dist = TransformedDistribution(dist.Normal(z_loc_q, z_scale_q),
                                          self.iafs)
         assert z_dist.event_shape == (self.z_q_0.size(0), )
         assert z_dist.batch_shape[-1:] == (batch_size, )
     else:
         z_dist = dist.Normal(z_loc_q, z_scale_q)
         assert z_dist.event_shape == ()
         assert z_dist.batch_shape[-2:] == (batch_size, self.z_q_0.size(0))
     with pyro.poutine.scale(scale=annealing_factor):
         if len(self.iafs) > 0:
             # in output of normalizing flow, all dimensions are correlated (event shape is not empty)
             z_t = pyro.sample(f"z_{t}", z_dist)
         else:
             # when no normalizing flow used, ".to_event(1)" indicates latent dimensions are independent
             z_t = pyro.sample(f"z_{t}", z_dist.to_event(1))
     return z_t
Ejemplo n.º 10
0
    def model(self):
        thickness, intensity = self.pgm_model()

        x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3)
        x_dist = TransformedDistribution(
            x_base_dist,
            ComposeTransform(self.x_transforms).inv)

        x = pyro.sample('x', x_dist)

        return x, thickness, intensity
    def pgm_model(self):
        thickness_base_dist = Normal(self.thickness_base_loc,
                                     self.thickness_base_scale).to_event(1)
        thickness_dist = TransformedDistribution(
            thickness_base_dist, self.thickness_flow_transforms)

        thickness = pyro.sample('thickness', thickness_dist)
        # pseudo call to thickness_flow_transforms to register with pyro
        _ = self.thickness_flow_components

        intensity_base_dist = Normal(self.intensity_base_loc,
                                     self.intensity_base_scale).to_event(1)
        intensity_dist = TransformedDistribution(
            intensity_base_dist, self.intensity_flow_transforms)

        intensity = pyro.sample('intensity', intensity_dist)
        # pseudo call to intensity_flow_transforms to register with pyro
        _ = self.intensity_flow_components

        return thickness, intensity
def model(n_samples=None, scale=2.):
    with pyro.plate('observations', n_samples):
        thickness = pyro.sample('thickness', Gamma(10., 5.))

        loc = (thickness - 2.5) * 2

        transforms = ComposeTransform([SigmoidTransform(), AffineTransform(10, 15)])

        width = pyro.sample('width', TransformedDistribution(Normal(loc, scale), transforms))

    return thickness, width
Ejemplo n.º 13
0
Archivo: vi_utils.py Proyecto: bkmi/sbi
def check_parameters_modules_attribute(q: TransformedDistribution) -> None:
    """Checks a parameterized distribution object for valid `parameters` and `modules`.

    Args:
        q: Distribution object
    """

    if not hasattr(q, "parameters"):
        raise ValueError(
            """The variational distribution requires an `parameters` attribute, which
            returns an iterable of parameters""")
    else:
        assert isinstance(q.parameters,
                          Callable), "The parameters must be callable"
        parameters = q.parameters()
        assert isinstance(
            parameters,
            Iterable), "The parameters return value must be iterable"
        trainable = 0
        for p in parameters:
            assert isinstance(p, torch.Tensor)
            if p.requires_grad:
                trainable += 1
        assert (
            trainable > 0
        ), """Nothing to train, atleast one of the parameters must have an enabled
            gradient."""
    if not hasattr(q, "modules"):
        raise ValueError(
            """The variational distribution requires an modules attribute, which returns
            an iterable of parameters.""")
    else:
        assert isinstance(q.modules,
                          Callable), "The parameters must be callable"
        modules = q.modules()
        assert isinstance(
            modules, Iterable), "The parameters return value must be iterable"
        for m in modules:
            assert isinstance(
                m, Module), "The modules must contain PyTorch Module objects"
Ejemplo n.º 14
0
    def pgm_model(self):
        sex_dist = Bernoulli(logits=self.sex_logits).to_event(1)
        # pseudo call to register with pyro
        _ = self.sex_logits
        sex = pyro.sample('sex', sex_dist, infer=dict(baseline={'use_decaying_avg_baseline': True}))

        slice_number_dist = Uniform(self.slice_number_min, self.slice_number_max).to_event(1)
        slice_number = pyro.sample('slice_number', slice_number_dist)

        age_base_dist = Normal(self.age_base_loc, self.age_base_scale).to_event(1)
        age_dist = TransformedDistribution(age_base_dist, self.age_flow_transforms)
        _ = self.age_flow_components
        age = pyro.sample('age', age_dist)
        age_ = self.age_flow_constraint_transforms.inv(age)

        duration_context = torch.cat([sex, age_], 1)
        duration_base_dist = Normal(self.duration_base_loc, self.duration_base_scale).to_event(1)
        duration_dist = ConditionalTransformedDistribution(duration_base_dist, self.duration_flow_transforms).condition(duration_context)  # noqa: E501
        duration = pyro.sample('duration', duration_dist)
        _ = self.duration_flow_components
        duration_ = self.duration_flow_constraint_transforms.inv(duration)

        edss_context = torch.cat([sex, duration_], 1)
        edss_base_dist = Normal(self.edss_base_loc, self.edss_base_scale).to_event(1)
        edss_dist = ConditionalTransformedDistribution(edss_base_dist, self.edss_flow_transforms).condition(edss_context)  # noqa: E501
        edss = pyro.sample('edss', edss_dist)
        _ = self.edss_flow_components
        edss_ = self.edss_flow_constraint_transforms.inv(edss)

        brain_context = torch.cat([sex, age_], 1)
        brain_volume_base_dist = Normal(self.brain_volume_base_loc, self.brain_volume_base_scale).to_event(1)
        brain_volume_dist = ConditionalTransformedDistribution(brain_volume_base_dist, self.brain_volume_flow_transforms).condition(brain_context)
        _ = self.brain_volume_flow_components
        brain_volume = pyro.sample('brain_volume', brain_volume_dist)
        brain_volume_ = self.brain_volume_flow_constraint_transforms.inv(brain_volume)

        ventricle_context = torch.cat([age_, brain_volume_, duration_], 1)
        ventricle_volume_base_dist = Normal(self.ventricle_volume_base_loc, self.ventricle_volume_base_scale).to_event(1)
        ventricle_volume_dist = ConditionalTransformedDistribution(ventricle_volume_base_dist, self.ventricle_volume_flow_transforms).condition(ventricle_context)  # noqa: E501
        ventricle_volume = pyro.sample('ventricle_volume', ventricle_volume_dist)
        _ = self.ventricle_volume_flow_components
        ventricle_volume_ = self.ventricle_volume_flow_constraint_transforms.inv(ventricle_volume)

        lesion_context = torch.cat([brain_volume_, ventricle_volume_, duration_, edss_], 1)
        lesion_volume_base_dist = Normal(self.lesion_volume_base_loc, self.lesion_volume_base_scale).to_event(1)
        lesion_volume_dist = ConditionalTransformedDistribution(lesion_volume_base_dist, self.lesion_volume_flow_transforms).condition(lesion_context)
        lesion_volume = pyro.sample('lesion_volume', lesion_volume_dist)
        _ = self.lesion_volume_flow_components

        return dict(age=age, sex=sex, ventricle_volume=ventricle_volume, brain_volume=brain_volume,
                    lesion_volume=lesion_volume, duration=duration, edss=edss, slice_number=slice_number)
Ejemplo n.º 15
0
    def transform(self, t: Transform):
        """Performs a transformation on the distribution underlying the RV.

        :param t: The transformation (or sequence of transformations) to be
            applied to the distribution. There are many examples to be found in
            `torch.distributions.transforms` and `pyro.distributions.transforms`,
            or you can subclass directly from `Transform`.
        :type t: ~pyro.distributions.transforms.Transform

        :return: The transformed `RandomVariable`
        :rtype: ~pyro.contrib.randomvariable.random_variable.RandomVariable
        """
        dist = TransformedDistribution(self.distribution, t)
        return RandomVariable(dist)
Ejemplo n.º 16
0
    def pgm_model(self):
        sex_dist = Bernoulli(logits=self.sex_logits).to_event(1)

        _ = self.sex_logits

        sex = pyro.sample('sex', sex_dist)

        age_base_dist = Normal(self.age_base_loc,
                               self.age_base_scale).to_event(1)
        age_dist = TransformedDistribution(age_base_dist,
                                           self.age_flow_transforms)

        age = pyro.sample('age', age_dist)
        age_ = self.age_flow_constraint_transforms.inv(age)
        # pseudo call to thickness_flow_transforms to register with pyro
        _ = self.age_flow_components

        brain_context = torch.cat([sex, age_], 1)

        brain_volume_base_dist = Normal(
            self.brain_volume_base_loc,
            self.brain_volume_base_scale).to_event(1)
        brain_volume_dist = ConditionalTransformedDistribution(
            brain_volume_base_dist,
            self.brain_volume_flow_transforms).condition(brain_context)

        brain_volume = pyro.sample('brain_volume', brain_volume_dist)
        # pseudo call to intensity_flow_transforms to register with pyro
        _ = self.brain_volume_flow_components

        brain_volume_ = self.brain_volume_flow_constraint_transforms.inv(
            brain_volume)

        ventricle_context = torch.cat([age_, brain_volume_], 1)

        ventricle_volume_base_dist = Normal(
            self.ventricle_volume_base_loc,
            self.ventricle_volume_base_scale).to_event(1)
        ventricle_volume_dist = ConditionalTransformedDistribution(
            ventricle_volume_base_dist,
            self.ventricle_volume_flow_transforms).condition(
                ventricle_context)  # noqa: E501

        ventricle_volume = pyro.sample('ventricle_volume',
                                       ventricle_volume_dist)
        # pseudo call to intensity_flow_transforms to register with pyro
        _ = self.ventricle_volume_flow_components

        return age, sex, ventricle_volume, brain_volume
 def getDistribution(self,
                     z_mean,
                     z_sig,
                     cond_input,
                     use_cached_flows=False,
                     extra_cond=True):
     #Gets Either a multi-variate Gaussian or the transformed distribution
     #extra_cond is the extra condition in which to use the transformed distribution
     base_dist = Normal(z_mean, z_sig).to_event(1)
     if len(self.nf) > 0 and extra_cond and not self.use_cond_flow:
         dist = TransformedDistribution(base_dist, self.nf)
     elif len(self.nf) > 0 and self.use_cond_flow and extra_cond:
         #for some reason calling flows like this has event_dim= 0 by default (is wrong)
         if (use_cached_flows and self.cached_flows is not None):
             flows = self.cached_flows
         else:
             flows = [nf.condition(cond_input) for nf in self.nf]
             for f in flows:
                 f.event_dim = 1
             self.cached_flows = flows
         dist = TransformedDistribution(base_dist, flows)
     else:
         dist = base_dist
     return dist
Ejemplo n.º 18
0
def model(n_samples=None, scale=0.5, invert=False):
    with pyro.plate('observations', n_samples):
        thickness = 0.5 + pyro.sample('thickness', Gamma(10., 5.))

        if invert:
            loc = (thickness - 2) * -2
        else:
            loc = (thickness - 2.5) * 2

        transforms = ComposeTransform(
            [SigmoidTransform(), AffineTransform(64, 191)])

        intensity = pyro.sample(
            'intensity', TransformedDistribution(Normal(loc, scale),
                                                 transforms))

    return thickness, intensity
Ejemplo n.º 19
0
    def model(self):
        thickness, intensity = self.pgm_model()

        thickness_ = self.thickness_flow_constraint_transforms.inv(thickness)
        intensity_ = self.intensity_flow_norm.inv(intensity)

        context = torch.cat([thickness_, intensity_], 1)

        x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3)
        cond_x_transforms = ComposeTransform(
            ConditionalTransformedDistribution(
                x_base_dist,
                self.x_transforms).condition(context).transforms).inv
        cond_x_dist = TransformedDistribution(x_base_dist, cond_x_transforms)

        x = pyro.sample('x', cond_x_dist)

        return x, thickness, intensity
Ejemplo n.º 20
0
    def model(self):
        obs = self.pgm_model()

        ventricle_volume_ = self.ventricle_volume_flow_constraint_transforms.inv(
            obs['ventricle_volume'])
        brain_volume_ = self.brain_volume_flow_constraint_transforms.inv(
            obs['brain_volume'])
        lesion_volume_ = self.lesion_volume_flow_constraint_transforms.inv(
            obs['lesion_volume'])
        slice_number = obs['slice_number']
        ctx = torch.cat(
            [ventricle_volume_, brain_volume_, lesion_volume_, slice_number],
            1)

        z = []
        for i in range(self.n_levels):
            last_layer = self.hierarchical_layers[i] == self.last_layer
            if last_layer:
                z_base_dist = Normal(self.z_loc, self.z_scale).to_event(1)
                z_dist = TransformedDistribution(
                    z_base_dist, self.prior_flow_transforms
                ) if self.use_prior_flow else z_base_dist
                _ = self.prior_affine
                _ = self.prior_flow_components
            else:
                z_probs = getattr(self, f'z_probs_{i}')
                temperature = torch.tensor(self.temperature,
                                           device=ctx.device,
                                           requires_grad=False)
                z_dist = RelaxedBernoulliStraightThrough(
                    temperature, probs=z_probs).to_event(3)
            with poutine.scale(scale=self.annealing_factor[i]):
                z.append(pyro.sample(f'z{i}', z_dist))
        z[-1] = torch.cat([z[-1], ctx], 1)

        x_dist = self._get_transformed_x_dist(z, ctx)  # run decoder
        x = pyro.sample('x', x_dist)

        obs.update(dict(x=x, z=z))
        return obs
Ejemplo n.º 21
0
    def _get_transformed_x_dist(self, latent):
        x_pred_dist = self.decoder.predict(latent)
        x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3)

        preprocess_transform = self._get_preprocess_transforms()

        if isinstance(x_pred_dist, MultivariateNormal) or isinstance(
                x_pred_dist, LowRankMultivariateNormal):
            chol_transform = LowerCholeskyAffine(x_pred_dist.loc,
                                                 x_pred_dist.scale_tril)
            reshape_transform = ReshapeTransform(self.img_shape,
                                                 (np.prod(self.img_shape), ))
            x_reparam_transform = ComposeTransform(
                [reshape_transform, chol_transform, reshape_transform.inv])
        elif isinstance(x_pred_dist, Independent):
            x_pred_dist = x_pred_dist.base_dist
            x_reparam_transform = AffineTransform(x_pred_dist.loc,
                                                  x_pred_dist.scale, 3)

        return TransformedDistribution(
            x_base_dist,
            ComposeTransform([x_reparam_transform, preprocess_transform]))
Ejemplo n.º 22
0
    def guide(self, obs):
        batch_size = obs['x'].shape[0]
        with pyro.plate('observations', batch_size):
            hidden = self.encoder(obs['x'])

            ventricle_volume_ = self.ventricle_volume_flow_constraint_transforms.inv(
                obs['ventricle_volume'])
            brain_volume_ = self.brain_volume_flow_constraint_transforms.inv(
                obs['brain_volume'])
            lesion_volume_ = self.lesion_volume_flow_constraint_transforms.inv(
                obs['lesion_volume'])
            slice_number = obs['slice_number']
            ctx = torch.cat([
                ventricle_volume_, brain_volume_, lesion_volume_, slice_number
            ], 1)

            z = []
            layers = zip(self.latent_encoder, self.guide_ctx_attn,
                         self.hierarchical_layers)
            for i, (latent_enc, ctx_attn, layer) in enumerate(layers):
                last_layer = layer == self.last_layer
                if last_layer:
                    hidden_i = torch.cat([hidden[i], ctx], 1)
                    z_base_dist = latent_enc.predict(hidden_i)
                    z_dist = TransformedDistribution(
                        z_base_dist, self.posterior_flow_transforms
                    ) if self.use_posterior_flow else z_base_dist
                    _ = self.posterior_affine
                    _ = self.posterior_flow_components
                else:
                    ctx_ = ctx_attn(ctx).view(batch_size, -1, 1, 1)
                    hidden_i = hidden[i] * ctx_
                    z_dist = latent_enc.predict(hidden_i)
                with poutine.scale(scale=self.annealing_factor[i]):
                    z.append(pyro.sample(f'z{i}', z_dist))

        return z
Ejemplo n.º 23
0
    def model(self):
        age, sex, ventricle_volume, brain_volume = self.pgm_model()

        ventricle_volume_ = self.ventricle_volume_flow_constraint_transforms.inv(
            ventricle_volume)
        brain_volume_ = self.brain_volume_flow_constraint_transforms.inv(
            brain_volume)
        age_ = self.age_flow_constraint_transforms.inv(age)

        z = pyro.sample('z', Normal(self.z_loc, self.z_scale).to_event(1))

        latent = torch.cat([z, age_, ventricle_volume_, brain_volume_], 1)

        x_loc = self.decoder_mean(self.decoder(latent))
        x_scale = torch.exp(self.decoder_logstd)

        theta = self.decoder_affine_param_net(latent).view(-1, 2, 3)

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")

            grid = nn.functional.affine_grid(theta, x_loc.size())
            x_loc_deformed = nn.functional.grid_sample(x_loc, grid)

        x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3)

        preprocess_transform = self._get_preprocess_transforms()
        x_dist = TransformedDistribution(
            x_base_dist,
            ComposeTransform([
                AffineTransform(x_loc_deformed, x_scale, 3),
                preprocess_transform
            ]))

        x = pyro.sample('x', x_dist)

        return x, z, age, sex, ventricle_volume, brain_volume
Ejemplo n.º 24
0
 def __abs__(self):
     return RandomVariable(
         TransformedDistribution(self.distribution, AbsTransform()))
Ejemplo n.º 25
0
    def test_minibatch(which_mini_batch, shuffled_indices):

        # compute which sequences in the training set we should grab
        mini_batch_start = (which_mini_batch * args.mini_batch_size)
        mini_batch_end = np.min([(which_mini_batch + 1) * args.mini_batch_size,
                                 N_test_data])
        mini_batch_indices = shuffled_indices[mini_batch_start:mini_batch_end]
        # grab a fully prepped mini-batch using the helper function in the data loader
        mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \
            = poly.get_mini_batch(mini_batch_indices, test_data_sequences,
                                  test_seq_lengths, cuda=args.cuda)

        # Get the initial RNN state.
        h_0 = dmm.h_0
        h_0_contig = h_0.expand(1, mini_batch.size(0),
                                dmm.rnn.hidden_size).contiguous()

        # Feed the test sequence into the RNN.
        rnn_output, rnn_hidden_state = dmm.rnn(mini_batch_reversed, h_0_contig)

        # Reverse the time ordering of the hidden state and unpack it.
        rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths)
        print(rnn_output)
        print(rnn_output.shape)

        # set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...)
        z_prev = dmm.z_q_0.expand(mini_batch.size(0), dmm.z_q_0.size(0))

        # sample the latents z one time step at a time
        T_max = mini_batch.size(1)
        sequence_output = []
        for t in range(1, T_max + 1):
            # the next two lines assemble the distribution q(z_t | z_{t-1}, x_{t:T})
            z_loc, z_scale = dmm.combiner(z_prev, rnn_output[:, t - 1, :])

            # if we are using normalizing flows, we apply the sequence of transformations
            # parameterized by self.iafs to the base distribution defined in the previous line
            # to yield a transformed distribution that we use for q(z_t|...)
            if len(dmm.iafs) > 0:
                z_dist = TransformedDistribution(dist.Normal(z_loc, z_scale),
                                                 dmm.iafs)
            else:
                z_dist = dist.Normal(z_loc, z_scale)
            assert z_dist.event_shape == ()
            assert z_dist.batch_shape == (len(mini_batch), dmm.z_q_0.size(0))

            # sample z_t from the distribution z_dist
            annealing_factor = 1.0
            with pyro.poutine.scale(scale=annealing_factor):
                z_t = pyro.sample(
                    "z_%d" % t,
                    z_dist.mask(mini_batch_mask[:, t - 1:t]).to_event(1))

            print("z_{}:".format(t), z_t)
            print(z_t.shape)

            # compute the probabilities that parameterize the bernoulli likelihood
            emission_probs_t = dmm.emitter(z_t)

            emission_probs_t_np = emission_probs_t.detach().numpy()
            sequence_output.append(emission_probs_t_np)

            print("x_{}:".format(t), emission_probs_t)
            print(emission_probs_t.shape)

            # the latent sampled at this time step will be conditioned upon in the next time step
            # so keep track of it
            z_prev = z_t

        # Run the model another few steps.
        n_steps = 100
        for t in range(1, n_steps + 1):
            # first compute the parameters of the diagonal gaussian distribution p(z_t | z_{t-1})
            z_loc, z_scale = dmm.trans(z_prev)

            # then sample z_t according to dist.Normal(z_loc, z_scale)
            # note that we use the reshape method so that the univariate Normal distribution
            # is treated as a multivariate Normal distribution with a diagonal covariance.
            with poutine.scale(scale=annealing_factor):
                z_t = pyro.sample(
                    "z_%d" % t,
                    dist.Normal(z_loc, z_scale).mask(
                        mini_batch_mask[:, t - 1:t]).to_event(1))

            # compute the probabilities that parameterize the bernoulli likelihood
            emission_probs_t = dmm.emitter(z_t)

            emission_probs_t_np = emission_probs_t.detach().numpy()
            sequence_output.append(emission_probs_t_np)

            # # the next statement instructs pyro to observe x_t according to the
            # # bernoulli distribution p(x_t|z_t)
            # pyro.sample("obs_x_%d" % t,
            #             # dist.Bernoulli(emission_probs_t)
            #             dist.Normal(emission_probs_t, 0.5)
            #             .mask(mini_batch_mask[:, t - 1:t])
            #             .to_event(1),
            #             obs=mini_batch[:, t - 1, :])

            # the latent sampled at this time step will be conditioned upon
            # in the next time step so keep track of it
            z_prev = z_t

        sequence_output = np.concatenate(sequence_output, axis=1)
        print(sequence_output.shape)

        n_plots = 5
        fig, axes = plt.subplots(nrows=n_plots, ncols=1)
        x = range(sequence_output.shape[1])
        for i in range(n_plots):
            input = mini_batch[i, :].numpy().squeeze()
            output = sequence_output[i, :]
            axes[i].plot(range(input.shape[0]), input)
            axes[i].plot(range(len(output)), output)
            axes[i].grid()

        # plt.plot(sequence_output[0, :])
        plt.show()
Ejemplo n.º 26
0
 def __rsub__(self, x: Union[float, Tensor]):
     return RandomVariable(
         TransformedDistribution(self.distribution, AffineTransform(x, -1)))
Ejemplo n.º 27
0
def test_minibatch(dmm, mini_batch, args, sample_z=True):

    # Generate data that we can feed into the below fn.
    test_data_sequences = mini_batch.type(torch.FloatTensor)
    mini_batch_indices = torch.arange(0, test_data_sequences.size(0))
    test_seq_lengths = torch.full(
        (test_data_sequences.size(0), ),
        test_data_sequences.size(1)).type(torch.IntTensor)

    # grab a fully prepped mini-batch using the helper function in the data loader
    mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \
        = poly.get_mini_batch(mini_batch_indices, test_data_sequences,
                              test_seq_lengths, cuda=args.cuda)

    # Get the initial RNN state.
    h_0 = dmm.h_0
    h_0_contig = h_0.expand(1, mini_batch.size(0),
                            dmm.rnn.hidden_size).contiguous()

    # Feed the test sequence into the RNN.
    rnn_output, rnn_hidden_state = dmm.rnn(mini_batch_reversed, h_0_contig)

    # Reverse the time ordering of the hidden state and unpack it.
    rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths)
    # print(rnn_output)
    # print(rnn_output.shape)

    # set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...)
    z_prev = dmm.z_q_0.expand(mini_batch.size(0), dmm.z_q_0.size(0))

    # sample the latents z one time step at a time
    T_max = mini_batch.size(1)
    sequence_z = []
    sequence_output = []
    for t in range(1, T_max + 1):
        # the next two lines assemble the distribution q(z_t | z_{t-1}, x_{t:T})
        z_loc, z_scale = dmm.combiner(z_prev, rnn_output[:, t - 1, :])

        if sample_z:
            # if we are using normalizing flows, we apply the sequence of transformations
            # parameterized by self.iafs to the base distribution defined in the previous line
            # to yield a transformed distribution that we use for q(z_t|...)
            if len(dmm.iafs) > 0:
                z_dist = TransformedDistribution(dist.Normal(z_loc, z_scale),
                                                 dmm.iafs)
            else:
                z_dist = dist.Normal(z_loc, z_scale)
            assert z_dist.event_shape == ()
            assert z_dist.batch_shape == (len(mini_batch), dmm.z_q_0.size(0))

            # sample z_t from the distribution z_dist
            annealing_factor = 1.0
            with pyro.poutine.scale(scale=annealing_factor):
                z_t = pyro.sample(
                    "z_%d" % t,
                    z_dist.mask(mini_batch_mask[:, t - 1:t]).to_event(1))
        else:
            z_t = z_loc

        z_t_np = z_t.detach().numpy()
        z_t_np = z_t_np[:, np.newaxis, :]
        sequence_z.append(z_t_np)

        # print("z_{}:".format(t), z_t)
        # print(z_t.shape)

        # compute the probabilities that parameterize the bernoulli likelihood
        emission_probs_t = dmm.emitter(z_t)

        emission_probs_t_np = emission_probs_t.detach().numpy()
        sequence_output.append(emission_probs_t_np)

        # print("x_{}:".format(t), emission_probs_t)
        # print(emission_probs_t.shape)

        # the latent sampled at this time step will be conditioned upon in the next time step
        # so keep track of it
        z_prev = z_t

    # Run the model another few steps.
    n_extra_steps = 100
    for t in range(1, n_extra_steps + 1):
        # first compute the parameters of the diagonal gaussian distribution p(z_t | z_{t-1})
        z_loc, z_scale = dmm.trans(z_prev)

        # then sample z_t according to dist.Normal(z_loc, z_scale)
        # note that we use the reshape method so that the univariate Normal distribution
        # is treated as a multivariate Normal distribution with a diagonal covariance.
        annealing_factor = 1.0
        with poutine.scale(scale=annealing_factor):
            z_t = pyro.sample(
                "z_%d" % t,
                dist.Normal(z_loc, z_scale)
                # .mask(mini_batch_mask[:, t - 1:t])
                .to_event(1))

        z_t_np = z_t.detach().numpy()
        z_t_np = z_t_np[:, np.newaxis, :]
        sequence_z.append(z_t_np)

        # compute the probabilities that parameterize the bernoulli likelihood
        emission_probs_t = dmm.emitter(z_t)

        emission_probs_t_np = emission_probs_t.detach().numpy()
        sequence_output.append(emission_probs_t_np)

        # the latent sampled at this time step will be conditioned upon
        # in the next time step so keep track of it
        z_prev = z_t

    sequence_z = np.concatenate(sequence_z, axis=1)
    sequence_output = np.concatenate(sequence_output, axis=1)
    # print(sequence_output.shape)

    # n_plots = 5
    # fig, axes = plt.subplots(nrows=n_plots, ncols=1)
    # x = range(sequence_output.shape[1])
    # for i in range(n_plots):
    #     input = mini_batch[i, :].numpy().squeeze()
    #     output = sequence_output[i, :]
    #     axes[i].plot(range(input.shape[0]), input)
    #     axes[i].plot(range(len(output)), output)
    #     axes[i].grid()

    return mini_batch, sequence_z, sequence_output  #fig
Ejemplo n.º 28
0
 def __truediv__(self, x: Union[float, Tensor]):
     return RandomVariable(
         TransformedDistribution(self.distribution,
                                 AffineTransform(0, 1 / x)))
Ejemplo n.º 29
0
 def __neg__(self):
     return RandomVariable(
         TransformedDistribution(self.distribution, AffineTransform(0, -1)))
Ejemplo n.º 30
0
 def unconstrained_prior(self) -> TransformedDistribution:
     return TransformedDistribution(self(), self.bijection.inv)
Ejemplo n.º 31
0
 def __pow__(self, x):
     return RandomVariable(
         TransformedDistribution(self.distribution, PowerTransform(x)))