Esempio n. 1
0
 def testListDevices(self):
     with session.Session() as sess:
         devices = sess.list_devices()
         self.assertTrue(
             '/job:localhost/replica:0/task:0/device:CPU:0'
             in set([d.name for d in devices]), devices)
Esempio n. 2
0
 def testRunAndPartialRunDirect(self):
     self.RunTestRunAndPartialRun(session.Session())
Esempio n. 3
0
 def testPartialRunUnspecifiedFetchDirect(self):
     self.RunTestPartialRunUnspecifiedFetch(session.Session())
    def test1Workers2Period(self):
        num_workers = 1
        communication_period = 2
        num_ps = 1
        cluster, workers, _ = create_local_cluster(num_workers=num_workers,
                                                   num_ps=num_ps)

        sessions, graphs, train_ops, savers = _get_workers(
            num_workers, communication_period, workers, 1.0)

        var_0 = graphs[0].get_tensor_by_name("v0:0")
        var_1 = graphs[0].get_tensor_by_name("v1:0")
        global_step = training_util.get_global_step(graphs[0])
        var_0_g = graphs[0].get_tensor_by_name(GLOBAL_VARIABLE_NAME + "/v0:0")
        var_1_g = graphs[0].get_tensor_by_name(GLOBAL_VARIABLE_NAME + "/v1:0")
        # Verify the initialized value.
        self.assertAllEqual(0.0, sessions[0].run(var_0))
        self.assertAllEqual(1.0, sessions[0].run(var_1))
        self.assertAllEqual(0.0, sessions[0].run(var_0_g))
        self.assertAllEqual(1.0, sessions[0].run(var_1_g))
        self.assertAllEqual(0, sessions[0].run(global_step))

        sessions[0].run(train_ops[0])

        self.assertAllEqual(1.0, sessions[0].run(var_0))
        self.assertAllEqual(2.0, sessions[0].run(var_1))
        self.assertAllEqual(0.0, sessions[0].run(var_0_g))
        self.assertAllEqual(1.0, sessions[0].run(var_1_g))
        self.assertAllEqual(0, sessions[0].run(global_step))

        # iteration 2, global variable update
        sessions[0].run(train_ops[0])

        self.assertAllEqual(0.0, sessions[0].run(var_0))
        self.assertAllEqual(1.0, sessions[0].run(var_1))
        self.assertAllEqual(2.0, sessions[0].run(var_0_g))
        self.assertAllEqual(3.0, sessions[0].run(var_1_g))
        self.assertAllEqual(1, sessions[0].run(global_step))

        # iteration 3
        sessions[0].run(train_ops[0])

        self.assertAllEqual(1.0, sessions[0].run(var_0))
        self.assertAllEqual(2.0, sessions[0].run(var_1))
        self.assertAllEqual(2.0, sessions[0].run(var_0_g))
        self.assertAllEqual(3.0, sessions[0].run(var_1_g))
        self.assertAllEqual(1, sessions[0].run(global_step))
        sessions[0].run(train_ops[0])

        # save, data will be global value
        outfile = os.path.join(test.get_temp_dir(), "model")
        savers[0].save(sessions[0]._sess._sess._sess._sess, save_path=outfile)
        ops.reset_default_graph()  # restore on a new graph
        with session.Session() as sess:
            v0 = variable_scope.get_variable(initializer=0.0, name="v0")
            v1 = variable_scope.get_variable(initializer=1.0, name="v1")
            sess.run(variables.local_variables_initializer())
            saver_opt = saver.Saver(var_list=[v1, v0])
            saver_opt.restore(sess, outfile)
            self.assertAllEqual(2.0, sess.run(v0))
            self.assertAllEqual(3.0, sess.run(v1))
Esempio n. 5
0
 def testConcurrentPartialRunDirect(self):
     self.RunTestConcurrentPartialRun(session.Session())
 def testConstructWrapper(self):
     local_cli_wrapper.LocalCLIDebugWrapperSession(session.Session(),
                                                   log_usage=False)
Esempio n. 7
0
def freeze_graph_with_def_protos(input_graph_def,
                                 input_saver_def,
                                 input_checkpoint,
                                 output_node_names,
                                 restore_op_name,
                                 filename_tensor_name,
                                 output_graph,
                                 clear_devices,
                                 initializer_nodes,
                                 variable_names_whitelist="",
                                 variable_names_blacklist="",
                                 input_meta_graph_def=None,
                                 input_saved_model_dir=None,
                                 saved_model_tags=None,
                                 checkpoint_version=saver_pb2.SaverDef.V2):
    """Converts all variables in a graph and checkpoint into constants.

  Args:
    input_graph_def: A `GraphDef`.
    input_saver_def: A `SaverDef` (optional).
    input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking
      priority.  Typically the result of `Saver.save()` or that of
      `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
      V1/V2.
    output_node_names: The name(s) of the output nodes, comma separated.
    restore_op_name: Unused.
    filename_tensor_name: Unused.
    output_graph: String where to write the frozen `GraphDef`.
    clear_devices: A Bool whether to remove device specifications.
    initializer_nodes: Comma separated string of initializer nodes to run before
                       freezing.
    variable_names_whitelist: The set of variable names to convert (optional, by
                              default, all variables are converted).
    variable_names_blacklist: The set of variable names to omit converting
                              to constants (optional).
    input_meta_graph_def: A `MetaGraphDef` (optional),
    input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file
                           and variables (optional).
    saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to
                      load, in string format (optional).
    checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1
                        or saver_pb2.SaverDef.V2)

  Returns:
    Location of the output_graph_def.
  """
    del restore_op_name, filename_tensor_name  # Unused by updated loading code.

    # 'input_checkpoint' may be a prefix if we're using Saver V2 format
    if (not input_saved_model_dir
            and not checkpoint_management.checkpoint_exists(input_checkpoint)):
        print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
        return -1

    if not output_node_names:
        print("You need to supply the name of a node to --output_node_names.")
        return -1

    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
        if input_meta_graph_def:
            for node in input_meta_graph_def.graph_def.node:
                node.device = ""
        elif input_graph_def:
            for node in input_graph_def.node:
                node.device = ""

    if input_graph_def:
        _ = importer.import_graph_def(input_graph_def, name="")
    with session.Session() as sess:
        if input_saver_def:
            saver = saver_lib.Saver(saver_def=input_saver_def,
                                    write_version=checkpoint_version)
            saver.restore(sess, input_checkpoint)
        elif input_meta_graph_def:
            restorer = saver_lib.import_meta_graph(input_meta_graph_def,
                                                   clear_devices=True)
            restorer.restore(sess, input_checkpoint)
            if initializer_nodes:
                sess.run(initializer_nodes.replace(" ", "").split(","))
        elif input_saved_model_dir:
            if saved_model_tags is None:
                saved_model_tags = []
            loader.load(sess, saved_model_tags, input_saved_model_dir)
        else:
            var_list = {}
            reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
            var_to_shape_map = reader.get_variable_to_shape_map()

            # List of all partition variables. Because the condition is heuristic
            # based, the list could include false positives.
            all_parition_variable_names = [
                tensor.name.split(":")[0]
                for op in sess.graph.get_operations()
                for tensor in op.values()
                if re.search(r"/part_\d+/", tensor.name)
            ]
            has_partition_var = False

            for key in var_to_shape_map:
                try:
                    tensor = sess.graph.get_tensor_by_name(key + ":0")
                    if any(key in name
                           for name in all_parition_variable_names):
                        has_partition_var = True
                except KeyError:
                    # This tensor doesn't exist in the graph (for example it's
                    # 'global_step' or a similar housekeeping element) so skip it.
                    continue
                var_list[key] = tensor

            try:
                saver = saver_lib.Saver(var_list=var_list,
                                        write_version=checkpoint_version)
            except TypeError as e:
                # `var_list` is required to be a map of variable names to Variable
                # tensors. Partition variables are Identity tensors that cannot be
                # handled by Saver.
                if has_partition_var:
                    print(
                        "Models containing partition variables cannot be converted "
                        "from checkpoint files. Please pass in a SavedModel using "
                        "the flag --input_saved_model_dir.")
                    return -1
                # Models that have been frozen previously do not contain Variables.
                elif _has_no_variables(sess):
                    print(
                        "No variables were found in this model. It is likely the model "
                        "was frozen previously. You cannot freeze a graph twice."
                    )
                    return 0
                else:
                    raise e

            saver.restore(sess, input_checkpoint)
            if initializer_nodes:
                sess.run(initializer_nodes.replace(" ", "").split(","))

        variable_names_whitelist = (variable_names_whitelist.replace(
            " ", "").split(",") if variable_names_whitelist else None)
        variable_names_blacklist = (variable_names_blacklist.replace(
            " ", "").split(",") if variable_names_blacklist else None)

        if input_meta_graph_def:
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_meta_graph_def.graph_def,
                output_node_names.replace(" ", "").split(","),
                variable_names_whitelist=variable_names_whitelist,
                variable_names_blacklist=variable_names_blacklist)
        else:
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_graph_def,
                output_node_names.replace(" ", "").split(","),
                variable_names_whitelist=variable_names_whitelist,
                variable_names_blacklist=variable_names_blacklist)

    # Write GraphDef to file if output path has been given.
    if output_graph:
        with gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())

    return output_graph_def
Esempio n. 8
0
 def _test_session(self, target):
     config = config_pb2.ConfigProto(allow_soft_placement=True)
     config.graph_options.optimizer_options.opt_level = -1
     with session.Session(graph=None, config=config, target=target) as sess:
         yield sess
def freeze_graph_with_def_protos(
    input_graph_def,
    input_saver_def,
    input_checkpoint,
    output_node_names,
    restore_op_name,
    filename_tensor_name,
    clear_devices,
    initializer_nodes,
    variable_names_blacklist=''):
    """Converts all variables in a graph and checkpoint into constants."""
    del restore_op_name, filename_tensor_name  # Unused by updated loading code.
    
    # 'input_checkpoint' may be a prefix if we're using Saver V2 format
    if not saver_lib.checkpoint_exists(input_checkpoint):
        raise ValueError(
            "Input checkpoint ' + input_checkpoint + ' does not exist!")
        
    if not output_node_names:
        raise ValueError(
            'You must supply the name of a node to --output_node_names.')
        
    # Remove all the explicit device specifications for this node. This helps
    # to make the graph more portable.
    if clear_devices:
        for node in input_graph_def.node:
            node.device = ''
    
    with tf.Graph().as_default():
        tf.import_graph_def(input_graph_def, name='')
        config = tf.ConfigProto(graph_options=tf.GraphOptions())
        with session.Session(config=config) as sess:
            if input_saver_def:
                saver = saver_lib.Saver(saver_def=input_saver_def)
                saver.restore(sess, input_checkpoint)
            else:
                var_list = {}
                reader = pywrap_tensorflow.NewCheckpointReader(
                    input_checkpoint)
                var_to_shape_map = reader.get_variable_to_shape_map()
                for key in var_to_shape_map:
                    try:
                        tensor = sess.graph.get_tensor_by_name(key + ':0')
                    except KeyError:
                        # This tensor doesn't exist in the graph (for example
                        # it's 'global_step' or a similar housekeeping element)
                        # so skip it.
                        continue
                    var_list[key] = tensor
                saver = saver_lib.Saver(var_list=var_list)
                saver.restore(sess, input_checkpoint)
                if initializer_nodes:
                    sess.run(initializer_nodes)
            
            variable_names_blacklist = (variable_names_blacklist.split(',') if
                                        variable_names_blacklist else None)
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_graph_def,
                output_node_names.split(','),
                variable_names_blacklist=variable_names_blacklist)
    return output_graph_def
Esempio n. 10
0
    def testComplexCodeView(self):
        ops.reset_default_graph()
        outfile = os.path.join(test.get_temp_dir(), 'dump')
        opts = (builder(
            builder.trainable_variables_parameter()).with_file_output(
                outfile).with_accounted_types(['.*']).with_node_names(
                    show_name_regexes=['.*model_analyzer_testlib.py.*']).
                account_displayed_op_only(False).select(
                    ['params', 'float_ops']).build())

        with profile_context.ProfileContext(test.get_temp_dir(),
                                            trace_steps=[],
                                            dump_steps=[]) as pctx:
            with session.Session(
                    config=self._no_rewrite_session_config()) as sess:
                x = lib.BuildFullModel()

                self.evaluate(variables.global_variables_initializer())
                pctx.trace_next_step()
                _ = self.evaluate(x)
                tfprof_node = pctx.profiler.profile_python(options=opts)

                # pylint: disable=line-too-long
                with gfile.Open(outfile, 'r') as f:
                    lines = f.read().split('\n')
                    self.assertGreater(len(lines), 5)
                    result = '\n'.join([l[:min(len(l), 80)] for l in lines])
                    self.assertTrue(
                        compat.as_text(
                            lib.CheckAndRemoveDoc(result)).startswith(
                                'node name | # parameters | # float_ops'))

                self.assertLess(0, tfprof_node.total_exec_micros)
                self.assertEqual(2844, tfprof_node.total_parameters)
                #The graph is modifed when MKL is enabled,total_float_ops will
                #be different
                if test_util.IsMklEnabled():
                    self.assertLess(101600, tfprof_node.total_float_ops)
                else:
                    self.assertLess(145660, tfprof_node.total_float_ops)
                self.assertEqual(8, len(tfprof_node.children))
                self.assertEqual('_TFProfRoot', tfprof_node.name)
                self.assertEqual('model_analyzer_testlib.py:63:BuildFullModel',
                                 tfprof_node.children[0].name)
                self.assertEqual(
                    'model_analyzer_testlib.py:63:BuildFullModel (gradient)',
                    tfprof_node.children[1].name)
                self.assertEqual('model_analyzer_testlib.py:67:BuildFullModel',
                                 tfprof_node.children[2].name)
                self.assertEqual(
                    'model_analyzer_testlib.py:67:BuildFullModel (gradient)',
                    tfprof_node.children[3].name)
                self.assertEqual('model_analyzer_testlib.py:69:BuildFullModel',
                                 tfprof_node.children[4].name)
                self.assertEqual('model_analyzer_testlib.py:70:BuildFullModel',
                                 tfprof_node.children[5].name)
                self.assertEqual(
                    'model_analyzer_testlib.py:70:BuildFullModel (gradient)',
                    tfprof_node.children[6].name)
                self.assertEqual('model_analyzer_testlib.py:72:BuildFullModel',
                                 tfprof_node.children[7].name)
Esempio n. 11
0
def initialize_tpu_system(cluster_resolver=None):
  """Initialize the TPU devices.

  Args:
    cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
        which provides information about the TPU cluster.
  Returns:
    The tf.tpu.Topology object for the topology of the TPU cluster.

  Raises:
    RuntimeError: If no TPU devices found for eager execution.
  """
  if cluster_resolver is None:
    cluster_resolver = TPUClusterResolver("")
  assert isinstance(cluster_resolver, TPUClusterResolver)

  tpu_name = compat.as_text(cluster_resolver._tpu)  # pylint: disable=protected-access
  if tpu_name in _INITIALIZED_TPU_SYSTEMS:
    logging.warning("TPU system %s has already been initialized. "
                    "Reinitializing the TPU can cause previously created "
                    "variables on TPU to be lost.")

  logging.info("Initializing the TPU system: %s", tpu_name)

  if context.executing_eagerly():
    # This function looks as it is for the following non-intuitive reasons.
    # tpu.initialize_system creates a dummy op whose sole purpose is to trigger
    # DistributedTPURewritePass. This pass actually adds real ops that
    # initialize the TPU system. Thus, we can't simply run tpu.initialize_system
    # eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
    @function.defun
    def _tpu_init_fn():
      if tpu_name in _LOCAL_MASTERS:
        job = None
      else:
        # Explicitly place the tpu.initialize_system in the first worker to
        # avoid the output node match multiple devices error.
        job = "worker/replica:0/task:0"
      return tpu.initialize_system(job=job)

    tpu_devices = sorted(
        [x for x in context.list_devices() if "device:TPU:" in x])

    if not tpu_devices:
      raise RuntimeError("Could not find any TPU devices")

    # Replace the remote TPU device with the remote TPU_SYSTEM system device. As
    # in the remote TPU device case, we will try to compile it instead of
    # running through optimization passes and TF Executor, but TPU_SYSTEM should
    # work.
    tpu_system_device = tpu_devices[0].replace("TPU", "TPU_SYSTEM")

    with ops.device(tpu_system_device):
      output = _tpu_init_fn()
    serialized_topology = output.numpy()
  else:
    master = cluster_resolver.master()
    session_config = config_pb2.ConfigProto(allow_soft_placement=True)
    with ops.Graph().as_default():
      with session_lib.Session(config=session_config, target=master) as sess:
        serialized_topology = sess.run(tpu.initialize_system())

  logging.info("Finished initializing TPU system.")
  tpu_topology = topology.Topology(serialized=serialized_topology)
  _INITIALIZED_TPU_SYSTEMS[tpu_name] = tpu_topology

  return tpu_topology
 def testCondMissingArg2(self):
     with ops.Graph().as_default():
         with session.Session():
             x = constant_op.constant(1)
             with self.assertRaises(TypeError):
                 control_flow_ops.cond(True, lambda: x)
def initialize_tpu_system(cluster_resolver=None):
  """Initialize the TPU devices.

  Args:
    cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
        which provides information about the TPU cluster.
  Returns:
    The tf.tpu.Topology object for the topology of the TPU cluster. If called
    inside tf.function, it returns the serialized topology object instead.

  Raises:
    RuntimeError: If running inside a tf.function.
    NotFoundError: If no TPU devices found in eager mode.
  """

  # Deallocate all TPU buffers by clearing out eager context caches and
  # triggering garbage collection to avoid keeping invalid tpu buffer around
  # after reinitialized tpu system.
  logging.info("Deallocate tpu buffers before initializing tpu system.")
  context.context()._clear_caches()  # pylint: disable=protected-access
  context.context().clear_kernel_cache()
  gc.collect()

  job = None
  if cluster_resolver is None:
    # If no cluster resolver is specified, and running eagerly, execute the init
    # ops in the current device scope.
    if context.executing_eagerly():
      curr_device = device.DeviceSpec.from_string(context.context().device_name)
      if curr_device.job is not None:
        job = "{}/replica:0/task:0".format(curr_device.job)

    cluster_resolver = TPUClusterResolver("")
  assert isinstance(cluster_resolver, TPUClusterResolver)

  tpu_name = compat.as_text(cluster_resolver._tpu)  # pylint: disable=protected-access
  if tpu_name in _INITIALIZED_TPU_SYSTEMS:
    logging.warning(
        "TPU system %s has already been initialized. "
        "Reinitializing the TPU can cause previously created "
        "variables on TPU to be lost.", tpu_name)

  logging.info("Initializing the TPU system: %s", tpu_name)

  # This function looks as it is for the following non-intuitive reasons.
  # tpu.initialize_system creates a dummy op whose sole purpose is to trigger
  # DistributedTPURewritePass. This pass actually adds real ops that
  # initialize the TPU system. Thus, we can't simply run tpu.initialize_system
  # eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
  if tpu_name not in _LOCAL_MASTERS:
    # Explicitly place the tpu.initialize_system in the first worker to
    # avoid the output node match multiple devices error.
    job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name())

  if context.executing_eagerly():
    @function.defun
    def _tpu_init_fn():
      # In TF1, we usually close chips when compilation fails to clear the data
      # in infeed. In TF2, we don't need to do this because infeed is no longer
      # used, so user can recover from TPU compilation failures more smoothly.
      # Same for the cancellation of a TPU excution.
      return tpu.initialize_system(
          job=job,
          compilation_failure_closes_chips=False,
          tpu_cancellation_closes_chips=False)

    # The TPU_SYSTEM device must match the device used in tpu.initialize_system
    # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM
    # devices available.
    try:
      with ops.device(tpu._tpu_system_device_name(job)):  # pylint: disable=protected-access
        output = _tpu_init_fn()
      context.async_wait()
    except errors.InvalidArgumentError as e:
      raise errors.NotFoundError(
          None, None,
          "TPUs not found in the cluster. Failed in initialization: "
          + str(e))

    # Clear out the eager context caches since the memory is invalid now.
    context.context()._initialize_logical_devices()  # pylint: disable=protected-access

    serialized_topology = output.numpy()
  elif not ops.executing_eagerly_outside_functions():
    master = cluster_resolver.master()
    cluster_spec = cluster_resolver.cluster_spec()

    session_config = config_pb2.ConfigProto(allow_soft_placement=True)
    if cluster_spec:
      session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())

    with ops.Graph().as_default():
      with session_lib.Session(config=session_config, target=master) as sess:
        serialized_topology = sess.run(tpu.initialize_system())
  else:
    with ops.device(tpu._tpu_system_device_name(job)):  # pylint: disable=protected-access
      serialized_topology = tpu.initialize_system(
          job=job, compilation_failure_closes_chips=False)
      # If initialize_tpu_system is called inside tf.function, we only return
      # the serialized topology object as the tf.tpu.Topology object has to be
      # constructed in eager mode.
      return serialized_topology

  logging.info("Finished initializing TPU system.")
  tpu_topology = topology.Topology(serialized=serialized_topology)
  _INITIALIZED_TPU_SYSTEMS[tpu_name] = tpu_topology

  return tpu_topology
def shutdown_tpu_system(cluster_resolver=None):
  """Shuts down the TPU devices.

  This will clear all caches, even those that are maintained through sequential
  calls to tf.tpu.experimental.initialize_tpu_system, such as the compilation
  cache.

  Args:
    cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
        which provides information about the TPU cluster.

  Raises:
    RuntimeError: If no TPU devices found for eager execution or if run in a
        tf.function.
  """
  job = None
  if cluster_resolver is None:
    # If no cluster resolver is specified, and running eagerly, execute the init
    # ops in the current device scope.
    if context.executing_eagerly():
      curr_device = device.DeviceSpec.from_string(context.context().device_name)
      if curr_device.job is not None:
        job = "{}/replica:0/task:0".format(curr_device.job)

    cluster_resolver = TPUClusterResolver("")
  assert isinstance(cluster_resolver, TPUClusterResolver)

  tpu_name = compat.as_text(cluster_resolver._tpu)  # pylint: disable=protected-access
  if tpu_name not in _INITIALIZED_TPU_SYSTEMS:
    logging.warning("You are shutting down a TPU system %s that has not been "
                    "initialized." % tpu_name)

  logging.info("Shutting down the TPU system: %s", tpu_name)

  if context.executing_eagerly():
    # This function looks as it is for the following non-intuitive reasons.
    # tpu.shutdown_system creates a dummy op whose sole purpose is to trigger
    # DistributedTPURewritePass. This pass actually adds real ops that
    # shutdown the TPU system. Thus, we can't simply run tpu.shutdown_system
    # eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
    if tpu_name not in _LOCAL_MASTERS:
      # Explicitly place the tpu.shutdown_system in the first worker to
      # avoid the output node match multiple devices error.
      job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name())

    @function.defun
    def _tpu_shutdown_fn():
      tpu.shutdown_system(job=job)

    # The TPU_SYSTEM device must match the device used in tpu.shutdown_system
    # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM
    # devices available.
    with ops.device(tpu._tpu_system_device_name(job)):  # pylint: disable=protected-access
      _tpu_shutdown_fn()

    # Clear out the eager context caches since the memory is invalid now.
    logging.info("Clearing out eager caches")
    context.context()._clear_caches()  # pylint: disable=protected-access
    context.context().clear_kernel_cache()
  elif not ops.executing_eagerly_outside_functions():
    master = cluster_resolver.master()
    cluster_spec = cluster_resolver.cluster_spec()

    session_config = config_pb2.ConfigProto(allow_soft_placement=True)
    if cluster_spec:
      session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())

    with ops.Graph().as_default():
      with session_lib.Session(config=session_config, target=master) as sess:
        sess.run(tpu.shutdown_system())
  else:
    raise RuntimeError(
        "initialize_tpu_system is not supported within "
        "tf.functions.  You should call initialize_tpu_system outside of your tf.function. "
    )

  logging.info("Finished shutting down TPU system.")
  if tpu_name in _INITIALIZED_TPU_SYSTEMS:
    del _INITIALIZED_TPU_SYSTEMS[tpu_name]
def freeze_model(checkpoint_path: str,
                 meta_graph_def: meta_graph_pb2.MetaGraphDef,
                 output_prefix: str, signature_def_key: str,
                 variables_to_feed: List[str]) -> Tuple[str, str]:
    """Freeze a `MetaGraphDef` in preparation for tfcompile`.

  The graph is always optimized with grappler, and optionally (by default)
  variables are frozen as constants, before compilation happens.

  Args:
    checkpoint_path: Python string.  Path to checkpoints/variables.
    meta_graph_def: Instance of `MetaGraphDef`.
    output_prefix: Python string.  Path prefix for outputs.
    signature_def_key: String, the signature_def to use in the SavedModel.
    variables_to_feed: A list of strings, the variables that will be fed by the
      user; these won't be frozen.  If `None`, then we will extract all the
      variables in the graph and mark them as to-feed.  The default behavior is
      an empty tuple: all variables must be frozen.
  Returns:
    a pair containing the path to the frozen model and the path to the config.
  Raises:
    RuntimeError: If tensorflow was not built with XLA.
    ImportError: If tensorflow was built with XLA but there was another
      issue importing the tfcompile python wrapper.
    ValueError: If `meta_graph_def.signature_def[signature_def_key]` is
      missing or has empty outputs.
  """
    if _pywrap_tfcompile_import_error:
        raise _pywrap_tfcompile_import_error  # pylint: disable=raising-bad-type

    signature_def_map = meta_graph_def.signature_def
    if signature_def_key not in signature_def_map:
        raise ValueError(
            f"Unable to find signature_def_key '{signature_def_key}' in signature "
            'def map of `meta_graph_def`. Available keys: '
            f'{list(signature_def_map.keys())}')
    signature_def = signature_def_map[signature_def_key]
    if not signature_def.outputs:
        raise ValueError(
            f'Signature key {signature_def_key} must have outputs, but saw none:\n'
            f'{str(signature_def)}')

    file_io.recursive_create_dir(output_prefix)
    if logging.get_verbosity() >= logging.INFO:
        original_graph_def_location = os.path.join(output_prefix,
                                                   'original_graph.pb')
        with file_io.FileIO(original_graph_def_location, 'wb') as graph_writer:
            graph_writer.write(meta_graph_def.graph_def.SerializeToString())

    # This updates graph_def in place.
    _replace_input_placeholders_with_default_values(meta_graph_def.graph_def,
                                                    signature_def)

    graph_def = _optimize_graph(meta_graph_def, signature_def)

    all_variables = _get_variable_nodes_from_graph_def(graph_def)
    if variables_to_feed is None:
        variable_nodes_to_feed = list(all_variables.values())
    else:
        not_in_graph = set(variables_to_feed).difference(list(all_variables))
        if not_in_graph:
            raise ValueError(
                'Asked to feed variables that were not found in graph: '
                f'{not_in_graph}. Variables contained in the graph: '
                f'{list(all_variables)}')
        variable_nodes_to_feed = [
            all_variables[name] for name in variables_to_feed
        ]

    if logging.get_verbosity() >= logging.INFO:
        prefrozen_graph_def_location = os.path.join(output_prefix,
                                                    'prefrozen_graph.pb')
        with file_io.FileIO(prefrozen_graph_def_location,
                            'wb') as graph_writer:
            graph_writer.write(graph_def.SerializeToString())

    # Load the Variables so that we can freeze the graph.
    with session.Session(graph=ops_lib.Graph()) as sess:
        restorer = saver_lib.import_meta_graph(meta_graph_def,
                                               clear_devices=True)
        if restorer is not None:
            restorer.restore(sess, checkpoint_path)
        graph_def.CopyFrom(
            graph_util.convert_variables_to_constants(
                sess,
                graph_def,
                output_node_names=[
                    _parse_tensor_name(n.name)[0]
                    for n in signature_def.outputs.values()
                ],
                variable_names_blacklist=[
                    n.name for n, _ in variable_nodes_to_feed
                ],
            ))

    signature_def = _prune_removed_feed_nodes(signature_def, graph_def)

    frozen_graph_def_location = os.path.join(output_prefix, 'frozen_graph.pb')
    config_pbtxt_location = os.path.join(output_prefix, 'config.pbtxt')
    logging.info('Writing graph def to: {}'.format(frozen_graph_def_location))
    with file_io.FileIO(frozen_graph_def_location, 'wb') as graph_writer:
        graph_writer.write(graph_def.SerializeToString())
    config = _signature_to_tf2xla_config(
        signature_def, variable_nodes_to_feed=variable_nodes_to_feed)
    logging.info('Writing config_pbtxt to: {}'.format(config_pbtxt_location))
    with file_io.FileIO(config_pbtxt_location, mode='w') as config_writer:
        config_writer.write(str(config))
    return frozen_graph_def_location, config_pbtxt_location
Esempio n. 16
0
def sparse_tensor_dense_vs_dense_matmul_benchmark(thresh,
                                                  m,
                                                  k,
                                                  n,
                                                  adjoint_a,
                                                  adjoint_b,
                                                  use_gpu,
                                                  skip_dense=False):
    config = config_pb2.ConfigProto()
    config.allow_soft_placement = True

    # Configurable for benchmarking:
    # config.intra_op_parallelism_threads = 100
    # config.gpu_options.per_process_gpu_memory_fraction = 0.3

    np.random.seed([6, 117])  # Reproducibility
    x = np.random.rand(m, k).astype(np.float32)
    x[x < thresh] = 0
    y = np.random.randn(k, n).astype(np.float32)
    if adjoint_a:
        x = x.T
    if adjoint_b:
        y = y.T

    def _timer(sess, ops_fn, iterations):
        # Warm in
        sess.run(ops_fn(10, sess))

        # Timing run
        start = time.time()
        sess.run(ops_fn(iterations, sess))
        end = time.time()

        return (end - start) / (1.0 * iterations
                                )  # Average runtime per iteration

    # Using regular matmul, marking one of the matrices as dense.
    if skip_dense:
        delta_dense = float("nan")
    else:
        with session.Session(config=config, graph=ops.Graph()) as sess:
            if not use_gpu:
                with ops.device("/cpu:0"):
                    x_t = constant_op.constant(x)
                    y_t = constant_op.constant(y)
                    ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_dense(
                        x_t, y_t, adjoint_a, adjoint_b)
            else:
                with ops.device("/device:GPU:0"):
                    x_t = constant_op.constant(x)
                    y_t = constant_op.constant(y)
                    ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_dense(
                        x_t, y_t, adjoint_a, adjoint_b)
            delta_dense = _timer(sess, ops_fn, 200)

    # Using sparse_tensor_dense_matmul.
    with session.Session("", config=config, graph=ops.Graph()) as sess:
        if not use_gpu:
            with ops.device("/cpu:0"):
                x_ind = constant_op.constant(
                    np.vstack(np.where(x)).astype(np.int64).T)
                x_val = constant_op.constant(x[np.where(x)])
                x_shape = constant_op.constant(
                    np.array(x.shape).astype(np.int64))
                y_t = constant_op.constant(y)
                ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_sparse(
                    x_ind, x_val, x_shape, y_t, adjoint_a, adjoint_b)
        else:
            with ops.device("/device:GPU:0"):
                x_ind = constant_op.constant(
                    np.vstack(np.where(x)).astype(np.int64).T)
                x_val = constant_op.constant(x[np.where(x)])
                x_shape = constant_op.constant(
                    np.array(x.shape).astype(np.int64))
                y_t = constant_op.constant(y)
                ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_sparse(
                    x_ind, x_val, x_shape, y_t, adjoint_a, adjoint_b)
        delta_sparse = _timer(sess, ops_fn, 200)

    print("%g \t %d \t %s \t %d \t %d \t %g \t %g \t %g" %
          (1 - thresh, n, use_gpu, m, k, delta_dense, delta_sparse,
           delta_sparse / delta_dense))
Esempio n. 17
0
def freeze_graph_with_def_protos(input_graph_def,
                                 input_saver_def,
                                 input_checkpoint,
                                 output_node_names,
                                 restore_op_name,
                                 filename_tensor_name,
                                 clear_devices,
                                 initializer_nodes,
                                 optimize_graph=True,
                                 variable_names_blacklist=''):
    """Converts all variables in a graph and checkpoint into constants."""
    del restore_op_name, filename_tensor_name  # Unused by updated loading code.

    # 'input_checkpoint' may be a prefix if we're using Saver V2 format
    if not saver_lib.checkpoint_exists(input_checkpoint):
        raise ValueError('Input checkpoint "' + input_checkpoint +
                         '" does not exist!')

    if not output_node_names:
        raise ValueError(
            'You must supply the name of a node to --output_node_names.')

    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
        for node in input_graph_def.node:
            node.device = ''

    with tf.Graph().as_default():
        tf.import_graph_def(input_graph_def, name='')

        if optimize_graph:
            logging.info('Graph Rewriter optimizations enabled')
            if tf.__version__ < '1.5.0':
                rewrite_options = rewriter_config_pb2.RewriterConfig(
                    optimize_tensor_layout=rewriter_config_pb2.RewriterConfig.
                    ON)
            else:
                rewrite_options = rewriter_config_pb2.RewriterConfig(
                    layout_optimizer=rewriter_config_pb2.RewriterConfig.ON)
            rewrite_options.optimizers.append('pruning')
            rewrite_options.optimizers.append('constfold')
            rewrite_options.optimizers.append('layout')
            graph_options = tf.GraphOptions(rewrite_options=rewrite_options,
                                            infer_shapes=True)
        else:
            logging.info('Graph Rewriter optimizations disabled')
            graph_options = tf.GraphOptions()
        config = tf.ConfigProto(graph_options=graph_options)
        with session.Session(config=config) as sess:
            if input_saver_def:
                saver = saver_lib.Saver(saver_def=input_saver_def)
                saver.restore(sess, input_checkpoint)
            else:
                var_list = {}
                reader = pywrap_tensorflow.NewCheckpointReader(
                    input_checkpoint)
                var_to_shape_map = reader.get_variable_to_shape_map()
                for key in var_to_shape_map:
                    try:
                        tensor = sess.graph.get_tensor_by_name(key + ':0')
                    except KeyError:
                        # This tensor doesn't exist in the graph (for example it's
                        # 'global_step' or a similar housekeeping element) so skip it.
                        continue
                    var_list[key] = tensor
                saver = saver_lib.Saver(var_list=var_list)
                saver.restore(sess, input_checkpoint)
                if initializer_nodes:
                    sess.run(initializer_nodes)

            variable_names_blacklist = (variable_names_blacklist.split(',')
                                        if variable_names_blacklist else None)
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_graph_def,
                output_node_names.split(','),
                variable_names_blacklist=variable_names_blacklist)

    return output_graph_def
Esempio n. 18
0
  def test_session(self,
                   graph=None,
                   config=None,
                   use_gpu=False,
                   force_gpu=False):
    """Returns a TensorFlow Session for use in executing tests.

    This method should be used for all functional tests.

    Use the `use_gpu` and `force_gpu` options to control where ops are run. If
    `force_gpu` is True, all ops are pinned to `/gpu:0`. Otherwise, if `use_gpu`
    is True, TensorFlow tries to run as many ops on the GPU as possible. If both
    `force_gpu and `use_gpu` are False, all ops are pinned to the CPU.

    Example:

      class MyOperatorTest(test_util.TensorFlowTestCase):
        def testMyOperator(self):
          with self.test_session(use_gpu=True):
            valid_input = [1.0, 2.0, 3.0, 4.0, 5.0]
            result = MyOperator(valid_input).eval()
            self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0]
            invalid_input = [-1.0, 2.0, 7.0]
            with self.assertRaisesOpError("negative input not supported"):
              MyOperator(invalid_input).eval()

    Args:
      graph: Optional graph to use during the returned session.
      config: An optional config_pb2.ConfigProto to use to configure the
        session.
      use_gpu: If True, attempt to run as many ops as possible on GPU.
      force_gpu: If True, pin all ops to `/gpu:0`.

    Returns:
      A Session object that should be used as a context manager to surround
      the graph building and execution code in a test case.
    """
    def prepare_config(config):
      if config is None:
        config = config_pb2.ConfigProto()
        config.allow_soft_placement = not force_gpu
        config.gpu_options.per_process_gpu_memory_fraction = 0.3
      elif force_gpu and config.allow_soft_placement:
        config = config_pb2.ConfigProto().CopyFrom(config)
        config.allow_soft_placement = False
      return config

    if graph is None:
      if self._cached_session is None:
        self._cached_session = session.Session(graph=None,
                                               config=prepare_config(config))
      sess = self._cached_session
      with sess.graph.as_default(), sess.as_default():
        if force_gpu:
          with sess.graph.device("/gpu:0"):
            yield sess
        elif use_gpu:
          yield sess
        else:
          with sess.graph.device(graph_util.pin_to_cpu):
            yield sess
    else:
      with session.Session(graph=graph, config=prepare_config(config)) as sess:
        if force_gpu:
          with sess.graph.device("/gpu:0"):
            yield sess
        elif use_gpu:
          yield sess
        else:
          with sess.graph.device(graph_util.pin_to_cpu):
            yield sess
Esempio n. 19
0
    def export_savedmodel(self,
                          export_dir_base,
                          serving_input_receiver_fn,
                          assets_extra=None,
                          as_text=False,
                          checkpoint_path=None):
        """Exports inference graph as a SavedModel into given dir.

    This method builds a new graph by first calling the
    serving_input_receiver_fn to obtain feature `Tensor`s, and then calling
    this `Estimator`'s model_fn to generate the model graph based on those
    features. It restores the given checkpoint (or, lacking that, the most
    recent checkpoint) into this graph in a fresh session.  Finally it creates
    a timestamped export directory below the given export_dir_base, and writes
    a `SavedModel` into it containing a single `MetaGraphDef` saved from this
    session.

    The exported `MetaGraphDef` will provide one `SignatureDef` for each
    element of the export_outputs dict returned from the model_fn, named using
    the same keys.  One of these keys is always
    signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating which
    signature will be served when a serving request does not specify one.
    For each signature, the outputs are provided by the corresponding
    `ExportOutput`s, and the inputs are always the input receivers provided by
    the serving_input_receiver_fn.

    Extra assets may be written into the SavedModel via the extra_assets
    argument.  This should be a dict, where each key gives a destination path
    (including the filename) relative to the assets.extra directory.  The
    corresponding value gives the full path of the source file to be copied.
    For example, the simple case of copying a single file without renaming it
    is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`.

    Args:
      export_dir_base: A string containing a directory in which to create
        timestamped subdirectories containing exported SavedModels.
      serving_input_receiver_fn: A function that takes no argument and
        returns a `ServingInputReceiver`.
      assets_extra: A dict specifying how to populate the assets.extra directory
        within the exported SavedModel, or `None` if no extra assets are needed.
      as_text: whether to write the SavedModel proto in text format.
      checkpoint_path: The checkpoint path to export.  If `None` (the default),
        the most recent checkpoint found within the model directory is chosen.

    Returns:
      The string path to the exported directory.

    Raises:
      ValueError: if no serving_input_receiver_fn is provided, no export_outputs
          are provided, or no checkpoint can be found.
    """
        if serving_input_receiver_fn is None:
            raise ValueError('serving_input_receiver_fn must be defined.')

        with ops.Graph().as_default() as g:
            training.create_global_step(g)
            serving_input_receiver = serving_input_receiver_fn()

            # Call the model_fn and collect the export_outputs.
            estimator_spec = self._call_model_fn(
                features=serving_input_receiver.features,
                labels=None,
                mode=model_fn_lib.ModeKeys.PREDICT)

            # Build the SignatureDefs from receivers and all outputs
            signature_def_map = build_all_signature_defs(
                serving_input_receiver.receiver_tensors,
                estimator_spec.export_outputs)

            if not checkpoint_path:
                # Locate the latest checkpoint
                checkpoint_path = saver.latest_checkpoint(self._model_dir)
            if not checkpoint_path:
                raise ValueError("Couldn't find trained model at %s." %
                                 self._model_dir)

            export_dir = get_timestamped_export_dir(export_dir_base)

            # TODO(soergel): Consider whether MonitoredSession makes sense here
            with tf_session.Session() as session:

                saver_for_restore = estimator_spec.scaffold.saver or saver.Saver(
                    variables._all_saveable_objects(),  # pylint: disable=protected-access
                    sharded=True)
                saver_for_restore.restore(session, checkpoint_path)

                # TODO(b/36111876): replace legacy_init_op with main_op mechanism
                # pylint: disable=protected-access
                local_init_op = (
                    estimator_spec.scaffold.local_init_op
                    or monitored_session.Scaffold._default_local_init_op())
                # pylint: enable=protected-access

                # Perform the export
                builder = saved_model_builder.SavedModelBuilder(export_dir)
                builder.add_meta_graph_and_variables(
                    session, [tag_constants.SERVING],
                    signature_def_map=signature_def_map,
                    assets_collection=ops.get_collection(
                        ops.GraphKeys.ASSET_FILEPATHS),
                    legacy_init_op=local_init_op)
                builder.save(as_text)

            # Add the extra assets
            if assets_extra:
                assets_extra_path = os.path.join(
                    compat.as_bytes(export_dir),
                    compat.as_bytes('assets.extra'))
                for dest_relative, source in assets_extra.items():
                    dest_absolute = os.path.join(
                        compat.as_bytes(assets_extra_path),
                        compat.as_bytes(dest_relative))
                    dest_path = os.path.dirname(dest_absolute)
                    gfile.MakeDirs(dest_path)
                    gfile.Copy(source, dest_absolute)

            return export_dir
Esempio n. 20
0
    def testSinglePartitionedVariable(self):
        """Ensures partitioned variables fail cleanly with freeze graph."""
        checkpoint_prefix = os.path.join(self.get_temp_dir(),
                                         "saved_checkpoint")
        checkpoint_state_name = "checkpoint_state"
        input_graph_name = "input_graph.pb"
        output_graph_name = "output_graph.pb"

        # Create a graph with partition variables. When weights are partitioned into
        # a single partition, the weights variable is followed by a identity ->
        # identity (an additional identity node).
        partitioner = partitioned_variables.fixed_size_partitioner(1)
        with ops.Graph().as_default():
            with variable_scope.variable_scope("part",
                                               partitioner=partitioner):
                batch_size, height, width, depth = 5, 128, 128, 3
                input1 = array_ops.zeros((batch_size, height, width, depth),
                                         name="input1")
                input2 = array_ops.zeros((batch_size, height, width, depth),
                                         name="input2")

                num_nodes = depth
                filter1 = variable_scope.get_variable("filter",
                                                      [num_nodes, num_nodes])
                filter2 = array_ops.reshape(filter1,
                                            [1, 1, num_nodes, num_nodes])
                conv = nn.conv2d(input=input1,
                                 filter=filter2,
                                 strides=[1, 1, 1, 1],
                                 padding="SAME")
                node = math_ops.add(conv, input2, name="test/add")
                node = nn.relu6(node, name="test/relu6")

            # Save graph and checkpoints.
            sess = session.Session()
            sess.run(variables.global_variables_initializer())

            saver = saver_lib.Saver()
            checkpoint_path = saver.save(sess,
                                         checkpoint_prefix,
                                         global_step=0,
                                         latest_filename=checkpoint_state_name)
            graph_io.write_graph(sess.graph, self.get_temp_dir(),
                                 input_graph_name)

            # Ensure this graph has partition variables.
            self.assertTrue([
                tensor.name.split(":")[0]
                for op in sess.graph.get_operations()
                for tensor in op.values()
                if re.search(r"/part_\d+/", tensor.name)
            ])

        # Test freezing graph doesn't make it crash.
        output_node_names = "save/restore_all"
        output_graph_path = os.path.join(self.get_temp_dir(),
                                         output_graph_name)

        return_value = freeze_graph.freeze_graph_with_def_protos(
            input_graph_def=sess.graph_def,
            input_saver_def=None,
            input_checkpoint=checkpoint_path,
            output_node_names=output_node_names,
            restore_op_name="save/restore_all",  # default value
            filename_tensor_name="save/Const:0",  # default value
            output_graph=output_graph_path,
            clear_devices=False,
            initializer_nodes="")
        self.assertTrue(return_value, -1)
Esempio n. 21
0
 def benchmark_performance_graph(self):
     with ops.get_default_graph().as_default():
         with session_lib.Session(config=_config):
             self._benchmark_performance_with_standard_cudnn_impl()
Esempio n. 22
0
    def _testFreezeGraph(self, saver_write_version):

        checkpoint_prefix = os.path.join(self.get_temp_dir(),
                                         "saved_checkpoint")
        checkpoint_meta_graph_file = os.path.join(self.get_temp_dir(),
                                                  "saved_checkpoint.meta")
        checkpoint_state_name = "checkpoint_state"
        input_graph_name = "input_graph.pb"
        output_graph_name = "output_graph.pb"

        # We'll create an input graph that has a single variable containing 1.0,
        # and that then multiplies it by 2.
        with ops.Graph().as_default():
            variable_node = variables.VariableV1(1.0, name="variable_node")
            output_node = math_ops.multiply(variable_node,
                                            2.0,
                                            name="output_node")
            sess = session.Session()
            init = variables.global_variables_initializer()
            sess.run(init)
            output = sess.run(output_node)
            self.assertNear(2.0, output, 0.00001)
            saver = saver_lib.Saver(write_version=saver_write_version)
            checkpoint_path = saver.save(sess,
                                         checkpoint_prefix,
                                         global_step=0,
                                         latest_filename=checkpoint_state_name)
            graph_io.write_graph(sess.graph, self.get_temp_dir(),
                                 input_graph_name)

        # We save out the graph to disk, and then call the const conversion
        # routine.
        input_graph_path = os.path.join(self.get_temp_dir(), input_graph_name)
        input_saver_def_path = ""
        input_binary = False
        output_node_names = "output_node"
        restore_op_name = "save/restore_all"
        filename_tensor_name = "save/Const:0"
        output_graph_path = os.path.join(self.get_temp_dir(),
                                         output_graph_name)
        clear_devices = False
        input_meta_graph = checkpoint_meta_graph_file

        freeze_graph.freeze_graph(input_graph_path,
                                  input_saver_def_path,
                                  input_binary,
                                  checkpoint_path,
                                  output_node_names,
                                  restore_op_name,
                                  filename_tensor_name,
                                  output_graph_path,
                                  clear_devices,
                                  "",
                                  "",
                                  input_meta_graph,
                                  checkpoint_version=saver_write_version)

        # Now we make sure the variable is now a constant, and that the graph still
        # produces the expected result.
        with ops.Graph().as_default():
            output_graph_def = graph_pb2.GraphDef()
            with open(output_graph_path, "rb") as f:
                output_graph_def.ParseFromString(f.read())
                _ = importer.import_graph_def(output_graph_def, name="")

            self.assertEqual(4, len(output_graph_def.node))
            for node in output_graph_def.node:
                self.assertNotEqual("VariableV2", node.op)
                self.assertNotEqual("Variable", node.op)

            with session.Session() as sess:
                output_node = sess.graph.get_tensor_by_name("output_node:0")
                output = sess.run(output_node)
                self.assertNear(2.0, output, 0.00001)
Esempio n. 23
0
 def testPartialRunIncompleteDirect(self):
     self.RunTestPartialRunIncomplete(session.Session())
Esempio n. 24
0
def freeze_graph(input_graph,
                 input_saver,
                 input_binary,
                 input_checkpoint,
                 output_node_names,
                 restore_op_name,
                 filename_tensor_name,
                 output_graph,
                 clear_devices,
                 initializer_nodes,
                 variable_names_blacklist=""):
    """Converts all variables in a graph and checkpoint into constants."""

    del restore_op_name, filename_tensor_name  # Unused by updated loading code.

    if not gfile.Exists(input_graph):
        print("Input graph file '" + input_graph + "' does not exist!")
        return -1

    if input_saver and not gfile.Exists(input_saver):
        print("Input saver file '" + input_saver + "' does not exist!")
        return -1

    # 'input_checkpoint' may be a prefix if we're using Saver V2 format
    if not saver_lib.checkpoint_exists(input_checkpoint):
        print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
        return -1

    if not output_node_names:
        print("You need to supply the name of a node to --output_node_names.")
        return -1

    input_graph_def = graph_pb2.GraphDef()
    mode = "rb" if input_binary else "r"
    with gfile.FastGFile(input_graph, mode) as f:
        if input_binary:
            input_graph_def.ParseFromString(f.read())
        else:
            text_format.Merge(f.read().decode("utf-8"), input_graph_def)
    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
        for node in input_graph_def.node:
            node.device = ""

    _ = importer.import_graph_def(input_graph_def, name="")

    with session.Session() as sess:
        if input_saver:
            with gfile.FastGFile(input_saver, mode) as f:
                saver_def = saver_pb2.SaverDef()
                if input_binary:
                    saver_def.ParseFromString(f.read())
                else:
                    text_format.Merge(f.read(), saver_def)
                saver = saver_lib.Saver(saver_def=saver_def)
                saver.restore(sess, input_checkpoint)
        else:
            var_list = {}
            reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
            var_to_shape_map = reader.get_variable_to_shape_map()
            for key in var_to_shape_map:
                try:
                    tensor = sess.graph.get_tensor_by_name(key + ":0")
                except KeyError:
                    # This tensor doesn't exist in the graph (for example it's
                    # 'global_step' or a similar housekeeping element) so skip it.
                    continue
                var_list[key] = tensor
            saver = saver_lib.Saver(var_list=var_list)
            saver.restore(sess, input_checkpoint)
            if initializer_nodes:
                sess.run(initializer_nodes)

        variable_names_blacklist = (variable_names_blacklist.split(",")
                                    if variable_names_blacklist else None)
        output_graph_def = graph_util.convert_variables_to_constants(
            sess,
            input_graph_def,
            output_node_names.split(","),
            variable_names_blacklist=variable_names_blacklist)

    with gfile.GFile(output_graph, "wb") as f:
        f.write(output_graph_def.SerializeToString())
    print("%d ops in the final graph." % len(output_graph_def.node))
Esempio n. 25
0
 def testManyPartialRunDirect(self):
     self.RunTestManyPartialRun(session.Session())
Esempio n. 26
0
    def testMinOption(self):
        ops.reset_default_graph()

        def check_min(nodes, mm=0, mam=0, mcm=0, mb=0, mpb=0, mrb=0, mob=0):
            for n in nodes:
                if mm > 0:
                    self.assertGreaterEqual(n.exec_micros, mm)
                if mam > 0:
                    self.assertGreaterEqual(n.accelerator_exec_micros, mam)
                if mcm > 0:
                    self.assertGreaterEqual(n.cpu_exec_micros, mcm)
                if mb > 0:
                    self.assertGreaterEqual(n.requested_bytes, mb)
                if mpb > 0:
                    self.assertGreaterEqual(n.peak_bytes, mpb)
                if mrb > 0:
                    self.assertGreaterEqual(n.residual_bytes, mrb)
                if mob > 0:
                    self.assertGreaterEqual(n.output_bytes, mob)
                check_min(n.children, mm, mam, mcm, mb, mpb, mrb, mob)

        with session.Session() as sess:
            x = lib.BuildSmallModel()
            sess.run(variables.global_variables_initializer())
            run_meta = config_pb2.RunMetadata()
            _ = sess.run(x,
                         options=config_pb2.RunOptions(
                             trace_level=config_pb2.RunOptions.FULL_TRACE),
                         run_metadata=run_meta)

            min_val = random.randint(0, 10000)

            opts = builder(builder.time_and_memory(
                min_micros=min_val)).with_empty_output().build()
            tfprof_node = model_analyzer.profile(sess.graph,
                                                 run_meta=run_meta,
                                                 options=opts)
            check_min(tfprof_node.children, mm=min_val)

            opts = builder(
                builder.time_and_memory(min_accelerator_micros=min_val)
            ).with_empty_output().build()
            tfprof_node = model_analyzer.profile(sess.graph,
                                                 run_meta=run_meta,
                                                 options=opts)
            check_min(tfprof_node.children, mam=min_val)

            opts = builder(builder.time_and_memory(
                min_cpu_micros=min_val)).with_empty_output().build()
            tfprof_node = model_analyzer.profile(sess.graph,
                                                 run_meta=run_meta,
                                                 options=opts)
            check_min(tfprof_node.children, mcm=min_val)

            opts = builder(builder.time_and_memory(
                min_bytes=min_val)).with_empty_output().build()
            tfprof_node = model_analyzer.profile(sess.graph,
                                                 run_meta=run_meta,
                                                 options=opts)
            check_min(tfprof_node.children, mb=min_val)

            opts = builder(builder.time_and_memory(
                min_peak_bytes=min_val)).with_empty_output().build()
            tfprof_node = model_analyzer.profile(sess.graph,
                                                 run_meta=run_meta,
                                                 options=opts)
            check_min(tfprof_node.children, mpb=min_val)

            opts = builder(builder.time_and_memory(
                min_residual_bytes=min_val)).with_empty_output().build()
            tfprof_node = model_analyzer.profile(sess.graph,
                                                 run_meta=run_meta,
                                                 options=opts)
            check_min(tfprof_node.children, mrb=min_val)

            opts = builder(builder.time_and_memory(
                min_output_bytes=min_val)).with_empty_output().build()
            tfprof_node = model_analyzer.profile(sess.graph,
                                                 run_meta=run_meta,
                                                 options=opts)
            check_min(tfprof_node.children, mob=min_val)
Esempio n. 27
0
 def testPartialRunMissingPlaceholderFeedExceptionDirect(self):
     self.RunTestPartialRunMissingPlaceholderFeedException(
         session.Session())
Esempio n. 28
0
    def testSelectEverythingDetail(self):
        ops.reset_default_graph()
        dev = '/device:GPU:0' if test.is_gpu_available() else '/device:CPU:0'
        outfile = os.path.join(test.get_temp_dir(), 'dump')
        opts = (builder(
            builder.trainable_variables_parameter()).with_file_output(
                outfile).with_accounted_types(['.*']).select([
                    'micros', 'bytes', 'params', 'float_ops', 'occurrence',
                    'device', 'op_types', 'input_shapes'
                ]).build())

        with profile_context.ProfileContext(test.get_temp_dir(),
                                            trace_steps=[],
                                            dump_steps=[]) as pctx:
            with session.Session() as sess, ops.device(dev):
                x = lib.BuildSmallModel()

                sess.run(variables.global_variables_initializer())
                pctx.trace_next_step()
                pctx.dump_next_step()
                _ = sess.run(x)

                pctx.profiler.profile_name_scope(options=opts)

                with gfile.Open(outfile, 'r') as f:
                    # pylint: disable=line-too-long
                    dump_str = f.read()
                    outputs = dump_str.split('\n')

                    self.assertEqual(
                        outputs[0],
                        'node name | # parameters | # float_ops | requested bytes | total execution time | accelerator execution time | cpu execution time | assigned devices | op types | op count (run|defined) | input shapes'
                    )
                    for o in outputs[1:]:
                        if o.find('Conv2D ') > 0:
                            metrics = o[o.find('(') + 1:o.find(')')].split(',')
                            # Make sure time is profiled.
                            gap = 1 if test.is_gpu_available() else 2
                            for i in range(3, 6, gap):
                                mat = re.search('(.*)[um]s/(.*)[um]s',
                                                metrics[i])
                                self.assertGreater(float(mat.group(1)), 0.0)
                                self.assertGreater(float(mat.group(2)), 0.0)
                            # Make sure device is profiled.
                            if test.is_gpu_available():
                                self.assertTrue(metrics[6].find('gpu') > 0)
                                self.assertFalse(metrics[6].find('cpu') > 0)
                            else:
                                self.assertFalse(metrics[6].find('gpu') > 0)
                                self.assertTrue(metrics[6].find('cpu') > 0)
                            # Make sure float_ops is profiled.
                            mat = re.search('(.*)k/(.*)k flops',
                                            metrics[1].strip())
                            self.assertGreater(float(mat.group(1)), 0.0)
                            self.assertGreater(float(mat.group(2)), 0.0)
                            # Make sure op_count is profiled.
                            self.assertEqual(metrics[8].strip(), '1/1|1/1')
                            # Make sure input_shapes is profiled.
                            self.assertEqual(metrics[9].strip(),
                                             '0:2x6x6x3|1:3x3x3x6')

                        if o.find('DW (3x3x3x6') > 0:
                            metrics = o[o.find('(') + 1:o.find(')')].split(',')
                            mat = re.search('(.*)/(.*) params',
                                            metrics[1].strip())
                            self.assertGreater(float(mat.group(1)), 0.0)
                            self.assertGreater(float(mat.group(2)), 0.0)
                    # pylint: enable=line-too-long

        # Test that profiler restored from profile file gives the same result.
        gfile.Remove(outfile)
        profile_file = os.path.join(test.get_temp_dir(), 'profile_1')
        with lib.ProfilerFromFile(profile_file) as profiler:
            profiler.profile_name_scope(options=opts)
            with gfile.Open(outfile, 'r') as f:
                self.assertEqual(dump_str, f.read())
Esempio n. 29
0
 def testPartialRunAlreadyFetchedDirect(self):
     self.RunTestPartialRunAlreadyFetched(session.Session())
Esempio n. 30
0
    def testFunctionalModelMultipleInputs(self):
        """Test a Functional tf.keras model with multiple inputs and outputs."""
        with session.Session().as_default():
            a = keras.layers.Input(shape=(3, ), name='input_a')
            b = keras.layers.Input(shape=(3, ), name='input_b')
            dense = keras.layers.Dense(4, name='dense')
            c = dense(a)
            d = dense(b)
            e = keras.layers.Dropout(0.5, name='dropout')(c)

            model = keras.models.Model([a, b], [d, e])
            model.compile(loss=keras.losses.MSE,
                          optimizer=keras.optimizers.RMSprop(),
                          metrics=[keras.metrics.mae],
                          loss_weights=[1., 0.5])

            input_a_np = np.random.random((10, 3))
            input_b_np = np.random.random((10, 3))
            output_d_np = np.random.random((10, 4))
            output_e_np = np.random.random((10, 4))
            model.train_on_batch([input_a_np, input_b_np],
                                 [output_d_np, output_e_np])

            model.predict([input_a_np, input_b_np], batch_size=5)
            fd, keras_file = tempfile.mkstemp('.h5')
            try:
                keras.models.save_model(model, keras_file)
            finally:
                os.close(fd)

        # Convert to TFLite model.
        converter = lite.TocoConverter.from_keras_model_file(keras_file)
        tflite_model = converter.convert()
        self.assertTrue(tflite_model)

        os.remove(keras_file)

        # Check values from converted model.
        interpreter = Interpreter(model_content=tflite_model)
        interpreter.allocate_tensors()

        input_details = interpreter.get_input_details()
        self.assertEqual(2, len(input_details))
        self.assertEqual('input_a', input_details[0]['name'])
        self.assertEqual(np.float32, input_details[0]['dtype'])
        self.assertTrue(([1, 3] == input_details[0]['shape']).all())
        self.assertEqual((0., 0.), input_details[0]['quantization'])

        self.assertEqual('input_b', input_details[1]['name'])
        self.assertEqual(np.float32, input_details[1]['dtype'])
        self.assertTrue(([1, 3] == input_details[1]['shape']).all())
        self.assertEqual((0., 0.), input_details[1]['quantization'])

        output_details = interpreter.get_output_details()
        self.assertEqual(2, len(output_details))
        self.assertEqual('dense_1/BiasAdd', output_details[0]['name'])
        self.assertEqual(np.float32, output_details[0]['dtype'])
        self.assertTrue(([1, 4] == output_details[0]['shape']).all())
        self.assertEqual((0., 0.), output_details[0]['quantization'])

        self.assertEqual('dropout/Identity', output_details[1]['name'])
        self.assertEqual(np.float32, output_details[1]['dtype'])
        self.assertTrue(([1, 4] == output_details[1]['shape']).all())
        self.assertEqual((0., 0.), output_details[1]['quantization'])