def test_tas_for_tensors(self): a = tf.reshape(tf.range(20), [5, 4]) tensors = [a, (a, ExampleTuple(a, a))] tas = nested_utils.tas_for_tensors(tensors, 5) nest.assert_same_structure(tensors, tas) # We can't pass TensorArrays to sess.run so instead we turn then back into # tensors to check that they were created correctly. stacked = nested_utils.map_nested(lambda x: x.stack(), tas) with self.test_session() as sess: gt, out = sess.run([tensors, stacked]) gt = nest.flatten(gt) out = nest.flatten(out) # Check that the tas were created correctly. for x, y in zip(gt, out): self.assertAllClose(x, y)
def relaxed_resampling(log_weights, states, num_particles, batch_size, temperature=0.5, random_seed=None): """Resample states with relaxed resampling. Draw soft "ancestors" using the Gumbel-Softmax distribution. Args: log_weights: A [num_particles, batch_size] Tensor representing a batch of batch_size logits for num_particles-ary Categorical distribution. states: A nested list of [batch_size * num_particles, d] Tensors that will be resampled from the groups of every num_particles-th row. num_particles: The number of particles/samples. batch_size: The batch size. temperature: The temperature used for the relaxed one hot distribution. random_seed: The random seed to pass to the resampling operations in the particle filter. Mainly useful for testing. Returns: resampled_states: A nested list of [batch_size * num_particles, d] Tensors resampled via multinomial sampling. """ # log_weights are [num_particles, batch_size], so we transpose to get a # set of batch_size distributions over [0, num_particles). resampling_parameters = tf.transpose(log_weights, perm=[1, 0]) resampling_dist = tf.contrib.distributions.RelaxedOneHotCategorical( temperature, logits=resampling_parameters) # Sample num_particles samples from the distribution, resulting in a # [num_particles, batch_size, num_particles] Tensor that represents a set of # [num_particles, batch_size] blending weights. The dimensions represent # [particle index, batch index, blending weight index]. ancestors = resampling_dist.sample(sample_shape=num_particles, seed=random_seed) def map_fn(tensor): return _blend_tensor(ancestors, tensor, num_particles, batch_size) resampled_states = nested.map_nested(map_fn, states) return resampled_states
def test_map_nested_works_on_flat_lists(self): """Check that map_nested works with a flat list.""" original = [1, 2, 3] expected = [2, 3, 4] out = nested_utils.map_nested(lambda x: x + 1, original) self.assertEqual(expected, out)
def test_map_nested_works_on_single_objects(self): """Check that map_nested works with raw objects.""" original = 1 expected = 2 out = nested_utils.map_nested(lambda x: x + 1, original) self.assertEqual(expected, out)
def test_map_nested_works_on_nested_structures(self): """Check that map_nested works with nested structures.""" original = [1, (2, 3.2, (4., ExampleTuple(5, 6)))] expected = [2, (3, 4.2, (5., ExampleTuple(6, 7)))] out = nested_utils.map_nested(lambda x: x + 1, original) self.assertEqual(expected, out)
def test_map_nested_works_on_flat_lists(self): """Check that map_nested works with a flat list.""" original = [1, 2, 3] expected = [2, 3, 4] out = nested_utils.map_nested(lambda x: x+1, original) self.assertEqual(expected, out)
def test_map_nested_works_on_single_objects(self): """Check that map_nested works with raw objects.""" original = 1 expected = 2 out = nested_utils.map_nested(lambda x: x+1, original) self.assertEqual(expected, out)
def test_map_nested_works_on_nested_structures(self): """Check that map_nested works with nested structures.""" original = [1, (2, 3.2, (4., ExampleTuple(5, 6)))] expected = [2, (3, 4.2, (5., ExampleTuple(6, 7)))] out = nested_utils.map_nested(lambda x: x+1, original) self.assertEqual(expected, out)