示例#1
0
        def map_func(data):
            """Replicates data if necessary."""
            data = dict(data)

            if n_replicas > 1:
                tile_by_batch = snt.TileByDim([0], [n_replicas])
                data = {k: tile_by_batch(v) for k, v in data.items()}

            if transforms is not None:
                img = data['image']

                for k, transform in transforms.items():
                    data[k] = transform(img)

            return data
示例#2
0
  def _self_attention(self, x, presence):

    head_before = MultiHeadQKVAttention(self._n_heads)
    # head_after = MultiHeadQKVAttention(self._n_heads)
    head_after = head_before

    inducing_points = tf.get_variable(
        'inducing_points', shape=[1, self._n_inducing_points,
                                  int(x.shape[-1])])

    inducing_points = snt.TileByDim([0], [int(x.shape[0])])(inducing_points)

    z = head_before(inducing_points, x, x)
    y = head_after(x, z, z)
    return y
    def _build(self, x):
        """Applies the module.

    Args:
      x: tensor of shape [B, k, d].

    Returns:
      Tensor of shape [B, k, n_units].
    """

        # batch_size, n_inputs, n_dims = x.shape.as_list()
        shape = x.shape.as_list()

        if 'w' not in self.initializers:
            stddev = 1 / math.sqrt(shape[-1])
            self.initializers['w'] = tf.truncated_normal_initializer(
                stddev=stddev)

        weights_shape = shape + [self._n_units]
        tiles = []
        for i in self._tile_dims:
            tiles.append(weights_shape[i])
            weights_shape[i] = 1

        weights = tf.get_variable('weights',
                                  shape=weights_shape,
                                  initializer=self._init('w'))

        weights = snt.TileByDim(self._tile_dims, tiles)(weights)

        x = tf.expand_dims(x, -2)
        print(x.shape)
        print(weights.shape)
        y = tf.matmul(x, weights)
        y = tf.squeeze(y, -2)

        if self._use_bias:
            if 'b' not in self.initializers:
                self.initializers['b'] = tf.zeros_initializer()

            init = dict(b=self._init('b'))
            bias_dims = [
                i for i in range(len(shape)) if i not in self._tile_dims
            ]
            add_bias = snt.AddBias(bias_dims=bias_dims, initializers=init)
            y = add_bias(y)

        return y
示例#4
0
def render_by_scatter(size, points, colors=None, gt_presence=None):
    """Renders point by using tf.scatter_nd."""

    if colors is None:
        colors = tf.ones(points.shape[:-1].as_list() + [3], dtype=tf.float32)

    if gt_presence is not None:
        colors *= tf.cast(tf.expand_dims(gt_presence, -1), colors.dtype)

    batch_size, n_points = points.shape[:-1].as_list()
    shape = [batch_size] + list(size) + [3]
    batch_idx = tf.reshape(tf.range(batch_size), [batch_size, 1, 1])
    batch_idx = snt.TileByDim([1], [n_points])(batch_idx)
    idx = tf.concat([batch_idx, tf.cast(points, tf.int32)], -1)

    return tf.scatter_nd(idx, colors, shape)
示例#5
0
    def clevr_veggies_map_func(index, image, label):
        #st()
        data = {'index': index, 'image': image, 'label': label}

        if n_replicas > 1:
            print('n_replicas: ', n_replicas)
            tile_by_batch = snt.TileByDim([0], [n_replicas])
            data = {k: tile_by_batch(v) for k, v in data.items()}
            # print(data)

        if transforms is not None:
            img = data['image']
            # print('before transforms: ', data)

            for k, transform in transforms.items():
                data[k] = transform(img)
            # print('after transforms: ', data)

        return data
  def _build(self, x, presence=None):

    batch_size = int(x.shape[0])
    h = snt.BatchApply(snt.Linear(self._n_dims))(x)

    args = [self._n_heads, self._layer_norm, self._dropout_rate]
    klass = SelfAttention

    if self._n_inducing_points > 0:
      args = [self._n_inducing_points] + args
      klass = InducedSelfAttention

    for _ in range(self._n_layers):
      h = klass(*args)(h, presence)

    z = snt.BatchApply(snt.Linear(self._n_output_dims))(h)

    inducing_points = tf.get_variable(
        'inducing_points', shape=[1, self._n_outputs, self._n_output_dims])
    inducing_points = snt.TileByDim([0], [batch_size])(inducing_points)

    return MultiHeadQKVAttention(self._n_heads)(inducing_points, z, z, presence)
示例#7
0
    def _build(self, features, parent_transform=None, parent_presence=None):
        """Builds the module.

    Args:
      features: Tensor of encodings of shape [B, n_enc_dims].
      parent_transform: Tuple of (matrix, vector).
      parent_presence: pass

    Returns:
      A bunch of stuff.
    """
        batch_size = features.shape.as_list()[0]
        batch_shape = [batch_size, self._n_caps]

        # Predict capsule and additional params from the input encoding.
        # [B, n_caps, n_caps_dims]
        if self._n_caps_params is not None:

            # Use separate parameters to do predictions for different capsules.
            mlp = BatchMLP(self._n_hiddens + [self._n_caps_params])
            raw_caps_params = mlp(features)

            caps_params = tf.reshape(raw_caps_params,
                                     batch_shape + [self._n_caps_params])

        else:
            assert features.shape[:2].as_list() == batch_shape
            caps_params = features

        if self._caps_dropout_rate == 0.0:
            caps_exist = tf.ones(batch_shape + [1], dtype=tf.float32)
        else:
            pmf = tfd.Bernoulli(1. - self._caps_dropout_rate, dtype=tf.float32)
            caps_exist = pmf.sample(batch_shape + [1])

        caps_params = tf.concat([caps_params, caps_exist], -1)

        output_shapes = (
            [self._n_votes, self._n_transform_params],  # CPR_dynamic
            [1, self._n_transform_params],  # CCR
            [1],  # per-capsule presence
            [self._n_votes],  # per-vote-presence
            [self._n_votes],  # per-vote scale
        )

        splits = [np.prod(i).astype(np.int32) for i in output_shapes]
        n_outputs = sum(splits)

        # we don't use bias in the output layer in order to separate the static
        # and dynamic parts of the CPR
        caps_mlp = BatchMLP([self._n_hiddens, n_outputs], use_bias=False)
        all_params = caps_mlp(caps_params)
        all_params = tf.split(all_params, splits, -1)
        res = [
            tf.reshape(i, batch_shape + s)
            for (i, s) in zip(all_params, output_shapes)
        ]

        cpr_dynamic = res[0]

        # add bias to all remaining outputs
        res = [snt.AddBias()(i) for i in res[1:]]
        ccr, pres_logit_per_caps, pres_logit_per_vote, scale_per_vote = res

        if self._caps_dropout_rate != 0.0:
            pres_logit_per_caps += math_ops.safe_log(caps_exist)

        cpr_static = tf.get_variable(
            'cpr_static',
            shape=[1, self._n_caps, self._n_votes, self._n_transform_params])

        def add_noise(tensor):
            """Adds noise to tensors."""
            if self._noise_type == 'uniform':
                noise = tf.random.uniform(tensor.shape, minval=-.5,
                                          maxval=.5) * self._noise_scale

            elif self._noise_type == 'logistic':
                pdf = tfd.Logistic(0., self._noise_scale)
                noise = pdf.sample(tensor.shape)

            elif not self._noise_type:
                noise = 0.

            else:
                raise ValueError('Invalid noise type: "{}".'.format(
                    self._noise_type))

            return tensor + noise

        pres_logit_per_caps = add_noise(pres_logit_per_caps)
        pres_logit_per_vote = add_noise(pres_logit_per_vote)

        # this is for hierarchical
        if parent_transform is None:
            ccr = self._make_transform(ccr)
        else:
            ccr = parent_transform

        if not self._deformations:
            cpr_dynamic = tf.zeros_like(cpr_dynamic)

        cpr = self._make_transform(cpr_dynamic + cpr_static)

        ccr_per_vote = snt.TileByDim([2], [self._n_votes])(ccr)
        votes = tf.matmul(ccr_per_vote, cpr)

        if parent_presence is not None:
            pres_per_caps = parent_presence
        else:
            pres_per_caps = tf.nn.sigmoid(pres_logit_per_caps)

        pres_per_vote = pres_per_caps * tf.nn.sigmoid(pres_logit_per_vote)

        if self._learn_vote_scale:
            # for numerical stability
            scale_per_vote = tf.nn.softplus(scale_per_vote + .5) + 1e-2
        else:
            scale_per_vote = tf.zeros_like(scale_per_vote) + 1.

        return AttrDict(
            vote=votes,
            scale=scale_per_vote,
            vote_presence=pres_per_vote,
            pres_logit_per_caps=pres_logit_per_caps,
            pres_logit_per_vote=pres_logit_per_vote,
            dynamic_weights_l2=tf.nn.l2_loss(cpr_dynamic) / batch_size,
            raw_caps_params=raw_caps_params,
            raw_caps_features=features,
        )
示例#8
0
    def _build(self, x, presence=None):

        # x is [B, n_input_points, n_input_dims]
        batch_size, n_input_points = x.shape[:2].as_list()

        # votes and scale have shape [B, n_caps, n_input_points, n_input_dims|1]
        # since scale is a per-caps scalar and we have one vote per capsule
        vote_component_pdf = self._get_pdf(self._votes,
                                           tf.expand_dims(self._scales, -1))

        # expand along caps dimensions -> [B, 1, n_input_points, n_input_dims]
        expanded_x = tf.expand_dims(x, 1)
        vote_log_prob_per_dim = vote_component_pdf.log_prob(expanded_x)
        # [B, n_caps, n_input_points]
        vote_log_prob = tf.reduce_sum(vote_log_prob_per_dim, -1)
        dummy_vote_log_prob = tf.zeros([batch_size, 1, n_input_points])
        dummy_vote_log_prob -= 2. * tf.log(10.)

        # [B, n_caps + 1, n_input_points]
        vote_log_prob = tf.concat([vote_log_prob, dummy_vote_log_prob], 1)

        # [B, n_caps, n_input_points]
        mixing_logits = math_ops.safe_log(self._vote_presence_prob)

        dummy_logit = tf.zeros([batch_size, 1, 1]) - 2. * tf.log(10.)
        dummy_logit = snt.TileByDim([2], [n_input_points])(dummy_logit)

        # [B, n_caps + 1, n_input_points]
        mixing_logits = tf.concat([mixing_logits, dummy_logit], 1)
        mixing_log_prob = mixing_logits - tf.reduce_logsumexp(
            mixing_logits, 1, keepdims=True)
        # [B, n_input_points]
        mixture_log_prob_per_point = tf.reduce_logsumexp(
            mixing_logits + vote_log_prob, 1)

        if presence is not None:
            presence = tf.cast(presence, tf.float32)
            mixture_log_prob_per_point *= presence

        # [B,]
        mixture_log_prob_per_example\
          = tf.reduce_sum(mixture_log_prob_per_point, 1)

        # []
        mixture_log_prob_per_batch = tf.reduce_mean(
            mixture_log_prob_per_example)

        # [B, n_caps + 1, n_input_points]
        posterior_mixing_logits_per_point = mixing_logits + vote_log_prob

        # [B, n_input_points]
        winning_vote_idx = tf.argmax(posterior_mixing_logits_per_point[:, :-1],
                                     1)

        batch_idx = tf.expand_dims(tf.range(batch_size, dtype=tf.int64), 1)
        batch_idx = snt.TileByDim([1], [n_input_points])(batch_idx)

        point_idx = tf.expand_dims(tf.range(n_input_points, dtype=tf.int64), 0)
        point_idx = snt.TileByDim([0], [batch_size])(point_idx)

        idx = tf.stack([batch_idx, winning_vote_idx, point_idx], -1)
        winning_vote = tf.gather_nd(self._votes, idx)
        winning_pres = tf.gather_nd(self._vote_presence_prob, idx)
        vote_presence = tf.greater(mixing_logits[:, :-1], mixing_logits[:,
                                                                        -1:])

        # the first four votes belong to the square
        is_from_capsule = winning_vote_idx // self._n_votes

        posterior_mixing_probs = tf.nn.softmax(
            posterior_mixing_logits_per_point, 1)

        dummy_vote = tf.get_variable('dummy_vote',
                                     shape=self._votes[:1, :1].shape)
        dummy_vote = snt.TileByDim([0], [batch_size])(dummy_vote)
        dummy_pres = tf.zeros([batch_size, 1, n_input_points])

        votes = tf.concat((self._votes, dummy_vote), 1)
        pres = tf.concat([self._vote_presence_prob, dummy_pres], 1)

        soft_winner = tf.reduce_sum(
            tf.expand_dims(posterior_mixing_probs, -1) * votes, 1)
        soft_winner_pres = tf.reduce_sum(posterior_mixing_probs * pres, 1)

        posterior_mixing_probs = tf.transpose(posterior_mixing_probs[:, :-1],
                                              (0, 2, 1))

        assert winning_vote.shape == x.shape

        return self.OutputTuple(
            log_prob=mixture_log_prob_per_batch,
            vote_presence=tf.cast(vote_presence, tf.float32),
            winner=winning_vote,
            winner_pres=winning_pres,
            soft_winner=soft_winner,
            soft_winner_pres=soft_winner_pres,
            posterior_mixing_probs=posterior_mixing_probs,
            is_from_capsule=is_from_capsule,
            mixing_logits=mixing_logits,
            mixing_log_prob=mixing_log_prob,
        )
示例#9
0
    def _build(self, x, presence=None):

        batch_size, n_input_points = x.shape[:2].as_list()

        # we don't know what order the initial points came in, so we need to create
        # a big mixture of all votes for every input point
        # [B, 1, n_votes, n_input_dims]
        expanded_votes = tf.expand_dims(self._votes, 1)
        expanded_scale = tf.expand_dims(tf.expand_dims(self._scales, 1), -1)
        vote_component_pdf = self._get_pdf(expanded_votes, expanded_scale)

        # [B, n_points, n_caps, n_votes, n_input_dims]
        expanded_x = tf.expand_dims(x, 2)
        vote_log_prob_per_dim = vote_component_pdf.log_prob(expanded_x)
        # [B, n_points, n_votes]
        vote_log_prob = tf.reduce_sum(vote_log_prob_per_dim, -1)
        dummy_vote_log_prob = tf.zeros([batch_size, n_input_points, 1])
        dummy_vote_log_prob -= 2. * tf.log(10.)
        vote_log_prob = tf.concat([vote_log_prob, dummy_vote_log_prob], 2)

        # [B, n_points, n_votes]
        mixing_logits = math_ops.safe_log(self._vote_presence_prob)

        dummy_logit = tf.zeros([batch_size, 1]) - 2. * tf.log(10.)
        mixing_logits = tf.concat([mixing_logits, dummy_logit], 1)

        mixing_log_prob = mixing_logits - tf.reduce_logsumexp(
            mixing_logits, 1, keepdims=True)

        expanded_mixing_logits = tf.expand_dims(mixing_log_prob, 1)
        mixture_log_prob_per_component\
          = tf.reduce_logsumexp(expanded_mixing_logits + vote_log_prob, 2)

        if presence is not None:
            presence = tf.cast(presence, tf.float32)
            mixture_log_prob_per_component *= presence

        mixture_log_prob_per_example\
          = tf.reduce_sum(mixture_log_prob_per_component, 1)

        mixture_log_prob_per_batch = tf.reduce_mean(
            mixture_log_prob_per_example)

        # [B, n_points, n_votes]
        posterior_mixing_logits_per_point = expanded_mixing_logits + vote_log_prob
        # [B, n_points]
        winning_vote_idx = tf.argmax(
            posterior_mixing_logits_per_point[:, :, :-1], 2)

        batch_idx = tf.expand_dims(tf.range(batch_size, dtype=tf.int64), -1)
        batch_idx = snt.TileByDim([1], [winning_vote_idx.shape[-1]])(batch_idx)

        idx = tf.stack([batch_idx, winning_vote_idx], -1)
        winning_vote = tf.gather_nd(self._votes, idx)
        winning_pres = tf.gather_nd(self._vote_presence_prob, idx)
        vote_presence = tf.greater(mixing_logits[:, :-1], mixing_logits[:,
                                                                        -1:])

        # the first four votes belong to the square
        is_from_capsule = winning_vote_idx // self._n_votes

        posterior_mixing_probs = tf.nn.softmax(
            posterior_mixing_logits_per_point, -1)[Ellipsis, :-1]

        assert winning_vote.shape == x.shape

        return self.OutputTuple(
            log_prob=mixture_log_prob_per_batch,
            vote_presence=tf.cast(vote_presence, tf.float32),
            winner=winning_vote,
            winner_pres=winning_pres,
            is_from_capsule=is_from_capsule,
            mixing_logits=mixing_logits,
            mixing_log_prob=mixing_log_prob,
            # TODO(adamrk): this is broken
            soft_winner=tf.zeros_like(winning_vote),
            soft_winner_pres=tf.zeros_like(winning_pres),
            posterior_mixing_probs=posterior_mixing_probs,
        )
示例#10
0
def naive_log_likelihood(x, presence=None):
    """Implementation from original repo ripped wholesale"""

    batch_size, n_input_points = x.shape[:2].as_list()

    # Generate gaussian mixture pdfs...
    # [B, 1, n_votes, n_input_dims]
    expanded_votes = tf.expand_dims(_votes, 1)
    expanded_scale = tf.expand_dims(tf.expand_dims(_scales, 1), -1)
    vote_component_pdf = _get_pdf(expanded_votes, expanded_scale)

    # For each part, evaluates all capsule, vote mixture likelihoods
    # [B, n_points, n_caps x n_votes, n_input_dims]
    expanded_x = tf.expand_dims(x, 2)
    vote_log_prob_per_dim = vote_component_pdf.log_prob(expanded_x)

    # Compressing mixture likelihood across all part dimension (ie. 2d point)
    # [B, n_points, n_caps x n_votes]
    vote_log_prob = tf.reduce_sum(vote_log_prob_per_dim, -1)
    dummy_vote_log_prob = tf.zeros([batch_size, n_input_points, 1])
    dummy_vote_log_prob -= 2. * tf.log(10.)
    # adding extra [B, n_points, n_caps x n_votes] to end. WHY?
    vote_log_prob = tf.concat([vote_log_prob, dummy_vote_log_prob], 2)

    # [B, n_points, n_caps x n_votes]
    # CONDITIONAL LOGIT a_(k,n)
    mixing_logits = math_ops.safe_log(_vote_presence_prob)

    dummy_logit = tf.zeros([batch_size, 1]) - 2. * tf.log(10.)
    mixing_logits = tf.concat([mixing_logits, dummy_logit], 1)

    #
    # Following seems relevant only towards compressing ll for loss.
    # REDUNDANCY
    #

    # mixing_logits -> presence (a)
    # vote_log_prob -> Gaussian value (one per vote) for each coordinate

    # BAD -> vote presence / summed vote presence
    mixing_log_prob = mixing_logits - tf.reduce_logsumexp(
        mixing_logits, 1, keepdims=True)

    # BAD -> mixing presence (above) * each vote gaussian prob
    expanded_mixing_logits = tf.expand_dims(mixing_log_prob, 1)
    # Reduce to loglikelihood given k,n combination (capsule, vote)
    mixture_log_prob_per_component\
        = tf.reduce_logsumexp(expanded_mixing_logits + vote_log_prob, 2)

    if presence is not None:
        presence = tf.to_float(presence)
        mixture_log_prob_per_component *= presence

    # Reduce votes to single capsule
    # ^ Misleading, reducing across all parts, multiplying log
    # likelihoods for each part _wrt all capsules_.
    mixture_log_prob_per_example\
        = tf.reduce_sum(mixture_log_prob_per_component, 1)

    # Same as above but across all compressed part likelihoods in a batch.
    mixture_log_prob_per_batch = tf.reduce_mean(mixture_log_prob_per_example)

    #
    # Back from compression to argmax (routing to proper k)
    #

    # [B, n_points, n_votes]
    posterior_mixing_logits_per_point = expanded_mixing_logits + vote_log_prob
    # [B, n_points]
    winning_vote_idx = tf.argmax(posterior_mixing_logits_per_point[:, :, :-1],
                                 2)

    batch_idx = tf.expand_dims(tf.range(batch_size, dtype=tf.int64), -1)
    batch_idx = snt.TileByDim([1], [winning_vote_idx.shape[-1]])(batch_idx)

    idx = tf.stack([batch_idx, winning_vote_idx], -1)
    winning_vote = tf.gather_nd(_votes, idx)
    winning_pres = tf.gather_nd(_vote_presence_prob, idx)
    vote_presence = tf.greater(mixing_logits[:, :-1], mixing_logits[:, -1:])

    # the first four votes belong to the square
    # Just assuming the votes are ordered by capsule...
    is_from_capsule = winning_vote_idx // _n_votes

    posterior_mixing_probs = tf.nn.softmax(posterior_mixing_logits_per_point,
                                           -1)[Ellipsis, :-1]

    assert winning_vote.shape == x.shape

    return OutputTuple(
        log_prob=mixture_log_prob_per_batch,
        vote_presence=tf.to_float(vote_presence),
        winner=winning_vote,
        winner_pres=winning_pres,
        is_from_capsule=is_from_capsule,
        mixing_logits=mixing_logits,
        mixing_log_prob=mixing_log_prob,
        # TODO(adamrk): this is broken
        soft_winner=tf.zeros_like(winning_vote),
        soft_winner_pres=tf.zeros_like(winning_pres),
        posterior_mixing_probs=posterior_mixing_probs,
    )
示例#11
0
def argmax_log_likelihood(x, presence=None):
    """Most simple of the optimization schemes.

    Skip the product of closeform probability of part given _all_ data. Rather
    use the value at the argmax as a proxy for each part.
    """

    batch_size, n_input_points = x.shape[:2].as_list()

    # Generate gaussian mixture pdfs...
    # [B, 1, n_votes, n_input_dims]
    expanded_votes = tf.expand_dims(_votes, 1)
    expanded_scale = tf.expand_dims(tf.expand_dims(_scales, 1), -1)
    vote_component_pdf = _get_pdf(expanded_votes, expanded_scale)

    # For each part, evaluates all capsule, vote mixture likelihoods
    # [B, n_points, n_caps x n_votes, n_input_dims]
    expanded_x = tf.expand_dims(x, 2)
    vote_log_prob_per_dim = vote_component_pdf.log_prob(expanded_x)

    # Compressing mixture likelihood across all part dimension (ie. 2d point)
    # [B, n_points, n_caps x n_votes]
    vote_log_prob = tf.reduce_sum(vote_log_prob_per_dim, -1)
    dummy_vote_log_prob = tf.zeros([batch_size, n_input_points, 1])
    dummy_vote_log_prob -= 2. * tf.log(10.)
    # adding extra [B, n_points, n_caps x n_votes] to end. WHY?
    vote_log_prob = tf.concat([vote_log_prob, dummy_vote_log_prob], 2)

    # [B, n_points, n_caps x n_votes]
    # CONDITIONAL LOGIT a_(k,n)
    mixing_logits = math_ops.safe_log(_vote_presence_prob)

    dummy_logit = tf.zeros([batch_size, 1]) - 2. * tf.log(10.)
    mixing_logits = tf.concat([mixing_logits, dummy_logit], 1)

    # BAD -> vote presence / summed vote presence
    mixing_log_prob = mixing_logits - tf.reduce_logsumexp(
        mixing_logits, 1, keepdims=True)

    expanded_mixing_logits = tf.expand_dims(mixing_log_prob, 1)

    # [B, n_points, n_votes]
    posterior_mixing_logits_per_point = expanded_mixing_logits + vote_log_prob
    # [B, n_points]
    winning_vote_idx = tf.argmax(posterior_mixing_logits_per_point[:, :, :-1],
                                 2)

    batch_idx = tf.expand_dims(tf.range(batch_size, dtype=tf.int64), -1)
    batch_idx = snt.TileByDim([1], [winning_vote_idx.shape[-1]])(batch_idx)

    idx = tf.stack([batch_idx, winning_vote_idx], -1)
    winning_vote = tf.gather_nd(_votes, idx)
    winning_pres = tf.gather_nd(_vote_presence_prob, idx)
    vote_presence = tf.greater(mixing_logits[:, :-1], mixing_logits[:, -1:])

    # the first four votes belong to the square
    # Just assuming the votes are ordered by capsule...
    is_from_capsule = winning_vote_idx // _n_votes

    posterior_mixing_probs = tf.nn.softmax(posterior_mixing_logits_per_point,
                                           -1)[Ellipsis, :-1]

    assert winning_vote.shape == x.shape

    # log_prob=mixture_log_prob_per_batch,
    return OutputTuple(
        log_prob=None,
        vote_presence=tf.to_float(vote_presence),
        winner=winning_vote,
        winner_pres=winning_pres,
        is_from_capsule=is_from_capsule,
        mixing_logits=mixing_logits,
        mixing_log_prob=mixing_log_prob,
        # TODO(adamrk): this is broken
        soft_winner=tf.zeros_like(winning_vote),
        soft_winner_pres=tf.zeros_like(winning_pres),
        posterior_mixing_probs=posterior_mixing_probs,
    )
示例#12
0
    def _build(self, data):

        input_x = self._img(data, False)
        target_x = self._img(data, prep=self._prep)
        batch_size = int(input_x.shape[0])

        primary_caps = self._primary_encoder(input_x)
        pres = primary_caps.presence

        expanded_pres = tf.expand_dims(pres, -1)
        pose = primary_caps.pose
        input_pose = tf.concat([pose, 1. - expanded_pres], -1)

        input_pres = pres
        if self._stop_grad_caps_inpt:
            input_pose = tf.stop_gradient(input_pose)
            input_pres = tf.stop_gradient(pres)

        target_pose, target_pres = pose, pres
        if self._stop_grad_caps_target:
            target_pose = tf.stop_gradient(target_pose)
            target_pres = tf.stop_gradient(target_pres)

        # skip connection from the img to the higher level capsule
        if primary_caps.feature is not None:
            input_pose = tf.concat([input_pose, primary_caps.feature], -1)

        # try to feed presence as a separate input
        # and if that works, concatenate templates to poses
        # this is necessary for set transformer
        n_templates = int(primary_caps.pose.shape[1])
        templates = self._primary_decoder.make_templates(
            n_templates, primary_caps.feature)

        try:
            if self._feed_templates:
                inpt_templates = templates
                if self._stop_grad_caps_inpt:
                    inpt_templates = tf.stop_gradient(inpt_templates)

                if inpt_templates.shape[0] == 1:
                    inpt_templates = snt.TileByDim(
                        [0], [batch_size])(inpt_templates)
                inpt_templates = snt.BatchFlatten(2)(inpt_templates)
                pose_with_templates = tf.concat([input_pose, inpt_templates],
                                                -1)
            else:
                pose_with_templates = input_pose

            h = self._encoder(pose_with_templates, input_pres)

        except TypeError:
            h = self._encoder(input_pose)

        res = self._decoder(h, target_pose, target_pres)
        res.primary_presence = primary_caps.presence

        if self._vote_type == 'enc':
            primary_dec_vote = primary_caps.pose
        elif self._vote_type == 'soft':
            primary_dec_vote = res.soft_winner
        elif self._vote_type == 'hard':
            primary_dec_vote = res.winner
        else:
            raise ValueError('Invalid vote_type="{}"".'.format(
                self._vote_type))

        if self._pres_type == 'enc':
            primary_dec_pres = pres
        elif self._pres_type == 'soft':
            primary_dec_pres = res.soft_winner_pres
        elif self._pres_type == 'hard':
            primary_dec_pres = res.winner_pres
        else:
            raise ValueError('Invalid pres_type="{}"".'.format(
                self._pres_type))

        res.bottom_up_rec = self._primary_decoder(
            primary_caps.pose,
            primary_caps.presence,
            template_feature=primary_caps.feature,
            img_embedding=primary_caps.img_embedding)

        res.top_down_rec = self._primary_decoder(
            res.winner,
            primary_caps.presence,
            template_feature=primary_caps.feature,
            img_embedding=primary_caps.img_embedding)

        rec = self._primary_decoder(primary_dec_vote,
                                    primary_dec_pres,
                                    template_feature=primary_caps.feature,
                                    img_embedding=primary_caps.img_embedding)

        tile = snt.TileByDim([0], [res.vote.shape[1]])
        tiled_presence = tile(primary_caps.presence)

        tiled_feature = primary_caps.feature
        if tiled_feature is not None:
            tiled_feature = tile(tiled_feature)

        tiled_img_embedding = tile(primary_caps.img_embedding)

        res.top_down_per_caps_rec = self._primary_decoder(
            snt.MergeDims(0, 2)(res.vote),
            snt.MergeDims(0, 2)(res.vote_presence) * tiled_presence,
            template_feature=tiled_feature,
            img_embedding=tiled_img_embedding)

        res.templates = templates
        res.template_pres = pres
        res.used_templates = rec.transformed_templates

        res.rec_mode = rec.pdf.mode()
        res.rec_mean = rec.pdf.mean()

        res.mse_per_pixel = tf.square(target_x - res.rec_mode)
        res.mse = math_ops.flat_reduce(res.mse_per_pixel)

        res.rec_ll_per_pixel = rec.pdf.log_prob(target_x)
        res.rec_ll = math_ops.flat_reduce(res.rec_ll_per_pixel)

        n_points = int(res.posterior_mixing_probs.shape[1])
        mass_explained_by_capsule = tf.reduce_sum(res.posterior_mixing_probs,
                                                  1)

        (res.posterior_within_sparsity_loss,
         res.posterior_between_sparsity_loss) = _capsule.sparsity_loss(
             self._posterior_sparsity_loss_type,
             mass_explained_by_capsule / n_points,
             num_classes=self._n_classes)

        (res.prior_within_sparsity_loss,
         res.prior_between_sparsity_loss) = _capsule.sparsity_loss(
             self._prior_sparsity_loss_type,
             res.caps_presence_prob,
             num_classes=self._n_classes,
             within_example_constant=self._prior_within_example_constant)

        label = self._label(data)
        if label is not None:
            res.posterior_cls_xe, res.posterior_cls_acc = probe.classification_probe(
                mass_explained_by_capsule,
                label,
                self._n_classes,
                labeled=data.get('labeled', None))
            res.prior_cls_xe, res.prior_cls_acc = probe.classification_probe(
                res.caps_presence_prob,
                label,
                self._n_classes,
                labeled=data.get('labeled', None))

        res.best_cls_acc = tf.maximum(res.prior_cls_acc, res.posterior_cls_acc)

        res.primary_caps_l1 = math_ops.flat_reduce(res.primary_presence)

        if self._weight_decay > 0.0:
            decay_losses_list = []
            for var in tf.trainable_variables():
                if 'w:' in var.name or 'weights:' in var.name:
                    decay_losses_list.append(tf.nn.l2_loss(var))
            res.weight_decay_loss = tf.reduce_sum(decay_losses_list)
        else:
            res.weight_decay_loss = 0.0

        return res
示例#13
0
    def _build(self,
               pose,
               presence=None,
               template_feature=None,
               bg_image=None,
               img_embedding=None):
        """Builds the module.

    Args:
      pose: [B, n_templates, 6] tensor.
      presence: [B, n_templates] tensor.
      template_feature: [B, n_templates, n_features] tensor; these features are
        used to change templates based on the input, if present.
      bg_image: [B, *output_size] tensor representing the background.
      img_embedding: [B, d] tensor containing image embeddings.

    Returns:
      [B, n_templates, *output_size, n_channels] tensor.
    """
        batch_size, n_templates = pose.shape[:2].as_list()
        templates = self.make_templates(n_templates, template_feature)

        if templates.shape[0] == 1:
            templates = snt.TileByDim([0], [batch_size])(templates)

        # it's easier for me to think in inverse coordinates
        warper = snt.AffineGridWarper(self._output_size, self._template_size)
        warper = warper.inverse()

        grid_coords = snt.BatchApply(warper)(pose)
        resampler = snt.BatchApply(contrib_resampler.resampler)
        transformed_templates = resampler(templates, grid_coords)

        if bg_image is not None:
            bg_image = tf.expand_dims(bg_image, axis=1)
        else:
            bg_image = tf.nn.sigmoid(tf.get_variable('bg_value', shape=[1]))
            bg_image = tf.zeros_like(transformed_templates[:, :1]) + bg_image

        transformed_templates = tf.concat([transformed_templates, bg_image],
                                          axis=1)

        if presence is not None:
            presence = tf.concat([presence, tf.ones([batch_size, 1])], axis=1)

        if True:  # pylint: disable=using-constant-test

            if self._use_alpha_channel:
                template_mixing_logits = snt.TileByDim([0], [batch_size])(
                    self._templates_alpha)
                template_mixing_logits = resampler(template_mixing_logits,
                                                   grid_coords)

                bg_mixing_logit = tf.nn.softplus(
                    tf.get_variable('bg_mixing_logit', initializer=[0.]))

                bg_mixing_logit = (
                    tf.zeros_like(template_mixing_logits[:, :1]) +
                    bg_mixing_logit)

                template_mixing_logits = tf.concat(
                    [template_mixing_logits, bg_mixing_logit], 1)

            else:
                temperature_logit = tf.get_variable('temperature_logit',
                                                    shape=[1])
                temperature = tf.nn.softplus(temperature_logit + .5) + 1e-4
                template_mixing_logits = transformed_templates / temperature

        scale = 1.
        if self._learn_output_scale:
            scale = tf.get_variable('scale', shape=[1])
            scale = tf.nn.softplus(scale) + 1e-4

        if self._output_pdf_type == 'mixture':
            template_mixing_logits += make_brodcastable(
                math_ops.safe_log(presence), template_mixing_logits)

            rec_pdf = prob.MixtureDistribution(template_mixing_logits,
                                               [transformed_templates, scale],
                                               tfd.Normal)

        else:
            raise ValueError('Unknown pdf type: "{}".'.format(
                self._output_pdf_type))

        return AttrDict(raw_templates=tf.squeeze(self._templates, 0),
                        transformed_templates=transformed_templates[:, :-1],
                        mixing_logits=template_mixing_logits[:, :-1],
                        pdf=rec_pdf)
示例#14
0
def render_constellations(pred_points,
                          capsule_num,
                          canvas_size,
                          gt_points=None,
                          n_caps=2,
                          gt_presence=None,
                          pred_presence=None,
                          caps_presence_prob=None):
    """Renderes predicted and ground-truth points as gaussian blobs.

  Args:
    pred_points: [B, m, 2].
    capsule_num: [B, m] tensor indicating which capsule the corresponding point
      comes from. Plots from different capsules are plotted with different
      colors. Currently supported values: {0, 1, ..., 11}.
    canvas_size: tuple of ints
    gt_points: [B, k, 2]; plots ground-truth points if present.
    n_caps: integer, number of capsules.
    gt_presence: [B, k] binary tensor.
    pred_presence: [B, m] binary tensor.
    caps_presence_prob: [B, m], a tensor of presence probabilities for caps.

  Returns:
    [B, *canvas_size] tensor with plotted points
  """

    # convert coords to be in [0, side_length]
    pred_points = denormalize_coords(pred_points, canvas_size, rounded=True)

    # render predicted points
    batch_size, n_points = pred_points.shape[:2].as_list()
    capsule_num = tf.to_float(tf.one_hot(capsule_num, depth=n_caps))
    capsule_num = tf.reshape(capsule_num,
                             [batch_size, n_points, 1, 1, n_caps, 1])

    color = tf.convert_to_tensor(_COLORS[:n_caps])
    color = tf.reshape(color, [1, 1, 1, 1, n_caps, 3]) * capsule_num
    color = tf.reduce_sum(color, -2)
    color = tf.squeeze(tf.squeeze(color, 3), 2)

    colored = render_by_scatter(canvas_size, pred_points, color, pred_presence)

    # Prepare a vertical separator between predicted and gt points.
    # Separator is composed of all supported colors and also serves as
    # a legend.
    # [b, h, w, 3]
    n_colors = _COLORS.shape[0]
    sep = tf.reshape(tf.convert_to_tensor(_COLORS), [1, 1, n_colors, 3])
    n_tiles = int(colored.shape[2]) // n_colors
    sep = snt.TileByDim([0, 1, 3], [batch_size, 3, n_tiles])(sep)
    sep = tf.reshape(sep, [batch_size, 3, n_tiles * n_colors, 3])

    pad = int(colored.shape[2]) - n_colors * n_tiles
    pad, r = pad // 2, pad % 2

    if caps_presence_prob is not None:
        n_caps = int(caps_presence_prob.shape[1])
        prob_pads = ([0, 0], [0, n_colors - n_caps])
        caps_presence_prob = tf.pad(caps_presence_prob, prob_pads)
        zeros = tf.zeros([batch_size, 3, n_colors, n_tiles, 3],
                         dtype=tf.float32)

        shape = [batch_size, 1, n_colors, 1, 1]
        caps_presence_prob = tf.reshape(caps_presence_prob, shape)

        prob_vals = snt.MergeDims(2, 2)(caps_presence_prob + zeros)
        sep = tf.concat([sep, tf.ones_like(sep[:, :1]), prob_vals], 1)

    sep = tf.pad(sep, [(0, 0), (1, 1), (pad, pad + r), (0, 0)],
                 constant_values=1.)

    # render gt points
    if gt_points is not None:
        gt_points = denormalize_coords(gt_points, canvas_size, rounded=True)

        gt_rendered = render_by_scatter(canvas_size,
                                        gt_points,
                                        colors=None,
                                        gt_presence=gt_presence)

        colored = tf.where(tf.cast(colored, bool), colored, gt_rendered)
        colored = tf.concat([gt_rendered, sep, colored], 1)

    res = tf.clip_by_value(colored, 0., 1.)
    return res
    def _build(self, feature, parent_transform=None, parent_presence=None):
        """Builds the module.

    Args:
      features: Tensor of encodings of shape [B, n_enc_dims].
      parent_transform: Tuple of (matrix, vector).
      parent_presence: pass

    Returns:
      A bunch of stuff.
    """
        features = tf.ones([200, 32, 256])
        batch_size = features.shape.as_list()[0]
        print(batch_size)
        batch_shape = [batch_size, self._n_caps]

        # Predict capsule and additional params from the input encoding.
        # [B, n_caps, n_caps_dims]
        if self._n_caps_params is not None:

            # Use separate parameters to do predictions for different capsules.
            mlp = BatchMLP(self._n_hiddens + [self._n_caps_params])
            print('mlp1')
            print(features.shape)
            raw_caps_params = mlp(features)
            caps_params = tf.reshape(raw_caps_params,
                                     batch_shape + [self._n_caps_params])

        else:
            assert features.shape[:2].as_list() == batch_shape
            caps_params = features

        if self._caps_dropout_rate == 0.0:
            caps_exist = tf.ones(batch_shape + [1], dtype=tf.float32)
        else:
            pmf = tfd.Bernoulli(1. - self._caps_dropout_rate, dtype=tf.float32)
            caps_exist = pmf.sample(batch_shape + [1])

        caps_params = tf.concat([caps_params, caps_exist], -1)

        output_shapes = (
            [self._n_votes, self._n_transform_params],  # CPR_dynamic
            [1, self._n_transform_params],  # CCR
            [1],  # per-capsule presence
            [self._n_votes],  # per-vote-presence
            [self._n_votes],  # per-vote scale
        )

        splits = [np.prod(i).astype(np.int32) for i in output_shapes]
        n_outputs = sum(splits)

        # we don't use bias in the output layer in order to separate the static
        # and dynamic parts of the CPR
        caps_mlp = BatchMLP([self._n_hiddens, n_outputs], use_bias=False)
        print('mlp2')
        print(caps_params.shape)
        all_params = caps_mlp(caps_params)
        all_params = tf.split(all_params, splits, -1)
        res = [
            tf.reshape(i, batch_shape + s)
            for (i, s) in zip(all_params, output_shapes)
        ]

        cpr_dynamic = res[0]

        # add bias to all remaining outputs
        ccr = res[1]

        cpr_static = tf.get_variable(
            'cpr_static',
            shape=[1, self._n_caps, self._n_votes, self._n_transform_params])

        # this is for hierarchical
        if parent_transform is None:
            ccr = self._make_transform(ccr)
        else:
            ccr = parent_transform

        if not self._deformations:
            cpr_dynamic = tf.zeros_like(cpr_dynamic)

        cpr = self._make_transform(cpr_dynamic + cpr_static)

        ccr_per_vote = snt.TileByDim([2], [self._n_votes])(ccr)
        print('start matmul')
        print(ccr_per_vote.shape)
        print(cpr.shape)
        print('end matmul')
        votes = tf.matmul(ccr_per_vote, cpr)

        return votes