def test_output_handler_output_transform(): wrapper = OutputHandler("tag", output_transform=lambda x: x) mock_logger = MagicMock(spec=MLflowLogger) mock_logger.log_metrics = MagicMock() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.output = 12345 mock_engine.state.iteration = 123 wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) mock_logger.log_metrics.assert_called_once_with({"tag output": 12345}, step=123) wrapper = OutputHandler("another_tag", output_transform=lambda x: {"loss": x}) mock_logger = MagicMock(spec=MLflowLogger) mock_logger.log_metrics = MagicMock() wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) mock_logger.log_metrics.assert_called_once_with( {"another_tag loss": 12345}, step=123, )
def test_mlflow_bad_metric_name_handling(dirname): import mlflow true_values = [123.0, 23.4, 333.4] with MLflowLogger(os.path.join(dirname, "mlruns")) as mlflow_logger: active_run = mlflow.active_run() handler = OutputHandler(tag="training", metric_names="all") engine = Engine(lambda e, b: None) engine.state = State(metrics={"metric:0 in %": 123.0, "metric 0": 1000.0,}) with pytest.warns(UserWarning, match=r"MLflowLogger output_handler encountered an invalid metric name"): engine.state.epoch = 1 handler(engine, mlflow_logger, event_name=Events.EPOCH_COMPLETED) for i, v in enumerate(true_values): engine.state.epoch += 1 engine.state.metrics["metric 0"] = v handler(engine, mlflow_logger, event_name=Events.EPOCH_COMPLETED) from mlflow.tracking import MlflowClient client = MlflowClient(tracking_uri=os.path.join(dirname, "mlruns")) stored_values = client.get_metric_history(active_run.info.run_id, "training metric 0") for t, s in zip([1000.0,] + true_values, stored_values): assert t == s.value
def test_output_handler_with_global_step_from_engine(): mock_another_engine = MagicMock() mock_another_engine.state = State() mock_another_engine.state.epoch = 10 mock_another_engine.state.output = 12.345 wrapper = OutputHandler( "tag", output_transform=lambda x: {"loss": x}, global_step_transform=global_step_from_engine(mock_another_engine), ) mock_logger = MagicMock(spec=MLflowLogger) mock_logger.log_metrics = MagicMock() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.epoch = 1 mock_engine.state.output = 0.123 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) assert mock_logger.log_metrics.call_count == 1 mock_logger.log_metrics.assert_has_calls( [call({"tag loss": mock_engine.state.output}, step=mock_another_engine.state.epoch)] ) mock_another_engine.state.epoch = 11 mock_engine.state.output = 1.123 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) assert mock_logger.log_metrics.call_count == 2 mock_logger.log_metrics.assert_has_calls( [call({"tag loss": mock_engine.state.output}, step=mock_another_engine.state.epoch)] )
def test_output_handler_metric_names(): wrapper = OutputHandler("tag", metric_names=["a", "b", "c"]) mock_logger = MagicMock(spec=MLflowLogger) mock_logger.log_metrics = MagicMock() mock_engine = MagicMock() mock_engine.state = State(metrics={"a": 12.23, "b": 23.45, "c": torch.tensor(10.0)}) mock_engine.state.iteration = 5 wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) assert mock_logger.log_metrics.call_count == 1 mock_logger.log_metrics.assert_called_once_with( {"tag a": 12.23, "tag b": 23.45, "tag c": 10.0}, step=5, ) wrapper = OutputHandler("tag", metric_names=["a",]) mock_engine = MagicMock() mock_engine.state = State(metrics={"a": torch.Tensor([0.0, 1.0, 2.0, 3.0])}) mock_engine.state.iteration = 5 mock_logger = MagicMock(spec=MLflowLogger) mock_logger.log_metrics = MagicMock() wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) assert mock_logger.log_metrics.call_count == 1 mock_logger.log_metrics.assert_has_calls( [call({"tag a 0": 0.0, "tag a 1": 1.0, "tag a 2": 2.0, "tag a 3": 3.0}, step=5),], any_order=True ) wrapper = OutputHandler("tag", metric_names=["a", "c"]) mock_engine = MagicMock() mock_engine.state = State(metrics={"a": 55.56, "c": "Some text"}) mock_engine.state.iteration = 7 mock_logger = MagicMock(spec=MLflowLogger) mock_logger.log_metrics = MagicMock() with pytest.warns(UserWarning): wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) assert mock_logger.log_metrics.call_count == 1 mock_logger.log_metrics.assert_has_calls([call({"tag a": 55.56}, step=7)], any_order=True)
def test_output_handler_with_wrong_logger_type(): wrapper = OutputHandler("tag", output_transform=lambda x: x) mock_logger = MagicMock() mock_engine = MagicMock() with pytest.raises(TypeError, match="Handler 'OutputHandler' works only with MLflowLogger"): wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
def test_output_handler_with_global_step_transform(): def global_step_transform(*args, **kwargs): return 10 wrapper = OutputHandler("tag", output_transform=lambda x: {"loss": x}, global_step_transform=global_step_transform) mock_logger = MagicMock(spec=MLflowLogger) mock_logger.log_metrics = MagicMock() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.epoch = 5 mock_engine.state.output = 12345 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) mock_logger.log_metrics.assert_called_once_with({"tag loss": 12345}, step=10)
def test_output_handler_with_wrong_global_step_transform_output(): def global_step_transform(*args, **kwargs): return "a" wrapper = OutputHandler("tag", output_transform=lambda x: {"loss": x}, global_step_transform=global_step_transform) mock_logger = MagicMock(spec=MLflowLogger) mock_logger.log_metrics = MagicMock() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.epoch = 5 mock_engine.state.output = 12345 with pytest.raises(TypeError, match="global_step must be int"): wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
def test_output_handler_both(): wrapper = OutputHandler("tag", metric_names=["a", "b"], output_transform=lambda x: {"loss": x}) mock_logger = MagicMock(spec=MLflowLogger) mock_logger.log_metrics = MagicMock() mock_engine = MagicMock() mock_engine.state = State(metrics={"a": 12.23, "b": 23.45}) mock_engine.state.epoch = 5 mock_engine.state.output = 12345 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) assert mock_logger.log_metrics.call_count == 1 mock_logger.log_metrics.assert_called_once_with({"tag a": 12.23, "tag b": 23.45, "tag loss": 12345}, step=5)
def test_output_handler_state_attrs(): wrapper = OutputHandler("tag", state_attributes=["alpha", "beta", "gamma"]) mock_logger = MagicMock(spec=MLflowLogger) mock_logger.log_metrics = MagicMock() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.iteration = 5 mock_engine.state.alpha = 3.899 mock_engine.state.beta = torch.tensor(12.21) mock_engine.state.gamma = torch.tensor([21.0, 6.0]) wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) mock_logger.log_metrics.assert_called_once_with( {"tag alpha": 3.899, "tag beta": torch.tensor(12.21).item(), "tag gamma 0": 21.0, "tag gamma 1": 6.0,}, step=5, )
def __call__(self, model, train_dataset, val_dataset=None, **_): """Train a PyTorch model. Args: model (torch.nn.Module): PyTorch model to train. train_dataset (torch.utils.data.Dataset): Dataset used to train. val_dataset (torch.utils.data.Dataset, optional): Dataset used to validate. Returns: trained_model (torch.nn.Module): Trained PyTorch model. """ assert train_dataset is not None train_params = self.train_params mlflow_logging = self.mlflow_logging if mlflow_logging: try: import mlflow # NOQA except ImportError: log.warning( "Failed to import mlflow. MLflow logging is disabled.") mlflow_logging = False loss_fn = train_params.get("loss_fn") assert loss_fn epochs = train_params.get("epochs") seed = train_params.get("seed") optimizer = train_params.get("optimizer") assert optimizer optimizer_params = train_params.get("optimizer_params", dict()) train_dataset_size_limit = train_params.get("train_dataset_size_limit") if train_dataset_size_limit: train_dataset = PartialDataset(train_dataset, train_dataset_size_limit) log.info("train dataset size is set to {}".format( len(train_dataset))) val_dataset_size_limit = train_params.get("val_dataset_size_limit") if val_dataset_size_limit and (val_dataset is not None): val_dataset = PartialDataset(val_dataset, val_dataset_size_limit) log.info("val dataset size is set to {}".format(len(val_dataset))) train_data_loader_params = train_params.get("train_data_loader_params", dict()) val_data_loader_params = train_params.get("val_data_loader_params", dict()) evaluation_metrics = train_params.get("evaluation_metrics") evaluate_train_data = train_params.get("evaluate_train_data") evaluate_val_data = train_params.get("evaluate_val_data") progress_update = train_params.get("progress_update") scheduler = train_params.get("scheduler") scheduler_params = train_params.get("scheduler_params", dict()) model_checkpoint = train_params.get("model_checkpoint") model_checkpoint_params = train_params.get("model_checkpoint_params") early_stopping_params = train_params.get("early_stopping_params") time_limit = train_params.get("time_limit") cudnn_deterministic = train_params.get("cudnn_deterministic") cudnn_benchmark = train_params.get("cudnn_benchmark") if seed: torch.manual_seed(seed) np.random.seed(seed) if cudnn_deterministic: torch.backends.cudnn.deterministic = cudnn_deterministic if cudnn_benchmark: torch.backends.cudnn.benchmark = cudnn_benchmark device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) optimizer_ = optimizer(model.parameters(), **optimizer_params) trainer = create_supervised_trainer(model, optimizer_, loss_fn=loss_fn, device=device) train_data_loader_params.setdefault("shuffle", True) train_data_loader_params.setdefault("drop_last", True) train_data_loader_params["batch_size"] = _clip_batch_size( train_data_loader_params.get("batch_size", 1), train_dataset, "train") train_loader = DataLoader(train_dataset, **train_data_loader_params) RunningAverage(output_transform=lambda x: x, alpha=0.98).attach(trainer, "ema_loss") RunningAverage(output_transform=lambda x: x, alpha=2**(-1022)).attach(trainer, "batch_loss") if scheduler: class ParamSchedulerSavingAsMetric( ParamSchedulerSavingAsMetricMixIn, scheduler): pass cycle_epochs = scheduler_params.pop("cycle_epochs", 1) scheduler_params.setdefault("cycle_size", int(cycle_epochs * len(train_loader))) scheduler_params.setdefault("param_name", "lr") scheduler_ = ParamSchedulerSavingAsMetric(optimizer_, **scheduler_params) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler_) if evaluate_train_data: evaluator_train = create_supervised_evaluator( model, metrics=evaluation_metrics, device=device) if evaluate_val_data: val_data_loader_params["batch_size"] = _clip_batch_size( val_data_loader_params.get("batch_size", 1), val_dataset, "val") val_loader = DataLoader(val_dataset, **val_data_loader_params) evaluator_val = create_supervised_evaluator( model, metrics=evaluation_metrics, device=device) if model_checkpoint_params: assert isinstance(model_checkpoint_params, dict) minimize = model_checkpoint_params.pop("minimize", True) save_interval = model_checkpoint_params.get("save_interval", None) if not save_interval: model_checkpoint_params.setdefault( "score_function", get_score_function("ema_loss", minimize=minimize)) model_checkpoint_params.setdefault("score_name", "ema_loss") mc = model_checkpoint(**model_checkpoint_params) trainer.add_event_handler(Events.EPOCH_COMPLETED, mc, {"model": model}) if early_stopping_params: assert isinstance(early_stopping_params, dict) metric = early_stopping_params.pop("metric", None) assert (metric is None) or (metric in evaluation_metrics) minimize = early_stopping_params.pop("minimize", False) if metric: assert ( "score_function" not in early_stopping_params ), "Remove either 'metric' or 'score_function' from early_stopping_params: {}".format( early_stopping_params) early_stopping_params["score_function"] = get_score_function( metric, minimize=minimize) es = EarlyStopping(trainer=trainer, **early_stopping_params) if evaluate_val_data: evaluator_val.add_event_handler(Events.COMPLETED, es) elif evaluate_train_data: evaluator_train.add_event_handler(Events.COMPLETED, es) elif early_stopping_params: log.warning( "Early Stopping is disabled because neither " "evaluate_val_data nor evaluate_train_data is set True.") if time_limit: assert isinstance(time_limit, (int, float)) tl = TimeLimit(limit_sec=time_limit) trainer.add_event_handler(Events.ITERATION_COMPLETED, tl) pbar = None if progress_update: if not isinstance(progress_update, dict): progress_update = dict() progress_update.setdefault("persist", True) progress_update.setdefault("desc", "") pbar = ProgressBar(**progress_update) pbar.attach(trainer, ["ema_loss"]) else: def log_train_metrics(engine): log.info("[Epoch: {} | {}]".format(engine.state.epoch, engine.state.metrics)) trainer.add_event_handler(Events.EPOCH_COMPLETED, log_train_metrics) if evaluate_train_data: def log_evaluation_train_data(engine): evaluator_train.run(train_loader) train_report = _get_report_str(engine, evaluator_train, "Train Data") if pbar: pbar.log_message(train_report) else: log.info(train_report) eval_train_event = (Events[evaluate_train_data] if isinstance( evaluate_train_data, str) else Events.EPOCH_COMPLETED) trainer.add_event_handler(eval_train_event, log_evaluation_train_data) if evaluate_val_data: def log_evaluation_val_data(engine): evaluator_val.run(val_loader) val_report = _get_report_str(engine, evaluator_val, "Val Data") if pbar: pbar.log_message(val_report) else: log.info(val_report) eval_val_event = (Events[evaluate_val_data] if isinstance( evaluate_val_data, str) else Events.EPOCH_COMPLETED) trainer.add_event_handler(eval_val_event, log_evaluation_val_data) if mlflow_logging: mlflow_logger = MLflowLogger() logging_params = { "train_n_samples": len(train_dataset), "train_n_batches": len(train_loader), "optimizer": _name(optimizer), "loss_fn": _name(loss_fn), "pytorch_version": torch.__version__, "ignite_version": ignite.__version__, } logging_params.update(_loggable_dict(optimizer_params, "optimizer")) logging_params.update( _loggable_dict(train_data_loader_params, "train")) if scheduler: logging_params.update({"scheduler": _name(scheduler)}) logging_params.update( _loggable_dict(scheduler_params, "scheduler")) if evaluate_val_data: logging_params.update({ "val_n_samples": len(val_dataset), "val_n_batches": len(val_loader), }) logging_params.update( _loggable_dict(val_data_loader_params, "val")) mlflow_logger.log_params(logging_params) batch_metric_names = ["batch_loss", "ema_loss"] if scheduler: batch_metric_names.append(scheduler_params.get("param_name")) mlflow_logger.attach( trainer, log_handler=OutputHandler( tag="step", metric_names=batch_metric_names, global_step_transform=global_step_from_engine(trainer), ), event_name=Events.ITERATION_COMPLETED, ) if evaluate_train_data: mlflow_logger.attach( evaluator_train, log_handler=OutputHandler( tag="train", metric_names=list(evaluation_metrics.keys()), global_step_transform=global_step_from_engine(trainer), ), event_name=Events.COMPLETED, ) if evaluate_val_data: mlflow_logger.attach( evaluator_val, log_handler=OutputHandler( tag="val", metric_names=list(evaluation_metrics.keys()), global_step_transform=global_step_from_engine(trainer), ), event_name=Events.COMPLETED, ) trainer.run(train_loader, max_epochs=epochs) try: if pbar and pbar.pbar: pbar.pbar.close() except Exception as e: log.error(e, exc_info=True) model = load_latest_model(model_checkpoint_params)(model) return model
def inference(config, local_rank, with_pbar_on_iters=True): set_seed(config.seed + local_rank) torch.cuda.set_device(local_rank) device = 'cuda' torch.backends.cudnn.benchmark = True # Load model and weights model_weights_filepath = Path( get_artifact_path(config.run_uuid, config.weights_filename)) assert model_weights_filepath.exists(), \ "Model weights file '{}' is not found".format(model_weights_filepath.as_posix()) model = config.model.to(device) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank) if hasattr(config, "custom_weights_loading"): config.custom_weights_loading(model, model_weights_filepath) else: state_dict = torch.load(model_weights_filepath) if not all([k.startswith("module.") for k in state_dict]): state_dict = {f"module.{k}": v for k, v in state_dict.items()} model.load_state_dict(state_dict) model.eval() prepare_batch = config.prepare_batch non_blocking = getattr(config, "non_blocking", True) model_output_transform = getattr(config, "model_output_transform", lambda x: x) tta_transforms = getattr(config, "tta_transforms", None) def eval_update_function(engine, batch): with torch.no_grad(): x, y, meta = prepare_batch(batch, device=device, non_blocking=non_blocking) if tta_transforms is not None: y_preds = [] for t in tta_transforms: t_x = t.augment_image(x) t_y_pred = model(t_x) t_y_pred = model_output_transform(t_y_pred) y_pred = t.deaugment_mask(t_y_pred) y_preds.append(y_pred) y_preds = torch.stack(y_preds, dim=0) y_pred = torch.mean(y_preds, dim=0) else: y_pred = model(x) y_pred = model_output_transform(y_pred) return {"y_pred": y_pred, "y": y, "meta": meta} evaluator = Engine(eval_update_function) has_targets = getattr(config, "has_targets", False) if has_targets: def output_transform(output): return output['y_pred'], output['y'] num_classes = config.num_classes cm_metric = ConfusionMatrix(num_classes=num_classes, output_transform=output_transform) pr = cmPrecision(cm_metric, average=False) re = cmRecall(cm_metric, average=False) val_metrics = { "IoU": IoU(cm_metric), "mIoU_bg": mIoU(cm_metric), "Accuracy": cmAccuracy(cm_metric), "Precision": pr, "Recall": re, "F1": Fbeta(beta=1.0, output_transform=output_transform) } if hasattr(config, "metrics") and isinstance(config.metrics, dict): val_metrics.update(config.metrics) for name, metric in val_metrics.items(): metric.attach(evaluator, name) if dist.get_rank() == 0: # Log val metrics: mlflow_logger = MLflowLogger() mlflow_logger.attach(evaluator, log_handler=OutputHandler( tag="validation", metric_names=list(val_metrics.keys())), event_name=Events.EPOCH_COMPLETED) if dist.get_rank() == 0 and with_pbar_on_iters: ProgressBar(persist=True, desc="Inference").attach(evaluator) if dist.get_rank() == 0: do_save_raw_predictions = getattr(config, "do_save_raw_predictions", True) do_save_overlayed_predictions = getattr( config, "do_save_overlayed_predictions", True) if not has_targets: assert do_save_raw_predictions or do_save_overlayed_predictions, \ "If no targets, either do_save_overlayed_predictions or do_save_raw_predictions should be " \ "defined in the config and has value equal True" # Save predictions if do_save_raw_predictions: raw_preds_path = config.output_path / "raw" raw_preds_path.mkdir(parents=True) evaluator.add_event_handler(Events.ITERATION_COMPLETED, save_raw_predictions_with_geoinfo, raw_preds_path) if do_save_overlayed_predictions: overlayed_preds_path = config.output_path / "overlay" overlayed_preds_path.mkdir(parents=True) evaluator.add_event_handler( Events.ITERATION_COMPLETED, save_overlayed_predictions, overlayed_preds_path, img_denormalize_fn=config.img_denormalize, palette=default_palette) evaluator.add_event_handler(Events.EXCEPTION_RAISED, report_exception) # Run evaluation evaluator.run(config.data_loader)
def build_engine(spec: RunSpec, output_dir=None, trainer=None, metric_cls=RunningAverage, tag="", mlflow_logger=None, is_training=None, device=None): if spec.plot_event is not None or spec.log_event is not None: assert output_dir is not None plot_fname = os.path.join(output_dir, "{}-{}".format(tag, PLOT_FNAME)) logs_fname = os.path.join(output_dir, "{}-{}".format(tag, LOGS_FNAME)) else: plot_fname = None logs_fname = None # Create engine if device: to_device = tensors_to_device(device=device) def step(engine, batch): batch = to_device(batch) return spec.step(engine, batch) else: step = spec.step engine = Engine(step) if trainer is None: # training trainer = engine if is_training is None: is_training = True else: # evaluation if is_training is None: is_training = False spec.set_defaults(is_training=is_training) # Attach metrics for name, metric in spec.metrics.items(): metric = auto_metric(metric, cls=metric_cls) metric.attach(engine, name) if spec.enable_timer: timer_metric(engine=engine) # Progress bar if spec.enable_pbar: ProgressBar(file=sys.stdout).attach(engine, metric_names=spec.pbar_metrics) # Print logs if spec.print_event is not None: engine.add_event_handler(event_name=spec.print_event, handler=print_logs, trainer=trainer, fmt=spec.print_fmt, metric_fmt=spec.print_metric_fmt, metric_names=spec.print_metrics) # Save logs if spec.log_event is not None: engine.add_event_handler(event_name=spec.log_event, handler=save_logs, fname=logs_fname, trainer=trainer, metric_names=spec.log_metrics) # Plot metrics if spec.plot_event is not None: # Plots require logs assert spec.log_event is not None engine.add_event_handler(event_name=spec.plot_event, handler=create_plots, logs_fname=logs_fname, plots_fname=plot_fname, metric_names=spec.plot_metrics) # Optional user callback for additional configuration chain_callbacks(callbacks=spec.callback, engine=engine, trainer=trainer) if mlflow_logger is not None and spec.log_event is not None: mlflow_logger.attach( engine, log_handler=OutputHandler( tag=tag, metric_names=spec.log_metrics, global_step_transform=global_step_from_engine(trainer)), event_name=spec.log_event) return engine
### TRAINING # Attach metrics metrics = ["loss_d", "loss_g"] RunningAverage(alpha=0.98, output_transform=lambda x: x["loss_d"]).attach( engine, "loss_d") RunningAverage(alpha=0.98, output_transform=lambda x: x["loss_g"]).attach( engine, "loss_g") if args.local_rank <= 0: pbar = ProgressBar() pbar.attach(engine, metric_names=metrics) mlflow_logger.attach(engine, log_handler=OutputHandler(tag="generator/loss", metric_names=["loss_g" ]), event_name=Events.ITERATION_COMPLETED(every=10)) mlflow_logger.attach(engine, log_handler=OutputHandler( tag="discriminator/loss", metric_names=["loss_d"]), event_name=Events.ITERATION_COMPLETED(every=10)) @engine.on(Events.EPOCH_COMPLETED) def log_times(engine): pbar.log_message( "Epoch {} finished: Batch average time is {:.3f}".format( engine.state.epoch, timer.value())) timer.reset()