Exemple #1
0
def get_paths(base_dir, parser):
  """Gets a list of Paths in a given directory.

  Args:
    base_dir: directory.
    parser: a function which gets the raw Path and can augment it with
      information such as the export_version, or ignore the path by returning
      None.  An example parser may extract the export version from a path
      such as "/tmp/exports/100" an another may extract from a full file
      name such as "/tmp/checkpoint-99.out".

  Returns:
    A list of Paths contained in the base directory with the parsing function
    applied.
    By default the following fields are populated,
      - Path.path
    The parsing function is responsible for populating,
      - Path.export_version
  """
  raw_paths = gfile.ListDirectory(base_dir)
  paths = []
  for r in raw_paths:
    p = parser(Path(os.path.join(compat.as_str_any(base_dir),
                                 compat.as_str_any(r)),
                    None))
    if p:
      paths.append(p)
  return sorted(paths)
Exemple #2
0
def get_matching_files_v2(pattern):
  """Returns a list of files that match the given pattern(s).

  Args:
    pattern: string or iterable of strings. The glob pattern(s).

  Returns:
    A list of strings containing filenames that match the given pattern(s).

  Raises:
    errors.OpError: If there are filesystem / directory listing errors.
  """
  with errors.raise_exception_on_not_ok_status() as status:
    if isinstance(pattern, six.string_types):
      return [
          # Convert the filenames to string from bytes.
          compat.as_str_any(matching_filename)
          for matching_filename in pywrap_tensorflow.GetMatchingFiles(
              compat.as_bytes(pattern), status)
      ]
    else:
      return [
          # Convert the filenames to string from bytes.
          compat.as_str_any(matching_filename)
          for single_filename in pattern
          for matching_filename in pywrap_tensorflow.GetMatchingFiles(
              compat.as_bytes(single_filename), status)
      ]
Exemple #3
0
def get_matching_files_v2(pattern):
  """Returns a list of files that match the given pattern(s).

  Args:
    pattern: string or iterable of strings. The glob pattern(s).

  Returns:
    A list of strings containing filenames that match the given pattern(s).

  Raises:
    errors.OpError: If there are filesystem / directory listing errors.
  """
  if isinstance(pattern, six.string_types):
    return [
        # Convert the filenames to string from bytes.
        compat.as_str_any(matching_filename)
        for matching_filename in pywrap_tensorflow.GetMatchingFiles(
            compat.as_bytes(pattern))
    ]
  else:
    return [
        # Convert the filenames to string from bytes.
        compat.as_str_any(matching_filename)  # pylint: disable=g-complex-comprehension
        for single_filename in pattern
        for matching_filename in pywrap_tensorflow.GetMatchingFiles(
            compat.as_bytes(single_filename))
    ]
  def request_stop(self, ex=None):
    """Request that the threads stop.

    After this is called, calls to `should_stop()` will return `True`.

    Note: If an exception is being passed in, in must be in the context of
    handling the exception (i.e. `try: ... except Exception as ex: ...`) and not
    a newly created one.

    Args:
      ex: Optional `Exception`, or Python `exc_info` tuple as returned by
        `sys.exc_info()`.  If this is the first call to `request_stop()` the
        corresponding exception is recorded and re-raised from `join()`.
    """
    with self._lock:
      ex = self._filter_exception(ex)
      # If we have already joined the coordinator the exception will not have a
      # chance to be reported, so just raise it normally.  This can happen if
      # you continue to use a session have having stopped and joined the
      # coordinator threads.
      if self._joined:
        if isinstance(ex, tuple):
          six.reraise(*ex)
        elif ex is not None:
          # NOTE(touts): This is bogus if request_stop() is not called
          # from the exception handler that raised ex.
          six.reraise(*sys.exc_info())
      if not self._stop_event.is_set():
        if ex and self._exc_info_to_raise is None:
          if isinstance(ex, tuple):
            logging.info("Error reported to Coordinator: %s, %s",
                         type(ex[1]),
                         compat.as_str_any(ex[1]))
            self._exc_info_to_raise = ex
          else:
            logging.info("Error reported to Coordinator: %s, %s",
                         type(ex),
                         compat.as_str_any(ex))
            self._exc_info_to_raise = sys.exc_info()
          # self._exc_info_to_raise should contain a tuple containing exception
          # (type, value, traceback)
          if (len(self._exc_info_to_raise) != 3 or
              not self._exc_info_to_raise[0] or
              not self._exc_info_to_raise[1]):
            # Raise, catch and record the exception here so that error happens
            # where expected.
            try:
              raise ValueError(
                  "ex must be a tuple or sys.exc_info must return the current "
                  "exception: %s"
                  % self._exc_info_to_raise)
            except ValueError:
              # Record this error so it kills the coordinator properly.
              # NOTE(touts): As above, this is bogus if request_stop() is not
              # called from the exception handler that raised ex.
              self._exc_info_to_raise = sys.exc_info()

        self._stop_event.set()
  def test_add_pruned_collection_proto_in_bytes_list(self):
    # Note: This also tests _is_removed_mentioned().
    collection_name = 'proto_collection'
    base_meta_graph_def = meta_graph_pb2.MetaGraphDef()
    base_meta_graph_def.collection_def[collection_name].bytes_list.value.extend(
        [compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node1'))),
         compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node2'))),
         compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node3'))),
         compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node4'))),
         compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('/a/a_1'))),
         compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('/b/b_1')))
        ])

    meta_graph_def = meta_graph_pb2.MetaGraphDef()
    removed_op_names = ['node2', 'node4', 'node5', '/a', '/b/b_1']
    meta_graph_transform._add_pruned_collection(
        base_meta_graph_def, meta_graph_def, collection_name, removed_op_names)

    collection = meta_graph_def.collection_def[collection_name]

    expected_values = [
        compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node1'))),
        compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node3'))),
        compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('/a/a_1'))),
    ]
    self.assertEqual(expected_values, collection.bytes_list.value[:])
Exemple #6
0
 def parser(path):
   # Modify the path object for RegEx match for Windows Paths
   if os.name == 'nt':
     match = re.match("^" + compat.as_str_any(base_dir).replace('\\','/') + "/(\\d+)$",
                     compat.as_str_any(path.path).replace('\\','/'))
   else:
     match = re.match("^" + compat.as_str_any(base_dir) + "/(\\d+)$",
                     compat.as_str_any(path.path))
   if not match:
     return None
   return path._replace(export_version=int(match.group(1)))
Exemple #7
0
 def test_asset_path_returned(self):
   root = tracking.AutoTrackable()
   root.path = tracking.TrackableAsset(self._vocab_path)
   save_dir = os.path.join(self.get_temp_dir(), "saved_model")
   root.get_asset = def_function.function(lambda: root.path.asset_path)
   save.save(root, save_dir, signatures=root.get_asset.get_concrete_function())
   second_dir = os.path.join(self.get_temp_dir(), "second_dir")
   file_io.rename(save_dir, second_dir)
   imported_path = _import_and_infer(second_dir, {})["output_0"]
   self.assertIn(compat.as_str_any(second_dir),
                 compat.as_str_any(imported_path))
Exemple #8
0
        def _export_eval_result(self, eval_result, checkpoint_path):
            """Export `eval_result` according to strategies in `EvalSpec`."""
            export_dir_base = os.path.join(
                compat.as_str_any(self._estimator.model_dir),
                compat.as_str_any('export'))

            for strategy in self._eval_spec.export_strategies:
                strategy.export(self._estimator,
                                os.path.join(
                                    compat.as_str_any(export_dir_base),
                                    compat.as_str_any(strategy.name)),
                                checkpoint_path=checkpoint_path,
                                eval_result=eval_result)
 def test_asset_path_returned(self):
     root = tracking.AutoTrackable()
     root.path = tracking.Asset(self._vocab_path)
     save_dir = os.path.join(self.get_temp_dir(), "saved_model")
     root.get_asset = def_function.function(lambda: root.path.asset_path)
     save.save(root,
               save_dir,
               signatures=root.get_asset.get_concrete_function())
     second_dir = os.path.join(self.get_temp_dir(), "second_dir")
     file_io.rename(save_dir, second_dir)
     imported_path = _import_and_infer(second_dir, {})["output_0"]
     self.assertIn(compat.as_str_any(second_dir),
                   compat.as_str_any(imported_path))
 def parser(path):
     # Modify the path object for RegEx match for Windows Paths
     if os.name == "nt":
         match = re.match(
             r"^" + compat.as_str_any(base_dir).replace("\\", "/") +
             r"/(\d+)$",
             compat.as_str_any(path.path).replace("\\", "/"))
     else:
         match = re.match(r"^" + compat.as_str_any(base_dir) + r"/(\d+)$",
                          compat.as_str_any(path.path))
     if not match:
         return None
     return path._replace(export_version=int(match.group(1)))
def _add_mean(op, context):

  input_name = compat.as_str_any(op.inputs[0].name)
  output_name = compat.as_str_any(op.outputs[0].name)
  axis_ind = context.consts[op.inputs[1].name]

  input_shape = context.shape_dict[input_name]
  output_shape = context.shape_dict[output_name]

  if context.use_dfs_shape_infer:
    status = interpret_shape(input_name, context)
  else:
    status = False

  if status:
    labeled_shape = context.dim_labels[input_name]
    if isinstance(axis_ind, np.ndarray):
      axis = ''
      for i in axis_ind:
        if input_shape[i] != 1:
          axis += labeled_shape[i]
      axis = ''.join(sorted(axis))
    else:
      axis = labeled_shape[axis_ind]
    assert axis in ['S', 'C', 'H', 'W', 'CHW', 'HW'], (
        'Axis value %s not supported. '
        'Reduction supported along C, H, W, HW, CHW dimensions only.' % axis)
  else:
    if len(input_shape) == 4 and (
        np.array_equal(axis_ind, np.array([0, 1, 2])) or
        np.array_equal(axis_ind, np.array([1, 2]))):
      axis = 'HW'
    else:
      assert False, ('Mean axis case not handled currently. '
                     'Input shape = {}, output shape = {}, axis_ind = {}'.
                     format(str(input_shape), str(output_shape), str(axis_ind)))

  mode = 'avg'
  # The simple case; reduction along non sequence axis
  if axis != 'S':
    context.builder.add_reduce(output_name, input_name, output_name, axis, mode)
  # Need to permute, reduce and then permute back
  else:
    context.builder.add_permute(
        output_name, (1, 0, 2, 3), input_name, output_name + '_swap_Seq_C')
    context.builder.add_reduce(
        output_name, output_name + '_swap_Seq_C',
        output_name + '_pre_permute', 'C', mode)
    context.builder.add_permute(
        output_name, (1, 0, 2, 3), output_name + '_pre_permute', output_name)
  context.translated[output_name] = True
Exemple #12
0
    def _export_eval_result(self, eval_result, checkpoint_path):
      """Export `eval_result` according to strategies in `EvalSpec`."""
      export_dir_base = os.path.join(
          compat.as_str_any(self._estimator.model_dir),
          compat.as_str_any('export'))

      for strategy in self._eval_spec.export_strategies:
        strategy.export(
            self._estimator,
            os.path.join(
                compat.as_str_any(export_dir_base),
                compat.as_str_any(strategy.name)),
            checkpoint_path=checkpoint_path,
            eval_result=eval_result)
        def _export_eval_result(self, eval_result, is_the_final_export):
            """Export `eval_result` according to exporters in `EvalSpec`."""
            export_dir_base = os.path.join(
                compat.as_str_any(self._estimator.model_dir),
                compat.as_str_any('export'))

            for exporter in self._eval_spec.exporters:
                exporter.export(estimator=self._estimator,
                                export_path=os.path.join(
                                    compat.as_str_any(export_dir_base),
                                    compat.as_str_any(exporter.name)),
                                checkpoint_path=eval_result.checkpoint_path,
                                eval_result=eval_result.metrics,
                                is_the_final_export=is_the_final_export)
Exemple #14
0
    def _export_eval_result(self, eval_result, is_the_final_export):
      """Export `eval_result` according to exporters in `EvalSpec`."""
      export_dir_base = os.path.join(
          compat.as_str_any(self._estimator.model_dir),
          compat.as_str_any('export'))

      for exporter in self._eval_spec.exporters:
        exporter.export(
            estimator=self._estimator,
            export_path=os.path.join(
                compat.as_str_any(export_dir_base),
                compat.as_str_any(exporter.name)),
            checkpoint_path=eval_result.checkpoint_path,
            eval_result=eval_result.metrics,
            is_the_final_export=is_the_final_export)
Exemple #15
0
def strided_slice(op, context):

  input_name = compat.as_str_any(op.inputs[0].name)
  output_name = compat.as_str_any(op.outputs[0].name)
  make_tensor(op.inputs[0], context)

  [x, y] = context.session.run([input_name, output_name],
                               feed_dict=context.input_feed_dict)

  assert op.inputs[1].name in context.consts, \
      'Strided Slice: begin index must be a constant'
  assert op.inputs[2].name in context.consts, \
      'Strided Slice: end index must be a constant'
  assert op.inputs[3].name in context.consts, \
      'Strided Slice: strides must be a constant'

  begin = context.consts[compat.as_str_any(op.inputs[1].name)]
  end = context.consts[compat.as_str_any(op.inputs[2].name)]
  strides = context.consts[compat.as_str_any(op.inputs[3].name)]
  begin_mask = op.get_attr('begin_mask')
  end_mask = op.get_attr('end_mask')
  ellipsis_mask = op.get_attr('ellipsis_mask')
  new_axis_mask = op.get_attr('new_axis_mask')
  shrink_axis_mask = op.get_attr('shrink_axis_mask')

  input_shape = context.shape_dict[input_name]

  if len(input_shape) == 1 and len(begin) == 1 and len(end) == 1 and \
      len(strides) == 1:
    if begin_mask:
      begin[0] = 0
    if end_mask:
      end[0] = input_shape[0]
    context.builder.add_slice(
        output_name, input_name, output_name,
        'channel', begin[0], end[0], strides[0])
  elif len(x.shape) == 4 and len(y.shape) == 3 and x.shape[:3] == y.shape:
    context.builder.add_slice(
      output_name, input_name, output_name,
      'channel', begin[-1], end[-1], 1)
  elif input_name in context.consts:
    #this means all the inputs to the strided slice layer are constant
    add_const(context, output_name, y, output_name)
  elif np.array_equal(np.squeeze(x),np.squeeze(y)):
    skip(op,context)
  else:
    assert False, 'Strided Slice case not handled'
  context.translated[output_name] = True
Exemple #16
0
def lrn(op, context):
  input_name = compat.as_str_any(op.inputs[0].name)
  output_name = compat.as_str_any(op.outputs[0].name)

  input_shape = context.shape_dict[input_name]
  C = input_shape[-1]
  alpha = op.get_attr('alpha')
  beta = op.get_attr('beta')
  bias = op.get_attr('bias')
  depth_radius = op.get_attr('depth_radius')
  context.builder.add_lrn(output_name, input_name, output_name,
                          alpha=alpha * C,
                          beta=beta,
                          local_size=depth_radius,
                          k=bias)
  context.translated[output_name] = True
Exemple #17
0
def mul(op, context):
  output_name = compat.as_str_any(op.outputs[0].name)

  # input_names: names of input tensors
  input_names = [make_tensor(ts, context) for ts in op.inputs]
  # input_shapes: shapes of input tensors
  input_shapes = [context.shape_dict[ts.name] for ts in op.inputs]
  mult_input_names = input_names

  # For rank-4 inputs, CoreML only allows [1], [C], [1,H,W] blobs to be
  # broadcasted in elementwise operations. To handle other broadcasting cases,
  # (e.g. [1,1,W] --> [C,H,W]), we insert up-sampling layers
  input_ranks = [len(shape) for shape in input_shapes]
  if 4 in input_ranks:
    broadcasted_shape4 = _get_broadcasted_shape4(input_shapes)
    for idx, in_name in enumerate(input_names):
      input_shape = input_shapes[idx]
      axis = _broadcast_axis(broadcasted_shape4, input_shape)
      if axis is not None:
        # add upsample layer
        upsampled_in_name = in_name + '__upsampled'
        mult_input_names[idx] = upsampled_in_name
        input_axis_dim = 1 if axis >= len(input_shape) else input_shape[axis]
        scale = broadcasted_shape4[axis] // input_axis_dim
        if axis == 1:
          context.builder.add_upsample(
              upsampled_in_name, scale, 1, in_name, upsampled_in_name)
        else:
          context.builder.add_upsample(
              upsampled_in_name, 1, scale, in_name, upsampled_in_name)

  context.builder.add_elementwise(
      output_name, mult_input_names, output_name, 'MULTIPLY')
  context.translated[output_name] = True
Exemple #18
0
def _is_removed_mentioned(s, removed_op_names):
    """Determine whether any removed op is mentioned in the given object.

  This relies on the string representation of the object.  This is used for
  proto messages that may mention ops by name in nested fields.  The string
  representation of the proto includes those field values, so this string
  search approach is sufficient.

  Args:
    s: an object to search for removed op names.
    removed_op_names: An iterable of names of ops that were removed.

  Returns:
    True if any removed op is mentioned in the given object, False otherwise.
  """
    # A common approach taken by some of the transforms in gtt is to add new nodes
    # that have the same prefix as the node they are removing. For example, if
    # the original node name was /foo, they may remove that node and add in
    # /foo/bar. This regex ensures that we handle these two nodes
    # as separate entities.  It matches on nodes having names in the form of
    # '/foo/bar_x' as well as nodes having names in the form of 'foo.'
    s_names = _re.findall(r'((?:[\/]?[a-zA-Z0-9\_]*)*)', compat.as_str_any(s))
    for removed_op_name in removed_op_names:
        for s_name in s_names:
            if s_name.endswith(removed_op_name):
                return True
    return False
Exemple #19
0
def _normalize_outputs(outputs, function_name, signature_key):
  """Construct an output dictionary from unnormalized function outputs."""
  if isinstance(outputs, collections.Mapping):
    for key, value in outputs.items():
      if not isinstance(value, ops.Tensor):
        raise ValueError(
            ("Got a dictionary containing non-Tensor value {} for key {} "
             "in the output of the function {} used to generate a SavedModel "
             "signature. Dictionaries outputs for functions used as signatures "
             "should have one Tensor output per string key.")
            .format(value, key, compat.as_str_any(function_name)))
    return outputs
  else:
    original_outputs = outputs
    if not isinstance(outputs, collections.Sequence):
      outputs = [outputs]
    if not _is_flat(outputs):
      raise ValueError(
          ("Got non-flat outputs '{}' from '{}' for SavedModel "
           "signature '{}'. Signatures have one Tensor per output, so "
           "to have predictable names Python functions used to generate "
           "these signatures should avoid outputting Tensors in nested "
           "structures.")
          .format(original_outputs, function_name, signature_key))
    return {("output_{}".format(output_index)): output
            for output_index, output
            in enumerate(outputs)}
def get_summary_description(node_def):
  """Given a TensorSummary node_def, retrieve its SummaryDescription.

  When a Summary op is instantiated, a SummaryDescription of associated
  metadata is stored in its NodeDef. This method retrieves the description.

  Args:
    node_def: the node_def_pb2.NodeDef of a TensorSummary op

  Returns:
    a summary_pb2.SummaryDescription

  Raises:
    ValueError: if the node is not a summary op.

  @compatibility(eager)
  Not compatible with eager execution. To write TensorBoard
  summaries under eager execution, use `tf.contrib.summary` instead.
  @end_compatibility
  """

  if node_def.op != 'TensorSummary':
    raise ValueError("Can't get_summary_description on %s" % node_def.op)
  description_str = _compat.as_str_any(node_def.attr['description'].s)
  summary_description = SummaryDescription()
  _json_format.Parse(description_str, summary_description)
  return summary_description
Exemple #21
0
def get_summary_description(node_def):
    """Given a TensorSummary node_def, retrieve its SummaryDescription.

  When a Summary op is instantiated, a SummaryDescription of associated
  metadata is stored in its NodeDef. This method retrieves the description.

  Args:
    node_def: the node_def_pb2.NodeDef of a TensorSummary op

  Returns:
    a summary_pb2.SummaryDescription

  Raises:
    ValueError: if the node is not a summary op.

  @compatibility(eager)
  Not compatible with eager execution. To write TensorBoard
  summaries under eager execution, use `tf.contrib.summary` instead.
  @end_compatbility
  """

    if node_def.op != 'TensorSummary':
        raise ValueError("Can't get_summary_description on %s" % node_def.op)
    description_str = _compat.as_str_any(node_def.attr['description'].s)
    summary_description = SummaryDescription()
    _json_format.Parse(description_str, summary_description)
    return summary_description
Exemple #22
0
def join(path, *paths):
    r"""Join one or more path components intelligently.

  TensorFlow specific filesystems will be joined
  like a url (using "/" as the path seperator) on all platforms:

  On Windows or Linux/Unix-like:
  >>> tf.io.gfile.join("gcs://folder", "file.py")
  'gcs://folder/file.py'

  >>> tf.io.gfile.join("ram://folder", "file.py")
  'ram://folder/file.py'

  But the native filesystem is handled just like os.path.join:

  >>> path = tf.io.gfile.join("folder", "file.py")
  >>> if os.name == "nt":
  ...   expected = "folder\\file.py"  # Windows
  ... else:
  ...   expected = "folder/file.py"  # Linux/Unix-like
  >>> path == expected
  True

  Args:
    path: string, path to a directory
    paths: string, additional paths to concatenate

  Returns:
    path: the joined path.
  """
    # os.path.join won't take mixed bytes/str, so don't overwrite the incoming `path` var
    path_ = compat.as_str_any(compat.path_to_str(path))
    if "://" in path_[1:]:
        return urljoin(path, *paths)
    return os.path.join(path, *paths)
Exemple #23
0
def list_directory_v2(path):
    """Returns a list of entries contained within a directory.

  The list is in arbitrary order. It does not contain the special entries "."
  and "..".

  Args:
    path: string, path to a directory

  Returns:
    [filename1, filename2, ... filenameN] as strings

  Raises:
    errors.NotFoundError if directory doesn't exist
  """
    if not is_directory(path):
        raise errors.NotFoundError(
            node_def=None,
            op=None,
            message="Could not find directory {}".format(path))
    with errors.raise_exception_on_not_ok_status() as status:
        # Convert each element to string, since the return values of the
        # vector of string should be interpreted as strings, not bytes.
        return [
            compat.as_str_any(filename)
            for filename in pywrap_tensorflow.GetChildren(
                compat.as_bytes(path), status)
        ]
Exemple #24
0
def list_directory_v2(path):
  """Returns a list of entries contained within a directory.

  The list is in arbitrary order. It does not contain the special entries "."
  and "..".

  Args:
    path: string, path to a directory

  Returns:
    [filename1, filename2, ... filenameN] as strings

  Raises:
    errors.NotFoundError if directory doesn't exist
  """
  if not is_directory(path):
    raise errors.NotFoundError(
        node_def=None,
        op=None,
        message="Could not find directory {}".format(path))

  # Convert each element to string, since the return values of the
  # vector of string should be interpreted as strings, not bytes.
  return [
      compat.as_str_any(filename)
      for filename in pywrap_tensorflow.GetChildren(compat.as_bytes(path))
  ]
def _is_removed_mentioned(s, removed_op_names):
  """Determine whether any removed op is mentioned in the given object.

  This relies on the string representation of the object.  This is used for
  proto messages that may mention ops by name in nested fields.  The string
  representation of the proto includes those field values, so this string
  search approach is sufficient.

  Args:
    s: an object to search for removed op names.
    removed_op_names: An iterable of names of ops that were removed.

  Returns:
    True if any removed op is mentioned in the given object, False otherwise.
  """
  # A common approach taken by some of the transforms in gtt is to add new nodes
  # that have the same prefix as the node they are removing. For example, if
  # the original node name was /foo, they may remove that node and add in
  # /foo/bar. This regex ensures that we handle these two nodes
  # as separate entities.  It matches on nodes having names in the form of
  # '/foo/bar_x' as well as nodes having names in the form of 'foo.'
  s_names = _re.findall(r'((?:[\/]?[a-zA-Z0-9\_]*)*)', compat.as_str_any(s))
  for removed_op_name in removed_op_names:
    for s_name in s_names:
      if s_name.endswith(removed_op_name):
        return True
  return False
Exemple #26
0
def _normalize_outputs(outputs, function_name, signature_key):
    """Construct an output dictionary from unnormalized function outputs."""
    if isinstance(outputs, collections.Mapping):
        for key, value in outputs.items():
            if not isinstance(value, ops.Tensor):
                raise ValueError((
                    "Got a dictionary containing non-Tensor value {} for key {} "
                    "in the output of the function {} used to generate a SavedModel "
                    "signature. Dictionaries outputs for functions used as signatures "
                    "should have one Tensor output per string key.").format(
                        value, key, compat.as_str_any(function_name)))
        return outputs
    else:
        original_outputs = outputs
        if not isinstance(outputs, collections.Sequence):
            outputs = [outputs]
        if not _is_flat(outputs):
            raise ValueError(
                ("Got non-flat outputs '{}' from '{}' for SavedModel "
                 "signature '{}'. Signatures have one Tensor per output, so "
                 "to have predictable names Python functions used to generate "
                 "these signatures should avoid outputting Tensors in nested "
                 "structures.").format(original_outputs, function_name,
                                       signature_key))
        return {("output_{}".format(output_index)): output
                for output_index, output in enumerate(outputs)}
Exemple #27
0
def list_directory(dirname):
  """Returns a list of entries contained within a directory.

  The list is in arbitrary order. It does not contain the special entries "."
  and "..".

  Args:
    dirname: string, path to a directory

  Returns:
    [filename1, filename2, ... filenameN] as strings

  Raises:
    errors.NotFoundError if directory doesn't exist
  """
  if not is_directory(dirname):
    raise errors.NotFoundError(None, None, "Could not find directory")
  with errors.raise_exception_on_not_ok_status() as status:
    # Convert each element to string, since the return values of the
    # vector of string should be interpreted as strings, not bytes.
    return [
        compat.as_str_any(filename)
        for filename in pywrap_tensorflow.GetChildren(
            compat.as_bytes(dirname), status)
    ]
def identity(op, context, input_name = None):
  is_network_output = False
  for out in op.outputs:
    if out.name in context.output_names:
      is_network_output = True
      break
  if input_name is None:
    input_name = compat.as_str_any(op.inputs[0].name)
  for out in op.outputs:
    output_name = compat.as_str_any(out.name)
    if op.inputs[0].op.type != 'Const':
      if is_network_output:
        context.builder.add_activation(
            output_name, 'LINEAR', input_name, output_name, [1.0, 0])
      else:
        skip(op, context)
    context.translated[output_name] = True
Exemple #29
0
def real_div(op, context):
  output_name = compat.as_str_any(op.outputs[0].name)
  input_names = []
  for inp in op.inputs:
    input_names.append(make_tensor(inp, context))
  add_tensor_div(
      context.builder, output_name, input_names[0], input_names[1], output_name)
  context.translated[output_name] = True
Exemple #30
0
def IsTensorFlowEventsFile(path):
  """Check the path name to see if it is probably a TF Events file."""
  if 'tfevents' not in compat.as_str_any(os.path.basename(path)):
    return False
  if _CNS_DELETED_FILE_PATTERN.search(path):
    logging.info('Ignoring deleted Colossus file: %s', path)
    return False
  return True
Exemple #31
0
def check(op, context):
    for inp in op.inputs:
        inp_name = compat.as_str_any(inp.name)
        assert inp_name in context.translated, (
            'No translation found for {}'.format(inp_name))
    for out in op.outputs:
        assert out.name in context.shape_dict, (
            'Shape for {} is not fully defined'.format(out.name))
Exemple #32
0
def translation_required(op, context):
    for out in op.outputs:
        out_name = compat.as_str_any(out.name)
        if out_name in context.translated:
            continue
        else:
            return True
    return False
Exemple #33
0
def sub(op, context):
  assert len(op.inputs) == 2, 'Sub op currently supports only two inputs'
  output_name = compat.as_str_any(op.outputs[0].name)
  input_1_name = make_tensor(op.inputs[0], context)
  input_2_name = make_tensor(op.inputs[1], context)
  add_tensor_sub(
      context.builder, output_name, input_1_name, input_2_name, output_name)
  context.translated[output_name] = True
Exemple #34
0
  def request_stop(self, ex=None):
    """Request that the threads stop.

    After this is called, calls to `should_stop()` will return `True`.

    Note: If an exception is being passed in, in must be in the context of
    handling the exception (i.e. `try: ... except Exception as ex: ...`) and not
    a newly created one.

    Args:
      ex: Optional `Exception`, or Python `exc_info` tuple as returned by
        `sys.exc_info()`.  If this is the first call to `request_stop()` the
        corresponding exception is recorded and re-raised from `join()`.
    """
    ex = self._filter_exception(ex)
    with self._lock:
      if not self._stop_event.is_set():
        if ex and self._exc_info_to_raise is None:
          if isinstance(ex, tuple):
            logging.info("Error reported to Coordinator: %s, %s",
                         type(ex[1]),
                         compat.as_str_any(ex[1]))
            self._exc_info_to_raise = ex
          else:
            logging.info("Error reported to Coordinator: %s, %s",
                         type(ex),
                         compat.as_str_any(ex))
            self._exc_info_to_raise = sys.exc_info()
          # self._exc_info_to_raise should contain a tuple containing exception
          # (type, value, traceback)
          if (len(self._exc_info_to_raise) != 3 or
              not self._exc_info_to_raise[0] or
              not self._exc_info_to_raise[1]):
            # Raise, catch and record the exception here so that error happens
            # where expected.
            try:
              raise ValueError(
                  "ex must be a tuple or sys.exc_info must return the current "
                  "exception: %s"
                  % self._exc_info_to_raise)
            except ValueError:
              # Record this error so it kills the coordinator properly.
              self._exc_info_to_raise = sys.exc_info()

        self._stop_event.set()
Exemple #35
0
  def request_stop(self, ex=None):
    """Request that the threads stop.

    After this is called, calls to `should_stop()` will return `True`.

    Note: If an exception is being passed in, in must be in the context of
    handling the exception (i.e. `try: ... except Exception as ex: ...`) and not
    a newly created one.

    Args:
      ex: Optional `Exception`, or Python `exc_info` tuple as returned by
        `sys.exc_info()`.  If this is the first call to `request_stop()` the
        corresponding exception is recorded and re-raised from `join()`.
    """
    ex = self._filter_exception(ex)
    with self._lock:
      if not self._stop_event.is_set():
        if ex and self._exc_info_to_raise is None:
          if isinstance(ex, tuple):
            logging.info("Error reported to Coordinator: %s, %s",
                         type(ex[1]),
                         compat.as_str_any(ex[1]))
            self._exc_info_to_raise = ex
          else:
            logging.info("Error reported to Coordinator: %s, %s",
                         type(ex),
                         compat.as_str_any(ex))
            self._exc_info_to_raise = sys.exc_info()
          # self._exc_info_to_raise should contain a tuple containing exception
          # (type, value, traceback)
          if (len(self._exc_info_to_raise) != 3 or
              not self._exc_info_to_raise[0] or
              not self._exc_info_to_raise[1]):
            # Raise, catch and record the exception here so that error happens
            # where expected.
            try:
              raise ValueError(
                  "ex must be a tuple or sys.exc_info must return the current "
                  "exception: %s"
                  % self._exc_info_to_raise)
            except ValueError:
              # Record this error so it kills the coordinator properly.
              self._exc_info_to_raise = sys.exc_info()

        self._stop_event.set()
Exemple #36
0
def product(op, context):

  input_name = compat.as_str_any(op.inputs[0].name)
  output_name = compat.as_str_any(op.outputs[0].name)
  start_ind = context.consts[op.inputs[1].name]

  assert start_ind == 0, 'Prod: only start index = 0 case supported'

  input_shape = context.shape_dict[input_name]

  if len(input_shape) == 1:
    axis = 'C'
  else:
    assert False, 'Reduce Sum axis case not handled currently'

  mode = 'prod'
  context.translated[output_name] = True
  context.builder.add_reduce(output_name, input_name, output_name, axis, mode)
Exemple #37
0
def mirror_pad(op, context):
  input_name = compat.as_str_any(op.inputs[0].name)
  output_name = compat.as_str_any(op.outputs[0].name)

  paddings = context.consts[op.inputs[1].name]
  top = paddings[1][0]
  bottom = paddings[1][1]
  left = paddings[2][0]
  right = paddings[2][1]

  assert compat.as_str_any(op.get_attr('mode')) != 'SYMMETRIC', \
      'symmetric mode is not supported by Core ML'

  context.translated[output_name] = True
  context.builder.add_padding(
      output_name, left, right, top, bottom,
      input_name=input_name, output_name=output_name,
      padding_type='reflection')
Exemple #38
0
    def testMixedStrTypes(self):
        temp_dir = compat.as_bytes(test.get_temp_dir())

        for sub_dir in ['str', b'bytes', u'unicode']:
            base_dir = os.path.join((temp_dir if isinstance(sub_dir, bytes)
                                     else temp_dir.decode()), sub_dir)
            self.assertFalse(gfile.Exists(base_dir))
            gfile.MakeDirs(os.path.join(compat.as_str_any(base_dir), "42"))
            gc.get_paths(base_dir, _create_parser(base_dir))
Exemple #39
0
def argmax(op, context):
  input_name = compat.as_str_any(op.inputs[0].name)
  output_name = compat.as_str_any(op.outputs[0].name)

  input_shape = context.shape_dict[input_name]
  axis_tensor = compat.as_str_any(op.inputs[1].name)
  if axis_tensor in context.consts:
    axis_tf = context.consts[axis_tensor]
  else:
    assert False, 'ArgMax: Axis tensor not found in the list of Consts'
  if len(input_shape) == 4 and axis_tf == 3:
    axis = 'C'
  else:
    assert False, 'ArgMax: Axis translation case not handled currently'

  context.builder.add_reduce(
      output_name, input_name, output_name, axis, 'argmax')
  context.translated[output_name] = True
Exemple #40
0
def extract_image_patches(op, context):
  # use a big convolution layer (that has weights!) for this op
  input_name = compat.as_str_any(op.inputs[0].name)
  output_name = compat.as_str_any(op.outputs[0].name)
  ksizes = op.get_attr('ksizes')
  padding_type = compat.as_str_any(op.get_attr('padding'))
  if padding_type == 'VALID':
    padding_type = 'valid'
  elif padding_type == 'SAME':
    padding_type = 'same'
  else:
    raise NotImplementedError('%s not implemented' %(padding_type))
  strides = op.get_attr('strides')
  rates = op.get_attr('rates')
  assert rates == [1] * len(rates), 'Only supports when rates are all 1s'
  kh, kw = ksizes[1], ksizes[2]
  sh, sw = strides[1], strides[2]

  c_in = context.shape_dict[input_name][-1]
  n_filters = kh * kw * c_in
  W = np.zeros((kh, kw, c_in, n_filters))
  for i_h in range(kh):
    for i_w in range(kw):
      for i_c in range(c_in):
        idx = i_c + (i_w * c_in) + (i_h * c_in * kw)
        W[i_h, i_w, i_c, idx] = 1

  context.builder.add_convolution(name=output_name,
                                  kernel_channels=c_in,
                                  output_channels=n_filters,
                                  height=kh,
                                  width=kw,
                                  stride_height=sh,
                                  stride_width=sw,
                                  border_mode=padding_type,
                                  groups=1,
                                  W=W,
                                  b=None,
                                  has_bias=False,
                                  is_deconv=False,
                                  output_shape=None,
                                  input_name=input_name,
                                  output_name=output_name)
  context.translated[output_name] = True
Exemple #41
0
    def testMixedStrTypes(self):
        temp_dir = compat.as_bytes(test.get_temp_dir())

        for sub_dir in ["str", b"bytes", u"unicode"]:
            base_dir = os.path.join((temp_dir if isinstance(sub_dir, bytes)
                                     else temp_dir.decode()), sub_dir)
            self.assertFalse(gfile.Exists(base_dir))
            gfile.MakeDirs(os.path.join(compat.as_str_any(base_dir), "42"))
            gc._get_paths(base_dir, _create_parser(base_dir))
            gfile.DeleteRecursively(base_dir)
Exemple #42
0
def list_directory(dirname):
    """Returns a list of entries contained within a directory.

  The list is in arbitrary order. It does not contain the special entries "."
  and "..".

  Args:
    dirname: string, path to a directory

  Returns:
    [filename1, filename2, ... filenameN] as strings

  Raises:
    errors.NotFoundError if directory doesn't exist
  """
    if not is_directory(dirname):
        raise errors.NotFoundError(None, None, "Could not find directory")
    file_list = get_matching_files(os.path.join(compat.as_str_any(dirname), "*"))
    return [compat.as_str_any(pywrap_tensorflow.Basename(compat.as_bytes(filename))) for filename in file_list]
Exemple #43
0
  def testMixedStrTypes(self):
    temp_dir = compat.as_bytes(test.get_temp_dir())

    for sub_dir in ["str", b"bytes", u"unicode"]:
      base_dir = os.path.join(
          (temp_dir
           if isinstance(sub_dir, bytes) else temp_dir.decode()), sub_dir)
      self.assertFalse(gfile.Exists(base_dir))
      gfile.MakeDirs(os.path.join(compat.as_str_any(base_dir), "42"))
      gc.get_paths(base_dir, _create_parser(base_dir))
Exemple #44
0
def walk_v2(top, topdown=True, onerror=None):
    """Recursive directory tree generator for directories.

  Args:
    top: string, a Directory name
    topdown: bool, Traverse pre order if True, post order if False.
    onerror: optional handler for errors. Should be a function, it will be
      called with the error as argument. Rethrowing the error aborts the walk.
      Errors that happen while listing directories are ignored.

  Yields:
    Each yield is a 3-tuple:  the pathname of a directory, followed by lists of
    all its subdirectories and leaf files. That is, each yield looks like:
    `(dirname, [subdirname, subdirname, ...], [filename, filename, ...])`.
    Each item is a string.
  """
    def _make_full_path(parent, item):
        # Since `join` discards paths before one that starts with the path
        # separator (https://docs.python.org/3/library/os.path.html#join),
        # we have to manually handle that case as `/` is a valid character on GCS.
        if item[0] == os.sep:
            return "".join([join(parent, ""), item])
        return join(parent, item)

    top = compat.as_str_any(compat.path_to_str(top))
    try:
        listing = list_directory(top)
    except errors.NotFoundError as err:
        if onerror:
            onerror(err)
        else:
            return

    files = []
    subdirs = []
    for item in listing:
        full_path = _make_full_path(top, item)
        if is_directory(full_path):
            subdirs.append(item)
        else:
            files.append(item)

    here = (top, subdirs, files)

    if topdown:
        yield here

    for subdir in subdirs:
        for subitem in walk_v2(_make_full_path(top, subdir),
                               topdown,
                               onerror=onerror):
            yield subitem

    if not topdown:
        yield here
Exemple #45
0
    def request_stop(self, ex=None):
        """Request that the threads stop.

    After this is called, calls to `should_stop()` will return `True`.

    Args:
      ex: Optional `Exception`, or Python `exc_info` tuple as returned by
        `sys.exc_info()`.  If this is the first call to `request_stop()` the
        corresponding exception is recorded and re-raised from `join()`.
    """
        with self._lock:
            if not self._stop_event.is_set():
                if ex and self._exc_info_to_raise is None:
                    if isinstance(ex, tuple):
                        logging.info("Error reported to Coordinator: %s", compat.as_str_any(ex[1]))
                        self._exc_info_to_raise = ex
                    else:
                        logging.info("Error reported to Coordinator: %s", compat.as_str_any(ex))
                        self._exc_info_to_raise = sys.exc_info()
                self._stop_event.set()
Exemple #46
0
def relu6(op, context):
  input_name = compat.as_str_any(op.inputs[0].name)
  output_name = compat.as_str_any(op.outputs[0].name)

  relu_output_name = 'relu_' + output_name
  context.builder.add_activation(
      relu_output_name, 'RELU', input_name, relu_output_name)
  neg_output_name = relu_output_name + '_neg'
  # negate it
  context.builder.add_activation(
      neg_output_name, 'LINEAR', relu_output_name, neg_output_name, [-1.0, 0])
  # apply threshold
  clip_output_name = relu_output_name + '_clip'
  context.builder.add_unary(
      clip_output_name, neg_output_name, clip_output_name, 'threshold',
      alpha=-6.0)
  # negate it back
  context.builder.add_activation(
      output_name, 'LINEAR', clip_output_name, output_name, [-1.0, 0])
  context.translated[output_name] = True
Exemple #47
0
def placeholder(op, context):
  context.translated[compat.as_str_any(op.outputs[0].name)] = True
  try:
    inname = op.inputs[0].name
    # chain together no-ops here
    if inname in context.out_name_to_in_name:
      context.out_name_to_in_name[op.outputs[0].name] = (
          context.out_name_to_in_name[op.inputs[0].name])
    else:
      context.out_name_to_in_name[op.outputs[0].name] = op.inputs[0].name
  except:
    print('Skipping name of placeholder')
Exemple #48
0
def one_hot(op, context):
  input_name = compat.as_str_any(op.inputs[0].name)
  output_name = compat.as_str_any(op.outputs[0].name)

  depth = context.consts[compat.as_str_any(op.inputs[1].name)]
  on_value = context.consts[compat.as_str_any(op.inputs[2].name)]
  off_value = context.consts[compat.as_str_any(op.inputs[3].name)]

  n_dims = depth
  W = np.ones((depth, depth)) * off_value
  for i in range(depth):
    W[i, i] = on_value
  context.builder.add_embedding(name=output_name,
                                W=W,
                                b=None,
                                input_dim=n_dims,
                                output_channels=n_dims,
                                has_bias=False,
                                input_name=input_name,
                                output_name=output_name)
  context.translated[output_name] = True
def path_to_str(path):
  """Returns the file system path representation of a `PathLike` object,
  else as it is.

  Args:
    path: An object that can be converted to path representation.

  Returns:
    A `str` object.
  """
  if hasattr(path, "__fspath__"):
    path = as_str_any(path.__fspath__())
  return path
Exemple #50
0
 def testListDirectory(self):
   dir_path = os.path.join(self._base_dir, "test_dir")
   file_io.create_dir(dir_path)
   files = [b"file1.txt", b"file2.txt", b"file3.txt"]
   for name in files:
     file_path = os.path.join(dir_path, compat.as_str_any(name))
     file_io.FileIO(file_path, mode="w").write("testing")
   subdir_path = os.path.join(dir_path, "sub_dir")
   file_io.create_dir(subdir_path)
   subdir_file_path = os.path.join(subdir_path, "file4.txt")
   file_io.FileIO(subdir_file_path, mode="w").write("testing")
   dir_list = file_io.list_directory(dir_path)
   self.assertItemsEqual(files + [b"sub_dir"], dir_list)
def IsTensorFlowEventsFile(path):
  """Check the path name to see if it is probably a TF Events file.

  Args:
    path: A file path to check if it is an event file.

  Raises:
    ValueError: If the path is an empty string.

  Returns:
    If path is formatted like a TensorFlowEventsFile.
  """
  if not path:
    raise ValueError('Path must be a nonempty string')
  return 'tfevents' in compat.as_str_any(os.path.basename(path))
Exemple #52
0
  def __init__(self, handle, dtype, session):
    """Constructs a new tensor handle.

    A tensor handle for a persistent tensor is a python string
    that has the form of "tensor_name;unique_id;device_name".

    Args:
      handle: A tensor handle.
      dtype: The data type of the tensor represented by `handle`.
      session: The session in which the tensor is produced.
    """
    self._handle = compat.as_str_any(handle)
    self._dtype = dtype
    self._session = session
    self._auto_gc_enabled = True
Exemple #53
0
def walk_v2(top, topdown=True, onerror=None):
  """Recursive directory tree generator for directories.

  Args:
    top: string, a Directory name
    topdown: bool, Traverse pre order if True, post order if False.
    onerror: optional handler for errors. Should be a function, it will be
      called with the error as argument. Rethrowing the error aborts the walk.

  Errors that happen while listing directories are ignored.

  Yields:
    Each yield is a 3-tuple:  the pathname of a directory, followed by lists of
    all its subdirectories and leaf files.
    (dirname, [subdirname, subdirname, ...], [filename, filename, ...])
    as strings
  """
  top = compat.as_str_any(top)
  try:
    listing = list_directory(top)
  except errors.NotFoundError as err:
    if onerror:
      onerror(err)
    else:
      return

  files = []
  subdirs = []
  for item in listing:
    full_path = os.path.join(top, item)
    if is_directory(full_path):
      subdirs.append(item)
    else:
      files.append(item)

  here = (top, subdirs, files)

  if topdown:
    yield here

  for subdir in subdirs:
    for subitem in walk_v2(os.path.join(top, subdir), topdown, onerror=onerror):
      yield subitem

  if not topdown:
    yield here
Exemple #54
0
def get_matching_files(filename):
  """Returns a list of files that match the given pattern.

  Args:
    filename: string, the pattern

  Returns:
    Returns a list of strings containing filenames that match the given pattern.

  Raises:
    errors.OpError: If there are filesystem / directory listing errors.
  """
  with errors.raise_exception_on_not_ok_status() as status:
    # Convert each element to string, since the return values of the
    # vector of string should be interpreted as strings, not bytes.
    return [compat.as_str_any(matching_filename)
            for matching_filename in pywrap_tensorflow.GetMatchingFiles(
                compat.as_bytes(filename), status)]
def _is_removed_mentioned(s, removed_op_names):
  """Determine whether any removed op is mentioned in the given object.

  This relies on the string representation of the object.  This is used for
  proto messages that may mention ops by name in nested fields.  The string
  representation of the proto includes those field values, so this string
  search approach is sufficient.

  Args:
    s: an object to search for removed op names.
    removed_op_names: An iterable of names of ops that were removed.

  Returns:
    True if any removed op is mentioned in the given object, False otherwise.
  """
  for removed_op_name in removed_op_names:
    if removed_op_name in compat.as_str_any(s):
      return True
  return False
Exemple #56
0
def walk(top, in_order=True):
  """Recursive directory tree generator for directories.

  Args:
    top: string, a Directory name
    in_order: bool, Traverse in order if True, post order if False.

  Errors that happen while listing directories are ignored.

  Yields:
    Each yield is a 3-tuple:  the pathname of a directory, followed by lists of
    all its subdirectories and leaf files.
    (dirname, [subdirname, subdirname, ...], [filename, filename, ...])
    as strings
  """
  top = compat.as_str_any(top)
  try:
    listing = list_directory(top)
  except errors.NotFoundError:
    return

  files = []
  subdirs = []
  for item in listing:
    full_path = os.path.join(top, item)
    if is_directory(full_path):
      subdirs.append(item)
    else:
      files.append(item)

  here = (top, subdirs, files)

  if in_order:
    yield here

  for subdir in subdirs:
    for subitem in walk(os.path.join(top, subdir), in_order):
      yield subitem

  if not in_order:
    yield here
Exemple #57
0
def get_summary_description(node_def):
  """Given a TensorSummary node_def, retrieve its SummaryDescription.

  When a Summary op is instantiated, a SummaryDescription of associated
  metadata is stored in its NodeDef. This method retrieves the description.

  Args:
    node_def: the node_def_pb2.NodeDef of a TensorSummary op

  Returns:
    a summary_pb2.SummaryDescription

  Raises:
    ValueError: if the node is not a summary op.
  """

  if node_def.op != 'TensorSummary':
    raise ValueError("Can't get_summary_description on %s" % node_def.op)
  description_str = _compat.as_str_any(node_def.attr['description'].s)
  summary_description = SummaryDescription()
  _json_format.Parse(description_str, summary_description)
  return summary_description
Exemple #58
0
 def readline(self):
   r"""Reads the next line from the file. Leaves the '\n' at the end."""
   self._prereadline_check()
   return compat.as_str_any(self._read_buf.ReadLineAsString())
Exemple #59
0
 def _get_device_name(handle):
   """The device name encoded in the handle."""
   handle_str = compat.as_str_any(handle)
   return pydev.canonical_name(handle_str.split(";")[-1])