Beispiel #1
0
  def _build(self, x):
    batch_size = x.shape[0]
    img_embedding = self._encoder(x)

    splits = [self._n_caps_dims, self._n_features, 1]  # 1 for presence
    n_dims = sum(splits)

    if self._encoder_type == 'linear':
      n_outputs = self._n_caps * n_dims

      h = snt.BatchFlatten()(img_embedding)
      h = snt.Linear(n_outputs)(h)

    else:
      h = snt.AddBias(bias_dims=[1, 2, 3])(img_embedding)

      if self._encoder_type == 'conv':
        h = snt.Conv2D(n_dims * self._n_caps, 1, 1)(h)
        h = tf.reduce_mean(h, (1, 2))
        h = tf.reshape(h, [batch_size, self._n_caps, n_dims])

      elif self._encoder_type == 'conv_att':
        h = snt.Conv2D(n_dims * self._n_caps + self._n_caps, 1, 1)(h)
        h = snt.MergeDims(1, 2)(h)
        h, a = tf.split(h, [n_dims * self._n_caps, self._n_caps], -1)

        h = tf.reshape(h, [batch_size, -1, n_dims, self._n_caps])
        a = tf.nn.softmax(a, 1)
        a = tf.reshape(a, [batch_size, -1, 1, self._n_caps])
        h = tf.reduce_sum(h * a, 1)

      else:
        raise ValueError('Invalid encoder type="{}".'.format(
            self._encoder_type))

    h = tf.reshape(h, [batch_size, self._n_caps, n_dims])

    pose, feature, pres_logit = tf.split(h, splits, -1)
    if self._n_features == 0:
      feature = None

    pres_logit = tf.squeeze(pres_logit, -1)
    if self._noise_scale > 0.:
      pres_logit += ((tf.random.uniform(pres_logit.shape) - .5)
                     * self._noise_scale)


    pres = tf.nn.sigmoid(pres_logit)
    pose = math_ops.geometric_transform(pose, self._similarity_transform)
    return self.OutputTuple(pose, feature, pres, pres_logit, img_embedding)
 def _make_transform(self, params):
     return math_ops.geometric_transform(params,
                                         self._similarity_transform,
                                         nonlinear=True,
                                         as_matrix=True)