Beispiel #1
0
def run_evaluation(config_filepath, backend="nccl", with_clearml=True):
    """Main entry to run model's evaluation:
        - compute validation metrics

    Args:
        config_filepath (str): evaluation configuration .py file
        backend (str): distributed backend: nccl, gloo, horovod or None to run without distributed config
        with_clearml (bool): if True, uses ClearML as experiment tracking system
    """
    assert torch.cuda.is_available(), torch.cuda.is_available()
    assert torch.backends.cudnn.enabled
    torch.backends.cudnn.benchmark = True

    config_filepath = Path(config_filepath)
    assert config_filepath.exists(), f"File '{config_filepath.as_posix()}' is not found"

    with idist.Parallel(backend=backend) as parallel:
        logger = setup_logger(name="Pascal-VOC12 Evaluation", distributed_rank=idist.get_rank())

        config = ConfigObject(config_filepath)
        InferenceConfigSchema.validate(config)
        config.script_filepath = Path(__file__)

        output_path = setup_experiment_tracking(config, with_clearml=with_clearml, task_type="testing")
        config.output_path = output_path

        utils.log_basic_info(logger, get_params(config, InferenceConfigSchema))

        try:
            parallel.run(evaluation, config, logger=logger, with_clearml=with_clearml)
        except KeyboardInterrupt:
            logger.info("Catched KeyboardInterrupt -> exit")
        except Exception as e:  # noqa
            logger.exception("")
            raise e
Beispiel #2
0
def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
    """function to be run by idist.Parallel context manager."""

    # ----------------------
    # make a certain seed
    # ----------------------
    rank = idist.get_rank()
    manual_seed(config.seed + rank)

    # -----------------------
    # create output folder
    # -----------------------

    if rank == 0:
        now = datetime.now().strftime("%Y%m%d-%H%M%S")
        name = f"{config.model}-backend-{idist.backend()}-{now}"
        path = Path(config.output_dir, name)
        path.mkdir(parents=True, exist_ok=True)
        config.output_dir = path.as_posix()

    config.output_dir = Path(idist.broadcast(config.output_dir, src=0))

    # -----------------------------
    # datasets and dataloaders
    # -----------------------------
    # TODO : PLEASE provide your custom datasets and dataloaders configurations
    # we can use `idist.auto_dataloader` to handle distributed configurations
    # TODO : PLEASE replace `kwargs` with your desirable DataLoader arguments
    # See : https://pytorch.org/ignite/distributed.html#ignite.distributed.auto.auto_dataloader

    train_dataset, eval_dataset = get_datasets(path=config.data_path)

    train_dataloader = idist.auto_dataloader(
        train_dataset,
        batch_size=config.train_batch_size,
        num_workers=config.num_workers,
        shuffle=True,
        {% if use_distributed_training and not use_distributed_launcher %}
        persistent_workers=True,
        {% endif %}
    )
    eval_dataloader = idist.auto_dataloader(
        eval_dataset,
        batch_size=config.eval_batch_size,
        num_workers=config.num_workers,
        shuffle=False,
        {% if use_distributed_training and not use_distributed_launcher %}
        persistent_workers=True,
        {% endif %}
    )

    # ------------------------------------------
    # model, optimizer, loss function, device
    # ------------------------------------------

    device = idist.device()
    config.num_iters_per_epoch = len(train_dataloader)
    model, optimizer, loss_fn, lr_scheduler = initialize(config=config)

    # -----------------------------
    # trainer and evaluator
    # -----------------------------

    trainer, evaluator = create_trainers(
        config=config,
        model=model,
        optimizer=optimizer,
        loss_fn=loss_fn,
        device=device,
    )

    # ---------------------------------
    # attach metrics to evaluator
    # ---------------------------------
    accuracy = Accuracy(device=device)
    metrics = {
        "eval_accuracy": accuracy,
        "eval_loss": Loss(loss_fn, device=device),
        "eval_error": (1.0 - accuracy) * 100,
    }
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # -------------------------------------------
    # setup engines logger with python logging
    # print training configurations
    # -------------------------------------------

    logger = setup_logging(config)
    log_basic_info(logger, config)
    trainer.logger = logger
    evaluator.logger = logger

    # -------------------------------------
    # ignite handlers and ignite loggers
    # -------------------------------------

    to_save = {"model": model, "optimizer": optimizer, "trainer": trainer, "lr_scheduler": lr_scheduler}
    best_model_handler, es_handler, timer_handler = get_handlers(
        config=config,
        model=model,
        trainer=trainer,
        evaluator=evaluator,
        metric_name="eval_accuracy",
        es_metric_name="eval_accuracy",
        to_save=to_save,
        lr_scheduler=lr_scheduler,
        output_names=None,
    )

    # setup ignite logger only on rank 0
    if rank == 0:
        logger_handler = get_logger(
            config=config, trainer=trainer, evaluator=evaluator, optimizers=optimizer
        )

    # -----------------------------------
    # resume from the saved checkpoints
    # -----------------------------------

    if config.resume_from:
        resume_from(to_load=to_save, checkpoint_fp=config.resume_from)

    # --------------------------------
    # print metrics to the stderr
    # with `add_event_handler` API
    # for training stats
    # --------------------------------

    trainer.add_event_handler(Events.ITERATION_COMPLETED(every=config.log_every_iters), log_metrics, tag="train")

    # ---------------------------------------------
    # run evaluation at every training epoch end
    # with shortcut `on` decorator API and
    # print metrics to the stderr
    # again with `add_event_handler` API
    # for evaluation stats
    # ---------------------------------------------

    @trainer.on(Events.EPOCH_COMPLETED(every=1))
    def _():
        evaluator.run(eval_dataloader, epoch_length=config.eval_epoch_length)
        log_metrics(evaluator, "eval")

    # --------------------------------------------------
    # let's try run evaluation first as a sanity check
    # --------------------------------------------------

    @trainer.on(Events.STARTED)
    def _():
        evaluator.run(eval_dataloader, epoch_length=config.eval_epoch_length)

    # ------------------------------------------
    # setup if done. let's run the training
    # ------------------------------------------

    trainer.run(train_dataloader, max_epochs=config.max_epochs, epoch_length=config.train_epoch_length)

    # ------------------------------------------------------------
    # close the logger after the training completed / terminated
    # ------------------------------------------------------------

    if rank == 0:
        from ignite.contrib.handlers.wandb_logger import WandBLogger

        if isinstance(logger_handler, WandBLogger):
            # why handle differently for wandb ?
            # See : https://github.com/pytorch/ignite/issues/1894
            logger_handler.finish()
        elif logger_handler:
            logger_handler.close()

    # -----------------------------------------
    # where is my best and last checkpoint ?
    # -----------------------------------------

    if best_model_handler is not None:
        logger.info("Last and best checkpoint: %s", best_model_handler.last_checkpoint)
Beispiel #3
0
def main():
    params = argparse.ArgumentParser(
        description='CLI to train sockeye sequence-to-sequence models.')
    arguments.add_train_cli_args(params)
    args = params.parse_args()

    utils.seedRNGs(args)

    check_arg_compatibility(args)
    output_folder = os.path.abspath(args.output)
    resume_training, training_state_dir = check_resume(args, output_folder)

    global logger
    logger = setup_main_logger(__name__,
                               file_logging=True,
                               console=not args.quiet,
                               path=os.path.join(output_folder, C.LOG_NAME))
    utils.log_basic_info(args)
    with open(os.path.join(output_folder, C.ARGS_STATE_NAME), "w") as fp:
        json.dump(vars(args), fp)

    with ExitStack() as exit_stack:
        context = determine_context(args, exit_stack)
        vocab_source, vocab_target = load_or_create_vocabs(
            args, resume_training, output_folder)
        vocab_source_size = len(vocab_source)
        vocab_target_size = len(vocab_target)
        logger.info("Vocabulary sizes: source=%d target=%d", vocab_source_size,
                    vocab_target_size)
        train_iter, eval_iter, config_data = create_data_iters(
            args, vocab_source, vocab_target)
        lr_scheduler_instance = create_lr_scheduler(args, resume_training,
                                                    training_state_dir)

        model_config = create_model_config(args, vocab_source_size,
                                           vocab_target_size, config_data)
        model_config.freeze()

        training_model = create_training_model(model_config, args, context,
                                               train_iter,
                                               lr_scheduler_instance,
                                               resume_training,
                                               training_state_dir)

        weight_initializer = initializer.get_initializer(
            default_init_type=args.weight_init,
            default_init_scale=args.weight_init_scale,
            default_init_xavier_rand_type=args.weight_init_xavier_rand_type,
            default_init_xavier_factor_type=args.
            weight_init_xavier_factor_type,
            embed_init_type=args.embed_weight_init,
            embed_init_sigma=vocab_source_size**-0.5,  # TODO
            rnn_init_type=args.rnn_h2h_init)

        optimizer, optimizer_params, kvstore, gradient_clipping_type, gradient_clipping_threshold = define_optimizer(
            args, lr_scheduler_instance)

        # Handle options that override training settings
        max_updates = args.max_updates
        max_num_checkpoint_not_improved = args.max_num_checkpoint_not_improved
        min_num_epochs = args.min_num_epochs
        max_num_epochs = args.max_num_epochs
        if min_num_epochs is not None and max_num_epochs is not None:
            check_condition(
                min_num_epochs <= max_num_epochs,
                "Minimum number of epochs must be smaller than maximum number of epochs"
            )
        # Fixed training schedule always runs for a set number of updates
        if args.learning_rate_schedule:
            max_updates = sum(num_updates
                              for (_,
                                   num_updates) in args.learning_rate_schedule)
            max_num_checkpoint_not_improved = -1
            min_num_epochs = None
            max_num_epochs = None

        decode_and_evaluate, decode_and_evaluate_context = determine_decode_and_evaluate_context(
            args, exit_stack, context)

        training_model.fit(
            train_iter,
            eval_iter,
            output_folder=output_folder,
            max_params_files_to_keep=args.keep_last_params,
            metrics=args.metrics,
            initializer=weight_initializer,
            allow_missing_params=args.allow_missing_params,
            max_updates=max_updates,
            checkpoint_frequency=args.checkpoint_frequency,
            optimizer=optimizer,
            optimizer_params=optimizer_params,
            optimized_metric=args.optimized_metric,
            gradient_clipping_type=gradient_clipping_type,
            clip_gradient_threshold=gradient_clipping_threshold,
            kvstore=kvstore,
            max_num_not_improved=max_num_checkpoint_not_improved,
            min_num_epochs=min_num_epochs,
            max_num_epochs=max_num_epochs,
            decode_and_evaluate=decode_and_evaluate,
            decode_and_evaluate_context=decode_and_evaluate_context,
            use_tensorboard=args.use_tensorboard,
            mxmonitor_pattern=args.monitor_pattern,
            mxmonitor_stat_func=args.monitor_stat_func,
            lr_decay_param_reset=args.learning_rate_decay_param_reset,
            lr_decay_opt_states_reset=args.
            learning_rate_decay_optimizer_states_reset)
Beispiel #4
0
def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
    """function to be run by idist.Parallel context manager."""

    # ----------------------
    # make a certain seed
    # ----------------------
    rank = idist.get_rank()
    manual_seed(config.seed + rank)

    # -----------------------
    # create output folder
    # -----------------------

    if rank == 0:
        now = datetime.now().strftime("%Y%m%d-%H%M%S")
        name = f"{config.dataset}-backend-{idist.backend()}-{now}"
        path = Path(config.output_dir, name)
        path.mkdir(parents=True, exist_ok=True)
        config.output_dir = path.as_posix()

    config.output_dir = Path(idist.broadcast(config.output_dir, src=0))

    # -----------------------------
    # datasets and dataloaders
    # -----------------------------
    # TODO : PLEASE provide your custom datasets and dataloaders configurations
    # we can use `idist.auto_dataloader` to handle distributed configurations
    # TODO : PLEASE replace `kwargs` with your desirable DataLoader arguments
    # See : https://pytorch.org/ignite/distributed.html#ignite.distributed.auto.auto_dataloader

    train_dataset, eval_dataset = get_datasets()

    train_dataloader = idist.auto_dataloader(train_dataset, **kwargs)
    eval_dataloader = idist.auto_dataloader(eval_dataset, **kwargs)

    # ------------------------------------------
    # model, optimizer, loss function, device
    # ------------------------------------------

    device = idist.device()
    model, optimizer, loss_fn, lr_scheduler = initialize()

    # -----------------------------
    # trainer and evaluator
    # -----------------------------

    trainer, evaluator = create_trainers(
        config=config,
        model=model,
        optimizer=optimizer,
        loss_fn=loss_fn,
        device=device,
    )

    # -------------------------------------------
    # update config with optimizer parameters
    # setup engines logger with python logging
    # print training configurations
    # -------------------------------------------

    config.__dict__.update(**optimizer.defaults)
    logger = setup_logging(config)
    log_basic_info(logger, config)
    trainer.logger = logger
    evaluator.logger = logger

    # -------------------------------------
    # ignite handlers and ignite loggers
    # -------------------------------------

    to_save = {"model": model, "optimizer": optimizer, "trainer": trainer, "lr_scheduler": lr_scheduler}
    best_model_handler, es_handler, timer_handler = get_handlers(
        config=config,
        model=model,
        trainer=trainer,
        evaluator=evaluator,
        metric_name=None,
        # TODO : replace with the metric name to save the best model
        # if you check `Save the best model by evaluation score` otherwise leave it None
        # metric must be in evaluator.state.metrics.
        es_metric_name=None,
        # TODO : replace with the metric name to early stop
        # if you check `Early stop the training by evaluation score` otherwise leave it None
        # metric must be in evaluator.state.metrics.
        to_save=to_save,
        lr_scheduler=lr_scheduler,
        output_names=None,
    )

    # setup ignite logger only on rank 0
    if rank == 0:
        logger_handler = get_logger(
            config=config, trainer=trainer, evaluator=evaluator, optimizers=optimizer
        )

    # -----------------------------------
    # resume from the saved checkpoints
    # -----------------------------------

    if config.resume_from:
        resume_from(to_load=to_save, checkpoint_fp=config.resume_from)

    # --------------------------------------------
    # let's trigger custom events we registered
    # we will use a `event_filter` to trigger that
    # `event_filter` has to return boolean
    # whether this event should be executed
    # here will log the gradients on the 1st iteration
    # and every 100 iterations
    # --------------------------------------------

    @trainer.on(TrainEvents.BACKWARD_COMPLETED(lambda _, ev: (ev % 100 == 0) or (ev == 1)))
    def _():
        # do something interesting
        pass

    # ----------------------------------------
    # here we will use `every` to trigger
    # every 100 iterations
    # ----------------------------------------

    @trainer.on(TrainEvents.OPTIM_STEP_COMPLETED(every=100))
    def _():
        # do something interesting
        pass

    # --------------------------------
    # print metrics to the stderr
    # with `add_event_handler` API
    # for training stats
    # --------------------------------

    trainer.add_event_handler(Events.ITERATION_COMPLETED(every=config.log_every_iters), log_metrics, tag="train")

    # ---------------------------------------------
    # run evaluation at every training epoch end
    # with shortcut `on` decorator API and
    # print metrics to the stderr
    # again with `add_event_handler` API
    # for evaluation stats
    # ---------------------------------------------

    @trainer.on(Events.EPOCH_COMPLETED(every=1))
    def _():
        evaluator.run(eval_dataloader, epoch_length=config.eval_epoch_length)
        log_metrics(evaluator, "eval")

    # --------------------------------------------------
    # let's try run evaluation first as a sanity check
    # --------------------------------------------------

    @trainer.on(Events.STARTED)
    def _():
        evaluator.run(eval_dataloader, epoch_length=config.eval_epoch_length)

    # ------------------------------------------
    # setup if done. let's run the training
    # ------------------------------------------
    # TODO : PLEASE provide `max_epochs` parameters

    trainer.run(train_dataloader, max_epochs=config.max_epochs, epoch_length=config.train_epoch_length)

    # ------------------------------------------------------------
    # close the logger after the training completed / terminated
    # ------------------------------------------------------------

    if rank == 0:
        from ignite.contrib.handlers.wandb_logger import WandBLogger

        if isinstance(logger_handler, WandBLogger):
            # why handle differently for wandb ?
            # See : https://github.com/pytorch/ignite/issues/1894
            logger_handler.finish()
        elif logger_handler:
            logger_handler.close()

    # -----------------------------------------
    # where is my best and last checkpoint ?
    # -----------------------------------------

    if best_model_handler is not None:
        logger.info("Last and best checkpoint: %s", best_model_handler.last_checkpoint)
Beispiel #5
0
def main():
    params = argparse.ArgumentParser(description='Translate CLI')
    arguments.add_translate_cli_args(params)
    args = params.parse_args()

    if args.output is not None:
        global logger
        logger = setup_main_logger(__name__,
                                   console=not args.quiet,
                                   file_logging=True,
                                   path="%s.%s" % (args.output, C.LOG_NAME))

    if args.checkpoints is not None:
        check_condition(len(args.checkpoints) == len(args.models), "must provide checkpoints for each model")

    log_basic_info(args)

    out_handler = output_handler.get_output_handler(args.output_type,
                                                               args.output,
                                                               args.sure_align_threshold)

    with ExitStack() as exit_stack:
        context = _setup_context(args, exit_stack)
        with context:
            models, vocab_source, vocab_target = inference.load_models(
                context,
                args.max_input_len,
                args.beam_size,
                args.batch_size,
                args.models,
                args.checkpoints,
                args.softmax_temperature,
                args.max_output_length_num_stds,
                decoder_return_logit_inputs=args.restrict_lexicon is not None,
                cache_output_layer_w_b=args.restrict_lexicon is not None,
                vis_target_enc_attention_layer=args.vis_target_enc_attention_layer)
            restrict_lexicon = None # type: TopKLexicon
            if args.restrict_lexicon:
                restrict_lexicon = TopKLexicon(vocab_source, vocab_target)
                restrict_lexicon.load(args.restrict_lexicon)
            translator = inference.Translator(context,
                                                      args.ensemble_mode,
                                                      args.bucket_width,
                                                      inference.LengthPenalty(args.length_penalty_alpha,
                                                                                      args.length_penalty_beta),
                                                      models,
                                                      vocab_source,
                                                      vocab_target,
                                              restrict_lexicon,
                                              lex_weight=args.lex_weight,
                                              align_weight=args.align_weight,
                                              align_skip_threshold=args.align_threshold,
                                              align_k_best=args.align_beam_size
                                              )
            translator.dictionary = data_io.read_dictionary(args.dictionary) if args.dictionary else None
            translator.dictionary_override_with_max_attention = args.dictionary_override_with_max_attention
            if translator.dictionary_override_with_max_attention:
                utils.check_condition(args.batch_size==1, "batching not supported with dictionary override yet")
            translator.dictionary_ignore_se = args.dictionary_ignore_se
            if translator.dictionary_ignore_se:
                utils.check_condition(args.batch_size==1, "batching not supported with dictionary override yet")
            read_and_translate(translator, out_handler, args.chunk_size, args.input, args.reference)
Beispiel #6
0
def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
    """function to be run by idist.Parallel context manager."""

    # ----------------------
    # make a certain seed
    # ----------------------
    rank = idist.get_rank()
    manual_seed(config.seed + rank)

    # -----------------------
    # create output folder
    # -----------------------

    if rank == 0:
        now = datetime.now().strftime("%Y%m%d-%H%M%S")
        name = f"{config.dataset}-backend-{idist.backend()}-{now}"
        path = Path(config.output_dir, name)
        path.mkdir(parents=True, exist_ok=True)
        config.output_dir = path.as_posix()

    config.output_dir = Path(idist.broadcast(config.output_dir, src=0))

    # -----------------------------
    # datasets and dataloaders
    # -----------------------------

    train_dataset, num_channels = get_datasets(config.dataset, config.data_path)

    train_dataloader = idist.auto_dataloader(
        train_dataset,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        {% if use_distributed_training and not use_distributed_launcher %}
        persistent_workers=True,
        {% endif %}
    )

    # ------------------------------------------
    # model, optimizer, loss function, device
    # ------------------------------------------

    device = idist.device()
    netD, netG, optimizerD, optimizerG, loss_fn, lr_scheduler = initialize(config, num_channels)

    # -----------------------------
    # trainer and evaluator
    # -----------------------------
    ws = idist.get_world_size()
    real_labels = torch.ones(config.batch_size // ws, device=device)
    fake_labels = torch.zeros(config.batch_size // ws, device=device)
    fixed_noise = torch.randn(config.batch_size // ws, config.z_dim, 1, 1, device=device)

    trainer = create_trainers(
        config=config,
        netD=netD,
        netG=netG,
        optimizerD=optimizerD,
        optimizerG=optimizerG,
        loss_fn=loss_fn,
        device=device,
        real_labels=real_labels,
        fake_labels=fake_labels,
    )

    # -------------------------------------------
    # setup engines logger with python logging
    # print training configurations
    # -------------------------------------------

    logger = setup_logging(config)
    log_basic_info(logger, config)
    trainer.logger = logger

    # -------------------------------------
    # ignite handlers and ignite loggers
    # -------------------------------------

    to_save = {'netD': netD, 'netG': netG, 'optimizerD': optimizerD, 'optimizerG': optimizerG, 'trainer': trainer}
    optimizers = {'optimizerD': optimizerD, 'optimizerG': optimizerG}
    best_model_handler, es_handler, timer_handler = get_handlers(
        config=config,
        model={'netD', netD, 'netG', netG},
        trainer=trainer,
        evaluator=trainer,
        metric_name='errD',
        es_metric_name='errD',
        to_save=to_save,
        lr_scheduler=lr_scheduler,
        output_names=["errD", "errG", "D_x", "D_G_z1", "D_G_z2"],
    )

    # setup ignite logger only on rank 0
    if rank == 0:
        logger_handler = get_logger(config=config, trainer=trainer, optimizers=optimizers)

    # -----------------------------------
    # resume from the saved checkpoints
    # -----------------------------------

    if config.resume_from:
        resume_from(to_load=to_save, checkpoint_fp=config.resume_from)

    # --------------------------------------------------
    # adding handlers using `trainer.on` decorator API
    # --------------------------------------------------

    @trainer.on(Events.EPOCH_COMPLETED)
    def save_fake_example(engine):
        fake = netG(fixed_noise)
        path = config.output_dir / (FAKE_IMG_FNAME.format(engine.state.epoch))
        vutils.save_image(fake.detach(), path, normalize=True)

    # --------------------------------------------------
    # adding handlers using `trainer.on` decorator API
    # --------------------------------------------------
    @trainer.on(Events.EPOCH_COMPLETED)
    def save_real_example(engine):
        img, y = engine.state.batch
        path = config.output_dir / (REAL_IMG_FNAME.format(engine.state.epoch))
        vutils.save_image(img, path, normalize=True)

    # -------------------------------------------------------------
    # adding handlers using `trainer.on` decorator API
    # -------------------------------------------------------------
    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        if not timer_handler:
            logger.info(f"Epoch {engine.state.epoch} done. Time per batch: {timer_handler.value():.3f}[s]")
            timer_handler.reset()

    @trainer.on(Events.ITERATION_COMPLETED(every=config.log_every_iters))
    @idist.one_rank_only()
    def print_logs(engine):
        fname = config.output_dir / LOGS_FNAME
        columns = ["iteration", ] + list(engine.state.metrics.keys())
        values = [str(engine.state.iteration), ] + [str(round(value, 5)) for value in engine.state.metrics.values()]

        with open(fname, "a") as f:
            if f.tell() == 0:
                print("\t".join(columns), file=f)
            print("\t".join(values), file=f)
        message = f"[{engine.state.epoch}/{config.max_epochs}][{engine.state.iteration % len(train_dataloader)}/{len(train_dataloader)}]"
        for name, value in zip(columns, values):
            message += f" | {name}: {value}"

    # -------------------------------------------------------------
    # adding handlers using `trainer.on` decorator API
    # -------------------------------------------------------------
    @trainer.on(Events.EPOCH_COMPLETED)
    def create_plots(engine):
        try:
            import matplotlib as mpl

            mpl.use("agg")

            import matplotlib.pyplot as plt
            import pandas as pd

        except ImportError:
            warnings.warn("Loss plots will not be generated -- pandas or matplotlib not found")

        else:
            df = pd.read_csv(config.output_dir / LOGS_FNAME, delimiter="\t", index_col="iteration")
            _ = df.plot(subplots=True, figsize=(20, 20))
            _ = plt.xlabel("Iteration number")
            fig = plt.gcf()
            path = config.output_dir / PLOT_FNAME

            fig.savefig(path)

    # --------------------------------
    # print metrics to the stderr
    # with `add_event_handler` API
    # for training stats
    # --------------------------------

    trainer.add_event_handler(Events.ITERATION_COMPLETED(every=config.log_every_iters), log_metrics, tag="train")

    # ------------------------------------------
    # setup if done. let's run the training
    # ------------------------------------------

    trainer.run(train_dataloader, max_epochs=config.max_epochs, epoch_length=config.train_epoch_length)

    # ------------------------------------------------------------
    # close the logger after the training completed / terminated
    # ------------------------------------------------------------

    if rank == 0:
        from ignite.contrib.handlers.wandb_logger import WandBLogger

        if isinstance(logger_handler, WandBLogger):
            # why handle differently for wandb ?
            # See : https://github.com/pytorch/ignite/issues/1894
            logger_handler.finish()
        elif logger_handler:
            logger_handler.close()

    # -----------------------------------------
    # where is my best and last checkpoint ?
    # -----------------------------------------

    if best_model_handler is not None:
        logger.info("Last and best checkpoint: %s", best_model_handler.last_checkpoint)
Beispiel #7
0
def run(local_rank, config):

    # ----------------------
    # Make a certain seed
    # ----------------------
    rank = idist.get_rank()
    manual_seed(config.seed + rank)
    device = idist.device()

    # -----------------------
    # Create output folder
    # -----------------------
    if rank == 0:
        now = datetime.now().strftime("%Y%m%d-%H%M%S")
        name = f"{config.model}-backend-{idist.backend()}-{now}"
        path = Path(config.output_dir, name)
        path.mkdir(parents=True, exist_ok=True)
        config.output_dir = path.as_posix()

    config.output_dir = Path(idist.broadcast(config.output_dir, src=0))

    # -----------------------------
    # datasets and dataloaders
    # -----------------------------
    train_loader, test_loader = get_dataflow(config)

    # ------------------------------------------
    # model, optimizer, loss function, device
    # ------------------------------------------
    config.num_iters_per_epoch = len(train_loader)
    model, optimizer, loss_fn, lr_scheduler = initialize(config)

    # -----------------------------
    # trainer and evaluator
    # -----------------------------
    trainer, evaluator = create_trainers(
        config=config,
        model=model,
        optimizer=optimizer,
        loss_fn=loss_fn,
        device=device,
    )

    # ---------------------------------
    # attach metrics to evaluator
    # ---------------------------------
    metrics = {
        "eval_accuracy": Accuracy(output_transform=thresholded_output_transform, device=device),
        "eval_loss": Loss(loss_fn, device=device),
    }

    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # -------------------------------------------
    # setup engines logger with python logging
    # print training configurations
    # -------------------------------------------
    logger = setup_logging(config)
    log_basic_info(logger, config)
    trainer.logger = logger
    evaluator.logger = logger

    # -------------------------------------
    # ignite handlers and ignite loggers
    # -------------------------------------
    to_save = {"model": model, "optimizer": optimizer, "trainer": trainer, "lr_scheduler": lr_scheduler}
    best_model_handler, es_handler, timer_handler = get_handlers(
        config=config,
        model=model,
        trainer=trainer,
        evaluator=evaluator,
        metric_name="eval_accuracy",
        es_metric_name="eval_accuracy",
        to_save=to_save,
        lr_scheduler=lr_scheduler,
        output_names=None,
    )

    # setup ignite logger only on rank 0
    if rank == 0:
        logger_handler = get_logger(
            config=config, trainer=trainer, evaluator=evaluator, optimizers=optimizer
        )

    # -----------------------------------
    # resume from the saved checkpoints
    # -----------------------------------
    if config.resume_from:
        resume_from(to_load=to_save, checkpoint_fp=config.resume_from)

    # --------------------------------
    # print metrics to the stderr
    # with `add_event_handler` API
    # for training stats
    # --------------------------------
    trainer.add_event_handler(Events.ITERATION_COMPLETED(every=config.log_every_iters), log_metrics, tag="train")

    # ---------------------------------------------
    # run evaluation at every training epoch end
    # with shortcut `on` decorator API and
    # print metrics to the stderr
    # again with `add_event_handler` API
    # for evaluation stats
    # ---------------------------------------------
    @trainer.on(Events.EPOCH_COMPLETED(every=config.validate_every))
    def _():
        evaluator.run(test_loader, epoch_length=config.eval_epoch_length)
        log_metrics(evaluator, tag="eval")

    # --------------------------------------------------
    # let's try run evaluation first as a sanity check
    # --------------------------------------------------
    @trainer.on(Events.STARTED)
    def _():
        evaluator.run(test_loader, epoch_length=config.eval_epoch_length)

    # ------------------------------------------
    # setup if done. let's run the training
    # ------------------------------------------
    trainer.run(train_loader, max_epochs=config.max_epochs, epoch_length=config.train_epoch_length)

    # ------------------------------------------------------------
    # close the logger after the training completed / terminated
    # ------------------------------------------------------------
    if rank == 0:
        from ignite.contrib.handlers.wandb_logger import WandBLogger

        if isinstance(logger_handler, WandBLogger):
            # why handle differently for wandb ?
            # See : https://github.com/pytorch/ignite/issues/1894
            logger_handler.finish()
        elif logger_handler:
            logger_handler.close()

    # -----------------------------------------
    # where is my best and last checkpoint ?
    # -----------------------------------------
    if best_model_handler is not None:
        logger.info("Last and best checkpoint: %s", best_model_handler.last_checkpoint)