コード例 #1
0
  def __init__(self, mixing_logits, component_stats, component_class,
               presence=None):
    """Builds the module.
    Args:
      mixing_logits: tensor [B, k, ...] with k the number of components.
      component_stats: list of tensors of shape [B, k, ...] or broadcastable
        to these shapes; they are argument to the chosen distribution class.
      component_class: callable; returns a distribution object.
      presence: [B, k] tensor of floats in [0, 1] or None.
    """
    super(MixtureDistribution, self).__init__()
    if presence is not None:
      mixing_logits += make_brodcastable(safe_log(presence), mixing_logits)

    self._mixing_logits = mixing_logits

    component_stats = nest.flatten(component_stats)
    self._distributions = component_class(*component_stats)
    self._presence = presence
コード例 #2
0
    def _build(self,
               pose,
               presence=None,
               template_feature=None,
               bg_image=None,
               img_embedding=None):
        """Builds the module.

		Args:
			pose: [B, n_templates, 6] tensor.
			presence: [B, n_templates] tensor.
			template_feature: [B, n_templates, n_features] tensor; these features are
				used to change templates based on the input, if present.
			bg_image: [B, *output_size] tensor representing the background.
			img_embedding: [B, d] tensor containing image embeddings.

		Returns:
			[B, n_templates, *output_size, n_channels] tensor.
		"""
        batch_size, n_templates = pose.shape[:2].as_list()
        templates = self.make_templates(n_templates, template_feature)

        if templates.shape[0] == 1:
            templates = snt.TileByDim([0], [batch_size])(templates)

        # it's easier for me to think in inverse coordinates
        warper = snt.AffineGridWarper(self._output_size, self._template_size)
        warper = warper.inverse()

        grid_coords = snt.BatchApply(warper)(pose)
        resampler = snt.BatchApply(tf.contrib.resampler.resampler)
        transformed_templates = resampler(templates, grid_coords)

        if bg_image is not None:
            bg_image = tf.expand_dims(bg_image, axis=1)
        else:
            bg_image = tf.nn.sigmoid(tf.get_variable('bg_value', shape=[1]))
            bg_image = tf.zeros_like(transformed_templates[:, :1]) + bg_image

        transformed_templates = tf.concat([transformed_templates, bg_image],
                                          axis=1)

        if presence is not None:
            presence = tf.concat([presence, tf.ones([batch_size, 1])], axis=1)

        if True:  # pylint: disable=using-constant-test

            if self._use_alpha_channel:
                template_mixing_logits = snt.TileByDim([0], [batch_size])(
                    self._templates_alpha)
                template_mixing_logits = resampler(template_mixing_logits,
                                                   grid_coords)

                bg_mixing_logit = tf.nn.softplus(
                    tf.get_variable('bg_mixing_logit', initializer=[0.]))

                bg_mixing_logit = (
                    tf.zeros_like(template_mixing_logits[:, :1]) +
                    bg_mixing_logit)

                template_mixing_logits = tf.concat(
                    [template_mixing_logits, bg_mixing_logit], 1)

            else:
                temperature_logit = tf.get_variable('temperature_logit',
                                                    shape=[1])
                temperature = tf.nn.softplus(temperature_logit + .5) + 1e-4
                template_mixing_logits = transformed_templates / temperature

        scale = 1.
        if self._learn_output_scale:
            scale = tf.get_variable('scale', shape=[1])
            scale = tf.nn.softplus(scale) + 1e-4

        if self._output_pdf_type == 'mixture':
            template_mixing_logits += make_brodcastable(
                math_ops.safe_log(presence), template_mixing_logits)

            rec_pdf = prob.MixtureDistribution(template_mixing_logits,
                                               [transformed_templates, scale],
                                               tfd.Normal)

        else:
            raise ValueError('Unknown pdf type: "{}".'.format(
                self._output_pdf_type))

        return AttrDict(raw_templates=tf.squeeze(self._templates, 0),
                        transformed_templates=transformed_templates[:, :-1],
                        mixing_logits=template_mixing_logits[:, :-1],
                        pdf=rec_pdf)
コード例 #3
0
ファイル: capsule.py プロジェクト: FrostbiteXSW/SCAE_Attack
    def _build(self, features, parent_transform=None, parent_presence=None):
        """Builds the module.

		Args:
			features: Tensor of encodings of shape [B, n_enc_dims].
			parent_transform: Tuple of (matrix, vector).
			parent_presence: pass

		Returns:
			A bunch of stuff.
		"""
        batch_size = features.shape.as_list()[0]
        batch_shape = [batch_size, self._n_caps]

        # Predict capsule and additional params from the input encoding.
        # [B, n_caps, n_caps_dims]
        if self._n_caps_params is not None:

            # Use separate parameters to do predictions for different capsules.
            mlp = BatchMLP(self._n_hiddens + [self._n_caps_params])
            raw_caps_params = mlp(features)

            caps_params = tf.reshape(raw_caps_params,
                                     batch_shape + [self._n_caps_params])

        else:
            assert features.shape[:2].as_list() == batch_shape
            caps_params = features

        if self._caps_dropout_rate == 0.0:
            caps_exist = tf.ones(batch_shape + [1], dtype=tf.float32)
        else:
            pmf = tfd.Bernoulli(1. - self._caps_dropout_rate, dtype=tf.float32)
            caps_exist = pmf.sample(batch_shape + [1])

        caps_params = tf.concat([caps_params, caps_exist], -1)

        output_shapes = (
            [self._n_votes, self._n_transform_params],  # CPR_dynamic
            [1, self._n_transform_params],  # CCR
            [1],  # per-capsule presence
            [self._n_votes],  # per-vote-presence
            [self._n_votes],  # per-vote scale
        )

        splits = [np.prod(i).astype(np.int32) for i in output_shapes]
        n_outputs = sum(splits)

        # we don't use bias in the output layer in order to separate the static
        # and dynamic parts of the CPR
        caps_mlp = BatchMLP([self._n_hiddens, n_outputs], use_bias=False)
        all_params = caps_mlp(caps_params)
        all_params = tf.split(all_params, splits, -1)
        res = [
            tf.reshape(i, batch_shape + s)
            for (i, s) in zip(all_params, output_shapes)
        ]

        cpr_dynamic = res[0]

        # add bias to all remaining outputs
        res = [snt.AddBias()(i) for i in res[1:]]
        ccr, pres_logit_per_caps, pres_logit_per_vote, scale_per_vote = res

        if self._caps_dropout_rate != 0.0:
            pres_logit_per_caps += math_ops.safe_log(caps_exist)

        cpr_static = tf.get_variable(
            'cpr_static',
            shape=[1, self._n_caps, self._n_votes, self._n_transform_params])

        def add_noise(tensor):
            """Adds noise to tensors."""
            if self._noise_type == 'uniform':
                noise = tf.random.uniform(tensor.shape, minval=-.5,
                                          maxval=.5) * self._noise_scale

            elif self._noise_type == 'logistic':
                pdf = tfd.Logistic(0., self._noise_scale)
                noise = pdf.sample(tensor.shape)

            elif not self._noise_type:
                noise = 0.

            else:
                raise ValueError('Invalid noise type: "{}".'.format(
                    self._noise_type))

            return tensor + noise

        pres_logit_per_caps = add_noise(pres_logit_per_caps)
        pres_logit_per_vote = add_noise(pres_logit_per_vote)

        # this is for hierarchical
        if parent_transform is None:
            ccr = self._make_transform(ccr)
        else:
            ccr = parent_transform

        if not self._deformations:
            cpr_dynamic = tf.zeros_like(cpr_dynamic)

        cpr = self._make_transform(cpr_dynamic + cpr_static)

        ccr_per_vote = snt.TileByDim([2], [self._n_votes])(ccr)
        votes = tf.matmul(ccr_per_vote, cpr)

        if parent_presence is not None:
            pres_per_caps = parent_presence
        else:
            pres_per_caps = tf.nn.sigmoid(pres_logit_per_caps)

        pres_per_vote = pres_per_caps * tf.nn.sigmoid(pres_logit_per_vote)

        if self._learn_vote_scale:
            # for numerical stability
            scale_per_vote = tf.nn.softplus(scale_per_vote + .5) + 1e-2
        else:
            scale_per_vote = tf.zeros_like(scale_per_vote) + 1.

        return AttrDict(
            vote=votes,
            scale=scale_per_vote,
            vote_presence=pres_per_vote,
            pres_logit_per_caps=pres_logit_per_caps,
            pres_logit_per_vote=pres_logit_per_vote,
            dynamic_weights_l2=tf.nn.l2_loss(cpr_dynamic) / batch_size,
            raw_caps_params=raw_caps_params,
            raw_caps_features=features,
        )
コード例 #4
0
ファイル: capsule.py プロジェクト: FrostbiteXSW/SCAE_Attack
    def _build(self, x, presence=None):
        # x is [B, n_input_points, n_input_dims]
        batch_size, n_input_points = x.shape[:2].as_list()

        # votes and scale have shape [B, n_caps, n_input_points, n_input_dims|1]
        # since scale is a per-caps scalar and we have one vote per capsule
        vote_component_pdf = self._get_pdf(self._votes,
                                           tf.expand_dims(self._scales, -1))

        # expand along caps dimensions -> [B, 1, n_input_points, n_input_dims]
        expanded_x = tf.expand_dims(x, 1)
        vote_log_prob_per_dim = vote_component_pdf.log_prob(expanded_x)
        # [B, n_caps, n_input_points]
        vote_log_prob = tf.reduce_sum(vote_log_prob_per_dim, -1)
        dummy_vote_log_prob = tf.zeros([batch_size, 1, n_input_points])
        dummy_vote_log_prob -= 2. * tf.log(10.)

        # [B, n_caps + 1, n_input_points]
        vote_log_prob = tf.concat([vote_log_prob, dummy_vote_log_prob], 1)

        # [B, n_caps, n_input_points]
        mixing_logits = math_ops.safe_log(self._vote_presence_prob)

        dummy_logit = tf.zeros([batch_size, 1, 1]) - 2. * tf.log(10.)
        dummy_logit = snt.TileByDim([2], [n_input_points])(dummy_logit)

        # [B, n_caps + 1, n_input_points]
        mixing_logits = tf.concat([mixing_logits, dummy_logit], 1)
        mixing_log_prob = mixing_logits - tf.reduce_logsumexp(
            mixing_logits, 1, keepdims=True)
        # [B, n_input_points]
        mixture_log_prob_per_point = tf.reduce_logsumexp(
            mixing_logits + vote_log_prob, 1)

        if presence is not None:
            presence = tf.to_float(presence)
            mixture_log_prob_per_point *= presence

        # [B,]
        mixture_log_prob_per_example \
         = tf.reduce_sum(mixture_log_prob_per_point, 1)

        # []
        mixture_log_prob_per_batch = tf.reduce_mean(
            mixture_log_prob_per_example)

        # [B, n_caps + 1, n_input_points]
        posterior_mixing_logits_per_point = mixing_logits + vote_log_prob

        # [B, n_input_points]
        winning_vote_idx = tf.argmax(posterior_mixing_logits_per_point[:, :-1],
                                     1)

        batch_idx = tf.expand_dims(tf.range(batch_size, dtype=tf.int64), 1)
        batch_idx = snt.TileByDim([1], [n_input_points])(batch_idx)

        point_idx = tf.expand_dims(tf.range(n_input_points, dtype=tf.int64), 0)
        point_idx = snt.TileByDim([0], [batch_size])(point_idx)

        idx = tf.stack([batch_idx, winning_vote_idx, point_idx], -1)
        winning_vote = tf.gather_nd(self._votes, idx)
        winning_pres = tf.gather_nd(self._vote_presence_prob, idx)
        vote_presence = tf.greater(mixing_logits[:, :-1], mixing_logits[:,
                                                                        -1:])

        # the first four votes belong to the square
        is_from_capsule = winning_vote_idx // self._n_votes

        posterior_mixing_probs = tf.nn.softmax(
            posterior_mixing_logits_per_point, 1)

        dummy_vote = tf.get_variable('dummy_vote',
                                     shape=self._votes[:1, :1].shape)
        dummy_vote = snt.TileByDim([0], [batch_size])(dummy_vote)
        dummy_pres = tf.zeros([batch_size, 1, n_input_points])

        votes = tf.concat((self._votes, dummy_vote), 1)
        pres = tf.concat([self._vote_presence_prob, dummy_pres], 1)

        soft_winner = tf.reduce_sum(
            tf.expand_dims(posterior_mixing_probs, -1) * votes, 1)
        soft_winner_pres = tf.reduce_sum(posterior_mixing_probs * pres, 1)

        posterior_mixing_probs = tf.transpose(posterior_mixing_probs[:, :-1],
                                              (0, 2, 1))

        assert winning_vote.shape == x.shape

        return self.OutputTuple(
            log_prob=mixture_log_prob_per_batch,
            vote_presence=tf.to_float(vote_presence),
            winner=winning_vote,
            winner_pres=winning_pres,
            soft_winner=soft_winner,
            soft_winner_pres=soft_winner_pres,
            posterior_mixing_probs=posterior_mixing_probs,
            is_from_capsule=is_from_capsule,
            mixing_logits=mixing_logits,
            mixing_log_prob=mixing_log_prob,
        )
コード例 #5
0
ファイル: capsule.py プロジェクト: FrostbiteXSW/SCAE_Attack
    def _build(self, x, presence=None):

        batch_size, n_input_points = x.shape[:2].as_list()

        # we don't know what order the initial points came in, so we need to create
        # a big mixture of all votes for every input point
        # [B, 1, n_votes, n_input_dims]
        expanded_votes = tf.expand_dims(self._votes, 1)
        expanded_scale = tf.expand_dims(tf.expand_dims(self._scales, 1), -1)
        vote_component_pdf = self._get_pdf(expanded_votes, expanded_scale)

        # [B, n_points, n_caps, n_votes, n_input_dims]
        expanded_x = tf.expand_dims(x, 2)
        vote_log_prob_per_dim = vote_component_pdf.log_prob(expanded_x)
        # [B, n_points, n_votes]
        vote_log_prob = tf.reduce_sum(vote_log_prob_per_dim, -1)
        dummy_vote_log_prob = tf.zeros([batch_size, n_input_points, 1])
        dummy_vote_log_prob -= 2. * tf.log(10.)
        vote_log_prob = tf.concat([vote_log_prob, dummy_vote_log_prob], 2)

        # [B, n_points, n_votes]
        mixing_logits = math_ops.safe_log(self._vote_presence_prob)

        dummy_logit = tf.zeros([batch_size, 1]) - 2. * tf.log(10.)
        mixing_logits = tf.concat([mixing_logits, dummy_logit], 1)

        mixing_log_prob = mixing_logits - tf.reduce_logsumexp(
            mixing_logits, 1, keepdims=True)

        expanded_mixing_logits = tf.expand_dims(mixing_log_prob, 1)
        mixture_log_prob_per_component \
         = tf.reduce_logsumexp(expanded_mixing_logits + vote_log_prob, 2)

        if presence is not None:
            presence = tf.to_float(presence)
            mixture_log_prob_per_component *= presence

        mixture_log_prob_per_example \
         = tf.reduce_sum(mixture_log_prob_per_component, 1)

        mixture_log_prob_per_batch = tf.reduce_mean(
            mixture_log_prob_per_example)

        # [B, n_points, n_votes]
        posterior_mixing_logits_per_point = expanded_mixing_logits + vote_log_prob
        # [B, n_points]
        winning_vote_idx = tf.argmax(
            posterior_mixing_logits_per_point[:, :, :-1], 2)

        batch_idx = tf.expand_dims(tf.range(batch_size, dtype=tf.int64), -1)
        batch_idx = snt.TileByDim([1], [winning_vote_idx.shape[-1]])(batch_idx)

        idx = tf.stack([batch_idx, winning_vote_idx], -1)
        winning_vote = tf.gather_nd(self._votes, idx)
        winning_pres = tf.gather_nd(self._vote_presence_prob, idx)
        vote_presence = tf.greater(mixing_logits[:, :-1], mixing_logits[:,
                                                                        -1:])

        # the first four votes belong to the square
        is_from_capsule = winning_vote_idx // self._n_votes

        posterior_mixing_probs = tf.nn.softmax(
            posterior_mixing_logits_per_point, -1)[Ellipsis, :-1]

        assert winning_vote.shape == x.shape

        return self.OutputTuple(
            log_prob=mixture_log_prob_per_batch,
            vote_presence=tf.to_float(vote_presence),
            winner=winning_vote,
            winner_pres=winning_pres,
            is_from_capsule=is_from_capsule,
            mixing_logits=mixing_logits,
            mixing_log_prob=mixing_log_prob,
            # TODO(adamrk): this is broken
            soft_winner=tf.zeros_like(winning_vote),
            soft_winner_pres=tf.zeros_like(winning_pres),
            posterior_mixing_probs=posterior_mixing_probs,
        )