示例#1
0
  def expand(self, inputs):
    pcoll, = inputs

    return (pcoll
            | 'Encode[%s]' % self._label >> beam.Map(self._coder.encode_cache)
            | 'Count[%s]' % self._label >>
            common.IncrementCounter('cache_entries_encoded'))
示例#2
0
def _decode_cache_impl(inputs, operation, extra_args):
  """A PTransform-like method that extracts and decodes a cache object."""
  # This is implemented as a PTransform-like function because its PCollection
  # inputs were passed in from the user.
  assert not inputs

  return (
      extra_args.cache_pcoll_dict[operation.dataset_key][operation.cache_key]
      | 'Decode[%s]' % operation.label >> beam.Map(operation.coder.decode_cache)
      | 'Count[%s]' % operation.label >>
      common.IncrementCounter('cache_entries_decoded'))
示例#3
0
 def expand(self, inputs):
     pcoll, = inputs
     # We specify a fanout so that the packed combiner doesn't exhibit stragglers
     # during the 'reduce' phase when we have a lot of combine analyzers packed.
     fanout = int(math.ceil(math.sqrt(len(self._combiners))))
     # TODO(b/34792459): Don't set with_defaults.
     return (pcoll
             | 'InitialPackedCombineGlobally' >> beam.CombineGlobally(
                 _PackedCombinerWrapper(self._combiners, self._tf_config)
             ).with_fanout(fanout).with_defaults(False)
             | 'Count[InitialPackedCombineGlobally]' >>
             common.IncrementCounter('num_packed_combiners'))
示例#4
0
  def expand(self, inputs):
    pcoll, = inputs

    # TODO(b/34792459): Don't set with_defaults.
    return (
        pcoll
        | 'MergePackedCombinesGlobally' >> beam.CombineGlobally(
            _PackedCombinerWrapper(
                self._combiners,
                self._tf_config,
                is_combining_accumulators=True)).with_defaults(False)
        | 'Count' >>
        common.IncrementCounter('num_packed_merge_combiners'))
示例#5
0
 def expand(self, inputs):
   unbound_saved_model_dir = beam_common.get_unique_temp_path(
       self._base_temp_dir)
   with self._graph.as_default():
     with tf.compat.v1.Session(graph=self._graph) as session:
       table_initializers_ref = tf.compat.v1.get_collection_ref(
           tf.compat.v1.GraphKeys.TABLE_INITIALIZERS)
       original_table_initializers = list(table_initializers_ref)
       del table_initializers_ref[:]
       table_initializers_ref.extend(self._table_initializers)
       # Initialize all variables so they can be saved.
       session.run(tf.compat.v1.global_variables_initializer())
       saved_transform_io.write_saved_transform_from_session(
           session, self._input_signature, self._output_signature,
           unbound_saved_model_dir)
       del table_initializers_ref[:]
       table_initializers_ref.extend(original_table_initializers)
   return (inputs
           | 'BindTensors' >> _BindTensors(self._base_temp_dir,
                                           unbound_saved_model_dir)
           | 'Count' >> beam_common.IncrementCounter('saved_models_created'))
示例#6
0
文件: impl.py 项目: jiwidi/transform
def _create_saved_model_impl(inputs, operation, extra_args):
  """Create a SavedModel from a TF Graph."""
  unbound_saved_model_dir = common.get_unique_temp_path(
      extra_args.base_temp_dir)
  with extra_args.graph.as_default():
    with tf.compat.v1.Session(graph=extra_args.graph) as session:
      table_initializers_ref = tf.compat.v1.get_collection_ref(
          tf.compat.v1.GraphKeys.TABLE_INITIALIZERS)
      original_table_initializers = list(table_initializers_ref)
      del table_initializers_ref[:]
      table_initializers_ref.extend(operation.table_initializers)
      # Initialize all variables so they can be saved.
      session.run(tf.compat.v1.global_variables_initializer())
      saved_transform_io.write_saved_transform_from_session(
          session, extra_args.input_signature, operation.output_signature,
          unbound_saved_model_dir)
      del table_initializers_ref[:]
      table_initializers_ref.extend(original_table_initializers)
  return (inputs | operation.label >> _BindTensors(
      extra_args.base_temp_dir, unbound_saved_model_dir, extra_args.pipeline)
          | 'Count[%s]' % operation.label >>
          common.IncrementCounter('saved_models_created'))
示例#7
0
    def expand(self, pbegin):
        del pbegin  # unused

        return (self._cache_pcoll
                | 'Decode' >> beam.Map(self._coder.decode_cache)
                | 'Count' >> common.IncrementCounter('cache_entries_decoded'))