def guide(self, mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths, annealing_factor=1.0): # this is the number of time steps we need to process in the mini-batch T_max = mini_batch.size(1) # register all PyTorch (sub)modules with pyro pyro.module("dmm", self) # if on gpu we need the fully broadcast view of the rnn initial state # to be in contiguous gpu memory h_0_contig = self.h_0.expand(1, mini_batch.size(0), self.rnn.hidden_size).contiguous() # push the observed x's through the rnn; # rnn_output contains the hidden state at each time step rnn_output, _ = self.rnn(mini_batch_reversed, h_0_contig) # reverse the time-ordering in the hidden state and un-pack it rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths) # set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...) z_prev = self.z_q_0.expand(mini_batch.size(0), self.z_q_0.size(0)) # we enclose all the sample statements in the guide in a plate. # this marks that each datapoint is conditionally independent of the others. with pyro.plate("z_minibatch", len(mini_batch)): # sample the latents z one time step at a time # we wrap this loop in pyro.markov so that TraceEnum_ELBO can use multiple samples from the guide at each z for t in pyro.markov(range(1, T_max + 1)): # the next two lines assemble the distribution q(z_t | z_{t-1}, x_{t:T}) z_loc, z_scale = self.combiner(z_prev, rnn_output[:, t - 1, :]) # if we are using normalizing flows, we apply the sequence of transformations # parameterized by self.iafs to the base distribution defined in the previous line # to yield a transformed distribution that we use for q(z_t|...) if len(self.iafs) > 0: z_dist = TransformedDistribution(dist.Normal(z_loc, z_scale), self.iafs) assert z_dist.event_shape == (self.z_q_0.size(0),) assert z_dist.batch_shape[-1:] == (len(mini_batch),) else: z_dist = dist.Normal(z_loc, z_scale) assert z_dist.event_shape == () assert z_dist.batch_shape[-2:] == (len(mini_batch), self.z_q_0.size(0)) # sample z_t from the distribution z_dist with pyro.poutine.scale(scale=annealing_factor): if len(self.iafs) > 0: # in output of normalizing flow, all dimensions are correlated (event shape is not empty) z_t = pyro.sample("z_%d" % t, z_dist.mask(mini_batch_mask[:, t - 1])) else: # when no normalizing flow used, ".to_event(1)" indicates latent dimensions are independent z_t = pyro.sample("z_%d" % t, z_dist.mask(mini_batch_mask[:, t - 1:t]) .to_event(1)) # the latent sampled at this time step will be conditioned upon in the next time step # so keep track of it z_prev = z_t
def __init__(self, use_affine_ex=True, **kwargs): super.__init__(**kwargs) self.num_scales = 2 self.register_buffer("glasses_base_loc", torch.zeros([ 1, ], requires_grad=False)) self.register_buffer("glasses_base_scale", torch.ones([ 1, ], requires_grad=False)) self.register_buffer("glasses_flow_lognorm_loc", torch.zeros([], requires_grad=False)) self.register_buffer("glasses_flow_lognorm_scale", torch.ones([], requires_grad=False)) self.glasses_flow_components = ComposeTransformModule([Spline(1)]) self.glasses_flow_constraint_transforms = ComposeTransform( [self.glasses_flow_lognorm, SigmoidTransform()]) self.glasses_flow_transforms = ComposeTransform([ self.glasses_flow_components, self.glasses_flow_constraint_transforms ]) glasses_base_dist = Normal(self.glasses_base_loc, self.glasses_base_scale).to_event(1) self.glasses_dist = TransformedDistribution( glasses_base_dist, self.glasses_flow_transforms) glasses_ = pyro.sample("glasses_", self.glasses_dist) glasses = pyro.sample("glasses", dist.Bernoulli(glasses_)) glasses_context = self.glasses_flow_constraint_transforms.inv(glasses_) self.x_transforms = self._build_image_flow() self.register_buffer("x_base_loc", torch.zeros([1, 64, 64], requires_grad=False)) self.register_buffer("x_base_scale", torch.ones([1, 64, 64], requires_grad=False)) x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3) cond_x_transforms = ComposeTransform( ConditionalTransformedDistribution( x_base_dist, self.x_transforms).condition(context).transforms).inv cond_x_dist = TransformedDistribution(x_base_dist, cond_x_transforms) x = pyro.sample("x", cond_x_dist) return x, glasses
def _get_transformed_x_dist(self, latent, ctx=None): x_pred_dist = self.decoder.predict( latent, ctx) # returns a normal dist with mean of the predicted image if self.laplace_likelihood: x_base_dist = Laplace(self.x_base_loc, self.x_base_scale).to_event(3) else: x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event( 3) # 3 dimensions starting from right dep. preprocess_transform = self._get_preprocess_transforms() if isinstance(x_pred_dist, MultivariateNormal) or isinstance( x_pred_dist, LowRankMultivariateNormal): chol_transform = LowerCholeskyAffine(x_pred_dist.loc, x_pred_dist.scale_tril) reshape_transform = ReshapeTransform(self.img_shape, (np.prod(self.img_shape), )) x_reparam_transform = ComposeTransform( [reshape_transform, chol_transform, reshape_transform.inv]) elif isinstance(x_pred_dist, Independent): x_pred_dist = x_pred_dist.base_dist x_reparam_transform = AffineTransform(x_pred_dist.loc, x_pred_dist.scale, 3) else: raise ValueError(f'{x_pred_dist} not valid.') return TransformedDistribution( x_base_dist, ComposeTransform([x_reparam_transform, preprocess_transform]))
def guide(self, obs): batch_size = obs['x'].shape[0] with pyro.plate('observations', batch_size): hidden = self.encoder(obs['x']) ventricle_volume_ = self.ventricle_volume_flow_constraint_transforms.inv( obs['ventricle_volume']) brain_volume_ = self.brain_volume_flow_constraint_transforms.inv( obs['brain_volume']) lesion_volume_ = self.lesion_volume_flow_constraint_transforms.inv( obs['lesion_volume']) slice_number = obs['slice_number'] ctx = torch.cat([ ventricle_volume_, brain_volume_, lesion_volume_, slice_number ], 1) hidden = torch.cat([hidden, ctx], 1) z_base_dist = self.latent_encoder.predict(hidden) z_dist = TransformedDistribution( z_base_dist, self.posterior_flow_transforms ) if self.use_posterior_flow else z_base_dist _ = self.posterior_affine _ = self.posterior_flow_components with poutine.scale(scale=self.annealing_factor[-1]): z = pyro.sample('z', z_dist) return z
def model(self): obs = self.pgm_model() ventricle_volume_ = self.ventricle_volume_flow_constraint_transforms.inv( obs['ventricle_volume']) brain_volume_ = self.brain_volume_flow_constraint_transforms.inv( obs['brain_volume']) lesion_volume_ = self.lesion_volume_flow_constraint_transforms.inv( obs['lesion_volume']) slice_number = obs['slice_number'] ctx = torch.cat( [ventricle_volume_, brain_volume_, lesion_volume_, slice_number], 1) if self.prior_components > 1: z_scale = ( 0.5 * self.z_scale).exp() + 1e-5 # z_scale parameter is logvar z_base_dist = MixtureOfDiagNormalsSharedCovariance( self.z_loc, z_scale, self.z_components).to_event(0) else: z_base_dist = Normal(self.z_loc, self.z_scale).to_event(1) z_dist = TransformedDistribution( z_base_dist, self.prior_flow_transforms) if self.use_prior_flow else z_base_dist _ = self.prior_affine _ = self.prior_flow_components with poutine.scale(scale=self.annealing_factor[-1]): z = pyro.sample('z', z_dist) latent = torch.cat([z, ctx], 1) x_dist = self._get_transformed_x_dist(latent) # run decoder x = pyro.sample('x', x_dist) obs.update(dict(x=x, z=z)) return obs
def guide(self, response, mask, annealing_factor=1): pyro.module("item_encoder", self.item_encoder) pyro.module("ability_encoder", self.ability_encoder) device = response.device item_domain = torch.arange(self.num_item).unsqueeze(1).to(device) item_feat_mu, item_feat_logvar = self.item_encoder(item_domain) item_feat_scale = torch.exp(0.5 * item_feat_logvar) with poutine.scale(scale=annealing_factor): item_feat = pyro.sample( "item_feat", dist.Normal(item_feat_mu, item_feat_scale), ) if self.conditional_posterior: ability_mu, ability_logvar = self.ability_encoder( response, mask, item_feat) else: ability_mu, ability_logvar = self.ability_encoder(response, mask) ability_scale = torch.exp(0.5 * ability_logvar) ability_dist = dist.Normal(ability_mu, ability_scale) if self.num_iafs > 0: ability_dist = TransformedDistribution(ability_dist, self.iafs) with poutine.scale(scale=annealing_factor): ability = pyro.sample("ability", ability_dist) return ability_mu, ability_logvar, item_feat_mu, item_feat_logvar
def adapt_variational_distribution( q: TransformedDistribution, prior: Distribution, link_transform: Callable, parameters: Iterable = [], modules: Iterable = [], ) -> Distribution: """This will adapt a distribution to be compatible with DivergenceOptimizers. Especially it will make sure that the distribution has parameters and that it satisfies obvious contraints which a posterior must satisfy i.e. the support must be equal to that of the prior. Args: q: Variational distribution. prior: Prior distribution theta_transform: Theta transformation. parameters: List of parameters. modules: List of modules. Returns: TransformedDistribution: Compatible variational distribution. """ # Extract user define parameters def parameters_fn(): """Returns the parameters of the distribution.""" return parameters def modules_fn(): """Returns the modules of the distribution.""" return modules if isinstance(q, TransformedDistribution): if parameters == [] or modules_fn == []: add_parameter_attributes_to_transformed_distribution(q) else: add_parameters_module_attributes(q, parameters_fn, modules_fn) if hasattr(prior, "support") and q.support != prior.support: q = TransformedDistribution(q.base_dist, q.transforms + [link_transform]) else: if hasattr(prior, "support") and q.support != prior.support: q = TransformedDistribution(q, [link_transform]) add_parameters_module_attributes(q, parameters_fn, modules_fn) return q
def guide(self, mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths, annealing_factor=1.0): # this is the number of time steps we need to process in the mini-batch T_max = mini_batch.size(1) # register all PyTorch (sub)modules with pyro pyro.module("dmm", self) # if on gpu we need the fully broadcast view of the rnn initial state # to be in contiguous gpu memory h_0_contig = self.h_0.expand(1, mini_batch.size(0), self.rnn.hidden_size).contiguous() # push the observed x's through the rnn; # rnn_output contains the hidden state at each time step rnn_output, _ = self.rnn(mini_batch_reversed, h_0_contig) # reverse the time-ordering in the hidden state and un-pack it rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths) # set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...) z_prev = self.z_q_0.expand(mini_batch.size(0), self.z_q_0.size(0)) # we enclose all the sample statements in the guide in a iarange. # this marks that each datapoint is conditionally independent of the others. with pyro.iarange("z_minibatch", len(mini_batch)): # sample the latents z one time step at a time for t in range(1, T_max + 1): # the next two lines assemble the distribution q(z_t | z_{t-1}, x_{t:T}) z_loc, z_scale = self.combiner(z_prev, rnn_output[:, t - 1, :]) # if we are using normalizing flows, we apply the sequence of transformations # parameterized by self.iafs to the base distribution defined in the previous line # to yield a transformed distribution that we use for q(z_t|...) if len(self.iafs) > 0: z_dist = TransformedDistribution(dist.Normal(z_loc, z_scale), self.iafs) else: z_dist = dist.Normal(z_loc, z_scale) assert z_dist.event_shape == () assert z_dist.batch_shape == (len(mini_batch), self.z_q_0.size(0)) # sample z_t from the distribution z_dist with pyro.poutine.scale(scale=annealing_factor): z_t = pyro.sample("z_%d" % t, z_dist.mask(mini_batch_mask[:, t - 1:t]) .independent(1)) # the latent sampled at this time step will be conditioned upon in the next time step # so keep track of it z_prev = z_t
def sample_latent_space(self, annealing_factor, batch_size, t, z_loc_q, z_scale_q): if len(self.iafs) > 0: z_dist = TransformedDistribution(dist.Normal(z_loc_q, z_scale_q), self.iafs) assert z_dist.event_shape == (self.z_q_0.size(0), ) assert z_dist.batch_shape[-1:] == (batch_size, ) else: z_dist = dist.Normal(z_loc_q, z_scale_q) assert z_dist.event_shape == () assert z_dist.batch_shape[-2:] == (batch_size, self.z_q_0.size(0)) with pyro.poutine.scale(scale=annealing_factor): if len(self.iafs) > 0: # in output of normalizing flow, all dimensions are correlated (event shape is not empty) z_t = pyro.sample(f"z_{t}", z_dist) else: # when no normalizing flow used, ".to_event(1)" indicates latent dimensions are independent z_t = pyro.sample(f"z_{t}", z_dist.to_event(1)) return z_t
def model(self): thickness, intensity = self.pgm_model() x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3) x_dist = TransformedDistribution( x_base_dist, ComposeTransform(self.x_transforms).inv) x = pyro.sample('x', x_dist) return x, thickness, intensity
def pgm_model(self): thickness_base_dist = Normal(self.thickness_base_loc, self.thickness_base_scale).to_event(1) thickness_dist = TransformedDistribution( thickness_base_dist, self.thickness_flow_transforms) thickness = pyro.sample('thickness', thickness_dist) # pseudo call to thickness_flow_transforms to register with pyro _ = self.thickness_flow_components intensity_base_dist = Normal(self.intensity_base_loc, self.intensity_base_scale).to_event(1) intensity_dist = TransformedDistribution( intensity_base_dist, self.intensity_flow_transforms) intensity = pyro.sample('intensity', intensity_dist) # pseudo call to intensity_flow_transforms to register with pyro _ = self.intensity_flow_components return thickness, intensity
def model(n_samples=None, scale=2.): with pyro.plate('observations', n_samples): thickness = pyro.sample('thickness', Gamma(10., 5.)) loc = (thickness - 2.5) * 2 transforms = ComposeTransform([SigmoidTransform(), AffineTransform(10, 15)]) width = pyro.sample('width', TransformedDistribution(Normal(loc, scale), transforms)) return thickness, width
def check_parameters_modules_attribute(q: TransformedDistribution) -> None: """Checks a parameterized distribution object for valid `parameters` and `modules`. Args: q: Distribution object """ if not hasattr(q, "parameters"): raise ValueError( """The variational distribution requires an `parameters` attribute, which returns an iterable of parameters""") else: assert isinstance(q.parameters, Callable), "The parameters must be callable" parameters = q.parameters() assert isinstance( parameters, Iterable), "The parameters return value must be iterable" trainable = 0 for p in parameters: assert isinstance(p, torch.Tensor) if p.requires_grad: trainable += 1 assert ( trainable > 0 ), """Nothing to train, atleast one of the parameters must have an enabled gradient.""" if not hasattr(q, "modules"): raise ValueError( """The variational distribution requires an modules attribute, which returns an iterable of parameters.""") else: assert isinstance(q.modules, Callable), "The parameters must be callable" modules = q.modules() assert isinstance( modules, Iterable), "The parameters return value must be iterable" for m in modules: assert isinstance( m, Module), "The modules must contain PyTorch Module objects"
def pgm_model(self): sex_dist = Bernoulli(logits=self.sex_logits).to_event(1) # pseudo call to register with pyro _ = self.sex_logits sex = pyro.sample('sex', sex_dist, infer=dict(baseline={'use_decaying_avg_baseline': True})) slice_number_dist = Uniform(self.slice_number_min, self.slice_number_max).to_event(1) slice_number = pyro.sample('slice_number', slice_number_dist) age_base_dist = Normal(self.age_base_loc, self.age_base_scale).to_event(1) age_dist = TransformedDistribution(age_base_dist, self.age_flow_transforms) _ = self.age_flow_components age = pyro.sample('age', age_dist) age_ = self.age_flow_constraint_transforms.inv(age) duration_context = torch.cat([sex, age_], 1) duration_base_dist = Normal(self.duration_base_loc, self.duration_base_scale).to_event(1) duration_dist = ConditionalTransformedDistribution(duration_base_dist, self.duration_flow_transforms).condition(duration_context) # noqa: E501 duration = pyro.sample('duration', duration_dist) _ = self.duration_flow_components duration_ = self.duration_flow_constraint_transforms.inv(duration) edss_context = torch.cat([sex, duration_], 1) edss_base_dist = Normal(self.edss_base_loc, self.edss_base_scale).to_event(1) edss_dist = ConditionalTransformedDistribution(edss_base_dist, self.edss_flow_transforms).condition(edss_context) # noqa: E501 edss = pyro.sample('edss', edss_dist) _ = self.edss_flow_components edss_ = self.edss_flow_constraint_transforms.inv(edss) brain_context = torch.cat([sex, age_], 1) brain_volume_base_dist = Normal(self.brain_volume_base_loc, self.brain_volume_base_scale).to_event(1) brain_volume_dist = ConditionalTransformedDistribution(brain_volume_base_dist, self.brain_volume_flow_transforms).condition(brain_context) _ = self.brain_volume_flow_components brain_volume = pyro.sample('brain_volume', brain_volume_dist) brain_volume_ = self.brain_volume_flow_constraint_transforms.inv(brain_volume) ventricle_context = torch.cat([age_, brain_volume_, duration_], 1) ventricle_volume_base_dist = Normal(self.ventricle_volume_base_loc, self.ventricle_volume_base_scale).to_event(1) ventricle_volume_dist = ConditionalTransformedDistribution(ventricle_volume_base_dist, self.ventricle_volume_flow_transforms).condition(ventricle_context) # noqa: E501 ventricle_volume = pyro.sample('ventricle_volume', ventricle_volume_dist) _ = self.ventricle_volume_flow_components ventricle_volume_ = self.ventricle_volume_flow_constraint_transforms.inv(ventricle_volume) lesion_context = torch.cat([brain_volume_, ventricle_volume_, duration_, edss_], 1) lesion_volume_base_dist = Normal(self.lesion_volume_base_loc, self.lesion_volume_base_scale).to_event(1) lesion_volume_dist = ConditionalTransformedDistribution(lesion_volume_base_dist, self.lesion_volume_flow_transforms).condition(lesion_context) lesion_volume = pyro.sample('lesion_volume', lesion_volume_dist) _ = self.lesion_volume_flow_components return dict(age=age, sex=sex, ventricle_volume=ventricle_volume, brain_volume=brain_volume, lesion_volume=lesion_volume, duration=duration, edss=edss, slice_number=slice_number)
def transform(self, t: Transform): """Performs a transformation on the distribution underlying the RV. :param t: The transformation (or sequence of transformations) to be applied to the distribution. There are many examples to be found in `torch.distributions.transforms` and `pyro.distributions.transforms`, or you can subclass directly from `Transform`. :type t: ~pyro.distributions.transforms.Transform :return: The transformed `RandomVariable` :rtype: ~pyro.contrib.randomvariable.random_variable.RandomVariable """ dist = TransformedDistribution(self.distribution, t) return RandomVariable(dist)
def pgm_model(self): sex_dist = Bernoulli(logits=self.sex_logits).to_event(1) _ = self.sex_logits sex = pyro.sample('sex', sex_dist) age_base_dist = Normal(self.age_base_loc, self.age_base_scale).to_event(1) age_dist = TransformedDistribution(age_base_dist, self.age_flow_transforms) age = pyro.sample('age', age_dist) age_ = self.age_flow_constraint_transforms.inv(age) # pseudo call to thickness_flow_transforms to register with pyro _ = self.age_flow_components brain_context = torch.cat([sex, age_], 1) brain_volume_base_dist = Normal( self.brain_volume_base_loc, self.brain_volume_base_scale).to_event(1) brain_volume_dist = ConditionalTransformedDistribution( brain_volume_base_dist, self.brain_volume_flow_transforms).condition(brain_context) brain_volume = pyro.sample('brain_volume', brain_volume_dist) # pseudo call to intensity_flow_transforms to register with pyro _ = self.brain_volume_flow_components brain_volume_ = self.brain_volume_flow_constraint_transforms.inv( brain_volume) ventricle_context = torch.cat([age_, brain_volume_], 1) ventricle_volume_base_dist = Normal( self.ventricle_volume_base_loc, self.ventricle_volume_base_scale).to_event(1) ventricle_volume_dist = ConditionalTransformedDistribution( ventricle_volume_base_dist, self.ventricle_volume_flow_transforms).condition( ventricle_context) # noqa: E501 ventricle_volume = pyro.sample('ventricle_volume', ventricle_volume_dist) # pseudo call to intensity_flow_transforms to register with pyro _ = self.ventricle_volume_flow_components return age, sex, ventricle_volume, brain_volume
def getDistribution(self, z_mean, z_sig, cond_input, use_cached_flows=False, extra_cond=True): #Gets Either a multi-variate Gaussian or the transformed distribution #extra_cond is the extra condition in which to use the transformed distribution base_dist = Normal(z_mean, z_sig).to_event(1) if len(self.nf) > 0 and extra_cond and not self.use_cond_flow: dist = TransformedDistribution(base_dist, self.nf) elif len(self.nf) > 0 and self.use_cond_flow and extra_cond: #for some reason calling flows like this has event_dim= 0 by default (is wrong) if (use_cached_flows and self.cached_flows is not None): flows = self.cached_flows else: flows = [nf.condition(cond_input) for nf in self.nf] for f in flows: f.event_dim = 1 self.cached_flows = flows dist = TransformedDistribution(base_dist, flows) else: dist = base_dist return dist
def model(n_samples=None, scale=0.5, invert=False): with pyro.plate('observations', n_samples): thickness = 0.5 + pyro.sample('thickness', Gamma(10., 5.)) if invert: loc = (thickness - 2) * -2 else: loc = (thickness - 2.5) * 2 transforms = ComposeTransform( [SigmoidTransform(), AffineTransform(64, 191)]) intensity = pyro.sample( 'intensity', TransformedDistribution(Normal(loc, scale), transforms)) return thickness, intensity
def model(self): thickness, intensity = self.pgm_model() thickness_ = self.thickness_flow_constraint_transforms.inv(thickness) intensity_ = self.intensity_flow_norm.inv(intensity) context = torch.cat([thickness_, intensity_], 1) x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3) cond_x_transforms = ComposeTransform( ConditionalTransformedDistribution( x_base_dist, self.x_transforms).condition(context).transforms).inv cond_x_dist = TransformedDistribution(x_base_dist, cond_x_transforms) x = pyro.sample('x', cond_x_dist) return x, thickness, intensity
def model(self): obs = self.pgm_model() ventricle_volume_ = self.ventricle_volume_flow_constraint_transforms.inv( obs['ventricle_volume']) brain_volume_ = self.brain_volume_flow_constraint_transforms.inv( obs['brain_volume']) lesion_volume_ = self.lesion_volume_flow_constraint_transforms.inv( obs['lesion_volume']) slice_number = obs['slice_number'] ctx = torch.cat( [ventricle_volume_, brain_volume_, lesion_volume_, slice_number], 1) z = [] for i in range(self.n_levels): last_layer = self.hierarchical_layers[i] == self.last_layer if last_layer: z_base_dist = Normal(self.z_loc, self.z_scale).to_event(1) z_dist = TransformedDistribution( z_base_dist, self.prior_flow_transforms ) if self.use_prior_flow else z_base_dist _ = self.prior_affine _ = self.prior_flow_components else: z_probs = getattr(self, f'z_probs_{i}') temperature = torch.tensor(self.temperature, device=ctx.device, requires_grad=False) z_dist = RelaxedBernoulliStraightThrough( temperature, probs=z_probs).to_event(3) with poutine.scale(scale=self.annealing_factor[i]): z.append(pyro.sample(f'z{i}', z_dist)) z[-1] = torch.cat([z[-1], ctx], 1) x_dist = self._get_transformed_x_dist(z, ctx) # run decoder x = pyro.sample('x', x_dist) obs.update(dict(x=x, z=z)) return obs
def _get_transformed_x_dist(self, latent): x_pred_dist = self.decoder.predict(latent) x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3) preprocess_transform = self._get_preprocess_transforms() if isinstance(x_pred_dist, MultivariateNormal) or isinstance( x_pred_dist, LowRankMultivariateNormal): chol_transform = LowerCholeskyAffine(x_pred_dist.loc, x_pred_dist.scale_tril) reshape_transform = ReshapeTransform(self.img_shape, (np.prod(self.img_shape), )) x_reparam_transform = ComposeTransform( [reshape_transform, chol_transform, reshape_transform.inv]) elif isinstance(x_pred_dist, Independent): x_pred_dist = x_pred_dist.base_dist x_reparam_transform = AffineTransform(x_pred_dist.loc, x_pred_dist.scale, 3) return TransformedDistribution( x_base_dist, ComposeTransform([x_reparam_transform, preprocess_transform]))
def guide(self, obs): batch_size = obs['x'].shape[0] with pyro.plate('observations', batch_size): hidden = self.encoder(obs['x']) ventricle_volume_ = self.ventricle_volume_flow_constraint_transforms.inv( obs['ventricle_volume']) brain_volume_ = self.brain_volume_flow_constraint_transforms.inv( obs['brain_volume']) lesion_volume_ = self.lesion_volume_flow_constraint_transforms.inv( obs['lesion_volume']) slice_number = obs['slice_number'] ctx = torch.cat([ ventricle_volume_, brain_volume_, lesion_volume_, slice_number ], 1) z = [] layers = zip(self.latent_encoder, self.guide_ctx_attn, self.hierarchical_layers) for i, (latent_enc, ctx_attn, layer) in enumerate(layers): last_layer = layer == self.last_layer if last_layer: hidden_i = torch.cat([hidden[i], ctx], 1) z_base_dist = latent_enc.predict(hidden_i) z_dist = TransformedDistribution( z_base_dist, self.posterior_flow_transforms ) if self.use_posterior_flow else z_base_dist _ = self.posterior_affine _ = self.posterior_flow_components else: ctx_ = ctx_attn(ctx).view(batch_size, -1, 1, 1) hidden_i = hidden[i] * ctx_ z_dist = latent_enc.predict(hidden_i) with poutine.scale(scale=self.annealing_factor[i]): z.append(pyro.sample(f'z{i}', z_dist)) return z
def model(self): age, sex, ventricle_volume, brain_volume = self.pgm_model() ventricle_volume_ = self.ventricle_volume_flow_constraint_transforms.inv( ventricle_volume) brain_volume_ = self.brain_volume_flow_constraint_transforms.inv( brain_volume) age_ = self.age_flow_constraint_transforms.inv(age) z = pyro.sample('z', Normal(self.z_loc, self.z_scale).to_event(1)) latent = torch.cat([z, age_, ventricle_volume_, brain_volume_], 1) x_loc = self.decoder_mean(self.decoder(latent)) x_scale = torch.exp(self.decoder_logstd) theta = self.decoder_affine_param_net(latent).view(-1, 2, 3) with warnings.catch_warnings(): warnings.simplefilter("ignore") grid = nn.functional.affine_grid(theta, x_loc.size()) x_loc_deformed = nn.functional.grid_sample(x_loc, grid) x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3) preprocess_transform = self._get_preprocess_transforms() x_dist = TransformedDistribution( x_base_dist, ComposeTransform([ AffineTransform(x_loc_deformed, x_scale, 3), preprocess_transform ])) x = pyro.sample('x', x_dist) return x, z, age, sex, ventricle_volume, brain_volume
def __abs__(self): return RandomVariable( TransformedDistribution(self.distribution, AbsTransform()))
def test_minibatch(which_mini_batch, shuffled_indices): # compute which sequences in the training set we should grab mini_batch_start = (which_mini_batch * args.mini_batch_size) mini_batch_end = np.min([(which_mini_batch + 1) * args.mini_batch_size, N_test_data]) mini_batch_indices = shuffled_indices[mini_batch_start:mini_batch_end] # grab a fully prepped mini-batch using the helper function in the data loader mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \ = poly.get_mini_batch(mini_batch_indices, test_data_sequences, test_seq_lengths, cuda=args.cuda) # Get the initial RNN state. h_0 = dmm.h_0 h_0_contig = h_0.expand(1, mini_batch.size(0), dmm.rnn.hidden_size).contiguous() # Feed the test sequence into the RNN. rnn_output, rnn_hidden_state = dmm.rnn(mini_batch_reversed, h_0_contig) # Reverse the time ordering of the hidden state and unpack it. rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths) print(rnn_output) print(rnn_output.shape) # set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...) z_prev = dmm.z_q_0.expand(mini_batch.size(0), dmm.z_q_0.size(0)) # sample the latents z one time step at a time T_max = mini_batch.size(1) sequence_output = [] for t in range(1, T_max + 1): # the next two lines assemble the distribution q(z_t | z_{t-1}, x_{t:T}) z_loc, z_scale = dmm.combiner(z_prev, rnn_output[:, t - 1, :]) # if we are using normalizing flows, we apply the sequence of transformations # parameterized by self.iafs to the base distribution defined in the previous line # to yield a transformed distribution that we use for q(z_t|...) if len(dmm.iafs) > 0: z_dist = TransformedDistribution(dist.Normal(z_loc, z_scale), dmm.iafs) else: z_dist = dist.Normal(z_loc, z_scale) assert z_dist.event_shape == () assert z_dist.batch_shape == (len(mini_batch), dmm.z_q_0.size(0)) # sample z_t from the distribution z_dist annealing_factor = 1.0 with pyro.poutine.scale(scale=annealing_factor): z_t = pyro.sample( "z_%d" % t, z_dist.mask(mini_batch_mask[:, t - 1:t]).to_event(1)) print("z_{}:".format(t), z_t) print(z_t.shape) # compute the probabilities that parameterize the bernoulli likelihood emission_probs_t = dmm.emitter(z_t) emission_probs_t_np = emission_probs_t.detach().numpy() sequence_output.append(emission_probs_t_np) print("x_{}:".format(t), emission_probs_t) print(emission_probs_t.shape) # the latent sampled at this time step will be conditioned upon in the next time step # so keep track of it z_prev = z_t # Run the model another few steps. n_steps = 100 for t in range(1, n_steps + 1): # first compute the parameters of the diagonal gaussian distribution p(z_t | z_{t-1}) z_loc, z_scale = dmm.trans(z_prev) # then sample z_t according to dist.Normal(z_loc, z_scale) # note that we use the reshape method so that the univariate Normal distribution # is treated as a multivariate Normal distribution with a diagonal covariance. with poutine.scale(scale=annealing_factor): z_t = pyro.sample( "z_%d" % t, dist.Normal(z_loc, z_scale).mask( mini_batch_mask[:, t - 1:t]).to_event(1)) # compute the probabilities that parameterize the bernoulli likelihood emission_probs_t = dmm.emitter(z_t) emission_probs_t_np = emission_probs_t.detach().numpy() sequence_output.append(emission_probs_t_np) # # the next statement instructs pyro to observe x_t according to the # # bernoulli distribution p(x_t|z_t) # pyro.sample("obs_x_%d" % t, # # dist.Bernoulli(emission_probs_t) # dist.Normal(emission_probs_t, 0.5) # .mask(mini_batch_mask[:, t - 1:t]) # .to_event(1), # obs=mini_batch[:, t - 1, :]) # the latent sampled at this time step will be conditioned upon # in the next time step so keep track of it z_prev = z_t sequence_output = np.concatenate(sequence_output, axis=1) print(sequence_output.shape) n_plots = 5 fig, axes = plt.subplots(nrows=n_plots, ncols=1) x = range(sequence_output.shape[1]) for i in range(n_plots): input = mini_batch[i, :].numpy().squeeze() output = sequence_output[i, :] axes[i].plot(range(input.shape[0]), input) axes[i].plot(range(len(output)), output) axes[i].grid() # plt.plot(sequence_output[0, :]) plt.show()
def __rsub__(self, x: Union[float, Tensor]): return RandomVariable( TransformedDistribution(self.distribution, AffineTransform(x, -1)))
def test_minibatch(dmm, mini_batch, args, sample_z=True): # Generate data that we can feed into the below fn. test_data_sequences = mini_batch.type(torch.FloatTensor) mini_batch_indices = torch.arange(0, test_data_sequences.size(0)) test_seq_lengths = torch.full( (test_data_sequences.size(0), ), test_data_sequences.size(1)).type(torch.IntTensor) # grab a fully prepped mini-batch using the helper function in the data loader mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \ = poly.get_mini_batch(mini_batch_indices, test_data_sequences, test_seq_lengths, cuda=args.cuda) # Get the initial RNN state. h_0 = dmm.h_0 h_0_contig = h_0.expand(1, mini_batch.size(0), dmm.rnn.hidden_size).contiguous() # Feed the test sequence into the RNN. rnn_output, rnn_hidden_state = dmm.rnn(mini_batch_reversed, h_0_contig) # Reverse the time ordering of the hidden state and unpack it. rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths) # print(rnn_output) # print(rnn_output.shape) # set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...) z_prev = dmm.z_q_0.expand(mini_batch.size(0), dmm.z_q_0.size(0)) # sample the latents z one time step at a time T_max = mini_batch.size(1) sequence_z = [] sequence_output = [] for t in range(1, T_max + 1): # the next two lines assemble the distribution q(z_t | z_{t-1}, x_{t:T}) z_loc, z_scale = dmm.combiner(z_prev, rnn_output[:, t - 1, :]) if sample_z: # if we are using normalizing flows, we apply the sequence of transformations # parameterized by self.iafs to the base distribution defined in the previous line # to yield a transformed distribution that we use for q(z_t|...) if len(dmm.iafs) > 0: z_dist = TransformedDistribution(dist.Normal(z_loc, z_scale), dmm.iafs) else: z_dist = dist.Normal(z_loc, z_scale) assert z_dist.event_shape == () assert z_dist.batch_shape == (len(mini_batch), dmm.z_q_0.size(0)) # sample z_t from the distribution z_dist annealing_factor = 1.0 with pyro.poutine.scale(scale=annealing_factor): z_t = pyro.sample( "z_%d" % t, z_dist.mask(mini_batch_mask[:, t - 1:t]).to_event(1)) else: z_t = z_loc z_t_np = z_t.detach().numpy() z_t_np = z_t_np[:, np.newaxis, :] sequence_z.append(z_t_np) # print("z_{}:".format(t), z_t) # print(z_t.shape) # compute the probabilities that parameterize the bernoulli likelihood emission_probs_t = dmm.emitter(z_t) emission_probs_t_np = emission_probs_t.detach().numpy() sequence_output.append(emission_probs_t_np) # print("x_{}:".format(t), emission_probs_t) # print(emission_probs_t.shape) # the latent sampled at this time step will be conditioned upon in the next time step # so keep track of it z_prev = z_t # Run the model another few steps. n_extra_steps = 100 for t in range(1, n_extra_steps + 1): # first compute the parameters of the diagonal gaussian distribution p(z_t | z_{t-1}) z_loc, z_scale = dmm.trans(z_prev) # then sample z_t according to dist.Normal(z_loc, z_scale) # note that we use the reshape method so that the univariate Normal distribution # is treated as a multivariate Normal distribution with a diagonal covariance. annealing_factor = 1.0 with poutine.scale(scale=annealing_factor): z_t = pyro.sample( "z_%d" % t, dist.Normal(z_loc, z_scale) # .mask(mini_batch_mask[:, t - 1:t]) .to_event(1)) z_t_np = z_t.detach().numpy() z_t_np = z_t_np[:, np.newaxis, :] sequence_z.append(z_t_np) # compute the probabilities that parameterize the bernoulli likelihood emission_probs_t = dmm.emitter(z_t) emission_probs_t_np = emission_probs_t.detach().numpy() sequence_output.append(emission_probs_t_np) # the latent sampled at this time step will be conditioned upon # in the next time step so keep track of it z_prev = z_t sequence_z = np.concatenate(sequence_z, axis=1) sequence_output = np.concatenate(sequence_output, axis=1) # print(sequence_output.shape) # n_plots = 5 # fig, axes = plt.subplots(nrows=n_plots, ncols=1) # x = range(sequence_output.shape[1]) # for i in range(n_plots): # input = mini_batch[i, :].numpy().squeeze() # output = sequence_output[i, :] # axes[i].plot(range(input.shape[0]), input) # axes[i].plot(range(len(output)), output) # axes[i].grid() return mini_batch, sequence_z, sequence_output #fig
def __truediv__(self, x: Union[float, Tensor]): return RandomVariable( TransformedDistribution(self.distribution, AffineTransform(0, 1 / x)))
def __neg__(self): return RandomVariable( TransformedDistribution(self.distribution, AffineTransform(0, -1)))
def unconstrained_prior(self) -> TransformedDistribution: return TransformedDistribution(self(), self.bijection.inv)
def __pow__(self, x): return RandomVariable( TransformedDistribution(self.distribution, PowerTransform(x)))