def _defun_internal(name, func, args, kwds): """Defines and returns graph-mode version of func.""" graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access with context.graph_mode(): captures = {} tmp_graph = CapturingGraph(captures) # Inherit the graph key, since this is used for matching variables in # optimizers. tmp_graph._graph_key = graph_key # pylint: disable=protected-access # Copy the graph collections to ensure summaries and other things work. This # lets the function access (but not mutate) collections of the containing # graph, such as the global step and the summary writer collections. curr_graph = ops.get_default_graph() for collection in curr_graph.collections: tmp_graph.get_collection_ref(collection)[:] = curr_graph.get_collection( collection) with tmp_graph.as_default(): func_inputs = _get_defun_inputs(args) with capture_tensors(captures): this_tape = tape.push_new_tape() try: func_outputs = func(*func_inputs, **kwds) finally: tape.pop_tape(this_tape) variables = this_tape.watched_variables() # Returning a closed-over tensor as an output does not trigger a # call to convert_to_tensor, so we manually capture all such tensors. outputs_list = _flatten(func_outputs) func_def_outputs = [ _convert_to_graph_tensor(x) for x in outputs_list if x is not None ] ids = list(sorted(captures.keys())) if ids: extra_inputs, extra_placeholders = zip(* [captures[x] for x in ids]) else: extra_inputs = [] extra_placeholders = [] output_shapes = tuple( x.shape if isinstance(x, ops.Tensor) else None for x in outputs_list) flat_inputs = [x for x in nest.flatten(func_inputs) if isinstance(x, ops.Tensor)] all_inputs = flat_inputs + list(extra_placeholders) all_ignored_ops = frozenset(x.op for x in all_inputs) fname = _inference_name(name) operations = tuple(x for x in tmp_graph.get_operations() if x not in all_ignored_ops) # Register any other functions defined in the graph # TODO(ashankar): Oh lord, forgive me for this lint travesty. if context.in_eager_mode(): for f in tmp_graph._functions.values(): # pylint: disable=protected-access # TODO(ashankar): What about the gradient registry? _register(f._c_func) # pylint: disable=protected-access return GraphModeFunction( fname, all_inputs, extra_inputs, tmp_graph, operations, func_def_outputs, func_outputs, output_shapes, variables)
def get_seed(op_seed): """Returns the local seeds an operation should use given an op-specific seed. Given operation-specific seed, `op_seed`, this helper function returns two seeds derived from graph-level and op-level seeds. Many random operations internally use the two seeds to allow user to change the seed globally for a graph, or for only specific operations. For details on how the graph-level seed interacts with op seeds, see [`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed). Args: op_seed: integer. Returns: A tuple of two integers that should be used for the local seed of this operation. """ graph_seed = ops.get_default_graph().seed if graph_seed is not None: if op_seed is not None: return _truncate_seed(graph_seed), _truncate_seed(op_seed) else: return _truncate_seed(graph_seed), _truncate_seed(ops.get_default_graph()._last_id) else: if op_seed is not None: return _truncate_seed(_DEFAULT_GRAPH_SEED), _truncate_seed(op_seed) else: return None, None
def _WeMustGoDeeper(self, msg): with self.assertRaisesOpError(msg): node_def = ops._NodeDef("op_type", "name") node_def_orig = ops._NodeDef("op_type_orig", "orig") op_orig = ops.Operation(node_def_orig, ops.get_default_graph()) op = ops.Operation(node_def, ops.get_default_graph(), original_op=op_orig) raise errors.UnauthenticatedError(node_def, op, "true_err")
def testParallelApplyGradMean(self): # We need each thread to keep its own device stack or the device scopes # won't be properly nested. ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.SparseConditionalAccumulator( dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2])) elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0] accum_ops = [] for x in elems: x = _indexedslice(np.array([[x, 0], [0, x]]).astype(np.float32)) accum_ops.append(q.apply_indexed_slices_grad(x, local_step=0)) takeg_t = q.take_indexed_slices_grad(1) def apply_indexed_slices_grad(accum_op): self.evaluate(accum_op) threads = [ self.checkedThread( target=apply_indexed_slices_grad, args=(o,)) for o in accum_ops ] for thread in threads: thread.start() for thread in threads: thread.join() val = self.evaluate(takeg_t) expected_val = sum(elems) / len(elems) self._assertEqual_nparray( np.array([[expected_val, 0], [0, expected_val]]).astype(np.float32), val, sess)
def wrapped_body(loop_counter, *args): """Loop body augmented with counter update. Args: loop_counter: Loop counter which needs to be incremented in the body. *args: List of args Returns: A list of tensors the same length as args. """ # Capture the tensors already captured in cond_graph so that they appear # in the same order in body_graph.external_captures. for t in cond_graph.external_captures: ops.get_default_graph().capture(t) # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. outputs = body(*_pack_sequence_as(orig_loop_vars, args)) if not nest.is_sequence(outputs): outputs = [outputs] # Compare the structure of input and output of body converting the # top-level tuples to list to be compatible with legacy while_loop. nest.assert_same_structure(list(outputs), list(orig_loop_vars)) outputs = _tensor_array_to_flow(outputs) # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. return [loop_counter + 1] + list(outputs)
def testParallelAssignWithoutLocking(self): # We need each thread to keep its own device stack or the device scopes # won't be properly nested. ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: ones_t = array_ops.fill([1024, 1024], float(1)) p = variables.Variable(array_ops.zeros([1024, 1024])) assigns = [ state_ops.assign(p, math_ops.multiply(ones_t, float(i)), False) for i in range(1, 21) ] self.evaluate(variables.global_variables_initializer()) def run_assign(assign_op): self.evaluate(assign_op) threads = [ self.checkedThread( target=run_assign, args=(assign_op,)) for assign_op in assigns ] for t in threads: t.start() for t in threads: t.join() vals = self.evaluate(p) # Assert every element is taken from one of the assignments. self.assertTrue((vals > 0).all()) self.assertTrue((vals <= 20).all())
def before_run(self, run_context): """ Dumps graphs and loads checkpoint if there exits. Called before each call to run(). Args: run_context: A `SessionRunContext` object. Returns: A `SessionRunArgs` object containing global_step. """ # We do write graph and saver_def at the first call of before_run. # We cannot do this in begin, since we let other hooks to change graph and # add variables in begin. Graph is finalized after all begin calls. if self._is_chief and self._first_call: training_util.write_graph( ops.get_default_graph().as_graph_def(add_shapes=True), self._checkpoint_dir, "graph.pbtxt") # dump model details "model_analysis.txt" dump_model_analysis(self._checkpoint_dir) # dump model configs graph = ops.get_default_graph() meta_graph_def = meta_graph.create_meta_graph_def( graph_def=graph.as_graph_def(add_shapes=True), saver_def=self._saver.saver_def) if self._summary_writer is not None: self._summary_writer.add_graph(graph) self._summary_writer.add_meta_graph(meta_graph_def) tf.logging.info("CheckpointSaverHook (before_run): dump graph...") self._first_call = False return tf.train.SessionRunArgs(self._global_step)
def testParallelAssignWithLocking(self): # We need each thread to keep its own device stack or the device scopes # won't be properly nested. ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: zeros_t = array_ops.fill([1024, 1024], 0.0) ones_t = array_ops.fill([1024, 1024], 1.0) p = variables.Variable(zeros_t) assigns = [ state_ops.assign( p, math_ops.multiply(ones_t, float(i)), use_locking=True) for i in range(1, 21) ] self.evaluate(p.initializer) def run_assign(assign_op): self.evaluate(assign_op) threads = [ self.checkedThread( target=run_assign, args=(assign_op,)) for assign_op in assigns ] for t in threads: t.start() for t in threads: t.join() vals = self.evaluate(p) # Assert every element is the same, and taken from one of the assignments. self.assertTrue(vals[0, 0] > 0) self.assertTrue(vals[0, 0] <= 20) self.assertAllEqual(vals, np.ones([1024, 1024]) * vals[0, 0])
def testParallelUpdateWithoutLocking(self): # We need each thread to keep its own device stack or the device scopes # won't be properly nested. ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: ones_t = array_ops.fill([1024, 1024], 1.0) p = variables.Variable(array_ops.zeros([1024, 1024])) adds = [ state_ops.assign_add( p, ones_t, use_locking=False) for _ in range(20) ] self.evaluate(variables.global_variables_initializer()) def run_add(add_op): self.evaluate(add_op) threads = [ self.checkedThread( target=run_add, args=(add_op,)) for add_op in adds ] for t in threads: t.start() for t in threads: t.join() vals = self.evaluate(p) ones = np.ones((1024, 1024)).astype(np.float32) self.assertTrue((vals >= ones).all()) self.assertTrue((vals <= ones * 20).all())
def _parse_kwargs_as_attrs(func_name, **kwargs): """Parses **kwargs into a node's attributes.""" attrs = {} noinline = kwargs.pop("noinline", None) if noinline is not None: attrs["_noinline"] = attr_value_pb2.AttrValue(b=bool(noinline)) compiled = kwargs.pop("compiled", None) separate_compiled_gradients = kwargs.pop("separate_compiled_gradients", None) if compiled is not None: attrs["_XlaCompile"] = attr_value_pb2.AttrValue(b=bool(compiled)) attrs["_XlaSeparateCompiledGradients"] = attr_value_pb2.AttrValue( b=bool(separate_compiled_gradients)) # Forward _XlaScope from enclosing context (if set), otherwise create new. # pylint: disable=protected-access if "_XlaScope" in ops.get_default_graph()._attr_scope_map: attrs["_XlaScope"] = ops.get_default_graph()._attr_scope_map["_XlaScope"] else: attrs["_XlaScope"] = attr_value_pb2.AttrValue( s=("function_%s" % func_name).encode()) # pylint: enable=protected-access if kwargs: raise ValueError("Unknown keyword arguments: %s" % kwargs.keys()) return attrs
def get_seed(op_seed): """Returns the local seeds an operation should use given an op-specific seed. Given operation-specific seed, `op_seed`, this helper function returns two seeds derived from graph-level and op-level seeds. Many random operations internally use the two seeds to allow user to change the seed globally for a graph, or for only specific operations. For details on how the graph-level seed interacts with op seeds, see @{tf.set_random_seed}. Args: op_seed: integer. Returns: A tuple of two integers that should be used for the local seed of this operation. """ graph_seed = ops.get_default_graph().seed if graph_seed is not None: if op_seed is None: # pylint: disable=protected-access op_seed = ops.get_default_graph()._last_id seeds = _truncate_seed(graph_seed), _truncate_seed(op_seed) else: if op_seed is not None: seeds = DEFAULT_GRAPH_SEED, _truncate_seed(op_seed) else: seeds = None, None # Avoid (0, 0) as the C++ ops interpret it as nondeterminism, which would # be unexpected since Python docs say nondeterminism is (None, None). if seeds == (0, 0): return (0, _MAXINT32) return seeds
def copy_scoped_meta_graph(from_scope, to_scope, from_graph=None, to_graph=None): """Copies a sub-meta_graph from one scope to another. Args: from_scope: `String` name scope containing the subgraph to be copied. to_scope: `String` name scope under which the copied subgraph will reside. from_graph: Optional `Graph` from which to copy the subgraph. If `None`, the default graph is use. to_graph: Optional `Graph` to which to copy the subgraph. If `None`, the default graph is used. Returns: A dictionary of `Variables` that has been copied into `to_scope`. Raises: ValueError: If `from_scope` and `to_scope` are the same while `from_graph` and `to_graph` are also the same. """ from_graph = from_graph or ops.get_default_graph() to_graph = to_graph or ops.get_default_graph() if from_graph == to_graph and from_scope == to_scope: raise ValueError("'from_scope' and 'to_scope' need to be different " "when performing copy in the same graph.") orig_meta_graph, var_list = export_scoped_meta_graph( export_scope=from_scope, graph=from_graph) var_list = import_scoped_meta_graph(orig_meta_graph, graph=to_graph, import_scope=to_scope) return var_list
def _testDefaultGraphInThread(self, constructed_event, continue_event, i): with session.Session() as s: self.assertEqual(ops.get_default_graph(), s.graph) a = constant_op.constant(1.0, shape=[1, 2]) b = constant_op.constant(2.0, shape=[2, 3]) c = math_ops.matmul(a, b) v = variables.Variable(c, name='var_%d' % i) # Block here until all threads have constructed their graph. constructed_event.set() continue_event.wait() assign_c_to_v = state_ops.assign(v, c) v.initializer.run() assign_c_to_v.eval() v_val = v.eval() self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) d = constant_op.constant(3.0, shape=[2, 3]) e = math_ops.matmul(a, d) assign_e_to_v = state_ops.assign(v, e) e_val = e.eval() self.assertAllEqual([[6.0, 6.0, 6.0]], e_val) v_val = v.eval() self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) s.run(assign_e_to_v) v_val = v.eval() self.assertAllEqual([[6.0, 6.0, 6.0]], v_val) self.assertEqual(ops.get_default_graph(), s.graph)
def container(self, container_name): """Returns a context manager that specifies the resource container to use. Overridden from `tf.Graph` to update both the init_scope container and the present inner container. This is necessary to make sure setting containers applies correctly both to created variables and to stateful ops. Args: container_name: container name string. Returns: A context manager for defining resource containers for stateful ops, yields the container name. """ original_container = self._container # pylint: disable=protected-access with ops.init_scope(): original_init_container = ops.get_default_graph()._container try: self._container = container_name with ops.init_scope(): ops.get_default_graph()._container = container_name yield self._container finally: self._container = original_container with ops.init_scope(): ops.get_default_graph()._container = original_init_container
def testIteratorStringHandleReuseTensorObject(self): dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) one_shot_iterator = dataset_ops.make_one_shot_iterator(dataset) initializable_iterator = dataset_ops.make_initializable_iterator(dataset) structure_iterator = iterator_ops.Iterator.from_structure( dataset.output_types) created_ops = len(ops.get_default_graph().get_operations()) self.assertIs(one_shot_iterator.string_handle(), one_shot_iterator.string_handle()) self.assertIs(initializable_iterator.string_handle(), initializable_iterator.string_handle()) self.assertIs(structure_iterator.string_handle(), structure_iterator.string_handle()) # Assert that getting the (default) string handle creates no ops. self.assertEqual(created_ops, len(ops.get_default_graph().get_operations())) # Specifying an explicit name will create a new op. handle_with_name = one_shot_iterator.string_handle(name="foo") self.assertEqual("foo", handle_with_name.op.name) self.assertIsNot(one_shot_iterator.string_handle(), handle_with_name) handle_with_same_name = one_shot_iterator.string_handle(name="foo") self.assertEqual("foo_1", handle_with_same_name.op.name) self.assertIsNot(handle_with_name, handle_with_same_name)
def decorator(self, **kwargs): """Finds existing Tensors, runs the test, checks for new Tensors.""" def _is_tensor(obj): try: return (isinstance(obj, ops.Tensor) or isinstance(obj, variables.Variable)) except ReferenceError: # If the object no longer exists, we don't care about it. return False tensors_before = set(id(obj) for obj in gc.get_objects() if _is_tensor(obj)) outside_container_prefix = ops.get_default_graph()._container_prefix with IsolateTest(): # Run the test in a new graph so that collections get cleared when it's # done, but inherit the container prefix so that we can print the values # of variables which get leaked when executing eagerly. ops.get_default_graph()._container_prefix = outside_container_prefix f(self, **kwargs) # Make an effort to clear caches, which would otherwise look like leaked # Tensors. backprop._last_zero = [None] backprop._shape_dtype = [None, None] context.get_default_context().scalar_cache().clear() gc.collect() tensors_after = [ obj for obj in gc.get_objects() if _is_tensor(obj) and id(obj) not in tensors_before ] if tensors_after: raise AssertionError(("%d Tensors not deallocated after test: %s" % ( len(tensors_after), str(tensors_after), )))
def __init__(self, master, is_chief=True, checkpoint_dir=None, monitors=None, scaffold=None, config=None): self._graph = ops.get_default_graph() self._master = master self._checkpoint_dir = checkpoint_dir self._is_chief = is_chief self._config = config self._monitors = monitors or [] self._scaffold = scaffold or Scaffold() # Finalize and write the graph. self._graph.finalize() # Create the session. self._session_manager = sm.SessionManager( local_init_op=self._scaffold.local_init_op, ready_op=self._scaffold.ready_op, graph=ops.get_default_graph()) self._sess = recoverable_session.RecoverableSession(self._create_session) # Call the begin() method of monitors. self._init_step = self._tf_sess.run(self._scaffold.global_step_tensor) for monitor in self._monitors: monitor.begin(max_steps=None) # Write the graph out, note: this uses self._init_step. self.write_graph()
def testRead(self): for batch_size in [1, 2]: for num_epochs in [1, 10]: with ops.Graph().as_default(): with self.test_session(graph=ops.get_default_graph()) as sess: # Basic test: read from file 0. self.outputs = self._read_batch_features( filenames=self.test_filenames[0], num_epochs=num_epochs, batch_size=batch_size) self._verify_records(sess, batch_size, 0, num_epochs=num_epochs) with self.assertRaises(errors.OutOfRangeError): self._next_actual_batch(sess) with ops.Graph().as_default(): with self.test_session(graph=ops.get_default_graph()) as sess: # Basic test: read from file 1. self.outputs = self._read_batch_features( filenames=self.test_filenames[1], num_epochs=num_epochs, batch_size=batch_size) self._verify_records(sess, batch_size, 1, num_epochs=num_epochs) with self.assertRaises(errors.OutOfRangeError): self._next_actual_batch(sess) with ops.Graph().as_default(): with self.test_session(graph=ops.get_default_graph()) as sess: # Basic test: read from both files. self.outputs = self._read_batch_features( filenames=self.test_filenames, num_epochs=num_epochs, batch_size=batch_size) self._verify_records(sess, batch_size, num_epochs=num_epochs) with self.assertRaises(errors.OutOfRangeError): self._next_actual_batch(sess)
def testGraphWithoutVariables(self): export_dir = self._get_export_dir("test_graph_has_variables") builder = saved_model_builder.SavedModelBuilder(export_dir) # Graph with no variables. with self.test_session(graph=ops.Graph()) as sess: constant_5_name = constant_op.constant(5.0).name builder.add_meta_graph_and_variables(sess, ["foo"]) # Second graph with no variables with self.test_session(graph=ops.Graph()) as sess: constant_6_name = constant_op.constant(6.0).name builder.add_meta_graph(["bar"]) # Save the SavedModel to disk. builder.save() # Restore the graph with tag "foo". with self.test_session(graph=ops.Graph()) as sess: loader.load(sess, ["foo"], export_dir) # Read the constant a from the graph. a = ops.get_default_graph().get_tensor_by_name(constant_5_name) b = constant_op.constant(6.0) c = a * b self.assertEqual(30.0, sess.run(c)) # Restore the graph with tag "bar". with self.test_session(graph=ops.Graph()) as sess: loader.load(sess, ["bar"], export_dir) # Read the constant a from the graph. a = ops.get_default_graph().get_tensor_by_name(constant_6_name) b = constant_op.constant(5.0) c = a * b self.assertEqual(30.0, sess.run(c))
def testPartitionConcatenatesAlongCorrectAxis(self): def _part_axis_0(**unused_kwargs): return (2, 1, 1) def _part_axis_1(**unused_kwargs): return (1, 2, 1) with variable_scope.variable_scope("root"): v0 = variable_scope.get_variable( "n0", shape=(2, 2, 2), partitioner=_part_axis_0) v1 = variable_scope.get_variable( "n1", shape=(2, 2, 2), partitioner=_part_axis_1) self.assertEqual(v0.get_shape(), (2, 2, 2)) self.assertEqual(v1.get_shape(), (2, 2, 2)) n0_0 = ops.get_default_graph().get_tensor_by_name("root/n0/part_0:0") n0_1 = ops.get_default_graph().get_tensor_by_name("root/n0/part_1:0") self.assertEqual(n0_0.get_shape(), (1, 2, 2)) self.assertEqual(n0_1.get_shape(), (1, 2, 2)) n1_0 = ops.get_default_graph().get_tensor_by_name("root/n1/part_0:0") n1_1 = ops.get_default_graph().get_tensor_by_name("root/n1/part_1:0") self.assertEqual(n1_0.get_shape(), (2, 1, 2)) self.assertEqual(n1_1.get_shape(), (2, 1, 2))
def finalize(self): """Creates operations if needed and finalizes the graph.""" if self._init_op is None: self._init_op = Scaffold._get_or_default( 'init_op', ops.GraphKeys.INIT_OP, variables.initialize_all_variables) if self._ready_op is None: self._ready_op = Scaffold._get_or_default( 'ready_op', ops.GraphKeys.READY_OP, variables.report_uninitialized_variables) if self._local_init_op is None: self._local_init_op = Scaffold._get_or_default( 'local_init_op', ops.GraphKeys.LOCAL_INIT_OP, Scaffold._default_local_init_op) if self._summary_op is None: self._summary_op = Scaffold._get_or_default( 'summary_op', ops.GraphKeys.SUMMARY_OP, logging_ops.merge_all_summaries) # pylint: disable=g-long-lambda if self._saver is None: self._saver = Scaffold._get_or_default( 'saver', ops.GraphKeys.SAVERS, lambda: training_saver.Saver(sharded=True, allow_empty=True)) # pylint: enable=g-long-lambda self._saver.build() ops.get_default_graph().finalize() return self
def finalize(self): """Creates operations if needed and finalizes the graph.""" if self._global_step_tensor is None: self._global_step_tensor = contrib_variables.get_or_create_global_step() if self._init_op is None: self._init_op = Scaffold._get_or_default( 'init_op', ops.GraphKeys.INIT_OP, variables.initialize_all_variables) if self._ready_op is None: self._ready_op = Scaffold._get_or_default( 'ready_op', ops.GraphKeys.READY_OP, variables.report_uninitialized_variables) if self._local_init_op is None: self._local_init_op = Scaffold._get_or_default( 'local_init_op', ops.GraphKeys.LOCAL_INIT_OP, Scaffold._default_local_init_op) if self._summary_op is None: self._summary_op = Scaffold._get_or_default( 'summary_op', ops.GraphKeys.SUMMARY_OP, logging_ops.merge_all_summaries) # pylint: disable=g-long-lambda if self._saver is None: self._saver = Scaffold._get_or_default( 'saver', ops.GraphKeys.SAVERS, lambda: training_saver.Saver(sharded=True, max_to_keep=self._keep_checkpoint_max)) # pylint: enable=g-long-lambda ops.get_default_graph().finalize()
def test_assign_stays_in_true_dtype(self, distribute): with get_distribute_scope(distribute): x = get_var(1., dtypes.float32) x = get_autocast_var(x, distribute) self.evaluate(x.initializer) # small_val is a value such that 1.0 + small_val == 1.0 in fp16, but not # in fp32 small_val = np.finfo('float16').eps / 2 small_tensor = constant_op.constant(small_val, dtype=dtypes.float32) with ops.get_default_graph()._enable_auto_casting_variables( dtypes.float16): # Variable should be increased, despite it appearing to be the same # float16 value. self.assertEqual(1. + small_val, self.evaluate(x.assign(1. + small_tensor))) self.assertEqual(1., self.evaluate(x.value())) self.assertEqual(1. + small_val, self.evaluate(x.value())) self.evaluate(x.assign(1.)) with ops.get_default_graph()._enable_auto_casting_variables( dtypes.float16): self.assertEqual(1. + small_val, self.evaluate(x.assign_add(small_tensor))) self.assertEqual(1., self.evaluate(x.value())) self.assertEqual(1. + small_val, self.evaluate(x.value()))
def testAccumulatorApplyAndBlockingTake(self): # We need each thread to keep its own device stack or the device scopes # won't be properly nested. ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.ConditionalAccumulator( dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1])) elems = [10.0, 20.0, 30.0] elems_ave = sum(elems) / len(elems) accum_ops = [q.apply_grad((x,), local_step=0) for x in elems] takeg_t = q.take_grad(3) def apply_grad(): time.sleep(1.0) for accum_op in accum_ops: self.evaluate(accum_op) return_array = [] def take_grad(): return_array.append(self.evaluate(takeg_t)) accum_thread = self.checkedThread(target=apply_grad) takeg_thread = self.checkedThread(target=take_grad) accum_thread.start() takeg_thread.start() accum_thread.join() takeg_thread.join() self.assertEqual([elems_ave], return_array)
def test_operator_overloads(self, distribute): with get_distribute_scope(distribute): x = get_var(1., dtypes.float32) x = get_autocast_var(x, distribute) self.evaluate(x.initializer) v1 = constant_op.constant(2., dtype=dtypes.float32) v2 = constant_op.constant(2., dtype=dtypes.float16) # Because autocast variables do not yet define operator overloads, the # operator is defined by the non-variable tensor # Test variable as the LHS. Currently, this is not supported with # distributed autocast variables if not distribute: self.assertEqual(self.evaluate(x + v1), 3.) with ops.get_default_graph()._enable_auto_casting_variables( dtypes.float16): self.assertEqual(self.evaluate(x + v2), 3.) # Test variable as the RHS self.assertEqual(self.evaluate(v1 + x), 3.) with ops.get_default_graph()._enable_auto_casting_variables( dtypes.float16): self.assertEqual(self.evaluate(v2 + x), 3.)
def testParallelApplyGrad(self): # We need each thread to keep its own device stack or the device scopes # won't be properly nested. ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.ConditionalAccumulator( dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1])) elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0] accum_ops = [q.apply_grad((x,), local_step=0) for x in elems] takeg_t = q.take_grad(1) def apply_grad(accum_op): self.evaluate(accum_op) threads = [ self.checkedThread( target=apply_grad, args=(o,)) for o in accum_ops ] for thread in threads: thread.start() for thread in threads: thread.join() val = self.evaluate(takeg_t) self.assertEqual(val, sum(elems) / len(elems))
def testParallelTakeGrad(self): # We need each thread to keep its own device stack or the device scopes # won't be properly nested. ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.ConditionalAccumulator( dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1])) elems = [e for e in range(10)] accum_ops = [q.apply_grad((np.float32(e),), local_step=e) for e in elems] takeg_t = q.take_grad(1) def apply_grad(): for accum_op in accum_ops: time.sleep(1.0) self.evaluate(accum_op) apply_grad_thread = self.checkedThread(target=apply_grad) results = [] def take_grad(): results.append(self.evaluate(takeg_t)) threads = [self.checkedThread(target=take_grad) for _ in range(10)] for thread in threads: thread.start() apply_grad_thread.start() for thread in threads: thread.join() apply_grad_thread.join() self.assertItemsEqual(elems, results)
def test_read(self, distribute): with get_distribute_scope(distribute): x = get_var(1., dtypes.float32) x = get_autocast_var(x, distribute) self.evaluate(x.initializer) # outside of auto cast scope. self.assertEqual(x.dtype, dtypes.float32) self.assertEqual(x.value().dtype, dtypes.float32) self.assertEqual(x.read_value().dtype, dtypes.float32) self.assertEqual(array_ops.identity(x).dtype, dtypes.float32) # within auto cast scope of different dtype with ops.get_default_graph()._enable_auto_casting_variables( dtypes.float16): self.assertEqual(x.dtype, dtypes.float16) self.assertEqual(x.value().dtype, dtypes.float16) self.assertEqual(x.read_value().dtype, dtypes.float16) self.assertEqual(array_ops.identity(x).dtype, dtypes.float16) # within auto cast scope of same dtype with ops.get_default_graph()._enable_auto_casting_variables( dtypes.float32): self.assertEqual(x.dtype, dtypes.float32) self.assertEqual(x.value().dtype, dtypes.float32) self.assertEqual(x.read_value().dtype, dtypes.float32) self.assertEqual(array_ops.identity(x).dtype, dtypes.float32)
def testAccumulatorApplyAndBlockingTake(self): # We need each thread to keep its own device stack or the device scopes # won't be properly nested. ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.SparseConditionalAccumulator( dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2])) elems = [10.0, 20.0, 30.0] elems_ave = sum(elems) / len(elems) accum_ops = [] for x in elems: x = _indexedslice(np.array([[0, x], [0, 0]]).astype(np.float32)) accum_ops.append(q.apply_indexed_slices_grad(x, local_step=0)) takeg_t = q.take_indexed_slices_grad(3) results = [] def apply_indexed_slices_grad(): for accum_op in accum_ops: self.evaluate(accum_op) def take_grad(): results.append(self.evaluate(takeg_t)) accum_thread = self.checkedThread(target=apply_indexed_slices_grad) takeg_thread = self.checkedThread(target=take_grad) accum_thread.start() takeg_thread.start() accum_thread.join() takeg_thread.join() self._assertEqual_nparray([[0, elems_ave], [0, 0]], results[0], sess)
def test_graph_replace_gradients(self): ops.reset_default_graph() w = variables.VariableV1(0.0, name="w") y = math_ops.multiply(math_ops.multiply(w, w, name="mul1"), w, name="mul2") g = gradients_impl.gradients(y, w, name="grad")[0] # Extract the operations. replacement_ts = {w.value(): g} original_mul1_grad = (ops.get_default_graph(). get_operation_by_name("grad/mul1_grad/Mul_1")) # Should not raise exception. res = ge.graph_replace(g, replacement_ts, dst_scope="res") # Extract the operations after graph_replace. result_mul1_grad = (ops.get_default_graph(). get_operation_by_name("res/grad/mul1_grad/Mul_1")) # Make sure _original_ops are as expected. self.assertEqual(original_mul1_grad._original_op.name, u"mul1") self.assertEqual(result_mul1_grad._original_op.name, u"res/mul1") self.assertNotEqual(res.name, g.name) with session.Session() as sess: sess.run(variables.global_variables_initializer()) g_val, res_val = sess.run([g, res]) self.assertNear(g_val, 0.0, ERROR_TOLERANCE) self.assertNear(res_val, 0.0, ERROR_TOLERANCE)
def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, initial_loop_values=None): if initial_loop_values is None: initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) ctx = input_lib.MultiStepContext() def body(i, *args): """A wrapper around `fn` to create the while loop body.""" del args fn_result = fn(ctx, iterator.get_next()) for (name, output) in ctx.last_step_outputs.items(): # Convert all outputs to tensors, potentially from `DistributedValues`. ctx.last_step_outputs[name] = self._local_results(output) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) with ops.control_dependencies([fn_result]): return [i + 1] + flat_last_step_outputs # We capture the control_flow_context at this point, before we run `fn` # inside a while_loop. This is useful in cases where we might need to exit # these contexts and get back to the outer context to do some things, for # e.g. create an op which should be evaluated only once at the end of the # loop on the host. One such usage is in creating metrics' value op. self._outer_control_flow_context = ( ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access cond = lambda i, *args: i < iterations i = constant_op.constant(0) loop_result = control_flow_ops.while_loop(cond, body, [i] + initial_loop_values, name="", parallel_iterations=1, back_prop=False, swap_memory=False, return_same_structure=True) del self._outer_control_flow_context ctx.run_op = control_flow_ops.group(loop_result) # Convert the last_step_outputs from a list to the original dict structure # of last_step_outputs. last_step_tensor_outputs = loop_result[1:] last_step_tensor_outputs_dict = nest.pack_sequence_as( ctx.last_step_outputs, last_step_tensor_outputs) for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access output = last_step_tensor_outputs_dict[name] # For outputs that have already been reduced, wrap them in a Mirrored # container, else in a PerReplica container. if reduce_op is None: last_step_tensor_outputs_dict[name] = values.regroup(output) else: assert len(output) == 1 last_step_tensor_outputs_dict[name] = output[0] ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access return ctx
def dynamic_decode(decoder, output_time_major=False, impute_finished=False, maximum_iterations=None, parallel_iterations=32, swap_memory=False, scope=None): """Perform dynamic decoding with `decoder`. Calls initialize() once and step() repeatedly on the Decoder object. Args: decoder: A `Decoder` instance. output_time_major: Python boolean. Default: `False` (batch major). If `True`, outputs are returned as time major tensors (this mode is faster). Otherwise, outputs are returned as batch major tensors (this adds extra time to the computation). impute_finished: Python boolean. If `True`, then states for batch entries which are marked as finished get copied through and the corresponding outputs get zeroed out. This causes some slowdown at each time step, but ensures that the final state and outputs have the correct values and that backprop ignores time steps that were marked as finished. maximum_iterations: `int32` scalar, maximum allowed number of decoding steps. Default is `None` (decode until the decoder is fully done). parallel_iterations: Argument passed to `tf.while_loop`. swap_memory: Argument passed to `tf.while_loop`. scope: Optional variable scope to use. Returns: `(final_outputs, final_state, final_sequence_lengths)`. Raises: TypeError: if `decoder` is not an instance of `Decoder`. ValueError: if `maximum_iterations` is provided but is not a scalar. """ if not isinstance(decoder, Decoder): raise TypeError("Expected decoder to be type Decoder, but saw: %s" % type(decoder)) with variable_scope.variable_scope(scope, "decoder") as varscope: # Determine context types. ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access is_xla = control_flow_util.GetContainingXLAContext(ctxt) is not None in_while_loop = (control_flow_util.GetContainingWhileContext(ctxt) is not None) # Properly cache variable values inside the while_loop. # Don't set a caching device when running in a loop, since it is possible # that train steps could be wrapped in a tf.while_loop. In that scenario # caching prevents forward computations in loop iterations from re-reading # the updated weights. if not context.executing_eagerly() and not in_while_loop: if varscope.caching_device is None: varscope.set_caching_device(lambda op: op.device) if maximum_iterations is not None: maximum_iterations = ops.convert_to_tensor( maximum_iterations, dtype=dtypes.int32, name="maximum_iterations") if maximum_iterations.get_shape().ndims != 0: raise ValueError("maximum_iterations must be a scalar") initial_finished, initial_inputs, initial_state = decoder.initialize() zero_outputs = _create_zero_outputs(decoder.output_size, decoder.output_dtype, decoder.batch_size) if is_xla and maximum_iterations is None: raise ValueError( "maximum_iterations is required for XLA compilation.") if maximum_iterations is not None: initial_finished = math_ops.logical_or(initial_finished, 0 >= maximum_iterations) initial_sequence_lengths = array_ops.zeros_like(initial_finished, dtype=dtypes.int32) initial_time = constant_op.constant(0, dtype=dtypes.int32) def _shape(batch_size, from_shape): if (not isinstance(from_shape, tensor_shape.TensorShape) or from_shape.ndims == 0): return tensor_shape.TensorShape(None) else: batch_size = tensor_util.constant_value( ops.convert_to_tensor(batch_size, name="batch_size")) return tensor_shape.TensorShape([batch_size ]).concatenate(from_shape) dynamic_size = maximum_iterations is None or not is_xla def _create_ta(s, d): return tensor_array_ops.TensorArray( dtype=d, size=0 if dynamic_size else maximum_iterations, dynamic_size=dynamic_size, element_shape=_shape(decoder.batch_size, s)) initial_outputs_ta = nest.map_structure(_create_ta, decoder.output_size, decoder.output_dtype) def condition(unused_time, unused_outputs_ta, unused_state, unused_inputs, finished, unused_sequence_lengths): return math_ops.logical_not(math_ops.reduce_all(finished)) def body(time, outputs_ta, state, inputs, finished, sequence_lengths): """Internal while_loop body. Args: time: scalar int32 tensor. outputs_ta: structure of TensorArray. state: (structure of) state tensors and TensorArrays. inputs: (structure of) input tensors. finished: bool tensor (keeping track of what's finished). sequence_lengths: int32 tensor (keeping track of time of finish). Returns: `(time + 1, outputs_ta, next_state, next_inputs, next_finished, next_sequence_lengths)`. ``` """ (next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step(time, inputs, state) if decoder.tracks_own_finished: next_finished = decoder_finished else: next_finished = math_ops.logical_or(decoder_finished, finished) next_sequence_lengths = array_ops.where( math_ops.logical_not(finished), array_ops.fill(array_ops.shape(sequence_lengths), time + 1), sequence_lengths) nest.assert_same_structure(state, decoder_state) nest.assert_same_structure(outputs_ta, next_outputs) nest.assert_same_structure(inputs, next_inputs) # Zero out output values past finish if impute_finished: emit = nest.map_structure( lambda out, zero: array_ops.where(finished, zero, out), next_outputs, zero_outputs) else: emit = next_outputs # Copy through states past finish def _maybe_copy_state(new, cur): # TensorArrays and scalar states get passed through. if isinstance(cur, tensor_array_ops.TensorArray): pass_through = True else: new.set_shape(cur.shape) pass_through = (new.shape.ndims == 0) return new if pass_through else array_ops.where( finished, cur, new) if impute_finished: next_state = nest.map_structure(_maybe_copy_state, decoder_state, state) else: next_state = decoder_state outputs_ta = nest.map_structure( lambda ta, out: ta.write(time, out), outputs_ta, emit) return (time + 1, outputs_ta, next_state, next_inputs, next_finished, next_sequence_lengths) res = control_flow_ops.while_loop( condition, body, loop_vars=( initial_time, initial_outputs_ta, initial_state, initial_inputs, initial_finished, initial_sequence_lengths, ), parallel_iterations=parallel_iterations, maximum_iterations=maximum_iterations, swap_memory=swap_memory) final_outputs_ta = res[1] final_state = res[2] final_sequence_lengths = res[5] final_outputs = nest.map_structure(lambda ta: ta.stack(), final_outputs_ta) try: final_outputs, final_state = decoder.finalize( final_outputs, final_state, final_sequence_lengths) except NotImplementedError: pass if not output_time_major: final_outputs = nest.map_structure(_transpose_batch_time, final_outputs) return final_outputs, final_state, final_sequence_lengths
def export_scoped_meta_graph(filename=None, graph_def=None, graph=None, export_scope=None, as_text=False, unbound_inputs_col_name="unbound_inputs", clear_devices=False, **kwargs): """Returns `MetaGraphDef` proto. Optionally writes it to filename. This function exports the graph, saver, and collection objects into `MetaGraphDef` protocol buffer with the intention of it being imported at a later time or location to restart training, run inference, or be a subgraph. Args: filename: Optional filename including the path for writing the generated `MetaGraphDef` protocol buffer. graph_def: `GraphDef` protocol buffer. graph: The `Graph` to import into. If `None`, use the default graph. export_scope: Optional `string`. Name scope under which to extract the subgraph. The scope name will be striped from the node definitions for easy import later into new name scopes. If `None`, the whole graph is exported. graph_def and export_scope cannot both be specified. as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto. unbound_inputs_col_name: Optional `string`. If provided, a string collection with the given name will be added to the returned `MetaGraphDef`, containing the names of tensors that must be remapped when importing the `MetaGraphDef`. clear_devices: Boolean which controls whether to clear device information before exporting the graph. **kwargs: Optional keyed arguments, including meta_info_def, saver_def, collection_list. Returns: A `MetaGraphDef` proto and dictionary of `Variables` in the exported name scope. Raises: ValueError: When the `GraphDef` is larger than 2GB. """ graph = graph or ops.get_default_graph() unbound_inputs = [] if export_scope or clear_devices: if graph_def: new_graph_def = graph_pb2.GraphDef() new_graph_def.versions.CopyFrom(graph_def.versions) for node_def in graph_def.node: if _should_include_node(node_def.name, export_scope): new_node_def = _node_def(node_def, export_scope, unbound_inputs, clear_devices=clear_devices) new_graph_def.node.extend([new_node_def]) graph_def = new_graph_def else: # Only do this complicated work if we want to remove a name scope. graph_def = graph_pb2.GraphDef() # pylint: disable=protected-access graph_def.versions.CopyFrom(graph.graph_def_versions) bytesize = 0 for key in sorted(graph._nodes_by_id): if _should_include_node(graph._nodes_by_id[key].name, export_scope): value = graph._nodes_by_id[key] # pylint: enable=protected-access node_def = _node_def(value.node_def, export_scope, unbound_inputs, clear_devices=clear_devices) graph_def.node.extend([node_def]) if value.outputs: assert "_output_shapes" not in graph_def.node[-1].attr graph_def.node[-1].attr[ "_output_shapes"].list.shape.extend([ output.get_shape().as_proto() for output in value.outputs ]) bytesize += value.node_def.ByteSize() if bytesize >= (1 << 31) or bytesize < 0: raise ValueError("GraphDef cannot be larger than 2GB.") # It's possible that not all the inputs are in the export_scope. # If we would like such information included in the exported meta_graph, # add them to a special unbound_inputs collection. if unbound_inputs_col_name: # Clears the unbound_inputs collections. graph.clear_collection(unbound_inputs_col_name) for k in unbound_inputs: graph.add_to_collection(unbound_inputs_col_name, k) var_list = {} variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope=export_scope) for v in variables: if _should_include_node(v, export_scope): var_list[ops.strip_name_scope(v.name, export_scope)] = v scoped_meta_graph_def = create_meta_graph_def(graph_def=graph_def, graph=graph, export_scope=export_scope, **kwargs) if filename: training_util.write_graph(scoped_meta_graph_def, os.path.dirname(filename), os.path.basename(filename), as_text=as_text) return scoped_meta_graph_def, var_list
def import_scoped_meta_graph(meta_graph_or_file, clear_devices=False, graph=None, import_scope=None, input_map=None, unbound_inputs_col_name="unbound_inputs"): """Recreates a`Graph` saved in a `MetaGraphDef` proto. This function takes a `MetaGraphDef` protocol buffer as input. If the argument is a file containing a `MetaGraphDef` protocol buffer , it constructs a protocol buffer from the file content. The function then adds all the nodes from the `graph_def` field to the current graph, recreates all the collections, and returns a saver constructed from the `saver_def` field. In combination with `export_scoped_meta_graph()`, this function can be used to * Serialize a graph along with other Python objects such as `QueueRunner`, `Variable` into a `MetaGraphDef`. * Restart training from a saved graph and checkpoints. * Run inference from a saved graph and checkpoints. Args: meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including the path) containing a `MetaGraphDef`. clear_devices: Boolean which controls whether to clear device information from graph_def. Default false. graph: The `Graph` to import into. If `None`, use the default graph. import_scope: Optional `string`. Name scope into which to import the subgraph. If `None`, the graph is imported to the root name scope. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. unbound_inputs_col_name: Collection name for looking up unbound inputs. Returns: A dictionary of all the `Variables` imported into the name scope. Raises: ValueError: If the graph_def contains unbound inputs. """ if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): meta_graph_def = meta_graph_or_file else: meta_graph_def = read_meta_graph_file(meta_graph_or_file) if unbound_inputs_col_name: for key, col_def in meta_graph_def.collection_def.items(): if key == unbound_inputs_col_name: kind = col_def.WhichOneof("kind") field = getattr(col_def, kind) if field.value and (not input_map or sorted( [compat.as_str(v) for v in field.value]) != sorted(input_map)): raise ValueError( "Graph contains unbound inputs: %s. Must " "provide these inputs through input_map." % ",".join([compat.as_str(v) for v in field.value])) break # Sets graph to default graph if it's not passed in. graph = graph or ops.get_default_graph() # Gathers the list of nodes we are interested in. with graph.as_default(): producer_op_list = None if meta_graph_def.meta_info_def.HasField("stripped_op_list"): producer_op_list = meta_graph_def.meta_info_def.stripped_op_list input_graph_def = meta_graph_def.graph_def # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = "" importer.import_graph_def(input_graph_def, name=(import_scope or ""), input_map=input_map, producer_op_list=producer_op_list) # Restores all the other collections. for key, col_def in meta_graph_def.collection_def.items(): # Don't add unbound_inputs to the new graph. if key == unbound_inputs_col_name: continue kind = col_def.WhichOneof("kind") if kind is None: logging.error( "Cannot identify data type for collection %s. Skipping.", key) continue from_proto = ops.get_from_proto_function(key) if from_proto: assert kind == "bytes_list" proto_type = ops.get_collection_proto_type(key) for value in col_def.bytes_list.value: proto = proto_type() proto.ParseFromString(value) graph.add_to_collection( key, from_proto(proto, import_scope=import_scope)) else: field = getattr(col_def, kind) if kind == "node_list": for value in field.value: col_op = graph.as_graph_element( ops.prepend_name_scope(value, import_scope)) graph.add_to_collection(key, col_op) elif kind == "int64_list": # NOTE(opensource): This force conversion is to work around the fact # that Python2 distinguishes between int and long, while Python3 has # only int. for value in field.value: graph.add_to_collection(key, int(value)) else: for value in field.value: graph.add_to_collection( key, ops.prepend_name_scope(value, import_scope)) var_list = {} variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope=import_scope) for v in variables: var_list[ops.strip_name_scope(v.name, import_scope)] = v return var_list
def create_meta_graph_def(meta_info_def=None, graph_def=None, saver_def=None, collection_list=None, graph=None, export_scope=None): """Construct and returns a `MetaGraphDef` protocol buffer. Args: meta_info_def: `MetaInfoDef` protocol buffer. graph_def: `GraphDef` protocol buffer. saver_def: `SaverDef` protocol buffer. collection_list: List of string keys to collect. graph: The `Graph` to create `MetaGraphDef` out of. export_scope: Optional `string`. Name scope to remove. Returns: MetaGraphDef protocol buffer. Raises: TypeError: If the arguments are not of the correct proto buffer type. """ # Type check. if graph and not isinstance(graph, ops.Graph): raise TypeError("graph must be of type Graph, not %s", type(graph)) if meta_info_def and not isinstance( meta_info_def, meta_graph_pb2.MetaGraphDef.MetaInfoDef): raise TypeError("meta_info_def must be of type MetaInfoDef, not %s", type(meta_info_def)) if graph_def and not isinstance(graph_def, graph_pb2.GraphDef): raise TypeError("graph_def must be of type GraphDef, not %s", type(graph_def)) if saver_def and not isinstance(saver_def, saver_pb2.SaverDef): raise TypeError("saver_def must be of type SaverDef, not %s", type(saver_def)) # Sets graph to default graph if it's not passed in. graph = graph or ops.get_default_graph() # Creates a MetaGraphDef proto. meta_graph_def = meta_graph_pb2.MetaGraphDef() # Adds meta_info_def. if meta_info_def: meta_graph_def.meta_info_def.MergeFrom(meta_info_def) # Adds graph_def or the default. if not graph_def: meta_graph_def.graph_def.MergeFrom(graph.as_graph_def(add_shapes=True)) else: meta_graph_def.graph_def.MergeFrom(graph_def) # Fills in meta_info_def.stripped_op_list using the ops from graph_def. # pylint: disable=g-explicit-length-test if len(meta_graph_def.meta_info_def.stripped_op_list.op) == 0: meta_graph_def.meta_info_def.stripped_op_list.MergeFrom( stripped_op_list_for_graph(meta_graph_def.graph_def)) # pylint: enable=g-explicit-length-test # Adds saver_def. if saver_def: meta_graph_def.saver_def.MergeFrom(saver_def) # Adds collection_list. if collection_list: clist = collection_list else: clist = graph.get_all_collection_keys() for ctype in clist: add_collection_def(meta_graph_def, ctype, graph=graph, export_scope=export_scope) return meta_graph_def
def add_collection_def(meta_graph_def, key, graph=None, export_scope=None): """Adds a collection to MetaGraphDef protocol buffer. Args: meta_graph_def: MetaGraphDef protocol buffer. key: One of the GraphKeys or user-defined string. graph: The `Graph` from which to get collections. export_scope: Optional `string`. Name scope to remove. """ if graph and not isinstance(graph, ops.Graph): raise TypeError("graph must be of type Graph, not %s", type(graph)) if not isinstance(key, six.string_types) and not isinstance(key, bytes): logging.warning( "Only collections with string type keys will be " "serialized. This key has %s", type(key)) return # Sets graph to default graph if it's not passed in. graph = graph or ops.get_default_graph() collection_list = graph.get_collection(key) # Remove nodes that should not be exported from the collection list. collection_list = [ x for x in collection_list if _should_include_node(x, export_scope) ] if not collection_list: return try: col_def = meta_graph_def.collection_def[key] to_proto = ops.get_to_proto_function(key) proto_type = ops.get_collection_proto_type(key) if to_proto: kind = "bytes_list" for x in collection_list: # Additional type check to make sure the returned proto is indeed # what we expect. proto = to_proto(x, export_scope=export_scope) if proto: assert isinstance(proto, proto_type) getattr(col_def, kind).value.append(proto.SerializeToString()) else: kind = _get_kind_name(collection_list[0]) if kind == "node_list": for x in collection_list: if not export_scope or x.name.startswith(export_scope): getattr(col_def, kind).value.append( ops.strip_name_scope(x.name, export_scope)) elif kind == "bytes_list": # NOTE(opensource): This force conversion is to work around the fact # that Python3 distinguishes between bytes and strings. getattr(col_def, kind).value.extend( [compat.as_bytes(x) for x in collection_list]) else: getattr(col_def, kind).value.extend([x for x in collection_list]) except Exception as e: # pylint: disable=broad-except logging.warning( "Error encountered when serializing %s.\n" "Type is unsupported, or the types of the items don't " "match field type in CollectionDef.\n%s", key, str(e)) if key in meta_graph_def.collection_def: del meta_graph_def.collection_def[key] return
def _call_for_each_replica(distribution, devices, fn, args, kwargs): """Run `fn` in separate threads, once per replica/worker device. Args: distribution: the DistributionStrategy object. devices: the devices to run `fn` on (logical device 0 for each replica). fn: function to run (will be run once per replica, each in its own thread). args: positional arguments for `fn` kwargs: keyword arguments for `fn`. Returns: Merged return value of `fn` across all replicas. Raises: RuntimeError: If fn() calls get_replica_context().merge_call() a different number of times from the available devices. """ # TODO(josh11b): Add this option once we add synchronization to variable # creation. Until then, this is pretty unsafe to use. run_concurrently = False if not context.executing_eagerly(): # Needed for per-thread device, etc. contexts in graph mode. ops.get_default_graph().switch_to_thread_local() coord = coordinator.Coordinator( clean_stop_exception_types=(_RequestedStop, )) shared_variable_store = {} # TODO(isaprykin): Create these threads once instead of during every call. threads = [] for index in range(len(devices)): variable_creator_fn = shared_variable_creator.make_fn( shared_variable_store, index) t = _MirroredReplicaThread(distribution, coord, index, devices, variable_creator_fn, fn, values.select_replica(index, args), values.select_replica(index, kwargs)) threads.append(t) for t in threads: t.start() # When `fn` starts `should_run` event is set on _MirroredReplicaThread # (`MRT`) threads. The execution waits until # `MRT.has_paused` is set, which indicates that either `fn` is # complete or a `get_replica_context().merge_call()` is called. If `fn` is # complete, then `MRT.done` is set to True. Otherwise, arguments # of `get_replica_context().merge_call` from all paused threads are grouped # and the `merge_fn` is performed. Results of the # `get_replica_context().merge_call` are then set to `MRT.merge_result`. # Each such `get_replica_context().merge_call` call returns the # `MRT.merge_result` for that thread when `MRT.should_run` event # is reset again. Execution of `fn` resumes. try: with coord.stop_on_exception(): all_done = False while not all_done and not coord.should_stop(): done = [] if run_concurrently: for t in threads: t.should_run.set() for t in threads: t.has_paused.wait() t.has_paused.clear() if coord.should_stop(): return None done.append(t.done) else: for t in threads: t.should_run.set() t.has_paused.wait() t.has_paused.clear() if coord.should_stop(): return None done.append(t.done) if coord.should_stop(): return None all_done = all(done) if not all_done: if any(done): raise RuntimeError( "Some replicas made a different number of " "replica_context().merge_call() calls.") # get_replica_context().merge_call() case merge_args = values.regroup( tuple(t.merge_args for t in threads)) merge_kwargs = values.regroup( tuple(t.merge_kwargs for t in threads)) # We capture the name_scope of the MRT when we call merge_fn # to ensure that if we have opened a name scope in the MRT, # it will be respected when executing the merge function. We only # capture the name_scope from the first MRT and assume it is # the same for all other MRTs. mtt_captured_name_scope = threads[0].captured_name_scope mtt_captured_var_scope = threads[0].captured_var_scope # Capture and merge the control dependencies from all the threads. mtt_captured_control_deps = set() for t in threads: mtt_captured_control_deps.update( t.captured_control_deps) with ops.name_scope(mtt_captured_name_scope),\ ops.control_dependencies(mtt_captured_control_deps), \ variable_scope.variable_scope(mtt_captured_var_scope): merge_result = threads[0].merge_fn( distribution, *merge_args, **merge_kwargs) for r, t in enumerate(threads): t.merge_result = values.select_replica(r, merge_result) finally: for t in threads: t.should_run.set() coord.join(threads) return values.regroup(tuple(t.main_result for t in threads))
def split_compile_and_replicate(computation, inputs=None, infeed_queue=None, device_assignment=None, name=None, use_tpu=True): """Builds graph operators that runs compilation and replicated computation. This is a lower level interface than replicate that returns a separate compile and execute output tensor. In the generated graph the compile op feeds into the execute op and no additional compilation is incurred when running the compile op before the execute op. The compile op returns additional information about the compilation but does not return the compiled program. Args: computation: A Python function that builds the computation to replicate. inputs: A list of lists of input tensors or `None` (equivalent to `[[]]`), indexed by `[replica_num][input_num]`. All replicas must have the same number of inputs. infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple of arguments as inputs to computation. device_assignment: If not `None`, a `DeviceAssignment` describing the mapping between logical cores in the computation with physical cores in the TPU topology. Uses a default device assignment if `None`. The `DeviceAssignment` may be omitted if each replica of the computation uses only one core, and there is either only one replica, or the number of replicas is equal to the number of cores in the TPU system. name: (Deprecated) Does nothing. use_tpu: When false, the input `computation` is executed on the XLA CPU/GPU backends. Currently, only supports a default placement (computation is placed on GPU if one is available, and on CPU if not). Returns: A list of lists with the first list corresponding to the compile op and the second a list of output tensors, indexed by `[replica_num][output_num]`. Raises: ValueError: If all replicas do not have equal numbers of input tensors. ValueError: If the number of inputs per replica does not match the number of formal parameters to `computation`. """ del name inputs = [[]] if inputs is None else inputs metadata_kwargs = {} if device_assignment is not None: # Turn the Numpy array into a flattened list so we can pass it as an # operator attribute. metadata_kwargs = { "topology": device_assignment.topology.serialized(), "device_assignment": device_assignment.core_assignment.flatten().tolist() } # TODO(phawkins): remove this case after the forward compatibility window # expires on 2018-10-5. if api_compat.forward_compatible(2018, 10, 5): metadata_kwargs["num_cores_per_replica"] = ( device_assignment.num_cores_per_replica) else: metadata_kwargs["computation_shape"] = [ device_assignment.num_cores_per_replica ] if ((not isinstance(inputs, list)) or any(not isinstance(inp, (list, tuple)) for inp in inputs)): raise TypeError( "tpu.replicate() inputs must be a list of lists/tuples") num_replicas = len(inputs) # No replicas? Nothing to do. if num_replicas == 0: return [] # Converts inputs to Tensors. inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in inputs] # Verifies that all replicas have matching numbers and types of inputs input_types = [x.dtype for x in inputs[0]] input_arity = len(input_types) for i in range(num_replicas): if len(inputs[i]) != input_arity: raise ValueError("Replicas must have the same number of inputs. " "Replica 0 had {} inputs, replica {} had {} " "inputs.".format(input_arity, i, len(inputs[i]))) types = [x.dtype for x in inputs[i]] if types != input_types: raise ValueError( "Replicas must have matching input types. Replica 0 had " "input types {}, replica {} had input types {}".format( input_types, i, types)) arg_error = tpu_function.check_function_argument_count( computation, input_arity, infeed_queue) if arg_error is not None: if infeed_queue is None: raise TypeError( "Supplied computation cannot be called with the specified inputs. " "You specified %d inputs: %s, but the computation needs %s" % (input_arity, str([i.name for i in inputs[0]]), arg_error)) else: raise TypeError( "Supplied computation cannot be called with the specified inputs. " "You specified %d inputs: %s and %d additional inputs from infeed," " but the computation needs %s" % (input_arity, str([i.name for i in inputs[0]]), infeed_queue.number_of_tuple_elements, arg_error)) graph = ops.get_default_graph() # Fan-in: Builds a TPUReplicatedInput node for each input. computation_inputs = [] for i in range(0, input_arity): replicas = [inputs[replica][i] for replica in xrange(num_replicas)] computation_inputs.append( tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i))) cluster_name = graph.unique_name("cluster") pivot = control_flow_ops.no_op(name=cluster_name + "/pivot") context = TPUReplicateContext(name=cluster_name, num_replicas=num_replicas, pivot=pivot) try: context.Enter() metadata = tpu_ops.tpu_replicate_metadata(num_replicas=num_replicas, use_tpu=use_tpu, **metadata_kwargs) with tpu_function.tpu_shard_context( num_replicas), ops.control_dependencies([metadata]): # Add identity ops so even unused inputs are "consumed" by the # computation. This is to avoid orphaned TPUReplicatedInput nodes. # TODO(phawkins): consider instead pruning unused TPUReplicatedInput # and eliding trivial TPUReplicatedInput/TPUReplicatedOutput pairs. computation_inputs = [ array_ops.identity(x, name="replicated_input_{}".format(i)) for i, x in enumerate(computation_inputs) ] # If there is an infeed queue, adds the dequeued values to the # computation's inputs. if infeed_queue is not None: infeed_queue.set_number_of_shards(num_replicas) for t in infeed_queue.generate_dequeue_op(): computation_inputs.append(t) # Only resource variables work inside a TPU computation, so turn on # resource variables for the computation. # TODO(phawkins): consider removing this code. It will # be less confusing to clients if they knowingly choose to use resource # variables. # Partitioned variables is not supported (b/112311320). def custom_getter(getter, name, *args, **kwargs): """Variables on TPU have a few restrictions.""" partitioner = kwargs["partitioner"] if partitioner is not None: kwargs["partitioner"] = None logging.warning( "Partitioned variables are not supported on TPU. Got " "`partitioner` that is {} for variable {}. " "Setting `partitioner` to `None`.".format( partitioner, name)) return getter(name, *args, **kwargs) vscope = variable_scope.get_variable_scope() saved_use_resource = vscope.use_resource saved_custom_getter = vscope.custom_getter vscope.set_use_resource(True) vscope.set_custom_getter(custom_getter) outputs = computation(*computation_inputs) vscope.set_use_resource(saved_use_resource) vscope.set_custom_getter(saved_custom_getter) # If the computation returns `None`, make it an empty tuple. if outputs is None: outputs = tuple() # If the computation only returned one value, makes it a tuple. if not isinstance(outputs, (list, tuple)): outputs = (outputs, ) # Append `no_op` here so that fetching any return value of this function # will trigger TPUExecute node. outputs += (control_flow_ops.no_op(), ) try: with ops.device(core(0)): outputs = [ o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) for o in outputs ] except Exception as e: raise ValueError( "TPU function return values must all either be Operations or " "convertible to Tensors. Got '%s'" % str(e)) # Separates the returned Operations and Tensors. output_operations = [ o for o in outputs if isinstance(o, ops.Operation) ] output_tensors = [ o for o in outputs if not isinstance(o, ops.Operation) ] if outputs != output_tensors + output_operations: raise ValueError( "TPU functions must return zero-or more Tensor values followed by " "zero or more Operations.") output_arity = len(output_tensors) # Wraps outputs in Identity ops. Otherwise a replicated input copied # straight to an output would bypass the replicate(). This would be bad # because the TPUReplicatedInput/TPUReplicatedOutput operator would not # be rewritten away, leading to a runtime error. # TODO(phawkins): extend the rewrite to elide these nodes instead. new_output_tensors = [] for t in output_tensors: with ops.device(t.device if t.device else core(0)): new_output_tensors.append(array_ops.identity(t)) output_tensors = new_output_tensors context.ExitResult(output_tensors) finally: context.report_unsupported_operations() context.Exit() host_compute_core = context.HostComputeCore() if host_compute_core: attr_value = attr_value_pb2.AttrValue() attr_value.list.s.extend( [compat.as_bytes(x) for x in host_compute_core]) metadata._set_attr("host_compute_core", attr_value) # pylint: disable=protected-access # Fan-out: Builds a TPUReplicatedOutput node for each output. outputs = [ tpu_ops.tpu_replicated_output(output_tensors[i], num_replicas, name="output{}".format(i)) for i in xrange(output_arity) ] with ops.control_dependencies([metadata]): if use_tpu: compile_status = tpu_ops.tpu_compilation_result() op = compile_status.op attr_value = attr_value_pb2.AttrValue( s=compat.as_bytes(cluster_name)) op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value) # pylint: disable=protected-access else: compile_status = control_flow_ops.no_op(name="compilation_status") with ops.control_dependencies(output_operations): if output_arity == 0: # Returns a list of NoOps dependent on the replication Op, indexed by # [replica_num]. return [ compile_status, [ control_flow_ops.no_op(name="shard_%d" % i) for i in range(num_replicas) ] ] else: # Wraps the outputs in identity operators so the names of any possible # `fetch` nodes are preserved by the replication rewrite. return [ compile_status, [[ array_ops.identity(outputs[out][replica], name="output_%d_shard_%d" % (out, replica)) for out in xrange(output_arity) ] for replica in xrange(num_replicas)] ]
def begin(self): self._loss_tensor = ops.get_default_graph().get_tensor_by_name( KMeansClustering.LOSS_OP_NAME + ':0') assert self._loss_tensor is not None
def _default_layout(self, layout: layout_lib.Layout): """Sets a default output layout for all ops in the scope. Note: This is an internal helper method, which is not user facing api. Useful for requesting a specific layout for ops which would have no inferred layout, e.g. tf.zeros. Caveats: - Currently only affects the first output of an op. For Op with multiple outputs, this does not support yet. - All Ops in the scope will be attached with the same layout. This might not be valid as the rank is different. The current suggestion is: Try to wrap the raw op wheneven possible. Args: layout: A Layout for the outputs of all operations in this scope. Yields: Nothing. """ previous_default = None previous_graph_size = None graph = None self._register_mesh(layout.mesh) try: previous_default = self._current_output_layout self._current_output_layout = layout.to_string().encode("utf-8") _pywrap_dtensor_device.ExperimentalSetDefaultLayout( self._device_info, self._current_output_layout) if context.executing_eagerly(): with ops.device(self.name): yield else: # Custom devices currently don't affect graph building, so we need a # separate way to indicate layouts. # # TODO(allenl): Remove this case once the DTensor device is active # during tracing. graph = ops.get_default_graph() previous_graph_size = len(graph.get_operations()) yield finally: if graph is not None: # Tag operations added under this scope for operation in graph.get_operations()[previous_graph_size:]: # Set layout directly on the Op itself. operation._set_attr( # pylint: disable=protected-access "_layout", attr_value_pb2.AttrValue( list=attr_value_pb2.AttrValue.ListValue( s=[self._current_output_layout]))) operation._set_attr( # pylint: disable=protected-access "_mesh", attr_value_pb2.AttrValue( s=layout.mesh.to_string().encode("utf-8"))) self._current_output_layout = previous_default if self._current_output_layout is None: _pywrap_dtensor_device.ExperimentalClearDefaultLayout(self._device_info) else: _pywrap_dtensor_device.ExperimentalSetDefaultLayout( self._device_info, self._current_output_layout.decode("utf-8"))
def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay): """Computes batch norm correction params. Before batch normalization is frozen: We use batch statistics for batch norm. correction_scale = sigma_b/sigma_mv correction_recip = 1/correction_scale correction_offset = 0 After batch normalization is frozen: correction_scale = sigma_b/sigma_mv correction_recip = 1 correction_offset = gamma*(mu_b/sigma_b-mu_mv/sigma_mv). Batch norm is frozen if global_step > bn_freeze_delay. The corrections ensure that: a) The weights are quantized after scaling by gamma/sigma_mv. This enables smoother training as the scaling on the weights changes slowly, rather than jump across mini-batches b) Changing the values of the corrections allows for one to switch between using batch statistics to using moving mean and average, without requiring changes to batch_norm Args: context: The scope under which we look for batch norm params match: Object containing required batch norm tensors for correction computation. freeze_batch_norm_delay: Delay in steps at which computation switches from regular batch norm to frozen mean and variance. Returns: A tuple of correction_scale, correction_recip, correction_offset """ g = ops.get_default_graph() prefix = '' if not context else context with g.name_scope(prefix + 'batch_norm_correction'): recip_sigma_mv = math_ops.rsqrt(match.moving_variance_tensor + match.batch_epsilon) recip_sigma = math_ops.rsqrt(match.variance_tensor + match.batch_epsilon) correction_scale = math_ops.divide(recip_sigma_mv, recip_sigma, name='scale_compute') correction_scale = array_ops.identity(correction_scale, name='correction_scale') correction_recip = math_ops.reciprocal(correction_scale, name='reciprocal_compute') correction_offset = math_ops.multiply( match.gamma_tensor, match.mean_tensor * recip_sigma - match.moving_mean_tensor * recip_sigma_mv, name='offset_compute') if freeze_batch_norm_delay is not None: use_mv_avg = math_ops.greater_equal( common.CreateOrGetQuantizationStep(), freeze_batch_norm_delay, name='use_moving_average') else: use_mv_avg = False bn_decay_zero = 0.0 bn_decay_mean_consumers = list(match.bn_decay_mean_tensor.consumers()) bn_decay_var_consumers = list(match.bn_decay_mean_tensor.consumers()) bn_decay_mean_out = utils.smart_cond( use_mv_avg, lambda: bn_decay_zero, lambda: match.bn_decay_mean_tensor, name='freeze_moving_mean') common.RerouteTensor(bn_decay_mean_out, match.bn_decay_mean_tensor, can_modify=bn_decay_mean_consumers) bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers()) bn_decay_var_out = utils.smart_cond(use_mv_avg, lambda: bn_decay_zero, lambda: match.bn_decay_var_tensor, name='freeze_moving_var') common.RerouteTensor(bn_decay_var_out, match.bn_decay_var_tensor, can_modify=bn_decay_var_consumers) correction_recip = utils.smart_cond( use_mv_avg, lambda: array_ops.ones(correction_scale.shape), lambda: correction_recip, name='correction_recip') correction_offset = utils.smart_cond( use_mv_avg, lambda: correction_offset, lambda: array_ops.zeros(correction_offset.shape), name='correction_offset') return correction_scale, correction_recip, correction_offset
def __exit__(self, unused_type, unused_value, unused_traceback): if context.executing_eagerly(): return if self._graph is not ops.get_default_graph(): raise RuntimeError( "Graph changed while trying to add control dependencies.") # pylint: disable=protected-access if hasattr(self._graph, "outer_graph"): outer_val = self._graph.outer_graph._add_control_dependencies self._graph._add_control_dependencies = outer_val else: self._graph._add_control_dependencies = False # pylint: enable=protected-access # map from resource tensor to the last op which used it last_op_using_resource_tensor = {} # set of conditional and loop exits ops_which_must_run = set() # merge which must depend on ops which use this resource merge_for_resource = {} new_operations = self._graph.get_operations()[self._n_operations:] # Ensures that uses of resource tensors get serialized properly and all # execute. This is done by keeping a map from resource tensor to the last op # in graph-construction order which used it (last_op_using_resource_tensor). # # Conditionals are written in TensorFlow such that every external tensor # accessed in the conditional goes through a switch op and every return # tensor (it's guaranteed that there will be at least one) goes through a # merge op. # # To handle conditionals, switches are handled in a special way (see # comments for _process_switch). Merge nodes created by TF's conditional # logic (as opposed to by _process_switch) are forced to run and also get a # control dependency added to them to ensure all stateful ops inside their # control flow context run. # # We also ensure that if an op is using a resource output by a switch node # (that is, a resource tensor for which there's a value in # merge_for_resource) this op will run before the merge for that resource. # # We try to add control inputs to nodes respecting their control flow # contexts to avoid dead nodes propagating everywhere and leading to # "retval[0] doesn't have value" errors. If a node gets a control dependency # on a dead node (i.e. a note from an untaken control flow branch) that node # will be marked as dead unless it's a merge node. # # TODO(apassos): serialize non-resource-taking stateful ops as well, and # test that it works. Support while loops. Support init_scope escaping from # this. for op in new_operations: # TODO(apassos) make this code safely support while loops. if control_flow_util.IsInWhileLoop(op): continue control_inputs = set() # Ensure stateful ops run if op_def_registry.get(op.type) is None or op_is_stateful(op): ops_which_must_run.add(op) # Ignore switches (they're handled separately) if op.type == "Switch" and op.inputs[ 0].dtype == dtypes_module.resource: continue # Make merges trigger all other computation which must run if op.type == "Merge": for o in ops_which_must_run: op._add_control_input(o) # pylint: disable=protected-access for inp in o.inputs: input_id = ops.tensor_id(inp) if input_id in last_op_using_resource_tensor: last_op_using_resource_tensor[input_id] = op ops_which_must_run = set([op]) continue resource_inputs = set() # Check for any resource inputs. If we find any, we update control_inputs # and last_op_using_resource_tensor. for inp in _get_resource_inputs(op): input_id = ops.tensor_id(inp) # If the op receives the same resource tensor twice as an input, we skip # to avoid the op getting a control dependency on itself. if input_id in resource_inputs: continue resource_inputs.add(input_id) # Deal with switches, finally. if inp.op.type == "Switch": self._process_switch(inp.op, ops_which_must_run, last_op_using_resource_tensor, merge_for_resource) is_building_function = op.graph.building_function # Ensure uses of resources are serialized if input_id in last_op_using_resource_tensor: if is_building_function or ( last_op_using_resource_tensor[input_id]. _control_flow_context # pylint: disable=protected-access is op._control_flow_context): # pylint: disable=protected-access control_inputs.add( last_op_using_resource_tensor[input_id]) # Ensure merges happen after the closing of a cond block if input_id in merge_for_resource: merge_for_resource[input_id]._add_control_input(op) # pylint: disable=protected-access last_op_using_resource_tensor[input_id] = op if (op_is_stateful(op) and not resource_inputs and op._control_flow_context is None): # pylint: disable=protected-access if None in last_op_using_resource_tensor: op._add_control_input(last_op_using_resource_tensor[None]) # pylint: disable=protected-access last_op_using_resource_tensor[None] = op control_inputs = [ c for c in control_inputs if is_building_function or ( c._control_flow_context is op._control_flow_context) ] # pylint: disable=protected-access op._add_control_inputs(control_inputs) # pylint: disable=protected-access # Ensure all ops which must run do run self.ops_which_must_run.update(ops_which_must_run) for r in nest.flatten(list(self._returned_tensors), expand_composites=True): if self.ops_which_must_run: r.op._add_control_inputs( # pylint: disable=protected-access [ o for o in self.ops_which_must_run if r.graph.building_function or (o._control_flow_context is r.op._control_flow_context) # pylint: disable=protected-access ])
def testGradientFromInsideNestedDefun(self): def build_graph(): pred_outer = array_ops.placeholder(dtypes.bool, name="pred_outer") pred_inner = array_ops.placeholder(dtypes.bool, name="pred_inner") x = constant_op.constant(1.0, name="x") y = constant_op.constant(2.0, name="y") def true_fn(): return 2.0 def false_fn(): def inner_true_fn(): return x * y * 2.0 def inner_false_fn(): return x * 5.0 return cond_v2.cond_v2( pred_inner, inner_true_fn, inner_false_fn, name="inner_cond") cond_outer = cond_v2.cond_v2( pred_outer, true_fn, false_fn, name="outer_cond") # Compute grads inside a Defun. @function.defun def nesting_fn(): @function.defun def inner_nesting_fn(): return gradients_impl.gradients(cond_outer, [x, y]) return inner_nesting_fn() grads = nesting_fn() return grads, pred_outer, pred_inner with ops.Graph().as_default(): grads, pred_outer, pred_inner = build_graph() with self.session(graph=ops.get_default_graph()) as sess: self.assertSequenceEqual( sess.run(grads, { pred_outer: True, pred_inner: True }), [0., 0.]) self.assertSequenceEqual( sess.run(grads, { pred_outer: True, pred_inner: False }), [0., 0.]) self.assertSequenceEqual( sess.run(grads, { pred_outer: False, pred_inner: True }), [4., 2.]) self.assertSequenceEqual( sess.run(grads, { pred_outer: False, pred_inner: False }), [5., 0.])
def __init__(self, input_dataset, target_device, source_device="/cpu:0"): """Constructs a _CopyToDeviceDataset. Args: input_dataset: `Dataset` to be copied target_device: The name of the device to which elements would be copied. source_device: Device where input_dataset would be placed. """ super(_CopyToDeviceDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._target_device = target_device spec = framework_device.DeviceSpec().from_string(self._target_device) self._is_gpu_target = (spec.device_type == "GPU") self._source_device_string = source_device self._source_device = ops.convert_to_tensor(source_device) self._flat_output_shapes = nest.flatten( sparse.as_dense_shapes(self._input_dataset.output_shapes, self._input_dataset.output_classes)) self._flat_output_types = nest.flatten( sparse.as_dense_types(self._input_dataset.output_types, self._input_dataset.output_classes)) @function.Defun() def _init_func(): """Creates an iterator for the input dataset. Returns: A `string` tensor that encapsulates the iterator created. """ # pylint: disable=protected-access ds_variant = self._input_dataset._as_variant_tensor() resource = core_gen_dataset_ops.anonymous_iterator( output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) with ops.control_dependencies( [core_gen_dataset_ops.make_iterator(ds_variant, resource)]): return core_gen_dataset_ops.iterator_to_string_handle(resource) @function.Defun() def _remote_init_func(): return functional_ops.remote_call( target=self._source_device, args=_init_func.captured_inputs, Tout=[dtypes.string], f=_init_func) self._init_func = _remote_init_func self._init_captured_args = _remote_init_func.captured_inputs @function.Defun(dtypes.string) def _next_func(string_handle): """Calls get_next for created iterator. Args: string_handle: An iterator string handle created by _init_func Returns: The elements generated from `input_dataset` """ with ops.device(self._source_device_string): iterator = iterator_ops.Iterator.from_string_handle( string_handle, self.output_types, self.output_shapes, self.output_classes) ret = iterator.get_next() return nest.flatten(sparse.serialize_sparse_tensors(ret)) @function.Defun(dtypes.string) def _remote_next_func(string_handle): return functional_ops.remote_call( target=self._source_device, args=[string_handle] + _next_func.captured_inputs, Tout=self._flat_output_types, f=_next_func) self._next_func = _remote_next_func self._next_captured_args = _remote_next_func.captured_inputs @function.Defun(dtypes.string) def _finalize_func(string_handle): """Destroys the iterator resource created. Args: string_handle: An iterator string handle created by _init_func Returns: Tensor constant 0 """ iterator_resource = core_gen_dataset_ops.iterator_from_string_handle_v2( string_handle, output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) with ops.control_dependencies([ resource_variable_ops.destroy_resource_op( iterator_resource, ignore_lookup_error=True)]): return array_ops.constant(0, dtypes.int64) @function.Defun(dtypes.string) def _remote_finalize_func(string_handle): return functional_ops.remote_call( target=self._source_device, args=[string_handle] + _finalize_func.captured_inputs, Tout=[dtypes.int64], f=_finalize_func) self._finalize_func = _remote_finalize_func self._finalize_captured_args = _remote_finalize_func.captured_inputs g = ops.get_default_graph() _remote_init_func.add_to_graph(g) _remote_next_func.add_to_graph(g) _remote_finalize_func.add_to_graph(g)
def initialize_variables(): for v, init in initializer_map.items(): v.assign( lift_to_graph.lift_to_graph([init], ops.get_default_graph())[init])
def _GetLayerMatch(match_result): """Populates a layer match object containing ops/tensors for folding BNs. Args: match_result: Matched result from graph matcher Returns: layer_op: Matching conv/fc op prior to batch norm BatchNormMatch: _BatchNormMatch containing all required batch norm parameters. """ moving_mean_tensor = None moving_variance_tensor = None bn_decay_mean_tensor = None bn_decay_var_tensor = None batch_to_space_op = None layer_op = match_result.get_op(layer_pattern) layer_tensor = match_result.get_tensor(layer_pattern) bn_id_op = match_result.get_op(batch_norm_identity_pattern) bn_op = match_result.get_op(batch_norm_pattern) if bn_id_op is None: bn_id_op = bn_op batch_epsilon = bn_op.get_attr('epsilon') # In the MatMul case, the output of batch norm is reshaped back into a # 2D tensor, so the output_tensor is the output of the Reshape op. output_tensor = bn_op.outputs[0] if layer_op.type == 'MatMul': output_reshape_op = match_result.get_op( matmul_bn_output_reshape_pattern) # If the matcher didn't match matmul_bn_output_reshape, there will be # another match for this 'MatMul' later, so we can skip this one. if output_reshape_op is None: return None, None output_tensor = output_reshape_op.outputs[0] # Ensure that the output tensor has consumers, otherwise this is a dangling # node and not a match. if not output_tensor.consumers(): return None, None batch_to_space_op = match_result.get_op(batch_to_space_pattern) input_tensor = match_result.get_tensor(input_pattern) weight_tensor = match_result.get_tensor(weight_pattern) gamma_tensor = match_result.get_tensor(gamma_pattern) beta_tensor = match_result.get_tensor(beta_pattern) # FusedBatchNorm in training is different from that in inference. It takes # empty 'mean' and empty 'variance', and produces the mean and the variance # of the batch. Therefore, when is_training is true, mean_tensor and # variance_tensor point to 1st and 2nd (0-based) output of bn_op, # respectively; when is_training is false, they point to bn_op's inputs. is_training = bn_op.get_attr('is_training') if is_training: # FusedBatchNormGrad doesn't compute gradients of the batch_mean and # batch_variance outputs, so we need to substitute our own custom # gradient. # TODO(suharshs, raghuramank): Find a way to avoid needing this hack. # pylint: disable=protected-access bn_op._set_attr( '_gradient_op_type', attr_value_pb2.AttrValue( s=compat.as_bytes('FoldFusedBatchNormGrad'))) # pylint: enable=protected-access mean_tensor = bn_op.outputs[1] # The batch variance used during forward and backward prop is biased, # i.e it is calculated as: V=sum(x(k)-mu)^2/N. For the moving average # calculation, the variance is corrected by the term N/N-1 (Bessel's # correction). The variance tensor read from FuseBatchNorm has Bessel's # correction applied, so we undo it here. scope, sep, _ = bn_op.name.rpartition('/') g = ops.get_default_graph() with g.as_default(), g.name_scope(scope + sep): n = math_ops.cast( array_ops.size(layer_tensor) / array_ops.size(mean_tensor), dtypes.float32) variance_tensor = math_ops.multiply( bn_op.outputs[2], (n - 1) / n, name='Undo_Bessel_Correction') # TODO(suharshs): Find a way to get rid of this inner match. for mul_match_result in moving_avg_mul_matcher.match_graph(graph): sub_op = mul_match_result.get_op(moving_average_sub_pattern) if sub_op.inputs[1].name == bn_op.outputs[1].name: # During training: Batch Mean is bn_op.outputs[1] moving_mean_tensor = sub_op.inputs[0] bn_decay_mean_tensor = mul_match_result.get_tensor( bn_decay_pattern) if sub_op.inputs[1].name == bn_op.outputs[2].name: # During training: Batch Var is bn_op.outputs[2] moving_variance_tensor = sub_op.inputs[0] bn_decay_var_tensor = mul_match_result.get_tensor( bn_decay_pattern) else: mean_tensor = match_result.get_tensor(mean_pattern) variance_tensor = match_result.get_tensor(variance_pattern) return layer_op, _BatchNormMatch( layer_op=layer_op, bn_op=bn_op, output_tensor=output_tensor, input_tensor=input_tensor, weight_tensor=weight_tensor, gamma_tensor=gamma_tensor, beta_tensor=beta_tensor, mean_tensor=mean_tensor, variance_tensor=variance_tensor, moving_mean_tensor=moving_mean_tensor, moving_variance_tensor=moving_variance_tensor, bn_decay_mean_tensor=bn_decay_mean_tensor, bn_decay_var_tensor=bn_decay_var_tensor, batch_epsilon=batch_epsilon, batch_to_space_op=batch_to_space_op)
def set_random_seed(seed): """Sets the graph-level random seed. Operations that rely on a random seed actually derive it from two seeds: the graph-level and operation-level seeds. This sets the graph-level seed. Its interactions with operation-level seeds is as follows: 1. If neither the graph-level nor the operation seed is set: A random seed is used for this op. 2. If the graph-level seed is set, but the operation seed is not: The system deterministically picks an operation seed in conjunction with the graph-level seed so that it gets a unique random sequence. 3. If the graph-level seed is not set, but the operation seed is set: A default graph-level seed and the specified operation seed are used to determine the random sequence. 4. If both the graph-level and the operation seed are set: Both seeds are used in conjunction to determine the random sequence. To illustrate the user-visible effects, consider these examples: To generate different sequences across sessions, set neither graph-level nor op-level seeds: ```python a = tf.random_uniform([1]) b = tf.random_normal([1]) print("Session 1") with tf.Session() as sess1: print(sess1.run(a)) # generates 'A1' print(sess1.run(a)) # generates 'A2' print(sess1.run(b)) # generates 'B1' print(sess1.run(b)) # generates 'B2' print("Session 2") with tf.Session() as sess2: print(sess2.run(a)) # generates 'A3' print(sess2.run(a)) # generates 'A4' print(sess2.run(b)) # generates 'B3' print(sess2.run(b)) # generates 'B4' ``` To generate the same repeatable sequence for an op across sessions, set the seed for the op: ```python a = tf.random_uniform([1], seed=1) b = tf.random_normal([1]) # Repeatedly running this block with the same graph will generate the same # sequence of values for 'a', but different sequences of values for 'b'. print("Session 1") with tf.Session() as sess1: print(sess1.run(a)) # generates 'A1' print(sess1.run(a)) # generates 'A2' print(sess1.run(b)) # generates 'B1' print(sess1.run(b)) # generates 'B2' print("Session 2") with tf.Session() as sess2: print(sess2.run(a)) # generates 'A1' print(sess2.run(a)) # generates 'A2' print(sess2.run(b)) # generates 'B3' print(sess2.run(b)) # generates 'B4' ``` To make the random sequences generated by all ops be repeatable across sessions, set a graph-level seed: ```python tf.set_random_seed(1234) a = tf.random_uniform([1]) b = tf.random_normal([1]) # Repeatedly running this block with the same graph will generate the same # sequences of 'a' and 'b'. print("Session 1") with tf.Session() as sess1: print(sess1.run(a)) # generates 'A1' print(sess1.run(a)) # generates 'A2' print(sess1.run(b)) # generates 'B1' print(sess1.run(b)) # generates 'B2' print("Session 2") with tf.Session() as sess2: print(sess2.run(a)) # generates 'A1' print(sess2.run(a)) # generates 'A2' print(sess2.run(b)) # generates 'B1' print(sess2.run(b)) # generates 'B2' ``` Args: seed: integer. """ if context.executing_eagerly(): context.set_global_seed(seed) else: ops.get_default_graph().seed = seed
def _is_old_cond(): return isinstance(ops.get_default_graph()._get_control_flow_context(), control_flow_ops.CondContext)
def add_weight(self, name, shape, dtype=None, initializer=None, regularizer=None, trainable=None, constraint=None, use_resource=None, synchronization=vs.VariableSynchronization.AUTO, aggregation=vs.VariableAggregation.NONE, partitioner=None, **kwargs): """Adds a new variable to the layer, or gets an existing one; returns it. Arguments: name: variable name. shape: variable shape. dtype: The type of the variable. Defaults to `self.dtype` or `float32`. initializer: initializer instance (callable). regularizer: regularizer instance (callable). trainable: whether the variable should be part of the layer's "trainable_variables" (e.g. variables, biases) or "non_trainable_variables" (e.g. BatchNorm mean, stddev). Note, if the current variable scope is marked as non-trainable then this parameter is ignored and any added variables are also marked as non-trainable. `trainable` defaults to `True` unless `synchronization` is set to `ON_READ`. constraint: constraint instance (callable). use_resource: Whether to use `ResourceVariable`. synchronization: Indicates when a distributed a variable will be aggregated. Accepted values are constants defined in the class `tf.VariableSynchronization`. By default the synchronization is set to `AUTO` and the current `DistributionStrategy` chooses when to synchronize. If `synchronization` is set to `ON_READ`, `trainable` must not be set to `True`. aggregation: Indicates how a distributed variable will be aggregated. Accepted values are constants defined in the class `tf.VariableAggregation`. partitioner: (optional) partitioner instance (callable). If provided, when the requested variable is created it will be split into multiple partitions according to `partitioner`. In this case, an instance of `PartitionedVariable` is returned. Available partitioners include `tf.compat.v1.fixed_size_partitioner` and `tf.compat.v1.variable_axis_size_partitioner`. For more details, see the documentation of `tf.compat.v1.get_variable` and the "Variable Partitioners and Sharding" section of the API guide. **kwargs: Additional keyword arguments. Returns: The created variable. Usually either a `Variable` or `ResourceVariable` instance. If `partitioner` is not `None`, a `PartitionedVariable` instance is returned. Raises: RuntimeError: If called with partitioned variable regularization and eager execution is enabled. ValueError: When trainable has been set to True with synchronization set as `ON_READ`. """ for kwarg in kwargs: if kwarg != 'experimental_autocast': raise TypeError('Unknown keyword argument:', kwarg) if self._keras_style: return super(Layer, self).add_weight( name=name, shape=shape, dtype=dtype, initializer=initializer, regularizer=regularizer, trainable=trainable and self.trainable, constraint=constraint, use_resource=use_resource, synchronization=vs.VariableSynchronization.AUTO, aggregation=vs.VariableAggregation.NONE, partitioner=partitioner, **kwargs) if synchronization == vs.VariableSynchronization.ON_READ: if trainable: raise ValueError( 'Synchronization value can be set to ' 'VariableSynchronization.ON_READ only for non-trainable variables. ' 'You have specified trainable=True and ' 'synchronization=VariableSynchronization.ON_READ.') else: # Set trainable to be false when variable is to be synced on read. trainable = False elif trainable is None: trainable = True def _should_add_regularizer(variable, existing_variable_set): if isinstance(variable, tf_variables.PartitionedVariable): for var in variable: if var in existing_variable_set: return False return True else: return variable not in existing_variable_set init_graph = None if not context.executing_eagerly(): default_graph = ops.get_default_graph() if default_graph.building_function: with ops.init_scope(): # Retrieve the variables from the graph into which variables # will be lifted; if initialization ops will be lifted into # the eager context, then there is nothing to retrieve, since variable # collections are not supported when eager execution is enabled. if not context.executing_eagerly(): init_graph = ops.get_default_graph() existing_variables = set(tf_variables.global_variables()) else: # Initialization ops will not be lifted out of the default graph. init_graph = default_graph existing_variables = set(tf_variables.global_variables()) if dtype is None: dtype = self.dtype or dtypes.float32 self._set_scope(None) reuse = self.built or self._reuse prev_len_trainable = len(self._trainable_weights) with vs.variable_scope( self._scope, reuse=reuse, auxiliary_name_scope=False) as scope: self._current_scope = scope with ops.name_scope(self._name_scope(), skip_on_eager=False): use_resource = (use_resource or self._use_resource_variables or scope.use_resource) if initializer is None: initializer = scope.initializer variable = super(Layer, self).add_weight( name, shape, dtype=dtypes.as_dtype(dtype), initializer=initializer, trainable=trainable and self.trainable, constraint=constraint, partitioner=partitioner, use_resource=use_resource, synchronization=synchronization, aggregation=aggregation, getter=vs.get_variable, **kwargs) if regularizer: if (ops.executing_eagerly_outside_functions() or _should_add_regularizer(variable, existing_variables)): self._handle_weight_regularization(name, variable, regularizer) if init_graph is not None: # Handle edge case where a custom getter has overridden `trainable`. # There is one known occurrence of this, in unit test # testBasicRNNCellNotTrainable in # contrib.rnn.python.kernel_tests.core_rnn_cell_test with init_graph.as_default(): trainable_variables = tf_variables.trainable_variables() if (trainable and self.trainable and variable not in trainable_variables): # A custom getter / variable scope overrode the trainable flag. extra_trainable_vars = self._trainable_weights[prev_len_trainable:] self._trainable_weights = self._trainable_weights[ :prev_len_trainable] self._non_trainable_weights += extra_trainable_vars return variable
def __init__( self, # pylint: disable=super-init-not-called initial_value=None, trainable=None, caching_device=None, name=None, dtype=None, constraint=None, add_initializers_to=None, lifted_initializer_graph=None, **unused_kwargs): """Creates a variable. Args: initial_value: A `Tensor`, or Python object convertible to a `Tensor`, which is the initial value for the Variable. The initial value must have a shape specified unless `validate_shape` is set to False. Can also be a callable with no argument that returns the initial value when called. (Note that initializer functions from init_ops.py must first be bound to a shape before being used here.) trainable: If `True`, GradientTapes automatically watch uses of this Variable. caching_device: Optional device string or function describing where the Variable should be cached for reading. Defaults to the Variable's device. If not `None`, caches on another device. Typical use is to cache on the device where the Ops using the Variable reside, to deduplicate copying through `Switch` and other conditional statements. name: Optional name for the variable. Defaults to `'Variable'` and gets uniquified automatically. dtype: If set, initial_value will be converted to the given type. If None, either the datatype will be kept (if initial_value is a Tensor) or float32 will be used (if it is a Python object convertible to a Tensor). constraint: An optional projection function to be applied to the variable after being updated by an `Optimizer` (e.g. used to implement norm constraints or value constraints for layer weights). The function must take as input the unprojected Tensor representing the value of the variable and return the Tensor for the projected value (which must have the same shape). Constraints are not safe to use when doing asynchronous distributed training. add_initializers_to: if not None and not in legacy graph mode, the initializer tensor will be added to this map in addition to adding the assignment to the function. lifted_initializer_graph: FuncGraph to try to lift initializers to. Raises: ValueError: If the initial value is not specified, or does not have a shape and `validate_shape` is `True`. RuntimeError: If called outside of a function definition. """ if not ops.inside_function(): # If we've been init_scope()d out of the function definition nothing to do # here; we can't really do the capturing or conditional logic. resource_variable_ops.ResourceVariable.__init__( self, initial_value=initial_value, trainable=trainable, caching_device=caching_device, name=name, dtype=dtype, constraint=constraint) return with ops.init_scope(): self._in_graph_mode = not context.executing_eagerly() if initial_value is None: raise ValueError("initial_value must be specified.") init_from_fn = callable(initial_value) if constraint is not None and not callable(constraint): raise ValueError("The `constraint` argument must be a callable.") if isinstance(initial_value, checkpointable.CheckpointInitialValue): self._maybe_initialize_checkpointable() self._update_uid = initial_value.checkpoint_position.restore_uid initial_value = initial_value.wrapped_value if trainable is None: trainable = True self._trainable = trainable self._save_slice_info = None self._initial_value = None self._initializer_op = None self._is_initialized_op = None self._graph_element = None self._cached_value = None # Store the graph key so optimizers know how to only retrieve variables from # this graph. Guaranteed to be the same as the eager graph_key. self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access with ops.name_scope(name, "Variable", [] if init_from_fn else [initial_value]) as name: # pylint: disable=protected-access with ops.init_scope(): handle_name = ops._name_from_scope_name(name) unique_id = "%s_%d" % (handle_name, ops.uid()) shared_name = context.shared_name(unique_id) with ops.name_scope("Initializer"), ops.device(None): initial_value = ops.convert_to_tensor( initial_value() if init_from_fn else initial_value, name="initial_value", dtype=dtype) with ops.init_scope(): self._handle = resource_variable_ops.eager_safe_variable_handle( initial_value=initial_value, shared_name=shared_name, name=name, graph_mode=self._in_graph_mode) self._shape = initial_value.shape self._unique_id = unique_id self._handle_name = handle_name + ":0" self._dtype = initial_value.dtype.base_dtype self._constraint = constraint assert initial_value is not None if self._in_graph_mode: with ops.init_scope(): outer_graph = ops.get_default_graph() func_graph = ops.get_default_graph() function_placeholders = (func_graph.inputs + func_graph.internal_captures) placeholder_ops = set( [tensor.op for tensor in function_placeholders]) lifted_initializer = lift_to_graph.lift_to_graph( [initial_value], outer_graph, disallowed_placeholders=placeholder_ops)[initial_value] with ops.init_scope(): self._initial_value = lifted_initializer with ops.name_scope("IsInitialized"): self._is_initialized_op = ( resource_variable_ops.var_is_initialized_op( self._handle)) if initial_value is not None: with ops.name_scope("Assign") as n, ops.colocate_with( self._handle): self._initializer_op = resource_variable_ops.assign_variable_op( self._handle, lifted_initializer, name=n) with ops.name_scope("Read"), ops.colocate_with( self._handle): # Manually assign reads to the handle's device to avoid log # messages. with ops.device(self._handle.device): value = self._read_variable_op() self._graph_element = value ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, self) else: if add_initializers_to is not None: add_initializers_to[self] = initial_value def assign_fn(): with ops.name_scope("Assign") as n, ops.colocate_with( self._handle): resource_variable_ops.assign_variable_op(self._handle, initial_value, name=n) # Returning values to keep tf.cond happy. return ops.convert_to_tensor(1) def not_assign_fn(): return ops.convert_to_tensor(0) # Note: this cond is always guaranteed to run because we're inside a # defun which will insert automatic control dependencies. control_flow_ops.cond( resource_variable_ops.var_is_initialized_op(self._handle), not_assign_fn, assign_fn) # After the handle has been created, set up a way to clean it up when # executing eagerly. We'll hold the only reference to the deleter, so that # when this object is garbage collected the deleter will be too. This # means ResourceVariables can be part of reference cycles without those # cycles being uncollectable. if not self._in_graph_mode: self._handle_deleter = resource_variable_ops.EagerResourceDeleter( handle=self._handle, handle_device=self._handle.device) self._cached_shape_as_list = None
def __enter__(self): self._g = ops.get_default_graph() self._old = self._g._get_control_flow_context() # pylint: disable=protected-access self._g._set_control_flow_context(self) # pylint: disable=protected-access
def constant(value, dtype=None, shape=None, name="Const", verify_shape=False): """Creates a constant tensor. The resulting tensor is populated with values of type `dtype`, as specified by arguments `value` and (optionally) `shape` (see examples below). The argument `value` can be a constant value, or a list of values of type `dtype`. If `value` is a list, then the length of the list must be less than or equal to the number of elements implied by the `shape` argument (if specified). In the case where the list length is less than the number of elements specified by `shape`, the last element in the list will be used to fill the remaining entries. The argument `shape` is optional. If present, it specifies the dimensions of the resulting tensor. If not present, the shape of `value` is used. If the argument `dtype` is not specified, then the type is inferred from the type of `value`. For example: ```python # Constant 1-D Tensor populated with value list. tensor = tf.constant([1, 2, 3, 4, 5, 6, 7]) => [1 2 3 4 5 6 7] # Constant 2-D tensor populated with scalar value -1. tensor = tf.constant(-1.0, shape=[2, 3]) => [[-1. -1. -1.] [-1. -1. -1.]] ``` Args: value: A constant value (or list) of output type `dtype`. dtype: The type of the elements of the resulting tensor. shape: Optional dimensions of resulting tensor. name: Optional name for the tensor. verify_shape: Boolean that enables verification of a shape of values. Returns: A Constant Tensor. Raises: TypeError: if shape is incorrectly specified or unsupported. """ ctx = context.context() if not ctx.in_graph_mode(): t = convert_to_eager_tensor(value, ctx, dtype) if shape is None: return t shape = tensor_shape.as_shape(shape) if shape == t.shape: return t if verify_shape: raise TypeError("Expected Tensor's shape: %s, got %s." % (tuple(shape), tuple(t.shape))) num_t = t.shape.num_elements() # TODO(josh11b): Implement shape -> eager tensor conversion. if num_t == shape.num_elements(): return _eager_reshape(t, shape.as_list(), ctx) if num_t == 1: if t.dtype == dtypes.bool: # We don't have a Fill kernel for bool dtype on GPU. So we first run # Fill on CPU and then copy to GPU if needed. with ops.device("/device:CPU:0"): x = _eager_fill(shape.as_list(), t.cpu(), ctx) return _eager_identity(x, ctx) else: return _eager_fill(shape.as_list(), t, ctx) raise TypeError( "Eager execution of tf.constant with unsupported shape " "(value has %d elements, shape is %s with %d elements)." % (num_t, shape, shape.num_elements())) g = ops.get_default_graph() tensor_value = attr_value_pb2.AttrValue() tensor_value.tensor.CopyFrom( tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape, verify_shape=verify_shape)) dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype) const_tensor = g.create_op("Const", [], [dtype_value.type], attrs={ "value": tensor_value, "dtype": dtype_value }, name=name).outputs[0] return const_tensor
def __init__(self, input_dataset, target_device, source_device="/cpu:0"): """Constructs a _CopyToDeviceDataset. Args: input_dataset: `Dataset` to be copied target_device: The name of the device to which elements would be copied. source_device: Device where input_dataset would be placed. """ self._input_dataset = input_dataset self._target_device = target_device spec = framework_device.DeviceSpec().from_string(self._target_device) self._is_gpu_target = (spec.device_type == "GPU") self._source_device_string = source_device self._source_device = ops.convert_to_tensor(source_device) wrap_ds_variant = gen_dataset_ops.wrap_dataset_variant( self._input_dataset._variant_tensor) # pylint: disable=protected-access @function.defun() def _init_func(): """Creates an iterator for the input dataset. Returns: A `string` tensor that encapsulates the iterator created. """ ds_variant = gen_dataset_ops.unwrap_dataset_variant( wrap_ds_variant) resource = gen_dataset_ops.anonymous_iterator( **dataset_ops.flat_structure(self._input_dataset)) with ops.control_dependencies( [gen_dataset_ops.make_iterator(ds_variant, resource)]): return gen_dataset_ops.iterator_to_string_handle(resource) init_func_concrete = _init_func._get_concrete_function_internal() # pylint: disable=protected-access @function.defun() def _remote_init_func(): return functional_ops.remote_call( target=self._source_device, args=init_func_concrete.captured_inputs, Tout=[dtypes.string], f=init_func_concrete) self._init_func = _remote_init_func._get_concrete_function_internal() # pylint: disable=protected-access self._init_captured_args = self._init_func.captured_inputs @function.defun( input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) def _next_func(string_handle): """Calls get_next for created iterator. Args: string_handle: An iterator string handle created by _init_func Returns: The elements generated from `input_dataset` """ with ops.device(self._source_device_string): iterator = iterator_ops.Iterator.from_string_handle( string_handle, dataset_ops.get_legacy_output_types(self), dataset_ops.get_legacy_output_shapes(self), dataset_ops.get_legacy_output_classes(self)) return self._element_structure._to_tensor_list(iterator.get_next()) # pylint: disable=protected-access next_func_concrete = _next_func._get_concrete_function_internal() # pylint: disable=protected-access @function.defun_with_attributes( input_signature=[tensor_spec.TensorSpec([], dtypes.string)], attributes={"experimental_ints_on_device": True}) def _remote_next_func(string_handle): return functional_ops.remote_call( target=self._source_device, args=[string_handle] + next_func_concrete.captured_inputs, Tout=self._input_dataset._element_structure._flat_types, # pylint: disable=protected-access f=next_func_concrete) self._next_func = _remote_next_func._get_concrete_function_internal() # pylint: disable=protected-access self._next_captured_args = self._next_func.captured_inputs @function.defun( input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) def _finalize_func(string_handle): """Destroys the iterator resource created. Args: string_handle: An iterator string handle created by _init_func Returns: Tensor constant 0 """ iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( string_handle, **dataset_ops.flat_structure(self._input_dataset)) with ops.control_dependencies([ resource_variable_ops.destroy_resource_op( iterator_resource, ignore_lookup_error=True) ]): return array_ops.constant(0, dtypes.int64) finalize_func_concrete = _finalize_func._get_concrete_function_internal( ) # pylint: disable=protected-access @function.defun( input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) def _remote_finalize_func(string_handle): return functional_ops.remote_call( target=self._source_device, args=[string_handle] + finalize_func_concrete.captured_inputs, Tout=[dtypes.int64], f=finalize_func_concrete) self._finalize_func = _remote_finalize_func._get_concrete_function_internal( # pylint: disable=protected-access ) self._finalize_captured_args = self._finalize_func.captured_inputs g = ops.get_default_graph() self._init_func.add_to_graph(g) self._next_func.add_to_graph(g) self._finalize_func.add_to_graph(g) # pylint: enable=protected-scope with ops.device(self._target_device): variant_tensor = gen_dataset_ops.generator_dataset( self._init_captured_args, self._next_captured_args, self._finalize_captured_args, init_func=self._init_func, next_func=self._next_func, finalize_func=self._finalize_func, **dataset_ops.flat_structure(self._input_dataset)) super(_CopyToDeviceDataset, self).__init__(input_dataset, variant_tensor)
def _defun_internal(name, func, args, kwds): """Defines and returns graph-mode version of func.""" container_prefix = ops.get_default_graph()._container_prefix # pylint: disable=protected-access with context.graph_mode(): captures = {} tmp_graph = CapturingGraph(captures) # Inherit the container prefix, since this is used for error checking when # isolating eager execution (the container prefix at creation must match the # container prefix when used, and variables accessed in the defun will be # used in the outside context). tmp_graph._container_prefix = container_prefix # pylint: disable=protected-access # Copy the graph collections to ensure summaries and other things work. This # lets the function access (but not mutate) collections of the containing # graph, such as the global step and the summary writer collections. curr_graph = ops.get_default_graph() for collection in curr_graph.collections: tmp_graph.get_collection_ref( collection)[:] = curr_graph.get_collection(collection) with tmp_graph.as_default(): func_inputs = _get_defun_inputs(args) with capture_tensors(captures): tape.push_new_tape() try: func_outputs = func(*func_inputs, **kwds) finally: variables = tape.pop_tape().watched_variables() # Returning a closed-over tensor as an output does not trigger a # call to convert_to_tensor, so we manually capture all such tensors. outputs_list = nest.flatten(func_outputs) func_def_outputs = [ _convert_to_graph_tensor(x) for x in outputs_list if x is not None ] ids = list(sorted(captures.keys())) if ids: extra_inputs, extra_placeholders = zip( *[captures[x] for x in ids]) else: extra_inputs = [] extra_placeholders = [] output_shapes = tuple( x.shape if isinstance(x, ops.Tensor) else None for x in outputs_list) flat_inputs = [ x for x in nest.flatten(func_inputs) if isinstance(x, ops.Tensor) ] all_inputs = flat_inputs + list(extra_placeholders) all_ignored_ops = frozenset(x.op for x in all_inputs) fname = _inference_name(name) operations = tuple(x for x in tmp_graph.get_operations() if x not in all_ignored_ops) # Register any other functions defined in the graph # TODO(ashankar): Oh lord, forgive me for this lint travesty. if context.in_eager_mode(): for f in tmp_graph._functions.values(): # pylint: disable=protected-access # TODO(ashankar): What about the gradient registry? _register(f._c_func) # pylint: disable=protected-access return GraphModeFunction(fname, all_inputs, extra_inputs, tmp_graph, operations, func_def_outputs, func_outputs, output_shapes, variables)
def compute_weighted_loss( losses, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES, reduction=Reduction.SUM_BY_NONZERO_WEIGHTS): """Computes the weighted loss. Args: losses: `Tensor` of shape `[batch_size, d1, ... dN]`. weights: Optional `Tensor` whose rank is either 0, or the same rank as `losses`, and must be broadcastable to `losses` (i.e., all dimensions must be either `1`, or the same as the corresponding `losses` dimension). scope: the scope for the operations performed in computing the loss. loss_collection: the loss will be added to these collections. reduction: Type of reduction to apply to loss. Returns: Weighted loss `Tensor` of the same type as `losses`. If `reduction` is `NONE`, this has the same shape as `losses`; otherwise, it is scalar. Raises: ValueError: If `weights` is `None` or the shape is not compatible with `losses`, or if the number of dimensions (rank) of either `losses` or `weights` is missing. Note: When calculating the gradient of a weighted loss contributions from both `losses` and `weights` are considered. If your `weights` depend on some model parameters but you do not want this to affect the loss gradient, you need to apply `tf.stop_gradient` to `weights` before passing them to `compute_weighted_loss`. @compatibility(eager) The `loss_collection` argument is ignored when executing eagerly. Consider holding on to the return value or collecting losses via a `tf.keras.Model`. @end_compatibility """ Reduction.validate(reduction) with ops.name_scope(scope, "weighted_loss", (losses, weights)): # Save the `reduction` argument for loss normalization when distributing # to multiple replicas. # TODO(josh11b): Associate it with the returned op for more precision. ops.get_default_graph()._last_loss_reduction = reduction # pylint: disable=protected-access with ops.control_dependencies(( weights_broadcast_ops.assert_broadcastable(weights, losses),)): losses = ops.convert_to_tensor(losses) input_dtype = losses.dtype losses = math_ops.to_float(losses) weights = math_ops.to_float(weights) weighted_losses = math_ops.multiply(losses, weights) if reduction == Reduction.NONE: loss = weighted_losses else: loss = math_ops.reduce_sum(weighted_losses) if reduction == Reduction.MEAN: loss = _safe_mean( loss, math_ops.reduce_sum(array_ops.ones_like(losses) * weights)) elif (reduction == Reduction.SUM_BY_NONZERO_WEIGHTS or reduction == Reduction.SUM_OVER_NONZERO_WEIGHTS): loss = _safe_mean(loss, _num_present(losses, weights)) elif reduction == Reduction.SUM_OVER_BATCH_SIZE: loss = _safe_mean(loss, _num_elements(losses)) # Convert the result back to the input type. loss = math_ops.cast(loss, input_dtype) util.add_loss(loss, loss_collection) return loss
def _apply_op_helper(self, op_type_name, name=None, **keywords): """Implementation of apply_op that returns output_structure, op.""" op_info = self._ops.get(op_type_name, None) if op_info is None: raise RuntimeError("Unrecognized Op name " + op_type_name) op_def = op_info.op_def # Determine the graph context. try: # Need to flatten all the arguments into a list. # pylint: disable=protected-access g = ops._get_graph_from_inputs(_Flatten(keywords.values())) # pylint: enable=protected-access except AssertionError as e: raise RuntimeError( "Cannot determine graph for Op '%s' due to: %s" % (op_type_name, e.message)) # Default name if not specified. if name is None: name = op_type_name # Check for deprecation deprecation_version = op_def.deprecation.version if deprecation_version: producer = g.graph_def_versions.producer if producer >= deprecation_version: raise NotImplementedError( ("Op %s is not available in GraphDef version %d. " "It has been removed in version %d. %s.") % (op_type_name, producer, deprecation_version, op_def.deprecation.explanation)) # Fill in the list of default types for all "type" attrs. This # will be used to choose a preferred dtype to convert to in the # absence of input type information. # # TODO(b/31302892): Currently the defaults don't work in the right # way if you have two inputs, one of whose type resolution depends # on the other. Handling this will require restructuring this code # significantly. default_type_attr_map = {} for attr_def in op_def.attr: if attr_def.type != "type": continue key = attr_def.name if attr_def.HasField("default_value"): default_type_attr_map[key] = dtypes.as_dtype( attr_def.default_value.type) # Requires that op_def has passed validation (using the C++ # ValidateOpDef() from ../framework/op_def_util.h). attrs = {} inputs = [] input_types = [] with g.as_default(), ops.name_scope(name) as scope: # Perform input type inference inferred_from = {} for input_arg in op_def.input_arg: input_name = input_arg.name if input_name in keywords: values = keywords.pop(input_name) elif input_name + "_" in keywords: # Handle the case where the name is a keyword or built-in # for Python so we use the name + _ instead. input_name += "_" values = keywords.pop(input_name) else: raise TypeError("No argument for input " + input_name) # Goals: # * Convert values to Tensors if it contains constants. # * Verify that values is a list if that matches the input_arg's # type. # * If the input_arg's type is determined by attrs, either set # those attrs and validate those attr values are legal (if # they have not yet been set) or validate the input matches # the type indicated by the attrs (if they have already been # inferred via an earlier input). # * If the input_arg has an explicit type, make sure the input # conforms. if _IsListParameter(input_arg): if not _IsListValue(values): raise TypeError( "Expected list for '%s' argument to '%s' Op, not %s." % (input_name, op_type_name, values)) # In cases where we expect all elements of the list to have the # same dtype, try to cast non-Tensor elements to that type. dtype = None default_dtype = None if input_arg.type != types_pb2.DT_INVALID: dtype = input_arg.type elif input_arg.number_attr: if input_arg.type_attr in attrs: dtype = attrs[input_arg.type_attr] else: for t in values: if isinstance(t, ops.Tensor): dtype = t.dtype break # dtype still not found, prefer using the default dtype # from the attr. if dtype is None and input_arg.type_attr in default_type_attr_map: default_dtype = default_type_attr_map[input_arg.type_attr] try: if not input_arg.is_ref and dtype: dtype = dtypes.as_dtype(dtype).base_dtype values = ops.internal_convert_n_to_tensor( values, name=input_arg.name, dtype=dtype if dtype else None, preferred_dtype=default_dtype, as_ref=input_arg.is_ref) if input_arg.number_attr and len( set(v.dtype.base_dtype for v in values)) > 1: raise TypeError() # All types should match. except (TypeError, ValueError): # What types does the conversion function think values have? observed_types = [] for value in values: try: converted_value = ops.internal_convert_to_tensor( value, as_ref=input_arg.is_ref) observed_types.append(converted_value.dtype.base_dtype.name) except (TypeError, ValueError): observed_types.append("<NOT CONVERTIBLE TO TENSOR>") observed = ", ".join(observed_types) prefix = ( "Tensors in list passed to '%s' of '%s' Op have types [%s]" % (input_name, op_type_name, observed)) if input_arg.number_attr: if input_arg.type != types_pb2.DT_INVALID: raise TypeError("%s that do not match expected type %s." % (prefix, dtype.name)) elif input_arg.type_attr in attrs: raise TypeError("%s that do not match type %s inferred from " "earlier arguments." % (prefix, dtype.name)) else: raise TypeError("%s that don't all match." % prefix) else: raise TypeError("%s that are invalid." % prefix) types = [x.dtype for x in values] inputs.extend(values) else: # In cases where we have an expected type, try to convert non-Tensor # arguments to that type. dtype = None default_dtype = None if input_arg.type != types_pb2.DT_INVALID: dtype = input_arg.type elif input_arg.type_attr in attrs: dtype = attrs[input_arg.type_attr] elif input_arg.type_attr in default_type_attr_map: # The dtype could not be inferred solely from the inputs, # so we prefer the attr's default, so code that adds a new attr # with a default is backwards compatible. default_dtype = default_type_attr_map[input_arg.type_attr] try: values = ops.internal_convert_to_tensor( values, name=input_arg.name, dtype=dtype, as_ref=input_arg.is_ref, preferred_dtype=default_dtype) except TypeError as err: if dtype is None: raise err else: raise TypeError( "Expected %s passed to parameter '%s' of op '%s', got %s of " "type '%s' instead." % (dtypes.as_dtype(dtype).name, input_arg.name, op_type_name, repr(values), type(values).__name__)) except ValueError: # What type does convert_to_tensor think it has? try: observed = ops.internal_convert_to_tensor( values, as_ref=input_arg.is_ref).dtype.name except ValueError as err: raise ValueError( "Tried to convert '%s' to a tensor and failed. Error: %s" % (input_name, err)) prefix = ("Input '%s' of '%s' Op has type %s that does not match" % (input_name, op_type_name, observed)) if input_arg.type != types_pb2.DT_INVALID: raise TypeError("%s expected type of %s." % (prefix, dtypes.as_dtype(input_arg.type).name)) else: # Update the maps with the default, if needed. k = input_arg.type_attr if k in default_type_attr_map: if k not in attrs: attrs[k] = default_type_attr_map[k] if k not in inferred_from: inferred_from[k] = "Default in OpDef" raise TypeError( "%s type %s of argument '%s'." % (prefix, dtypes.as_dtype(attrs[input_arg.type_attr]).name, inferred_from[input_arg.type_attr])) types = [values.dtype] inputs.append(values) base_types = [x.base_dtype for x in types] if input_arg.number_attr: # <number-attr> * <type> or <number-attr> * <type-attr> if input_arg.number_attr in attrs: if len(values) != attrs[input_arg.number_attr]: raise ValueError( "List argument '%s' to '%s' Op with length %d must match " "length %d of argument '%s'." % (input_name, op_type_name, len(values), attrs[input_arg.number_attr], inferred_from[input_arg.number_attr])) else: attrs[input_arg.number_attr] = len(values) inferred_from[input_arg.number_attr] = input_name num_attr = _Attr(op_def, input_arg.number_attr) if num_attr.has_minimum and len(values) < num_attr.minimum: raise ValueError( "List argument '%s' to '%s' Op with length %d shorter " "than minimum length %d." % (input_name, op_type_name, len(values), num_attr.minimum)) # All tensors must have the same base type. if any([bt != base_types[0] for bt in base_types]): raise TypeError( "All tensors passed to '%s' of '%s' Op " "must have the same type." % (input_name, op_type_name)) if input_arg.type != types_pb2.DT_INVALID: # <number-attr> * <type> case if base_types and base_types[0] != input_arg.type: assert False, "Unreachable" elif input_arg.type_attr in attrs: # <number-attr> * <type-attr> case, where <type-attr> already # has an inferred value. if base_types and base_types[0] != attrs[input_arg.type_attr]: assert False, "Unreachable" else: # <number-attr> * <type-attr> case, where we are now setting # the <type-attr> based on this input if not base_types: raise TypeError( "Don't know how to infer type variable from empty input " "list passed to input '%s' of '%s' Op." % (input_name, op_type_name)) attrs[input_arg.type_attr] = base_types[0] inferred_from[input_arg.type_attr] = input_name type_attr = _Attr(op_def, input_arg.type_attr) _SatisfiesTypeConstraint(base_types[0], type_attr, param_name=input_name) elif input_arg.type_attr: # <type-attr> attr_value = base_types[0] if input_arg.type_attr in attrs: if attrs[input_arg.type_attr] != attr_value: assert False, "Unreachable" else: for base_type in base_types: _SatisfiesTypeConstraint(base_type, _Attr(op_def, input_arg.type_attr), param_name=input_name) attrs[input_arg.type_attr] = attr_value inferred_from[input_arg.type_attr] = input_name elif input_arg.type_list_attr: # <type-list-attr> attr_value = base_types if input_arg.type_list_attr in attrs: if attrs[input_arg.type_list_attr] != attr_value: raise TypeError( "Input '%s' of '%s' Op has type list of %s that does not " "match type list %s of argument '%s'." % (input_name, op_type_name, ", ".join(dtypes.as_dtype(x).name for x in attr_value), ", ".join(dtypes.as_dtype(x).name for x in attrs[input_arg.type_list_attr]), inferred_from[input_arg.type_list_attr])) else: for base_type in base_types: _SatisfiesTypeConstraint(base_type, _Attr(op_def, input_arg.type_list_attr), param_name=input_name) attrs[input_arg.type_list_attr] = attr_value inferred_from[input_arg.type_list_attr] = input_name else: # single Tensor with specified type if base_types[0] != input_arg.type: assert False, "Unreachable" if input_arg.is_ref: if not all(x._is_ref_dtype for x in types): # pylint: disable=protected-access raise TypeError( ("'%s' Op requires that input '%s' be a mutable tensor " "(e.g.: a tf.Variable)") % (op_type_name, input_name)) input_types.extend(types) else: input_types.extend(base_types) # Process remaining attrs for attr in op_def.attr: # Skip attrs that have already had their values inferred if attr.name in attrs: if attr.name in keywords: raise TypeError( "Should not specify value for inferred attr '%s'." % attr.name) continue if attr.name in keywords: attrs[attr.name] = keywords.pop(attr.name) elif attr.name + "_" in keywords: # Attrs whose names match Python keywords have an extra '_' # appended, so we must check for that as well. attrs[attr.name] = keywords.pop(attr.name + "_") else: raise TypeError("No argument for attr " + attr.name) # Convert attr values to AttrValue protos. attr_protos = {} for attr_def in op_def.attr: key = attr_def.name value = attrs[key] attr_value = attr_value_pb2.AttrValue() if attr_def.HasField("default_value") and value is None: attr_value.CopyFrom(attr_def.default_value) attr_protos[key] = attr_value continue if attr_def.type.startswith("list("): if not _IsListValue(value): raise TypeError("Expected list for attr " + key) if attr_def.has_minimum: if len(value) < attr_def.minimum: raise ValueError("Attr '%s' of '%s' Op passed list of length %d " "less than minimum %d." % (key, op_type_name, len(value), attr_def.minimum)) attr_value.list.SetInParent() if attr_def.type == "string": attr_value.s = _MakeStr(value, key) if attr_def.HasField("allowed_values"): if attr_value.s not in attr_def.allowed_values.list.s: raise ValueError( "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % (key, op_type_name, compat.as_text(attr_value.s), '", "'.join(map(compat.as_text, attr_def.allowed_values.list.s)))) elif attr_def.type == "list(string)": attr_value.list.s.extend([_MakeStr(x, key) for x in value]) if attr_def.HasField("allowed_values"): for x in attr_value.list.s: if x not in attr_def.allowed_values.list.s: raise ValueError( "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % (key, op_type_name, compat.as_text(x), '", "'.join(map(compat.as_text, attr_def.allowed_values.list.s)))) elif attr_def.type == "int": attr_value.i = _MakeInt(value, key) if attr_def.has_minimum: if attr_value.i < attr_def.minimum: raise ValueError( "Attr '%s' of '%s' Op passed %d less than minimum %d." % (key, op_type_name, attr_value.i, attr_def.minimum)) elif attr_def.type == "list(int)": attr_value.list.i.extend([_MakeInt(x, key) for x in value]) elif attr_def.type == "float": attr_value.f = _MakeFloat(value, key) elif attr_def.type == "list(float)": attr_value.list.f.extend([_MakeFloat(x, key) for x in value]) elif attr_def.type == "bool": attr_value.b = _MakeBool(value, key) elif attr_def.type == "list(bool)": attr_value.list.b.extend([_MakeBool(x, key) for x in value]) elif attr_def.type == "type": attr_value.type = _MakeType(value, attr_def) elif attr_def.type == "list(type)": attr_value.list.type.extend( [_MakeType(x, attr_def) for x in value]) elif attr_def.type == "shape": attr_value.shape.CopyFrom(_MakeShape(value, key)) elif attr_def.type == "list(shape)": attr_value.list.shape.extend( [_MakeShape(x, key) for x in value]) elif attr_def.type == "tensor": attr_value.tensor.CopyFrom(_MakeTensor(value, key)) elif attr_def.type == "list(tensor)": attr_value.list.tensor.extend( [_MakeTensor(x, key) for x in value]) elif attr_def.type == "func": if isinstance(value, attr_value_pb2.NameAttrList): attr_value.func.CopyFrom(value) elif isinstance(value, compat.bytes_or_text_types): attr_value.func.name = value else: value.add_to_graph(ops.get_default_graph()) attr_value.func.name = value.name else: raise TypeError("Unrecognized Attr type " + attr_def.type) attr_protos[key] = attr_value del attrs # attrs is no longer authoritative, use attr_protos instead # Determine output types (possibly using attrs) output_types = [] output_structure = [] for arg in op_def.output_arg: types = [] if arg.number_attr: n = _AttrValue(attr_protos, arg.number_attr).i if arg.type_attr: types = [_AttrValue(attr_protos, arg.type_attr).type] * n else: types = [arg.type] * n output_structure.append(n) elif arg.type_attr: t = _AttrValue(attr_protos, arg.type_attr) types = [t.type] output_structure.append(None) elif arg.type_list_attr: t = _AttrValue(attr_protos, arg.type_list_attr) types = t.list.type output_structure.append(len(types)) else: types = [arg.type] output_structure.append(None) if arg.is_ref: types = [dtypes.as_dtype(x)._as_ref for x in types] # pylint: disable=protected-access output_types.extend(types) if keywords: raise TypeError("apply_op() got unexpected keyword arguments: " + ", ".join(sorted(keywords.keys()))) # NOTE(mrry): We add an explicit colocation constraint between # the newly created op and any of its reference-typed inputs. must_colocate_inputs = [val for arg, val in zip(op_def.input_arg, inputs) if arg.is_ref] with _MaybeColocateWith(must_colocate_inputs): # Add Op to graph op = g.create_op(op_type_name, inputs, output_types, name=scope, input_types=input_types, attrs=attr_protos, op_def=op_def) return output_structure, op_def.is_stateful, op
def _test(args): #pdb.set_trace(); args.solver.master = '' container_name = "" checkpoint_dir = os.path.join(format(args.logdir)) logging.error('Checkpoint_dir: %s', args.logdir) config = tf.ConfigProto() config.device_count['GPU'] = 1 m = utils.Foo() m.tf_graph = tf.Graph() rng_data_seed = 0 rng_action_seed = 0 R = lambda: nav_env.get_multiplexer_class(args.navtask, rng_data_seed) with m.tf_graph.as_default(): with tf.container(container_name): m = args.setup_to_run(m, args, is_training=False, batch_norm_is_training=args.control. force_batchnorm_is_training_at_test, summary_mode=args.control.test_mode) train_step_kwargs = args.setup_train_step_kwargs( m, R(), os.path.join(args.logdir, args.control.test_name), rng_seed=rng_data_seed, is_chief=True, num_steps=args.navtask.task_params.num_steps * args.navtask.task_params.num_goals, iters=args.summary.test_iters, train_display_interval=None, dagger_sample_bn_false=args.arch.dagger_sample_bn_false) saver = slim.learning.tf_saver.Saver( variables.get_variables_to_restore()) sv = slim.learning.supervisor.Supervisor( graph=ops.get_default_graph(), logdir=None, init_op=m.init_op, summary_op=None, summary_writer=None, global_step=None, saver=m.saver_op) last_checkpoint = None reported = False while True: last_checkpoint_ = None while last_checkpoint_ is None: last_checkpoint_ = slim.evaluation.wait_for_new_checkpoint( checkpoint_dir, last_checkpoint, seconds_to_sleep=10, timeout=60) if last_checkpoint_ is None: break last_checkpoint = last_checkpoint_ checkpoint_iter = int( os.path.basename(last_checkpoint).split('-')[1]) logging.info( 'Starting evaluation at %s using checkpoint %s.', time.strftime('%Y-%m-%d-%H:%M:%S', time.localtime()), last_checkpoint) if (args.control.only_eval_when_done == False or checkpoint_iter >= args.solver.max_steps): start = time.time() logging.info( 'Starting evaluation at %s using checkpoint %s.', time.strftime('%Y-%m-%d-%H:%M:%S', time.localtime()), last_checkpoint) with sv.managed_session( args.solver.master, config=config, start_standard_services=False) as sess: sess.run(m.init_op) sv.saver.restore(sess, last_checkpoint) sv.start_queue_runners(sess) if args.control.reset_rng_seed: train_step_kwargs['rng_data'] = [ np.random.RandomState(rng_data_seed), np.random.RandomState(rng_data_seed) ] train_step_kwargs[ 'rng_action'] = np.random.RandomState( rng_action_seed) vals, _ = tf_utils.train_step_custom_online_sampling( sess, None, m.global_step_op, train_step_kwargs, mode=args.control.test_mode) should_stop = False if checkpoint_iter >= args.solver.max_steps: should_stop = True if should_stop: break
def __init__(self, dataset=None, devices=None, max_buffer_size=1, prefetch_buffer_size=1, source_device="/cpu:0", components=None, element_spec=None): """Constructs an owned MultiDeviceIterator object. Args: dataset: The input dataset to be iterated over. devices: The list of devices to fetch data to. max_buffer_size: Maximum size of the host side per device buffer to keep. prefetch_buffer_size: if > 1, then we setup a buffer on each device to prefetch into. source_device: The host device to place the `dataset` on. In order to prevent deadlocks, if the prefetch_buffer_size is greater than the max_buffer_size, we set the max_buffer_size to prefetch_buffer_size. components: Tensor components to construct the MultiDeviceIterator from. element_spec: A nested structure of `TypeSpec` objects that represents the type specification of elements of the iterator. Raises: RuntimeError: If executed in graph mode or outside of function building mode. """ if (not context.executing_eagerly() and not ops.get_default_graph()._building_function): # pylint: disable=protected-access raise RuntimeError( "OwnedMultiDeviceIterator is only supported inside of " "tf.function or when eager execution is enabled.") if devices is None: raise ValueError("`devices` must be provided") error_message = "Either `dataset` or both `components` and " "`element_spec` need to be provided." if dataset is None: if (components is None or element_spec is None): raise ValueError(error_message) self._element_spec = element_spec self._devices = devices self._source_device = source_device self._multi_device_iterator_resource = components[0] self._deleter = components[1] self._device_iterators = components[2:] iterator_handles = [] for it in self._device_iterators: iterator_handles.append(it._iterator_resource) # pylint: disable=protected-access else: if (components is not None or element_spec is not None): raise ValueError(error_message) options = dataset_ops.Options() options.experimental_distribute.num_devices = len(devices) dataset = dataset.with_options(options) dataset = dataset._apply_options() # pylint: disable=protected-access self._element_spec = dataset.element_spec experimental_slack = dataset.options().experimental_slack self._devices = devices self._source_device = source_device source_device_tensor = ops.convert_to_tensor(self._source_device) if prefetch_buffer_size > max_buffer_size: max_buffer_size = prefetch_buffer_size # Create the MultiDeviceIterator. with ops.device(self._source_device): self._multi_device_iterator_resource, self._deleter = ( gen_dataset_ops.anonymous_multi_device_iterator( devices=self._devices, **dataset._flat_structure)) # pylint: disable=protected-access # The incarnation ID is used to ensure consistency between the # per-device iterators and the multi-device iterator. incarnation_id = gen_dataset_ops.multi_device_iterator_init( dataset._variant_tensor, # pylint: disable=protected-access self._multi_device_iterator_resource, max_buffer_size=max_buffer_size) prototype_device_datasets = [] for i, device in enumerate(self._devices): with ops.device(device): ds = _PerDeviceGenerator( i, self._multi_device_iterator_resource, incarnation_id, source_device_tensor, dataset.element_spec) prototype_device_datasets.append(ds) # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to # initialize the device side of the pipeline. This would allow the # MultiDeviceIterator to choose, for example, to move some transformations # into the device side from its input. It might be useful in rewriting. # Create the per device iterators. self._device_iterators = [] iterator_handles = [] for i, device in enumerate(self._devices): with ops.device(device): ds = _create_device_dataset(prototype_device_datasets[i], incarnation_id, prefetch_buffer_size, experimental_slack) iterator = iter(ds) self._device_iterators.append(iterator) iterator_handles.append(iterator._iterator_resource) # pylint: disable=protected-access self._resource_deleter = MultiDeviceIteratorResourceDeleter( multi_device_iterator=self._multi_device_iterator_resource, iterators=iterator_handles, device=self._source_device, deleter=self._deleter)
def _clear_graph_state(): # Clearing the Graph collection will prevent _PerGraphState from being # serialized. ops_lib.get_default_graph().clear_collection( TowerOptimizer.COLLECTION_FOR_GRAPH_STATES)
def _experimental_run_steps_on_iterator(self, fn, multi_worker_iterator, iterations, initial_loop_values=None): # Wrap `fn` for repeat. if initial_loop_values is None: initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) ctx = input_lib.MultiStepContext() def run_fn(inputs): """Single step on the TPU device.""" fn_result = fn(ctx, inputs) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) if flat_last_step_outputs: with ops.control_dependencies([fn_result]): return [ array_ops.identity(f) for f in flat_last_step_outputs ] else: return fn_result # We capture the control_flow_context at this point, before we run `fn` # inside a while_loop and TPU replicate context. This is useful in cases # where we might need to exit these contexts and get back to the outer # context to do some things, for e.g. create an op which should be # evaluated only once at the end of the loop on the host. One such usage # is in creating metrics' value op. self._outer_control_flow_context = ( ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access def rewrite_fn(*args): """The rewritten step fn running on TPU.""" del args per_replica_inputs = multi_worker_iterator.get_next() replicate_inputs = [] for replica_id in range(self._num_replicas_in_sync): select_replica = lambda x: values.select_replica(replica_id, x) # pylint: disable=cell-var-from-loop replicate_inputs.append( (nest.map_structure(select_replica, per_replica_inputs), )) replicate_outputs = tpu.replicate(run_fn, replicate_inputs) # If run_fn has tensor outputs, tpu.replicate returns a list of list. We # will flatten it in this case. If run_fn has no tensor outputs, # tpu.replicate returns a list of no_ops, we will keep the output as it # is. if isinstance(replicate_outputs[0], list): replicate_outputs = nest.flatten(replicate_outputs) return replicate_outputs # TODO(sourabhbajaj): The input to while loop should be based on the # output type of the step_fn assert isinstance(initial_loop_values, list) initial_loop_values = initial_loop_values * self._num_replicas_in_sync # Put the while loop op on TPU host 0. with ops.device(self._host_device): if self.steps_per_run == 1: replicate_outputs = rewrite_fn() else: replicate_outputs = training_loop.repeat( iterations, rewrite_fn, initial_loop_values) del self._outer_control_flow_context ctx.run_op = control_flow_ops.group(replicate_outputs) if isinstance(replicate_outputs, list): # Filter out any ops from the outputs, typically this would be the case # when there were no tensor outputs. last_step_tensor_outputs = [ x for x in replicate_outputs if not isinstance(x, ops.Operation) ] # Outputs are currently of the structure (flattened) # [output0_device0, output1_device0, output2_device0, # output0_device1, output1_device1, output2_device1, # ...] # Convert this to the following structure instead: (grouped by output) # [[output0_device0, output0_device1], # [output1_device0, output1_device1], # [output2_device0, output2_device1]] output_num = len( last_step_tensor_outputs) // self._num_replicas_in_sync last_step_tensor_outputs = [ last_step_tensor_outputs[i::output_num] for i in range(output_num) ] else: # no tensors returned. last_step_tensor_outputs = [] _set_last_step_outputs(ctx, last_step_tensor_outputs) return ctx