Ejemplo n.º 1
0
    def testSaveLoadCheckpoint(self):
        init_var = m["custom_model"]().trainable_variables
        with tempfile.TemporaryDirectory() as tempdir:
            ckpt_dir = os.path.join(tempdir, "testSaveLoadCheckpoint")
            os.makedirs(ckpt_dir)
            checkpoint_saver = CheckpointSaver(ckpt_dir, 3, 5, False)
            self.assertTrue(checkpoint_saver.is_enabled())
            params = Parameters()

            for var in init_var:
                params.non_embedding_params[var.name] = var
            model_pb = params.to_model_pb()

            checkpoint_saver.save(0, model_pb, False)

            ckpt_version_dir = os.path.join(ckpt_dir, "version-0")
            restore_params = CheckpointSaver.restore_params_from_checkpoint(
                ckpt_version_dir, 0, 1)
            self.assertEqual(restore_params.version, params.version)
            for var_name in params.non_embedding_params:
                self.assertTrue(
                    np.array_equal(
                        params.non_embedding_params[var_name].numpy(),
                        restore_params.non_embedding_params[var_name].numpy(),
                    ))
Ejemplo n.º 2
0
def _get_trained_params_from_checkpoint(checkpoint_dir):
    """Get parameters from a checkpoint directory saved by ElasticDL"""
    parameters = CheckpointSaver.restore_params_from_checkpoint(
        checkpoint_dir, 0, 1)

    trained_params = parameters.non_embedding_params
    for name, table in parameters.embedding_params.items():
        trained_params[name] = table
    return trained_params
Ejemplo n.º 3
0
def _get_trained_params_from_checkpoint(checkpoint_dir):
    """Get parameters from a checkpoint directory saved by ElasticDL"""
    parameters = CheckpointSaver.restore_params_from_checkpoint(
        checkpoint_dir, 0, 1)

    trained_params = parameters.non_embedding_params
    for name, table in parameters.embedding_params.items():
        # The name of variable in a tf.keras.layers.Embedding layer is
        # "{layer_name}/embeddings:0"
        var_name = name + "/embeddings:0"
        trained_params[var_name] = table
    return trained_params
Ejemplo n.º 4
0
 def testSaveLoadCheckpoint(self):
     with tempfile.TemporaryDirectory() as tempdir:
         self.params.version = 0
         ckpt_dir = save_variables_to_checkpoint(tempdir, self.params)
         ckpt_version_dir = os.path.join(ckpt_dir, "version-0")
         restore_params = CheckpointSaver.restore_params_from_checkpoint(
             ckpt_version_dir, 0, 1)
         self.assertEqual(restore_params.version, self.params.version)
         for var_name in self.params.non_embedding_params:
             self.assertTrue(
                 np.array_equal(
                     self.params.non_embedding_params[var_name].numpy(),
                     restore_params.non_embedding_params[var_name].numpy(),
                 ))
Ejemplo n.º 5
0
    def _restore_params_from_checkpoint(self, checkpoint_dir_for_init):
        """Restore parameters from a checkpint directory for the PS instance
        """
        if not checkpoint_dir_for_init:
            self.logger.info("checkpoint directory for init is None")
            return

        if not CheckpointSaver.check_checkpoint_valid(checkpoint_dir_for_init):
            raise ValueError("Invalid checkpoint directory")

        self.parameters = CheckpointSaver.restore_params_from_checkpoint(
            checkpoint_dir_for_init, self.ps_id, self.num_ps_pods)
        self.parameters.init_status = True
        self.logger.info("The version of restored parameters is %d" %
                         self.parameters.version)