def _get_new_node_defs(self): """Gets new NodeDefs written by the NodeFileWriter. Returns: A list of new NodeDefs in the file written by NodeDefWriter since the last time this method was called. """ node_def_bytes = self.node_file.read() node_defs = [] cur_pos = 0 while cur_pos < len(node_def_bytes): size_bytes = node_def_bytes[cur_pos:cur_pos + 8] (size, ) = struct.unpack('<Q', size_bytes) cur_pos += 8 node_def = node_def_pb2.NodeDef() node_def.ParseFromString(node_def_bytes[cur_pos:cur_pos + size]) # When running eager op as function is enabled we expect these extra nodes # to show up in the list of executed nodes. ignored_ops = [] if context.run_eager_op_as_function_enabled(): ignored_ops.extend(['_Arg', '_Retval', 'NoOp']) # TODO(b/206047926): Fix or remove _Recv/_HostRecv from the ignored_ops. ignored_ops.extend(['_Recv', '_HostRecv']) if node_def.op not in ignored_ops: node_defs.append(node_def) cur_pos += size self.assertEqual(cur_pos, len(node_def_bytes)) return node_defs
def testNumericHigherOrder(self): def f(x): pointwise = math_ops.sin(x) * math_ops.tan(x) return math_ops.reduce_prod(pointwise + math_ops.reduce_sum(pointwise), axis=1) if (context.run_eager_op_as_function_enabled() and test_util.is_xla_enabled()): # Autoclustering kicks in when eager_op_as_function is enabled. # Under XLA the symbolic tolerances are less than under TF. # Ref: b/202559426 _test_gradients(self, f, [constant_op.constant([[2.0, 3.0], [1.0, 4.0]])], order=3, srtol=1e-6, satol=1e-3) else: _test_gradients(self, f, [constant_op.constant([[2.0, 3.0], [1.0, 4.0]])], order=3)
def _get_benchmark_name(self): """Copied from benchmarks_test.py.""" stack = tf_inspect.stack() name = None for frame in stack[::-1]: f_locals = frame[0].f_locals f_self = f_locals.get("self", None) if isinstance(f_self, test.Benchmark): name = frame[3] # Get the method name # This is a hack to get around the fact that some methods might have a # disable_tfrt decorator around them. In that case a function called # 'decorated' wraps the real called function underneath and so we # peek one deeper into the stack to get the real name. if name == "decorated": continue else: break if name is None: raise ValueError("Unable to determine calling Benchmark function.") if context.is_tfrt_enabled(): name = name + "_tfrt" if context.run_eager_op_as_function_enabled(): name = name + "_eager_op_as_function" return name
def use_anonymous_iterator_v3(): return (forward_compat.forward_compatible(2022, 1, 6) or context.run_eager_op_as_function_enabled())