def testGetGraphFromValidInputs(self): g0 = ops.Graph() with g0.as_default(): values = [constant_op.constant(0.0), constant_op.constant(1.0)] self.assertIs(g0, ops_lib.get_graph_from_inputs(values)) self.assertIs(g0, ops_lib.get_graph_from_inputs(values, g0)) with ops.Graph().as_default(): self.assertIs(g0, ops_lib.get_graph_from_inputs(values)) self.assertIs(g0, ops_lib.get_graph_from_inputs(values, g0))
def testGetGraphFromInvalidInputs(self): g0 = ops.Graph() with g0.as_default(): values = [constant_op.constant(0.0), constant_op.constant(1.0)] g1 = ops.Graph() with self.assertRaisesRegexp(ValueError, "not from the passed-in graph"): ops_lib.get_graph_from_inputs(values, g1) with g1.as_default(): values.append(constant_op.constant(2.0)) with self.assertRaisesRegexp(ValueError, "must be from the same graph"): ops_lib.get_graph_from_inputs(values) with self.assertRaisesRegexp(ValueError, "not from the passed-in graph"): ops_lib.get_graph_from_inputs(values, g0) with self.assertRaisesRegexp(ValueError, "not from the passed-in graph"): ops_lib.get_graph_from_inputs(values, g1)
def run_feeds_iter(output_dict, feed_dicts, restore_checkpoint_path=None): """Run `output_dict` tensors with each input in `feed_dicts`. If `restore_checkpoint_path` is supplied, restore from checkpoint. Otherwise, init all variables. Args: output_dict: A `dict` mapping string names to `Output` objects to run. Tensors must all be from the same graph. feed_dicts: Iterable of `dict` objects of input values to feed. restore_checkpoint_path: A string containing the path to a checkpoint to restore. Yields: A sequence of dicts of values read from `output_dict` tensors, one item yielded for each item in `feed_dicts`. Keys are the same as `output_dict`, values are the results read from the corresponding `Output` in `output_dict`. Raises: ValueError: if `output_dict` or `feed_dicts` is None or empty. """ if not output_dict: raise ValueError('output_dict is invalid: %s.' % output_dict) if not feed_dicts: raise ValueError('feed_dicts is invalid: %s.' % feed_dicts) graph = contrib_ops.get_graph_from_inputs(output_dict.values()) with graph.as_default() as g: with tf_session.Session('') as session: session.run( resources.initialize_resources(resources.shared_resources() + resources.local_resources())) if restore_checkpoint_path: _restore_from_checkpoint(session, g, restore_checkpoint_path) else: session.run(variables.global_variables_initializer()) session.run(variables.local_variables_initializer()) session.run(data_flow_ops.initialize_all_tables()) coord = coordinator.Coordinator() threads = None try: threads = queue_runner.start_queue_runners(session, coord=coord) for f in feed_dicts: yield session.run(output_dict, f) finally: coord.request_stop() if threads: coord.join(threads, stop_grace_period_secs=120)
def run_feeds_iter(output_dict, feed_dicts, restore_checkpoint_path=None): """Run `output_dict` tensors with each input in `feed_dicts`. If `restore_checkpoint_path` is supplied, restore from checkpoint. Otherwise, init all variables. Args: output_dict: A `dict` mapping string names to `Tensor` objects to run. Tensors must all be from the same graph. feed_dicts: Iterable of `dict` objects of input values to feed. restore_checkpoint_path: A string containing the path to a checkpoint to restore. Yields: A sequence of dicts of values read from `output_dict` tensors, one item yielded for each item in `feed_dicts`. Keys are the same as `output_dict`, values are the results read from the corresponding `Tensor` in `output_dict`. Raises: ValueError: if `output_dict` or `feed_dicts` is None or empty. """ if not output_dict: raise ValueError('output_dict is invalid: %s.' % output_dict) if not feed_dicts: raise ValueError('feed_dicts is invalid: %s.' % feed_dicts) graph = contrib_ops.get_graph_from_inputs(output_dict.values()) with graph.as_default() as g: with tf_session.Session('') as session: session.run( resources.initialize_resources(resources.shared_resources() + resources.local_resources())) if restore_checkpoint_path: _restore_from_checkpoint(session, g, restore_checkpoint_path) else: session.run(variables.global_variables_initializer()) session.run(variables.local_variables_initializer()) session.run(data_flow_ops.initialize_all_tables()) coord = coordinator.Coordinator() threads = None try: threads = queue_runner.start_queue_runners(session, coord=coord) for f in feed_dicts: yield session.run(output_dict, f) finally: coord.request_stop() if threads: coord.join(threads, stop_grace_period_secs=120)
def run_feeds(output_dict, feed_dicts, restore_checkpoint_path=None): """Run `output_dict` tensors with each input in `feed_dicts`. If `checkpoint_path` is supplied, restore from checkpoint. Otherwise, init all variables. Args: output_dict: A `dict` mapping string names to `Tensor` objects to run. Tensors must all be from the same graph. feed_dicts: Iterable of `dict` objects of input values to feed. restore_checkpoint_path: A string containing the path to a checkpoint to restore. Returns: A list of dicts of values read from `output_dict` tensors, one item in the list for each item in `feed_dicts`. Keys are the same as `output_dict`, values are the results read from the corresponding `Tensor` in `output_dict`. Raises: ValueError: if `output_dict` or `feed_dicts` is None or empty. """ if not output_dict: raise ValueError('output_dict is invalid: %s.' % output_dict) if not feed_dicts: raise ValueError('feed_dicts is invalid: %s.' % feed_dicts) graph = contrib_ops.get_graph_from_inputs(output_dict.values()) with graph.as_default() as g: with tf_session.Session('') as session: if restore_checkpoint_path: _restore_from_checkpoint(session, g, restore_checkpoint_path) else: session.run(variables.initialize_all_variables()) session.run(variables.initialize_local_variables()) session.run(data_flow_ops.initialize_all_tables()) coord = Coordinator() try: queue_runner.start_queue_runners(session, coord=coord) return [_run_dict(session, output_dict, f) for f in feed_dicts] finally: coord.request_stop()
def testGetGraphFromEmptyInputs(self): with ops.Graph().as_default() as g0: self.assertIs(g0, ops_lib.get_graph_from_inputs([]))