Ejemplo n.º 1
0
  def _build(self, data):

    x = data[self._input_key]
    presence = data[self._presence_key] if self._presence_key else None

    inputs = nest.flatten(x)
    if presence is not None:
      inputs.append(presence)

    h = self._encoder(*inputs)
    res = self._decoder(h, *inputs)

    n_points = int(res.posterior_mixing_probs.shape[1])
    mass_explained_by_capsule = tf.reduce_sum(res.posterior_mixing_probs, 1)

    (res.posterior_within_sparsity_loss,
     res.posterior_between_sparsity_loss) = _capsule.sparsity_loss(
         self._posterior_sparsity_loss_type,
         mass_explained_by_capsule / n_points,
         num_classes=self._n_classes)

    (res.prior_within_sparsity_loss,
     res.prior_between_sparsity_loss) = _capsule.sparsity_loss(
         self._prior_sparsity_loss_type,
         res.caps_presence_prob,
         num_classes=self._n_classes,
         within_example_constant=self._prior_within_example_constant)

    return res
Ejemplo n.º 2
0
    def _build(self, data):

        input_x = self._img(data, False)
        target_x = self._img(data, prep=self._prep)
        batch_size = int(input_x.shape[0])

        primary_caps = self._primary_encoder(input_x)
        pres = primary_caps.presence

        expanded_pres = tf.expand_dims(pres, -1)
        pose = primary_caps.pose
        input_pose = tf.concat([pose, 1. - expanded_pres], -1)

        input_pres = pres
        if self._stop_grad_caps_inpt:
            input_pose = tf.stop_gradient(input_pose)
            input_pres = tf.stop_gradient(pres)

        target_pose, target_pres = pose, pres
        if self._stop_grad_caps_target:
            target_pose = tf.stop_gradient(target_pose)
            target_pres = tf.stop_gradient(target_pres)

        # skip connection from the img to the higher level capsule
        if primary_caps.feature is not None:
            input_pose = tf.concat([input_pose, primary_caps.feature], -1)

        # try to feed presence as a separate input
        # and if that works, concatenate templates to poses
        # this is necessary for set transformer
        n_templates = int(primary_caps.pose.shape[1])
        templates = self._primary_decoder.make_templates(
            n_templates, primary_caps.feature)

        try:
            if self._feed_templates:
                inpt_templates = templates
                if self._stop_grad_caps_inpt:
                    inpt_templates = tf.stop_gradient(inpt_templates)

                if inpt_templates.shape[0] == 1:
                    inpt_templates = snt.TileByDim(
                        [0], [batch_size])(inpt_templates)
                inpt_templates = snt.BatchFlatten(2)(inpt_templates)
                pose_with_templates = tf.concat([input_pose, inpt_templates],
                                                -1)
            else:
                pose_with_templates = input_pose

            h = self._encoder(pose_with_templates, input_pres)

        except TypeError:
            h = self._encoder(input_pose)

        res = self._decoder(h, target_pose, target_pres)
        res.primary_presence = primary_caps.presence

        if self._vote_type == 'enc':
            primary_dec_vote = primary_caps.pose
        elif self._vote_type == 'soft':
            primary_dec_vote = res.soft_winner
        elif self._vote_type == 'hard':
            primary_dec_vote = res.winner
        else:
            raise ValueError('Invalid vote_type="{}"".'.format(
                self._vote_type))

        if self._pres_type == 'enc':
            primary_dec_pres = pres
        elif self._pres_type == 'soft':
            primary_dec_pres = res.soft_winner_pres
        elif self._pres_type == 'hard':
            primary_dec_pres = res.winner_pres
        else:
            raise ValueError('Invalid pres_type="{}"".'.format(
                self._pres_type))

        res.bottom_up_rec = self._primary_decoder(
            primary_caps.pose,
            primary_caps.presence,
            template_feature=primary_caps.feature,
            img_embedding=primary_caps.img_embedding)

        res.top_down_rec = self._primary_decoder(
            res.winner,
            primary_caps.presence,
            template_feature=primary_caps.feature,
            img_embedding=primary_caps.img_embedding)

        rec = self._primary_decoder(primary_dec_vote,
                                    primary_dec_pres,
                                    template_feature=primary_caps.feature,
                                    img_embedding=primary_caps.img_embedding)

        tile = snt.TileByDim([0], [res.vote.shape[1]])
        tiled_presence = tile(primary_caps.presence)

        tiled_feature = primary_caps.feature
        if tiled_feature is not None:
            tiled_feature = tile(tiled_feature)

        tiled_img_embedding = tile(primary_caps.img_embedding)

        res.top_down_per_caps_rec = self._primary_decoder(
            snt.MergeDims(0, 2)(res.vote),
            snt.MergeDims(0, 2)(res.vote_presence) * tiled_presence,
            template_feature=tiled_feature,
            img_embedding=tiled_img_embedding)

        res.templates = templates
        res.template_pres = pres
        res.used_templates = rec.transformed_templates

        res.rec_mode = rec.pdf.mode()
        res.rec_mean = rec.pdf.mean()

        res.mse_per_pixel = tf.square(target_x - res.rec_mode)
        res.mse = math_ops.flat_reduce(res.mse_per_pixel)

        res.rec_ll_per_pixel = rec.pdf.log_prob(target_x)
        res.rec_ll = math_ops.flat_reduce(res.rec_ll_per_pixel)

        n_points = int(res.posterior_mixing_probs.shape[1])
        mass_explained_by_capsule = tf.reduce_sum(res.posterior_mixing_probs,
                                                  1)

        (res.posterior_within_sparsity_loss,
         res.posterior_between_sparsity_loss) = _capsule.sparsity_loss(
             self._posterior_sparsity_loss_type,
             mass_explained_by_capsule / n_points,
             num_classes=self._n_classes)

        (res.prior_within_sparsity_loss,
         res.prior_between_sparsity_loss) = _capsule.sparsity_loss(
             self._prior_sparsity_loss_type,
             res.caps_presence_prob,
             num_classes=self._n_classes,
             within_example_constant=self._prior_within_example_constant)

        label = self._label(data)
        if label is not None:
            res.posterior_cls_xe, res.posterior_cls_acc = probe.classification_probe(
                mass_explained_by_capsule,
                label,
                self._n_classes,
                labeled=data.get('labeled', None))
            res.prior_cls_xe, res.prior_cls_acc = probe.classification_probe(
                res.caps_presence_prob,
                label,
                self._n_classes,
                labeled=data.get('labeled', None))

        res.best_cls_acc = tf.maximum(res.prior_cls_acc, res.posterior_cls_acc)

        res.primary_caps_l1 = math_ops.flat_reduce(res.primary_presence)

        if self._weight_decay > 0.0:
            decay_losses_list = []
            for var in tf.trainable_variables():
                if 'w:' in var.name or 'weights:' in var.name:
                    decay_losses_list.append(tf.nn.l2_loss(var))
            res.weight_decay_loss = tf.reduce_sum(decay_losses_list)
        else:
            res.weight_decay_loss = 0.0

        return res