コード例 #1
0
ファイル: test_log_writer.py プロジェクト: SenWu/emmental
def test_tensorboard_writer(caplog):
    """Unit test of log_writer."""
    caplog.set_level(logging.INFO)

    emmental.Meta.reset()

    emmental.init()

    log_writer = TensorBoardWriter()

    log_writer.add_scalar(name="step 1", value=0.1, step=1)
    log_writer.add_scalar(name="step 2", value=0.2, step=2)

    config_filename = "config.yaml"
    log_writer.write_config(config_filename)

    # Test config
    with open(os.path.join(emmental.Meta.log_path, config_filename), "r") as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    assert config["meta_config"]["verbose"] is True
    assert config["logging_config"]["counter_unit"] == "epoch"
    assert config["logging_config"]["checkpointing"] is False

    log_writer.write_log()

    log_writer.close()
コード例 #2
0
class LoggingManager(object):
    r"""A class to manage logging during training progress.

    Args:
      n_batches_per_epoch(int): Total number batches per epoch.

    """
    def __init__(self, n_batches_per_epoch: int) -> None:
        self.n_batches_per_epoch = n_batches_per_epoch

        # Set up counter

        # Set up evaluation/checkpointing unit (sample, batch, epoch)
        self.counter_unit = Meta.config["logging_config"]["counter_unit"]

        if self.counter_unit not in ["sample", "batch", "epoch"]:
            raise ValueError(f"Unrecognized unit: {self.counter_unit}")

        # Set up evaluation frequency
        self.evaluation_freq = Meta.config["logging_config"]["evaluation_freq"]
        logger.info(
            f"Evaluating every {self.evaluation_freq} {self.counter_unit}.")

        if Meta.config["logging_config"]["checkpointing"]:
            self.checkpointing = True

            # Set up checkpointing frequency
            self.checkpointing_freq = int(
                Meta.config["logging_config"]["checkpointer_config"]
                ["checkpoint_freq"])
            logger.info(
                f"Checkpointing every "
                f"{self.checkpointing_freq * self.evaluation_freq} {self.counter_unit}."
            )

            # Set up checkpointer
            self.checkpointer = Checkpointer()
        else:
            self.checkpointing = False
            logger.info("No checkpointing.")

        # Set up number of samples passed since last evaluation/checkpointing and
        # total number of samples passed since learning process
        self.sample_count: int = 0
        self.sample_total: int = 0

        # Set up number of batches passed since last evaluation/checkpointing and
        # total number of batches passed since learning process
        self.batch_count: int = 0
        self.batch_total: int = 0

        # Set up number of epochs passed since last evaluation/checkpointing and
        # total number of epochs passed since learning process
        self.epoch_count: Union[float, int] = 0
        self.epoch_total: Union[float, int] = 0

        # Set up number of unit passed since last evaluation/checkpointing and
        # total number of unit passed since learning process
        self.unit_count: Union[float, int] = 0
        self.unit_total: Union[float, int] = 0

        # Set up count that triggers the evaluation since last checkpointing
        self.trigger_count = 0

        # Set up log writer
        writer_opt = Meta.config["logging_config"]["writer_config"]["writer"]

        if writer_opt is None:
            self.writer = None
        elif writer_opt == "json":
            self.writer = LogWriter()
        elif writer_opt == "tensorboard":
            self.writer = TensorBoardWriter()
        else:
            raise ValueError(f"Unrecognized writer option '{writer_opt}'")

    def update(self, batch_size: int) -> None:
        r"""Update the counter.

        Args:
          batch_size(int): The number of the samples in the batch.

        """

        # Update number of samples
        self.sample_count += batch_size
        self.sample_total += batch_size

        # Update number of batches
        self.batch_count += 1
        self.batch_total += 1

        # Update number of epochs
        self.epoch_count = self.batch_count / self.n_batches_per_epoch
        self.epoch_total = self.batch_total / self.n_batches_per_epoch

        # Update number of units
        if self.counter_unit == "sample":
            self.unit_count = self.sample_count
            self.unit_total = self.sample_total
        if self.counter_unit == "batch":
            self.unit_count = self.batch_count
            self.unit_total = self.batch_total
        elif self.counter_unit == "epoch":
            self.unit_count = self.epoch_count
            self.unit_total = self.epoch_total

    def trigger_evaluation(self) -> bool:
        r"""Check if triggers the evaluation."""

        satisfied = self.unit_count >= self.evaluation_freq
        if satisfied:
            self.trigger_count += 1
            self.reset()
        return satisfied

    def trigger_checkpointing(self) -> bool:
        r"""Check if triggers the checkpointing."""

        if not self.checkpointing:
            return False
        satisfied = self.trigger_count >= self.checkpointing_freq
        if satisfied:
            self.trigger_count = 0
        return satisfied

    def reset(self) -> None:
        r"""Reset the counter."""

        self.sample_count = 0
        self.batch_count = 0
        self.epoch_count = 0
        self.unit_count = 0

    def write_log(self, metric_dict: Dict[str, float]) -> None:
        r"""Write the metrics to the log.

        Args:
          metric_dict(dict): The metric dict.

        """
        for metric_name, metric_value in metric_dict.items():
            self.writer.add_scalar(metric_name, metric_value, self.batch_total)

    def checkpoint_model(
        self,
        model: EmmentalModel,
        optimizer: Optimizer,
        lr_scheduler: _LRScheduler,
        metric_dict: Dict[str, float],
    ) -> None:
        r"""Checkpoint the model.

        Args:
          model(EmmentalModel): The model to checkpoint.
          optimizer(Optimizer): The optimizer used during training process.
          lr_scheduler(_LRScheduler): Learning rate scheduler.
          metric_dict(dict): the metric dict.

        """

        self.checkpointer.checkpoint(self.unit_total, model, optimizer,
                                     lr_scheduler, metric_dict)

    def close(self, model: EmmentalModel) -> EmmentalModel:
        r"""Close the checkpointer and reload the model if necessary.

        Args:
          model(EmmentalModel): The trained model.

        Returns:
          EmmentalModel: The reloaded model if necessary

        """
        self.writer.close()
        if self.checkpointing:
            model = self.checkpointer.load_best_model(model)
            self.checkpointer.clear()
        return model