def test_nested_functions(self): f = def_function.function( lambda x: x*2.0, input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) g = def_function.function( lambda x: f(x) + 1.0, input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) root = tracking.AutoCheckpointable() root.g = g imported = self.cycle(root) imported.g(constant_op.constant([1.0]))
def test_nested_functions(self, cycles): f = def_function.function( lambda x: x*2.0, input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) g = def_function.function( lambda x: f(x) + 1.0, input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) root = tracking.AutoCheckpointable() root.g = g # TODO(vbardiovsky): Enable this test. For this to work, we must ensure that # concrete_function._inference_function._graph._functions contains all # functions that were on the graph before saving. imported = self.cycle(root, 1) imported.g(constant_op.constant([1.0]))
def add_metric_step(defun): optimizer = keras.optimizer_v2.rmsprop.RMSprop() model = testing_utils.get_model_from_layers([ LayerWithMetrics(), keras.layers.Dense(1, kernel_initializer='zeros', activation='softmax') ], input_shape=(10,)) def train_step(x, y): with backprop.GradientTape() as tape: y_pred_1 = model(x) y_pred_2 = model(2 * x) y_pred = y_pred_1 + y_pred_2 loss = keras.losses.mean_squared_error(y, y_pred) gradients = tape.gradient(loss, model.trainable_weights) optimizer.apply_gradients(zip(gradients, model.trainable_weights)) assert len(model.metrics) == 2 return [m.result() for m in model.metrics] if defun: train_step = def_function.function(train_step) x, y = array_ops.ones((10, 10)), array_ops.zeros((10, 1)) metrics = train_step(x, y) assert np.allclose(metrics[0], 1.5) assert np.allclose(metrics[1], 1.5) return metrics
def test_table(self): initializer = lookup_ops.TextFileInitializer( self._vocab_path, key_dtype=dtypes.string, key_index=lookup_ops.TextFileIndex.WHOLE_LINE, value_dtype=dtypes.int64, value_index=lookup_ops.TextFileIndex.LINE_NUMBER) root = util.Checkpoint(table=lookup_ops.HashTable( initializer, default_value=-1)) root.table_user = def_function.function( root.table.lookup, input_signature=[tensor_spec.TensorSpec(None, dtypes.string)]) self.assertEqual( 2, self.evaluate(root.table_user(constant_op.constant("gamma")))) save_dir = os.path.join(self.get_temp_dir(), "saved_model") save.save(root, save_dir) file_io.delete_file(self._vocab_path) self.assertAllClose( {"output_0": [2, 0]}, _import_and_infer(save_dir, {"keys": ["gamma", "alpha"]})) second_dir = os.path.join(self.get_temp_dir(), "second_dir") # Asset paths should track the location the SavedModel is loaded from. file_io.rename(save_dir, second_dir) self.assertAllClose( {"output_0": [2, 1]}, _import_and_infer(second_dir, {"keys": ["gamma", "beta"]}))
def testRequestNotToCompile(self): with self.test_scope(): def f(x): with ops.device('device:CPU:0'): y = 2.0 * x return x, y wholly_compiled_f = def_function.function(f) op_by_op_f = function.defun_with_attributes( f, attributes={'_XlaCompile': False}) x = constant_op.constant([0.0, 2.0], name='data') # When function is wholly compiled, all outputs will be on the # device on which it is run. r_x, r_y = wholly_compiled_f(x) self.assertAllEqual([0.0, 2.0], r_x) self.assertAllEqual([0.0, 4.0], r_y) if context.executing_eagerly(): # backing_device is only available for eager tensors. self.assertRegexpMatches(r_x.backing_device, self.device) self.assertRegexpMatches(r_y.backing_device, self.device) # When function is executed op-by-op, requested devices will be # respected. r_x, r_y = op_by_op_f(x) self.assertAllEqual([0.0, 2.0], r_x) self.assertAllEqual([0.0, 4.0], r_y) if context.executing_eagerly(): # backing_device is only available for eager tensors. self.assertRegexpMatches(r_x.backing_device, self.device) self.assertRegexpMatches(r_y.backing_device, 'device:CPU:0')
def test_structured_output(self): # Use fields with non-alphabetical order named_tuple_type = collections.namedtuple("NamedTupleHello", ["b", "a"]) def func(input1, input2): named_tuple = named_tuple_type(a=input1 + input2, b=input1 * input2) return [named_tuple, input2, {"x": 0.5}] root = tracking.AutoCheckpointable() root.f = def_function.function(func) result = root.f(constant_op.constant(2), constant_op.constant(3)) self.assertEqual(5, result[0].a.numpy()) self.assertEqual(6, result[0].b.numpy()) self.assertEqual(["b", "a"], list(result[0]._asdict().keys())) self.assertEqual(3, result[1].numpy()) self.assertEqual(0.5, result[2]["x"].numpy()) imported = self.cycle(root) result = imported.f(constant_op.constant(2), constant_op.constant(5)) self.assertEqual(7, result[0].a.numpy()) self.assertEqual(10, result[0].b.numpy()) self.assertEqual(["b", "a"], list(result[0]._asdict().keys())) self.assertEqual(5, result[1].numpy()) self.assertEqual(0.5, result[2]["x"].numpy())
def test_structured_inputs(self): def func(x, training=True): # x is a nested structure, we care about one particular tensor. _, (a, b) = x if training: return 2 * a["a"] + b else: return 7 root = tracking.AutoCheckpointable() root.f = def_function.function(func) x = constant_op.constant(10) y = constant_op.constant(11) input1 = [6, ({"a": x}, y)] input2 = [7, ({"a": x}, y)] # Not compatible with input1 signature. input3 = [6, ({"a": y}, x)] # Compatible with input1 signature. # Note: by only calling f(input1) before serialization, only inputs with # matching signature will be valid on the loaded model. self.assertEqual(31, root.f(input1).numpy()) imported = self.cycle(root) with self.assertRaisesRegexp(AssertionError, "Could not find matching function to call.*"): imported.f(input2) self.assertEqual(31, imported.f(input1).numpy()) self.assertEqual(32, imported.f(input3).numpy())
def testConstSavedModel(self): """Test a basic model with functions to make sure functions are inlined.""" input_data = constant_op.constant(1., shape=[1]) root = tracking.AutoTrackable() root.f = def_function.function(lambda x: 2. * x) to_save = root.f.get_concrete_function(input_data) save_dir = os.path.join(self.get_temp_dir(), "saved_model") save(root, save_dir, to_save) saved_model = load(save_dir) input_func = saved_model.signatures["serving_default"] variable_graph_def = input_func.graph.as_graph_def() self.assertEqual(0, self._getNumVariables(variable_graph_def)) self.assertTrue(variable_graph_def.library.function) output_func = convert_to_constants.convert_variables_to_constants_v2( input_func) constant_graph_def = output_func.graph.as_graph_def() self.assertEqual(0, self._getNumVariables(constant_graph_def)) self.assertFalse(constant_graph_def.library.function) # Check value. expected_value = root.f(input_data) actual_value = self._evaluateGraphDef(constant_graph_def, input_func, [input_data.numpy()]) self.assertEqual(expected_value.numpy(), actual_value)
def testConstructConcreteFunction(self): input_data = constant_op.constant(1., shape=[1]) root = tracking.AutoTrackable() root.v1 = variables.Variable(3.) root.v2 = variables.Variable(2.) root.f = def_function.function(lambda x: root.v1 * root.v2 * x) func = root.f.get_concrete_function(input_data) input_func = convert_to_constants._construct_concrete_function( func, func.graph.as_graph_def()) # Test if model has enough metadata to be frozen afterwards. variable_graph_def = input_func.graph.as_graph_def() self.assertEqual(2, self._getNumVariables(variable_graph_def)) output_func = convert_to_constants.convert_variables_to_constants_v2( input_func) constant_graph_def = output_func.graph.as_graph_def() self.assertEqual(0, self._getNumVariables(constant_graph_def)) self.assertFalse(self._hasStatefulPartitionedCallOp(constant_graph_def)) # Check value. expected_value = root.f(input_data) actual_value = self._evaluateGraphDef(constant_graph_def, input_func, [input_data.numpy()]) self.assertEqual(expected_value.numpy(), actual_value)
def testVariableSavedModel(self): """Test a basic model with Variables with saving/loading the SavedModel.""" input_data = constant_op.constant(1., shape=[1]) root = tracking.AutoTrackable() root.v1 = variables.Variable(3.) root.v2 = variables.Variable(2.) root.f = def_function.function(lambda x: root.v1 * root.v2 * x) to_save = root.f.get_concrete_function(input_data) save_dir = os.path.join(self.get_temp_dir(), "saved_model") save(root, save_dir, to_save) saved_model = load(save_dir) input_func = saved_model.signatures["serving_default"] variable_graph_def = input_func.graph.as_graph_def() self.assertTrue(self._hasStatefulPartitionedCallOp(variable_graph_def)) output_func = convert_to_constants.convert_variables_to_constants_v2( input_func) constant_graph_def = output_func.graph.as_graph_def() self.assertEqual(0, self._getNumVariables(constant_graph_def)) self.assertFalse(self._hasStatefulPartitionedCallOp(constant_graph_def)) # Check value. expected_value = root.f(input_data) actual_value = self._evaluateGraphDef(constant_graph_def, input_func, [input_data.numpy()]) self.assertEqual(expected_value.numpy(), actual_value)
def test_functools_partial_new_default(self): def f(x=3, y=7): return x + y func = def_function.function(functools.partial(f, y=6)) self.assertEqual(func().numpy(), 9) self.assertEqual(func(y=8).numpy(), 11)
def test_functools_partial_single_positional(self): def f(x, y): return x + y func = def_function.function( functools.partial(f, constant_op.constant(1))) self.assertAllEqual(func(5), 6)
def test_callable(self): class M1(tracking.AutoCheckpointable): @def_function.function( input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) def __call__(self, x): return x root = tracking.AutoCheckpointable() root.m1 = M1() root.m2 = tracking.AutoCheckpointable() root.m2.__call__ = def_function.function( input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])( lambda x: x*3.0) imported = self.cycle(root) x = constant_op.constant(1.0) self.assertTrue(callable(imported.m1)) self.assertAllEqual(root.m1(x), imported.m1(x)) # Note: `root.m2` was not callable since `__call__` attribute was set # into the instance and not on the class. But after a serialization cycle # that starts to work. self.assertTrue(callable(imported.m2)) self.assertAllEqual(root.m2.__call__(x), imported.m2(x)) # Verify that user objects without `__call__` attribute are not callable. self.assertFalse(callable(imported))
def _apply_all_reduce(reduction, tensors): """Helper function for all_* functions.""" if not tensors: raise ValueError('Must pass >0 tensors to all reduce operations') shared_name = _get_shared_name() def _all_reduce(): """Call nccl allreduce.""" res = [] for t in tensors: _check_device(t) with ops.device(t.device): res.append( gen_nccl_ops.nccl_all_reduce( input=t, reduction=reduction, num_devices=len(tensors), shared_name=shared_name)) return res if context.executing_eagerly(): # Nccl ops will block unless they are executed concurrently such as in a # graph or a defun. return def_function.function(_all_reduce)() else: return _all_reduce()
def testDecorate(self): func = def_function.function(lambda: 1) def decorator(f): return lambda: 1 + f() func._decorate(decorator) self.assertEqual(func().numpy(), 2)
def test_functools_partial_keywords(self): def f(x, y): return x + y func = def_function.function( functools.partial(f, x=array_ops.zeros([1]), y=array_ops.zeros([1]))) self.assertAllEqual(func(), [0.0])
def test_single_function_default_signature(self): model = tracking.AutoCheckpointable() model.f = def_function.function(lambda: 3., input_signature=()) model.f() save_dir = os.path.join(self.get_temp_dir(), "saved_model") save.save(model, save_dir) self.assertAllClose({"output_0": 3.}, _import_and_infer(save_dir, {}))
def test_non_concrete_error(self): root = tracking.AutoTrackable() root.f = def_function.function(lambda x: 2. * x) root.f(constant_op.constant(1.)) save_dir = os.path.join(self.get_temp_dir(), "saved_model") with self.assertRaisesRegexp( ValueError, "Expected a TensorFlow function"): save.save(root, save_dir, root.f)
def test_non_concrete_error(self): root = tracking.AutoCheckpointable() root.f = def_function.function(lambda x: 2. * x) root.f(constant_op.constant(1.)) save_dir = os.path.join(self.get_temp_dir(), "saved_model") with self.assertRaisesRegexp( ValueError, "must be converted to concrete functions"): save.save(root, save_dir, root.f)
def testGradient(self): matmul = def_function.function(math_ops.matmul) def sq(x): return matmul(x, x, transpose_a=True) t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) grad_t, = backprop.gradients_function(sq, [0])(t) self.assertAllEqual(grad_t, [[6, 6], [14, 14]])
def testFunctionReferenceCycles(self): fn = def_function.function(lambda x: 2. * x) fn(constant_op.constant(4.0)) weak_fn = weakref.ref(fn) del fn # Tests that the weak reference we made to the function is now dead, which # means the object has been deleted. This should be true as long as the # function itself is not involved in a reference cycle. self.assertIs(None, weak_fn())
def testTypeInvalid(self): root = tracking.AutoTrackable() root.v1 = variables.Variable(3.) root.v2 = variables.Variable(2.) root.f = def_function.function(lambda x: root.v1 * root.v2 * x) with self.assertRaises(ValueError) as error: _ = lite.TFLiteConverterV2.from_concrete_function(root.f) self.assertIn('call from_concrete_function', str(error.exception))
def test_ambiguous_signatures(self): model = _ModelWithOptimizer() x = constant_op.constant([[3., 4.]]) y = constant_op.constant([2.]) model.call(x, y) model.second_function = def_function.function(lambda: 1.) save_dir = os.path.join(self.get_temp_dir(), "saved_model") with self.assertRaisesRegexp(ValueError, "call.*second_function"): save.save(model, save_dir)
def test_nested_outputs(self): root = tracking.AutoCheckpointable() root.f = def_function.function(lambda x: (2. * x, (3. * x, 4. * x))) root.f(constant_op.constant(1.)) to_save = root.f.get_concrete_function(constant_op.constant(1.)) save_dir = os.path.join(self.get_temp_dir(), "saved_model") with self.assertRaisesRegexp( ValueError, "non-flat outputs"): save.save(root, save_dir, to_save)
def testGradient(self): # TODO(b/121134877): Remove the autograph override. matmul = def_function.function(math_ops.matmul, autograph=False) def sq(x): return matmul(x, x, transpose_a=True) t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) grad_t, = backprop.gradients_function(sq, [0])(t) self.assertAllEqual(grad_t, [[6, 6], [14, 14]])
def test_capture_variables(self): root = tracking.AutoCheckpointable() root.weights = variables.Variable(2.) root.f = def_function.function( lambda x: root.weights * x, input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) imported = self.cycle(root) self.assertEqual(4., imported.f(constant_op.constant(2.)).numpy()) imported.weights.assign(4.0) self.assertEqual(8., imported.f(constant_op.constant(2.)).numpy())
def test_captured_constant(self, cycles): const = array_ops.zeros([100]) root = tracking.AutoCheckpointable() root.f = def_function.function(lambda: const + 1.) root.g = def_function.function(lambda: const + 2.) self.assertAllClose(array_ops.ones([100]), root.f()) self.assertAllClose(2. * array_ops.ones([100]), root.g()) imported = self.cycle(root, cycles) self.assertAllClose(array_ops.ones([100]), imported.f()) self.assertAllClose(2. * array_ops.ones([100]), imported.g()) # TODO(b/123408994): Use the public get_concrete_function. f_concrete = imported.f._list_all_concrete_functions_for_serialization()[0] g_concrete = imported.g._list_all_concrete_functions_for_serialization()[0] self.assertLen(f_concrete.captured_inputs, 1) self.assertLen(g_concrete.captured_inputs, 1) # We should be using the same captured EagerTensor in both functions, not # duplicating the constant. self.assertIs(f_concrete.captured_inputs[0], g_concrete.captured_inputs[0])
def test_nested_inputs(self): root = tracking.Checkpointable() root.f = def_function.function(lambda x: 2. * x[0]) root.f([constant_op.constant(1.)]) to_export = root.f.get_concrete_function( [constant_op.constant(1.), constant_op.constant(2.)]) export_dir = os.path.join(self.get_temp_dir(), "saved_model") with self.assertRaisesRegexp( ValueError, "non-unique argument names"): export.export(root, export_dir, to_export)
def test_nested_dict_outputs(self): root = util.Checkpoint( f=def_function.function( lambda x: {"a": 2. * x, "b": (3. * x, 4. * x)})) root.f(constant_op.constant(1.)) to_save = root.f.get_concrete_function(constant_op.constant(1.)) save_dir = os.path.join(self.get_temp_dir(), "saved_model") with self.assertRaisesRegexp( ValueError, "dictionary containing non-Tensor value"): save.save(root, save_dir, to_save)
def test_nested_dict_outputs(self): root = tracking.Checkpointable() root.f = def_function.function( lambda x: {"a": 2. * x, "b": (3. * x, 4. * x)}) root.f(constant_op.constant(1.)) to_export = root.f.get_concrete_function(constant_op.constant(1.)) export_dir = os.path.join(self.get_temp_dir(), "saved_model") with self.assertRaisesRegexp( ValueError, "dictionary containing non-Tensor value"): export.export(root, export_dir, to_export)
def testAllGatherRaiseDiffShapeAtNonAxis(self, strategy, pure_eager): """Different at non-`axis`-th dimension : [2, 1], [1, 1], all_gather(...axis=1...) -> raise error.""" if _get_num_devices_per_worker(strategy) > 1: self.skipTest('b/167331966') def value_fn(ctx): return constant_op.constant( 1, shape=(1, ctx.replica_id_in_sync_group + 1)) per_replica_value = strategy.experimental_distribute_values_from_function( value_fn) def run(value): value_identity = array_ops.identity(value) ctx = ds_context.get_replica_context() return ctx._all_gather(value_identity, axis=0) if not pure_eager: run = def_function.function(run) with self.assertRaisesRegex(errors.InvalidArgumentError, r'Shape mismatch'): strategy.run(run, args=(per_replica_value, ))
def testSimpleReduce(self, strategy): def fn_eager(): def replica_fn(): return array_ops.ones((), dtypes.float32) per_replica_value = strategy.run(replica_fn) return strategy.reduce( reduce_util.ReduceOp.SUM, value=per_replica_value, axis=None) fn_graph = def_function.function(fn_eager) # Run reduce under the strategy scope to explicitly enter # strategy default_device scope. with strategy.scope(): self.assertEqual(fn_eager().numpy(), 1.0 * strategy.num_replicas_in_sync) self.assertEqual(fn_graph().numpy(), 1.0 * strategy.num_replicas_in_sync) # Run reduce without a strategy scope to implicitly enter # strategy default_device scope. self.assertEqual(fn_eager().numpy(), 1.0 * strategy.num_replicas_in_sync) self.assertEqual(fn_graph().numpy(), 1.0 * strategy.num_replicas_in_sync)
def testFloat(self): input_data = constant_op.constant(1., shape=[1]) root = tracking.AutoTrackable() root.v1 = variables.Variable(3.) root.v2 = variables.Variable(2.) root.f = def_function.function(lambda x: root.v1 * root.v2 * x) concrete_func = root.f.get_concrete_function(input_data) # Convert model. converter = lite.TFLiteConverterV2.from_concrete_function( concrete_func) converter.target_spec.supported_ops = set([lite.OpsSet.SELECT_TF_OPS]) tflite_model = converter.convert() # Ensures the model contains TensorFlow ops. # TODO(nupurgarg): Check values once there is a Python delegate interface. interpreter = Interpreter(model_content=tflite_model) with self.assertRaises(RuntimeError) as error: interpreter.allocate_tensors() self.assertIn( 'Regular TensorFlow ops are not supported by this interpreter. Make ' 'sure you invoke the Flex delegate before inference.', str(error.exception))
def _all_gather_same_shape_and_verify(self, value_on_replica, axis, pure_eager, strategy): per_replica_value = strategy.experimental_distribute_values_from_function( lambda _: array_ops.identity(value_on_replica)) def replica_fn(per_replica_value): ctx = ds_context.get_replica_context() local_value = array_ops.identity(per_replica_value) return ctx._all_gather(local_value, axis=axis) if not pure_eager: replica_fn = def_function.function(replica_fn) result = strategy.experimental_local_results( strategy.run(replica_fn, args=(per_replica_value, ))) all_value = [ value_on_replica for _ in range(strategy.num_replicas_in_sync) ] expect = array_ops.concat(all_value, axis=axis) expected_result = [expect] * _get_num_replicas_per_client(strategy) self.assertAllClose(expected_result, result)
def testConstSavedModel(self): """Test a basic model with functions to make sure functions are inlined.""" input_data = constant_op.constant(1., shape=[1]) root = tracking.AutoTrackable() root.f = def_function.function(lambda x: 2. * x) to_save = root.f.get_concrete_function(input_data) save_dir = os.path.join(self.get_temp_dir(), "saved_model") save(root, save_dir, to_save) saved_model = load(save_dir) input_func = saved_model.signatures["serving_default"] variable_graph_def = input_func.graph.as_graph_def() self.assertEqual(0, self._getNumVariables(variable_graph_def)) self.assertTrue(variable_graph_def.library.function) output_func = convert_to_constants.convert_variables_to_constants_v2( input_func) constant_graph_def = output_func.graph.as_graph_def() self.assertEqual(0, self._getNumVariables(constant_graph_def)) self.assertFalse(constant_graph_def.library.function) self._testConvertedFunction(root, root.f, output_func, input_data)
def test_positional_arguments(self, cycles): def func(x, training=False, abc=7.1, defg=7.7): del abc if training: return 2 * x if defg == 7: return 6 else: return 7 root = tracking.AutoTrackable() root.f = def_function.function(func) self.assertEqual(20, root.f(constant_op.constant(10), True).numpy()) self.assertEqual(7, root.f(constant_op.constant(1)).numpy()) self.assertEqual(2, root.f(constant_op.constant(1), True).numpy()) self.assertEqual(6, root.f(constant_op.constant(1), defg=7.0).numpy()) imported = self.cycle(root, cycles) self.assertEqual(4, imported.f(constant_op.constant(2), True).numpy()) self.assertEqual(7, imported.f(constant_op.constant(2)).numpy()) self.assertEqual(6, imported.f(constant_op.constant(1), defg=7.0).numpy())
def testVariableSavedModel(self): """Test a basic model with Variables with saving/loading the SavedModel.""" input_data = constant_op.constant(1., shape=[1]) root = tracking.AutoTrackable() root.v1 = variables.Variable(3.) root.v2 = variables.Variable(2.) root.f = def_function.function(lambda x: root.v1 * root.v2 * x) to_save = root.f.get_concrete_function(input_data) save_dir = os.path.join(self.get_temp_dir(), 'saved_model') save(root, save_dir, to_save) saved_model = load(save_dir) concrete_func = saved_model.signatures['serving_default'] # Convert model and ensure model is not None. converter = lite.TFLiteConverterV2.from_concrete_function( concrete_func) tflite_model = converter.convert() # Check values from converted model. expected_value = root.f(input_data) actual_value = self._evaluateTFLiteModel(tflite_model, [input_data]) self.assertEqual(expected_value.numpy(), actual_value)
def replica_fn(): collective, devices, pid = self.make_collective(options.num_processes, options.gpus_per_process, options.communication) def reduce_fn(): value_fn = lambda device_idx: inputs[pid * len(devices) + device_idx] per_replica_value = make_per_replica_value(value_fn, devices) reduced_values = collective.reduce(options.reduce_op, per_replica_value, per_replica_value) reduced_values = self.as_list(reduced_values) self.assertAllEqual(devices, [v.device for v in reduced_values]) return [ops.convert_to_tensor(v) for v in reduced_values] per_replica_expect = [ops.convert_to_tensor(expect)] * len(devices) if "eager" in options.mode: got = reduce_fn() self.assertAllClose(got, per_replica_expect) if "func_graph" in options.mode: got = def_function.function(reduce_fn)() self.assertAllClose(got, per_replica_expect)
def test_subclassed_no_signature(self): class Subclassed(training.Model): def call(self, inputs): return inputs * 2. save_dir = os.path.join(self.get_temp_dir(), "saved_model") model = Subclassed() with self.assertRaisesRegexp( ValueError, "no @tf.function-decorated methods"): save.save(model, save_dir) traced_call = def_function.function( model.call, input_signature=(tensor_spec.TensorSpec( (None, None), dtype=dtypes.float32),)) save.save(model, save_dir, traced_call) self.assertAllClose({"output_0": [[8., 10.], [10., 12.]]}, _import_and_infer( save_dir, {"inputs": [[4., 5.], [5., 6.]]}))
def test_load_in_func_graph(self, cycles): root = tracking.AutoCheckpointable() root.v1 = variables.Variable(1.) root.v2 = variables.Variable(2.) root.f = def_function.function( lambda x: root.v2 * x, input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) if cycles > 1: root = self.cycle(root, cycles - 1) path = tempfile.mkdtemp(prefix=self.get_temp_dir()) save.save(root, path) closure = tracking.AutoCheckpointable() @def_function.function def func(x): if not hasattr(closure, "model"): closure.model = load.load(path) return closure.model.f(x) inputs = constant_op.constant(2.) self.assertEqual(4.0, func(inputs).numpy())
def testKerasLSTM(self): input_data = constant_op.constant( np.array(np.random.random_sample((10, 10, 10)), dtype=np.float32)) model = keras.models.Sequential( [keras.layers.LSTM(units=10, input_shape=(10, 10))]) run_model = def_function.function(model.__call__) concrete_func = run_model.get_concrete_function( tensor_spec.TensorSpec((10, 10, 10), dtype=dtypes.float32)) # Convert model. converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func]) converter.experimental_enable_mlir_converter = True tflite_model = converter.convert() # Check values from converted model. expected_value = concrete_func(input_data) # TODO(haoliang): Remove this `reshape` op since it's not necessary. actual_value = np.reshape( self._evaluateTFLiteModel(tflite_model, [input_data]), (10, 10)) for expected, actual in zip(expected_value, actual_value): np.testing.assert_almost_equal(expected, actual)
def testGatherRaiseDiffShapeAtNonAxis(self, strategy, pure_eager): """Different at non-`axis`-th dimension : [1, 1], [1, 2], 0th -> raise error.""" if _get_num_devices_per_worker(strategy) > 1: self.skipTest('b/167331966') def value_fn(ctx): return constant_op.constant( 1, shape=(1, ctx.replica_id_in_sync_group + 1)) distributed_values = strategy.experimental_distribute_values_from_function( value_fn) axis = 0 def run(): return strategy._gather(distributed_values, axis=axis) error_message = 'Shape mismatch' if not pure_eager: run = def_function.function(run) with self.assertRaisesRegex(errors.InvalidArgumentError, error_message): run()
def test_saved_model(self): with self.device: different_values = self.device.pack( [constant_op.constant(-1.), constant_op.constant(3.)]) m = module.Module() m.v = variables.Variable(different_values) m.f = def_function.function(lambda: m.v * 2.) self.assertAllClose([-2., 6.], self.device.unpack(m.f())) saved_model_path = os.path.join(self.get_temp_dir(), "saved_model") save.save(m, saved_model_path) context._reset_context() self.setUp() single_device_loaded = load.load(saved_model_path) self.assertAllClose(-2., single_device_loaded.f()) with self.device: parallel_loaded = load.load(saved_model_path) self.assertAllClose([-2., 6.], self.device.unpack(parallel_loaded.f())) self.assertAllClose([-1., 3.], self.device.unpack(parallel_loaded.v)) parallel_loaded.v.assign(self.device.pack([.1, .2])) self.assertAllClose([.2, .4], self.device.unpack(parallel_loaded.f()))
def test_complicated_partial_with_defaults(self): def identity(*args): return args def dynamic_unroll(core_fn, input_sequence, initial_state, sequence_length=None, parallel_iterations=1, swap_memory=False): del core_fn self.assertIs(None, sequence_length) self.assertEqual(1, parallel_iterations) self.assertTrue(swap_memory) return input_sequence, initial_state input_sequence = random_ops.random_uniform([1, 1, 1]) initial_state = random_ops.random_uniform([1, 1]) func = def_function.function( functools.partial(dynamic_unroll, identity, swap_memory=True)) func(input_sequence, initial_state)
def testDerivative(self): with ops.device('device:{}:0'.format(self.device)): def fn(x, a): return 2 * x + a xla_func = def_function.function(fn, jit_compile=True) with backprop.GradientTape() as tape: inputs = constant_op.constant([1., 2., 2., 3., 3.]) tape.watch(inputs) outputs = xla_func(inputs, 1) self.assertAllClose([2, 2, 2, 2, 2], tape.gradient(outputs, inputs)) # pylint: disable=protected-access (forward, backward) = xla_func.get_concrete_function( inputs, 1)._delayed_rewrite_functions.forward_backward() # Check that the must-compile attribute gets correctly propagated to the # created derivatives. self.assertTrue(backward.function_def.attr['_XlaMustCompile']) self.assertTrue(forward.definition.attr['_XlaMustCompile'])
def test_variables_mismatched_device_assignment(self): resolver = get_tpu_cluster_resolver() remote.connect_to_cluster(resolver) topology = tpu_strategy_util.initialize_tpu_system(resolver) strategy0 = tpu_lib.TPUStrategyV2(resolver) self.assertEqual(("/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1"), strategy0.extended.worker_devices) with strategy0.scope(): v = variables.Variable(1.) v1_assign_op = strategy0.experimental_local_results(v)[1].assign(42.) with self.cached_session(): self.evaluate(variables.global_variables_initializer()) self.evaluate(v1_assign_op) self.assertAllEqual([1., 42.], self.evaluate( strategy0.experimental_local_results(v))) # Second strategy has devices reversed relative to the first. device_assignment = device_assignment_lib.DeviceAssignment( topology, core_assignment=[[[0, 0, 0, 1]], [[0, 0, 0, 0]]]) strategy1 = tpu_lib.TPUStrategyV2( resolver, experimental_device_assignment=device_assignment) self.assertEqual(("/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU:0"), strategy1.extended.worker_devices) v_read = strategy1.run(def_function.function(v.read_value)) with self.cached_session(): self.assertAllEqual( [42., 1.], self.evaluate(strategy0.experimental_local_results(v_read)))
def test_composite_tensor(self): with self.session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED operator, mat = self.operator_and_matrix( shapes_info, dtype, use_placeholder=use_placeholder) self.assertIsInstance(operator, composite_tensor.CompositeTensor) flat = nest.flatten(operator, expand_composites=True) unflat = nest.pack_sequence_as(operator, flat, expand_composites=True) self.assertIsInstance(unflat, type(operator)) # Input the operator to a `tf.function`. x = self.make_x(operator, adjoint=False) op_y = def_function.function(lambda op: op.matmul(x))(unflat) mat_y = math_ops.matmul(mat, x) if not use_placeholder: self.assertAllEqual(mat_y.shape, op_y.shape) # Test while_loop. def body(op): return type(op)(**op.parameters), op_out, = while_v2.while_loop(cond=lambda _: True, body=body, loop_vars=(operator, ), maximum_iterations=3) loop_y = op_out.matmul(x) op_y_, loop_y_, mat_y_ = sess.run([op_y, loop_y, mat_y]) self.assertAC(op_y_, mat_y_) self.assertAC(loop_y_, mat_y_) # Ensure that the `TypeSpec` can be encoded. nested_structure_coder.encode_structure(operator._type_spec) # pylint: disable=protected-access
def testClone(self, input_signature, autograph, autograph_options, implements, relax_shapes, compile_, override_function): original_py_function = lambda x: x compile_ = False func = def_function.function( func=original_py_function, input_signature=input_signature, autograph=autograph, experimental_implements=implements, experimental_autograph_options=autograph_options, experimental_relax_shapes=relax_shapes, jit_compile=compile_) if override_function: cloned_py_function = lambda x: x + 1 else: cloned_py_function = original_py_function cloned = func._clone(python_function=cloned_py_function) self.assertEqual(cloned_py_function, cloned._python_function) self.assertEqual(func._name, cloned._name) self.assertEqual(input_signature, cloned._input_signature) self.assertEqual(autograph, cloned._autograph) self.assertEqual(implements, cloned._implements) self.assertEqual(autograph_options, cloned._experimental_autograph_options) self.assertEqual(relax_shapes, cloned._experimental_relax_shapes) self.assertEqual(compile_, cloned._jit_compile) # This test does not run with XLA JIT support linked in so we can only check # the output of the function if compile is disabled. if not compile_: x = array_ops.zeros([]) self.assertEqual(self.evaluate(cloned(x)), self.evaluate(cloned_py_function(x)))
def testDistVarAsTFFunArg(self, strat, jit_replica_fn): """Tests that RNG with dist variables can be used as tf.function's arg.""" strat_name = type(strat).__name__ if "CentralStorage" in strat_name: self.skipTest( "CentralStorageStrategy wraps variable updates in merge_call which " "can't be called inside a tf.function that doesn't cover the entire " "replica function (the function passed to strategy.run).") if "TPU" in strat_name and not jit_replica_fn: self.skipTest( "TPUStrategy requires the replica function (the function passed to " "strategy.run) to be decorated with tf.function") coord = None if "ParameterServer" in strat_name: coord = coordinator_lib.ClusterCoordinator(strat) shape = [3, 4] dtype = dtypes.int32 with strat.scope(): gen = rng.Generator.from_seed(1234) @def_function.function def f(gen): # the main focus t1 = gen.uniform_full_int(shape=shape, dtype=dtype) t2 = gen.uniform_full_int(shape=shape, dtype=dtype) t = array_ops.stack([t1, t2]) return t def g(): return f(gen) replica_fn = def_function.function(g) if jit_replica_fn else g for _ in range(2): results = run_on_strategy(replica_fn, strat, coord) values = strat.experimental_local_results(results) n = get_num_local_replicas(strat, values) self.assertAllEqual(n, len(values)) self.assertAllDifferent(values)
def testAllGatherRaiseDiffShapeAtNonAxis(self, strategy, pure_eager): """Different at non-`axis`-th dimension : [2, 1], [1, 1], all_gather(...axis=1...) -> raise error.""" if _is_tpu_strategy(strategy): self.skipTest('TODO(b/169108777): raise a clear error message in xla.') if isinstance(strategy, CollectiveAllReduceStrategy ) and _get_num_replicas_per_client(strategy) > 1: self.skipTest('b/167331966') if strategy.num_replicas_in_sync <= 1: self.skipTest('Test for more than 1 replica only.') def value_fn(ctx): return constant_op.constant( 1, shape=(1, ctx.replica_id_in_sync_group + 1)) per_replica_value = strategy.experimental_distribute_values_from_function( value_fn) def run(value): value_identity = array_ops.identity(value) ctx = ds_context.get_replica_context() return ctx.all_gather(value_identity, axis=0) if not pure_eager: run = def_function.function(run) if isinstance(strategy, CollectiveAllReduceStrategy): with self.assertRaisesRegex(errors.InvalidArgumentError, r'Shape mismatch'): strategy.run(run, args=(per_replica_value,)) elif isinstance(strategy, (mirrored_strategy.MirroredStrategy, central_storage_strategy.CentralStorageStrategy)): with self.assertRaisesRegex((errors.InvalidArgumentError, ValueError), r'Dimension \d in both shapes must be equal'): strategy.run(run, args=(per_replica_value,))
def replica_fn(): CollectiveReplicaLauncher._prefer_unique_instance_key = ( options.prefer_unique_instance_key) collective, devices, pid = self.make_collective( options.num_processes, options.gpus_per_process) def batch_reduce_fn(): batch_size = len(inputs[0]) value_dst_pairs = [] for i in range(batch_size): def value_fn(device_idx, idx=i): return inputs[pid * len(devices) + device_idx][idx] per_replica_value = make_per_replica_value( value_fn, devices) value_dst_pairs.append( (per_replica_value, per_replica_value)) reduced_values = collective.batch_reduce( options.reduce_op, value_dst_pairs, options.communication_options) reduced_values = [self.as_list(v) for v in reduced_values] for v in reduced_values: self.assertAllEqual(devices, [t.device for t in v]) return nest.map_structure(ops.convert_to_tensor, reduced_values) per_replica_expect = nest.map_structure( lambda x: [ops.convert_to_tensor(x)] * len(devices), expect) if "eager" in options.mode: got = batch_reduce_fn() self.assertAllClose(got, per_replica_expect) if "func_graph" in options.mode: got = def_function.function(batch_reduce_fn)() self.assertAllClose(got, per_replica_expect)
def test_save_invalid_custom_gradients(self): # The full custom gradients test is in load_test.py. This test just makes # sure that a warning is logged when the user has disabled custom gradients # and saves a model with invalid custom gradients. @custom_gradient.custom_gradient def invalid(x): def grad(dy): raise NotImplementedError return x, grad root = tracking.AutoTrackable() root.f = def_function.function( invalid, input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) save_dir = os.path.join(self.get_temp_dir(), "saved_model") options = save_options.SaveOptions(experimental_custom_gradients=None) with self.assertLogs(level="WARNING") as logs: save.save(root, save_dir, root.f, options=options) self.assertIn( "Your model contains untraceable custom gradients. This " "warning may become an error in the future. Please set the option " "tf.saved_model.SaveOption(experimental_custom_gradients=True) to get " "the full error message.", "".join(logs.output))
def testInputSignatureMissingTensorSpecsLambdaFunction(self): tf_func_dec = def_function.function( input_signature=(tensor_spec.TensorSpec([], dtypes.int32), )) error_msg = 'TensorSpecs are still required.*arg2.*arg3' with self.assertRaisesRegex(TypeError, error_msg): tf_func_dec(lambda ar1, arg2, arg3: None)(1, 2, 3) with self.assertRaisesRegex(TypeError, error_msg): tf_func_dec(lambda arg1, arg2, arg3, **kwargs: None)(1, 2, 3) with self.assertRaisesRegex(TypeError, error_msg): tf_func_dec(lambda arg1, arg2, arg3, arg4=4, **kwargs: None)(1, 2, 3) with self.assertRaisesRegex(TypeError, error_msg): tf_func_dec(lambda arg1, arg2, arg3, *args: None)(1, 2, 3) with self.assertRaisesRegex(TypeError, error_msg): tf_func_dec(lambda arg1, arg2, arg3, *args, **kwargs: None)(1, 2, 3) self.assertEqual( tf_func_dec(lambda arg1, arg4=4, **kwargs: arg1 + arg4)(1).numpy(), 5)
def testIteratedGradientsNestedWithVariable(self): def _grad(f): def _grad_function(): with backprop.GradientTape() as tape: primal_out = f() g, = tape.gradient(primal_out, tape.watched_variables()) return g return _grad_function v = variables.Variable(2.) @def_function.function def _forward(): return math_ops.cos(v) f = _forward two = constant_op.constant(2.) for expected in _COS_DERIVATIVES: self.assertAllClose(expected(two), f()) self.assertAllClose(expected(two), def_function.function(f)()) f = _grad(f)
def testVariableModel(self): """Test a basic model with Variables.""" input_data = constant_op.constant(1., shape=[1]) root = tracking.AutoTrackable() root.v1 = variables.Variable(3.) root.v2 = variables.Variable(2.) root.f = def_function.function(lambda x: root.v1 * root.v2 * x) concrete_func = root.f.get_concrete_function(input_data) variable_graph_def = concrete_func.graph.as_graph_def() self.assertEqual(2, self._getNumVariables(variable_graph_def)) constant_graph_def = convert_to_constants.convert_variables_to_constants_v2( concrete_func) self.assertEqual(0, self._getNumVariables(constant_graph_def)) self.assertFalse( self._hasStatefulPartitionedCallOp(constant_graph_def)) # Check value. expected_value = root.f(input_data) actual_value = self._evaluateGraphDef(constant_graph_def, concrete_func, [input_data.numpy()]) self.assertEqual(expected_value.numpy(), actual_value)
def testCapturingInFunctionWhileExecutingEagerly(self): optimizer = gradient_descent.SGD(1.0) var_holder = {} def step(): if not var_holder: var_holder["var"] = variables.Variable(1.0) else: var_holder["var"].assign(1.0) with backprop.GradientTape() as tape: loss = var_holder["var"]**2 grad = tape.gradient(loss, var_holder["var"]) optimizer.apply_gradients([(grad, var_holder["var"])]) return var_holder["var"].read_value() compiled_step = def_function.function(step) self.assertEqual(float(step()), -1.0) self.assertEqual(float(compiled_step()), -1.0) # This shouldn't fail; in particular, the learning rate tensor should # be an EagerTensor once again, not a graph Tensor. self.assertEqual(float(step()), -1.0)
def __init__(self, dataset_fn, coordinator): """Makes an iterable from datasets created by the given function. Args: dataset_fn: A function that returns a `Dataset`. coordinator: a `ClusterCoordinator` object, used to create dataset resources. """ def disallow_variable_creation(next_creator, **kwargs): raise ValueError( "Creating variables in `dataset_fn` is not allowed.") if isinstance(dataset_fn, def_function.Function): with variable_scope.variable_creator_scope( disallow_variable_creation): dataset_fn = dataset_fn.get_concrete_function() elif not isinstance(dataset_fn, tf_function.ConcreteFunction): with variable_scope.variable_creator_scope( disallow_variable_creation): dataset_fn = def_function.function( dataset_fn).get_concrete_function() self._dataset_fn = dataset_fn self._coordinator = coordinator self._element_spec = None
def restore(self, restored_tensors, restored_shapes): del restored_shapes # Unused. restored_tensor_dict = {} for n, local_name in enumerate(self._local_names): restored_tensor_dict[local_name] = restored_tensors[n] def restore_from_tensors(): restore_fn = self._trackable._restore_from_tensors # pylint: disable=protected-access if (self._call_with_mapped_captures and isinstance(restore_fn, core.ConcreteFunction)): self._call_with_mapped_captures(restore_fn, [restored_tensor_dict]) else: restore_fn(restored_tensor_dict) # In graph mode, this wrapper function is converted into a tf.function, # and to ensure that _restore_from_tensors is executed, there must be at # least one returned tensor. `_restore_from_tensors` may return zero # tensors so create a dummy constant here. return constant_op.constant(1) if not ops.executing_eagerly_outside_functions(): restore_from_tensors = def_function.function(restore_from_tensors) return restore_from_tensors()
def testAllGatherNest1D0Axis(self, strategy, pure_eager): """all_gather(..., axis=0,...) a nest of DistributedValues.""" single_value = constant_op.constant([1, 2, 3]) axis = 0 def run(): value_identity = array_ops.identity(single_value) ctx = ds_context.get_replica_context() return ctx._all_gather([value_identity, value_identity], axis=axis) if not pure_eager: run = def_function.function(run) all_value = [ single_value for _ in range(strategy.num_replicas_in_sync) ] expect = array_ops.concat(all_value, axis=axis) expected_per_replica = [expect] * _get_num_devices_per_worker(strategy) result = strategy.run(run) for gathered_result in result: self.assertAllEqual( strategy.experimental_local_results(gathered_result), expected_per_replica)
def test_assets_import(self): file1 = self._make_asset("contents 1") file2 = self._make_asset("contents 2") root = tracking.Checkpointable() root.f = def_function.function( lambda x: 2. * x, input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) root.asset1 = tracking.TrackableAsset(file1) root.asset2 = tracking.TrackableAsset(file2) save_dir = os.path.join(self.get_temp_dir(), "save_dir") save.save(root, save_dir) file_io.delete_file(file1) file_io.delete_file(file2) load_dir = os.path.join(self.get_temp_dir(), "load_dir") file_io.rename(save_dir, load_dir) imported = load.load(load_dir) with open(imported.asset1.asset_path.numpy(), "r") as f: self.assertEquals("contents 1", f.read()) with open(imported.asset2.asset_path.numpy(), "r") as f: self.assertEquals("contents 2", f.read())