def saccader_pretraining_loss(model, images, is_training): """Saccader pretraining loss. Args: model: Callable saccader model object. images: (4D Tensor) input images. is_training: (Boolen) training or inference mode. Returns: Pretraining loss for the model location weights. """ _, _, _, endpoints = model(images, num_times=12, is_training=is_training, policy="learned", stop_gradient_after_representation=True) location_scale = endpoints["location_scale"] logits2d = endpoints["logits2d"] locations_logits2d_t = endpoints["locations_logits2d_t"] batch_size, height, width, _ = logits2d.shape.as_list() num_times = len(locations_logits2d_t) target_locations_t = saccader.engineered_policies( images, logits2d, utils.position_channels(logits2d) * location_scale, model.glimpse_shape, num_times, policy="ordered_logits") one_hot_t = [] for loc in target_locations_t: one_hot_t.append( tf.reshape( utils.onehot2d(logits2d, tf.stop_gradient(loc) / location_scale), (batch_size, height * width))) locations_logits_t = [ tf.reshape(locations_logits2d_t[t], (batch_size, height * width)) for t in range(num_times) ] pretrain_loss = tf.reduce_mean( tf.losses.softmax_cross_entropy(onehot_labels=tf.concat(one_hot_t, axis=0), logits=tf.concat(locations_logits_t, axis=0), loss_collection=None, reduction=tf.losses.Reduction.NONE)) return pretrain_loss
def test_position_channels(self): corner_locations = [ (-1, -1), # Upper left. (-1, 1), # Upper right. (1, -1), # Lower left. (1, 1), # Lower right. ] batch_size = len(corner_locations) images = _construct_images(batch_size) channels = utils.position_channels(images) # Corner positions. upper_left = channels[0][0, 0] # Should be position [-1, -1]. upper_right = channels[1][0, -1] # Should be position [-1, 1]. lower_left = channels[2][-1, 0] # Should be position [1, -1]. lower_right = channels[3][-1, -1] # Should be position [1, 1]. corners = (upper_left, upper_right, lower_left, lower_right) corner_locations = tf.convert_to_tensor(corner_locations, dtype=tf.float32) glimpses = tf.image.extract_glimpse(channels, size=(1, 1), offsets=corner_locations, centered=True, normalized=True) # Check shape. self.assertEqual(channels.shape.as_list(), images.shape.as_list()[:-1] + [ 2, ]) corners, glimpses, corner_locations = self.evaluate( (corners, glimpses, corner_locations)) glimpses = np.squeeze(glimpses) # Check correct corners self.assertEqual(tuple(corners[0]), tuple(corner_locations[0])) self.assertEqual(tuple(corners[1]), tuple(corner_locations[1])) self.assertEqual(tuple(corners[2]), tuple(corner_locations[2])) self.assertEqual(tuple(corners[3]), tuple(corner_locations[3])) # Check match with extract_glimpse function. self.assertEqual(tuple(corners[0]), tuple(glimpses[0])) self.assertEqual(tuple(corners[1]), tuple(glimpses[1])) self.assertEqual(tuple(corners[2]), tuple(glimpses[2])) self.assertEqual(tuple(corners[3]), tuple(glimpses[3]))
def test_sort2d(self, direction): elements = range(0, 9) x = tf.convert_to_tensor(np.array( [np.reshape(elements, (3, 3)), np.reshape(elements[::-1], (3, 3))]), dtype=tf.float32) ref_indices = utils.position_channels(x) sorted_x, argsorted_x = utils.sort2d(x, ref_indices, direction=direction) sorted_x = self.evaluate(sorted_x) argsorted_x = self.evaluate(argsorted_x) # Examples include same elements. So sorted examples should be equal. self.assertAllEqual(sorted_x[:, 0], sorted_x[:, 1]) # Examples are in reverse order. So, indices should be reversed. ndims = 2 for i in range(ndims): self.assertAllEqual(argsorted_x[:, 0, i], argsorted_x[:, 1, i][::-1])
def __call__(self, images, num_times, is_training=False, policy="learned", stop_gradient_after_representation=False): endpoints = {} reuse = True if self.var_list else False with tf.variable_scope(self.variable_scope + "/representation_network", reuse=reuse): representation_logits, endpoints_ = self.representation_network( images, is_training) if not self.var_list_representation_network: self.var_list_representation_network = self.representation_network.var_list self.glimpse_shape = self.representation_network.receptive_field glimpse_size = tf.cast(self.glimpse_shape[0], dtype=tf.float32) image_size = tf.cast(tf.shape(images)[1], dtype=tf.float32) # Ensure glimpses within image. location_scale = 1. - glimpse_size / image_size endpoints["location_scale"] = location_scale endpoints["representation_network"] = endpoints_ endpoints["representation_network"]["logits"] = representation_logits features2d = endpoints_["features2d"] # size [batch, 28, 28, 2048] logits2d = endpoints_["logits2d"] # size [batch, 28, 28, 1001] what_features2d = endpoints_["features2d_lowd"] endpoints["logits2d"] = logits2d endpoints["features2d"] = features2d endpoints["what_features2d"] = what_features2d # Freeze the representation network weights. if stop_gradient_after_representation: features2d = tf.stop_gradient(features2d) logits2d = tf.stop_gradient(logits2d) what_features2d = tf.stop_gradient(what_features2d) # Attention network. variables_before = set(tf.global_variables()) with tf.variable_scope(self.variable_scope, reuse=reuse): where_features2d = build_attention_network( features2d, self.config.attention_groups, self.config.attention_layers_per_group, is_training) endpoints["where_features2d"] = where_features2d # Mix what and where features. mixed_features2d = tf.layers.conv2d(tf.concat( [where_features2d, what_features2d], axis=-1), filters=512, kernel_size=1, strides=1, activation=None, use_bias=True, name="mixed_features2d", padding="same") endpoints["mixed_features2d"] = mixed_features2d variables_after = set(tf.global_variables()) if not self.var_list_attention_network: self.var_list_attention_network = list(variables_after - variables_before) # Unrolling the model in time. classification_logits_t = [] locations_t = [] best_locations_t = [] locations_logits2d_t = [] batch_size, height, width, _ = mixed_features2d.shape.as_list() cell_state = tf.zeros((batch_size, height, width, 1), dtype=tf.float32) # Engineered policies. if policy in ["ordered_logits", "sobel_mean", "sobel_var"]: locations_t = engineered_policies( images, logits2d, utils.position_channels(logits2d) * location_scale, self.glimpse_shape, num_times, policy) best_locations_t = locations_t classification_logits_t = [ gather_2d(logits2d, locations / location_scale) for locations in locations_t ] # Run for 1 time to create variables (but output is unused). with tf.name_scope("time%d" % 0): with tf.variable_scope(self.variable_scope): self.saccader_cell(mixed_features2d, cell_state, logits2d, is_training=is_training, policy="random") # Other policies elif policy in ["learned", "random", "center"]: for t in range(num_times): endpoints["time%d" % t] = {} with tf.name_scope("time%d" % t): with tf.variable_scope(self.variable_scope): logits, cell_state, endpoints_ = self.saccader_cell( mixed_features2d, cell_state, logits2d, is_training=is_training, policy=policy) cell_outputs = endpoints_["cell_outputs"] endpoints["time%d" % t].update(endpoints_) classification_logits_t.append(logits) # Convert to center glimpse location on images space. locations_t.append(cell_outputs["locations"] * location_scale) best_locations_t.append(cell_outputs["best_locations"] * location_scale) locations_logits2d_t.append( cell_outputs["locations_logits2d"]) endpoints["locations_logits2d_t"] = locations_logits2d_t else: raise ValueError( "policy can be either 'learned', 'random', or 'center'") if not self.var_list_saccader_cell: self.var_list_saccader_cell = self.saccader_cell.var_list self.collect_variables() endpoints["classification_logits_t"] = classification_logits_t logits = tf.reduce_mean(classification_logits_t, axis=0) return (logits, locations_t, best_locations_t, endpoints)
def __call__(self, mixed_features2d, cell_state, logits2d, is_training=False, policy="learned"): """Builds Saccader cell. Args: mixed_features2d: 4-D Tensor of shape [batch, height, width, channels]. cell_state: 4-D Tensor of shape [batch, height, width, 1] with cell state. logits2d: 4-D Tensor of shape [batch, height, width, channels]. is_training: (Boolean) To indicate training or inference modes. policy: (String) 'learned': uses learned policy, 'random': uses random policy, or 'center': uses center look policy. Returns: logits: Model logits. cell_state: New cell state. endpoints: Dictionary with cell parameters. """ batch_size, height, width, channels = mixed_features2d.shape.as_list() reuse = True if self.var_list else False position_channels = utils.position_channels(mixed_features2d) variables_before = set(tf.global_variables()) with tf.variable_scope("saccader_cell", reuse=reuse): # Compute 2D weights of features across space. features_space_logits = tf.layers.dense( mixed_features2d, units=1, use_bias=False, name="attention_weights") / tf.math.sqrt(float(channels)) features_space_logits += (cell_state * -1.e5 ) # Mask used locations. features_space_weights = utils.softmax2d(features_space_logits) # Compute 1D weights of features across channels. features_channels_logits = tf.reduce_sum(mixed_features2d * features_space_weights, axis=[1, 2]) features_channels_weights = tf.nn.softmax(features_channels_logits, axis=1) # Compute location probability. locations_logits2d = tf.reduce_sum( (mixed_features2d * features_channels_weights[:, tf.newaxis, tf.newaxis, :]), axis=-1, keepdims=True) locations_logits2d += (cell_state * -1e5) # Mask used locations. locations_prob2d = utils.softmax2d(locations_logits2d) variables_after = set(tf.global_variables()) # Compute best locations. locations_logits = tf.reshape(locations_logits2d, (batch_size, -1)) all_positions = tf.reshape(position_channels, [batch_size, height * width, 2]) best_locations_labels = tf.argmax(locations_logits, axis=-1) best_locations = utils.batch_gather_nd(all_positions, best_locations_labels, axis=1) # Sample locations. if policy == "learned": if is_training: dist = tfp.distributions.Categorical(logits=locations_logits) locations_labels = dist.sample() # At training samples location from the learned distribution. locations = utils.batch_gather_nd(all_positions, locations_labels, axis=1) # Ensures range [-1., 1.] locations = tf.clip_by_value(locations, -1., 1) tf.logging.info("Sampling locations.") tf.logging.info( "==================================================") else: # At inference uses the mean value for the location. locations = best_locations locations_labels = best_locations_labels elif policy == "random": # Use random policy for location. locations = tf.random_uniform(shape=(batch_size, 2), minval=-1., maxval=1.) locations_labels = None elif policy == "center": # Use center look policy. locations = tf.zeros(shape=(batch_size, 2)) locations_labels = None # Update cell_state. cell_state += utils.onehot2d(cell_state, locations) cell_state = tf.clip_by_value(cell_state, 0, 1) ######################################################################### # Extract logits from the 2D logits. if self.soft_attention: logits = tf.reduce_sum(logits2d * locations_prob2d, axis=[1, 2]) else: logits = gather_2d(logits2d, locations) ############################################################ endpoints = {} endpoints["cell_outputs"] = { "locations": locations, "locations_labels": locations_labels, "best_locations": best_locations, "best_locations_labels": best_locations_labels, "locations_logits2d": locations_logits2d, "locations_prob2d": locations_prob2d, "cell_state": cell_state, "features_space_logits": features_space_logits, "features_space_weights": features_space_weights, "features_channels_logits": features_channels_logits, "features_channels_weights": features_channels_weights, "locations_logits": locations_logits, "all_positions": all_positions, } if not reuse: self.collect_variables(list(variables_after - variables_before)) return logits, cell_state, endpoints
def __call__(self, images, locations, is_training, use_resolution): """Builds glimpse network. Args: images: 4-D Tensor of shape [batch, height, width, channels]. locations: 2D Tensor of shape [batch, 2] with glimpse locations. is_training: (Boolean) training or inference mode. use_resolution: (List of Boolean of size num_resolutions) Indicates which resolutions to use from high (small receptive field) to low (wide receptive field). Returns: output: Network output reflecting representation learned from glimpses and locations. endpoints: Dictionary with activations at different layers. """ if self.var_list: reuse = True else: reuse = False tf.logging.info("Build Glimpse Network") endpoints = {} # Append position channels. images_with_position = tf.concat( [images, utils.position_channels(images)], axis=3) images_glimpses_list = self.extract_glimpses(images_with_position, locations) endpoints["images_glimpses_list"] = [ g[:, :, :, 0:3] for g in images_glimpses_list ] endpoints["model_input_list"] = images_glimpses_list # Concatenate along channels axis. images_glimpses_list_ = [] for use, g in zip(use_resolution, images_glimpses_list): if not use: # If masking is required, use the spatial mean per channel. images_glimpses_list_.append(0. * g + tf.stop_gradient( tf.reduce_mean(g, axis=[1, 2], keepdims=True))) else: images_glimpses_list_.append(g) images_glimpses_list = images_glimpses_list_ images_glimpses = tf.concat(images_glimpses_list, axis=3) net = images_glimpses if self.network_type == "wrn": with tf.variable_scope("glimpse_network", reuse=reuse): output, endpoints_ = model_utils.build_wide_residual_network( net, self.output_dims, residual_blocks_per_group=self.residual_blocks_per_group, number_groups=self.number_groups, init_conv_channels=self.init_conv_channels, widening_factor=self.widening_factor, dropout_rate=self.dropout_rate, expand_rate=2, conv_size=3, is_training=is_training, activation=self.activation, regularizer=self.regularizer, normalization_type=self.normalization_type, zero_pad=self.zero_pad, global_average_pool=self.global_average_pool) else: network = nets_factory.get_network_fn(self.network_type, num_classes=self.output_dims, is_training=is_training) output, endpoints_ = network(net, scope="glimpse_network", reuse=reuse) if self.output_dims is None: # Global average of activations. output = tf.reduce_mean(output, [1, 2]) endpoints.update(endpoints_) if not reuse: self.collect_variables() return output, endpoints