Пример #1
0
def load_module(handle, tags=None, load_options=None):
    if callable(handle):
        if tags is not None:
            raise ValueError("Passing a callable handle is mutually exclusive "
                             "with setting tags.")
        if load_options is not None:
            raise ValueError("Passing a callable handle is mutually exclusive "
                             "with setting load_options.")
        return handle
    else:
        try:
            # pylint: disable=g-import-not-at-top
            # pylint: disable=g-direct-tensorflow-import
            from tensorflow.python.keras.saving.saved_model import load_context
            set_load_options = load_options or load_context.get_load_options()
        except ImportError:  # Expected before TF2.5.
            try:
                # pylint: disable=g-import-not-at-top
                # pylint: disable=g-direct-tensorflow-import
                from tensorflow.python.saved_model import load_context
                set_load_options = load_options or load_context.get_load_options(
                )
            except ImportError:  # Expected before TF2.4.
                set_load_options = load_options
        return module_v2.load(handle, tags=tags, options=set_load_options)
Пример #2
0
def load_module(handle, tags=None):
    if callable(handle):
        if tags is not None:
            raise ValueError("Passing a callable handle is mutually exclusive "
                             "with setting tags.")
        return handle
    else:
        return module_v2.load(handle, tags=tags)
Пример #3
0
 def test_load(self, module_name, tags, is_hub_module_v1):
     export_dir = os.path.join(self.get_temp_dir(), module_name)
     if module_name == 'hub_module_v1_mini':
         _save_plus_one_hub_module_v1(export_dir)
     else:
         _save_plus_one_saved_model_v2(export_dir)
     m = module_v2.load(export_dir, tags)
     self.assertEqual(m._is_hub_module_v1, is_hub_module_v1)
Пример #4
0
    def __init__(self, handle, trainable=False, arguments=None, **kwargs):
        # Note: for compatibility with keras-model serialization this layer is
        # json-serializable. If you add or change arguments here, please also update
        # the `get_config` method.
        self._handle = handle

        # Resolve the handle to a callable `func`.
        # NOTE: The name _func gets baked into object-based checkpoints.
        if callable(handle):
            self._func = handle
        else:
            self._func = module_v2.load(handle)
            if not callable(self._func):
                raise ValueError("Non-callable result from hub.load('%s')" %
                                 str(handle))
        # TODO(b/124219898): We should do shape inference on the callable.
        if "output_shape" in kwargs:
            self._output_shape = tuple(kwargs.pop("output_shape"))

        # Initialize an empty layer, then add_weight() etc. as needed.
        super(KerasLayer, self).__init__(trainable=trainable, **kwargs)

        # Add trainable and non-trainable weights from the callable.
        if hasattr(self._func, "trainable_variables"):
            for v in self._func.trainable_variables:
                self._add_existing_weight(v, trainable=True)
            trainable_variables = {
                id(v)
                for v in self._func.trainable_variables
            }
        else:
            trainable_variables = set()
        if hasattr(self._func, "variables"):
            for v in self._func.variables:
                if id(v) not in trainable_variables:
                    self._add_existing_weight(v, trainable=False)

        # Forward the callable's regularization losses (if any).
        if hasattr(self._func, "regularization_losses"):
            for l in self._func.regularization_losses:
                if not callable(l):
                    raise ValueError(
                        "hub.KerasLayer(obj) expects obj.regularization_losses to be an "
                        "iterable of callables, each returning a scalar loss term."
                    )
                self.add_loss(
                    self._call_loss_if_trainable(l))  # Supports callables.

        # Prepare to call `func`.
        self._func_fullargspec = tf_inspect.getfullargspec(self._func.__call__)
        self._func_wants_training = ("training" in self._func_fullargspec.args
                                     or "training"
                                     in self._func_fullargspec.kwonlyargs)
        if arguments is not None:
            self._arguments = arguments
Пример #5
0
    def __init__(self, handle, trainable=False, arguments=None, **kwargs):
        # Resolve the handle to a callable `func`.
        if callable(handle):
            self._func = handle
        else:
            self._func = module_v2.load(handle)
            if not callable(self._func):
                raise ValueError("Non-callable result from hub.load('%s')" %
                                 str(handle))

        # Set self._{non,}_trainable_weights and then call Layer.__init__.
        # This together with @no_automatic_dependency_tracking above preserves
        # func.trainable_variables independent of tf.Variable(..., trainable=...).
        if hasattr(self._func, "trainable_variables"):
            self._trainable_weights = [
                v for v in self._func.trainable_variables
            ]
            trainable_variables_set = set(self._func.trainable_variables)
        else:
            self._trainable_weights = []
            trainable_variables_set = set()
        if hasattr(self._func, "variables"):
            self._non_trainable_weights = [
                v for v in self._func.variables
                if v not in trainable_variables_set
            ]
        else:
            self._non_trainable_weights = []

        # TODO(b/124219898): We should be able to get the embedding dimension from
        # the restored model.
        if "output_shape" in kwargs:
            self._output_shape = tuple(kwargs.pop("output_shape"))

        super(KerasLayer, self).__init__(trainable=trainable, **kwargs)

        # Prepare to call `func`.
        self._func_fullargspec = tf_inspect.getfullargspec(self._func.__call__)
        self._func_wants_training = ("training" in self._func_fullargspec.args
                                     or "training"
                                     in self._func_fullargspec.kwonlyargs)
        self._arguments = arguments or {}

        # Forward the callable's regularization losses (if any).
        if hasattr(self._func, "regularization_losses"):
            for l in self._func.regularization_losses:
                if not callable(l):
                    raise ValueError(
                        "hub.KerasLayer(obj) expects obj.regularization_losses to be an "
                        "iterable of callables, each returning a scalar loss term."
                    )
                self.add_loss(l)  # Supports callables.
Пример #6
0
 def test_load_sparse(self):
     if any(tf.__version__.startswith(bad) for bad in ['1.', '2.0.']):
         self.skipTest(
             'load_v1_in_v2 did not handle sparse tensors correctly'
             'in TensorFlow version %r.' % (tf.__version__, ))
     export_dir = os.path.join(self.get_temp_dir(), 'sparse')
     _save_sparse_plus_one_hub_module_v1(export_dir)
     m = module_v2.load(export_dir)
     self.assertTrue(m._is_hub_module_v1)
     plus_one = m.signatures['default']
     st = tf.sparse.from_dense([[1.0, 2.0, 0.0], [0.0, 3.0, 0.0]])
     actual = plus_one(default_indices=st.indices,
                       default_values=st.values,
                       default_dense_shape=st.dense_shape)['default']
     expected = [2.0, 3.0, 4.0]
     self.assertAllEqual(actual.values, expected)
Пример #7
0
 def test_load_ragged(self):
     if any(
             tf.__version__.startswith(bad)
             for bad in ['1.', '2.0.', '2.1.', '2.2.', '2.3.']):
         self.skipTest(
             'load_v1_in_v2 did not handle composite tensors correctly'
             'in TensorFlow version %r.' % (tf.__version__, ))
     export_dir = os.path.join(self.get_temp_dir(), 'ragged')
     _save_ragged_plus_one_hub_module_v1(export_dir)
     m = module_v2.load(export_dir)
     self.assertTrue(m._is_hub_module_v1)
     plus_one = m.signatures['default']
     rt = tf.ragged.constant([[1.0, 8.0], [3.0]])
     actual = plus_one(default_component_0=rt.values,
                       default_component_1=rt.row_splits)['default']
     expected = [2.0, 9.0, 4.0]
     self.assertAllEqual(actual.values, expected)
Пример #8
0
 def test_load_without_string(self):
     with self.assertRaisesRegex(ValueError, 'Expected a string, got.*'):
         module_v2.load(0)
Пример #9
0
 def test_load_with_old_tensorflow_raises_error(self, tf_v1_mock):
     tf_v1_mock.saved_model = None
     with self.assertRaises(NotImplementedError):
         module_v2.load('dummy_module_name')
Пример #10
0
 def test_load(self, module_name, tags, is_hub_module_v1):
     path = test_utils.get_test_data_path(module_name)
     m = module_v2.load(path, tags)
     self.assertEqual(m._is_hub_module_v1, is_hub_module_v1)
Пример #11
0
    def __init__(self, handle, trainable=False, arguments=None, **kwargs):
        # Note: for compatibility with keras-model serialization this layer is
        # json-serializable. If you add or change arguments here, please also update
        # the `get_config` method.
        self._handle = handle

        # Resolve the handle to a callable `func`.
        # NOTE: The name _func gets baked into object-based checkpoints.
        if callable(handle):
            self._func = handle
        else:
            self._func = module_v2.load(handle)
            if not callable(self._func):
                raise ValueError("Non-callable result from hub.load('%s')" %
                                 str(handle))
        # TODO(b/142213824): Remove setting shapes when shape inference works.
        if "output_shape" in kwargs:
            # Autograph chokes on _convert_nest_to_shapes(), so we call it here
            # and not from within call(). The result is marked NoDependency
            # to avoid autoconversion to a trackable _DictWrapper, because that
            # upsets json.dumps() when saving the result of get_config().
            self._output_shape = data_structures.NoDependency(
                _convert_nest_to_shapes(kwargs.pop("output_shape")))

        # Initialize an empty layer, then add_weight() etc. as needed.
        super(KerasLayer, self).__init__(trainable=trainable, **kwargs)

        # Add trainable and non-trainable weights from the callable.
        if hasattr(self._func, "trainable_variables"):
            for v in self._func.trainable_variables:
                self._add_existing_weight(v, trainable=True)
            trainable_variables = {
                id(v)
                for v in self._func.trainable_variables
            }
        else:
            trainable_variables = set()
        if hasattr(self._func, "variables"):
            for v in self._func.variables:
                if id(v) not in trainable_variables:
                    self._add_existing_weight(v, trainable=False)

        # Forward the callable's regularization losses (if any).
        if hasattr(self._func, "regularization_losses"):
            for l in self._func.regularization_losses:
                if not callable(l):
                    raise ValueError(
                        "hub.KerasLayer(obj) expects obj.regularization_losses to be an "
                        "iterable of callables, each returning a scalar loss term."
                    )
                self.add_loss(
                    self._call_loss_if_trainable(l))  # Supports callables.

        # Prepare to call `func`.
        self._func_fullargspec = tf_inspect.getfullargspec(self._func.__call__)
        self._func_wants_training = ("training" in self._func_fullargspec.args
                                     or "training"
                                     in self._func_fullargspec.kwonlyargs)
        if arguments is not None:
            # The attribute is marked NoDependency to avoid autoconversion to a
            # trackable _DictWrapper, because that upsets json.dumps() when saving
            # the result of get_config().
            self._arguments = data_structures.NoDependency(arguments)