def _prune(self, xs, retain=None): """Drops fields from `xs`, retaining those specified by `retain`. Args: xs: A structure like that of `self.distribution.dtype` retain: One of `'pinned'` or `'unpinned'`. Returns: xs: Input `xs`, pruned to retain only the parts specified by `retain`. """ if retain not in ('pinned', 'unpinned'): raise ValueError('Invalid value for `retain`: {}'.format(retain)) def should_retain(k): return (k not in self.pins) ^ (retain == 'pinned') if isinstance(xs, dict): return type(xs)((k, v) for k, v in xs.items() if should_retain(k)) if hasattr(xs, '_fields') and hasattr(xs, '_asdict'): tuple_type = structural_tuple.structtuple( [k for k in xs._fields if should_retain(k)]) return tuple_type( **{k: v for k, v in xs._asdict().items() if should_retain(k)}) names = self.distribution._flat_resolve_names() return type(xs)( [x for i, x in enumerate(xs) if should_retain(names[i])])
def testValidNamedTuple(self): t = structural_tuple.structtuple(['a', 'b', 'c']) inst = t(a=1, b=2, c=3) a, b, c = inst self.assertAllEqualNested((1, 2, 3), (a, b, c)) self.assertAllEqualNested( t(2, 3, 4), tf.nest.map_structure(lambda x: x + 1, inst))
def testArgsExpansion(self): def foo(a, b): return a + b t = structural_tuple.structtuple(['c', 'd']) self.assertEqual(3, nest_util.call_fn(foo, t(1, 2)))
def testMake(self): t = structural_tuple.structtuple(['a', 'b']) ab = t._make((1, 2)) self.assertEqual(1, ab.a) self.assertEqual(2, ab.b) ab = t._make((1,)) self.assertEqual(1, ab.a) self.assertIs(None, ab.b)
def test_event_shape_and_structure_jd_coroutine(self): self.assertEqual( structural_tuple.structtuple(['x', 'y', 'z'])([], [2], [2, 4, 4]), tfde.JointDistributionPinned(jd_coroutine(), w=1.).event_shape) self.assertEqual( structural_tuple.structtuple(['w', 'z'])([], [2, 4, 4]), tfde.JointDistributionPinned(jd_coroutine(), None, 1., [2., 3]).event_shape) self.assertEqual( structural_tuple.structtuple(['w', 'y', 'z'])([], [2], [2, 4, 4]), tfde.JointDistributionPinned(jd_coroutine(), x=2.).event_shape) obs = jd_coroutine().sample(seed=test_util.test_seed())[-1:] self.assertEqual(obs._fields, ('z', )) self.assertEqual( structural_tuple.structtuple(['w', 'x', 'y'])([], [], [2]), tfde.JointDistributionPinned(jd_coroutine(), obs).event_shape)
def testConcatenation(self): t1 = structural_tuple.structtuple(['a', 'b']) t2 = structural_tuple.structtuple(['c', 'd']) ab = t1(a=1, b=2) cd = t2(c=3, d=4) abcd = ab + cd self.assertAllEqual((1, 2, 3, 4), tuple(abcd)) self.assertAllEqual(('a', 'b', 'c', 'd'), abcd._fields) cdab = cd + ab self.assertAllEqual((3, 4, 1, 2), tuple(cdab)) self.assertAllEqual(('c', 'd', 'a', 'b'), cdab._fields) ab_tuple = ab + (3,) self.assertAllEqual((1, 2, 3), ab_tuple) tuple_ab = (3,) + ab self.assertAllEqual((3, 1, 2), tuple_ab)
def _model_unflatten(self, xs): """Unflattens `xs` to a structure like-typed to `self.distribution`.""" # Use the underlying JD dtype to infer model structure. dtype = self.distribution.dtype if isinstance(dtype, dict): ks = self._flat_resolve_names() if len(ks) != len(xs): raise ValueError('Invalid xs length {}, ks={}'.format(len(xs), ks)) return type(dtype)(zip(ks, xs)) if hasattr(dtype, '_fields') and hasattr(dtype, '_asdict'): ks = [k for k in dtype._fields if k not in self.pins] return structural_tuple.structtuple(ks)(*xs) return type(dtype)(xs)
def _initialize_parameters(generator, seed=None): """Samples initial values for all parameters yielded by a generator. Args: generator: Python generator that yields initialization callables (which take a `seed` and return a (structure of) `Tensor`(s)), returns a value, and has no side effects. See module description. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. Returns: raw_parameters: Python list of `Tensor` (or structure of `Tensor`s) initial parameter values returned from the yielded callables. """ gen = generator() if not isinstance(gen, types.GeneratorType): raise ValueError( 'Expected generator but saw function: {}. A generator ' 'must contain at least one `yield` statement. To define a ' 'trivial generator, which yields zero times, a `yield` ' 'statement may be placed after `return`, but must still ' 'be present.'.format(generator)) raw_parameters = [] parameter_names = [] param_value = None try: while True: parameter = gen.send(param_value) if not hasattr(parameter, 'init_fn'): raise ValueError( 'Expected generator to yield a ' 'trainable_state_util.Parameter namedtuple, but saw ' '{} instead.'.format(parameter)) seed, local_seed = samplers.split_seed(seed, n=2) # Note: this framework guarantees that the `init_fn` is only ever called # here, immediately after being yielded before control is returned # to the coroutine. This allows the coroutine to safely incorporate # loop-dependent state in the closure of `init_fn` if desired. param_value = _call_init_fn(parameter.init_fn, seed=local_seed) raw_value = (param_value if parameter.constraining_bijector is None else parameter.constraining_bijector.inverse(param_value)) raw_parameters.append(raw_value) parameter_names.append( _get_unused_parameter_name(parameter.name or 'parameter', parameter_names)) except StopIteration: pass return structural_tuple.structtuple(parameter_names)(*raw_parameters)
def test_dtype_and_structure_jd_named(self): self.assertEqual(dict(x=tf.float32, y=tf.float32, z=tf.float32), tfde.JointDistributionPinned(jd_named(), w=1.).dtype) self.assertEqual( collections.OrderedDict((('w', tf.float32), ('z', tf.float32))), tfde.JointDistributionPinned(jd_named_ordered(), None, 1., [2., 3]).dtype) self.assertEqual( collections.OrderedDict( (('w', tf.float32), ('x', tf.float32), ('z', tf.float32))), tfde.JointDistributionPinned(jd_named_ordered(), y=[2., 3]).dtype) self.assertEqual( structural_tuple.structtuple(['w', 'y', 'z'])(tf.float32, tf.float32, tf.float32), tfde.JointDistributionPinned(jd_named_namedtuple(), x=2.).dtype)
def test_event_shape_and_structure_jd_named(self): self.assertEqual( dict(x=[], y=[2], z=[2, 4, 4]), tfde.JointDistributionPinned(jd_named(), w=1.).event_shape) self.assertEqual( collections.OrderedDict((('w', []), ('z', [2, 4, 4]))), tfde.JointDistributionPinned(jd_named_ordered(), None, 1., [2., 3]).event_shape) self.assertEqual( collections.OrderedDict((('w', []), ('x', []), ('z', [2, 4, 4]))), tfde.JointDistributionPinned(jd_named_ordered(), y=[2., 3]).event_shape) self.assertEqual( structural_tuple.structtuple(['w', 'y', 'z'])([], [2], [2, 4, 4]), tfde.JointDistributionPinned(jd_named_namedtuple(), x=2.).event_shape)
def testSlicing(self): t = structural_tuple.structtuple(['a', 'b', 'c']) inst = t(a=1, b=2, c=3) abc = inst[:] self.assertAllEqual((1, 2, 3), tuple(abc)) self.assertAllEqual(('a', 'b', 'c'), abc._fields) ab = inst[:2] self.assertAllEqual((1, 2), tuple(ab)) self.assertAllEqual(('a', 'b'), ab._fields) ac = inst[::2] self.assertAllEqual((1, 3), tuple(ac)) self.assertAllEqual(('a', 'c'), ac._fields) ab2 = abc[:2] self.assertAllEqual((1, 2), tuple(ab2)) self.assertAllEqual(('a', 'b'), ab2._fields)
def test_bijector(self): jd = tfd.JointDistributionSequential([ tfd.Uniform(-1., 1.), lambda a: tfd.Uniform( a + tf.ones_like(a), a + tf.constant(2, a.dtype)), lambda b, a: tfd.Uniform(a, b, name='c') ]) bij = jd._experimental_default_event_space_bijector(a=-.5, b=1.) self.assertAllClose((2 / 3, ), tf.math.sigmoid(bij.inverse((0.5, )))) @tfd.JointDistributionCoroutine def model(): root = tfd.JointDistributionCoroutine.Root x = yield root(tfd.Normal(0., 1., name='x')) y = yield root(tfd.Gamma(1., 1., name='y')) yield tfd.Normal(x, y, name='z') bij = model._experimental_default_event_space_bijector( model.sample(seed=test_util.test_seed())[-1:]) self.assertAllCloseNested( structural_tuple.structtuple(['x', 'y'])(1., 2.), bij.forward((1., tfp.math.softplus_inverse(2.))))
def testReplaceUnknownFields(self): t = structural_tuple.structtuple(['a']) a = t() with self.assertRaisesRegexp( ValueError, r'Got unexpected field names: \[\'b\', \'c\'\]'): a._replace(b=1, c=2)
def testMakeTooManyValues(self): t = structural_tuple.structtuple(['a', 'b']) with self.assertRaisesRegexp(TypeError, 'Expected 2 arguments or fewer, got 3'): t._make([1, 2, 3])
def testMoreThan255Fields(self): num_fields = 1000 t = structural_tuple.structtuple( ['field{}'.format(n) for n in range(num_fields)]) self.assertLen(t._fields, num_fields)
def testInvalidIdentifierField(self): with self.assertRaisesRegexp(ValueError, 'Field names must be valid identifiers: 0'): structural_tuple.structtuple(['0'])
def testNonStrField(self): with self.assertRaisesRegexp( TypeError, 'Field names must be strings: 1 has type <class \'int\'>'): structural_tuple.structtuple([1])
def testUnderscoreField(self): with self.assertRaisesRegexp( ValueError, 'Field names cannot start with an underscore: _a'): structural_tuple.structtuple(['_a'])
def testKeywordField(self): with self.assertRaisesRegexp(ValueError, 'Field names cannot be a keyword: def'): structural_tuple.structtuple(['def'])
def testDuplicateConstructorArg(self): t = structural_tuple.structtuple(['a']) with self.assertRaisesRegexp(TypeError, 'Got multiple values for argument a'): t(1, a=2)
def testDuplicateField(self): with self.assertRaisesRegexp(ValueError, 'Encountered duplicate field name: a'): structural_tuple.structtuple(['a', 'a'])
def testCacheWithUnderscores(self): t1 = structural_tuple.structtuple(['a_b', 'b']) t2 = structural_tuple.structtuple(['a', 'b_c']) self.assertIsNot(t1, t2)
def testCacheWorks(self): t1 = structural_tuple.structtuple(['a', 'b', 'c']) t2 = structural_tuple.structtuple(['a', 'b', 'c']) t3 = structural_tuple.structtuple(['a', 'b']) self.assertIs(t1, t2) self.assertIsNot(t1, t3)
def testUnexpectedConstructorArg(self): t = structural_tuple.structtuple(['a']) with self.assertRaisesRegexp(TypeError, 'Got an unexpected keyword argument b'): t(b=2)
def test_bijector_unconstrained_shapes(self): pinned = tfde.JointDistributionPinned(jd_coroutine(), x=1., y=[1., 1]) bij = pinned._experimental_default_event_space_bijector() self.assertEqual( structural_tuple.structtuple(['w', 'z'])([], [2, 6]), bij.inverse_event_shape(pinned.event_shape))
def testMissingAttribute(self): t = structural_tuple.structtuple(['a']) a = t() with self.assertRaisesRegexp(AttributeError, 'StructTuple has no attribute b'): _ = a.b
def _model_unflatten(self, xs): if self._sample_dtype is None: return structural_tuple.structtuple(self._flat_resolve_names())(*xs) # Cast `xs` as `tuple` so we can handle generators. return tf.nest.pack_sequence_as(self._sample_dtype, tuple(xs))
def testTreeUtilIntegration(self): t = structural_tuple.structtuple(['a', 'b', 'c']) inst = t(a=1, b=2, c=3) mapped = jax.tree_util.tree_map(lambda x: x + 1, inst) self.assertIsInstance(mapped, type(inst)) self.assertAllEqualNested(t(2, 3, 4), mapped)