Exemplo n.º 1
0
def test_output_handler_with_global_step_from_engine():

    mock_another_engine = MagicMock()
    mock_another_engine.state = State()
    mock_another_engine.state.epoch = 10
    mock_another_engine.state.output = 12.345

    wrapper = OutputHandler(
        "tag",
        output_transform=lambda x: {"loss": x},
        global_step_transform=global_step_from_engine(mock_another_engine),
    )

    mock_logger = MagicMock(spec=MLflowLogger)
    mock_logger.log_metrics = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.epoch = 1
    mock_engine.state.output = 0.123

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
    assert mock_logger.log_metrics.call_count == 1
    mock_logger.log_metrics.assert_has_calls(
        [call({"tag loss": mock_engine.state.output}, step=mock_another_engine.state.epoch)]
    )

    mock_another_engine.state.epoch = 11
    mock_engine.state.output = 1.123

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
    assert mock_logger.log_metrics.call_count == 2
    mock_logger.log_metrics.assert_has_calls(
        [call({"tag loss": mock_engine.state.output}, step=mock_another_engine.state.epoch)]
    )
    def __call__(self, model, train_dataset, val_dataset=None, **_):
        """Train a PyTorch model.

        Args:
            model (torch.nn.Module): PyTorch model to train.
            train_dataset (torch.utils.data.Dataset): Dataset used to train.
            val_dataset (torch.utils.data.Dataset, optional): Dataset used to validate.

        Returns:
            trained_model (torch.nn.Module): Trained PyTorch model.
        """
        assert train_dataset is not None
        train_params = self.train_params
        mlflow_logging = self.mlflow_logging

        if mlflow_logging:
            try:
                import mlflow  # NOQA
            except ImportError:
                log.warning(
                    "Failed to import mlflow. MLflow logging is disabled.")
                mlflow_logging = False

        loss_fn = train_params.get("loss_fn")
        assert loss_fn
        epochs = train_params.get("epochs")
        seed = train_params.get("seed")
        optimizer = train_params.get("optimizer")
        assert optimizer
        optimizer_params = train_params.get("optimizer_params", dict())
        train_dataset_size_limit = train_params.get("train_dataset_size_limit")
        if train_dataset_size_limit:
            train_dataset = PartialDataset(train_dataset,
                                           train_dataset_size_limit)
            log.info("train dataset size is set to {}".format(
                len(train_dataset)))

        val_dataset_size_limit = train_params.get("val_dataset_size_limit")
        if val_dataset_size_limit and (val_dataset is not None):
            val_dataset = PartialDataset(val_dataset, val_dataset_size_limit)
            log.info("val dataset size is set to {}".format(len(val_dataset)))

        train_data_loader_params = train_params.get("train_data_loader_params",
                                                    dict())
        val_data_loader_params = train_params.get("val_data_loader_params",
                                                  dict())
        evaluation_metrics = train_params.get("evaluation_metrics")
        evaluate_train_data = train_params.get("evaluate_train_data")
        evaluate_val_data = train_params.get("evaluate_val_data")
        progress_update = train_params.get("progress_update")

        scheduler = train_params.get("scheduler")
        scheduler_params = train_params.get("scheduler_params", dict())

        model_checkpoint = train_params.get("model_checkpoint")
        model_checkpoint_params = train_params.get("model_checkpoint_params")
        early_stopping_params = train_params.get("early_stopping_params")
        time_limit = train_params.get("time_limit")

        cudnn_deterministic = train_params.get("cudnn_deterministic")
        cudnn_benchmark = train_params.get("cudnn_benchmark")

        if seed:
            torch.manual_seed(seed)
            np.random.seed(seed)
        if cudnn_deterministic:
            torch.backends.cudnn.deterministic = cudnn_deterministic
        if cudnn_benchmark:
            torch.backends.cudnn.benchmark = cudnn_benchmark

        device = "cuda" if torch.cuda.is_available() else "cpu"
        model.to(device)
        optimizer_ = optimizer(model.parameters(), **optimizer_params)
        trainer = create_supervised_trainer(model,
                                            optimizer_,
                                            loss_fn=loss_fn,
                                            device=device)

        train_data_loader_params.setdefault("shuffle", True)
        train_data_loader_params.setdefault("drop_last", True)
        train_data_loader_params["batch_size"] = _clip_batch_size(
            train_data_loader_params.get("batch_size", 1), train_dataset,
            "train")
        train_loader = DataLoader(train_dataset, **train_data_loader_params)

        RunningAverage(output_transform=lambda x: x,
                       alpha=0.98).attach(trainer, "ema_loss")

        RunningAverage(output_transform=lambda x: x,
                       alpha=2**(-1022)).attach(trainer, "batch_loss")

        if scheduler:

            class ParamSchedulerSavingAsMetric(
                    ParamSchedulerSavingAsMetricMixIn, scheduler):
                pass

            cycle_epochs = scheduler_params.pop("cycle_epochs", 1)
            scheduler_params.setdefault("cycle_size",
                                        int(cycle_epochs * len(train_loader)))
            scheduler_params.setdefault("param_name", "lr")
            scheduler_ = ParamSchedulerSavingAsMetric(optimizer_,
                                                      **scheduler_params)
            trainer.add_event_handler(Events.ITERATION_STARTED, scheduler_)

        if evaluate_train_data:
            evaluator_train = create_supervised_evaluator(
                model, metrics=evaluation_metrics, device=device)

        if evaluate_val_data:
            val_data_loader_params["batch_size"] = _clip_batch_size(
                val_data_loader_params.get("batch_size", 1), val_dataset,
                "val")
            val_loader = DataLoader(val_dataset, **val_data_loader_params)
            evaluator_val = create_supervised_evaluator(
                model, metrics=evaluation_metrics, device=device)

        if model_checkpoint_params:
            assert isinstance(model_checkpoint_params, dict)
            minimize = model_checkpoint_params.pop("minimize", True)
            save_interval = model_checkpoint_params.get("save_interval", None)
            if not save_interval:
                model_checkpoint_params.setdefault(
                    "score_function",
                    get_score_function("ema_loss", minimize=minimize))
            model_checkpoint_params.setdefault("score_name", "ema_loss")
            mc = model_checkpoint(**model_checkpoint_params)
            trainer.add_event_handler(Events.EPOCH_COMPLETED, mc,
                                      {"model": model})

        if early_stopping_params:
            assert isinstance(early_stopping_params, dict)
            metric = early_stopping_params.pop("metric", None)
            assert (metric is None) or (metric in evaluation_metrics)
            minimize = early_stopping_params.pop("minimize", False)
            if metric:
                assert (
                    "score_function" not in early_stopping_params
                ), "Remove either 'metric' or 'score_function' from early_stopping_params: {}".format(
                    early_stopping_params)
                early_stopping_params["score_function"] = get_score_function(
                    metric, minimize=minimize)

            es = EarlyStopping(trainer=trainer, **early_stopping_params)
            if evaluate_val_data:
                evaluator_val.add_event_handler(Events.COMPLETED, es)
            elif evaluate_train_data:
                evaluator_train.add_event_handler(Events.COMPLETED, es)
            elif early_stopping_params:
                log.warning(
                    "Early Stopping is disabled because neither "
                    "evaluate_val_data nor evaluate_train_data is set True.")

        if time_limit:
            assert isinstance(time_limit, (int, float))
            tl = TimeLimit(limit_sec=time_limit)
            trainer.add_event_handler(Events.ITERATION_COMPLETED, tl)

        pbar = None
        if progress_update:
            if not isinstance(progress_update, dict):
                progress_update = dict()
            progress_update.setdefault("persist", True)
            progress_update.setdefault("desc", "")
            pbar = ProgressBar(**progress_update)
            pbar.attach(trainer, ["ema_loss"])

        else:

            def log_train_metrics(engine):
                log.info("[Epoch: {} | {}]".format(engine.state.epoch,
                                                   engine.state.metrics))

            trainer.add_event_handler(Events.EPOCH_COMPLETED,
                                      log_train_metrics)

        if evaluate_train_data:

            def log_evaluation_train_data(engine):
                evaluator_train.run(train_loader)
                train_report = _get_report_str(engine, evaluator_train,
                                               "Train Data")
                if pbar:
                    pbar.log_message(train_report)
                else:
                    log.info(train_report)

            eval_train_event = (Events[evaluate_train_data] if isinstance(
                evaluate_train_data, str) else Events.EPOCH_COMPLETED)
            trainer.add_event_handler(eval_train_event,
                                      log_evaluation_train_data)

        if evaluate_val_data:

            def log_evaluation_val_data(engine):
                evaluator_val.run(val_loader)
                val_report = _get_report_str(engine, evaluator_val, "Val Data")
                if pbar:
                    pbar.log_message(val_report)
                else:
                    log.info(val_report)

            eval_val_event = (Events[evaluate_val_data] if isinstance(
                evaluate_val_data, str) else Events.EPOCH_COMPLETED)
            trainer.add_event_handler(eval_val_event, log_evaluation_val_data)

        if mlflow_logging:
            mlflow_logger = MLflowLogger()

            logging_params = {
                "train_n_samples": len(train_dataset),
                "train_n_batches": len(train_loader),
                "optimizer": _name(optimizer),
                "loss_fn": _name(loss_fn),
                "pytorch_version": torch.__version__,
                "ignite_version": ignite.__version__,
            }
            logging_params.update(_loggable_dict(optimizer_params,
                                                 "optimizer"))
            logging_params.update(
                _loggable_dict(train_data_loader_params, "train"))
            if scheduler:
                logging_params.update({"scheduler": _name(scheduler)})
                logging_params.update(
                    _loggable_dict(scheduler_params, "scheduler"))

            if evaluate_val_data:
                logging_params.update({
                    "val_n_samples": len(val_dataset),
                    "val_n_batches": len(val_loader),
                })
                logging_params.update(
                    _loggable_dict(val_data_loader_params, "val"))

            mlflow_logger.log_params(logging_params)

            batch_metric_names = ["batch_loss", "ema_loss"]
            if scheduler:
                batch_metric_names.append(scheduler_params.get("param_name"))

            mlflow_logger.attach(
                trainer,
                log_handler=OutputHandler(
                    tag="step",
                    metric_names=batch_metric_names,
                    global_step_transform=global_step_from_engine(trainer),
                ),
                event_name=Events.ITERATION_COMPLETED,
            )

            if evaluate_train_data:
                mlflow_logger.attach(
                    evaluator_train,
                    log_handler=OutputHandler(
                        tag="train",
                        metric_names=list(evaluation_metrics.keys()),
                        global_step_transform=global_step_from_engine(trainer),
                    ),
                    event_name=Events.COMPLETED,
                )
            if evaluate_val_data:
                mlflow_logger.attach(
                    evaluator_val,
                    log_handler=OutputHandler(
                        tag="val",
                        metric_names=list(evaluation_metrics.keys()),
                        global_step_transform=global_step_from_engine(trainer),
                    ),
                    event_name=Events.COMPLETED,
                )

        trainer.run(train_loader, max_epochs=epochs)

        try:
            if pbar and pbar.pbar:
                pbar.pbar.close()
        except Exception as e:
            log.error(e, exc_info=True)

        model = load_latest_model(model_checkpoint_params)(model)

        return model
Exemplo n.º 3
0
def build_engine(spec: RunSpec,
                 output_dir=None,
                 trainer=None,
                 metric_cls=RunningAverage,
                 tag="",
                 mlflow_logger=None,
                 is_training=None,
                 device=None):
    if spec.plot_event is not None or spec.log_event is not None:
        assert output_dir is not None
        plot_fname = os.path.join(output_dir, "{}-{}".format(tag, PLOT_FNAME))
        logs_fname = os.path.join(output_dir, "{}-{}".format(tag, LOGS_FNAME))
    else:
        plot_fname = None
        logs_fname = None
    # Create engine
    if device:
        to_device = tensors_to_device(device=device)

        def step(engine, batch):
            batch = to_device(batch)
            return spec.step(engine, batch)
    else:
        step = spec.step
    engine = Engine(step)
    if trainer is None:
        # training
        trainer = engine
        if is_training is None:
            is_training = True
    else:
        # evaluation
        if is_training is None:
            is_training = False
    spec.set_defaults(is_training=is_training)

    # Attach metrics
    for name, metric in spec.metrics.items():
        metric = auto_metric(metric, cls=metric_cls)
        metric.attach(engine, name)

    if spec.enable_timer:
        timer_metric(engine=engine)

    # Progress bar
    if spec.enable_pbar:
        ProgressBar(file=sys.stdout).attach(engine,
                                            metric_names=spec.pbar_metrics)

    # Print logs
    if spec.print_event is not None:
        engine.add_event_handler(event_name=spec.print_event,
                                 handler=print_logs,
                                 trainer=trainer,
                                 fmt=spec.print_fmt,
                                 metric_fmt=spec.print_metric_fmt,
                                 metric_names=spec.print_metrics)

    # Save logs
    if spec.log_event is not None:
        engine.add_event_handler(event_name=spec.log_event,
                                 handler=save_logs,
                                 fname=logs_fname,
                                 trainer=trainer,
                                 metric_names=spec.log_metrics)

    # Plot metrics
    if spec.plot_event is not None:
        # Plots require logs
        assert spec.log_event is not None
        engine.add_event_handler(event_name=spec.plot_event,
                                 handler=create_plots,
                                 logs_fname=logs_fname,
                                 plots_fname=plot_fname,
                                 metric_names=spec.plot_metrics)

    # Optional user callback for additional configuration
    chain_callbacks(callbacks=spec.callback, engine=engine, trainer=trainer)

    if mlflow_logger is not None and spec.log_event is not None:
        mlflow_logger.attach(
            engine,
            log_handler=OutputHandler(
                tag=tag,
                metric_names=spec.log_metrics,
                global_step_transform=global_step_from_engine(trainer)),
            event_name=spec.log_event)

    return engine