def __init__(self, mixing_logits, component_stats, component_class, presence=None): """Builds the module. Args: mixing_logits: tensor [B, k, ...] with k the number of components. component_stats: list of tensors of shape [B, k, ...] or broadcastable to these shapes; they are argument to the chosen distribution class. component_class: callable; returns a distribution object. presence: [B, k] tensor of floats in [0, 1] or None. """ super(MixtureDistribution, self).__init__() if presence is not None: mixing_logits += make_brodcastable(safe_log(presence), mixing_logits) self._mixing_logits = mixing_logits component_stats = nest.flatten(component_stats) self._distributions = component_class(*component_stats) self._presence = presence
def _maybe_mask(self, tensor): if self._presence is None: return tensor pres = make_brodcastable(self._presence, tensor) return tensor * pres
def _build(self, pose, presence=None, template_feature=None, bg_image=None, img_embedding=None): """Builds the module. Args: pose: [B, n_templates, 6] tensor. presence: [B, n_templates] tensor. template_feature: [B, n_templates, n_features] tensor; these features are used to change templates based on the input, if present. bg_image: [B, *output_size] tensor representing the background. img_embedding: [B, d] tensor containing image embeddings. Returns: [B, n_templates, *output_size, n_channels] tensor. """ batch_size, n_templates = pose.shape[:2].as_list() templates = self.make_templates(n_templates, template_feature) if templates.shape[0] == 1: templates = snt.TileByDim([0], [batch_size])(templates) # it's easier for me to think in inverse coordinates warper = snt.AffineGridWarper(self._output_size, self._template_size) warper = warper.inverse() grid_coords = snt.BatchApply(warper)(pose) resampler = snt.BatchApply(contrib_resampler.resampler) transformed_templates = resampler(templates, grid_coords) if bg_image is not None: bg_image = tf.expand_dims(bg_image, axis=1) else: bg_image = tf.nn.sigmoid(tf.get_variable('bg_value', shape=[1])) bg_image = tf.zeros_like(transformed_templates[:, :1]) + bg_image transformed_templates = tf.concat([transformed_templates, bg_image], axis=1) if presence is not None: presence = tf.concat([presence, tf.ones([batch_size, 1])], axis=1) if True: # pylint: disable=using-constant-test if self._use_alpha_channel: template_mixing_logits = snt.TileByDim([0], [batch_size])( self._templates_alpha) template_mixing_logits = resampler(template_mixing_logits, grid_coords) bg_mixing_logit = tf.nn.softplus( tf.get_variable('bg_mixing_logit', initializer=[0.])) bg_mixing_logit = ( tf.zeros_like(template_mixing_logits[:, :1]) + bg_mixing_logit) template_mixing_logits = tf.concat( [template_mixing_logits, bg_mixing_logit], 1) else: temperature_logit = tf.get_variable('temperature_logit', shape=[1]) temperature = tf.nn.softplus(temperature_logit + .5) + 1e-4 template_mixing_logits = transformed_templates / temperature scale = 1. if self._learn_output_scale: scale = tf.get_variable('scale', shape=[1]) scale = tf.nn.softplus(scale) + 1e-4 if self._output_pdf_type == 'mixture': template_mixing_logits += make_brodcastable( math_ops.safe_log(presence), template_mixing_logits) rec_pdf = prob.MixtureDistribution(template_mixing_logits, [transformed_templates, scale], tfd.Normal) else: raise ValueError('Unknown pdf type: "{}".'.format( self._output_pdf_type)) return AttrDict(raw_templates=tf.squeeze(self._templates, 0), transformed_templates=transformed_templates[:, :-1], mixing_logits=template_mixing_logits[:, :-1], pdf=rec_pdf)