Ejemplo n.º 1
0
    def test_patches_masks(self):
        batch_size = 16
        patch_size = 8
        images = _construct_images(batch_size)
        image_size = images.shape.as_list()[1]
        location_scale = 1. - float(patch_size) / float(image_size)

        locations = tf.convert_to_tensor(2 * np.random.rand(batch_size, 2) - 1,
                                         dtype=tf.float32)
        locations = tf.clip_by_value(locations, -location_scale,
                                     location_scale)
        locations_t = [locations, locations]

        # Construct masks.
        masks = utils.patches_masks(locations_t,
                                    image_size,
                                    patch_size=patch_size)
        patches_masks = utils.extract_glimpse(masks,
                                              size=(patch_size, patch_size),
                                              offsets=locations)

        # Check the mask value at the patches location is 0.
        self.assertEqual(self.evaluate(tf.reduce_sum(patches_masks)), 0)

        # Check the area of mask is equal the specified patch area.
        patches_area = self.evaluate(tf.reduce_sum(1. - masks, axis=[1, 2, 3]))
        self.assertAllEqual(patches_area,
                            np.array([patch_size**2] * batch_size))
Ejemplo n.º 2
0
 def test_extract_glimpse(self):
   batch_size = 50
   glimpse_shape = (8, 8)
   images = _construct_images(batch_size)
   location_scale = 1. - float(glimpse_shape[0]) / float(
       images.shape.as_list()[1])
   locations = tf.convert_to_tensor(
       2 * np.random.rand(batch_size, 2) - 1, dtype=tf.float32)
   locations = tf.clip_by_value(locations, -location_scale, location_scale)
   images_glimpse1 = utils.extract_glimpse(
       images, size=glimpse_shape, offsets=locations)
   images_glimpse2 = tf.image.extract_glimpse(
       images,
       size=glimpse_shape,
       offsets=locations,
       centered=True,
       normalized=True)
   diff = tf.reduce_sum(tf.abs(images_glimpse1 - images_glimpse2))
   self.assertEqual(self.evaluate(diff), 0)
Ejemplo n.º 3
0
    def test_extract_glimpses_at_boundaries(self):
        glimpse_shape = (8, 8)
        locations = tf.constant([[-1, -1], [-1, 1], [1, -1], [1, 1]],
                                dtype=tf.float32)
        images = _construct_images(4)

        glimpse = utils.extract_glimpse(images,
                                        size=glimpse_shape,
                                        offsets=locations)
        glimpses_np, images_np = self.evaluate((glimpse, images))

        true_glimpse = np.stack(
            (np.pad(images_np[0, :4, :4, :], [[4, 0], [4, 0], [0, 0]], "edge"),
             np.pad(images_np[1, :4, -4:, :], [[4, 0], [0, 4], [0, 0]],
                    "edge"),
             np.pad(images_np[2, -4:, :4, :], [[0, 4], [4, 0], [0, 0]],
                    "edge"),
             np.pad(images_np[3, -4:, -4:, :], [[0, 4], [0, 4], [0, 0]],
                    "edge")), 0)

        self.assertAllEqual(glimpses_np, true_glimpse)
Ejemplo n.º 4
0
    def extract_glimpses(self, images, locations):
        """Extracts fovea-like glimpses.

    Args:
      images: 4-D Tensor of shape [batch, height, width, channels].
      locations: 2D Tensor of shape [batch, 2] with glimpse locations. Locations
        are in the interval of [-1, 1] where points:
        (-1, -1): upper left corner.
        (-1, 1): upper right corner.
        (1, 1): lower right corner.
        (1, -1): lower left corner.

    Returns:
      glimpses: 5D tensor of size [batch, # glimpses, height, width, channels].
    """
        # Get multi resolution fields of view (first is full resolution)
        image_shape = tf.cast(tf.shape(images)[1:3], dtype=tf.float32)
        start = tf.cast(self.glimpse_shape[0],
                        dtype=tf.float32) / image_shape[0]
        fields_of_view = tf.cast(tf.lin_space(start, 1., self.num_resolutions),
                                 dtype=tf.float32)
        receptive_fields = [self.glimpse_shape] + [
            tf.cast(fields_of_view[i] * image_shape, dtype=tf.int32)
            for i in range(1, self.num_resolutions)
        ]
        images_glimpses_list = []
        for field in receptive_fields:
            # Extract a glimpse with specific shape and scale.
            images_glimpse = utils.extract_glimpse(images,
                                                   size=field,
                                                   offsets=locations)
            # Bigger receptive fields have lower resolution.
            images_glimpse = tf.image.resize_images(images_glimpse,
                                                    size=self.glimpse_shape)
            # Stop gradient
            if self.apply_stop_gradient:
                images_glimpse = tf.stop_gradient(images_glimpse)
            images_glimpses_list.append(images_glimpse)
        return images_glimpses_list
Ejemplo n.º 5
0
    def __call__(self,
                 images_saccader,
                 images_classnet,
                 num_times,
                 is_training_saccader=False,
                 is_training_classnet=False,
                 policy="learned",
                 stop_gradient_after_representation=False):

        logits, locations_t, best_locations_t, endpoints = Saccader.__call__(
            self,
            images_saccader,
            num_times,
            is_training=is_training_saccader,
            policy=policy,
            stop_gradient_after_representation=
            stop_gradient_after_representation)

        self.glimpse_shape_saccader = self.glimpse_shape
        image_size_saccader = images_saccader.shape.as_list()[1]
        image_size_classnet = images_classnet.shape.as_list()[1]
        if self.glimpse_shape_classnet[0] < 0:
            self.glimpse_shape_classnet = tuple([
                int(image_size_classnet / image_size_saccader *
                    self.glimpse_shape[0])
            ] * 2)
        self.glimpse_shape = self.glimpse_shape_classnet

        images_glimpse_t = []
        for locations in locations_t:
            images_glimpse = utils.extract_glimpse(
                images_classnet,
                size=self.glimpse_shape_classnet,
                offsets=locations)
            images_glimpse_t.append(images_glimpse)

        batch_size = images_classnet.shape.as_list()[0]
        images_glimpse_t = tf.concat(images_glimpse_t, axis=0)

        variables_before = set(tf.global_variables())
        reuse = True if self.var_list_classnet else False
        with tf.variable_scope(self.variable_scope_classnet, reuse=reuse):
            if self.classnet_type == "nasnet":
                classnet_config = nasnet.large_imagenet_config()
                classnet_config.use_aux_head = 0
                classnet_config.drop_path_keep_prob = 1.0
                with slim.arg_scope(nasnet.nasnet_large_arg_scope()):
                    classnet_logits, endpoints_ = nasnet.build_nasnet_large(
                        images_glimpse_t,
                        self.num_classes,
                        is_training=is_training_classnet,
                        config=classnet_config)
            elif self.classnet_type == "resnet_v2_50":
                network = nets_factory.get_network_fn(
                    "resnet_v2_50",
                    self.num_classes,
                    is_training=is_training_classnet)
                classnet_logits, endpoints_ = network(images_glimpse_t)

        endpoints["classnet"] = endpoints_
        variables_after = set(tf.global_variables())
        logits_t = tf.reshape(classnet_logits, (num_times, batch_size, -1))
        logits = tf.reduce_mean(logits_t, axis=0)
        if not reuse:
            self.var_list_saccader = self.var_list_classification + self.var_list_location
            self.var_list_classnet = [
                v for v in list(variables_after - variables_before)
                if "global_step" not in v.op.name
            ]
            self.var_list.extend(self.var_list_classnet)
            self.init_op = tf.variables_initializer(var_list=self.var_list)

        return logits, locations_t, best_locations_t, endpoints