def test_use_resolution(self, is_training, use_resolution): config = dram_config.get_config() image_shape = (28, 28, 1) batch_size = 5 output_dims = 10 config.glimpse_model_config.output_dims = output_dims config.glimpse_model_config.glimpse_shape = config.glimpse_shape config.glimpse_model_config.num_resolutions = config.num_resolutions config.glimpse_model_config.glimpse_shape = (8, 8) config.glimpse_model_config.num_resolutions = 3 locations = tf.placeholder(shape=(batch_size, 2), dtype=tf.float32) model = glimpse_model.GlimpseNetwork(config.glimpse_model_config) images = tf.random_uniform(minval=-1, maxval=1, shape=(batch_size, ) + image_shape, dtype=tf.float32) locations = tf.zeros(shape=(batch_size, 2), dtype=tf.float32) model = glimpse_model.GlimpseNetwork(config.glimpse_model_config) g, endpoints = model(images, locations, is_training=is_training, use_resolution=use_resolution) gnorms = [ tf.norm(grad) for grad in tf.gradients(g[:, 0], endpoints["model_input_list"]) ] self.evaluate(tf.global_variables_initializer()) gnorms = self.evaluate(gnorms) for use, gnorm in zip(use_resolution, gnorms): if use: self.assertGreater(gnorm, 0.) else: self.assertEqual(gnorm, 0.)
def test_build(self, is_training): config = dram_config.get_config() image_shape = (28, 28, 1) batch_size = 10 location_dims = 2 output_dims = 10 config.glimpse_model_config.output_dims = output_dims config.glimpse_model_config.glimpse_shape = config.glimpse_shape config.glimpse_model_config.num_resolutions = config.num_resolutions images = tf.placeholder(shape=(batch_size, ) + image_shape, dtype=tf.float32) locations = tf.placeholder(shape=(batch_size, 2), dtype=tf.float32) model = glimpse_model.GlimpseNetwork(config.glimpse_model_config) g, _ = model(images, locations, is_training, use_resolution=[True] * model.num_resolutions) init_op = model.init_op with self.test_session() as sess: sess.run(init_op) self.assertEqual( (batch_size, output_dims), sess.run(g, feed_dict={ images: np.random.rand(*((batch_size, ) + image_shape)), locations: np.random.rand(batch_size, location_dims) }).shape)
def __init__(self, config): """Init. Args: config: ConfigDict object with model parameters (see dram_config.py). """ self.config = copy.deepcopy(config) if len(self.config.num_units_rnn_layers) != 2: raise ValueError( "num_units_rnn_layers should be a list of length 2.") self.cell_type = self.config.cell_type glimpse_model_config = self.config.glimpse_model_config emission_model_config = self.config.emission_model_config classification_model_config = self.config.classification_model_config classification_model_config.num_classes = self.config.num_classes glimpse_model_config.glimpse_shape = self.config.glimpse_shape glimpse_model_config.num_resolutions = self.config.num_resolutions self.glimpse_net = glimpse_model.GlimpseNetwork(glimpse_model_config) self.emission_net = emission_model.EmissionNetwork( emission_model_config) self.classification_net = classification_model.ClassificationNetwork( classification_model_config) self.rnn_layers = [] self.zero_states = [] for num_units in self.config.num_units_rnn_layers: if self.cell_type == "lstm": rnn_layer = tf.nn.rnn_cell.LSTMCell( num_units, state_is_tuple=True, activation=self.config.rnn_activation) elif self.cell_type == "gru": rnn_layer = tf.nn.rnn_cell.GRUCell( num_units, activation=self.config.rnn_activation) self.rnn_layers.append(rnn_layer) self.zero_states.append(rnn_layer.zero_state) self.zero_states = self.zero_states self.var_list = [] self.var_list_location = [] self.var_list_classification = [] self.init_op = None
def test_apply_stop_gradient(self, apply_stop_gradient): config = dram_config.get_config() image_shape = (28, 28, 1) batch_size = 10 config.glimpse_model_config.output_dims = 10 config.glimpse_model_config.glimpse_shape = config.glimpse_shape config.glimpse_model_config.num_resolutions = config.num_resolutions config.glimpse_model_config.apply_stop_gradient = apply_stop_gradient images = tf.placeholder(shape=(batch_size, ) + image_shape, dtype=tf.float32) locations = tf.placeholder(shape=(batch_size, 2), dtype=tf.float32) model = glimpse_model.GlimpseNetwork(config.glimpse_model_config) outputs, _ = model(images, locations, False, use_resolution=[True] * model.num_resolutions) gradients = tf.gradients([outputs[0, 0]], images) if apply_stop_gradient: self.assertEqual(gradients, [None]) else: self.assertEqual(gradients[0].shape, images.shape)