Ejemplo n.º 1
0
def run(config, **kwargs):

    print("Run example training")

    if hasattr(config, "setup"):
        config = config.setup()

    set_seed(config.seed)
    assert foo()
Ejemplo n.º 2
0
def test_set_seed():

    # check numpy tensor
    import numpy as np
    np.random.seed(0)
    a1 = np.random.rand(4, 10)
    a2 = np.random.rand(4, 10)

    set_seed(0)
    b1 = np.random.rand(4, 10)
    b2 = np.random.rand(4, 10)

    assert np.all(a1 == b1)
    assert np.all(a2 == b2)

    # check random
    import random
    random.seed(0)
    a1 = random.randint(0, 100)
    a2 = random.randint(0, 100)

    set_seed(0)
    b1 = random.randint(0, 100)
    b2 = random.randint(0, 100)

    assert a1 == b1
    assert a2 == b2

    # check torch tensors
    import torch

    torch.manual_seed(0)
    a1 = torch.randint(0, 100, size=(10, ))
    a2 = torch.randint(0, 100, size=(10, ))

    set_seed(0)
    b1 = torch.randint(0, 100, size=(10, ))
    b2 = torch.randint(0, 100, size=(10, ))

    assert torch.all(a1 == b1)
    assert torch.all(a2 == b2)
Ejemplo n.º 3
0
def training(local_rank, config, logger=None):

    if not getattr(config, "use_fp16", True):
        raise RuntimeError("This training script uses by default fp16 AMP")

    torch.backends.cudnn.benchmark = True

    set_seed(config.seed + local_rank)

    train_loader, val_loader, train_eval_loader = config.train_loader, config.val_loader, config.train_eval_loader

    # Setup model, optimizer, criterion
    model, optimizer, criterion = initialize(config)

    # Setup trainer for this specific task
    trainer = create_trainer(model, optimizer, criterion, train_loader.sampler,
                             config, logger)

    # Setup evaluators
    num_classes = config.num_classes
    cm_metric = ConfusionMatrix(num_classes=num_classes)

    val_metrics = {
        "IoU": IoU(cm_metric),
        "mIoU_bg": mIoU(cm_metric),
    }

    if hasattr(config, "val_metrics") and isinstance(config.val_metrics, dict):
        val_metrics.update(config.val_metrics)

    evaluator, train_evaluator = create_evaluators(model, val_metrics, config)

    val_interval = getattr(config, "val_interval", 1)

    @trainer.on(Events.EPOCH_COMPLETED(every=val_interval))
    def run_validation():
        epoch = trainer.state.epoch
        state = train_evaluator.run(train_eval_loader)
        log_metrics(logger, epoch, state.times["COMPLETED"], "Train",
                    state.metrics)
        state = evaluator.run(val_loader)
        log_metrics(logger, epoch, state.times["COMPLETED"], "Test",
                    state.metrics)

    if config.num_epochs % val_interval != 0:
        trainer.add_event_handler(Events.COMPLETED, run_validation)

    if getattr(config, "start_by_validation", False):
        trainer.add_event_handler(Events.STARTED, run_validation)

    score_metric_name = "mIoU_bg"

    if hasattr(config, "es_patience"):
        common.add_early_stopping_by_val_score(config.es_patience,
                                               evaluator,
                                               trainer,
                                               metric_name=score_metric_name)

    # Store 3 best models by validation accuracy:
    common.gen_save_best_models_by_val_score(
        save_handler=get_save_handler(config),
        evaluator=evaluator,
        models=model,
        metric_name=score_metric_name,
        n_saved=3,
        trainer=trainer,
        tag="val",
    )

    if idist.get_rank() == 0:

        tb_logger = common.setup_tb_logging(
            config.output_path.as_posix(),
            trainer,
            optimizer,
            evaluators={
                "training": train_evaluator,
                "validation": evaluator
            },
        )

        if not exp_tracking.has_clearml:
            exp_tracking_logger = exp_tracking.setup_logging(
                trainer,
                optimizer,
                evaluators={
                    "training": train_evaluator,
                    "validation": evaluator
                })

        # Log validation predictions as images
        # We define a custom event filter to log less frequently the images (to reduce storage size)
        # - we plot images with masks of the middle validation batch
        # - once every 3 validations and
        # - at the end of the training
        def custom_event_filter(_, val_iteration):
            c1 = val_iteration == len(val_loader) // 2
            c2 = trainer.state.epoch % (getattr(config, "val_interval", 1) *
                                        3) == 0
            c2 |= trainer.state.epoch == config.num_epochs
            return c1 and c2

        tb_logger.attach(
            evaluator,
            log_handler=predictions_gt_images_handler(
                img_denormalize_fn=config.img_denormalize,
                n_images=15,
                another_engine=trainer,
                prefix_tag="validation"),
            event_name=Events.ITERATION_COMPLETED(
                event_filter=custom_event_filter),
        )

    # Log confusion matrix to ClearML:
    if exp_tracking.has_clearml:

        @trainer.on(Events.COMPLETED)
        def compute_and_log_cm():
            cm = cm_metric.compute()
            # CM: values are normalized such that diagonal values represent class recalls
            cm = ConfusionMatrix.normalize(cm, "recall").cpu().numpy()

            if idist.get_rank() == 0:
                try:
                    from clearml import Task
                except ImportError:
                    # Backwards-compatibility for legacy Trains SDK
                    from trains import Task

                clearml_logger = Task.current_task().get_logger()
                clearml_logger.report_confusion_matrix(
                    title="Final Confusion Matrix",
                    series="cm-preds-gt",
                    matrix=cm,
                    iteration=trainer.state.iteration,
                    xlabels=VOCSegmentationOpencv.target_names,
                    ylabels=VOCSegmentationOpencv.target_names,
                )

    trainer.run(train_loader, max_epochs=config.num_epochs)

    if idist.get_rank() == 0:
        tb_logger.close()
        if not exp_tracking.has_clearml:
            exp_tracking_logger.close()
Ejemplo n.º 4
0
def training(config,
             local_rank=None,
             with_mlflow_logging=False,
             with_plx_logging=False):

    if not getattr(config, "use_fp16", True):
        raise RuntimeError("This training script uses by default fp16 AMP")

    set_seed(config.seed + local_rank)
    torch.cuda.set_device(local_rank)
    device = 'cuda'

    torch.backends.cudnn.benchmark = True

    train_loader = config.train_loader
    train_sampler = getattr(train_loader, "sampler", None)
    assert train_sampler is not None, "Train loader of type '{}' " \
                                      "should have attribute 'sampler'".format(type(train_loader))
    assert hasattr(train_sampler, 'set_epoch') and callable(train_sampler.set_epoch), \
        "Train sampler should have a callable method `set_epoch`"

    train_eval_loader = config.train_eval_loader
    val_loader = config.val_loader

    model = config.model.to(device)
    optimizer = config.optimizer
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      opt_level=getattr(
                                          config, "fp16_opt_level", "O2"),
                                      num_losses=1)
    model = DDP(model, delay_allreduce=True)
    criterion = config.criterion.to(device)

    prepare_batch = getattr(config, "prepare_batch", _prepare_batch)
    non_blocking = getattr(config, "non_blocking", True)

    # Setup trainer
    accumulation_steps = getattr(config, "accumulation_steps", 1)
    model_output_transform = getattr(config, "model_output_transform",
                                     lambda x: x)

    def train_update_function(engine, batch):

        model.train()

        x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
        y_pred = model(x)
        y_pred = model_output_transform(y_pred)
        loss = criterion(y_pred, y) / accumulation_steps

        with amp.scale_loss(loss, optimizer, loss_id=0) as scaled_loss:
            scaled_loss.backward()

        if engine.state.iteration % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        return {
            'supervised batch loss': loss.item(),
        }

    trainer = Engine(train_update_function)
    common.setup_common_distrib_training_handlers(
        trainer,
        train_sampler,
        to_save={
            'model': model,
            'optimizer': optimizer
        },
        save_every_iters=1000,
        output_path=config.output_path.as_posix(),
        lr_scheduler=config.lr_scheduler,
        with_gpu_stats=True,
        output_names=[
            'supervised batch loss',
        ],
        with_pbars=True,
        with_pbar_on_iters=with_mlflow_logging,
        log_every_iters=1)

    if getattr(config, "benchmark_dataflow", False):
        benchmark_dataflow_num_iters = getattr(config,
                                               "benchmark_dataflow_num_iters",
                                               1000)
        DataflowBenchmark(benchmark_dataflow_num_iters,
                          prepare_batch=prepare_batch,
                          device=device).attach(trainer, train_loader)

    # Setup evaluators
    val_metrics = {
        "Accuracy": Accuracy(device=device),
        "Top-5 Accuracy": TopKCategoricalAccuracy(k=5, device=device),
    }

    if hasattr(config, "val_metrics") and isinstance(config.val_metrics, dict):
        val_metrics.update(config.val_metrics)

    model_output_transform = getattr(config, "model_output_transform",
                                     lambda x: x)

    evaluator_args = dict(model=model,
                          metrics=val_metrics,
                          device=device,
                          non_blocking=non_blocking,
                          prepare_batch=prepare_batch,
                          output_transform=lambda x, y, y_pred: (
                              model_output_transform(y_pred),
                              y,
                          ))
    train_evaluator = create_supervised_evaluator(**evaluator_args)
    evaluator = create_supervised_evaluator(**evaluator_args)

    if dist.get_rank() == 0 and with_mlflow_logging:
        ProgressBar(persist=False,
                    desc="Train Evaluation").attach(train_evaluator)
        ProgressBar(persist=False, desc="Val Evaluation").attach(evaluator)

    def run_validation(_):
        train_evaluator.run(train_eval_loader)
        evaluator.run(val_loader)

    if getattr(config, "start_by_validation", False):
        trainer.add_event_handler(Events.STARTED, run_validation)
    trainer.add_event_handler(
        Events.EPOCH_COMPLETED(every=getattr(config, "val_interval", 1)),
        run_validation)
    trainer.add_event_handler(Events.COMPLETED, run_validation)

    score_metric_name = "Accuracy"

    if hasattr(config, "es_patience"):
        common.add_early_stopping_by_val_score(config.es_patience,
                                               evaluator,
                                               trainer,
                                               metric_name=score_metric_name)

    if dist.get_rank() == 0:

        tb_logger = common.setup_tb_logging(config.output_path.as_posix(),
                                            trainer,
                                            optimizer,
                                            evaluators={
                                                "training": train_evaluator,
                                                "validation": evaluator
                                            })
        if with_mlflow_logging:
            common.setup_mlflow_logging(trainer,
                                        optimizer,
                                        evaluators={
                                            "training": train_evaluator,
                                            "validation": evaluator
                                        })

        if with_plx_logging:
            common.setup_plx_logging(trainer,
                                     optimizer,
                                     evaluators={
                                         "training": train_evaluator,
                                         "validation": evaluator
                                     })

        common.save_best_model_by_val_score(config.output_path.as_posix(),
                                            evaluator,
                                            model,
                                            metric_name=score_metric_name,
                                            trainer=trainer)

        # Log train/val predictions:
        tb_logger.attach(
            evaluator,
            log_handler=predictions_gt_images_handler(
                img_denormalize_fn=config.img_denormalize,
                n_images=15,
                another_engine=trainer,
                prefix_tag="validation"),
            event_name=Events.ITERATION_COMPLETED(once=len(val_loader) // 2))

        tb_logger.attach(train_evaluator,
                         log_handler=predictions_gt_images_handler(
                             img_denormalize_fn=config.img_denormalize,
                             n_images=15,
                             another_engine=trainer,
                             prefix_tag="training"),
                         event_name=Events.ITERATION_COMPLETED(
                             once=len(train_eval_loader) // 2))

    trainer.run(train_loader, max_epochs=config.num_epochs)
Ejemplo n.º 5
0
def training(local_rank, config, logger=None):

    if not getattr(config, "use_fp16", True):
        raise RuntimeError("This training script uses by default fp16 AMP")

    torch.backends.cudnn.benchmark = True

    set_seed(config.seed + local_rank)

    train_loader, val_loader, train_eval_loader = config.train_loader, config.val_loader, config.train_eval_loader

    # Setup model, optimizer, criterion
    model, optimizer, criterion = initialize(config)

    if not hasattr(config, "prepare_batch"):
        config.prepare_batch = _prepare_batch

    # Setup trainer for this specific task
    trainer = create_trainer(model, optimizer, criterion, train_loader.sampler,
                             config, logger)

    if getattr(config, "benchmark_dataflow", False):
        benchmark_dataflow_num_iters = getattr(config,
                                               "benchmark_dataflow_num_iters",
                                               1000)
        DataflowBenchmark(benchmark_dataflow_num_iters,
                          prepare_batch=config.prepare_batch).attach(
                              trainer, train_loader)

    # Setup evaluators
    val_metrics = {
        "Accuracy": Accuracy(),
        "Top-5 Accuracy": TopKCategoricalAccuracy(k=5),
    }

    if hasattr(config, "val_metrics") and isinstance(config.val_metrics, dict):
        val_metrics.update(config.val_metrics)

    evaluator, train_evaluator = create_evaluators(model, val_metrics, config)

    @trainer.on(
        Events.EPOCH_COMPLETED(every=getattr(config, "val_interval", 1))
        | Events.COMPLETED)
    def run_validation():
        epoch = trainer.state.epoch
        state = train_evaluator.run(train_eval_loader)
        log_metrics(logger, epoch, state.times["COMPLETED"], "Train",
                    state.metrics)
        state = evaluator.run(val_loader)
        log_metrics(logger, epoch, state.times["COMPLETED"], "Test",
                    state.metrics)

    if getattr(config, "start_by_validation", False):
        trainer.add_event_handler(Events.STARTED, run_validation)

    score_metric_name = "Accuracy"

    if hasattr(config, "es_patience"):
        common.add_early_stopping_by_val_score(config.es_patience,
                                               evaluator,
                                               trainer,
                                               metric_name=score_metric_name)

    # Store 3 best models by validation accuracy:
    common.save_best_model_by_val_score(
        config.output_path.as_posix(),
        evaluator,
        model=model,
        metric_name=score_metric_name,
        n_saved=3,
        trainer=trainer,
        tag="val",
    )

    if idist.get_rank() == 0:

        tb_logger = common.setup_tb_logging(
            config.output_path.as_posix(),
            trainer,
            optimizer,
            evaluators={
                "training": train_evaluator,
                "validation": evaluator
            },
        )

        exp_tracking_logger = exp_tracking.setup_logging(trainer,
                                                         optimizer,
                                                         evaluators={
                                                             "training":
                                                             train_evaluator,
                                                             "validation":
                                                             evaluator
                                                         })

        # Log train/val predictions:
        tb_logger.attach(
            evaluator,
            log_handler=predictions_gt_images_handler(
                img_denormalize_fn=config.img_denormalize,
                n_images=15,
                another_engine=trainer,
                prefix_tag="validation"),
            event_name=Events.ITERATION_COMPLETED(once=len(val_loader) // 2),
        )

        tb_logger.attach(
            train_evaluator,
            log_handler=predictions_gt_images_handler(
                img_denormalize_fn=config.img_denormalize,
                n_images=15,
                another_engine=trainer,
                prefix_tag="training"),
            event_name=Events.ITERATION_COMPLETED(
                once=len(train_eval_loader) // 2),
        )

    trainer.run(train_loader, max_epochs=config.num_epochs)

    if idist.get_rank() == 0:
        tb_logger.close()
        exp_tracking_logger.close()
Ejemplo n.º 6
0
def inference(config, local_rank, with_pbar_on_iters=True):

    set_seed(config.seed + local_rank)
    torch.cuda.set_device(local_rank)
    device = 'cuda'

    torch.backends.cudnn.benchmark = True

    # Load model and weights
    model_weights_filepath = Path(
        get_artifact_path(config.run_uuid, config.weights_filename))
    assert model_weights_filepath.exists(), \
        "Model weights file '{}' is not found".format(model_weights_filepath.as_posix())

    model = config.model.to(device)
    model = torch.nn.parallel.DistributedDataParallel(model,
                                                      device_ids=[local_rank],
                                                      output_device=local_rank)

    if hasattr(config, "custom_weights_loading"):
        config.custom_weights_loading(model, model_weights_filepath)
    else:
        state_dict = torch.load(model_weights_filepath)
        if not all([k.startswith("module.") for k in state_dict]):
            state_dict = {f"module.{k}": v for k, v in state_dict.items()}
        model.load_state_dict(state_dict)

    model.eval()

    prepare_batch = config.prepare_batch
    non_blocking = getattr(config, "non_blocking", True)
    model_output_transform = getattr(config, "model_output_transform",
                                     lambda x: x)

    tta_transforms = getattr(config, "tta_transforms", None)

    def eval_update_function(engine, batch):
        with torch.no_grad():
            x, y, meta = prepare_batch(batch,
                                       device=device,
                                       non_blocking=non_blocking)

            if tta_transforms is not None:
                y_preds = []
                for t in tta_transforms:
                    t_x = t.augment_image(x)
                    t_y_pred = model(t_x)
                    t_y_pred = model_output_transform(t_y_pred)
                    y_pred = t.deaugment_mask(t_y_pred)
                    y_preds.append(y_pred)

                y_preds = torch.stack(y_preds, dim=0)
                y_pred = torch.mean(y_preds, dim=0)
            else:
                y_pred = model(x)
                y_pred = model_output_transform(y_pred)
            return {"y_pred": y_pred, "y": y, "meta": meta}

    evaluator = Engine(eval_update_function)

    has_targets = getattr(config, "has_targets", False)

    if has_targets:

        def output_transform(output):
            return output['y_pred'], output['y']

        num_classes = config.num_classes
        cm_metric = ConfusionMatrix(num_classes=num_classes,
                                    output_transform=output_transform)
        pr = cmPrecision(cm_metric, average=False)
        re = cmRecall(cm_metric, average=False)

        val_metrics = {
            "IoU": IoU(cm_metric),
            "mIoU_bg": mIoU(cm_metric),
            "Accuracy": cmAccuracy(cm_metric),
            "Precision": pr,
            "Recall": re,
            "F1": Fbeta(beta=1.0, output_transform=output_transform)
        }

        if hasattr(config, "metrics") and isinstance(config.metrics, dict):
            val_metrics.update(config.metrics)

        for name, metric in val_metrics.items():
            metric.attach(evaluator, name)

        if dist.get_rank() == 0:
            # Log val metrics:
            mlflow_logger = MLflowLogger()
            mlflow_logger.attach(evaluator,
                                 log_handler=OutputHandler(
                                     tag="validation",
                                     metric_names=list(val_metrics.keys())),
                                 event_name=Events.EPOCH_COMPLETED)

    if dist.get_rank() == 0 and with_pbar_on_iters:
        ProgressBar(persist=True, desc="Inference").attach(evaluator)

    if dist.get_rank() == 0:
        do_save_raw_predictions = getattr(config, "do_save_raw_predictions",
                                          True)
        do_save_overlayed_predictions = getattr(
            config, "do_save_overlayed_predictions", True)

        if not has_targets:
            assert do_save_raw_predictions or do_save_overlayed_predictions, \
                "If no targets, either do_save_overlayed_predictions or do_save_raw_predictions should be " \
                "defined in the config and has value equal True"

        # Save predictions
        if do_save_raw_predictions:
            raw_preds_path = config.output_path / "raw"
            raw_preds_path.mkdir(parents=True)

            evaluator.add_event_handler(Events.ITERATION_COMPLETED,
                                        save_raw_predictions_with_geoinfo,
                                        raw_preds_path)

        if do_save_overlayed_predictions:
            overlayed_preds_path = config.output_path / "overlay"
            overlayed_preds_path.mkdir(parents=True)

            evaluator.add_event_handler(
                Events.ITERATION_COMPLETED,
                save_overlayed_predictions,
                overlayed_preds_path,
                img_denormalize_fn=config.img_denormalize,
                palette=default_palette)

    evaluator.add_event_handler(Events.EXCEPTION_RAISED, report_exception)

    # Run evaluation
    evaluator.run(config.data_loader)
Ejemplo n.º 7
0
def training(config, local_rank, with_pbar_on_iters=True):

    if not getattr(config, "use_fp16", True):
        raise RuntimeError("This training script uses by default fp16 AMP")

    set_seed(config.seed + local_rank)
    torch.cuda.set_device(local_rank)
    device = 'cuda'

    torch.backends.cudnn.benchmark = True

    train_loader = config.train_loader
    train_sampler = getattr(train_loader, "sampler", None)
    assert train_sampler is not None, "Train loader of type '{}' " \
                                      "should have attribute 'sampler'".format(type(train_loader))
    assert hasattr(train_sampler, 'set_epoch') and callable(train_sampler.set_epoch), \
        "Train sampler should have a callable method `set_epoch`"

    train_eval_loader = config.train_eval_loader
    val_loader = config.val_loader

    model = config.model.to(device)
    optimizer = config.optimizer
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      opt_level=getattr(
                                          config, "fp16_opt_level", "O2"),
                                      num_losses=1)
    model = DDP(model, delay_allreduce=True)

    criterion = config.criterion.to(device)

    # Setup trainer
    prepare_batch = getattr(config, "prepare_batch")
    non_blocking = getattr(config, "non_blocking", True)
    accumulation_steps = getattr(config, "accumulation_steps", 1)
    model_output_transform = getattr(config, "model_output_transform",
                                     lambda x: x)

    def train_update_function(engine, batch):
        model.train()

        x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
        y_pred = model(x)
        y_pred = model_output_transform(y_pred)
        loss = criterion(y_pred, y)

        if isinstance(loss, Mapping):
            assert 'supervised batch loss' in loss
            loss_dict = loss
            output = {k: v.item() for k, v in loss_dict.items()}
            loss = loss_dict['supervised batch loss'] / accumulation_steps
        else:
            output = {'supervised batch loss': loss.item()}

        with amp.scale_loss(loss, optimizer, loss_id=0) as scaled_loss:
            scaled_loss.backward()

        if engine.state.iteration % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        return output

    output_names = getattr(config, "output_names", [
        'supervised batch loss',
    ])

    trainer = Engine(train_update_function)
    common.setup_common_distrib_training_handlers(
        trainer,
        train_sampler,
        to_save={
            'model': model,
            'optimizer': optimizer
        },
        save_every_iters=1000,
        output_path=config.output_path.as_posix(),
        lr_scheduler=config.lr_scheduler,
        output_names=output_names,
        with_pbars=True,
        with_pbar_on_iters=with_pbar_on_iters,
        log_every_iters=1)

    def output_transform(output):
        return output['y_pred'], output['y']

    num_classes = config.num_classes
    cm_metric = ConfusionMatrix(num_classes=num_classes,
                                output_transform=output_transform)
    pr = cmPrecision(cm_metric, average=False)
    re = cmRecall(cm_metric, average=False)

    val_metrics = {
        "IoU": IoU(cm_metric),
        "mIoU_bg": mIoU(cm_metric),
        "Accuracy": cmAccuracy(cm_metric),
        "Precision": pr,
        "Recall": re,
        "F1": Fbeta(beta=1.0, output_transform=output_transform)
    }

    if hasattr(config, "val_metrics") and isinstance(config.val_metrics, dict):
        val_metrics.update(config.val_metrics)

    evaluator_args = dict(model=model,
                          metrics=val_metrics,
                          device=device,
                          non_blocking=non_blocking,
                          prepare_batch=prepare_batch,
                          output_transform=lambda x, y, y_pred: {
                              'y_pred': model_output_transform(y_pred),
                              'y': y
                          })
    train_evaluator = create_supervised_evaluator(**evaluator_args)
    evaluator = create_supervised_evaluator(**evaluator_args)

    if dist.get_rank() == 0 and with_pbar_on_iters:
        ProgressBar(persist=False,
                    desc="Train Evaluation").attach(train_evaluator)
        ProgressBar(persist=False, desc="Val Evaluation").attach(evaluator)

    def run_validation(engine):
        train_evaluator.run(train_eval_loader)
        evaluator.run(val_loader)

    if getattr(config, "start_by_validation", False):
        trainer.add_event_handler(Events.STARTED, run_validation)
    trainer.add_event_handler(
        Events.EPOCH_COMPLETED(every=getattr(config, "val_interval", 1)),
        run_validation)
    trainer.add_event_handler(Events.COMPLETED, run_validation)

    score_metric_name = "mIoU_bg"

    if hasattr(config, "es_patience"):
        common.add_early_stopping_by_val_score(config.es_patience,
                                               evaluator,
                                               trainer,
                                               metric_name=score_metric_name)

    if dist.get_rank() == 0:

        tb_logger = common.setup_tb_logging(config.output_path.as_posix(),
                                            trainer,
                                            optimizer,
                                            evaluators={
                                                "training": train_evaluator,
                                                "validation": evaluator
                                            })
        common.setup_mlflow_logging(trainer,
                                    optimizer,
                                    evaluators={
                                        "training": train_evaluator,
                                        "validation": evaluator
                                    })

        common.save_best_model_by_val_score(config.output_path.as_posix(),
                                            evaluator,
                                            model,
                                            metric_name=score_metric_name,
                                            trainer=trainer)

        # Log train/val predictions:
        tb_logger.attach(
            evaluator,
            log_handler=predictions_gt_images_handler(
                img_denormalize_fn=config.img_denormalize,
                n_images=15,
                another_engine=trainer,
                prefix_tag="validation"),
            event_name=Events.ITERATION_COMPLETED(once=len(val_loader) // 2))

        log_train_predictions = getattr(config, "log_train_predictions", False)
        if log_train_predictions:
            tb_logger.attach(train_evaluator,
                             log_handler=predictions_gt_images_handler(
                                 img_denormalize_fn=config.img_denormalize,
                                 n_images=15,
                                 another_engine=trainer,
                                 prefix_tag="validation"),
                             event_name=Events.ITERATION_COMPLETED(
                                 once=len(train_eval_loader) // 2))

    trainer.run(train_loader, max_epochs=config.num_epochs)
Ejemplo n.º 8
0
def training(config, local_rank, with_pbar_on_iters=True):

    if not getattr(config, "use_fp16", True):
        raise RuntimeError("This training script uses by default fp16 AMP")

    set_seed(config.seed + local_rank)
    torch.cuda.set_device(local_rank)
    device = 'cuda'

    torch.backends.cudnn.benchmark = True

    train_loader = config.train_loader
    train_sampler = getattr(train_loader, "sampler", None)
    assert train_sampler is not None, "Train loader of type '{}' " \
                                      "should have attribute 'sampler'".format(type(train_loader))
    assert hasattr(train_sampler, 'set_epoch') and callable(train_sampler.set_epoch), \
        "Train sampler should have a callable method `set_epoch`"

    unsup_train_loader = config.unsup_train_loader
    unsup_train_sampler = getattr(unsup_train_loader, "sampler", None)
    assert unsup_train_sampler is not None, "Train loader of type '{}' " \
                                      "should have attribute 'sampler'".format(type(unsup_train_loader))
    assert hasattr(unsup_train_sampler, 'set_epoch') and callable(unsup_train_sampler.set_epoch), \
        "Unsupervised train sampler should have a callable method `set_epoch`"

    train_eval_loader = config.train_eval_loader
    val_loader = config.val_loader

    model = config.model.to(device)
    optimizer = config.optimizer
    model, optimizer = amp.initialize(model, optimizer, opt_level=getattr(config, "fp16_opt_level", "O2"), num_losses=2)
    model = DDP(model, delay_allreduce=True)
    
    criterion = config.criterion.to(device)
    unsup_criterion = config.unsup_criterion.to(device)
    unsup_batch_num_repetitions = getattr(config, "unsup_batch_num_repetitions", 1)

    # Setup trainer
    prepare_batch = getattr(config, "prepare_batch")
    non_blocking = getattr(config, "non_blocking", True)
    accumulation_steps = getattr(config, "accumulation_steps", 1)
    model_output_transform = getattr(config, "model_output_transform", lambda x: x)

    def cycle(seq):
        while True:
            for i in seq:
                yield i

    unsup_train_loader_iter = cycle(unsup_train_loader)

    def supervised_loss(batch):
        x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
        y_pred = model(x)
        y_pred = model_output_transform(y_pred)
        loss = criterion(y_pred, y)
        return loss

    def unsupervised_loss(x):

        with torch.no_grad():
            y_pred_orig = model(x)

            # Data augmentation: geom only
            k = random.randint(1, 3)
            x_aug = torch.rot90(x, k=k, dims=(2, 3))
            y_pred_orig_aug = torch.rot90(y_pred_orig, k=k, dims=(2, 3))
            if random.random() < 0.5:
                x_aug = torch.flip(x_aug, dims=(2, ))
                y_pred_orig_aug = torch.flip(y_pred_orig_aug, dims=(2, )) 
            if random.random() < 0.5:
                x_aug = torch.flip(x_aug, dims=(3, ))
                y_pred_orig_aug = torch.flip(y_pred_orig_aug, dims=(3, )) 

            y_pred_orig_aug = y_pred_orig_aug.argmax(dim=1).long()

        y_pred_aug = model(x_aug.detach())

        loss = unsup_criterion(y_pred_aug, y_pred_orig_aug.detach())

        return loss

    def train_update_function(engine, batch):
        model.train()

        loss = supervised_loss(batch)
        if isinstance(loss, Mapping):
            assert 'supervised batch loss' in loss
            loss_dict = loss
            output = {k: v.item() for k, v in loss_dict.items()}
            loss = loss_dict['supervised batch loss'] / accumulation_steps
        else:
            output = {'supervised batch loss': loss.item()}
        
        # Difference with original UDA
        # Apply separately grads from supervised/unsupervised parts
        with amp.scale_loss(loss, optimizer, loss_id=0) as scaled_loss:
            scaled_loss.backward()

        if engine.state.iteration % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        unsup_batch = next(unsup_train_loader_iter)
        unsup_x = unsup_batch['image']
        unsup_x = convert_tensor(unsup_x, device=device, non_blocking=non_blocking)

        for _ in range(unsup_batch_num_repetitions):
            unsup_loss = engine.state.unsup_lambda * unsupervised_loss(unsup_x)

            assert isinstance(unsup_loss, torch.Tensor)
            output['unsupervised batch loss'] = unsup_loss.item()

            with amp.scale_loss(unsup_loss, optimizer, loss_id=1) as scaled_loss:
                scaled_loss.backward()

            if engine.state.iteration % accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()

        unsup_batch = None
        unsup_x = None

        total_loss = loss +  unsup_loss
        output['total batch loss'] = total_loss.item()

        return output

    output_names = getattr(config, "output_names", 
                           ['supervised batch loss', 'unsupervised batch loss', 'total batch loss'])

    trainer = Engine(train_update_function)

    @trainer.on(Events.STARTED)
    def init(engine):
        if hasattr(config, "unsup_lambda_min"):
            engine.state.unsup_lambda = config.unsup_lambda_min
        else:
            engine.state.unsup_lambda = getattr(config, "unsup_lambda", 0.001)

    @trainer.on(Events.ITERATION_COMPLETED)
    def update_unsup_params(engine):        
        engine.state.unsup_lambda += getattr(config, "unsup_lambda_delta", 0.00001)
        if hasattr(config, "unsup_lambda_max"):
            m = config.unsup_lambda_max
            engine.state.unsup_lambda = engine.state.unsup_lambda if engine.state.unsup_lambda < m else m

    common.setup_common_distrib_training_handlers(
        trainer, train_sampler,
        to_save={'model': model, 'optimizer': optimizer},
        save_every_iters=1000, output_path=config.output_path.as_posix(),
        lr_scheduler=config.lr_scheduler, output_names=output_names,
        with_pbars=True, with_pbar_on_iters=with_pbar_on_iters, log_every_iters=1
    )

    def output_transform(output):        
        return output['y_pred'], output['y']

    num_classes = config.num_classes
    cm_metric = ConfusionMatrix(num_classes=num_classes, output_transform=output_transform)
    pr = cmPrecision(cm_metric, average=False)
    re = cmRecall(cm_metric, average=False)

    val_metrics = {
        "IoU": IoU(cm_metric),
        "mIoU_bg": mIoU(cm_metric),
        "Accuracy": cmAccuracy(cm_metric),
        "Precision": pr,
        "Recall": re,
        "F1": Fbeta(beta=1.0, output_transform=output_transform)
    }

    if hasattr(config, "val_metrics") and isinstance(config.val_metrics, dict):
        val_metrics.update(config.val_metrics)

    evaluator_args = dict(
        model=model, metrics=val_metrics, device=device, non_blocking=non_blocking, prepare_batch=prepare_batch,
        output_transform=lambda x, y, y_pred: {'y_pred': model_output_transform(y_pred), 'y': y}
    )
    train_evaluator = create_supervised_evaluator(**evaluator_args)
    evaluator = create_supervised_evaluator(**evaluator_args)

    if dist.get_rank() == 0 and with_pbar_on_iters:
        ProgressBar(persist=False, desc="Train Evaluation").attach(train_evaluator)
        ProgressBar(persist=False, desc="Val Evaluation").attach(evaluator)

    def run_validation(engine):
        train_evaluator.run(train_eval_loader)
        evaluator.run(val_loader)

    if getattr(config, "start_by_validation", False):
        trainer.add_event_handler(Events.STARTED, run_validation)
    trainer.add_event_handler(Events.EPOCH_COMPLETED(every=getattr(config, "val_interval", 1)), run_validation)
    trainer.add_event_handler(Events.COMPLETED, run_validation)

    score_metric_name = "mIoU_bg"

    if hasattr(config, "es_patience"):
        common.add_early_stopping_by_val_score(config.es_patience, evaluator, trainer, metric_name=score_metric_name)

    if dist.get_rank() == 0:

        tb_logger = common.setup_tb_logging(config.output_path.as_posix(), trainer, optimizer,
                                            evaluators={"training": train_evaluator, "validation": evaluator})
        common.setup_mlflow_logging(trainer, optimizer,
                                    evaluators={"training": train_evaluator, "validation": evaluator})

        common.save_best_model_by_val_score(config.output_path.as_posix(), evaluator, model,
                                            metric_name=score_metric_name, trainer=trainer)

        # Log unsup_lambda
        @trainer.on(Events.ITERATION_COMPLETED(every=100))
        def tblog_unsupervised_lambda(engine):
            tb_logger.writer.add_scalar("training/unsupervised lambda", engine.state.unsup_lambda, engine.state.iteration)
            mlflow.log_metric("training unsupervised lambda", engine.state.unsup_lambda, step=engine.state.iteration)

        # Log train/val predictions:
        tb_logger.attach(evaluator,
                         log_handler=predictions_gt_images_handler(img_denormalize_fn=config.img_denormalize,
                                                                   n_images=15,
                                                                   another_engine=trainer,
                                                                   prefix_tag="validation"),
                         event_name=Events.ITERATION_COMPLETED(once=len(val_loader) // 2))

        log_train_predictions = getattr(config, "log_train_predictions", False)
        if log_train_predictions:
            tb_logger.attach(train_evaluator,
                             log_handler=predictions_gt_images_handler(img_denormalize_fn=config.img_denormalize,
                                                                       n_images=15,
                                                                       another_engine=trainer,
                                                                       prefix_tag="validation"),
                             event_name=Events.ITERATION_COMPLETED(once=len(train_eval_loader) // 2))

    trainer.run(train_loader, max_epochs=config.num_epochs)
Ejemplo n.º 9
0
def training(local_rank, config, logger=None):
    #
    # if not getattr(config, "use_fp16", True):
    #     raise RuntimeError("This training script uses by default fp16 AMP")

    torch.backends.cudnn.benchmark = True

    set_seed(config.seed + local_rank)

    train_loader, val_loader, train_eval_loader = config.train_loader, config.val_loader, config.train_eval_loader

    # Setup model, optimizer, criterion
    model, optimizer, criterion = initialize(config)

    # Setup trainer for this specific task
    trainer = create_trainer(model, optimizer, criterion, train_loader.sampler,
                             config, logger)

    # Setup evaluators
    num_classes = config.num_classes
    cm_metric = ConfusionMatrix(num_classes=num_classes)

    val_metrics = {
        "IoU": IoU(cm_metric),
        "mIoU_bg": mIoU(cm_metric),
    }

    if hasattr(config, "val_metrics") and isinstance(config.val_metrics, dict):
        val_metrics.update(config.val_metrics)

    evaluator, train_evaluator = create_evaluators(model, val_metrics, config)

    val_interval = getattr(config, "val_interval", 1)

    @trainer.on(Events.EPOCH_COMPLETED(every=val_interval))
    def run_validation():
        epoch = trainer.state.epoch
        state = train_evaluator.run(train_eval_loader)
        log_metrics(logger, epoch, state.times["COMPLETED"], "Train",
                    state.metrics)
        state = evaluator.run(val_loader)
        log_metrics(logger, epoch, state.times["COMPLETED"], "Test",
                    state.metrics)

    if config.num_epochs % val_interval != 0:
        trainer.add_event_handler(Events.COMPLETED, run_validation)

    if getattr(config, "start_by_validation", False):
        trainer.add_event_handler(Events.STARTED, run_validation)

    score_metric_name = "mIoU_bg"

    if hasattr(config, "es_patience"):
        common.add_early_stopping_by_val_score(config.es_patience,
                                               evaluator,
                                               trainer,
                                               metric_name=score_metric_name)

    # Store 3 best models by validation accuracy:
    common.gen_save_best_models_by_val_score(
        save_handler=get_save_handler(config),
        evaluator=evaluator,
        models=model,
        metric_name=score_metric_name,
        n_saved=3,
        trainer=trainer,
        tag="val",
    )

    if idist.get_rank() == 0:

        tb_logger = common.setup_tb_logging(
            config.output_path.as_posix(),
            trainer,
            optimizer,
            evaluators={
                "training": train_evaluator,
                "validation": evaluator
            },
        )

        exp_tracking_logger = tracking.setup_logging(trainer,
                                                     optimizer,
                                                     evaluators={
                                                         "training":
                                                         train_evaluator,
                                                         "validation":
                                                         evaluator
                                                     })

        # Log validation predictions as images
        # We define a custom event filter to log less frequently the images (to reduce storage size)
        # - we plot images with masks of the middle validation batch
        # - once every 3 validations and
        # - at the end of the training
        def custom_event_filter(_, val_iteration):
            c1 = val_iteration == len(val_loader) // 2
            c2 = trainer.state.epoch % (getattr(config, "val_interval", 1) *
                                        3) == 0
            c2 |= trainer.state.epoch == config.num_epochs
            return c1 and c2

        tb_logger.attach(
            evaluator,
            log_handler=predictions_gt_images_handler(
                img_denormalize_fn=config.img_denormalize,
                n_images=15,
                another_engine=trainer,
                prefix_tag="validation"),
            event_name=Events.ITERATION_COMPLETED(
                event_filter=custom_event_filter),
        )

    # Log confusion matrix to Trains:

    trainer.run(train_loader, max_epochs=config.num_epochs)

    if idist.get_rank() == 0:
        tb_logger.close()
        exp_tracking_logger.close()