Пример #1
0
  def test_nested_functions(self):
    f = def_function.function(
        lambda x: x*2.0,
        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
    g = def_function.function(
        lambda x: f(x) + 1.0,
        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])

    root = tracking.AutoCheckpointable()
    root.g = g
    imported = self.cycle(root)
    imported.g(constant_op.constant([1.0]))
Пример #2
0
  def test_nested_functions(self, cycles):
    f = def_function.function(
        lambda x: x*2.0,
        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
    g = def_function.function(
        lambda x: f(x) + 1.0,
        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])

    root = tracking.AutoCheckpointable()
    root.g = g
    # TODO(vbardiovsky): Enable this test. For this to work, we must ensure that
    # concrete_function._inference_function._graph._functions contains all
    # functions that were on the graph before saving.
    imported = self.cycle(root, 1)
    imported.g(constant_op.constant([1.0]))
def add_metric_step(defun):
  optimizer = keras.optimizer_v2.rmsprop.RMSprop()
  model = testing_utils.get_model_from_layers([
      LayerWithMetrics(),
      keras.layers.Dense(1, kernel_initializer='zeros', activation='softmax')
  ],
                                              input_shape=(10,))

  def train_step(x, y):
    with backprop.GradientTape() as tape:
      y_pred_1 = model(x)
      y_pred_2 = model(2 * x)
      y_pred = y_pred_1 + y_pred_2
      loss = keras.losses.mean_squared_error(y, y_pred)
    gradients = tape.gradient(loss, model.trainable_weights)
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))
    assert len(model.metrics) == 2
    return [m.result() for m in model.metrics]

  if defun:
    train_step = def_function.function(train_step)

  x, y = array_ops.ones((10, 10)), array_ops.zeros((10, 1))
  metrics = train_step(x, y)
  assert np.allclose(metrics[0], 1.5)
  assert np.allclose(metrics[1], 1.5)
  return metrics
Пример #4
0
 def test_table(self):
   initializer = lookup_ops.TextFileInitializer(
       self._vocab_path,
       key_dtype=dtypes.string,
       key_index=lookup_ops.TextFileIndex.WHOLE_LINE,
       value_dtype=dtypes.int64,
       value_index=lookup_ops.TextFileIndex.LINE_NUMBER)
   root = util.Checkpoint(table=lookup_ops.HashTable(
       initializer, default_value=-1))
   root.table_user = def_function.function(
       root.table.lookup,
       input_signature=[tensor_spec.TensorSpec(None, dtypes.string)])
   self.assertEqual(
       2,
       self.evaluate(root.table_user(constant_op.constant("gamma"))))
   save_dir = os.path.join(self.get_temp_dir(), "saved_model")
   save.save(root, save_dir)
   file_io.delete_file(self._vocab_path)
   self.assertAllClose(
       {"output_0": [2, 0]},
       _import_and_infer(save_dir, {"keys": ["gamma", "alpha"]}))
   second_dir = os.path.join(self.get_temp_dir(), "second_dir")
   # Asset paths should track the location the SavedModel is loaded from.
   file_io.rename(save_dir, second_dir)
   self.assertAllClose(
       {"output_0": [2, 1]},
       _import_and_infer(second_dir, {"keys": ["gamma", "beta"]}))
Пример #5
0
  def testRequestNotToCompile(self):
    with self.test_scope():
      def f(x):
        with ops.device('device:CPU:0'):
          y = 2.0 * x
        return x, y

      wholly_compiled_f = def_function.function(f)
      op_by_op_f = function.defun_with_attributes(
          f, attributes={'_XlaCompile': False})

      x = constant_op.constant([0.0, 2.0], name='data')

      # When function is wholly compiled, all outputs will be on the
      # device on which it is run.
      r_x, r_y = wholly_compiled_f(x)
      self.assertAllEqual([0.0, 2.0], r_x)
      self.assertAllEqual([0.0, 4.0], r_y)
      if context.executing_eagerly():
        # backing_device is only available for eager tensors.
        self.assertRegexpMatches(r_x.backing_device, self.device)
        self.assertRegexpMatches(r_y.backing_device, self.device)

      # When function is executed op-by-op, requested devices will be
      # respected.
      r_x, r_y = op_by_op_f(x)
      self.assertAllEqual([0.0, 2.0], r_x)
      self.assertAllEqual([0.0, 4.0], r_y)
      if context.executing_eagerly():
        # backing_device is only available for eager tensors.
        self.assertRegexpMatches(r_x.backing_device, self.device)
        self.assertRegexpMatches(r_y.backing_device, 'device:CPU:0')
Пример #6
0
  def test_structured_output(self):

    # Use fields with non-alphabetical order
    named_tuple_type = collections.namedtuple("NamedTupleHello", ["b", "a"])

    def func(input1, input2):
      named_tuple = named_tuple_type(a=input1 + input2, b=input1 * input2)
      return [named_tuple, input2, {"x": 0.5}]

    root = tracking.AutoCheckpointable()
    root.f = def_function.function(func)

    result = root.f(constant_op.constant(2), constant_op.constant(3))

    self.assertEqual(5, result[0].a.numpy())
    self.assertEqual(6, result[0].b.numpy())
    self.assertEqual(["b", "a"], list(result[0]._asdict().keys()))
    self.assertEqual(3, result[1].numpy())
    self.assertEqual(0.5, result[2]["x"].numpy())

    imported = self.cycle(root)

    result = imported.f(constant_op.constant(2), constant_op.constant(5))
    self.assertEqual(7, result[0].a.numpy())
    self.assertEqual(10, result[0].b.numpy())
    self.assertEqual(["b", "a"], list(result[0]._asdict().keys()))
    self.assertEqual(5, result[1].numpy())
    self.assertEqual(0.5, result[2]["x"].numpy())
Пример #7
0
  def test_structured_inputs(self):

    def func(x, training=True):
      # x is a nested structure, we care about one particular tensor.
      _, (a, b) = x
      if training:
        return 2 * a["a"] + b
      else:
        return 7

    root = tracking.AutoCheckpointable()
    root.f = def_function.function(func)

    x = constant_op.constant(10)
    y = constant_op.constant(11)

    input1 = [6, ({"a": x}, y)]
    input2 = [7, ({"a": x}, y)]  # Not compatible with input1 signature.
    input3 = [6, ({"a": y}, x)]  # Compatible with input1 signature.

    # Note: by only calling f(input1) before serialization, only inputs with
    # matching signature will be valid on the loaded model.
    self.assertEqual(31, root.f(input1).numpy())

    imported = self.cycle(root)

    with self.assertRaisesRegexp(AssertionError,
                                 "Could not find matching function to call.*"):
      imported.f(input2)

    self.assertEqual(31, imported.f(input1).numpy())
    self.assertEqual(32, imported.f(input3).numpy())
  def testConstSavedModel(self):
    """Test a basic model with functions to make sure functions are inlined."""
    input_data = constant_op.constant(1., shape=[1])
    root = tracking.AutoTrackable()
    root.f = def_function.function(lambda x: 2. * x)
    to_save = root.f.get_concrete_function(input_data)

    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
    save(root, save_dir, to_save)
    saved_model = load(save_dir)
    input_func = saved_model.signatures["serving_default"]

    variable_graph_def = input_func.graph.as_graph_def()
    self.assertEqual(0, self._getNumVariables(variable_graph_def))
    self.assertTrue(variable_graph_def.library.function)

    output_func = convert_to_constants.convert_variables_to_constants_v2(
        input_func)
    constant_graph_def = output_func.graph.as_graph_def()
    self.assertEqual(0, self._getNumVariables(constant_graph_def))
    self.assertFalse(constant_graph_def.library.function)

    # Check value.
    expected_value = root.f(input_data)
    actual_value = self._evaluateGraphDef(constant_graph_def, input_func,
                                          [input_data.numpy()])
    self.assertEqual(expected_value.numpy(), actual_value)
  def testConstructConcreteFunction(self):
    input_data = constant_op.constant(1., shape=[1])
    root = tracking.AutoTrackable()
    root.v1 = variables.Variable(3.)
    root.v2 = variables.Variable(2.)
    root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
    func = root.f.get_concrete_function(input_data)

    input_func = convert_to_constants._construct_concrete_function(
        func, func.graph.as_graph_def())

    # Test if model has enough metadata to be frozen afterwards.
    variable_graph_def = input_func.graph.as_graph_def()
    self.assertEqual(2, self._getNumVariables(variable_graph_def))

    output_func = convert_to_constants.convert_variables_to_constants_v2(
        input_func)
    constant_graph_def = output_func.graph.as_graph_def()
    self.assertEqual(0, self._getNumVariables(constant_graph_def))
    self.assertFalse(self._hasStatefulPartitionedCallOp(constant_graph_def))

    # Check value.
    expected_value = root.f(input_data)
    actual_value = self._evaluateGraphDef(constant_graph_def, input_func,
                                          [input_data.numpy()])
    self.assertEqual(expected_value.numpy(), actual_value)
  def testVariableSavedModel(self):
    """Test a basic model with Variables with saving/loading the SavedModel."""
    input_data = constant_op.constant(1., shape=[1])
    root = tracking.AutoTrackable()
    root.v1 = variables.Variable(3.)
    root.v2 = variables.Variable(2.)
    root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
    to_save = root.f.get_concrete_function(input_data)

    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
    save(root, save_dir, to_save)
    saved_model = load(save_dir)
    input_func = saved_model.signatures["serving_default"]

    variable_graph_def = input_func.graph.as_graph_def()
    self.assertTrue(self._hasStatefulPartitionedCallOp(variable_graph_def))

    output_func = convert_to_constants.convert_variables_to_constants_v2(
        input_func)
    constant_graph_def = output_func.graph.as_graph_def()
    self.assertEqual(0, self._getNumVariables(constant_graph_def))
    self.assertFalse(self._hasStatefulPartitionedCallOp(constant_graph_def))

    # Check value.
    expected_value = root.f(input_data)
    actual_value = self._evaluateGraphDef(constant_graph_def, input_func,
                                          [input_data.numpy()])
    self.assertEqual(expected_value.numpy(), actual_value)
Пример #11
0
  def test_functools_partial_new_default(self):
    def f(x=3, y=7):
      return x + y

    func = def_function.function(functools.partial(f, y=6))
    self.assertEqual(func().numpy(), 9)
    self.assertEqual(func(y=8).numpy(), 11)
Пример #12
0
  def test_functools_partial_single_positional(self):
    def f(x, y):
      return x + y

    func = def_function.function(
        functools.partial(f, constant_op.constant(1)))
    self.assertAllEqual(func(5), 6)
Пример #13
0
  def test_callable(self):
    class M1(tracking.AutoCheckpointable):

      @def_function.function(
          input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
      def __call__(self, x):
        return x

    root = tracking.AutoCheckpointable()
    root.m1 = M1()
    root.m2 = tracking.AutoCheckpointable()
    root.m2.__call__ = def_function.function(
        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])(
            lambda x: x*3.0)
    imported = self.cycle(root)
    x = constant_op.constant(1.0)

    self.assertTrue(callable(imported.m1))
    self.assertAllEqual(root.m1(x), imported.m1(x))

    # Note: `root.m2` was not callable since `__call__` attribute was set
    # into the instance and not on the class. But after a serialization cycle
    # that starts to work.
    self.assertTrue(callable(imported.m2))
    self.assertAllEqual(root.m2.__call__(x), imported.m2(x))

    # Verify that user objects without `__call__` attribute are not callable.
    self.assertFalse(callable(imported))
Пример #14
0
def _apply_all_reduce(reduction, tensors):
  """Helper function for all_* functions."""
  if not tensors:
    raise ValueError('Must pass >0 tensors to all reduce operations')

  shared_name = _get_shared_name()

  def _all_reduce():
    """Call nccl allreduce."""
    res = []
    for t in tensors:
      _check_device(t)
      with ops.device(t.device):
        res.append(
            gen_nccl_ops.nccl_all_reduce(
                input=t,
                reduction=reduction,
                num_devices=len(tensors),
                shared_name=shared_name))
    return res

  if context.executing_eagerly():
    # Nccl ops will block unless they are executed concurrently such as in a
    # graph or a defun.
    return def_function.function(_all_reduce)()
  else:
    return _all_reduce()
Пример #15
0
  def testDecorate(self):
    func = def_function.function(lambda: 1)
    def decorator(f):
      return lambda: 1 + f()

    func._decorate(decorator)
    self.assertEqual(func().numpy(), 2)
Пример #16
0
  def test_functools_partial_keywords(self):
    def f(x, y):
      return x + y

    func = def_function.function(
        functools.partial(f, x=array_ops.zeros([1]), y=array_ops.zeros([1])))
    self.assertAllEqual(func(), [0.0])
Пример #17
0
 def test_single_function_default_signature(self):
   model = tracking.AutoCheckpointable()
   model.f = def_function.function(lambda: 3., input_signature=())
   model.f()
   save_dir = os.path.join(self.get_temp_dir(), "saved_model")
   save.save(model, save_dir)
   self.assertAllClose({"output_0": 3.},
                       _import_and_infer(save_dir, {}))
Пример #18
0
 def test_non_concrete_error(self):
   root = tracking.AutoTrackable()
   root.f = def_function.function(lambda x: 2. * x)
   root.f(constant_op.constant(1.))
   save_dir = os.path.join(self.get_temp_dir(), "saved_model")
   with self.assertRaisesRegexp(
       ValueError, "Expected a TensorFlow function"):
     save.save(root, save_dir, root.f)
Пример #19
0
 def test_non_concrete_error(self):
   root = tracking.AutoCheckpointable()
   root.f = def_function.function(lambda x: 2. * x)
   root.f(constant_op.constant(1.))
   save_dir = os.path.join(self.get_temp_dir(), "saved_model")
   with self.assertRaisesRegexp(
       ValueError, "must be converted to concrete functions"):
     save.save(root, save_dir, root.f)
  def testGradient(self):
    matmul = def_function.function(math_ops.matmul)

    def sq(x):
      return matmul(x, x, transpose_a=True)

    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
    grad_t, = backprop.gradients_function(sq, [0])(t)
    self.assertAllEqual(grad_t, [[6, 6], [14, 14]])
Пример #21
0
 def testFunctionReferenceCycles(self):
   fn = def_function.function(lambda x: 2. * x)
   fn(constant_op.constant(4.0))
   weak_fn = weakref.ref(fn)
   del fn
   # Tests that the weak reference we made to the function is now dead, which
   # means the object has been deleted. This should be true as long as the
   # function itself is not involved in a reference cycle.
   self.assertIs(None, weak_fn())
Пример #22
0
  def testTypeInvalid(self):
    root = tracking.AutoTrackable()
    root.v1 = variables.Variable(3.)
    root.v2 = variables.Variable(2.)
    root.f = def_function.function(lambda x: root.v1 * root.v2 * x)

    with self.assertRaises(ValueError) as error:
      _ = lite.TFLiteConverterV2.from_concrete_function(root.f)
    self.assertIn('call from_concrete_function', str(error.exception))
Пример #23
0
 def test_ambiguous_signatures(self):
   model = _ModelWithOptimizer()
   x = constant_op.constant([[3., 4.]])
   y = constant_op.constant([2.])
   model.call(x, y)
   model.second_function = def_function.function(lambda: 1.)
   save_dir = os.path.join(self.get_temp_dir(), "saved_model")
   with self.assertRaisesRegexp(ValueError, "call.*second_function"):
     save.save(model, save_dir)
Пример #24
0
 def test_nested_outputs(self):
   root = tracking.AutoCheckpointable()
   root.f = def_function.function(lambda x: (2. * x, (3. * x, 4. * x)))
   root.f(constant_op.constant(1.))
   to_save = root.f.get_concrete_function(constant_op.constant(1.))
   save_dir = os.path.join(self.get_temp_dir(), "saved_model")
   with self.assertRaisesRegexp(
       ValueError, "non-flat outputs"):
     save.save(root, save_dir, to_save)
Пример #25
0
  def testGradient(self):
    # TODO(b/121134877): Remove the autograph override.
    matmul = def_function.function(math_ops.matmul, autograph=False)

    def sq(x):
      return matmul(x, x, transpose_a=True)

    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
    grad_t, = backprop.gradients_function(sq, [0])(t)
    self.assertAllEqual(grad_t, [[6, 6], [14, 14]])
Пример #26
0
 def test_capture_variables(self):
   root = tracking.AutoCheckpointable()
   root.weights = variables.Variable(2.)
   root.f = def_function.function(
       lambda x: root.weights * x,
       input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
   imported = self.cycle(root)
   self.assertEqual(4., imported.f(constant_op.constant(2.)).numpy())
   imported.weights.assign(4.0)
   self.assertEqual(8., imported.f(constant_op.constant(2.)).numpy())
Пример #27
0
 def test_captured_constant(self, cycles):
   const = array_ops.zeros([100])
   root = tracking.AutoCheckpointable()
   root.f = def_function.function(lambda: const + 1.)
   root.g = def_function.function(lambda: const + 2.)
   self.assertAllClose(array_ops.ones([100]), root.f())
   self.assertAllClose(2. * array_ops.ones([100]), root.g())
   imported = self.cycle(root, cycles)
   self.assertAllClose(array_ops.ones([100]), imported.f())
   self.assertAllClose(2. * array_ops.ones([100]), imported.g())
   # TODO(b/123408994): Use the public get_concrete_function.
   f_concrete = imported.f._list_all_concrete_functions_for_serialization()[0]
   g_concrete = imported.g._list_all_concrete_functions_for_serialization()[0]
   self.assertLen(f_concrete.captured_inputs, 1)
   self.assertLen(g_concrete.captured_inputs, 1)
   # We should be using the same captured EagerTensor in both functions, not
   # duplicating the constant.
   self.assertIs(f_concrete.captured_inputs[0],
                 g_concrete.captured_inputs[0])
Пример #28
0
 def test_nested_inputs(self):
   root = tracking.Checkpointable()
   root.f = def_function.function(lambda x: 2. * x[0])
   root.f([constant_op.constant(1.)])
   to_export = root.f.get_concrete_function(
       [constant_op.constant(1.), constant_op.constant(2.)])
   export_dir = os.path.join(self.get_temp_dir(), "saved_model")
   with self.assertRaisesRegexp(
       ValueError, "non-unique argument names"):
     export.export(root, export_dir, to_export)
Пример #29
0
 def test_nested_dict_outputs(self):
   root = util.Checkpoint(
       f=def_function.function(
           lambda x: {"a": 2. * x, "b": (3. * x, 4. * x)}))
   root.f(constant_op.constant(1.))
   to_save = root.f.get_concrete_function(constant_op.constant(1.))
   save_dir = os.path.join(self.get_temp_dir(), "saved_model")
   with self.assertRaisesRegexp(
       ValueError, "dictionary containing non-Tensor value"):
     save.save(root, save_dir, to_save)
Пример #30
0
 def test_nested_dict_outputs(self):
   root = tracking.Checkpointable()
   root.f = def_function.function(
       lambda x: {"a": 2. * x, "b": (3. * x, 4. * x)})
   root.f(constant_op.constant(1.))
   to_export = root.f.get_concrete_function(constant_op.constant(1.))
   export_dir = os.path.join(self.get_temp_dir(), "saved_model")
   with self.assertRaisesRegexp(
       ValueError, "dictionary containing non-Tensor value"):
     export.export(root, export_dir, to_export)
Пример #31
0
    def testAllGatherRaiseDiffShapeAtNonAxis(self, strategy, pure_eager):
        """Different at non-`axis`-th dimension : [2, 1], [1, 1], all_gather(...axis=1...) -> raise error."""
        if _get_num_devices_per_worker(strategy) > 1:
            self.skipTest('b/167331966')

        def value_fn(ctx):
            return constant_op.constant(
                1, shape=(1, ctx.replica_id_in_sync_group + 1))

        per_replica_value = strategy.experimental_distribute_values_from_function(
            value_fn)

        def run(value):
            value_identity = array_ops.identity(value)
            ctx = ds_context.get_replica_context()
            return ctx._all_gather(value_identity, axis=0)

        if not pure_eager:
            run = def_function.function(run)

        with self.assertRaisesRegex(errors.InvalidArgumentError,
                                    r'Shape mismatch'):
            strategy.run(run, args=(per_replica_value, ))
Пример #32
0
  def testSimpleReduce(self, strategy):

    def fn_eager():

      def replica_fn():
        return array_ops.ones((), dtypes.float32)

      per_replica_value = strategy.run(replica_fn)
      return strategy.reduce(
          reduce_util.ReduceOp.SUM, value=per_replica_value, axis=None)

    fn_graph = def_function.function(fn_eager)

    # Run reduce under the strategy scope to explicitly enter
    # strategy default_device scope.
    with strategy.scope():
      self.assertEqual(fn_eager().numpy(), 1.0 * strategy.num_replicas_in_sync)
      self.assertEqual(fn_graph().numpy(), 1.0 * strategy.num_replicas_in_sync)

    # Run reduce without a strategy scope to implicitly enter
    # strategy default_device scope.
    self.assertEqual(fn_eager().numpy(), 1.0 * strategy.num_replicas_in_sync)
    self.assertEqual(fn_graph().numpy(), 1.0 * strategy.num_replicas_in_sync)
Пример #33
0
    def testFloat(self):
        input_data = constant_op.constant(1., shape=[1])
        root = tracking.AutoTrackable()
        root.v1 = variables.Variable(3.)
        root.v2 = variables.Variable(2.)
        root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
        concrete_func = root.f.get_concrete_function(input_data)

        # Convert model.
        converter = lite.TFLiteConverterV2.from_concrete_function(
            concrete_func)
        converter.target_spec.supported_ops = set([lite.OpsSet.SELECT_TF_OPS])
        tflite_model = converter.convert()

        # Ensures the model contains TensorFlow ops.
        # TODO(nupurgarg): Check values once there is a Python delegate interface.
        interpreter = Interpreter(model_content=tflite_model)
        with self.assertRaises(RuntimeError) as error:
            interpreter.allocate_tensors()
        self.assertIn(
            'Regular TensorFlow ops are not supported by this interpreter. Make '
            'sure you invoke the Flex delegate before inference.',
            str(error.exception))
Пример #34
0
    def _all_gather_same_shape_and_verify(self, value_on_replica, axis,
                                          pure_eager, strategy):
        per_replica_value = strategy.experimental_distribute_values_from_function(
            lambda _: array_ops.identity(value_on_replica))

        def replica_fn(per_replica_value):
            ctx = ds_context.get_replica_context()
            local_value = array_ops.identity(per_replica_value)
            return ctx._all_gather(local_value, axis=axis)

        if not pure_eager:
            replica_fn = def_function.function(replica_fn)

        result = strategy.experimental_local_results(
            strategy.run(replica_fn, args=(per_replica_value, )))

        all_value = [
            value_on_replica for _ in range(strategy.num_replicas_in_sync)
        ]
        expect = array_ops.concat(all_value, axis=axis)
        expected_result = [expect] * _get_num_replicas_per_client(strategy)

        self.assertAllClose(expected_result, result)
Пример #35
0
    def testConstSavedModel(self):
        """Test a basic model with functions to make sure functions are inlined."""
        input_data = constant_op.constant(1., shape=[1])
        root = tracking.AutoTrackable()
        root.f = def_function.function(lambda x: 2. * x)
        to_save = root.f.get_concrete_function(input_data)

        save_dir = os.path.join(self.get_temp_dir(), "saved_model")
        save(root, save_dir, to_save)
        saved_model = load(save_dir)
        input_func = saved_model.signatures["serving_default"]

        variable_graph_def = input_func.graph.as_graph_def()
        self.assertEqual(0, self._getNumVariables(variable_graph_def))
        self.assertTrue(variable_graph_def.library.function)

        output_func = convert_to_constants.convert_variables_to_constants_v2(
            input_func)
        constant_graph_def = output_func.graph.as_graph_def()
        self.assertEqual(0, self._getNumVariables(constant_graph_def))
        self.assertFalse(constant_graph_def.library.function)

        self._testConvertedFunction(root, root.f, output_func, input_data)
Пример #36
0
  def test_positional_arguments(self, cycles):
    def func(x, training=False, abc=7.1, defg=7.7):
      del abc
      if training:
        return 2 * x
      if defg == 7:
        return 6
      else:
        return 7

    root = tracking.AutoTrackable()
    root.f = def_function.function(func)

    self.assertEqual(20, root.f(constant_op.constant(10), True).numpy())
    self.assertEqual(7, root.f(constant_op.constant(1)).numpy())
    self.assertEqual(2, root.f(constant_op.constant(1), True).numpy())
    self.assertEqual(6, root.f(constant_op.constant(1), defg=7.0).numpy())

    imported = self.cycle(root, cycles)

    self.assertEqual(4, imported.f(constant_op.constant(2), True).numpy())
    self.assertEqual(7, imported.f(constant_op.constant(2)).numpy())
    self.assertEqual(6, imported.f(constant_op.constant(1), defg=7.0).numpy())
Пример #37
0
    def testVariableSavedModel(self):
        """Test a basic model with Variables with saving/loading the SavedModel."""
        input_data = constant_op.constant(1., shape=[1])
        root = tracking.AutoTrackable()
        root.v1 = variables.Variable(3.)
        root.v2 = variables.Variable(2.)
        root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
        to_save = root.f.get_concrete_function(input_data)

        save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
        save(root, save_dir, to_save)
        saved_model = load(save_dir)
        concrete_func = saved_model.signatures['serving_default']

        # Convert model and ensure model is not None.
        converter = lite.TFLiteConverterV2.from_concrete_function(
            concrete_func)
        tflite_model = converter.convert()

        # Check values from converted model.
        expected_value = root.f(input_data)
        actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
        self.assertEqual(expected_value.numpy(), actual_value)
Пример #38
0
    def replica_fn():
      collective, devices, pid = self.make_collective(options.num_processes,
                                                      options.gpus_per_process,
                                                      options.communication)

      def reduce_fn():
        value_fn = lambda device_idx: inputs[pid * len(devices) + device_idx]
        per_replica_value = make_per_replica_value(value_fn, devices)
        reduced_values = collective.reduce(options.reduce_op, per_replica_value,
                                           per_replica_value)
        reduced_values = self.as_list(reduced_values)
        self.assertAllEqual(devices, [v.device for v in reduced_values])
        return [ops.convert_to_tensor(v) for v in reduced_values]

      per_replica_expect = [ops.convert_to_tensor(expect)] * len(devices)

      if "eager" in options.mode:
        got = reduce_fn()
        self.assertAllClose(got, per_replica_expect)

      if "func_graph" in options.mode:
        got = def_function.function(reduce_fn)()
        self.assertAllClose(got, per_replica_expect)
Пример #39
0
  def test_subclassed_no_signature(self):

    class Subclassed(training.Model):

      def call(self, inputs):
        return inputs * 2.

    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
    model = Subclassed()
    with self.assertRaisesRegexp(
        ValueError, "no @tf.function-decorated methods"):
      save.save(model, save_dir)

    traced_call = def_function.function(
        model.call,
        input_signature=(tensor_spec.TensorSpec(
            (None, None),
            dtype=dtypes.float32),))
    save.save(model, save_dir, traced_call)
    self.assertAllClose({"output_0": [[8., 10.], [10., 12.]]},
                        _import_and_infer(
                            save_dir,
                            {"inputs": [[4., 5.], [5., 6.]]}))
Пример #40
0
    def test_load_in_func_graph(self, cycles):
        root = tracking.AutoCheckpointable()
        root.v1 = variables.Variable(1.)
        root.v2 = variables.Variable(2.)
        root.f = def_function.function(
            lambda x: root.v2 * x,
            input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])

        if cycles > 1:
            root = self.cycle(root, cycles - 1)
        path = tempfile.mkdtemp(prefix=self.get_temp_dir())
        save.save(root, path)

        closure = tracking.AutoCheckpointable()

        @def_function.function
        def func(x):
            if not hasattr(closure, "model"):
                closure.model = load.load(path)
            return closure.model.f(x)

        inputs = constant_op.constant(2.)
        self.assertEqual(4.0, func(inputs).numpy())
Пример #41
0
  def testKerasLSTM(self):
    input_data = constant_op.constant(
        np.array(np.random.random_sample((10, 10, 10)), dtype=np.float32))

    model = keras.models.Sequential(
        [keras.layers.LSTM(units=10, input_shape=(10, 10))])

    run_model = def_function.function(model.__call__)
    concrete_func = run_model.get_concrete_function(
        tensor_spec.TensorSpec((10, 10, 10), dtype=dtypes.float32))

    # Convert model.
    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
    converter.experimental_enable_mlir_converter = True
    tflite_model = converter.convert()

    # Check values from converted model.
    expected_value = concrete_func(input_data)
    # TODO(haoliang): Remove this `reshape` op since it's not necessary.
    actual_value = np.reshape(
        self._evaluateTFLiteModel(tflite_model, [input_data]), (10, 10))
    for expected, actual in zip(expected_value, actual_value):
      np.testing.assert_almost_equal(expected, actual)
Пример #42
0
    def testGatherRaiseDiffShapeAtNonAxis(self, strategy, pure_eager):
        """Different at non-`axis`-th dimension : [1, 1], [1, 2], 0th -> raise error."""
        if _get_num_devices_per_worker(strategy) > 1:
            self.skipTest('b/167331966')

        def value_fn(ctx):
            return constant_op.constant(
                1, shape=(1, ctx.replica_id_in_sync_group + 1))

        distributed_values = strategy.experimental_distribute_values_from_function(
            value_fn)
        axis = 0

        def run():
            return strategy._gather(distributed_values, axis=axis)

        error_message = 'Shape mismatch'
        if not pure_eager:
            run = def_function.function(run)

        with self.assertRaisesRegex(errors.InvalidArgumentError,
                                    error_message):
            run()
Пример #43
0
  def test_saved_model(self):
    with self.device:
      different_values = self.device.pack(
          [constant_op.constant(-1.),
           constant_op.constant(3.)])
      m = module.Module()
      m.v = variables.Variable(different_values)
      m.f = def_function.function(lambda: m.v * 2.)
      self.assertAllClose([-2., 6.], self.device.unpack(m.f()))
    saved_model_path = os.path.join(self.get_temp_dir(), "saved_model")
    save.save(m, saved_model_path)

    context._reset_context()
    self.setUp()

    single_device_loaded = load.load(saved_model_path)
    self.assertAllClose(-2., single_device_loaded.f())
    with self.device:
      parallel_loaded = load.load(saved_model_path)
      self.assertAllClose([-2., 6.], self.device.unpack(parallel_loaded.f()))
      self.assertAllClose([-1., 3.], self.device.unpack(parallel_loaded.v))
      parallel_loaded.v.assign(self.device.pack([.1, .2]))
      self.assertAllClose([.2, .4], self.device.unpack(parallel_loaded.f()))
Пример #44
0
  def test_complicated_partial_with_defaults(self):

    def identity(*args):
      return args

    def dynamic_unroll(core_fn,
                       input_sequence,
                       initial_state,
                       sequence_length=None,
                       parallel_iterations=1,
                       swap_memory=False):
      del core_fn
      self.assertIs(None, sequence_length)
      self.assertEqual(1, parallel_iterations)
      self.assertTrue(swap_memory)
      return input_sequence, initial_state

    input_sequence = random_ops.random_uniform([1, 1, 1])
    initial_state = random_ops.random_uniform([1, 1])

    func = def_function.function(
        functools.partial(dynamic_unroll, identity, swap_memory=True))
    func(input_sequence, initial_state)
  def testDerivative(self):
    with ops.device('device:{}:0'.format(self.device)):

      def fn(x, a):
        return 2 * x + a

      xla_func = def_function.function(fn, jit_compile=True)

      with backprop.GradientTape() as tape:
        inputs = constant_op.constant([1., 2., 2., 3., 3.])
        tape.watch(inputs)
        outputs = xla_func(inputs, 1)

      self.assertAllClose([2, 2, 2, 2, 2], tape.gradient(outputs, inputs))

      # pylint: disable=protected-access
      (forward, backward) = xla_func.get_concrete_function(
          inputs, 1)._delayed_rewrite_functions.forward_backward()

      # Check that the must-compile attribute gets correctly propagated to the
      # created derivatives.
      self.assertTrue(backward.function_def.attr['_XlaMustCompile'])
      self.assertTrue(forward.definition.attr['_XlaMustCompile'])
Пример #46
0
    def test_variables_mismatched_device_assignment(self):
        resolver = get_tpu_cluster_resolver()
        remote.connect_to_cluster(resolver)
        topology = tpu_strategy_util.initialize_tpu_system(resolver)

        strategy0 = tpu_lib.TPUStrategyV2(resolver)
        self.assertEqual(("/job:localhost/replica:0/task:0/device:TPU:0",
                          "/job:localhost/replica:0/task:0/device:TPU:1"),
                         strategy0.extended.worker_devices)

        with strategy0.scope():
            v = variables.Variable(1.)

        v1_assign_op = strategy0.experimental_local_results(v)[1].assign(42.)

        with self.cached_session():
            self.evaluate(variables.global_variables_initializer())
            self.evaluate(v1_assign_op)
            self.assertAllEqual([1., 42.],
                                self.evaluate(
                                    strategy0.experimental_local_results(v)))

        # Second strategy has devices reversed relative to the first.
        device_assignment = device_assignment_lib.DeviceAssignment(
            topology, core_assignment=[[[0, 0, 0, 1]], [[0, 0, 0, 0]]])
        strategy1 = tpu_lib.TPUStrategyV2(
            resolver, experimental_device_assignment=device_assignment)
        self.assertEqual(("/job:localhost/replica:0/task:0/device:TPU:1",
                          "/job:localhost/replica:0/task:0/device:TPU:0"),
                         strategy1.extended.worker_devices)

        v_read = strategy1.run(def_function.function(v.read_value))

        with self.cached_session():
            self.assertAllEqual(
                [42., 1.],
                self.evaluate(strategy0.experimental_local_results(v_read)))
    def test_composite_tensor(self):
        with self.session(graph=ops.Graph()) as sess:
            sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
            operator, mat = self.operator_and_matrix(
                shapes_info, dtype, use_placeholder=use_placeholder)
            self.assertIsInstance(operator, composite_tensor.CompositeTensor)

            flat = nest.flatten(operator, expand_composites=True)
            unflat = nest.pack_sequence_as(operator,
                                           flat,
                                           expand_composites=True)
            self.assertIsInstance(unflat, type(operator))

            # Input the operator to a `tf.function`.
            x = self.make_x(operator, adjoint=False)
            op_y = def_function.function(lambda op: op.matmul(x))(unflat)
            mat_y = math_ops.matmul(mat, x)

            if not use_placeholder:
                self.assertAllEqual(mat_y.shape, op_y.shape)

            # Test while_loop.
            def body(op):
                return type(op)(**op.parameters),

            op_out, = while_v2.while_loop(cond=lambda _: True,
                                          body=body,
                                          loop_vars=(operator, ),
                                          maximum_iterations=3)
            loop_y = op_out.matmul(x)

            op_y_, loop_y_, mat_y_ = sess.run([op_y, loop_y, mat_y])
            self.assertAC(op_y_, mat_y_)
            self.assertAC(loop_y_, mat_y_)

            # Ensure that the `TypeSpec` can be encoded.
            nested_structure_coder.encode_structure(operator._type_spec)  # pylint: disable=protected-access
Пример #48
0
    def testClone(self, input_signature, autograph, autograph_options,
                  implements, relax_shapes, compile_, override_function):
        original_py_function = lambda x: x

        compile_ = False
        func = def_function.function(
            func=original_py_function,
            input_signature=input_signature,
            autograph=autograph,
            experimental_implements=implements,
            experimental_autograph_options=autograph_options,
            experimental_relax_shapes=relax_shapes,
            jit_compile=compile_)

        if override_function:
            cloned_py_function = lambda x: x + 1
        else:
            cloned_py_function = original_py_function

        cloned = func._clone(python_function=cloned_py_function)

        self.assertEqual(cloned_py_function, cloned._python_function)
        self.assertEqual(func._name, cloned._name)
        self.assertEqual(input_signature, cloned._input_signature)
        self.assertEqual(autograph, cloned._autograph)
        self.assertEqual(implements, cloned._implements)
        self.assertEqual(autograph_options,
                         cloned._experimental_autograph_options)
        self.assertEqual(relax_shapes, cloned._experimental_relax_shapes)
        self.assertEqual(compile_, cloned._jit_compile)

        # This test does not run with XLA JIT support linked in so we can only check
        # the output of the function if compile is disabled.
        if not compile_:
            x = array_ops.zeros([])
            self.assertEqual(self.evaluate(cloned(x)),
                             self.evaluate(cloned_py_function(x)))
Пример #49
0
    def testDistVarAsTFFunArg(self, strat, jit_replica_fn):
        """Tests that RNG with dist variables can be used as tf.function's arg."""
        strat_name = type(strat).__name__
        if "CentralStorage" in strat_name:
            self.skipTest(
                "CentralStorageStrategy wraps variable updates in merge_call which "
                "can't be called inside a tf.function that doesn't cover the entire "
                "replica function (the function passed to strategy.run).")
        if "TPU" in strat_name and not jit_replica_fn:
            self.skipTest(
                "TPUStrategy requires the replica function (the function passed to "
                "strategy.run) to be decorated with tf.function")
        coord = None
        if "ParameterServer" in strat_name:
            coord = coordinator_lib.ClusterCoordinator(strat)
        shape = [3, 4]
        dtype = dtypes.int32
        with strat.scope():
            gen = rng.Generator.from_seed(1234)

            @def_function.function
            def f(gen):  # the main focus
                t1 = gen.uniform_full_int(shape=shape, dtype=dtype)
                t2 = gen.uniform_full_int(shape=shape, dtype=dtype)
                t = array_ops.stack([t1, t2])
                return t

            def g():
                return f(gen)

            replica_fn = def_function.function(g) if jit_replica_fn else g
            for _ in range(2):
                results = run_on_strategy(replica_fn, strat, coord)
                values = strat.experimental_local_results(results)
                n = get_num_local_replicas(strat, values)
                self.assertAllEqual(n, len(values))
                self.assertAllDifferent(values)
  def testAllGatherRaiseDiffShapeAtNonAxis(self, strategy, pure_eager):
    """Different at non-`axis`-th dimension : [2, 1], [1, 1], all_gather(...axis=1...) -> raise error."""
    if _is_tpu_strategy(strategy):
      self.skipTest('TODO(b/169108777): raise a clear error message in xla.')

    if isinstance(strategy, CollectiveAllReduceStrategy
                 ) and _get_num_replicas_per_client(strategy) > 1:
      self.skipTest('b/167331966')

    if strategy.num_replicas_in_sync <= 1:
      self.skipTest('Test for more than 1 replica only.')

    def value_fn(ctx):
      return constant_op.constant(
          1, shape=(1, ctx.replica_id_in_sync_group + 1))

    per_replica_value = strategy.experimental_distribute_values_from_function(
        value_fn)

    def run(value):
      value_identity = array_ops.identity(value)
      ctx = ds_context.get_replica_context()
      return ctx.all_gather(value_identity, axis=0)

    if not pure_eager:
      run = def_function.function(run)

    if isinstance(strategy, CollectiveAllReduceStrategy):
      with self.assertRaisesRegex(errors.InvalidArgumentError,
                                  r'Shape mismatch'):
        strategy.run(run, args=(per_replica_value,))
    elif isinstance(strategy,
                    (mirrored_strategy.MirroredStrategy,
                     central_storage_strategy.CentralStorageStrategy)):
      with self.assertRaisesRegex((errors.InvalidArgumentError, ValueError),
                                  r'Dimension \d in both shapes must be equal'):
        strategy.run(run, args=(per_replica_value,))
Пример #51
0
        def replica_fn():
            CollectiveReplicaLauncher._prefer_unique_instance_key = (
                options.prefer_unique_instance_key)
            collective, devices, pid = self.make_collective(
                options.num_processes, options.gpus_per_process)

            def batch_reduce_fn():
                batch_size = len(inputs[0])
                value_dst_pairs = []
                for i in range(batch_size):

                    def value_fn(device_idx, idx=i):
                        return inputs[pid * len(devices) + device_idx][idx]

                    per_replica_value = make_per_replica_value(
                        value_fn, devices)
                    value_dst_pairs.append(
                        (per_replica_value, per_replica_value))
                reduced_values = collective.batch_reduce(
                    options.reduce_op, value_dst_pairs,
                    options.communication_options)
                reduced_values = [self.as_list(v) for v in reduced_values]
                for v in reduced_values:
                    self.assertAllEqual(devices, [t.device for t in v])
                return nest.map_structure(ops.convert_to_tensor,
                                          reduced_values)

            per_replica_expect = nest.map_structure(
                lambda x: [ops.convert_to_tensor(x)] * len(devices), expect)

            if "eager" in options.mode:
                got = batch_reduce_fn()
                self.assertAllClose(got, per_replica_expect)

            if "func_graph" in options.mode:
                got = def_function.function(batch_reduce_fn)()
                self.assertAllClose(got, per_replica_expect)
Пример #52
0
  def test_save_invalid_custom_gradients(self):
    # The full custom gradients test is in load_test.py. This test just makes
    # sure that a warning is logged when the user has disabled custom gradients
    # and saves a model with invalid custom gradients.

    @custom_gradient.custom_gradient
    def invalid(x):
      def grad(dy):
        raise NotImplementedError
      return x, grad

    root = tracking.AutoTrackable()
    root.f = def_function.function(
        invalid, input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
    options = save_options.SaveOptions(experimental_custom_gradients=None)

    with self.assertLogs(level="WARNING") as logs:
      save.save(root, save_dir, root.f, options=options)
    self.assertIn(
        "Your model contains untraceable custom gradients. This "
        "warning may become an error in the future. Please set the option "
        "tf.saved_model.SaveOption(experimental_custom_gradients=True) to get "
        "the full error message.", "".join(logs.output))
Пример #53
0
    def testInputSignatureMissingTensorSpecsLambdaFunction(self):
        tf_func_dec = def_function.function(
            input_signature=(tensor_spec.TensorSpec([], dtypes.int32), ))
        error_msg = 'TensorSpecs are still required.*arg2.*arg3'
        with self.assertRaisesRegex(TypeError, error_msg):
            tf_func_dec(lambda ar1, arg2, arg3: None)(1, 2, 3)

        with self.assertRaisesRegex(TypeError, error_msg):
            tf_func_dec(lambda arg1, arg2, arg3, **kwargs: None)(1, 2, 3)

        with self.assertRaisesRegex(TypeError, error_msg):
            tf_func_dec(lambda arg1, arg2, arg3, arg4=4, **kwargs: None)(1, 2,
                                                                         3)

        with self.assertRaisesRegex(TypeError, error_msg):
            tf_func_dec(lambda arg1, arg2, arg3, *args: None)(1, 2, 3)

        with self.assertRaisesRegex(TypeError, error_msg):
            tf_func_dec(lambda arg1, arg2, arg3, *args, **kwargs: None)(1, 2,
                                                                        3)

        self.assertEqual(
            tf_func_dec(lambda arg1, arg4=4, **kwargs: arg1 + arg4)(1).numpy(),
            5)
Пример #54
0
    def testIteratedGradientsNestedWithVariable(self):
        def _grad(f):
            def _grad_function():
                with backprop.GradientTape() as tape:
                    primal_out = f()
                g, = tape.gradient(primal_out, tape.watched_variables())
                return g

            return _grad_function

        v = variables.Variable(2.)

        @def_function.function
        def _forward():
            return math_ops.cos(v)

        f = _forward

        two = constant_op.constant(2.)

        for expected in _COS_DERIVATIVES:
            self.assertAllClose(expected(two), f())
            self.assertAllClose(expected(two), def_function.function(f)())
            f = _grad(f)
Пример #55
0
    def testVariableModel(self):
        """Test a basic model with Variables."""
        input_data = constant_op.constant(1., shape=[1])
        root = tracking.AutoTrackable()
        root.v1 = variables.Variable(3.)
        root.v2 = variables.Variable(2.)
        root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
        concrete_func = root.f.get_concrete_function(input_data)

        variable_graph_def = concrete_func.graph.as_graph_def()
        self.assertEqual(2, self._getNumVariables(variable_graph_def))

        constant_graph_def = convert_to_constants.convert_variables_to_constants_v2(
            concrete_func)
        self.assertEqual(0, self._getNumVariables(constant_graph_def))
        self.assertFalse(
            self._hasStatefulPartitionedCallOp(constant_graph_def))

        # Check value.
        expected_value = root.f(input_data)
        actual_value = self._evaluateGraphDef(constant_graph_def,
                                              concrete_func,
                                              [input_data.numpy()])
        self.assertEqual(expected_value.numpy(), actual_value)
Пример #56
0
    def testCapturingInFunctionWhileExecutingEagerly(self):
        optimizer = gradient_descent.SGD(1.0)

        var_holder = {}

        def step():
            if not var_holder:
                var_holder["var"] = variables.Variable(1.0)
            else:
                var_holder["var"].assign(1.0)

            with backprop.GradientTape() as tape:
                loss = var_holder["var"]**2
            grad = tape.gradient(loss, var_holder["var"])
            optimizer.apply_gradients([(grad, var_holder["var"])])
            return var_holder["var"].read_value()

        compiled_step = def_function.function(step)

        self.assertEqual(float(step()), -1.0)
        self.assertEqual(float(compiled_step()), -1.0)
        # This shouldn't fail; in particular, the learning rate tensor should
        # be an EagerTensor once again, not a graph Tensor.
        self.assertEqual(float(step()), -1.0)
Пример #57
0
    def __init__(self, dataset_fn, coordinator):
        """Makes an iterable from datasets created by the given function.

    Args:
      dataset_fn: A function that returns a `Dataset`.
      coordinator: a `ClusterCoordinator` object, used to create dataset
        resources.
    """
        def disallow_variable_creation(next_creator, **kwargs):
            raise ValueError(
                "Creating variables in `dataset_fn` is not allowed.")

        if isinstance(dataset_fn, def_function.Function):
            with variable_scope.variable_creator_scope(
                    disallow_variable_creation):
                dataset_fn = dataset_fn.get_concrete_function()
        elif not isinstance(dataset_fn, tf_function.ConcreteFunction):
            with variable_scope.variable_creator_scope(
                    disallow_variable_creation):
                dataset_fn = def_function.function(
                    dataset_fn).get_concrete_function()
        self._dataset_fn = dataset_fn
        self._coordinator = coordinator
        self._element_spec = None
Пример #58
0
    def restore(self, restored_tensors, restored_shapes):
        del restored_shapes  # Unused.
        restored_tensor_dict = {}
        for n, local_name in enumerate(self._local_names):
            restored_tensor_dict[local_name] = restored_tensors[n]

        def restore_from_tensors():
            restore_fn = self._trackable._restore_from_tensors  # pylint: disable=protected-access
            if (self._call_with_mapped_captures
                    and isinstance(restore_fn, core.ConcreteFunction)):
                self._call_with_mapped_captures(restore_fn,
                                                [restored_tensor_dict])
            else:
                restore_fn(restored_tensor_dict)

            # In graph mode, this wrapper function is converted into a tf.function,
            # and to ensure that _restore_from_tensors is executed, there must be at
            # least one returned tensor. `_restore_from_tensors` may return zero
            # tensors so create a dummy constant here.
            return constant_op.constant(1)

        if not ops.executing_eagerly_outside_functions():
            restore_from_tensors = def_function.function(restore_from_tensors)
        return restore_from_tensors()
Пример #59
0
    def testAllGatherNest1D0Axis(self, strategy, pure_eager):
        """all_gather(..., axis=0,...) a nest of DistributedValues."""
        single_value = constant_op.constant([1, 2, 3])
        axis = 0

        def run():
            value_identity = array_ops.identity(single_value)
            ctx = ds_context.get_replica_context()
            return ctx._all_gather([value_identity, value_identity], axis=axis)

        if not pure_eager:
            run = def_function.function(run)

        all_value = [
            single_value for _ in range(strategy.num_replicas_in_sync)
        ]
        expect = array_ops.concat(all_value, axis=axis)
        expected_per_replica = [expect] * _get_num_devices_per_worker(strategy)

        result = strategy.run(run)
        for gathered_result in result:
            self.assertAllEqual(
                strategy.experimental_local_results(gathered_result),
                expected_per_replica)
Пример #60
0
    def test_assets_import(self):
        file1 = self._make_asset("contents 1")
        file2 = self._make_asset("contents 2")

        root = tracking.Checkpointable()
        root.f = def_function.function(
            lambda x: 2. * x,
            input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
        root.asset1 = tracking.TrackableAsset(file1)
        root.asset2 = tracking.TrackableAsset(file2)

        save_dir = os.path.join(self.get_temp_dir(), "save_dir")
        save.save(root, save_dir)

        file_io.delete_file(file1)
        file_io.delete_file(file2)
        load_dir = os.path.join(self.get_temp_dir(), "load_dir")
        file_io.rename(save_dir, load_dir)

        imported = load.load(load_dir)
        with open(imported.asset1.asset_path.numpy(), "r") as f:
            self.assertEquals("contents 1", f.read())
        with open(imported.asset2.asset_path.numpy(), "r") as f:
            self.assertEquals("contents 2", f.read())