Example #1
0
def test_output_handler_output_transform():

    wrapper = OutputHandler("tag", output_transform=lambda x: x)
    mock_logger = MagicMock(spec=MLflowLogger)
    mock_logger.log_metrics = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.output = 12345
    mock_engine.state.iteration = 123

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    mock_logger.log_metrics.assert_called_once_with({"tag output": 12345},
                                                    step=123)

    wrapper = OutputHandler("another_tag",
                            output_transform=lambda x: {"loss": x})
    mock_logger = MagicMock(spec=MLflowLogger)
    mock_logger.log_metrics = MagicMock()

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
    mock_logger.log_metrics.assert_called_once_with(
        {"another_tag loss": 12345},
        step=123,
    )
Example #2
0
def test_mlflow_bad_metric_name_handling(dirname):
    import mlflow

    true_values = [123.0, 23.4, 333.4]
    with MLflowLogger(os.path.join(dirname, "mlruns")) as mlflow_logger:

        active_run = mlflow.active_run()

        handler = OutputHandler(tag="training", metric_names="all")
        engine = Engine(lambda e, b: None)
        engine.state = State(metrics={"metric:0 in %": 123.0, "metric 0": 1000.0,})

        with pytest.warns(UserWarning, match=r"MLflowLogger output_handler encountered an invalid metric name"):

            engine.state.epoch = 1
            handler(engine, mlflow_logger, event_name=Events.EPOCH_COMPLETED)

            for i, v in enumerate(true_values):
                engine.state.epoch += 1
                engine.state.metrics["metric 0"] = v
                handler(engine, mlflow_logger, event_name=Events.EPOCH_COMPLETED)

    from mlflow.tracking import MlflowClient

    client = MlflowClient(tracking_uri=os.path.join(dirname, "mlruns"))
    stored_values = client.get_metric_history(active_run.info.run_id, "training metric 0")

    for t, s in zip([1000.0,] + true_values, stored_values):
        assert t == s.value
Example #3
0
def test_output_handler_with_global_step_from_engine():

    mock_another_engine = MagicMock()
    mock_another_engine.state = State()
    mock_another_engine.state.epoch = 10
    mock_another_engine.state.output = 12.345

    wrapper = OutputHandler(
        "tag",
        output_transform=lambda x: {"loss": x},
        global_step_transform=global_step_from_engine(mock_another_engine),
    )

    mock_logger = MagicMock(spec=MLflowLogger)
    mock_logger.log_metrics = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.epoch = 1
    mock_engine.state.output = 0.123

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
    assert mock_logger.log_metrics.call_count == 1
    mock_logger.log_metrics.assert_has_calls(
        [call({"tag loss": mock_engine.state.output}, step=mock_another_engine.state.epoch)]
    )

    mock_another_engine.state.epoch = 11
    mock_engine.state.output = 1.123

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
    assert mock_logger.log_metrics.call_count == 2
    mock_logger.log_metrics.assert_has_calls(
        [call({"tag loss": mock_engine.state.output}, step=mock_another_engine.state.epoch)]
    )
Example #4
0
def test_output_handler_metric_names():

    wrapper = OutputHandler("tag", metric_names=["a", "b", "c"])
    mock_logger = MagicMock(spec=MLflowLogger)
    mock_logger.log_metrics = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State(metrics={"a": 12.23, "b": 23.45, "c": torch.tensor(10.0)})
    mock_engine.state.iteration = 5

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    assert mock_logger.log_metrics.call_count == 1
    mock_logger.log_metrics.assert_called_once_with(
        {"tag a": 12.23, "tag b": 23.45, "tag c": 10.0}, step=5,
    )

    wrapper = OutputHandler("tag", metric_names=["a",])

    mock_engine = MagicMock()
    mock_engine.state = State(metrics={"a": torch.Tensor([0.0, 1.0, 2.0, 3.0])})
    mock_engine.state.iteration = 5

    mock_logger = MagicMock(spec=MLflowLogger)
    mock_logger.log_metrics = MagicMock()

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    assert mock_logger.log_metrics.call_count == 1
    mock_logger.log_metrics.assert_has_calls(
        [call({"tag a 0": 0.0, "tag a 1": 1.0, "tag a 2": 2.0, "tag a 3": 3.0}, step=5),], any_order=True
    )

    wrapper = OutputHandler("tag", metric_names=["a", "c"])

    mock_engine = MagicMock()
    mock_engine.state = State(metrics={"a": 55.56, "c": "Some text"})
    mock_engine.state.iteration = 7

    mock_logger = MagicMock(spec=MLflowLogger)
    mock_logger.log_metrics = MagicMock()

    with pytest.warns(UserWarning):
        wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    assert mock_logger.log_metrics.call_count == 1
    mock_logger.log_metrics.assert_has_calls([call({"tag a": 55.56}, step=7)], any_order=True)
Example #5
0
def test_output_handler_with_wrong_logger_type():

    wrapper = OutputHandler("tag", output_transform=lambda x: x)

    mock_logger = MagicMock()
    mock_engine = MagicMock()
    with pytest.raises(TypeError, match="Handler 'OutputHandler' works only with MLflowLogger"):
        wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
Example #6
0
def test_output_handler_with_global_step_transform():
    def global_step_transform(*args, **kwargs):
        return 10

    wrapper = OutputHandler("tag", output_transform=lambda x: {"loss": x}, global_step_transform=global_step_transform)
    mock_logger = MagicMock(spec=MLflowLogger)
    mock_logger.log_metrics = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.epoch = 5
    mock_engine.state.output = 12345

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
    mock_logger.log_metrics.assert_called_once_with({"tag loss": 12345}, step=10)
Example #7
0
def test_output_handler_with_wrong_global_step_transform_output():
    def global_step_transform(*args, **kwargs):
        return "a"

    wrapper = OutputHandler("tag", output_transform=lambda x: {"loss": x}, global_step_transform=global_step_transform)
    mock_logger = MagicMock(spec=MLflowLogger)
    mock_logger.log_metrics = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.epoch = 5
    mock_engine.state.output = 12345

    with pytest.raises(TypeError, match="global_step must be int"):
        wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
Example #8
0
def test_output_handler_both():

    wrapper = OutputHandler("tag", metric_names=["a", "b"], output_transform=lambda x: {"loss": x})
    mock_logger = MagicMock(spec=MLflowLogger)
    mock_logger.log_metrics = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State(metrics={"a": 12.23, "b": 23.45})
    mock_engine.state.epoch = 5
    mock_engine.state.output = 12345

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)

    assert mock_logger.log_metrics.call_count == 1
    mock_logger.log_metrics.assert_called_once_with({"tag a": 12.23, "tag b": 23.45, "tag loss": 12345}, step=5)
Example #9
0
def test_output_handler_state_attrs():
    wrapper = OutputHandler("tag", state_attributes=["alpha", "beta", "gamma"])
    mock_logger = MagicMock(spec=MLflowLogger)
    mock_logger.log_metrics = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.iteration = 5
    mock_engine.state.alpha = 3.899
    mock_engine.state.beta = torch.tensor(12.21)
    mock_engine.state.gamma = torch.tensor([21.0, 6.0])

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    mock_logger.log_metrics.assert_called_once_with(
        {"tag alpha": 3.899, "tag beta": torch.tensor(12.21).item(), "tag gamma 0": 21.0, "tag gamma 1": 6.0,}, step=5,
    )
    def __call__(self, model, train_dataset, val_dataset=None, **_):
        """Train a PyTorch model.

        Args:
            model (torch.nn.Module): PyTorch model to train.
            train_dataset (torch.utils.data.Dataset): Dataset used to train.
            val_dataset (torch.utils.data.Dataset, optional): Dataset used to validate.

        Returns:
            trained_model (torch.nn.Module): Trained PyTorch model.
        """
        assert train_dataset is not None
        train_params = self.train_params
        mlflow_logging = self.mlflow_logging

        if mlflow_logging:
            try:
                import mlflow  # NOQA
            except ImportError:
                log.warning(
                    "Failed to import mlflow. MLflow logging is disabled.")
                mlflow_logging = False

        loss_fn = train_params.get("loss_fn")
        assert loss_fn
        epochs = train_params.get("epochs")
        seed = train_params.get("seed")
        optimizer = train_params.get("optimizer")
        assert optimizer
        optimizer_params = train_params.get("optimizer_params", dict())
        train_dataset_size_limit = train_params.get("train_dataset_size_limit")
        if train_dataset_size_limit:
            train_dataset = PartialDataset(train_dataset,
                                           train_dataset_size_limit)
            log.info("train dataset size is set to {}".format(
                len(train_dataset)))

        val_dataset_size_limit = train_params.get("val_dataset_size_limit")
        if val_dataset_size_limit and (val_dataset is not None):
            val_dataset = PartialDataset(val_dataset, val_dataset_size_limit)
            log.info("val dataset size is set to {}".format(len(val_dataset)))

        train_data_loader_params = train_params.get("train_data_loader_params",
                                                    dict())
        val_data_loader_params = train_params.get("val_data_loader_params",
                                                  dict())
        evaluation_metrics = train_params.get("evaluation_metrics")
        evaluate_train_data = train_params.get("evaluate_train_data")
        evaluate_val_data = train_params.get("evaluate_val_data")
        progress_update = train_params.get("progress_update")

        scheduler = train_params.get("scheduler")
        scheduler_params = train_params.get("scheduler_params", dict())

        model_checkpoint = train_params.get("model_checkpoint")
        model_checkpoint_params = train_params.get("model_checkpoint_params")
        early_stopping_params = train_params.get("early_stopping_params")
        time_limit = train_params.get("time_limit")

        cudnn_deterministic = train_params.get("cudnn_deterministic")
        cudnn_benchmark = train_params.get("cudnn_benchmark")

        if seed:
            torch.manual_seed(seed)
            np.random.seed(seed)
        if cudnn_deterministic:
            torch.backends.cudnn.deterministic = cudnn_deterministic
        if cudnn_benchmark:
            torch.backends.cudnn.benchmark = cudnn_benchmark

        device = "cuda" if torch.cuda.is_available() else "cpu"
        model.to(device)
        optimizer_ = optimizer(model.parameters(), **optimizer_params)
        trainer = create_supervised_trainer(model,
                                            optimizer_,
                                            loss_fn=loss_fn,
                                            device=device)

        train_data_loader_params.setdefault("shuffle", True)
        train_data_loader_params.setdefault("drop_last", True)
        train_data_loader_params["batch_size"] = _clip_batch_size(
            train_data_loader_params.get("batch_size", 1), train_dataset,
            "train")
        train_loader = DataLoader(train_dataset, **train_data_loader_params)

        RunningAverage(output_transform=lambda x: x,
                       alpha=0.98).attach(trainer, "ema_loss")

        RunningAverage(output_transform=lambda x: x,
                       alpha=2**(-1022)).attach(trainer, "batch_loss")

        if scheduler:

            class ParamSchedulerSavingAsMetric(
                    ParamSchedulerSavingAsMetricMixIn, scheduler):
                pass

            cycle_epochs = scheduler_params.pop("cycle_epochs", 1)
            scheduler_params.setdefault("cycle_size",
                                        int(cycle_epochs * len(train_loader)))
            scheduler_params.setdefault("param_name", "lr")
            scheduler_ = ParamSchedulerSavingAsMetric(optimizer_,
                                                      **scheduler_params)
            trainer.add_event_handler(Events.ITERATION_STARTED, scheduler_)

        if evaluate_train_data:
            evaluator_train = create_supervised_evaluator(
                model, metrics=evaluation_metrics, device=device)

        if evaluate_val_data:
            val_data_loader_params["batch_size"] = _clip_batch_size(
                val_data_loader_params.get("batch_size", 1), val_dataset,
                "val")
            val_loader = DataLoader(val_dataset, **val_data_loader_params)
            evaluator_val = create_supervised_evaluator(
                model, metrics=evaluation_metrics, device=device)

        if model_checkpoint_params:
            assert isinstance(model_checkpoint_params, dict)
            minimize = model_checkpoint_params.pop("minimize", True)
            save_interval = model_checkpoint_params.get("save_interval", None)
            if not save_interval:
                model_checkpoint_params.setdefault(
                    "score_function",
                    get_score_function("ema_loss", minimize=minimize))
            model_checkpoint_params.setdefault("score_name", "ema_loss")
            mc = model_checkpoint(**model_checkpoint_params)
            trainer.add_event_handler(Events.EPOCH_COMPLETED, mc,
                                      {"model": model})

        if early_stopping_params:
            assert isinstance(early_stopping_params, dict)
            metric = early_stopping_params.pop("metric", None)
            assert (metric is None) or (metric in evaluation_metrics)
            minimize = early_stopping_params.pop("minimize", False)
            if metric:
                assert (
                    "score_function" not in early_stopping_params
                ), "Remove either 'metric' or 'score_function' from early_stopping_params: {}".format(
                    early_stopping_params)
                early_stopping_params["score_function"] = get_score_function(
                    metric, minimize=minimize)

            es = EarlyStopping(trainer=trainer, **early_stopping_params)
            if evaluate_val_data:
                evaluator_val.add_event_handler(Events.COMPLETED, es)
            elif evaluate_train_data:
                evaluator_train.add_event_handler(Events.COMPLETED, es)
            elif early_stopping_params:
                log.warning(
                    "Early Stopping is disabled because neither "
                    "evaluate_val_data nor evaluate_train_data is set True.")

        if time_limit:
            assert isinstance(time_limit, (int, float))
            tl = TimeLimit(limit_sec=time_limit)
            trainer.add_event_handler(Events.ITERATION_COMPLETED, tl)

        pbar = None
        if progress_update:
            if not isinstance(progress_update, dict):
                progress_update = dict()
            progress_update.setdefault("persist", True)
            progress_update.setdefault("desc", "")
            pbar = ProgressBar(**progress_update)
            pbar.attach(trainer, ["ema_loss"])

        else:

            def log_train_metrics(engine):
                log.info("[Epoch: {} | {}]".format(engine.state.epoch,
                                                   engine.state.metrics))

            trainer.add_event_handler(Events.EPOCH_COMPLETED,
                                      log_train_metrics)

        if evaluate_train_data:

            def log_evaluation_train_data(engine):
                evaluator_train.run(train_loader)
                train_report = _get_report_str(engine, evaluator_train,
                                               "Train Data")
                if pbar:
                    pbar.log_message(train_report)
                else:
                    log.info(train_report)

            eval_train_event = (Events[evaluate_train_data] if isinstance(
                evaluate_train_data, str) else Events.EPOCH_COMPLETED)
            trainer.add_event_handler(eval_train_event,
                                      log_evaluation_train_data)

        if evaluate_val_data:

            def log_evaluation_val_data(engine):
                evaluator_val.run(val_loader)
                val_report = _get_report_str(engine, evaluator_val, "Val Data")
                if pbar:
                    pbar.log_message(val_report)
                else:
                    log.info(val_report)

            eval_val_event = (Events[evaluate_val_data] if isinstance(
                evaluate_val_data, str) else Events.EPOCH_COMPLETED)
            trainer.add_event_handler(eval_val_event, log_evaluation_val_data)

        if mlflow_logging:
            mlflow_logger = MLflowLogger()

            logging_params = {
                "train_n_samples": len(train_dataset),
                "train_n_batches": len(train_loader),
                "optimizer": _name(optimizer),
                "loss_fn": _name(loss_fn),
                "pytorch_version": torch.__version__,
                "ignite_version": ignite.__version__,
            }
            logging_params.update(_loggable_dict(optimizer_params,
                                                 "optimizer"))
            logging_params.update(
                _loggable_dict(train_data_loader_params, "train"))
            if scheduler:
                logging_params.update({"scheduler": _name(scheduler)})
                logging_params.update(
                    _loggable_dict(scheduler_params, "scheduler"))

            if evaluate_val_data:
                logging_params.update({
                    "val_n_samples": len(val_dataset),
                    "val_n_batches": len(val_loader),
                })
                logging_params.update(
                    _loggable_dict(val_data_loader_params, "val"))

            mlflow_logger.log_params(logging_params)

            batch_metric_names = ["batch_loss", "ema_loss"]
            if scheduler:
                batch_metric_names.append(scheduler_params.get("param_name"))

            mlflow_logger.attach(
                trainer,
                log_handler=OutputHandler(
                    tag="step",
                    metric_names=batch_metric_names,
                    global_step_transform=global_step_from_engine(trainer),
                ),
                event_name=Events.ITERATION_COMPLETED,
            )

            if evaluate_train_data:
                mlflow_logger.attach(
                    evaluator_train,
                    log_handler=OutputHandler(
                        tag="train",
                        metric_names=list(evaluation_metrics.keys()),
                        global_step_transform=global_step_from_engine(trainer),
                    ),
                    event_name=Events.COMPLETED,
                )
            if evaluate_val_data:
                mlflow_logger.attach(
                    evaluator_val,
                    log_handler=OutputHandler(
                        tag="val",
                        metric_names=list(evaluation_metrics.keys()),
                        global_step_transform=global_step_from_engine(trainer),
                    ),
                    event_name=Events.COMPLETED,
                )

        trainer.run(train_loader, max_epochs=epochs)

        try:
            if pbar and pbar.pbar:
                pbar.pbar.close()
        except Exception as e:
            log.error(e, exc_info=True)

        model = load_latest_model(model_checkpoint_params)(model)

        return model
Example #11
0
def inference(config, local_rank, with_pbar_on_iters=True):

    set_seed(config.seed + local_rank)
    torch.cuda.set_device(local_rank)
    device = 'cuda'

    torch.backends.cudnn.benchmark = True

    # Load model and weights
    model_weights_filepath = Path(
        get_artifact_path(config.run_uuid, config.weights_filename))
    assert model_weights_filepath.exists(), \
        "Model weights file '{}' is not found".format(model_weights_filepath.as_posix())

    model = config.model.to(device)
    model = torch.nn.parallel.DistributedDataParallel(model,
                                                      device_ids=[local_rank],
                                                      output_device=local_rank)

    if hasattr(config, "custom_weights_loading"):
        config.custom_weights_loading(model, model_weights_filepath)
    else:
        state_dict = torch.load(model_weights_filepath)
        if not all([k.startswith("module.") for k in state_dict]):
            state_dict = {f"module.{k}": v for k, v in state_dict.items()}
        model.load_state_dict(state_dict)

    model.eval()

    prepare_batch = config.prepare_batch
    non_blocking = getattr(config, "non_blocking", True)
    model_output_transform = getattr(config, "model_output_transform",
                                     lambda x: x)

    tta_transforms = getattr(config, "tta_transforms", None)

    def eval_update_function(engine, batch):
        with torch.no_grad():
            x, y, meta = prepare_batch(batch,
                                       device=device,
                                       non_blocking=non_blocking)

            if tta_transforms is not None:
                y_preds = []
                for t in tta_transforms:
                    t_x = t.augment_image(x)
                    t_y_pred = model(t_x)
                    t_y_pred = model_output_transform(t_y_pred)
                    y_pred = t.deaugment_mask(t_y_pred)
                    y_preds.append(y_pred)

                y_preds = torch.stack(y_preds, dim=0)
                y_pred = torch.mean(y_preds, dim=0)
            else:
                y_pred = model(x)
                y_pred = model_output_transform(y_pred)
            return {"y_pred": y_pred, "y": y, "meta": meta}

    evaluator = Engine(eval_update_function)

    has_targets = getattr(config, "has_targets", False)

    if has_targets:

        def output_transform(output):
            return output['y_pred'], output['y']

        num_classes = config.num_classes
        cm_metric = ConfusionMatrix(num_classes=num_classes,
                                    output_transform=output_transform)
        pr = cmPrecision(cm_metric, average=False)
        re = cmRecall(cm_metric, average=False)

        val_metrics = {
            "IoU": IoU(cm_metric),
            "mIoU_bg": mIoU(cm_metric),
            "Accuracy": cmAccuracy(cm_metric),
            "Precision": pr,
            "Recall": re,
            "F1": Fbeta(beta=1.0, output_transform=output_transform)
        }

        if hasattr(config, "metrics") and isinstance(config.metrics, dict):
            val_metrics.update(config.metrics)

        for name, metric in val_metrics.items():
            metric.attach(evaluator, name)

        if dist.get_rank() == 0:
            # Log val metrics:
            mlflow_logger = MLflowLogger()
            mlflow_logger.attach(evaluator,
                                 log_handler=OutputHandler(
                                     tag="validation",
                                     metric_names=list(val_metrics.keys())),
                                 event_name=Events.EPOCH_COMPLETED)

    if dist.get_rank() == 0 and with_pbar_on_iters:
        ProgressBar(persist=True, desc="Inference").attach(evaluator)

    if dist.get_rank() == 0:
        do_save_raw_predictions = getattr(config, "do_save_raw_predictions",
                                          True)
        do_save_overlayed_predictions = getattr(
            config, "do_save_overlayed_predictions", True)

        if not has_targets:
            assert do_save_raw_predictions or do_save_overlayed_predictions, \
                "If no targets, either do_save_overlayed_predictions or do_save_raw_predictions should be " \
                "defined in the config and has value equal True"

        # Save predictions
        if do_save_raw_predictions:
            raw_preds_path = config.output_path / "raw"
            raw_preds_path.mkdir(parents=True)

            evaluator.add_event_handler(Events.ITERATION_COMPLETED,
                                        save_raw_predictions_with_geoinfo,
                                        raw_preds_path)

        if do_save_overlayed_predictions:
            overlayed_preds_path = config.output_path / "overlay"
            overlayed_preds_path.mkdir(parents=True)

            evaluator.add_event_handler(
                Events.ITERATION_COMPLETED,
                save_overlayed_predictions,
                overlayed_preds_path,
                img_denormalize_fn=config.img_denormalize,
                palette=default_palette)

    evaluator.add_event_handler(Events.EXCEPTION_RAISED, report_exception)

    # Run evaluation
    evaluator.run(config.data_loader)
Example #12
0
def build_engine(spec: RunSpec,
                 output_dir=None,
                 trainer=None,
                 metric_cls=RunningAverage,
                 tag="",
                 mlflow_logger=None,
                 is_training=None,
                 device=None):
    if spec.plot_event is not None or spec.log_event is not None:
        assert output_dir is not None
        plot_fname = os.path.join(output_dir, "{}-{}".format(tag, PLOT_FNAME))
        logs_fname = os.path.join(output_dir, "{}-{}".format(tag, LOGS_FNAME))
    else:
        plot_fname = None
        logs_fname = None
    # Create engine
    if device:
        to_device = tensors_to_device(device=device)

        def step(engine, batch):
            batch = to_device(batch)
            return spec.step(engine, batch)
    else:
        step = spec.step
    engine = Engine(step)
    if trainer is None:
        # training
        trainer = engine
        if is_training is None:
            is_training = True
    else:
        # evaluation
        if is_training is None:
            is_training = False
    spec.set_defaults(is_training=is_training)

    # Attach metrics
    for name, metric in spec.metrics.items():
        metric = auto_metric(metric, cls=metric_cls)
        metric.attach(engine, name)

    if spec.enable_timer:
        timer_metric(engine=engine)

    # Progress bar
    if spec.enable_pbar:
        ProgressBar(file=sys.stdout).attach(engine,
                                            metric_names=spec.pbar_metrics)

    # Print logs
    if spec.print_event is not None:
        engine.add_event_handler(event_name=spec.print_event,
                                 handler=print_logs,
                                 trainer=trainer,
                                 fmt=spec.print_fmt,
                                 metric_fmt=spec.print_metric_fmt,
                                 metric_names=spec.print_metrics)

    # Save logs
    if spec.log_event is not None:
        engine.add_event_handler(event_name=spec.log_event,
                                 handler=save_logs,
                                 fname=logs_fname,
                                 trainer=trainer,
                                 metric_names=spec.log_metrics)

    # Plot metrics
    if spec.plot_event is not None:
        # Plots require logs
        assert spec.log_event is not None
        engine.add_event_handler(event_name=spec.plot_event,
                                 handler=create_plots,
                                 logs_fname=logs_fname,
                                 plots_fname=plot_fname,
                                 metric_names=spec.plot_metrics)

    # Optional user callback for additional configuration
    chain_callbacks(callbacks=spec.callback, engine=engine, trainer=trainer)

    if mlflow_logger is not None and spec.log_event is not None:
        mlflow_logger.attach(
            engine,
            log_handler=OutputHandler(
                tag=tag,
                metric_names=spec.log_metrics,
                global_step_transform=global_step_from_engine(trainer)),
            event_name=spec.log_event)

    return engine
Example #13
0
    ### TRAINING
    # Attach metrics
    metrics = ["loss_d", "loss_g"]
    RunningAverage(alpha=0.98, output_transform=lambda x: x["loss_d"]).attach(
        engine, "loss_d")
    RunningAverage(alpha=0.98, output_transform=lambda x: x["loss_g"]).attach(
        engine, "loss_g")

    if args.local_rank <= 0:
        pbar = ProgressBar()
        pbar.attach(engine, metric_names=metrics)

        mlflow_logger.attach(engine,
                             log_handler=OutputHandler(tag="generator/loss",
                                                       metric_names=["loss_g"
                                                                     ]),
                             event_name=Events.ITERATION_COMPLETED(every=10))
        mlflow_logger.attach(engine,
                             log_handler=OutputHandler(
                                 tag="discriminator/loss",
                                 metric_names=["loss_d"]),
                             event_name=Events.ITERATION_COMPLETED(every=10))

        @engine.on(Events.EPOCH_COMPLETED)
        def log_times(engine):
            pbar.log_message(
                "Epoch {} finished: Batch average time is {:.3f}".format(
                    engine.state.epoch, timer.value()))
            timer.reset()