Пример #1
0
def test_rl_early_stopping():
    gata_double_dqn = GATADoubleDQN()
    trainer = Trainer()
    es = RLEarlyStopping("val_monitor", "train_monitor", 0.95, patience=3)

    # if val score and train score are all below the threshold 0.95, don't stop
    trainer.callback_metrics = {"val_monitor": 0.1, "train_monitor": 0.1}
    es._run_early_stopping_check(trainer, gata_double_dqn)
    assert not trainer.should_stop

    # if val score is 1.0 and train score is above the threshold, stop
    trainer.callback_metrics = {"val_monitor": 1.0, "train_monitor": 0.95}
    trainer.current_epoch = 1
    es._run_early_stopping_check(trainer, gata_double_dqn)
    assert trainer.should_stop
    assert es.stopped_epoch == 1

    # if train score is above the threshold for `patience` times,
    # but val score is not 1.0, stop
    trainer.should_stop = False
    es.wait_count = 0
    es.stopped_epoch = 0
    for i in range(3):
        trainer.current_epoch = i
        trainer.callback_metrics = {"val_monitor": 0.9, "train_monitor": 0.95}
        es._run_early_stopping_check(trainer, gata_double_dqn)
        if i == 2:
            assert trainer.should_stop
            assert es.stopped_epoch == 2
        else:
            assert not trainer.should_stop
            assert es.stopped_epoch == 0
Пример #2
0
def test_model_checkpoint_options(tmpdir, save_top_k, save_last, file_prefix,
                                  expected_files):
    """Test ModelCheckpoint options."""
    def mock_save_function(filepath, *args):
        open(filepath, 'a').close()

    # simulated losses
    losses = [10, 9, 2.8, 5, 2.5]

    checkpoint_callback = ModelCheckpoint(tmpdir,
                                          save_top_k=save_top_k,
                                          save_last=save_last,
                                          prefix=file_prefix,
                                          verbose=1)
    checkpoint_callback.save_function = mock_save_function
    trainer = Trainer()

    # emulate callback's calls during the training
    for i, loss in enumerate(losses):
        trainer.current_epoch = i
        trainer.callback_metrics = {'val_loss': loss}
        checkpoint_callback.on_validation_end(trainer, trainer.get_model())

    file_lists = set(os.listdir(tmpdir))

    assert len(file_lists) == len(expected_files), \
        "Should save %i models when save_top_k=%i" % (len(expected_files), save_top_k)

    # verify correct naming
    for fname in expected_files:
        assert fname in file_lists
Пример #3
0
def test_model_checkpoint_options(tmpdir, save_top_k, save_last, file_prefix,
                                  expected_files):
    """Test ModelCheckpoint options."""
    def mock_save_function(filepath, *args):
        open(filepath, 'a').close()

    # simulated losses
    losses = [10, 9, 2.8, 5, 2.5]

    checkpoint_callback = ModelCheckpoint(tmpdir,
                                          monitor='checkpoint_on',
                                          save_top_k=save_top_k,
                                          save_last=save_last,
                                          prefix=file_prefix,
                                          verbose=1)
    checkpoint_callback.save_function = mock_save_function
    trainer = Trainer()

    # emulate callback's calls during the training
    for i, loss in enumerate(losses):
        trainer.current_epoch = i
        trainer.logger_connector.callback_metrics = {
            'checkpoint_on': torch.tensor(loss)
        }
        checkpoint_callback.on_validation_end(trainer, trainer.get_model())

    file_lists = set(os.listdir(tmpdir))

    assert len(file_lists) == len(expected_files), (
        f"Should save {len(expected_files)} models when save_top_k={save_top_k} but found={file_lists}"
    )

    # verify correct naming
    for fname in expected_files:
        assert fname in file_lists
Пример #4
0
def trainWithTune(config,
                  checkpoint_dir=None,
                  datamodule=None,
                  num_epochs=10,
                  num_gpus=0):
    trainer = Trainer(
        max_epochs=num_epochs,
        # If fractional GPUs passed in, convert to int.
        gpus=math.ceil(num_gpus),
        logger=TensorBoardLogger(save_dir=tune.get_trial_dir(),
                                 name="",
                                 version="."),
        progress_bar_refresh_rate=0,
        callbacks=[
            TuneReportCheckpointCallback(metrics={
                "loss": "val_loss",
                "mean_accuracy": "val_acc",
                "mean_iou": "val_iou",
            },
                                         filename="checkpoint",
                                         on="validation_end")
        ])

    if checkpoint_dir:
        # Currently, this leads to errors:
        # model = LightningMNISTClassifier.load_from_checkpoint(
        #     os.path.join(checkpoint, "checkpoint"))
        # Workaround:
        ckpt = pl_load(os.path.join(checkpoint_dir, "checkpoint"),
                       map_location=lambda storage, loc: storage)
        model = MMETrainingModule._load_model_state(
            ckpt,
            lr=10**config['log_lr'],
            lrRatio=10**config['log_lrRatio'],
            decay=10**config['log_decay'],
            num_cls=NUM_CLS)
        trainer.current_epoch = ckpt["epoch"]
    else:
        model = MMETrainingModule(lr=10**config['log_lr'],
                                  lrRatio=10**config['log_lrRatio'],
                                  decay=10**config['log_decay'],
                                  num_cls=NUM_CLS)

    trainer.fit(model, datamodule=datamodule)
Пример #5
0
        default_root_dir=args.results,
        resume_from_checkpoint=ckpt_path,
        accelerator="ddp" if args.gpus > 1 else None,
        limit_train_batches=1.0 if args.train_batches == 0 else args.train_batches,
        limit_val_batches=1.0 if args.test_batches == 0 else args.test_batches,
        limit_test_batches=1.0 if args.test_batches == 0 else args.test_batches,
    )

    if args.benchmark:
        if args.exec_mode == "train":
            trainer.fit(model, train_dataloader=data_module.train_dataloader())
        else:
            # warmup
            trainer.test(model, test_dataloaders=data_module.test_dataloader())
            # benchmark run
            trainer.current_epoch = 1
            trainer.test(model, test_dataloaders=data_module.test_dataloader())
    elif args.exec_mode == "train":
        trainer.fit(model, data_module)
        if is_main_process():
            logname = args.logname if args.logname is not None else "train_log.json"
            log(logname, torch.tensor(model.best_mean_dice), results=args.results)
    elif args.exec_mode == "evaluate":
        model.args = args
        trainer.test(model, test_dataloaders=data_module.val_dataloader())
        if is_main_process():
            logname = args.logname if args.logname is not None else "eval_log.json"
            log(logname, model.eval_dice, results=args.results)
    elif args.exec_mode == "predict":
        if args.save_preds:
            ckpt_name = "_".join(args.ckpt_path.split("/")[-1].split(".")[:-1])
Пример #6
0
def test_model_checkpoint_options(tmp_path):
    """Test ModelCheckpoint options."""
    def mock_save_function(filepath):
        open(filepath, 'a').close()

    hparams = tutils.get_hparams()
    _ = LightningTestModel(hparams)

    # simulated losses
    save_dir = tmp_path / "1"
    save_dir.mkdir()
    losses = [10, 9, 2.8, 5, 2.5]

    # -----------------
    # CASE K=-1  (all)
    checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=-1, verbose=1)
    checkpoint_callback.save_function = mock_save_function
    trainer = Trainer()

    # emulate callback's calls during the training
    for i, loss in enumerate(losses):
        trainer.current_epoch = i
        trainer.callback_metrics = {'val_loss': loss}
        checkpoint_callback.on_validation_end(trainer, trainer.get_model())

    file_lists = set(os.listdir(save_dir))

    assert len(file_lists) == len(
        losses), "Should save all models when save_top_k=-1"

    # verify correct naming
    for i in range(0, len(losses)):
        assert f"_ckpt_epoch_{i}.ckpt" in file_lists

    save_dir = tmp_path / "2"
    save_dir.mkdir()

    # -----------------
    # CASE K=0 (none)
    checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=0, verbose=1)
    checkpoint_callback.save_function = mock_save_function
    trainer = Trainer()

    # emulate callback's calls during the training
    for i, loss in enumerate(losses):
        trainer.current_epoch = i
        trainer.callback_metrics = {'val_loss': loss}
        checkpoint_callback.on_validation_end(trainer, trainer.get_model())

    file_lists = os.listdir(save_dir)

    assert len(file_lists) == 0, "Should save 0 models when save_top_k=0"

    save_dir = tmp_path / "3"
    save_dir.mkdir()

    # -----------------
    # CASE K=1 (2.5, epoch 4)
    checkpoint_callback = ModelCheckpoint(save_dir,
                                          save_top_k=1,
                                          verbose=1,
                                          prefix='test_prefix')
    checkpoint_callback.save_function = mock_save_function
    trainer = Trainer()

    # emulate callback's calls during the training
    for i, loss in enumerate(losses):
        trainer.current_epoch = i
        trainer.callback_metrics = {'val_loss': loss}
        checkpoint_callback.on_validation_end(trainer, trainer.get_model())

    file_lists = set(os.listdir(save_dir))

    assert len(file_lists) == 1, "Should save 1 model when save_top_k=1"
    assert 'test_prefix_ckpt_epoch_4.ckpt' in file_lists

    save_dir = tmp_path / "4"
    save_dir.mkdir()

    # -----------------
    # CASE K=2 (2.5 epoch 4, 2.8 epoch 2)
    # make sure other files don't get deleted

    checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=2, verbose=1)
    open(f"{save_dir}/other_file.ckpt", 'a').close()
    checkpoint_callback.save_function = mock_save_function
    trainer = Trainer()

    # emulate callback's calls during the training
    for i, loss in enumerate(losses):
        trainer.current_epoch = i
        trainer.callback_metrics = {'val_loss': loss}
        checkpoint_callback.on_validation_end(trainer, trainer.get_model())

    file_lists = set(os.listdir(save_dir))

    assert len(file_lists) == 3, 'Should save 2 model when save_top_k=2'
    assert '_ckpt_epoch_4.ckpt' in file_lists
    assert '_ckpt_epoch_2.ckpt' in file_lists
    assert 'other_file.ckpt' in file_lists

    save_dir = tmp_path / "5"
    save_dir.mkdir()

    # -----------------
    # CASE K=4 (save all 4 models)
    # multiple checkpoints within same epoch

    checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=4, verbose=1)
    checkpoint_callback.save_function = mock_save_function
    trainer = Trainer()

    # emulate callback's calls during the training
    for loss in losses:
        trainer.current_epoch = 0
        trainer.callback_metrics = {'val_loss': loss}
        checkpoint_callback.on_validation_end(trainer, trainer.get_model())

    file_lists = set(os.listdir(save_dir))

    assert len(
        file_lists
    ) == 4, 'Should save all 4 models when save_top_k=4 within same epoch'

    save_dir = tmp_path / "6"
    save_dir.mkdir()

    # -----------------
    # CASE K=3 (save the 2nd, 3rd, 4th model)
    # multiple checkpoints within same epoch

    checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=3, verbose=1)
    checkpoint_callback.save_function = mock_save_function
    trainer = Trainer()

    # emulate callback's calls during the training
    for loss in losses:
        trainer.current_epoch = 0
        trainer.callback_metrics = {'val_loss': loss}
        checkpoint_callback.on_validation_end(trainer, trainer.get_model())

    file_lists = set(os.listdir(save_dir))

    assert len(file_lists) == 3, 'Should save 3 models when save_top_k=3'
    assert '_ckpt_epoch_0_v2.ckpt' in file_lists
    assert '_ckpt_epoch_0_v1.ckpt' in file_lists
    assert '_ckpt_epoch_0.ckpt' in file_lists