def testBfloat16Reload(self): checkpoint_path = os.path.join(self.get_temp_dir(), "bfloat16_restore") # Create a resource variable of type tf.float32 and save them to disk. g_for_save_graph = tf.Graph() fl = 0.99 with self.session(graph=g_for_save_graph) as sess: v0 = tf.Variable(fl, name="v0", dtype=tf.float32, use_resource=True) tf.global_variables_initializer().run() self.assertAlmostEqual(fl, v0.eval()) saver = tf.train.Saver({ "v0": v0, }, restore_sequentially=True) val = saver.save(sess, checkpoint_path) self.assertEqual(checkpoint_path, val) # Restore the variable as bfloat16. g_for_restore_graph = tf.Graph() with self.session(graph=g_for_restore_graph) as sess: v0 = tf.Variable(0.0, name="v0", dtype=tf.bfloat16, use_resource=True) tf.global_variables_initializer().run() self.assertAlmostEqual(0.0, v0.eval()) saveable = bfloat16_variables.Bfloat16VariableSaveable( v0, tf.float32, "", "v0") saver = tf.train.Saver({"v0": saveable}, restore_sequentially=True) saver.restore(sess, checkpoint_path) self.assertAlmostEqual(fl, v0.eval())
def GetVariablesWithBfloat16Overrides(variables_to_load): """Returns a dictionary containing overrides to load variables as bf16.""" saver_dict = {} for v in variables_to_load: var_name = _GetVarName(v) if v.dtype == tf.bfloat16: # TODO(rohananil): Add support for PartitionedVariables if there is # demand. savable = bfloat16_variables.Bfloat16VariableSaveable( v, tf.float32, '', var_name) saver_dict[var_name] = savable else: saver_dict[var_name] = v return saver_dict