Ejemplo n.º 1
0
    def _prune(self, xs, retain=None):
        """Drops fields from `xs`, retaining those specified by `retain`.

    Args:
      xs: A structure like that of `self.distribution.dtype`
      retain: One of `'pinned'` or `'unpinned'`.

    Returns:
      xs: Input `xs`, pruned to retain only the parts specified by `retain`.
    """
        if retain not in ('pinned', 'unpinned'):
            raise ValueError('Invalid value for `retain`: {}'.format(retain))

        def should_retain(k):
            return (k not in self.pins) ^ (retain == 'pinned')

        if isinstance(xs, dict):
            return type(xs)((k, v) for k, v in xs.items() if should_retain(k))
        if hasattr(xs, '_fields') and hasattr(xs, '_asdict'):
            tuple_type = structural_tuple.structtuple(
                [k for k in xs._fields if should_retain(k)])
            return tuple_type(
                **{k: v
                   for k, v in xs._asdict().items() if should_retain(k)})
        names = self.distribution._flat_resolve_names()
        return type(xs)(
            [x for i, x in enumerate(xs) if should_retain(names[i])])
Ejemplo n.º 2
0
 def testValidNamedTuple(self):
   t = structural_tuple.structtuple(['a', 'b', 'c'])
   inst = t(a=1, b=2, c=3)
   a, b, c = inst
   self.assertAllEqualNested((1, 2, 3), (a, b, c))
   self.assertAllEqualNested(
       t(2, 3, 4), tf.nest.map_structure(lambda x: x + 1, inst))
  def testArgsExpansion(self):

    def foo(a, b):
      return a + b

    t = structural_tuple.structtuple(['c', 'd'])

    self.assertEqual(3, nest_util.call_fn(foo, t(1, 2)))
 def testMake(self):
   t = structural_tuple.structtuple(['a', 'b'])
   ab = t._make((1, 2))
   self.assertEqual(1, ab.a)
   self.assertEqual(2, ab.b)
   ab = t._make((1,))
   self.assertEqual(1, ab.a)
   self.assertIs(None, ab.b)
Ejemplo n.º 5
0
    def test_event_shape_and_structure_jd_coroutine(self):
        self.assertEqual(
            structural_tuple.structtuple(['x', 'y', 'z'])([], [2], [2, 4, 4]),
            tfde.JointDistributionPinned(jd_coroutine(), w=1.).event_shape)
        self.assertEqual(
            structural_tuple.structtuple(['w', 'z'])([], [2, 4, 4]),
            tfde.JointDistributionPinned(jd_coroutine(), None, 1.,
                                         [2., 3]).event_shape)
        self.assertEqual(
            structural_tuple.structtuple(['w', 'y', 'z'])([], [2], [2, 4, 4]),
            tfde.JointDistributionPinned(jd_coroutine(), x=2.).event_shape)

        obs = jd_coroutine().sample(seed=test_util.test_seed())[-1:]
        self.assertEqual(obs._fields, ('z', ))
        self.assertEqual(
            structural_tuple.structtuple(['w', 'x', 'y'])([], [], [2]),
            tfde.JointDistributionPinned(jd_coroutine(), obs).event_shape)
  def testConcatenation(self):
    t1 = structural_tuple.structtuple(['a', 'b'])
    t2 = structural_tuple.structtuple(['c', 'd'])
    ab = t1(a=1, b=2)
    cd = t2(c=3, d=4)

    abcd = ab + cd
    self.assertAllEqual((1, 2, 3, 4), tuple(abcd))
    self.assertAllEqual(('a', 'b', 'c', 'd'), abcd._fields)

    cdab = cd + ab
    self.assertAllEqual((3, 4, 1, 2), tuple(cdab))
    self.assertAllEqual(('c', 'd', 'a', 'b'), cdab._fields)

    ab_tuple = ab + (3,)
    self.assertAllEqual((1, 2, 3), ab_tuple)

    tuple_ab = (3,) + ab
    self.assertAllEqual((3, 1, 2), tuple_ab)
 def _model_unflatten(self, xs):
   """Unflattens `xs` to a structure like-typed to `self.distribution`."""
   # Use the underlying JD dtype to infer model structure.
   dtype = self.distribution.dtype
   if isinstance(dtype, dict):
     ks = self._flat_resolve_names()
     if len(ks) != len(xs):
       raise ValueError('Invalid xs length {}, ks={}'.format(len(xs), ks))
     return type(dtype)(zip(ks, xs))
   if hasattr(dtype, '_fields') and hasattr(dtype, '_asdict'):
     ks = [k for k in dtype._fields if k not in self.pins]
     return structural_tuple.structtuple(ks)(*xs)
   return type(dtype)(xs)
Ejemplo n.º 8
0
def _initialize_parameters(generator, seed=None):
    """Samples initial values for all parameters yielded by a generator.

  Args:
    generator: Python generator that yields initialization callables
      (which take a `seed` and return a (structure of) `Tensor`(s)),
      returns a value, and has no side effects. See module description.
    seed: PRNG seed; see `tfp.random.sanitize_seed` for details.
  Returns:
    raw_parameters: Python list of `Tensor` (or structure of `Tensor`s) initial
      parameter values returned from the yielded callables.
  """
    gen = generator()
    if not isinstance(gen, types.GeneratorType):
        raise ValueError(
            'Expected generator but saw function: {}. A generator '
            'must contain at least one `yield` statement. To define a '
            'trivial generator, which yields zero times, a `yield` '
            'statement may be placed after `return`, but must still '
            'be present.'.format(generator))

    raw_parameters = []
    parameter_names = []
    param_value = None
    try:
        while True:
            parameter = gen.send(param_value)
            if not hasattr(parameter, 'init_fn'):
                raise ValueError(
                    'Expected generator to yield a '
                    'trainable_state_util.Parameter namedtuple, but saw '
                    '{} instead.'.format(parameter))
            seed, local_seed = samplers.split_seed(seed, n=2)
            # Note: this framework guarantees that the `init_fn` is only ever called
            # here, immediately after being yielded before control is returned
            # to the coroutine. This allows the coroutine to safely incorporate
            # loop-dependent state in the closure of `init_fn` if desired.
            param_value = _call_init_fn(parameter.init_fn, seed=local_seed)
            raw_value = (param_value
                         if parameter.constraining_bijector is None else
                         parameter.constraining_bijector.inverse(param_value))
            raw_parameters.append(raw_value)
            parameter_names.append(
                _get_unused_parameter_name(parameter.name or 'parameter',
                                           parameter_names))
    except StopIteration:
        pass
    return structural_tuple.structtuple(parameter_names)(*raw_parameters)
Ejemplo n.º 9
0
 def test_dtype_and_structure_jd_named(self):
     self.assertEqual(dict(x=tf.float32, y=tf.float32, z=tf.float32),
                      tfde.JointDistributionPinned(jd_named(), w=1.).dtype)
     self.assertEqual(
         collections.OrderedDict((('w', tf.float32), ('z', tf.float32))),
         tfde.JointDistributionPinned(jd_named_ordered(), None, 1.,
                                      [2., 3]).dtype)
     self.assertEqual(
         collections.OrderedDict(
             (('w', tf.float32), ('x', tf.float32), ('z', tf.float32))),
         tfde.JointDistributionPinned(jd_named_ordered(), y=[2., 3]).dtype)
     self.assertEqual(
         structural_tuple.structtuple(['w', 'y',
                                       'z'])(tf.float32, tf.float32,
                                             tf.float32),
         tfde.JointDistributionPinned(jd_named_namedtuple(), x=2.).dtype)
Ejemplo n.º 10
0
 def test_event_shape_and_structure_jd_named(self):
     self.assertEqual(
         dict(x=[], y=[2], z=[2, 4, 4]),
         tfde.JointDistributionPinned(jd_named(), w=1.).event_shape)
     self.assertEqual(
         collections.OrderedDict((('w', []), ('z', [2, 4, 4]))),
         tfde.JointDistributionPinned(jd_named_ordered(), None, 1.,
                                      [2., 3]).event_shape)
     self.assertEqual(
         collections.OrderedDict((('w', []), ('x', []), ('z', [2, 4, 4]))),
         tfde.JointDistributionPinned(jd_named_ordered(),
                                      y=[2., 3]).event_shape)
     self.assertEqual(
         structural_tuple.structtuple(['w', 'y', 'z'])([], [2], [2, 4, 4]),
         tfde.JointDistributionPinned(jd_named_namedtuple(),
                                      x=2.).event_shape)
Ejemplo n.º 11
0
  def testSlicing(self):
    t = structural_tuple.structtuple(['a', 'b', 'c'])
    inst = t(a=1, b=2, c=3)

    abc = inst[:]
    self.assertAllEqual((1, 2, 3), tuple(abc))
    self.assertAllEqual(('a', 'b', 'c'), abc._fields)

    ab = inst[:2]
    self.assertAllEqual((1, 2), tuple(ab))
    self.assertAllEqual(('a', 'b'), ab._fields)

    ac = inst[::2]
    self.assertAllEqual((1, 3), tuple(ac))
    self.assertAllEqual(('a', 'c'), ac._fields)

    ab2 = abc[:2]
    self.assertAllEqual((1, 2), tuple(ab2))
    self.assertAllEqual(('a', 'b'), ab2._fields)
Ejemplo n.º 12
0
    def test_bijector(self):
        jd = tfd.JointDistributionSequential([
            tfd.Uniform(-1., 1.), lambda a: tfd.Uniform(
                a + tf.ones_like(a), a + tf.constant(2, a.dtype)),
            lambda b, a: tfd.Uniform(a, b, name='c')
        ])
        bij = jd._experimental_default_event_space_bijector(a=-.5, b=1.)
        self.assertAllClose((2 / 3, ), tf.math.sigmoid(bij.inverse((0.5, ))))

        @tfd.JointDistributionCoroutine
        def model():
            root = tfd.JointDistributionCoroutine.Root
            x = yield root(tfd.Normal(0., 1., name='x'))
            y = yield root(tfd.Gamma(1., 1., name='y'))
            yield tfd.Normal(x, y, name='z')

        bij = model._experimental_default_event_space_bijector(
            model.sample(seed=test_util.test_seed())[-1:])
        self.assertAllCloseNested(
            structural_tuple.structtuple(['x', 'y'])(1., 2.),
            bij.forward((1., tfp.math.softplus_inverse(2.))))
Ejemplo n.º 13
0
 def testReplaceUnknownFields(self):
   t = structural_tuple.structtuple(['a'])
   a = t()
   with self.assertRaisesRegexp(
       ValueError, r'Got unexpected field names: \[\'b\', \'c\'\]'):
     a._replace(b=1, c=2)
Ejemplo n.º 14
0
 def testMakeTooManyValues(self):
   t = structural_tuple.structtuple(['a', 'b'])
   with self.assertRaisesRegexp(TypeError,
                                'Expected 2 arguments or fewer, got 3'):
     t._make([1, 2, 3])
Ejemplo n.º 15
0
 def testMoreThan255Fields(self):
   num_fields = 1000
   t = structural_tuple.structtuple(
       ['field{}'.format(n) for n in range(num_fields)])
   self.assertLen(t._fields, num_fields)
Ejemplo n.º 16
0
 def testInvalidIdentifierField(self):
   with self.assertRaisesRegexp(ValueError,
                                'Field names must be valid identifiers: 0'):
     structural_tuple.structtuple(['0'])
Ejemplo n.º 17
0
 def testNonStrField(self):
   with self.assertRaisesRegexp(
       TypeError, 'Field names must be strings: 1 has type <class \'int\'>'):
     structural_tuple.structtuple([1])
Ejemplo n.º 18
0
 def testUnderscoreField(self):
   with self.assertRaisesRegexp(
       ValueError, 'Field names cannot start with an underscore: _a'):
     structural_tuple.structtuple(['_a'])
Ejemplo n.º 19
0
 def testKeywordField(self):
   with self.assertRaisesRegexp(ValueError,
                                'Field names cannot be a keyword: def'):
     structural_tuple.structtuple(['def'])
Ejemplo n.º 20
0
 def testDuplicateConstructorArg(self):
   t = structural_tuple.structtuple(['a'])
   with self.assertRaisesRegexp(TypeError,
                                'Got multiple values for argument a'):
     t(1, a=2)
Ejemplo n.º 21
0
 def testDuplicateField(self):
   with self.assertRaisesRegexp(ValueError,
                                'Encountered duplicate field name: a'):
     structural_tuple.structtuple(['a', 'a'])
Ejemplo n.º 22
0
 def testCacheWithUnderscores(self):
   t1 = structural_tuple.structtuple(['a_b', 'b'])
   t2 = structural_tuple.structtuple(['a', 'b_c'])
   self.assertIsNot(t1, t2)
Ejemplo n.º 23
0
 def testCacheWorks(self):
   t1 = structural_tuple.structtuple(['a', 'b', 'c'])
   t2 = structural_tuple.structtuple(['a', 'b', 'c'])
   t3 = structural_tuple.structtuple(['a', 'b'])
   self.assertIs(t1, t2)
   self.assertIsNot(t1, t3)
Ejemplo n.º 24
0
 def testUnexpectedConstructorArg(self):
   t = structural_tuple.structtuple(['a'])
   with self.assertRaisesRegexp(TypeError,
                                'Got an unexpected keyword argument b'):
     t(b=2)
Ejemplo n.º 25
0
 def test_bijector_unconstrained_shapes(self):
     pinned = tfde.JointDistributionPinned(jd_coroutine(), x=1., y=[1., 1])
     bij = pinned._experimental_default_event_space_bijector()
     self.assertEqual(
         structural_tuple.structtuple(['w', 'z'])([], [2, 6]),
         bij.inverse_event_shape(pinned.event_shape))
Ejemplo n.º 26
0
 def testMissingAttribute(self):
   t = structural_tuple.structtuple(['a'])
   a = t()
   with self.assertRaisesRegexp(AttributeError,
                                'StructTuple has no attribute b'):
     _ = a.b
Ejemplo n.º 27
0
 def _model_unflatten(self, xs):
   if self._sample_dtype is None:
     return structural_tuple.structtuple(self._flat_resolve_names())(*xs)
   # Cast `xs` as `tuple` so we can handle generators.
   return tf.nest.pack_sequence_as(self._sample_dtype, tuple(xs))
Ejemplo n.º 28
0
 def testTreeUtilIntegration(self):
   t = structural_tuple.structtuple(['a', 'b', 'c'])
   inst = t(a=1, b=2, c=3)
   mapped = jax.tree_util.tree_map(lambda x: x + 1, inst)
   self.assertIsInstance(mapped, type(inst))
   self.assertAllEqualNested(t(2, 3, 4), mapped)