Exemplo n.º 1
0
    def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        if len(self._predictions) < 1 or len(self._targets) < 1:
            raise NotComputableError("PrecisionRecallCurve must have at least one example before it can be computed.")

        _prediction_tensor = torch.cat(self._predictions, dim=0)
        _target_tensor = torch.cat(self._targets, dim=0)

        ws = idist.get_world_size()
        if ws > 1 and not self._is_reduced:
            # All gather across all processes
            _prediction_tensor = cast(torch.Tensor, idist.all_gather(_prediction_tensor))
            _target_tensor = cast(torch.Tensor, idist.all_gather(_target_tensor))
        self._is_reduced = True

        if idist.get_rank() == 0:
            # Run compute_fn on zero rank only
            precision, recall, thresholds = self.compute_fn(_prediction_tensor, _target_tensor)
            precision = torch.tensor(precision)
            recall = torch.tensor(recall)
            # thresholds can have negative strides, not compatible with torch tensors
            # https://discuss.pytorch.org/t/negative-strides-in-tensor-error/134287/2
            thresholds = torch.tensor(thresholds.copy())
        else:
            precision, recall, thresholds = None, None, None

        if ws > 1:
            # broadcast result to all processes
            precision = idist.broadcast(precision, src=0, safe_mode=True)
            recall = idist.broadcast(recall, src=0, safe_mode=True)
            thresholds = idist.broadcast(thresholds, src=0, safe_mode=True)

        return precision, recall, thresholds
Exemplo n.º 2
0
def _test_distrib_broadcast(device):

    rank = idist.get_rank()
    ws = idist.get_world_size()
    for src in range(ws):

        d = 10 if rank == src else 0
        res = idist.broadcast(d, src=src)
        true_res = 10
        assert res == true_res

        if rank == src:
            t = torch.tensor([1.2345, 2.3456],
                             dtype=torch.float,
                             device=device)
        else:
            t = torch.empty(2, device=device)

        res = idist.broadcast(t, src=src)
        true_res = torch.tensor([1.2345, 2.3456],
                                dtype=torch.float,
                                device=device)
        assert (res == true_res).all(), f"{res} vs {true_res}"

        def _test(text):

            if rank == src:
                t = text
            else:
                t = ""

            res = idist.broadcast(t, src=src)
            true_res = text
            assert res == true_res

        _test("test-abcdefg")
        _test(
            "tests/ignite/distributed/utils/test_horovod.py::test_idist_broadcast_hvd"
            * 200)

        if rank == src:
            t = torch.arange(100, device=device).reshape(4, 25) * (src + 1)
        else:
            t = torch.empty(4, 25, dtype=torch.long, device=device)

        in_dtype = torch.long
        res = idist.broadcast(t, src)
        assert res.shape == (4, 25)
        assert res.dtype == in_dtype
        true_res = torch.arange(100, device=device).reshape(4, 25) * (src + 1)
        assert (res == true_res).all()

    if idist.get_world_size() > 1:
        with pytest.raises(TypeError, match=r"Unhandled input type"):
            idist.broadcast([0, 1, 2], src=0)
Exemplo n.º 3
0
def _test_distrib_broadcast(device):

    rank = idist.get_rank()
    ws = idist.get_world_size()

    def _test(data_src, data_others, safe_mode):
        for src in range(ws):

            data = data_src if rank == src else data_others
            res = idist.broadcast(data, src=src, safe_mode=safe_mode)

            if isinstance(res, torch.Tensor):
                assert (res == data_src).all(), f"{res} vs {data_src}"
                assert data_src.dtype == res.dtype
            else:
                assert res == data_src, f"{res} vs {data_src}"

    _test(10, 0, safe_mode=False)
    _test(10, None, safe_mode=True)

    t = torch.tensor([1.2345, 2.3456], dtype=torch.float, device=device)
    _test(t, torch.empty_like(t), safe_mode=False)
    _test(t, None, safe_mode=True)
    _test(t, "abc", safe_mode=True)

    _test("test-abcdefg", "", safe_mode=False)
    _test("test-abcdefg", None, safe_mode=True)
    _test("test-abcdefg", 1.2, safe_mode=True)

    s = "tests/ignite/distributed/utils/test_horovod.py::test_idist_broadcast_hvd" * 200
    _test(s, "", safe_mode=False)
    _test(s, None, safe_mode=True)
    _test(s, 123.0, safe_mode=True)

    t = torch.arange(100, device=device).reshape(4, 25) * 2
    _test(t, torch.empty_like(t), safe_mode=False)
    _test(t, None, safe_mode=True)
    _test(t, "None", safe_mode=True)

    t = torch.tensor(12)
    _test(t, torch.empty_like(t), safe_mode=False)
    _test(t, None, safe_mode=True)
    _test(t, 123.4, safe_mode=True)

    if idist.get_world_size() > 1:
        with pytest.raises(TypeError, match=r"Unhandled input type"):
            idist.broadcast([0, 1, 2], src=0)

    if idist.get_world_size() > 1:
        msg = "Source data can not be None" if rank == 0 else "Argument safe_mode should be True"
        with pytest.raises(ValueError, match=msg):
            idist.broadcast(None, src=0)
Exemplo n.º 4
0
    def compute(self) -> float:
        if len(self._predictions) < 1 or len(self._targets) < 1:
            raise NotComputableError(
                "EpochMetric must have at least one example before it can be computed."
            )

        _prediction_tensor = torch.cat(self._predictions, dim=0)
        _target_tensor = torch.cat(self._targets, dim=0)

        ws = idist.get_world_size()

        if ws > 1 and not self._is_reduced:
            # All gather across all processes
            _prediction_tensor = cast(torch.Tensor,
                                      idist.all_gather(_prediction_tensor))
            _target_tensor = cast(torch.Tensor,
                                  idist.all_gather(_target_tensor))
        self._is_reduced = True

        result = 0.0
        if idist.get_rank() == 0:
            # Run compute_fn on zero rank only
            result = self.compute_fn(_prediction_tensor, _target_tensor)

        if ws > 1:
            # broadcast result to all processes
            result = cast(float, idist.broadcast(result, src=0))

        return result
Exemplo n.º 5
0
def setup_experiment_tracking(config, with_clearml, task_type="training"):
    from datetime import datetime

    assert task_type in ("training", "testing"), task_type

    output_path = ""
    if idist.get_rank() == 0:
        if with_clearml:
            from clearml import Task

            schema = TrainvalConfigSchema if task_type == "training" else InferenceConfigSchema

            task = Task.init("Pascal-VOC12 Training", config.config_filepath.stem, task_type=task_type)
            task.connect_configuration(config.config_filepath.as_posix())

            task.upload_artifact(config.script_filepath.name, config.script_filepath.as_posix())
            task.upload_artifact(config.config_filepath.name, config.config_filepath.as_posix())
            task.connect(get_params(config, schema))

            output_path = Path(os.environ.get("CLEARML_OUTPUT_PATH", "/tmp"))
            output_path = output_path / "clearml" / datetime.now().strftime("%Y%m%d-%H%M%S")
        else:
            import shutil

            output_path = Path(os.environ.get("OUTPUT_PATH", "/tmp/output-pascal-voc12"))
            output_path = output_path / task_type / config.config_filepath.stem
            output_path = output_path / datetime.now().strftime("%Y%m%d-%H%M%S")
            output_path.mkdir(parents=True, exist_ok=True)

            shutil.copyfile(config.script_filepath.as_posix(), output_path / config.script_filepath.name)
            shutil.copyfile(config.config_filepath.as_posix(), output_path / config.config_filepath.name)

        output_path = output_path.as_posix()
    return Path(idist.broadcast(output_path, src=0))
Exemplo n.º 6
0
    def compute(self) -> float:
        if len(self._probs) < 1:
            raise NotComputableError(
                "Inception score must have at least one example before it can be computed."
            )

        ws = idist.get_world_size()
        _probs_tensor = torch.cat(self._probs, dim=0)

        if ws > 1 and not self._is_reduced:
            _probs_tensor = cast(torch.Tensor, idist.all_gather(_probs_tensor))
        self._is_reduced = True

        result = 0.0
        if idist.get_rank() == 0:
            N = _probs_tensor.shape[0]
            scores = torch.zeros((self.n_splits, ))
            for i in range(self.n_splits):
                part = _probs_tensor[i * (N // self.n_splits):(i + 1) *
                                     (N // self.n_splits)]
                kl = part * (torch.log(part) -
                             torch.log(torch.mean(part, dim=0)))
                kl = torch.mean(torch.sum(kl, dim=1))
                scores[i] = torch.exp(kl)
            result = torch.mean(scores).item()
        if ws > 1:
            result = cast(float, idist.broadcast(result, src=0))

        return result
Exemplo n.º 7
0
        def _test(text):

            if rank == src:
                t = text
            else:
                t = ""

            res = idist.broadcast(t, src=src)
            true_res = text
            assert res == true_res
Exemplo n.º 8
0
    def _test(data_src, data_others, safe_mode):
        for src in range(ws):

            data = data_src if rank == src else data_others
            res = idist.broadcast(data, src=src, safe_mode=safe_mode)

            if isinstance(res, torch.Tensor):
                assert (res == data_src).all(), f"{res} vs {data_src}"
                assert data_src.dtype == res.dtype
            else:
                assert res == data_src, f"{res} vs {data_src}"
Exemplo n.º 9
0
def _test_distrib_broadcast(device):

    rank = idist.get_rank()
    ws = idist.get_world_size()
    for src in range(ws):

        d = 10 if rank == src else 0
        res = idist.broadcast(d, src=src)
        true_res = 10
        assert res == true_res

        if rank == src:
            t = torch.tensor([1.2345, 2.3456],
                             dtype=torch.float,
                             device=device)
        else:
            t = torch.empty(2, device=device)

        res = idist.broadcast(t, src=src)
        true_res = torch.tensor([1.2345, 2.3456],
                                dtype=torch.float,
                                device=device)
        assert (res == true_res).all(), "{} vs {}".format(res, true_res)

        if rank == src:
            t = "test-abcdefg"
        else:
            t = ""

        res = idist.broadcast(t, src=src)
        true_res = "test-abcdefg"
        assert res == true_res

        if rank == src:
            t = torch.arange(100, device=device).reshape(4, 25) * (src + 1)
        else:
            t = torch.empty(4, 25, dtype=torch.long, device=device)

        in_dtype = torch.long
        res = idist.broadcast(t, src)
        assert res.shape == (4, 25)
        assert res.dtype == in_dtype
        true_res = torch.arange(100, device=device).reshape(4, 25) * (src + 1)
        assert (res == true_res).all()

    if idist.get_world_size() > 1:
        with pytest.raises(TypeError, match=r"Unhandled input type"):
            idist.broadcast([0, 1, 2], src=0)
Exemplo n.º 10
0
def get_model_weights(config, logger, with_clearml):

    path = ""
    if with_clearml:
        from clearml import Model

        if idist.get_rank() > 0:
            idist.barrier()
        else:
            model_id = config.weights_path

            logger.info(f"Loading trained model: {model_id}")
            model = Model(model_id)
            assert model is not None, f"{model_id}"
            path = model.get_local_copy()
            idist.barrier()
        path = idist.broadcast(path, src=0)
    else:
        path = config.weights_path
        logger.info(f"Loading {path}")

    assert Path(path).exists(), f"{path} is not found"
    return torch.load(path)
Exemplo n.º 11
0
def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
    """function to be run by idist.Parallel context manager."""

    # ----------------------
    # make a certain seed
    # ----------------------
    rank = idist.get_rank()
    manual_seed(config.seed + rank)

    # -----------------------
    # create output folder
    # -----------------------

    if rank == 0:
        now = datetime.now().strftime("%Y%m%d-%H%M%S")
        name = f"{config.model}-backend-{idist.backend()}-{now}"
        path = Path(config.output_dir, name)
        path.mkdir(parents=True, exist_ok=True)
        config.output_dir = path.as_posix()

    config.output_dir = Path(idist.broadcast(config.output_dir, src=0))

    # -----------------------------
    # datasets and dataloaders
    # -----------------------------
    # TODO : PLEASE provide your custom datasets and dataloaders configurations
    # we can use `idist.auto_dataloader` to handle distributed configurations
    # TODO : PLEASE replace `kwargs` with your desirable DataLoader arguments
    # See : https://pytorch.org/ignite/distributed.html#ignite.distributed.auto.auto_dataloader

    train_dataset, eval_dataset = get_datasets(path=config.data_path)

    train_dataloader = idist.auto_dataloader(
        train_dataset,
        batch_size=config.train_batch_size,
        num_workers=config.num_workers,
        shuffle=True,
        {% if use_distributed_training and not use_distributed_launcher %}
        persistent_workers=True,
        {% endif %}
    )
    eval_dataloader = idist.auto_dataloader(
        eval_dataset,
        batch_size=config.eval_batch_size,
        num_workers=config.num_workers,
        shuffle=False,
        {% if use_distributed_training and not use_distributed_launcher %}
        persistent_workers=True,
        {% endif %}
    )

    # ------------------------------------------
    # model, optimizer, loss function, device
    # ------------------------------------------

    device = idist.device()
    config.num_iters_per_epoch = len(train_dataloader)
    model, optimizer, loss_fn, lr_scheduler = initialize(config=config)

    # -----------------------------
    # trainer and evaluator
    # -----------------------------

    trainer, evaluator = create_trainers(
        config=config,
        model=model,
        optimizer=optimizer,
        loss_fn=loss_fn,
        device=device,
    )

    # ---------------------------------
    # attach metrics to evaluator
    # ---------------------------------
    accuracy = Accuracy(device=device)
    metrics = {
        "eval_accuracy": accuracy,
        "eval_loss": Loss(loss_fn, device=device),
        "eval_error": (1.0 - accuracy) * 100,
    }
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # -------------------------------------------
    # setup engines logger with python logging
    # print training configurations
    # -------------------------------------------

    logger = setup_logging(config)
    log_basic_info(logger, config)
    trainer.logger = logger
    evaluator.logger = logger

    # -------------------------------------
    # ignite handlers and ignite loggers
    # -------------------------------------

    to_save = {"model": model, "optimizer": optimizer, "trainer": trainer, "lr_scheduler": lr_scheduler}
    best_model_handler, es_handler, timer_handler = get_handlers(
        config=config,
        model=model,
        trainer=trainer,
        evaluator=evaluator,
        metric_name="eval_accuracy",
        es_metric_name="eval_accuracy",
        to_save=to_save,
        lr_scheduler=lr_scheduler,
        output_names=None,
    )

    # setup ignite logger only on rank 0
    if rank == 0:
        logger_handler = get_logger(
            config=config, trainer=trainer, evaluator=evaluator, optimizers=optimizer
        )

    # -----------------------------------
    # resume from the saved checkpoints
    # -----------------------------------

    if config.resume_from:
        resume_from(to_load=to_save, checkpoint_fp=config.resume_from)

    # --------------------------------
    # print metrics to the stderr
    # with `add_event_handler` API
    # for training stats
    # --------------------------------

    trainer.add_event_handler(Events.ITERATION_COMPLETED(every=config.log_every_iters), log_metrics, tag="train")

    # ---------------------------------------------
    # run evaluation at every training epoch end
    # with shortcut `on` decorator API and
    # print metrics to the stderr
    # again with `add_event_handler` API
    # for evaluation stats
    # ---------------------------------------------

    @trainer.on(Events.EPOCH_COMPLETED(every=1))
    def _():
        evaluator.run(eval_dataloader, epoch_length=config.eval_epoch_length)
        log_metrics(evaluator, "eval")

    # --------------------------------------------------
    # let's try run evaluation first as a sanity check
    # --------------------------------------------------

    @trainer.on(Events.STARTED)
    def _():
        evaluator.run(eval_dataloader, epoch_length=config.eval_epoch_length)

    # ------------------------------------------
    # setup if done. let's run the training
    # ------------------------------------------

    trainer.run(train_dataloader, max_epochs=config.max_epochs, epoch_length=config.train_epoch_length)

    # ------------------------------------------------------------
    # close the logger after the training completed / terminated
    # ------------------------------------------------------------

    if rank == 0:
        from ignite.contrib.handlers.wandb_logger import WandBLogger

        if isinstance(logger_handler, WandBLogger):
            # why handle differently for wandb ?
            # See : https://github.com/pytorch/ignite/issues/1894
            logger_handler.finish()
        elif logger_handler:
            logger_handler.close()

    # -----------------------------------------
    # where is my best and last checkpoint ?
    # -----------------------------------------

    if best_model_handler is not None:
        logger.info("Last and best checkpoint: %s", best_model_handler.last_checkpoint)
Exemplo n.º 12
0
def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
    """function to be run by idist.Parallel context manager."""

    # ----------------------
    # make a certain seed
    # ----------------------
    rank = idist.get_rank()
    manual_seed(config.seed + rank)

    # -----------------------
    # create output folder
    # -----------------------

    if rank == 0:
        now = datetime.now().strftime("%Y%m%d-%H%M%S")
        name = f"{config.dataset}-backend-{idist.backend()}-{now}"
        path = Path(config.output_dir, name)
        path.mkdir(parents=True, exist_ok=True)
        config.output_dir = path.as_posix()

    config.output_dir = Path(idist.broadcast(config.output_dir, src=0))

    # -----------------------------
    # datasets and dataloaders
    # -----------------------------
    # TODO : PLEASE provide your custom datasets and dataloaders configurations
    # we can use `idist.auto_dataloader` to handle distributed configurations
    # TODO : PLEASE replace `kwargs` with your desirable DataLoader arguments
    # See : https://pytorch.org/ignite/distributed.html#ignite.distributed.auto.auto_dataloader

    train_dataset, eval_dataset = get_datasets()

    train_dataloader = idist.auto_dataloader(train_dataset, **kwargs)
    eval_dataloader = idist.auto_dataloader(eval_dataset, **kwargs)

    # ------------------------------------------
    # model, optimizer, loss function, device
    # ------------------------------------------

    device = idist.device()
    model, optimizer, loss_fn, lr_scheduler = initialize()

    # -----------------------------
    # trainer and evaluator
    # -----------------------------

    trainer, evaluator = create_trainers(
        config=config,
        model=model,
        optimizer=optimizer,
        loss_fn=loss_fn,
        device=device,
    )

    # -------------------------------------------
    # update config with optimizer parameters
    # setup engines logger with python logging
    # print training configurations
    # -------------------------------------------

    config.__dict__.update(**optimizer.defaults)
    logger = setup_logging(config)
    log_basic_info(logger, config)
    trainer.logger = logger
    evaluator.logger = logger

    # -------------------------------------
    # ignite handlers and ignite loggers
    # -------------------------------------

    to_save = {"model": model, "optimizer": optimizer, "trainer": trainer, "lr_scheduler": lr_scheduler}
    best_model_handler, es_handler, timer_handler = get_handlers(
        config=config,
        model=model,
        trainer=trainer,
        evaluator=evaluator,
        metric_name=None,
        # TODO : replace with the metric name to save the best model
        # if you check `Save the best model by evaluation score` otherwise leave it None
        # metric must be in evaluator.state.metrics.
        es_metric_name=None,
        # TODO : replace with the metric name to early stop
        # if you check `Early stop the training by evaluation score` otherwise leave it None
        # metric must be in evaluator.state.metrics.
        to_save=to_save,
        lr_scheduler=lr_scheduler,
        output_names=None,
    )

    # setup ignite logger only on rank 0
    if rank == 0:
        logger_handler = get_logger(
            config=config, trainer=trainer, evaluator=evaluator, optimizers=optimizer
        )

    # -----------------------------------
    # resume from the saved checkpoints
    # -----------------------------------

    if config.resume_from:
        resume_from(to_load=to_save, checkpoint_fp=config.resume_from)

    # --------------------------------------------
    # let's trigger custom events we registered
    # we will use a `event_filter` to trigger that
    # `event_filter` has to return boolean
    # whether this event should be executed
    # here will log the gradients on the 1st iteration
    # and every 100 iterations
    # --------------------------------------------

    @trainer.on(TrainEvents.BACKWARD_COMPLETED(lambda _, ev: (ev % 100 == 0) or (ev == 1)))
    def _():
        # do something interesting
        pass

    # ----------------------------------------
    # here we will use `every` to trigger
    # every 100 iterations
    # ----------------------------------------

    @trainer.on(TrainEvents.OPTIM_STEP_COMPLETED(every=100))
    def _():
        # do something interesting
        pass

    # --------------------------------
    # print metrics to the stderr
    # with `add_event_handler` API
    # for training stats
    # --------------------------------

    trainer.add_event_handler(Events.ITERATION_COMPLETED(every=config.log_every_iters), log_metrics, tag="train")

    # ---------------------------------------------
    # run evaluation at every training epoch end
    # with shortcut `on` decorator API and
    # print metrics to the stderr
    # again with `add_event_handler` API
    # for evaluation stats
    # ---------------------------------------------

    @trainer.on(Events.EPOCH_COMPLETED(every=1))
    def _():
        evaluator.run(eval_dataloader, epoch_length=config.eval_epoch_length)
        log_metrics(evaluator, "eval")

    # --------------------------------------------------
    # let's try run evaluation first as a sanity check
    # --------------------------------------------------

    @trainer.on(Events.STARTED)
    def _():
        evaluator.run(eval_dataloader, epoch_length=config.eval_epoch_length)

    # ------------------------------------------
    # setup if done. let's run the training
    # ------------------------------------------
    # TODO : PLEASE provide `max_epochs` parameters

    trainer.run(train_dataloader, max_epochs=config.max_epochs, epoch_length=config.train_epoch_length)

    # ------------------------------------------------------------
    # close the logger after the training completed / terminated
    # ------------------------------------------------------------

    if rank == 0:
        from ignite.contrib.handlers.wandb_logger import WandBLogger

        if isinstance(logger_handler, WandBLogger):
            # why handle differently for wandb ?
            # See : https://github.com/pytorch/ignite/issues/1894
            logger_handler.finish()
        elif logger_handler:
            logger_handler.close()

    # -----------------------------------------
    # where is my best and last checkpoint ?
    # -----------------------------------------

    if best_model_handler is not None:
        logger.info("Last and best checkpoint: %s", best_model_handler.last_checkpoint)
Exemplo n.º 13
0
def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
    """function to be run by idist.Parallel context manager."""

    # ----------------------
    # make a certain seed
    # ----------------------
    rank = idist.get_rank()
    manual_seed(config.seed + rank)

    # -----------------------
    # create output folder
    # -----------------------

    if rank == 0:
        now = datetime.now().strftime("%Y%m%d-%H%M%S")
        name = f"{config.dataset}-backend-{idist.backend()}-{now}"
        path = Path(config.output_dir, name)
        path.mkdir(parents=True, exist_ok=True)
        config.output_dir = path.as_posix()

    config.output_dir = Path(idist.broadcast(config.output_dir, src=0))

    # -----------------------------
    # datasets and dataloaders
    # -----------------------------

    train_dataset, num_channels = get_datasets(config.dataset, config.data_path)

    train_dataloader = idist.auto_dataloader(
        train_dataset,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        {% if use_distributed_training and not use_distributed_launcher %}
        persistent_workers=True,
        {% endif %}
    )

    # ------------------------------------------
    # model, optimizer, loss function, device
    # ------------------------------------------

    device = idist.device()
    netD, netG, optimizerD, optimizerG, loss_fn, lr_scheduler = initialize(config, num_channels)

    # -----------------------------
    # trainer and evaluator
    # -----------------------------
    ws = idist.get_world_size()
    real_labels = torch.ones(config.batch_size // ws, device=device)
    fake_labels = torch.zeros(config.batch_size // ws, device=device)
    fixed_noise = torch.randn(config.batch_size // ws, config.z_dim, 1, 1, device=device)

    trainer = create_trainers(
        config=config,
        netD=netD,
        netG=netG,
        optimizerD=optimizerD,
        optimizerG=optimizerG,
        loss_fn=loss_fn,
        device=device,
        real_labels=real_labels,
        fake_labels=fake_labels,
    )

    # -------------------------------------------
    # setup engines logger with python logging
    # print training configurations
    # -------------------------------------------

    logger = setup_logging(config)
    log_basic_info(logger, config)
    trainer.logger = logger

    # -------------------------------------
    # ignite handlers and ignite loggers
    # -------------------------------------

    to_save = {'netD': netD, 'netG': netG, 'optimizerD': optimizerD, 'optimizerG': optimizerG, 'trainer': trainer}
    optimizers = {'optimizerD': optimizerD, 'optimizerG': optimizerG}
    best_model_handler, es_handler, timer_handler = get_handlers(
        config=config,
        model={'netD', netD, 'netG', netG},
        trainer=trainer,
        evaluator=trainer,
        metric_name='errD',
        es_metric_name='errD',
        to_save=to_save,
        lr_scheduler=lr_scheduler,
        output_names=["errD", "errG", "D_x", "D_G_z1", "D_G_z2"],
    )

    # setup ignite logger only on rank 0
    if rank == 0:
        logger_handler = get_logger(config=config, trainer=trainer, optimizers=optimizers)

    # -----------------------------------
    # resume from the saved checkpoints
    # -----------------------------------

    if config.resume_from:
        resume_from(to_load=to_save, checkpoint_fp=config.resume_from)

    # --------------------------------------------------
    # adding handlers using `trainer.on` decorator API
    # --------------------------------------------------

    @trainer.on(Events.EPOCH_COMPLETED)
    def save_fake_example(engine):
        fake = netG(fixed_noise)
        path = config.output_dir / (FAKE_IMG_FNAME.format(engine.state.epoch))
        vutils.save_image(fake.detach(), path, normalize=True)

    # --------------------------------------------------
    # adding handlers using `trainer.on` decorator API
    # --------------------------------------------------
    @trainer.on(Events.EPOCH_COMPLETED)
    def save_real_example(engine):
        img, y = engine.state.batch
        path = config.output_dir / (REAL_IMG_FNAME.format(engine.state.epoch))
        vutils.save_image(img, path, normalize=True)

    # -------------------------------------------------------------
    # adding handlers using `trainer.on` decorator API
    # -------------------------------------------------------------
    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        if not timer_handler:
            logger.info(f"Epoch {engine.state.epoch} done. Time per batch: {timer_handler.value():.3f}[s]")
            timer_handler.reset()

    @trainer.on(Events.ITERATION_COMPLETED(every=config.log_every_iters))
    @idist.one_rank_only()
    def print_logs(engine):
        fname = config.output_dir / LOGS_FNAME
        columns = ["iteration", ] + list(engine.state.metrics.keys())
        values = [str(engine.state.iteration), ] + [str(round(value, 5)) for value in engine.state.metrics.values()]

        with open(fname, "a") as f:
            if f.tell() == 0:
                print("\t".join(columns), file=f)
            print("\t".join(values), file=f)
        message = f"[{engine.state.epoch}/{config.max_epochs}][{engine.state.iteration % len(train_dataloader)}/{len(train_dataloader)}]"
        for name, value in zip(columns, values):
            message += f" | {name}: {value}"

    # -------------------------------------------------------------
    # adding handlers using `trainer.on` decorator API
    # -------------------------------------------------------------
    @trainer.on(Events.EPOCH_COMPLETED)
    def create_plots(engine):
        try:
            import matplotlib as mpl

            mpl.use("agg")

            import matplotlib.pyplot as plt
            import pandas as pd

        except ImportError:
            warnings.warn("Loss plots will not be generated -- pandas or matplotlib not found")

        else:
            df = pd.read_csv(config.output_dir / LOGS_FNAME, delimiter="\t", index_col="iteration")
            _ = df.plot(subplots=True, figsize=(20, 20))
            _ = plt.xlabel("Iteration number")
            fig = plt.gcf()
            path = config.output_dir / PLOT_FNAME

            fig.savefig(path)

    # --------------------------------
    # print metrics to the stderr
    # with `add_event_handler` API
    # for training stats
    # --------------------------------

    trainer.add_event_handler(Events.ITERATION_COMPLETED(every=config.log_every_iters), log_metrics, tag="train")

    # ------------------------------------------
    # setup if done. let's run the training
    # ------------------------------------------

    trainer.run(train_dataloader, max_epochs=config.max_epochs, epoch_length=config.train_epoch_length)

    # ------------------------------------------------------------
    # close the logger after the training completed / terminated
    # ------------------------------------------------------------

    if rank == 0:
        from ignite.contrib.handlers.wandb_logger import WandBLogger

        if isinstance(logger_handler, WandBLogger):
            # why handle differently for wandb ?
            # See : https://github.com/pytorch/ignite/issues/1894
            logger_handler.finish()
        elif logger_handler:
            logger_handler.close()

    # -----------------------------------------
    # where is my best and last checkpoint ?
    # -----------------------------------------

    if best_model_handler is not None:
        logger.info("Last and best checkpoint: %s", best_model_handler.last_checkpoint)
Exemplo n.º 14
0
def run(local_rank, config):

    # ----------------------
    # Make a certain seed
    # ----------------------
    rank = idist.get_rank()
    manual_seed(config.seed + rank)
    device = idist.device()

    # -----------------------
    # Create output folder
    # -----------------------
    if rank == 0:
        now = datetime.now().strftime("%Y%m%d-%H%M%S")
        name = f"{config.model}-backend-{idist.backend()}-{now}"
        path = Path(config.output_dir, name)
        path.mkdir(parents=True, exist_ok=True)
        config.output_dir = path.as_posix()

    config.output_dir = Path(idist.broadcast(config.output_dir, src=0))

    # -----------------------------
    # datasets and dataloaders
    # -----------------------------
    train_loader, test_loader = get_dataflow(config)

    # ------------------------------------------
    # model, optimizer, loss function, device
    # ------------------------------------------
    config.num_iters_per_epoch = len(train_loader)
    model, optimizer, loss_fn, lr_scheduler = initialize(config)

    # -----------------------------
    # trainer and evaluator
    # -----------------------------
    trainer, evaluator = create_trainers(
        config=config,
        model=model,
        optimizer=optimizer,
        loss_fn=loss_fn,
        device=device,
    )

    # ---------------------------------
    # attach metrics to evaluator
    # ---------------------------------
    metrics = {
        "eval_accuracy": Accuracy(output_transform=thresholded_output_transform, device=device),
        "eval_loss": Loss(loss_fn, device=device),
    }

    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # -------------------------------------------
    # setup engines logger with python logging
    # print training configurations
    # -------------------------------------------
    logger = setup_logging(config)
    log_basic_info(logger, config)
    trainer.logger = logger
    evaluator.logger = logger

    # -------------------------------------
    # ignite handlers and ignite loggers
    # -------------------------------------
    to_save = {"model": model, "optimizer": optimizer, "trainer": trainer, "lr_scheduler": lr_scheduler}
    best_model_handler, es_handler, timer_handler = get_handlers(
        config=config,
        model=model,
        trainer=trainer,
        evaluator=evaluator,
        metric_name="eval_accuracy",
        es_metric_name="eval_accuracy",
        to_save=to_save,
        lr_scheduler=lr_scheduler,
        output_names=None,
    )

    # setup ignite logger only on rank 0
    if rank == 0:
        logger_handler = get_logger(
            config=config, trainer=trainer, evaluator=evaluator, optimizers=optimizer
        )

    # -----------------------------------
    # resume from the saved checkpoints
    # -----------------------------------
    if config.resume_from:
        resume_from(to_load=to_save, checkpoint_fp=config.resume_from)

    # --------------------------------
    # print metrics to the stderr
    # with `add_event_handler` API
    # for training stats
    # --------------------------------
    trainer.add_event_handler(Events.ITERATION_COMPLETED(every=config.log_every_iters), log_metrics, tag="train")

    # ---------------------------------------------
    # run evaluation at every training epoch end
    # with shortcut `on` decorator API and
    # print metrics to the stderr
    # again with `add_event_handler` API
    # for evaluation stats
    # ---------------------------------------------
    @trainer.on(Events.EPOCH_COMPLETED(every=config.validate_every))
    def _():
        evaluator.run(test_loader, epoch_length=config.eval_epoch_length)
        log_metrics(evaluator, tag="eval")

    # --------------------------------------------------
    # let's try run evaluation first as a sanity check
    # --------------------------------------------------
    @trainer.on(Events.STARTED)
    def _():
        evaluator.run(test_loader, epoch_length=config.eval_epoch_length)

    # ------------------------------------------
    # setup if done. let's run the training
    # ------------------------------------------
    trainer.run(train_loader, max_epochs=config.max_epochs, epoch_length=config.train_epoch_length)

    # ------------------------------------------------------------
    # close the logger after the training completed / terminated
    # ------------------------------------------------------------
    if rank == 0:
        from ignite.contrib.handlers.wandb_logger import WandBLogger

        if isinstance(logger_handler, WandBLogger):
            # why handle differently for wandb ?
            # See : https://github.com/pytorch/ignite/issues/1894
            logger_handler.finish()
        elif logger_handler:
            logger_handler.close()

    # -----------------------------------------
    # where is my best and last checkpoint ?
    # -----------------------------------------
    if best_model_handler is not None:
        logger.info("Last and best checkpoint: %s", best_model_handler.last_checkpoint)