Пример #1
0
def engineered_policies(images, logits2d, position_channels, glimpse_shape,
                        num_times, policy):
    """Engineered policies.

  Args:
    images: A Tensor of type float32. A 4-D float tensor of shape
      [batch_size, height, width, channels].
    logits2d: 2D logits tensor of type float32 of shape
      [batch_size, height, width, classes].
    position_channels: A Tensor of type float32 containing the output of
      `utils.position_channels` called on `images`.
    glimpse_shape: (Tuple of integer) Glimpse shape.
    num_times: (Integer) Number of glimpses.
    policy: (String) 'ordered logits', 'sobel_mean', or 'sobel_var'.


  Returns:
    locations_t: List of 2D Tensors containing policy locations.

  """
    if policy == "ordered_logits":
        pred_labels = tf.argmax(tf.reduce_mean(logits2d, axis=[1, 2]), axis=-1)
        metric2d = utils.batch_gather_nd(logits2d, pred_labels, axis=-1)

    elif "sobel" in policy:
        edges = sobel_edges(images)
        edges = edges[:, :, :, tf.newaxis]
        _, orig_h, orig_w, _ = edges.shape.as_list()
        _, h, w, _ = logits2d.shape.as_list()
        ksize = [1, glimpse_shape[0], glimpse_shape[1], 1]
        strides = [
            1,
            int(np.ceil((orig_h - glimpse_shape[0] + 1) / h)),
            int(np.ceil((orig_w - glimpse_shape[1] + 1) / w)), 1
        ]
        mean_per_glimpse = tf.nn.avg_pool(edges,
                                          ksize=ksize,
                                          strides=strides,
                                          padding="VALID")
        if "mean" in policy:
            metric2d = mean_per_glimpse
        elif "var" in policy:
            n = np.prod(glimpse_shape)
            var_per_glimpse = (n /
                               (n - 1.)) * (tf.nn.avg_pool(tf.square(edges),
                                                           ksize=ksize,
                                                           strides=strides,
                                                           padding="VALID") -
                                            tf.square(mean_per_glimpse))
            metric2d = var_per_glimpse

        metric2d = tf.squeeze(metric2d, axis=3)
    _, locations_t = utils.sort2d(metric2d,
                                  position_channels,
                                  first_k=num_times,
                                  direction="DESCENDING")
    locations_t = tf.unstack(locations_t, axis=0)

    return locations_t
Пример #2
0
    def test_sort2d(self, direction):
        elements = range(0, 9)
        x = tf.convert_to_tensor(np.array(
            [np.reshape(elements, (3, 3)),
             np.reshape(elements[::-1], (3, 3))]),
                                 dtype=tf.float32)

        ref_indices = utils.position_channels(x)
        sorted_x, argsorted_x = utils.sort2d(x,
                                             ref_indices,
                                             direction=direction)
        sorted_x = self.evaluate(sorted_x)
        argsorted_x = self.evaluate(argsorted_x)
        # Examples include same elements. So sorted examples should be equal.
        self.assertAllEqual(sorted_x[:, 0], sorted_x[:, 1])

        # Examples are in reverse order. So, indices should be reversed.
        ndims = 2
        for i in range(ndims):
            self.assertAllEqual(argsorted_x[:, 0, i], argsorted_x[:, 1,
                                                                  i][::-1])