def test_save_and_load_roundtrip(self):
    state = build_fake_state()
    export_dir = os.path.join(self.get_temp_dir(), 'ckpt_1')
    checkpoint_utils.save(state, export_dir)

    loaded_state = checkpoint_utils.load(export_dir, state)
    self.assertEqual(state, loaded_state)
Esempio n. 2
0
    def test_load_latest_state(self):
        server_optimizer_fn = functools.partial(tf.keras.optimizers.SGD,
                                                learning_rate=0.1,
                                                momentum=0.9)

        iterative_process = tff.learning.build_federated_averaging_process(
            models.model_fn, server_optimizer_fn=server_optimizer_fn)
        server_state = iterative_process.initialize()
        # TODO(b/130724878): These conversions should not be needed.
        obj_1 = Obj.from_anon_tuple(server_state, 1)
        export_dir = os.path.join(self.get_temp_dir(), 'ckpt_1')
        checkpoint_utils.save(obj_1, export_dir)

        # TODO(b/130724878): These conversions should not be needed.
        obj_2 = Obj.from_anon_tuple(server_state, 2)
        export_dir = os.path.join(self.get_temp_dir(), 'ckpt_2')
        checkpoint_utils.save(obj_2, export_dir)

        export_dir = checkpoint_utils.latest_checkpoint(self.get_temp_dir())

        loaded_obj = checkpoint_utils.load(export_dir, obj_1)

        self.assertEqual(os.path.join(self.get_temp_dir(), 'ckpt_2'),
                         export_dir)
        self.assertAllClose(tf.nest.flatten(obj_2),
                            tf.nest.flatten(loaded_obj))
Esempio n. 3
0
def read_checkpoint(filepath, example_server_state):
    """Read a previously saved experiment state to memory."""
    experiment_state = ExperimentState(round_num=0,
                                       metrics_csv_string='',
                                       server_state=example_server_state)
    experiment_state = checkpoint_utils.load(filepath, experiment_state)
    metrics_dict = pd.read_csv(io.BytesIO(
        experiment_state.metrics_csv_string.numpy()),
                               header=0,
                               index_col=0,
                               engine='c')
    return (experiment_state.server_state, metrics_dict,
            experiment_state.round_num.numpy())
Esempio n. 4
0
def read_checkpoint(checkpoint_dir, server_state):
  """Read a previously saved experiment state to memory."""
  obj_template = ExperimentState(round_num=0, server_state=server_state)
  state = checkpoint_utils.load(checkpoint_dir, obj_template)

  return state.server_state, state.round_num