Example #1
0
def _defun_internal(name, func, args, kwds):
  """Defines and returns graph-mode version of func."""
  graph_key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
  with context.graph_mode():
    captures = {}
    tmp_graph = CapturingGraph(captures)
    # Inherit the graph key, since this is used for matching variables in
    # optimizers.
    tmp_graph._graph_key = graph_key  # pylint: disable=protected-access
    # Copy the graph collections to ensure summaries and other things work. This
    # lets the function access (but not mutate) collections of the containing
    # graph, such as the global step and the summary writer collections.
    curr_graph = ops.get_default_graph()
    for collection in curr_graph.collections:
      tmp_graph.get_collection_ref(collection)[:] = curr_graph.get_collection(
          collection)
    with tmp_graph.as_default():
      func_inputs = _get_defun_inputs(args)

      with capture_tensors(captures):
        this_tape = tape.push_new_tape()
        try:
          func_outputs = func(*func_inputs, **kwds)
        finally:
          tape.pop_tape(this_tape)
        variables = this_tape.watched_variables()

        # Returning a closed-over tensor as an output does not trigger a
        # call to convert_to_tensor, so we manually capture all such tensors.
        outputs_list = _flatten(func_outputs)
        func_def_outputs = [
            _convert_to_graph_tensor(x) for x in outputs_list if x is not None
        ]

      ids = list(sorted(captures.keys()))
      if ids:
        extra_inputs, extra_placeholders = zip(* [captures[x] for x in ids])
      else:
        extra_inputs = []
        extra_placeholders = []
      output_shapes = tuple(
          x.shape if isinstance(x, ops.Tensor) else None
          for x in outputs_list)

  flat_inputs = [x for x in nest.flatten(func_inputs)
                 if isinstance(x, ops.Tensor)]
  all_inputs = flat_inputs + list(extra_placeholders)
  all_ignored_ops = frozenset(x.op for x in all_inputs)
  fname = _inference_name(name)
  operations = tuple(x for x in tmp_graph.get_operations()
                     if x not in all_ignored_ops)
  # Register any other functions defined in the graph
  # TODO(ashankar): Oh lord, forgive me for this lint travesty.
  if context.in_eager_mode():
    for f in tmp_graph._functions.values():  # pylint: disable=protected-access
      # TODO(ashankar): What about the gradient registry?
      _register(f._c_func)  # pylint: disable=protected-access
  return GraphModeFunction(
      fname, all_inputs, extra_inputs, tmp_graph, operations, func_def_outputs,
      func_outputs, output_shapes, variables)
Example #2
0
def get_seed(op_seed):
  """Returns the local seeds an operation should use given an op-specific seed.

  Given operation-specific seed, `op_seed`, this helper function returns two
  seeds derived from graph-level and op-level seeds. Many random operations
  internally use the two seeds to allow user to change the seed globally for a
  graph, or for only specific operations.

  For details on how the graph-level seed interacts with op seeds, see
  [`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed).

  Args:
    op_seed: integer.

  Returns:
    A tuple of two integers that should be used for the local seed of this
    operation.
  """
  graph_seed = ops.get_default_graph().seed
  if graph_seed is not None:
    if op_seed is not None:
      return _truncate_seed(graph_seed), _truncate_seed(op_seed)
    else:
      return _truncate_seed(graph_seed), _truncate_seed(ops.get_default_graph()._last_id)
  else:
    if op_seed is not None:
      return _truncate_seed(_DEFAULT_GRAPH_SEED), _truncate_seed(op_seed)
    else:
      return None, None
Example #3
0
 def _WeMustGoDeeper(self, msg):
     with self.assertRaisesOpError(msg):
         node_def = ops._NodeDef("op_type", "name")
         node_def_orig = ops._NodeDef("op_type_orig", "orig")
         op_orig = ops.Operation(node_def_orig, ops.get_default_graph())
         op = ops.Operation(node_def, ops.get_default_graph(), original_op=op_orig)
         raise errors.UnauthenticatedError(node_def, op, "true_err")
  def testParallelApplyGradMean(self):
    # We need each thread to keep its own device stack or the device scopes
    # won't be properly nested.
    ops.get_default_graph().switch_to_thread_local()
    with self.cached_session() as sess:
      q = data_flow_ops.SparseConditionalAccumulator(
          dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2]))
      elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
      accum_ops = []
      for x in elems:
        x = _indexedslice(np.array([[x, 0], [0, x]]).astype(np.float32))
        accum_ops.append(q.apply_indexed_slices_grad(x, local_step=0))
      takeg_t = q.take_indexed_slices_grad(1)

      def apply_indexed_slices_grad(accum_op):
        self.evaluate(accum_op)

      threads = [
          self.checkedThread(
              target=apply_indexed_slices_grad, args=(o,)) for o in accum_ops
      ]

      for thread in threads:
        thread.start()
      for thread in threads:
        thread.join()

      val = self.evaluate(takeg_t)

      expected_val = sum(elems) / len(elems)
      self._assertEqual_nparray(
          np.array([[expected_val, 0], [0, expected_val]]).astype(np.float32),
          val, sess)
Example #5
0
    def wrapped_body(loop_counter, *args):
      """Loop body augmented with counter update.

      Args:
        loop_counter: Loop counter which needs to be incremented in the body.
        *args: List of args

      Returns:
        A list of tensors the same length as args.
      """
      # Capture the tensors already captured in cond_graph so that they appear
      # in the same order in body_graph.external_captures.
      for t in cond_graph.external_captures:
        ops.get_default_graph().capture(t)

      # Convert the flow variables in `args` to TensorArrays. `args` should
      # already have the same structure as `orig_loop_vars` but currently there
      # is no nest.zip so we call `_pack_sequence_as` which flattens both
      # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
      # and packs it into the structure of `orig_loop_vars`.
      outputs = body(*_pack_sequence_as(orig_loop_vars, args))
      if not nest.is_sequence(outputs):
        outputs = [outputs]
      # Compare the structure of input and output of body converting the
      # top-level tuples to list to be compatible with legacy while_loop.
      nest.assert_same_structure(list(outputs), list(orig_loop_vars))

      outputs = _tensor_array_to_flow(outputs)

      # TODO(srbs): Update lowering code to create _Enter nodes with
      # is_constant=True for inputs that are directly passed to outputs.
      return [loop_counter + 1] + list(outputs)
  def testParallelAssignWithoutLocking(self):
    # We need each thread to keep its own device stack or the device scopes
    # won't be properly nested.
    ops.get_default_graph().switch_to_thread_local()
    with self.cached_session() as sess:
      ones_t = array_ops.fill([1024, 1024], float(1))
      p = variables.Variable(array_ops.zeros([1024, 1024]))
      assigns = [
          state_ops.assign(p, math_ops.multiply(ones_t, float(i)), False)
          for i in range(1, 21)
      ]
      self.evaluate(variables.global_variables_initializer())

      def run_assign(assign_op):
        self.evaluate(assign_op)

      threads = [
          self.checkedThread(
              target=run_assign, args=(assign_op,)) for assign_op in assigns
      ]
      for t in threads:
        t.start()
      for t in threads:
        t.join()

      vals = self.evaluate(p)

      # Assert every element is taken from one of the assignments.
      self.assertTrue((vals > 0).all())
      self.assertTrue((vals <= 20).all())
Example #7
0
    def before_run(self, run_context):
        """ Dumps graphs and loads checkpoint if there exits.

        Called before each call to run().

        Args:
            run_context: A `SessionRunContext` object.

        Returns: A `SessionRunArgs` object containing global_step.
        """
        # We do write graph and saver_def at the first call of before_run.
        # We cannot do this in begin, since we let other hooks to change graph and
        # add variables in begin. Graph is finalized after all begin calls.
        if self._is_chief and self._first_call:
            training_util.write_graph(
                ops.get_default_graph().as_graph_def(add_shapes=True),
                self._checkpoint_dir,
                "graph.pbtxt")
            # dump model details "model_analysis.txt"
            dump_model_analysis(self._checkpoint_dir)  # dump model configs
            graph = ops.get_default_graph()
            meta_graph_def = meta_graph.create_meta_graph_def(
                graph_def=graph.as_graph_def(add_shapes=True),
                saver_def=self._saver.saver_def)
            if self._summary_writer is not None:
                self._summary_writer.add_graph(graph)
                self._summary_writer.add_meta_graph(meta_graph_def)
            tf.logging.info("CheckpointSaverHook (before_run): dump graph...")
        self._first_call = False
        return tf.train.SessionRunArgs(self._global_step)
  def testParallelAssignWithLocking(self):
    # We need each thread to keep its own device stack or the device scopes
    # won't be properly nested.
    ops.get_default_graph().switch_to_thread_local()
    with self.cached_session() as sess:
      zeros_t = array_ops.fill([1024, 1024], 0.0)
      ones_t = array_ops.fill([1024, 1024], 1.0)
      p = variables.Variable(zeros_t)
      assigns = [
          state_ops.assign(
              p, math_ops.multiply(ones_t, float(i)), use_locking=True)
          for i in range(1, 21)
      ]
      self.evaluate(p.initializer)

      def run_assign(assign_op):
        self.evaluate(assign_op)

      threads = [
          self.checkedThread(
              target=run_assign, args=(assign_op,)) for assign_op in assigns
      ]
      for t in threads:
        t.start()
      for t in threads:
        t.join()

      vals = self.evaluate(p)

      # Assert every element is the same, and taken from one of the assignments.
      self.assertTrue(vals[0, 0] > 0)
      self.assertTrue(vals[0, 0] <= 20)
      self.assertAllEqual(vals, np.ones([1024, 1024]) * vals[0, 0])
  def testParallelUpdateWithoutLocking(self):
    # We need each thread to keep its own device stack or the device scopes
    # won't be properly nested.
    ops.get_default_graph().switch_to_thread_local()
    with self.cached_session() as sess:
      ones_t = array_ops.fill([1024, 1024], 1.0)
      p = variables.Variable(array_ops.zeros([1024, 1024]))
      adds = [
          state_ops.assign_add(
              p, ones_t, use_locking=False) for _ in range(20)
      ]
      self.evaluate(variables.global_variables_initializer())

      def run_add(add_op):
        self.evaluate(add_op)

      threads = [
          self.checkedThread(
              target=run_add, args=(add_op,)) for add_op in adds
      ]
      for t in threads:
        t.start()
      for t in threads:
        t.join()

      vals = self.evaluate(p)
      ones = np.ones((1024, 1024)).astype(np.float32)
      self.assertTrue((vals >= ones).all())
      self.assertTrue((vals <= ones * 20).all())
Example #10
0
def _parse_kwargs_as_attrs(func_name, **kwargs):
  """Parses **kwargs into a node's attributes."""
  attrs = {}

  noinline = kwargs.pop("noinline", None)
  if noinline is not None:
    attrs["_noinline"] = attr_value_pb2.AttrValue(b=bool(noinline))

  compiled = kwargs.pop("compiled", None)
  separate_compiled_gradients = kwargs.pop("separate_compiled_gradients", None)
  if compiled is not None:
    attrs["_XlaCompile"] = attr_value_pb2.AttrValue(b=bool(compiled))
    attrs["_XlaSeparateCompiledGradients"] = attr_value_pb2.AttrValue(
        b=bool(separate_compiled_gradients))
    # Forward _XlaScope from enclosing context (if set), otherwise create new.
    # pylint: disable=protected-access
    if "_XlaScope" in ops.get_default_graph()._attr_scope_map:
      attrs["_XlaScope"] = ops.get_default_graph()._attr_scope_map["_XlaScope"]
    else:
      attrs["_XlaScope"] = attr_value_pb2.AttrValue(
          s=("function_%s" % func_name).encode())
    # pylint: enable=protected-access

  if kwargs:
    raise ValueError("Unknown keyword arguments: %s" % kwargs.keys())
  return attrs
Example #11
0
def get_seed(op_seed):
  """Returns the local seeds an operation should use given an op-specific seed.

  Given operation-specific seed, `op_seed`, this helper function returns two
  seeds derived from graph-level and op-level seeds. Many random operations
  internally use the two seeds to allow user to change the seed globally for a
  graph, or for only specific operations.

  For details on how the graph-level seed interacts with op seeds, see
  @{tf.set_random_seed}.

  Args:
    op_seed: integer.

  Returns:
    A tuple of two integers that should be used for the local seed of this
    operation.
  """
  graph_seed = ops.get_default_graph().seed
  if graph_seed is not None:
    if op_seed is None:
      # pylint: disable=protected-access
      op_seed = ops.get_default_graph()._last_id
    seeds = _truncate_seed(graph_seed), _truncate_seed(op_seed)
  else:
    if op_seed is not None:
      seeds = DEFAULT_GRAPH_SEED, _truncate_seed(op_seed)
    else:
      seeds = None, None
  # Avoid (0, 0) as the C++ ops interpret it as nondeterminism, which would
  # be unexpected since Python docs say nondeterminism is (None, None).
  if seeds == (0, 0):
    return (0, _MAXINT32)
  return seeds
Example #12
0
def copy_scoped_meta_graph(from_scope, to_scope,
                           from_graph=None, to_graph=None):
  """Copies a sub-meta_graph from one scope to another.

  Args:
    from_scope: `String` name scope containing the subgraph to be copied.
    to_scope: `String` name scope under which the copied subgraph will reside.
    from_graph: Optional `Graph` from which to copy the subgraph. If `None`, the
      default graph is use.
    to_graph: Optional `Graph` to which to copy the subgraph. If `None`, the
      default graph is used.

  Returns:
    A dictionary of `Variables` that has been copied into `to_scope`.

  Raises:
    ValueError: If `from_scope` and `to_scope` are the same while
      `from_graph` and `to_graph` are also the same.
  """
  from_graph = from_graph or ops.get_default_graph()
  to_graph = to_graph or ops.get_default_graph()

  if from_graph == to_graph and from_scope == to_scope:
    raise ValueError("'from_scope' and 'to_scope' need to be different "
                     "when performing copy in the same graph.")

  orig_meta_graph, var_list = export_scoped_meta_graph(
      export_scope=from_scope, graph=from_graph)
  var_list = import_scoped_meta_graph(orig_meta_graph,
                                      graph=to_graph,
                                      import_scope=to_scope)
  return var_list
  def _testDefaultGraphInThread(self, constructed_event, continue_event, i):
    with session.Session() as s:
      self.assertEqual(ops.get_default_graph(), s.graph)
      a = constant_op.constant(1.0, shape=[1, 2])
      b = constant_op.constant(2.0, shape=[2, 3])
      c = math_ops.matmul(a, b)
      v = variables.Variable(c, name='var_%d' % i)

      # Block here until all threads have constructed their graph.
      constructed_event.set()
      continue_event.wait()

      assign_c_to_v = state_ops.assign(v, c)
      v.initializer.run()
      assign_c_to_v.eval()
      v_val = v.eval()
      self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
      d = constant_op.constant(3.0, shape=[2, 3])
      e = math_ops.matmul(a, d)
      assign_e_to_v = state_ops.assign(v, e)
      e_val = e.eval()
      self.assertAllEqual([[6.0, 6.0, 6.0]], e_val)
      v_val = v.eval()
      self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
      s.run(assign_e_to_v)
      v_val = v.eval()
      self.assertAllEqual([[6.0, 6.0, 6.0]], v_val)
      self.assertEqual(ops.get_default_graph(), s.graph)
Example #14
0
  def container(self, container_name):
    """Returns a context manager that specifies the resource container to use.

    Overridden from `tf.Graph` to update both the init_scope container
    and the present inner container. This is necessary to make sure setting
    containers applies correctly both to created variables and to stateful
    ops.

    Args:
      container_name: container name string.

    Returns:
      A context manager for defining resource containers for stateful ops,
        yields the container name.
    """
    original_container = self._container
    # pylint: disable=protected-access
    with ops.init_scope():
      original_init_container = ops.get_default_graph()._container
    try:
      self._container = container_name
      with ops.init_scope():
        ops.get_default_graph()._container = container_name
      yield self._container
    finally:
      self._container = original_container
      with ops.init_scope():
        ops.get_default_graph()._container = original_init_container
Example #15
0
  def testIteratorStringHandleReuseTensorObject(self):
    dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
    one_shot_iterator = dataset_ops.make_one_shot_iterator(dataset)
    initializable_iterator = dataset_ops.make_initializable_iterator(dataset)
    structure_iterator = iterator_ops.Iterator.from_structure(
        dataset.output_types)

    created_ops = len(ops.get_default_graph().get_operations())

    self.assertIs(one_shot_iterator.string_handle(),
                  one_shot_iterator.string_handle())
    self.assertIs(initializable_iterator.string_handle(),
                  initializable_iterator.string_handle())
    self.assertIs(structure_iterator.string_handle(),
                  structure_iterator.string_handle())

    # Assert that getting the (default) string handle creates no ops.
    self.assertEqual(created_ops, len(ops.get_default_graph().get_operations()))

    # Specifying an explicit name will create a new op.
    handle_with_name = one_shot_iterator.string_handle(name="foo")
    self.assertEqual("foo", handle_with_name.op.name)
    self.assertIsNot(one_shot_iterator.string_handle(), handle_with_name)

    handle_with_same_name = one_shot_iterator.string_handle(name="foo")
    self.assertEqual("foo_1", handle_with_same_name.op.name)
    self.assertIsNot(handle_with_name, handle_with_same_name)
Example #16
0
  def decorator(self, **kwargs):
    """Finds existing Tensors, runs the test, checks for new Tensors."""

    def _is_tensor(obj):
      try:
        return (isinstance(obj, ops.Tensor) or
                isinstance(obj, variables.Variable))
      except ReferenceError:
        # If the object no longer exists, we don't care about it.
        return False

    tensors_before = set(id(obj) for obj in gc.get_objects() if _is_tensor(obj))
    outside_container_prefix = ops.get_default_graph()._container_prefix
    with IsolateTest():
      # Run the test in a new graph so that collections get cleared when it's
      # done, but inherit the container prefix so that we can print the values
      # of variables which get leaked when executing eagerly.
      ops.get_default_graph()._container_prefix = outside_container_prefix
      f(self, **kwargs)
    # Make an effort to clear caches, which would otherwise look like leaked
    # Tensors.
    backprop._last_zero = [None]
    backprop._shape_dtype = [None, None]
    context.get_default_context().scalar_cache().clear()
    gc.collect()
    tensors_after = [
        obj for obj in gc.get_objects()
        if _is_tensor(obj) and id(obj) not in tensors_before
    ]
    if tensors_after:
      raise AssertionError(("%d Tensors not deallocated after test: %s" % (
          len(tensors_after),
          str(tensors_after),
      )))
Example #17
0
 def __init__(self,
              master,
              is_chief=True,
              checkpoint_dir=None,
              monitors=None,
              scaffold=None,
              config=None):
   self._graph = ops.get_default_graph()
   self._master = master
   self._checkpoint_dir = checkpoint_dir
   self._is_chief = is_chief
   self._config = config
   self._monitors = monitors or []
   self._scaffold = scaffold or Scaffold()
   # Finalize and write the graph.
   self._graph.finalize()
   # Create the session.
   self._session_manager = sm.SessionManager(
       local_init_op=self._scaffold.local_init_op,
       ready_op=self._scaffold.ready_op,
       graph=ops.get_default_graph())
   self._sess = recoverable_session.RecoverableSession(self._create_session)
   # Call the begin() method of monitors.
   self._init_step = self._tf_sess.run(self._scaffold.global_step_tensor)
   for monitor in self._monitors:
     monitor.begin(max_steps=None)
   # Write the graph out, note: this uses self._init_step.
   self.write_graph()
  def testRead(self):
    for batch_size in [1, 2]:
      for num_epochs in [1, 10]:
        with ops.Graph().as_default():
          with self.test_session(graph=ops.get_default_graph()) as sess:
            # Basic test: read from file 0.
            self.outputs = self._read_batch_features(
                filenames=self.test_filenames[0],
                num_epochs=num_epochs,
                batch_size=batch_size)
            self._verify_records(sess, batch_size, 0, num_epochs=num_epochs)
            with self.assertRaises(errors.OutOfRangeError):
              self._next_actual_batch(sess)

        with ops.Graph().as_default():
          with self.test_session(graph=ops.get_default_graph()) as sess:
            # Basic test: read from file 1.
            self.outputs = self._read_batch_features(
                filenames=self.test_filenames[1],
                num_epochs=num_epochs,
                batch_size=batch_size)
            self._verify_records(sess, batch_size, 1, num_epochs=num_epochs)
            with self.assertRaises(errors.OutOfRangeError):
              self._next_actual_batch(sess)

        with ops.Graph().as_default():
          with self.test_session(graph=ops.get_default_graph()) as sess:
            # Basic test: read from both files.
            self.outputs = self._read_batch_features(
                filenames=self.test_filenames,
                num_epochs=num_epochs,
                batch_size=batch_size)
            self._verify_records(sess, batch_size, num_epochs=num_epochs)
            with self.assertRaises(errors.OutOfRangeError):
              self._next_actual_batch(sess)
  def testGraphWithoutVariables(self):
    export_dir = self._get_export_dir("test_graph_has_variables")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    # Graph with no variables.
    with self.test_session(graph=ops.Graph()) as sess:
      constant_5_name = constant_op.constant(5.0).name
      builder.add_meta_graph_and_variables(sess, ["foo"])

    # Second graph with no variables
    with self.test_session(graph=ops.Graph()) as sess:
      constant_6_name = constant_op.constant(6.0).name
      builder.add_meta_graph(["bar"])

    # Save the SavedModel to disk.
    builder.save()

    # Restore the graph with tag "foo".
    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["foo"], export_dir)
      # Read the constant a from the graph.
      a = ops.get_default_graph().get_tensor_by_name(constant_5_name)
      b = constant_op.constant(6.0)
      c = a * b
      self.assertEqual(30.0, sess.run(c))

    # Restore the graph with tag "bar".
    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["bar"], export_dir)
      # Read the constant a from the graph.
      a = ops.get_default_graph().get_tensor_by_name(constant_6_name)
      b = constant_op.constant(5.0)
      c = a * b
      self.assertEqual(30.0, sess.run(c))
Example #20
0
  def testPartitionConcatenatesAlongCorrectAxis(self):

    def _part_axis_0(**unused_kwargs):
      return (2, 1, 1)

    def _part_axis_1(**unused_kwargs):
      return (1, 2, 1)

    with variable_scope.variable_scope("root"):
      v0 = variable_scope.get_variable(
          "n0", shape=(2, 2, 2), partitioner=_part_axis_0)
      v1 = variable_scope.get_variable(
          "n1", shape=(2, 2, 2), partitioner=_part_axis_1)

    self.assertEqual(v0.get_shape(), (2, 2, 2))
    self.assertEqual(v1.get_shape(), (2, 2, 2))

    n0_0 = ops.get_default_graph().get_tensor_by_name("root/n0/part_0:0")
    n0_1 = ops.get_default_graph().get_tensor_by_name("root/n0/part_1:0")
    self.assertEqual(n0_0.get_shape(), (1, 2, 2))
    self.assertEqual(n0_1.get_shape(), (1, 2, 2))

    n1_0 = ops.get_default_graph().get_tensor_by_name("root/n1/part_0:0")
    n1_1 = ops.get_default_graph().get_tensor_by_name("root/n1/part_1:0")
    self.assertEqual(n1_0.get_shape(), (2, 1, 2))
    self.assertEqual(n1_1.get_shape(), (2, 1, 2))
  def finalize(self):
    """Creates operations if needed and finalizes the graph."""
    if self._init_op is None:
      self._init_op = Scaffold._get_or_default(
          'init_op', ops.GraphKeys.INIT_OP, variables.initialize_all_variables)
    if self._ready_op is None:
      self._ready_op = Scaffold._get_or_default(
          'ready_op', ops.GraphKeys.READY_OP,
          variables.report_uninitialized_variables)
    if self._local_init_op is None:
      self._local_init_op = Scaffold._get_or_default(
          'local_init_op', ops.GraphKeys.LOCAL_INIT_OP,
          Scaffold._default_local_init_op)
    if self._summary_op is None:
      self._summary_op = Scaffold._get_or_default(
          'summary_op', ops.GraphKeys.SUMMARY_OP,
          logging_ops.merge_all_summaries)
    # pylint: disable=g-long-lambda
    if self._saver is None:
      self._saver = Scaffold._get_or_default(
          'saver',
          ops.GraphKeys.SAVERS,
          lambda: training_saver.Saver(sharded=True, allow_empty=True))
    # pylint: enable=g-long-lambda
    self._saver.build()

    ops.get_default_graph().finalize()
    return self
  def finalize(self):
    """Creates operations if needed and finalizes the graph."""
    if self._global_step_tensor is None:
      self._global_step_tensor = contrib_variables.get_or_create_global_step()
    if self._init_op is None:
      self._init_op = Scaffold._get_or_default(
          'init_op', ops.GraphKeys.INIT_OP, variables.initialize_all_variables)
    if self._ready_op is None:
      self._ready_op = Scaffold._get_or_default(
          'ready_op', ops.GraphKeys.READY_OP,
          variables.report_uninitialized_variables)
    if self._local_init_op is None:
      self._local_init_op = Scaffold._get_or_default(
          'local_init_op', ops.GraphKeys.LOCAL_INIT_OP,
          Scaffold._default_local_init_op)
    if self._summary_op is None:
      self._summary_op = Scaffold._get_or_default(
          'summary_op', ops.GraphKeys.SUMMARY_OP,
          logging_ops.merge_all_summaries)
    # pylint: disable=g-long-lambda
    if self._saver is None:
      self._saver = Scaffold._get_or_default(
          'saver',
          ops.GraphKeys.SAVERS,
          lambda: training_saver.Saver(sharded=True,
                                       max_to_keep=self._keep_checkpoint_max))
    # pylint: enable=g-long-lambda

    ops.get_default_graph().finalize()
  def test_assign_stays_in_true_dtype(self, distribute):
    with get_distribute_scope(distribute):
      x = get_var(1., dtypes.float32)
      x = get_autocast_var(x, distribute)
      self.evaluate(x.initializer)
      # small_val is a value such that 1.0 + small_val == 1.0 in fp16, but not
      # in fp32
      small_val = np.finfo('float16').eps / 2
      small_tensor = constant_op.constant(small_val, dtype=dtypes.float32)
      with ops.get_default_graph()._enable_auto_casting_variables(
          dtypes.float16):
        # Variable should be increased, despite it appearing to be the same
        # float16 value.
        self.assertEqual(1. + small_val,
                         self.evaluate(x.assign(1. + small_tensor)))
        self.assertEqual(1., self.evaluate(x.value()))
      self.assertEqual(1. + small_val, self.evaluate(x.value()))

      self.evaluate(x.assign(1.))
      with ops.get_default_graph()._enable_auto_casting_variables(
          dtypes.float16):
        self.assertEqual(1. + small_val,
                         self.evaluate(x.assign_add(small_tensor)))
        self.assertEqual(1., self.evaluate(x.value()))
      self.assertEqual(1. + small_val, self.evaluate(x.value()))
  def testAccumulatorApplyAndBlockingTake(self):
    # We need each thread to keep its own device stack or the device scopes
    # won't be properly nested.
    ops.get_default_graph().switch_to_thread_local()
    with self.cached_session() as sess:
      q = data_flow_ops.ConditionalAccumulator(
          dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))

      elems = [10.0, 20.0, 30.0]
      elems_ave = sum(elems) / len(elems)
      accum_ops = [q.apply_grad((x,), local_step=0) for x in elems]
      takeg_t = q.take_grad(3)

      def apply_grad():
        time.sleep(1.0)
        for accum_op in accum_ops:
          self.evaluate(accum_op)

      return_array = []

      def take_grad():
        return_array.append(self.evaluate(takeg_t))

      accum_thread = self.checkedThread(target=apply_grad)
      takeg_thread = self.checkedThread(target=take_grad)
      accum_thread.start()
      takeg_thread.start()
      accum_thread.join()
      takeg_thread.join()

      self.assertEqual([elems_ave], return_array)
  def test_operator_overloads(self, distribute):
    with get_distribute_scope(distribute):
      x = get_var(1., dtypes.float32)
      x = get_autocast_var(x, distribute)
      self.evaluate(x.initializer)

    v1 = constant_op.constant(2., dtype=dtypes.float32)
    v2 = constant_op.constant(2., dtype=dtypes.float16)

    # Because autocast variables do not yet define operator overloads, the
    # operator is defined by the non-variable tensor

    # Test variable as the LHS. Currently, this is not supported with
    # distributed autocast variables
    if not distribute:
      self.assertEqual(self.evaluate(x + v1), 3.)

      with ops.get_default_graph()._enable_auto_casting_variables(
          dtypes.float16):
        self.assertEqual(self.evaluate(x + v2), 3.)

    # Test variable as the RHS
    self.assertEqual(self.evaluate(v1 + x), 3.)

    with ops.get_default_graph()._enable_auto_casting_variables(
        dtypes.float16):
      self.assertEqual(self.evaluate(v2 + x), 3.)
  def testParallelApplyGrad(self):
    # We need each thread to keep its own device stack or the device scopes
    # won't be properly nested.
    ops.get_default_graph().switch_to_thread_local()
    with self.cached_session() as sess:
      q = data_flow_ops.ConditionalAccumulator(
          dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
      elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
      accum_ops = [q.apply_grad((x,), local_step=0) for x in elems]
      takeg_t = q.take_grad(1)

      def apply_grad(accum_op):
        self.evaluate(accum_op)

      threads = [
          self.checkedThread(
              target=apply_grad, args=(o,)) for o in accum_ops
      ]

      for thread in threads:
        thread.start()
      for thread in threads:
        thread.join()

      val = self.evaluate(takeg_t)

      self.assertEqual(val, sum(elems) / len(elems))
  def testParallelTakeGrad(self):
    # We need each thread to keep its own device stack or the device scopes
    # won't be properly nested.
    ops.get_default_graph().switch_to_thread_local()
    with self.cached_session() as sess:
      q = data_flow_ops.ConditionalAccumulator(
          dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
      elems = [e for e in range(10)]
      accum_ops = [q.apply_grad((np.float32(e),), local_step=e) for e in elems]
      takeg_t = q.take_grad(1)

      def apply_grad():
        for accum_op in accum_ops:
          time.sleep(1.0)
          self.evaluate(accum_op)

      apply_grad_thread = self.checkedThread(target=apply_grad)

      results = []

      def take_grad():
        results.append(self.evaluate(takeg_t))

      threads = [self.checkedThread(target=take_grad) for _ in range(10)]

      for thread in threads:
        thread.start()
      apply_grad_thread.start()

      for thread in threads:
        thread.join()
      apply_grad_thread.join()

      self.assertItemsEqual(elems, results)
  def test_read(self, distribute):
    with get_distribute_scope(distribute):
      x = get_var(1., dtypes.float32)
      x = get_autocast_var(x, distribute)
      self.evaluate(x.initializer)

      # outside of auto cast scope.
      self.assertEqual(x.dtype, dtypes.float32)
      self.assertEqual(x.value().dtype, dtypes.float32)
      self.assertEqual(x.read_value().dtype, dtypes.float32)
      self.assertEqual(array_ops.identity(x).dtype, dtypes.float32)

      # within auto cast scope of different dtype
      with ops.get_default_graph()._enable_auto_casting_variables(
          dtypes.float16):
        self.assertEqual(x.dtype, dtypes.float16)
        self.assertEqual(x.value().dtype, dtypes.float16)
        self.assertEqual(x.read_value().dtype, dtypes.float16)
        self.assertEqual(array_ops.identity(x).dtype, dtypes.float16)

      # within auto cast scope of same dtype
      with ops.get_default_graph()._enable_auto_casting_variables(
          dtypes.float32):
        self.assertEqual(x.dtype, dtypes.float32)
        self.assertEqual(x.value().dtype, dtypes.float32)
        self.assertEqual(x.read_value().dtype, dtypes.float32)
        self.assertEqual(array_ops.identity(x).dtype, dtypes.float32)
  def testAccumulatorApplyAndBlockingTake(self):
    # We need each thread to keep its own device stack or the device scopes
    # won't be properly nested.
    ops.get_default_graph().switch_to_thread_local()
    with self.cached_session() as sess:
      q = data_flow_ops.SparseConditionalAccumulator(
          dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2]))

      elems = [10.0, 20.0, 30.0]
      elems_ave = sum(elems) / len(elems)
      accum_ops = []
      for x in elems:
        x = _indexedslice(np.array([[0, x], [0, 0]]).astype(np.float32))
        accum_ops.append(q.apply_indexed_slices_grad(x, local_step=0))
      takeg_t = q.take_indexed_slices_grad(3)

      results = []

      def apply_indexed_slices_grad():
        for accum_op in accum_ops:
          self.evaluate(accum_op)

      def take_grad():
        results.append(self.evaluate(takeg_t))

      accum_thread = self.checkedThread(target=apply_indexed_slices_grad)
      takeg_thread = self.checkedThread(target=take_grad)
      accum_thread.start()
      takeg_thread.start()
      accum_thread.join()
      takeg_thread.join()

      self._assertEqual_nparray([[0, elems_ave], [0, 0]], results[0], sess)
Example #30
0
  def test_graph_replace_gradients(self):
    ops.reset_default_graph()
    w = variables.VariableV1(0.0, name="w")
    y = math_ops.multiply(math_ops.multiply(w, w, name="mul1"), w, name="mul2")
    g = gradients_impl.gradients(y, w, name="grad")[0]

    # Extract the operations.
    replacement_ts = {w.value(): g}
    original_mul1_grad = (ops.get_default_graph().
                          get_operation_by_name("grad/mul1_grad/Mul_1"))

    # Should not raise exception.
    res = ge.graph_replace(g, replacement_ts, dst_scope="res")

    # Extract the operations after graph_replace.
    result_mul1_grad = (ops.get_default_graph().
                        get_operation_by_name("res/grad/mul1_grad/Mul_1"))

    # Make sure _original_ops are as expected.
    self.assertEqual(original_mul1_grad._original_op.name, u"mul1")
    self.assertEqual(result_mul1_grad._original_op.name, u"res/mul1")
    self.assertNotEqual(res.name, g.name)
    with session.Session() as sess:
      sess.run(variables.global_variables_initializer())
      g_val, res_val = sess.run([g, res])
    self.assertNear(g_val, 0.0, ERROR_TOLERANCE)
    self.assertNear(res_val, 0.0, ERROR_TOLERANCE)
Example #31
0
    def _experimental_run_steps_on_iterator(self,
                                            fn,
                                            iterator,
                                            iterations,
                                            initial_loop_values=None):
        if initial_loop_values is None:
            initial_loop_values = {}
        initial_loop_values = nest.flatten(initial_loop_values)

        ctx = input_lib.MultiStepContext()

        def body(i, *args):
            """A wrapper around `fn` to create the while loop body."""
            del args
            fn_result = fn(ctx, iterator.get_next())
            for (name, output) in ctx.last_step_outputs.items():
                # Convert all outputs to tensors, potentially from `DistributedValues`.
                ctx.last_step_outputs[name] = self._local_results(output)
            flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
            with ops.control_dependencies([fn_result]):
                return [i + 1] + flat_last_step_outputs

        # We capture the control_flow_context at this point, before we run `fn`
        # inside a while_loop. This is useful in cases where we might need to exit
        # these contexts and get back to the outer context to do some things, for
        # e.g. create an op which should be evaluated only once at the end of the
        # loop on the host. One such usage is in creating metrics' value op.
        self._outer_control_flow_context = (
            ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access

        cond = lambda i, *args: i < iterations
        i = constant_op.constant(0)
        loop_result = control_flow_ops.while_loop(cond,
                                                  body,
                                                  [i] + initial_loop_values,
                                                  name="",
                                                  parallel_iterations=1,
                                                  back_prop=False,
                                                  swap_memory=False,
                                                  return_same_structure=True)
        del self._outer_control_flow_context

        ctx.run_op = control_flow_ops.group(loop_result)

        # Convert the last_step_outputs from a list to the original dict structure
        # of last_step_outputs.
        last_step_tensor_outputs = loop_result[1:]
        last_step_tensor_outputs_dict = nest.pack_sequence_as(
            ctx.last_step_outputs, last_step_tensor_outputs)

        for name, reduce_op in ctx._last_step_outputs_reduce_ops.items():  # pylint: disable=protected-access
            output = last_step_tensor_outputs_dict[name]
            # For outputs that have already been reduced, wrap them in a Mirrored
            # container, else in a PerReplica container.
            if reduce_op is None:
                last_step_tensor_outputs_dict[name] = values.regroup(output)
            else:
                assert len(output) == 1
                last_step_tensor_outputs_dict[name] = output[0]

        ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access
        return ctx
Example #32
0
def dynamic_decode(decoder,
                   output_time_major=False,
                   impute_finished=False,
                   maximum_iterations=None,
                   parallel_iterations=32,
                   swap_memory=False,
                   scope=None):
    """Perform dynamic decoding with `decoder`.

  Calls initialize() once and step() repeatedly on the Decoder object.

  Args:
    decoder: A `Decoder` instance.
    output_time_major: Python boolean.  Default: `False` (batch major).  If
      `True`, outputs are returned as time major tensors (this mode is faster).
      Otherwise, outputs are returned as batch major tensors (this adds extra
      time to the computation).
    impute_finished: Python boolean.  If `True`, then states for batch
      entries which are marked as finished get copied through and the
      corresponding outputs get zeroed out.  This causes some slowdown at
      each time step, but ensures that the final state and outputs have
      the correct values and that backprop ignores time steps that were
      marked as finished.
    maximum_iterations: `int32` scalar, maximum allowed number of decoding
       steps.  Default is `None` (decode until the decoder is fully done).
    parallel_iterations: Argument passed to `tf.while_loop`.
    swap_memory: Argument passed to `tf.while_loop`.
    scope: Optional variable scope to use.

  Returns:
    `(final_outputs, final_state, final_sequence_lengths)`.

  Raises:
    TypeError: if `decoder` is not an instance of `Decoder`.
    ValueError: if `maximum_iterations` is provided but is not a scalar.
  """
    if not isinstance(decoder, Decoder):
        raise TypeError("Expected decoder to be type Decoder, but saw: %s" %
                        type(decoder))

    with variable_scope.variable_scope(scope, "decoder") as varscope:
        # Determine context types.
        ctxt = ops.get_default_graph()._get_control_flow_context()  # pylint: disable=protected-access
        is_xla = control_flow_util.GetContainingXLAContext(ctxt) is not None
        in_while_loop = (control_flow_util.GetContainingWhileContext(ctxt)
                         is not None)
        # Properly cache variable values inside the while_loop.
        # Don't set a caching device when running in a loop, since it is possible
        # that train steps could be wrapped in a tf.while_loop. In that scenario
        # caching prevents forward computations in loop iterations from re-reading
        # the updated weights.
        if not context.executing_eagerly() and not in_while_loop:
            if varscope.caching_device is None:
                varscope.set_caching_device(lambda op: op.device)

        if maximum_iterations is not None:
            maximum_iterations = ops.convert_to_tensor(
                maximum_iterations,
                dtype=dtypes.int32,
                name="maximum_iterations")
            if maximum_iterations.get_shape().ndims != 0:
                raise ValueError("maximum_iterations must be a scalar")

        initial_finished, initial_inputs, initial_state = decoder.initialize()

        zero_outputs = _create_zero_outputs(decoder.output_size,
                                            decoder.output_dtype,
                                            decoder.batch_size)

        if is_xla and maximum_iterations is None:
            raise ValueError(
                "maximum_iterations is required for XLA compilation.")
        if maximum_iterations is not None:
            initial_finished = math_ops.logical_or(initial_finished,
                                                   0 >= maximum_iterations)
        initial_sequence_lengths = array_ops.zeros_like(initial_finished,
                                                        dtype=dtypes.int32)
        initial_time = constant_op.constant(0, dtype=dtypes.int32)

        def _shape(batch_size, from_shape):
            if (not isinstance(from_shape, tensor_shape.TensorShape)
                    or from_shape.ndims == 0):
                return tensor_shape.TensorShape(None)
            else:
                batch_size = tensor_util.constant_value(
                    ops.convert_to_tensor(batch_size, name="batch_size"))
                return tensor_shape.TensorShape([batch_size
                                                 ]).concatenate(from_shape)

        dynamic_size = maximum_iterations is None or not is_xla

        def _create_ta(s, d):
            return tensor_array_ops.TensorArray(
                dtype=d,
                size=0 if dynamic_size else maximum_iterations,
                dynamic_size=dynamic_size,
                element_shape=_shape(decoder.batch_size, s))

        initial_outputs_ta = nest.map_structure(_create_ta,
                                                decoder.output_size,
                                                decoder.output_dtype)

        def condition(unused_time, unused_outputs_ta, unused_state,
                      unused_inputs, finished, unused_sequence_lengths):
            return math_ops.logical_not(math_ops.reduce_all(finished))

        def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
            """Internal while_loop body.

      Args:
        time: scalar int32 tensor.
        outputs_ta: structure of TensorArray.
        state: (structure of) state tensors and TensorArrays.
        inputs: (structure of) input tensors.
        finished: bool tensor (keeping track of what's finished).
        sequence_lengths: int32 tensor (keeping track of time of finish).

      Returns:
        `(time + 1, outputs_ta, next_state, next_inputs, next_finished,
          next_sequence_lengths)`.
        ```
      """
            (next_outputs, decoder_state, next_inputs,
             decoder_finished) = decoder.step(time, inputs, state)
            if decoder.tracks_own_finished:
                next_finished = decoder_finished
            else:
                next_finished = math_ops.logical_or(decoder_finished, finished)
            next_sequence_lengths = array_ops.where(
                math_ops.logical_not(finished),
                array_ops.fill(array_ops.shape(sequence_lengths), time + 1),
                sequence_lengths)

            nest.assert_same_structure(state, decoder_state)
            nest.assert_same_structure(outputs_ta, next_outputs)
            nest.assert_same_structure(inputs, next_inputs)

            # Zero out output values past finish
            if impute_finished:
                emit = nest.map_structure(
                    lambda out, zero: array_ops.where(finished, zero, out),
                    next_outputs, zero_outputs)
            else:
                emit = next_outputs

            # Copy through states past finish
            def _maybe_copy_state(new, cur):
                # TensorArrays and scalar states get passed through.
                if isinstance(cur, tensor_array_ops.TensorArray):
                    pass_through = True
                else:
                    new.set_shape(cur.shape)
                    pass_through = (new.shape.ndims == 0)
                return new if pass_through else array_ops.where(
                    finished, cur, new)

            if impute_finished:
                next_state = nest.map_structure(_maybe_copy_state,
                                                decoder_state, state)
            else:
                next_state = decoder_state

            outputs_ta = nest.map_structure(
                lambda ta, out: ta.write(time, out), outputs_ta, emit)
            return (time + 1, outputs_ta, next_state, next_inputs,
                    next_finished, next_sequence_lengths)

        res = control_flow_ops.while_loop(
            condition,
            body,
            loop_vars=(
                initial_time,
                initial_outputs_ta,
                initial_state,
                initial_inputs,
                initial_finished,
                initial_sequence_lengths,
            ),
            parallel_iterations=parallel_iterations,
            maximum_iterations=maximum_iterations,
            swap_memory=swap_memory)

        final_outputs_ta = res[1]
        final_state = res[2]
        final_sequence_lengths = res[5]

        final_outputs = nest.map_structure(lambda ta: ta.stack(),
                                           final_outputs_ta)

        try:
            final_outputs, final_state = decoder.finalize(
                final_outputs, final_state, final_sequence_lengths)
        except NotImplementedError:
            pass

        if not output_time_major:
            final_outputs = nest.map_structure(_transpose_batch_time,
                                               final_outputs)

    return final_outputs, final_state, final_sequence_lengths
Example #33
0
def export_scoped_meta_graph(filename=None,
                             graph_def=None,
                             graph=None,
                             export_scope=None,
                             as_text=False,
                             unbound_inputs_col_name="unbound_inputs",
                             clear_devices=False,
                             **kwargs):
    """Returns `MetaGraphDef` proto. Optionally writes it to filename.

  This function exports the graph, saver, and collection objects into
  `MetaGraphDef` protocol buffer with the intention of it being imported
  at a later time or location to restart training, run inference, or be
  a subgraph.

  Args:
    filename: Optional filename including the path for writing the
      generated `MetaGraphDef` protocol buffer.
    graph_def: `GraphDef` protocol buffer.
    graph: The `Graph` to import into. If `None`, use the default graph.
    export_scope: Optional `string`. Name scope under which to extract
      the subgraph. The scope name will be striped from the node definitions
      for easy import later into new name scopes. If `None`, the whole graph
      is exported. graph_def and export_scope cannot both be specified.
    as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto.
    unbound_inputs_col_name: Optional `string`. If provided, a string collection
      with the given name will be added to the returned `MetaGraphDef`,
      containing the names of tensors that must be remapped when importing the
      `MetaGraphDef`.
    clear_devices: Boolean which controls whether to clear device information
      before exporting the graph.
    **kwargs: Optional keyed arguments, including meta_info_def,
      saver_def, collection_list.

  Returns:
    A `MetaGraphDef` proto and dictionary of `Variables` in the exported
    name scope.

  Raises:
    ValueError: When the `GraphDef` is larger than 2GB.
  """
    graph = graph or ops.get_default_graph()
    unbound_inputs = []
    if export_scope or clear_devices:
        if graph_def:
            new_graph_def = graph_pb2.GraphDef()
            new_graph_def.versions.CopyFrom(graph_def.versions)
            for node_def in graph_def.node:
                if _should_include_node(node_def.name, export_scope):
                    new_node_def = _node_def(node_def,
                                             export_scope,
                                             unbound_inputs,
                                             clear_devices=clear_devices)
                    new_graph_def.node.extend([new_node_def])
            graph_def = new_graph_def
        else:
            # Only do this complicated work if we want to remove a name scope.
            graph_def = graph_pb2.GraphDef()
            # pylint: disable=protected-access
            graph_def.versions.CopyFrom(graph.graph_def_versions)
            bytesize = 0
            for key in sorted(graph._nodes_by_id):
                if _should_include_node(graph._nodes_by_id[key].name,
                                        export_scope):
                    value = graph._nodes_by_id[key]
                    # pylint: enable=protected-access
                    node_def = _node_def(value.node_def,
                                         export_scope,
                                         unbound_inputs,
                                         clear_devices=clear_devices)
                    graph_def.node.extend([node_def])
                    if value.outputs:
                        assert "_output_shapes" not in graph_def.node[-1].attr
                        graph_def.node[-1].attr[
                            "_output_shapes"].list.shape.extend([
                                output.get_shape().as_proto()
                                for output in value.outputs
                            ])
                    bytesize += value.node_def.ByteSize()
                    if bytesize >= (1 << 31) or bytesize < 0:
                        raise ValueError("GraphDef cannot be larger than 2GB.")
        # It's possible that not all the inputs are in the export_scope.
        # If we would like such information included in the exported meta_graph,
        # add them to a special unbound_inputs collection.
        if unbound_inputs_col_name:
            # Clears the unbound_inputs collections.
            graph.clear_collection(unbound_inputs_col_name)
            for k in unbound_inputs:
                graph.add_to_collection(unbound_inputs_col_name, k)

    var_list = {}
    variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
                                     scope=export_scope)
    for v in variables:
        if _should_include_node(v, export_scope):
            var_list[ops.strip_name_scope(v.name, export_scope)] = v

    scoped_meta_graph_def = create_meta_graph_def(graph_def=graph_def,
                                                  graph=graph,
                                                  export_scope=export_scope,
                                                  **kwargs)

    if filename:
        training_util.write_graph(scoped_meta_graph_def,
                                  os.path.dirname(filename),
                                  os.path.basename(filename),
                                  as_text=as_text)

    return scoped_meta_graph_def, var_list
Example #34
0
def import_scoped_meta_graph(meta_graph_or_file,
                             clear_devices=False,
                             graph=None,
                             import_scope=None,
                             input_map=None,
                             unbound_inputs_col_name="unbound_inputs"):
    """Recreates a`Graph` saved in a `MetaGraphDef` proto.

  This function takes a `MetaGraphDef` protocol buffer as input. If
  the argument is a file containing a `MetaGraphDef` protocol buffer ,
  it constructs a protocol buffer from the file content. The function
  then adds all the nodes from the `graph_def` field to the
  current graph, recreates all the collections, and returns a saver
  constructed from the `saver_def` field.

  In combination with `export_scoped_meta_graph()`, this function can be used to

  * Serialize a graph along with other Python objects such as `QueueRunner`,
    `Variable` into a `MetaGraphDef`.

  * Restart training from a saved graph and checkpoints.

  * Run inference from a saved graph and checkpoints.

  Args:
    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
      the path) containing a `MetaGraphDef`.
    clear_devices: Boolean which controls whether to clear device information
      from graph_def. Default false.
    graph: The `Graph` to import into. If `None`, use the default graph.
    import_scope: Optional `string`. Name scope into which to import the
      subgraph. If `None`, the graph is imported to the root name scope.
    input_map: A dictionary mapping input names (as strings) in `graph_def` to
      `Tensor` objects. The values of the named input tensors in the imported
      graph will be re-mapped to the respective `Tensor` values.
    unbound_inputs_col_name: Collection name for looking up unbound inputs.

  Returns:
    A dictionary of all the `Variables` imported into the name scope.

  Raises:
    ValueError: If the graph_def contains unbound inputs.
  """
    if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
        meta_graph_def = meta_graph_or_file
    else:
        meta_graph_def = read_meta_graph_file(meta_graph_or_file)

    if unbound_inputs_col_name:
        for key, col_def in meta_graph_def.collection_def.items():
            if key == unbound_inputs_col_name:
                kind = col_def.WhichOneof("kind")
                field = getattr(col_def, kind)
                if field.value and (not input_map or sorted(
                    [compat.as_str(v)
                     for v in field.value]) != sorted(input_map)):
                    raise ValueError(
                        "Graph contains unbound inputs: %s. Must "
                        "provide these inputs through input_map." %
                        ",".join([compat.as_str(v) for v in field.value]))
                break

    # Sets graph to default graph if it's not passed in.
    graph = graph or ops.get_default_graph()

    # Gathers the list of nodes we are interested in.
    with graph.as_default():
        producer_op_list = None
        if meta_graph_def.meta_info_def.HasField("stripped_op_list"):
            producer_op_list = meta_graph_def.meta_info_def.stripped_op_list
        input_graph_def = meta_graph_def.graph_def
        # Remove all the explicit device specifications for this node. This helps to
        # make the graph more portable.
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        importer.import_graph_def(input_graph_def,
                                  name=(import_scope or ""),
                                  input_map=input_map,
                                  producer_op_list=producer_op_list)

        # Restores all the other collections.
        for key, col_def in meta_graph_def.collection_def.items():
            # Don't add unbound_inputs to the new graph.
            if key == unbound_inputs_col_name:
                continue

            kind = col_def.WhichOneof("kind")
            if kind is None:
                logging.error(
                    "Cannot identify data type for collection %s. Skipping.",
                    key)
                continue
            from_proto = ops.get_from_proto_function(key)
            if from_proto:
                assert kind == "bytes_list"
                proto_type = ops.get_collection_proto_type(key)
                for value in col_def.bytes_list.value:
                    proto = proto_type()
                    proto.ParseFromString(value)
                    graph.add_to_collection(
                        key, from_proto(proto, import_scope=import_scope))
            else:
                field = getattr(col_def, kind)
                if kind == "node_list":
                    for value in field.value:
                        col_op = graph.as_graph_element(
                            ops.prepend_name_scope(value, import_scope))
                        graph.add_to_collection(key, col_op)
                elif kind == "int64_list":
                    # NOTE(opensource): This force conversion is to work around the fact
                    # that Python2 distinguishes between int and long, while Python3 has
                    # only int.
                    for value in field.value:
                        graph.add_to_collection(key, int(value))
                else:
                    for value in field.value:
                        graph.add_to_collection(
                            key, ops.prepend_name_scope(value, import_scope))

        var_list = {}
        variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
                                         scope=import_scope)
        for v in variables:
            var_list[ops.strip_name_scope(v.name, import_scope)] = v

    return var_list
Example #35
0
def create_meta_graph_def(meta_info_def=None,
                          graph_def=None,
                          saver_def=None,
                          collection_list=None,
                          graph=None,
                          export_scope=None):
    """Construct and returns a `MetaGraphDef` protocol buffer.

  Args:
    meta_info_def: `MetaInfoDef` protocol buffer.
    graph_def: `GraphDef` protocol buffer.
    saver_def: `SaverDef` protocol buffer.
    collection_list: List of string keys to collect.
    graph: The `Graph` to create `MetaGraphDef` out of.
    export_scope: Optional `string`. Name scope to remove.

  Returns:
    MetaGraphDef protocol buffer.

  Raises:
    TypeError: If the arguments are not of the correct proto buffer type.
  """
    # Type check.
    if graph and not isinstance(graph, ops.Graph):
        raise TypeError("graph must be of type Graph, not %s", type(graph))
    if meta_info_def and not isinstance(
            meta_info_def, meta_graph_pb2.MetaGraphDef.MetaInfoDef):
        raise TypeError("meta_info_def must be of type MetaInfoDef, not %s",
                        type(meta_info_def))
    if graph_def and not isinstance(graph_def, graph_pb2.GraphDef):
        raise TypeError("graph_def must be of type GraphDef, not %s",
                        type(graph_def))
    if saver_def and not isinstance(saver_def, saver_pb2.SaverDef):
        raise TypeError("saver_def must be of type SaverDef, not %s",
                        type(saver_def))

    # Sets graph to default graph if it's not passed in.
    graph = graph or ops.get_default_graph()

    # Creates a MetaGraphDef proto.
    meta_graph_def = meta_graph_pb2.MetaGraphDef()
    # Adds meta_info_def.
    if meta_info_def:
        meta_graph_def.meta_info_def.MergeFrom(meta_info_def)

    # Adds graph_def or the default.
    if not graph_def:
        meta_graph_def.graph_def.MergeFrom(graph.as_graph_def(add_shapes=True))
    else:
        meta_graph_def.graph_def.MergeFrom(graph_def)

    # Fills in meta_info_def.stripped_op_list using the ops from graph_def.
    # pylint: disable=g-explicit-length-test
    if len(meta_graph_def.meta_info_def.stripped_op_list.op) == 0:
        meta_graph_def.meta_info_def.stripped_op_list.MergeFrom(
            stripped_op_list_for_graph(meta_graph_def.graph_def))
    # pylint: enable=g-explicit-length-test

    # Adds saver_def.
    if saver_def:
        meta_graph_def.saver_def.MergeFrom(saver_def)

    # Adds collection_list.
    if collection_list:
        clist = collection_list
    else:
        clist = graph.get_all_collection_keys()
    for ctype in clist:
        add_collection_def(meta_graph_def,
                           ctype,
                           graph=graph,
                           export_scope=export_scope)
    return meta_graph_def
Example #36
0
def add_collection_def(meta_graph_def, key, graph=None, export_scope=None):
    """Adds a collection to MetaGraphDef protocol buffer.

  Args:
    meta_graph_def: MetaGraphDef protocol buffer.
    key: One of the GraphKeys or user-defined string.
    graph: The `Graph` from which to get collections.
    export_scope: Optional `string`. Name scope to remove.
  """
    if graph and not isinstance(graph, ops.Graph):
        raise TypeError("graph must be of type Graph, not %s", type(graph))

    if not isinstance(key, six.string_types) and not isinstance(key, bytes):
        logging.warning(
            "Only collections with string type keys will be "
            "serialized. This key has %s", type(key))
        return

    # Sets graph to default graph if it's not passed in.
    graph = graph or ops.get_default_graph()

    collection_list = graph.get_collection(key)
    # Remove nodes that should not be exported from the collection list.
    collection_list = [
        x for x in collection_list if _should_include_node(x, export_scope)
    ]
    if not collection_list:
        return

    try:
        col_def = meta_graph_def.collection_def[key]
        to_proto = ops.get_to_proto_function(key)
        proto_type = ops.get_collection_proto_type(key)
        if to_proto:
            kind = "bytes_list"
            for x in collection_list:
                # Additional type check to make sure the returned proto is indeed
                # what we expect.
                proto = to_proto(x, export_scope=export_scope)
                if proto:
                    assert isinstance(proto, proto_type)
                    getattr(col_def,
                            kind).value.append(proto.SerializeToString())
        else:
            kind = _get_kind_name(collection_list[0])
            if kind == "node_list":
                for x in collection_list:
                    if not export_scope or x.name.startswith(export_scope):
                        getattr(col_def, kind).value.append(
                            ops.strip_name_scope(x.name, export_scope))
            elif kind == "bytes_list":
                # NOTE(opensource): This force conversion is to work around the fact
                # that Python3 distinguishes between bytes and strings.
                getattr(col_def, kind).value.extend(
                    [compat.as_bytes(x) for x in collection_list])
            else:
                getattr(col_def,
                        kind).value.extend([x for x in collection_list])
    except Exception as e:  # pylint: disable=broad-except
        logging.warning(
            "Error encountered when serializing %s.\n"
            "Type is unsupported, or the types of the items don't "
            "match field type in CollectionDef.\n%s", key, str(e))
        if key in meta_graph_def.collection_def:
            del meta_graph_def.collection_def[key]
        return
Example #37
0
def _call_for_each_replica(distribution, devices, fn, args, kwargs):
    """Run `fn` in separate threads, once per replica/worker device.

  Args:
    distribution: the DistributionStrategy object.
    devices: the devices to run `fn` on (logical device 0 for each replica).
    fn: function to run (will be run once per replica, each in its own thread).
    args: positional arguments for `fn`
    kwargs: keyword arguments for `fn`.

  Returns:
    Merged return value of `fn` across all replicas.

  Raises:
    RuntimeError: If fn() calls get_replica_context().merge_call() a different
        number of times from the available devices.
  """
    # TODO(josh11b): Add this option once we add synchronization to variable
    # creation. Until then, this is pretty unsafe to use.
    run_concurrently = False
    if not context.executing_eagerly():
        # Needed for per-thread device, etc. contexts in graph mode.
        ops.get_default_graph().switch_to_thread_local()

    coord = coordinator.Coordinator(
        clean_stop_exception_types=(_RequestedStop, ))

    shared_variable_store = {}

    # TODO(isaprykin): Create these threads once instead of during every call.
    threads = []
    for index in range(len(devices)):
        variable_creator_fn = shared_variable_creator.make_fn(
            shared_variable_store, index)
        t = _MirroredReplicaThread(distribution, coord, index, devices,
                                   variable_creator_fn, fn,
                                   values.select_replica(index, args),
                                   values.select_replica(index, kwargs))
        threads.append(t)

    for t in threads:
        t.start()

    # When `fn` starts `should_run` event is set on _MirroredReplicaThread
    # (`MRT`) threads. The execution waits until
    # `MRT.has_paused` is set, which indicates that either `fn` is
    # complete or a `get_replica_context().merge_call()` is called.  If `fn` is
    # complete, then `MRT.done` is set to True.  Otherwise, arguments
    # of `get_replica_context().merge_call` from all paused threads are grouped
    # and the `merge_fn` is performed.  Results of the
    # `get_replica_context().merge_call` are then set to `MRT.merge_result`.
    # Each such `get_replica_context().merge_call` call returns the
    # `MRT.merge_result` for that thread when `MRT.should_run` event
    # is reset again. Execution of `fn` resumes.

    try:
        with coord.stop_on_exception():
            all_done = False
            while not all_done and not coord.should_stop():
                done = []
                if run_concurrently:
                    for t in threads:
                        t.should_run.set()
                    for t in threads:
                        t.has_paused.wait()
                        t.has_paused.clear()
                        if coord.should_stop():
                            return None
                        done.append(t.done)
                else:
                    for t in threads:
                        t.should_run.set()
                        t.has_paused.wait()
                        t.has_paused.clear()
                        if coord.should_stop():
                            return None
                        done.append(t.done)
                if coord.should_stop():
                    return None
                all_done = all(done)
                if not all_done:
                    if any(done):
                        raise RuntimeError(
                            "Some replicas made a different number of "
                            "replica_context().merge_call() calls.")
                    # get_replica_context().merge_call() case
                    merge_args = values.regroup(
                        tuple(t.merge_args for t in threads))
                    merge_kwargs = values.regroup(
                        tuple(t.merge_kwargs for t in threads))
                    # We capture the name_scope of the MRT when we call merge_fn
                    # to ensure that if we have opened a name scope in the MRT,
                    # it will be respected when executing the merge function. We only
                    # capture the name_scope from the first MRT and assume it is
                    # the same for all other MRTs.
                    mtt_captured_name_scope = threads[0].captured_name_scope
                    mtt_captured_var_scope = threads[0].captured_var_scope
                    # Capture and merge the control dependencies from all the threads.
                    mtt_captured_control_deps = set()
                    for t in threads:
                        mtt_captured_control_deps.update(
                            t.captured_control_deps)
                    with ops.name_scope(mtt_captured_name_scope),\
                        ops.control_dependencies(mtt_captured_control_deps), \
                        variable_scope.variable_scope(mtt_captured_var_scope):
                        merge_result = threads[0].merge_fn(
                            distribution, *merge_args, **merge_kwargs)
                    for r, t in enumerate(threads):
                        t.merge_result = values.select_replica(r, merge_result)
    finally:
        for t in threads:
            t.should_run.set()
        coord.join(threads)

    return values.regroup(tuple(t.main_result for t in threads))
Example #38
0
def split_compile_and_replicate(computation,
                                inputs=None,
                                infeed_queue=None,
                                device_assignment=None,
                                name=None,
                                use_tpu=True):
    """Builds graph operators that runs compilation and replicated computation.

  This is a lower level interface than replicate that returns a separate compile
  and execute output tensor. In the generated graph the compile op feeds into
  the execute op and no additional compilation is incurred when running the
  compile op before the execute op. The compile op returns additional
  information about the compilation but does not return the compiled program.

  Args:
    computation: A Python function that builds the computation to replicate.
    inputs: A list of lists of input tensors or `None` (equivalent to
      `[[]]`), indexed by `[replica_num][input_num]`. All replicas must
      have the same number of inputs.
    infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
      of arguments as inputs to computation.
    device_assignment: If not `None`, a `DeviceAssignment` describing the
      mapping between logical cores in the computation with physical cores in
      the TPU topology. Uses a default device assignment if `None`. The
      `DeviceAssignment` may be omitted if each replica of the computation uses
      only one core, and there is either only one replica, or the number of
      replicas is equal to the number of cores in the TPU system.
    name: (Deprecated) Does nothing.
    use_tpu: When false, the input `computation` is executed on the XLA CPU/GPU
      backends. Currently, only supports a default placement (computation is
      placed on GPU if one is available, and on CPU if not).
  Returns:
    A list of lists with the first list corresponding to the compile op and the
    second a list of output tensors, indexed by `[replica_num][output_num]`.
  Raises:
    ValueError: If all replicas do not have equal numbers of input tensors.
    ValueError: If the number of inputs per replica does not match
      the number of formal parameters to `computation`.
  """
    del name
    inputs = [[]] if inputs is None else inputs

    metadata_kwargs = {}
    if device_assignment is not None:
        # Turn the Numpy array into a flattened list so we can pass it as an
        # operator attribute.
        metadata_kwargs = {
            "topology":
            device_assignment.topology.serialized(),
            "device_assignment":
            device_assignment.core_assignment.flatten().tolist()
        }
        # TODO(phawkins): remove this case after the forward compatibility window
        # expires on 2018-10-5.
        if api_compat.forward_compatible(2018, 10, 5):
            metadata_kwargs["num_cores_per_replica"] = (
                device_assignment.num_cores_per_replica)
        else:
            metadata_kwargs["computation_shape"] = [
                device_assignment.num_cores_per_replica
            ]

    if ((not isinstance(inputs, list))
            or any(not isinstance(inp, (list, tuple)) for inp in inputs)):
        raise TypeError(
            "tpu.replicate() inputs must be a list of lists/tuples")

    num_replicas = len(inputs)

    # No replicas? Nothing to do.
    if num_replicas == 0:
        return []

    # Converts inputs to Tensors.
    inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in inputs]

    # Verifies that all replicas have matching numbers and types of inputs
    input_types = [x.dtype for x in inputs[0]]
    input_arity = len(input_types)
    for i in range(num_replicas):
        if len(inputs[i]) != input_arity:
            raise ValueError("Replicas must have the same number of inputs. "
                             "Replica 0 had {} inputs, replica {} had {} "
                             "inputs.".format(input_arity, i, len(inputs[i])))

        types = [x.dtype for x in inputs[i]]
        if types != input_types:
            raise ValueError(
                "Replicas must have matching input types. Replica 0 had "
                "input types {}, replica {} had input types {}".format(
                    input_types, i, types))

    arg_error = tpu_function.check_function_argument_count(
        computation, input_arity, infeed_queue)
    if arg_error is not None:
        if infeed_queue is None:
            raise TypeError(
                "Supplied computation cannot be called with the specified inputs. "
                "You specified %d inputs: %s, but the computation needs %s" %
                (input_arity, str([i.name for i in inputs[0]]), arg_error))
        else:
            raise TypeError(
                "Supplied computation cannot be called with the specified inputs. "
                "You specified %d inputs: %s and %d additional inputs from infeed,"
                " but the computation needs %s" %
                (input_arity, str([i.name for i in inputs[0]]),
                 infeed_queue.number_of_tuple_elements, arg_error))

    graph = ops.get_default_graph()

    # Fan-in: Builds a TPUReplicatedInput node for each input.
    computation_inputs = []
    for i in range(0, input_arity):
        replicas = [inputs[replica][i] for replica in xrange(num_replicas)]
        computation_inputs.append(
            tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i)))

    cluster_name = graph.unique_name("cluster")
    pivot = control_flow_ops.no_op(name=cluster_name + "/pivot")
    context = TPUReplicateContext(name=cluster_name,
                                  num_replicas=num_replicas,
                                  pivot=pivot)
    try:
        context.Enter()

        metadata = tpu_ops.tpu_replicate_metadata(num_replicas=num_replicas,
                                                  use_tpu=use_tpu,
                                                  **metadata_kwargs)

        with tpu_function.tpu_shard_context(
                num_replicas), ops.control_dependencies([metadata]):

            # Add identity ops so even unused inputs are "consumed" by the
            # computation. This is to avoid orphaned TPUReplicatedInput nodes.
            # TODO(phawkins): consider instead pruning unused TPUReplicatedInput
            # and eliding trivial TPUReplicatedInput/TPUReplicatedOutput pairs.
            computation_inputs = [
                array_ops.identity(x, name="replicated_input_{}".format(i))
                for i, x in enumerate(computation_inputs)
            ]

            # If there is an infeed queue, adds the dequeued values to the
            # computation's inputs.
            if infeed_queue is not None:
                infeed_queue.set_number_of_shards(num_replicas)
                for t in infeed_queue.generate_dequeue_op():
                    computation_inputs.append(t)

            # Only resource variables work inside a TPU computation, so turn on
            # resource variables for the computation.
            # TODO(phawkins): consider removing this code. It will
            # be less confusing to clients if they knowingly choose to use resource
            # variables.
            # Partitioned variables is not supported (b/112311320).
            def custom_getter(getter, name, *args, **kwargs):
                """Variables on TPU have a few restrictions."""
                partitioner = kwargs["partitioner"]
                if partitioner is not None:
                    kwargs["partitioner"] = None
                    logging.warning(
                        "Partitioned variables are not supported on TPU. Got "
                        "`partitioner` that is {} for variable {}. "
                        "Setting `partitioner` to `None`.".format(
                            partitioner, name))
                return getter(name, *args, **kwargs)

            vscope = variable_scope.get_variable_scope()

            saved_use_resource = vscope.use_resource
            saved_custom_getter = vscope.custom_getter

            vscope.set_use_resource(True)
            vscope.set_custom_getter(custom_getter)

            outputs = computation(*computation_inputs)

            vscope.set_use_resource(saved_use_resource)
            vscope.set_custom_getter(saved_custom_getter)

        # If the computation returns `None`, make it an empty tuple.
        if outputs is None:
            outputs = tuple()
        # If the computation only returned one value, makes it a tuple.
        if not isinstance(outputs, (list, tuple)):
            outputs = (outputs, )

        # Append `no_op` here so that fetching any return value of this function
        # will trigger TPUExecute node.
        outputs += (control_flow_ops.no_op(), )
        try:
            with ops.device(core(0)):
                outputs = [
                    o if isinstance(o, ops.Operation) else
                    ops.convert_to_tensor(o) for o in outputs
                ]
        except Exception as e:
            raise ValueError(
                "TPU function return values must all either be Operations or "
                "convertible to Tensors. Got '%s'" % str(e))

        # Separates the returned Operations and Tensors.
        output_operations = [
            o for o in outputs if isinstance(o, ops.Operation)
        ]
        output_tensors = [
            o for o in outputs if not isinstance(o, ops.Operation)
        ]

        if outputs != output_tensors + output_operations:
            raise ValueError(
                "TPU functions must return zero-or more Tensor values followed by "
                "zero or more Operations.")
        output_arity = len(output_tensors)

        # Wraps outputs in Identity ops. Otherwise a replicated input copied
        # straight to an output would bypass the replicate(). This would be bad
        # because the TPUReplicatedInput/TPUReplicatedOutput operator would not
        # be rewritten away, leading to a runtime error.
        # TODO(phawkins): extend the rewrite to elide these nodes instead.
        new_output_tensors = []
        for t in output_tensors:
            with ops.device(t.device if t.device else core(0)):
                new_output_tensors.append(array_ops.identity(t))
        output_tensors = new_output_tensors
        context.ExitResult(output_tensors)
    finally:
        context.report_unsupported_operations()
        context.Exit()
        host_compute_core = context.HostComputeCore()

    if host_compute_core:
        attr_value = attr_value_pb2.AttrValue()
        attr_value.list.s.extend(
            [compat.as_bytes(x) for x in host_compute_core])
        metadata._set_attr("host_compute_core", attr_value)  # pylint: disable=protected-access

    # Fan-out: Builds a TPUReplicatedOutput node for each output.
    outputs = [
        tpu_ops.tpu_replicated_output(output_tensors[i],
                                      num_replicas,
                                      name="output{}".format(i))
        for i in xrange(output_arity)
    ]

    with ops.control_dependencies([metadata]):
        if use_tpu:
            compile_status = tpu_ops.tpu_compilation_result()
            op = compile_status.op
            attr_value = attr_value_pb2.AttrValue(
                s=compat.as_bytes(cluster_name))
            op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value)  # pylint: disable=protected-access
        else:
            compile_status = control_flow_ops.no_op(name="compilation_status")

    with ops.control_dependencies(output_operations):
        if output_arity == 0:
            # Returns a list of NoOps dependent on the replication Op, indexed by
            # [replica_num].
            return [
                compile_status,
                [
                    control_flow_ops.no_op(name="shard_%d" % i)
                    for i in range(num_replicas)
                ]
            ]
        else:
            # Wraps the outputs in identity operators so the names of any possible
            # `fetch` nodes are preserved by the replication rewrite.
            return [
                compile_status,
                [[
                    array_ops.identity(outputs[out][replica],
                                       name="output_%d_shard_%d" %
                                       (out, replica))
                    for out in xrange(output_arity)
                ] for replica in xrange(num_replicas)]
            ]
Example #39
0
 def begin(self):
     self._loss_tensor = ops.get_default_graph().get_tensor_by_name(
         KMeansClustering.LOSS_OP_NAME + ':0')
     assert self._loss_tensor is not None
Example #40
0
  def _default_layout(self, layout: layout_lib.Layout):
    """Sets a default output layout for all ops in the scope.

    Note: This is an internal helper method, which is not user facing api.

    Useful for requesting a specific layout for ops which would have no inferred
    layout, e.g. tf.zeros.

    Caveats:

    - Currently only affects the first output of an op. For Op with multiple
      outputs, this does not support yet.

    - All Ops in the scope will be attached with the same layout. This might not
      be valid as the rank is different. The current suggestion is: Try to wrap
      the raw op wheneven possible.

    Args:
      layout: A Layout for the outputs of all operations in this scope.

    Yields:
      Nothing.
    """
    previous_default = None
    previous_graph_size = None
    graph = None

    self._register_mesh(layout.mesh)
    try:
      previous_default = self._current_output_layout
      self._current_output_layout = layout.to_string().encode("utf-8")
      _pywrap_dtensor_device.ExperimentalSetDefaultLayout(
          self._device_info, self._current_output_layout)
      if context.executing_eagerly():
        with ops.device(self.name):
          yield
      else:
        # Custom devices currently don't affect graph building, so we need a
        # separate way to indicate layouts.
        #
        # TODO(allenl): Remove this case once the DTensor device is active
        # during tracing.
        graph = ops.get_default_graph()
        previous_graph_size = len(graph.get_operations())
        yield
    finally:
      if graph is not None:
        # Tag operations added under this scope
        for operation in graph.get_operations()[previous_graph_size:]:
          # Set layout directly on the Op itself.
          operation._set_attr(  # pylint: disable=protected-access
              "_layout",
              attr_value_pb2.AttrValue(
                  list=attr_value_pb2.AttrValue.ListValue(
                      s=[self._current_output_layout])))
          operation._set_attr(  # pylint: disable=protected-access
              "_mesh",
              attr_value_pb2.AttrValue(
                  s=layout.mesh.to_string().encode("utf-8")))

      self._current_output_layout = previous_default
      if self._current_output_layout is None:
        _pywrap_dtensor_device.ExperimentalClearDefaultLayout(self._device_info)
      else:
        _pywrap_dtensor_device.ExperimentalSetDefaultLayout(
            self._device_info, self._current_output_layout.decode("utf-8"))
Example #41
0
def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay):
    """Computes batch norm correction params.

     Before batch normalization is frozen:
     We use batch statistics for batch norm.
       correction_scale = sigma_b/sigma_mv
       correction_recip = 1/correction_scale
       correction_offset = 0

     After batch normalization is frozen:
      correction_scale = sigma_b/sigma_mv
      correction_recip = 1
      correction_offset =  gamma*(mu_b/sigma_b-mu_mv/sigma_mv).

     Batch norm is frozen if global_step > bn_freeze_delay.
     The corrections ensure that:
     a) The weights are quantized after scaling by gamma/sigma_mv. This enables
     smoother training as the scaling on the weights changes slowly, rather than
     jump across mini-batches
     b) Changing the values of the corrections allows for one to switch between
     using batch statistics to using moving mean and average, without requiring
     changes to batch_norm


  Args:
    context: The scope under which we look for batch norm params
    match: Object containing required batch norm tensors for correction
      computation.
    freeze_batch_norm_delay: Delay in steps at which computation switches
      from regular batch norm to frozen mean and variance.


  Returns:
    A tuple of correction_scale, correction_recip, correction_offset
  """

    g = ops.get_default_graph()
    prefix = '' if not context else context
    with g.name_scope(prefix + 'batch_norm_correction'):
        recip_sigma_mv = math_ops.rsqrt(match.moving_variance_tensor +
                                        match.batch_epsilon)
        recip_sigma = math_ops.rsqrt(match.variance_tensor +
                                     match.batch_epsilon)
        correction_scale = math_ops.divide(recip_sigma_mv,
                                           recip_sigma,
                                           name='scale_compute')
        correction_scale = array_ops.identity(correction_scale,
                                              name='correction_scale')
        correction_recip = math_ops.reciprocal(correction_scale,
                                               name='reciprocal_compute')
        correction_offset = math_ops.multiply(
            match.gamma_tensor,
            match.mean_tensor * recip_sigma -
            match.moving_mean_tensor * recip_sigma_mv,
            name='offset_compute')

        if freeze_batch_norm_delay is not None:
            use_mv_avg = math_ops.greater_equal(
                common.CreateOrGetQuantizationStep(),
                freeze_batch_norm_delay,
                name='use_moving_average')
        else:
            use_mv_avg = False

        bn_decay_zero = 0.0
        bn_decay_mean_consumers = list(match.bn_decay_mean_tensor.consumers())
        bn_decay_var_consumers = list(match.bn_decay_mean_tensor.consumers())

        bn_decay_mean_out = utils.smart_cond(
            use_mv_avg,
            lambda: bn_decay_zero,
            lambda: match.bn_decay_mean_tensor,
            name='freeze_moving_mean')

        common.RerouteTensor(bn_decay_mean_out,
                             match.bn_decay_mean_tensor,
                             can_modify=bn_decay_mean_consumers)

        bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers())
        bn_decay_var_out = utils.smart_cond(use_mv_avg,
                                            lambda: bn_decay_zero,
                                            lambda: match.bn_decay_var_tensor,
                                            name='freeze_moving_var')
        common.RerouteTensor(bn_decay_var_out,
                             match.bn_decay_var_tensor,
                             can_modify=bn_decay_var_consumers)

        correction_recip = utils.smart_cond(
            use_mv_avg,
            lambda: array_ops.ones(correction_scale.shape),
            lambda: correction_recip,
            name='correction_recip')

        correction_offset = utils.smart_cond(
            use_mv_avg,
            lambda: correction_offset,
            lambda: array_ops.zeros(correction_offset.shape),
            name='correction_offset')
    return correction_scale, correction_recip, correction_offset
Example #42
0
    def __exit__(self, unused_type, unused_value, unused_traceback):
        if context.executing_eagerly():
            return

        if self._graph is not ops.get_default_graph():
            raise RuntimeError(
                "Graph changed while trying to add control dependencies.")

        # pylint: disable=protected-access
        if hasattr(self._graph, "outer_graph"):
            outer_val = self._graph.outer_graph._add_control_dependencies
            self._graph._add_control_dependencies = outer_val
        else:
            self._graph._add_control_dependencies = False
        # pylint: enable=protected-access

        # map from resource tensor to the last op which used it
        last_op_using_resource_tensor = {}
        # set of conditional and loop exits
        ops_which_must_run = set()
        # merge which must depend on ops which use this resource
        merge_for_resource = {}

        new_operations = self._graph.get_operations()[self._n_operations:]

        # Ensures that uses of resource tensors get serialized properly and all
        # execute. This is done by keeping a map from resource tensor to the last op
        # in graph-construction order which used it (last_op_using_resource_tensor).
        #
        # Conditionals are written in TensorFlow such that every external tensor
        # accessed in the conditional goes through a switch op and every return
        # tensor (it's guaranteed that there will be at least one) goes through a
        # merge op.
        #
        # To handle conditionals, switches are handled in a special way (see
        # comments for _process_switch). Merge nodes created by TF's conditional
        # logic (as opposed to by _process_switch) are forced to run and also get a
        # control dependency added to them to ensure all stateful ops inside their
        # control flow context run.
        #
        # We also ensure that if an op is using a resource output by a switch node
        # (that is, a resource tensor for which there's a value in
        # merge_for_resource) this op will run before the merge for that resource.
        #
        # We try to add control inputs to nodes respecting their control flow
        # contexts to avoid dead nodes propagating everywhere and leading to
        # "retval[0] doesn't have value" errors. If a node gets a control dependency
        # on a dead node (i.e. a note from an untaken control flow branch) that node
        # will be marked as dead unless it's a merge node.
        #
        # TODO(apassos): serialize non-resource-taking stateful ops as well, and
        # test that it works. Support while loops. Support init_scope escaping from
        # this.
        for op in new_operations:
            # TODO(apassos) make this code safely support while loops.
            if control_flow_util.IsInWhileLoop(op):
                continue
            control_inputs = set()
            # Ensure stateful ops run
            if op_def_registry.get(op.type) is None or op_is_stateful(op):
                ops_which_must_run.add(op)
            # Ignore switches (they're handled separately)
            if op.type == "Switch" and op.inputs[
                    0].dtype == dtypes_module.resource:
                continue
            # Make merges trigger all other computation which must run
            if op.type == "Merge":
                for o in ops_which_must_run:
                    op._add_control_input(o)  # pylint: disable=protected-access
                    for inp in o.inputs:
                        input_id = ops.tensor_id(inp)
                        if input_id in last_op_using_resource_tensor:
                            last_op_using_resource_tensor[input_id] = op
                ops_which_must_run = set([op])
                continue

            resource_inputs = set()
            # Check for any resource inputs. If we find any, we update control_inputs
            # and last_op_using_resource_tensor.
            for inp in _get_resource_inputs(op):
                input_id = ops.tensor_id(inp)

                # If the op receives the same resource tensor twice as an input, we skip
                # to avoid the op getting a control dependency on itself.
                if input_id in resource_inputs:
                    continue

                resource_inputs.add(input_id)
                # Deal with switches, finally.
                if inp.op.type == "Switch":
                    self._process_switch(inp.op, ops_which_must_run,
                                         last_op_using_resource_tensor,
                                         merge_for_resource)
                is_building_function = op.graph.building_function
                # Ensure uses of resources are serialized
                if input_id in last_op_using_resource_tensor:
                    if is_building_function or (
                            last_op_using_resource_tensor[input_id].
                            _control_flow_context  # pylint: disable=protected-access
                            is op._control_flow_context):  # pylint: disable=protected-access
                        control_inputs.add(
                            last_op_using_resource_tensor[input_id])
                # Ensure merges happen after the closing of a cond block
                if input_id in merge_for_resource:
                    merge_for_resource[input_id]._add_control_input(op)  # pylint: disable=protected-access
                last_op_using_resource_tensor[input_id] = op

            if (op_is_stateful(op) and not resource_inputs
                    and op._control_flow_context is None):  # pylint: disable=protected-access
                if None in last_op_using_resource_tensor:
                    op._add_control_input(last_op_using_resource_tensor[None])  # pylint: disable=protected-access
                last_op_using_resource_tensor[None] = op
            control_inputs = [
                c for c in control_inputs if is_building_function or (
                    c._control_flow_context is op._control_flow_context)
            ]  # pylint: disable=protected-access
            op._add_control_inputs(control_inputs)  # pylint: disable=protected-access

        # Ensure all ops which must run do run
        self.ops_which_must_run.update(ops_which_must_run)
        for r in nest.flatten(list(self._returned_tensors),
                              expand_composites=True):
            if self.ops_which_must_run:
                r.op._add_control_inputs(  # pylint: disable=protected-access
                    [
                        o for o in self.ops_which_must_run
                        if r.graph.building_function or
                        (o._control_flow_context is r.op._control_flow_context)  # pylint: disable=protected-access
                    ])
Example #43
0
  def testGradientFromInsideNestedDefun(self):

    def build_graph():
      pred_outer = array_ops.placeholder(dtypes.bool, name="pred_outer")
      pred_inner = array_ops.placeholder(dtypes.bool, name="pred_inner")
      x = constant_op.constant(1.0, name="x")
      y = constant_op.constant(2.0, name="y")

      def true_fn():
        return 2.0

      def false_fn():

        def inner_true_fn():
          return x * y * 2.0

        def inner_false_fn():
          return x * 5.0

        return cond_v2.cond_v2(
            pred_inner, inner_true_fn, inner_false_fn, name="inner_cond")

      cond_outer = cond_v2.cond_v2(
          pred_outer, true_fn, false_fn, name="outer_cond")

      # Compute grads inside a Defun.
      @function.defun
      def nesting_fn():

        @function.defun
        def inner_nesting_fn():
          return gradients_impl.gradients(cond_outer, [x, y])

        return inner_nesting_fn()

      grads = nesting_fn()

      return grads, pred_outer, pred_inner

    with ops.Graph().as_default():
      grads, pred_outer, pred_inner = build_graph()
      with self.session(graph=ops.get_default_graph()) as sess:
        self.assertSequenceEqual(
            sess.run(grads, {
                pred_outer: True,
                pred_inner: True
            }), [0., 0.])
        self.assertSequenceEqual(
            sess.run(grads, {
                pred_outer: True,
                pred_inner: False
            }), [0., 0.])
        self.assertSequenceEqual(
            sess.run(grads, {
                pred_outer: False,
                pred_inner: True
            }), [4., 2.])
        self.assertSequenceEqual(
            sess.run(grads, {
                pred_outer: False,
                pred_inner: False
            }), [5., 0.])
Example #44
0
  def __init__(self, input_dataset, target_device, source_device="/cpu:0"):
    """Constructs a _CopyToDeviceDataset.

    Args:
      input_dataset: `Dataset` to be copied
      target_device: The name of the device to which elements would be copied.
      source_device: Device where input_dataset would be placed.
    """
    super(_CopyToDeviceDataset, self).__init__(input_dataset)
    self._input_dataset = input_dataset
    self._target_device = target_device
    spec = framework_device.DeviceSpec().from_string(self._target_device)
    self._is_gpu_target = (spec.device_type == "GPU")
    self._source_device_string = source_device
    self._source_device = ops.convert_to_tensor(source_device)

    self._flat_output_shapes = nest.flatten(
        sparse.as_dense_shapes(self._input_dataset.output_shapes,
                               self._input_dataset.output_classes))
    self._flat_output_types = nest.flatten(
        sparse.as_dense_types(self._input_dataset.output_types,
                              self._input_dataset.output_classes))

    @function.Defun()
    def _init_func():
      """Creates an iterator for the input dataset.

      Returns:
        A `string` tensor that encapsulates the iterator created.
      """
      # pylint: disable=protected-access
      ds_variant = self._input_dataset._as_variant_tensor()
      resource = core_gen_dataset_ops.anonymous_iterator(
          output_types=self._flat_output_types,
          output_shapes=self._flat_output_shapes)
      with ops.control_dependencies(
          [core_gen_dataset_ops.make_iterator(ds_variant, resource)]):
        return core_gen_dataset_ops.iterator_to_string_handle(resource)

    @function.Defun()
    def _remote_init_func():
      return functional_ops.remote_call(
          target=self._source_device,
          args=_init_func.captured_inputs,
          Tout=[dtypes.string],
          f=_init_func)

    self._init_func = _remote_init_func
    self._init_captured_args = _remote_init_func.captured_inputs

    @function.Defun(dtypes.string)
    def _next_func(string_handle):
      """Calls get_next for created iterator.

      Args:
        string_handle: An iterator string handle created by _init_func
      Returns:
        The elements generated from `input_dataset`
      """
      with ops.device(self._source_device_string):
        iterator = iterator_ops.Iterator.from_string_handle(
            string_handle, self.output_types, self.output_shapes,
            self.output_classes)
      ret = iterator.get_next()
      return nest.flatten(sparse.serialize_sparse_tensors(ret))

    @function.Defun(dtypes.string)
    def _remote_next_func(string_handle):
      return functional_ops.remote_call(
          target=self._source_device,
          args=[string_handle] + _next_func.captured_inputs,
          Tout=self._flat_output_types,
          f=_next_func)

    self._next_func = _remote_next_func
    self._next_captured_args = _remote_next_func.captured_inputs

    @function.Defun(dtypes.string)
    def _finalize_func(string_handle):
      """Destroys the iterator resource created.

      Args:
        string_handle: An iterator string handle created by _init_func
      Returns:
        Tensor constant 0
      """
      iterator_resource = core_gen_dataset_ops.iterator_from_string_handle_v2(
          string_handle,
          output_types=self._flat_output_types,
          output_shapes=self._flat_output_shapes)
      with ops.control_dependencies([
          resource_variable_ops.destroy_resource_op(
              iterator_resource, ignore_lookup_error=True)]):
        return array_ops.constant(0, dtypes.int64)

    @function.Defun(dtypes.string)
    def _remote_finalize_func(string_handle):
      return functional_ops.remote_call(
          target=self._source_device,
          args=[string_handle] + _finalize_func.captured_inputs,
          Tout=[dtypes.int64],
          f=_finalize_func)

    self._finalize_func = _remote_finalize_func
    self._finalize_captured_args = _remote_finalize_func.captured_inputs

    g = ops.get_default_graph()
    _remote_init_func.add_to_graph(g)
    _remote_next_func.add_to_graph(g)
    _remote_finalize_func.add_to_graph(g)
Example #45
0
 def initialize_variables():
     for v, init in initializer_map.items():
         v.assign(
             lift_to_graph.lift_to_graph([init],
                                         ops.get_default_graph())[init])
Example #46
0
    def _GetLayerMatch(match_result):
        """Populates a layer match object containing ops/tensors for folding BNs.

    Args:
      match_result: Matched result from graph matcher

    Returns:
      layer_op: Matching conv/fc op prior to batch norm
      BatchNormMatch: _BatchNormMatch containing all required batch norm
      parameters.
    """
        moving_mean_tensor = None
        moving_variance_tensor = None
        bn_decay_mean_tensor = None
        bn_decay_var_tensor = None
        batch_to_space_op = None
        layer_op = match_result.get_op(layer_pattern)
        layer_tensor = match_result.get_tensor(layer_pattern)
        bn_id_op = match_result.get_op(batch_norm_identity_pattern)
        bn_op = match_result.get_op(batch_norm_pattern)
        if bn_id_op is None:
            bn_id_op = bn_op

        batch_epsilon = bn_op.get_attr('epsilon')

        # In the MatMul case, the output of batch norm is reshaped back into a
        # 2D tensor, so the output_tensor is the output of the Reshape op.
        output_tensor = bn_op.outputs[0]
        if layer_op.type == 'MatMul':
            output_reshape_op = match_result.get_op(
                matmul_bn_output_reshape_pattern)
            # If the matcher didn't match matmul_bn_output_reshape, there will be
            # another match for this 'MatMul' later, so we can skip this one.
            if output_reshape_op is None:
                return None, None
            output_tensor = output_reshape_op.outputs[0]

        # Ensure that the output tensor has consumers, otherwise this is a dangling
        # node and not a match.
        if not output_tensor.consumers():
            return None, None

        batch_to_space_op = match_result.get_op(batch_to_space_pattern)
        input_tensor = match_result.get_tensor(input_pattern)
        weight_tensor = match_result.get_tensor(weight_pattern)
        gamma_tensor = match_result.get_tensor(gamma_pattern)
        beta_tensor = match_result.get_tensor(beta_pattern)
        # FusedBatchNorm in training is different from that in inference. It takes
        # empty 'mean' and empty 'variance', and produces the mean and the variance
        # of the batch. Therefore, when is_training is true, mean_tensor and
        # variance_tensor point to 1st and 2nd (0-based) output of bn_op,
        # respectively; when is_training is false, they point to bn_op's inputs.
        is_training = bn_op.get_attr('is_training')
        if is_training:
            # FusedBatchNormGrad doesn't compute gradients of the batch_mean and
            # batch_variance outputs, so we need to substitute our own custom
            # gradient.
            # TODO(suharshs, raghuramank): Find a way to avoid needing this hack.
            # pylint: disable=protected-access
            bn_op._set_attr(
                '_gradient_op_type',
                attr_value_pb2.AttrValue(
                    s=compat.as_bytes('FoldFusedBatchNormGrad')))
            # pylint: enable=protected-access
            mean_tensor = bn_op.outputs[1]
            # The batch variance used during forward and backward prop is biased,
            # i.e it is calculated as: V=sum(x(k)-mu)^2/N. For the moving average
            # calculation, the variance is corrected by the term N/N-1 (Bessel's
            # correction). The variance tensor read from FuseBatchNorm has Bessel's
            # correction applied, so we undo it here.
            scope, sep, _ = bn_op.name.rpartition('/')
            g = ops.get_default_graph()
            with g.as_default(), g.name_scope(scope + sep):
                n = math_ops.cast(
                    array_ops.size(layer_tensor) / array_ops.size(mean_tensor),
                    dtypes.float32)
                variance_tensor = math_ops.multiply(
                    bn_op.outputs[2], (n - 1) / n,
                    name='Undo_Bessel_Correction')
            # TODO(suharshs): Find a way to get rid of this inner match.
            for mul_match_result in moving_avg_mul_matcher.match_graph(graph):
                sub_op = mul_match_result.get_op(moving_average_sub_pattern)
                if sub_op.inputs[1].name == bn_op.outputs[1].name:
                    # During training: Batch Mean is bn_op.outputs[1]
                    moving_mean_tensor = sub_op.inputs[0]
                    bn_decay_mean_tensor = mul_match_result.get_tensor(
                        bn_decay_pattern)
                if sub_op.inputs[1].name == bn_op.outputs[2].name:
                    # During training: Batch Var is bn_op.outputs[2]
                    moving_variance_tensor = sub_op.inputs[0]
                    bn_decay_var_tensor = mul_match_result.get_tensor(
                        bn_decay_pattern)
        else:
            mean_tensor = match_result.get_tensor(mean_pattern)
            variance_tensor = match_result.get_tensor(variance_pattern)

        return layer_op, _BatchNormMatch(
            layer_op=layer_op,
            bn_op=bn_op,
            output_tensor=output_tensor,
            input_tensor=input_tensor,
            weight_tensor=weight_tensor,
            gamma_tensor=gamma_tensor,
            beta_tensor=beta_tensor,
            mean_tensor=mean_tensor,
            variance_tensor=variance_tensor,
            moving_mean_tensor=moving_mean_tensor,
            moving_variance_tensor=moving_variance_tensor,
            bn_decay_mean_tensor=bn_decay_mean_tensor,
            bn_decay_var_tensor=bn_decay_var_tensor,
            batch_epsilon=batch_epsilon,
            batch_to_space_op=batch_to_space_op)
Example #47
0
def set_random_seed(seed):
    """Sets the graph-level random seed.

  Operations that rely on a random seed actually derive it from two seeds:
  the graph-level and operation-level seeds. This sets the graph-level seed.

  Its interactions with operation-level seeds is as follows:

    1. If neither the graph-level nor the operation seed is set:
      A random seed is used for this op.
    2. If the graph-level seed is set, but the operation seed is not:
      The system deterministically picks an operation seed in conjunction
      with the graph-level seed so that it gets a unique random sequence.
    3. If the graph-level seed is not set, but the operation seed is set:
      A default graph-level seed and the specified operation seed are used to
      determine the random sequence.
    4. If both the graph-level and the operation seed are set:
      Both seeds are used in conjunction to determine the random sequence.

  To illustrate the user-visible effects, consider these examples:

  To generate different sequences across sessions, set neither
  graph-level nor op-level seeds:

  ```python
  a = tf.random_uniform([1])
  b = tf.random_normal([1])

  print("Session 1")
  with tf.Session() as sess1:
    print(sess1.run(a))  # generates 'A1'
    print(sess1.run(a))  # generates 'A2'
    print(sess1.run(b))  # generates 'B1'
    print(sess1.run(b))  # generates 'B2'

  print("Session 2")
  with tf.Session() as sess2:
    print(sess2.run(a))  # generates 'A3'
    print(sess2.run(a))  # generates 'A4'
    print(sess2.run(b))  # generates 'B3'
    print(sess2.run(b))  # generates 'B4'
  ```

  To generate the same repeatable sequence for an op across sessions, set the
  seed for the op:

  ```python
  a = tf.random_uniform([1], seed=1)
  b = tf.random_normal([1])

  # Repeatedly running this block with the same graph will generate the same
  # sequence of values for 'a', but different sequences of values for 'b'.
  print("Session 1")
  with tf.Session() as sess1:
    print(sess1.run(a))  # generates 'A1'
    print(sess1.run(a))  # generates 'A2'
    print(sess1.run(b))  # generates 'B1'
    print(sess1.run(b))  # generates 'B2'

  print("Session 2")
  with tf.Session() as sess2:
    print(sess2.run(a))  # generates 'A1'
    print(sess2.run(a))  # generates 'A2'
    print(sess2.run(b))  # generates 'B3'
    print(sess2.run(b))  # generates 'B4'
  ```

  To make the random sequences generated by all ops be repeatable across
  sessions, set a graph-level seed:

  ```python
  tf.set_random_seed(1234)
  a = tf.random_uniform([1])
  b = tf.random_normal([1])

  # Repeatedly running this block with the same graph will generate the same
  # sequences of 'a' and 'b'.
  print("Session 1")
  with tf.Session() as sess1:
    print(sess1.run(a))  # generates 'A1'
    print(sess1.run(a))  # generates 'A2'
    print(sess1.run(b))  # generates 'B1'
    print(sess1.run(b))  # generates 'B2'

  print("Session 2")
  with tf.Session() as sess2:
    print(sess2.run(a))  # generates 'A1'
    print(sess2.run(a))  # generates 'A2'
    print(sess2.run(b))  # generates 'B1'
    print(sess2.run(b))  # generates 'B2'
  ```

  Args:
    seed: integer.
  """
    if context.executing_eagerly():
        context.set_global_seed(seed)
    else:
        ops.get_default_graph().seed = seed
Example #48
0
def _is_old_cond():
  return isinstance(ops.get_default_graph()._get_control_flow_context(),
                    control_flow_ops.CondContext)
Example #49
0
  def add_weight(self,
                 name,
                 shape,
                 dtype=None,
                 initializer=None,
                 regularizer=None,
                 trainable=None,
                 constraint=None,
                 use_resource=None,
                 synchronization=vs.VariableSynchronization.AUTO,
                 aggregation=vs.VariableAggregation.NONE,
                 partitioner=None,
                 **kwargs):
    """Adds a new variable to the layer, or gets an existing one; returns it.

    Arguments:
      name: variable name.
      shape: variable shape.
      dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
      initializer: initializer instance (callable).
      regularizer: regularizer instance (callable).
      trainable: whether the variable should be part of the layer's
        "trainable_variables" (e.g. variables, biases)
        or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
        Note, if the current variable scope is marked as non-trainable
        then this parameter is ignored and any added variables are also
        marked as non-trainable. `trainable` defaults to `True` unless
        `synchronization` is set to `ON_READ`.
      constraint: constraint instance (callable).
      use_resource: Whether to use `ResourceVariable`.
      synchronization: Indicates when a distributed a variable will be
        aggregated. Accepted values are constants defined in the class
        `tf.VariableSynchronization`. By default the synchronization is set to
        `AUTO` and the current `DistributionStrategy` chooses
        when to synchronize. If `synchronization` is set to `ON_READ`,
        `trainable` must not be set to `True`.
      aggregation: Indicates how a distributed variable will be aggregated.
        Accepted values are constants defined in the class
        `tf.VariableAggregation`.
      partitioner: (optional) partitioner instance (callable).  If
        provided, when the requested variable is created it will be split
        into multiple partitions according to `partitioner`.  In this case,
        an instance of `PartitionedVariable` is returned.  Available
        partitioners include `tf.compat.v1.fixed_size_partitioner` and
        `tf.compat.v1.variable_axis_size_partitioner`.  For more details, see
        the documentation of `tf.compat.v1.get_variable` and the  "Variable
        Partitioners and Sharding" section of the API guide.
      **kwargs: Additional keyword arguments.

    Returns:
      The created variable.  Usually either a `Variable` or `ResourceVariable`
      instance.  If `partitioner` is not `None`, a `PartitionedVariable`
      instance is returned.

    Raises:
      RuntimeError: If called with partitioned variable regularization and
        eager execution is enabled.
      ValueError: When trainable has been set to True with synchronization
        set as `ON_READ`.
    """
    for kwarg in kwargs:
      if kwarg != 'experimental_autocast':
        raise TypeError('Unknown keyword argument:', kwarg)
    if self._keras_style:
      return super(Layer, self).add_weight(
          name=name,
          shape=shape,
          dtype=dtype,
          initializer=initializer,
          regularizer=regularizer,
          trainable=trainable and self.trainable,
          constraint=constraint,
          use_resource=use_resource,
          synchronization=vs.VariableSynchronization.AUTO,
          aggregation=vs.VariableAggregation.NONE,
          partitioner=partitioner,
          **kwargs)

    if synchronization == vs.VariableSynchronization.ON_READ:
      if trainable:
        raise ValueError(
            'Synchronization value can be set to '
            'VariableSynchronization.ON_READ only for non-trainable variables. '
            'You have specified trainable=True and '
            'synchronization=VariableSynchronization.ON_READ.')
      else:
        # Set trainable to be false when variable is to be synced on read.
        trainable = False
    elif trainable is None:
      trainable = True

    def _should_add_regularizer(variable, existing_variable_set):
      if isinstance(variable, tf_variables.PartitionedVariable):
        for var in variable:
          if var in existing_variable_set:
            return False
        return True
      else:
        return variable not in existing_variable_set

    init_graph = None
    if not context.executing_eagerly():
      default_graph = ops.get_default_graph()
      if default_graph.building_function:
        with ops.init_scope():
          # Retrieve the variables from the graph into which variables
          # will be lifted; if initialization ops will be lifted into
          # the eager context, then there is nothing to retrieve, since variable
          # collections are not supported when eager execution is enabled.
          if not context.executing_eagerly():
            init_graph = ops.get_default_graph()
            existing_variables = set(tf_variables.global_variables())
      else:
        # Initialization ops will not be lifted out of the default graph.
        init_graph = default_graph
        existing_variables = set(tf_variables.global_variables())

    if dtype is None:
      dtype = self.dtype or dtypes.float32

    self._set_scope(None)
    reuse = self.built or self._reuse
    prev_len_trainable = len(self._trainable_weights)
    with vs.variable_scope(
        self._scope, reuse=reuse, auxiliary_name_scope=False) as scope:
      self._current_scope = scope
      with ops.name_scope(self._name_scope(), skip_on_eager=False):
        use_resource = (use_resource or
                        self._use_resource_variables or
                        scope.use_resource)
        if initializer is None:
          initializer = scope.initializer
        variable = super(Layer, self).add_weight(
            name,
            shape,
            dtype=dtypes.as_dtype(dtype),
            initializer=initializer,
            trainable=trainable and self.trainable,
            constraint=constraint,
            partitioner=partitioner,
            use_resource=use_resource,
            synchronization=synchronization,
            aggregation=aggregation,
            getter=vs.get_variable,
            **kwargs)

        if regularizer:
          if (ops.executing_eagerly_outside_functions()
              or _should_add_regularizer(variable, existing_variables)):
            self._handle_weight_regularization(name, variable, regularizer)

        if init_graph is not None:
          # Handle edge case where a custom getter has overridden `trainable`.
          # There is one known occurrence of this, in unit test
          # testBasicRNNCellNotTrainable in
          # contrib.rnn.python.kernel_tests.core_rnn_cell_test
          with init_graph.as_default():
            trainable_variables = tf_variables.trainable_variables()
          if (trainable and self.trainable and
              variable not in trainable_variables):
            # A custom getter / variable scope overrode the trainable flag.
            extra_trainable_vars = self._trainable_weights[prev_len_trainable:]
            self._trainable_weights = self._trainable_weights[
                :prev_len_trainable]
            self._non_trainable_weights += extra_trainable_vars
    return variable
Example #50
0
    def __init__(
            self,  # pylint: disable=super-init-not-called
            initial_value=None,
            trainable=None,
            caching_device=None,
            name=None,
            dtype=None,
            constraint=None,
            add_initializers_to=None,
            lifted_initializer_graph=None,
            **unused_kwargs):
        """Creates a variable.

    Args:
      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
        which is the initial value for the Variable. The initial value must have
        a shape specified unless `validate_shape` is set to False. Can also be a
        callable with no argument that returns the initial value when called.
        (Note that initializer functions from init_ops.py must first be bound
         to a shape before being used here.)
      trainable: If `True`, GradientTapes automatically watch uses of this
        Variable.
      caching_device: Optional device string or function describing where the
        Variable should be cached for reading.  Defaults to the Variable's
        device.  If not `None`, caches on another device.  Typical use is to
        cache on the device where the Ops using the Variable reside, to
        deduplicate copying through `Switch` and other conditional statements.
      name: Optional name for the variable. Defaults to `'Variable'` and gets
        uniquified automatically.
      dtype: If set, initial_value will be converted to the given type.
        If None, either the datatype will be kept (if initial_value is
       a Tensor) or float32 will be used (if it is a Python object convertible
       to a Tensor).
      constraint: An optional projection function to be applied to the variable
        after being updated by an `Optimizer` (e.g. used to implement norm
        constraints or value constraints for layer weights). The function must
        take as input the unprojected Tensor representing the value of the
        variable and return the Tensor for the projected value
        (which must have the same shape). Constraints are not safe to
        use when doing asynchronous distributed training.
      add_initializers_to: if not None and not in legacy graph mode, the
        initializer tensor will be added to this map in addition to adding the
        assignment to the function.
      lifted_initializer_graph: FuncGraph to try to lift initializers to.

    Raises:
      ValueError: If the initial value is not specified, or does not have a
        shape and `validate_shape` is `True`.
      RuntimeError: If called outside of a function definition.
    """
        if not ops.inside_function():
            # If we've been init_scope()d out of the function definition nothing to do
            # here; we can't really do the capturing or conditional logic.
            resource_variable_ops.ResourceVariable.__init__(
                self,
                initial_value=initial_value,
                trainable=trainable,
                caching_device=caching_device,
                name=name,
                dtype=dtype,
                constraint=constraint)
            return
        with ops.init_scope():
            self._in_graph_mode = not context.executing_eagerly()
        if initial_value is None:
            raise ValueError("initial_value must be specified.")
        init_from_fn = callable(initial_value)

        if constraint is not None and not callable(constraint):
            raise ValueError("The `constraint` argument must be a callable.")

        if isinstance(initial_value, checkpointable.CheckpointInitialValue):
            self._maybe_initialize_checkpointable()
            self._update_uid = initial_value.checkpoint_position.restore_uid
            initial_value = initial_value.wrapped_value

        if trainable is None:
            trainable = True
        self._trainable = trainable
        self._save_slice_info = None
        self._initial_value = None
        self._initializer_op = None
        self._is_initialized_op = None
        self._graph_element = None
        self._cached_value = None
        # Store the graph key so optimizers know how to only retrieve variables from
        # this graph. Guaranteed to be the same as the eager graph_key.
        self._graph_key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
        with ops.name_scope(name, "Variable",
                            [] if init_from_fn else [initial_value]) as name:
            # pylint: disable=protected-access
            with ops.init_scope():
                handle_name = ops._name_from_scope_name(name)
                unique_id = "%s_%d" % (handle_name, ops.uid())
                shared_name = context.shared_name(unique_id)
            with ops.name_scope("Initializer"), ops.device(None):
                initial_value = ops.convert_to_tensor(
                    initial_value() if init_from_fn else initial_value,
                    name="initial_value",
                    dtype=dtype)
            with ops.init_scope():
                self._handle = resource_variable_ops.eager_safe_variable_handle(
                    initial_value=initial_value,
                    shared_name=shared_name,
                    name=name,
                    graph_mode=self._in_graph_mode)
            self._shape = initial_value.shape
            self._unique_id = unique_id
            self._handle_name = handle_name + ":0"
            self._dtype = initial_value.dtype.base_dtype
            self._constraint = constraint
            assert initial_value is not None
            if self._in_graph_mode:
                with ops.init_scope():
                    outer_graph = ops.get_default_graph()
                func_graph = ops.get_default_graph()
                function_placeholders = (func_graph.inputs +
                                         func_graph.internal_captures)
                placeholder_ops = set(
                    [tensor.op for tensor in function_placeholders])
                lifted_initializer = lift_to_graph.lift_to_graph(
                    [initial_value],
                    outer_graph,
                    disallowed_placeholders=placeholder_ops)[initial_value]
                with ops.init_scope():
                    self._initial_value = lifted_initializer
                    with ops.name_scope("IsInitialized"):
                        self._is_initialized_op = (
                            resource_variable_ops.var_is_initialized_op(
                                self._handle))
                    if initial_value is not None:
                        with ops.name_scope("Assign") as n, ops.colocate_with(
                                self._handle):
                            self._initializer_op = resource_variable_ops.assign_variable_op(
                                self._handle, lifted_initializer, name=n)
                    with ops.name_scope("Read"), ops.colocate_with(
                            self._handle):
                        # Manually assign reads to the handle's device to avoid log
                        # messages.
                        with ops.device(self._handle.device):
                            value = self._read_variable_op()
                        self._graph_element = value
                    ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, self)
            else:
                if add_initializers_to is not None:
                    add_initializers_to[self] = initial_value

                def assign_fn():
                    with ops.name_scope("Assign") as n, ops.colocate_with(
                            self._handle):
                        resource_variable_ops.assign_variable_op(self._handle,
                                                                 initial_value,
                                                                 name=n)
                        # Returning values to keep tf.cond happy.
                    return ops.convert_to_tensor(1)

                def not_assign_fn():
                    return ops.convert_to_tensor(0)

                # Note: this cond is always guaranteed to run because we're inside a
                # defun which will insert automatic control dependencies.
                control_flow_ops.cond(
                    resource_variable_ops.var_is_initialized_op(self._handle),
                    not_assign_fn, assign_fn)

        # After the handle has been created, set up a way to clean it up when
        # executing eagerly. We'll hold the only reference to the deleter, so that
        # when this object is garbage collected the deleter will be too. This
        # means ResourceVariables can be part of reference cycles without those
        # cycles being uncollectable.
        if not self._in_graph_mode:
            self._handle_deleter = resource_variable_ops.EagerResourceDeleter(
                handle=self._handle, handle_device=self._handle.device)
        self._cached_shape_as_list = None
Example #51
0
 def __enter__(self):
     self._g = ops.get_default_graph()
     self._old = self._g._get_control_flow_context()  # pylint: disable=protected-access
     self._g._set_control_flow_context(self)  # pylint: disable=protected-access
Example #52
0
def constant(value, dtype=None, shape=None, name="Const", verify_shape=False):
    """Creates a constant tensor.

  The resulting tensor is populated with values of type `dtype`, as
  specified by arguments `value` and (optionally) `shape` (see examples
  below).

  The argument `value` can be a constant value, or a list of values of type
  `dtype`. If `value` is a list, then the length of the list must be less
  than or equal to the number of elements implied by the `shape` argument (if
  specified). In the case where the list length is less than the number of
  elements specified by `shape`, the last element in the list will be used
  to fill the remaining entries.

  The argument `shape` is optional. If present, it specifies the dimensions of
  the resulting tensor. If not present, the shape of `value` is used.

  If the argument `dtype` is not specified, then the type is inferred from
  the type of `value`.

  For example:

  ```python
  # Constant 1-D Tensor populated with value list.
  tensor = tf.constant([1, 2, 3, 4, 5, 6, 7]) => [1 2 3 4 5 6 7]

  # Constant 2-D tensor populated with scalar value -1.
  tensor = tf.constant(-1.0, shape=[2, 3]) => [[-1. -1. -1.]
                                               [-1. -1. -1.]]
  ```

  Args:
    value:          A constant value (or list) of output type `dtype`.

    dtype:          The type of the elements of the resulting tensor.

    shape:          Optional dimensions of resulting tensor.

    name:           Optional name for the tensor.

    verify_shape:   Boolean that enables verification of a shape of values.

  Returns:
    A Constant Tensor.

  Raises:
    TypeError: if shape is incorrectly specified or unsupported.
  """
    ctx = context.context()
    if not ctx.in_graph_mode():
        t = convert_to_eager_tensor(value, ctx, dtype)
        if shape is None:
            return t
        shape = tensor_shape.as_shape(shape)
        if shape == t.shape:
            return t
        if verify_shape:
            raise TypeError("Expected Tensor's shape: %s, got %s." %
                            (tuple(shape), tuple(t.shape)))
        num_t = t.shape.num_elements()
        # TODO(josh11b): Implement shape -> eager tensor conversion.
        if num_t == shape.num_elements():
            return _eager_reshape(t, shape.as_list(), ctx)
        if num_t == 1:
            if t.dtype == dtypes.bool:
                # We don't have a Fill kernel for bool dtype on GPU. So we first run
                # Fill on CPU and then copy to GPU if needed.
                with ops.device("/device:CPU:0"):
                    x = _eager_fill(shape.as_list(), t.cpu(), ctx)
                return _eager_identity(x, ctx)
            else:
                return _eager_fill(shape.as_list(), t, ctx)
        raise TypeError(
            "Eager execution of tf.constant with unsupported shape "
            "(value has %d elements, shape is %s with %d elements)." %
            (num_t, shape, shape.num_elements()))
    g = ops.get_default_graph()
    tensor_value = attr_value_pb2.AttrValue()
    tensor_value.tensor.CopyFrom(
        tensor_util.make_tensor_proto(value,
                                      dtype=dtype,
                                      shape=shape,
                                      verify_shape=verify_shape))
    dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype)
    const_tensor = g.create_op("Const", [], [dtype_value.type],
                               attrs={
                                   "value": tensor_value,
                                   "dtype": dtype_value
                               },
                               name=name).outputs[0]
    return const_tensor
Example #53
0
    def __init__(self, input_dataset, target_device, source_device="/cpu:0"):
        """Constructs a _CopyToDeviceDataset.

    Args:
      input_dataset: `Dataset` to be copied
      target_device: The name of the device to which elements would be copied.
      source_device: Device where input_dataset would be placed.
    """
        self._input_dataset = input_dataset
        self._target_device = target_device
        spec = framework_device.DeviceSpec().from_string(self._target_device)
        self._is_gpu_target = (spec.device_type == "GPU")
        self._source_device_string = source_device
        self._source_device = ops.convert_to_tensor(source_device)

        wrap_ds_variant = gen_dataset_ops.wrap_dataset_variant(
            self._input_dataset._variant_tensor)  # pylint: disable=protected-access

        @function.defun()
        def _init_func():
            """Creates an iterator for the input dataset.

      Returns:
        A `string` tensor that encapsulates the iterator created.
      """
            ds_variant = gen_dataset_ops.unwrap_dataset_variant(
                wrap_ds_variant)
            resource = gen_dataset_ops.anonymous_iterator(
                **dataset_ops.flat_structure(self._input_dataset))
            with ops.control_dependencies(
                [gen_dataset_ops.make_iterator(ds_variant, resource)]):
                return gen_dataset_ops.iterator_to_string_handle(resource)

        init_func_concrete = _init_func._get_concrete_function_internal()  # pylint: disable=protected-access

        @function.defun()
        def _remote_init_func():
            return functional_ops.remote_call(
                target=self._source_device,
                args=init_func_concrete.captured_inputs,
                Tout=[dtypes.string],
                f=init_func_concrete)

        self._init_func = _remote_init_func._get_concrete_function_internal()  # pylint: disable=protected-access
        self._init_captured_args = self._init_func.captured_inputs

        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
        def _next_func(string_handle):
            """Calls get_next for created iterator.

      Args:
        string_handle: An iterator string handle created by _init_func
      Returns:
        The elements generated from `input_dataset`
      """
            with ops.device(self._source_device_string):
                iterator = iterator_ops.Iterator.from_string_handle(
                    string_handle, dataset_ops.get_legacy_output_types(self),
                    dataset_ops.get_legacy_output_shapes(self),
                    dataset_ops.get_legacy_output_classes(self))
            return self._element_structure._to_tensor_list(iterator.get_next())  # pylint: disable=protected-access

        next_func_concrete = _next_func._get_concrete_function_internal()  # pylint: disable=protected-access

        @function.defun_with_attributes(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
            attributes={"experimental_ints_on_device": True})
        def _remote_next_func(string_handle):
            return functional_ops.remote_call(
                target=self._source_device,
                args=[string_handle] + next_func_concrete.captured_inputs,
                Tout=self._input_dataset._element_structure._flat_types,  # pylint: disable=protected-access
                f=next_func_concrete)

        self._next_func = _remote_next_func._get_concrete_function_internal()  # pylint: disable=protected-access
        self._next_captured_args = self._next_func.captured_inputs

        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
        def _finalize_func(string_handle):
            """Destroys the iterator resource created.

      Args:
        string_handle: An iterator string handle created by _init_func
      Returns:
        Tensor constant 0
      """
            iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
                string_handle,
                **dataset_ops.flat_structure(self._input_dataset))
            with ops.control_dependencies([
                    resource_variable_ops.destroy_resource_op(
                        iterator_resource, ignore_lookup_error=True)
            ]):
                return array_ops.constant(0, dtypes.int64)

        finalize_func_concrete = _finalize_func._get_concrete_function_internal(
        )  # pylint: disable=protected-access

        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
        def _remote_finalize_func(string_handle):
            return functional_ops.remote_call(
                target=self._source_device,
                args=[string_handle] + finalize_func_concrete.captured_inputs,
                Tout=[dtypes.int64],
                f=finalize_func_concrete)

        self._finalize_func = _remote_finalize_func._get_concrete_function_internal(  # pylint: disable=protected-access
        )
        self._finalize_captured_args = self._finalize_func.captured_inputs

        g = ops.get_default_graph()
        self._init_func.add_to_graph(g)
        self._next_func.add_to_graph(g)
        self._finalize_func.add_to_graph(g)
        # pylint: enable=protected-scope

        with ops.device(self._target_device):
            variant_tensor = gen_dataset_ops.generator_dataset(
                self._init_captured_args,
                self._next_captured_args,
                self._finalize_captured_args,
                init_func=self._init_func,
                next_func=self._next_func,
                finalize_func=self._finalize_func,
                **dataset_ops.flat_structure(self._input_dataset))
        super(_CopyToDeviceDataset, self).__init__(input_dataset,
                                                   variant_tensor)
Example #54
0
def _defun_internal(name, func, args, kwds):
    """Defines and returns graph-mode version of func."""
    container_prefix = ops.get_default_graph()._container_prefix  # pylint: disable=protected-access
    with context.graph_mode():
        captures = {}
        tmp_graph = CapturingGraph(captures)
        # Inherit the container prefix, since this is used for error checking when
        # isolating eager execution (the container prefix at creation must match the
        # container prefix when used, and variables accessed in the defun will be
        # used in the outside context).
        tmp_graph._container_prefix = container_prefix  # pylint: disable=protected-access
        # Copy the graph collections to ensure summaries and other things work. This
        # lets the function access (but not mutate) collections of the containing
        # graph, such as the global step and the summary writer collections.
        curr_graph = ops.get_default_graph()
        for collection in curr_graph.collections:
            tmp_graph.get_collection_ref(
                collection)[:] = curr_graph.get_collection(collection)
        with tmp_graph.as_default():
            func_inputs = _get_defun_inputs(args)

            with capture_tensors(captures):
                tape.push_new_tape()
                try:
                    func_outputs = func(*func_inputs, **kwds)
                finally:
                    variables = tape.pop_tape().watched_variables()

                # Returning a closed-over tensor as an output does not trigger a
                # call to convert_to_tensor, so we manually capture all such tensors.
                outputs_list = nest.flatten(func_outputs)
                func_def_outputs = [
                    _convert_to_graph_tensor(x) for x in outputs_list
                    if x is not None
                ]

            ids = list(sorted(captures.keys()))
            if ids:
                extra_inputs, extra_placeholders = zip(
                    *[captures[x] for x in ids])
            else:
                extra_inputs = []
                extra_placeholders = []
            output_shapes = tuple(
                x.shape if isinstance(x, ops.Tensor) else None
                for x in outputs_list)

    flat_inputs = [
        x for x in nest.flatten(func_inputs) if isinstance(x, ops.Tensor)
    ]
    all_inputs = flat_inputs + list(extra_placeholders)
    all_ignored_ops = frozenset(x.op for x in all_inputs)
    fname = _inference_name(name)
    operations = tuple(x for x in tmp_graph.get_operations()
                       if x not in all_ignored_ops)
    # Register any other functions defined in the graph
    # TODO(ashankar): Oh lord, forgive me for this lint travesty.
    if context.in_eager_mode():
        for f in tmp_graph._functions.values():  # pylint: disable=protected-access
            # TODO(ashankar): What about the gradient registry?
            _register(f._c_func)  # pylint: disable=protected-access
    return GraphModeFunction(fname, all_inputs, extra_inputs, tmp_graph,
                             operations, func_def_outputs, func_outputs,
                             output_shapes, variables)
Example #55
0
def compute_weighted_loss(
    losses, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES,
    reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
  """Computes the weighted loss.

  Args:
    losses: `Tensor` of shape `[batch_size, d1, ... dN]`.
    weights: Optional `Tensor` whose rank is either 0, or the same rank as
      `losses`, and must be broadcastable to `losses` (i.e., all dimensions must
      be either `1`, or the same as the corresponding `losses` dimension).
    scope: the scope for the operations performed in computing the loss.
    loss_collection: the loss will be added to these collections.
    reduction: Type of reduction to apply to loss.

  Returns:
    Weighted loss `Tensor` of the same type as `losses`. If `reduction` is
    `NONE`, this has the same shape as `losses`; otherwise, it is scalar.

  Raises:
    ValueError: If `weights` is `None` or the shape is not compatible with
      `losses`, or if the number of dimensions (rank) of either `losses` or
      `weights` is missing.

  Note:
    When calculating the gradient of a weighted loss contributions from
    both `losses` and `weights` are considered. If your `weights` depend
    on some model parameters but you do not want this to affect the loss
    gradient, you need to apply `tf.stop_gradient` to `weights` before
    passing them to `compute_weighted_loss`.

  @compatibility(eager)
  The `loss_collection` argument is ignored when executing eagerly. Consider
  holding on to the return value or collecting losses via a `tf.keras.Model`.
  @end_compatibility
  """
  Reduction.validate(reduction)
  with ops.name_scope(scope, "weighted_loss", (losses, weights)):
    # Save the `reduction` argument for loss normalization when distributing
    # to multiple replicas.
    # TODO(josh11b): Associate it with the returned op for more precision.
    ops.get_default_graph()._last_loss_reduction = reduction  # pylint: disable=protected-access

    with ops.control_dependencies((
        weights_broadcast_ops.assert_broadcastable(weights, losses),)):
      losses = ops.convert_to_tensor(losses)
      input_dtype = losses.dtype
      losses = math_ops.to_float(losses)
      weights = math_ops.to_float(weights)
      weighted_losses = math_ops.multiply(losses, weights)
      if reduction == Reduction.NONE:
        loss = weighted_losses
      else:
        loss = math_ops.reduce_sum(weighted_losses)
        if reduction == Reduction.MEAN:
          loss = _safe_mean(
              loss,
              math_ops.reduce_sum(array_ops.ones_like(losses) * weights))
        elif (reduction == Reduction.SUM_BY_NONZERO_WEIGHTS or
              reduction == Reduction.SUM_OVER_NONZERO_WEIGHTS):
          loss = _safe_mean(loss, _num_present(losses, weights))
        elif reduction == Reduction.SUM_OVER_BATCH_SIZE:
          loss = _safe_mean(loss, _num_elements(losses))

      # Convert the result back to the input type.
      loss = math_ops.cast(loss, input_dtype)
      util.add_loss(loss, loss_collection)
      return loss
Example #56
0
  def _apply_op_helper(self, op_type_name, name=None, **keywords):
    """Implementation of apply_op that returns output_structure, op."""
    op_info = self._ops.get(op_type_name, None)
    if op_info is None:
      raise RuntimeError("Unrecognized Op name " + op_type_name)
    op_def = op_info.op_def

    # Determine the graph context.
    try:
      # Need to flatten all the arguments into a list.
      # pylint: disable=protected-access
      g = ops._get_graph_from_inputs(_Flatten(keywords.values()))
      # pylint: enable=protected-access
    except AssertionError as e:
      raise RuntimeError(
          "Cannot determine graph for Op '%s' due to: %s"
          % (op_type_name, e.message))

    # Default name if not specified.
    if name is None:
      name = op_type_name

    # Check for deprecation
    deprecation_version = op_def.deprecation.version
    if deprecation_version:
      producer = g.graph_def_versions.producer
      if producer >= deprecation_version:
        raise NotImplementedError(
            ("Op %s is not available in GraphDef version %d. "
             "It has been removed in version %d. %s.") %
            (op_type_name, producer, deprecation_version,
             op_def.deprecation.explanation))

    # Fill in the list of default types for all "type" attrs.  This
    # will be used to choose a preferred dtype to convert to in the
    # absence of input type information.
    #
    # TODO(b/31302892): Currently the defaults don't work in the right
    # way if you have two inputs, one of whose type resolution depends
    # on the other.  Handling this will require restructuring this code
    # significantly.
    default_type_attr_map = {}
    for attr_def in op_def.attr:
      if attr_def.type != "type":
        continue
      key = attr_def.name
      if attr_def.HasField("default_value"):
        default_type_attr_map[key] = dtypes.as_dtype(
            attr_def.default_value.type)

    # Requires that op_def has passed validation (using the C++
    # ValidateOpDef() from ../framework/op_def_util.h).
    attrs = {}
    inputs = []
    input_types = []
    with g.as_default(), ops.name_scope(name) as scope:

      # Perform input type inference
      inferred_from = {}
      for input_arg in op_def.input_arg:
        input_name = input_arg.name
        if input_name in keywords:
          values = keywords.pop(input_name)
        elif input_name + "_" in keywords:
          # Handle the case where the name is a keyword or built-in
          # for Python so we use the name + _ instead.
          input_name += "_"
          values = keywords.pop(input_name)
        else:
          raise TypeError("No argument for input " + input_name)

        # Goals:
        # * Convert values to Tensors if it contains constants.
        # * Verify that values is a list if that matches the input_arg's
        #   type.
        # * If the input_arg's type is determined by attrs, either set
        #   those attrs and validate those attr values are legal (if
        #   they have not yet been set) or validate the input matches
        #   the type indicated by the attrs (if they have already been
        #   inferred via an earlier input).
        # * If the input_arg has an explicit type, make sure the input
        #   conforms.

        if _IsListParameter(input_arg):
          if not _IsListValue(values):
            raise TypeError(
                "Expected list for '%s' argument to '%s' Op, not %s." %
                (input_name, op_type_name, values))
          # In cases where we expect all elements of the list to have the
          # same dtype, try to cast non-Tensor elements to that type.
          dtype = None
          default_dtype = None
          if input_arg.type != types_pb2.DT_INVALID:
            dtype = input_arg.type
          elif input_arg.number_attr:
            if input_arg.type_attr in attrs:
              dtype = attrs[input_arg.type_attr]
            else:
              for t in values:
                if isinstance(t, ops.Tensor):
                  dtype = t.dtype
                  break

            # dtype still not found, prefer using the default dtype
            # from the attr.
            if dtype is None and input_arg.type_attr in default_type_attr_map:
              default_dtype = default_type_attr_map[input_arg.type_attr]

          try:
            if not input_arg.is_ref and dtype:
              dtype = dtypes.as_dtype(dtype).base_dtype
            values = ops.internal_convert_n_to_tensor(
                values,
                name=input_arg.name,
                dtype=dtype if dtype else None,
                preferred_dtype=default_dtype,
                as_ref=input_arg.is_ref)
            if input_arg.number_attr and len(
                set(v.dtype.base_dtype for v in values)) > 1:
              raise TypeError()  # All types should match.
          except (TypeError, ValueError):
            # What types does the conversion function think values have?
            observed_types = []
            for value in values:
              try:
                converted_value = ops.internal_convert_to_tensor(
                    value, as_ref=input_arg.is_ref)
                observed_types.append(converted_value.dtype.base_dtype.name)
              except (TypeError, ValueError):
                observed_types.append("<NOT CONVERTIBLE TO TENSOR>")
            observed = ", ".join(observed_types)

            prefix = (
                "Tensors in list passed to '%s' of '%s' Op have types [%s]" %
                (input_name, op_type_name, observed))
            if input_arg.number_attr:
              if input_arg.type != types_pb2.DT_INVALID:
                raise TypeError("%s that do not match expected type %s." %
                                (prefix, dtype.name))
              elif input_arg.type_attr in attrs:
                raise TypeError("%s that do not match type %s inferred from "
                                "earlier arguments." %
                                (prefix, dtype.name))
              else:
                raise TypeError("%s that don't all match." % prefix)
            else:
              raise TypeError("%s that are invalid." % prefix)

          types = [x.dtype for x in values]
          inputs.extend(values)
        else:
          # In cases where we have an expected type, try to convert non-Tensor
          # arguments to that type.
          dtype = None
          default_dtype = None
          if input_arg.type != types_pb2.DT_INVALID:
            dtype = input_arg.type
          elif input_arg.type_attr in attrs:
            dtype = attrs[input_arg.type_attr]
          elif input_arg.type_attr in default_type_attr_map:
            # The dtype could not be inferred solely from the inputs,
            # so we prefer the attr's default, so code that adds a new attr
            # with a default is backwards compatible.
            default_dtype = default_type_attr_map[input_arg.type_attr]

          try:
            values = ops.internal_convert_to_tensor(
                values,
                name=input_arg.name,
                dtype=dtype,
                as_ref=input_arg.is_ref,
                preferred_dtype=default_dtype)
          except TypeError as err:
            if dtype is None:
              raise err
            else:
              raise TypeError(
                  "Expected %s passed to parameter '%s' of op '%s', got %s of "
                  "type '%s' instead." %
                  (dtypes.as_dtype(dtype).name, input_arg.name, op_type_name,
                   repr(values), type(values).__name__))
          except ValueError:
            # What type does convert_to_tensor think it has?
            try:
              observed = ops.internal_convert_to_tensor(
                  values, as_ref=input_arg.is_ref).dtype.name
            except ValueError as err:
              raise ValueError(
                  "Tried to convert '%s' to a tensor and failed. Error: %s" %
                  (input_name, err))
            prefix = ("Input '%s' of '%s' Op has type %s that does not match" %
                      (input_name, op_type_name, observed))
            if input_arg.type != types_pb2.DT_INVALID:
              raise TypeError("%s expected type of %s." %
                              (prefix, dtypes.as_dtype(input_arg.type).name))
            else:
              # Update the maps with the default, if needed.
              k = input_arg.type_attr
              if k in default_type_attr_map:
                if k not in attrs:
                  attrs[k] = default_type_attr_map[k]
                  if k not in inferred_from:
                    inferred_from[k] = "Default in OpDef"

              raise TypeError(
                  "%s type %s of argument '%s'." %
                  (prefix, dtypes.as_dtype(attrs[input_arg.type_attr]).name,
                   inferred_from[input_arg.type_attr]))

          types = [values.dtype]
          inputs.append(values)
        base_types = [x.base_dtype for x in types]

        if input_arg.number_attr:
          # <number-attr> * <type> or <number-attr> * <type-attr>
          if input_arg.number_attr in attrs:
            if len(values) != attrs[input_arg.number_attr]:
              raise ValueError(
                  "List argument '%s' to '%s' Op with length %d must match "
                  "length %d of argument '%s'." %
                  (input_name, op_type_name, len(values),
                   attrs[input_arg.number_attr],
                   inferred_from[input_arg.number_attr]))
          else:
            attrs[input_arg.number_attr] = len(values)
            inferred_from[input_arg.number_attr] = input_name
            num_attr = _Attr(op_def, input_arg.number_attr)
            if num_attr.has_minimum and len(values) < num_attr.minimum:
              raise ValueError(
                  "List argument '%s' to '%s' Op with length %d shorter "
                  "than minimum length %d." %
                  (input_name, op_type_name, len(values), num_attr.minimum))
          # All tensors must have the same base type.
          if any([bt != base_types[0] for bt in base_types]):
            raise TypeError(
                "All tensors passed to '%s' of '%s' Op "
                "must have the same type." %
                (input_name, op_type_name))
          if input_arg.type != types_pb2.DT_INVALID:
            # <number-attr> * <type> case
            if base_types and base_types[0] != input_arg.type:
              assert False, "Unreachable"
          elif input_arg.type_attr in attrs:
            # <number-attr> * <type-attr> case, where <type-attr> already
            # has an inferred value.
            if base_types and base_types[0] != attrs[input_arg.type_attr]:
              assert False, "Unreachable"
          else:
            # <number-attr> * <type-attr> case, where we are now setting
            # the <type-attr> based on this input
            if not base_types:
              raise TypeError(
                  "Don't know how to infer type variable from empty input "
                  "list passed to input '%s' of '%s' Op." %
                  (input_name, op_type_name))
            attrs[input_arg.type_attr] = base_types[0]
            inferred_from[input_arg.type_attr] = input_name
            type_attr = _Attr(op_def, input_arg.type_attr)
            _SatisfiesTypeConstraint(base_types[0], type_attr,
                                     param_name=input_name)
        elif input_arg.type_attr:
          # <type-attr>
          attr_value = base_types[0]
          if input_arg.type_attr in attrs:
            if attrs[input_arg.type_attr] != attr_value:
              assert False, "Unreachable"
          else:
            for base_type in base_types:
              _SatisfiesTypeConstraint(base_type,
                                       _Attr(op_def, input_arg.type_attr),
                                       param_name=input_name)
            attrs[input_arg.type_attr] = attr_value
            inferred_from[input_arg.type_attr] = input_name
        elif input_arg.type_list_attr:
          # <type-list-attr>
          attr_value = base_types
          if input_arg.type_list_attr in attrs:
            if attrs[input_arg.type_list_attr] != attr_value:
              raise TypeError(
                  "Input '%s' of '%s' Op has type list of %s that does not "
                  "match type list %s of argument '%s'." %
                  (input_name, op_type_name,
                   ", ".join(dtypes.as_dtype(x).name for x in attr_value),
                   ", ".join(dtypes.as_dtype(x).name
                             for x in attrs[input_arg.type_list_attr]),
                   inferred_from[input_arg.type_list_attr]))
          else:
            for base_type in base_types:
              _SatisfiesTypeConstraint(base_type,
                                       _Attr(op_def, input_arg.type_list_attr),
                                       param_name=input_name)
            attrs[input_arg.type_list_attr] = attr_value
            inferred_from[input_arg.type_list_attr] = input_name
        else:
          # single Tensor with specified type
          if base_types[0] != input_arg.type:
            assert False, "Unreachable"

        if input_arg.is_ref:
          if not all(x._is_ref_dtype for x in types):  # pylint: disable=protected-access
            raise TypeError(
                ("'%s' Op requires that input '%s' be a mutable tensor "
                 "(e.g.: a tf.Variable)") % (op_type_name, input_name))
          input_types.extend(types)
        else:
          input_types.extend(base_types)

      # Process remaining attrs
      for attr in op_def.attr:
        # Skip attrs that have already had their values inferred
        if attr.name in attrs:
          if attr.name in keywords:
            raise TypeError(
                "Should not specify value for inferred attr '%s'." % attr.name)
          continue
        if attr.name in keywords:
          attrs[attr.name] = keywords.pop(attr.name)
        elif attr.name + "_" in keywords:
          # Attrs whose names match Python keywords have an extra '_'
          # appended, so we must check for that as well.
          attrs[attr.name] = keywords.pop(attr.name + "_")
        else:
          raise TypeError("No argument for attr " + attr.name)

      # Convert attr values to AttrValue protos.
      attr_protos = {}
      for attr_def in op_def.attr:
        key = attr_def.name
        value = attrs[key]
        attr_value = attr_value_pb2.AttrValue()
        if attr_def.HasField("default_value") and value is None:
          attr_value.CopyFrom(attr_def.default_value)
          attr_protos[key] = attr_value
          continue
        if attr_def.type.startswith("list("):
          if not _IsListValue(value):
            raise TypeError("Expected list for attr " + key)
          if attr_def.has_minimum:
            if len(value) < attr_def.minimum:
              raise ValueError("Attr '%s' of '%s' Op passed list of length %d "
                               "less than minimum %d." %
                               (key, op_type_name, len(value),
                                attr_def.minimum))
          attr_value.list.SetInParent()
        if attr_def.type == "string":
          attr_value.s = _MakeStr(value, key)
          if attr_def.HasField("allowed_values"):
            if attr_value.s not in attr_def.allowed_values.list.s:
              raise ValueError(
                  "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
                  (key, op_type_name, compat.as_text(attr_value.s),
                   '", "'.join(map(compat.as_text,
                                   attr_def.allowed_values.list.s))))
        elif attr_def.type == "list(string)":
          attr_value.list.s.extend([_MakeStr(x, key) for x in value])
          if attr_def.HasField("allowed_values"):
            for x in attr_value.list.s:
              if x not in attr_def.allowed_values.list.s:
                raise ValueError(
                    "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
                    (key, op_type_name, compat.as_text(x),
                     '", "'.join(map(compat.as_text,
                                     attr_def.allowed_values.list.s))))
        elif attr_def.type == "int":
          attr_value.i = _MakeInt(value, key)
          if attr_def.has_minimum:
            if attr_value.i < attr_def.minimum:
              raise ValueError(
                  "Attr '%s' of '%s' Op passed %d less than minimum %d." %
                  (key, op_type_name, attr_value.i, attr_def.minimum))
        elif attr_def.type == "list(int)":
          attr_value.list.i.extend([_MakeInt(x, key) for x in value])
        elif attr_def.type == "float":
          attr_value.f = _MakeFloat(value, key)
        elif attr_def.type == "list(float)":
          attr_value.list.f.extend([_MakeFloat(x, key) for x in value])
        elif attr_def.type == "bool":
          attr_value.b = _MakeBool(value, key)
        elif attr_def.type == "list(bool)":
          attr_value.list.b.extend([_MakeBool(x, key) for x in value])
        elif attr_def.type == "type":
          attr_value.type = _MakeType(value, attr_def)
        elif attr_def.type == "list(type)":
          attr_value.list.type.extend(
              [_MakeType(x, attr_def) for x in value])
        elif attr_def.type == "shape":
          attr_value.shape.CopyFrom(_MakeShape(value, key))
        elif attr_def.type == "list(shape)":
          attr_value.list.shape.extend(
              [_MakeShape(x, key) for x in value])
        elif attr_def.type == "tensor":
          attr_value.tensor.CopyFrom(_MakeTensor(value, key))
        elif attr_def.type == "list(tensor)":
          attr_value.list.tensor.extend(
              [_MakeTensor(x, key) for x in value])
        elif attr_def.type == "func":
          if isinstance(value, attr_value_pb2.NameAttrList):
            attr_value.func.CopyFrom(value)
          elif isinstance(value, compat.bytes_or_text_types):
            attr_value.func.name = value
          else:
            value.add_to_graph(ops.get_default_graph())
            attr_value.func.name = value.name
        else:
          raise TypeError("Unrecognized Attr type " + attr_def.type)

        attr_protos[key] = attr_value
      del attrs  # attrs is no longer authoritative, use attr_protos instead

      # Determine output types (possibly using attrs)
      output_types = []
      output_structure = []
      for arg in op_def.output_arg:
        types = []
        if arg.number_attr:
          n = _AttrValue(attr_protos, arg.number_attr).i
          if arg.type_attr:
            types = [_AttrValue(attr_protos, arg.type_attr).type] * n
          else:
            types = [arg.type] * n
          output_structure.append(n)
        elif arg.type_attr:
          t = _AttrValue(attr_protos, arg.type_attr)
          types = [t.type]
          output_structure.append(None)
        elif arg.type_list_attr:
          t = _AttrValue(attr_protos, arg.type_list_attr)
          types = t.list.type
          output_structure.append(len(types))
        else:
          types = [arg.type]
          output_structure.append(None)
        if arg.is_ref:
          types = [dtypes.as_dtype(x)._as_ref for x in types]  # pylint: disable=protected-access
        output_types.extend(types)

      if keywords:
        raise TypeError("apply_op() got unexpected keyword arguments: " +
                        ", ".join(sorted(keywords.keys())))

      # NOTE(mrry): We add an explicit colocation constraint between
      # the newly created op and any of its reference-typed inputs.
      must_colocate_inputs = [val for arg, val in zip(op_def.input_arg, inputs)
                              if arg.is_ref]
      with _MaybeColocateWith(must_colocate_inputs):
        # Add Op to graph
        op = g.create_op(op_type_name, inputs, output_types, name=scope,
                         input_types=input_types, attrs=attr_protos,
                         op_def=op_def)
      return output_structure, op_def.is_stateful, op
Example #57
0
def _test(args):
    #pdb.set_trace();
    args.solver.master = ''
    container_name = ""
    checkpoint_dir = os.path.join(format(args.logdir))
    logging.error('Checkpoint_dir: %s', args.logdir)

    config = tf.ConfigProto()
    config.device_count['GPU'] = 1

    m = utils.Foo()
    m.tf_graph = tf.Graph()

    rng_data_seed = 0
    rng_action_seed = 0
    R = lambda: nav_env.get_multiplexer_class(args.navtask, rng_data_seed)
    with m.tf_graph.as_default():
        with tf.container(container_name):
            m = args.setup_to_run(m,
                                  args,
                                  is_training=False,
                                  batch_norm_is_training=args.control.
                                  force_batchnorm_is_training_at_test,
                                  summary_mode=args.control.test_mode)
            train_step_kwargs = args.setup_train_step_kwargs(
                m,
                R(),
                os.path.join(args.logdir, args.control.test_name),
                rng_seed=rng_data_seed,
                is_chief=True,
                num_steps=args.navtask.task_params.num_steps *
                args.navtask.task_params.num_goals,
                iters=args.summary.test_iters,
                train_display_interval=None,
                dagger_sample_bn_false=args.arch.dagger_sample_bn_false)

            saver = slim.learning.tf_saver.Saver(
                variables.get_variables_to_restore())

            sv = slim.learning.supervisor.Supervisor(
                graph=ops.get_default_graph(),
                logdir=None,
                init_op=m.init_op,
                summary_op=None,
                summary_writer=None,
                global_step=None,
                saver=m.saver_op)

            last_checkpoint = None
            reported = False
            while True:
                last_checkpoint_ = None
                while last_checkpoint_ is None:
                    last_checkpoint_ = slim.evaluation.wait_for_new_checkpoint(
                        checkpoint_dir,
                        last_checkpoint,
                        seconds_to_sleep=10,
                        timeout=60)
                if last_checkpoint_ is None: break

                last_checkpoint = last_checkpoint_
                checkpoint_iter = int(
                    os.path.basename(last_checkpoint).split('-')[1])

                logging.info(
                    'Starting evaluation at %s using checkpoint %s.',
                    time.strftime('%Y-%m-%d-%H:%M:%S', time.localtime()),
                    last_checkpoint)

                if (args.control.only_eval_when_done == False
                        or checkpoint_iter >= args.solver.max_steps):
                    start = time.time()
                    logging.info(
                        'Starting evaluation at %s using checkpoint %s.',
                        time.strftime('%Y-%m-%d-%H:%M:%S', time.localtime()),
                        last_checkpoint)

                    with sv.managed_session(
                            args.solver.master,
                            config=config,
                            start_standard_services=False) as sess:
                        sess.run(m.init_op)
                        sv.saver.restore(sess, last_checkpoint)
                        sv.start_queue_runners(sess)
                        if args.control.reset_rng_seed:
                            train_step_kwargs['rng_data'] = [
                                np.random.RandomState(rng_data_seed),
                                np.random.RandomState(rng_data_seed)
                            ]
                            train_step_kwargs[
                                'rng_action'] = np.random.RandomState(
                                    rng_action_seed)
                        vals, _ = tf_utils.train_step_custom_online_sampling(
                            sess,
                            None,
                            m.global_step_op,
                            train_step_kwargs,
                            mode=args.control.test_mode)
                        should_stop = False

                        if checkpoint_iter >= args.solver.max_steps:
                            should_stop = True

                        if should_stop:
                            break
Example #58
0
    def __init__(self,
                 dataset=None,
                 devices=None,
                 max_buffer_size=1,
                 prefetch_buffer_size=1,
                 source_device="/cpu:0",
                 components=None,
                 element_spec=None):
        """Constructs an owned MultiDeviceIterator object.

    Args:
      dataset: The input dataset to be iterated over.
      devices: The list of devices to fetch data to.
      max_buffer_size: Maximum size of the host side per device buffer to keep.
      prefetch_buffer_size: if > 1, then we setup a buffer on each device to
        prefetch into.
      source_device: The host device to place the `dataset` on.  In order to
        prevent deadlocks, if the prefetch_buffer_size is greater than the
        max_buffer_size, we set the max_buffer_size to prefetch_buffer_size.
      components: Tensor components to construct the MultiDeviceIterator from.
      element_spec: A nested structure of `TypeSpec` objects that
        represents the type specification of elements of the iterator.

    Raises:
      RuntimeError: If executed in graph mode or outside of function building
      mode.
    """
        if (not context.executing_eagerly()
                and not ops.get_default_graph()._building_function):  # pylint: disable=protected-access
            raise RuntimeError(
                "OwnedMultiDeviceIterator is only supported inside of "
                "tf.function or when eager execution is enabled.")
        if devices is None:
            raise ValueError("`devices` must be provided")
        error_message = "Either `dataset` or both `components` and "
        "`element_spec` need to be provided."

        if dataset is None:
            if (components is None or element_spec is None):
                raise ValueError(error_message)
            self._element_spec = element_spec
            self._devices = devices
            self._source_device = source_device
            self._multi_device_iterator_resource = components[0]
            self._deleter = components[1]
            self._device_iterators = components[2:]
            iterator_handles = []
            for it in self._device_iterators:
                iterator_handles.append(it._iterator_resource)  # pylint: disable=protected-access
        else:
            if (components is not None or element_spec is not None):
                raise ValueError(error_message)
            options = dataset_ops.Options()
            options.experimental_distribute.num_devices = len(devices)
            dataset = dataset.with_options(options)
            dataset = dataset._apply_options()  # pylint: disable=protected-access
            self._element_spec = dataset.element_spec
            experimental_slack = dataset.options().experimental_slack
            self._devices = devices
            self._source_device = source_device
            source_device_tensor = ops.convert_to_tensor(self._source_device)

            if prefetch_buffer_size > max_buffer_size:
                max_buffer_size = prefetch_buffer_size

            # Create the MultiDeviceIterator.
            with ops.device(self._source_device):
                self._multi_device_iterator_resource, self._deleter = (
                    gen_dataset_ops.anonymous_multi_device_iterator(
                        devices=self._devices, **dataset._flat_structure))  # pylint: disable=protected-access

                # The incarnation ID is used to ensure consistency between the
                # per-device iterators and the multi-device iterator.
                incarnation_id = gen_dataset_ops.multi_device_iterator_init(
                    dataset._variant_tensor,  # pylint: disable=protected-access
                    self._multi_device_iterator_resource,
                    max_buffer_size=max_buffer_size)

            prototype_device_datasets = []
            for i, device in enumerate(self._devices):
                with ops.device(device):
                    ds = _PerDeviceGenerator(
                        i, self._multi_device_iterator_resource,
                        incarnation_id, source_device_tensor,
                        dataset.element_spec)
                    prototype_device_datasets.append(ds)

            # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
            # initialize the device side of the pipeline. This would allow the
            # MultiDeviceIterator to choose, for example, to move some transformations
            # into the device side from its input. It might be useful in rewriting.
            # Create the per device iterators.
            self._device_iterators = []
            iterator_handles = []
            for i, device in enumerate(self._devices):
                with ops.device(device):
                    ds = _create_device_dataset(prototype_device_datasets[i],
                                                incarnation_id,
                                                prefetch_buffer_size,
                                                experimental_slack)
                    iterator = iter(ds)
                    self._device_iterators.append(iterator)
                    iterator_handles.append(iterator._iterator_resource)  # pylint: disable=protected-access

            self._resource_deleter = MultiDeviceIteratorResourceDeleter(
                multi_device_iterator=self._multi_device_iterator_resource,
                iterators=iterator_handles,
                device=self._source_device,
                deleter=self._deleter)
Example #59
0
 def _clear_graph_state():
     # Clearing the Graph collection will prevent _PerGraphState from being
     # serialized.
     ops_lib.get_default_graph().clear_collection(
         TowerOptimizer.COLLECTION_FOR_GRAPH_STATES)
Example #60
0
    def _experimental_run_steps_on_iterator(self,
                                            fn,
                                            multi_worker_iterator,
                                            iterations,
                                            initial_loop_values=None):
        # Wrap `fn` for repeat.
        if initial_loop_values is None:
            initial_loop_values = {}
        initial_loop_values = nest.flatten(initial_loop_values)
        ctx = input_lib.MultiStepContext()

        def run_fn(inputs):
            """Single step on the TPU device."""
            fn_result = fn(ctx, inputs)
            flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
            if flat_last_step_outputs:
                with ops.control_dependencies([fn_result]):
                    return [
                        array_ops.identity(f) for f in flat_last_step_outputs
                    ]
            else:
                return fn_result

        # We capture the control_flow_context at this point, before we run `fn`
        # inside a while_loop and TPU replicate context. This is useful in cases
        # where we might need to exit these contexts and get back to the outer
        # context to do some things, for e.g. create an op which should be
        # evaluated only once at the end of the loop on the host. One such usage
        # is in creating metrics' value op.
        self._outer_control_flow_context = (
            ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access

        def rewrite_fn(*args):
            """The rewritten step fn running on TPU."""
            del args

            per_replica_inputs = multi_worker_iterator.get_next()
            replicate_inputs = []
            for replica_id in range(self._num_replicas_in_sync):
                select_replica = lambda x: values.select_replica(replica_id, x)  # pylint: disable=cell-var-from-loop
                replicate_inputs.append(
                    (nest.map_structure(select_replica, per_replica_inputs), ))

            replicate_outputs = tpu.replicate(run_fn, replicate_inputs)

            # If run_fn has tensor outputs, tpu.replicate returns a list of list. We
            # will flatten it in this case. If run_fn has no tensor outputs,
            # tpu.replicate returns a list of no_ops, we will keep the output as it
            # is.
            if isinstance(replicate_outputs[0], list):
                replicate_outputs = nest.flatten(replicate_outputs)

            return replicate_outputs

        # TODO(sourabhbajaj): The input to while loop should be based on the
        # output type of the step_fn
        assert isinstance(initial_loop_values, list)
        initial_loop_values = initial_loop_values * self._num_replicas_in_sync

        # Put the while loop op on TPU host 0.
        with ops.device(self._host_device):
            if self.steps_per_run == 1:
                replicate_outputs = rewrite_fn()
            else:
                replicate_outputs = training_loop.repeat(
                    iterations, rewrite_fn, initial_loop_values)

        del self._outer_control_flow_context
        ctx.run_op = control_flow_ops.group(replicate_outputs)

        if isinstance(replicate_outputs, list):
            # Filter out any ops from the outputs, typically this would be the case
            # when there were no tensor outputs.
            last_step_tensor_outputs = [
                x for x in replicate_outputs
                if not isinstance(x, ops.Operation)
            ]

            # Outputs are currently of the structure (flattened)
            # [output0_device0, output1_device0, output2_device0,
            #  output0_device1, output1_device1, output2_device1,
            #  ...]
            # Convert this to the following structure instead: (grouped by output)
            # [[output0_device0, output0_device1],
            #  [output1_device0, output1_device1],
            #  [output2_device0, output2_device1]]
            output_num = len(
                last_step_tensor_outputs) // self._num_replicas_in_sync
            last_step_tensor_outputs = [
                last_step_tensor_outputs[i::output_num]
                for i in range(output_num)
            ]
        else:
            # no tensors returned.
            last_step_tensor_outputs = []

        _set_last_step_outputs(ctx, last_step_tensor_outputs)
        return ctx