예제 #1
0
def test_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir):
    """Test that printing in LightningModule goes through built-in print function when progress bar is disabled."""
    model = PrintModel()
    bar = ProgressBar()
    trainer = Trainer(
        default_root_dir=tmpdir,
        num_sanity_val_steps=0,
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        limit_predict_batches=1,
        max_steps=1,
        callbacks=[bar],
    )
    bar.disable()
    trainer.fit(model)
    trainer.test(model, verbose=False)
    trainer.predict(model)

    mock_print.assert_has_calls([
        call("training_step", end=""),
        call("validation_step", file=ANY),
        call("test_step"),
        call("predict_step")
    ])
    tqdm_write.assert_not_called()
예제 #2
0
def test_progress_bar_misconfiguration():
    """Test that Trainer doesn't accept multiple progress bars."""
    callbacks = [
        ProgressBar(),
        ProgressBar(),
        ModelCheckpoint(dirpath='../trainer')
    ]
    with pytest.raises(MisconfigurationException,
                       match=r'^You added multiple progress bar callbacks'):
        Trainer(callbacks=callbacks)
    def configure_progress_bar(self, refresh_rate=1, process_position=0):
        # smaller refresh rate on colab causes crashes, warn user about this
        if os.getenv('COLAB_GPU') and refresh_rate < 20:
            rank_zero_warn(
                "You have set progress_bar_refresh_rate < 20 on Google Colab. This"
                " may crash. Consider using progress_bar_refresh_rate >= 20 in Trainer.",
                UserWarning)

        progress_bars = [
            c for c in self.trainer.callbacks
            if isinstance(c, ProgressBarBase)
        ]
        if len(progress_bars) > 1:
            raise MisconfigurationException(
                'You added multiple progress bar callbacks to the Trainer, but currently only one'
                ' progress bar is supported.')
        elif len(progress_bars) == 1:
            progress_bar_callback = progress_bars[0]
        elif refresh_rate > 0:
            progress_bar_callback = ProgressBar(
                refresh_rate=refresh_rate,
                process_position=process_position,
            )
            self.trainer.callbacks.append(progress_bar_callback)
        else:
            progress_bar_callback = None

        return progress_bar_callback
예제 #4
0
def test_checkpoint_callbacks_are_last(tmpdir):
    """Test that checkpoint callbacks always get moved to the end of the list, with preserved order."""
    checkpoint1 = ModelCheckpoint(tmpdir)
    checkpoint2 = ModelCheckpoint(tmpdir)
    early_stopping = EarlyStopping()
    lr_monitor = LearningRateMonitor()
    progress_bar = ProgressBar()

    # no model callbacks
    model = Mock()
    model.configure_callbacks.return_value = []
    trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, checkpoint2])
    trainer.model = model
    cb_connector = CallbackConnector(trainer)
    cb_connector._attach_model_callbacks()
    assert trainer.callbacks == [progress_bar, lr_monitor, checkpoint1, checkpoint2]

    # with model-specific callbacks that substitute ones in Trainer
    model = Mock()
    model.configure_callbacks.return_value = [checkpoint1, early_stopping, checkpoint2]
    trainer = Trainer(callbacks=[progress_bar, lr_monitor, ModelCheckpoint(tmpdir)])
    trainer.model = model
    cb_connector = CallbackConnector(trainer)
    cb_connector._attach_model_callbacks()
    assert trainer.callbacks == [progress_bar, lr_monitor, early_stopping, checkpoint1, checkpoint2]
예제 #5
0
    def configure_progress_bar(self, refresh_rate=None, process_position=0):
        if os.getenv('COLAB_GPU') and refresh_rate is None:
            # smaller refresh rate on colab causes crashes, choose a higher value
            refresh_rate = 20
        refresh_rate = 1 if refresh_rate is None else refresh_rate

        progress_bars = [
            c for c in self.trainer.callbacks
            if isinstance(c, ProgressBarBase)
        ]
        if len(progress_bars) > 1:
            raise MisconfigurationException(
                'You added multiple progress bar callbacks to the Trainer, but currently only one'
                ' progress bar is supported.')
        elif len(progress_bars) == 1:
            progress_bar_callback = progress_bars[0]
        elif refresh_rate > 0:
            progress_bar_callback = ProgressBar(
                refresh_rate=refresh_rate,
                process_position=process_position,
            )
            self.trainer.callbacks.append(progress_bar_callback)
        else:
            progress_bar_callback = None

        return progress_bar_callback
예제 #6
0
def get_trainer(args):
    pl.seed_everything(args.seed)

    # loggers
    root_dir = Path(args.default_root_dir).expanduser().resolve()
    root_dir.mkdir(parents=True, exist_ok=True)
    tb_save_dir = root_dir / "tb"
    tb_logger = TensorBoardLogger(save_dir=tb_save_dir)
    loggers = [tb_logger]
    logger.info(f"Run tensorboard --logdir {tb_save_dir}")

    # callbacks
    ckpt_cb = ModelCheckpoint(verbose=True)
    lr_cb = LearningRateMonitor(logging_interval="step")
    pb_cb = ProgressBar(refresh_rate=args.progress_bar_refresh_rate)
    callbacks = [lr_cb, pb_cb]

    callbacks.append(ckpt_cb)

    gpu_cb = GPUStatsMonitor()
    callbacks.append(gpu_cb)

    plugins = []
    trainer = pl.Trainer.from_argparse_args(args,
                                            logger=loggers,
                                            callbacks=callbacks,
                                            plugins=plugins)

    return trainer
예제 #7
0
def train_ecg(config,
              data_dir=None,
              num_epochs=10,
              normalised=True,
              num_gpus=1):
    model = MLECG(config)
    if (normalised):
        model = model.float()
    dm = ECGDataModule(data_dir=data_dir,
                       num_workers=8,
                       batch_size=config["batch_size"],
                       normalised=normalised)
    metrics = {"loss": "ptl/val_loss", "mean_accuracy": "ptl/val_accuracy"}
    bar = ProgressBar()
    trainer = pl.Trainer(
        max_epochs=num_epochs,
        # If fractional GPUs passed in, convert to int.
        gpus=math.ceil(num_gpus),
        progress_bar_refresh_rate=0,
        callbacks=[EarlyStopping(patience=10, monitor='ptl/val_loss'), bar]
        #            TuneReportCheckpointCallback(
        #                metrics=metrics,
        #                filename="checkpoint",
        #                on="validation_end")
        #        ]
    )
    trainer.fit(model, dm)
def test_attach_model_callbacks():
    """ Test that the callbacks defined in the model and through Trainer get merged correctly. """
    def assert_composition(trainer_callbacks, model_callbacks, expected):
        model = Mock()
        model.configure_callbacks.return_value = model_callbacks
        trainer = Trainer(checkpoint_callback=False,
                          progress_bar_refresh_rate=0,
                          callbacks=trainer_callbacks)
        cb_connector = CallbackConnector(trainer)
        cb_connector._attach_model_callbacks(model, trainer)
        assert trainer.callbacks == expected

    early_stopping = EarlyStopping()
    progress_bar = ProgressBar()
    lr_monitor = LearningRateMonitor()
    grad_accumulation = GradientAccumulationScheduler({1: 1})

    # no callbacks
    assert_composition(trainer_callbacks=[], model_callbacks=[], expected=[])

    # callbacks of different types
    assert_composition(trainer_callbacks=[early_stopping],
                       model_callbacks=[progress_bar],
                       expected=[early_stopping, progress_bar])

    # same callback type twice, different instance
    assert_composition(trainer_callbacks=[progress_bar,
                                          EarlyStopping()],
                       model_callbacks=[early_stopping],
                       expected=[progress_bar, early_stopping])

    # multiple callbacks of the same type in trainer
    assert_composition(trainer_callbacks=[
        LearningRateMonitor(),
        EarlyStopping(),
        LearningRateMonitor(),
        EarlyStopping()
    ],
                       model_callbacks=[early_stopping, lr_monitor],
                       expected=[early_stopping, lr_monitor])

    # multiple callbacks of the same type, in both trainer and model
    assert_composition(trainer_callbacks=[
        LearningRateMonitor(), progress_bar,
        EarlyStopping(),
        LearningRateMonitor(),
        EarlyStopping()
    ],
                       model_callbacks=[
                           early_stopping, lr_monitor, grad_accumulation,
                           early_stopping
                       ],
                       expected=[
                           progress_bar, early_stopping, lr_monitor,
                           grad_accumulation, early_stopping
                       ])
예제 #9
0
def test_progress_bar_can_be_pickled():
    bar = ProgressBar()
    trainer = Trainer(fast_dev_run=True, callbacks=[bar], max_steps=1)
    model = BoringModel()

    pickle.dumps(bar)
    trainer.fit(model)
    pickle.dumps(bar)
    trainer.test(model)
    pickle.dumps(bar)
    trainer.predict(model)
    pickle.dumps(bar)
예제 #10
0
파일: train.py 프로젝트: ycdhqzhiai/nanodet
def main(args):
    load_config(cfg, args.config)
    if cfg.model.arch.head.num_classes != len(cfg.class_names):
        raise ValueError('cfg.model.arch.head.num_classes must equal len(cfg.class_names),but got {} and {}'.format(cfg.model.arch.head.num_classes,len(cfg.class_names)))
    local_rank = int(args.local_rank)
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    mkdir(local_rank, cfg.save_dir)
    logger = Logger(local_rank, cfg.save_dir)

    if args.seed is not None:
        logger.log('Set random seed to {}'.format(args.seed))
        pl.seed_everything(args.seed)

    logger.log('Setting up data...')
    train_dataset = build_dataset(cfg.data.train, 'train')
    val_dataset = build_dataset(cfg.data.val, 'test')

    evaluator = build_evaluator(cfg, val_dataset)

    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=cfg.device.batchsize_per_gpu,
                                                   shuffle=True, num_workers=cfg.device.workers_per_gpu,
                                                   pin_memory=True, collate_fn=collate_function, drop_last=True)
    # TODO: batch eval
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False,
                                                 num_workers=cfg.device.workers_per_gpu,
                                                 pin_memory=True, collate_fn=collate_function, drop_last=True)

    logger.log('Creating model...')
    task = TrainingTask(cfg, evaluator)

    if 'load_model' in cfg.schedule:
        ckpt = torch.load(cfg.schedule.load_model)
        if 'pytorch-lightning_version' not in ckpt:
            warnings.warn('Warning! Old .pth checkpoint is deprecated. '
                          'Convert the checkpoint with tools/convert_old_checkpoint.py ')
            ckpt = convert_old_model(ckpt)
        task.load_state_dict(ckpt['state_dict'], strict=False)

    model_resume_path = os.path.join(cfg.save_dir, 'model_last.ckpt') if 'resume' in cfg.schedule else None

    trainer = pl.Trainer(default_root_dir=cfg.save_dir,
                         max_epochs=cfg.schedule.total_epochs,
                         gpus=cfg.device.gpu_ids,
                         check_val_every_n_epoch=cfg.schedule.val_intervals,
                         accelerator='ddp',
                         log_every_n_steps=cfg.log.interval,
                         num_sanity_val_steps=0,
                         resume_from_checkpoint=model_resume_path,
                         callbacks=[ProgressBar(refresh_rate=0)]  # disable tqdm bar
                         )

    trainer.fit(task, train_dataloader, val_dataloader)
def test_checkpoint_callbacks_are_last(tmpdir):
    """ Test that checkpoint callbacks always get moved to the end of the list, with preserved order. """
    checkpoint1 = ModelCheckpoint(tmpdir)
    checkpoint2 = ModelCheckpoint(tmpdir)
    lr_monitor = LearningRateMonitor()
    progress_bar = ProgressBar()

    model = Mock()
    model.configure_callbacks.return_value = []
    trainer = Trainer(
        callbacks=[checkpoint1, progress_bar, lr_monitor, checkpoint2])
    assert trainer.callbacks == [
        progress_bar, lr_monitor, checkpoint1, checkpoint2
    ]
예제 #12
0
파일: train.py 프로젝트: gbotev1/cgmfpim
def main(args: Namespace) -> None:
    if args.seed_everything:
        seed_everything(0)  # For reproducibility
    datamodule = MemesDataModule(args)
    model = GPT2(args=args, tokenizer=datamodule.tokenizer)
    trainer = Trainer.from_argparse_args(
        args,
        callbacks=[
            ProgressBar(),
            ModelCheckpoint(monitor='train_loss',
                            save_top_k=args.max_epochs,
                            save_weights_only=True)
        ])  # Save checkpoint after every epoch
    trainer.tune(model, datamodule=datamodule)
    trainer.fit(model, datamodule)
def test_attach_model_callbacks_override_info(caplog):
    """ Test that the logs contain the info about overriding callbacks returned by configure_callbacks. """
    model = Mock()
    model.configure_callbacks.return_value = [
        LearningRateMonitor(), EarlyStopping()
    ]
    trainer = Trainer(
        checkpoint_callback=False,
        callbacks=[EarlyStopping(),
                   LearningRateMonitor(),
                   ProgressBar()])
    cb_connector = CallbackConnector(trainer)
    with caplog.at_level(logging.INFO):
        cb_connector._attach_model_callbacks(model, trainer)

    assert "existing callbacks passed to Trainer: EarlyStopping, LearningRateMonitor" in caplog.text
예제 #14
0
 def configure_progress_bar(self):
     progress_bars = [
         c for c in self.callbacks if isinstance(c, ProgressBarBase)
     ]
     if len(progress_bars) > 1:
         raise MisconfigurationException(
             'You added multiple progress bar callbacks to the Trainer, but currently only one'
             ' progress bar is supported.')
     elif len(progress_bars) == 1:
         self.progress_bar_callback = progress_bars[0]
     elif self.progress_bar_refresh_rate > 0:
         self.progress_bar_callback = ProgressBar(
             refresh_rate=self.progress_bar_refresh_rate,
             process_position=self.process_position,
         )
         self.callbacks.append(self.progress_bar_callback)
     else:
         self.progress_bar_callback = None
예제 #15
0
def test_progress_bar_print(tqdm_write, tmpdir):
    """ Test that printing in the LightningModule redirects arguments to the progress bar. """
    model = PrintModel()
    bar = ProgressBar()
    trainer = Trainer(
        default_root_dir=tmpdir,
        num_sanity_val_steps=0,
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        max_steps=1,
        callbacks=[bar],
    )
    trainer.fit(model)
    trainer.test(model)
    assert tqdm_write.call_count == 3
    assert tqdm_write.call_args_list == [
        call("training_step", end="", file=None, nolock=False),
        call("validation_step", end=os.linesep, file=sys.stderr, nolock=False),
        call("test_step", end=os.linesep, file=None, nolock=False),
    ]
    model = load_model(model_name,
                       loss,
                       LEARNING_RATE,
                       BATCH_SIZE,
                       NUM_WORKERS,
                       pretrain=True)

    print('The model has {:,} trainable parameters'.format(
        count_parameters(model)))
    """
    ------------------------------------
    Intialize the trainier from the model and the callbacks
    ------------------------------------
    """

    bar = ProgressBar()
    # early_stopping = EarlyStopping('val_acc_en', patience=15)

    # Two validation metrics. Let's have two different saver.
    ckpt_en = ModelCheckpoint(dirpath=save_dir,
                              monitor='val_acc_en',
                              mode='auto',
                              save_last=True,
                              filename='{epoch:02d}-{val_acc_en:.4f}')

    ckpt_reg = ModelCheckpoint(dirpath=save_dir,
                               monitor='val_acc',
                               mode='auto',
                               save_last=True,
                               filename='{epoch:02d}-{val_acc:.4f}')
예제 #17
0
def test_progress_bar_main_bar_resume():
    """Test that the progress bar can resume its counters based on the Trainer state."""
    bar = ProgressBar()
    trainer = Mock()
    model = Mock()

    trainer.sanity_checking = False
    trainer.check_val_every_n_epoch = 1
    trainer.current_epoch = 1
    trainer.num_training_batches = 5
    trainer.val_check_batch = 5
    trainer.num_val_batches = [3]
    trainer.fit_loop.epoch_loop.batch_progress.current.completed = 3

    bar.on_init_end(trainer)
    bar.on_train_start(trainer, model)
    bar.on_train_epoch_start(trainer, model)

    assert bar.main_progress_bar.n == 3
    assert bar.main_progress_bar.total == 8

    # bar.on_train_epoch_end(trainer, model)
    bar.on_validation_start(trainer, model)
    bar.on_validation_epoch_start(trainer, model)

    # restarting mid validation epoch is not currently supported
    assert bar.val_progress_bar.n == 0
    assert bar.val_progress_bar.total == 3
예제 #18
0
import pytest

import tests.base.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ProgressBarBase, ProgressBar, ModelCheckpoint
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate


@pytest.mark.parametrize('callbacks,refresh_rate', [
    ([], 1),
    ([], 2),
    ([ProgressBar(refresh_rate=1)], 0),
    ([ProgressBar(refresh_rate=2)], 0),
    ([ProgressBar(refresh_rate=2)], 1),
])
def test_progress_bar_on(callbacks, refresh_rate):
    """Test different ways the progress bar can be turned on."""

    trainer = Trainer(
        callbacks=callbacks,
        progress_bar_refresh_rate=refresh_rate,
        max_epochs=1,
        overfit_pct=0.2,
    )

    progress_bars = [
        c for c in trainer.callbacks if isinstance(c, ProgressBarBase)
    ]
    # Trainer supports only a single progress bar callback at the moment
    assert len(progress_bars) == 1
예제 #19
0
def main(args):
    load_config(cfg, args.config)
    if cfg.model.arch.head.num_classes != len(cfg.class_names):
        raise ValueError(
            "cfg.model.arch.head.num_classes must equal len(cfg.class_names), "
            "but got {} and {}".format(cfg.model.arch.head.num_classes,
                                       len(cfg.class_names)))
    local_rank = int(args.local_rank)
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    mkdir(local_rank, cfg.save_dir)

    logger = NanoDetLightningLogger(cfg.save_dir)
    logger.dump_cfg(cfg)

    if args.seed is not None:
        logger.info("Set random seed to {}".format(args.seed))
        pl.seed_everything(args.seed)

    logger.info("Setting up data...")
    train_dataset = build_dataset(cfg.data.train, "train")
    val_dataset = build_dataset(cfg.data.val, "test")

    evaluator = build_evaluator(cfg.evaluator, val_dataset)

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.device.batchsize_per_gpu,
        shuffle=True,
        num_workers=cfg.device.workers_per_gpu,
        pin_memory=True,
        collate_fn=naive_collate,
        drop_last=True,
    )
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=cfg.device.batchsize_per_gpu,
        shuffle=False,
        num_workers=cfg.device.workers_per_gpu,
        pin_memory=True,
        collate_fn=naive_collate,
        drop_last=False,
    )

    logger.info("Creating model...")
    task = TrainingTask(cfg, evaluator)

    if "load_model" in cfg.schedule:
        ckpt = torch.load(cfg.schedule.load_model)
        if "pytorch-lightning_version" not in ckpt:
            warnings.warn(
                "Warning! Old .pth checkpoint is deprecated. "
                "Convert the checkpoint with tools/convert_old_checkpoint.py ")
            ckpt = convert_old_model(ckpt)
        load_model_weight(task.model, ckpt, logger)
        logger.info("Loaded model weight from {}".format(
            cfg.schedule.load_model))

    model_resume_path = (os.path.join(cfg.save_dir, "model_last.ckpt")
                         if "resume" in cfg.schedule else None)

    accelerator = None if len(cfg.device.gpu_ids) <= 1 else "ddp"

    trainer = pl.Trainer(
        default_root_dir=cfg.save_dir,
        max_epochs=cfg.schedule.total_epochs,
        gpus=cfg.device.gpu_ids,
        check_val_every_n_epoch=cfg.schedule.val_intervals,
        accelerator=accelerator,
        log_every_n_steps=cfg.log.interval,
        num_sanity_val_steps=0,
        resume_from_checkpoint=model_resume_path,
        callbacks=[ProgressBar(refresh_rate=0)],  # disable tqdm bar
        logger=logger,
        benchmark=True,
        gradient_clip_val=cfg.get("grad_clip", 0.0),
    )

    trainer.fit(task, train_dataloader, val_dataloader)