Пример #1
0
def train_with_pruning_callback(
    tmpdir,
    parameters_to_prune=False,
    use_global_unstructured=False,
    pruning_fn="l1_unstructured",
    use_lottery_ticket_hypothesis=False,
    accelerator=None,
    gpus=None,
    num_processes=1,
):
    model = TestModel()

    # Weights are random. None is 0
    assert torch.all(model.layer.mlp_2.weight != 0)

    pruning_kwargs = {
        "pruning_fn": pruning_fn,
        "amount": 0.3,
        "use_global_unstructured": use_global_unstructured,
        "use_lottery_ticket_hypothesis": use_lottery_ticket_hypothesis,
        "verbose": 1,
    }
    if parameters_to_prune:
        pruning_kwargs["parameters_to_prune"] = [(model.layer.mlp_1, "weight"),
                                                 (model.layer.mlp_2, "weight")]
    else:
        if isinstance(pruning_fn, str) and pruning_fn.endswith("_structured"):
            pruning_kwargs["parameter_names"] = ["weight"]
        else:
            pruning_kwargs["parameter_names"] = ["weight", "bias"]
    if isinstance(pruning_fn, str) and pruning_fn.endswith("_structured"):
        pruning_kwargs["pruning_dim"] = 0
    if pruning_fn == "ln_structured":
        pruning_kwargs["pruning_norm"] = 1

    # Misconfiguration checks
    if isinstance(pruning_fn, str) and pruning_fn.endswith(
            "_structured") and use_global_unstructured:
        with pytest.raises(
                MisconfigurationException,
                match="is supported with `use_global_unstructured=True`"):
            ModelPruning(**pruning_kwargs)
        return
    if ModelPruning._is_pruning_method(
            pruning_fn) and not use_global_unstructured:
        with pytest.raises(MisconfigurationException,
                           match="currently only supported with"):
            ModelPruning(**pruning_kwargs)
        return

    pruning = ModelPruning(**pruning_kwargs)

    trainer = Trainer(
        default_root_dir=tmpdir,
        progress_bar_refresh_rate=0,
        weights_summary=None,
        checkpoint_callback=False,
        logger=False,
        limit_train_batches=10,
        limit_val_batches=2,
        max_epochs=10,
        accelerator=accelerator,
        gpus=gpus,
        num_processes=num_processes,
        callbacks=pruning,
    )
    trainer.fit(model)
    trainer.test(model)

    if not accelerator:
        # Check some have been pruned
        assert torch.any(model.layer.mlp_2.weight == 0)
def test_disabled_checkpointing(tmpdir):
    # no callback
    trainer = Trainer(max_epochs=3, enable_checkpointing=False)
    assert not trainer.checkpoint_callbacks
    trainer.fit(BoringModel())
    assert not trainer.checkpoint_callbacks
Пример #3
0
def test_gradient_accumulation_scheduling():
    """
    Test grad accumulation by the freq of optimizer updates
    """
    # test incorrect configs
    with pytest.raises(IndexError):
        assert Trainer(accumulate_grad_batches={0: 3, 1: 4, 4: 6})
        assert Trainer(accumulate_grad_batches={-2: 3})

    with pytest.raises(TypeError):
        assert Trainer(accumulate_grad_batches={})
        assert Trainer(accumulate_grad_batches=[[2, 3], [4, 6]])
        assert Trainer(accumulate_grad_batches={1: 2, 3.: 4})
        assert Trainer(accumulate_grad_batches={1: 2.5, 3: 5})

    # test optimizer call freq matches scheduler
    def optimizer_step(self, epoch_nb, batch_nb, optimizer, optimizer_i):
        # only test the first 12 batches in epoch
        if batch_nb < 12:
            if epoch_nb == 0:
                # reset counter when starting epoch
                if batch_nb == 0:
                    self.prev_called_batch_nb = 0

                    # use this opportunity to test once
                    assert self.trainer.accumulate_grad_batches == 1

                assert batch_nb == self.prev_called_batch_nb
                self.prev_called_batch_nb += 1

            elif 1 <= epoch_nb <= 2:
                # reset counter when starting epoch
                if batch_nb == 1:
                    self.prev_called_batch_nb = 1

                    # use this opportunity to test once
                    assert self.trainer.accumulate_grad_batches == 2

                assert batch_nb == self.prev_called_batch_nb
                self.prev_called_batch_nb += 2

            else:
                if batch_nb == 3:
                    self.prev_called_batch_nb = 3

                    # use this opportunity to test once
                    assert self.trainer.accumulate_grad_batches == 4

                assert batch_nb == self.prev_called_batch_nb
                self.prev_called_batch_nb += 3

        optimizer.step()

        # clear gradients
        optimizer.zero_grad()

    hparams = get_hparams()
    model = LightningTestModel(hparams)
    schedule = {1: 2, 3: 4}

    trainer = Trainer(accumulate_grad_batches=schedule,
                      train_percent_check=0.1,
                      val_percent_check=0.1,
                      max_nb_epochs=4)

    # for the test
    trainer.optimizer_step = optimizer_step
    model.prev_called_batch_nb = 0

    trainer.fit(model)
def test_metric_collections(tmpdir):
    """This test ensures the metric attribute is properly found even with complex nested metric structure."""
    class TestModel(BoringModel):
        def __init__(self):
            super().__init__()
            self.metrics_list = ModuleList([DummyMetric() for _ in range(2)])
            self.metrics_dict = ModuleDict({
                "a": DummyMetric(),
                "b": DummyMetric()
            })
            self.metrics_collection_dict = MetricCollection({
                "a": DummyMetric(),
                "b": DummyMetric()
            })
            self.metrics_collection_dict_nested = ModuleDict({
                "a":
                ModuleList([ModuleDict({"b": DummyMetric()}),
                            DummyMetric()])
            })

        def training_step(self, batch, batch_idx):
            loss = super().training_step(batch, batch_idx)
            self.metrics_list[0](batch_idx)
            self.metrics_list[1](batch_idx)

            self.metrics_dict["a"](batch_idx)
            self.metrics_dict["b"](batch_idx)

            self.metrics_collection_dict["a"](batch_idx)
            self.metrics_collection_dict["b"](batch_idx)

            self.metrics_collection_dict_nested["a"][0]["b"](batch_idx)
            self.metrics_collection_dict_nested["a"][1](batch_idx)

            self.log("a", self.metrics_list[0])
            self.log("b", self.metrics_list[1])

            self.log("c", self.metrics_dict["a"])
            self.log("d", self.metrics_dict["b"])

            self.log("e", self.metrics_collection_dict["a"])
            self.log("f", self.metrics_collection_dict["b"])

            self.log("g", self.metrics_collection_dict_nested["a"][0]["b"])
            self.log("h", self.metrics_collection_dict_nested["a"][1])

            return loss

        def on_train_epoch_end(self) -> None:
            results = self.trainer.fit_loop.epoch_loop._results
            assert results[
                "training_step.a"].meta.metric_attribute == "metrics_list.0"
            assert results[
                "training_step.b"].meta.metric_attribute == "metrics_list.1"

            assert results[
                "training_step.c"].meta.metric_attribute == "metrics_dict.a"
            assert results[
                "training_step.d"].meta.metric_attribute == "metrics_dict.b"

            assert results[
                "training_step.e"].meta.metric_attribute == "metrics_collection_dict.a"
            assert results[
                "training_step.f"].meta.metric_attribute == "metrics_collection_dict.b"

            assert results[
                "training_step.g"].meta.metric_attribute == "metrics_collection_dict_nested.a.0.b"
            assert results[
                "training_step.h"].meta.metric_attribute == "metrics_collection_dict_nested.a.1"

    model = TestModel()

    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=2,
                      limit_train_batches=2,
                      limit_val_batches=0)
    trainer.fit(model)
Пример #5
0
def main(cfg: DictConfig):
    cur_dir = hydra.utils.get_original_cwd()
    os.chdir(cur_dir)
    # Random Seed
    seed_everything(cfg.train.seed)

    # Model  ####################################################################
    net = ENet(model_name=cfg.train.model_name)
    transform = ImageTransform(img_size=cfg.data.img_size)

    # Comet.ml
    experiment = Experiment(api_key=cfg.comet_ml.api_key,
                            project_name=cfg.comet_ml.project_name)
    # Log Parameters
    experiment.log_parameters(dict(cfg.exp))
    experiment.log_parameters(dict(cfg.data))
    experiment.log_parameters(dict(cfg.train))
    # Log Model Graph
    experiment.set_model_graph(str(net))

    # Lightning Module  #########################################################
    model = LightningSystem(net, cfg, experiment)
    datamodule = DataModule(data_dir, cfg, transform, cv)

    checkpoint_callback = ModelCheckpoint(filepath='./checkpoint',
                                          save_top_k=1,
                                          verbose=True,
                                          monitor='avg_val_loss',
                                          mode='min',
                                          prefix=cfg.exp.exp_name + '_')

    trainer = Trainer(logger=False,
                      max_epochs=cfg.train.epoch,
                      checkpoint_callback=checkpoint_callback,
                      gpus=1)

    # Train & Test  ############################################################
    # Train
    trainer.fit(model, datamodule=datamodule)
    experiment.log_metric('best_auc', model.best_auc)
    checkpoint_path = glob.glob(f'./checkpoint/{cfg.exp.exp_name}_*.ckpt')[0]
    experiment.log_asset(file_data=checkpoint_path)

    # Test
    for i in range(test_num):
        trainer.test(model)

    # Submit
    sub_list = glob.glob(f'submission_{cfg.exp.exp_name}*.csv')
    _ = summarize_submit(sub_list,
                         experiment,
                         filename=f'sub_{cfg.exp.exp_name}.csv')

    # oof
    oof_dataset = datamodule.oof_dataset
    oof_dataloader = DataLoader(oof_dataset,
                                batch_size=cfg.train.batch_size,
                                pin_memory=False,
                                shuffle=False,
                                drop_last=False)
    for i in range(10):
        trainer.test(model, test_dataloaders=oof_dataloader)

    # Submit
    sub_list = glob.glob('submission*.csv')
    _ = summarize_submit(sub_list,
                         experiment,
                         filename=f'oof_{cfg.exp.exp_name}.csv')
Пример #6
0
def test_cpu_restore_training():
    """
    Verify continue training session on CPU
    :return:
    """
    hparams = get_hparams()
    model = LightningTestModel(hparams)

    save_dir = init_save_dir()

    # exp file to get meta
    test_exp_version = 10
    exp = get_exp(False, version=test_exp_version)
    exp.argparse(hparams)
    exp.save()

    trainer_options = dict(max_nb_epochs=2,
                           val_check_interval=0.50,
                           val_percent_check=0.2,
                           train_percent_check=0.2,
                           experiment=exp,
                           checkpoint_callback=ModelCheckpoint(save_dir))

    # fit model
    trainer = Trainer(**trainer_options)
    result = trainer.fit(model)
    real_global_epoch = trainer.current_epoch

    # traning complete
    assert result == 1, 'amp + ddp model failed to complete'

    # wipe-out trainer and model
    # retrain with not much data... this simulates picking training back up after slurm
    # we want to see if the weights come back correctly
    new_exp = get_exp(False, version=test_exp_version)
    trainer_options = dict(
        max_nb_epochs=2,
        val_check_interval=0.50,
        val_percent_check=0.2,
        train_percent_check=0.2,
        experiment=new_exp,
        checkpoint_callback=ModelCheckpoint(save_dir),
    )
    trainer = Trainer(**trainer_options)
    model = LightningTestModel(hparams)

    # set the epoch start hook so we can predict before the model does the full training
    def assert_good_acc():
        assert trainer.current_epoch == real_global_epoch and trainer.current_epoch > 0

        # if model and state loaded correctly, predictions will be good even though we
        # haven't trained with the new loaded model
        trainer.model.eval()
        run_prediction(trainer.val_dataloader, trainer.model)

    model.on_sanity_check_start = assert_good_acc

    # by calling fit again, we trigger training, loading weights from the cluster
    # and our hook to predict using current model before any more weight updates
    trainer.fit(model)

    clear_save_dir()
Пример #7
0
def test_trainer_callback_system(tmpdir):
    """Test the callback system."""
    class CurrentTestModel(
            LightTrainDataloader,
            LightTestMixin,
            LightValidationMixin,
            TestModelBase,
    ):
        pass

    hparams = tutils.get_default_hparams()
    model = CurrentTestModel(hparams)

    def _check_args(trainer, pl_module):
        assert isinstance(trainer, Trainer)
        assert isinstance(pl_module, LightningModule)

    class TestCallback(Callback):
        def __init__(self):
            super().__init__()
            self.on_init_start_called = False
            self.on_init_end_called = False
            self.on_epoch_start_called = False
            self.on_epoch_end_called = False
            self.on_batch_start_called = False
            self.on_batch_end_called = False
            self.on_train_start_called = False
            self.on_train_end_called = False
            self.on_validation_start_called = False
            self.on_validation_end_called = False
            self.on_test_start_called = False
            self.on_test_end_called = False

        def on_init_start(self, trainer):
            assert isinstance(trainer, Trainer)
            self.on_init_start_called = True

        def on_init_end(self, trainer):
            assert isinstance(trainer, Trainer)
            self.on_init_end_called = True

        def on_epoch_start(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_epoch_start_called = True

        def on_epoch_end(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_epoch_end_called = True

        def on_batch_start(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_batch_start_called = True

        def on_batch_end(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_batch_end_called = True

        def on_train_start(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_train_start_called = True

        def on_train_end(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_train_end_called = True

        def on_validation_start(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_validation_start_called = True

        def on_validation_end(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_validation_end_called = True

        def on_test_start(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_test_start_called = True

        def on_test_end(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_test_end_called = True

    test_callback = TestCallback()

    trainer_options = {
        'callbacks': [test_callback],
        'max_epochs': 1,
        'val_percent_check': 0.1,
        'train_percent_check': 0.2,
        'show_progress_bar': False
    }

    assert not test_callback.on_init_start_called
    assert not test_callback.on_init_end_called
    assert not test_callback.on_epoch_start_called
    assert not test_callback.on_epoch_start_called
    assert not test_callback.on_batch_start_called
    assert not test_callback.on_batch_end_called
    assert not test_callback.on_train_start_called
    assert not test_callback.on_train_end_called
    assert not test_callback.on_validation_start_called
    assert not test_callback.on_validation_end_called
    assert not test_callback.on_test_start_called
    assert not test_callback.on_test_end_called

    # fit model
    trainer = Trainer(**trainer_options)

    assert trainer.callbacks[0] == test_callback
    assert test_callback.on_init_start_called
    assert test_callback.on_init_end_called
    assert not test_callback.on_epoch_start_called
    assert not test_callback.on_epoch_start_called
    assert not test_callback.on_batch_start_called
    assert not test_callback.on_batch_end_called
    assert not test_callback.on_train_start_called
    assert not test_callback.on_train_end_called
    assert not test_callback.on_validation_start_called
    assert not test_callback.on_validation_end_called
    assert not test_callback.on_test_start_called
    assert not test_callback.on_test_end_called

    trainer.fit(model)

    assert test_callback.on_init_start_called
    assert test_callback.on_init_end_called
    assert test_callback.on_epoch_start_called
    assert test_callback.on_epoch_start_called
    assert test_callback.on_batch_start_called
    assert test_callback.on_batch_end_called
    assert test_callback.on_train_start_called
    assert test_callback.on_train_end_called
    assert test_callback.on_validation_start_called
    assert test_callback.on_validation_end_called
    assert not test_callback.on_test_start_called
    assert not test_callback.on_test_end_called

    trainer.test()

    assert test_callback.on_test_start_called
    assert test_callback.on_test_end_called
Пример #8
0
class LightningTrainer(BaseTrainer):
    def __init__(self, config: DictConfig):
        super().__init__(config)
        self.trainer = None
        self.trainer_config = self.config.trainer.params
        self.data_module = None

    def load(self):
        super().load()
        self._calculate_max_updates()
        self._load_loggers()
        self._load_trainer()

    def _load_trainer(self):
        lightning_params = self.trainer_config

        with omegaconf.open_dict(lightning_params):
            lightning_params.pop("max_steps")
            lightning_params.pop("max_epochs")

        lightning_params_dict = OmegaConf.to_container(lightning_params,
                                                       resolve=True)
        self.trainer = Trainer(callbacks=self._callbacks,
                               max_steps=self._max_updates,
                               default_root_dir=get_mmf_env(key="log_dir"),
                               **lightning_params_dict)

    def configure_device(self) -> None:
        pass

    def configure_seed(self) -> None:
        seed = self.config.training.seed
        seed_everything(seed)

    def _load_loggers(self) -> None:
        self.tb_writer = None
        if self.training_config.tensorboard:
            # TODO: @sash PL logger upgrade
            log_dir = setup_output_folder(folder_only=True)
            env_tb_logdir = get_mmf_env(key="tensorboard_logdir")
            if env_tb_logdir:
                log_dir = env_tb_logdir

            self.tb_writer = TensorboardLogger(log_dir)

    def load_datasets(self) -> None:
        logger.info("Loading datasets")
        data_module = MultiDataModule(self.config)
        self.data_module = data_module

        self.train_loader = data_module.train_dataloader()
        self.val_loader = data_module.val_dataloader()
        self.test_loader = data_module.test_dataloader()

    def load_model(self) -> None:
        logger.info("Loading models")

        attributes = self.config.model_config[self.config.model]
        if isinstance(attributes, str):
            attributes = self.config.model_config[attributes]
        with omegaconf.open_dict(attributes):
            attributes.model = self.config.model

        self.model = build_model(attributes)
        self.model.is_pl_enabled = True
        self.model.build_meters(self.run_type)

    def load_optimizer(self) -> None:
        logger.info("Loading optimizer: noop for lightning")

    def load_metrics(self) -> None:
        logger.info("Loading metrics")
        metrics = self.config.evaluation.get("metrics", [])
        # moved metrics into the model object
        self.model.metrics = Metrics(metrics)

    def configure_callbacks(self) -> None:
        self._callbacks = [LightningLoopCallback(self)]

    def train(self) -> None:
        logger.info("===== Model =====")
        logger.info(self.model)
        print_model_parameters(self.model)

        logger.info("Starting training...")

        if "train" not in self.run_type:
            self.inference()
            return

        self.trainer.fit(self.model, self.data_module)
        # TODO: Look for a better way to hook this
        self.data_module.teardown()

    def inference(self) -> None:
        logger.info("Starting inference...")
        # TODO: @sash coming soon
        pass

    def _calculate_max_updates(self) -> None:
        self._max_updates = self.trainer_config.max_steps
        self._max_epochs = self.trainer_config.max_epochs
        if self._max_updates is None and self._max_epochs is None:
            raise ValueError(
                "Neither max_updates nor max_epochs is specified.")

        self._max_updates, max_epochs = get_max_updates(
            self._max_updates,
            self._max_epochs,
            self.train_loader,
            self.trainer_config.accumulate_grad_batches,
        )
        self._max_epochs = math.ceil(max_epochs)
        return self._max_updates
def test_deepspeed_fp32_works(tmpdir):
    model = BoringModel()
    trainer = Trainer(default_root_dir=tmpdir, gpus=1, strategy="deepspeed_stage_3", fast_dev_run=True)
    trainer.fit(model)
Пример #10
0
def test_ckpt_metric_names_results(tmpdir):
    class ResultLog(BoringModel):
        def training_step(self, batch, batch_idx):
            y_hat = self(batch)

            # calculate loss
            loss_val = self.loss(batch, y_hat)
            log_val = loss_val

            # alternate between tensors and scalars for "log" and "progress_bar"
            if batch_idx % 2 == 0:
                log_val = log_val.item()

            result = pl.core.step_result.TrainResult(loss_val)
            result.log('some_val',
                       log_val * log_val,
                       prog_bar=True,
                       logger=False)
            result.log('train_some_val', log_val * log_val)
            return result

        def validation_step(self, batch, batch_idx):
            y_hat = self(batch)

            loss_val = self.loss(batch, y_hat)

            # acc
            labels_hat = torch.argmax(y_hat, dim=1)
            val_acc = torch.sum(batch == labels_hat).item() / (len(batch) *
                                                               1.0)
            val_acc = torch.tensor(val_acc).type_as(batch)

            result = pl.core.step_result.EvalResult(checkpoint_on=loss_val,
                                                    early_stop_on=loss_val)
            result.log_dict({
                'val_loss': loss_val,
                'val_acc': val_acc,
            })
            return result

    model = ResultLog()
    model.training_step_end = None
    model.training_epoch_end = None
    model.validation_step_end = None
    model.validation_epoch_end = None

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        gradient_clip_val=1.0,
        overfit_batches=0.20,
        progress_bar_refresh_rate=0,
        limit_train_batches=0.01,
        limit_val_batches=0.01,
        callbacks=[
            ModelCheckpoint(monitor='early_stop_on',
                            dirpath=tmpdir,
                            filename="{val_loss:.2f}")
        ],
    )

    trainer.fit(model)

    # make sure the checkpoint we saved has the metric in the name
    ckpts = os.listdir(tmpdir)
    ckpts = [x for x in ckpts if "val_loss" in x]
    assert len(ckpts) == 1
    val = re.sub("[^0-9.]", "", ckpts[0])
    assert len(val) > 3
Пример #11
0
def test_cpu_slurm_save_load(tmpdir):
    """Verify model save/load/checkpoint on CPU."""
    hparams = tutils.get_default_hparams()
    model = EvalModelTemplate(hparams)

    # logger file to get meta
    logger = tutils.get_default_logger(tmpdir)
    version = logger.version

    # fit model
    trainer = Trainer(
        max_epochs=1,
        logger=logger,
        checkpoint_callback=ModelCheckpoint(tmpdir)
    )
    result = trainer.fit(model)
    real_global_step = trainer.global_step

    # traning complete
    assert result == 1, 'cpu model failed to complete'

    # predict with trained model before saving
    # make a prediction
    dataloaders = model.test_dataloader()
    if not isinstance(dataloaders, list):
        dataloaders = [dataloaders]

    for dataloader in dataloaders:
        for batch in dataloader:
            break

    x, y = batch
    x = x.view(x.size(0), -1)

    model.eval()
    pred_before_saving = model(x)

    # test HPC saving
    # simulate snapshot on slurm
    saved_filepath = trainer.hpc_save(tmpdir, logger)
    assert os.path.exists(saved_filepath)

    # new logger file to get meta
    logger = tutils.get_default_logger(tmpdir, version=version)

    trainer = Trainer(
        max_epochs=1,
        logger=logger,
        checkpoint_callback=ModelCheckpoint(tmpdir),
    )
    model = EvalModelTemplate(hparams)

    # set the epoch start hook so we can predict before the model does the full training
    def assert_pred_same():
        assert trainer.global_step == real_global_step and trainer.global_step > 0

        # predict with loaded model to make sure answers are the same
        trainer.model.eval()
        new_pred = trainer.model(x)
        assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1

    model.on_epoch_start = assert_pred_same

    # by calling fit again, we trigger training, loading weights from the cluster
    # and our hook to predict using current model before any more weight updates
    trainer.fit(model)
Пример #12
0
def main():

    if pre_model == False:

        # Data Module and Model
        dm = DataModule(bs)
        dm.prepare_data()

        model = FCN(num_classes, learning_rate)

        # Running our Model
        trainer = Trainer(max_epochs=epochs,
                          fast_dev_run=False,
                          gpus=1,
                          profiler=False,
                          progress_bar_refresh_rate=1,
                          logger=tboard_logger)
        trainer.fit(model, dm)

    if pre_model == True:

        model = models.segmentation.fcn_resnet101(pretrained=True).eval()

        transform_VOC = T.Compose([
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        full_dataset = torchvision.datasets.VOCSegmentation(
            "/projects/brfi3983/",
            image_set='train',
            download=True,
            transform=transform_VOC)
        test_dataset = torchvision.datasets.VOCSegmentation(
            "/projects/brfi3983/",
            image_set='val',
            download=True,
            transform=transform_VOC)

        # # Loading our custom image and transforming it to our pretrained network (uncomment which image to use)
        # image = Image.open('bird.jpg')
        # image = Image.open('person.jpg')
        image = Image.open('dog.jpg')

        trf = T.Compose([
            T.Resize(800),
            T.CenterCrop(720),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        # Convert single image to single batch
        inp = trf(image).unsqueeze(0)
        out = model(inp)['out']
        print(f'Output shape: {out.shape}')

        # Seeing which classes are most dominant across its depth
        om = torch.argmax(out.squeeze(), dim=0).detach().cpu().numpy()

        # Plotting Class distribution
        classes = np.unique(om)
        plt.hist(om, bins=num_classes)
        plt.title('Occurrences of Classes')
        plt.xlabel('Class')
        plt.ylabel('Count')

        # Plotting original image across segmented image
        fig, (ax1, ax2) = plt.subplots(1, 2)
        plt.suptitle('Image Segmentation')
        ax1.set_title('Original Image')
        ax1.imshow(image)
        ax2.set_title('Segmented Image')
        ax2.imshow(om, cmap='YlOrBr')

        plt.show()
Пример #13
0
def test_write_predictions(tmpdir, option: int, do_train: bool, gpus: int):
    class CustomBoringModel(BoringModel):
        def test_step(self, batch, batch_idx, optimizer_idx=None):
            output = self(batch)
            test_loss = self.loss(batch, output)
            self.log('test_loss', test_loss)

            batch_size = batch.size(0)
            lst_of_str = [
                random.choice(['dog', 'cat']) for i in range(batch_size)
            ]
            lst_of_int = [random.randint(500, 1000) for i in range(batch_size)]
            lst_of_lst = [[x] for x in lst_of_int]
            lst_of_dict = [{k: v} for k, v in zip(lst_of_str, lst_of_int)]

            prediction_file = getattr(self, 'prediction_file',
                                      'predictions.pt')

            lazy_ids = torch.arange(batch_idx * batch_size,
                                    batch_idx * batch_size + batch_size)

            # Base
            if option == 0:
                self.write_prediction('idxs', lazy_ids, prediction_file)
                self.write_prediction('preds', output, prediction_file)

            # Check mismatching tensor len
            elif option == 1:
                self.write_prediction('idxs', torch.cat((lazy_ids, lazy_ids)),
                                      prediction_file)
                self.write_prediction('preds', output, prediction_file)

            # write multi-dimension
            elif option == 2:
                self.write_prediction('idxs', lazy_ids, prediction_file)
                self.write_prediction('preds', output, prediction_file)
                self.write_prediction('x', batch, prediction_file)

            # write str list
            elif option == 3:
                self.write_prediction('idxs', lazy_ids, prediction_file)
                self.write_prediction('vals', lst_of_str, prediction_file)

            # write int list
            elif option == 4:
                self.write_prediction('idxs', lazy_ids, prediction_file)
                self.write_prediction('vals', lst_of_int, prediction_file)

            # write nested list
            elif option == 5:
                self.write_prediction('idxs', lazy_ids, prediction_file)
                self.write_prediction('vals', lst_of_lst, prediction_file)

            # write dict list
            elif option == 6:
                self.write_prediction('idxs', lazy_ids, prediction_file)
                self.write_prediction('vals', lst_of_dict, prediction_file)

            elif option == 7:
                self.write_prediction_dict({
                    'idxs': lazy_ids,
                    'preds': output
                }, prediction_file)

    prediction_file = Path(tmpdir) / 'predictions.pt'

    dm = BoringDataModule()
    model = CustomBoringModel()
    model.test_epoch_end = None
    model.prediction_file = prediction_file.as_posix()

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=3,
        weights_summary=None,
        deterministic=True,
        gpus=gpus,
    )

    # Prediction file shouldn't exist yet because we haven't done anything
    assert not prediction_file.exists()

    if do_train:
        trainer.fit(model, dm)
        assert trainer.state.finished, f"Training failed with {trainer.state}"
        trainer.test(datamodule=dm)
    else:
        trainer.test(model, datamodule=dm)

    # check prediction file now exists and is of expected length
    assert prediction_file.exists()
    predictions = torch.load(prediction_file)
    assert len(predictions) == len(dm.random_test)
def cli_main():
    parser = ArgumentParser()
    parser.add_argument("--DATA_PATH",
                        type=str,
                        help="path to folders with images")
    parser.add_argument(
        "--encoder",
        default=None,
        type=str,
        help=
        "encoder to initialize. Can accept SimCLR model checkpoint or just encoder name in from encoders_dali"
    )
    parser.add_argument("--batch_size",
                        default=128,
                        type=int,
                        help="batch size for SSL")
    parser.add_argument("--num_workers",
                        default=1,
                        type=int,
                        help="number of workers to use to fetch data")
    parser.add_argument(
        "--hidden_dims",
        default=128,
        type=int,
        help=
        "hidden dimensions in classification layer added onto model for finetuning"
    )
    parser.add_argument("--epochs",
                        default=400,
                        type=int,
                        help="number of epochs to train model")
    parser.add_argument("--lr",
                        default=1e-3,
                        type=float,
                        help="learning rate for training model")
    parser.add_argument(
        "--patience",
        default=-1,
        type=int,
        help=
        "automatically cuts off training if validation does not drop for (patience) epochs. Leave blank to have no validation based early stopping."
    )
    parser.add_argument("--val_split",
                        default=0.2,
                        type=float,
                        help="percent in validation data")
    parser.add_argument(
        "--withhold_split",
        default=0,
        type=float,
        help=
        "decimal from 0-1 representing how much of the training data to withold from either training or validation. Used for experimenting with labels neeeded"
    )
    parser.add_argument("--gpus",
                        default=1,
                        type=int,
                        help="number of gpus to use for training")
    parser.add_argument("--log_name",
                        type=str,
                        help="name of model to log on wandb and locally")
    parser.add_argument(
        "--online_eval",
        default=False,
        type=bool,
        help="Do finetuning on model if labels are provided as a sanity check")

    args = parser.parse_args()
    DATA_PATH = args.DATA_PATH
    batch_size = args.batch_size
    num_workers = args.num_workers
    hidden_dims = args.hidden_dims
    epochs = args.epochs
    lr = args.lr
    patience = args.patience
    val_split = args.val_split
    withhold = args.withhold_split
    gpus = args.gpus
    encoder = args.encoder
    log_name = 'SIMCLR_SSL_' + args.log_name + '.ckpt'
    online_eval = args.online_eval

    wandb_logger = WandbLogger(name=log_name, project='SpaceForce')
    checkpointed = '.ckpt' in encoder
    if checkpointed:
        print('Resuming SSL Training from Model Checkpoint')
        try:
            model = SIMCLR.load_from_checkpoint(checkpoint_path=encoder)
            embedding_size = model.embedding_size
        except Exception as e:
            print(e)
            print(
                'invalid checkpoint to initialize SIMCLR. This checkpoint needs to include the encoder and projection and is of the SIMCLR class from this library. Will try to initialize just the encoder'
            )
            checkpointed = False

    elif not checkpointed:
        encoder, embedding_size = load_encoder(encoder)
        model = SIMCLR(encoder=encoder,
                       embedding_size=embedding_size,
                       gpus=gpus,
                       epochs=epochs,
                       DATA_PATH=DATA_PATH,
                       withhold=withhold,
                       batch_size=batch_size,
                       val_split=val_split,
                       hidden_dims=hidden_dims,
                       train_transform=SimCLRTrainDataTransform,
                       val_transform=SimCLRTrainDataTransform,
                       num_workers=num_workers,
                       lr=lr)

    online_evaluator = SSLOnlineEvaluator(drop_p=0.,
                                          hidden_dim=None,
                                          z_dim=embedding_size,
                                          num_classes=model.num_classes,
                                          dataset='None')

    cbs = []
    backend = 'dp'

    if patience > 0:
        cb = EarlyStopping('val_loss', patience=patience)
        cbs.append(cb)

    if online_eval:
        cbs.append(online_evaluator)
        backend = 'ddp'

    trainer = Trainer(
        gpus=gpus,
        max_epochs=epochs,
        progress_bar_refresh_rate=5,
        callbacks=cbs,
        distributed_backend=f'{backend}' if args.gpus > 1 else None,
        logger=wandb_logger,
        enable_pl_optimizer=True)

    print('USING BACKEND______________________________ ', backend)
    trainer.fit(model)
    Path(f"./models/SSL").mkdir(parents=True, exist_ok=True)
    trainer.save_checkpoint(f"./models/SSL/{log_name}")
Пример #15
0
def test_cpu_slurm_save_load():
    """
    Verify model save/load/checkpoint on CPU
    :return:
    """
    hparams = get_hparams()
    model = LightningTestModel(hparams)

    save_dir = init_save_dir()

    # exp file to get meta
    exp = get_exp(False)
    exp.argparse(hparams)
    exp.save()

    cluster_a = SlurmCluster()
    trainer_options = dict(max_nb_epochs=1,
                           cluster=cluster_a,
                           experiment=exp,
                           checkpoint_callback=ModelCheckpoint(save_dir))

    # fit model
    trainer = Trainer(**trainer_options)
    result = trainer.fit(model)
    real_global_step = trainer.global_step

    # traning complete
    assert result == 1, 'amp + ddp model failed to complete'

    # predict with trained model before saving
    # make a prediction
    for batch in model.test_dataloader:
        break

    x, y = batch
    x = x.view(x.size(0), -1)

    model.eval()
    pred_before_saving = model(x)

    # test registering a save function
    trainer.enable_auto_hpc_walltime_manager()

    # test HPC saving
    # simulate snapshot on slurm
    saved_filepath = trainer.hpc_save(save_dir, exp)
    assert os.path.exists(saved_filepath)

    # wipe-out trainer and model
    # retrain with not much data... this simulates picking training back up after slurm
    # we want to see if the weights come back correctly
    continue_tng_hparams = get_hparams(continue_training=True,
                                       hpc_exp_number=cluster_a.hpc_exp_number)
    trainer_options = dict(
        max_nb_epochs=1,
        cluster=SlurmCluster(continue_tng_hparams),
        experiment=exp,
        checkpoint_callback=ModelCheckpoint(save_dir),
    )
    trainer = Trainer(**trainer_options)
    model = LightningTestModel(hparams)

    # set the epoch start hook so we can predict before the model does the full training
    def assert_pred_same():
        assert trainer.global_step == real_global_step and trainer.global_step > 0

        # predict with loaded model to make sure answers are the same
        trainer.model.eval()
        new_pred = trainer.model(x)
        assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1

    model.on_epoch_start = assert_pred_same

    # by calling fit again, we trigger training, loading weights from the cluster
    # and our hook to predict using current model before any more weight updates
    trainer.fit(model)

    clear_save_dir()
Пример #16
0
def test_dp_resume(tmpdir):
    """Make sure DP continues training correctly."""
    hparams = EvalModelTemplate.get_default_hparams()
    model = EvalModelTemplate(**hparams)

    trainer_options = dict(
        max_epochs=1,
        gpus=2,
        distributed_backend='dp',
        default_root_dir=tmpdir,
    )

    # get logger
    logger = tutils.get_default_logger(tmpdir)

    # exp file to get weights
    # logger file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    # add these to the trainer options
    trainer_options['logger'] = logger
    trainer_options['checkpoint_callback'] = checkpoint

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.is_slurm_managing_tasks = True
    result = trainer.fit(model)

    # track epoch before saving. Increment since we finished the current epoch, don't want to rerun
    real_global_epoch = trainer.current_epoch + 1

    # correct result and ok accuracy
    assert result == 1, 'amp + dp model failed to complete'

    # ---------------------------
    # HPC LOAD/SAVE
    # ---------------------------
    # save
    trainer.checkpoint_connector.hpc_save(tmpdir, logger)

    # init new trainer
    new_logger = tutils.get_default_logger(tmpdir, version=logger.version)
    trainer_options['logger'] = new_logger
    trainer_options['checkpoint_callback'] = ModelCheckpoint(tmpdir)
    trainer_options['limit_train_batches'] = 0.5
    trainer_options['limit_val_batches'] = 0.2
    trainer_options['max_epochs'] = 1
    new_trainer = Trainer(**trainer_options)

    # set the epoch start hook so we can predict before the model does the full training
    def assert_good_acc():
        assert new_trainer.current_epoch == real_global_epoch and new_trainer.current_epoch > 0

        # if model and state loaded correctly, predictions will be good even though we
        # haven't trained with the new loaded model
        dp_model = new_trainer.model
        dp_model.eval()

        dataloader = trainer.train_dataloader
        tpipes.run_prediction(dataloader, dp_model, dp=True)

    # new model
    model = EvalModelTemplate(**hparams)
    model.on_train_start = assert_good_acc

    # fit new model which should load hpc weights
    new_trainer.fit(model)

    # test freeze on gpu
    model.freeze()
    model.unfreeze()
Пример #17
0
def test_amp_gpu_ddp_slurm_managed():
    """
    Make sure DDP + AMP work
    :return:
    """
    if not torch.cuda.is_available():
        warnings.warn('test_amp_gpu_ddp cannot run.'
                      ' Rerun on a GPU node to run this test')
        return
    if not torch.cuda.device_count() > 1:
        warnings.warn('test_amp_gpu_ddp cannot run.'
                      ' Rerun on a node with 2+ GPUs to run this test')
        return

    # simulate setting slurm flags
    os.environ['MASTER_PORT'] = str(np.random.randint(12000, 19000, 1)[0])
    os.environ['SLURM_LOCALID'] = str(0)

    hparams = get_hparams()
    model = LightningTestModel(hparams)

    trainer_options = dict(progress_bar=True,
                           max_nb_epochs=1,
                           gpus=[0],
                           distributed_backend='ddp',
                           use_amp=True)

    save_dir = init_save_dir()

    # exp file to get meta
    exp = get_exp(False)
    exp.argparse(hparams)
    exp.save()

    # exp file to get weights
    checkpoint = ModelCheckpoint(save_dir)

    # add these to the trainer options
    trainer_options['checkpoint_callback'] = checkpoint
    trainer_options['experiment'] = exp

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.is_slurm_managing_tasks = True
    result = trainer.fit(model)

    # correct result and ok accuracy
    assert result == 1, 'amp + ddp model failed to complete'

    # test root model address
    assert trainer.resolve_root_node_address('abc') == 'abc'
    assert trainer.resolve_root_node_address('abc[23]') == 'abc23'
    assert trainer.resolve_root_node_address('abc[23-24]') == 'abc23'
    assert trainer.resolve_root_node_address(
        'abc[23-24, 45-40, 40]') == 'abc23'

    # test model loading with a map_location
    map_location = 'cuda:1'
    pretrained_model = load_model(exp, save_dir, True, map_location)

    # test model preds
    run_prediction(model.test_dataloader, pretrained_model)

    if trainer.use_ddp:
        # on hpc this would work fine... but need to hack it for the purpose of the test
        trainer.model = pretrained_model
        trainer.optimizers, trainer.lr_schedulers = pretrained_model.configure_optimizers(
        )

    # test HPC loading / saving
    trainer.hpc_save(save_dir, exp)
    trainer.hpc_load(save_dir, on_gpu=True)

    # test freeze on gpu
    model.freeze()
    model.unfreeze()

    clear_save_dir()
Пример #18
0
def test_lr_scheduler_step_hook(tmpdir):
    """Test that custom lr scheduler works and `lr_scheduler_step` is called at appropriate time."""
    class CustomEpochScheduler:
        def __init__(self, optimizer):
            self.optimizer = optimizer

        def step(self, epoch):
            ...

        def state_dict(self):
            ...

        def load_state_dict(self, state_dict):
            ...

    class CustomBoringModel(BoringModel):
        def training_step(self, batch, batch_idx, optimizer_idx=0):
            return super().training_step(batch, batch_idx)

        def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
            # step-level
            if optimizer_idx == 0:
                super().lr_scheduler_step(scheduler, optimizer_idx, metric)
            # epoch-level
            elif optimizer_idx == 1:
                scheduler.step(epoch=self.current_epoch)

        def configure_optimizers(self):
            opt1 = torch.optim.SGD(self.layer.parameters(), lr=1e-2)
            lr_scheduler1 = {
                "scheduler": torch.optim.lr_scheduler.StepLR(opt1,
                                                             step_size=1),
                "interval": "step"
            }
            opt2 = torch.optim.SGD(self.layer.parameters(), lr=1e-2)
            lr_scheduler2 = CustomEpochScheduler(opt2)
            return {
                "optimizer": opt1,
                "lr_scheduler": lr_scheduler1
            }, {
                "optimizer": opt2,
                "lr_scheduler": lr_scheduler2,
            }

    model = CustomBoringModel()
    model.training_epoch_end = None
    max_epochs = 3
    limit_train_batches = 2
    trainer = Trainer(
        default_root_dir=tmpdir,
        enable_checkpointing=False,
        logger=False,
        max_epochs=max_epochs,
        limit_train_batches=limit_train_batches,
        limit_val_batches=0,
    )

    with patch.object(CustomEpochScheduler,
                      "step") as mock_method_epoch, patch.object(
                          torch.optim.lr_scheduler.StepLR,
                          "step") as mock_method_step:
        trainer.fit(model)

    assert mock_method_epoch.mock_calls == [
        call(epoch=e) for e in range(max_epochs)
    ]
    # first step is called by PyTorch _LRScheduler
    assert mock_method_step.call_count == max_epochs * limit_train_batches + 1
def test_lightning_optimizer_automatic_optimization_optimizer_zero_grad_make_optimizer_step(
        tmpdir):
    """
    Test lightning optimize works with optimizer_zero_grad overrides and make_optimizer_step in automatic_optimization
    """

    try:
        with patch("torch.optim.Adam.zero_grad") as adam_zero_grad, \
             patch("torch.optim.SGD.zero_grad") as sgd_zero_grad:

            class TestModel(BoringModel):
                def training_step(self, batch, batch_idx, optimizer_idx=None):
                    output = self.layer(batch)
                    loss = self.loss(batch, output)
                    return {"loss": loss}

                def training_epoch_end(self, outputs):
                    outputs = sum(outputs, [])
                    torch.stack([x["loss"] for x in outputs]).mean()

                def optimizer_zero_grad(self, epoch: int, batch_idx: int,
                                        optimizer: Optimizer,
                                        optimizer_idx: int):
                    if optimizer_idx == 0:
                        if batch_idx % 2 == 0:
                            optimizer.zero_grad()

                    if optimizer_idx == 1:
                        if batch_idx % 5 == 0:
                            optimizer.zero_grad()

                def optimizer_step(
                    self,
                    epoch,
                    batch_idx,
                    optimizer,
                    optimizer_idx,
                    optimizer_closure,
                    on_tpu,
                    using_native_amp,
                    using_lbfgs,
                ):

                    assert optimizer_closure.__name__ == "train_step_and_backward_closure"

                    if optimizer_idx == 0:
                        optimizer.step(closure=optimizer_closure,
                                       make_optimizer_step=batch_idx % 3 == 0)
                        return
                    optimizer.step(closure=optimizer_closure)

                def configure_optimizers(self):
                    optimizer_1 = torch.optim.SGD(self.layer.parameters(),
                                                  lr=0.1)
                    optimizer_2 = torch.optim.Adam(self.layer.parameters(),
                                                   lr=0.1)
                    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_1,
                                                                   step_size=1)
                    return [optimizer_1, optimizer_2], [lr_scheduler]

            model = TestModel()
            trainer = Trainer(
                default_root_dir=tmpdir,
                limit_train_batches=20,
                limit_val_batches=1,
                max_epochs=1,
                weights_summary=None,
            )
            trainer.fit(model)

            assert adam_zero_grad.call_count == 4
            assert sgd_zero_grad.call_count == 10

    except MisconfigurationException as e:
        assert "When overriding LightningModule `optimizer_zero_grad`, make_optimizer_step is not allowed" in str(
            e)
Пример #20
0
def test_trainer_callback_system(tmpdir):
    """Test the callback system."""

    hparams = EvalModelTemplate.get_default_hparams()
    model = EvalModelTemplate(**hparams)

    def _check_args(trainer, pl_module):
        assert isinstance(trainer, Trainer)
        assert isinstance(pl_module, LightningModule)

    class TestCallback(Callback):
        def __init__(self):
            super().__init__()
            self.setup_called = False
            self.teardown_called = False
            self.on_init_start_called = False
            self.on_init_end_called = False
            self.on_fit_start_called = False
            self.on_fit_end_called = False
            self.on_sanity_check_start_called = False
            self.on_sanity_check_end_called = False
            self.on_epoch_start_called = False
            self.on_epoch_end_called = False
            self.on_batch_start_called = False
            self.on_batch_end_called = False
            self.on_train_batch_start_called = False
            self.on_train_batch_end_called = False
            self.on_validation_batch_start_called = False
            self.on_validation_batch_end_called = False
            self.on_test_batch_start_called = False
            self.on_test_batch_end_called = False
            self.on_train_start_called = False
            self.on_train_end_called = False
            self.on_pretrain_routine_start_called = False
            self.on_pretrain_routine_end_called = False
            self.on_validation_start_called = False
            self.on_validation_end_called = False
            self.on_test_start_called = False
            self.on_test_end_called = False
            self.on_after_backward_called = False
            self.on_before_zero_grad_called = False

        def setup(self, trainer, pl_module, stage: str):
            assert isinstance(trainer, Trainer)
            self.setup_called = True

        def teardown(self, trainer, pl_module, step: str):
            assert isinstance(trainer, Trainer)
            self.teardown_called = True

        def on_init_start(self, trainer):
            assert isinstance(trainer, Trainer)
            self.on_init_start_called = True

        def on_init_end(self, trainer):
            assert isinstance(trainer, Trainer)
            self.on_init_end_called = True

        def on_fit_start(self, trainer, pl_module):
            assert isinstance(trainer, Trainer)
            self.on_fit_start_called = True

        def on_fit_end(self, trainer, pl_module):
            assert isinstance(trainer, Trainer)
            self.on_fit_end_called = True

        def on_sanity_check_start(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_sanity_check_start_called = True

        def on_sanity_check_end(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_sanity_check_end_called = True

        def on_epoch_start(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_epoch_start_called = True

        def on_epoch_end(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_epoch_end_called = True

        def on_batch_start(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_batch_start_called = True

        def on_batch_end(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_batch_end_called = True

        def on_train_batch_start(self, trainer, pl_module, batch, batch_idx,
                                 dataloader_idx):
            _check_args(trainer, pl_module)
            self.on_train_batch_start_called = True

        def on_train_batch_end(self, trainer, pl_module, outputs, batch,
                               batch_idx, dataloader_idx):
            _check_args(trainer, pl_module)
            self.on_train_batch_end_called = True

        def on_validation_batch_start(self, trainer, pl_module, batch,
                                      batch_idx, dataloader_idx):
            _check_args(trainer, pl_module)
            self.on_validation_batch_start_called = True

        def on_validation_batch_end(self, trainer, pl_module, outputs, batch,
                                    batch_idx, dataloader_idx):
            _check_args(trainer, pl_module)
            self.on_validation_batch_end_called = True

        def on_test_batch_start(self, trainer, pl_module, batch, batch_idx,
                                dataloader_idx):
            _check_args(trainer, pl_module)
            self.on_test_batch_start_called = True

        def on_test_batch_end(self, trainer, pl_module, outputs, batch,
                              batch_idx, dataloader_idx):
            _check_args(trainer, pl_module)
            self.on_test_batch_end_called = True

        def on_train_start(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_train_start_called = True

        def on_train_end(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_train_end_called = True

        def on_pretrain_routine_start(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_pretrain_routine_start_called = True

        def on_pretrain_routine_end(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_pretrain_routine_end_called = True

        def on_validation_start(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_validation_start_called = True

        def on_validation_end(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_validation_end_called = True

        def on_test_start(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_test_start_called = True

        def on_test_end(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_test_end_called = True

        def on_after_backward(self, trainer, pl_module):
            _check_args(trainer, pl_module)
            self.on_after_backward_called = True

        def on_before_zero_grad(self, trainer, pl_module, optimizer):
            _check_args(trainer, pl_module)
            self.on_before_zero_grad_called = True

    test_callback = TestCallback()

    trainer_options = dict(
        default_root_dir=tmpdir,
        callbacks=[test_callback],
        max_epochs=1,
        limit_val_batches=0.1,
        limit_train_batches=0.2,
        progress_bar_refresh_rate=0,
    )

    assert not test_callback.setup_called
    assert not test_callback.teardown_called
    assert not test_callback.on_init_start_called
    assert not test_callback.on_init_end_called
    assert not test_callback.on_fit_start_called
    assert not test_callback.on_fit_end_called
    assert not test_callback.on_sanity_check_start_called
    assert not test_callback.on_sanity_check_end_called
    assert not test_callback.on_epoch_start_called
    assert not test_callback.on_epoch_start_called
    assert not test_callback.on_batch_start_called
    assert not test_callback.on_batch_end_called
    assert not test_callback.on_train_batch_start_called
    assert not test_callback.on_train_batch_end_called
    assert not test_callback.on_validation_batch_start_called
    assert not test_callback.on_validation_batch_end_called
    assert not test_callback.on_test_batch_start_called
    assert not test_callback.on_test_batch_end_called
    assert not test_callback.on_train_start_called
    assert not test_callback.on_train_end_called
    assert not test_callback.on_pretrain_routine_start_called
    assert not test_callback.on_pretrain_routine_end_called
    assert not test_callback.on_validation_start_called
    assert not test_callback.on_validation_end_called
    assert not test_callback.on_test_start_called
    assert not test_callback.on_test_end_called
    assert not test_callback.on_after_backward_called
    assert not test_callback.on_before_zero_grad_called

    # fit model
    trainer = Trainer(**trainer_options)

    assert trainer.callbacks[0] == test_callback
    assert test_callback.on_init_start_called
    assert test_callback.on_init_end_called
    assert not test_callback.setup_called
    assert not test_callback.teardown_called
    assert not test_callback.on_fit_start_called
    assert not test_callback.on_fit_end_called
    assert not test_callback.on_sanity_check_start_called
    assert not test_callback.on_sanity_check_end_called
    assert not test_callback.on_epoch_start_called
    assert not test_callback.on_epoch_start_called
    assert not test_callback.on_batch_start_called
    assert not test_callback.on_batch_end_called
    assert not test_callback.on_train_batch_start_called
    assert not test_callback.on_train_batch_end_called
    assert not test_callback.on_validation_batch_start_called
    assert not test_callback.on_validation_batch_end_called
    assert not test_callback.on_test_batch_start_called
    assert not test_callback.on_test_batch_end_called
    assert not test_callback.on_train_start_called
    assert not test_callback.on_train_end_called
    assert not test_callback.on_pretrain_routine_start_called
    assert not test_callback.on_pretrain_routine_end_called
    assert not test_callback.on_validation_start_called
    assert not test_callback.on_validation_end_called
    assert not test_callback.on_test_start_called
    assert not test_callback.on_test_end_called
    assert not test_callback.on_after_backward_called
    assert not test_callback.on_before_zero_grad_called

    trainer.fit(model)

    assert test_callback.setup_called
    assert test_callback.teardown_called
    assert test_callback.on_init_start_called
    assert test_callback.on_init_end_called
    assert test_callback.on_fit_start_called
    assert test_callback.on_fit_end_called
    assert test_callback.on_sanity_check_start_called
    assert test_callback.on_sanity_check_end_called
    assert test_callback.on_epoch_start_called
    assert test_callback.on_epoch_start_called
    assert test_callback.on_batch_start_called
    assert test_callback.on_batch_end_called
    assert test_callback.on_train_batch_start_called
    assert test_callback.on_train_batch_end_called
    assert test_callback.on_validation_batch_start_called
    assert test_callback.on_validation_batch_end_called
    assert test_callback.on_train_start_called
    assert test_callback.on_train_end_called
    assert test_callback.on_pretrain_routine_start_called
    assert test_callback.on_pretrain_routine_end_called
    assert test_callback.on_validation_start_called
    assert test_callback.on_validation_end_called
    assert not test_callback.on_test_batch_start_called
    assert not test_callback.on_test_batch_end_called
    assert not test_callback.on_test_start_called
    assert not test_callback.on_test_end_called
    assert test_callback.on_after_backward_called
    assert test_callback.on_before_zero_grad_called

    # reset setup teardown callback
    test_callback.teardown_called = False
    test_callback.setup_called = False

    test_callback = TestCallback()
    trainer_options.update(callbacks=[test_callback])
    trainer = Trainer(**trainer_options)
    trainer.test(model)

    assert test_callback.setup_called
    assert test_callback.teardown_called
    assert test_callback.on_test_batch_start_called
    assert test_callback.on_test_batch_end_called
    assert test_callback.on_test_start_called
    assert test_callback.on_test_end_called
    assert not test_callback.on_validation_start_called
    assert not test_callback.on_validation_end_called
    assert not test_callback.on_validation_batch_end_called
    assert not test_callback.on_validation_batch_start_called
    assert not test_callback.on_after_backward_called
    assert not test_callback.on_before_zero_grad_called
def result_collection_reload(accelerator="auto", devices=1, **kwargs):
    """This test is going to validate _ResultCollection is properly being reload and final accumulation with Fault
    Tolerant Training is correct."""
    class CustomException(Exception):
        pass

    class ExtendedBoringModel(BoringModel):
        def __init__(self):
            super().__init__()
            self.breaking_batch_idx = 3
            self.has_validated_sum = False
            self.dummy_metric = DummyMeanMetric()

        @property
        def results(self):
            return self.trainer.fit_loop._results

        def training_step(self, batch, batch_idx):

            # In the training step, we will accumulate metrics using batch_idx from 0 to 4
            # Without failure, we would expect to get `total=10 * world_size` and `num_batches=5 * world_size`
            # Therefore, compute on `epoch_end` should provide 2 as `10 / 5`.
            # However, below we will simulate a failure on `batch_idx=3`.

            if self.trainer.fit_loop.restarting:
                self.log("tracking", batch_idx, on_step=True, on_epoch=True)
                self.log("tracking_2",
                         batch_idx,
                         on_step=True,
                         on_epoch=True,
                         sync_dist=True)

                self.dummy_metric(batch_idx)
                self.log("tracking_metric",
                         self.dummy_metric,
                         on_step=True,
                         on_epoch=True)

                value = self.results["training_step.tracking_metric"].value
                value_2 = self.results["training_step.tracking"].value

                # On failure, the Metric states are being accumulated on rank 0 and zeroed-out on other ranks.
                # The shift indicates we failed while the state was `shift=sign(is_global_zero > 0) * [0..3]`
                shift = 0
                if devices == 2:
                    shift = 3 if self.trainer.is_global_zero else -3
                expected = sum(range(batch_idx + 1)) + shift
                assert expected == value == value_2
            else:
                if batch_idx == self.breaking_batch_idx:
                    # simulate failure mid epoch
                    raise CustomException

                self.log("tracking", batch_idx, on_step=True, on_epoch=True)
                self.log("tracking_2",
                         batch_idx,
                         on_step=True,
                         on_epoch=True,
                         sync_dist=True)

                self.dummy_metric(batch_idx)
                self.log("tracking_metric",
                         self.dummy_metric,
                         on_step=True,
                         on_epoch=True)

                value = self.results["training_step.tracking"].value
                assert value == sum(range(batch_idx + 1))

                value = self.results["training_step.tracking_2"]
                assert value == sum(range(batch_idx + 1))

            return super().training_step(batch, batch_idx)

        def on_train_epoch_end(self) -> None:
            if self.trainer.fit_loop.restarting:
                total = sum(range(5)) * devices
                metrics = self.results.metrics(on_step=False)
                assert self.results["training_step.tracking"].value == total
                assert metrics["callback"][
                    "tracking"] == self.dummy_metric.compute() == 2
                assert self.results["training_step.tracking_2"].value == total
                assert metrics["callback"][
                    "tracking_2"] == self.dummy_metric.compute() == 2
                self.has_validated_sum = True

    model = ExtendedBoringModel()
    trainer_kwargs = {
        "max_epochs": 1,
        "limit_train_batches": 5,
        "limit_val_batches": 0,
        "accelerator": accelerator,
        "devices": devices,
    }
    trainer_kwargs.update(kwargs)
    trainer = Trainer(**trainer_kwargs)

    with suppress(CustomException):
        trainer.fit(model)
    assert not model.has_validated_sum

    tmpdir = (trainer.strategy.broadcast(trainer_kwargs["default_root_dir"], 0)
              if devices >= 2 else trainer_kwargs["default_root_dir"])
    ckpt_path = os.path.join(tmpdir, ".pl_auto_save.ckpt")

    trainer = Trainer(**trainer_kwargs)
    trainer.fit(model, ckpt_path=ckpt_path)
    assert model.has_validated_sum
Пример #22
0
                                            "max")
profiler = pl.profiler.AdvancedProfiler(report_file)
callbacks = [SaveModelCallback(), checkpoint_callback, early_stopping]
tb_logger = pl_loggers.TensorBoardLogger(save_dir=logdir,
                                         name=exp_id,
                                         version=args.seed,
                                         log_graph=True)

epochs = args.epochs
trainer = Trainer(
    gpus=1 if device == "cuda" else 0,
    profiler=profiler,
    callbacks=callbacks,
    min_epochs=epochs,
    max_epochs=epochs + 5,
    progress_bar_refresh_rate=20,
    weights_summary="top",
    benchmark=True,
    logger=tb_logger,
)

trainer.fit(model, train_loader, valid_loader)

final_checkpoint_file = os.path.join(exp_log_dir, "final_epoch.pth")
torch.save(model.state_dict(), final_checkpoint_file)

model.eval()
model = model.to(device)
df = validate(model, valid_loader, "valid", exp_log_dir)
df = validate(model, test_loader, "test", exp_log_dir)
def train(param):
    if not isinstance(param, dict):
        args = vars(param)
    else:
        args = param

    framework = get_class_by_name('conditioned_separation', args['model'])
    if args['spec_type'] != 'magnitude':
        args['input_channels'] = 4

    if args['resume_from_checkpoint'] is None:
        if args['seed'] is not None:
            seed_everything(args['seed'])

    model = framework(**args)

    if args['last_activation'] != 'identity' and args[
            'spec_est_mode'] != 'masking':
        warn(
            'Please check if you really want to use a mapping-based spectrogram estimation method '
            'with a final activation function. ')
    ##########################################################

    # -- checkpoint
    ckpt_path = Path(args['ckpt_root_path'])
    mkdir_if_not_exists(ckpt_path)
    ckpt_path = ckpt_path.joinpath(args['model'])
    mkdir_if_not_exists(ckpt_path)
    run_id = args['run_id']
    ckpt_path = ckpt_path.joinpath(run_id)
    mkdir_if_not_exists(ckpt_path)
    save_top_k = args['save_top_k']

    checkpoint_callback = ModelCheckpoint(
        filepath=ckpt_path,
        save_top_k=save_top_k,
        verbose=False,
        monitor='val_loss',
        save_last=False,
        save_weights_only=args['save_weights_only'])
    args['checkpoint_callback'] = checkpoint_callback

    # -- early stop
    patience = args['patience']
    early_stop_callback = EarlyStopping(monitor='val_loss',
                                        min_delta=0.0,
                                        patience=patience,
                                        verbose=False)
    args['early_stop_callback'] = early_stop_callback

    if args['resume_from_checkpoint'] is not None:
        run_id = run_id + "_resume_" + args['resume_from_checkpoint']
        args['resume_from_checkpoint'] = Path(args['ckpt_root_path']).joinpath(
            args['model']).joinpath(args['run_id']).joinpath(
                args['resume_from_checkpoint'])
        args['resume_from_checkpoint'] = str(args['resume_from_checkpoint'])

    # -- logger setting
    log = args['log']
    if log == 'False':
        args['logger'] = False
    elif log == 'wandb':
        args['logger'] = WandbLogger(project='lasaft',
                                     tags=args['model'],
                                     offline=False,
                                     id=run_id)
        args['logger'].log_hyperparams(model.hparams)
        args['logger'].watch(model, log='all')
    elif log == 'tensorboard':
        raise NotImplementedError
    else:
        args['logger'] = True  # default
        default_save_path = 'etc/lightning_logs'
        mkdir_if_not_exists(default_save_path)

    valid_kwargs = inspect.signature(Trainer.__init__).parameters
    trainer_kwargs = dict(
        (name, args[name]) for name in valid_kwargs if name in args)

    # DATASET
    ##########################################################
    data_provider = DataProvider(**args)
    ##########################################################
    # Trainer Definition

    # Trainer
    trainer = Trainer(**trainer_kwargs)
    n_fft, hop_length, num_frame = args['n_fft'], args['hop_length'], args[
        'num_frame']
    train_data_loader = data_provider.get_train_dataloader(
        n_fft, hop_length, num_frame)
    valid_data_loader = data_provider.get_valid_dataloader(
        n_fft, hop_length, num_frame)

    for key in sorted(args.keys()):
        print('{}:{}'.format(key, args[key]))

    if args['auto_lr_find']:
        lr_finder = trainer.lr_find(model,
                                    train_data_loader,
                                    valid_data_loader,
                                    early_stop_threshold=None)
        print(lr_finder.results)
        # torch.save(lr_finder.results, 'lr_result.cache')
        new_lr = lr_finder.suggestion()
        print('new_lr_suggestion:', new_lr)
        return 0

    print(model)

    trainer.fit(model, train_data_loader, valid_data_loader)

    return None
Пример #24
0
def test_gradient_accumulation_scheduling(tmpdir):
    """
    Test grad accumulation by the freq of optimizer updates
    """
    tutils.reset_seed()

    # test incorrect configs
    with pytest.raises(IndexError):
        assert Trainer(accumulate_grad_batches={0: 3, 1: 4, 4: 6})
        assert Trainer(accumulate_grad_batches={-2: 3})

    with pytest.raises(TypeError):
        assert Trainer(accumulate_grad_batches={})
        assert Trainer(accumulate_grad_batches=[[2, 3], [4, 6]])
        assert Trainer(accumulate_grad_batches={1: 2, 3.: 4})
        assert Trainer(accumulate_grad_batches={1: 2.5, 3: 5})

    # test optimizer call freq matches scheduler
    def _optimizer_step(self,
                        epoch,
                        batch_idx,
                        optimizer,
                        optimizer_idx,
                        second_order_closure=None):
        # only test the first 12 batches in epoch
        if batch_idx < 12:
            if epoch == 0:
                # reset counter when starting epoch
                if batch_idx == 0:
                    self.prev_called_batch_idx = 0

                    # use this opportunity to test once
                    assert self.trainer.accumulate_grad_batches == 1

                assert batch_idx == self.prev_called_batch_idx
                self.prev_called_batch_idx += 1

            elif 1 <= epoch <= 2:
                # reset counter when starting epoch
                if batch_idx == 1:
                    self.prev_called_batch_idx = 1

                    # use this opportunity to test once
                    assert self.trainer.accumulate_grad_batches == 2

                assert batch_idx == self.prev_called_batch_idx
                self.prev_called_batch_idx += 2

            else:
                if batch_idx == 3:
                    self.prev_called_batch_idx = 3

                    # use this opportunity to test once
                    assert self.trainer.accumulate_grad_batches == 4

                assert batch_idx == self.prev_called_batch_idx
                self.prev_called_batch_idx += 3

        optimizer.step()

        # clear gradients
        optimizer.zero_grad()

    hparams = tutils.get_default_hparams()
    model = LightningTestModel(hparams)
    schedule = {1: 2, 3: 4}

    trainer = Trainer(accumulate_grad_batches=schedule,
                      train_percent_check=0.1,
                      val_percent_check=0.1,
                      max_epochs=2,
                      default_root_dir=tmpdir)

    # for the test
    trainer.optimizer_step = _optimizer_step
    model.prev_called_batch_idx = 0

    trainer.fit(model)
Пример #25
0
def test_training_loop_hook_call_order(tmpdir):
    """Tests that hooks / methods called in the training loop are in the correct order as detailed in the docs:
    https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#hooks"""
    class HookedModel(BoringModel):
        def __init__(self):
            super().__init__()
            self.called = []

        def on_epoch_start(self):
            self.called.append("on_epoch_start")
            super().on_epoch_start()

        def on_train_epoch_start(self):
            self.called.append("on_train_epoch_start")
            super().on_train_epoch_start()

        def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
            self.called.append("on_train_batch_start")
            super().on_train_batch_start(batch, batch_idx, dataloader_idx)

        def training_step(self, batch, batch_idx):
            self.called.append("training_step")
            return super().training_step(batch, batch_idx)

        def on_before_zero_grad(self, optimizer):
            self.called.append("on_before_zero_grad")
            super().on_before_zero_grad(optimizer)

        def optimizer_zero_grad(self, epoch, batch_idx, optimizer,
                                optimizer_idx):
            self.called.append("optimizer_zero_grad")
            super().optimizer_zero_grad(epoch, batch_idx, optimizer,
                                        optimizer_idx)

        def backward(self, loss, optimizer, optimizer_idx, *args, **kwargs):
            self.called.append("backward")
            super().backward(loss, optimizer, optimizer_idx, *args, **kwargs)

        def on_after_backward(self):
            self.called.append("on_after_backward")
            super().on_after_backward()

        def optimizer_step(
            self,
            epoch,
            batch_idx,
            optimizer,
            optimizer_idx,
            optimizer_closure,
            on_tpu,
            using_native_amp,
            using_lbfgs,
        ):
            super().optimizer_step(epoch, batch_idx, optimizer, optimizer_idx,
                                   optimizer_closure, on_tpu, using_native_amp,
                                   using_lbfgs)
            self.called.append("optimizer_step"
                               )  # append after as closure calls other methods

        def on_train_batch_end(self, outputs, batch, batch_idx,
                               dataloader_idx):
            self.called.append("on_train_batch_end")
            super().on_train_batch_end(outputs, batch, batch_idx,
                                       dataloader_idx)

        def training_epoch_end(self, outputs):
            self.called.append("training_epoch_end")
            super().training_epoch_end(outputs)

        def on_train_epoch_end(self, outputs):
            self.called.append("on_train_epoch_end")
            super().on_train_epoch_end(outputs)

        def on_epoch_end(self):
            self.called.append("on_epoch_end")
            super().on_epoch_end()

    model = HookedModel()

    # fit model
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_val_batches=1,
        limit_train_batches=1,
        limit_test_batches=1,
        progress_bar_refresh_rate=0,
        weights_summary=None,
    )

    assert model.called == []

    trainer.fit(model)
    expected = [
        'on_epoch_start',  # validation
        'on_epoch_end',
        'on_epoch_start',  # training
        'on_train_epoch_start',
        'on_train_batch_start',
        'training_step',
        'on_before_zero_grad',
        'optimizer_zero_grad',
        'backward',
        'on_after_backward',
        'optimizer_step',
        'on_train_batch_end',
        'training_epoch_end',
        'on_train_epoch_end',
        'on_epoch_end',
        'on_epoch_start',  # validation
        'on_epoch_end'
    ]
    assert model.called == expected
Пример #26
0
def test_resume_from_checkpoint_epoch_restored(tmpdir):
    """Verify resuming from checkpoint runs the right number of epochs"""
    import types

    tutils.reset_seed()

    hparams = tutils.get_default_hparams()

    def _new_model():
        # Create a model that tracks epochs and batches seen
        model = LightningTestModel(hparams)
        model.num_epochs_seen = 0
        model.num_batches_seen = 0

        def increment_epoch(self):
            self.num_epochs_seen += 1

        def increment_batch(self, _):
            self.num_batches_seen += 1

        # Bind the increment_epoch function on_epoch_end so that the
        # model keeps track of the number of epochs it has seen.
        model.on_epoch_end = types.MethodType(increment_epoch, model)
        model.on_batch_start = types.MethodType(increment_batch, model)
        return model

    model = _new_model()

    trainer_options = dict(
        progress_bar_refresh_rate=0,
        max_epochs=2,
        train_percent_check=0.65,
        val_percent_check=1,
        checkpoint_callback=ModelCheckpoint(tmpdir, save_top_k=-1),
        logger=False,
        default_root_dir=tmpdir,
        early_stop_callback=False,
        val_check_interval=1.,
    )

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.fit(model)

    training_batches = trainer.num_training_batches

    assert model.num_epochs_seen == 2
    assert model.num_batches_seen == training_batches * 2

    # Other checkpoints can be uncommented if/when resuming mid-epoch is supported
    checkpoints = sorted(
        glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, '*.ckpt')))

    for check in checkpoints:
        next_model = _new_model()
        state = torch.load(check)

        # Resume training
        trainer_options['max_epochs'] = 2
        new_trainer = Trainer(**trainer_options, resume_from_checkpoint=check)
        new_trainer.fit(next_model)
        assert state[
            'global_step'] + next_model.num_batches_seen == training_batches * trainer_options[
                'max_epochs']
        return torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9)

    def train_dataloader(self, *args, **kwargs):
        ds = CIFAR10(root='./data',
                     train=True,
                     download=True,
                     transform=self.transform)
        return torch.utils.data.DataLoader(ds,
                                           batch_size=32,
                                           shuffle=True,
                                           num_workers=0)

    def val_dataloader(self, *args, **kwargs):
        ds = CIFAR10(root='./data',
                     train=False,
                     download=True,
                     transform=self.transform)
        return torch.utils.data.DataLoader(ds,
                                           batch_size=32,
                                           shuffle=True,
                                           num_workers=0)


if __name__ == "__main__":
    model = LiftModel(models.resnet50(pretrained=True))
    trainer = Trainer(max_epochs=10,
                      gpus=2,
                      profiler="pytorch",
                      accelerator="ddp")
    trainer.fit(model)
    trainer.validate(model)
Пример #28
0
def train(config):
    # ======================================================
    # EXPERIMENT SETUP
    # ======================================================
    from pytorch_lightning import seed_everything
    # Seed
    seed_everything(config.seed)

    # DATASET SETUP
    print("======================================================")
    print("SETTING UP DATASET")
    print("======================================================")
    from ml4floods.models.dataset_setup import get_dataset
    dataset = get_dataset(config.data_params)

    # MODEL SETUP
    print("======================================================")
    print("SETTING UP MODEL")
    print("======================================================")
    from ml4floods.models.model_setup import get_model
    config.model_params.test = False
    config.model_params.train = True
    model = get_model(config.model_params)

    # LOGGING SETUP
    print("======================================================")
    print("SETTING UP LOGGERS")
    print("======================================================")
    import wandb
    from pytorch_lightning.loggers import WandbLogger
    wandb_logger = WandbLogger(
        name=config.experiment_name,
        project=config.wandb_project,
        entity=config.wandb_entity,
        #         save_dir=f"{config.model_params.model_folder}/{config.experiment_name}"
    )

    # CHECKPOINTING SETUP
    print("======================================================")
    print("SETTING UP CHECKPOINTING")
    print("======================================================")
    from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
    experiment_path = f"{config.model_params.model_folder}/{config.experiment_name}"

    checkpoint_path = f"{experiment_path}/checkpoint"
    checkpoint_callback = ModelCheckpoint(
        dirpath=checkpoint_path,
        save_top_k=True,
        verbose=True,
        monitor=config.model_params.hyperparameters.metric_monitor,
        mode='min',
        prefix='')

    early_stop_callback = EarlyStopping(
        monitor=config.model_params.hyperparameters.metric_monitor,
        patience=4,
        strict=False,
        verbose=False,
        mode='min')

    callbacks = [checkpoint_callback, early_stop_callback]

    # TRAINING SETUP
    print("======================================================")
    print("START TRAINING")
    print("======================================================")
    from pytorch_lightning import Trainer
    trainer = Trainer(
        fast_dev_run=False,
        logger=wandb_logger,
        callbacks=callbacks,
        default_root_dir=
        f"{config.model_params.model_folder}/{config.experiment_name}",
        accumulate_grad_batches=1,
        gradient_clip_val=0.0,
        auto_lr_find=False,
        benchmark=False,
        distributed_backend=None,
        gpus=config.gpus if config.gpus != '' else None,
        max_epochs=config.model_params.hyperparameters.max_epochs,
        check_val_every_n_epoch=config.model_params.hyperparameters.val_every,
        log_gpu_memory=None,
        resume_from_checkpoint=checkpoint_path
        if config.resume_from_checkpoint else None)

    trainer.fit(model, dataset)

    # ======================================================
    # SAVING SETUP
    # ======================================================
    print("======================================================")
    print("FINISHED TRAINING, SAVING MODEL")
    print("======================================================")
    from pytorch_lightning.utilities.cloud_io import atomic_save
    atomic_save(model.state_dict(), f"{experiment_path}/model.pt")
    torch.save(model.state_dict(),
               os.path.join(wandb_logger.save_dir, 'model.pt'))
    wandb.save(os.path.join(wandb_logger.save_dir, 'model.pt'))
    wandb.finish()

    # Save cofig file in experiment_path
    config_file_path = f"{experiment_path}/config.json"

    save_json(config, config_file_path)

    return 1
Пример #29
0
def test_dp_resume():
    """
    Make sure DP continues training correctly
    :return:
    """
    if not can_run_gpu_test():
        return

    hparams = get_hparams()
    model = LightningTestModel(hparams)

    trainer_options = dict(
        show_progress_bar=True,
        max_nb_epochs=2,
        gpus=2,
        distributed_backend='dp',
    )

    save_dir = init_save_dir()

    # get logger
    logger = get_test_tube_logger(debug=False)
    logger.log_hyperparams(hparams)

    # exp file to get weights
    checkpoint = ModelCheckpoint(save_dir)

    # add these to the trainer options
    trainer_options['logger'] = logger
    trainer_options['checkpoint_callback'] = checkpoint

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.is_slurm_managing_tasks = True
    result = trainer.fit(model)

    # track epoch before saving
    real_global_epoch = trainer.current_epoch

    # correct result and ok accuracy
    assert result == 1, 'amp + dp model failed to complete'

    # ---------------------------
    # HPC LOAD/SAVE
    # ---------------------------
    # save
    trainer.hpc_save(save_dir, logger)

    # init new trainer
    new_logger = get_test_tube_logger(version=logger.version)
    trainer_options['logger'] = new_logger
    trainer_options['checkpoint_callback'] = ModelCheckpoint(save_dir)
    trainer_options['train_percent_check'] = 0.2
    trainer_options['val_percent_check'] = 0.2
    trainer_options['max_nb_epochs'] = 1
    new_trainer = Trainer(**trainer_options)

    # set the epoch start hook so we can predict before the model does the full training
    def assert_good_acc():
        assert new_trainer.current_epoch == real_global_epoch and new_trainer.current_epoch > 0

        # if model and state loaded correctly, predictions will be good even though we
        # haven't trained with the new loaded model
        dp_model = new_trainer.model
        dp_model.eval()

        _ = [run_prediction(dataloader, dp_model, dp=True) for dataloader in trainer.val_dataloader]

    # new model
    model = LightningTestModel(hparams)
    model.on_sanity_check_start = assert_good_acc

    # fit new model which should load hpc weights
    new_trainer.fit(model)

    # test freeze on gpu
    model.freeze()
    model.unfreeze()

    clear_save_dir()
Пример #30
0
def test_fit_can_fail_during_validation(train_datasets, val_datasets,
                                        val_check_interval, tmpdir):
    size, n_batches = 2, 4
    stop_batch = 1
    n_val_dataloaders = len(val_datasets)
    stop_dataloader = n_val_dataloaders - 1

    class TestModel(LightningModule):
        def __init__(self, should_fail):
            super().__init__()
            self.layer = torch.nn.Linear(size, 2)
            self.should_fail = should_fail

        def step(self, batch):
            return sum(self.layer(b).sum() for b in batch)

        def training_step(self, batch, batch_idx):
            return self.step(batch)

        def validation_step(self, batch, batch_idx, dataloader_idx=0):
            if self.should_fail and dataloader_idx == stop_dataloader and batch_idx == stop_batch:
                raise CustomException
            return self.step(batch)

        def configure_optimizers(self):
            return torch.optim.SGD(self.layer.parameters(), lr=0.1)

        def train_dataloader(self):
            return [DataLoader(cls(size, n_batches)) for cls in train_datasets]

        def val_dataloader(self):
            return [DataLoader(cls(size, n_batches)) for cls in val_datasets]

    model = TestModel(False)
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        val_check_interval=val_check_interval,
        num_sanity_val_steps=0,
        enable_progress_bar=False,
    )
    trainer.fit(model)

    ckpt_path = os.path.join(tmpdir, ".pl_auto_save.ckpt")
    assert not os.path.exists(ckpt_path), "Shouldn't have failed"
    state_dict = trainer.fit_loop.state_dict()
    expected_global_step = trainer.global_step

    assert state_dict["epoch_loop.batch_progress"] == {
        "total": {
            "ready": n_batches,
            "started": n_batches,
            "processed": n_batches,
            "completed": n_batches
        },
        "current": {
            "ready": n_batches,
            "started": n_batches,
            "processed": n_batches,
            "completed": n_batches
        },
        "is_last_batch": True,
    }

    val_per_epoch = int(1 // val_check_interval)
    assert state_dict["epoch_loop.val_loop.dataloader_progress"] == {
        "total": {
            "ready": n_val_dataloaders * val_per_epoch,
            "completed": n_val_dataloaders * val_per_epoch
        },
        "current": {
            "ready": n_val_dataloaders,
            "completed": n_val_dataloaders
        },
    }

    assert state_dict["epoch_loop.val_loop.epoch_loop.batch_progress"] == {
        "total": {
            "ready": n_val_dataloaders * val_per_epoch * n_batches,
            "started": n_val_dataloaders * val_per_epoch * n_batches,
            "processed": n_val_dataloaders * val_per_epoch * n_batches,
            "completed": n_val_dataloaders * val_per_epoch * n_batches,
        },
        "current": {
            "ready": n_batches,
            "completed": n_batches,
            "started": n_batches,
            "processed": n_batches
        },
        "is_last_batch": True,
    }

    model = TestModel(True)
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        val_check_interval=val_check_interval,
        num_sanity_val_steps=0,
        enable_progress_bar=False,
    )
    with pytest.raises(CustomException):
        # will stop during validation
        trainer.fit(model)

    assert os.path.exists(ckpt_path)
    checkpoint = torch.load(ckpt_path)["loops"]["fit_loop"]

    per_val_train_batches = int(n_batches * val_check_interval)
    assert checkpoint["epoch_loop.batch_progress"] == {
        "total": {
            "ready": per_val_train_batches,
            "started": per_val_train_batches,
            "processed": per_val_train_batches,
            "completed": per_val_train_batches,
        },
        "current": {
            "ready": per_val_train_batches,
            "started": per_val_train_batches,
            "processed": per_val_train_batches,
            "completed": per_val_train_batches,
        },
        "is_last_batch": val_check_interval == 1,
    }

    val_batch_progress = "epoch_loop.val_loop.epoch_loop.batch_progress"
    # "nb_": non-breaking
    nb_total_val_batch = stop_dataloader * n_batches
    assert checkpoint[val_batch_progress] == {
        "total": {
            "ready": nb_total_val_batch + stop_batch + 1,
            "started": nb_total_val_batch + stop_batch + 1,
            "processed": nb_total_val_batch + stop_batch,
            "completed": nb_total_val_batch + stop_batch,
        },
        "current": {
            "ready": stop_batch + 1,
            "started": stop_batch + 1,
            "processed": stop_batch,
            "completed": stop_batch,
        },
        "is_last_batch": False,
    }

    model = TestModel(False)
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        val_check_interval=val_check_interval,
        enable_progress_bar=False,
    )
    trainer.fit(model, ckpt_path=ckpt_path)

    assert trainer.global_step == expected_global_step

    state_dict_after_restart = trainer.fit_loop.state_dict()

    # should get the same values as in the run that did not fail
    # totals are increased by 1 (the failed batch which never completed)
    expected = state_dict.copy()

    assert state_dict_after_restart["epoch_loop.batch_progress"] == expected[
        "epoch_loop.batch_progress"]

    val_dl_progress = "epoch_loop.val_loop.dataloader_progress"
    expected[val_dl_progress]["total"]["ready"] += 1
    assert state_dict_after_restart[val_dl_progress] == expected[
        val_dl_progress]

    expected[val_batch_progress]["total"]["ready"] += 1
    expected[val_batch_progress]["total"]["started"] += 1
    assert state_dict_after_restart[val_batch_progress] == expected[
        val_batch_progress]