def __init__(self, mixing_logits, component_stats, component_class, presence=None): """Builds the module. Args: mixing_logits: tensor [B, k, ...] with k the number of components. component_stats: list of tensors of shape [B, k, ...] or broadcastable to these shapes; they are argument to the chosen distribution class. component_class: callable; returns a distribution object. presence: [B, k] tensor of floats in [0, 1] or None. """ super(MixtureDistribution, self).__init__() if presence is not None: mixing_logits += make_brodcastable(safe_log(presence), mixing_logits) self._mixing_logits = mixing_logits component_stats = nest.flatten(component_stats) self._distributions = component_class(*component_stats) self._presence = presence
def _build(self, pose, presence=None, template_feature=None, bg_image=None, img_embedding=None): """Builds the module. Args: pose: [B, n_templates, 6] tensor. presence: [B, n_templates] tensor. template_feature: [B, n_templates, n_features] tensor; these features are used to change templates based on the input, if present. bg_image: [B, *output_size] tensor representing the background. img_embedding: [B, d] tensor containing image embeddings. Returns: [B, n_templates, *output_size, n_channels] tensor. """ batch_size, n_templates = pose.shape[:2].as_list() templates = self.make_templates(n_templates, template_feature) if templates.shape[0] == 1: templates = snt.TileByDim([0], [batch_size])(templates) # it's easier for me to think in inverse coordinates warper = snt.AffineGridWarper(self._output_size, self._template_size) warper = warper.inverse() grid_coords = snt.BatchApply(warper)(pose) resampler = snt.BatchApply(tf.contrib.resampler.resampler) transformed_templates = resampler(templates, grid_coords) if bg_image is not None: bg_image = tf.expand_dims(bg_image, axis=1) else: bg_image = tf.nn.sigmoid(tf.get_variable('bg_value', shape=[1])) bg_image = tf.zeros_like(transformed_templates[:, :1]) + bg_image transformed_templates = tf.concat([transformed_templates, bg_image], axis=1) if presence is not None: presence = tf.concat([presence, tf.ones([batch_size, 1])], axis=1) if True: # pylint: disable=using-constant-test if self._use_alpha_channel: template_mixing_logits = snt.TileByDim([0], [batch_size])( self._templates_alpha) template_mixing_logits = resampler(template_mixing_logits, grid_coords) bg_mixing_logit = tf.nn.softplus( tf.get_variable('bg_mixing_logit', initializer=[0.])) bg_mixing_logit = ( tf.zeros_like(template_mixing_logits[:, :1]) + bg_mixing_logit) template_mixing_logits = tf.concat( [template_mixing_logits, bg_mixing_logit], 1) else: temperature_logit = tf.get_variable('temperature_logit', shape=[1]) temperature = tf.nn.softplus(temperature_logit + .5) + 1e-4 template_mixing_logits = transformed_templates / temperature scale = 1. if self._learn_output_scale: scale = tf.get_variable('scale', shape=[1]) scale = tf.nn.softplus(scale) + 1e-4 if self._output_pdf_type == 'mixture': template_mixing_logits += make_brodcastable( math_ops.safe_log(presence), template_mixing_logits) rec_pdf = prob.MixtureDistribution(template_mixing_logits, [transformed_templates, scale], tfd.Normal) else: raise ValueError('Unknown pdf type: "{}".'.format( self._output_pdf_type)) return AttrDict(raw_templates=tf.squeeze(self._templates, 0), transformed_templates=transformed_templates[:, :-1], mixing_logits=template_mixing_logits[:, :-1], pdf=rec_pdf)
def _build(self, features, parent_transform=None, parent_presence=None): """Builds the module. Args: features: Tensor of encodings of shape [B, n_enc_dims]. parent_transform: Tuple of (matrix, vector). parent_presence: pass Returns: A bunch of stuff. """ batch_size = features.shape.as_list()[0] batch_shape = [batch_size, self._n_caps] # Predict capsule and additional params from the input encoding. # [B, n_caps, n_caps_dims] if self._n_caps_params is not None: # Use separate parameters to do predictions for different capsules. mlp = BatchMLP(self._n_hiddens + [self._n_caps_params]) raw_caps_params = mlp(features) caps_params = tf.reshape(raw_caps_params, batch_shape + [self._n_caps_params]) else: assert features.shape[:2].as_list() == batch_shape caps_params = features if self._caps_dropout_rate == 0.0: caps_exist = tf.ones(batch_shape + [1], dtype=tf.float32) else: pmf = tfd.Bernoulli(1. - self._caps_dropout_rate, dtype=tf.float32) caps_exist = pmf.sample(batch_shape + [1]) caps_params = tf.concat([caps_params, caps_exist], -1) output_shapes = ( [self._n_votes, self._n_transform_params], # CPR_dynamic [1, self._n_transform_params], # CCR [1], # per-capsule presence [self._n_votes], # per-vote-presence [self._n_votes], # per-vote scale ) splits = [np.prod(i).astype(np.int32) for i in output_shapes] n_outputs = sum(splits) # we don't use bias in the output layer in order to separate the static # and dynamic parts of the CPR caps_mlp = BatchMLP([self._n_hiddens, n_outputs], use_bias=False) all_params = caps_mlp(caps_params) all_params = tf.split(all_params, splits, -1) res = [ tf.reshape(i, batch_shape + s) for (i, s) in zip(all_params, output_shapes) ] cpr_dynamic = res[0] # add bias to all remaining outputs res = [snt.AddBias()(i) for i in res[1:]] ccr, pres_logit_per_caps, pres_logit_per_vote, scale_per_vote = res if self._caps_dropout_rate != 0.0: pres_logit_per_caps += math_ops.safe_log(caps_exist) cpr_static = tf.get_variable( 'cpr_static', shape=[1, self._n_caps, self._n_votes, self._n_transform_params]) def add_noise(tensor): """Adds noise to tensors.""" if self._noise_type == 'uniform': noise = tf.random.uniform(tensor.shape, minval=-.5, maxval=.5) * self._noise_scale elif self._noise_type == 'logistic': pdf = tfd.Logistic(0., self._noise_scale) noise = pdf.sample(tensor.shape) elif not self._noise_type: noise = 0. else: raise ValueError('Invalid noise type: "{}".'.format( self._noise_type)) return tensor + noise pres_logit_per_caps = add_noise(pres_logit_per_caps) pres_logit_per_vote = add_noise(pres_logit_per_vote) # this is for hierarchical if parent_transform is None: ccr = self._make_transform(ccr) else: ccr = parent_transform if not self._deformations: cpr_dynamic = tf.zeros_like(cpr_dynamic) cpr = self._make_transform(cpr_dynamic + cpr_static) ccr_per_vote = snt.TileByDim([2], [self._n_votes])(ccr) votes = tf.matmul(ccr_per_vote, cpr) if parent_presence is not None: pres_per_caps = parent_presence else: pres_per_caps = tf.nn.sigmoid(pres_logit_per_caps) pres_per_vote = pres_per_caps * tf.nn.sigmoid(pres_logit_per_vote) if self._learn_vote_scale: # for numerical stability scale_per_vote = tf.nn.softplus(scale_per_vote + .5) + 1e-2 else: scale_per_vote = tf.zeros_like(scale_per_vote) + 1. return AttrDict( vote=votes, scale=scale_per_vote, vote_presence=pres_per_vote, pres_logit_per_caps=pres_logit_per_caps, pres_logit_per_vote=pres_logit_per_vote, dynamic_weights_l2=tf.nn.l2_loss(cpr_dynamic) / batch_size, raw_caps_params=raw_caps_params, raw_caps_features=features, )
def _build(self, x, presence=None): # x is [B, n_input_points, n_input_dims] batch_size, n_input_points = x.shape[:2].as_list() # votes and scale have shape [B, n_caps, n_input_points, n_input_dims|1] # since scale is a per-caps scalar and we have one vote per capsule vote_component_pdf = self._get_pdf(self._votes, tf.expand_dims(self._scales, -1)) # expand along caps dimensions -> [B, 1, n_input_points, n_input_dims] expanded_x = tf.expand_dims(x, 1) vote_log_prob_per_dim = vote_component_pdf.log_prob(expanded_x) # [B, n_caps, n_input_points] vote_log_prob = tf.reduce_sum(vote_log_prob_per_dim, -1) dummy_vote_log_prob = tf.zeros([batch_size, 1, n_input_points]) dummy_vote_log_prob -= 2. * tf.log(10.) # [B, n_caps + 1, n_input_points] vote_log_prob = tf.concat([vote_log_prob, dummy_vote_log_prob], 1) # [B, n_caps, n_input_points] mixing_logits = math_ops.safe_log(self._vote_presence_prob) dummy_logit = tf.zeros([batch_size, 1, 1]) - 2. * tf.log(10.) dummy_logit = snt.TileByDim([2], [n_input_points])(dummy_logit) # [B, n_caps + 1, n_input_points] mixing_logits = tf.concat([mixing_logits, dummy_logit], 1) mixing_log_prob = mixing_logits - tf.reduce_logsumexp( mixing_logits, 1, keepdims=True) # [B, n_input_points] mixture_log_prob_per_point = tf.reduce_logsumexp( mixing_logits + vote_log_prob, 1) if presence is not None: presence = tf.to_float(presence) mixture_log_prob_per_point *= presence # [B,] mixture_log_prob_per_example \ = tf.reduce_sum(mixture_log_prob_per_point, 1) # [] mixture_log_prob_per_batch = tf.reduce_mean( mixture_log_prob_per_example) # [B, n_caps + 1, n_input_points] posterior_mixing_logits_per_point = mixing_logits + vote_log_prob # [B, n_input_points] winning_vote_idx = tf.argmax(posterior_mixing_logits_per_point[:, :-1], 1) batch_idx = tf.expand_dims(tf.range(batch_size, dtype=tf.int64), 1) batch_idx = snt.TileByDim([1], [n_input_points])(batch_idx) point_idx = tf.expand_dims(tf.range(n_input_points, dtype=tf.int64), 0) point_idx = snt.TileByDim([0], [batch_size])(point_idx) idx = tf.stack([batch_idx, winning_vote_idx, point_idx], -1) winning_vote = tf.gather_nd(self._votes, idx) winning_pres = tf.gather_nd(self._vote_presence_prob, idx) vote_presence = tf.greater(mixing_logits[:, :-1], mixing_logits[:, -1:]) # the first four votes belong to the square is_from_capsule = winning_vote_idx // self._n_votes posterior_mixing_probs = tf.nn.softmax( posterior_mixing_logits_per_point, 1) dummy_vote = tf.get_variable('dummy_vote', shape=self._votes[:1, :1].shape) dummy_vote = snt.TileByDim([0], [batch_size])(dummy_vote) dummy_pres = tf.zeros([batch_size, 1, n_input_points]) votes = tf.concat((self._votes, dummy_vote), 1) pres = tf.concat([self._vote_presence_prob, dummy_pres], 1) soft_winner = tf.reduce_sum( tf.expand_dims(posterior_mixing_probs, -1) * votes, 1) soft_winner_pres = tf.reduce_sum(posterior_mixing_probs * pres, 1) posterior_mixing_probs = tf.transpose(posterior_mixing_probs[:, :-1], (0, 2, 1)) assert winning_vote.shape == x.shape return self.OutputTuple( log_prob=mixture_log_prob_per_batch, vote_presence=tf.to_float(vote_presence), winner=winning_vote, winner_pres=winning_pres, soft_winner=soft_winner, soft_winner_pres=soft_winner_pres, posterior_mixing_probs=posterior_mixing_probs, is_from_capsule=is_from_capsule, mixing_logits=mixing_logits, mixing_log_prob=mixing_log_prob, )
def _build(self, x, presence=None): batch_size, n_input_points = x.shape[:2].as_list() # we don't know what order the initial points came in, so we need to create # a big mixture of all votes for every input point # [B, 1, n_votes, n_input_dims] expanded_votes = tf.expand_dims(self._votes, 1) expanded_scale = tf.expand_dims(tf.expand_dims(self._scales, 1), -1) vote_component_pdf = self._get_pdf(expanded_votes, expanded_scale) # [B, n_points, n_caps, n_votes, n_input_dims] expanded_x = tf.expand_dims(x, 2) vote_log_prob_per_dim = vote_component_pdf.log_prob(expanded_x) # [B, n_points, n_votes] vote_log_prob = tf.reduce_sum(vote_log_prob_per_dim, -1) dummy_vote_log_prob = tf.zeros([batch_size, n_input_points, 1]) dummy_vote_log_prob -= 2. * tf.log(10.) vote_log_prob = tf.concat([vote_log_prob, dummy_vote_log_prob], 2) # [B, n_points, n_votes] mixing_logits = math_ops.safe_log(self._vote_presence_prob) dummy_logit = tf.zeros([batch_size, 1]) - 2. * tf.log(10.) mixing_logits = tf.concat([mixing_logits, dummy_logit], 1) mixing_log_prob = mixing_logits - tf.reduce_logsumexp( mixing_logits, 1, keepdims=True) expanded_mixing_logits = tf.expand_dims(mixing_log_prob, 1) mixture_log_prob_per_component \ = tf.reduce_logsumexp(expanded_mixing_logits + vote_log_prob, 2) if presence is not None: presence = tf.to_float(presence) mixture_log_prob_per_component *= presence mixture_log_prob_per_example \ = tf.reduce_sum(mixture_log_prob_per_component, 1) mixture_log_prob_per_batch = tf.reduce_mean( mixture_log_prob_per_example) # [B, n_points, n_votes] posterior_mixing_logits_per_point = expanded_mixing_logits + vote_log_prob # [B, n_points] winning_vote_idx = tf.argmax( posterior_mixing_logits_per_point[:, :, :-1], 2) batch_idx = tf.expand_dims(tf.range(batch_size, dtype=tf.int64), -1) batch_idx = snt.TileByDim([1], [winning_vote_idx.shape[-1]])(batch_idx) idx = tf.stack([batch_idx, winning_vote_idx], -1) winning_vote = tf.gather_nd(self._votes, idx) winning_pres = tf.gather_nd(self._vote_presence_prob, idx) vote_presence = tf.greater(mixing_logits[:, :-1], mixing_logits[:, -1:]) # the first four votes belong to the square is_from_capsule = winning_vote_idx // self._n_votes posterior_mixing_probs = tf.nn.softmax( posterior_mixing_logits_per_point, -1)[Ellipsis, :-1] assert winning_vote.shape == x.shape return self.OutputTuple( log_prob=mixture_log_prob_per_batch, vote_presence=tf.to_float(vote_presence), winner=winning_vote, winner_pres=winning_pres, is_from_capsule=is_from_capsule, mixing_logits=mixing_logits, mixing_log_prob=mixing_log_prob, # TODO(adamrk): this is broken soft_winner=tf.zeros_like(winning_vote), soft_winner_pres=tf.zeros_like(winning_pres), posterior_mixing_probs=posterior_mixing_probs, )