示例#1
0
 def test_tas_for_tensors(self):
     a = tf.reshape(tf.range(20), [5, 4])
     tensors = [a, (a, ExampleTuple(a, a))]
     tas = nested_utils.tas_for_tensors(tensors, 5)
     nest.assert_same_structure(tensors, tas)
     # We can't pass TensorArrays to sess.run so instead we turn then back into
     # tensors to check that they were created correctly.
     stacked = nested_utils.map_nested(lambda x: x.stack(), tas)
     with self.test_session() as sess:
         gt, out = sess.run([tensors, stacked])
         gt = nest.flatten(gt)
         out = nest.flatten(out)
         # Check that the tas were created correctly.
         for x, y in zip(gt, out):
             self.assertAllClose(x, y)
示例#2
0
 def test_tas_for_tensors(self):
   a = tf.reshape(tf.range(20), [5, 4])
   tensors = [a, (a, ExampleTuple(a, a))]
   tas = nested_utils.tas_for_tensors(tensors, 5)
   nest.assert_same_structure(tensors, tas)
   # We can't pass TensorArrays to sess.run so instead we turn then back into
   # tensors to check that they were created correctly.
   stacked = nested_utils.map_nested(lambda x: x.stack(), tas)
   with self.test_session() as sess:
     gt, out = sess.run([tensors, stacked])
     gt = nest.flatten(gt)
     out = nest.flatten(out)
     # Check that the tas were created correctly.
     for x, y in zip(gt, out):
       self.assertAllClose(x, y)
示例#3
0
文件: smc.py 项目: CV-IP/MobileNeXt
def relaxed_resampling(log_weights,
                       states,
                       num_particles,
                       batch_size,
                       temperature=0.5,
                       random_seed=None):
    """Resample states with relaxed resampling.

  Draw soft "ancestors" using the Gumbel-Softmax distribution.

  Args:
    log_weights: A [num_particles, batch_size] Tensor representing a batch
      of batch_size logits for num_particles-ary Categorical distribution.
    states: A nested list of [batch_size * num_particles, d] Tensors that will
      be resampled from the groups of every num_particles-th row.
    num_particles: The number of particles/samples.
    batch_size: The batch size.
    temperature: The temperature used for the relaxed one hot distribution.
    random_seed: The random seed to pass to the resampling operations in
      the particle filter. Mainly useful for testing.

  Returns:
    resampled_states: A nested list of [batch_size * num_particles, d]
      Tensors resampled via multinomial sampling.
  """
    # log_weights are [num_particles, batch_size], so we transpose to get a
    # set of batch_size distributions over [0, num_particles).
    resampling_parameters = tf.transpose(log_weights, perm=[1, 0])
    resampling_dist = tf.contrib.distributions.RelaxedOneHotCategorical(
        temperature, logits=resampling_parameters)

    # Sample num_particles samples from the distribution, resulting in a
    # [num_particles, batch_size, num_particles] Tensor that represents a set of
    # [num_particles, batch_size] blending weights. The dimensions represent
    # [particle index, batch index, blending weight index].
    ancestors = resampling_dist.sample(sample_shape=num_particles,
                                       seed=random_seed)

    def map_fn(tensor):
        return _blend_tensor(ancestors, tensor, num_particles, batch_size)

    resampled_states = nested.map_nested(map_fn, states)
    return resampled_states
示例#4
0
文件: smc.py 项目: 812864539/models
def relaxed_resampling(log_weights, states, num_particles, batch_size,
                       temperature=0.5, random_seed=None):
  """Resample states with relaxed resampling.

  Draw soft "ancestors" using the Gumbel-Softmax distribution.

  Args:
    log_weights: A [num_particles, batch_size] Tensor representing a batch
      of batch_size logits for num_particles-ary Categorical distribution.
    states: A nested list of [batch_size * num_particles, d] Tensors that will
      be resampled from the groups of every num_particles-th row.
    num_particles: The number of particles/samples.
    batch_size: The batch size.
    temperature: The temperature used for the relaxed one hot distribution.
    random_seed: The random seed to pass to the resampling operations in
      the particle filter. Mainly useful for testing.

  Returns:
    resampled_states: A nested list of [batch_size * num_particles, d]
      Tensors resampled via multinomial sampling.
  """
  # log_weights are [num_particles, batch_size], so we transpose to get a
  # set of batch_size distributions over [0, num_particles).
  resampling_parameters = tf.transpose(log_weights, perm=[1, 0])
  resampling_dist = tf.contrib.distributions.RelaxedOneHotCategorical(
      temperature,
      logits=resampling_parameters)

  # Sample num_particles samples from the distribution, resulting in a
  # [num_particles, batch_size, num_particles] Tensor that represents a set of
  # [num_particles, batch_size] blending weights. The dimensions represent
  # [particle index, batch index, blending weight index].
  ancestors = resampling_dist.sample(sample_shape=num_particles,
                                     seed=random_seed)
  def map_fn(tensor):
    return _blend_tensor(ancestors, tensor, num_particles, batch_size)

  resampled_states = nested.map_nested(map_fn, states)
  return resampled_states
示例#5
0
 def test_map_nested_works_on_flat_lists(self):
     """Check that map_nested works with a flat list."""
     original = [1, 2, 3]
     expected = [2, 3, 4]
     out = nested_utils.map_nested(lambda x: x + 1, original)
     self.assertEqual(expected, out)
示例#6
0
 def test_map_nested_works_on_single_objects(self):
     """Check that map_nested works with raw objects."""
     original = 1
     expected = 2
     out = nested_utils.map_nested(lambda x: x + 1, original)
     self.assertEqual(expected, out)
示例#7
0
 def test_map_nested_works_on_nested_structures(self):
     """Check that map_nested works with nested structures."""
     original = [1, (2, 3.2, (4., ExampleTuple(5, 6)))]
     expected = [2, (3, 4.2, (5., ExampleTuple(6, 7)))]
     out = nested_utils.map_nested(lambda x: x + 1, original)
     self.assertEqual(expected, out)
示例#8
0
 def test_map_nested_works_on_flat_lists(self):
   """Check that map_nested works with a flat list."""
   original = [1, 2, 3]
   expected = [2, 3, 4]
   out = nested_utils.map_nested(lambda x: x+1, original)
   self.assertEqual(expected, out)
示例#9
0
 def test_map_nested_works_on_single_objects(self):
   """Check that map_nested works with raw objects."""
   original = 1
   expected = 2
   out = nested_utils.map_nested(lambda x: x+1, original)
   self.assertEqual(expected, out)
示例#10
0
 def test_map_nested_works_on_nested_structures(self):
   """Check that map_nested works with nested structures."""
   original = [1, (2, 3.2, (4., ExampleTuple(5, 6)))]
   expected = [2, (3, 4.2, (5., ExampleTuple(6, 7)))]
   out = nested_utils.map_nested(lambda x: x+1, original)
   self.assertEqual(expected, out)