def _build(self, data): x = data[self._input_key] presence = data[self._presence_key] if self._presence_key else None inputs = nest.flatten(x) if presence is not None: inputs.append(presence) h = self._encoder(*inputs) res = self._decoder(h, *inputs) n_points = int(res.posterior_mixing_probs.shape[1]) mass_explained_by_capsule = tf.reduce_sum(res.posterior_mixing_probs, 1) (res.posterior_within_sparsity_loss, res.posterior_between_sparsity_loss) = _capsule.sparsity_loss( self._posterior_sparsity_loss_type, mass_explained_by_capsule / n_points, num_classes=self._n_classes) (res.prior_within_sparsity_loss, res.prior_between_sparsity_loss) = _capsule.sparsity_loss( self._prior_sparsity_loss_type, res.caps_presence_prob, num_classes=self._n_classes, within_example_constant=self._prior_within_example_constant) return res
def _build(self, data): input_x = self._img(data, False) target_x = self._img(data, prep=self._prep) batch_size = int(input_x.shape[0]) primary_caps = self._primary_encoder(input_x) pres = primary_caps.presence expanded_pres = tf.expand_dims(pres, -1) pose = primary_caps.pose input_pose = tf.concat([pose, 1. - expanded_pres], -1) input_pres = pres if self._stop_grad_caps_inpt: input_pose = tf.stop_gradient(input_pose) input_pres = tf.stop_gradient(pres) target_pose, target_pres = pose, pres if self._stop_grad_caps_target: target_pose = tf.stop_gradient(target_pose) target_pres = tf.stop_gradient(target_pres) # skip connection from the img to the higher level capsule if primary_caps.feature is not None: input_pose = tf.concat([input_pose, primary_caps.feature], -1) # try to feed presence as a separate input # and if that works, concatenate templates to poses # this is necessary for set transformer n_templates = int(primary_caps.pose.shape[1]) templates = self._primary_decoder.make_templates( n_templates, primary_caps.feature) try: if self._feed_templates: inpt_templates = templates if self._stop_grad_caps_inpt: inpt_templates = tf.stop_gradient(inpt_templates) if inpt_templates.shape[0] == 1: inpt_templates = snt.TileByDim( [0], [batch_size])(inpt_templates) inpt_templates = snt.BatchFlatten(2)(inpt_templates) pose_with_templates = tf.concat([input_pose, inpt_templates], -1) else: pose_with_templates = input_pose h = self._encoder(pose_with_templates, input_pres) except TypeError: h = self._encoder(input_pose) res = self._decoder(h, target_pose, target_pres) res.primary_presence = primary_caps.presence if self._vote_type == 'enc': primary_dec_vote = primary_caps.pose elif self._vote_type == 'soft': primary_dec_vote = res.soft_winner elif self._vote_type == 'hard': primary_dec_vote = res.winner else: raise ValueError('Invalid vote_type="{}"".'.format( self._vote_type)) if self._pres_type == 'enc': primary_dec_pres = pres elif self._pres_type == 'soft': primary_dec_pres = res.soft_winner_pres elif self._pres_type == 'hard': primary_dec_pres = res.winner_pres else: raise ValueError('Invalid pres_type="{}"".'.format( self._pres_type)) res.bottom_up_rec = self._primary_decoder( primary_caps.pose, primary_caps.presence, template_feature=primary_caps.feature, img_embedding=primary_caps.img_embedding) res.top_down_rec = self._primary_decoder( res.winner, primary_caps.presence, template_feature=primary_caps.feature, img_embedding=primary_caps.img_embedding) rec = self._primary_decoder(primary_dec_vote, primary_dec_pres, template_feature=primary_caps.feature, img_embedding=primary_caps.img_embedding) tile = snt.TileByDim([0], [res.vote.shape[1]]) tiled_presence = tile(primary_caps.presence) tiled_feature = primary_caps.feature if tiled_feature is not None: tiled_feature = tile(tiled_feature) tiled_img_embedding = tile(primary_caps.img_embedding) res.top_down_per_caps_rec = self._primary_decoder( snt.MergeDims(0, 2)(res.vote), snt.MergeDims(0, 2)(res.vote_presence) * tiled_presence, template_feature=tiled_feature, img_embedding=tiled_img_embedding) res.templates = templates res.template_pres = pres res.used_templates = rec.transformed_templates res.rec_mode = rec.pdf.mode() res.rec_mean = rec.pdf.mean() res.mse_per_pixel = tf.square(target_x - res.rec_mode) res.mse = math_ops.flat_reduce(res.mse_per_pixel) res.rec_ll_per_pixel = rec.pdf.log_prob(target_x) res.rec_ll = math_ops.flat_reduce(res.rec_ll_per_pixel) n_points = int(res.posterior_mixing_probs.shape[1]) mass_explained_by_capsule = tf.reduce_sum(res.posterior_mixing_probs, 1) (res.posterior_within_sparsity_loss, res.posterior_between_sparsity_loss) = _capsule.sparsity_loss( self._posterior_sparsity_loss_type, mass_explained_by_capsule / n_points, num_classes=self._n_classes) (res.prior_within_sparsity_loss, res.prior_between_sparsity_loss) = _capsule.sparsity_loss( self._prior_sparsity_loss_type, res.caps_presence_prob, num_classes=self._n_classes, within_example_constant=self._prior_within_example_constant) label = self._label(data) if label is not None: res.posterior_cls_xe, res.posterior_cls_acc = probe.classification_probe( mass_explained_by_capsule, label, self._n_classes, labeled=data.get('labeled', None)) res.prior_cls_xe, res.prior_cls_acc = probe.classification_probe( res.caps_presence_prob, label, self._n_classes, labeled=data.get('labeled', None)) res.best_cls_acc = tf.maximum(res.prior_cls_acc, res.posterior_cls_acc) res.primary_caps_l1 = math_ops.flat_reduce(res.primary_presence) if self._weight_decay > 0.0: decay_losses_list = [] for var in tf.trainable_variables(): if 'w:' in var.name or 'weights:' in var.name: decay_losses_list.append(tf.nn.l2_loss(var)) res.weight_decay_loss = tf.reduce_sum(decay_losses_list) else: res.weight_decay_loss = 0.0 return res