Exemple #1
0
 def test_batch_gather_nd(self, axis):
     batch_size = 10
     axis_dims = [50, 40, 30]
     x = tf.random.uniform([
         batch_size,
     ] + axis_dims, minval=0, maxval=1)
     indices = tf.random.uniform(
         (batch_size, ),
         minval=0,
         maxval=axis_dims[axis - 1 if axis > 0 else axis],
         dtype=tf.int32)
     x_gathered = utils.batch_gather_nd(x, indices, axis)
     if axis == 1:
         identity = tf.eye(batch_size, batch_size)[:, :, tf.newaxis,
                                                   tf.newaxis]
     if axis == 2:
         identity = tf.eye(batch_size, batch_size)[:, tf.newaxis, :,
                                                   tf.newaxis]
     if axis == 3 or axis == -1:
         identity = tf.eye(batch_size, batch_size)[:, tf.newaxis,
                                                   tf.newaxis, :]
     x_gathered2 = tf.reduce_sum(identity *
                                 tf.gather(x, indices, axis=axis),
                                 axis=axis)
     diff = tf.reduce_sum(tf.math.abs(x_gathered - x_gathered2))
     self.assertAlmostEqual(self.evaluate(diff), 0., 9)
def engineered_policies(images, logits2d, position_channels, glimpse_shape,
                        num_times, policy):
    """Engineered policies.

  Args:
    images: A Tensor of type float32. A 4-D float tensor of shape
      [batch_size, height, width, channels].
    logits2d: 2D logits tensor of type float32 of shape
      [batch_size, height, width, classes].
    position_channels: A Tensor of type float32 containing the output of
      `utils.position_channels` called on `images`.
    glimpse_shape: (Tuple of integer) Glimpse shape.
    num_times: (Integer) Number of glimpses.
    policy: (String) 'ordered logits', 'sobel_mean', or 'sobel_var'.


  Returns:
    locations_t: List of 2D Tensors containing policy locations.

  """
    if policy == "ordered_logits":
        pred_labels = tf.argmax(tf.reduce_mean(logits2d, axis=[1, 2]), axis=-1)
        metric2d = utils.batch_gather_nd(logits2d, pred_labels, axis=-1)

    elif "sobel" in policy:
        edges = sobel_edges(images)
        edges = edges[:, :, :, tf.newaxis]
        _, orig_h, orig_w, _ = edges.shape.as_list()
        _, h, w, _ = logits2d.shape.as_list()
        ksize = [1, glimpse_shape[0], glimpse_shape[1], 1]
        strides = [
            1,
            int(np.ceil((orig_h - glimpse_shape[0] + 1) / h)),
            int(np.ceil((orig_w - glimpse_shape[1] + 1) / w)), 1
        ]
        mean_per_glimpse = tf.nn.avg_pool(edges,
                                          ksize=ksize,
                                          strides=strides,
                                          padding="VALID")
        if "mean" in policy:
            metric2d = mean_per_glimpse
        elif "var" in policy:
            n = np.prod(glimpse_shape)
            var_per_glimpse = (n /
                               (n - 1.)) * (tf.nn.avg_pool(tf.square(edges),
                                                           ksize=ksize,
                                                           strides=strides,
                                                           padding="VALID") -
                                            tf.square(mean_per_glimpse))
            metric2d = var_per_glimpse

        metric2d = tf.squeeze(metric2d, axis=3)
    _, locations_t = utils.sort2d(metric2d,
                                  position_channels,
                                  first_k=num_times,
                                  direction="DESCENDING")
    locations_t = tf.unstack(locations_t, axis=0)

    return locations_t
    def __call__(self,
                 mixed_features2d,
                 cell_state,
                 logits2d,
                 is_training=False,
                 policy="learned"):
        """Builds Saccader cell.

    Args:
      mixed_features2d: 4-D Tensor of shape [batch, height, width, channels].
      cell_state: 4-D Tensor of shape [batch, height, width, 1] with cell state.
      logits2d: 4-D Tensor of shape [batch, height, width, channels].
      is_training: (Boolean) To indicate training or inference modes.
      policy: (String) 'learned': uses learned policy, 'random': uses random
        policy, or 'center': uses center look policy.
    Returns:
      logits: Model logits.
      cell_state: New cell state.
      endpoints: Dictionary with cell parameters.
    """
        batch_size, height, width, channels = mixed_features2d.shape.as_list()
        reuse = True if self.var_list else False
        position_channels = utils.position_channels(mixed_features2d)

        variables_before = set(tf.global_variables())
        with tf.variable_scope("saccader_cell", reuse=reuse):
            # Compute 2D weights of features across space.
            features_space_logits = tf.layers.dense(
                mixed_features2d,
                units=1,
                use_bias=False,
                name="attention_weights") / tf.math.sqrt(float(channels))

            features_space_logits += (cell_state * -1.e5
                                      )  # Mask used locations.
            features_space_weights = utils.softmax2d(features_space_logits)

            # Compute 1D weights of features across channels.
            features_channels_logits = tf.reduce_sum(mixed_features2d *
                                                     features_space_weights,
                                                     axis=[1, 2])
            features_channels_weights = tf.nn.softmax(features_channels_logits,
                                                      axis=1)

            # Compute location probability.
            locations_logits2d = tf.reduce_sum(
                (mixed_features2d *
                 features_channels_weights[:, tf.newaxis, tf.newaxis, :]),
                axis=-1,
                keepdims=True)

            locations_logits2d += (cell_state * -1e5)  # Mask used locations.
            locations_prob2d = utils.softmax2d(locations_logits2d)

        variables_after = set(tf.global_variables())
        # Compute best locations.
        locations_logits = tf.reshape(locations_logits2d, (batch_size, -1))
        all_positions = tf.reshape(position_channels,
                                   [batch_size, height * width, 2])

        best_locations_labels = tf.argmax(locations_logits, axis=-1)
        best_locations = utils.batch_gather_nd(all_positions,
                                               best_locations_labels,
                                               axis=1)

        # Sample locations.
        if policy == "learned":
            if is_training:
                dist = tfp.distributions.Categorical(logits=locations_logits)
                locations_labels = dist.sample()
                # At training samples location from the learned distribution.
                locations = utils.batch_gather_nd(all_positions,
                                                  locations_labels,
                                                  axis=1)
                # Ensures range [-1., 1.]
                locations = tf.clip_by_value(locations, -1., 1)
                tf.logging.info("Sampling locations.")
                tf.logging.info(
                    "==================================================")
            else:
                # At inference uses the mean value for the location.
                locations = best_locations
                locations_labels = best_locations_labels
        elif policy == "random":
            # Use random policy for location.
            locations = tf.random_uniform(shape=(batch_size, 2),
                                          minval=-1.,
                                          maxval=1.)
            locations_labels = None
        elif policy == "center":
            # Use center look policy.
            locations = tf.zeros(shape=(batch_size, 2))
            locations_labels = None

        # Update cell_state.
        cell_state += utils.onehot2d(cell_state, locations)
        cell_state = tf.clip_by_value(cell_state, 0, 1)
        #########################################################################
        # Extract logits from the 2D logits.
        if self.soft_attention:
            logits = tf.reduce_sum(logits2d * locations_prob2d, axis=[1, 2])
        else:
            logits = gather_2d(logits2d, locations)
        ############################################################
        endpoints = {}
        endpoints["cell_outputs"] = {
            "locations": locations,
            "locations_labels": locations_labels,
            "best_locations": best_locations,
            "best_locations_labels": best_locations_labels,
            "locations_logits2d": locations_logits2d,
            "locations_prob2d": locations_prob2d,
            "cell_state": cell_state,
            "features_space_logits": features_space_logits,
            "features_space_weights": features_space_weights,
            "features_channels_logits": features_channels_logits,
            "features_channels_weights": features_channels_weights,
            "locations_logits": locations_logits,
            "all_positions": all_positions,
        }
        if not reuse:
            self.collect_variables(list(variables_after - variables_before))

        return logits, cell_state, endpoints