Example #1
0
def writer_pipeline(compressors, write_parallelism, record_id, output_dir,
                    suffix, args):
    prefix_name = tf.constant("{}_".format(record_id), name="prefix_string")
    compressed_batch = pipeline.join(compressors,
                                     parallel=write_parallelism,
                                     capacity=8,
                                     multi=True,
                                     name="write_input")

    for base, meta, first_ordinal, num_recs in compressed_batch:
        first_ord_as_string = string_ops.as_string(first_ordinal,
                                                   name="first_ord_as_string")
        base_key = string_ops.string_join(
            [output_dir, prefix_name, first_ord_as_string, ".", suffix],
            name="base_key_string")
        meta_key = string_ops.string_join(
            [output_dir, prefix_name, first_ord_as_string, ".metadata"],
            name="metadata_key_string")
        base_path = persona_ops.agd_file_system_buffer_writer(
            record_id=record_id,
            record_type="text" if args.protein else "base_compact",
            resource_handle=base,
            path=base_key,
            compressed=True,
            first_ordinal=first_ordinal,
            num_records=tf.to_int32(num_recs))
        meta_path = persona_ops.agd_file_system_buffer_writer(
            record_id=record_id,
            record_type="text",
            resource_handle=meta,
            path=meta_key,
            compressed=True,
            first_ordinal=first_ordinal,
            num_records=tf.to_int32(num_recs))
        yield base_path, meta_path, first_ordinal, num_recs
        def serving_input_fn():
            receiver_1 = array_ops.placeholder(dtypes.string)
            receiver_2 = array_ops.placeholder(dtypes.string)

            receiver_tensors = {
                'rec1': receiver_1,
                u'rec2': receiver_2,
            }

            concat = string_ops.string_join([receiver_1, receiver_2])
            concat2 = array_ops.identity(concat)
            features = {
                'feature0': string_ops.string_join([concat, concat2], ':'),
                u'feature1': constant_op.constant([1])
            }

            alternate_tensors = {
                'alt_name_1': concat,
                'alt_name_2': {
                    'tensor1': concat,
                    'tensor2': concat2
                }
            }
            return export_lib.ServingInputReceiver(features, receiver_tensors,
                                                   alternate_tensors)
Example #3
0
def writer_pipeline(compressors, write_parallelism, record_id, output_dir,
                    compressed):
    prefix_name = tf.constant("{}_".format(record_id), name="prefix_string")

    if compressed:
        write_op = partial(persona_ops.agd_file_system_buffer_writer,
                           compressed=compressed)
        converted_compressors = [
            [a.compressed_buffer
             for a in result_item[:3]] + list(result_item[3:])
            for result_item in compressors
        ]
    else:
        write_op = persona_ops.agd_file_system_buffer_pair_writer
        converted_compressors = compressors

    compressed_batch = pipeline.join(converted_compressors,
                                     parallel=write_parallelism,
                                     capacity=8,
                                     multi=True,
                                     name="write_input")

    for base, qual, meta, first_ordinal, num_recs in compressed_batch:
        first_ord_as_string = string_ops.as_string(first_ordinal,
                                                   name="first_ord_as_string")
        base_key = string_ops.string_join(
            [output_dir, prefix_name, first_ord_as_string, ".base"],
            name="base_key_string")
        qual_key = string_ops.string_join(
            [output_dir, prefix_name, first_ord_as_string, ".qual"],
            name="qual_key_string")
        meta_key = string_ops.string_join(
            [output_dir, prefix_name, first_ord_as_string, ".metadata"],
            name="metadata_key_string")
        base_path = write_op(record_id=record_id,
                             record_type="base_compact",
                             resource_handle=base,
                             path=base_key,
                             first_ordinal=first_ordinal,
                             num_records=tf.to_int32(num_recs))
        qual_path = write_op(record_id=record_id,
                             record_type="text",
                             resource_handle=qual,
                             path=qual_key,
                             first_ordinal=first_ordinal,
                             num_records=tf.to_int32(num_recs))
        meta_path = write_op(record_id=record_id,
                             record_type="text",
                             resource_handle=meta,
                             path=meta_key,
                             first_ordinal=first_ordinal,
                             num_records=tf.to_int32(num_recs))
        yield base_path, qual_path, meta_path, first_ordinal, num_recs
Example #4
0
    def make_inter_writers(self, batch, output_dir, write_parallelism):
        single = pipeline.join(batch,
                               parallel=write_parallelism,
                               capacity=4,
                               multi=True,
                               name="writer_queue")
        types = get_types_for_columns(self.inter_columns)
        #print("inter col types {}".format(types))
        #types = [ "structured", "base_compact", "text", "text"]

        # no uncompressed buffer pair writer yet
        writers = []
        for buf, num_recs, record_id in single:
            w = []
            bufs = tf.unstack(buf)
            for i, b in enumerate(bufs):
                result_key = string_ops.string_join(
                    [output_dir, "/", record_id, ".", self.inter_columns[i]],
                    name="key_string")

                result = persona_ops.agd_file_system_buffer_pair_writer(
                    record_id=record_id,
                    record_type=types[i],
                    resource_handle=b,
                    path=result_key,
                    first_ordinal=0,
                    num_records=tf.to_int32(num_recs))
                w.append(result)
            w.append(record_id)
            writers.append(w)
        return writers
 def interleave_fn(filename):
     # Test function that uses control flow. The True branch is never taken
     concat = string_ops.string_join([filename, "abc"])
     return control_flow_ops.cond(
         math_ops.equal(filename, "abc"),
         lambda: reader_ops.TextLineDataset(concat),
         lambda: reader_ops.TextLineDataset(filename))
Example #6
0
def writer_pipeline(compressors, write_parallelism, record_id, output_dir):
    prefix_name = tf.constant("{}_".format(record_id), name="prefix_string")
    compressed_batch = pipeline.join(compressors,
                                     parallel=write_parallelism,
                                     capacity=8,
                                     multi=True,
                                     name="write_input")

    types = ['base_compact', 'text', 'text', 'structured']
    exts = ['.base', '.qual', '.metadata', '.results']
    for chunk_stacked, first_ordinal, num_recs in compressed_batch:
        chunks = tf.unstack(chunk_stacked)
        first_ord_as_string = string_ops.as_string(first_ordinal,
                                                   name="first_ord_as_string")

        paths = []
        for i, chunk in enumerate(chunks):
            key = string_ops.string_join(
                [output_dir, prefix_name, first_ord_as_string, exts[i]],
                name="key_string")
            paths.append(
                persona_ops.agd_file_system_buffer_writer(
                    record_id=record_id,
                    record_type=types[i],
                    resource_handle=chunk,
                    path=key,
                    compressed=True,
                    first_ordinal=first_ordinal,
                    num_records=tf.to_int32(num_recs)))
        yield paths + [first_ordinal, num_recs]
Example #7
0
def _make_writers(compressed_batch, output_dir, write_parallelism):

    compressed_single = pipeline.join(compressed_batch,
                                      parallel=write_parallelism,
                                      capacity=8,
                                      multi=True)

    for buf, num_recs, first_ordinal, record_id in compressed_single:

        first_ord_as_string = string_ops.as_string(first_ordinal,
                                                   name="first_ord_as_string")
        result_key = string_ops.string_join(
            [output_dir, "/", record_id, "_", first_ord_as_string, ".results"],
            name="base_key_string")

        result = persona_ops.agd_file_system_buffer_writer(
            record_id=record_id,
            record_type="structured",
            resource_handle=buf,
            path=result_key,
            compressed=True,
            first_ordinal=first_ordinal,
            num_records=tf.to_int32(num_recs))

        yield result  # writes out the file path key (full path)
 def testStateSaverScopeNames(self):
   batch_size = constant_op.constant(2)
   sqss_scope_name = "unique_scope_name_for_sqss"
   num_unroll = 2
   length = 3
   key = string_ops.string_join([
       "key_", string_ops.as_string(
           math_ops.cast(10000 * random_ops.random_uniform(()), dtypes.int32))
   ])
   padded_length = 4
   sequences = {
       "seq1": np.random.rand(padded_length, 5),
       "seq2": np.random.rand(padded_length, 4, 2)
   }
   context = {"context1": [3, 4]}
   initial_states = {
       "state1": np.random.rand(6, 7),
       "state2": np.random.rand(8)
   }
   state_saver = sqss.SequenceQueueingStateSaver(
       batch_size=batch_size,
       num_unroll=num_unroll,
       input_length=length,
       input_key=key,
       input_sequences=sequences,
       input_context=context,
       initial_states=initial_states,
       name=sqss_scope_name)
   prefetch_op = state_saver.prefetch_op
   next_batch = state_saver.next_batch
   self.assertTrue(
       state_saver.barrier.barrier_ref.name.startswith("%s/" %
                                                       sqss_scope_name))
   self.assertTrue(prefetch_op.name.startswith("%s/" % sqss_scope_name))
   self.assertTrue(next_batch.key.name.startswith("%s/" % sqss_scope_name))
 def input_fn():
     start = random_ops.random_uniform((),
                                       minval=0,
                                       maxval=sequence_length,
                                       dtype=dtypes.int32,
                                       seed=seed)
     # Concatenate lyrics_list so inputs and labels wrap when start > 0.
     lyrics_list_concat = lyrics_list + lyrics_list
     inputs_dense = array_ops.slice(lyrics_list_concat, [start],
                                    [sequence_length])
     indices = array_ops.constant([[i, 0]
                                   for i in range(sequence_length)],
                                  dtype=dtypes.int64)
     dense_shape = [sequence_length, 1]
     inputs = sparse_tensor.SparseTensor(indices=indices,
                                         values=inputs_dense,
                                         dense_shape=dense_shape)
     table = lookup.string_to_index_table_from_tensor(
         mapping=list(vocab), default_value=-1, name='lookup')
     labels = table.lookup(
         array_ops.slice(lyrics_list_concat, [start + 1],
                         [sequence_length]))
     input_key = string_ops.string_join([
         'key_',
         string_ops.as_string(
             random_ops.random_uniform((),
                                       minval=0,
                                       maxval=10000000,
                                       dtype=dtypes.int32,
                                       seed=seed))
     ])
     return {
         'lyrics': inputs,
         input_key_column_name: input_key
     }, labels
 def input_fn():
   start = random_ops.random_uniform(
       (), minval=0, maxval=sequence_length, dtype=dtypes.int32, seed=seed)
   # Concatenate lyrics_list so inputs and labels wrap when start > 0.
   lyrics_list_concat = lyrics_list + lyrics_list
   inputs_dense = array_ops.slice(lyrics_list_concat, [start],
                                  [sequence_length])
   indices = array_ops.constant(
       [[i, 0] for i in range(sequence_length)], dtype=dtypes.int64)
   dense_shape = [sequence_length, 1]
   inputs = sparse_tensor.SparseTensor(
       indices=indices, values=inputs_dense, dense_shape=dense_shape)
   table = lookup.string_to_index_table_from_tensor(
       mapping=list(vocab), default_value=-1, name='lookup')
   labels = table.lookup(
       array_ops.slice(lyrics_list_concat, [start + 1], [sequence_length]))
   input_key = string_ops.string_join([
       'key_', string_ops.as_string(
           random_ops.random_uniform(
               (),
               minval=0,
               maxval=10000000,
               dtype=dtypes.int32,
               seed=seed))
   ])
   return {'lyrics': inputs, input_key_column_name: input_key}, labels
Example #11
0
 def testStateSaverScopeNames(self):
   batch_size = constant_op.constant(2)
   sqss_scope_name = "unique_scope_name_for_sqss"
   num_unroll = 2
   length = 3
   key = string_ops.string_join([
       "key_", string_ops.as_string(
           math_ops.cast(10000 * random_ops.random_uniform(()), dtypes.int32))
   ])
   padded_length = 4
   sequences = {
       "seq1": np.random.rand(padded_length, 5),
       "seq2": np.random.rand(padded_length, 4, 2)
   }
   context = {"context1": [3, 4]}
   initial_states = {
       "state1": np.random.rand(6, 7),
       "state2": np.random.rand(8)
   }
   state_saver = sqss.SequenceQueueingStateSaver(
       batch_size=batch_size,
       num_unroll=num_unroll,
       input_length=length,
       input_key=key,
       input_sequences=sequences,
       input_context=context,
       initial_states=initial_states,
       name=sqss_scope_name)
   prefetch_op = state_saver.prefetch_op
   next_batch = state_saver.next_batch
   self.assertTrue(
       state_saver.barrier.barrier_ref.name.startswith("%s/" %
                                                       sqss_scope_name))
   self.assertTrue(prefetch_op.name.startswith("%s/" % sqss_scope_name))
   self.assertTrue(next_batch.key.name.startswith("%s/" % sqss_scope_name))
Example #12
0
 def get_updates(self, loss, params):
     del params
     return [
         string_ops.string_join([
             constant_op.constant(expected_train_result),
             string_ops.as_string(loss, precision=3)
         ])
     ]
 def setUp(self):
   super(BatchSequencesWithStatesTest, self).setUp()
   self.value_length = 4
   ind1 = np.array([
       [0, 0],
       [1, 0], [1, 3], [1, 4],
       [3, 2], [3, 3]])
   val1 = np.array([0, 10, 13, 14, 32, 33])
   shape1 = np.array([self.value_length, 6])
   sp_tensor1 = sparse_tensor.SparseTensor(
       array_ops.constant(ind1, dtypes.int64),
       array_ops.constant(val1, dtypes.int64),
       array_ops.constant(shape1, dtypes.int64))
   ind2 = np.array([
       [0, 0, 1],
       [0, 1, 0],
       [0, 1, 2],
       [1, 0, 3],
       [1, 1, 0],
       [1, 1, 1],
       [1, 1, 2],
       [1, 2, 2]])
   val2 = np.array([1, 10, 12, 103, 150, 149, 150, 122])
   shape2 = np.array([self.value_length, 3, 4])
   sp_tensor2 = sparse_tensor.SparseTensor(
       array_ops.constant(ind2, dtypes.int64),
       array_ops.constant(val2, dtypes.int64),
       array_ops.constant(shape2, dtypes.int64))
   sp_tensor3 = sparse_tensor.SparseTensor(
       array_ops.constant([[1, 9], [2, 2], [2, 10]], dtypes.int64),
       array_ops.constant([7, 15, 2], dtypes.int64),
       array_ops.constant([5, 12], dtypes.int64)
   )
   self.sp_tensor3_expected = sparse_tensor.SparseTensorValue(
       [[0, 1, 9], [0, 2, 2], [0, 2, 10], [1, 1, 9], [1, 2, 2], [1, 2, 10]],
       [7, 15, 2, 7, 15, 2],
       [2, 5, 12]
   )
   self.batch_size = 2
   self.key = string_ops.string_join([
       "key_", string_ops.as_string(
           math_ops.cast(10000 * random_ops.random_uniform(()), dtypes.int32))
   ])
   self.sequences = {
       "seq1": np.random.rand(self.value_length, 5),
       "seq2": np.random.rand(self.value_length, 4, 2),
       "seq3": sp_tensor1,
       "seq4": sp_tensor2}
   self.context = {
       "context1": [3, 4],
       "sp_context": sp_tensor3}
   self.initial_states = {
       "state1": np.random.rand(6, 7),
       "state2": np.random.rand(8)
   }
 def setUp(self):
   super(BatchSequencesWithStatesTest, self).setUp()
   self.value_length = 4
   ind1 = np.array([
       [0, 0],
       [1, 0], [1, 3], [1, 4],
       [3, 2], [3, 3]])
   val1 = np.array([0, 10, 13, 14, 32, 33])
   shape1 = np.array([self.value_length, 6])
   sp_tensor1 = sparse_tensor.SparseTensor(
       array_ops.constant(ind1, dtypes.int64),
       array_ops.constant(val1, dtypes.int64),
       array_ops.constant(shape1, dtypes.int64))
   ind2 = np.array([
       [0, 0, 1],
       [0, 1, 0],
       [0, 1, 2],
       [1, 0, 3],
       [1, 1, 0],
       [1, 1, 1],
       [1, 1, 2],
       [1, 2, 2]])
   val2 = np.array([1, 10, 12, 103, 150, 149, 150, 122])
   shape2 = np.array([self.value_length, 3, 4])
   sp_tensor2 = sparse_tensor.SparseTensor(
       array_ops.constant(ind2, dtypes.int64),
       array_ops.constant(val2, dtypes.int64),
       array_ops.constant(shape2, dtypes.int64))
   sp_tensor3 = sparse_tensor.SparseTensor(
       array_ops.constant([[1, 9], [2, 2], [2, 10]], dtypes.int64),
       array_ops.constant([7, 15, 2], dtypes.int64),
       array_ops.constant([5, 12], dtypes.int64)
   )
   self.sp_tensor3_expected = sparse_tensor.SparseTensorValue(
       [[0, 1, 9], [0, 2, 2], [0, 2, 10], [1, 1, 9], [1, 2, 2], [1, 2, 10]],
       [7, 15, 2, 7, 15, 2],
       [2, 5, 12]
   )
   self.batch_size = 2
   self.key = string_ops.string_join([
       "key_", string_ops.as_string(
           math_ops.cast(10000 * random_ops.random_uniform(()), dtypes.int32))
   ])
   self.sequences = {
       "seq1": np.random.rand(self.value_length, 5),
       "seq2": np.random.rand(self.value_length, 4, 2),
       "seq3": sp_tensor1,
       "seq4": sp_tensor2}
   self.context = {
       "context1": [3, 4],
       "sp_context": sp_tensor3}
   self.initial_states = {
       "state1": np.random.rand(6, 7),
       "state2": np.random.rand(8)
   }
 def input_fn():
   start = random_ops.random_uniform(
       (), minval=0, maxval=(np.pi * 2.0), dtype=dtypes.float32, seed=seed)
   sin_curves = math_ops.sin(
       math_ops.linspace(start, (sequence_length - 1) * increment,
                         sequence_length + 1))
   inputs = array_ops.slice(sin_curves, [0], [sequence_length])
   labels = array_ops.slice(sin_curves, [1], [sequence_length])
   input_key = string_ops.string_join([
       'key_',
       string_ops.as_string(math_ops.cast(10000 * start, dtypes.int32))
   ])
   return {'inputs': inputs, input_key_column_name: input_key}, labels
Example #16
0
    def make_merge_pipeline(self, args, record_name, chunks_to_merge, bpp):

        types = [dtypes.int32] + [dtypes.string] * len(self.inter_columns)
        shapes = [tensor_shape.scalar()
                  ] + [tensor_shape.vector(2)] * len(self.inter_columns)
        q = data_flow_ops.FIFOQueue(
            capacity=8,  # big because who cares
            dtypes=types,
            shapes=shapes,
            name="merge_output_queue")

        #bpp = persona_ops.buffer_pair_pool(size=0, bound=False, name="local_read_merge_buffer_list_pool")

        if args.order_by == location_value:
            merge = persona_ops.agd_merge
        else:
            merge = persona_ops.agd_merge_metadata

        merge_op = merge(chunk_size=args.chunk,
                         buffer_pair_pool=bpp,
                         chunk_group_handles=chunks_to_merge,
                         output_buffer_queue_handle=q.queue_ref,
                         name="agd_local_merge")

        tf.train.queue_runner.add_queue_runner(
            tf.train.queue_runner.QueueRunner(q, [merge_op]))

        # num_recs, results, base, qual, meta
        #num_recs, results, base, qual, meta = q.dequeue()
        val = q.dequeue()
        num_recs = val[0]

        record_name_constant = constant_op.constant(record_name)
        first_ordinal = tf.Variable(-1 * args.chunk,
                                    dtype=dtypes.int64,
                                    name="first_ordinal")
        first_ord = first_ordinal.assign_add(math_ops.to_int64(
            args.chunk, name="first_ord_cast_to_64"),
                                             use_locking=True)
        first_ord_str = string_ops.as_string(first_ord,
                                             name="first_ord_string")
        file_name = string_ops.string_join(
            [args.dataset_dir, "/", record_name_constant, first_ord_str],
            name="file_name_string_joiner")

        out_tuple = val[1:] + [record_name, first_ord, num_recs, file_name]

        return out_tuple
 def input_fn():
   random_sequence = random_ops.random_uniform(
       [sequence_length + 1], 0, 2, dtype=dtypes.int32, seed=seed)
   labels = array_ops.slice(random_sequence, [0], [sequence_length])
   inputs = math_ops.to_float(
       array_ops.slice(random_sequence, [1], [sequence_length]))
   input_key = string_ops.string_join([
       'key_', string_ops.as_string(
           random_ops.random_uniform(
               (),
               minval=0,
               maxval=10000000,
               dtype=dtypes.int32,
               seed=seed))
   ])
   return {'inputs': inputs, input_key_column_name: input_key}, labels
 def setUp(self):
   super(BatchSequencesWithStatesTest, self).setUp()
   self.value_length = 4
   self.batch_size = 2
   self.key = string_ops.string_join([
       "key_", string_ops.as_string(
           math_ops.cast(10000 * random_ops.random_uniform(()), dtypes.int32))
   ])
   self.sequences = {
       "seq1": np.random.rand(self.value_length, 5),
       "seq2": np.random.rand(self.value_length, 4, 2)
   }
   self.context = {"context1": [3, 4]}
   self.initial_states = {
       "state1": np.random.rand(6, 7),
       "state2": np.random.rand(8)
   }
  def testValidPipelineWithRangeDataset(self, shuffle):
    dataset = dataset_ops.Dataset.range(self._num_files)
    dataset = dataset.map(lambda n: string_ops.string_join(  # pylint:disable=g-long-lambda
        [self.get_temp_dir(),
         string_ops.string_format("/tf_record.{}.txt", [n])]))
    dataset = dataset.apply(
        interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10))
    dataset = dataset.map(lambda x: string_ops.substr_v2(x, 2, 1000))
    dataset = dataset.batch(5)
    dataset = distribute._AutoShardDataset(dataset, 5, 3)

    expected = [
        b"cord %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
        for r in range(0, 10)
        for f in (3, 8)
    ]
    self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle)
Example #20
0
  def testValidPipelineWithRangeDataset(self, shuffle):
    dataset = dataset_ops.Dataset.range(self._num_files)
    dataset = dataset.map(lambda n: string_ops.string_join(  # pylint:disable=g-long-lambda
        [self.get_temp_dir(),
         string_ops.string_format("/tf_record.{}.txt", [n])]))
    dataset = dataset.apply(
        interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10))
    dataset = dataset.map(lambda x: string_ops.substr_v2(x, 2, 1000))
    dataset = dataset.batch(5)
    dataset = distribute._AutoShardDataset(dataset, 5, 3)

    expected = [
        b"cord %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
        for r in range(0, 10)
        for f in (3, 8)
    ]
    self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle)
Example #21
0
 def setUp(self):
     super(BatchSequencesWithStatesTest, self).setUp()
     self.value_length = 4
     self.batch_size = 2
     self.key = string_ops.string_join([
         "key_",
         string_ops.as_string(
             math_ops.cast(10000 * random_ops.random_uniform(()),
                           dtypes.int32))
     ])
     self.sequences = {
         "seq1": np.random.rand(self.value_length, 5),
         "seq2": np.random.rand(self.value_length, 4, 2)
     }
     self.context = {"context1": [3, 4]}
     self.initial_states = {
         "state1": np.random.rand(6, 7),
         "state2": np.random.rand(8)
     }
Example #22
0
 def input_fn():
     random_sequence = random_ops.random_uniform(
         [sequence_length + 1], 0, 2, dtype=dtypes.int32, seed=seed)
     labels = array_ops.slice(random_sequence, [0],
                              [sequence_length])
     inputs = math_ops.to_float(
         array_ops.slice(random_sequence, [1], [sequence_length]))
     input_key = string_ops.string_join([
         'key_',
         string_ops.as_string(
             random_ops.random_uniform((),
                                       minval=0,
                                       maxval=10000000,
                                       dtype=dtypes.int32,
                                       seed=seed))
     ])
     return {
         'inputs': inputs,
         input_key_column_name: input_key
     }, labels
Example #23
0
 def input_fn():
     start = random_ops.random_uniform((),
                                       minval=0,
                                       maxval=(np.pi * 2.0),
                                       dtype=dtypes.float32,
                                       seed=seed)
     sin_curves = math_ops.sin(
         math_ops.linspace(start, (sequence_length - 1) * increment,
                           sequence_length + 1))
     inputs = array_ops.slice(sin_curves, [0], [sequence_length])
     labels = array_ops.slice(sin_curves, [1], [sequence_length])
     input_key = string_ops.string_join([
         'key_',
         string_ops.as_string(
             math_ops.cast(10000 * start, dtypes.int32))
     ])
     return {
         'inputs': inputs,
         input_key_column_name: input_key
     }, labels
Example #24
0
def compressed_writer_pipeline(converters, write_parallelism, record_id,
                               output_dir, compress_parallelism):
    def make_compress():
        to_compress_batch = tf.train.batch_join_pdq(
            tuple(converters),
            1,
            num_dq_ops=compress_parallelism,
            name="pre_compression_queue")
        bp = persona_ops.buffer_pool(size=0,
                                     bound=False,
                                     name="compression_buffer_pool")
        for converted_handle, first_ordinal, num_recs in to_compress_batch:
            yield persona_ops.buffer_list_compressor(
                buffer_list_size=3,
                buffer_pool=bp,
                buffer_list=converted_handle), first_ordinal, num_recs

    prefix_name = tf.Variable("{}_".format(record_id),
                              dtype=dtypes.string,
                              name="prefix_string")
    compressed_batch = tf.train.batch_join_pdq(
        tuple(make_compress()),
        1,
        num_dq_ops=write_parallelism,
        name="post_compression_pre_write_queue")

    for compressed_set, first_ordinal, num_recs in compressed_batch:
        first_ord_as_string = string_ops.as_string(first_ordinal,
                                                   name="first_ord_as_string")
        file_key = string_ops.string_join([prefix_name, first_ord_as_string],
                                          name="file_key_string")
        file_path_out = persona_ops.column_writer(
            record_id=record_id,
            compressed=True,
            record_types=record_type,
            outdir=output_dir,
            columns=compressed_set,
            file_path=file_key,
            first_ordinal=first_ordinal,
            num_recs=tf.to_int32(num_recs))
        yield file_path_out, first_ordinal, num_recs
Example #25
0
def writer_pipeline(converters, write_parallelism, record_id, output_dir):
    prefix_name = tf.Variable("{}_".format(record_id),
                              dtype=dtypes.string,
                              name="prefix_string")
    converted_batch = tf.train.batch_join_pdq(tuple(converters),
                                              1,
                                              num_dq_ops=write_parallelism,
                                              name="converted_batch_queue")
    for converted_handle, first_ordinal, num_recs in converted_batch:
        first_ord_as_string = string_ops.as_string(first_ordinal,
                                                   name="first_ord_as_string")
        file_key = string_ops.string_join([prefix_name, first_ord_as_string],
                                          name="file_key_string")
        file_path, _ = persona_ops.agd_write_columns(
            record_id=record_id,
            record_type=record_type,
            column_handle=converted_handle,
            output_dir=output_dir,
            file_path=file_key,
            first_ordinal=first_ordinal,
            num_records=tf.to_int32(num_recs))
        yield file_path, first_ordinal, num_recs
      def input_fn():
        input_key = string_ops.string_join([
            'key_', string_ops.as_string(
                random_ops.random_uniform(
                    (),
                    minval=0,
                    maxval=10000000,
                    dtype=dtypes.int32,
                    seed=seed))
        ])
        features = {}
        random_sequence = random_ops.random_uniform(
            [sequence_length + 1], 0, 2, dtype=dtypes.int32, seed=seed)
        labels = array_ops.slice(random_sequence, [0], [sequence_length])
        inputs = math_ops.to_float(
            array_ops.slice(random_sequence, [1], [sequence_length]))
        features = {'inputs': inputs, input_key_column_name: input_key}

        if mode == model_fn_lib.ModeKeys.INFER:
          input_examples = array_ops.placeholder(dtypes.string)
          features[input_feature_key] = input_examples
          labels = None
        return features, labels
Example #27
0
            def input_fn():
                input_key = string_ops.string_join([
                    'key_',
                    string_ops.as_string(
                        random_ops.random_uniform((),
                                                  minval=0,
                                                  maxval=10000000,
                                                  dtype=dtypes.int32,
                                                  seed=seed))
                ])
                features = {}
                random_sequence = random_ops.random_uniform(
                    [sequence_length + 1], 0, 2, dtype=dtypes.int32, seed=seed)
                labels = array_ops.slice(random_sequence, [0],
                                         [sequence_length])
                inputs = math_ops.to_float(
                    array_ops.slice(random_sequence, [1], [sequence_length]))
                features = {'inputs': inputs, input_key_column_name: input_key}

                if mode == model_fn_lib.ModeKeys.INFER:
                    input_examples = array_ops.placeholder(dtypes.string)
                    features[input_feature_key] = input_examples
                    labels = None
                return features, labels
Example #28
0
  def testStringJoin(self):
    input0 = ["a", "b"]
    input1 = "a"
    input2 = [["b"], ["c"]]

    with self.cached_session():
      output = string_ops.string_join([input0, input1])
      self.assertAllEqual(output.eval(), [b"aa", b"ba"])

      output = string_ops.string_join([input0, input1], separator="--")
      self.assertAllEqual(output.eval(), [b"a--a", b"b--a"])

      output = string_ops.string_join([input0, input1, input0], separator="--")
      self.assertAllEqual(output.eval(), [b"a--a--a", b"b--a--b"])

      output = string_ops.string_join([input1] * 4, separator="!")
      self.assertEqual(output.eval(), b"a!a!a!a")

      output = string_ops.string_join([input2] * 2, separator="")
      self.assertAllEqual(output.eval(), [[b"bb"], [b"cc"]])

      with self.assertRaises(ValueError):  # Inconsistent shapes
        string_ops.string_join([input0, input2]).eval()
Example #29
0
 def _train_op_fn(loss):
     return string_ops.string_join([
         constant_op.constant(expected_train_result),
         string_ops.as_string(loss, precision=3)
     ])
Example #30
0
 def _train_op_fn(loss):
   return string_ops.string_join(
       [constant_op.constant(expected_train_result),
        string_ops.as_string(loss, precision=3)])
Example #31
0
 def fn(entry_lt):
   op = string_ops.string_join([entry_lt, 'world'])
   return core.LabeledTensor(op, [])
Example #32
0
  def testStateSaverWithTwoSimpleSteps(self):
    with self.cached_session() as sess:
      batch_size_value = 2
      batch_size = constant_op.constant(batch_size_value)
      num_unroll = 2
      length = 3
      key = string_ops.string_join([
          "key_", string_ops.as_string(
              math_ops.cast(10000 * random_ops.random_uniform(()),
                            dtypes.int32))
      ])
      padded_length = 4
      sequences = {
          "seq1": np.random.rand(padded_length, 5),
          "seq2": np.random.rand(padded_length, 4, 2)
      }
      context = {"context1": [3, 4]}
      initial_states = {
          "state1": np.random.rand(6, 7),
          "state2": np.random.rand(8)
      }
      state_saver = sqss.SequenceQueueingStateSaver(
          batch_size=batch_size,
          num_unroll=num_unroll,
          input_length=length,
          input_key=key,
          input_sequences=sequences,
          input_context=context,
          initial_states=initial_states,
          capacity=100)

      initial_key_value_0, _ = sess.run((key, state_saver.prefetch_op))
      initial_key_value_1, _ = sess.run((key, state_saver.prefetch_op))

      initial_key_value_0 = initial_key_value_0.decode("ascii")
      initial_key_value_1 = initial_key_value_1.decode("ascii")

      # Step 1
      next_batch = state_saver.next_batch
      (key_value, next_key_value, seq1_value, seq2_value, context1_value,
       state1_value, state2_value, length_value, _, _) = sess.run(
           (next_batch.key, next_batch.next_key, next_batch.sequences["seq1"],
            next_batch.sequences["seq2"], next_batch.context["context1"],
            next_batch.state("state1"), next_batch.state("state2"),
            next_batch.length,
            next_batch.save_state("state1", next_batch.state("state1") + 1),
            next_batch.save_state("state2", next_batch.state("state2") - 1)))

      expected_first_keys = set(
          ("00000_of_00002:%s" % x).encode("ascii")
          for x in (initial_key_value_0, initial_key_value_1))
      expected_second_keys = set(
          ("00001_of_00002:%s" % x).encode("ascii")
          for x in (initial_key_value_0, initial_key_value_1))
      expected_final_keys = set(
          ("STOP:%s" % x).encode("ascii")
          for x in (initial_key_value_0, initial_key_value_1))

      self.assertEqual(set(key_value), expected_first_keys)
      self.assertEqual(set(next_key_value), expected_second_keys)
      self.assertAllEqual(context1_value,
                          np.tile(context["context1"], (batch_size_value, 1)))
      self.assertAllEqual(seq1_value,
                          np.tile(sequences["seq1"][np.newaxis, 0:2, :],
                                  (batch_size_value, 1, 1)))
      self.assertAllEqual(seq2_value,
                          np.tile(sequences["seq2"][np.newaxis, 0:2, :, :],
                                  (batch_size_value, 1, 1, 1)))
      self.assertAllEqual(state1_value,
                          np.tile(initial_states["state1"],
                                  (batch_size_value, 1, 1)))
      self.assertAllEqual(state2_value,
                          np.tile(initial_states["state2"],
                                  (batch_size_value, 1)))
      self.assertAllEqual(length_value, [2, 2])

      # Step 2
      (key_value, next_key_value, seq1_value, seq2_value, context1_value,
       state1_value, state2_value, length_value, _, _) = sess.run(
           (next_batch.key, next_batch.next_key, next_batch.sequences["seq1"],
            next_batch.sequences["seq2"], next_batch.context["context1"],
            next_batch.state("state1"), next_batch.state("state2"),
            next_batch.length,
            next_batch.save_state("state1", next_batch.state("state1") + 1),
            next_batch.save_state("state2", next_batch.state("state2") - 1)))

      self.assertEqual(set(key_value), expected_second_keys)
      self.assertEqual(set(next_key_value), expected_final_keys)
      self.assertAllEqual(context1_value,
                          np.tile(context["context1"], (batch_size_value, 1)))
      self.assertAllEqual(seq1_value,
                          np.tile(sequences["seq1"][np.newaxis, 2:4, :],
                                  (batch_size_value, 1, 1)))
      self.assertAllEqual(seq2_value,
                          np.tile(sequences["seq2"][np.newaxis, 2:4, :, :],
                                  (batch_size_value, 1, 1, 1)))
      self.assertAllEqual(state1_value, 1 + np.tile(initial_states["state1"],
                                                    (batch_size_value, 1, 1)))
      self.assertAllEqual(state2_value, -1 + np.tile(initial_states["state2"],
                                                     (batch_size_value, 1)))
      self.assertAllEqual(length_value, [1, 1])

      # Finished.  Let's make sure there's nothing left in the barrier.
      self.assertEqual(0, state_saver.barrier.ready_size().eval())
  def save(self, file_prefix):
    """Save the saveable objects to a checkpoint with `file_prefix`.

    Args:
      file_prefix: A string or scalar string Tensor containing the prefix to
        save under.
    Returns:
      An `Operation`, or None when executing eagerly.
    """
    # IMPLEMENTATION DETAILS: most clients should skip.
    #
    # Suffix for any well-formed "checkpoint_prefix", when sharded.
    # Transformations:
    # * Users pass in "save_path" in save() and restore().  Say "myckpt".
    # * checkpoint_prefix gets fed <save_path><sharded_suffix>.
    #
    # Example:
    #   During runtime, a temporary directory is first created, which contains
    #   files
    #
    #     <train dir>/myckpt_temp/
    #        part-?????-of-?????{.index, .data-00000-of-00001}
    #
    #   Before .save() finishes, they will be (hopefully, atomically) renamed to
    #
    #     <train dir>/
    #        myckpt{.index, .data-?????-of-?????}
    #
    # Users only need to interact with the user-specified prefix, which is
    # "<train dir>/myckpt" in this case.  Save() and Restore() work with the
    # prefix directly, instead of any physical pathname.  (On failure and
    # subsequent restore, an outdated and orphaned temporary directory can be
    # safely removed.)
    sharded_suffix = "_temp_%s/part" % uuid.uuid4().hex

    with ops.device("cpu:0"):
      tmp_checkpoint_prefix = string_ops.string_join(
          [file_prefix, sharded_suffix])

    num_shards = len(self._single_device_savers)
    sharded_saves = []
    sharded_prefixes = []
    num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
    last_device = None
    for shard, (device, saver) in enumerate(
        sorted(self._single_device_savers.items())):
      last_device = device
      with ops.device(saveable_object_util.set_cpu0(device)):
        shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard,
                                        num_shards_tensor)
      sharded_prefixes.append(shard_prefix)
      with ops.device(device):
        # _SingleDeviceSaver will use the CPU device when necessary, but initial
        # read operations should be placed on the SaveableObject's device.
        sharded_saves.append(saver.save(shard_prefix))

    with ops.control_dependencies(sharded_saves):
      # Co-locates the merge step with the last device.
      with ops.device(saveable_object_util.set_cpu0(last_device)):
        # V2 format write path consists of a metadata merge step.  Once merged,
        # attempts to delete the temporary directory, "<user-fed prefix>_temp".
        return gen_io_ops.merge_v2_checkpoints(
            sharded_prefixes, file_prefix, delete_old_dirs=True)
Example #34
0
 def minimize(self, loss, global_step):
   del global_step
   return string_ops.string_join(
       [constant_op.constant(expected_train_result),
        string_ops.as_string(loss, precision=3)])
 def fn(entry_lt):
     op = string_ops.string_join([entry_lt, 'world'])
     return core.LabeledTensor(op, [])
Example #36
0
 def minimize(self, loss, global_step):
   del global_step
   return string_ops.string_join(
       [constant_op.constant(expected_train_result),
        string_ops.as_string(loss, precision=3)])
Example #37
0
    def save(self, file_prefix, options=None):
        """Save the saveable objects to a checkpoint with `file_prefix`.

    Args:
      file_prefix: A string or scalar string Tensor containing the prefix to
        save under.
      options: Optional `CheckpointOptions` object.
    Returns:
      An `Operation`, or None when executing eagerly.
    """
        options = options or checkpoint_options.CheckpointOptions()
        for callback in self._before_save_callbacks:
            callback()

        # IMPLEMENTATION DETAILS: most clients should skip.
        #
        # Suffix for any well-formed "checkpoint_prefix", when sharded.
        # Transformations:
        # * Users pass in "save_path" in save() and restore().  Say "myckpt".
        # * checkpoint_prefix gets fed <save_path><sharded_suffix>.
        #
        # Example:
        #   During runtime, a temporary directory is first created, which contains
        #   files
        #
        #     <train dir>/myckpt_temp/
        #        part-?????-of-?????{.index, .data-00000-of-00001}
        #
        #   Before .save() finishes, they will be (hopefully, atomically) renamed to
        #
        #     <train dir>/
        #        myckpt{.index, .data-?????-of-?????}
        #
        #   Filesystems with eventual consistency (such as S3), don't need a
        #   temporary location. Using a temporary directory in those cases might
        #   cause situations where files are not available during copy.
        #
        # Users only need to interact with the user-specified prefix, which is
        # "<train dir>/myckpt" in this case.  Save() and Restore() work with the
        # prefix directly, instead of any physical pathname.  (On failure and
        # subsequent restore, an outdated and orphaned temporary directory can be
        # safely removed.)
        with ops.device("CPU"):
            sharded_suffix = array_ops.where(
                string_ops.regex_full_match(file_prefix, "^s3://.*"),
                constant_op.constant(".part"),
                constant_op.constant("_temp_%s/part" % uuid.uuid4().hex))
            tmp_checkpoint_prefix = string_ops.string_join(
                [file_prefix, sharded_suffix])

        num_shards = len(self._single_device_savers)
        sharded_saves = []
        sharded_prefixes = []
        num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
        last_device = None
        for shard, (device, saver) in enumerate(
                sorted(self._single_device_savers.items())):
            last_device = device
            with ops.device(saveable_object_util.set_cpu0(device)):
                shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard,
                                                num_shards_tensor)
            sharded_prefixes.append(shard_prefix)
            with ops.device(device):
                # _SingleDeviceSaver will use the CPU device when necessary, but initial
                # read operations should be placed on the SaveableObject's device.
                sharded_saves.append(saver.save(shard_prefix, options))

        with ops.control_dependencies(sharded_saves):
            # Merge on the io_device if specified, otherwise co-locates the merge op
            # with the last device used.
            merge_device = (options.experimental_io_device
                            or saveable_object_util.set_cpu0(last_device))
            with ops.device(merge_device):
                # V2 format write path consists of a metadata merge step.  Once merged,
                # attempts to delete the temporary directory, "<user-fed prefix>_temp".
                return gen_io_ops.merge_v2_checkpoints(sharded_prefixes,
                                                       file_prefix,
                                                       delete_old_dirs=True)
Example #38
0
 def reduce_func(key, dataset):
     shard_filename = string_ops.string_join(
         [filename, string_ops.as_string(key)])
     writer = writers.TFRecordWriter(shard_filename)
     writer.write(dataset.map(lambda _, x: x))
     return dataset_ops.Dataset.from_tensors(shard_filename)
Example #39
0
def registered_saver_filename(filename_tensor, saver_name):
  return string_ops.string_join(
      [filename_tensor, constant_op.constant(f"-{saver_name}")])
Example #40
0
def from_dataset_id(processing_mode,
                    service,
                    dataset_id,
                    element_spec=None,
                    job_name=None,
                    consumer_index=None,
                    num_consumers=None,
                    max_outstanding_requests=None,
                    max_request_pipelining_per_worker=1,
                    data_transfer_protocol=None,
                    target_workers="AUTO"):
    """Creates a dataset which reads data from the tf.data service.

  This is useful when the dataset is registered by one process, then used in
  another process. When the same process is both registering and reading from
  the dataset, it is simpler to use `tf.data.experimental.service.distribute`
  instead.

  Before using `from_dataset_id`, the dataset must have been registered with the
  tf.data service using `tf.data.experimental.service.register_dataset`.
  `register_dataset` returns a dataset id for the registered dataset. That is
  the `dataset_id` which should be passed to `from_dataset_id`.

  The `element_spec` argument indicates the `tf.TypeSpec`s for the elements
  produced by the dataset. Currently `element_spec` must be explicitly
  specified, and match the dataset registered under `dataset_id`. `element_spec`
  defaults to `None` so that in the future we can support automatically
  discovering the `element_spec` by querying the tf.data service.

  `tf.data.experimental.service.distribute` is a convenience method which
  combines `register_dataset` and `from_dataset_id` into a dataset
  transformation.
  See the documentation for `tf.data.experimental.service.distribute` for more
  detail about how `from_dataset_id` works.

  >>> dispatcher = tf.data.experimental.service.DispatchServer()
  >>> dispatcher_address = dispatcher.target.split("://")[1]
  >>> worker = tf.data.experimental.service.WorkerServer(
  ...     tf.data.experimental.service.WorkerConfig(
  ...         dispatcher_address=dispatcher_address))
  >>> dataset = tf.data.Dataset.range(10)
  >>> dataset_id = tf.data.experimental.service.register_dataset(
  ...     dispatcher.target, dataset)
  >>> dataset = tf.data.experimental.service.from_dataset_id(
  ...     processing_mode="parallel_epochs",
  ...     service=dispatcher.target,
  ...     dataset_id=dataset_id,
  ...     element_spec=dataset.element_spec)
  >>> print(list(dataset.as_numpy_iterator()))
  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

  Args:
    processing_mode: A string specifying the policy for how data should be
      processed by tf.data workers. Can be either "parallel_epochs" to have each
      tf.data worker process a copy of the dataset, or "distributed_epoch" to
      split a single iteration of the dataset across all the workers.
    service: A string or a tuple indicating how to connect to the tf.data
      service. If it's a string, it should be in the format
      `[<protocol>://]<address>`, where `<address>` identifies the dispatcher
      address and `<protocol>` can optionally be used to override the default
      protocol to use. If it's a tuple, it should be (protocol, address).
    dataset_id: The id of the dataset to read from. This id is returned by
      `register_dataset` when the dataset is registered with the tf.data
      service.
    element_spec: A nested structure of `tf.TypeSpec`s representing the type of
      elements produced by the dataset. This argument is only required inside a
      tf.function. Use `tf.data.Dataset.element_spec` to get the element spec
      for a given dataset.
    job_name: (Optional.) The name of the job. If provided, it must be a
      non-empty string. This argument makes it possible
      for multiple datasets to share the same job. The default behavior is that
      the dataset creates anonymous, exclusively owned jobs.
    consumer_index: (Optional.) The index of the consumer in the range from `0`
      to `num_consumers`. Must be specified alongside `num_consumers`. When
      specified, consumers will read from the job in a strict round-robin order,
      instead of the default first-come-first-served order.
    num_consumers: (Optional.) The number of consumers which will consume from
      the job. Must be specified alongside `consumer_index`. When specified,
      consumers will read from the job in a strict round-robin order, instead of
      the default first-come-first-served order. When `num_consumers` is
      specified, the dataset must have infinite cardinality to prevent a
      producer from running out of data early and causing consumers to go out of
      sync.
    max_outstanding_requests: (Optional.) A limit on how many elements may be
      requested at the same time. You can use this option to control the amount
      of memory used, since `distribute` won't use more than `element_size` *
      `max_outstanding_requests` of memory.
    data_transfer_protocol: (Optional.) The protocol to use for transferring
      data with the tf.data service. By default, data is transferred using gRPC.
    target_workers: (Optional.) Which workers to read from. If `"AUTO"`, tf.data
      runtime decides which workers to read from. If `"ANY"`, reads from any
      tf.data service workers. If `"LOCAL"`, only reads from local in-processs
      tf.data service workers. `"AUTO"` works well for most cases, while users
      can specify other targets. For example, `"LOCAL"` helps avoid RPCs and
      data copy if every TF worker colocates with a tf.data service worker.
      Defaults to `"AUTO"`.

    EASL:
    max_request_pipelining_per_worker: (Optional.) We add this parameter to increase
      the number of parallel request a client can send to a single worker. Defaults
      to 1 to default to original behaviour.

  Returns:
    A `tf.data.Dataset` which reads from the tf.data service.
  """
    _check_job_name(job_name)
    if job_name is not None:
        job_name = string_ops.string_join(
            ["dataset_id=",
             string_ops.as_string(dataset_id), job_name], "/")

    return _from_dataset_id(
        processing_mode=processing_mode,
        service=service,
        dataset_id=dataset_id,
        element_spec=element_spec,
        job_name=job_name,
        consumer_index=consumer_index,
        num_consumers=num_consumers,
        max_outstanding_requests=max_outstanding_requests,
        max_request_pipelining_per_worker=max_request_pipelining_per_worker,
        data_transfer_protocol=data_transfer_protocol,
        target_workers=target_workers)
Example #41
0
  def save(self, file_prefix, options=None):
    """Save the saveable objects to a checkpoint with `file_prefix`.

    Args:
      file_prefix: A string or scalar string Tensor containing the prefix to
        save under.
      options: Optional `CheckpointOptions` object.
    Returns:
      An `Operation`, or None when executing eagerly.
    """
    options = options or checkpoint_options.CheckpointOptions()

    # IMPLEMENTATION DETAILS: most clients should skip.
    #
    # Suffix for any well-formed "checkpoint_prefix", when sharded.
    # Transformations:
    # * Users pass in "save_path" in save() and restore().  Say "myckpt".
    # * checkpoint_prefix gets fed <save_path><sharded_suffix>.
    #
    # Example:
    #   During runtime, a temporary directory is first created, which contains
    #   files
    #
    #     <train dir>/myckpt_temp/
    #        part-?????-of-?????{.index, .data-00000-of-00001}
    #
    #   Before .save() finishes, they will be (hopefully, atomically) renamed to
    #
    #     <train dir>/
    #        myckpt{.index, .data-?????-of-?????}
    #
    #   Filesystems with eventual consistency (such as S3), don't need a
    #   temporary location. Using a temporary directory in those cases might
    #   cause situations where files are not available during copy.
    #
    # Users only need to interact with the user-specified prefix, which is
    # "<train dir>/myckpt" in this case.  Save() and Restore() work with the
    # prefix directly, instead of any physical pathname.  (On failure and
    # subsequent restore, an outdated and orphaned temporary directory can be
    # safely removed.)
    with ops.device("CPU"):
      sharded_suffix = array_ops.where(
          string_ops.regex_full_match(file_prefix, "^s3://.*"),
          constant_op.constant(".part"),
          constant_op.constant("_temp/part"))
      tmp_checkpoint_prefix = string_ops.string_join(
          [file_prefix, sharded_suffix])
      registered_paths = {
          saver_name: registered_saver_filename(file_prefix, saver_name)
          for saver_name in self._registered_savers
      }

    def save_fn():
      saved_prefixes = []
      # Save with the registered savers. These run before default savers due to
      # the API contract.
      for saver_name, (save_fn, _) in self._registered_savers.items():
        maybe_saved_prefixes = save_fn(registered_paths[saver_name])
        if maybe_saved_prefixes is not None:
          flattened_saved_prefixes = nest.flatten(maybe_saved_prefixes)
          if not all(
              tensor_util.is_tf_type(x) and x.dtype == dtypes.string
              for x in flattened_saved_prefixes):
            raise ValueError(
                "Registered saver must return a (maybe empty) list of "
                f"string type tensors. Got {maybe_saved_prefixes}.")
          saved_prefixes.extend(flattened_saved_prefixes)

      # (Default saver) Save with single device savers.
      num_shards = len(self._single_device_savers)
      sharded_saves = []
      num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
      last_device = None
      for shard, (device, saver) in enumerate(
          sorted(self._single_device_savers.items())):
        last_device = device
        with ops.device(saveable_object_util.set_cpu0(device)):
          shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard,
                                          num_shards_tensor)
        saved_prefixes.append(shard_prefix)
        with ops.device(device):
          # _SingleDeviceSaver will use the CPU device when necessary, but
          # initial read operations should be placed on the SaveableObject's
          # device.
          sharded_saves.append(saver.save(shard_prefix, options))

      with ops.control_dependencies(sharded_saves):
        # Merge on the io_device if specified, otherwise co-locates the merge op
        # with the last device used.
        merge_device = (
            options.experimental_io_device or
            saveable_object_util.set_cpu0(last_device))
        with ops.device(merge_device):
          # V2 format write path consists of a metadata merge step.  Once
          # merged, attempts to delete the temporary directory,
          # "<user-fed prefix>_temp".
          return gen_io_ops.merge_v2_checkpoints(
              saved_prefixes, file_prefix, delete_old_dirs=True)

    # Since this will causes a function re-trace on each save, limit this to the
    # cases where it is needed: eager and when there are multiple tasks/single
    # device savers. Note that the retrace is needed to ensure we pickup the
    # latest values of options like experimental_io_device.
    if context.executing_eagerly() and len(self._single_device_savers) > 1:
      # Explicitly place the identity op on the first device.
      @def_function.function(jit_compile=False)
      def tf_function_save():
        save_fn()
      tf_function_save()
    else:
      return save_fn()
  def testStateSaverWithTwoSimpleSteps(self):
    with self.test_session() as sess:
      batch_size_value = 2
      batch_size = constant_op.constant(batch_size_value)
      num_unroll = 2
      length = 3
      key = string_ops.string_join([
          "key_", string_ops.as_string(
              math_ops.cast(10000 * random_ops.random_uniform(()),
                            dtypes.int32))
      ])
      padded_length = 4
      sequences = {
          "seq1": np.random.rand(padded_length, 5),
          "seq2": np.random.rand(padded_length, 4, 2)
      }
      context = {"context1": [3, 4]}
      initial_states = {
          "state1": np.random.rand(6, 7),
          "state2": np.random.rand(8)
      }
      state_saver = sqss.SequenceQueueingStateSaver(
          batch_size=batch_size,
          num_unroll=num_unroll,
          input_length=length,
          input_key=key,
          input_sequences=sequences,
          input_context=context,
          initial_states=initial_states,
          capacity=100)

      initial_key_value_0, _ = sess.run((key, state_saver.prefetch_op))
      initial_key_value_1, _ = sess.run((key, state_saver.prefetch_op))

      initial_key_value_0 = initial_key_value_0.decode("ascii")
      initial_key_value_1 = initial_key_value_1.decode("ascii")

      # Step 1
      next_batch = state_saver.next_batch
      (key_value, next_key_value, seq1_value, seq2_value, context1_value,
       state1_value, state2_value, length_value, _, _) = sess.run(
           (next_batch.key, next_batch.next_key, next_batch.sequences["seq1"],
            next_batch.sequences["seq2"], next_batch.context["context1"],
            next_batch.state("state1"), next_batch.state("state2"),
            next_batch.length,
            next_batch.save_state("state1", next_batch.state("state1") + 1),
            next_batch.save_state("state2", next_batch.state("state2") - 1)))

      expected_first_keys = set(
          ("00000_of_00002:%s" % x).encode("ascii")
          for x in (initial_key_value_0, initial_key_value_1))
      expected_second_keys = set(
          ("00001_of_00002:%s" % x).encode("ascii")
          for x in (initial_key_value_0, initial_key_value_1))
      expected_final_keys = set(
          ("STOP:%s" % x).encode("ascii")
          for x in (initial_key_value_0, initial_key_value_1))

      self.assertEqual(set(key_value), expected_first_keys)
      self.assertEqual(set(next_key_value), expected_second_keys)
      self.assertAllEqual(context1_value,
                          np.tile(context["context1"], (batch_size_value, 1)))
      self.assertAllEqual(seq1_value,
                          np.tile(sequences["seq1"][np.newaxis, 0:2, :],
                                  (batch_size_value, 1, 1)))
      self.assertAllEqual(seq2_value,
                          np.tile(sequences["seq2"][np.newaxis, 0:2, :, :],
                                  (batch_size_value, 1, 1, 1)))
      self.assertAllEqual(state1_value,
                          np.tile(initial_states["state1"],
                                  (batch_size_value, 1, 1)))
      self.assertAllEqual(state2_value,
                          np.tile(initial_states["state2"],
                                  (batch_size_value, 1)))
      self.assertAllEqual(length_value, [2, 2])

      # Step 2
      (key_value, next_key_value, seq1_value, seq2_value, context1_value,
       state1_value, state2_value, length_value, _, _) = sess.run(
           (next_batch.key, next_batch.next_key, next_batch.sequences["seq1"],
            next_batch.sequences["seq2"], next_batch.context["context1"],
            next_batch.state("state1"), next_batch.state("state2"),
            next_batch.length,
            next_batch.save_state("state1", next_batch.state("state1") + 1),
            next_batch.save_state("state2", next_batch.state("state2") - 1)))

      self.assertEqual(set(key_value), expected_second_keys)
      self.assertEqual(set(next_key_value), expected_final_keys)
      self.assertAllEqual(context1_value,
                          np.tile(context["context1"], (batch_size_value, 1)))
      self.assertAllEqual(seq1_value,
                          np.tile(sequences["seq1"][np.newaxis, 2:4, :],
                                  (batch_size_value, 1, 1)))
      self.assertAllEqual(seq2_value,
                          np.tile(sequences["seq2"][np.newaxis, 2:4, :, :],
                                  (batch_size_value, 1, 1, 1)))
      self.assertAllEqual(state1_value, 1 + np.tile(initial_states["state1"],
                                                    (batch_size_value, 1, 1)))
      self.assertAllEqual(state2_value, -1 + np.tile(initial_states["state2"],
                                                     (batch_size_value, 1)))
      self.assertAllEqual(length_value, [1, 1])

      # Finished.  Let's make sure there's nothing left in the barrier.
      self.assertEqual(0, state_saver.barrier.ready_size().eval())