def saccader_pretraining_loss(model, images, is_training):
    """Saccader pretraining loss.

  Args:
    model: Callable saccader model object.
    images: (4D Tensor) input images.
    is_training: (Boolen) training or inference mode.

  Returns:
    Pretraining loss for the model location weights.
  """
    _, _, _, endpoints = model(images,
                               num_times=12,
                               is_training=is_training,
                               policy="learned",
                               stop_gradient_after_representation=True)

    location_scale = endpoints["location_scale"]
    logits2d = endpoints["logits2d"]
    locations_logits2d_t = endpoints["locations_logits2d_t"]
    batch_size, height, width, _ = logits2d.shape.as_list()
    num_times = len(locations_logits2d_t)
    target_locations_t = saccader.engineered_policies(
        images,
        logits2d,
        utils.position_channels(logits2d) * location_scale,
        model.glimpse_shape,
        num_times,
        policy="ordered_logits")

    one_hot_t = []
    for loc in target_locations_t:
        one_hot_t.append(
            tf.reshape(
                utils.onehot2d(logits2d,
                               tf.stop_gradient(loc) / location_scale),
                (batch_size, height * width)))

    locations_logits_t = [
        tf.reshape(locations_logits2d_t[t], (batch_size, height * width))
        for t in range(num_times)
    ]
    pretrain_loss = tf.reduce_mean(
        tf.losses.softmax_cross_entropy(onehot_labels=tf.concat(one_hot_t,
                                                                axis=0),
                                        logits=tf.concat(locations_logits_t,
                                                         axis=0),
                                        loss_collection=None,
                                        reduction=tf.losses.Reduction.NONE))
    return pretrain_loss
Example #2
0
    def test_position_channels(self):
        corner_locations = [
            (-1, -1),  # Upper left.
            (-1, 1),  # Upper right.
            (1, -1),  # Lower left.
            (1, 1),  # Lower right.
        ]
        batch_size = len(corner_locations)
        images = _construct_images(batch_size)
        channels = utils.position_channels(images)
        # Corner positions.
        upper_left = channels[0][0, 0]  # Should be position [-1, -1].
        upper_right = channels[1][0, -1]  # Should be position [-1, 1].
        lower_left = channels[2][-1, 0]  # Should be position [1, -1].
        lower_right = channels[3][-1, -1]  # Should be position [1, 1].

        corners = (upper_left, upper_right, lower_left, lower_right)

        corner_locations = tf.convert_to_tensor(corner_locations,
                                                dtype=tf.float32)
        glimpses = tf.image.extract_glimpse(channels,
                                            size=(1, 1),
                                            offsets=corner_locations,
                                            centered=True,
                                            normalized=True)

        # Check shape.
        self.assertEqual(channels.shape.as_list(),
                         images.shape.as_list()[:-1] + [
                             2,
                         ])
        corners, glimpses, corner_locations = self.evaluate(
            (corners, glimpses, corner_locations))
        glimpses = np.squeeze(glimpses)

        # Check correct corners
        self.assertEqual(tuple(corners[0]), tuple(corner_locations[0]))
        self.assertEqual(tuple(corners[1]), tuple(corner_locations[1]))
        self.assertEqual(tuple(corners[2]), tuple(corner_locations[2]))
        self.assertEqual(tuple(corners[3]), tuple(corner_locations[3]))
        # Check match with extract_glimpse function.
        self.assertEqual(tuple(corners[0]), tuple(glimpses[0]))
        self.assertEqual(tuple(corners[1]), tuple(glimpses[1]))
        self.assertEqual(tuple(corners[2]), tuple(glimpses[2]))
        self.assertEqual(tuple(corners[3]), tuple(glimpses[3]))
Example #3
0
    def test_sort2d(self, direction):
        elements = range(0, 9)
        x = tf.convert_to_tensor(np.array(
            [np.reshape(elements, (3, 3)),
             np.reshape(elements[::-1], (3, 3))]),
                                 dtype=tf.float32)

        ref_indices = utils.position_channels(x)
        sorted_x, argsorted_x = utils.sort2d(x,
                                             ref_indices,
                                             direction=direction)
        sorted_x = self.evaluate(sorted_x)
        argsorted_x = self.evaluate(argsorted_x)
        # Examples include same elements. So sorted examples should be equal.
        self.assertAllEqual(sorted_x[:, 0], sorted_x[:, 1])

        # Examples are in reverse order. So, indices should be reversed.
        ndims = 2
        for i in range(ndims):
            self.assertAllEqual(argsorted_x[:, 0, i], argsorted_x[:, 1,
                                                                  i][::-1])
Example #4
0
    def __call__(self,
                 images,
                 num_times,
                 is_training=False,
                 policy="learned",
                 stop_gradient_after_representation=False):

        endpoints = {}
        reuse = True if self.var_list else False
        with tf.variable_scope(self.variable_scope + "/representation_network",
                               reuse=reuse):
            representation_logits, endpoints_ = self.representation_network(
                images, is_training)

        if not self.var_list_representation_network:
            self.var_list_representation_network = self.representation_network.var_list

        self.glimpse_shape = self.representation_network.receptive_field
        glimpse_size = tf.cast(self.glimpse_shape[0], dtype=tf.float32)
        image_size = tf.cast(tf.shape(images)[1], dtype=tf.float32)
        # Ensure glimpses within image.
        location_scale = 1. - glimpse_size / image_size
        endpoints["location_scale"] = location_scale
        endpoints["representation_network"] = endpoints_
        endpoints["representation_network"]["logits"] = representation_logits
        features2d = endpoints_["features2d"]  # size [batch, 28, 28, 2048]
        logits2d = endpoints_["logits2d"]  # size [batch, 28, 28, 1001]
        what_features2d = endpoints_["features2d_lowd"]
        endpoints["logits2d"] = logits2d
        endpoints["features2d"] = features2d
        endpoints["what_features2d"] = what_features2d

        # Freeze the representation network weights.
        if stop_gradient_after_representation:
            features2d = tf.stop_gradient(features2d)
            logits2d = tf.stop_gradient(logits2d)
            what_features2d = tf.stop_gradient(what_features2d)

        # Attention network.
        variables_before = set(tf.global_variables())
        with tf.variable_scope(self.variable_scope, reuse=reuse):
            where_features2d = build_attention_network(
                features2d, self.config.attention_groups,
                self.config.attention_layers_per_group, is_training)
            endpoints["where_features2d"] = where_features2d
            # Mix what and where features.
            mixed_features2d = tf.layers.conv2d(tf.concat(
                [where_features2d, what_features2d], axis=-1),
                                                filters=512,
                                                kernel_size=1,
                                                strides=1,
                                                activation=None,
                                                use_bias=True,
                                                name="mixed_features2d",
                                                padding="same")
        endpoints["mixed_features2d"] = mixed_features2d
        variables_after = set(tf.global_variables())
        if not self.var_list_attention_network:
            self.var_list_attention_network = list(variables_after -
                                                   variables_before)
        # Unrolling the model in time.
        classification_logits_t = []
        locations_t = []
        best_locations_t = []
        locations_logits2d_t = []
        batch_size, height, width, _ = mixed_features2d.shape.as_list()
        cell_state = tf.zeros((batch_size, height, width, 1), dtype=tf.float32)
        # Engineered policies.
        if policy in ["ordered_logits", "sobel_mean", "sobel_var"]:
            locations_t = engineered_policies(
                images, logits2d,
                utils.position_channels(logits2d) * location_scale,
                self.glimpse_shape, num_times, policy)

            best_locations_t = locations_t
            classification_logits_t = [
                gather_2d(logits2d, locations / location_scale)
                for locations in locations_t
            ]
            # Run for 1 time to create variables (but output is unused).
            with tf.name_scope("time%d" % 0):
                with tf.variable_scope(self.variable_scope):
                    self.saccader_cell(mixed_features2d,
                                       cell_state,
                                       logits2d,
                                       is_training=is_training,
                                       policy="random")

        # Other policies
        elif policy in ["learned", "random", "center"]:
            for t in range(num_times):
                endpoints["time%d" % t] = {}
                with tf.name_scope("time%d" % t):
                    with tf.variable_scope(self.variable_scope):
                        logits, cell_state, endpoints_ = self.saccader_cell(
                            mixed_features2d,
                            cell_state,
                            logits2d,
                            is_training=is_training,
                            policy=policy)
                    cell_outputs = endpoints_["cell_outputs"]
                    endpoints["time%d" % t].update(endpoints_)
                    classification_logits_t.append(logits)
                    # Convert to center glimpse location on images space.
                    locations_t.append(cell_outputs["locations"] *
                                       location_scale)
                    best_locations_t.append(cell_outputs["best_locations"] *
                                            location_scale)
                    locations_logits2d_t.append(
                        cell_outputs["locations_logits2d"])
            endpoints["locations_logits2d_t"] = locations_logits2d_t
        else:
            raise ValueError(
                "policy can be either 'learned', 'random', or 'center'")

        if not self.var_list_saccader_cell:
            self.var_list_saccader_cell = self.saccader_cell.var_list
        self.collect_variables()
        endpoints["classification_logits_t"] = classification_logits_t
        logits = tf.reduce_mean(classification_logits_t, axis=0)
        return (logits, locations_t, best_locations_t, endpoints)
Example #5
0
    def __call__(self,
                 mixed_features2d,
                 cell_state,
                 logits2d,
                 is_training=False,
                 policy="learned"):
        """Builds Saccader cell.

    Args:
      mixed_features2d: 4-D Tensor of shape [batch, height, width, channels].
      cell_state: 4-D Tensor of shape [batch, height, width, 1] with cell state.
      logits2d: 4-D Tensor of shape [batch, height, width, channels].
      is_training: (Boolean) To indicate training or inference modes.
      policy: (String) 'learned': uses learned policy, 'random': uses random
        policy, or 'center': uses center look policy.
    Returns:
      logits: Model logits.
      cell_state: New cell state.
      endpoints: Dictionary with cell parameters.
    """
        batch_size, height, width, channels = mixed_features2d.shape.as_list()
        reuse = True if self.var_list else False
        position_channels = utils.position_channels(mixed_features2d)

        variables_before = set(tf.global_variables())
        with tf.variable_scope("saccader_cell", reuse=reuse):
            # Compute 2D weights of features across space.
            features_space_logits = tf.layers.dense(
                mixed_features2d,
                units=1,
                use_bias=False,
                name="attention_weights") / tf.math.sqrt(float(channels))

            features_space_logits += (cell_state * -1.e5
                                      )  # Mask used locations.
            features_space_weights = utils.softmax2d(features_space_logits)

            # Compute 1D weights of features across channels.
            features_channels_logits = tf.reduce_sum(mixed_features2d *
                                                     features_space_weights,
                                                     axis=[1, 2])
            features_channels_weights = tf.nn.softmax(features_channels_logits,
                                                      axis=1)

            # Compute location probability.
            locations_logits2d = tf.reduce_sum(
                (mixed_features2d *
                 features_channels_weights[:, tf.newaxis, tf.newaxis, :]),
                axis=-1,
                keepdims=True)

            locations_logits2d += (cell_state * -1e5)  # Mask used locations.
            locations_prob2d = utils.softmax2d(locations_logits2d)

        variables_after = set(tf.global_variables())
        # Compute best locations.
        locations_logits = tf.reshape(locations_logits2d, (batch_size, -1))
        all_positions = tf.reshape(position_channels,
                                   [batch_size, height * width, 2])

        best_locations_labels = tf.argmax(locations_logits, axis=-1)
        best_locations = utils.batch_gather_nd(all_positions,
                                               best_locations_labels,
                                               axis=1)

        # Sample locations.
        if policy == "learned":
            if is_training:
                dist = tfp.distributions.Categorical(logits=locations_logits)
                locations_labels = dist.sample()
                # At training samples location from the learned distribution.
                locations = utils.batch_gather_nd(all_positions,
                                                  locations_labels,
                                                  axis=1)
                # Ensures range [-1., 1.]
                locations = tf.clip_by_value(locations, -1., 1)
                tf.logging.info("Sampling locations.")
                tf.logging.info(
                    "==================================================")
            else:
                # At inference uses the mean value for the location.
                locations = best_locations
                locations_labels = best_locations_labels
        elif policy == "random":
            # Use random policy for location.
            locations = tf.random_uniform(shape=(batch_size, 2),
                                          minval=-1.,
                                          maxval=1.)
            locations_labels = None
        elif policy == "center":
            # Use center look policy.
            locations = tf.zeros(shape=(batch_size, 2))
            locations_labels = None

        # Update cell_state.
        cell_state += utils.onehot2d(cell_state, locations)
        cell_state = tf.clip_by_value(cell_state, 0, 1)
        #########################################################################
        # Extract logits from the 2D logits.
        if self.soft_attention:
            logits = tf.reduce_sum(logits2d * locations_prob2d, axis=[1, 2])
        else:
            logits = gather_2d(logits2d, locations)
        ############################################################
        endpoints = {}
        endpoints["cell_outputs"] = {
            "locations": locations,
            "locations_labels": locations_labels,
            "best_locations": best_locations,
            "best_locations_labels": best_locations_labels,
            "locations_logits2d": locations_logits2d,
            "locations_prob2d": locations_prob2d,
            "cell_state": cell_state,
            "features_space_logits": features_space_logits,
            "features_space_weights": features_space_weights,
            "features_channels_logits": features_channels_logits,
            "features_channels_weights": features_channels_weights,
            "locations_logits": locations_logits,
            "all_positions": all_positions,
        }
        if not reuse:
            self.collect_variables(list(variables_after - variables_before))

        return logits, cell_state, endpoints
Example #6
0
    def __call__(self, images, locations, is_training, use_resolution):
        """Builds glimpse network.

    Args:
      images: 4-D Tensor of shape [batch, height, width, channels].
      locations: 2D Tensor of shape [batch, 2] with glimpse locations.
      is_training: (Boolean) training or inference mode.
      use_resolution: (List of Boolean of size num_resolutions) Indicates which
        resolutions to use from high (small receptive field)
        to low (wide receptive field).

    Returns:
      output: Network output reflecting representation learned from glimpses
        and locations.
      endpoints: Dictionary with activations at different layers.
    """
        if self.var_list:
            reuse = True
        else:
            reuse = False

        tf.logging.info("Build Glimpse Network")
        endpoints = {}

        # Append position channels.
        images_with_position = tf.concat(
            [images, utils.position_channels(images)], axis=3)

        images_glimpses_list = self.extract_glimpses(images_with_position,
                                                     locations)
        endpoints["images_glimpses_list"] = [
            g[:, :, :, 0:3] for g in images_glimpses_list
        ]
        endpoints["model_input_list"] = images_glimpses_list
        # Concatenate along channels axis.
        images_glimpses_list_ = []
        for use, g in zip(use_resolution, images_glimpses_list):
            if not use:
                # If masking is required, use the spatial mean per channel.
                images_glimpses_list_.append(0. * g + tf.stop_gradient(
                    tf.reduce_mean(g, axis=[1, 2], keepdims=True)))
            else:
                images_glimpses_list_.append(g)

        images_glimpses_list = images_glimpses_list_

        images_glimpses = tf.concat(images_glimpses_list, axis=3)
        net = images_glimpses

        if self.network_type == "wrn":
            with tf.variable_scope("glimpse_network", reuse=reuse):
                output, endpoints_ = model_utils.build_wide_residual_network(
                    net,
                    self.output_dims,
                    residual_blocks_per_group=self.residual_blocks_per_group,
                    number_groups=self.number_groups,
                    init_conv_channels=self.init_conv_channels,
                    widening_factor=self.widening_factor,
                    dropout_rate=self.dropout_rate,
                    expand_rate=2,
                    conv_size=3,
                    is_training=is_training,
                    activation=self.activation,
                    regularizer=self.regularizer,
                    normalization_type=self.normalization_type,
                    zero_pad=self.zero_pad,
                    global_average_pool=self.global_average_pool)
        else:
            network = nets_factory.get_network_fn(self.network_type,
                                                  num_classes=self.output_dims,
                                                  is_training=is_training)
            output, endpoints_ = network(net,
                                         scope="glimpse_network",
                                         reuse=reuse)
            if self.output_dims is None:
                # Global average of activations.
                output = tf.reduce_mean(output, [1, 2])

        endpoints.update(endpoints_)

        if not reuse:
            self.collect_variables()

        return output, endpoints