Ejemplo n.º 1
0
    def testDetermineReadyTensorsAndTableInitializersRaises(
            self, create_graph_fn, feeds, replaced_tensors_ready, fetches,
            error_msg_regex):
        """Test determine_ready_tensors_and_table_initializers.

    Args:
      create_graph_fn: A function that adds ops to a graph and returns a dict
          mapping tensor names to `Tensor` or `SparseTensor`s.
      feeds: A list of keys in the dict returned by create_graph_fn that are fed
          in the main run (but not table initialization run).
      replaced_tensors_ready: A dict whose keys are keys in the dict returned by
          create_graph_fn and values are a bools indicating whether that tensor
          is ready to be replaced in this phase.
      fetches: A list keys in the dict returned by create_graph_fn to determine
          ready status for.
      error_msg_regex: The expected error message.
    """
        tensors = create_graph_fn()
        feeds = [tensors[name] for name in feeds]
        fetches = [tensors[name] for name in fetches]
        replaced_tensors_ready = {
            tensors[name]: ready
            for name, ready in replaced_tensors_ready.items()
        }
        with self.assertRaisesRegexp(ValueError, error_msg_regex):
            graph_tools.determine_ready_tensors_and_table_initializers(
                fetches, feeds, replaced_tensors_ready)
Ejemplo n.º 2
0
    def testDetermineReadyTensorsAndTableInitializers(
            self, create_graph_fn, feeds, replaced_tensors_ready,
            should_be_ready, num_ready_table_initializers):
        """Test determine_ready_tensors_and_table_initializers.

    Args:
      create_graph_fn: A function that adds ops to a graph and returns a dict
          mapping tensor names to `Tensor` or `SparseTensor`s.
      feeds: A list of keys in the dict returned by create_graph_fn that are fed
          in the main run (but not table initialization run).
      replaced_tensors_ready: A dict whose keys are keys in the dict returned by
          create_graph_fn and values are a bools indicating whether that tensor
          is ready to be replaced in this phase.
      should_be_ready: A dict dict whose keys are keys in the dict returned by
          create_graph_fn and value are bools indicating whether a tensor can be
          calculated in this phase.
      num_ready_table_initializers: The number of table initializers that are
          ready to run in the table initialization run of this phase.
    """
        tensors = create_graph_fn()
        feeds = [tensors[name] for name in feeds]
        replaced_tensors_ready = {
            tensors[name]: ready
            for name, ready in replaced_tensors_ready.items()
        }
        fetches = [tensors[name] for name in should_be_ready]
        expected_ready_tensors = [
            tensors[name] for name in should_be_ready if should_be_ready[name]
        ]
        ready_table_initializers, ready_tensors = (
            graph_tools.determine_ready_tensors_and_table_initializers(
                fetches, feeds, replaced_tensors_ready))
        self.assertEqual(len(ready_table_initializers),
                         num_ready_table_initializers)
        self.assertCountEqual(ready_tensors, expected_ready_tensors)
Ejemplo n.º 3
0
def create_phases(inputs):
    """Returns a list of `Phase`s describing how to execute the pipeline.

  The default graph is assumed to contain some `Analyzer`s which must be
  executed by doing a full pass over the dataset, and passing the inputs for
  that analyzer into some implementation, then taking the results and replacing
  the `Analyzer`s outputs with constants in the graph containing these results.

  The execution plan is described by a list of `Phase`s.  Each phase contains
  a list of `Analyzer`s, which are the `Analyzer`s which are ready to run in
  that phase, together with a list of ops, which are the table initializers that
  are ready to run in that phase.

  An `Analyzer` or op is ready to run when all its dependencies in the graph
  have been computed.  Thus if the graph is constructed by

  def preprocessing_fn(input)
    x = inputs['x']
    scaled_0 = x - tft.min(x)
    scaled_0_1 = scaled_0 / tft.max(scaled_0)

  Then the first phase will contain the analyzer corresponding to the call to
  `min`, because `x` is an input and so is ready to compute in the first phase,
  while the second phase will contain the analyzer corresponding to the call to
  `max` since `scaled_1` depends on the result of the call to `tft.min` which
  is computed in the first phase.

  More generally, we define a level for each op and each `Analyzer` by walking
  the graph, assigning to each operation the max level of its inputs, to each
  `Tensor` the level of its operation, unless it's the output of an `Analyzer`
  in which case we assign the level of its `Analyzer` plus one.

  The above description omits the role of `FunctionApplication`s.  A
  `FunctionApplication` is a hint to create_phases about the control flow of the
  graph.  Because control flow ops can introduce circular dependencies (and
  other circumstances such as mutable reference introduce similar problems) we
  allow users to construct a `FunctionApplication` which is a hint that the
  outputs `Tensor`s depend only on the input `Tensor`s.  `FunctionApplication`s
  are also needed to collect table initializers to determine which phase a table
  initializer is ready to run in.

  Args:
    inputs: A dict whose keys are strings and values are `Tensor` or
        `SparseTensor`s.

  Returns:
    A list of `Phase`s.

  Raises:
    ValueError: if the graph cannot be analyzed.
  """
    feed_tensors = inputs.values()

    remaining_analyzers = tf.get_collection(analyzers.ANALYZER_COLLECTION)
    analyzer_output_ready = {}
    for analyzer in remaining_analyzers:
        for tensor in analyzer.outputs:
            analyzer_output_ready[tensor] = False

    # Construct `AnalyzerInfo`s, removing any tensors that are analyzer outputs
    # from the ASSET_FILEPATHS collection.  These tensors will be replaced and
    # the replacements will be added to the ASSET_FILEPATHS.  Setting
    # AnalyzerOutputInfo.is_asset instructs the implementation to do this.
    asset_filepaths_collection = tf.get_collection_ref(
        tf.GraphKeys.ASSET_FILEPATHS)
    asset_filepaths = collections.OrderedDict(
        (tensor, True)
        for tensor in tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS))

    phases = []
    while remaining_analyzers:
        analyzer_inputs = []
        for analyzer in remaining_analyzers:
            analyzer_inputs.extend(analyzer.inputs)
        ready_init_ops, ready_analyzer_inputs = (
            graph_tools.determine_ready_tensors_and_table_initializers(
                tf.get_default_graph(), analyzer_inputs, feed_tensors,
                analyzer_output_ready))
        ready_analyzer_inputs = set(ready_analyzer_inputs)

        new_remaining_analyzers = []
        analyzer_infos = []
        for analyzer in remaining_analyzers:
            if all(tensor in ready_analyzer_inputs
                   for tensor in analyzer.inputs):
                input_tensor_names = [
                    tensor.name for tensor in analyzer.inputs
                ]
                output_infos = [
                    AnalyzerOutputInfo(tensor.name,
                                       asset_filepaths.pop(tensor, False))
                    for tensor in analyzer.outputs
                ]
                analyzer_infos.append(
                    AnalyzerInfo(analyzer.name, input_tensor_names,
                                 analyzer.spec, output_infos))

                for tensor in analyzer.outputs:
                    analyzer_output_ready[tensor] = True
            else:
                new_remaining_analyzers.append(analyzer)
        phases.append(Phase(analyzer_infos, ready_init_ops))

        assert len(new_remaining_analyzers) < len(remaining_analyzers)
        remaining_analyzers = new_remaining_analyzers

    del asset_filepaths_collection[:]
    asset_filepaths_collection.extend(six.iterkeys(asset_filepaths))

    return phases