def test_batch_gather_nd(self, axis): batch_size = 10 axis_dims = [50, 40, 30] x = tf.random.uniform([ batch_size, ] + axis_dims, minval=0, maxval=1) indices = tf.random.uniform( (batch_size, ), minval=0, maxval=axis_dims[axis - 1 if axis > 0 else axis], dtype=tf.int32) x_gathered = utils.batch_gather_nd(x, indices, axis) if axis == 1: identity = tf.eye(batch_size, batch_size)[:, :, tf.newaxis, tf.newaxis] if axis == 2: identity = tf.eye(batch_size, batch_size)[:, tf.newaxis, :, tf.newaxis] if axis == 3 or axis == -1: identity = tf.eye(batch_size, batch_size)[:, tf.newaxis, tf.newaxis, :] x_gathered2 = tf.reduce_sum(identity * tf.gather(x, indices, axis=axis), axis=axis) diff = tf.reduce_sum(tf.math.abs(x_gathered - x_gathered2)) self.assertAlmostEqual(self.evaluate(diff), 0., 9)
def engineered_policies(images, logits2d, position_channels, glimpse_shape, num_times, policy): """Engineered policies. Args: images: A Tensor of type float32. A 4-D float tensor of shape [batch_size, height, width, channels]. logits2d: 2D logits tensor of type float32 of shape [batch_size, height, width, classes]. position_channels: A Tensor of type float32 containing the output of `utils.position_channels` called on `images`. glimpse_shape: (Tuple of integer) Glimpse shape. num_times: (Integer) Number of glimpses. policy: (String) 'ordered logits', 'sobel_mean', or 'sobel_var'. Returns: locations_t: List of 2D Tensors containing policy locations. """ if policy == "ordered_logits": pred_labels = tf.argmax(tf.reduce_mean(logits2d, axis=[1, 2]), axis=-1) metric2d = utils.batch_gather_nd(logits2d, pred_labels, axis=-1) elif "sobel" in policy: edges = sobel_edges(images) edges = edges[:, :, :, tf.newaxis] _, orig_h, orig_w, _ = edges.shape.as_list() _, h, w, _ = logits2d.shape.as_list() ksize = [1, glimpse_shape[0], glimpse_shape[1], 1] strides = [ 1, int(np.ceil((orig_h - glimpse_shape[0] + 1) / h)), int(np.ceil((orig_w - glimpse_shape[1] + 1) / w)), 1 ] mean_per_glimpse = tf.nn.avg_pool(edges, ksize=ksize, strides=strides, padding="VALID") if "mean" in policy: metric2d = mean_per_glimpse elif "var" in policy: n = np.prod(glimpse_shape) var_per_glimpse = (n / (n - 1.)) * (tf.nn.avg_pool(tf.square(edges), ksize=ksize, strides=strides, padding="VALID") - tf.square(mean_per_glimpse)) metric2d = var_per_glimpse metric2d = tf.squeeze(metric2d, axis=3) _, locations_t = utils.sort2d(metric2d, position_channels, first_k=num_times, direction="DESCENDING") locations_t = tf.unstack(locations_t, axis=0) return locations_t
def __call__(self, mixed_features2d, cell_state, logits2d, is_training=False, policy="learned"): """Builds Saccader cell. Args: mixed_features2d: 4-D Tensor of shape [batch, height, width, channels]. cell_state: 4-D Tensor of shape [batch, height, width, 1] with cell state. logits2d: 4-D Tensor of shape [batch, height, width, channels]. is_training: (Boolean) To indicate training or inference modes. policy: (String) 'learned': uses learned policy, 'random': uses random policy, or 'center': uses center look policy. Returns: logits: Model logits. cell_state: New cell state. endpoints: Dictionary with cell parameters. """ batch_size, height, width, channels = mixed_features2d.shape.as_list() reuse = True if self.var_list else False position_channels = utils.position_channels(mixed_features2d) variables_before = set(tf.global_variables()) with tf.variable_scope("saccader_cell", reuse=reuse): # Compute 2D weights of features across space. features_space_logits = tf.layers.dense( mixed_features2d, units=1, use_bias=False, name="attention_weights") / tf.math.sqrt(float(channels)) features_space_logits += (cell_state * -1.e5 ) # Mask used locations. features_space_weights = utils.softmax2d(features_space_logits) # Compute 1D weights of features across channels. features_channels_logits = tf.reduce_sum(mixed_features2d * features_space_weights, axis=[1, 2]) features_channels_weights = tf.nn.softmax(features_channels_logits, axis=1) # Compute location probability. locations_logits2d = tf.reduce_sum( (mixed_features2d * features_channels_weights[:, tf.newaxis, tf.newaxis, :]), axis=-1, keepdims=True) locations_logits2d += (cell_state * -1e5) # Mask used locations. locations_prob2d = utils.softmax2d(locations_logits2d) variables_after = set(tf.global_variables()) # Compute best locations. locations_logits = tf.reshape(locations_logits2d, (batch_size, -1)) all_positions = tf.reshape(position_channels, [batch_size, height * width, 2]) best_locations_labels = tf.argmax(locations_logits, axis=-1) best_locations = utils.batch_gather_nd(all_positions, best_locations_labels, axis=1) # Sample locations. if policy == "learned": if is_training: dist = tfp.distributions.Categorical(logits=locations_logits) locations_labels = dist.sample() # At training samples location from the learned distribution. locations = utils.batch_gather_nd(all_positions, locations_labels, axis=1) # Ensures range [-1., 1.] locations = tf.clip_by_value(locations, -1., 1) tf.logging.info("Sampling locations.") tf.logging.info( "==================================================") else: # At inference uses the mean value for the location. locations = best_locations locations_labels = best_locations_labels elif policy == "random": # Use random policy for location. locations = tf.random_uniform(shape=(batch_size, 2), minval=-1., maxval=1.) locations_labels = None elif policy == "center": # Use center look policy. locations = tf.zeros(shape=(batch_size, 2)) locations_labels = None # Update cell_state. cell_state += utils.onehot2d(cell_state, locations) cell_state = tf.clip_by_value(cell_state, 0, 1) ######################################################################### # Extract logits from the 2D logits. if self.soft_attention: logits = tf.reduce_sum(logits2d * locations_prob2d, axis=[1, 2]) else: logits = gather_2d(logits2d, locations) ############################################################ endpoints = {} endpoints["cell_outputs"] = { "locations": locations, "locations_labels": locations_labels, "best_locations": best_locations, "best_locations_labels": best_locations_labels, "locations_logits2d": locations_logits2d, "locations_prob2d": locations_prob2d, "cell_state": cell_state, "features_space_logits": features_space_logits, "features_space_weights": features_space_weights, "features_channels_logits": features_channels_logits, "features_channels_weights": features_channels_weights, "locations_logits": locations_logits, "all_positions": all_positions, } if not reuse: self.collect_variables(list(variables_after - variables_before)) return logits, cell_state, endpoints