def test_tqdm_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir):
    """Test that printing in LightningModule goes through built-in print function when progress bar is disabled."""
    model = PrintModel()
    bar = TQDMProgressBar()
    trainer = Trainer(
        default_root_dir=tmpdir,
        num_sanity_val_steps=0,
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        limit_predict_batches=1,
        max_steps=1,
        callbacks=[bar],
    )
    bar.disable()
    trainer.fit(model)
    trainer.test(model, verbose=False)
    trainer.predict(model)

    mock_print.assert_has_calls([
        call("training_step", end=""),
        call("validation_step", file=ANY),
        call("test_step"),
        call("predict_step")
    ])
    tqdm_write.assert_not_called()
示例#2
0
def test_tqdm_progress_bar_misconfiguration():
    """Test that Trainer doesn't accept multiple progress bars."""
    # Trainer supports only a single progress bar callback at the moment
    callbacks = [TQDMProgressBar(), TQDMProgressBar(), ModelCheckpoint(dirpath="../trainer")]
    with pytest.raises(MisconfigurationException, match=r"^You added multiple progress bar callbacks"):
        Trainer(callbacks=callbacks)

    with pytest.raises(MisconfigurationException, match=r"enable_progress_bar=False` but found `TQDMProgressBar"):
        Trainer(callbacks=TQDMProgressBar(), enable_progress_bar=False)
def test_tqdm_progress_bar_value_on_colab(tmpdir):
    """Test that Trainer will override the default in Google COLAB."""
    trainer = Trainer(default_root_dir=tmpdir)
    assert trainer.progress_bar_callback.refresh_rate == 20

    trainer = Trainer(default_root_dir=tmpdir, callbacks=TQDMProgressBar())
    assert trainer.progress_bar_callback.refresh_rate == 20

    trainer = Trainer(default_root_dir=tmpdir, callbacks=TQDMProgressBar(refresh_rate=19))
    assert trainer.progress_bar_callback.refresh_rate == 19
示例#4
0
def test_tqdm_progress_bar_value_on_colab(tmpdir):
    """Test that Trainer will override the default in Google COLAB."""
    trainer = Trainer(default_root_dir=tmpdir)
    assert trainer.progress_bar_callback.refresh_rate == 20

    trainer = Trainer(default_root_dir=tmpdir, callbacks=TQDMProgressBar())
    assert trainer.progress_bar_callback.refresh_rate == 20

    trainer = Trainer(default_root_dir=tmpdir, callbacks=TQDMProgressBar(refresh_rate=19))
    assert trainer.progress_bar_callback.refresh_rate == 19

    with pytest.deprecated_call(match=r"progress_bar_refresh_rate=19\)` is deprecated"):
        trainer = Trainer(default_root_dir=tmpdir, progress_bar_refresh_rate=19)
    assert trainer.progress_bar_callback.refresh_rate == 19
    def _configure_progress_bar(self,
                                enable_progress_bar: bool = True) -> None:
        progress_bars = [
            c for c in self.trainer.callbacks
            if isinstance(c, ProgressBarBase)
        ]
        if len(progress_bars) > 1:
            raise MisconfigurationException(
                "You added multiple progress bar callbacks to the Trainer, but currently only one"
                " progress bar is supported.")
        if len(progress_bars) == 1:
            # the user specified the progress bar in the callbacks list
            # so the trainer doesn't need to provide a default one
            if enable_progress_bar:
                return

            # otherwise the user specified a progress bar callback but also
            # elected to disable the progress bar with the trainer flag
            progress_bar_callback = progress_bars[0]
            raise MisconfigurationException(
                "Trainer was configured with `enable_progress_bar=False`"
                f" but found `{progress_bar_callback.__class__.__name__}` in callbacks list."
            )

        if enable_progress_bar:
            progress_bar_callback = TQDMProgressBar()
            self.trainer.callbacks.append(progress_bar_callback)
示例#6
0
def test_main_progress_bar_update_amount(
    tmpdir, train_batches: int, val_batches: int, refresh_rate: int, train_updates, val_updates
):
    """Test that the main progress updates with the correct amount together with the val progress.

    At the end of the epoch, the progress must not overshoot if the number of steps is not divisible by the refresh
    rate.
    """
    model = BoringModel()
    progress_bar = TQDMProgressBar(refresh_rate=refresh_rate)
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_train_batches=train_batches,
        limit_val_batches=val_batches,
        callbacks=[progress_bar],
        logger=False,
        enable_checkpointing=False,
    )
    with mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm", MockTqdm):
        trainer.fit(model)
    if train_batches > 0:
        assert progress_bar.main_progress_bar.n_values == train_updates
    if val_batches > 0:
        assert progress_bar.val_progress_bar.n_values == val_updates
def test_progress_bar_max_val_check_interval(
    tmpdir, val_check_interval, main_progress_bar_updates, val_progress_bar_updates
):
    limit_batches = 7
    model = BoringModel()
    trainer = Trainer(
        default_root_dir=tmpdir,
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
        val_check_interval=val_check_interval,
        limit_train_batches=limit_batches,
        limit_val_batches=limit_batches,
        callbacks=TQDMProgressBar(refresh_rate=3),
    )
    with mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm", MockTqdm):
        trainer.fit(model)

    pbar = trainer.progress_bar_callback
    assert pbar.main_progress_bar.n_values == main_progress_bar_updates
    assert pbar.val_progress_bar.n_values == val_progress_bar_updates

    val_check_batch = (
        max(1, int(limit_batches * val_check_interval)) if isinstance(val_check_interval, float) else val_check_interval
    )
    assert trainer.val_check_batch == val_check_batch
    val_checks_per_epoch = math.ceil(limit_batches // val_check_batch)
    pbar_callback = trainer.progress_bar_callback
    total_val_batches = limit_batches * val_checks_per_epoch

    assert pbar_callback.val_progress_bar.n == limit_batches
    assert pbar_callback.val_progress_bar.total == limit_batches
    assert pbar_callback.main_progress_bar.n == limit_batches + total_val_batches
    assert pbar_callback.main_progress_bar.total == limit_batches + total_val_batches
    assert pbar_callback.is_enabled
    def configure_progress_bar(
            self,
            refresh_rate: Optional[int] = None,
            process_position: int = 0,
            enable_progress_bar: bool = True) -> Optional[ProgressBarBase]:
        if os.getenv("COLAB_GPU") and refresh_rate is None:
            # smaller refresh rate on colab causes crashes, choose a higher value
            refresh_rate = 20
        refresh_rate = 1 if refresh_rate is None else refresh_rate

        progress_bars = [
            c for c in self.trainer.callbacks
            if isinstance(c, ProgressBarBase)
        ]
        if len(progress_bars) > 1:
            raise MisconfigurationException(
                "You added multiple progress bar callbacks to the Trainer, but currently only one"
                " progress bar is supported.")
        if len(progress_bars) == 1:
            progress_bar_callback = progress_bars[0]
            if not enable_progress_bar:
                raise MisconfigurationException(
                    "Trainer was configured with `enable_progress_bar=False`"
                    f" but found `{progress_bar_callback.__class__.__name__}` in callbacks list."
                )
        elif refresh_rate > 0 and enable_progress_bar:
            progress_bar_callback = TQDMProgressBar(
                refresh_rate=refresh_rate, process_position=process_position)
            self.trainer.callbacks.append(progress_bar_callback)
        else:
            progress_bar_callback = None

        return progress_bar_callback
示例#9
0
def test_checkpoint_callbacks_are_last(tmpdir):
    """Test that checkpoint callbacks always get moved to the end of the list, with preserved order."""
    checkpoint1 = ModelCheckpoint(tmpdir)
    checkpoint2 = ModelCheckpoint(tmpdir)
    model_summary = ModelSummary()
    early_stopping = EarlyStopping(monitor="foo")
    lr_monitor = LearningRateMonitor()
    progress_bar = TQDMProgressBar()

    # no model reference
    trainer = Trainer(callbacks=[
        checkpoint1, progress_bar, lr_monitor, model_summary, checkpoint2
    ])
    assert trainer.callbacks == [
        progress_bar,
        lr_monitor,
        model_summary,
        trainer.accumulation_scheduler,
        checkpoint1,
        checkpoint2,
    ]

    # no model callbacks
    model = LightningModule()
    model.configure_callbacks = lambda: []
    trainer.model = model
    cb_connector = CallbackConnector(trainer)
    cb_connector._attach_model_callbacks()
    assert trainer.callbacks == [
        progress_bar,
        lr_monitor,
        model_summary,
        trainer.accumulation_scheduler,
        checkpoint1,
        checkpoint2,
    ]

    # with model-specific callbacks that substitute ones in Trainer
    model = LightningModule()
    model.configure_callbacks = lambda: [
        checkpoint1, early_stopping, model_summary, checkpoint2
    ]
    trainer = Trainer(
        callbacks=[progress_bar, lr_monitor,
                   ModelCheckpoint(tmpdir)])
    trainer.model = model
    cb_connector = CallbackConnector(trainer)
    cb_connector._attach_model_callbacks()
    assert trainer.callbacks == [
        progress_bar,
        lr_monitor,
        trainer.accumulation_scheduler,
        early_stopping,
        model_summary,
        checkpoint1,
        checkpoint2,
    ]
示例#10
0
def main():
    """The main for this multi-source domain adaptation example, showing the workflow"""
    args = arg_parse()

    # ---- setup configs ----
    cfg = get_cfg_defaults()
    cfg.merge_from_file(args.cfg)
    cfg.freeze()
    print(cfg)

    # ---- setup output ----
    format_str = "@%(asctime)s %(name)s [%(levelname)s] - (%(message)s)"
    logging.basicConfig(format=format_str)
    # ---- setup dataset ----
    if type(cfg.DATASET.SOURCE) == list:
        sub_domain_set = cfg.DATASET.SOURCE + [cfg.DATASET.TARGET]
    else:
        sub_domain_set = None
    num_channels = cfg.DATASET.NUM_CHANNELS
    if cfg.DATASET.NAME.upper() == "DIGITS":
        kwargs = {"return_domain_label": True}
    else:
        kwargs = {"download": True, "return_domain_label": True}

    data_access = ImageAccess.get_multi_domain_images(
        cfg.DATASET.NAME.upper(), cfg.DATASET.ROOT, sub_domain_set=sub_domain_set, **kwargs
    )

    # Repeat multiple times to get std
    for i in range(0, cfg.DATASET.NUM_REPEAT):
        seed = cfg.SOLVER.SEED + i * 10
        dataset = MultiDomainAdapDataset(data_access, random_state=seed)
        set_seed(seed)  # seed_everything in pytorch_lightning did not set torch.backends.cudnn
        print(f"==> Building model for seed {seed} ......")
        # ---- setup model and logger ----
        model, train_params = get_model(cfg, dataset, num_channels)

        tb_logger = TensorBoardLogger(cfg.OUTPUT.TB_DIR, name="seed{}".format(seed))
        checkpoint_callback = ModelCheckpoint(
            filename="{epoch}-{step}-{valid_loss:.4f}", monitor="valid_loss", mode="min",
        )
        progress_bar = TQDMProgressBar(cfg.OUTPUT.PB_FRESH)

        trainer = pl.Trainer(
            min_epochs=cfg.SOLVER.MIN_EPOCHS,
            max_epochs=cfg.SOLVER.MAX_EPOCHS,
            callbacks=[checkpoint_callback, progress_bar],
            gpus=args.gpus,
            auto_select_gpus=True,
            logger=tb_logger,  # logger,
            # weights_summary='full',
            fast_dev_run=False,  # True,
        )

        trainer.fit(model)
        trainer.test()
def test_tqdm_progress_bar_can_be_pickled():
    bar = TQDMProgressBar()
    trainer = Trainer(fast_dev_run=True, callbacks=[bar], max_steps=1)
    model = BoringModel()

    pickle.dumps(bar)
    trainer.fit(model)
    pickle.dumps(bar)
    trainer.test(model)
    pickle.dumps(bar)
    trainer.predict(model)
    pickle.dumps(bar)
示例#12
0
def train(model: LightningModule, data_module: LightningDataModule,
          config: DictConfig):
    seed_everything(config.seed)
    params = config.train

    # define logger
    wandb_logger = WandbLogger(
        project=config.wandb.project,
        group=config.wandb.group,
        log_model=False,
        offline=config.wandb.offline,
        config=OmegaConf.to_container(config),
    )

    # define model checkpoint callback
    checkpoint_callback = ModelCheckpointWithUploadCallback(
        dirpath=wandb_logger.experiment.dir,
        filename="{epoch:02d}-val_loss={val/loss:.4f}",
        monitor="val/loss",
        every_n_epochs=params.save_every_epoch,
        save_top_k=-1,
        auto_insert_metric_name=False,
    )
    # define early stopping callback
    early_stopping_callback = EarlyStopping(patience=params.patience,
                                            monitor="val/loss",
                                            verbose=True,
                                            mode="min")
    # define callback for printing intermediate result
    print_epoch_result_callback = PrintEpochResultCallback(after_test=False)
    # define learning rate logger
    lr_logger = LearningRateMonitor("step")
    # define progress bar callback
    progress_bar = TQDMProgressBar(
        refresh_rate=config.progress_bar_refresh_rate)
    trainer = Trainer(
        max_epochs=params.n_epochs,
        gradient_clip_val=params.clip_norm,
        deterministic=True,
        check_val_every_n_epoch=params.val_every_epoch,
        log_every_n_steps=params.log_every_n_steps,
        logger=wandb_logger,
        gpus=params.gpu,
        callbacks=[
            lr_logger, early_stopping_callback, checkpoint_callback,
            print_epoch_result_callback, progress_bar
        ],
        resume_from_checkpoint=config.get("checkpoint", None),
    )

    trainer.fit(model=model, datamodule=data_module)
    trainer.test(model=model, datamodule=data_module)
示例#13
0
文件: main.py 项目: sz144/pykale
def main():
    """The main for this domain adaptation example, showing the workflow"""
    args = arg_parse()

    # ---- setup configs ----
    cfg = get_cfg_defaults()
    cfg.merge_from_file(args.cfg)
    cfg.freeze()
    print(cfg)

    # ---- setup output ----
    format_str = "@%(asctime)s %(name)s [%(levelname)s] - (%(message)s)"
    logging.basicConfig(format=format_str)
    # ---- setup dataset ----
    source, target, num_channels = DigitDataset.get_source_target(
        DigitDataset(cfg.DATASET.SOURCE.upper()),
        DigitDataset(cfg.DATASET.TARGET.upper()), cfg.DATASET.ROOT)
    dataset = MultiDomainDatasets(
        source,
        target,
        config_weight_type=cfg.DATASET.WEIGHT_TYPE,
        config_size_type=cfg.DATASET.SIZE_TYPE,
        valid_split_ratio=cfg.DATASET.VALID_SPLIT_RATIO,
    )

    # Repeat multiple times to get std
    for i in range(0, cfg.DATASET.NUM_REPEAT):
        seed = cfg.SOLVER.SEED + i * 10
        # seed_everything in pytorch_lightning did not set torch.backends.cudnn
        set_seed(seed)
        print(f"==> Building model for seed {seed} ......")
        # ---- setup model and logger ----
        model, train_params = get_model(cfg, dataset, num_channels)
        tb_logger = pl_loggers.TensorBoardLogger(cfg.OUTPUT.TB_DIR,
                                                 name="seed{}".format(seed))
        checkpoint_callback = ModelCheckpoint(
            filename="{epoch}-{step}-{valid_loss:.4f}",
            monitor="valid_loss",
            mode="min",
        )
        progress_bar = TQDMProgressBar(cfg.OUTPUT.PB_FRESH)

        trainer = pl.Trainer(
            min_epochs=cfg.SOLVER.MIN_EPOCHS,
            max_epochs=cfg.SOLVER.MAX_EPOCHS,
            callbacks=[checkpoint_callback, progress_bar],
            logger=tb_logger,
            gpus=args.gpus,
        )

        trainer.fit(model)
        trainer.test()
示例#14
0
def test_tqdm_progress_bar_disabled_when_not_rank_zero(is_global_zero):
    """Test that the progress bar is disabled when not in global rank zero."""
    pbar = TQDMProgressBar()
    model = BoringModel()
    trainer = Trainer(
        callbacks=[pbar],
        fast_dev_run=True,
    )

    pbar.enable()
    trainer.fit(model)
    assert pbar.is_disabled

    pbar.enable()
    trainer.predict(model)
    assert pbar.is_disabled

    pbar.enable()
    trainer.validate(model)
    assert pbar.is_disabled

    pbar.enable()
    trainer.test(model)
    assert pbar.is_disabled
示例#15
0
def test_test_progress_bar_update_amount(tmpdir, test_batches: int, refresh_rate: int, updates: list):
    """Test that test progress updates with the correct amount."""
    model = BoringModel()
    progress_bar = TQDMProgressBar(refresh_rate=refresh_rate)
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_test_batches=test_batches,
        callbacks=[progress_bar],
        logger=False,
        enable_checkpointing=False,
    )
    with mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm", MockTqdm):
        trainer.test(model)
    assert progress_bar.test_progress_bar.n_values == updates
示例#16
0
def test_attach_model_callbacks_override_info(caplog):
    """Test that the logs contain the info about overriding callbacks returned by configure_callbacks."""
    model = LightningModule()
    model.configure_callbacks = lambda: [
        LearningRateMonitor(),
        EarlyStopping(monitor="foo")
    ]
    trainer = Trainer(enable_checkpointing=False,
                      callbacks=[
                          EarlyStopping(monitor="foo"),
                          LearningRateMonitor(),
                          TQDMProgressBar()
                      ])
    trainer.model = model
    cb_connector = CallbackConnector(trainer)
    with caplog.at_level(logging.INFO):
        cb_connector._attach_model_callbacks()

    assert "existing callbacks passed to Trainer: EarlyStopping, LearningRateMonitor" in caplog.text
def _get_trainer_callbacks(cfg: CfgNode) -> List[Callback]:
    """Gets the trainer callbacks based on the given D2Go Config.

    Args:
        cfg: The normalized ConfigNode for this D2Go Task.

    Returns:
        A list of configured Callbacks to be used by the Lightning Trainer.
    """
    callbacks: List[Callback] = [
        TQDMProgressBar(refresh_rate=10),  # Arbitrary refresh_rate.
        LearningRateMonitor(logging_interval="step"),
        ModelCheckpoint(
            dirpath=cfg.OUTPUT_DIR,
            save_last=True,
        ),
    ]
    if cfg.QUANTIZATION.QAT.ENABLED:
        callbacks.append(QuantizationAwareTraining.from_config(cfg))
    return callbacks
示例#18
0
def test_tqdm_progress_bar_print_no_train(tqdm_write, tmpdir):
    """Test that printing in the LightningModule redirects arguments to the progress bar without training."""
    model = PrintModel()
    bar = TQDMProgressBar()
    trainer = Trainer(
        default_root_dir=tmpdir,
        num_sanity_val_steps=0,
        limit_val_batches=1,
        limit_test_batches=1,
        limit_predict_batches=1,
        max_steps=1,
        callbacks=[bar],
    )

    trainer.validate(model)
    trainer.test(model)
    trainer.predict(model)
    assert tqdm_write.call_args_list == [
        call("validation_step", file=sys.stderr),
        call("test_step"),
        call("predict_step"),
    ]
def test_tqdm_progress_bar_print(tqdm_write, tmpdir):
    """Test that printing in the LightningModule redirects arguments to the progress bar."""
    model = PrintModel()
    bar = TQDMProgressBar()
    trainer = Trainer(
        default_root_dir=tmpdir,
        num_sanity_val_steps=0,
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        limit_predict_batches=1,
        max_steps=1,
        callbacks=[bar],
    )
    trainer.fit(model)
    trainer.test(model)
    trainer.predict(model)
    assert tqdm_write.call_count == 4
    assert tqdm_write.call_args_list == [
        call("training_step", end="", file=None, nolock=False),
        call("validation_step", end=os.linesep, file=sys.stderr, nolock=False),
        call("test_step", end=os.linesep, file=None, nolock=False),
        call("predict_step", end=os.linesep, file=None, nolock=False),
    ]
示例#20
0
文件: main.py 项目: sz144/pykale
def main():
    """The main for this domain adaptation example, showing the workflow"""
    args = arg_parse()

    # ---- setup configs ----
    cfg = get_cfg_defaults()
    cfg.merge_from_file(args.cfg)
    cfg.freeze()
    print(cfg)

    # ---- setup output ----
    format_str = "@%(asctime)s %(name)s [%(levelname)s] - (%(message)s)"
    logging.basicConfig(format=format_str)
    # ---- setup dataset ----
    seed = cfg.SOLVER.SEED
    source, target, num_classes = VideoDataset.get_source_target(
        VideoDataset(cfg.DATASET.SOURCE.upper()), VideoDataset(cfg.DATASET.TARGET.upper()), seed, cfg
    )
    dataset = VideoMultiDomainDatasets(
        source,
        target,
        image_modality=cfg.DATASET.IMAGE_MODALITY,
        seed=seed,
        config_weight_type=cfg.DATASET.WEIGHT_TYPE,
        config_size_type=cfg.DATASET.SIZE_TYPE,
    )

    # ---- training/test process ----
    ### Repeat multiple times to get std
    for i in range(0, cfg.DATASET.NUM_REPEAT):
        seed = seed + i * 10
        set_seed(seed)  # seed_everything in pytorch_lightning did not set torch.backends.cudnn
        print(f"==> Building model for seed {seed} ......")
        # ---- setup model and logger ----
        model, train_params = get_model(cfg, dataset, num_classes)
        tb_logger = pl_loggers.TensorBoardLogger(cfg.OUTPUT.TB_DIR, name="seed{}".format(seed))
        checkpoint_callback = ModelCheckpoint(
            # dirpath=full_checkpoint_dir,
            filename="{epoch}-{step}-{valid_loss:.4f}",
            # save_last=True,
            # save_top_k=1,
            monitor="valid_loss",
            mode="min",
        )

        ### Set early stopping
        # early_stop_callback = EarlyStopping(monitor="valid_target_acc", min_delta=0.0000, patience=100, mode="max")

        lr_monitor = LearningRateMonitor(logging_interval="epoch")
        progress_bar = TQDMProgressBar(cfg.OUTPUT.PB_FRESH)

        ### Set the lightning trainer. Comment `limit_train_batches`, `limit_val_batches`, `limit_test_batches` when
        # training. Uncomment and change the ratio to test the code on the smallest sub-dataset for efficiency in
        # debugging. Uncomment early_stop_callback to activate early stopping.
        trainer = pl.Trainer(
            min_epochs=cfg.SOLVER.MIN_EPOCHS,
            max_epochs=cfg.SOLVER.MAX_EPOCHS,
            # resume_from_checkpoint=last_checkpoint_file,
            gpus=args.gpus,
            logger=tb_logger,  # logger,
            # weights_summary='full',
            fast_dev_run=cfg.OUTPUT.FAST_DEV_RUN,  # True,
            callbacks=[lr_monitor, checkpoint_callback, progress_bar],
            # callbacks=[early_stop_callback, lr_monitor],
            # limit_train_batches=0.005,
            # limit_val_batches=0.06,
            # limit_test_batches=0.06,
        )

        ### Find learning_rate
        # lr_finder = trainer.tuner.lr_find(model, max_lr=0.1, min_lr=1e-6)
        # fig = lr_finder.plot(suggest=True)
        # fig.show()
        # logging.info(lr_finder.suggestion())

        ### Training/validation process
        trainer.fit(model)

        ### Test process
        trainer.test()
示例#21
0
def test_attach_model_callbacks():
    """Test that the callbacks defined in the model and through Trainer get merged correctly."""

    def _attach_callbacks(trainer_callbacks, model_callbacks):
        model = LightningModule()
        model.configure_callbacks = lambda: model_callbacks
        has_progress_bar = any(isinstance(cb, ProgressBarBase) for cb in trainer_callbacks + model_callbacks)
        trainer = Trainer(
            enable_checkpointing=False,
            enable_progress_bar=has_progress_bar,
            enable_model_summary=False,
            callbacks=trainer_callbacks,
        )
        trainer.model = model
        cb_connector = CallbackConnector(trainer)
        cb_connector._attach_model_callbacks()
        return trainer

    early_stopping = EarlyStopping(monitor="foo")
    progress_bar = TQDMProgressBar()
    lr_monitor = LearningRateMonitor()
    grad_accumulation = GradientAccumulationScheduler({1: 1})

    # no callbacks
    trainer = _attach_callbacks(trainer_callbacks=[], model_callbacks=[])
    assert trainer.callbacks == [trainer.accumulation_scheduler]

    # callbacks of different types
    trainer = _attach_callbacks(trainer_callbacks=[early_stopping], model_callbacks=[progress_bar])
    assert trainer.callbacks == [early_stopping, trainer.accumulation_scheduler, progress_bar]

    # same callback type twice, different instance
    trainer = _attach_callbacks(
        trainer_callbacks=[progress_bar, EarlyStopping(monitor="foo")],
        model_callbacks=[early_stopping],
    )
    assert trainer.callbacks == [progress_bar, trainer.accumulation_scheduler, early_stopping]

    # multiple callbacks of the same type in trainer
    trainer = _attach_callbacks(
        trainer_callbacks=[
            LearningRateMonitor(),
            EarlyStopping(monitor="foo"),
            LearningRateMonitor(),
            EarlyStopping(monitor="foo"),
        ],
        model_callbacks=[early_stopping, lr_monitor],
    )
    assert trainer.callbacks == [trainer.accumulation_scheduler, early_stopping, lr_monitor]

    # multiple callbacks of the same type, in both trainer and model
    trainer = _attach_callbacks(
        trainer_callbacks=[
            LearningRateMonitor(),
            progress_bar,
            EarlyStopping(monitor="foo"),
            LearningRateMonitor(),
            EarlyStopping(monitor="foo"),
        ],
        model_callbacks=[early_stopping, lr_monitor, grad_accumulation, early_stopping],
    )
    assert trainer.callbacks == [progress_bar, early_stopping, lr_monitor, grad_accumulation, early_stopping]
def test_tqdm_progress_bar_main_bar_resume():
    """Test that the progress bar can resume its counters based on the Trainer state."""
    bar = TQDMProgressBar()
    trainer = Mock()
    model = Mock()

    trainer.sanity_checking = False
    trainer.check_val_every_n_epoch = 1
    trainer.current_epoch = 1
    trainer.num_training_batches = 5
    trainer.val_check_batch = 5
    trainer.num_val_batches = [3]
    trainer.fit_loop.epoch_loop.batch_progress.current.completed = 3

    bar.setup(trainer, model)
    bar.on_train_start(trainer, model)
    bar.on_train_epoch_start(trainer, model)

    assert bar.main_progress_bar.n == 3
    assert bar.main_progress_bar.total == 8

    # bar.on_train_epoch_end(trainer, model)
    bar.on_validation_start(trainer, model)
    bar.on_validation_epoch_start(trainer, model)

    # restarting mid validation epoch is not currently supported
    assert bar.val_progress_bar.n == 0
    assert bar.val_progress_bar.total == 3
示例#23
0
def run_lightning(argv=None):
    '''Run training with PyTorch Lightning'''
    global RANK
    from pytorch_lightning.loggers import WandbLogger
    import numpy as np
    import traceback
    import os
    import pprint

    pformat = pprint.PrettyPrinter(sort_dicts=False, width=100,
                                   indent=2).pformat

    model, args, addl_targs, data_mod = process_args(parse_args(argv=argv))

    # if 'OMPI_COMM_WORLD_RANK' in os.environ or 'SLURMD_NODENAME' in os.environ:
    #     from mpi4py import MPI
    #     comm = MPI.COMM_WORLD
    #     RANK = comm.Get_rank()
    # else:
    #     RANK = 0
    #     print('OMPI_COMM_WORLD_RANK or SLURMD_NODENAME not set in environment -- not using MPI')

    # output is a wrapper function for os.path.join(outdir, <FILE>)
    outdir, output = process_output(args)
    check_directory(outdir)
    if not args.quiet:
        print0(' '.join(sys.argv), file=sys.stderr)
        print0("Processed Args:\n", pformat(vars(args)), file=sys.stderr)

    # save arguments
    with open(output('args.pkl'), 'wb') as f:
        pickle.dump(args, f)

    checkpoint = None
    if args.init is not None:
        checkpoint = args.init
        link_dest = 'init.ckpt'
    elif args.checkpoint is not None:
        checkpoint = args.checkpoint
        link_dest = 'resumed_from.ckpt'

    if checkpoint is not None:
        if RANK == 0:
            print0(f'symlinking to {args.checkpoint} from {outdir}')
            dest = output(link_dest)
            src = os.path.relpath(checkpoint, start=outdir)
            if os.path.exists(dest):
                existing_src = os.readlink(dest)
                if existing_src != src:
                    msg = f'Cannot create symlink to checkpoint -- {dest} already exists, but points to {existing_src}'
                    raise RuntimeError(msg)
            else:
                os.symlink(src, dest)

    seed_everything(args.seed)

    if args.csv:
        logger = CSVLogger(save_dir=output('logs')),
    else:
        logger = WandbLogger(project="deep-taxon",
                             entity='deep-taxon',
                             name=args.experiment)

    # get dataset so we can set model parameters that are
    # dependent on the dataset, such as final number of outputs

    monitor, mode = (AbstractLit.val_loss,
                     'min') if args.manifold else (AbstractLit.val_acc, 'max')
    callbacks = [
        LearningRateMonitor(logging_interval='epoch'),
        TQDMProgressBar(refresh_rate=50)
    ]
    if not args.disable_checkpoint:
        callbacks.append(
            ModelCheckpoint(dirpath=outdir,
                            save_weights_only=False,
                            save_last=True,
                            save_top_k=3,
                            mode=mode,
                            monitor=monitor))

    if args.early_stop:
        callbacks.append(
            EarlyStopping(monitor=monitor,
                          min_delta=0.001,
                          patience=10,
                          verbose=False,
                          mode=mode))

    if args.swa:
        callbacks.append(
            StochasticWeightAveraging(swa_epoch_start=args.swa_start,
                                      annealing_epochs=args.swa_anneal))

    targs = dict(
        enable_checkpointing=True,
        callbacks=callbacks,
        logger=logger,
        num_sanity_val_steps=0,
    )
    targs.update(addl_targs)

    if args.debug:
        targs['log_every_n_steps'] = 1
        targs['fast_dev_run'] = 10

    if not args.quiet:
        print0('Trainer args:\n', pformat(targs), file=sys.stderr)
        print0('DataLoader args\n:',
               pformat(data_mod._loader_kwargs),
               file=sys.stderr)
        print0('Model:\n', model, file=sys.stderr)

    trainer = Trainer(**targs)

    if args.debug:
        #print_dataloader(data_mod.test_dataloader())
        print_dataloader(data_mod.train_dataloader())
        print_dataloader(data_mod.val_dataloader())

    s = datetime.now()
    print0('START_TIME', time())
    trainer.fit(model, data_mod)
    e = datetime.now()
    td = e - s
    hours, seconds = divmod(td.seconds, 3600)
    minutes, seconds = divmod(seconds, 60)

    print0("Took %02d:%02d:%02d.%d" %
           (hours, minutes, seconds, td.microseconds),
           file=sys.stderr)
    print0("Total seconds:", td.total_seconds(), file=sys.stderr)
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBarBase, TQDMProgressBar
from pytorch_lightning.callbacks.progress.tqdm_progress import Tqdm
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel, RandomDataset
from tests.helpers.runif import RunIf


@pytest.mark.parametrize(
    "kwargs",
    [
        # won't print but is still set
        {
            "callbacks": TQDMProgressBar(refresh_rate=0)
        },
        {
            "callbacks": TQDMProgressBar()
        },
        {
            "progress_bar_refresh_rate": 1
        },
    ],
)
def test_tqdm_progress_bar_on(tmpdir, kwargs):
    """Test different ways the progress bar can be turned on."""
    if "progress_bar_refresh_rate" in kwargs:
        with pytest.deprecated_call(
                match=r"progress_bar_refresh_rate=.*` is deprecated"):
            trainer = Trainer(default_root_dir=tmpdir, **kwargs)
def create_lightning_trainer(container: LightningContainer,
                             resume_from_checkpoint: Optional[Path] = None,
                             num_nodes: int = 1,
                             multiple_trainloader_mode: str = "max_size_cycle") -> \
        Tuple[Trainer, StoringLogger]:
    """
    Creates a Pytorch Lightning Trainer object for the given model configuration. It creates checkpoint handlers
    and loggers. That includes a diagnostic logger for use in unit tests, that is also returned as the second
    return value.
    :param container: The container with model and data.
    :param resume_from_checkpoint: If provided, training resumes from this checkpoint point.
    :param num_nodes: The number of nodes to use in distributed training.
    :return: A tuple [Trainer object, diagnostic logger]
    """
    logging.debug(f"resume_from_checkpoint: {resume_from_checkpoint}")
    num_gpus = container.num_gpus_per_node()
    effective_num_gpus = num_gpus * num_nodes
    strategy = None
    if effective_num_gpus == 0:
        accelerator = "cpu"
        devices = 1
        message = "CPU"
    else:
        accelerator = "gpu"
        devices = num_gpus
        message = f"{devices} GPU"
        if effective_num_gpus > 1:
            # Accelerator should be "ddp" when running large models in AzureML (when using DDP_spawn, we get out of
            # GPU memory).
            # Initialize the DDP plugin. The default for pl_find_unused_parameters is False. If True, the plugin
            # prints out lengthy warnings about the performance impact of find_unused_parameters.
            strategy = DDPPlugin(find_unused_parameters=container.pl_find_unused_parameters)
            message += "s per node with DDP"
    logging.info(f"Using {message}")
    tensorboard_logger = TensorBoardLogger(save_dir=str(container.logs_folder), name="Lightning", version="")
    loggers = [tensorboard_logger, AzureMLLogger(False)]
    storing_logger = StoringLogger()
    loggers.append(storing_logger)
    # Use 32bit precision when running on CPU. Otherwise, make it depend on use_mixed_precision flag.
    precision = 32 if num_gpus == 0 else 16 if container.use_mixed_precision else 32
    # The next two flags control the settings in torch.backends.cudnn.deterministic and torch.backends.cudnn.benchmark
    # https://pytorch.org/docs/stable/notes/randomness.html
    # Note that switching to deterministic models can have large performance downside.
    if container.pl_deterministic:
        deterministic = True
        benchmark = False
    else:
        deterministic = False
        benchmark = True

    # The last checkpoint is considered the "best" checkpoint. For large segmentation
    # models, this still appears to be the best way of choosing them because validation loss on the relatively small
    # training patches is not stable enough. Going by the validation loss somehow works for the Prostate model, but
    # not for the HeadAndNeck model.
    # Note that "last" is somehow a misnomer, it should rather be "latest". There is a "last" checkpoint written in
    # every epoch. We could use that for recovery too, but it could happen that the job gets preempted right during
    # writing that file, and we would end up with an invalid file.
    last_checkpoint_callback = ModelCheckpoint(dirpath=str(container.checkpoint_folder),
                                               save_last=True,
                                               save_top_k=0)
    recovery_checkpoint_callback = ModelCheckpoint(dirpath=str(container.checkpoint_folder),
                                                   filename=AUTOSAVE_CHECKPOINT_FILE_NAME,
                                                   every_n_val_epochs=container.autosave_every_n_val_epochs,
                                                   save_last=False)
    callbacks: List[Callback] = [
        last_checkpoint_callback,
        recovery_checkpoint_callback,
    ]
    if container.monitor_loading:
        # TODO antonsc: Remove after fixing the callback.
        raise NotImplementedError("Monitoring batch loading times has been temporarily disabled.")
        # callbacks.append(BatchTimeCallback())
    if num_gpus > 0 and container.monitor_gpu:
        logging.info("Adding monitoring for GPU utilization")
        callbacks.append(GPUStatsMonitor(intra_step_time=True, inter_step_time=True))
    # Add the additional callbacks that were specified in get_trainer_arguments for LightningContainers
    additional_args = container.get_trainer_arguments()
    # Callbacks can be specified via the "callbacks" argument (the legacy behaviour) or the new get_callbacks method
    if "callbacks" in additional_args:
        more_callbacks = additional_args.pop("callbacks")
        if isinstance(more_callbacks, list):
            callbacks.extend(more_callbacks)  # type: ignore
        else:
            callbacks.append(more_callbacks)  # type: ignore
    callbacks.extend(container.get_callbacks())
    is_azureml_run = not is_offline_run_context(RUN_CONTEXT)
    progress_bar_refresh_rate = container.pl_progress_bar_refresh_rate
    if progress_bar_refresh_rate is None:
        progress_bar_refresh_rate = 50
        logging.info(f"The progress bar refresh rate is not set. Using a default of {progress_bar_refresh_rate}. "
                     f"To change, modify the pl_progress_bar_refresh_rate field of the container.")
    if is_azureml_run:
        callbacks.append(AzureMLProgressBar(refresh_rate=progress_bar_refresh_rate,
                                            write_to_logging_info=True,
                                            print_timestamp=False))
    else:
        callbacks.append(TQDMProgressBar(refresh_rate=progress_bar_refresh_rate))
    # Read out additional model-specific args here.
    # We probably want to keep essential ones like numgpu and logging.
    trainer = Trainer(default_root_dir=str(container.outputs_folder),
                      deterministic=deterministic,
                      benchmark=benchmark,
                      accelerator=accelerator,
                      strategy=strategy,
                      max_epochs=container.num_epochs,
                      # Both these arguments can be integers or floats. If integers, it is the number of batches.
                      # If float, it's the fraction of batches. We default to 1.0 (processing all batches).
                      limit_train_batches=container.pl_limit_train_batches or 1.0,
                      limit_val_batches=container.pl_limit_val_batches or 1.0,
                      num_sanity_val_steps=container.pl_num_sanity_val_steps,
                      check_val_every_n_epoch=container.pl_check_val_every_n_epoch,
                      callbacks=callbacks,
                      logger=loggers,
                      num_nodes=num_nodes,
                      devices=devices,
                      precision=precision,
                      sync_batchnorm=True,
                      detect_anomaly=container.detect_anomaly,
                      profiler=container.pl_profiler,
                      resume_from_checkpoint=str(resume_from_checkpoint) if resume_from_checkpoint else None,
                      multiple_trainloader_mode=multiple_trainloader_mode,
                      **additional_args)
    return trainer, storing_logger
示例#26
0
    @total.setter
    def total(self, value):
        self.__total = value
        self.total_values.append(value)

    def set_description(self, *args, **kwargs):
        super().set_description(*args, **kwargs)
        self.descriptions.append(self.desc)


@pytest.mark.parametrize(
    "pbar",
    [
        # won't print but is still set
        TQDMProgressBar(refresh_rate=0),
        TQDMProgressBar(),
    ],
)
def test_tqdm_progress_bar_on(tmpdir, pbar):
    """Test different ways the progress bar can be turned on."""
    trainer = Trainer(default_root_dir=tmpdir, callbacks=pbar)

    progress_bars = [
        c for c in trainer.callbacks if isinstance(c, ProgressBarBase)
    ]
    assert len(progress_bars) == 1
    assert progress_bars[0] is trainer.progress_bar_callback


def test_tqdm_progress_bar_off(tmpdir):
示例#27
0
文件: train.py 项目: RangiLyu/nanodet
def main(args):
    load_config(cfg, args.config)
    if cfg.model.arch.head.num_classes != len(cfg.class_names):
        raise ValueError(
            "cfg.model.arch.head.num_classes must equal len(cfg.class_names), "
            "but got {} and {}".format(cfg.model.arch.head.num_classes,
                                       len(cfg.class_names)))
    local_rank = int(args.local_rank)
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    mkdir(local_rank, cfg.save_dir)

    logger = NanoDetLightningLogger(cfg.save_dir)
    logger.dump_cfg(cfg)

    if args.seed is not None:
        logger.info("Set random seed to {}".format(args.seed))
        pl.seed_everything(args.seed)

    logger.info("Setting up data...")
    train_dataset = build_dataset(cfg.data.train, "train")
    val_dataset = build_dataset(cfg.data.val, "test")

    evaluator = build_evaluator(cfg.evaluator, val_dataset)

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.device.batchsize_per_gpu,
        shuffle=True,
        num_workers=cfg.device.workers_per_gpu,
        pin_memory=True,
        collate_fn=naive_collate,
        drop_last=True,
    )
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=cfg.device.batchsize_per_gpu,
        shuffle=False,
        num_workers=cfg.device.workers_per_gpu,
        pin_memory=True,
        collate_fn=naive_collate,
        drop_last=False,
    )

    logger.info("Creating model...")
    task = TrainingTask(cfg, evaluator)

    if "load_model" in cfg.schedule:
        ckpt = torch.load(cfg.schedule.load_model)
        if "pytorch-lightning_version" not in ckpt:
            warnings.warn(
                "Warning! Old .pth checkpoint is deprecated. "
                "Convert the checkpoint with tools/convert_old_checkpoint.py ")
            ckpt = convert_old_model(ckpt)
        load_model_weight(task.model, ckpt, logger)
        logger.info("Loaded model weight from {}".format(
            cfg.schedule.load_model))

    model_resume_path = (os.path.join(cfg.save_dir, "model_last.ckpt")
                         if "resume" in cfg.schedule else None)
    if cfg.device.gpu_ids == -1:
        logger.info("Using CPU training")
        accelerator, devices = "cpu", None
    else:
        accelerator, devices = "gpu", cfg.device.gpu_ids

    trainer = pl.Trainer(
        default_root_dir=cfg.save_dir,
        max_epochs=cfg.schedule.total_epochs,
        check_val_every_n_epoch=cfg.schedule.val_intervals,
        accelerator=accelerator,
        devices=devices,
        log_every_n_steps=cfg.log.interval,
        num_sanity_val_steps=0,
        resume_from_checkpoint=model_resume_path,
        callbacks=[TQDMProgressBar(refresh_rate=0)],  # disable tqdm bar
        logger=logger,
        benchmark=cfg.get("cudnn_benchmark", True),
        gradient_clip_val=cfg.get("grad_clip", 0.0),
    )

    trainer.fit(task, train_dataloader, val_dataloader)
示例#28
0
                                           pin_memory=True)

    def test_dataloader(self):
        return self.val_dataloader()


if __name__ == "__main__":
    LightningCLI(
        ImageNetLightningModel,
        trainer_defaults={
            "max_epochs":
            90,
            "accelerator":
            "auto",
            "devices":
            1,
            "logger":
            False,
            "benchmark":
            True,
            "callbacks": [
                # the PyTorch example refreshes every 10 batches
                TQDMProgressBar(refresh_rate=10),
                # save when the validation top1 accuracy improves
                ModelCheckpoint(monitor="val_acc1", mode="max"),
            ],
        },
        seed_everything_default=42,
        save_config_overwrite=True,
    )