def test_build(self, use_prev_locations, policy, is_training): config = dram_config.get_config() state_dim = 64 batch_size = 10 location_dims = 2 state = tf.placeholder(shape=(batch_size, state_dim), dtype=tf.float32) model = emission_model.EmissionNetwork(config.emission_model_config) if use_prev_locations: prev_locations = tf.convert_to_tensor(np.random.rand( batch_size, location_dims), dtype=tf.float32) else: prev_locations = None locations, _ = model(state, location_scale=1, prev_locations=prev_locations, policy=policy, is_training=is_training) init_op = model.init_op with self.test_session() as sess: sess.run(init_op) self.assertEqual( (batch_size, location_dims), sess.run(locations, feed_dict={ state: np.random.rand(batch_size, state_dim) }).shape)
def test_use_resolution(self, is_training, use_resolution): config = dram_config.get_config() image_shape = (28, 28, 1) batch_size = 5 output_dims = 10 config.glimpse_model_config.output_dims = output_dims config.glimpse_model_config.glimpse_shape = config.glimpse_shape config.glimpse_model_config.num_resolutions = config.num_resolutions config.glimpse_model_config.glimpse_shape = (8, 8) config.glimpse_model_config.num_resolutions = 3 locations = tf.placeholder(shape=(batch_size, 2), dtype=tf.float32) model = glimpse_model.GlimpseNetwork(config.glimpse_model_config) images = tf.random_uniform(minval=-1, maxval=1, shape=(batch_size, ) + image_shape, dtype=tf.float32) locations = tf.zeros(shape=(batch_size, 2), dtype=tf.float32) model = glimpse_model.GlimpseNetwork(config.glimpse_model_config) g, endpoints = model(images, locations, is_training=is_training, use_resolution=use_resolution) gnorms = [ tf.norm(grad) for grad in tf.gradients(g[:, 0], endpoints["model_input_list"]) ] self.evaluate(tf.global_variables_initializer()) gnorms = self.evaluate(gnorms) for use, gnorm in zip(use_resolution, gnorms): if use: self.assertGreater(gnorm, 0.) else: self.assertEqual(gnorm, 0.)
def test_build(self, is_training): config = dram_config.get_config() image_shape = (28, 28, 1) batch_size = 10 location_dims = 2 output_dims = 10 config.glimpse_model_config.output_dims = output_dims config.glimpse_model_config.glimpse_shape = config.glimpse_shape config.glimpse_model_config.num_resolutions = config.num_resolutions images = tf.placeholder(shape=(batch_size, ) + image_shape, dtype=tf.float32) locations = tf.placeholder(shape=(batch_size, 2), dtype=tf.float32) model = glimpse_model.GlimpseNetwork(config.glimpse_model_config) g, _ = model(images, locations, is_training, use_resolution=[True] * model.num_resolutions) init_op = model.init_op with self.test_session() as sess: sess.run(init_op) self.assertEqual( (batch_size, output_dims), sess.run(g, feed_dict={ images: np.random.rand(*((batch_size, ) + image_shape)), locations: np.random.rand(batch_size, location_dims) }).shape)
def test_build(self, is_training): config = dram_config.get_config() num_times = 2 image_shape = (28, 28, 1) num_classes = 10 config.num_classes = num_classes config.num_units_rnn_layers = [10, 10,] config.num_times = num_times batch_size = 3 images = tf.constant( np.random.rand(*((batch_size,) + image_shape)), dtype=tf.float32) model = dram.DRAMNetwork(config) logits_t = model(images, num_times=num_times, is_training=is_training)[0] init_op = model.init_op self.evaluate(init_op) self.assertEqual((batch_size, num_classes), self.evaluate(logits_t[-1]).shape)
def test_build(self): config = dram_config.get_config() batch_size = 10 input_dims = 64 num_classes = 10 config.classification_model_config.num_classes = 10 state = tf.placeholder(shape=(batch_size, input_dims), dtype=tf.float32) model = classification_model.ClassificationNetwork( config.classification_model_config) logits, _ = model(state) init_op = model.init_op with self.test_session() as sess: sess.run(init_op) self.assertEqual( (batch_size, num_classes), sess.run( logits, feed_dict={ state: np.random.rand(batch_size, input_dims) }).shape)
def test_apply_stop_gradient(self, apply_stop_gradient): config = dram_config.get_config() image_shape = (28, 28, 1) batch_size = 10 config.glimpse_model_config.output_dims = 10 config.glimpse_model_config.glimpse_shape = config.glimpse_shape config.glimpse_model_config.num_resolutions = config.num_resolutions config.glimpse_model_config.apply_stop_gradient = apply_stop_gradient images = tf.placeholder(shape=(batch_size, ) + image_shape, dtype=tf.float32) locations = tf.placeholder(shape=(batch_size, 2), dtype=tf.float32) model = glimpse_model.GlimpseNetwork(config.glimpse_model_config) outputs, _ = model(images, locations, False, use_resolution=[True] * model.num_resolutions) gradients = tf.gradients([outputs[0, 0]], images) if apply_stop_gradient: self.assertEqual(gradients, [None]) else: self.assertEqual(gradients[0].shape, images.shape)