def engineered_policies(images, logits2d, position_channels, glimpse_shape, num_times, policy): """Engineered policies. Args: images: A Tensor of type float32. A 4-D float tensor of shape [batch_size, height, width, channels]. logits2d: 2D logits tensor of type float32 of shape [batch_size, height, width, classes]. position_channels: A Tensor of type float32 containing the output of `utils.position_channels` called on `images`. glimpse_shape: (Tuple of integer) Glimpse shape. num_times: (Integer) Number of glimpses. policy: (String) 'ordered logits', 'sobel_mean', or 'sobel_var'. Returns: locations_t: List of 2D Tensors containing policy locations. """ if policy == "ordered_logits": pred_labels = tf.argmax(tf.reduce_mean(logits2d, axis=[1, 2]), axis=-1) metric2d = utils.batch_gather_nd(logits2d, pred_labels, axis=-1) elif "sobel" in policy: edges = sobel_edges(images) edges = edges[:, :, :, tf.newaxis] _, orig_h, orig_w, _ = edges.shape.as_list() _, h, w, _ = logits2d.shape.as_list() ksize = [1, glimpse_shape[0], glimpse_shape[1], 1] strides = [ 1, int(np.ceil((orig_h - glimpse_shape[0] + 1) / h)), int(np.ceil((orig_w - glimpse_shape[1] + 1) / w)), 1 ] mean_per_glimpse = tf.nn.avg_pool(edges, ksize=ksize, strides=strides, padding="VALID") if "mean" in policy: metric2d = mean_per_glimpse elif "var" in policy: n = np.prod(glimpse_shape) var_per_glimpse = (n / (n - 1.)) * (tf.nn.avg_pool(tf.square(edges), ksize=ksize, strides=strides, padding="VALID") - tf.square(mean_per_glimpse)) metric2d = var_per_glimpse metric2d = tf.squeeze(metric2d, axis=3) _, locations_t = utils.sort2d(metric2d, position_channels, first_k=num_times, direction="DESCENDING") locations_t = tf.unstack(locations_t, axis=0) return locations_t
def test_sort2d(self, direction): elements = range(0, 9) x = tf.convert_to_tensor(np.array( [np.reshape(elements, (3, 3)), np.reshape(elements[::-1], (3, 3))]), dtype=tf.float32) ref_indices = utils.position_channels(x) sorted_x, argsorted_x = utils.sort2d(x, ref_indices, direction=direction) sorted_x = self.evaluate(sorted_x) argsorted_x = self.evaluate(argsorted_x) # Examples include same elements. So sorted examples should be equal. self.assertAllEqual(sorted_x[:, 0], sorted_x[:, 1]) # Examples are in reverse order. So, indices should be reversed. ndims = 2 for i in range(ndims): self.assertAllEqual(argsorted_x[:, 0, i], argsorted_x[:, 1, i][::-1])