コード例 #1
0
    def testExportWithRandomSeeds(self):
        """Test the effect of setting random seeds on export."""
        params = model_registry.GetParams('test.LinearModelParams', 'Test')
        # Default -- use random_seed = None.
        inference_graph = inference_graph_exporter.InferenceGraphExporter.Export(
            params, subgraph_filter=['default'])
        pred = predictor.Predictor(inference_graph)
        [no_op_seed_1] = pred.Run(['output'], input=3)
        [no_op_seed_2] = pred.Run(['output'], input=3)
        self.assertNotEqual(no_op_seed_1, no_op_seed_2)
        pred = predictor.Predictor(inference_graph)
        [no_op_seed_3] = pred.Run(['output'], input=3)
        self.assertNotEqual(no_op_seed_1, no_op_seed_3)

        # Use a fixed random_seed.
        inference_graph = inference_graph_exporter.InferenceGraphExporter.Export(
            params, subgraph_filter=['default'], random_seed=1234)
        pred = predictor.Predictor(inference_graph)
        [fixed_op_seed_1] = pred.Run(['output'], input=3)
        [fixed_op_seed_2] = pred.Run(['output'], input=3)
        self.assertEqual(fixed_op_seed_1, fixed_op_seed_2)
        pred = predictor.Predictor(inference_graph)
        [fixed_op_seed_3] = pred.Run(['output'], input=3)
        self.assertEqual(fixed_op_seed_1, fixed_op_seed_3)

        # A different seed gives different results.
        inference_graph = inference_graph_exporter.InferenceGraphExporter.Export(
            params, subgraph_filter=['default'], random_seed=1235)
        pred = predictor.Predictor(inference_graph)
        [fixed_op_seed_4] = pred.Run(['output'], input=3)
        self.assertNotEqual(fixed_op_seed_1, fixed_op_seed_4)
コード例 #2
0
 def testInvalidFetchWithoutValidateFetchesReturnsNone(self):
   pred = predictor.Predictor(self._testInferenceGraph())
   fetch1, nonexistent = pred.Run(['fetch1', 'nonexistent'],
                                  feed1=[12345],
                                  validate_fetches=False)
   self.assertEqual(12345, fetch1)
   self.assertIsNone(nonexistent)
コード例 #3
0
 def testPredictorNoLoadGraphDefFromInferenceGraph(self):
     p = base_model.SingleTaskModel.Params(
         DummyModel.Params().Set(name='test'))
     p.input = base_input_generator.BaseInputGenerator.Params().Set(
         name='test')
     pred = predictor.Predictor(p.Instantiate().GetTask().Inference(),
                                load_graph_def_from_inference_graph=False)
     fetch1 = pred.Run('fetch1', feed1=[12345])
     self.assertEqual(12345, fetch1)
コード例 #4
0
    def __init__(self,
                 checkpoint,
                 output_dir,
                 inference_graph=None,
                 inference_subgraph_name='',
                 device_type='cpu',
                 output_num_shards=1,
                 output_shard_id=0,
                 max_inputs=0,
                 input_id_filter=None,
                 tf_master='local',
                 inference_threads=1,
                 batch_size=64,
                 prediction_step_interval=3000):
        """Constructor.

    Args:
      checkpoint: Either a checkpoint file to load, or a directory containing
        multiple checkpoints, where the latest checkpoint will be loaded.
      output_dir: Output directory. If `checkpoint` is a directory, a
        subdirectory will be created for each checkpoint evaluated.
      inference_graph: Path to an inference graph. If not specified, will be
        inferred from the checkpoint path.
      inference_subgraph_name: The name of the inference subgraph to use.
        Defaults to the default subgraph.
      device_type: Device type, either cpu, gpu, or tpu.
      output_num_shards: Each replica generates one shard of output according to
        `output_shard_id`.
      output_shard_id: The output shard id in range `[0, output_num_shards -
        1]`.
      max_inputs: Only process the first n inputs. 0 means process all inputs.
      input_id_filter: If not empty, only process the input ids in the given
        list.
      tf_master: tf_master for predictor session.
      inference_threads: Number of inference threads.
      batch_size: Batch size.
      prediction_step_interval: Number of steps between outputs. Only meaningful
        if `checkpoint` is a directory.
    """
        self._checkpoint = checkpoint
        self._output_dir = output_dir
        self._output_num_shards = output_num_shards
        self._output_shard_id = output_shard_id
        self._max_inputs = max_inputs
        input_id_filter = input_id_filter or []
        self._input_id_filter = [str(x) for x in input_id_filter]
        self._batch_size = batch_size
        self._prediction_step_interval = prediction_step_interval

        if device_type == 'tpu' and FLAGS.xla_device != 'tpu':
            raise ValueError(
                'xla_device=tpu should be set with device_type=tpu!')

        while True:
            if tf.gfile.IsDirectory(self._checkpoint):
                if tf.train.latest_checkpoint(self._checkpoint):
                    break
            elif tf.gfile.Exists(self._checkpoint + '.index'):
                break
            tf.logging.log_first_n(tf.logging.INFO,
                                   'Waiting for checkpoint to be available.',
                                   1)
            time.sleep(_RETRY_SLEEP_SECONDS)

        # Use saved inference graph.
        if inference_graph:
            self._inference_graph = inference_graph
        else:
            checkpoint_dir = self._checkpoint
            if not tf.gfile.IsDirectory(checkpoint_dir):
                checkpoint_dir = os.path.dirname(checkpoint_dir)
            logdir = os.path.dirname(checkpoint_dir)
            inference_graph_filename = 'inference.pbtxt'
            if device_type == 'tpu':
                inference_graph_filename = 'inference_tpu.pbtxt'
            self._inference_graph = os.path.join(logdir, 'inference_graphs',
                                                 inference_graph_filename)
        self._predictor = predictor.Predictor(
            inference_graph=self._inference_graph,
            subgraph_name=inference_subgraph_name,
            device_type=device_type,
            tf_master=tf_master)
        self._threadpool = concurrent.futures.ThreadPoolExecutor(
            inference_threads)
        self._locks = [threading.Lock() for _ in range(inference_threads)]
コード例 #5
0
 def testInvalidFetchRaisesKeyError(self):
   pred = predictor.Predictor(self._testInferenceGraph())
   with self.assertRaisesRegex(KeyError, 'nonexistent'):
     pred.Run(['fetch1', 'nonexistent'], feed1=[12345])
コード例 #6
0
 def testMissingFeedRaisesInvalidArgumentError(self):
   pred = predictor.Predictor(self._testInferenceGraph())
   with self.assertRaisesRegex(tf.errors.InvalidArgumentError, 'feed1'):
     pred.Run(['fetch1'])
コード例 #7
0
 def testPredictor(self):
   pred = predictor.Predictor(self._testInferenceGraph())
   [fetch1] = pred.Run(['fetch1'], feed1=[12345])
   self.assertEqual(12345, fetch1)
コード例 #8
0
 def testPredictorSubgraph(self):
     pred = predictor.Predictor(self._testInferenceGraph())
     fetch1 = pred.Run('fetch1',
                       feed1=[12345, 23456],
                       subgraph_name='subgraph2')
     self.assertAllEqual([12345, 23456], fetch1)
コード例 #9
0
 def testPredictorFetchShapes(self):
     pred = predictor.Predictor(self._testInferenceGraph())
     self.assertEqual([1], pred.fetch_shapes.fetch1)
     self.assertEqual([2], pred.subgraph_fetch_shapes('subgraph2').fetch1)
コード例 #10
0
 def testPredictorFeedShapes(self):
     pred = predictor.Predictor(self._testInferenceGraph())
     self.assertEqual([1], pred.feed_shapes.feed1)