示例#1
0
 def testGetGraphFromValidInputs(self):
   g0 = ops.Graph()
   with g0.as_default():
     values = [constant_op.constant(0.0), constant_op.constant(1.0)]
   self.assertIs(g0, ops_lib.get_graph_from_inputs(values))
   self.assertIs(g0, ops_lib.get_graph_from_inputs(values, g0))
   with ops.Graph().as_default():
     self.assertIs(g0, ops_lib.get_graph_from_inputs(values))
     self.assertIs(g0, ops_lib.get_graph_from_inputs(values, g0))
示例#2
0
 def testGetGraphFromValidInputs(self):
     g0 = ops.Graph()
     with g0.as_default():
         values = [constant_op.constant(0.0), constant_op.constant(1.0)]
     self.assertIs(g0, ops_lib.get_graph_from_inputs(values))
     self.assertIs(g0, ops_lib.get_graph_from_inputs(values, g0))
     with ops.Graph().as_default():
         self.assertIs(g0, ops_lib.get_graph_from_inputs(values))
         self.assertIs(g0, ops_lib.get_graph_from_inputs(values, g0))
示例#3
0
 def testGetGraphFromInvalidInputs(self):
   g0 = ops.Graph()
   with g0.as_default():
     values = [constant_op.constant(0.0), constant_op.constant(1.0)]
   g1 = ops.Graph()
   with self.assertRaisesRegexp(ValueError, "not from the passed-in graph"):
     ops_lib.get_graph_from_inputs(values, g1)
   with g1.as_default():
     values.append(constant_op.constant(2.0))
   with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
     ops_lib.get_graph_from_inputs(values)
   with self.assertRaisesRegexp(ValueError, "not from the passed-in graph"):
     ops_lib.get_graph_from_inputs(values, g0)
   with self.assertRaisesRegexp(ValueError, "not from the passed-in graph"):
     ops_lib.get_graph_from_inputs(values, g1)
示例#4
0
def run_feeds_iter(output_dict, feed_dicts, restore_checkpoint_path=None):
    """Run `output_dict` tensors with each input in `feed_dicts`.

  If `restore_checkpoint_path` is supplied, restore from checkpoint. Otherwise,
  init all variables.

  Args:
    output_dict: A `dict` mapping string names to `Output` objects to run.
      Tensors must all be from the same graph.
    feed_dicts: Iterable of `dict` objects of input values to feed.
    restore_checkpoint_path: A string containing the path to a checkpoint to
      restore.

  Yields:
    A sequence of dicts of values read from `output_dict` tensors, one item
    yielded for each item in `feed_dicts`. Keys are the same as `output_dict`,
    values are the results read from the corresponding `Output` in
    `output_dict`.

  Raises:
    ValueError: if `output_dict` or `feed_dicts` is None or empty.
  """
    if not output_dict:
        raise ValueError('output_dict is invalid: %s.' % output_dict)
    if not feed_dicts:
        raise ValueError('feed_dicts is invalid: %s.' % feed_dicts)

    graph = contrib_ops.get_graph_from_inputs(output_dict.values())
    with graph.as_default() as g:
        with tf_session.Session('') as session:
            session.run(
                resources.initialize_resources(resources.shared_resources() +
                                               resources.local_resources()))
            if restore_checkpoint_path:
                _restore_from_checkpoint(session, g, restore_checkpoint_path)
            else:
                session.run(variables.global_variables_initializer())
            session.run(variables.local_variables_initializer())
            session.run(data_flow_ops.initialize_all_tables())
            coord = coordinator.Coordinator()
            threads = None
            try:
                threads = queue_runner.start_queue_runners(session,
                                                           coord=coord)
                for f in feed_dicts:
                    yield session.run(output_dict, f)
            finally:
                coord.request_stop()
                if threads:
                    coord.join(threads, stop_grace_period_secs=120)
示例#5
0
def run_feeds_iter(output_dict, feed_dicts, restore_checkpoint_path=None):
  """Run `output_dict` tensors with each input in `feed_dicts`.

  If `restore_checkpoint_path` is supplied, restore from checkpoint. Otherwise,
  init all variables.

  Args:
    output_dict: A `dict` mapping string names to `Tensor` objects to run.
      Tensors must all be from the same graph.
    feed_dicts: Iterable of `dict` objects of input values to feed.
    restore_checkpoint_path: A string containing the path to a checkpoint to
      restore.

  Yields:
    A sequence of dicts of values read from `output_dict` tensors, one item
    yielded for each item in `feed_dicts`. Keys are the same as `output_dict`,
    values are the results read from the corresponding `Tensor` in
    `output_dict`.

  Raises:
    ValueError: if `output_dict` or `feed_dicts` is None or empty.
  """
  if not output_dict:
    raise ValueError('output_dict is invalid: %s.' % output_dict)
  if not feed_dicts:
    raise ValueError('feed_dicts is invalid: %s.' % feed_dicts)

  graph = contrib_ops.get_graph_from_inputs(output_dict.values())
  with graph.as_default() as g:
    with tf_session.Session('') as session:
      session.run(
          resources.initialize_resources(resources.shared_resources() +
                                         resources.local_resources()))
      if restore_checkpoint_path:
        _restore_from_checkpoint(session, g, restore_checkpoint_path)
      else:
        session.run(variables.global_variables_initializer())
      session.run(variables.local_variables_initializer())
      session.run(data_flow_ops.initialize_all_tables())
      coord = coordinator.Coordinator()
      threads = None
      try:
        threads = queue_runner.start_queue_runners(session, coord=coord)
        for f in feed_dicts:
          yield session.run(output_dict, f)
      finally:
        coord.request_stop()
        if threads:
          coord.join(threads, stop_grace_period_secs=120)
示例#6
0
 def testGetGraphFromInvalidInputs(self):
     g0 = ops.Graph()
     with g0.as_default():
         values = [constant_op.constant(0.0), constant_op.constant(1.0)]
     g1 = ops.Graph()
     with self.assertRaisesRegexp(ValueError,
                                  "not from the passed-in graph"):
         ops_lib.get_graph_from_inputs(values, g1)
     with g1.as_default():
         values.append(constant_op.constant(2.0))
     with self.assertRaisesRegexp(ValueError,
                                  "must be from the same graph"):
         ops_lib.get_graph_from_inputs(values)
     with self.assertRaisesRegexp(ValueError,
                                  "not from the passed-in graph"):
         ops_lib.get_graph_from_inputs(values, g0)
     with self.assertRaisesRegexp(ValueError,
                                  "not from the passed-in graph"):
         ops_lib.get_graph_from_inputs(values, g1)
def run_feeds(output_dict, feed_dicts, restore_checkpoint_path=None):
  """Run `output_dict` tensors with each input in `feed_dicts`.

  If `checkpoint_path` is supplied, restore from checkpoint. Otherwise, init all
  variables.

  Args:
    output_dict: A `dict` mapping string names to `Tensor` objects to run.
      Tensors must all be from the same graph.
    feed_dicts: Iterable of `dict` objects of input values to feed.
    restore_checkpoint_path: A string containing the path to a checkpoint to
      restore.

  Returns:
    A list of dicts of values read from `output_dict` tensors, one item in the
    list for each item in `feed_dicts`. Keys are the same as `output_dict`,
    values are the results read from the corresponding `Tensor` in
    `output_dict`.

  Raises:
    ValueError: if `output_dict` or `feed_dicts` is None or empty.
  """
  if not output_dict:
    raise ValueError('output_dict is invalid: %s.' % output_dict)
  if not feed_dicts:
    raise ValueError('feed_dicts is invalid: %s.' % feed_dicts)

  graph = contrib_ops.get_graph_from_inputs(output_dict.values())

  with graph.as_default() as g:
    with tf_session.Session('') as session:
      if restore_checkpoint_path:
        _restore_from_checkpoint(session, g, restore_checkpoint_path)
      else:
        session.run(variables.initialize_all_variables())
      session.run(variables.initialize_local_variables())
      session.run(data_flow_ops.initialize_all_tables())
      coord = Coordinator()
      try:
        queue_runner.start_queue_runners(session, coord=coord)
        return [_run_dict(session, output_dict, f) for f in feed_dicts]
      finally:
        coord.request_stop()
示例#8
0
def run_feeds(output_dict, feed_dicts, restore_checkpoint_path=None):
  """Run `output_dict` tensors with each input in `feed_dicts`.

  If `checkpoint_path` is supplied, restore from checkpoint. Otherwise, init all
  variables.

  Args:
    output_dict: A `dict` mapping string names to `Tensor` objects to run.
      Tensors must all be from the same graph.
    feed_dicts: Iterable of `dict` objects of input values to feed.
    restore_checkpoint_path: A string containing the path to a checkpoint to
      restore.

  Returns:
    A list of dicts of values read from `output_dict` tensors, one item in the
    list for each item in `feed_dicts`. Keys are the same as `output_dict`,
    values are the results read from the corresponding `Tensor` in
    `output_dict`.

  Raises:
    ValueError: if `output_dict` or `feed_dicts` is None or empty.
  """
  if not output_dict:
    raise ValueError('output_dict is invalid: %s.' % output_dict)
  if not feed_dicts:
    raise ValueError('feed_dicts is invalid: %s.' % feed_dicts)

  graph = contrib_ops.get_graph_from_inputs(output_dict.values())

  with graph.as_default() as g:
    with tf_session.Session('') as session:
      if restore_checkpoint_path:
        _restore_from_checkpoint(session, g, restore_checkpoint_path)
      else:
        session.run(variables.initialize_all_variables())
      session.run(variables.initialize_local_variables())
      session.run(data_flow_ops.initialize_all_tables())
      coord = Coordinator()
      try:
        queue_runner.start_queue_runners(session, coord=coord)
        return [_run_dict(session, output_dict, f) for f in feed_dicts]
      finally:
        coord.request_stop()
示例#9
0
 def testGetGraphFromEmptyInputs(self):
   with ops.Graph().as_default() as g0:
     self.assertIs(g0, ops_lib.get_graph_from_inputs([]))
示例#10
0
 def testGetGraphFromEmptyInputs(self):
     with ops.Graph().as_default() as g0:
         self.assertIs(g0, ops_lib.get_graph_from_inputs([]))