def test_import_unrecognized_class(self):
        path = self.create_tempdir().full_path

        class Normal(tfd.Normal
                     ):  # Note, same name as tfd.Normal, but diff type.
            pass

        class Model(tf.Module):
            @tf.function(input_signature=())
            def make_dist(self):
                return tfp.experimental.as_composite(Normal(0, 1))

        m1 = Model()
        tf.saved_model.save(m1, os.path.join(path, 'saved_model1'))
        # Eliminate cached classes, forcing breakage on load.
        clsid_registry.clear()
        with self.assertRaisesRegexp(
                ValueError,
                r'For user-defined.*decorated.*register_composite'):
            tf.saved_model.load(os.path.join(path, 'saved_model1'))

        tfp.experimental.as_composite(Normal(0, 1))
        # Now warmed-up, loading should work.
        m2 = tf.saved_model.load(os.path.join(path, 'saved_model1'))
        self.evaluate(m2.make_dist().sample())

        # Eliminate cached classes again, but now register Normal as if it had been
        # decorated from the beginning.
        clsid_registry.clear()
        self.assertEqual(Normal, tfp.experimental.register_composite(Normal))

        # Loading should work again.
        m3 = tf.saved_model.load(os.path.join(path, 'saved_model1'))
        self.evaluate(m3.make_dist().sample())
示例#2
0
  def test_import_uncached_class(self):
    path = self.create_tempdir().full_path

    class Model(tf.Module):

      @tf.function(input_signature=(normal_composite(0, [1, 2])._type_spec,))
      def make_dist(self, d):
        return normal_composite(d.sample(), 1, validate_args=True)

    m1 = Model()
    tf.saved_model.save(m1, os.path.join(path, 'saved_model1'))
    # Eliminate cached classes, forcing auto-regen of class on load.
    clsid_registry.clear()
    m2 = tf.saved_model.load(os.path.join(path, 'saved_model1'))
    d = normal_composite(.3, [.5, .9])
    self.evaluate(m2.make_dist(d).sample())
    def test_import_unrecognized_class(self):
        path = self.create_tempdir().full_path

        class Normal(tfd.Distribution
                     ):  # Same name as tfd.Normal, but diff type.
            def __init__(self, loc, scale):
                self.dist = tfd.Normal(loc, scale)
                super(Normal, self).__init__(dtype=None,
                                             reparameterization_type=None,
                                             validate_args=False,
                                             allow_nan_stats=False)

            def _sample_n(self, n, seed=None, **kwargs):
                return self.dist._sample_n(n, seed=seed, **kwargs)  # pylint: disable=protected-access

        class Model(tf.Module):
            @tf.function(input_signature=())
            def make_dist(self):
                return tfp.experimental.as_composite(Normal(0, 1))

        m1 = Model()
        tf.saved_model.save(m1, os.path.join(path, 'saved_model1'))
        # Eliminate cached classes, forcing breakage on load.
        clsid_registry.clear()
        with self.assertRaisesRegexp(
                ValueError,
                r'For user-defined.*decorated.*register_composite'):
            tf.saved_model.load(os.path.join(path, 'saved_model1'))

        tfp.experimental.as_composite(Normal(0, 1))
        # Now warmed-up, loading should work.
        m2 = tf.saved_model.load(os.path.join(path, 'saved_model1'))
        self.evaluate(m2.make_dist().sample())

        # Eliminate cached classes again, but now register Normal as if it had been
        # decorated from the beginning.
        clsid_registry.clear()
        self.assertEqual(Normal, tfp.experimental.register_composite(Normal))

        # Loading should work again.
        m3 = tf.saved_model.load(os.path.join(path, 'saved_model1'))
        self.evaluate(m3.make_dist().sample())
示例#4
0
  def test_import_unrecognized_class(self):
    path = self.create_tempdir().full_path

    class Normal(tfd.Normal):  # Note, same name as tfd.Normal, but diff type.
      pass

    class Model(tf.Module):

      @tf.function(input_signature=())
      def make_dist(self):
        return tfp.experimental.as_composite(Normal(0, 1))

    m1 = Model()
    tf.saved_model.save(m1, os.path.join(path, 'saved_model1'))
    # Eliminate cached classes, forcing breakage on load.
    clsid_registry.clear()
    with self.assertRaisesRegexp(
        ValueError, r'For non-builtin.*call `as_composite` before'):
      tf.saved_model.load(os.path.join(path, 'saved_model1'))

    tfp.experimental.as_composite(Normal(0, 1))
    # Now warmed-up, loading should work.
    m2 = tf.saved_model.load(os.path.join(path, 'saved_model1'))
    self.evaluate(m2.make_dist().sample())