Пример #1
0
def _test_distrib_config(local_rank, backend, ws, true_device, rank=None):
    assert idist.backend() == backend, f"{idist.backend()} vs {backend}"

    this_device = idist.device()
    assert isinstance(this_device, torch.device)
    if backend in ("nccl", "horovod") and "cuda" in this_device.type:
        true_device = torch.device(f"{true_device}:{local_rank}")
        assert this_device == true_device, f"{this_device} vs {true_device}"
    elif backend in ("gloo", "horovod"):
        assert this_device == torch.device(true_device)
    elif backend == "xla-tpu":
        assert true_device in this_device.type

    if rank is None:
        if idist.model_name() == "native-dist":
            rank = dist.get_rank()

    if rank is not None:
        assert idist.get_rank() == rank

    assert idist.get_world_size() == ws
    assert idist.get_local_rank() == local_rank

    assert idist.model_name() in ("native-dist", "xla-dist", "horovod-dist")

    _sanity_check()
 def log_basic_info(self, logger):
     logger.info("- PyTorch version: {}".format(torch.__version__))
     logger.info("- Ignite version: {}".format(ignite.__version__))
     if idist.get_world_size() > 1:
         logger.info("\nDistributed setting:")
         logger.info("\tbackend: {}".format(idist.backend()))
         logger.info("\tworld size: {}".format(idist.get_world_size()))
         logger.info("\n")
Пример #3
0
def _test_func(index, ws, device, backend, true_init_method):
    assert 0 <= index < ws
    assert index == idist.get_local_rank()
    assert ws == idist.get_world_size()
    assert torch.device(device).type == idist.device().type
    assert backend == idist.backend()

    if idist.model_name() == "native-dist":
        from ignite.distributed.utils import _model

        assert _model._init_method == true_init_method
Пример #4
0
def log_basic_info(logger, config):

    msg = "\n- PyTorch version: {}".format(torch.__version__)
    msg += "\n- Ignite version: {}".format(ignite.__version__)
    logger.info(msg)

    if idist.get_world_size() > 1:
        msg = "\nDistributed setting:"
        msg += "\tbackend: {}".format(idist.backend())
        msg += "\trank: {}".format(idist.get_rank())
        msg += "\tworld size: {}".format(idist.get_world_size())
        logger.info(msg)
Пример #5
0
def _mp_train(rank):

    # Specific ignite.distributed
    print(
        idist.get_rank(),
        "- backend=",
        idist.backend(),
        "- world size",
        idist.get_world_size(),
        "- device",
        idist.device(),
    )
    print(idist.get_rank(), " with seed ", torch.initial_seed())
Пример #6
0
def log_basic_info(logger, config):
    logger.info("Train {} on CIFAR10".format(config["model"]))
    logger.info("- PyTorch version: {}".format(torch.__version__))
    logger.info("- Ignite version: {}".format(ignite.__version__))

    logger.info("\n")
    logger.info("Configuration:")
    for key, value in config.items():
        logger.info("\t{}: {}".format(key, value))
    logger.info("\n")

    if idist.get_world_size() > 1:
        logger.info("\nDistributed setting:")
        logger.info("\tbackend: {}".format(idist.backend()))
        logger.info("\tworld size: {}".format(idist.get_world_size()))
        logger.info("\n")
Пример #7
0
def log_basic_info(logger: Logger, config: ConfigSchema):
    logger.info("Experiment: {}".format(config.experiment_name))
    logger.info("- PyTorch version: {}".format(torch.__version__))
    logger.info("- Ignite version: {}".format(ignite.__version__))

    logger.info("\n")
    logger.info("Configuration:")
    for line in OmegaConf.to_yaml(config).split("\n"):
        logger.info("\t" + line)
    logger.info("\n")

    if idist.get_world_size() > 1:
        logger.info("\nDistributed setting:")
        logger.info("\tbackend: {}".format(idist.backend()))
        logger.info("\tworld size: {}".format(idist.get_world_size()))
        logger.info("\n")
Пример #8
0
def _test_auto_model(model, ws, device, sync_bn=False, **kwargs):
    model = auto_model(model, sync_bn=sync_bn, **kwargs)
    bnd = idist.backend()
    if ws > 1 and torch.device(device).type in ("cuda", "cpu"):
        if idist.has_native_dist_support and bnd in ("nccl", "gloo"):
            assert isinstance(model, nn.parallel.DistributedDataParallel)
            if sync_bn:
                assert any([isinstance(m, nn.SyncBatchNorm) for m in model.modules()])
            if "find_unused_parameters" in kwargs:
                assert model.find_unused_parameters == kwargs["find_unused_parameters"]
        elif idist.has_hvd_support and bnd in ("horovod",):
            assert isinstance(model, nn.Module)
    elif device != "cpu" and torch.cuda.is_available() and torch.cuda.device_count() > 1:
        assert isinstance(model, nn.parallel.DataParallel)
    else:
        assert isinstance(model, nn.Module)

    assert all(
        [p.device.type == torch.device(device).type for p in model.parameters()]
    ), f"{[p.device.type for p in model.parameters()]} vs {torch.device(device).type}"
Пример #9
0
def _test_auto_model(model, ws, device, sync_bn=False):
    model = auto_model(model, sync_bn=sync_bn)
    bnd = idist.backend()
    if ws > 1 and device in ("cuda", "cpu"):
        if idist.has_native_dist_support and bnd in ("nccl" or "gloo"):
            assert isinstance(model, nn.parallel.DistributedDataParallel)
            if sync_bn:
                assert any(
                    [isinstance(m, nn.SyncBatchNorm) for m in model.modules()])
        elif idist.has_hvd_support and bnd in ("horovod", ):
            assert isinstance(model, nn.Module)
    elif device != "cpu" and torch.cuda.is_available(
    ) and torch.cuda.device_count() > 1:
        assert isinstance(model, nn.parallel.DataParallel)
    else:
        assert isinstance(model, nn.Module)

    assert all([p.device.type == device
                for p in model.parameters()]), "{} vs {}".format(
                    [p.device.type for p in model.parameters()], device)
Пример #10
0
def _test_auto_model_optimizer(ws, device):
    # Test auto_model
    model = nn.Linear(10, 10)
    _test_auto_model(model, ws, device)

    model = nn.Sequential(nn.Linear(20, 100), nn.BatchNorm1d(100))
    _test_auto_model(model, ws, device, sync_bn="cuda" in torch.device(device).type)
    if ws > 1:
        _test_auto_model(model, ws, device, find_unused_parameters=True)
        _test_auto_model(model, ws, device, find_unused_parameters=False)

    # Test auto_optim
    bnd = idist.backend()
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    optimizer = auto_optim(optimizer)
    if idist.has_xla_support and "xla" in device:
        assert isinstance(optimizer, optim.SGD) and hasattr(optimizer, "wrapped_optimizer")
    elif idist.has_hvd_support and bnd in ("horovod",):
        assert isinstance(optimizer, optim.SGD) and hasattr(optimizer, "_allreduce_grad_async")
    else:
        assert isinstance(optimizer, optim.SGD) and not hasattr(optimizer, "wrapped_optimizer")
Пример #11
0
def test_no_distrib(capsys):

    from ignite.distributed.utils import _model

    print("test_no_distrib : dist: ", dist.is_available())
    print("test_no_distrib : _model", type(_model))

    assert idist.backend() is None
    if torch.cuda.is_available():
        assert idist.device().type == "cuda"
    else:
        assert idist.device().type == "cpu"
    assert idist.get_rank() == 0
    assert idist.get_world_size() == 1
    assert idist.get_local_rank() == 0
    assert idist.model_name() == "serial"

    from ignite.distributed.utils import _model, _SerialModel

    _sanity_check()
    assert isinstance(_model, _SerialModel)

    idist.show_config()
    captured = capsys.readouterr()
    out = captured.err.split("\r")
    out = list(map(lambda x: x.strip(), out))
    out = list(filter(None, out))
    assert "ignite.distributed.utils INFO: distributed configuration: serial" in out[
        -1]
    assert "ignite.distributed.utils INFO: backend: None" in out[-1]
    if torch.cuda.is_available():
        assert "ignite.distributed.utils INFO: device: cuda" in out[-1]
    else:
        assert "ignite.distributed.utils INFO: device: cpu" in out[-1]
    assert "ignite.distributed.utils INFO: rank: 0" in out[-1]
    assert "ignite.distributed.utils INFO: local rank: 0" in out[-1]
    assert "ignite.distributed.utils INFO: world size: 1" in out[-1]
Пример #12
0
def log_basic_info(logger: Logger, config: Any) -> None:
    """Logging about pytorch, ignite, configurations, gpu system
    distributed settings.

    Parameters
    ----------
    logger
        Logger instance for logging
    config
        config object to log
    """
    import ignite

    logger.info("PyTorch version: %s", torch.__version__)
    logger.info("Ignite version: %s", ignite.__version__)
    if torch.cuda.is_available():
        # explicitly import cudnn as
        # torch.backends.cudnn can not be pickled with hvd spawning procs
        from torch.backends import cudnn

        logger.info("GPU device: %s", torch.cuda.get_device_name(idist.get_local_rank()))
        logger.info("CUDA version: %s", torch.version.cuda)
        logger.info("CUDNN version: %s", cudnn.version())

    logger.info("Configuration: %s", pformat(vars(config)))

    if idist.get_world_size() > 1:
        logger.info("distributed configuration: %s", idist.model_name())
        logger.info("backend: %s", idist.backend())
        logger.info("device: %s", idist.device().type)
        logger.info("hostname: %s", idist.hostname())
        logger.info("world size: %s", idist.get_world_size())
        logger.info("rank: %s", idist.get_rank())
        logger.info("local rank: %s", idist.get_local_rank())
        logger.info("num processes per node: %s", idist.get_nproc_per_node())
        logger.info("num nodes: %s", idist.get_nnodes())
        logger.info("node rank: %s", idist.get_node_rank())
Пример #13
0
def create_trainer(
    train_step,
    output_names,
    model,
    ema_model,
    optimizer,
    lr_scheduler,
    supervised_train_loader,
    test_loader,
    cfg,
    logger,
    cta=None,
    unsup_train_loader=None,
    cta_probe_loader=None,
):

    trainer = Engine(train_step)
    trainer.logger = logger

    output_path = os.getcwd()

    to_save = {
        "model": model,
        "ema_model": ema_model,
        "optimizer": optimizer,
        "trainer": trainer,
        "lr_scheduler": lr_scheduler,
    }
    if cta is not None:
        to_save["cta"] = cta

    common.setup_common_training_handlers(
        trainer,
        train_sampler=supervised_train_loader.sampler,
        to_save=to_save,
        save_every_iters=cfg.solver.checkpoint_every,
        output_path=output_path,
        output_names=output_names,
        lr_scheduler=lr_scheduler,
        with_pbars=False,
        clear_cuda_cache=False,
    )

    ProgressBar(persist=False).attach(
        trainer, metric_names="all", event_name=Events.ITERATION_COMPLETED
    )

    unsupervised_train_loader_iter = None
    if unsup_train_loader is not None:
        unsupervised_train_loader_iter = cycle(unsup_train_loader)

    cta_probe_loader_iter = None
    if cta_probe_loader is not None:
        cta_probe_loader_iter = cycle(cta_probe_loader)

    # Setup handler to prepare data batches
    @trainer.on(Events.ITERATION_STARTED)
    def prepare_batch(e):
        sup_batch = e.state.batch
        e.state.batch = {
            "sup_batch": sup_batch,
        }
        if unsupervised_train_loader_iter is not None:
            unsup_batch = next(unsupervised_train_loader_iter)
            e.state.batch["unsup_batch"] = unsup_batch

        if cta_probe_loader_iter is not None:
            cta_probe_batch = next(cta_probe_loader_iter)
            cta_probe_batch["policy"] = [
                deserialize(p) for p in cta_probe_batch["policy"]
            ]
            e.state.batch["cta_probe_batch"] = cta_probe_batch

    # Setup handler to update EMA model
    @trainer.on(Events.ITERATION_COMPLETED, cfg.ema_decay)
    def update_ema_model(ema_decay):
        # EMA on parametes
        for ema_param, param in zip(ema_model.parameters(), model.parameters()):
            ema_param.data.mul_(ema_decay).add_(param.data, alpha=1.0 - ema_decay)

    # Setup handlers for debugging
    if cfg.debug:

        @trainer.on(Events.STARTED | Events.ITERATION_COMPLETED(every=100))
        @idist.one_rank_only()
        def log_weights_norms():
            wn = []
            ema_wn = []
            for ema_param, param in zip(ema_model.parameters(), model.parameters()):
                wn.append(torch.mean(param.data))
                ema_wn.append(torch.mean(ema_param.data))

            msg = "\n\nWeights norms"
            msg += "\n- Raw model: {}".format(
                to_list_str(torch.tensor(wn[:10] + wn[-10:]))
            )
            msg += "\n- EMA model: {}\n".format(
                to_list_str(torch.tensor(ema_wn[:10] + ema_wn[-10:]))
            )
            logger.info(msg)

            rmn = []
            rvar = []
            ema_rmn = []
            ema_rvar = []
            for m1, m2 in zip(model.modules(), ema_model.modules()):
                if isinstance(m1, nn.BatchNorm2d) and isinstance(m2, nn.BatchNorm2d):
                    rmn.append(torch.mean(m1.running_mean))
                    rvar.append(torch.mean(m1.running_var))
                    ema_rmn.append(torch.mean(m2.running_mean))
                    ema_rvar.append(torch.mean(m2.running_var))

            msg = "\n\nBN buffers"
            msg += "\n- Raw mean: {}".format(to_list_str(torch.tensor(rmn[:10])))
            msg += "\n- Raw var: {}".format(to_list_str(torch.tensor(rvar[:10])))
            msg += "\n- EMA mean: {}".format(to_list_str(torch.tensor(ema_rmn[:10])))
            msg += "\n- EMA var: {}\n".format(to_list_str(torch.tensor(ema_rvar[:10])))
            logger.info(msg)

        # TODO: Need to inspect a bug
        # if idist.get_rank() == 0:
        #     from ignite.contrib.handlers import ProgressBar
        #
        #     profiler = BasicTimeProfiler()
        #     profiler.attach(trainer)
        #
        #     @trainer.on(Events.ITERATION_COMPLETED(every=200))
        #     def log_profiling(_):
        #         results = profiler.get_results()
        #         profiler.print_results(results)

    # Setup validation engine
    metrics = {
        "accuracy": Accuracy(),
    }

    if not (idist.has_xla_support and idist.backend() == idist.xla.XLA_TPU):
        metrics.update({
            "precision": Precision(average=False),
            "recall": Recall(average=False),
        })

    eval_kwargs = dict(
        metrics=metrics,
        prepare_batch=sup_prepare_batch,
        device=idist.device(),
        non_blocking=True,
    )

    evaluator = create_supervised_evaluator(model, **eval_kwargs)
    ema_evaluator = create_supervised_evaluator(ema_model, **eval_kwargs)

    def log_results(epoch, max_epochs, metrics, ema_metrics):
        msg1 = "\n".join(
            ["\t{:16s}: {}".format(k, to_list_str(v)) for k, v in metrics.items()]
        )
        msg2 = "\n".join(
            ["\t{:16s}: {}".format(k, to_list_str(v)) for k, v in ema_metrics.items()]
        )
        logger.info(
            "\nEpoch {}/{}\nRaw:\n{}\nEMA:\n{}\n".format(epoch, max_epochs, msg1, msg2)
        )
        if cta is not None:
            logger.info("\n" + stats(cta))

    @trainer.on(
        Events.EPOCH_COMPLETED(every=cfg.solver.validate_every)
        | Events.STARTED
        | Events.COMPLETED
    )
    def run_evaluation():
        evaluator.run(test_loader)
        ema_evaluator.run(test_loader)
        log_results(
            trainer.state.epoch,
            trainer.state.max_epochs,
            evaluator.state.metrics,
            ema_evaluator.state.metrics,
        )

    # setup TB logging
    if idist.get_rank() == 0:
        tb_logger = common.setup_tb_logging(
            output_path,
            trainer,
            optimizers=optimizer,
            evaluators={"validation": evaluator, "ema validation": ema_evaluator},
            log_every_iters=15,
        )
        if cfg.online_exp_tracking.wandb:
            from ignite.contrib.handlers import WandBLogger

            wb_dir = Path("/tmp/output-fixmatch-wandb")
            if not wb_dir.exists():
                wb_dir.mkdir()

            _ = WandBLogger(
                project="fixmatch-pytorch",
                name=cfg.name,
                config=cfg,
                sync_tensorboard=True,
                dir=wb_dir.as_posix(),
                reinit=True,
            )

    resume_from = cfg.solver.resume_from
    if resume_from is not None:
        resume_from = list(Path(resume_from).rglob("training_checkpoint*.pt*"))
        if len(resume_from) > 0:
            # get latest
            checkpoint_fp = max(resume_from, key=lambda p: p.stat().st_mtime)
            assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format(
                checkpoint_fp.as_posix()
            )
            logger.info("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix()))
            checkpoint = torch.load(checkpoint_fp.as_posix())
            Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    @trainer.on(Events.COMPLETED)
    def release_all_resources():
        nonlocal unsupervised_train_loader_iter, cta_probe_loader_iter

        if idist.get_rank() == 0:
            tb_logger.close()

        if unsupervised_train_loader_iter is not None:
            unsupervised_train_loader_iter = None

        if cta_probe_loader_iter is not None:
            cta_probe_loader_iter = None

    return trainer
Пример #14
0
def test_idist_methods_no_dist():
    assert idist.get_world_size() < 2
    assert idist.backend() is None, "{}".format(idist.backend())
Пример #15
0
def create_trainer(model, optimizer, criterion, train_sampler, config, logger, with_clearml):
    device = config.device
    prepare_batch = data.prepare_image_mask

    # Setup trainer
    accumulation_steps = config.get("accumulation_steps", 1)
    model_output_transform = config.get("model_output_transform", lambda x: x)

    with_amp = config.get("with_amp", True)
    scaler = GradScaler(enabled=with_amp)

    def forward_pass(batch):
        model.train()
        x, y = prepare_batch(batch, device=device, non_blocking=True)
        with autocast(enabled=with_amp):
            y_pred = model(x)
            y_pred = model_output_transform(y_pred)
            loss = criterion(y_pred, y) / accumulation_steps
        return loss

    def amp_backward_pass(engine, loss):
        scaler.scale(loss).backward()
        if engine.state.iteration % accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

    def hvd_amp_backward_pass(engine, loss):
        scaler.scale(loss).backward()
        optimizer.synchronize()
        with optimizer.skip_synchronize():
            scaler.step(optimizer)
            scaler.update()
        optimizer.zero_grad()

    if idist.backend() == "horovod" and with_amp:
        backward_pass = hvd_amp_backward_pass
    else:
        backward_pass = amp_backward_pass

    def training_step(engine, batch):
        loss = forward_pass(batch)
        output = {"supervised batch loss": loss.item()}
        backward_pass(engine, loss)
        return output

    trainer = Engine(training_step)
    trainer.logger = logger

    output_names = [
        "supervised batch loss",
    ]
    lr_scheduler = config.lr_scheduler

    to_save = {
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler,
        "trainer": trainer,
        "amp": scaler,
    }

    save_every_iters = config.get("save_every_iters", 1000)

    common.setup_common_training_handlers(
        trainer,
        train_sampler,
        to_save=to_save,
        save_every_iters=save_every_iters,
        save_handler=utils.get_save_handler(config.output_path.as_posix(), with_clearml),
        lr_scheduler=lr_scheduler,
        output_names=output_names,
        with_pbars=not with_clearml,
        log_every_iters=1,
    )

    resume_from = config.get("resume_from", None)
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(), f"Checkpoint '{checkpoint_fp.as_posix()}' is not found"
        logger.info(f"Resume from a checkpoint: {checkpoint_fp.as_posix()}")
        checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    return trainer
Пример #16
0
def training(rank, config):

    # Specific ignite.distributed
    print(
        idist.get_rank(),
        ": run with config:",
        config,
        "- backend=",
        idist.backend(),
        "- world size",
        idist.get_world_size(),
    )

    device = idist.device()

    # Data preparation:
    dataset = RndDataset(nb_samples=config["nb_samples"])

    # Specific ignite.distributed
    train_loader = idist.auto_dataloader(dataset,
                                         batch_size=config["batch_size"])

    # Model, criterion, optimizer setup
    model = idist.auto_model(wide_resnet50_2(num_classes=100))
    criterion = NLLLoss()
    optimizer = idist.auto_optim(SGD(model.parameters(), lr=0.01))

    # Training loop log param
    log_interval = config["log_interval"]

    def _train_step(engine, batch):

        data = batch[0].to(device)
        target = batch[1].to(device)

        optimizer.zero_grad()
        output = model(data)
        # Add a softmax layer
        probabilities = torch.nn.functional.softmax(output, dim=0)

        loss_val = criterion(probabilities, target)
        loss_val.backward()
        optimizer.step()

        return loss_val

    # Running the _train_step function on whole batch_data iterable only once
    trainer = Engine(_train_step)

    # Add a logger
    @trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
    def log_training():
        print("Process {}/{} Train Epoch: {} [{}/{}]\tLoss: {}".format(
            idist.get_rank(),
            idist.get_world_size(),
            trainer.state.epoch,
            trainer.state.iteration * len(trainer.state.batch[0]),
            len(dataset) / idist.get_world_size(),
            trainer.state.output,
        ))

    trainer.run(train_loader, max_epochs=1)
Пример #17
0
def training(local_rank, config):

    rank = idist.get_rank()
    manual_seed(config["seed"] + rank)
    device = idist.device()

    logger = setup_logger(name="CIFAR10-QAT-Training",
                          distributed_rank=local_rank)

    log_basic_info(logger, config)

    output_path = config["output_path"]
    if rank == 0:
        now = datetime.now().strftime("%Y%m%d-%H%M%S")

        folder_name = "{}_backend-{}-{}_{}".format(config["model"],
                                                   idist.backend(),
                                                   idist.get_world_size(), now)
        output_path = Path(output_path) / folder_name
        if not output_path.exists():
            output_path.mkdir(parents=True)
        config["output_path"] = output_path.as_posix()
        logger.info("Output path: {}".format(config["output_path"]))

        if "cuda" in device.type:
            config["cuda device name"] = torch.cuda.get_device_name(local_rank)

    # Setup dataflow, model, optimizer, criterion
    train_loader, test_loader = get_dataflow(config)

    config["num_iters_per_epoch"] = len(train_loader)
    model, optimizer, criterion, lr_scheduler = initialize(config)

    # Create trainer for current task
    trainer = create_trainer(model, optimizer, criterion, lr_scheduler,
                             train_loader.sampler, config, logger)

    # Let's now setup evaluator engine to perform model's validation and compute metrics
    metrics = {
        "Accuracy": Accuracy(),
        "Loss": Loss(criterion),
    }

    # We define two evaluators as they wont have exactly similar roles:
    # - `evaluator` will save the best model based on validation score
    evaluator = create_supervised_evaluator(model,
                                            metrics=metrics,
                                            device=device,
                                            non_blocking=True)
    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=metrics,
                                                  device=device,
                                                  non_blocking=True)

    def run_validation(engine):
        epoch = trainer.state.epoch
        state = train_evaluator.run(train_loader)
        log_metrics(logger, epoch, state.times["COMPLETED"], "Train",
                    state.metrics)
        state = evaluator.run(test_loader)
        log_metrics(logger, epoch, state.times["COMPLETED"], "Test",
                    state.metrics)

    trainer.add_event_handler(
        Events.EPOCH_COMPLETED(every=config["validate_every"])
        | Events.COMPLETED, run_validation)

    if rank == 0:
        # Setup TensorBoard logging on trainer and evaluators. Logged values are:
        #  - Training metrics, e.g. running average loss values
        #  - Learning rate
        #  - Evaluation train/test metrics
        evaluators = {"training": train_evaluator, "test": evaluator}
        tb_logger = common.setup_tb_logging(output_path,
                                            trainer,
                                            optimizer,
                                            evaluators=evaluators)

    # Store 3 best models by validation accuracy:
    common.save_best_model_by_val_score(
        output_path=config["output_path"],
        evaluator=evaluator,
        model=model,
        metric_name="Accuracy",
        n_saved=1,
        trainer=trainer,
        tag="test",
    )

    trainer.run(train_loader, max_epochs=config["num_epochs"])

    if rank == 0:
        tb_logger.close()
Пример #18
0
def test_idist_methods_no_dist():
    assert idist.get_world_size() < 2
    assert idist.backend() is None, f"{idist.backend()}"
Пример #19
0
def training(local_rank, config):

    rank = idist.get_rank()
    manual_seed(config["seed"] + rank)
    device = idist.device()

    logger = setup_logger(name="ImageNet-Training",
                          distributed_rank=local_rank)

    log_basic_info(logger, config)

    output_path = config["output_path"]
    if rank == 0:
        if config["stop_iteration"] is None:
            now = datetime.now().strftime("%Y%m%d-%H%M%S")
        else:
            now = "stop-on-{}".format(config["stop_iteration"])

        folder_name = "{}_backend-{}-{}_{}".format(config["model"],
                                                   idist.backend(),
                                                   idist.get_world_size(), now)
        output_path = Path(output_path) / folder_name
        if not output_path.exists():
            output_path.mkdir(parents=True)
        config["output_path"] = output_path.as_posix()
        logger.info("Output path: {}".format(config["output_path"]))

        if "cuda" in device.type:
            config["cuda device name"] = torch.cuda.get_device_name(local_rank)

    # Setup dataflow, model, optimizer, criterion
    train_loader, test_loader = get_imagenet_dataloader(config)

    config["num_iters_per_epoch"] = len(train_loader)
    model, optimizer, criterion, lr_scheduler = initialize(config)

    # Create trainer for current task
    trainer = create_supervised_trainer(model, optimizer, criterion,
                                        lr_scheduler, train_loader.sampler,
                                        config, logger)

    # Let's now setup evaluator engine to perform model's validation and compute metrics
    metrics = {
        "accuracy": Accuracy(),
        "loss": Loss(criterion),
    }

    # We define two evaluators as they wont have exactly similar roles:
    # - `evaluator` will save the best model based on validation score
    evaluator = create_supervised_evaluator(model,
                                            metrics=metrics,
                                            device=device,
                                            non_blocking=True)
    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=metrics,
                                                  device=device,
                                                  non_blocking=True)

    def run_validation(engine):
        epoch = trainer.state.epoch
        state = train_evaluator.run(train_loader)
        log_metrics(logger, epoch, state.times["COMPLETED"], "Train",
                    state.metrics)
        state = evaluator.run(test_loader)
        log_metrics(logger, epoch, state.times["COMPLETED"], "Test",
                    state.metrics)

    trainer.add_event_handler(
        Events.EPOCH_COMPLETED(every=config["validate_every"])
        | Events.COMPLETED, run_validation)

    if rank == 0:
        # Setup TensorBoard logging on trainer and evaluators. Logged values are:
        #  - Training metrics, e.g. running average loss values
        #  - Learning rate
        #  - Evaluation train/test metrics
        evaluators = {"training": train_evaluator, "test": evaluator}
        tb_logger = common.setup_tb_logging(output_path,
                                            trainer,
                                            optimizer,
                                            evaluators=evaluators)

    # Store 3 best models by validation accuracy:
    common.gen_save_best_models_by_val_score(
        save_handler=get_save_handler(config),
        evaluator=evaluator,
        models={"model": model},
        metric_name="accuracy",
        n_saved=3,
        trainer=trainer,
        tag="test",
    )

    # In order to check training resuming we can stop training on a given iteration
    if config["stop_iteration"] is not None:

        @trainer.on(Events.ITERATION_STARTED(once=config["stop_iteration"]))
        def _():
            logger.info("Stop training on {} iteration".format(
                trainer.state.iteration))
            trainer.terminate()

    @trainer.on(Events.ITERATION_COMPLETED(every=20))
    def print_acc(engine):
        if rank == 0:
            print("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}"\
                    .format(engine.state.epoch, engine.state.iteration, len(train_loader),
                            engine.state.saved_batch_loss
                            ))

    try:
        trainer.run(train_loader, max_epochs=config["num_epochs"])
    except Exception as e:
        import traceback

        print(traceback.format_exc())

    if rank == 0:
        tb_logger.close()