def test_import_unrecognized_class(self): path = self.create_tempdir().full_path class Normal(tfd.Normal ): # Note, same name as tfd.Normal, but diff type. pass class Model(tf.Module): @tf.function(input_signature=()) def make_dist(self): return tfp.experimental.as_composite(Normal(0, 1)) m1 = Model() tf.saved_model.save(m1, os.path.join(path, 'saved_model1')) # Eliminate cached classes, forcing breakage on load. clsid_registry.clear() with self.assertRaisesRegexp( ValueError, r'For user-defined.*decorated.*register_composite'): tf.saved_model.load(os.path.join(path, 'saved_model1')) tfp.experimental.as_composite(Normal(0, 1)) # Now warmed-up, loading should work. m2 = tf.saved_model.load(os.path.join(path, 'saved_model1')) self.evaluate(m2.make_dist().sample()) # Eliminate cached classes again, but now register Normal as if it had been # decorated from the beginning. clsid_registry.clear() self.assertEqual(Normal, tfp.experimental.register_composite(Normal)) # Loading should work again. m3 = tf.saved_model.load(os.path.join(path, 'saved_model1')) self.evaluate(m3.make_dist().sample())
def test_import_uncached_class(self): path = self.create_tempdir().full_path class Model(tf.Module): @tf.function(input_signature=(normal_composite(0, [1, 2])._type_spec,)) def make_dist(self, d): return normal_composite(d.sample(), 1, validate_args=True) m1 = Model() tf.saved_model.save(m1, os.path.join(path, 'saved_model1')) # Eliminate cached classes, forcing auto-regen of class on load. clsid_registry.clear() m2 = tf.saved_model.load(os.path.join(path, 'saved_model1')) d = normal_composite(.3, [.5, .9]) self.evaluate(m2.make_dist(d).sample())
def test_import_unrecognized_class(self): path = self.create_tempdir().full_path class Normal(tfd.Distribution ): # Same name as tfd.Normal, but diff type. def __init__(self, loc, scale): self.dist = tfd.Normal(loc, scale) super(Normal, self).__init__(dtype=None, reparameterization_type=None, validate_args=False, allow_nan_stats=False) def _sample_n(self, n, seed=None, **kwargs): return self.dist._sample_n(n, seed=seed, **kwargs) # pylint: disable=protected-access class Model(tf.Module): @tf.function(input_signature=()) def make_dist(self): return tfp.experimental.as_composite(Normal(0, 1)) m1 = Model() tf.saved_model.save(m1, os.path.join(path, 'saved_model1')) # Eliminate cached classes, forcing breakage on load. clsid_registry.clear() with self.assertRaisesRegexp( ValueError, r'For user-defined.*decorated.*register_composite'): tf.saved_model.load(os.path.join(path, 'saved_model1')) tfp.experimental.as_composite(Normal(0, 1)) # Now warmed-up, loading should work. m2 = tf.saved_model.load(os.path.join(path, 'saved_model1')) self.evaluate(m2.make_dist().sample()) # Eliminate cached classes again, but now register Normal as if it had been # decorated from the beginning. clsid_registry.clear() self.assertEqual(Normal, tfp.experimental.register_composite(Normal)) # Loading should work again. m3 = tf.saved_model.load(os.path.join(path, 'saved_model1')) self.evaluate(m3.make_dist().sample())
def test_import_unrecognized_class(self): path = self.create_tempdir().full_path class Normal(tfd.Normal): # Note, same name as tfd.Normal, but diff type. pass class Model(tf.Module): @tf.function(input_signature=()) def make_dist(self): return tfp.experimental.as_composite(Normal(0, 1)) m1 = Model() tf.saved_model.save(m1, os.path.join(path, 'saved_model1')) # Eliminate cached classes, forcing breakage on load. clsid_registry.clear() with self.assertRaisesRegexp( ValueError, r'For non-builtin.*call `as_composite` before'): tf.saved_model.load(os.path.join(path, 'saved_model1')) tfp.experimental.as_composite(Normal(0, 1)) # Now warmed-up, loading should work. m2 = tf.saved_model.load(os.path.join(path, 'saved_model1')) self.evaluate(m2.make_dist().sample())