示例#1
0
 def testCheckpointAveraging(self):
     model_dir = os.path.join(self.get_temp_dir(), "ckpt")
     os.makedirs(model_dir)
     checkpoints = []
     checkpoints.append(
         self._generateCheckpoint(
             model_dir,
             10, {
                 "x": np.zeros((2, 3), dtype=np.float32),
                 "words_per_sec/features_init": np.int64(42)
             },
             last_checkpoints=checkpoints))
     checkpoints.append(
         self._generateCheckpoint(
             model_dir,
             20, {
                 "x": np.ones((2, 3), dtype=np.float32),
                 "words_per_sec/features_init": np.int64(89)
             },
             last_checkpoints=checkpoints))
     avg_dir = os.path.join(model_dir, "avg")
     checkpoint.average_checkpoints(model_dir, avg_dir)
     avg_var = checkpoint.get_checkpoint_variables(avg_dir)
     self.assertEqual(avg_var["global_step"].dtype, np.int64)
     self.assertEqual(avg_var["global_step"], 20)
     self.assertEqual(avg_var["words_per_sec/features_init"].dtype,
                      np.int64)
     self.assertEqual(avg_var["words_per_sec/features_init"], 89)
     self.assertAllEqual(avg_var["x"], np.full((2, 3),
                                               0.5,
                                               dtype=np.float32))
示例#2
0
    def testCheckpointAveraging(self):
        model = _DummyModel()
        optimizer = tf.keras.optimizers.Adam()

        @tf.function
        def _build_model():
            x = tf.random.uniform([4, 10])
            y = model(x)
            loss = tf.reduce_mean(y)
            gradients = optimizer.get_gradients(loss,
                                                model.trainable_variables)
            optimizer.apply_gradients(zip(gradients,
                                          model.trainable_variables))

        def _assign_var(var, scalar):
            var.assign(tf.ones_like(var) * scalar)

        def _all_equal(var, scalar):
            return tf.size(tf.where(tf.not_equal(var, scalar))).numpy() == 0

        def _get_var_list(checkpoint_path):
            return [
                name for name, _ in tf.train.list_variables(checkpoint_path)
            ]

        _build_model()

        # Write some checkpoint with all variables set to the step value.
        steps = [10, 20, 30, 40]
        num_checkpoints = len(steps)
        avg_value = sum(steps) / num_checkpoints
        directory = os.path.join(self.get_temp_dir(), "src")
        checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
        checkpoint_manager = tf.train.CheckpointManager(
            checkpoint, directory, max_to_keep=num_checkpoints)
        for step in steps:
            _assign_var(model.layers[0].kernel, step)
            _assign_var(model.layers[0].bias, step)
            checkpoint_manager.save(checkpoint_number=step)

        output_dir = os.path.join(self.get_temp_dir(), "dst")
        checkpoint_util.average_checkpoints(
            directory, output_dir, dict(model=model, optimizer=optimizer))
        avg_checkpoint = tf.train.latest_checkpoint(output_dir)
        self.assertIsNotNone(avg_checkpoint)
        self.assertEqual(
            checkpoint_util.get_step_from_checkpoint_prefix(avg_checkpoint),
            steps[-1])
        checkpoint.restore(avg_checkpoint)
        self.assertTrue(_all_equal(model.layers[0].kernel, avg_value))
        self.assertTrue(_all_equal(model.layers[0].bias, avg_value))
        self.assertListEqual(
            _get_var_list(avg_checkpoint),
            _get_var_list(checkpoint_manager.latest_checkpoint))
示例#3
0
    def average_checkpoints(self, output_dir, max_count=8):
        """Averages checkpoints.

        Args:
          output_dir: The directory that will contain the averaged checkpoint.
          max_count: The maximum number of checkpoints to average.

        Returns:
          The path to the directory containing the averaged checkpoint.
        """
        config = self._finalize_config()
        model = self._init_model(config)
        optimizer = model.get_optimizer()
        checkpoint = checkpoint_util.Checkpoint.from_config(
            config, model, optimizer=optimizer)
        checkpoint.restore()
        model.create_variables(optimizer=optimizer)
        trackables = dict(model=model, optimizer=optimizer)
        output_dir = checkpoint_util.average_checkpoints(checkpoint.model_dir,
                                                         output_dir,
                                                         trackables,
                                                         max_count=max_count)
        _forward_model_description(self.model_dir, output_dir)
        self._config["model_dir"] = output_dir
        return output_dir
示例#4
0
def main():
    tf.logging.set_verbosity(tf.logging.INFO)

    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--model_dir",
                        required=True,
                        help="The model directory containing the checkpoints.")
    parser.add_argument(
        "--output_dir",
        required=True,
        help="The output directory where the averaged checkpoint will be saved."
    )
    parser.add_argument("--max_count",
                        type=int,
                        default=8,
                        help="The maximal number of checkpoints to average.")
    args = parser.parse_args()
    average_checkpoints(args.model_dir,
                        args.output_dir,
                        max_count=args.max_count)
示例#5
0
    def average_checkpoints(self, output_dir, max_count=8):
        """Averages checkpoints.

    Args:
      output_dir: The directory that will contain the averaged checkpoint.
      max_count: The maximum number of checkpoints to average.

    Returns:
      The path to the directory containing the averaged checkpoint.
    """
        return checkpoint.average_checkpoints(
            self._config["model_dir"],
            output_dir,
            max_count=max_count,
            session_config=self._session_config)
示例#6
0
    def average_checkpoints(self, output_dir, max_count=8):
        """Averages checkpoints.

    Args:
      output_dir: The directory that will contain the averaged checkpoint.
      max_count: The maximum number of checkpoints to average.

    Returns:
      The path to the directory containing the averaged checkpoint.
    """
        checkpoint, _ = self._init_run()
        checkpoint.restore()
        model = checkpoint.model
        optimizer = checkpoint.optimizer
        model.create_variables(optimizer=optimizer)
        trackables = dict(model=model, optimizer=optimizer)
        return checkpoint_util.average_checkpoints(checkpoint.model_dir,
                                                   output_dir,
                                                   trackables,
                                                   max_count=max_count)