def map_func(data): """Replicates data if necessary.""" data = dict(data) if n_replicas > 1: tile_by_batch = snt.TileByDim([0], [n_replicas]) data = {k: tile_by_batch(v) for k, v in data.items()} if transforms is not None: img = data['image'] for k, transform in transforms.items(): data[k] = transform(img) return data
def _self_attention(self, x, presence): head_before = MultiHeadQKVAttention(self._n_heads) # head_after = MultiHeadQKVAttention(self._n_heads) head_after = head_before inducing_points = tf.get_variable( 'inducing_points', shape=[1, self._n_inducing_points, int(x.shape[-1])]) inducing_points = snt.TileByDim([0], [int(x.shape[0])])(inducing_points) z = head_before(inducing_points, x, x) y = head_after(x, z, z) return y
def _build(self, x): """Applies the module. Args: x: tensor of shape [B, k, d]. Returns: Tensor of shape [B, k, n_units]. """ # batch_size, n_inputs, n_dims = x.shape.as_list() shape = x.shape.as_list() if 'w' not in self.initializers: stddev = 1 / math.sqrt(shape[-1]) self.initializers['w'] = tf.truncated_normal_initializer( stddev=stddev) weights_shape = shape + [self._n_units] tiles = [] for i in self._tile_dims: tiles.append(weights_shape[i]) weights_shape[i] = 1 weights = tf.get_variable('weights', shape=weights_shape, initializer=self._init('w')) weights = snt.TileByDim(self._tile_dims, tiles)(weights) x = tf.expand_dims(x, -2) print(x.shape) print(weights.shape) y = tf.matmul(x, weights) y = tf.squeeze(y, -2) if self._use_bias: if 'b' not in self.initializers: self.initializers['b'] = tf.zeros_initializer() init = dict(b=self._init('b')) bias_dims = [ i for i in range(len(shape)) if i not in self._tile_dims ] add_bias = snt.AddBias(bias_dims=bias_dims, initializers=init) y = add_bias(y) return y
def render_by_scatter(size, points, colors=None, gt_presence=None): """Renders point by using tf.scatter_nd.""" if colors is None: colors = tf.ones(points.shape[:-1].as_list() + [3], dtype=tf.float32) if gt_presence is not None: colors *= tf.cast(tf.expand_dims(gt_presence, -1), colors.dtype) batch_size, n_points = points.shape[:-1].as_list() shape = [batch_size] + list(size) + [3] batch_idx = tf.reshape(tf.range(batch_size), [batch_size, 1, 1]) batch_idx = snt.TileByDim([1], [n_points])(batch_idx) idx = tf.concat([batch_idx, tf.cast(points, tf.int32)], -1) return tf.scatter_nd(idx, colors, shape)
def clevr_veggies_map_func(index, image, label): #st() data = {'index': index, 'image': image, 'label': label} if n_replicas > 1: print('n_replicas: ', n_replicas) tile_by_batch = snt.TileByDim([0], [n_replicas]) data = {k: tile_by_batch(v) for k, v in data.items()} # print(data) if transforms is not None: img = data['image'] # print('before transforms: ', data) for k, transform in transforms.items(): data[k] = transform(img) # print('after transforms: ', data) return data
def _build(self, x, presence=None): batch_size = int(x.shape[0]) h = snt.BatchApply(snt.Linear(self._n_dims))(x) args = [self._n_heads, self._layer_norm, self._dropout_rate] klass = SelfAttention if self._n_inducing_points > 0: args = [self._n_inducing_points] + args klass = InducedSelfAttention for _ in range(self._n_layers): h = klass(*args)(h, presence) z = snt.BatchApply(snt.Linear(self._n_output_dims))(h) inducing_points = tf.get_variable( 'inducing_points', shape=[1, self._n_outputs, self._n_output_dims]) inducing_points = snt.TileByDim([0], [batch_size])(inducing_points) return MultiHeadQKVAttention(self._n_heads)(inducing_points, z, z, presence)
def _build(self, features, parent_transform=None, parent_presence=None): """Builds the module. Args: features: Tensor of encodings of shape [B, n_enc_dims]. parent_transform: Tuple of (matrix, vector). parent_presence: pass Returns: A bunch of stuff. """ batch_size = features.shape.as_list()[0] batch_shape = [batch_size, self._n_caps] # Predict capsule and additional params from the input encoding. # [B, n_caps, n_caps_dims] if self._n_caps_params is not None: # Use separate parameters to do predictions for different capsules. mlp = BatchMLP(self._n_hiddens + [self._n_caps_params]) raw_caps_params = mlp(features) caps_params = tf.reshape(raw_caps_params, batch_shape + [self._n_caps_params]) else: assert features.shape[:2].as_list() == batch_shape caps_params = features if self._caps_dropout_rate == 0.0: caps_exist = tf.ones(batch_shape + [1], dtype=tf.float32) else: pmf = tfd.Bernoulli(1. - self._caps_dropout_rate, dtype=tf.float32) caps_exist = pmf.sample(batch_shape + [1]) caps_params = tf.concat([caps_params, caps_exist], -1) output_shapes = ( [self._n_votes, self._n_transform_params], # CPR_dynamic [1, self._n_transform_params], # CCR [1], # per-capsule presence [self._n_votes], # per-vote-presence [self._n_votes], # per-vote scale ) splits = [np.prod(i).astype(np.int32) for i in output_shapes] n_outputs = sum(splits) # we don't use bias in the output layer in order to separate the static # and dynamic parts of the CPR caps_mlp = BatchMLP([self._n_hiddens, n_outputs], use_bias=False) all_params = caps_mlp(caps_params) all_params = tf.split(all_params, splits, -1) res = [ tf.reshape(i, batch_shape + s) for (i, s) in zip(all_params, output_shapes) ] cpr_dynamic = res[0] # add bias to all remaining outputs res = [snt.AddBias()(i) for i in res[1:]] ccr, pres_logit_per_caps, pres_logit_per_vote, scale_per_vote = res if self._caps_dropout_rate != 0.0: pres_logit_per_caps += math_ops.safe_log(caps_exist) cpr_static = tf.get_variable( 'cpr_static', shape=[1, self._n_caps, self._n_votes, self._n_transform_params]) def add_noise(tensor): """Adds noise to tensors.""" if self._noise_type == 'uniform': noise = tf.random.uniform(tensor.shape, minval=-.5, maxval=.5) * self._noise_scale elif self._noise_type == 'logistic': pdf = tfd.Logistic(0., self._noise_scale) noise = pdf.sample(tensor.shape) elif not self._noise_type: noise = 0. else: raise ValueError('Invalid noise type: "{}".'.format( self._noise_type)) return tensor + noise pres_logit_per_caps = add_noise(pres_logit_per_caps) pres_logit_per_vote = add_noise(pres_logit_per_vote) # this is for hierarchical if parent_transform is None: ccr = self._make_transform(ccr) else: ccr = parent_transform if not self._deformations: cpr_dynamic = tf.zeros_like(cpr_dynamic) cpr = self._make_transform(cpr_dynamic + cpr_static) ccr_per_vote = snt.TileByDim([2], [self._n_votes])(ccr) votes = tf.matmul(ccr_per_vote, cpr) if parent_presence is not None: pres_per_caps = parent_presence else: pres_per_caps = tf.nn.sigmoid(pres_logit_per_caps) pres_per_vote = pres_per_caps * tf.nn.sigmoid(pres_logit_per_vote) if self._learn_vote_scale: # for numerical stability scale_per_vote = tf.nn.softplus(scale_per_vote + .5) + 1e-2 else: scale_per_vote = tf.zeros_like(scale_per_vote) + 1. return AttrDict( vote=votes, scale=scale_per_vote, vote_presence=pres_per_vote, pres_logit_per_caps=pres_logit_per_caps, pres_logit_per_vote=pres_logit_per_vote, dynamic_weights_l2=tf.nn.l2_loss(cpr_dynamic) / batch_size, raw_caps_params=raw_caps_params, raw_caps_features=features, )
def _build(self, x, presence=None): # x is [B, n_input_points, n_input_dims] batch_size, n_input_points = x.shape[:2].as_list() # votes and scale have shape [B, n_caps, n_input_points, n_input_dims|1] # since scale is a per-caps scalar and we have one vote per capsule vote_component_pdf = self._get_pdf(self._votes, tf.expand_dims(self._scales, -1)) # expand along caps dimensions -> [B, 1, n_input_points, n_input_dims] expanded_x = tf.expand_dims(x, 1) vote_log_prob_per_dim = vote_component_pdf.log_prob(expanded_x) # [B, n_caps, n_input_points] vote_log_prob = tf.reduce_sum(vote_log_prob_per_dim, -1) dummy_vote_log_prob = tf.zeros([batch_size, 1, n_input_points]) dummy_vote_log_prob -= 2. * tf.log(10.) # [B, n_caps + 1, n_input_points] vote_log_prob = tf.concat([vote_log_prob, dummy_vote_log_prob], 1) # [B, n_caps, n_input_points] mixing_logits = math_ops.safe_log(self._vote_presence_prob) dummy_logit = tf.zeros([batch_size, 1, 1]) - 2. * tf.log(10.) dummy_logit = snt.TileByDim([2], [n_input_points])(dummy_logit) # [B, n_caps + 1, n_input_points] mixing_logits = tf.concat([mixing_logits, dummy_logit], 1) mixing_log_prob = mixing_logits - tf.reduce_logsumexp( mixing_logits, 1, keepdims=True) # [B, n_input_points] mixture_log_prob_per_point = tf.reduce_logsumexp( mixing_logits + vote_log_prob, 1) if presence is not None: presence = tf.cast(presence, tf.float32) mixture_log_prob_per_point *= presence # [B,] mixture_log_prob_per_example\ = tf.reduce_sum(mixture_log_prob_per_point, 1) # [] mixture_log_prob_per_batch = tf.reduce_mean( mixture_log_prob_per_example) # [B, n_caps + 1, n_input_points] posterior_mixing_logits_per_point = mixing_logits + vote_log_prob # [B, n_input_points] winning_vote_idx = tf.argmax(posterior_mixing_logits_per_point[:, :-1], 1) batch_idx = tf.expand_dims(tf.range(batch_size, dtype=tf.int64), 1) batch_idx = snt.TileByDim([1], [n_input_points])(batch_idx) point_idx = tf.expand_dims(tf.range(n_input_points, dtype=tf.int64), 0) point_idx = snt.TileByDim([0], [batch_size])(point_idx) idx = tf.stack([batch_idx, winning_vote_idx, point_idx], -1) winning_vote = tf.gather_nd(self._votes, idx) winning_pres = tf.gather_nd(self._vote_presence_prob, idx) vote_presence = tf.greater(mixing_logits[:, :-1], mixing_logits[:, -1:]) # the first four votes belong to the square is_from_capsule = winning_vote_idx // self._n_votes posterior_mixing_probs = tf.nn.softmax( posterior_mixing_logits_per_point, 1) dummy_vote = tf.get_variable('dummy_vote', shape=self._votes[:1, :1].shape) dummy_vote = snt.TileByDim([0], [batch_size])(dummy_vote) dummy_pres = tf.zeros([batch_size, 1, n_input_points]) votes = tf.concat((self._votes, dummy_vote), 1) pres = tf.concat([self._vote_presence_prob, dummy_pres], 1) soft_winner = tf.reduce_sum( tf.expand_dims(posterior_mixing_probs, -1) * votes, 1) soft_winner_pres = tf.reduce_sum(posterior_mixing_probs * pres, 1) posterior_mixing_probs = tf.transpose(posterior_mixing_probs[:, :-1], (0, 2, 1)) assert winning_vote.shape == x.shape return self.OutputTuple( log_prob=mixture_log_prob_per_batch, vote_presence=tf.cast(vote_presence, tf.float32), winner=winning_vote, winner_pres=winning_pres, soft_winner=soft_winner, soft_winner_pres=soft_winner_pres, posterior_mixing_probs=posterior_mixing_probs, is_from_capsule=is_from_capsule, mixing_logits=mixing_logits, mixing_log_prob=mixing_log_prob, )
def _build(self, x, presence=None): batch_size, n_input_points = x.shape[:2].as_list() # we don't know what order the initial points came in, so we need to create # a big mixture of all votes for every input point # [B, 1, n_votes, n_input_dims] expanded_votes = tf.expand_dims(self._votes, 1) expanded_scale = tf.expand_dims(tf.expand_dims(self._scales, 1), -1) vote_component_pdf = self._get_pdf(expanded_votes, expanded_scale) # [B, n_points, n_caps, n_votes, n_input_dims] expanded_x = tf.expand_dims(x, 2) vote_log_prob_per_dim = vote_component_pdf.log_prob(expanded_x) # [B, n_points, n_votes] vote_log_prob = tf.reduce_sum(vote_log_prob_per_dim, -1) dummy_vote_log_prob = tf.zeros([batch_size, n_input_points, 1]) dummy_vote_log_prob -= 2. * tf.log(10.) vote_log_prob = tf.concat([vote_log_prob, dummy_vote_log_prob], 2) # [B, n_points, n_votes] mixing_logits = math_ops.safe_log(self._vote_presence_prob) dummy_logit = tf.zeros([batch_size, 1]) - 2. * tf.log(10.) mixing_logits = tf.concat([mixing_logits, dummy_logit], 1) mixing_log_prob = mixing_logits - tf.reduce_logsumexp( mixing_logits, 1, keepdims=True) expanded_mixing_logits = tf.expand_dims(mixing_log_prob, 1) mixture_log_prob_per_component\ = tf.reduce_logsumexp(expanded_mixing_logits + vote_log_prob, 2) if presence is not None: presence = tf.cast(presence, tf.float32) mixture_log_prob_per_component *= presence mixture_log_prob_per_example\ = tf.reduce_sum(mixture_log_prob_per_component, 1) mixture_log_prob_per_batch = tf.reduce_mean( mixture_log_prob_per_example) # [B, n_points, n_votes] posterior_mixing_logits_per_point = expanded_mixing_logits + vote_log_prob # [B, n_points] winning_vote_idx = tf.argmax( posterior_mixing_logits_per_point[:, :, :-1], 2) batch_idx = tf.expand_dims(tf.range(batch_size, dtype=tf.int64), -1) batch_idx = snt.TileByDim([1], [winning_vote_idx.shape[-1]])(batch_idx) idx = tf.stack([batch_idx, winning_vote_idx], -1) winning_vote = tf.gather_nd(self._votes, idx) winning_pres = tf.gather_nd(self._vote_presence_prob, idx) vote_presence = tf.greater(mixing_logits[:, :-1], mixing_logits[:, -1:]) # the first four votes belong to the square is_from_capsule = winning_vote_idx // self._n_votes posterior_mixing_probs = tf.nn.softmax( posterior_mixing_logits_per_point, -1)[Ellipsis, :-1] assert winning_vote.shape == x.shape return self.OutputTuple( log_prob=mixture_log_prob_per_batch, vote_presence=tf.cast(vote_presence, tf.float32), winner=winning_vote, winner_pres=winning_pres, is_from_capsule=is_from_capsule, mixing_logits=mixing_logits, mixing_log_prob=mixing_log_prob, # TODO(adamrk): this is broken soft_winner=tf.zeros_like(winning_vote), soft_winner_pres=tf.zeros_like(winning_pres), posterior_mixing_probs=posterior_mixing_probs, )
def naive_log_likelihood(x, presence=None): """Implementation from original repo ripped wholesale""" batch_size, n_input_points = x.shape[:2].as_list() # Generate gaussian mixture pdfs... # [B, 1, n_votes, n_input_dims] expanded_votes = tf.expand_dims(_votes, 1) expanded_scale = tf.expand_dims(tf.expand_dims(_scales, 1), -1) vote_component_pdf = _get_pdf(expanded_votes, expanded_scale) # For each part, evaluates all capsule, vote mixture likelihoods # [B, n_points, n_caps x n_votes, n_input_dims] expanded_x = tf.expand_dims(x, 2) vote_log_prob_per_dim = vote_component_pdf.log_prob(expanded_x) # Compressing mixture likelihood across all part dimension (ie. 2d point) # [B, n_points, n_caps x n_votes] vote_log_prob = tf.reduce_sum(vote_log_prob_per_dim, -1) dummy_vote_log_prob = tf.zeros([batch_size, n_input_points, 1]) dummy_vote_log_prob -= 2. * tf.log(10.) # adding extra [B, n_points, n_caps x n_votes] to end. WHY? vote_log_prob = tf.concat([vote_log_prob, dummy_vote_log_prob], 2) # [B, n_points, n_caps x n_votes] # CONDITIONAL LOGIT a_(k,n) mixing_logits = math_ops.safe_log(_vote_presence_prob) dummy_logit = tf.zeros([batch_size, 1]) - 2. * tf.log(10.) mixing_logits = tf.concat([mixing_logits, dummy_logit], 1) # # Following seems relevant only towards compressing ll for loss. # REDUNDANCY # # mixing_logits -> presence (a) # vote_log_prob -> Gaussian value (one per vote) for each coordinate # BAD -> vote presence / summed vote presence mixing_log_prob = mixing_logits - tf.reduce_logsumexp( mixing_logits, 1, keepdims=True) # BAD -> mixing presence (above) * each vote gaussian prob expanded_mixing_logits = tf.expand_dims(mixing_log_prob, 1) # Reduce to loglikelihood given k,n combination (capsule, vote) mixture_log_prob_per_component\ = tf.reduce_logsumexp(expanded_mixing_logits + vote_log_prob, 2) if presence is not None: presence = tf.to_float(presence) mixture_log_prob_per_component *= presence # Reduce votes to single capsule # ^ Misleading, reducing across all parts, multiplying log # likelihoods for each part _wrt all capsules_. mixture_log_prob_per_example\ = tf.reduce_sum(mixture_log_prob_per_component, 1) # Same as above but across all compressed part likelihoods in a batch. mixture_log_prob_per_batch = tf.reduce_mean(mixture_log_prob_per_example) # # Back from compression to argmax (routing to proper k) # # [B, n_points, n_votes] posterior_mixing_logits_per_point = expanded_mixing_logits + vote_log_prob # [B, n_points] winning_vote_idx = tf.argmax(posterior_mixing_logits_per_point[:, :, :-1], 2) batch_idx = tf.expand_dims(tf.range(batch_size, dtype=tf.int64), -1) batch_idx = snt.TileByDim([1], [winning_vote_idx.shape[-1]])(batch_idx) idx = tf.stack([batch_idx, winning_vote_idx], -1) winning_vote = tf.gather_nd(_votes, idx) winning_pres = tf.gather_nd(_vote_presence_prob, idx) vote_presence = tf.greater(mixing_logits[:, :-1], mixing_logits[:, -1:]) # the first four votes belong to the square # Just assuming the votes are ordered by capsule... is_from_capsule = winning_vote_idx // _n_votes posterior_mixing_probs = tf.nn.softmax(posterior_mixing_logits_per_point, -1)[Ellipsis, :-1] assert winning_vote.shape == x.shape return OutputTuple( log_prob=mixture_log_prob_per_batch, vote_presence=tf.to_float(vote_presence), winner=winning_vote, winner_pres=winning_pres, is_from_capsule=is_from_capsule, mixing_logits=mixing_logits, mixing_log_prob=mixing_log_prob, # TODO(adamrk): this is broken soft_winner=tf.zeros_like(winning_vote), soft_winner_pres=tf.zeros_like(winning_pres), posterior_mixing_probs=posterior_mixing_probs, )
def argmax_log_likelihood(x, presence=None): """Most simple of the optimization schemes. Skip the product of closeform probability of part given _all_ data. Rather use the value at the argmax as a proxy for each part. """ batch_size, n_input_points = x.shape[:2].as_list() # Generate gaussian mixture pdfs... # [B, 1, n_votes, n_input_dims] expanded_votes = tf.expand_dims(_votes, 1) expanded_scale = tf.expand_dims(tf.expand_dims(_scales, 1), -1) vote_component_pdf = _get_pdf(expanded_votes, expanded_scale) # For each part, evaluates all capsule, vote mixture likelihoods # [B, n_points, n_caps x n_votes, n_input_dims] expanded_x = tf.expand_dims(x, 2) vote_log_prob_per_dim = vote_component_pdf.log_prob(expanded_x) # Compressing mixture likelihood across all part dimension (ie. 2d point) # [B, n_points, n_caps x n_votes] vote_log_prob = tf.reduce_sum(vote_log_prob_per_dim, -1) dummy_vote_log_prob = tf.zeros([batch_size, n_input_points, 1]) dummy_vote_log_prob -= 2. * tf.log(10.) # adding extra [B, n_points, n_caps x n_votes] to end. WHY? vote_log_prob = tf.concat([vote_log_prob, dummy_vote_log_prob], 2) # [B, n_points, n_caps x n_votes] # CONDITIONAL LOGIT a_(k,n) mixing_logits = math_ops.safe_log(_vote_presence_prob) dummy_logit = tf.zeros([batch_size, 1]) - 2. * tf.log(10.) mixing_logits = tf.concat([mixing_logits, dummy_logit], 1) # BAD -> vote presence / summed vote presence mixing_log_prob = mixing_logits - tf.reduce_logsumexp( mixing_logits, 1, keepdims=True) expanded_mixing_logits = tf.expand_dims(mixing_log_prob, 1) # [B, n_points, n_votes] posterior_mixing_logits_per_point = expanded_mixing_logits + vote_log_prob # [B, n_points] winning_vote_idx = tf.argmax(posterior_mixing_logits_per_point[:, :, :-1], 2) batch_idx = tf.expand_dims(tf.range(batch_size, dtype=tf.int64), -1) batch_idx = snt.TileByDim([1], [winning_vote_idx.shape[-1]])(batch_idx) idx = tf.stack([batch_idx, winning_vote_idx], -1) winning_vote = tf.gather_nd(_votes, idx) winning_pres = tf.gather_nd(_vote_presence_prob, idx) vote_presence = tf.greater(mixing_logits[:, :-1], mixing_logits[:, -1:]) # the first four votes belong to the square # Just assuming the votes are ordered by capsule... is_from_capsule = winning_vote_idx // _n_votes posterior_mixing_probs = tf.nn.softmax(posterior_mixing_logits_per_point, -1)[Ellipsis, :-1] assert winning_vote.shape == x.shape # log_prob=mixture_log_prob_per_batch, return OutputTuple( log_prob=None, vote_presence=tf.to_float(vote_presence), winner=winning_vote, winner_pres=winning_pres, is_from_capsule=is_from_capsule, mixing_logits=mixing_logits, mixing_log_prob=mixing_log_prob, # TODO(adamrk): this is broken soft_winner=tf.zeros_like(winning_vote), soft_winner_pres=tf.zeros_like(winning_pres), posterior_mixing_probs=posterior_mixing_probs, )
def _build(self, data): input_x = self._img(data, False) target_x = self._img(data, prep=self._prep) batch_size = int(input_x.shape[0]) primary_caps = self._primary_encoder(input_x) pres = primary_caps.presence expanded_pres = tf.expand_dims(pres, -1) pose = primary_caps.pose input_pose = tf.concat([pose, 1. - expanded_pres], -1) input_pres = pres if self._stop_grad_caps_inpt: input_pose = tf.stop_gradient(input_pose) input_pres = tf.stop_gradient(pres) target_pose, target_pres = pose, pres if self._stop_grad_caps_target: target_pose = tf.stop_gradient(target_pose) target_pres = tf.stop_gradient(target_pres) # skip connection from the img to the higher level capsule if primary_caps.feature is not None: input_pose = tf.concat([input_pose, primary_caps.feature], -1) # try to feed presence as a separate input # and if that works, concatenate templates to poses # this is necessary for set transformer n_templates = int(primary_caps.pose.shape[1]) templates = self._primary_decoder.make_templates( n_templates, primary_caps.feature) try: if self._feed_templates: inpt_templates = templates if self._stop_grad_caps_inpt: inpt_templates = tf.stop_gradient(inpt_templates) if inpt_templates.shape[0] == 1: inpt_templates = snt.TileByDim( [0], [batch_size])(inpt_templates) inpt_templates = snt.BatchFlatten(2)(inpt_templates) pose_with_templates = tf.concat([input_pose, inpt_templates], -1) else: pose_with_templates = input_pose h = self._encoder(pose_with_templates, input_pres) except TypeError: h = self._encoder(input_pose) res = self._decoder(h, target_pose, target_pres) res.primary_presence = primary_caps.presence if self._vote_type == 'enc': primary_dec_vote = primary_caps.pose elif self._vote_type == 'soft': primary_dec_vote = res.soft_winner elif self._vote_type == 'hard': primary_dec_vote = res.winner else: raise ValueError('Invalid vote_type="{}"".'.format( self._vote_type)) if self._pres_type == 'enc': primary_dec_pres = pres elif self._pres_type == 'soft': primary_dec_pres = res.soft_winner_pres elif self._pres_type == 'hard': primary_dec_pres = res.winner_pres else: raise ValueError('Invalid pres_type="{}"".'.format( self._pres_type)) res.bottom_up_rec = self._primary_decoder( primary_caps.pose, primary_caps.presence, template_feature=primary_caps.feature, img_embedding=primary_caps.img_embedding) res.top_down_rec = self._primary_decoder( res.winner, primary_caps.presence, template_feature=primary_caps.feature, img_embedding=primary_caps.img_embedding) rec = self._primary_decoder(primary_dec_vote, primary_dec_pres, template_feature=primary_caps.feature, img_embedding=primary_caps.img_embedding) tile = snt.TileByDim([0], [res.vote.shape[1]]) tiled_presence = tile(primary_caps.presence) tiled_feature = primary_caps.feature if tiled_feature is not None: tiled_feature = tile(tiled_feature) tiled_img_embedding = tile(primary_caps.img_embedding) res.top_down_per_caps_rec = self._primary_decoder( snt.MergeDims(0, 2)(res.vote), snt.MergeDims(0, 2)(res.vote_presence) * tiled_presence, template_feature=tiled_feature, img_embedding=tiled_img_embedding) res.templates = templates res.template_pres = pres res.used_templates = rec.transformed_templates res.rec_mode = rec.pdf.mode() res.rec_mean = rec.pdf.mean() res.mse_per_pixel = tf.square(target_x - res.rec_mode) res.mse = math_ops.flat_reduce(res.mse_per_pixel) res.rec_ll_per_pixel = rec.pdf.log_prob(target_x) res.rec_ll = math_ops.flat_reduce(res.rec_ll_per_pixel) n_points = int(res.posterior_mixing_probs.shape[1]) mass_explained_by_capsule = tf.reduce_sum(res.posterior_mixing_probs, 1) (res.posterior_within_sparsity_loss, res.posterior_between_sparsity_loss) = _capsule.sparsity_loss( self._posterior_sparsity_loss_type, mass_explained_by_capsule / n_points, num_classes=self._n_classes) (res.prior_within_sparsity_loss, res.prior_between_sparsity_loss) = _capsule.sparsity_loss( self._prior_sparsity_loss_type, res.caps_presence_prob, num_classes=self._n_classes, within_example_constant=self._prior_within_example_constant) label = self._label(data) if label is not None: res.posterior_cls_xe, res.posterior_cls_acc = probe.classification_probe( mass_explained_by_capsule, label, self._n_classes, labeled=data.get('labeled', None)) res.prior_cls_xe, res.prior_cls_acc = probe.classification_probe( res.caps_presence_prob, label, self._n_classes, labeled=data.get('labeled', None)) res.best_cls_acc = tf.maximum(res.prior_cls_acc, res.posterior_cls_acc) res.primary_caps_l1 = math_ops.flat_reduce(res.primary_presence) if self._weight_decay > 0.0: decay_losses_list = [] for var in tf.trainable_variables(): if 'w:' in var.name or 'weights:' in var.name: decay_losses_list.append(tf.nn.l2_loss(var)) res.weight_decay_loss = tf.reduce_sum(decay_losses_list) else: res.weight_decay_loss = 0.0 return res
def _build(self, pose, presence=None, template_feature=None, bg_image=None, img_embedding=None): """Builds the module. Args: pose: [B, n_templates, 6] tensor. presence: [B, n_templates] tensor. template_feature: [B, n_templates, n_features] tensor; these features are used to change templates based on the input, if present. bg_image: [B, *output_size] tensor representing the background. img_embedding: [B, d] tensor containing image embeddings. Returns: [B, n_templates, *output_size, n_channels] tensor. """ batch_size, n_templates = pose.shape[:2].as_list() templates = self.make_templates(n_templates, template_feature) if templates.shape[0] == 1: templates = snt.TileByDim([0], [batch_size])(templates) # it's easier for me to think in inverse coordinates warper = snt.AffineGridWarper(self._output_size, self._template_size) warper = warper.inverse() grid_coords = snt.BatchApply(warper)(pose) resampler = snt.BatchApply(contrib_resampler.resampler) transformed_templates = resampler(templates, grid_coords) if bg_image is not None: bg_image = tf.expand_dims(bg_image, axis=1) else: bg_image = tf.nn.sigmoid(tf.get_variable('bg_value', shape=[1])) bg_image = tf.zeros_like(transformed_templates[:, :1]) + bg_image transformed_templates = tf.concat([transformed_templates, bg_image], axis=1) if presence is not None: presence = tf.concat([presence, tf.ones([batch_size, 1])], axis=1) if True: # pylint: disable=using-constant-test if self._use_alpha_channel: template_mixing_logits = snt.TileByDim([0], [batch_size])( self._templates_alpha) template_mixing_logits = resampler(template_mixing_logits, grid_coords) bg_mixing_logit = tf.nn.softplus( tf.get_variable('bg_mixing_logit', initializer=[0.])) bg_mixing_logit = ( tf.zeros_like(template_mixing_logits[:, :1]) + bg_mixing_logit) template_mixing_logits = tf.concat( [template_mixing_logits, bg_mixing_logit], 1) else: temperature_logit = tf.get_variable('temperature_logit', shape=[1]) temperature = tf.nn.softplus(temperature_logit + .5) + 1e-4 template_mixing_logits = transformed_templates / temperature scale = 1. if self._learn_output_scale: scale = tf.get_variable('scale', shape=[1]) scale = tf.nn.softplus(scale) + 1e-4 if self._output_pdf_type == 'mixture': template_mixing_logits += make_brodcastable( math_ops.safe_log(presence), template_mixing_logits) rec_pdf = prob.MixtureDistribution(template_mixing_logits, [transformed_templates, scale], tfd.Normal) else: raise ValueError('Unknown pdf type: "{}".'.format( self._output_pdf_type)) return AttrDict(raw_templates=tf.squeeze(self._templates, 0), transformed_templates=transformed_templates[:, :-1], mixing_logits=template_mixing_logits[:, :-1], pdf=rec_pdf)
def render_constellations(pred_points, capsule_num, canvas_size, gt_points=None, n_caps=2, gt_presence=None, pred_presence=None, caps_presence_prob=None): """Renderes predicted and ground-truth points as gaussian blobs. Args: pred_points: [B, m, 2]. capsule_num: [B, m] tensor indicating which capsule the corresponding point comes from. Plots from different capsules are plotted with different colors. Currently supported values: {0, 1, ..., 11}. canvas_size: tuple of ints gt_points: [B, k, 2]; plots ground-truth points if present. n_caps: integer, number of capsules. gt_presence: [B, k] binary tensor. pred_presence: [B, m] binary tensor. caps_presence_prob: [B, m], a tensor of presence probabilities for caps. Returns: [B, *canvas_size] tensor with plotted points """ # convert coords to be in [0, side_length] pred_points = denormalize_coords(pred_points, canvas_size, rounded=True) # render predicted points batch_size, n_points = pred_points.shape[:2].as_list() capsule_num = tf.to_float(tf.one_hot(capsule_num, depth=n_caps)) capsule_num = tf.reshape(capsule_num, [batch_size, n_points, 1, 1, n_caps, 1]) color = tf.convert_to_tensor(_COLORS[:n_caps]) color = tf.reshape(color, [1, 1, 1, 1, n_caps, 3]) * capsule_num color = tf.reduce_sum(color, -2) color = tf.squeeze(tf.squeeze(color, 3), 2) colored = render_by_scatter(canvas_size, pred_points, color, pred_presence) # Prepare a vertical separator between predicted and gt points. # Separator is composed of all supported colors and also serves as # a legend. # [b, h, w, 3] n_colors = _COLORS.shape[0] sep = tf.reshape(tf.convert_to_tensor(_COLORS), [1, 1, n_colors, 3]) n_tiles = int(colored.shape[2]) // n_colors sep = snt.TileByDim([0, 1, 3], [batch_size, 3, n_tiles])(sep) sep = tf.reshape(sep, [batch_size, 3, n_tiles * n_colors, 3]) pad = int(colored.shape[2]) - n_colors * n_tiles pad, r = pad // 2, pad % 2 if caps_presence_prob is not None: n_caps = int(caps_presence_prob.shape[1]) prob_pads = ([0, 0], [0, n_colors - n_caps]) caps_presence_prob = tf.pad(caps_presence_prob, prob_pads) zeros = tf.zeros([batch_size, 3, n_colors, n_tiles, 3], dtype=tf.float32) shape = [batch_size, 1, n_colors, 1, 1] caps_presence_prob = tf.reshape(caps_presence_prob, shape) prob_vals = snt.MergeDims(2, 2)(caps_presence_prob + zeros) sep = tf.concat([sep, tf.ones_like(sep[:, :1]), prob_vals], 1) sep = tf.pad(sep, [(0, 0), (1, 1), (pad, pad + r), (0, 0)], constant_values=1.) # render gt points if gt_points is not None: gt_points = denormalize_coords(gt_points, canvas_size, rounded=True) gt_rendered = render_by_scatter(canvas_size, gt_points, colors=None, gt_presence=gt_presence) colored = tf.where(tf.cast(colored, bool), colored, gt_rendered) colored = tf.concat([gt_rendered, sep, colored], 1) res = tf.clip_by_value(colored, 0., 1.) return res
def _build(self, feature, parent_transform=None, parent_presence=None): """Builds the module. Args: features: Tensor of encodings of shape [B, n_enc_dims]. parent_transform: Tuple of (matrix, vector). parent_presence: pass Returns: A bunch of stuff. """ features = tf.ones([200, 32, 256]) batch_size = features.shape.as_list()[0] print(batch_size) batch_shape = [batch_size, self._n_caps] # Predict capsule and additional params from the input encoding. # [B, n_caps, n_caps_dims] if self._n_caps_params is not None: # Use separate parameters to do predictions for different capsules. mlp = BatchMLP(self._n_hiddens + [self._n_caps_params]) print('mlp1') print(features.shape) raw_caps_params = mlp(features) caps_params = tf.reshape(raw_caps_params, batch_shape + [self._n_caps_params]) else: assert features.shape[:2].as_list() == batch_shape caps_params = features if self._caps_dropout_rate == 0.0: caps_exist = tf.ones(batch_shape + [1], dtype=tf.float32) else: pmf = tfd.Bernoulli(1. - self._caps_dropout_rate, dtype=tf.float32) caps_exist = pmf.sample(batch_shape + [1]) caps_params = tf.concat([caps_params, caps_exist], -1) output_shapes = ( [self._n_votes, self._n_transform_params], # CPR_dynamic [1, self._n_transform_params], # CCR [1], # per-capsule presence [self._n_votes], # per-vote-presence [self._n_votes], # per-vote scale ) splits = [np.prod(i).astype(np.int32) for i in output_shapes] n_outputs = sum(splits) # we don't use bias in the output layer in order to separate the static # and dynamic parts of the CPR caps_mlp = BatchMLP([self._n_hiddens, n_outputs], use_bias=False) print('mlp2') print(caps_params.shape) all_params = caps_mlp(caps_params) all_params = tf.split(all_params, splits, -1) res = [ tf.reshape(i, batch_shape + s) for (i, s) in zip(all_params, output_shapes) ] cpr_dynamic = res[0] # add bias to all remaining outputs ccr = res[1] cpr_static = tf.get_variable( 'cpr_static', shape=[1, self._n_caps, self._n_votes, self._n_transform_params]) # this is for hierarchical if parent_transform is None: ccr = self._make_transform(ccr) else: ccr = parent_transform if not self._deformations: cpr_dynamic = tf.zeros_like(cpr_dynamic) cpr = self._make_transform(cpr_dynamic + cpr_static) ccr_per_vote = snt.TileByDim([2], [self._n_votes])(ccr) print('start matmul') print(ccr_per_vote.shape) print(cpr.shape) print('end matmul') votes = tf.matmul(ccr_per_vote, cpr) return votes