Ejemplo n.º 1
0
    def __init__(self, input_dataset, target_device, source_device="/cpu:0"):
        """Constructs a _CopyToDeviceDataset.

    Args:
      input_dataset: `Dataset` to be copied
      target_device: The name of the device to which elements would be copied.
      source_device: Device where input_dataset would be placed.
    """
        super(_CopyToDeviceDataset, self).__init__(input_dataset)
        self._input_dataset = input_dataset
        self._target_device = target_device
        spec = framework_device.DeviceSpec().from_string(self._target_device)
        self._is_gpu_target = (spec.device_type == "GPU")
        self._source_device_string = source_device
        self._source_device = ops.convert_to_tensor(source_device)

        @function.defun()
        def _init_func():
            """Creates an iterator for the input dataset.

      Returns:
        A `string` tensor that encapsulates the iterator created.
      """
            # pylint: disable=protected-access
            ds_variant = self._input_dataset._as_variant_tensor()
            resource = gen_dataset_ops.anonymous_iterator(
                **dataset_ops.flat_structure(self._input_dataset))
            with ops.control_dependencies(
                [gen_dataset_ops.make_iterator(ds_variant, resource)]):
                return gen_dataset_ops.iterator_to_string_handle(resource)

        init_func_concrete = _init_func._get_concrete_function_internal()  # pylint: disable=protected-access

        @function.defun()
        def _remote_init_func():
            return functional_ops.remote_call(
                target=self._source_device,
                args=init_func_concrete.captured_inputs,
                Tout=[dtypes.string],
                f=init_func_concrete)

        self._init_func = _remote_init_func._get_concrete_function_internal()  # pylint: disable=protected-access
        self._init_captured_args = self._init_func.captured_inputs

        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
        def _next_func(string_handle):
            """Calls get_next for created iterator.

      Args:
        string_handle: An iterator string handle created by _init_func
      Returns:
        The elements generated from `input_dataset`
      """
            with ops.device(self._source_device_string):
                iterator = iterator_ops.Iterator.from_string_handle(
                    string_handle, self.output_types, self.output_shapes,
                    self.output_classes)
            return self._element_structure._to_tensor_list(iterator.get_next())  # pylint: disable=protected-access

        next_func_concrete = _next_func._get_concrete_function_internal()  # pylint: disable=protected-access

        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
        def _remote_next_func(string_handle):
            return functional_ops.remote_call(
                target=self._source_device,
                args=[string_handle] + next_func_concrete.captured_inputs,
                Tout=self._input_dataset._element_structure._flat_types,  # pylint: disable=protected-access
                f=next_func_concrete)

        self._next_func = _remote_next_func._get_concrete_function_internal()  # pylint: disable=protected-access
        self._next_captured_args = self._next_func.captured_inputs

        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
        def _finalize_func(string_handle):
            """Destroys the iterator resource created.

      Args:
        string_handle: An iterator string handle created by _init_func
      Returns:
        Tensor constant 0
      """
            iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
                string_handle,
                **dataset_ops.flat_structure(self._input_dataset))
            with ops.control_dependencies([
                    resource_variable_ops.destroy_resource_op(
                        iterator_resource, ignore_lookup_error=True)
            ]):
                return array_ops.constant(0, dtypes.int64)

        finalize_func_concrete = _finalize_func._get_concrete_function_internal(
        )  # pylint: disable=protected-access

        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
        def _remote_finalize_func(string_handle):
            return functional_ops.remote_call(
                target=self._source_device,
                args=[string_handle] + finalize_func_concrete.captured_inputs,
                Tout=[dtypes.int64],
                f=finalize_func_concrete)

        self._finalize_func = _remote_finalize_func._get_concrete_function_internal(  # pylint: disable=protected-access
        )
        self._finalize_captured_args = self._finalize_func.captured_inputs

        g = ops.get_default_graph()
        self._init_func.add_to_graph(g)
        self._next_func.add_to_graph(g)
        self._finalize_func.add_to_graph(g)
Ejemplo n.º 2
0
 def testTensorDatasetSpec(self):
     self._testDatasetSpec(constant_op.constant(37.0),
                           tensor_spec.TensorSpec([], dtypes.float32))
Ejemplo n.º 3
0
 def testOptionalDatasetSpec(self):
     self._testDatasetSpec(
         optional_ops.Optional.from_value(37.0),
         optional_ops.OptionalSpec(
             tensor_spec.TensorSpec([], dtypes.float32)))
Ejemplo n.º 4
0
 def _component_specs(self):
     return [tensor_spec.TensorSpec((), dtypes.variant)]
Ejemplo n.º 5
0
    def _model_fn():
      """Compute fit/eval/predict for the TPU."""
      is_training = self.execution_mode == model_fn_lib.ModeKeys.TRAIN
      is_test = self.execution_mode == model_fn_lib.ModeKeys.EVAL
      is_predict = self.execution_mode == model_fn_lib.ModeKeys.PREDICT

      # During train/eval, we infeed our features as well as labels.
      if is_training or is_test:
        infeed_layers = self.model._input_layers + self.model._output_layers
      else:
        infeed_layers = self.model._input_layers

      # Generate our infeed operation to read features & labels.
      infeed_tensors = tpu_ops.infeed_dequeue_tuple(
          dtypes=[spec.dtype for spec in input_specs],
          shapes=[spec.shape for spec in input_specs],
          name='infeed-%s' % self.execution_mode)

      assert len(infeed_tensors) == len(infeed_layers), (
          'Infeed inputs did not match model: %s vs %s', (infeed_layers,
                                                          infeed_tensors))

      tpu_targets = []
      tpu_inputs = []

      # Sort infeed outputs into inputs and labels for calling our Keras model.
      for tensor, layer in zip(infeed_tensors, infeed_layers):
        if layer in self.model._input_layers:
          tpu_inputs.append(layers.Input(name=layer.name, tensor=tensor))
        if layer in self.model._output_layers:
          tpu_targets.append(tensor)

      # Call our model with our infeed inputs (re-using the weights).
      model_outputs = self.model(tpu_inputs)
      child_model = models.Model(inputs=tpu_inputs, outputs=model_outputs)
      if is_training or is_test:
        child_model.compile(
            optimizer=self.model.optimizer,
            loss=self.model.loss,
            loss_weights=self.model.loss_weights,
            metrics=self.model.metrics,
            weighted_metrics=self.model.weighted_metrics,
            target_tensors=tpu_targets,
        )

      # Compute our outfeed depending on the execution mode
      if is_training:
        child_model._make_train_function()
        self._outfeed_spec = [
            tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name)
            for tensor in child_model.train_function.outputs
        ]
        return [
            child_model.train_function.updates_op,
            tpu_ops.outfeed_enqueue_tuple(
                child_model.train_function.outputs, name='oufeed-enqueue-train')
        ]
      elif is_test:
        child_model._make_test_function()
        self._outfeed_spec = [
            tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name)
            for tensor in child_model.test_function.outputs
        ]
        return [
            tpu_ops.outfeed_enqueue_tuple(
                child_model.test_function.outputs, name='outfeed-enqueue-test')
        ]
      elif is_predict:
        child_model._make_predict_function()
        self._outfeed_spec = [
            tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name)
            for tensor in child_model.predict_function.outputs
        ]
        return [
            tpu_ops.outfeed_enqueue_tuple(
                child_model.predict_function.outputs,
                name='outfeed-enqueue-predict',
            )
        ]
      else:
        assert False, 'Unexpected execution mode: %s' % self.execution_mode
Ejemplo n.º 6
0
    def __init__(self,
                 dataset_id,
                 processing_mode,
                 address,
                 protocol,
                 job_name=None,
                 max_outstanding_requests=None,
                 task_refresh_interval_hint_ms=None):
        """Constructs a _DataServiceDatasetV2.

    Args:
      dataset_id: The dataset id for the dataset to read from.
      processing_mode: A string specifying the policy for how data should be
        processed by tf.data workers. Currently, the only supported value is
        "parallel_epochs".
      address: The tf.data service address, e.g. "localhost:5000".
      protocol: The protocol to use for communicating with the tf.data service,
        e.g. "grpc".
      job_name: (Optional.) The name of the job. This argument makes it possible
        for multiple datasets to share the same job. The default behavior is
        that the dataset creates anonymous, exclusively owned jobs.
      max_outstanding_requests: (Optional.) A limit on how many elements may be
        requested at the same time. You can use this option to control the
        amount of memory used, since `distribute` won't use more than
        `element_size` * `max_outstanding_requests` of memory.
      task_refresh_interval_hint_ms: (Optional.) A hint for how often to query
        the dispatcher for task changes.
    """

        if job_name is None:
            job_name = ""
        if max_outstanding_requests is None:
            max_outstanding_requests = dataset_ops.AUTOTUNE
        if task_refresh_interval_hint_ms is None:
            task_refresh_interval_hint_ms = dataset_ops.AUTOTUNE

        self._dataset_id = ops.convert_to_tensor(dataset_id,
                                                 dtype=dtypes.int64,
                                                 name="dataset_id")
        self._processing_mode = ops.convert_to_tensor(processing_mode,
                                                      dtype=dtypes.string,
                                                      name="processing_mode")
        self._address = ops.convert_to_tensor(address,
                                              dtype=dtypes.string,
                                              name="address")
        self._protocol = ops.convert_to_tensor(protocol,
                                               dtype=dtypes.string,
                                               name="protocol")
        self._job_name = ops.convert_to_tensor(job_name,
                                               dtype=dtypes.string,
                                               name="job_name")
        self._max_outstanding_requests = ops.convert_to_tensor(
            max_outstanding_requests,
            dtype=dtypes.int64,
            name="max_outstanding_requests")
        # Datasets executed by the tf.data service produce compressed elements
        # represented by scalar DT_VARIANTs.
        self._element_spec = tensor_spec.TensorSpec(shape=(),
                                                    dtype=dtypes.variant)

        variant_tensor = gen_experimental_dataset_ops.data_service_dataset(
            dataset_id=self._dataset_id,
            processing_mode=self._processing_mode,
            address=self._address,
            protocol=self._protocol,
            job_name=self._job_name,
            max_outstanding_requests=self._max_outstanding_requests,
            task_refresh_interval_hint_ms=task_refresh_interval_hint_ms,
            iteration_counter=gen_experimental_dataset_ops.
            dummy_iteration_counter(),
            **self._flat_structure)
        super(_DataServiceDatasetV2, self).__init__(variant_tensor)
Ejemplo n.º 7
0
 def testDatasetSpecInnerSpec(self):
   inner_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.int32)
   ds_spec = dataset_ops.DatasetSpec(inner_spec)
   self.assertEqual(ds_spec.element_spec, inner_spec)
Ejemplo n.º 8
0
 class Adder(module.Module):
     @def_function.function(input_signature=[
         tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32)
     ])
     def add(self, x):
         return x + x + 1.
Ejemplo n.º 9
0
 def compute_output_signature(self, input_spec):
   output_shape = self.compute_output_shape(input_spec.shape.as_list())
   output_dtype = K.floatx() if self._output_mode == TFIDF else dtypes.int64
   return tensor_spec.TensorSpec(shape=output_shape, dtype=output_dtype)
Ejemplo n.º 10
0
 def testCustomMapping(self):
   elem = CustomMap(foo=constant_op.constant(37.))
   spec = structure.type_spec_from_value(elem)
   self.assertIsInstance(spec, CustomMap)
   self.assertEqual(spec["foo"], tensor_spec.TensorSpec([], dtypes.float32))
Ejemplo n.º 11
0
 class ObjWithDefaultSignature(checkpoint.Checkpoint):
     @def_function.function(input_signature=[
         tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32)
     ])
     def _default_save_signature(self, x):
         return x + x + 1
Ejemplo n.º 12
0
class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):

  def _save_model_dir(self, dirname='saved_model'):
    temp_dir = self.get_temp_dir()
    self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
    return os.path.join(temp_dir, dirname)

  @parameterized.parameters(
      {
          'model_builder': functional_model,
          'uses_learning_phase': True,
          'optimizer_cls': adadelta.Adadelta,
          'train_before_export': True},
      {
          'model_builder': functional_model,
          'uses_learning_phase': True,
          'optimizer_cls': training_module.AdadeltaOptimizer,
          'train_before_export': False},
      {
          'model_builder': functional_model,
          'uses_learning_phase': False,
          'optimizer_cls': None,
          'train_before_export': False},
      {
          'model_builder': sequential_model,
          'uses_learning_phase': True,
          'optimizer_cls': training_module.AdadeltaOptimizer,
          'train_before_export': True},
      {
          'model_builder': sequential_model,
          'uses_learning_phase': True,
          'optimizer_cls': adadelta.Adadelta,
          'train_before_export': False},
      {
          'model_builder': sequential_model,
          'uses_learning_phase': False,
          'optimizer_cls': None,
          'train_before_export': False},
      {
          'model_builder': sequential_model_without_input_shape,
          'uses_learning_phase': True,
          'optimizer_cls': training_module.AdadeltaOptimizer,
          'train_before_export': False})
  def testSaveAndLoadSavedModelExport(
      self, model_builder, uses_learning_phase, optimizer_cls,
      train_before_export):
    optimizer = None if optimizer_cls is None else optimizer_cls()

    saved_model_dir = self._save_model_dir()

    np.random.seed(130)
    input_arr = np.random.random((1, 3))
    target_arr = np.random.random((1, 3))

    model = model_builder(uses_learning_phase)
    if optimizer is not None:
      model.compile(
          loss='mse',
          optimizer=optimizer,
          metrics=['mae'])
      if train_before_export:
        model.train_on_batch(input_arr, target_arr)

      ref_loss, ref_mae = model.evaluate(input_arr, target_arr)

    ref_predict = model.predict(input_arr)

    # Export SavedModel
    keras_saved_model.export_saved_model(model, saved_model_dir)

    input_name = model.input_names[0]
    output_name = model.output_names[0]
    target_name = output_name + '_target'

    # Load predict graph, and test predictions
    with session.Session(graph=ops.Graph()) as sess:
      inputs, outputs, _ = load_model(sess, saved_model_dir,
                                      mode_keys.ModeKeys.PREDICT)

      predictions = sess.run(outputs[output_name],
                             {inputs[input_name]: input_arr})
      self.assertAllClose(ref_predict, predictions, atol=1e-05)

    if optimizer:
      # Load eval graph, and test predictions, loss and metric values
      with session.Session(graph=ops.Graph()) as sess:
        inputs, outputs, _ = load_model(sess, saved_model_dir,
                                        mode_keys.ModeKeys.TEST)

        # First obtain the loss and predictions, and run the metric update op by
        # feeding in the inputs and targets.
        metrics_name = 'mae' if tf2.enabled() else 'mean_absolute_error'
        metrics_update_op_key = 'metrics/' + metrics_name + '/update_op'
        metrics_value_op_key = 'metrics/' + metrics_name + '/value'

        loss, predictions, _ = sess.run(
            (outputs['loss'], outputs['predictions/' + output_name],
             outputs[metrics_update_op_key]), {
                 inputs[input_name]: input_arr,
                 inputs[target_name]: target_arr
             })

        # The metric value should be run after the update op, to ensure that it
        # reflects the correct value.
        metric_value = sess.run(outputs[metrics_value_op_key])

        self.assertEqual(int(train_before_export),
                         sess.run(training_module.get_global_step()))
        self.assertAllClose(ref_loss, loss, atol=1e-05)
        self.assertAllClose(ref_mae, metric_value, atol=1e-05)
        self.assertAllClose(ref_predict, predictions, atol=1e-05)

      # Load train graph, and check for the train op, and prediction values
      with session.Session(graph=ops.Graph()) as sess:
        inputs, outputs, meta_graph_def = load_model(
            sess, saved_model_dir, mode_keys.ModeKeys.TRAIN)
        self.assertEqual(int(train_before_export),
                         sess.run(training_module.get_global_step()))
        self.assertIn('loss', outputs)
        self.assertIn(metrics_update_op_key, outputs)
        self.assertIn(metrics_value_op_key, outputs)
        self.assertIn('predictions/' + output_name, outputs)

        # Train for a step
        train_op = get_train_op(meta_graph_def)
        train_outputs, _ = sess.run(
            [outputs, train_op], {inputs[input_name]: input_arr,
                                  inputs[target_name]: target_arr})
        self.assertEqual(int(train_before_export) + 1,
                         sess.run(training_module.get_global_step()))

        if uses_learning_phase:
          self.assertAllClose(
              [[0, 0, 0]], train_outputs['predictions/' + output_name],
              atol=1e-05)
        else:
          self.assertNotAllClose(
              [[0, 0, 0]], train_outputs['predictions/' + output_name],
              atol=1e-05)

  def testSaveAndLoadSavedModelWithCustomObject(self):
    saved_model_dir = self._save_model_dir()
    with session.Session(graph=ops.Graph()) as sess:
      def relu6(x):
        return keras.backend.relu(x, max_value=6)
      inputs = keras.layers.Input(shape=(1,))
      outputs = keras.layers.Activation(relu6)(inputs)
      model = keras.models.Model(inputs, outputs)
      keras_saved_model.export_saved_model(
          model, saved_model_dir, custom_objects={'relu6': relu6})
    with session.Session(graph=ops.Graph()) as sess:
      inputs, outputs, _ = load_model(sess, saved_model_dir,
                                      mode_keys.ModeKeys.PREDICT)
      input_name = model.input_names[0]
      output_name = model.output_names[0]
      predictions = sess.run(
          outputs[output_name], {inputs[input_name]: [[7], [-3], [4]]})
      self.assertAllEqual([[6], [0], [4]], predictions)

  def testAssertModelCloneSameObjectsIgnoreOptimizer(self):
    input_arr = np.random.random((1, 3))
    target_arr = np.random.random((1, 3))

    model_graph = ops.Graph()
    clone_graph = ops.Graph()

    # Create two models with the same layers but different optimizers.
    with session.Session(graph=model_graph):
      inputs = keras.layers.Input(shape=(3,))
      x = keras.layers.Dense(2)(inputs)
      x = keras.layers.Dense(3)(x)
      model = keras.models.Model(inputs, x)

      model.compile(loss='mse', optimizer=training_module.AdadeltaOptimizer())
      model.train_on_batch(input_arr, target_arr)

    with session.Session(graph=clone_graph):
      inputs = keras.layers.Input(shape=(3,))
      x = keras.layers.Dense(2)(inputs)
      x = keras.layers.Dense(3)(x)
      clone = keras.models.Model(inputs, x)
      clone.compile(loss='mse', optimizer=optimizer_v1.RMSprop(lr=0.0001))
      clone.train_on_batch(input_arr, target_arr)

    keras_saved_model._assert_same_non_optimizer_objects(
        model, model_graph, clone, clone_graph)

  def testAssertModelCloneSameObjectsThrowError(self):
    input_arr = np.random.random((1, 3))
    target_arr = np.random.random((1, 3))

    model_graph = ops.Graph()
    clone_graph = ops.Graph()

    # Create two models with the same layers but different optimizers.
    with session.Session(graph=model_graph):
      inputs = keras.layers.Input(shape=(3,))
      x = keras.layers.Dense(2)(inputs)
      x = keras.layers.Dense(3)(x)
      model = keras.models.Model(inputs, x)

      model.compile(loss='mse', optimizer=training_module.AdadeltaOptimizer())
      model.train_on_batch(input_arr, target_arr)

    with session.Session(graph=clone_graph):
      inputs = keras.layers.Input(shape=(3,))
      x = keras.layers.Dense(2)(inputs)
      x = keras.layers.Dense(4)(x)
      x = keras.layers.Dense(3)(x)
      clone = keras.models.Model(inputs, x)
      clone.compile(loss='mse', optimizer=optimizer_v1.RMSprop(lr=0.0001))
      clone.train_on_batch(input_arr, target_arr)

  def testSaveSequentialModelWithoutInputShapes(self):
    model = sequential_model_without_input_shape(True)
    # A Sequential model that hasn't been built should raise an error.
    with self.assertRaisesRegex(
        ValueError, 'Weights for sequential model have not yet been created'):
      keras_saved_model.export_saved_model(model, '')

    # Even with input_signature, the model's weights has not been created.
    with self.assertRaisesRegex(
        ValueError, 'Weights for sequential model have not yet been created'):
      saved_model_dir = self._save_model_dir()
      keras_saved_model.export_saved_model(
          model,
          saved_model_dir,
          input_signature=tensor_spec.TensorSpec(
              shape=(10, 11, 12, 13, 14), dtype=dtypes.float32,
              name='spec_input'))

  @parameterized.parameters(
      {
          'model_builder': sequential_model_without_input_shape,
          'input_signature': [tensor_spec.TensorSpec(shape=[None, 3],
                                                     dtype=dtypes.float32)]},
      {
          'model_builder': subclassed_model,
          'input_signature': [tensor_spec.TensorSpec(shape=[None, 3],
                                                     dtype=dtypes.float32)]})
  def testServingOnly(self, model_builder, input_signature):
    if context.executing_eagerly():
      saved_model_dir = self._save_model_dir()
      input_arr = np.random.random((5, 3)).astype(np.float32)
      model = model_builder()
      ref_predict = model.predict(input_arr)

      keras_saved_model.export_saved_model(
          model,
          saved_model_dir,
          serving_only=True,
          input_signature=input_signature)

      # Load predict graph, and test predictions
      with session.Session(graph=ops.Graph()) as sess:
        inputs, outputs, _ = load_model(sess, saved_model_dir,
                                        mode_keys.ModeKeys.PREDICT)
        predictions = sess.run(outputs[next(iter(outputs.keys()))],
                               {inputs[next(iter(inputs.keys()))]: input_arr})
        self.assertAllClose(ref_predict, predictions, atol=1e-05)
Ejemplo n.º 13
0
    def __init__(self,
                 input_shape=None,
                 batch_size=None,
                 dtype=None,
                 input_tensor=None,
                 sparse=False,
                 name=None,
                 ragged=False,
                 **kwargs):
        strategy = distribution_strategy_context.get_strategy()
        if strategy and batch_size is not None and \
            distributed_training_utils.global_batch_size_supported(strategy):
            if batch_size % strategy.num_replicas_in_sync != 0:
                raise ValueError(
                    'The `batch_size` argument ({}) must be divisible by '
                    'the number of replicas ({})'.format(
                        batch_size, strategy.num_replicas_in_sync))
            batch_size = batch_size // strategy.num_replicas_in_sync

        if 'batch_input_shape' in kwargs:
            batch_input_shape = kwargs.pop('batch_input_shape')
            if input_shape and batch_input_shape:
                raise ValueError('Only provide the input_shape OR '
                                 'batch_input_shape argument to '
                                 'InputLayer, not both at the same time.')
            batch_size = batch_input_shape[0]
            input_shape = batch_input_shape[1:]
        if kwargs:
            raise ValueError('Unrecognized keyword arguments:', kwargs.keys())

        if sparse and ragged:
            raise ValueError(
                'Cannot set both sparse and ragged to True in a Keras input.')

        if not name:
            prefix = 'input'
            name = prefix + '_' + str(backend.get_uid(prefix))

        if not dtype:
            if input_tensor is None:
                dtype = backend.floatx()
            else:
                dtype = backend.dtype(input_tensor)
        elif input_tensor is not None and input_tensor.dtype != dtype:
            raise ValueError(
                '`input_tensor.dtype` differs from `dtype`: %s vs. %s' %
                (input_tensor.dtype, dtype))
        super(InputLayer, self).__init__(dtype=dtype, name=name)
        self.built = True
        self.sparse = sparse
        self.ragged = ragged
        self.batch_size = batch_size
        self.supports_masking = True

        if isinstance(input_shape, tensor_shape.TensorShape):
            input_shape = tuple(input_shape.as_list())
        elif isinstance(input_shape, int):
            input_shape = (input_shape, )

        if input_tensor is None:
            if input_shape is not None:
                batch_input_shape = (batch_size, ) + tuple(input_shape)
            else:
                batch_input_shape = None
            graph = backend.get_graph()
            with graph.as_default():
                input_tensor = backend.placeholder(shape=batch_input_shape,
                                                   dtype=dtype,
                                                   name=self.name,
                                                   sparse=sparse,
                                                   ragged=ragged)

            self.is_placeholder = True
            self._batch_input_shape = batch_input_shape
        else:
            if keras_tensor.keras_tensors_enabled():
                if not isinstance(input_tensor, keras_tensor.KerasTensor):
                    input_tensor = keras_tensor.keras_tensor_from_tensor(
                        input_tensor)
            else:
                if not tf_utils.is_symbolic_tensor(input_tensor):
                    raise ValueError(
                        'You should not pass an EagerTensor to `Input`. '
                        'For example, instead of creating an '
                        'InputLayer, you should instantiate your model and '
                        'directly call it on your input.')
            self.is_placeholder = False
            try:
                self._batch_input_shape = tuple(input_tensor.shape.as_list())
            except ValueError:
                # If the shape cannot be represented as a tuple (e.g. unknown rank)
                self._batch_input_shape = None
        # Create an input node.
        input_tensor._keras_mask = None
        node_module.Node(layer=self, outputs=input_tensor)

        # Store type spec
        if isinstance(
                input_tensor,
            (composite_tensor.CompositeTensor, keras_tensor.KerasTensor)):
            self._type_spec = input_tensor._type_spec  # pylint: disable=protected-access
        else:
            self._type_spec = tensor_spec.TensorSpec(shape=input_tensor.shape,
                                                     dtype=input_tensor.dtype,
                                                     name=self.name)
Ejemplo n.º 14
0
def layer_test(layer_cls,
               kwargs=None,
               input_shape=None,
               input_dtype=None,
               input_data=None,
               expected_output=None,
               expected_output_dtype=None,
               expected_output_shape=None,
               validate_training=True,
               adapt_data=None):
    """Test routine for a layer with a single input and single output.

  Arguments:
    layer_cls: Layer class object.
    kwargs: Optional dictionary of keyword arguments for instantiating the
      layer.
    input_shape: Input shape tuple.
    input_dtype: Data type of the input data.
    input_data: Numpy array of input data.
    expected_output: Numpy array of the expected output.
    expected_output_dtype: Data type expected for the output.
    expected_output_shape: Shape tuple for the expected shape of the output.
    validate_training: Whether to attempt to validate training on this layer.
      This might be set to False for non-differentiable layers that output
      string or integer values.
    adapt_data: Optional data for an 'adapt' call. If None, adapt() will not
      be tested for this layer. This is only relevant for PreprocessingLayers.

  Returns:
    The output data (Numpy array) returned by the layer, for additional
    checks to be done by the calling code.

  Raises:
    ValueError: if `input_shape is None`.
  """
    if input_data is None:
        if input_shape is None:
            raise ValueError('input_shape is None')
        if not input_dtype:
            input_dtype = 'float32'
        input_data_shape = list(input_shape)
        for i, e in enumerate(input_data_shape):
            if e is None:
                input_data_shape[i] = np.random.randint(1, 4)
        input_data = 10 * np.random.random(input_data_shape)
        if input_dtype[:5] == 'float':
            input_data -= 0.5
        input_data = input_data.astype(input_dtype)
    elif input_shape is None:
        input_shape = input_data.shape
    if input_dtype is None:
        input_dtype = input_data.dtype
    if expected_output_dtype is None:
        expected_output_dtype = input_dtype

    # instantiation
    kwargs = kwargs or {}
    layer = layer_cls(**kwargs)

    # Test adapt, if data was passed.
    if adapt_data is not None:
        layer.adapt(adapt_data)

    # test get_weights , set_weights at layer level
    weights = layer.get_weights()
    layer.set_weights(weights)

    # test and instantiation from weights
    if 'weights' in tf_inspect.getargspec(layer_cls.__init__):
        kwargs['weights'] = weights
        layer = layer_cls(**kwargs)

    # test in functional API
    x = layers.Input(shape=input_shape[1:], dtype=input_dtype)
    y = layer(x)
    if backend.dtype(y) != expected_output_dtype:
        raise AssertionError(
            'When testing layer %s, for input %s, found output '
            'dtype=%s but expected to find %s.\nFull kwargs: %s' %
            (layer_cls.__name__, x, backend.dtype(y), expected_output_dtype,
             kwargs))

    def assert_shapes_equal(expected, actual):
        """Asserts that the output shape from the layer matches the actual shape."""
        if len(expected) != len(actual):
            raise AssertionError(
                'When testing layer %s, for input %s, found output_shape='
                '%s but expected to find %s.\nFull kwargs: %s' %
                (layer_cls.__name__, x, actual, expected, kwargs))

        for expected_dim, actual_dim in zip(expected, actual):
            if isinstance(expected_dim, tensor_shape.Dimension):
                expected_dim = expected_dim.value
            if isinstance(actual_dim, tensor_shape.Dimension):
                actual_dim = actual_dim.value
            if expected_dim is not None and expected_dim != actual_dim:
                raise AssertionError(
                    'When testing layer %s, for input %s, found output_shape='
                    '%s but expected to find %s.\nFull kwargs: %s' %
                    (layer_cls.__name__, x, actual, expected, kwargs))

    if expected_output_shape is not None:
        assert_shapes_equal(tensor_shape.TensorShape(expected_output_shape),
                            y.shape)

    # check shape inference
    model = models.Model(x, y)
    computed_output_shape = tuple(
        layer.compute_output_shape(
            tensor_shape.TensorShape(input_shape)).as_list())
    computed_output_signature = layer.compute_output_signature(
        tensor_spec.TensorSpec(shape=input_shape, dtype=input_dtype))
    actual_output = model.predict(input_data)
    actual_output_shape = actual_output.shape
    assert_shapes_equal(computed_output_shape, actual_output_shape)
    assert_shapes_equal(computed_output_signature.shape, actual_output_shape)
    if computed_output_signature.dtype != actual_output.dtype:
        raise AssertionError(
            'When testing layer %s, for input %s, found output_dtype='
            '%s but expected to find %s.\nFull kwargs: %s' %
            (layer_cls.__name__, x, actual_output.dtype,
             computed_output_signature.dtype, kwargs))
    if expected_output is not None:
        np.testing.assert_allclose(actual_output,
                                   expected_output,
                                   rtol=1e-3,
                                   atol=1e-6)

    # test serialization, weight setting at model level
    model_config = model.get_config()
    recovered_model = models.Model.from_config(model_config)
    if model.weights:
        weights = model.get_weights()
        recovered_model.set_weights(weights)
        output = recovered_model.predict(input_data)
        np.testing.assert_allclose(output, actual_output, rtol=1e-3, atol=1e-6)

    # test training mode (e.g. useful for dropout tests)
    # Rebuild the model to avoid the graph being reused between predict() and
    # See b/120160788 for more details. This should be mitigated after 2.0.
    if validate_training:
        model = models.Model(x, layer(x))
        if _thread_local_data.run_eagerly is not None:
            model.compile('rmsprop',
                          'mse',
                          weighted_metrics=['acc'],
                          run_eagerly=should_run_eagerly())
        else:
            model.compile('rmsprop', 'mse', weighted_metrics=['acc'])
        model.train_on_batch(input_data, actual_output)

    # test as first layer in Sequential API
    layer_config = layer.get_config()
    layer_config['batch_input_shape'] = input_shape
    layer = layer.__class__.from_config(layer_config)

    # Test adapt, if data was passed.
    if adapt_data is not None:
        layer.adapt(adapt_data)

    model = models.Sequential()
    model.add(layers.Input(shape=input_shape[1:], dtype=input_dtype))
    model.add(layer)
    actual_output = model.predict(input_data)
    actual_output_shape = actual_output.shape
    for expected_dim, actual_dim in zip(computed_output_shape,
                                        actual_output_shape):
        if expected_dim is not None:
            if expected_dim != actual_dim:
                raise AssertionError(
                    'When testing layer %s **after deserialization**, '
                    'for input %s, found output_shape='
                    '%s but expected to find inferred shape %s.\nFull kwargs: %s'
                    % (layer_cls.__name__, x, actual_output_shape,
                       computed_output_shape, kwargs))
    if expected_output is not None:
        np.testing.assert_allclose(actual_output,
                                   expected_output,
                                   rtol=1e-3,
                                   atol=1e-6)

    # test serialization, weight setting at model level
    model_config = model.get_config()
    recovered_model = models.Sequential.from_config(model_config)
    if model.weights:
        weights = model.get_weights()
        recovered_model.set_weights(weights)
        output = recovered_model.predict(input_data)
        np.testing.assert_allclose(output, actual_output, rtol=1e-3, atol=1e-6)

    # for further checks in the caller function
    return actual_output
Ejemplo n.º 15
0
  def testGraphDefToTf(self):
    """Tests the basic flow of `tf.mlir.experimental.convert_graph_def`

        with tf-standard-pipeline converting all the way to the TF dialect.
    """

    tensor_shape = (10, 10)

    @def_function.function(
        input_signature=(
            tensor_spec.TensorSpec(shape=tensor_shape, dtype=dtypes.float32),
            tensor_spec.TensorSpec(shape=tensor_shape, dtype=dtypes.float32),
        ))
    def add_func(lhs, rhs):
      return math_ops.add(lhs, rhs)

    tf_graph_def = add_func.get_concrete_function().graph.as_graph_def()

    mlir_tf = import_graphdef(
        tf_graph_def,
        "tf-standard-pipeline",
        False,
        input_names=["lhs", "rhs"],
        input_data_types=["DT_FLOAT", "DT_FLOAT"],
        input_data_shapes=["10,10", "10,10"],
        output_names=["Add"])
    # Check whether the mlir-function signature has the mentioned
    # inputs and outputs.
    self.assertRegex(
        mlir_tf,
        r"func @main\(%arg0: tensor<10x10xf32>, %arg1: tensor<10x10xf32>")
    self.assertRegex(mlir_tf, r'inputs = "lhs,rhs"')
    self.assertRegex(mlir_tf, r'outputs = "Add"')

    # Same check with scalar input (empty input shape).
    mlir_tf = import_graphdef(
        tf_graph_def,
        "tf-standard-pipeline",
        False,
        input_names=["lhs", "rhs"],
        input_data_types=["DT_FLOAT", "DT_FLOAT"],
        input_data_shapes=["", ""],
        output_names=["Add"])
    self.assertRegex(mlir_tf,
                     r"func @main\(%arg0: tensor<f32>, %arg1: tensor<f32>")

    # Test invalid test cases where no. of input names is invalid/wrong.
    with self.assertRaisesRegex(
        errors.InvalidArgumentError,
        "Length of input node array and data type doesn't match"):

      import_graphdef(
          tf_graph_def,
          "tf-standard-pipeline",
          False,
          input_names=["lhs"],
          input_data_types=["DT_FLOAT", "DT_FLOAT"],
          input_data_shapes=["10,10", "10,10"],
          output_names=["Add"])

    # Test invalid test cases where the input shapes argument is wrong.
    with self.assertRaisesRegex(errors.InvalidArgumentError,
                                "Dimensions must be equal"):

      import_graphdef(
          tf_graph_def,
          "tf-standard-pipeline",
          False,
          input_names=["lhs", "rhs"],
          input_data_types=["DT_FLOAT", "DT_FLOAT"],
          input_data_shapes=["10,11", "10,10"],
          output_names=["Add"])
Ejemplo n.º 16
0
 def _flat_tensor_specs(self):
   # NOTE(mrry): The default flat shape of a boxed `SparseTensor` is `(3,)`,
   # but a `SparseTensorSpec` can also represent a batch of boxed
   # `SparseTensor` objects with shape `(..., 3)` (and batches of batches,
   # etc.), so the flat shape must be unknown.
   return [tensor_spec.TensorSpec(None, dtypes.variant)]
Ejemplo n.º 17
0
def _dtype_to_spec(d):
    if not isinstance(d, type_spec.TypeSpec):
        d = tensor_spec.TensorSpec(None, d)
    return d
Ejemplo n.º 18
0
    def __init__(self, input_dataset, features, num_parallel_calls,
                 deterministic):
        self._input_dataset = input_dataset
        if not structure.are_compatible(
                input_dataset.element_spec,
                tensor_spec.TensorSpec([None], dtypes.string)):
            raise TypeError(
                "Input dataset should be a dataset of vectors of "
                f"strings. Instead it is `{input_dataset.element_spec}`.")
        self._num_parallel_calls = num_parallel_calls
        if deterministic is None:
            self._deterministic = "default"
        elif deterministic:
            self._deterministic = "true"
        else:
            self._deterministic = "false"
        # pylint: disable=protected-access
        self._features = parsing_ops._prepend_none_dimension(features)
        # TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature
        params = parsing_ops._ParseOpParams.from_features(
            self._features, [
                parsing_ops.VarLenFeature, parsing_ops.SparseFeature,
                parsing_ops.FixedLenFeature,
                parsing_ops.FixedLenSequenceFeature, parsing_ops.RaggedFeature
            ])
        # pylint: enable=protected-access
        self._sparse_keys = params.sparse_keys
        self._sparse_types = params.sparse_types
        self._ragged_keys = params.ragged_keys
        self._ragged_value_types = params.ragged_value_types
        self._ragged_split_types = params.ragged_split_types
        self._dense_keys = params.dense_keys
        self._dense_defaults = params.dense_defaults_vec
        self._dense_shapes = params.dense_shapes_as_proto
        self._dense_types = params.dense_types
        input_dataset_shape = dataset_ops.get_legacy_output_shapes(
            self._input_dataset)

        self._element_spec = {}

        for (key, value_type) in zip(params.sparse_keys, params.sparse_types):
            self._element_spec[key] = sparse_tensor.SparseTensorSpec(
                input_dataset_shape.concatenate([None]), value_type)

        for (key, value_type, dense_shape) in zip(params.dense_keys,
                                                  params.dense_types,
                                                  params.dense_shapes):
            self._element_spec[key] = tensor_spec.TensorSpec(
                input_dataset_shape.concatenate(dense_shape), value_type)

        for (key, value_type, splits_type) in zip(params.ragged_keys,
                                                  params.ragged_value_types,
                                                  params.ragged_split_types):
            self._element_spec[key] = ragged_tensor.RaggedTensorSpec(
                input_dataset_shape.concatenate([None]), value_type, 1,
                splits_type)

        variant_tensor = (
            gen_experimental_dataset_ops.parse_example_dataset_v2(
                self._input_dataset._variant_tensor,  # pylint: disable=protected-access
                self._num_parallel_calls,
                self._dense_defaults,
                self._sparse_keys,
                self._dense_keys,
                self._sparse_types,
                self._dense_shapes,
                deterministic=self._deterministic,
                ragged_keys=self._ragged_keys,
                ragged_value_types=self._ragged_value_types,
                ragged_split_types=self._ragged_split_types,
                **self._flat_structure))
        super(_ParseExampleDataset, self).__init__(input_dataset,
                                                   variant_tensor)
Ejemplo n.º 19
0
 def compute_output_signature(self, input_spec):
   output_shape = self.compute_output_shape(input_spec.shape.as_list())
   output_dtype = dtypes.int64
   return tensor_spec.TensorSpec(shape=output_shape, dtype=output_dtype)
Ejemplo n.º 20
0
 def _component_specs(self):
   return (
       tensor_spec.TensorSpec([], dtypes.resource),
       tensor_spec.TensorSpec([], dtypes.variant),
   )
Ejemplo n.º 21
0
 class M1(tracking.Checkpointable):
     @def_function.function(
         input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
     def __call__(self, x):
         return x
Ejemplo n.º 22
0
class DefFunctionTest(test.TestCase, parameterized.TestCase):
    def testNoVariables(self):
        @def_function.function
        def fn(x):
            return 2 * x

        self.assertAllEqual(fn(constant_op.constant(4.0)), 8.0)

    def testFailIfVariablesAreCreatedMoreThanOnce(self):
        @def_function.function
        def fn(x):
            return variables.Variable(1.0) + x

        with self.assertRaises(ValueError):
            fn(1.0)

    def testFailIfVariablesAreCreatedMoreThanOnceNoWeakRef(self):
        state = []

        @def_function.function
        def fn(x):
            state.append(variables.Variable(1.0))
            return state[-1] + x

        with self.assertRaises(ValueError):
            fn(1.0)

    def testRange(self):
        @def_function.function
        def f(unused_x):
            return 1.0

        self.assertAllEqual(f(range(5)), 1.0)

    def testCorrectVariableCreation(self):

        state = []

        @def_function.function
        def fn(x):
            if not state:
                state.append(variables.Variable(2.0))
            return state[0] * x

        self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0)
        self.assertAllEqual(fn(constant_op.constant(3.0)), 6.0)

    def testFunctionInitializer(self):

        state = []

        @def_function.function
        def fn(x):
            if not state:
                state.append(variables.Variable(lambda: 2.0))
            return state[0] * x

        self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0)

    def testFunctionMultipleVariableInitializer(self):

        state = []

        @def_function.function
        def fn(x):
            if not state:
                state.append(variables.Variable(lambda: 2.0))
                state.append(variables.Variable(lambda: 5.0))
            return state[0] * x, state[1] * x

        self.assertAllEqual(fn(constant_op.constant(1.0)), [2.0, 5.0])

    def testFunctionInitializationFunction(self):

        state = []

        @def_function.function
        def fn(x):
            if not state:
                state.append(variables.Variable(2.0))
            return state[0] * x

        init_fn = fn.get_initialization_function(constant_op.constant(1.0))
        self.assertLen(state, 1)
        self.assertFalse(
            resource_variable_ops.var_is_initialized_op(state[0].handle))
        init_fn()
        self.assertEqual(state[0].numpy(), 2.0)

    def testVariableInitializerNotConstant(self):

        state = []

        @def_function.function
        def fn(x):
            if not state:
                state.append(variables.Variable(2.0 * x))
            return state[0] * x

        self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0)
        self.assertAllEqual(fn(constant_op.constant(3.0)), 6.0)

    def testLegacyGraphModeVariables(self):
        with ops.Graph().as_default(), self.test_session() as sess:
            state = []

            @def_function.function
            def fn(x):
                if not state:
                    state.append(variables.Variable(2.0))
                return state[0] * x

            result = fn(3.0)

            self.evaluate(variables.global_variables_initializer())
            self.assertAllEqual(sess.run(state[0]), 2.0)
            self.assertAllEqual(self.evaluate(result), 6.0)

    def testLegacyGraphModeVariablesNonTrivialInitializer(self):
        with ops.Graph().as_default(), self.test_session() as sess:
            state = []

            @def_function.function
            def fn(x):
                if not state:
                    two = constant_op.constant(2.0)
                    four = two * two
                    two_again = math_ops.sqrt(four)
                    state.append(variables.Variable(two_again + four))
                return state[0] * x

            result = fn(3.0)

            self.evaluate(variables.global_variables_initializer())
            self.assertAllEqual(sess.run(state[0]), 6.0)
            self.assertAllEqual(self.evaluate(result), 18.0)

    def testLegacyGraphModeInputDependentInitializerFails(self):
        with ops.Graph().as_default():
            state = []

            @def_function.function
            def fn(x):
                if not state:
                    state.append(variables.Variable(2.0 * x))
                return state[0] * x

            with self.assertRaisesRegexp(lift_to_graph.UnliftableError,
                                         r'transitively.* mul .* x'):
                fn(constant_op.constant(3.0))

    def testMethod(self):
        class MyModel(object):
            def __init__(self):
                self.var = None

            @def_function.function
            def apply(self, x):
                if self.var is None:
                    self.var = variables.Variable(2.0)
                return self.var * x

        m0 = MyModel()
        self.assertAllEqual(m0.apply(3.0), 6.0)
        # Calling twice to exercise that we do not recreate variables.
        m0.var.assign(3.0)
        self.assertAllEqual(m0.apply(3.0), 9.0)

        m1 = MyModel()
        self.assertAllEqual(m1.apply(3.0), 6.0)

    def test_functools_partial(self):
        self.assertAllClose(
            3.,
            def_function.function(functools.partial(lambda x, y: x + y, 1.))(
                constant_op.constant(2.)))

    def test_functools_partial_new_default(self):
        def f(x=3, y=7):
            return x + y

        func = def_function.function(functools.partial(f, y=6))
        self.assertEqual(func().numpy(), 9)
        self.assertEqual(func(y=8).numpy(), 11)

    def test_functools_partial_keywords(self):
        def f(x, y):
            return x + y

        func = def_function.function(
            functools.partial(f,
                              x=array_ops.zeros([1]),
                              y=array_ops.zeros([1])))
        self.assertAllEqual(func(), [0.0])

    def test_functools_partial_single_positional(self):
        def f(x, y):
            return x + y

        func = def_function.function(
            functools.partial(f, constant_op.constant(1)))
        self.assertAllEqual(func(5), 6)

    def test_complicated_partial_with_defaults(self):
        def identity(*args):
            return args

        def dynamic_unroll(core_fn,
                           input_sequence,
                           initial_state,
                           sequence_length=None,
                           parallel_iterations=1,
                           swap_memory=False):
            del core_fn
            self.assertIs(None, sequence_length)
            self.assertEqual(1, parallel_iterations)
            self.assertTrue(swap_memory)
            return input_sequence, initial_state

        input_sequence = random_ops.random_uniform([1, 1, 1])
        initial_state = random_ops.random_uniform([1, 1])

        func = def_function.function(
            functools.partial(dynamic_unroll, identity, swap_memory=True))
        func(input_sequence, initial_state)

    def test_unspecified_default_argument(self):
        wrapped = def_function.function(
            lambda x, y=2: x + y,
            input_signature=[tensor_spec.TensorSpec((), dtypes.int32)])
        self.assertEqual(3, wrapped(constant_op.constant(1)).numpy())

    def test_concrete_function_from_signature(self):
        @def_function.function(
            input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
        def compute(x):
            return 2. * x

        concrete = compute.get_concrete_function()
        self.assertAllClose(1., concrete(constant_op.constant(0.5)))
        concrete = compute.get_concrete_function(
            tensor_spec.TensorSpec(None, dtypes.float32))
        self.assertAllClose(4., concrete(constant_op.constant(2.)))
        signature_args, _ = concrete.structured_input_signature
        self.assertEqual(
            signature_args,
            (tensor_spec.TensorSpec(None, dtypes.float32, name='x'), ))

    @test_util.run_in_graph_and_eager_modes
    def test_variable_naming(self):
        class HasVars(module.Module):
            def __init__(self):
                self.x = None
                self.y = None
                self.z = None

            @def_function.function
            def make_x(self):
                if self.x is None:
                    self.x = variables.Variable(1., name='v')

            def make_y(self):
                if self.y is None:
                    self.y = variables.Variable(1., name='v')

            def make_z(self):
                if self.z is None:
                    with ops.name_scope('z_scope', skip_on_eager=False):
                        self.z = variables.Variable(1., name='z')

        root = HasVars()
        root.make_x()
        root.make_y()
        root.make_z()
        self.assertEqual('v:0', root.x.name)
        self.assertEqual('z_scope/z:0', root.z.name)

    def test_concrete_function_keyword_arguments(self):
        @def_function.function
        def f(x):
            return x

        conc = f.get_concrete_function(
            tensor_spec.TensorSpec(None, dtypes.float32, 'y'))
        conc(y=constant_op.constant(3.0))
        signature_args, _ = conc.structured_input_signature
        self.assertEqual('y', signature_args[0].name)

        conc = f.get_concrete_function(
            tensor_spec.TensorSpec(None, dtypes.float32))
        conc(x=constant_op.constant(3.0))
        signature_args, _ = conc.structured_input_signature
        self.assertEqual('x', signature_args[0].name)

        @def_function.function
        def g(x):
            return x[0]

        conc = g.get_concrete_function(
            [tensor_spec.TensorSpec(None, dtypes.float32, 'z'), 2])
        conc(z=constant_op.constant(3.0))
        signature_args, _ = conc.structured_input_signature
        self.assertEqual('z', signature_args[0][0].name)

    def test_error_inner_capture(self):
        @def_function.function
        def f(inputs):
            num_steps, _ = inputs.shape[:2]
            outputs = []
            for t in math_ops.range(num_steps):
                outputs.append(inputs[t])
            return outputs

        with self.assertRaisesRegexp(
                errors.InaccessibleTensorError,
                'defined in another function or code block'):
            f(array_ops.zeros(shape=(8, 42, 3)))

    def testRuntimeErrorNotSticky(self):
        @def_function.function
        def fail(i):
            control_flow_ops.Assert(math_ops.equal(i, 0), ['ick'])

        fail(constant_op.constant(0))  # OK
        with self.assertRaises(errors.InvalidArgumentError):
            fail(constant_op.constant(1))  # InvalidArgument: "ick"
        fail(constant_op.constant(0))  # OK

    def testUnderscoreName(self):
        @def_function.function
        def f(_):
            return _ + _

        self.assertAllEqual(2.0, f(constant_op.constant(1.0)))

    def test_serialization_signature_cache(self):
        @def_function.function
        def f(x, y):
            return x, y

        f(constant_op.constant([[3., 4.]]), constant_op.constant([2.]))
        f(constant_op.constant([[3, 4, 5]]), constant_op.constant([2]))

        signatures_args = set()
        concrete_functions = f._list_all_concrete_functions_for_serialization()
        for concrete_function in concrete_functions:
            args, kwargs = concrete_function.structured_input_signature
            signatures_args.add(args)
            self.assertEqual(dict(), kwargs)

        self.assertEqual(
            signatures_args,
            set(((tensor_spec.TensorSpec([1, 2], dtypes.float32, name='x'),
                  tensor_spec.TensorSpec([1], dtypes.float32, name='y')),
                 (tensor_spec.TensorSpec([1, 3], dtypes.int32, name='x'),
                  tensor_spec.TensorSpec([1], dtypes.int32, name='y')))))

    @test_util.assert_no_garbage_created
    def testFunctionReferenceCycles(self):
        fn = def_function.function(lambda x: 2. * x)
        fn(constant_op.constant(4.0))
        weak_fn = weakref.ref(fn)
        del fn
        # Tests that the weak reference we made to the function is now dead, which
        # means the object has been deleted. This should be true as long as the
        # function itself is not involved in a reference cycle.
        self.assertIs(None, weak_fn())

    @test_util.assert_no_garbage_created
    def testMethodReferenceCycles(self):
        has_decorated_method = _HasDecoratedMethod()
        has_decorated_method.f(constant_op.constant(5.))
        weak_fn = weakref.ref(has_decorated_method.f)
        del has_decorated_method
        # Tests that the weak reference we made to the function is now dead, which
        # means the object has been deleted. This should be true as long as the
        # function itself is not involved in a reference cycle.
        self.assertIs(None, weak_fn())

    @test_util.assert_no_new_pyobjects_executing_eagerly
    def testErrorMessageWhenGraphTensorIsPassedToEager(self):
        @def_function.function
        def failing_function():
            a = constant_op.constant(1.)

            with ops.init_scope():
                _ = a + a

        with self.assertRaisesRegexp(
                TypeError,
                re.compile('An op outside of the function.*passed.*Const',
                           re.DOTALL)):
            failing_function()

    def testNonUniqueNamesGetConcreteFunction(self):
        @def_function.function
        def non_unique_arg_names(x, **kwargs):
            a, b, c = x
            d = kwargs['d']
            return a + b + c + d

        concrete = non_unique_arg_names.get_concrete_function(
            (tensor_spec.TensorSpec(None, dtypes.float32),
             tensor_spec.TensorSpec(None, dtypes.float32),
             tensor_spec.TensorSpec(None, dtypes.float32)),
            d=tensor_spec.TensorSpec(None, dtypes.float32))
        self.assertAllClose(
            10.,
            concrete(x=constant_op.constant(1.),
                     x_1=constant_op.constant(2.),
                     x_2=constant_op.constant(3.),
                     d=constant_op.constant(4.)))
        self.assertAllClose(
            10.,
            concrete(constant_op.constant(1.), constant_op.constant(2.),
                     constant_op.constant(3.), constant_op.constant(4.)))

    def testVariableCreatorScope(self):
        created_variables = []
        captured_variables = []

        @def_function.function
        def f():
            if not created_variables:
                created_variables.append(variables.Variable(1.))
            return created_variables[0] + 1.

        def capture_creator(next_creator, **kwargs):
            created = next_creator(**kwargs)
            captured_variables.append(created)
            return created

        with variable_scope.variable_creator_scope(capture_creator):
            f()
        self.assertEqual(created_variables, captured_variables)

    def testVarAlreadyInitializedNoClobbering(self):
        v_holder = []

        @def_function.function
        def add_var(x):
            if not v_holder:
                v = variables.Variable([1., 2.])
                v_holder.append(v)
                already_initialized = variables.Variable(3.)
                with ops.init_scope():
                    already_initialized.assign(10.)
                v_holder.append(already_initialized)
            return v_holder[0] + v_holder[1] + x

        add_var.get_concrete_function(constant_op.constant(2.))
        self.assertAllClose([13., 14.], add_var(constant_op.constant(2.)))

    def testSameVariableTwice(self):
        v = variables.Variable(1.0)

        @def_function.function
        def add(a, b):
            return a + b

        self.assertAllEqual(add(v, v), 2.0)

    def testVariableUpdate(self):
        v1 = variables.Variable(1.0)
        v2 = variables.Variable(2.0)
        v3 = variables.Variable(4, dtype=dtypes.int32)

        trace_count = [0]

        @def_function.function
        def double_variable(x):
            trace_count[0] += 1
            x.assign_add(x.read_value())

        self.assertEqual(trace_count[0], 0)
        double_variable(v1)
        self.assertEqual(trace_count[0], 1)
        self.assertEqual(self.evaluate(v1), 2.0)
        double_variable(v2)
        self.assertEqual(trace_count[0], 1 if ops.Tensor._USE_EQUALITY else 2)
        self.assertEqual(self.evaluate(v2), 4.0)
        double_variable(v3)
        self.assertEqual(trace_count[0], 2 if ops.Tensor._USE_EQUALITY else 3)
        self.assertEqual(self.evaluate(v3), 8)

    def testShapeCache(self):
        @def_function.function
        def func(x):
            return 2 * x

        func_a = func.get_concrete_function(
            tensor_spec.TensorSpec([None], dtypes.int32))
        func_b = func.get_concrete_function(
            tensor_spec.TensorSpec([None], dtypes.int32))

        self.assertIs(func_a, func_b)

    def testInitializationInNestedCall(self):
        v_holder = []

        @def_function.function
        def add_var(x):
            if not v_holder:
                v = variables.Variable([1., 2.])
                v_holder.append(v)
                already_initialized = variables.Variable(3.)
                with ops.init_scope():
                    already_initialized.assign(10.)
                v_holder.append(already_initialized)
            return v_holder[0] + v_holder[1] + x

        @def_function.function
        def wrapper(x):
            return add_var(x)

        self.assertAllClose([13., 14.], wrapper(constant_op.constant(2.)))
        v_holder[1].assign(11.)
        self.assertAllClose([14., 15.], wrapper(constant_op.constant(2.)))

    @test_util.run_gpu_only
    def testDeviceAnnotationRespected(self):
        a = []

        @def_function.function()
        def create_variable():
            with ops.init_scope():
                initial_value = random_ops.random_uniform((2, 2),
                                                          maxval=1000000,
                                                          dtype=dtypes.int64)

            if not a:
                with ops.device('CPU:0'):
                    a.append(
                        resource_variable_ops.ResourceVariable(initial_value))

            return a[0].read_value()

        create_variable()
        self.assertRegexpMatches(a[0].device, 'CPU')

    @test_util.run_gpu_only
    def testDeviceAnnotationForInitializerRespected(self):
        a = []
        initial_value = []

        def initial_value_fn():
            initial_value.append(random_ops.random_uniform((2, 3)))
            return initial_value[0]

        @def_function.function()
        def create_variable():
            with ops.init_scope():
                if not a:
                    a.append(variables.Variable(initial_value_fn))

        with ops.device('CPU:0'):
            create_variable()
        self.assertRegexpMatches(a[0].device, 'CPU')
        self.assertRegexpMatches(initial_value[0].device, 'CPU')

    def testDecorate(self):
        func = def_function.function(lambda: 1)

        def decorator(f):
            return lambda: 1 + f()

        func._decorate(decorator)
        self.assertEqual(func().numpy(), 2)

    @parameterized.parameters(*itertools.product(
        (None, (tensor_spec.TensorSpec([]), )),  # input_signature
        (True, False),  # autograph
        (None, converter.Feature.ALL),  # autograph_options
        (None, 'foo.bar'),  # implements
        (None, True, False),  # relax_shapes
        (True, False),  # compile
        (True, False),  # override_function
    ))
    def testClone(self, input_signature, autograph, autograph_options,
                  implements, relax_shapes, compile_, override_function):
        original_py_function = lambda x: x

        compile_ = False
        func = def_function.function(
            func=original_py_function,
            input_signature=input_signature,
            autograph=autograph,
            experimental_implements=implements,
            experimental_autograph_options=autograph_options,
            experimental_relax_shapes=relax_shapes,
            experimental_compile=compile_)

        if override_function:
            cloned_py_function = lambda x: x + 1
        else:
            cloned_py_function = original_py_function

        cloned = func._clone(python_function=cloned_py_function)

        self.assertEqual(cloned_py_function, cloned._python_function)
        self.assertEqual(func._name, cloned._name)
        self.assertEqual(input_signature, cloned._input_signature)
        self.assertEqual(autograph, cloned._autograph)
        self.assertEqual(implements, cloned._implements)
        self.assertEqual(autograph_options,
                         cloned._experimental_autograph_options)
        self.assertEqual(relax_shapes, cloned._experimental_relax_shapes)
        self.assertEqual(compile_, cloned._experimental_compile)

        # This test does not run with XLA JIT support linked in so we can only check
        # the output of the function if compile is disabled.
        if not compile_:
            x = array_ops.zeros([])
            self.assertEqual(self.evaluate(cloned(x)),
                             self.evaluate(cloned_py_function(x)))

    def testLiftPlaceholderInitializedVariable(self):
        with ops.Graph().as_default():
            var_list = []

            @def_function.function
            def use_variable():
                if not var_list:
                    initial_value = array_ops.placeholder(shape=[],
                                                          dtype=dtypes.float32)
                    v = variables.Variable(initial_value)
                    var_list.append(v)
                return var_list[0] + 1.

            var_plus_one = use_variable()
            with self.session() as session:
                init_op = var_list[0].initializer
                session.run(init_op, feed_dict={init_op.inputs[1]: 2.})
                self.assertEqual(3., session.run(var_plus_one))

    def testDecorate_rejectedAfterTrace(self):
        func = def_function.function(lambda: 1)
        self.assertEqual(func().numpy(), 1)
        msg = 'Functions cannot be decorated after they have been traced.'
        with self.assertRaisesRegexp(ValueError, msg):
            func._decorate(lambda f: f)

    def testGetConcreteFunctionGraphLifetime(self):
        @def_function.function
        def func():
            pass

        graph = func.get_concrete_function().graph
        del func

        # If the graph is deleted, then an exception is raised on reading `captures`
        self.assertEmpty(graph.captures)

    @parameterized.parameters(*itertools.product(
        (None, (tensor_spec.TensorSpec([]), )),  # input_signature
        (True, False),  # autograph
        (None, converter.Feature.ALL),  # autograph_options
        (None, 'foo.bar'),  # implements
        (None, True, False),  # relax_shapes
    ))
    def test_pickle(self, input_signature, autograph, autograph_options,
                    implements, relax_shapes):
        """@function objects can be pickled and unpickled."""
        original_py_function = undecorated_function

        func = def_function.function(
            func=original_py_function,
            input_signature=input_signature,
            autograph=autograph,
            experimental_implements=implements,
            experimental_autograph_options=autograph_options,
            experimental_relax_shapes=relax_shapes,
        )

        cloned = pickle.loads(pickle.dumps(func))

        self.assertEqual(func._name, cloned._name)
        self.assertEqual(input_signature, cloned._input_signature)
        self.assertEqual(autograph, cloned._autograph)
        self.assertEqual(implements, cloned._implements)
        self.assertEqual(autograph_options,
                         cloned._experimental_autograph_options)
        self.assertEqual(relax_shapes, cloned._experimental_relax_shapes)

        x = array_ops.ones([])
        self.assertEqual(self.evaluate(cloned(x)), self.evaluate(func(x)))

    def test_frequent_retracing_warning(self):
        if sys.version_info[0] < 3:
            self.skipTest(
                'self.assertLogs() call is not available in Python 2.')

        @def_function.function
        def f(x):
            return x

        with self.assertLogs(level='WARN') as logs:
            f(1)
            f(2)
            f(3)
            f(4)
            self.assertEmpty(logs.output)
            f(5)

        self.assertLen(logs.output, 1)
        self.assertIn('Tracing is expensive', logs.output[0])

    def test_frequent_retracing_warning_lambda(self):
        if sys.version_info[0] < 3:
            self.skipTest(
                'self.assertLogs() call is not available in Python 2.')

        f = def_function.function(lambda x: x)

        with self.assertLogs(level='WARN') as logs:
            f(1)
            f(2)
            f(3)
            f(4)
            f(5)

        self.assertLen(logs.output, 1)
        self.assertIn('Tracing is expensive', logs.output[0])

    def test_frequent_retracing_warning_method(self):
        if sys.version_info[0] < 3:
            self.skipTest(
                'self.assertLogs() call is not available in Python 2.')

        class Foo(object):
            @def_function.function
            def f(self, x):
                return x

        f = Foo().f

        with self.assertLogs(level='WARN') as logs:
            f(1)
            f(2)
            f(3)
            f(4)
            f(5)

        self.assertLen(logs.output, 1)
        self.assertIn('Tracing is expensive', logs.output[0])

    def test_frequent_retracing_warning_two_independent_tf_functions(self):
        if sys.version_info[0] < 3:
            self.skipTest(
                'self.assertLogs() call is not available in Python 2.')

        @def_function.function
        def f(x):
            return x

        @def_function.function
        def g(x):
            return x

        with self.assertLogs(level='WARN') as logs:
            f(1)
            f(2)
            f(3)
            f(4)
            g(1)
            g(2)
            g(3)
            g(4)
            g(5)

        self.assertLen(logs.output, 1)
        self.assertIn('Tracing is expensive', logs.output[0])

    def test_frequent_retracing_warning_nested(self):
        if sys.version_info[0] < 3:
            self.skipTest(
                'self.assertLogs() call is not available in Python 2.')

        @def_function.function
        def inner(x):
            return x + 1

        @def_function.function
        def outer1(x):
            return inner(x) * 2

        @def_function.function
        def outer2(x):
            return inner(x) * 3

        with self.assertLogs(level='WARN') as logs:
            inner(1)
            inner(2)
            inner(3)
            inner(4)

            outer1(5)
            outer1(6)
            outer1(7)
            outer1(8)

            outer2(9)
            outer2(10)
            outer2(11)
            outer2(12)

            self.assertEmpty(logs.output)

            outer2(13)

            self.assertLen(logs.output, 1)
            self.assertIn('Tracing is expensive', logs.output[0])

    def test_frequent_retracing_warning_on_reinstantiation(self):
        if sys.version_info[0] < 3:
            self.skipTest(
                'self.assertLogs() call is not available in Python 2.')

        with self.assertLogs(level='WARN') as logs:
            for i in range(5):

                @def_function.function
                def f(x):
                    return x

                f(i)

                if i < 4:
                    self.assertEmpty(logs.output)

        self.assertLen(logs.output, 1)
        self.assertIn('Tracing is expensive', logs.output[0])
Ejemplo n.º 23
0
class DefFunctionTest(test.TestCase, parameterized.TestCase):
    def testNoVariables(self):
        @def_function.function
        def fn(x):
            return 2 * x

        self.assertAllEqual(fn(constant_op.constant(4.0)), 8.0)

    def testFailIfVariablesAreCreatedMoreThanOnce(self):
        @def_function.function
        def fn(x):
            return variables.Variable(1.0) + x

        with self.assertRaises(ValueError):
            fn(1.0)

    def testFailIfVariablesAreCreatedMoreThanOnceNoWeakRef(self):
        state = []

        @def_function.function
        def fn(x):
            state.append(variables.Variable(1.0))
            return state[-1] + x

        with self.assertRaises(ValueError):
            fn(1.0)

    def testRange(self):
        @def_function.function
        def f(unused_x):
            return 1.0

        self.assertAllEqual(f(range(5)), 1.0)

    def testCorrectVariableCreation(self):

        state = []

        @def_function.function
        def fn(x):
            if not state:
                state.append(variables.Variable(2.0))
            return state[0] * x

        self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0)
        self.assertAllEqual(fn(constant_op.constant(3.0)), 6.0)

    def testFunctionInitializer(self):

        state = []

        @def_function.function
        def fn(x):
            if not state:
                state.append(variables.Variable(lambda: 2.0))
            return state[0] * x

        self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0)

    def testFunctionMultipleVariableInitializer(self):

        state = []

        @def_function.function
        def fn(x):
            if not state:
                state.append(variables.Variable(lambda: 2.0))
                state.append(variables.Variable(lambda: 5.0))
            return state[0] * x, state[1] * x

        self.assertAllEqual(fn(constant_op.constant(1.0)), [2.0, 5.0])

    def testFunctionInitializationFunction(self):

        state = []

        @def_function.function
        def fn(x):
            if not state:
                state.append(variables.Variable(2.0))
            return state[0] * x

        init_fn = fn.get_initialization_function(constant_op.constant(1.0))
        self.assertLen(state, 1)
        self.assertFalse(
            resource_variable_ops.var_is_initialized_op(state[0].handle))
        init_fn()
        self.assertEqual(state[0].numpy(), 2.0)

    def testVariableInitializerNotConstant(self):

        state = []

        @def_function.function
        def fn(x):
            if not state:
                state.append(variables.Variable(2.0 * x))
            return state[0] * x

        self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0)
        self.assertAllEqual(fn(constant_op.constant(3.0)), 6.0)

    def testLegacyGraphModeVariables(self):
        with ops.Graph().as_default(), self.test_session() as sess:
            state = []

            @def_function.function
            def fn(x):
                if not state:
                    state.append(variables.Variable(2.0))
                return state[0] * x

            result = fn(3.0)

            self.evaluate(variables.global_variables_initializer())
            self.assertAllEqual(sess.run(state[0]), 2.0)
            self.assertAllEqual(self.evaluate(result), 6.0)

    def testLegacyGraphModeVariablesNonTrivialInitializer(self):
        with ops.Graph().as_default(), self.test_session() as sess:
            state = []

            @def_function.function
            def fn(x):
                if not state:
                    two = constant_op.constant(2.0)
                    four = two * two
                    two_again = math_ops.sqrt(four)
                    state.append(variables.Variable(two_again + four))
                return state[0] * x

            result = fn(3.0)

            self.evaluate(variables.global_variables_initializer())
            self.assertAllEqual(sess.run(state[0]), 6.0)
            self.assertAllEqual(self.evaluate(result), 18.0)

    def testLegacyGraphModeInputDependentInitializerFails(self):
        with ops.Graph().as_default():
            state = []

            @def_function.function
            def fn(x):
                if not state:
                    state.append(variables.Variable(2.0 * x))
                return state[0] * x

            with self.assertRaisesRegex(lift_to_graph.UnliftableError,
                                        r'transitively.* mul .* x'):
                fn(constant_op.constant(3.0))

    def testMethod(self):
        class MyModel(object):
            def __init__(self):
                self.var = None

            @def_function.function
            def apply(self, x):
                if self.var is None:
                    self.var = variables.Variable(2.0)
                return self.var * x

        m0 = MyModel()
        self.assertAllEqual(m0.apply(3.0), 6.0)
        # Calling twice to exercise that we do not recreate variables.
        m0.var.assign(3.0)
        self.assertAllEqual(m0.apply(3.0), 9.0)

        m1 = MyModel()
        self.assertAllEqual(m1.apply(3.0), 6.0)

    @unittest.expectedFailure
    def testMethodAllowDynamicVariableWithoutGuards(self):
        class Foo:
            def __init__(self):
                self._var = 0

            def __call__(self, val):
                self.compute(val)
                return self._var

            @def_function.function
            def compute(self, val):
                self._var = variables.Variable(val)

        def_function.ALLOW_DYNAMIC_VARIABLE_CREATION = True
        foo = Foo()
        self.assertAllEqual(foo(0.3), 0.3)
        self.assertAllEqual(
            foo(0.9), 0.9,
            'https://github.com/tensorflow/tensorflow/issues/27120')

    def testMethodAllowDynamicVariable(self):
        class Foo:
            def __init__(self):
                self._flag_keyed_vars = {}
                self.trace_count = 0

            def __call__(self, var_creation_flag):
                self.compute(var_creation_flag)
                return self._flag_keyed_vars[var_creation_flag]

            @def_function.function
            def compute(self, var_creation_flag):
                self.trace_count += 1
                if var_creation_flag not in self._flag_keyed_vars:
                    if var_creation_flag:
                        self._flag_keyed_vars[
                            var_creation_flag] = variables.Variable(1.0)
                    else:
                        self._flag_keyed_vars[
                            var_creation_flag] = variables.Variable(2.0)

        def_function.ALLOW_DYNAMIC_VARIABLE_CREATION = True
        foo = Foo()
        self.assertAllEqual(foo(True), 1.0)
        self.assertEqual(foo.trace_count, 2)
        self.assertAllEqual(foo(True), 1.0)
        self.assertEqual(foo.trace_count, 2)
        self.assertAllEqual(foo(False), 2.0)
        self.assertEqual(foo.trace_count, 3)

    def testMethodNotAllowDynamicVariable(self):
        class Foo:
            def __init__(self):
                self._flag_keyed_vars = {}
                self.trace_count = 0

            def __call__(self, var_creation_flag):
                self.compute(var_creation_flag)
                return self._flag_keyed_vars[var_creation_flag]

            @def_function.function
            def compute(self, var_creation_flag):
                self.trace_count += 1
                if var_creation_flag not in self._flag_keyed_vars:
                    if var_creation_flag:
                        self._flag_keyed_vars[
                            var_creation_flag] = variables.Variable(1.0)
                    else:
                        self._flag_keyed_vars[
                            var_creation_flag] = variables.Variable(2.0)

        def_function.ALLOW_DYNAMIC_VARIABLE_CREATION = False
        foo = Foo()
        self.assertAllEqual(foo(True), 1.0)
        self.assertEqual(foo.trace_count, 2)
        self.assertAllEqual(foo(True), 1.0)
        self.assertEqual(foo.trace_count, 2)
        msg = 'singleton tf.Variable.*on the first call'
        with self.assertRaisesRegex(ValueError, msg):
            foo(False)
        self.assertEqual(foo.trace_count, 3)

    def testMethodExtensionType(self):
        class MaskedTensor(extension_type.ExtensionType):
            values: ops.Tensor
            mask: ops.Tensor

            @def_function.function
            def with_default(self, default_value):
                return array_ops.where_v2(self.mask, self.values,
                                          default_value)

            @def_function.function
            def sum(self):
                # Use a loop & conditional to test that autograph works correctly.
                result = 0
                for i in range(array_ops.size(self.values)):
                    if self.mask[i]:
                        result += self.values[i]
                return result

        mt = MaskedTensor([1, 2, 3], [True, False, True])
        self.assertAllEqual(mt.with_default(-1), [1, -1, 3])
        self.assertAllEqual(mt.sum(), 4)

    def test_functools_partial(self):
        self.assertAllClose(
            3.,
            def_function.function(functools.partial(lambda x, y: x + y, 1.))(
                constant_op.constant(2.)))

    def test_functools_partial_new_default(self):
        def f(x=3, y=7):
            return x + y

        func = def_function.function(functools.partial(f, y=6))
        self.assertEqual(func().numpy(), 9)
        self.assertEqual(func(y=8).numpy(), 11)

    def test_functools_partial_keywords(self):
        def f(x, y):
            return x + y

        func = def_function.function(
            functools.partial(f,
                              x=array_ops.zeros([1]),
                              y=array_ops.zeros([1])))
        self.assertAllEqual(func(), [0.0])

    def test_functools_partial_single_positional(self):
        def f(x, y):
            return x + y

        func = def_function.function(
            functools.partial(f, constant_op.constant(1)))
        self.assertAllEqual(func(5), 6)

    def test_complicated_partial_with_defaults(self):
        def identity(*args):
            return args

        def dynamic_unroll(core_fn,
                           input_sequence,
                           initial_state,
                           sequence_length=None,
                           parallel_iterations=1,
                           swap_memory=False):
            del core_fn
            self.assertIs(None, sequence_length)
            self.assertEqual(1, parallel_iterations)
            self.assertTrue(swap_memory)
            return input_sequence, initial_state

        input_sequence = random_ops.random_uniform([1, 1, 1])
        initial_state = random_ops.random_uniform([1, 1])

        func = def_function.function(
            functools.partial(dynamic_unroll, identity, swap_memory=True))
        func(input_sequence, initial_state)

    def test_unspecified_default_argument(self):
        wrapped = def_function.function(
            lambda x, y=2: x + y,
            input_signature=[tensor_spec.TensorSpec((), dtypes.int32)])
        self.assertEqual(3, wrapped(constant_op.constant(1)).numpy())

    def test_concrete_function_from_signature(self):
        @def_function.function(
            input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
        def compute(x):
            return 2. * x

        concrete = compute.get_concrete_function()
        self.assertAllClose(1., concrete(constant_op.constant(0.5)))
        concrete = compute.get_concrete_function(
            tensor_spec.TensorSpec(None, dtypes.float32))
        self.assertAllClose(4., concrete(constant_op.constant(2.)))
        signature_args, _ = concrete.structured_input_signature
        self.assertEqual(
            signature_args,
            (tensor_spec.TensorSpec(None, dtypes.float32, name='x'), ))

    def testInputSignatureMissingTensorSpecsMethod(self):
        class MyModule(module.Module):
            def f1(self, arg1, arg2, arg3):
                pass

            def f2(self, arg1, arg2, arg3, **kwargs):
                pass

            def f3(self, arg1, arg2, arg3, arg4=4, **kwargs):
                pass

            def f4(self, arg1, arg2, arg3, *args):
                pass

            def f5(self, arg1, arg2, arg3, *args, **kwargs):
                pass

            def f6(self, arg1, arg4=4, **kwargs):
                return arg1 + arg4

        m = MyModule()
        tf_func_dec = def_function.function(
            input_signature=(tensor_spec.TensorSpec([], dtypes.int32), ))
        error_msg = 'TensorSpecs are still required.*arg2.*arg3'
        with self.assertRaisesRegex(TypeError, error_msg):
            tf_func_dec(m.f1)(1, 2, 3)

        with self.assertRaisesRegex(TypeError, error_msg):
            tf_func_dec(m.f2)(1, 2, 3)

        with self.assertRaisesRegex(TypeError, error_msg):
            tf_func_dec(m.f3)(1, 2, 3)

        with self.assertRaisesRegex(TypeError, error_msg):
            tf_func_dec(m.f4)(1, 2, 3)

        with self.assertRaisesRegex(TypeError, error_msg):
            tf_func_dec(m.f5)(1, 2, 3)

        self.assertEqual(tf_func_dec(m.f6)(1).numpy(), 5)

    def testInputSignatureMissingTensorSpecsFunction(self):
        tf_func_dec = def_function.function(
            input_signature=(tensor_spec.TensorSpec([], dtypes.int32), ))
        error_msg = 'TensorSpecs are still required.*arg2.*arg3'

        # pylint: disable=unused-argument
        def f1(arg1, arg2, arg3):
            pass

        with self.assertRaisesRegex(TypeError, error_msg):
            tf_func_dec(f1)(1, 2, 3)

        def f2(arg1, arg2, arg3, **kwargs):
            pass

        with self.assertRaisesRegex(TypeError, error_msg):
            tf_func_dec(f2)(1, 2, 3)

        def f3(arg1, arg2, arg3, arg4=4, **kwargs):
            pass

        with self.assertRaisesRegex(TypeError, error_msg):
            tf_func_dec(f3)(1, 2, 3)

        def f4(arg1, arg2, arg3, *args):
            pass

        with self.assertRaisesRegex(TypeError, error_msg):
            tf_func_dec(f4)(1, 2, 3)

        def f5(arg1, arg2, arg3, *args, **kwargs):
            pass

        with self.assertRaisesRegex(TypeError, error_msg):
            tf_func_dec(f5)(1, 2, 3)
        # pyline: enable=unused-argument

        def f6(arg1, arg4=4, **kwargs):
            return arg1 + arg4

        self.assertEqual(tf_func_dec(f6)(1).numpy(), 5)

    def testInputSignatureMissingTensorSpecsLambdaFunction(self):
        tf_func_dec = def_function.function(
            input_signature=(tensor_spec.TensorSpec([], dtypes.int32), ))
        error_msg = 'TensorSpecs are still required.*arg2.*arg3'
        with self.assertRaisesRegex(TypeError, error_msg):
            tf_func_dec(lambda ar1, arg2, arg3: None)(1, 2, 3)

        with self.assertRaisesRegex(TypeError, error_msg):
            tf_func_dec(lambda arg1, arg2, arg3, **kwargs: None)(1, 2, 3)

        with self.assertRaisesRegex(TypeError, error_msg):
            tf_func_dec(lambda arg1, arg2, arg3, arg4=4, **kwargs: None)(1, 2,
                                                                         3)

        with self.assertRaisesRegex(TypeError, error_msg):
            tf_func_dec(lambda arg1, arg2, arg3, *args: None)(1, 2, 3)

        with self.assertRaisesRegex(TypeError, error_msg):
            tf_func_dec(lambda arg1, arg2, arg3, *args, **kwargs: None)(1, 2,
                                                                        3)

        self.assertEqual(
            tf_func_dec(lambda arg1, arg4=4, **kwargs: arg1 + arg4)(1).numpy(),
            5)

    @parameterized.named_parameters(('_method', 'method'),
                                    ('_function', 'function'),
                                    ('_lambda_function', 'lambda_function'))
    def testInputSignaturePartialFuncMissingTensorSpecs(self, func_type):
        if func_type == 'method':

            class MyModule(module.Module):
                def f(self, arg1, arg2, arg3, arg4=4):
                    return arg1 + arg2 + arg3 + arg4

            f = MyModule().f
        elif func_type == 'function':

            def f(arg1, arg2, arg3, arg4=4):
                return arg1 + arg2 + arg3 + arg4
        else:  # lambda_function
            f = lambda arg1, arg2, arg3, arg4=4: arg1 + arg2 + arg3 + arg4

        tf_func_dec = def_function.function(
            input_signature=(tensor_spec.TensorSpec([], dtypes.int32), ))
        with self.assertRaisesRegex(TypeError,
                                    'TensorSpecs are still required.*arg3'):
            tf_func_dec(functools.partial(f, 1))(2, 3)

        with self.assertRaisesRegex(
                TypeError, 'TensorSpecs are still required.*arg2.*arg3'):
            tf_func_dec(functools.partial(f, arg4=5))(1, 2, 3)

        with self.assertRaisesRegex(TypeError,
                                    'TensorSpecs are still required.*arg3'):
            tf_func_dec(functools.partial(f, 1, arg4=5))(2, 3)

        self.assertAllEqual(
            tf_func_dec(functools.partial(f, 1, 2, arg4=5))(3),
            array_ops.constant(11))

    @test_util.run_in_graph_and_eager_modes
    def test_variable_naming(self):
        class HasVars(module.Module):
            def __init__(self):
                self.x = None
                self.y = None
                self.z = None

            @def_function.function
            def make_x(self):
                if self.x is None:
                    self.x = variables.Variable(1., name='v')

            def make_y(self):
                if self.y is None:
                    self.y = variables.Variable(1., name='v')

            def make_z(self):
                if self.z is None:
                    with ops.name_scope('z_scope', skip_on_eager=False):
                        self.z = variables.Variable(1., name='z')

        root = HasVars()
        root.make_x()
        root.make_y()
        root.make_z()
        self.assertEqual('v:0', root.x.name)
        self.assertEqual('z_scope/z:0', root.z.name)

    def test_concrete_function_keyword_arguments(self):
        @def_function.function
        def f(x):
            return x

        conc = f.get_concrete_function(
            tensor_spec.TensorSpec(None, dtypes.float32, 'y'))
        conc(y=constant_op.constant(3.0))
        signature_args, _ = conc.structured_input_signature
        self.assertEqual('y', signature_args[0].name)

        conc = f.get_concrete_function(
            tensor_spec.TensorSpec(None, dtypes.float32))
        conc(x=constant_op.constant(3.0))
        signature_args, _ = conc.structured_input_signature
        self.assertEqual('x', signature_args[0].name)

        @def_function.function
        def g(x):
            return x[0]

        conc = g.get_concrete_function(
            [tensor_spec.TensorSpec(None, dtypes.float32, 'z'), 2])
        conc(z=constant_op.constant(3.0))
        signature_args, _ = conc.structured_input_signature
        self.assertEqual('z', signature_args[0][0].name)

    def test_error_inner_capture(self):
        @def_function.function
        def f(inputs):
            num_steps, _ = inputs.shape[:2]
            outputs = []
            for t in math_ops.range(num_steps):
                outputs.append(inputs[t])
            return outputs

        with self.assertRaisesRegex(
                errors.InaccessibleTensorError,
                'defined in another function or code block'):
            f(array_ops.zeros(shape=(8, 42, 3)))

    def testRuntimeErrorNotSticky(self):
        @def_function.function
        def fail(i):
            control_flow_ops.Assert(math_ops.equal(i, 0), ['ick'])

        fail(constant_op.constant(0))  # OK
        with self.assertRaises(errors.InvalidArgumentError):
            fail(constant_op.constant(1))  # InvalidArgument: "ick"
        fail(constant_op.constant(0))  # OK

    def testUnderscoreName(self):
        @def_function.function
        def f(_):
            return _ + _

        self.assertAllEqual(2.0, f(constant_op.constant(1.0)))

    def test_serialization_signature_cache(self):
        @def_function.function
        def f(x, y):
            return x, y

        f(constant_op.constant([[3., 4.]]), constant_op.constant([2.]))
        f(constant_op.constant([[3, 4, 5]]), constant_op.constant([2]))

        signatures_args = set()
        concrete_functions = f._list_all_concrete_functions_for_serialization()
        for concrete_function in concrete_functions:
            args, kwargs = concrete_function.structured_input_signature
            signatures_args.add(args)
            self.assertEqual(dict(), kwargs)

        self.assertEqual(
            signatures_args,
            set(((tensor_spec.TensorSpec([1, 2], dtypes.float32, name='x'),
                  tensor_spec.TensorSpec([1], dtypes.float32, name='y')),
                 (tensor_spec.TensorSpec([1, 3], dtypes.int32, name='x'),
                  tensor_spec.TensorSpec([1], dtypes.int32, name='y')))))

    @test_util.assert_no_garbage_created
    def testFunctionReferenceCycles(self):
        fn = def_function.function(lambda x: 2. * x)
        fn(constant_op.constant(4.0))
        weak_fn = weakref.ref(fn)
        del fn
        # Tests that the weak reference we made to the function is now dead, which
        # means the object has been deleted. This should be true as long as the
        # function itself is not involved in a reference cycle.
        self.assertIs(None, weak_fn())

    @test_util.assert_no_garbage_created
    def testMethodReferenceCycles(self):
        has_decorated_method = _HasDecoratedMethod()
        has_decorated_method.f(constant_op.constant(5.))
        weak_fn = weakref.ref(has_decorated_method.f)
        del has_decorated_method
        # Tests that the weak reference we made to the function is now dead, which
        # means the object has been deleted. This should be true as long as the
        # function itself is not involved in a reference cycle.
        self.assertIs(None, weak_fn())

    @test_util.assert_no_new_pyobjects_executing_eagerly
    def testErrorMessageWhenGraphTensorIsPassedToEager(self):
        @def_function.function
        def failing_function():
            a = constant_op.constant(1.)

            with ops.init_scope():
                _ = a + a

        with self.assertRaisesRegex(
                TypeError,
                re.compile('An op outside of the function.*passed.*Const',
                           re.DOTALL)):
            failing_function()

    def testNonUniqueNamesGetConcreteFunction(self):
        @def_function.function
        def non_unique_arg_names(x, **kwargs):
            a, b, c = x
            d = kwargs['d']
            return a + b + c + d

        concrete = non_unique_arg_names.get_concrete_function(
            (tensor_spec.TensorSpec(None, dtypes.float32),
             tensor_spec.TensorSpec(None, dtypes.float32),
             tensor_spec.TensorSpec(None, dtypes.float32)),
            d=tensor_spec.TensorSpec(None, dtypes.float32))
        self.assertAllClose(
            10.,
            concrete(x=constant_op.constant(1.),
                     x_1=constant_op.constant(2.),
                     x_2=constant_op.constant(3.),
                     d=constant_op.constant(4.)))
        self.assertAllClose(
            10.,
            concrete(constant_op.constant(1.), constant_op.constant(2.),
                     constant_op.constant(3.), constant_op.constant(4.)))

    def testVariableCreatorScope(self):
        created_variables = []
        captured_variables = []

        @def_function.function
        def f():
            if not created_variables:
                created_variables.append(variables.Variable(1.))
            return created_variables[0] + 1.

        def capture_creator(next_creator, **kwargs):
            created = next_creator(**kwargs)
            captured_variables.append(created)
            return created

        with variable_scope.variable_creator_scope(capture_creator):
            f()
        self.assertEqual(created_variables, captured_variables)

    def testVarAlreadyInitializedNoClobbering(self):
        v_holder = []

        @def_function.function
        def add_var(x):
            if not v_holder:
                v = variables.Variable([1., 2.])
                v_holder.append(v)
                already_initialized = variables.Variable(3.)
                with ops.init_scope():
                    already_initialized.assign(10.)
                v_holder.append(already_initialized)
            return v_holder[0] + v_holder[1] + x

        add_var.get_concrete_function(constant_op.constant(2.))
        self.assertAllClose([13., 14.], add_var(constant_op.constant(2.)))

    def testSameVariableTwice(self):
        v = variables.Variable(1.0)

        @def_function.function
        def add(a, b):
            return a + b

        self.assertAllEqual(add(v, v), 2.0)

    def testVariableUpdate(self):
        v1 = variables.Variable(1.0)
        v2 = variables.Variable(2.0)
        v3 = variables.Variable(4, dtype=dtypes.int32)

        trace_count = [0]

        @def_function.function
        def double_variable(x):
            trace_count[0] += 1
            x.assign_add(x.read_value())

        self.assertEqual(trace_count[0], 0)
        double_variable(v1)
        self.assertEqual(trace_count[0], 1)
        self.assertEqual(self.evaluate(v1), 2.0)
        double_variable(v2)
        # No retracing because v2's data type and shape are the same as v1
        self.assertEqual(trace_count[0], 1)
        self.assertEqual(self.evaluate(v2), 4.0)
        double_variable(v3)
        # Retracing because of data type change
        self.assertEqual(trace_count[0], 2)
        self.assertEqual(self.evaluate(v3), 8)

    def testShapeCache(self):
        @def_function.function
        def func(x):
            return 2 * x

        func_a = func.get_concrete_function(
            tensor_spec.TensorSpec([None], dtypes.int32))
        func_b = func.get_concrete_function(
            tensor_spec.TensorSpec([None], dtypes.int32))

        self.assertIs(func_a, func_b)

    def testCacheWithinSaveContext(self):
        @def_function.function
        def func(x):
            return 2 * x

        func_a = func.get_concrete_function(constant_op.constant(2.))
        func_b = func.get_concrete_function(constant_op.constant(2.))

        self.assertIs(func_a, func_b)

        with save_context.save_context(
                save_options.SaveOptions(
                    experimental_variable_policy=save_options.VariablePolicy.
                    EXPAND_DISTRIBUTED_VARIABLES)):
            func_c = func.get_concrete_function(constant_op.constant(2.))

        with save_context.save_context(
                save_options.SaveOptions(
                    experimental_variable_policy=save_options.VariablePolicy.
                    NONE)):
            func_d = func.get_concrete_function(constant_op.constant(2.))

        self.assertIsNot(func_a, func_c)
        self.assertIsNot(func_a, func_d)

    def testInitializationInNestedCall(self):
        v_holder = []

        @def_function.function
        def add_var(x):
            if not v_holder:
                v = variables.Variable([1., 2.])
                v_holder.append(v)
                already_initialized = variables.Variable(3.)
                with ops.init_scope():
                    already_initialized.assign(10.)
                v_holder.append(already_initialized)
            return v_holder[0] + v_holder[1] + x

        @def_function.function
        def wrapper(x):
            return add_var(x)

        self.assertAllClose([13., 14.], wrapper(constant_op.constant(2.)))
        v_holder[1].assign(11.)
        self.assertAllClose([14., 15.], wrapper(constant_op.constant(2.)))

    @test_util.run_gpu_only
    def testDeviceAnnotationRespected(self):
        a = []

        @def_function.function()
        def create_variable():
            with ops.init_scope():
                initial_value = random_ops.random_uniform((2, 2),
                                                          maxval=1000000,
                                                          dtype=dtypes.int64)

            if not a:
                with ops.device('CPU:0'):
                    a.append(
                        resource_variable_ops.ResourceVariable(initial_value))

            return a[0].read_value()

        create_variable()
        self.assertRegex(a[0].device, 'CPU')

    @test_util.run_gpu_only
    def testDeviceAnnotationForInitializerRespected(self):
        a = []
        initial_value = []

        def initial_value_fn():
            initial_value.append(random_ops.random_uniform((2, 3)))
            return initial_value[0]

        @def_function.function()
        def create_variable():
            with ops.init_scope():
                if not a:
                    a.append(variables.Variable(initial_value_fn))

        with ops.device('CPU:0'):
            create_variable()
        self.assertRegex(a[0].device, 'CPU')
        self.assertRegex(initial_value[0].device, 'CPU')

    def testDecorate(self):
        func = def_function.function(lambda: 1)

        def decorator(f):
            return lambda: 1 + f()

        func._decorate(decorator)
        self.assertEqual(func().numpy(), 2)

    @parameterized.parameters(*itertools.product(
        (None, (tensor_spec.TensorSpec([]), )),  # input_signature
        (True, False),  # autograph
        (None, converter.Feature.ALL),  # autograph_options
        (None, 'foo.bar'),  # implements
        (None, True, False),  # relax_shapes
        (True, False),  # compile
        (True, False),  # override_function
    ))
    def testClone(self, input_signature, autograph, autograph_options,
                  implements, relax_shapes, compile_, override_function):
        original_py_function = lambda x: x

        compile_ = False
        func = def_function.function(
            func=original_py_function,
            input_signature=input_signature,
            autograph=autograph,
            experimental_implements=implements,
            experimental_autograph_options=autograph_options,
            experimental_relax_shapes=relax_shapes,
            jit_compile=compile_)

        if override_function:
            cloned_py_function = lambda x: x + 1
        else:
            cloned_py_function = original_py_function

        cloned = func._clone(python_function=cloned_py_function)

        self.assertEqual(cloned_py_function, cloned._python_function)
        self.assertEqual(func._name, cloned._name)
        self.assertEqual(input_signature, cloned._input_signature)
        self.assertEqual(autograph, cloned._autograph)
        self.assertEqual(implements, cloned._implements)
        self.assertEqual(autograph_options,
                         cloned._experimental_autograph_options)
        self.assertEqual(relax_shapes, cloned._experimental_relax_shapes)
        self.assertEqual(compile_, cloned._jit_compile)

        # This test does not run with XLA JIT support linked in so we can only check
        # the output of the function if compile is disabled.
        if not compile_:
            x = array_ops.zeros([])
            self.assertEqual(self.evaluate(cloned(x)),
                             self.evaluate(cloned_py_function(x)))

    def testLiftPlaceholderInitializedVariable(self):
        with ops.Graph().as_default():
            var_list = []

            @def_function.function
            def use_variable():
                if not var_list:
                    initial_value = array_ops.placeholder(shape=[],
                                                          dtype=dtypes.float32)
                    v = variables.Variable(initial_value)
                    var_list.append(v)
                return var_list[0] + 1.

            var_plus_one = use_variable()
            with self.session() as session:
                init_op = var_list[0].initializer
                session.run(init_op, feed_dict={init_op.inputs[1]: 2.})
                self.assertEqual(3., session.run(var_plus_one))

    def testDecorate_rejectedAfterTrace(self):
        func = def_function.function(lambda: 1)
        self.assertEqual(func().numpy(), 1)
        msg = 'Functions cannot be decorated after they have been traced.'
        with self.assertRaisesRegex(ValueError, msg):
            func._decorate(lambda f: f)

    def testGetConcreteFunctionGraphLifetime(self):
        @def_function.function
        def func():
            pass

        graph = func.get_concrete_function().graph
        del func

        # If the graph is deleted, then an exception is raised on reading `captures`
        self.assertEmpty(graph.captures)

    @parameterized.parameters(*itertools.product(
        (None, (tensor_spec.TensorSpec([]), )),  # input_signature
        (True, False),  # autograph
        (None, converter.Feature.ALL),  # autograph_options
        (None, 'foo.bar'),  # implements
        (None, True, False),  # relax_shapes
    ))
    def test_pickle(self, input_signature, autograph, autograph_options,
                    implements, relax_shapes):
        """@function objects can be pickled and unpickled."""
        original_py_function = undecorated_function

        func = def_function.function(
            func=original_py_function,
            input_signature=input_signature,
            autograph=autograph,
            experimental_implements=implements,
            experimental_autograph_options=autograph_options,
            experimental_relax_shapes=relax_shapes,
        )

        cloned = pickle.loads(pickle.dumps(func))

        self.assertEqual(func._name, cloned._name)
        self.assertEqual(input_signature, cloned._input_signature)
        self.assertEqual(autograph, cloned._autograph)
        self.assertEqual(implements, cloned._implements)
        self.assertEqual(autograph_options,
                         cloned._experimental_autograph_options)
        self.assertEqual(relax_shapes, cloned._experimental_relax_shapes)

        x = array_ops.ones([])
        self.assertEqual(self.evaluate(cloned(x)), self.evaluate(func(x)))

    def test_frequent_retracing_warning(self):
        if sys.version_info[0] < 3:
            self.skipTest(
                'self.assertLogs() call is not available in Python 2.')

        @def_function.function
        def f(x):
            return x

        with self.assertLogs(level='WARN') as logs:
            f(1)
            f(2)
            f(3)
            f(4)
            self.assertEmpty(logs.output)
            f(5)

        self.assertLen(logs.output, 1)
        self.assertIn('Tracing is expensive', logs.output[0])

    def test_frequent_retracing_warning_lambda(self):
        if sys.version_info[0] < 3:
            self.skipTest(
                'self.assertLogs() call is not available in Python 2.')

        f = def_function.function(lambda x: x)

        with self.assertLogs(level='WARN') as logs:
            f(1)
            f(2)
            f(3)
            f(4)
            f(5)

        self.assertLen(logs.output, 1)
        self.assertIn('Tracing is expensive', logs.output[0])

    def test_frequent_retracing_warning_method(self):
        if sys.version_info[0] < 3:
            self.skipTest(
                'self.assertLogs() call is not available in Python 2.')

        class Foo(object):
            @def_function.function
            def f(self, x):
                return x

        f = Foo().f

        with self.assertLogs(level='WARN') as logs:
            f(1)
            f(2)
            f(3)
            f(4)
            f(5)

        self.assertLen(logs.output, 1)
        self.assertIn('Tracing is expensive', logs.output[0])

    def test_frequent_retracing_warning_two_independent_tf_functions(self):
        if sys.version_info[0] < 3:
            self.skipTest(
                'self.assertLogs() call is not available in Python 2.')

        @def_function.function
        def f(x):
            return x

        @def_function.function
        def g(x):
            return x

        with self.assertLogs(level='WARN') as logs:
            f(1)
            f(2)
            f(3)
            f(4)
            g(1)
            g(2)
            g(3)
            g(4)
            g(5)

        self.assertLen(logs.output, 1)
        self.assertIn('Tracing is expensive', logs.output[0])

    def test_frequent_retracing_warning_nested(self):
        if sys.version_info[0] < 3:
            self.skipTest(
                'self.assertLogs() call is not available in Python 2.')

        @def_function.function
        def inner(x):
            return x + 1

        @def_function.function
        def outer1(x):
            return inner(x) * 2

        @def_function.function
        def outer2(x):
            return inner(x) * 3

        with self.assertLogs(level='WARN') as logs:
            inner(1)
            inner(2)
            inner(3)
            inner(4)

            outer1(5)
            outer1(6)
            outer1(7)
            outer1(8)

            outer2(9)
            outer2(10)
            outer2(11)
            outer2(12)

            self.assertEmpty(logs.output)

            outer2(13)

            self.assertLen(logs.output, 1)
            self.assertIn('Tracing is expensive', logs.output[0])

    def test_frequent_retracing_warning_on_reinstantiation(self):
        if sys.version_info[0] < 3:
            self.skipTest(
                'self.assertLogs() call is not available in Python 2.')

        with self.assertLogs(level='WARN') as logs:
            for i in range(5):

                @def_function.function
                def f(x):
                    return x

                f(i)

                if i < 4:
                    self.assertEmpty(logs.output)

        self.assertLen(logs.output, 1)
        self.assertIn('Tracing is expensive', logs.output[0])

    def test_restored_function_retracing_warning(self):
        class Foo(Checkpoint):
            @def_function.function
            def __call__(self, x):
                return x

        f_flexible = Foo()
        _ = f_flexible.__call__.get_concrete_function(
            tensor_spec.TensorSpec(shape=[None], dtype=dtypes.int32))
        tmp_dir = self.create_tempdir()
        save(f_flexible, tmp_dir.full_path)
        restored_f_flexible = load(tmp_dir.full_path)

        f_fixed_shape = Foo()

        with self.assertLogs(level='WARN') as logs:
            restored_f_flexible(constant_op.constant([1], dtypes.int32))
            restored_f_flexible(constant_op.constant([1, 2], dtypes.int32))
            restored_f_flexible(constant_op.constant([1, 2, 3], dtypes.int32))
            restored_f_flexible(
                constant_op.constant([1, 2, 3, 4], dtypes.int32))
            restored_f_flexible(
                constant_op.constant([1, 2, 3, 4, 5], dtypes.int32))
            self.assertEmpty(logs.output)

            f_fixed_shape(constant_op.constant([1], dtypes.int32))
            f_fixed_shape(constant_op.constant([1, 2], dtypes.int32))
            f_fixed_shape(constant_op.constant([1, 2, 3], dtypes.int32))
            f_fixed_shape(constant_op.constant([1, 2, 3, 4], dtypes.int32))
            f_fixed_shape(constant_op.constant([1, 2, 3, 4, 5], dtypes.int32))
            self.assertLen(logs.output, 1)
            self.assertIn('Tracing is expensive', logs.output[0])

    def test_retracing_warning_limits(self):
        @def_function.function
        def my_func(x):
            return x

        with self.assertLogs(level='WARN') as logs:
            for i in range(10):
                my_func(i)

            self.assertLen(logs.output, 2)

    def test_experimental_get_tracing_count_function(self):
        @def_function.function
        def double(a):
            return a + a

        double(constant_op.constant(1))
        double(constant_op.constant(2))
        self.assertAllEqual(double.experimental_get_tracing_count(), 1)
        double(constant_op.constant('a'))
        self.assertAllEqual(double.experimental_get_tracing_count(), 2)

    def test_experimental_get_tracing_count_method(self):
        class TestClass():
            @def_function.function
            def testDouble(self, a):
                return a + a

        obj1 = TestClass()
        obj1.testDouble(constant_op.constant(1))
        obj1.testDouble(constant_op.constant(2))
        obj1.testDouble(constant_op.constant(1.1))
        self.assertAllEqual(obj1.testDouble.experimental_get_tracing_count(),
                            2)
        obj2 = TestClass()
        obj2.testDouble(constant_op.constant(1))
        obj2.testDouble(constant_op.constant(1.1))
        obj2.testDouble(constant_op.constant('a'))
        self.assertAllEqual(obj2.testDouble.experimental_get_tracing_count(),
                            3)
        self.assertAllEqual(obj1.testDouble.experimental_get_tracing_count(),
                            2)
Ejemplo n.º 24
0
 def _save_model(self, model, saved_dir):
     call = model.__call__.get_concrete_function(
         tensor_spec.TensorSpec(None))
     saved_model.save(model, saved_dir, signatures=call)
Ejemplo n.º 25
0
 def element_spec(self):
     return tensor_spec.TensorSpec([], dtypes.string)
Ejemplo n.º 26
0
class SparseTensorSpecTest(test_util.TensorFlowTestCase,
                           parameterized.TestCase):
    def assertAllTensorsEqual(self, list1, list2):
        self.assertLen(list1, len(list2))
        for (t1, t2) in zip(list1, list2):
            self.assertAllEqual(t1, t2)

    def testConstruction(self):
        spec1 = sparse_tensor.SparseTensorSpec()
        self.assertEqual(spec1._shape.rank, None)
        self.assertEqual(spec1._dtype, dtypes.float32)

        spec2 = sparse_tensor.SparseTensorSpec([None, None], dtypes.string)
        self.assertEqual(spec2._shape.as_list(), [None, None])
        self.assertEqual(spec2._dtype, dtypes.string)

    def testValueType(self):
        spec1 = sparse_tensor.SparseTensorSpec()
        self.assertEqual(spec1.value_type, sparse_tensor.SparseTensor)

    @parameterized.parameters([
        (sparse_tensor.SparseTensorSpec(), (tensor_shape.TensorShape(None),
                                            dtypes.float32)),
        (sparse_tensor.SparseTensorSpec(shape=[5, None, None]),
         (tensor_shape.TensorShape([5, None, None]), dtypes.float32)),
        (sparse_tensor.SparseTensorSpec(dtype=dtypes.int32),
         (tensor_shape.TensorShape(None), dtypes.int32)),
    ])  # pyformat: disable
    def testSerialize(self, st_spec, expected):
        serialization = st_spec._serialize()
        # TensorShape has an unconventional definition of equality, so we can't use
        # assertEqual directly here.  But repr() is deterministic and lossless for
        # the expected values, so we can use that instead.
        self.assertEqual(repr(serialization), repr(expected))

    @parameterized.parameters([
        (sparse_tensor.SparseTensorSpec(dtype=dtypes.string), [
            tensor_spec.TensorSpec([None, None], dtypes.int64),
            tensor_spec.TensorSpec([None], dtypes.string),
            tensor_spec.TensorSpec([None], dtypes.int64)
        ]),
        (sparse_tensor.SparseTensorSpec(shape=[5, None, None]), [
            tensor_spec.TensorSpec([None, 3], dtypes.int64),
            tensor_spec.TensorSpec([None], dtypes.float32),
            tensor_spec.TensorSpec([3], dtypes.int64)
        ]),
    ])
    def testComponentSpecs(self, st_spec, expected):
        self.assertEqual(st_spec._component_specs, expected)

    @parameterized.parameters([
        {
            "st_spec": sparse_tensor.SparseTensorSpec(),
            "indices": [[0, 1], [10, 8]],
            "values": [3.0, 5.0],
            "dense_shape": [100, 100]
        },
        {
            "st_spec": sparse_tensor.SparseTensorSpec([100, None, None]),
            "indices": [[0, 1, 3], [10, 8, 2]],
            "values": [3.0, 5.0],
            "dense_shape": [100, 20, 20]
        },
    ])
    def testToFromComponents(self, st_spec, indices, values, dense_shape):
        st = sparse_tensor.SparseTensor(indices, values, dense_shape)
        actual_components = st_spec._to_components(st)
        self.assertAllTensorsEqual(actual_components,
                                   [indices, values, dense_shape])
        st_reconstructed = st_spec._from_components(actual_components)
        self.assertAllEqual(st.indices, st_reconstructed.indices)
        self.assertAllEqual(st.values, st_reconstructed.values)
        self.assertAllEqual(st.dense_shape, st_reconstructed.dense_shape)

    @test_util.run_v1_only("SparseTensorValue is deprecated in v2")
    def testFromNumpyComponents(self):
        indices = np.array([[0], [8]])
        values = np.array([1.0, 9.0])
        dense_shape = np.array([100])
        spec = sparse_tensor.SparseTensorSpec()
        st = spec._from_components([indices, values, dense_shape])
        self.assertIsInstance(st, sparse_tensor.SparseTensorValue)
        self.assertAllEqual(st.indices, indices)
        self.assertAllEqual(st.values, values)
        self.assertAllEqual(st.dense_shape, dense_shape)

    @parameterized.parameters([
        sparse_tensor.SparseTensorSpec(dtype=dtypes.string),
        sparse_tensor.SparseTensorSpec(shape=[5, None, None]),
    ])
    def testFlatTensorSpecs(self, st_spec):
        self.assertEqual(st_spec._flat_tensor_specs,
                         [tensor_spec.TensorSpec(None, dtypes.variant)])

    @parameterized.parameters([
        {
            "st_spec": sparse_tensor.SparseTensorSpec(),
            "indices": [[0, 1], [10, 8]],
            "values": [3.0, 5.0],
            "dense_shape": [100, 100]
        },
        {
            "st_spec": sparse_tensor.SparseTensorSpec([100, None, None]),
            "indices": [[0, 1, 3], [10, 8, 2]],
            "values": [3.0, 5.0],
            "dense_shape": [100, 20, 20]
        },
    ])
    def testToFromTensorList(self, st_spec, indices, values, dense_shape):
        st = sparse_tensor.SparseTensor(indices, values, dense_shape)
        tensor_list = st_spec._to_tensor_list(st)
        st_reconstructed = st_spec._from_tensor_list(tensor_list)
        self.assertAllEqual(st.indices, st_reconstructed.indices)
        self.assertAllEqual(st.values, st_reconstructed.values)
        self.assertAllEqual(st.dense_shape, st_reconstructed.dense_shape)

    @parameterized.parameters([
        (sparse_tensor.SparseTensorSpec([2, None], dtypes.float32), 32,
         sparse_tensor.SparseTensorSpec([32, 2, None], dtypes.float32)),
        (sparse_tensor.SparseTensorSpec([4, None], dtypes.float32), None,
         sparse_tensor.SparseTensorSpec([None, 4, None], dtypes.float32)),
        (sparse_tensor.SparseTensorSpec([2], dtypes.float32), 32,
         sparse_tensor.SparseTensorSpec([32, 2], dtypes.float32)),
    ])
    def testBatch(self, spec, batch_size, expected):
        self.assertEqual(spec._batch(batch_size), expected)

    @parameterized.parameters([
        (sparse_tensor.SparseTensorSpec([32, None, None], dtypes.float32),
         sparse_tensor.SparseTensorSpec([None, None], dtypes.float32)),
        (sparse_tensor.SparseTensorSpec([None, None, None], dtypes.float32),
         sparse_tensor.SparseTensorSpec([None, None], dtypes.float32)),
        (sparse_tensor.SparseTensorSpec([32, 2], dtypes.float32),
         sparse_tensor.SparseTensorSpec([2], dtypes.float32)),
    ])
    def testUnbatch(self, spec, expected):
        self.assertEqual(spec._unbatch(), expected)
Ejemplo n.º 27
0
 def testDatasetDatasetSpec(self):
     self._testDatasetSpec(
         dataset_ops.Dataset.from_tensor_slices(
             constant_op.constant([1, 2, 3])),
         dataset_ops.DatasetSpec(tensor_spec.TensorSpec([], dtypes.int32)))
Ejemplo n.º 28
0
 def testFlatTensorSpecs(self, st_spec):
     self.assertEqual(st_spec._flat_tensor_specs,
                      [tensor_spec.TensorSpec(None, dtypes.variant)])
Ejemplo n.º 29
0
 def test_unspecified_default_argument(self):
   wrapped = def_function.function(
       lambda x, y=2: x + y,
       input_signature=[tensor_spec.TensorSpec((), dtypes.int32)])
   self.assertEqual(3, wrapped(constant_op.constant(1)).numpy())
Ejemplo n.º 30
0
    def __init__(self, shard_num, multi_device_iterator_resource,
                 incarnation_id, source_device, element_spec):
        self._element_spec = element_spec

        multi_device_iterator_string_handle = (
            gen_dataset_ops.multi_device_iterator_to_string_handle(
                multi_device_iterator_resource))

        # TODO(b/124254153): Enable autograph once the overhead is low enough.
        @function.defun(autograph=False)  # Pure graph code.
        def _init_func():
            return multi_device_iterator_string_handle

        init_func_concrete = _init_func.get_concrete_function()

        # TODO(b/124254153): Enable autograph once the overhead is low enough.
        @function.defun(autograph=False)  # Pure graph code.
        def _remote_init_func():
            return functional_ops.remote_call(
                target=source_device,
                args=init_func_concrete.captured_inputs,
                Tout=[dtypes.string],
                f=init_func_concrete)

        self._init_func = _remote_init_func.get_concrete_function()
        self._init_captured_args = self._init_func.captured_inputs

        # TODO(b/124254153): Enable autograph once the overhead is low enough.
        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
            autograph=False)  # Pure graph code.
        def _next_func(string_handle):
            # pylint: disable=protected-access
            multi_device_iterator = (
                gen_dataset_ops.multi_device_iterator_from_string_handle(
                    string_handle=string_handle,
                    output_types=structure.get_flat_tensor_types(
                        self._element_spec),
                    output_shapes=structure.get_flat_tensor_shapes(
                        self._element_spec)))
            return gen_dataset_ops.multi_device_iterator_get_next_from_shard(
                multi_device_iterator=multi_device_iterator,
                shard_num=shard_num,
                incarnation_id=incarnation_id,
                output_types=structure.get_flat_tensor_types(
                    self._element_spec),
                output_shapes=structure.get_flat_tensor_shapes(
                    self._element_spec))

        next_func_concrete = _next_func.get_concrete_function()

        # TODO(b/124254153): Enable autograph once the overhead is low enough.
        @function.defun_with_attributes(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
            attributes={"experimental_ints_on_device": True},
            autograph=False)  # Pure graph code.
        def _remote_next_func(string_handle):
            return functional_ops.remote_call(
                target=source_device,
                args=[string_handle] + next_func_concrete.captured_inputs,
                Tout=structure.get_flat_tensor_types(self._element_spec),
                f=next_func_concrete)

        self._next_func = _remote_next_func.get_concrete_function()
        self._next_captured_args = self._next_func.captured_inputs

        self._incarnation_id_index = -1
        for i, arg in enumerate(self._next_captured_args):
            if arg is incarnation_id:
                self._incarnation_id_index = i

        # TODO(b/124254153): Enable autograph once the overhead is low enough.
        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
            autograph=False)  # Pure graph code.
        def _finalize_func(unused_string_handle):
            return array_ops.constant(0, dtypes.int64)

        finalize_func_concrete = _finalize_func.get_concrete_function()

        # TODO(b/124254153): Enable autograph once the overhead is low enough.
        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
            autograph=False)  # Pure graph code.
        def _remote_finalize_func(string_handle):
            return functional_ops.remote_call(
                target=source_device,
                args=[string_handle] + finalize_func_concrete.captured_inputs,
                Tout=[dtypes.int64],
                f=finalize_func_concrete)

        self._finalize_func = _remote_finalize_func.get_concrete_function()
        self._finalize_captured_args = self._finalize_func.captured_inputs

        variant_tensor = gen_dataset_ops.generator_dataset(
            self._init_captured_args,
            self._next_captured_args,
            self._finalize_captured_args,
            init_func=self._init_func,
            next_func=self._next_func,
            finalize_func=self._finalize_func,
            **self._flat_structure)
        super(_PerDeviceGenerator, self).__init__(variant_tensor)