コード例 #1
0
    def testBfloat16Reload(self):
        checkpoint_path = os.path.join(self.get_temp_dir(), "bfloat16_restore")

        # Create a resource variable of type tf.float32 and save them to disk.
        g_for_save_graph = tf.Graph()
        fl = 0.99
        with self.session(graph=g_for_save_graph) as sess:
            v0 = tf.Variable(fl,
                             name="v0",
                             dtype=tf.float32,
                             use_resource=True)
            tf.global_variables_initializer().run()
            self.assertAlmostEqual(fl, v0.eval())

            saver = tf.train.Saver({
                "v0": v0,
            }, restore_sequentially=True)
            val = saver.save(sess, checkpoint_path)
            self.assertEqual(checkpoint_path, val)

        # Restore the variable as bfloat16.
        g_for_restore_graph = tf.Graph()
        with self.session(graph=g_for_restore_graph) as sess:
            v0 = tf.Variable(0.0,
                             name="v0",
                             dtype=tf.bfloat16,
                             use_resource=True)
            tf.global_variables_initializer().run()
            self.assertAlmostEqual(0.0, v0.eval())
            saveable = bfloat16_variables.Bfloat16VariableSaveable(
                v0, tf.float32, "", "v0")
            saver = tf.train.Saver({"v0": saveable}, restore_sequentially=True)
            saver.restore(sess, checkpoint_path)
            self.assertAlmostEqual(fl, v0.eval())
コード例 #2
0
def GetVariablesWithBfloat16Overrides(variables_to_load):
  """Returns a dictionary containing overrides to load variables as bf16."""
  saver_dict = {}
  for v in variables_to_load:
    var_name = _GetVarName(v)
    if v.dtype == tf.bfloat16:
      # TODO(rohananil): Add support for PartitionedVariables if there is
      # demand.
      savable = bfloat16_variables.Bfloat16VariableSaveable(
          v, tf.float32, '', var_name)
      saver_dict[var_name] = savable
    else:
      saver_dict[var_name] = v
  return saver_dict