def test_save_and_load_roundtrip(self): state = build_fake_state() export_dir = os.path.join(self.get_temp_dir(), 'ckpt_1') checkpoint_utils.save(state, export_dir) loaded_state = checkpoint_utils.load(export_dir, state) self.assertEqual(state, loaded_state)
def test_load_latest_state(self): server_optimizer_fn = functools.partial(tf.keras.optimizers.SGD, learning_rate=0.1, momentum=0.9) iterative_process = tff.learning.build_federated_averaging_process( models.model_fn, server_optimizer_fn=server_optimizer_fn) server_state = iterative_process.initialize() # TODO(b/130724878): These conversions should not be needed. obj_1 = Obj.from_anon_tuple(server_state, 1) export_dir = os.path.join(self.get_temp_dir(), 'ckpt_1') checkpoint_utils.save(obj_1, export_dir) # TODO(b/130724878): These conversions should not be needed. obj_2 = Obj.from_anon_tuple(server_state, 2) export_dir = os.path.join(self.get_temp_dir(), 'ckpt_2') checkpoint_utils.save(obj_2, export_dir) export_dir = checkpoint_utils.latest_checkpoint(self.get_temp_dir()) loaded_obj = checkpoint_utils.load(export_dir, obj_1) self.assertEqual(os.path.join(self.get_temp_dir(), 'ckpt_2'), export_dir) self.assertAllClose(tf.nest.flatten(obj_2), tf.nest.flatten(loaded_obj))
def read_checkpoint(filepath, example_server_state): """Read a previously saved experiment state to memory.""" experiment_state = ExperimentState(round_num=0, metrics_csv_string='', server_state=example_server_state) experiment_state = checkpoint_utils.load(filepath, experiment_state) metrics_dict = pd.read_csv(io.BytesIO( experiment_state.metrics_csv_string.numpy()), header=0, index_col=0, engine='c') return (experiment_state.server_state, metrics_dict, experiment_state.round_num.numpy())
def read_checkpoint(checkpoint_dir, server_state): """Read a previously saved experiment state to memory.""" obj_template = ExperimentState(round_num=0, server_state=server_state) state = checkpoint_utils.load(checkpoint_dir, obj_template) return state.server_state, state.round_num