def testExportWithRandomSeeds(self): """Test the effect of setting random seeds on export.""" params = model_registry.GetParams('test.LinearModelParams', 'Test') # Default -- use random_seed = None. inference_graph = inference_graph_exporter.InferenceGraphExporter.Export( params, subgraph_filter=['default']) pred = predictor.Predictor(inference_graph) [no_op_seed_1] = pred.Run(['output'], input=3) [no_op_seed_2] = pred.Run(['output'], input=3) self.assertNotEqual(no_op_seed_1, no_op_seed_2) pred = predictor.Predictor(inference_graph) [no_op_seed_3] = pred.Run(['output'], input=3) self.assertNotEqual(no_op_seed_1, no_op_seed_3) # Use a fixed random_seed. inference_graph = inference_graph_exporter.InferenceGraphExporter.Export( params, subgraph_filter=['default'], random_seed=1234) pred = predictor.Predictor(inference_graph) [fixed_op_seed_1] = pred.Run(['output'], input=3) [fixed_op_seed_2] = pred.Run(['output'], input=3) self.assertEqual(fixed_op_seed_1, fixed_op_seed_2) pred = predictor.Predictor(inference_graph) [fixed_op_seed_3] = pred.Run(['output'], input=3) self.assertEqual(fixed_op_seed_1, fixed_op_seed_3) # A different seed gives different results. inference_graph = inference_graph_exporter.InferenceGraphExporter.Export( params, subgraph_filter=['default'], random_seed=1235) pred = predictor.Predictor(inference_graph) [fixed_op_seed_4] = pred.Run(['output'], input=3) self.assertNotEqual(fixed_op_seed_1, fixed_op_seed_4)
def testInvalidFetchWithoutValidateFetchesReturnsNone(self): pred = predictor.Predictor(self._testInferenceGraph()) fetch1, nonexistent = pred.Run(['fetch1', 'nonexistent'], feed1=[12345], validate_fetches=False) self.assertEqual(12345, fetch1) self.assertIsNone(nonexistent)
def testPredictorNoLoadGraphDefFromInferenceGraph(self): p = base_model.SingleTaskModel.Params( DummyModel.Params().Set(name='test')) p.input = base_input_generator.BaseInputGenerator.Params().Set( name='test') pred = predictor.Predictor(p.Instantiate().GetTask().Inference(), load_graph_def_from_inference_graph=False) fetch1 = pred.Run('fetch1', feed1=[12345]) self.assertEqual(12345, fetch1)
def __init__(self, checkpoint, output_dir, inference_graph=None, inference_subgraph_name='', device_type='cpu', output_num_shards=1, output_shard_id=0, max_inputs=0, input_id_filter=None, tf_master='local', inference_threads=1, batch_size=64, prediction_step_interval=3000): """Constructor. Args: checkpoint: Either a checkpoint file to load, or a directory containing multiple checkpoints, where the latest checkpoint will be loaded. output_dir: Output directory. If `checkpoint` is a directory, a subdirectory will be created for each checkpoint evaluated. inference_graph: Path to an inference graph. If not specified, will be inferred from the checkpoint path. inference_subgraph_name: The name of the inference subgraph to use. Defaults to the default subgraph. device_type: Device type, either cpu, gpu, or tpu. output_num_shards: Each replica generates one shard of output according to `output_shard_id`. output_shard_id: The output shard id in range `[0, output_num_shards - 1]`. max_inputs: Only process the first n inputs. 0 means process all inputs. input_id_filter: If not empty, only process the input ids in the given list. tf_master: tf_master for predictor session. inference_threads: Number of inference threads. batch_size: Batch size. prediction_step_interval: Number of steps between outputs. Only meaningful if `checkpoint` is a directory. """ self._checkpoint = checkpoint self._output_dir = output_dir self._output_num_shards = output_num_shards self._output_shard_id = output_shard_id self._max_inputs = max_inputs input_id_filter = input_id_filter or [] self._input_id_filter = [str(x) for x in input_id_filter] self._batch_size = batch_size self._prediction_step_interval = prediction_step_interval if device_type == 'tpu' and FLAGS.xla_device != 'tpu': raise ValueError( 'xla_device=tpu should be set with device_type=tpu!') while True: if tf.gfile.IsDirectory(self._checkpoint): if tf.train.latest_checkpoint(self._checkpoint): break elif tf.gfile.Exists(self._checkpoint + '.index'): break tf.logging.log_first_n(tf.logging.INFO, 'Waiting for checkpoint to be available.', 1) time.sleep(_RETRY_SLEEP_SECONDS) # Use saved inference graph. if inference_graph: self._inference_graph = inference_graph else: checkpoint_dir = self._checkpoint if not tf.gfile.IsDirectory(checkpoint_dir): checkpoint_dir = os.path.dirname(checkpoint_dir) logdir = os.path.dirname(checkpoint_dir) inference_graph_filename = 'inference.pbtxt' if device_type == 'tpu': inference_graph_filename = 'inference_tpu.pbtxt' self._inference_graph = os.path.join(logdir, 'inference_graphs', inference_graph_filename) self._predictor = predictor.Predictor( inference_graph=self._inference_graph, subgraph_name=inference_subgraph_name, device_type=device_type, tf_master=tf_master) self._threadpool = concurrent.futures.ThreadPoolExecutor( inference_threads) self._locks = [threading.Lock() for _ in range(inference_threads)]
def testInvalidFetchRaisesKeyError(self): pred = predictor.Predictor(self._testInferenceGraph()) with self.assertRaisesRegex(KeyError, 'nonexistent'): pred.Run(['fetch1', 'nonexistent'], feed1=[12345])
def testMissingFeedRaisesInvalidArgumentError(self): pred = predictor.Predictor(self._testInferenceGraph()) with self.assertRaisesRegex(tf.errors.InvalidArgumentError, 'feed1'): pred.Run(['fetch1'])
def testPredictor(self): pred = predictor.Predictor(self._testInferenceGraph()) [fetch1] = pred.Run(['fetch1'], feed1=[12345]) self.assertEqual(12345, fetch1)
def testPredictorSubgraph(self): pred = predictor.Predictor(self._testInferenceGraph()) fetch1 = pred.Run('fetch1', feed1=[12345, 23456], subgraph_name='subgraph2') self.assertAllEqual([12345, 23456], fetch1)
def testPredictorFetchShapes(self): pred = predictor.Predictor(self._testInferenceGraph()) self.assertEqual([1], pred.fetch_shapes.fetch1) self.assertEqual([2], pred.subgraph_fetch_shapes('subgraph2').fetch1)
def testPredictorFeedShapes(self): pred = predictor.Predictor(self._testInferenceGraph()) self.assertEqual([1], pred.feed_shapes.feed1)