def test_use_resolution(self, is_training, use_resolution):
        config = dram_config.get_config()
        image_shape = (28, 28, 1)
        batch_size = 5
        output_dims = 10
        config.glimpse_model_config.output_dims = output_dims
        config.glimpse_model_config.glimpse_shape = config.glimpse_shape
        config.glimpse_model_config.num_resolutions = config.num_resolutions
        config.glimpse_model_config.glimpse_shape = (8, 8)
        config.glimpse_model_config.num_resolutions = 3
        locations = tf.placeholder(shape=(batch_size, 2), dtype=tf.float32)
        model = glimpse_model.GlimpseNetwork(config.glimpse_model_config)
        images = tf.random_uniform(minval=-1,
                                   maxval=1,
                                   shape=(batch_size, ) + image_shape,
                                   dtype=tf.float32)
        locations = tf.zeros(shape=(batch_size, 2), dtype=tf.float32)
        model = glimpse_model.GlimpseNetwork(config.glimpse_model_config)
        g, endpoints = model(images,
                             locations,
                             is_training=is_training,
                             use_resolution=use_resolution)
        gnorms = [
            tf.norm(grad)
            for grad in tf.gradients(g[:, 0], endpoints["model_input_list"])
        ]
        self.evaluate(tf.global_variables_initializer())
        gnorms = self.evaluate(gnorms)

        for use, gnorm in zip(use_resolution, gnorms):
            if use:
                self.assertGreater(gnorm, 0.)
            else:
                self.assertEqual(gnorm, 0.)
    def test_build(self, is_training):
        config = dram_config.get_config()
        image_shape = (28, 28, 1)
        batch_size = 10
        location_dims = 2
        output_dims = 10
        config.glimpse_model_config.output_dims = output_dims
        config.glimpse_model_config.glimpse_shape = config.glimpse_shape
        config.glimpse_model_config.num_resolutions = config.num_resolutions

        images = tf.placeholder(shape=(batch_size, ) + image_shape,
                                dtype=tf.float32)
        locations = tf.placeholder(shape=(batch_size, 2), dtype=tf.float32)
        model = glimpse_model.GlimpseNetwork(config.glimpse_model_config)

        g, _ = model(images,
                     locations,
                     is_training,
                     use_resolution=[True] * model.num_resolutions)
        init_op = model.init_op

        with self.test_session() as sess:
            sess.run(init_op)
            self.assertEqual(
                (batch_size, output_dims),
                sess.run(g,
                         feed_dict={
                             images:
                             np.random.rand(*((batch_size, ) + image_shape)),
                             locations:
                             np.random.rand(batch_size, location_dims)
                         }).shape)
Ejemplo n.º 3
0
    def __init__(self, config):
        """Init.

    Args:
      config: ConfigDict object with model parameters (see dram_config.py).
    """
        self.config = copy.deepcopy(config)
        if len(self.config.num_units_rnn_layers) != 2:
            raise ValueError(
                "num_units_rnn_layers should be a list of length 2.")
        self.cell_type = self.config.cell_type

        glimpse_model_config = self.config.glimpse_model_config
        emission_model_config = self.config.emission_model_config
        classification_model_config = self.config.classification_model_config
        classification_model_config.num_classes = self.config.num_classes

        glimpse_model_config.glimpse_shape = self.config.glimpse_shape
        glimpse_model_config.num_resolutions = self.config.num_resolutions
        self.glimpse_net = glimpse_model.GlimpseNetwork(glimpse_model_config)

        self.emission_net = emission_model.EmissionNetwork(
            emission_model_config)

        self.classification_net = classification_model.ClassificationNetwork(
            classification_model_config)

        self.rnn_layers = []
        self.zero_states = []
        for num_units in self.config.num_units_rnn_layers:
            if self.cell_type == "lstm":
                rnn_layer = tf.nn.rnn_cell.LSTMCell(
                    num_units,
                    state_is_tuple=True,
                    activation=self.config.rnn_activation)
            elif self.cell_type == "gru":
                rnn_layer = tf.nn.rnn_cell.GRUCell(
                    num_units, activation=self.config.rnn_activation)

            self.rnn_layers.append(rnn_layer)
            self.zero_states.append(rnn_layer.zero_state)

        self.zero_states = self.zero_states

        self.var_list = []
        self.var_list_location = []
        self.var_list_classification = []
        self.init_op = None
    def test_apply_stop_gradient(self, apply_stop_gradient):
        config = dram_config.get_config()
        image_shape = (28, 28, 1)
        batch_size = 10
        config.glimpse_model_config.output_dims = 10
        config.glimpse_model_config.glimpse_shape = config.glimpse_shape
        config.glimpse_model_config.num_resolutions = config.num_resolutions
        config.glimpse_model_config.apply_stop_gradient = apply_stop_gradient

        images = tf.placeholder(shape=(batch_size, ) + image_shape,
                                dtype=tf.float32)
        locations = tf.placeholder(shape=(batch_size, 2), dtype=tf.float32)
        model = glimpse_model.GlimpseNetwork(config.glimpse_model_config)

        outputs, _ = model(images,
                           locations,
                           False,
                           use_resolution=[True] * model.num_resolutions)
        gradients = tf.gradients([outputs[0, 0]], images)

        if apply_stop_gradient:
            self.assertEqual(gradients, [None])
        else:
            self.assertEqual(gradients[0].shape, images.shape)