Esempio n. 1
0
 def testTpuBfloat16OverrideExport(self):
     """Test that we can export with tf.bfloat16 dtype."""
     params = model_registry.GetParams('test.LinearModelTpuParams', 'Test')
     inference_graph = inference_graph_exporter.InferenceGraphExporter.Export(
         params,
         subgraph_filter=['tpu'],
         device_options=inference_graph_exporter.InferenceDeviceOptions(
             device='tpu',
             retain_device_placement=True,
             var_options='ON_DEVICE',
             gen_init_op=True,
             dtype_override=tf.bfloat16))
     self.assertIn('tpu', inference_graph.subgraphs)
 def testExportModelDoesNotAffectFlagsOnException(self):
   initial_flags = {k: tf.flags.FLAGS[k].value for k in tf.flags.FLAGS}
   params = model_registry.GetParams('test.DummyLegacyModelParams', 'Test')
   with self.assertRaises(NotImplementedError):
     inference_graph_exporter.InferenceGraphExporter.Export(
         params,
         device_options=inference_graph_exporter.InferenceDeviceOptions(
             device='tpu',
             retain_device_placement=False,
             var_options=None,
             gen_init_op=True,
             dtype_override=None))
   self.assertDictEqual(initial_flags,
                        {k: tf.flags.FLAGS[k].value for k in tf.flags.FLAGS})
Esempio n. 3
0
    def WriteInferenceGraph(self, cfg=None, prune_graph=True):
        """Generates the inference graphs for a given model.

    Args:
      cfg: Full `~.hyperparams.Params` for the model class. If present, this cfg
        will be used instead of retrieving from model_registry.
      prune_graph: If true, prune the graph to just the parts we need.

    Returns:
      InferenceGraph proto for cpu.
    """
        inference_graph_dir = os.path.join(FLAGS.logdir, 'inference_graphs')
        tf.io.gfile.makedirs(inference_graph_dir)
        tf.logging.info('Writing inference graphs to dir: %s',
                        inference_graph_dir)

        if not cfg:
            cfg = self.model_registry.GetParams(self._model_name,
                                                FLAGS.inference_dataset_name)

        task_names = [FLAGS.model_task_name]
        if (issubclass(cfg.cls, base_model.MultiTaskModel)
                and not FLAGS.model_task_name):
            task_names = base_model.MultiTaskModel.TaskNames(cfg)

        inference_graph_proto = None

        if FLAGS.inference_graph_filename:
            # Custom inference graph.
            for task_name in task_names:
                filename_prefix = FLAGS.inference_graph_filename
                if task_name:
                    filename_prefix = '%s_inference' % task_name
                filename_prefix = os.path.join(inference_graph_dir,
                                               filename_prefix)

                device = ''
                var_options = None
                if FLAGS.inference_graph_device == 'tpu':
                    device = 'tpu'
                    var_options = 'ON_DEVICE'
                device_options = inference_graph_exporter.InferenceDeviceOptions(
                    device=device,
                    retain_device_placement=False,
                    var_options=var_options,
                    gen_init_op=FLAGS.inference_gen_tpu_init_op,
                    dtype_override=None,
                    fprop_dtype_override=None)
                inference_graph_proto = (
                    self.inference_graph_exporter.InferenceGraphExporter.
                    Export(model_cfg=cfg,
                           model_task_name=task_name,
                           device_options=device_options,
                           export_path=filename_prefix + '.pbtxt',
                           random_seed=FLAGS.inference_graph_random_seed,
                           prune_graph=prune_graph))
        else:
            for task_name in task_names:
                filename_prefix = 'inference'
                if task_name:
                    filename_prefix = '%s_inference' % task_name
                filename_prefix = os.path.join(inference_graph_dir,
                                               filename_prefix)

                # Standard inference graph.
                try:
                    inference_graph_proto = (
                        self.inference_graph_exporter.InferenceGraphExporter.
                        Export(model_cfg=cfg,
                               model_task_name=task_name,
                               export_path=filename_prefix + '.pbtxt',
                               random_seed=FLAGS.inference_graph_random_seed,
                               prune_graph=prune_graph))
                except NotImplementedError as e:
                    tf.logging.error('Cannot write inference graph: %s', e)

                # TPU inference graph. Not all models support it so fail silently.
                try:
                    device_options = self.inference_graph_exporter.InferenceDeviceOptions(
                        device='tpu',
                        retain_device_placement=False,
                        var_options='ON_DEVICE',
                        gen_init_op=FLAGS.inference_gen_tpu_init_op,
                        dtype_override=None,
                        fprop_dtype_override=None)
                    self.inference_graph_exporter.InferenceGraphExporter.Export(
                        model_cfg=cfg,
                        model_task_name=task_name,
                        device_options=device_options,
                        export_path=filename_prefix + '_tpu.pbtxt',
                        random_seed=FLAGS.inference_graph_random_seed,
                        prune_graph=prune_graph)
                except Exception as e:  # pylint: disable=broad-except
                    tf.logging.error(
                        'Error exporting TPU inference graph: %s' % e)

        if FLAGS.graph_def_filename and inference_graph_proto:
            for graph_def_filename in FLAGS.graph_def_filename:
                tf.logging.info('Writing graphdef: %s', graph_def_filename)
                dir_path = os.path.dirname(graph_def_filename)
                if (not tf.io.gfile.exists(dir_path)
                        or not tf.io.gfile.isdir(dir_path)):
                    tf.io.gfile.makedirs(dir_path)
                with tf.io.gfile.GFile(graph_def_filename, 'w') as f:
                    f.write(
                        text_format.MessageToString(
                            inference_graph_proto.graph_def))

        return inference_graph_proto