def _build_input_pipeline(self, name, num_outputs):
   with ops.name_scope(name):
     ds = dataset_ops.Dataset.range(num_outputs).shuffle(
         10, reshuffle_each_iteration=False).prefetch(10)
     iterator = ds.make_initializable_iterator()
     saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
     ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
     return iterator.initializer, iterator.get_next()
  def _build_graph(self, ds_fn):
    iterator = ds_fn().make_initializable_iterator()

    saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
    ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
    init_op = iterator.initializer
    get_next = iterator.get_next()
    self._add_iterator_ops_to_collection(init_op, get_next)
    saver = saver_lib.Saver(allow_empty=True)
    return init_op, get_next, saver
示例#3
0
  def _build_graph(self, ds_fn):
    iterator = ds_fn().make_initializable_iterator()

    saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
    ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
    init_op = iterator.initializer
    get_next = iterator.get_next()
    self._add_iterator_ops_to_collection(init_op, get_next)
    saver = saver_lib.Saver(allow_empty=True)
    return init_op, get_next, saver
 def _build_graph_tensor_slices(self, components):
   iterator = dataset_ops.Dataset.from_tensor_slices(
       components).make_initializable_iterator()
   init_op = iterator.initializer
   get_next = iterator.get_next()
   saveable = iterator_ops.make_saveable_from_iterator(iterator)
   ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
   for t in nest.flatten(get_next):
     ops.add_to_collection("get_next", t)
   return init_op, get_next
 def _build_empty_graph(self, ds_fn, sparse_tensors=False):
   iterator = iterator_ops.Iterator.from_structure(
       self._get_output_types(ds_fn), self._get_output_shapes(ds_fn))
   saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
   ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
   if sparse_tensors:
     get_next = sparse_tensor.SparseTensor(*iterator.get_next())
   else:
     get_next = iterator.get_next()
   saver = saver_lib.Saver(allow_empty=True)
   return get_next, saver
示例#6
0
 def _build_empty_graph(self, ds_fn, sparse_tensors=False):
     iterator = iterator_ops.Iterator.from_structure(
         self._get_output_types(ds_fn), self._get_output_shapes(ds_fn))
     saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
     ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
     if sparse_tensors:
         get_next = sparse_tensor.SparseTensor(*iterator.get_next())
     else:
         get_next = iterator.get_next()
     saver = saver_lib.Saver(allow_empty=True)
     return get_next, saver
 def _build_graph(start, stop):
   iterator = dataset_ops.Dataset.range(start,
                                        stop).make_initializable_iterator()
   init_op = iterator.initializer
   get_next = iterator.get_next()
   ops.add_to_collection("iterator_ops", init_op)
   ops.add_to_collection("iterator_ops", get_next)
   # Add the SaveableObject to the `SAVEABLE_OBJECTS` collection
   # so that it can be automatically picked up by the Saver.
   saveable_obj = contrib_iterator_ops.make_saveable_from_iterator(iterator)
   ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable_obj)
   saver = saver_lib.Saver()
   return init_op, get_next, saver
  def _build_graph(self, ds_fn, sparse_tensors=False):
    iterator = ds_fn().make_initializable_iterator()

    saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
    ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
    init_op = iterator.initializer
    if sparse_tensors:
      get_next = sparse_tensor.SparseTensor(*iterator.get_next())
    else:
      get_next = iterator.get_next()
    self._add_iterator_ops_to_collection(init_op, get_next, ds_fn,
                                         sparse_tensors)
    saver = saver_lib.Saver(allow_empty=True)
    return init_op, get_next, saver
示例#9
0
    def _build_graph(self, ds_fn, sparse_tensors=False):
        iterator = ds_fn().make_initializable_iterator()

        saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
        ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
        init_op = iterator.initializer
        if sparse_tensors:
            get_next = sparse_tensor.SparseTensor(*iterator.get_next())
        else:
            get_next = iterator.get_next()
        self._add_iterator_ops_to_collection(init_op, get_next, ds_fn,
                                             sparse_tensors)
        saver = saver_lib.Saver(allow_empty=True)
        return init_op, get_next, saver
 def _build_graph(start, stop):
     iterator = dataset_ops.Dataset.range(
         start, stop).make_initializable_iterator()
     init_op = iterator.initializer
     get_next = iterator.get_next()
     ops.add_to_collection("iterator_ops", init_op)
     ops.add_to_collection("iterator_ops", get_next)
     # Add the SaveableObject to the `SAVEABLE_OBJECTS` collection
     # so that it can be automatically picked up by the Saver.
     saveable_obj = contrib_iterator_ops.make_saveable_from_iterator(
         iterator)
     ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable_obj)
     saver = saver_lib.Saver()
     return init_op, get_next, saver
 def _build_graph(self, input_components, to_concatenate_components):
   input_dataset = dataset_ops.Dataset.from_tensor_slices(input_components)
   dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices(
       to_concatenate_components)
   iterator = input_dataset.concatenate(
       dataset_to_concatenate).make_initializable_iterator()
   init_op = iterator.initializer
   get_next = iterator.get_next()
   saveable = iterator_ops.make_saveable_from_iterator(iterator)
   ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
   # TODO(shivaniagrawal) : non-intuitive way, add support in mata_graph
   for t in nest.flatten(get_next):
     ops.add_to_collection("get_next", t)
   return init_op, get_next
 def _build_graph(self, input_components, to_concatenate_components):
     input_dataset = dataset_ops.Dataset.from_tensor_slices(
         input_components)
     dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices(
         to_concatenate_components)
     iterator = input_dataset.concatenate(
         dataset_to_concatenate).make_initializable_iterator()
     init_op = iterator.initializer
     get_next = iterator.get_next()
     saveable = iterator_ops.make_saveable_from_iterator(iterator)
     ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
     # TODO(shivaniagrawal) : non-intuitive way, add support in mata_graph
     for t in nest.flatten(get_next):
         ops.add_to_collection("get_next", t)
     return init_op, get_next
 def _build_graph(self,
                  test_filenames,
                  compression_type=None,
                  build_saveable=True):
   ds = readers.TextLineDataset(
       test_filenames, compression_type=compression_type, buffer_size=10)
   iterator = ds.make_initializable_iterator()
   if build_saveable:
     saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
     ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
   init_op = iterator.initializer
   get_next = iterator.get_next()
   ops.add_to_collection("iterator_ops", init_op)
   ops.add_to_collection("iterator_ops", get_next)
   saver = saver_lib.Saver(allow_empty=True)
   return init_op, get_next, saver
示例#14
0
    def testMultipleIterators(self):
        range_limit = 5
        num_repeats = 2
        num_outputs = range_limit * num_repeats
        buffer_sizes = [1, 3, 5, 8, 10]

        for reshuffle_each_iteration in [True, False]:
            for buffer_size in buffer_sizes:

                def ds_fn():
                    # pylint: disable=cell-var-from-loop
                    return self._build_shuffle_dataset(
                        range_limit=range_limit,
                        num_repeats=num_repeats,
                        buffer_size=buffer_size,
                        seed=
                        None,  # Iterator seeds are generated non-deterministically.
                        reshuffle_each_iteration=reshuffle_each_iteration)
                    # pylint: enable=cell-var-from-loop

                with ops.Graph().as_default() as g:
                    ds = ds_fn()
                    iterators = [
                        ds.make_one_shot_iterator(),
                        ds.make_one_shot_iterator()
                    ]
                    get_next_ops = [it.get_next() for it in iterators]
                    saveables = [
                        contrib_iterator_ops.make_saveable_from_iterator(it)
                        for it in iterators
                    ]
                    for saveable in saveables:
                        ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS,
                                              saveable)
                    saver = saver_lib.Saver(allow_empty=True)
                    with self.test_session(graph=g) as sess:
                        self._save(sess, saver)
                        expected = [
                            sess.run(get_next_ops) for _ in range(num_outputs)
                        ]
                        self._restore(saver, sess)
                        actual = [
                            sess.run(get_next_ops) for _ in range(num_outputs)
                        ]
                        self.match(expected, actual)
 def _build_graph(self,
                  range_limit=10,
                  num_repeats=5,
                  buffer_size=5,
                  seed=None,
                  reshuffle_each_iteration=None,
                  build_saveable=True):
   iterator = dataset_ops.Dataset.range(range_limit).shuffle(
       buffer_size,
       seed=seed,
       reshuffle_each_iteration=reshuffle_each_iteration).repeat(
           num_repeats).make_initializable_iterator()
   if build_saveable:
     saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
     ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
   init_op = iterator.initializer
   get_next = iterator.get_next()
   ops.add_to_collection("iterator_ops", init_op)
   ops.add_to_collection("iterator_ops", get_next)
   saver = saver_lib.Saver(allow_empty=True)
   return init_op, get_next, saver
示例#16
0
 def _build_graph(self,
                  range_limit=10,
                  num_repeats=5,
                  buffer_size=5,
                  seed=None,
                  reshuffle_each_iteration=None,
                  build_saveable=True):
     iterator = dataset_ops.Dataset.range(range_limit).shuffle(
         buffer_size,
         seed=seed,
         reshuffle_each_iteration=reshuffle_each_iteration).repeat(
             num_repeats).make_initializable_iterator()
     if build_saveable:
         saveable = contrib_iterator_ops.make_saveable_from_iterator(
             iterator)
         ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
     init_op = iterator.initializer
     get_next = iterator.get_next()
     ops.add_to_collection("iterator_ops", init_op)
     ops.add_to_collection("iterator_ops", get_next)
     saver = saver_lib.Saver(allow_empty=True)
     return init_op, get_next, saver
  def _testSaveRestoreFromTensorsUtility(self, start, break_range, stop):
    path = self._iterator_checkpoint_prefix()
    step = 0
    meta_filename = path + "-%d.meta" % step

    components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))

    with ops.Graph().as_default() as g:
      iterator = (
          dataset_ops.Dataset.from_tensors(components)
          .make_initializable_iterator())
      init_op = iterator.initializer
      get_next = iterator.get_next()
      saveable = iterator_ops.make_saveable_from_iterator(iterator)
      ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
      for t in nest.flatten(get_next):
        ops.add_to_collection("get_next", t)
      saver = saver_lib.Saver()
      with self.test_session(graph=g) as sess:
        sess.run(init_op)
        for _ in range(start, break_range):
          result = sess.run(get_next)
          for component, result_component in zip(components, result):
            self.assertAllEqual(component, result_component)
        saver.save(sess, path, step)

    with ops.Graph().as_default() as g:
      saver = saver_lib.import_meta_graph(meta_filename)
      with self.test_session(graph=g) as sess:
        get_next = nest.pack_sequence_as(("a", "b", "c"),
                                         ops.get_collection("get_next"))
        saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir()))
        for _ in range(break_range, stop):
          result = sess.run(get_next)
          for component, result_component in zip(components, result):
            self.assertAllEqual(component, result_component)
        with self.assertRaises(errors.OutOfRangeError):
          sess.run(get_next)
  def testMultipleIterators(self):
    range_limit = 5
    num_repeats = 2
    num_outputs = range_limit * num_repeats
    buffer_sizes = [1, 3, 5, 8, 10]

    for reshuffle_each_iteration in [True, False]:
      for buffer_size in buffer_sizes:

        def ds_fn():
          # pylint: disable=cell-var-from-loop
          return self._build_shuffle_dataset(
              range_limit=range_limit,
              num_repeats=num_repeats,
              buffer_size=buffer_size,
              seed=None,  # Iterator seeds are generated non-deterministically.
              reshuffle_each_iteration=reshuffle_each_iteration)
          # pylint: enable=cell-var-from-loop

        with ops.Graph().as_default() as g:
          ds = ds_fn()
          iterators = [ds.make_one_shot_iterator(), ds.make_one_shot_iterator()]
          get_next_ops = [it.get_next() for it in iterators]
          saveables = [
              contrib_iterator_ops.make_saveable_from_iterator(it)
              for it in iterators
          ]
          for saveable in saveables:
            ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
          saver = saver_lib.Saver(allow_empty=True)
          with self.test_session(graph=g) as sess:
            self._save(sess, saver)
            expected = [sess.run(get_next_ops) for _ in range(num_outputs)]
            self._restore(saver, sess)
            actual = [sess.run(get_next_ops) for _ in range(num_outputs)]
            self.match(expected, actual)