def testTpuBfloat16OverrideExport(self): """Test that we can export with tf.bfloat16 dtype.""" params = model_registry.GetParams('test.LinearModelTpuParams', 'Test') inference_graph = inference_graph_exporter.InferenceGraphExporter.Export( params, subgraph_filter=['tpu'], device_options=inference_graph_exporter.InferenceDeviceOptions( device='tpu', retain_device_placement=True, var_options='ON_DEVICE', gen_init_op=True, dtype_override=tf.bfloat16)) self.assertIn('tpu', inference_graph.subgraphs)
def testExportModelDoesNotAffectFlagsOnException(self): initial_flags = {k: tf.flags.FLAGS[k].value for k in tf.flags.FLAGS} params = model_registry.GetParams('test.DummyLegacyModelParams', 'Test') with self.assertRaises(NotImplementedError): inference_graph_exporter.InferenceGraphExporter.Export( params, device_options=inference_graph_exporter.InferenceDeviceOptions( device='tpu', retain_device_placement=False, var_options=None, gen_init_op=True, dtype_override=None)) self.assertDictEqual(initial_flags, {k: tf.flags.FLAGS[k].value for k in tf.flags.FLAGS})
def WriteInferenceGraph(self, cfg=None, prune_graph=True): """Generates the inference graphs for a given model. Args: cfg: Full `~.hyperparams.Params` for the model class. If present, this cfg will be used instead of retrieving from model_registry. prune_graph: If true, prune the graph to just the parts we need. Returns: InferenceGraph proto for cpu. """ inference_graph_dir = os.path.join(FLAGS.logdir, 'inference_graphs') tf.io.gfile.makedirs(inference_graph_dir) tf.logging.info('Writing inference graphs to dir: %s', inference_graph_dir) if not cfg: cfg = self.model_registry.GetParams(self._model_name, FLAGS.inference_dataset_name) task_names = [FLAGS.model_task_name] if (issubclass(cfg.cls, base_model.MultiTaskModel) and not FLAGS.model_task_name): task_names = base_model.MultiTaskModel.TaskNames(cfg) inference_graph_proto = None if FLAGS.inference_graph_filename: # Custom inference graph. for task_name in task_names: filename_prefix = FLAGS.inference_graph_filename if task_name: filename_prefix = '%s_inference' % task_name filename_prefix = os.path.join(inference_graph_dir, filename_prefix) device = '' var_options = None if FLAGS.inference_graph_device == 'tpu': device = 'tpu' var_options = 'ON_DEVICE' device_options = inference_graph_exporter.InferenceDeviceOptions( device=device, retain_device_placement=False, var_options=var_options, gen_init_op=FLAGS.inference_gen_tpu_init_op, dtype_override=None, fprop_dtype_override=None) inference_graph_proto = ( self.inference_graph_exporter.InferenceGraphExporter. Export(model_cfg=cfg, model_task_name=task_name, device_options=device_options, export_path=filename_prefix + '.pbtxt', random_seed=FLAGS.inference_graph_random_seed, prune_graph=prune_graph)) else: for task_name in task_names: filename_prefix = 'inference' if task_name: filename_prefix = '%s_inference' % task_name filename_prefix = os.path.join(inference_graph_dir, filename_prefix) # Standard inference graph. try: inference_graph_proto = ( self.inference_graph_exporter.InferenceGraphExporter. Export(model_cfg=cfg, model_task_name=task_name, export_path=filename_prefix + '.pbtxt', random_seed=FLAGS.inference_graph_random_seed, prune_graph=prune_graph)) except NotImplementedError as e: tf.logging.error('Cannot write inference graph: %s', e) # TPU inference graph. Not all models support it so fail silently. try: device_options = self.inference_graph_exporter.InferenceDeviceOptions( device='tpu', retain_device_placement=False, var_options='ON_DEVICE', gen_init_op=FLAGS.inference_gen_tpu_init_op, dtype_override=None, fprop_dtype_override=None) self.inference_graph_exporter.InferenceGraphExporter.Export( model_cfg=cfg, model_task_name=task_name, device_options=device_options, export_path=filename_prefix + '_tpu.pbtxt', random_seed=FLAGS.inference_graph_random_seed, prune_graph=prune_graph) except Exception as e: # pylint: disable=broad-except tf.logging.error( 'Error exporting TPU inference graph: %s' % e) if FLAGS.graph_def_filename and inference_graph_proto: for graph_def_filename in FLAGS.graph_def_filename: tf.logging.info('Writing graphdef: %s', graph_def_filename) dir_path = os.path.dirname(graph_def_filename) if (not tf.io.gfile.exists(dir_path) or not tf.io.gfile.isdir(dir_path)): tf.io.gfile.makedirs(dir_path) with tf.io.gfile.GFile(graph_def_filename, 'w') as f: f.write( text_format.MessageToString( inference_graph_proto.graph_def)) return inference_graph_proto