Exemple #1
0
    def __init__(self, n_batches_per_epoch: int) -> None:
        self.n_batches_per_epoch = n_batches_per_epoch

        # Set up counter

        # Set up evaluation/checkpointing unit (sample, batch, epoch)
        self.counter_unit = Meta.config["logging_config"]["counter_unit"]

        if self.counter_unit not in ["sample", "batch", "epoch"]:
            raise ValueError(f"Unrecognized unit: {self.counter_unit}")

        # Set up evaluation frequency
        self.evaluation_freq = Meta.config["logging_config"]["evaluation_freq"]
        logger.info(
            f"Evaluating every {self.evaluation_freq} {self.counter_unit}.")

        if Meta.config["logging_config"]["checkpointing"]:
            self.checkpointing = True

            # Set up checkpointing frequency
            self.checkpointing_freq = int(
                Meta.config["logging_config"]["checkpointer_config"]
                ["checkpoint_freq"])
            logger.info(
                f"Checkpointing every "
                f"{self.checkpointing_freq * self.evaluation_freq} {self.counter_unit}."
            )

            # Set up checkpointer
            self.checkpointer = Checkpointer()
        else:
            self.checkpointing = False
            logger.info("No checkpointing.")

        # Set up number of samples passed since last evaluation/checkpointing and
        # total number of samples passed since learning process
        self.sample_count: int = 0
        self.sample_total: int = 0

        # Set up number of batches passed since last evaluation/checkpointing and
        # total number of batches passed since learning process
        self.batch_count: int = 0
        self.batch_total: int = 0

        # Set up number of epochs passed since last evaluation/checkpointing and
        # total number of epochs passed since learning process
        self.epoch_count: Union[float, int] = 0
        self.epoch_total: Union[float, int] = 0

        # Set up number of unit passed since last evaluation/checkpointing and
        # total number of unit passed since learning process
        self.unit_count: Union[float, int] = 0
        self.unit_total: Union[float, int] = 0

        # Set up count that triggers the evaluation since last checkpointing
        self.trigger_count = 0

        # Set up log writer
        writer_opt = Meta.config["logging_config"]["writer_config"]["writer"]

        if writer_opt is None:
            self.writer = None
        elif writer_opt == "json":
            self.writer = LogWriter()
        elif writer_opt == "tensorboard":
            self.writer = TensorBoardWriter()
        else:
            raise ValueError(f"Unrecognized writer option '{writer_opt}'")
Exemple #2
0
def test_log_writer(caplog):
    """Unit test of log_writer."""
    caplog.set_level(logging.INFO)

    emmental.Meta.reset()

    emmental.init()
    emmental.Meta.update_config(
        config={
            "logging_config": {
                "counter_unit": "sample",
                "evaluation_freq": 10,
                "checkpointing": True,
                "checkpointer_config": {
                    "checkpoint_freq": 2
                },
            }
        })

    log_writer = LogWriter()

    log_writer.add_config(emmental.Meta.config)
    log_writer.add_scalar(name="step 1", value=0.1, step=1)
    log_writer.add_scalar(name="step 2", value=0.2, step=2)

    config_filename = "config.yaml"
    log_writer.write_config(config_filename)

    # Test config
    with open(os.path.join(emmental.Meta.log_path, config_filename), "r") as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    assert config["meta_config"]["verbose"] is True
    assert config["logging_config"]["counter_unit"] == "sample"
    assert config["logging_config"]["checkpointing"] is True

    log_filename = "log.json"
    log_writer.write_log(log_filename)

    # Test log
    with open(os.path.join(emmental.Meta.log_path, log_filename), "r") as f:
        log = json.load(f)

    assert log == {"step 1": [[1, 0.1]], "step 2": [[2, 0.2]]}

    log_writer.close()