def test_finetuning_with_resume_from_checkpoint(tmpdir):
    """
    This test validates that generated ModelCheckpoint is pointing to the right best_model_path during test
    """

    seed_everything(4)

    checkpoint_callback = ModelCheckpoint(monitor='val_loss',
                                          dirpath=tmpdir,
                                          filename="{epoch:02d}",
                                          save_top_k=-1)

    class ExtendedBoringModel(BoringModel):
        def configure_optimizers(self):
            optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.001)
            lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                           step_size=1)
            return [optimizer], [lr_scheduler]

        def validation_step(self, batch, batch_idx):
            output = self.layer(batch)
            loss = self.loss(batch, output)
            self.log("val_loss", loss, on_epoch=True, prog_bar=True)

    model = ExtendedBoringModel()
    model.validation_epoch_end = None
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_train_batches=12,
        limit_val_batches=6,
        limit_test_batches=12,
        callbacks=[checkpoint_callback],
        logger=False,
    )
    trainer.fit(model)
    assert os.listdir(tmpdir) == ['epoch=00.ckpt']

    best_model_paths = [checkpoint_callback.best_model_path]
    results = []

    for idx in range(3, 6):
        # load from checkpoint
        trainer = pl.Trainer(
            default_root_dir=tmpdir,
            max_epochs=idx,
            limit_train_batches=12,
            limit_val_batches=12,
            limit_test_batches=12,
            resume_from_checkpoint=best_model_paths[-1],
            progress_bar_refresh_rate=0,
        )
        trainer.fit(model)
        trainer.test()
        results.append(deepcopy(trainer.callback_metrics))
        best_model_paths.append(trainer.checkpoint_callback.best_model_path)

    for idx in range(len(results) - 1):
        assert results[idx]["val_loss"] > results[idx + 1]["val_loss"]

    for idx, best_model_path in enumerate(best_model_paths):
        if idx == 0:
            assert best_model_path.endswith(f"epoch=0{idx}.ckpt")
        else:
            assert f"epoch={idx + 1}" in best_model_path
def test_model_checkpoint_score_and_ckpt(tmpdir, validation_step_none: bool,
                                         val_dataloaders_none: bool,
                                         monitor: str,
                                         reduce_lr_on_plateau: bool):
    """
    Test that when a model checkpoint is saved, it saves with
    the correct score appended to ckpt_path and checkpoint data
    """
    max_epochs = 3
    limit_train_batches = 5
    limit_val_batches = 7
    lr = 1e-1

    class CustomBoringModel(BoringModel):
        def __init__(self):
            super().__init__()
            self.train_log_epochs = torch.randn(max_epochs,
                                                limit_train_batches)
            self.val_logs = torch.randn(max_epochs, limit_val_batches)

        def training_step(self, batch, batch_idx):
            log_value = self.train_log_epochs[self.current_epoch, batch_idx]
            self.log('train_log', log_value, on_epoch=True)
            return super().training_step(batch, batch_idx)

        def validation_step(self, batch, batch_idx):
            log_value = self.val_logs[self.current_epoch, batch_idx]
            self.log('val_log', log_value)
            self.log('epoch', self.current_epoch, on_epoch=True)
            return super().validation_step(batch, batch_idx)

        def configure_optimizers(self):
            optimizer = optim.SGD(self.parameters(), lr=lr)

            if reduce_lr_on_plateau:
                lr_scheduler = {
                    'scheduler':
                    optim.lr_scheduler.ReduceLROnPlateau(optimizer),
                    'monitor': monitor,
                    'strict': True,
                }
            else:
                lr_scheduler = optim.lr_scheduler.StepLR(optimizer,
                                                         step_size=1)

            return [optimizer], [lr_scheduler]

    filename = '{' + f'{monitor}' + ':.4f}-{epoch}'
    checkpoint = ModelCheckpoint(dirpath=tmpdir,
                                 filename=filename,
                                 monitor=monitor,
                                 save_top_k=-1)

    model = CustomBoringModel()

    if validation_step_none:
        model.validation_step = None
    if val_dataloaders_none:
        model.val_dataloaders = None

    trainer = Trainer(
        default_root_dir=tmpdir,
        callbacks=[checkpoint],
        limit_train_batches=limit_train_batches,
        limit_val_batches=limit_val_batches,
        max_epochs=max_epochs,
        progress_bar_refresh_rate=0,
    )
    results = trainer.fit(model)
    assert results
    assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"

    ckpt_files = list(Path(tmpdir).glob('*.ckpt'))
    scores = [
        metric[monitor] for metric in trainer.dev_debugger.logged_metrics
        if monitor in metric
    ]
    lr_scheduler_debug = trainer.dev_debugger.saved_lr_scheduler_updates
    assert len(ckpt_files) == len(scores) == max_epochs
    assert len(lr_scheduler_debug) == max_epochs

    for epoch in range(max_epochs):
        score = scores[epoch]
        expected_score = getattr(model, f'{monitor}s')[epoch].mean().item()
        expected_filename = f'{monitor}={score:.4f}-epoch={epoch}.ckpt'
        assert math.isclose(score, expected_score, rel_tol=1e-4)

        chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename))
        assert chk['epoch'] == epoch + 1
        assert chk['global_step'] == limit_train_batches * (epoch + 1)

        mc_specific_data = chk['callbacks'][type(checkpoint)]
        assert mc_specific_data['dirpath'] == checkpoint.dirpath
        assert mc_specific_data['monitor'] == monitor
        assert mc_specific_data['current_score'] == score

        if not reduce_lr_on_plateau:
            lr_scheduler_specific_data = chk['lr_schedulers'][0]
            assert lr_scheduler_specific_data['_step_count'] == epoch + 2
            assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**(
                epoch + 1))

        assert lr_scheduler_debug[epoch]['monitor_val'] == (
            score if reduce_lr_on_plateau else None)
        assert lr_scheduler_debug[epoch]['monitor_key'] == (
            monitor if reduce_lr_on_plateau else None)
Пример #3
0
    pspnet = LightningPSPNet.load_from_checkpoint(
        os.path.join(name + '_ckpt', 'last.ckpt'))

    wandb_logger = WandbLogger(
        name=name,
        project='PL-PSPNet',
        entity='yousiki',
    )

    wandb_logger.watch(pspnet.pspnet, log='all', log_freq=100)

    model_checkpoint = ModelCheckpoint(
        dirpath=os.path.join(os.getcwd(), name + '_ckpt'),
        filename='{epoch}-{step}-{val_loss:.2f}',
        monitor='val_loss',
        mode='min',
        verbose=True,
        save_last=True,
        save_top_k=10,
    )

    trainer = pl.Trainer(
        gpus=1,
        max_epochs=100,
        benchmark=True,
        check_val_every_n_epoch=1,
        logger=wandb_logger,
        callbacks=[model_checkpoint],
        accumulate_grad_batches=4,
        limit_val_batches=0.1,
        # fast_dev_run=True,
def test_model_checkpoint_mode_options():
    with pytest.raises(MisconfigurationException,
                       match="`mode` can be .* but got unknown_option"):
        ModelCheckpoint(mode="unknown_option")
def test_invalid_top_k(tmpdir):
    """ Make sure that a MisconfigurationException is raised for a negative save_top_k argument. """
    with pytest.raises(MisconfigurationException,
                       match=r'.*Must be None or >= -1'):
        ModelCheckpoint(dirpath=tmpdir, save_top_k=-3)
Пример #6
0
def main(cfg: DictConfig):
    print(f'Training {cfg.train.model} Model')
    cur_dir = hydra.utils.get_original_cwd()
    os.chdir(cur_dir)

    # Data Augmentation  --------------------------------------------------------
    transform = ImageTransform(cfg) if cfg.train.model != 'progan' else None

    # DataModule  ---------------------------------------------------------------
    dm = None
    data_dir = './data'
    if cfg.train.data == 'celeba_hq':
        img_paths = glob.glob(os.path.join(data_dir, 'celeba_hq', '**/*.jpg'),
                              recursive=True)
        dm = SingleImageDataModule(img_paths, transform, cfg)

    elif cfg.train.data == 'afhq':
        img_paths = glob.glob(os.path.join(data_dir, 'afhq', '**/*.jpg'),
                              recursive=True)
        dm = SingleImageDataModule(img_paths, transform, cfg)

    elif cfg.train.data == 'ffhq':
        img_paths = glob.glob(os.path.join(data_dir, 'ffhq', '**/*.png'),
                              recursive=True)
        dm = SingleImageDataModule(img_paths, transform, cfg)

    # Model  --------------------------------------------------------------------
    nets = build_model(cfg.train.model, cfg)

    # Comet_ml  -----------------------------------------------------------------
    load_dotenv('.env')
    logger = CometLogger(api_key=os.environ['COMET_ML_API_KEY'],
                         project_name=os.environ['COMET_ML_PROJECT_NAME'],
                         experiment_name=f"{cfg.train.model}")

    logger.log_hyperparams(dict(cfg.train))

    # Lightning Module  ---------------------------------------------------------
    model = None
    checkpoint_path = 'checkpoints/'
    checkpoint_callback = ModelCheckpoint(dirpath=checkpoint_path,
                                          filename='{epoch:02d}',
                                          prefix=cfg.train.model,
                                          period=1)

    if cfg.train.model == 'vae':
        model = VAE_LightningSystem(nets[0], cfg)

    elif cfg.train.model == 'dcgan':
        model = DCGAN_LightningSystem(nets[0], nets[1], cfg, checkpoint_path)

    elif cfg.train.model == 'wgan_gp':
        logger.log_hyperparams(dict(cfg.wgan_gp))
        model = WGAN_GP_LightningSystem(nets[0], nets[1], cfg, checkpoint_path)

    elif cfg.train.model == 'cyclegan':
        logger.log_hyperparams(dict(cfg.cyclegan))
        data_dir = 'data/'
        base_img_paths = glob.glob(os.path.join(data_dir,
                                                cfg.cyclegan.base_imgs_dir,
                                                '**/*.jpg'),
                                   recursive=True)
        style_img_paths = glob.glob(os.path.join(data_dir,
                                                 cfg.cyclegan.style_imgs_dir,
                                                 '**/*.jpg'),
                                    recursive=True)
        dm = CycleGANDataModule(base_img_paths,
                                style_img_paths,
                                transform,
                                cfg,
                                phase='train',
                                seed=cfg.train.seed)
        model = CycleGAN_LightningSystem(nets[0], nets[1], nets[2], nets[3],
                                         transform, cfg, checkpoint_path)

    elif cfg.train.model == 'sagan':
        logger.log_hyperparams(dict(cfg.sagan))
        model = SAGAN_LightningSystem(nets[0], nets[1], cfg, checkpoint_path)

    elif cfg.train.model == 'progan':
        logger.log_hyperparams(dict(cfg.progan))
        model = PROGAN_LightningSystem(nets[0], nets[1], cfg, checkpoint_path)

    # Trainer  ---------------------------------------------------------
    trainer = Trainer(
        logger=logger,
        max_epochs=cfg.train.epoch,
        gpus=1,
        callbacks=[checkpoint_callback],
        # fast_dev_run=True,
        # resume_from_checkpoint='./checkpoints/sagan-epoch=11.ckpt'
    )

    # Train
    trainer.fit(model, datamodule=dm)
def main(hparams, cluster=None, results_dict=None):
    """
    Main training routine specific for this project
    :param hparams:
    :return:
    """
    name = 'immersions_scalogram_resnet_maestro_smaller'
    version = 0
    hparams.log_dir = '/home/idivinci3005/experiments/logs'
    hparams.checkpoint_dir = '/home/idivinci3005/experiments/checkpoints/' + name + '/' + str(
        version)
    hparams.training_set_path = '/home/idivinci3005/data/maestro-v2.0.0'
    hparams.validation_set_path = '/home/idivinci3005/data/maestro-v2.0.0'
    hparams.test_task_set_path = '/home/idivinci3005/data/maestro-v2.0.0'
    hparams.dummy_datasets = False
    hparams.audio_noise = 3e-3

    hparams.cqt_fmin = 40.
    hparams.cqt_bins_per_octave = 24
    hparams.cqt_n_bins = 216
    hparams.cqt_hop_length = 512
    hparams.cqt_filter_scale = 0.43

    hparams.enc_channels = (1, 8, 16, 32, 64, 128, 256, 512, 512)
    hparams.enc_kernel_1_w = (3, 3, 3, 3, 3, 3, 3, 3)
    hparams.enc_kernel_1_h = (3, 3, 3, 3, 3, 3, 3, 3)
    hparams.enc_kernel_2_w = (1, 3, 1, 3, 1, 3, 1, 3)
    hparams.enc_kernel_2_h = (25, 3, 25, 3, 25, 3, 4, 3)
    hparams.enc_padding_1 = (1, 1, 1, 1, 1, 1, 1, 1)
    hparams.enc_padding_2 = (0, 1, 0, 1, 0, 1, 0, 0)
    hparams.enc_stride_1 = (1, 1, 1, 1, 1, 1, 1, 1)
    hparams.enc_stride_2 = (1, 1, 1, 1, 1, 1, 1, 1)
    hparams.enc_pooling_1 = (2, 1, 1, 1, 2, 1, 1, 1)

    hparams.ar_kernel_sizes = (5, 4, 1, 3, 3, 1, 3, 1, 6)
    hparams.ar_self_attention = (False, False, False, False, False, False,
                                 False, False, False)
    hparams.batch_size = 4
    hparams.learning_rate = 3e-4
    hparams.warmup_steps = 1000
    hparams.annealing_steps = 100000
    hparams.score_over_all_timesteps = False
    hparams.visible_steps = 60

    hparams.batch_size = 32
    hparams.learning_rate = 3e-4
    hparams.warmup_steps = 1000
    hparams.annealing_steps = 100000
    hparams.score_over_all_timesteps = False
    hparams.visible_steps = 60

    # init experiment
    exp = Experiment(name=name,
                     debug=False,
                     save_dir=hparams.log_dir,
                     version=version,
                     autosave=False,
                     description='maestro dataset experiment')

    # set the hparams for the experiment
    exp.argparse(hparams)
    exp.save()

    # build model
    model = ContrastivePredictiveSystemMaestro(hparams)
    task_model = MaestroClassificationTaskModel(
        model, task_dataset_path=hparams.test_task_set_path)
    model.test_task_model = task_model

    # callbacks
    early_stop = EarlyStopping(monitor=hparams.early_stop_metric,
                               patience=hparams.early_stop_patience,
                               verbose=True,
                               mode=hparams.early_stop_mode)

    checkpoint = ModelCheckpoint(filepath=hparams.checkpoint_dir,
                                 save_best_only=False,
                                 verbose=True,
                                 monitor=hparams.model_save_monitor_value,
                                 mode=hparams.model_save_monitor_mode)

    # configure trainer
    trainer = Trainer(
        experiment=exp,
        checkpoint_callback=checkpoint,
        #early_stop_callback=early_stop,
        #distributed_backend='dp',
        gpus=[0],
        nb_sanity_val_steps=5,
        val_check_interval=0.1,
        val_percent_check=0.25,
        gradient_clip=0.5,
        track_grad_norm=2)

    # train model
    trainer.fit(model)
def test_cpu_slurm_save_load():
    """
    Verify model save/load/checkpoint on CPU
    :return:
    """
    hparams = get_hparams()
    model = LightningTestModel(hparams)

    save_dir = init_save_dir()

    # exp file to get meta
    exp = get_exp(False)
    exp.argparse(hparams)
    exp.save()

    cluster_a = SlurmCluster()
    trainer_options = dict(
        max_nb_epochs=1,
        cluster=cluster_a,
        experiment=exp,
        checkpoint_callback=ModelCheckpoint(save_dir)
    )

    # fit model
    trainer = Trainer(**trainer_options)
    result = trainer.fit(model)
    real_global_step = trainer.global_step

    # traning complete
    assert result == 1, 'amp + ddp model failed to complete'

    # predict with trained model before saving
    # make a prediction
    for batch in model.test_dataloader:
        break

    x, y = batch
    x = x.view(x.size(0), -1)

    model.eval()
    pred_before_saving = model(x)

    # test registering a save function
    trainer.enable_auto_hpc_walltime_manager()

    # test HPC saving
    # simulate snapshot on slurm
    saved_filepath = trainer.hpc_save(save_dir, exp)
    assert os.path.exists(saved_filepath)

    # wipe-out trainer and model
    # retrain with not much data... this simulates picking training back up after slurm
    # we want to see if the weights come back correctly
    continue_tng_hparams = get_hparams(continue_training=True,
                                       hpc_exp_number=cluster_a.hpc_exp_number)
    trainer_options = dict(
        max_nb_epochs=1,
        cluster=SlurmCluster(continue_tng_hparams),
        experiment=exp,
        checkpoint_callback=ModelCheckpoint(save_dir),
    )
    trainer = Trainer(**trainer_options)
    model = LightningTestModel(hparams)

    # set the epoch start hook so we can predict before the model does the full training
    def assert_pred_same():
        assert trainer.global_step == real_global_step and trainer.global_step > 0

        # predict with loaded model to make sure answers are the same
        trainer.model.eval()
        new_pred = trainer.model(x)
        assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1

    model.on_epoch_start = assert_pred_same

    # by calling fit again, we trigger training, loading weights from the cluster
    # and our hook to predict using current model before any more weight updates
    trainer.fit(model)

    clear_save_dir()
def test_amp_gpu_ddp_slurm_managed():
    """
    Make sure DDP + AMP work
    :return:
    """
    if not can_run_gpu_test():
        return

    # simulate setting slurm flags
    os.environ['MASTER_PORT'] = str(np.random.randint(12000, 19000, 1)[0])
    os.environ['SLURM_LOCALID'] = str(0)

    hparams = get_hparams()
    model = LightningTestModel(hparams)

    trainer_options = dict(
        show_progress_bar=True,
        max_nb_epochs=1,
        gpus=[0],
        distributed_backend='ddp',
        use_amp=True
    )

    save_dir = init_save_dir()

    # exp file to get meta
    exp = get_exp(False)
    exp.argparse(hparams)
    exp.save()

    # exp file to get weights
    checkpoint = ModelCheckpoint(save_dir)

    # add these to the trainer options
    trainer_options['checkpoint_callback'] = checkpoint
    trainer_options['experiment'] = exp

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.is_slurm_managing_tasks = True
    result = trainer.fit(model)

    # correct result and ok accuracy
    assert result == 1, 'amp + ddp model failed to complete'

    # test root model address
    assert trainer.resolve_root_node_address('abc') == 'abc'
    assert trainer.resolve_root_node_address('abc[23]') == 'abc23'
    assert trainer.resolve_root_node_address('abc[23-24]') == 'abc23'
    assert trainer.resolve_root_node_address('abc[23-24, 45-40, 40]') == 'abc23'

    # test model loading with a map_location
    map_location = 'cuda:1'
    pretrained_model = load_model(exp, save_dir, True, map_location)

    # test model preds
    run_prediction(model.test_dataloader, pretrained_model)

    if trainer.use_ddp:
        # on hpc this would work fine... but need to hack it for the purpose of the test
        trainer.model = pretrained_model
        trainer.optimizers, trainer.lr_schedulers = pretrained_model.configure_optimizers()

    # test HPC loading / saving
    trainer.hpc_save(save_dir, exp)
    trainer.hpc_load(save_dir, on_gpu=True)

    # test freeze on gpu
    model.freeze()
    model.unfreeze()

    clear_save_dir()
Пример #10
0
def run_lightning(argv=None):
    '''Run training with PyTorch Lightning'''
    global RANK
    from pytorch_lightning.loggers import WandbLogger
    import numpy as np
    import traceback
    import os
    import pprint

    pformat = pprint.PrettyPrinter(sort_dicts=False, width=100,
                                   indent=2).pformat

    model, args, addl_targs, data_mod = process_args(parse_args(argv=argv))

    # if 'OMPI_COMM_WORLD_RANK' in os.environ or 'SLURMD_NODENAME' in os.environ:
    #     from mpi4py import MPI
    #     comm = MPI.COMM_WORLD
    #     RANK = comm.Get_rank()
    # else:
    #     RANK = 0
    #     print('OMPI_COMM_WORLD_RANK or SLURMD_NODENAME not set in environment -- not using MPI')

    # output is a wrapper function for os.path.join(outdir, <FILE>)
    outdir, output = process_output(args)
    check_directory(outdir)
    if not args.quiet:
        print0(' '.join(sys.argv), file=sys.stderr)
        print0("Processed Args:\n", pformat(vars(args)), file=sys.stderr)

    # save arguments
    with open(output('args.pkl'), 'wb') as f:
        pickle.dump(args, f)

    checkpoint = None
    if args.init is not None:
        checkpoint = args.init
        link_dest = 'init.ckpt'
    elif args.checkpoint is not None:
        checkpoint = args.checkpoint
        link_dest = 'resumed_from.ckpt'

    if checkpoint is not None:
        if RANK == 0:
            print0(f'symlinking to {args.checkpoint} from {outdir}')
            dest = output(link_dest)
            src = os.path.relpath(checkpoint, start=outdir)
            if os.path.exists(dest):
                existing_src = os.readlink(dest)
                if existing_src != src:
                    msg = f'Cannot create symlink to checkpoint -- {dest} already exists, but points to {existing_src}'
                    raise RuntimeError(msg)
            else:
                os.symlink(src, dest)

    seed_everything(args.seed)

    if args.csv:
        logger = CSVLogger(save_dir=output('logs')),
    else:
        logger = WandbLogger(project="deep-taxon",
                             entity='deep-taxon',
                             name=args.experiment)

    # get dataset so we can set model parameters that are
    # dependent on the dataset, such as final number of outputs

    monitor, mode = (AbstractLit.val_loss,
                     'min') if args.manifold else (AbstractLit.val_acc, 'max')
    callbacks = [
        LearningRateMonitor(logging_interval='epoch'),
        TQDMProgressBar(refresh_rate=50)
    ]
    if not args.disable_checkpoint:
        callbacks.append(
            ModelCheckpoint(dirpath=outdir,
                            save_weights_only=False,
                            save_last=True,
                            save_top_k=3,
                            mode=mode,
                            monitor=monitor))

    if args.early_stop:
        callbacks.append(
            EarlyStopping(monitor=monitor,
                          min_delta=0.001,
                          patience=10,
                          verbose=False,
                          mode=mode))

    if args.swa:
        callbacks.append(
            StochasticWeightAveraging(swa_epoch_start=args.swa_start,
                                      annealing_epochs=args.swa_anneal))

    targs = dict(
        enable_checkpointing=True,
        callbacks=callbacks,
        logger=logger,
        num_sanity_val_steps=0,
    )
    targs.update(addl_targs)

    if args.debug:
        targs['log_every_n_steps'] = 1
        targs['fast_dev_run'] = 10

    if not args.quiet:
        print0('Trainer args:\n', pformat(targs), file=sys.stderr)
        print0('DataLoader args\n:',
               pformat(data_mod._loader_kwargs),
               file=sys.stderr)
        print0('Model:\n', model, file=sys.stderr)

    trainer = Trainer(**targs)

    if args.debug:
        #print_dataloader(data_mod.test_dataloader())
        print_dataloader(data_mod.train_dataloader())
        print_dataloader(data_mod.val_dataloader())

    s = datetime.now()
    print0('START_TIME', time())
    trainer.fit(model, data_mod)
    e = datetime.now()
    td = e - s
    hours, seconds = divmod(td.seconds, 3600)
    minutes, seconds = divmod(seconds, 60)

    print0("Took %02d:%02d:%02d.%d" %
           (hours, minutes, seconds, td.microseconds),
           file=sys.stderr)
    print0("Total seconds:", td.total_seconds(), file=sys.stderr)
Пример #11
0
def test_cpu_restore_training():
    """
    Verify continue training session on CPU
    :return:
    """
    hparams = get_hparams()
    model = LightningTestModel(hparams)

    save_dir = init_save_dir()

    # exp file to get meta
    test_exp_version = 10
    exp = get_exp(False, version=test_exp_version)
    exp.argparse(hparams)
    exp.save()

    trainer_options = dict(
        max_nb_epochs=2,
        val_check_interval=0.50,
        val_percent_check=0.2,
        train_percent_check=0.2,
        experiment=exp,
        checkpoint_callback=ModelCheckpoint(save_dir)
    )

    # fit model
    trainer = Trainer(**trainer_options)
    result = trainer.fit(model)
    real_global_epoch = trainer.current_epoch

    # traning complete
    assert result == 1, 'amp + ddp model failed to complete'

    # wipe-out trainer and model
    # retrain with not much data... this simulates picking training back up after slurm
    # we want to see if the weights come back correctly
    new_exp = get_exp(False, version=test_exp_version)
    trainer_options = dict(
        max_nb_epochs=2,
        val_check_interval=0.50,
        val_percent_check=0.2,
        train_percent_check=0.2,
        experiment=new_exp,
        checkpoint_callback=ModelCheckpoint(save_dir),
    )
    trainer = Trainer(**trainer_options)
    model = LightningTestModel(hparams)

    # set the epoch start hook so we can predict before the model does the full training
    def assert_good_acc():
        assert trainer.current_epoch == real_global_epoch and trainer.current_epoch > 0

        # if model and state loaded correctly, predictions will be good even though we
        # haven't trained with the new loaded model
        trainer.model.eval()
        _ = [run_prediction(dataloader, trainer.model) for dataloader in trainer.val_dataloader]

    model.on_sanity_check_start = assert_good_acc

    # by calling fit again, we trigger training, loading weights from the cluster
    # and our hook to predict using current model before any more weight updates
    trainer.fit(model)

    clear_save_dir()
Пример #12
0
        return mnist_val

    def test_dataloader(self):
        mnist_test = DataLoader(self.mnist_test,
                                batch_size=self.batch_size,
                                num_workers=self.num_workers)
        return mnist_test


## 3: Implement Callbacks and Create Them

from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

lr_logger = LearningRateMonitor(**LearningRateMonitor_Params)

model_checkpoint = ModelCheckpoint(**ModelCheckpoint_Params)

# Step 4: Create NeptuneLogger

from pytorch_lightning.loggers.neptune import NeptuneLogger

neptune_logger = NeptuneLogger(
    api_key="ANONYMOUS",
    project_name="shared/pytorch-lightning-integration",
    close_after_fit=False,
    experiment_name="train-on-MNIST",
    params=ALL_PARAMS,
    tags=['1.x', 'advanced'],
)

# Step 5: Pass NeptuneLogger and Callbacks to the Trainer
def main(hparams):
    """
    Main training routine specific for this project
    :param hparams:
    :return:
    """
    # ------------------------
    # 1 INIT LIGHTNING MODEL
    # ------------------------
    print('loading model...')
    model = LightningTemplateModel(hparams)
    print('model built')

    # ------------------------
    # 2 INIT TEST TUBE EXP
    # ------------------------

    # init experiment
    exp = Experiment(
        name=hyperparams.experiment_name,
        save_dir=hyperparams.test_tube_save_path,
        autosave=False,
        description='test demo'
    )

    exp.argparse(hparams)
    exp.save()

    # ------------------------
    # 3 DEFINE CALLBACKS
    # ------------------------
    model_save_path = '{}/{}/{}'.format(hparams.model_save_path, exp.name, exp.version)
    early_stop = EarlyStopping(
        monitor='val_acc',
        patience=3,
        verbose=True,
        mode='max'
    )

    checkpoint = ModelCheckpoint(
        filepath=model_save_path,
        save_best_only=True,
        verbose=True,
        monitor='val_loss',
        mode='min'
    )

    # ------------------------
    # 4 INIT TRAINER
    # ------------------------
    trainer = Trainer(
        experiment=exp,
        checkpoint_callback=checkpoint,
        early_stop_callback=early_stop,
        gpus=hparams.gpus,
        distributed_backend=hparams.dist_backend,
    )

    # ------------------------
    # 5 START TRAINING
    # ------------------------
    trainer.fit(model)
Пример #14
0
        fold_data = self.data[self.data.fold == self.hparams.fold]
        fold_data = fold_data.reset_index().drop('index', axis=1)

        dataset = RsnaDS(
            df=fold_data,
            aug=aug,
            path=DIR_INPUT,
        )

        return DataLoader(dataset, shuffle=False, batch_size=64, num_workers=6)


if __name__ == '__main__':
    module = RsnaStrPeModule()

    checkpoint_callback = ModelCheckpoint(
        filepath='/var/data/rsna_checkpoints',
        save_top_k=5,
        verbose=True,
        monitor='avg_val_loss',
        mode='min',
        prefix=''
    )

    early_stop = EarlyStopping(monitor='avg_val_loss', verbose=True, patience=10, mode='min')
    trainer = pl.Trainer(gpus=1, max_epochs=50,
                         default_root_dir='/var/data/rsna_checkpoints/',
                         early_stop_callback=early_stop,
                         checkpoint_callback=checkpoint_callback)
    trainer.fit(module)
Пример #15
0
def main(model_name='usp_1d', max_epochs=1020, data_dir='./data', dataset='sc09', ps=False, wn=False, mx=False, perc=1,
         ts=False, fd=False, tts=False, tm=False, train=True, order=True, model_num=None):
    model_name = model_name + '_' + str(int(perc * 100))
    dataset_f = dataset
    nsynth_class = None
    if dataset == 'sc09':
        sample_rate = 16000
        n_classes = 10
        length = 1
        batch_size = 256
        train_transform = transforms.Compose([
            torchaudio.transforms.Resample(orig_freq=16000, new_freq=sample_rate),
            pad,
        ])
    elif dataset == 'sc':
        sample_rate = 16000
        n_classes = 35
        batch_size = 128
        length = 1
        train_transform = transforms.Compose([
            torchaudio.transforms.Resample(orig_freq=16000, new_freq=sample_rate),
            partial(pad, length=length),
        ])
    elif dataset == 'nsynth11':
        sample_rate = 16000
        n_classes = 11
        batch_size = 32
        max_epochs = 120
        dataset = 'nsynth'
        nsynth_class = 'instrument_family'
        length = 4
        train_transform = transforms.Compose([
            torchaudio.transforms.Resample(orig_freq=16000, new_freq=sample_rate),
            partial(pad, length=length),
        ])
    elif dataset == 'nsynth128':
        sample_rate = 16000
        n_classes = 128
        batch_size = 16
        max_epochs = 120
        dataset = 'nsynth'
        nsynth_class = 'pitch'
        length = 4
        train_transform = transforms.Compose([
            torchaudio.transforms.Resample(orig_freq=16000, new_freq=sample_rate),
            partial(pad, length=length),
        ])
    elif dataset == 'esc50':
        sample_rate = 16000
        max_epochs = 2000
        n_classes = 50
        batch_size = 32
        length = 5
        train_transform = transforms.Compose([
            torchaudio.transforms.Resample(orig_freq=44100, new_freq=sample_rate),
            partial(pad, length=length),
        ])
    elif dataset == 'esc10':
        sample_rate = 16000
        n_classes = 10
        max_epochs = 2000
        batch_size = 32
        length = 5
        train_transform = transforms.Compose([
            torchaudio.transforms.Resample(orig_freq=44100, new_freq=sample_rate),
            partial(pad, length=length),
        ])

    # model_name = model_name + '_' + dataset
    spec_transform = None
    aug_transform = []
    if order:
        if fd:
            aug_transform.append(transforms.RandomApply(add_fade))
            model_name = model_name + '_fd'
        if tm:
            aug_transform.append(transforms.RandomApply(time_masking))
            model_name = model_name + '_tm'
        if tts:
            aug_transform.append(transforms.RandomApply(partial(time_stret, length=length)))
            model_name = model_name + '_tts'
        if ps:
            aug_transform.append(transforms.RandomApply(pitch_shift))
            model_name = model_name + '_ps'
        if ts:
            aug_transform.append(transforms.RandomApply(time_shift))
            model_name = model_name + '_ts'
        if wn:
            aug_transform.append(transforms.RandomApply(add_white_noise))
            model_name = model_name + '_wn'
        if mx:
            m_x = Mixed_Noise(data_dir, sample_rate)
            aug_transform.append(transforms.RandomApply(m_x))
            model_name = model_name + '_mx'
    else:
        if mx:
            m_x = Mixed_Noise(data_dir, sample_rate)
            aug_transform.append(transforms.RandomApply(m_x))
            model_name = model_name + '_mx'
        if wn:
            aug_transform.append(transforms.RandomApply(add_white_noise))
            model_name = model_name + '_wn'
        if ts:
            aug_transform.append(transforms.RandomApply(time_shift))
            model_name = model_name + '_ts'
        if ps:
            aug_transform.append(transforms.RandomApply(pitch_shift))
            model_name = model_name + '_ps'
        if fd:
            aug_transform.append(transforms.RandomApply(add_fade))
            model_name = model_name + '_fd'
        if tts:
            aug_transform.append(transforms.RandomApply(partial(time_stret, length=length)))
            model_name = model_name + '_tts'
        if tm:
            aug_transform.append(transforms.RandomApply(time_masking))
            model_name = model_name + '_tm'

    aug_transform = transforms.Compose(aug_transform)
    print(f"Model: {model_name}")

    net = Main(batch_size=batch_size,
               sampling_rate=sample_rate,
               data_dir=data_dir,
               dataset=dataset,
               perc=perc,
               nsynth_class=nsynth_class,
               n_classes=n_classes,
               train_transform=train_transform,
               aug_transform=aug_transform,
               spec_transform=spec_transform,
               model=model_name)

    model_path = os.path.join(MODELS_FOLDER, model_name, dataset_f)
    os.makedirs(model_path, exist_ok=True)
    checkpoint_callback = ModelCheckpoint(
        filepath=model_path,
        save_last=True,
        mode='min',
        period=10,
        save_top_k=20000000,
    )
    if model_num is not None:
        checkpoint = os.path.join(model_path, get_last(os.listdir(model_path), model_num))
    elif os.path.exists(model_path) and len(os.listdir(model_path)) > 0:
        checkpoint = os.path.join(model_path, get_last(os.listdir(model_path)))
    else:
        checkpoint = None

    logger = TensorBoardLogger(
        save_dir=LOGS_FOLDER,
        version=dataset_f,
        name=model_name
    )

    # finetune in real-time
    print(f"Loading model: {checkpoint}")

    def to_device(batch, device):
        (x1, x2), y = batch
        x1 = x1.to(device)
        y = y.to(device).squeeze()
        return x1, y

    online_eval = SSLOnlineEvaluator(hidden_dim=512,
                                     z_dim=512,
                                     num_classes=n_classes,
                                     train_transform=train_transform,
                                     data_dir=data_dir,
                                     dataset=dataset,
                                     batch_size=batch_size,
                                     nsynth_class=nsynth_class
                                     )
    online_eval.to_device = to_device

    trainer = Trainer(resume_from_checkpoint=checkpoint,
                      distributed_backend='ddp',
                      max_epochs=max_epochs,
                      sync_batchnorm=True,
                      checkpoint_callback=checkpoint_callback,
                      logger=logger,
                      gpus=-1 if train else 1,
                      log_save_interval=25,
                      callbacks=[online_eval]
                      )
    if train:
        trainer.fit(net)
    else:
        trainer.test(net)
Пример #16
0
    def build(self, **kwargs):
        """
        Reponsável por criar os argumentos da classe
        """
        # Checagem das Chamadas
        self.build_called = True

        # Rcuperando Caminhos
        self.data_dirpath = self.config['dirpaths']['data_dirpath']
        self.log_dirpath = self.config['dirpaths']['log_dirpath']
        self.cwd_dirpath = self.config['dirpaths']['cwd_dirpath']

        # Rcuperando Parâmetros
        self.hparams = self.config['params']['hparams']
        self.lightning_params = self.config['params']['lightning_params']
        self.early_stop_callback_params = self.config['params'][
            'early_stop_callback_params']
        self.prepare_data_params = self.config['params']['prepare_data_params']
        #-
        self.test_size_from_dev = self.prepare_data_params[
            'test_size_from_dev']
        self.batch_dataset_preparation = self.prepare_data_params[
            'batch_dataset_preparation']
        #-
        self.model_name = self.hparams['model_name']
        self.train_batch_size = self.hparams['train_batch_size']
        self.eval_batch_size = self.hparams['eval_batch_size']
        self.max_length = self.hparams['max_length']
        self.doc_stride = self.hparams['doc_stride']
        self.learning_rate = self.hparams['learning_rate']
        self.eps = self.hparams['eps']
        self.seed = self.hparams['seed']
        #-
        self.num_gpus = self.lightning_params[
            'num_gpus'] if torch.cuda.is_available() else 0
        self.profiler = self.lightning_params['profiler']
        self.max_epochs = self.lightning_params['max_epochs']
        self.accumulate_grad_batches = self.lightning_params[
            'accumulate_grad_batches']
        self.check_val_every_n_epoch = self.lightning_params[
            'check_val_every_n_epoch']
        self.progress_bar_refresh_rate = self.lightning_params[
            'progress_bar_refresh_rate']
        self.gradient_clip_val = self.lightning_params['gradient_clip_val']
        self.fast_dev_run = self.lightning_params['fast_dev_run']
        #-
        self.monitor = self.early_stop_callback_params['monitor']
        self.min_delta = self.early_stop_callback_params['min_delta']
        self.patience = self.early_stop_callback_params['patience']
        self.verbose = self.early_stop_callback_params['verbose']
        self.mode = self.early_stop_callback_params['mode']

        # Criando parâmetros adicionais
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.config['params']['hparams']['model_name'])
        self.softmax = torch.nn.Softmax(dim=1)
        self.device = torch.device(
            'cuda') if torch.cuda.is_available() else torch.device('cpu')

        # Trainer
        if self.fast_dev_run:
            self.TRAINER = pl.Trainer(
                gpus=self.num_gpus,
                checkpoint_callback=False,
                fast_dev_run=True  # Disable checkpoint saving.
            )
        else:
            checkpoint_callback = ModelCheckpoint(dirpath=self.data_dirpath,
                                                  save_top_k=-1)

            early_stop_callback = EarlyStopping(
                monitor=self.early_stop_callback_params['monitor'],
                min_delta=self.early_stop_callback_params['min_delta'],
                patience=self.early_stop_callback_params['patience'],
                verbose=self.early_stop_callback_params['verbose'],
                mode=self.early_stop_callback_params['mode'])

            callbacks = [early_stop_callback, checkpoint_callback]
            if self.num_gpus > 0:
                gpu_stats = GPUStatsMonitor()
                callbacks.append(gpu_stats)
                tb_logger = pl.loggers.TensorBoardLogger(f"{self.log_dirpath}")
            else:
                tb_logger = None

            self.TRAINER = pl.Trainer(
                gpus=self.lightning_params['num_gpus'],
                profiler=self.lightning_params['profiler'],
                max_epochs=self.lightning_params['max_epochs'],
                accumulate_grad_batches=self.
                lightning_params['accumulate_grad_batches'],
                check_val_every_n_epoch=self.
                lightning_params['check_val_every_n_epoch'],
                progress_bar_refresh_rate=self.
                lightning_params['progress_bar_refresh_rate'],
                callbacks=callbacks,
                resume_from_checkpoint=None,
                logger=tb_logger)
Пример #17
0
 def define_callbacks(log_dir):
     checkpoint_callback = ModelCheckpoint(dirpath=log_dir,
                                           monitor='val_loss')
     lr_monitor = LearningRateMonitor(logging_interval='step')
     return [checkpoint_callback, lr_monitor]
Пример #18
0
def test_resume_from_checkpoint_epoch_restored(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
    """Verify resuming from checkpoint runs the right number of epochs"""
    # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir
    monkeypatch.setenv('TORCH_HOME', tmpdir)

    hparams = EvalModelTemplate.get_default_hparams()

    def _new_model():
        # Create a model that tracks epochs and batches seen
        model = EvalModelTemplate(**hparams)
        model.num_epochs_seen = 0
        model.num_batches_seen = 0
        model.num_on_load_checkpoint_called = 0

        def increment_epoch(self):
            self.num_epochs_seen += 1

        def increment_batch(self, batch, batch_idx, dataloader_idx):
            self.num_batches_seen += 1

        def increment_on_load_checkpoint(self, _):
            self.num_on_load_checkpoint_called += 1

        # Bind methods to keep track of epoch numbers, batch numbers it has seen
        # as well as number of times it has called on_load_checkpoint()
        model.on_epoch_end = types.MethodType(increment_epoch, model)
        model.on_train_batch_start = types.MethodType(increment_batch, model)
        model.on_load_checkpoint = types.MethodType(increment_on_load_checkpoint, model)
        return model

    model = _new_model()

    trainer_options = dict(
        progress_bar_refresh_rate=0,
        max_epochs=2,
        limit_train_batches=0.65,
        limit_val_batches=1,
        checkpoint_callback=ModelCheckpoint(tmpdir, monitor='val_loss', save_top_k=-1),
        default_root_dir=tmpdir,
        early_stop_callback=False,
        val_check_interval=1.,
    )

    trainer = Trainer(**trainer_options)
    # fit model
    trainer.fit(model)

    training_batches = trainer.num_training_batches

    assert model.num_epochs_seen == 2
    assert model.num_batches_seen == training_batches * 2
    assert model.num_on_load_checkpoint_called == 0

    # Other checkpoints can be uncommented if/when resuming mid-epoch is supported
    checkpoints = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, '*.ckpt')))
    if url_ckpt:
        # transform local paths into url checkpoints
        ip, port = tmpdir_server
        checkpoints = [f'http://{ip}:{port}/' + os.path.basename(check) for check in checkpoints]

    for check in checkpoints:
        next_model = _new_model()
        state = pl_load(check)

        # Resume training
        trainer_options['max_epochs'] = 2
        new_trainer = Trainer(**trainer_options, resume_from_checkpoint=check)
        new_trainer.fit(next_model)
        assert state['global_step'] + next_model.num_batches_seen == training_batches * trainer_options['max_epochs']
        assert next_model.num_on_load_checkpoint_called == 1
	input_size = 768
	hidden_size = 200
	training_epochs = 30

	model = LSTM_model(input_size, output_size, hidden_size)

	early_stop_callback = EarlyStopping(
		monitor='val_weighted_f1',
		min_delta=0.00,
		patience=10,
		verbose=False,
		mode='max'
	)
    # 3. Init ModelCheckpoint callback, monitoring 'val_loss'
	checkpoint_callback = ModelCheckpoint(
        monitor='val_weighted_f1',
        filepath='/home/yzhan273/Research/MPAA/Severity_Class_Pred/sent_bert/LT_save_model/',
        mode='max')


	trainer = pl.Trainer(fast_dev_run=False, max_epochs=training_epochs, gpus=[6],
						 early_stop_callback=early_stop_callback,
                        checkpoint_callback=checkpoint_callback)
	trainer.fit(model)

# output, final_hidden_state, final_cell_state = model(torch.rand(5, 10, 768), batch_size=5)
# print(output.shape, final_hidden_state.shape, final_cell_state.shape)

# best          precision    recall  f1-score   support
#
#            0     0.5962    0.4697    0.5254        66
#            1     0.5179    0.5321    0.5249       109
Пример #20
0
def train_model_part(conf,
                     train_part='filterbank',
                     pretrained_filterbank=None):
    train_set, val_set, train_loader, val_loader = get_data_loaders(
        conf, train_part=train_part)

    # Define model and optimizer in a local function (defined in the recipe).
    # Two advantages to this : re-instantiating the model and optimizer
    # for retraining and evaluating is straight-forward.
    model, optimizer = make_model_and_optimizer(
        conf,
        model_part=train_part,
        pretrained_filterbank=pretrained_filterbank)
    # Define scheduler
    scheduler = None
    if conf[train_part + '_training'][train_part[0] + '_half_lr']:
        scheduler = ReduceLROnPlateau(optimizer=optimizer,
                                      factor=0.5,
                                      patience=5)
    # Just after instantiating, save the args. Easy loading in the future.
    exp_dir, checkpoint_dir = get_encoded_paths(conf, train_part)
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, 'conf.yml')
    with open(conf_path, 'w') as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.
    loss_func = PITLossWrapper(pairwise_neg_sisdr, mode='pairwise')
    system = SystemTwoStep(model=model,
                           loss_func=loss_func,
                           optimizer=optimizer,
                           train_loader=train_loader,
                           val_loader=val_loader,
                           scheduler=scheduler,
                           config=conf,
                           module=train_part)

    # Define callbacks
    checkpoint = ModelCheckpoint(checkpoint_dir,
                                 monitor='val_loss',
                                 mode='min',
                                 save_top_k=1,
                                 verbose=1)
    early_stopping = False
    if conf[train_part + '_training'][train_part[0] + '_early_stop']:
        early_stopping = EarlyStopping(monitor='val_loss',
                                       patience=10,
                                       verbose=1)
    # Don't ask GPU if they are not available.
    if not torch.cuda.is_available():
        print('No available GPU were found, set gpus to None')
        conf['main_args']['gpus'] = None

    trainer = pl.Trainer(
        max_nb_epochs=conf[train_part + '_training'][train_part[0] +
                                                     '_epochs'],
        checkpoint_callback=checkpoint,
        early_stop_callback=early_stopping,
        default_save_path=exp_dir,
        gpus=conf['main_args']['gpus'],
        distributed_backend='dp',
        train_percent_check=1.0,  # Useful for fast experiment
        gradient_clip_val=5.)
    trainer.fit(system)

    with open(os.path.join(checkpoint_dir, "best_k_models.json"), "w") as file:
        json.dump(checkpoint.best_k_models, file, indent=0)
def test_checkpoint_repeated_strategy_extended(tmpdir):
    """
    This test validates checkpoint can be called several times without
    increasing internally its global step if nothing run.
    """
    class ExtendedBoringModel(BoringModel):
        def validation_step(self, batch, batch_idx):
            output = self.layer(batch)
            loss = self.loss(batch, output)
            return {"val_loss": loss}

        def validation_epoch_end(self, *_):
            ...

    def assert_trainer_init(trainer):
        assert not trainer.checkpoint_connector.has_trained
        assert trainer.global_step == 0
        assert trainer.current_epoch == 0

    def get_last_checkpoint(ckpt_dir):
        last = ckpt_dir.listdir(sort=True)[-1]
        return str(last)

    def assert_checkpoint_content(ckpt_dir):
        chk = pl_load(get_last_checkpoint(ckpt_dir))
        assert chk["epoch"] == epochs
        assert chk["global_step"] == 4

    def assert_checkpoint_log_dir(idx):
        lightning_logs = tmpdir / 'lightning_logs'
        actual = [d.basename for d in lightning_logs.listdir(sort=True)]
        assert actual == [f'version_{i}' for i in range(idx + 1)]
        assert len(ckpt_dir.listdir()) == epochs

    ckpt_dir = tmpdir / 'checkpoints'
    checkpoint_cb = ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1)
    epochs = 2
    limit_train_batches = 2
    trainer_config = dict(
        default_root_dir=tmpdir,
        max_epochs=epochs,
        limit_train_batches=limit_train_batches,
        limit_val_batches=3,
        limit_test_batches=4,
        callbacks=[checkpoint_cb],
    )
    trainer = pl.Trainer(**trainer_config)
    assert_trainer_init(trainer)

    model = ExtendedBoringModel()
    trainer.fit(model)
    assert trainer.checkpoint_connector.has_trained
    assert trainer.global_step == epochs * limit_train_batches
    assert trainer.current_epoch == epochs - 1
    assert_checkpoint_log_dir(0)
    assert_checkpoint_content(ckpt_dir)

    trainer.validate(model)
    assert trainer.current_epoch == epochs - 1

    trainer.test(model)
    assert trainer.current_epoch == epochs - 1

    for idx in range(1, 5):
        chk = get_last_checkpoint(ckpt_dir)
        assert_checkpoint_content(ckpt_dir)

        # load from checkpoint
        trainer_config["callbacks"] = [
            ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1)
        ]
        trainer = pl.Trainer(**trainer_config, resume_from_checkpoint=chk)
        assert_trainer_init(trainer)

        model = ExtendedBoringModel()

        trainer.test(model)
        assert not trainer.checkpoint_connector.has_trained
        # resume_from_checkpoint is resumed when calling `.fit`
        assert trainer.global_step == 0
        assert trainer.current_epoch == 0

        trainer.fit(model)
        assert not trainer.checkpoint_connector.has_trained
        assert trainer.global_step == epochs * limit_train_batches
        assert trainer.current_epoch == epochs
        assert_checkpoint_log_dir(idx)

        trainer.validate(model)
        assert not trainer.checkpoint_connector.has_trained
        assert trainer.global_step == epochs * limit_train_batches
        assert trainer.current_epoch == epochs
        dataloader = DataLoader(self.val_set, batch_size=self.config["training"]["batch_size"],
                                shuffle=False, num_workers=self.config["training"]["num_workers"], drop_last=False)
        return dataloader


if __name__ == "__main__":
    args = parser.parse_args()
    with open(args.conf_file, "r") as f:
        confs = yaml.load(f)

    # test if compatible with lightning
    confs.update(args.__dict__)
    sed = FixMatch(confs)

    checkpoint_dir = os.path.join(confs["log_dir"], 'checkpoints/')
    checkpoint = ModelCheckpoint(checkpoint_dir, monitor='val_loss',
                                 mode='min',  verbose=True, save_top_k=confs["training"]["save_top_k"])

    #early_stop_callback = EarlyStopping(
     #   monitor='val_loss',
      #  patience=confs["training"]["patience"],
       # verbose=True,
        #mode='min'
    #)

    with open(os.path.join(confs["log_dir"], "confs.yml"), "w") as f:
        yaml.dump(confs, f)

    logger = TensorBoardLogger(os.path.dirname(confs["log_dir"]), confs["log_dir"].split("/")[-1])

    trainer = pl.Trainer(max_nb_epochs=confs["training"]["n_epochs"], gpus=confs["gpus"], checkpoint_callback=checkpoint,
                         accumulate_grad_batches=confs["training"]["accumulate_batches"],
def test_model_checkpoint_format_checkpoint_name(tmpdir):
    # empty filename:
    ckpt_name = ModelCheckpoint._format_checkpoint_name('', 3, 2, {})
    assert ckpt_name == 'epoch=3-step=2'

    ckpt_name = ModelCheckpoint._format_checkpoint_name(None,
                                                        3,
                                                        2, {},
                                                        prefix='test')
    assert ckpt_name == 'test-epoch=3-step=2'

    # no groups case:
    ckpt_name = ModelCheckpoint._format_checkpoint_name('ckpt',
                                                        3,
                                                        2, {},
                                                        prefix='test')
    assert ckpt_name == 'test-ckpt'

    # no prefix
    ckpt_name = ModelCheckpoint._format_checkpoint_name(
        '{epoch:03d}-{acc}', 3, 2, {'acc': 0.03})
    assert ckpt_name == 'epoch=003-acc=0.03'

    # prefix
    char_org = ModelCheckpoint.CHECKPOINT_JOIN_CHAR
    ModelCheckpoint.CHECKPOINT_JOIN_CHAR = '@'
    ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch},{acc:.5f}',
                                                        3,
                                                        2, {'acc': 0.03},
                                                        prefix='test')
    assert ckpt_name == 'test@epoch=3,acc=0.03000'
    ModelCheckpoint.CHECKPOINT_JOIN_CHAR = char_org

    # no dirpath set
    ckpt_name = ModelCheckpoint(monitor='early_stop_on',
                                dirpath=None).format_checkpoint_name(3, 2, {})
    assert ckpt_name == 'epoch=3-step=2.ckpt'
    ckpt_name = ModelCheckpoint(monitor='early_stop_on',
                                dirpath='').format_checkpoint_name(5, 4, {})
    assert ckpt_name == 'epoch=5-step=4.ckpt'

    # CWD
    ckpt_name = ModelCheckpoint(monitor='early_stop_on',
                                dirpath='.').format_checkpoint_name(3, 4, {})
    assert ckpt_name == str(Path('.').resolve() / 'epoch=3-step=4.ckpt')

    # with version
    ckpt = ModelCheckpoint(monitor='early_stop_on',
                           dirpath=tmpdir,
                           filename='name')
    ckpt_name = ckpt.format_checkpoint_name(3, 2, {}, ver=3)
    assert ckpt_name == tmpdir / 'name-v3.ckpt'

    # using slashes
    ckpt = ModelCheckpoint(monitor='early_stop_on',
                           dirpath=None,
                           filename='{epoch}_{val/loss:.5f}')
    ckpt_name = ckpt.format_checkpoint_name(4, 3, {'val/loss': 0.03})
    assert ckpt_name == 'epoch=4_val/loss=0.03000.ckpt'

    # auto_insert_metric_name=False
    ckpt_name = ModelCheckpoint._format_checkpoint_name(
        'epoch={epoch:03d}-val_acc={val/acc}',
        3,
        2, {'val/acc': 0.03},
        auto_insert_metric_name=False)
    assert ckpt_name == 'epoch=003-val_acc=0.03'
Пример #24
0
def test_dp_resume(tmpdir):
    """Make sure DP continues training correctly."""
    hparams = EvalModelTemplate.get_default_hparams()
    model = EvalModelTemplate(**hparams)

    trainer_options = dict(
        max_epochs=1,
        gpus=2,
        distributed_backend='dp',
        default_root_dir=tmpdir,
    )

    # get logger
    logger = tutils.get_default_logger(tmpdir)

    # exp file to get weights
    # logger file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    # add these to the trainer options
    trainer_options['logger'] = logger
    trainer_options['checkpoint_callback'] = checkpoint

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.is_slurm_managing_tasks = True
    result = trainer.fit(model)

    # track epoch before saving. Increment since we finished the current epoch, don't want to rerun
    real_global_epoch = trainer.current_epoch + 1

    # correct result and ok accuracy
    assert result == 1, 'amp + dp model failed to complete'

    # ---------------------------
    # HPC LOAD/SAVE
    # ---------------------------
    # save
    trainer.checkpoint_connector.hpc_save(tmpdir, logger)

    # init new trainer
    new_logger = tutils.get_default_logger(tmpdir, version=logger.version)
    trainer_options['logger'] = new_logger
    trainer_options['checkpoint_callback'] = ModelCheckpoint(dirpath=tmpdir)
    trainer_options['limit_train_batches'] = 0.5
    trainer_options['limit_val_batches'] = 0.2
    trainer_options['max_epochs'] = 1
    new_trainer = Trainer(**trainer_options)

    # set the epoch start hook so we can predict before the model does the full training
    def assert_good_acc():
        assert new_trainer.current_epoch == real_global_epoch and new_trainer.current_epoch > 0

        # if model and state loaded correctly, predictions will be good even though we
        # haven't trained with the new loaded model
        dp_model = new_trainer.model
        dp_model.eval()

        dataloader = trainer.train_dataloader
        tpipes.run_prediction(dataloader, dp_model, dp=True)

    # new model
    model = EvalModelTemplate(**hparams)
    model.on_train_start = assert_good_acc

    # fit new model which should load hpc weights
    new_trainer.fit(model)

    # test freeze on gpu
    model.freeze()
    model.unfreeze()
def test_none_every_n_train_steps_val_epochs(tmpdir):
    checkpoint_callback = ModelCheckpoint(dirpath=tmpdir)
    assert checkpoint_callback.period == 1
    assert checkpoint_callback._every_n_val_epochs == 1
    assert checkpoint_callback._every_n_train_steps == 0
Пример #26
0
def main():
    # Setup
    config = get_config()
    prepair_dir(config)
    set_seed(config.train.seed)
    os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(config.base.gpu_id)

    # Preparing for trainer
    base_path = os.path.join(config.save.save_path, f"fold{config.data.fold}")
    model_path = os.path.join(base_path, "model")
    model_name = config.base.yaml
    monitor_metric = "avg_val_loss"
    checkpoint_callback = ModelCheckpoint(
        filepath=model_path,
        save_best_only=True,
        verbose=True,
        monitor=monitor_metric,
        mode="min",
        prefix=model_name,
    )
    logger = TestTubeLogger(
        save_dir=os.path.join(base_path, "logs"),
        name=model_name,
        debug=False,
        create_git_tag=False,
    )
    backend = "ddp" if len(config.base.gpu_id) > 1 else None

    model = Model(config)
    if config.data.is_train:
        trainer = Trainer(
            logger=logger,
            early_stop_callback=False,
            max_nb_epochs=config.train.epoch,
            checkpoint_callback=checkpoint_callback,
            accumulate_grad_batches=config.train.accumulation_steps,
            use_amp=True,
            amp_level="O1",
            gpus=[int(id_) for id_ in config.base.gpu_id],
            distributed_backend=backend,
            show_progress_bar=True,
            train_percent_check=1.0,
            check_val_every_n_epoch=1,
            val_check_interval=1.0,
            val_percent_check=1.0,
            test_percent_check=0.0,
            nb_sanity_val_steps=0,
            nb_gpu_nodes=1,
            print_nan_grads=False,
            track_grad_norm=-1,
            gradient_clip_val=1,
            row_log_interval=1000,
            log_save_interval=10,
        )

        trainer.fit(model)
    else:
        trainer = Trainer(
            logger=False,
            early_stop_callback=False,
            max_nb_epochs=0,
            checkpoint_callback=checkpoint_callback,
            use_amp=True,
            amp_level="O1",
            gpus=[int(id_) for id_ in config.base.gpu_id],
            distributed_backend=backend,
            show_progress_bar=True,
            train_percent_check=0,
            check_val_every_n_epoch=0,
            val_check_interval=0.0,
            val_percent_check=0.0,
            test_percent_check=1.0,
            nb_sanity_val_steps=0,
            nb_gpu_nodes=1,
            print_nan_grads=False,
            track_grad_norm=-1,
        )
        trainer.test(model)
Пример #27
0
    dict_args = vars(args)

    if "accelerator" in dict_args:
        if dict_args["accelerator"] == "None":
            dict_args["accelerator"] = None

    model = LightningMNISTClassifier(**dict_args)

    dm = MNISTDataModule(**dict_args)
    dm.setup(stage="fit")

    early_stopping = EarlyStopping(
        monitor=dict_args["es_monitor"],
        mode=dict_args["es_mode"],
        verbose=dict_args["es_verbose"],
        patience=dict_args["es_patience"],
    )

    checkpoint_callback = ModelCheckpoint(
        dirpath=os.getcwd(), save_top_k=1, verbose=True, monitor="val_loss", mode="min"
    )
    lr_logger = LearningRateMonitor()

    trainer = pl.Trainer.from_argparse_args(
        args,
        callbacks=[lr_logger, early_stopping, checkpoint_callback],
        checkpoint_callback=True,
    )
    trainer.fit(model, dm)
    trainer.test()
Пример #28
0
def test_resume_from_checkpoint_epoch_restored(tmpdir):
    """Verify resuming from checkpoint runs the right number of epochs"""
    import types

    tutils.reset_seed()

    hparams = tutils.get_default_hparams()

    def _new_model():
        # Create a model that tracks epochs and batches seen
        model = LightningTestModel(hparams)
        model.num_epochs_seen = 0
        model.num_batches_seen = 0

        def increment_epoch(self):
            self.num_epochs_seen += 1

        def increment_batch(self, _):
            self.num_batches_seen += 1

        # Bind the increment_epoch function on_epoch_end so that the
        # model keeps track of the number of epochs it has seen.
        model.on_epoch_end = types.MethodType(increment_epoch, model)
        model.on_batch_start = types.MethodType(increment_batch, model)
        return model

    model = _new_model()

    trainer_options = dict(
        progress_bar_refresh_rate=0,
        max_epochs=2,
        train_percent_check=0.65,
        val_percent_check=1,
        checkpoint_callback=ModelCheckpoint(tmpdir, save_top_k=-1),
        default_root_dir=tmpdir,
        early_stop_callback=False,
        val_check_interval=1.,
    )

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.fit(model)

    training_batches = trainer.num_training_batches

    assert model.num_epochs_seen == 2
    assert model.num_batches_seen == training_batches * 2

    # Other checkpoints can be uncommented if/when resuming mid-epoch is supported
    checkpoints = sorted(
        glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, '*.ckpt')))

    for check in checkpoints:
        next_model = _new_model()
        state = torch.load(check)

        # Resume training
        trainer_options['max_epochs'] = 2
        new_trainer = Trainer(**trainer_options, resume_from_checkpoint=check)
        new_trainer.fit(next_model)
        assert state[
            'global_step'] + next_model.num_batches_seen == training_batches * trainer_options[
                'max_epochs']
Пример #29
0
        else:
            sampler = None
        return DataLoader(val_dataset,
                          shuffle=(sampler is None),
                          sampler=sampler,
                          num_workers=4,
                          batch_size=self.hparams.batch_size,
                          pin_memory=True)


if __name__ == '__main__':
    hparams = get_opts()
    system = MVSSystem(hparams)
    checkpoint_callback = ModelCheckpoint(
        filepath=os.path.join('ckpts', hparams.exp_name),
        monitor='val/acc_2mm',
        mode='max',
        save_top_k=5,
    )

    logger = TestTubeLogger(save_dir="logs",
                            name=hparams.exp_name,
                            debug=False,
                            create_git_tag=False)

    trainer = Trainer(
        max_epochs=hparams.num_epochs,
        checkpoint_callback=checkpoint_callback,
        logger=logger,
        early_stop_callback=None,
        weights_summary=None,
        gpus=hparams.num_gpus,
Пример #30
0
def main(args):

    # get current dir as set by hydra
    run_path = os.getcwd()
    print(run_path)

    # Build dataset
    dataset = GPDataset(
        root=args.paths.datadir,
        min_context=args.min_context,
        max_context=args.max_context,
        n_points=args.total_points,
        **args.dataset,
    )

    pl.seed_everything(args.seed)

    steer_cnp = LightningSteerCNP(**args.model)
    datamodule = LightningGPDataModule(dataset,
                                       batch_size=args.batch_size,
                                       splits=args.splits)

    log_dir = os.path.join(args.paths.logdir, dataset.name, steer_cnp.name)
    run_name = args.experiment_name
    run_dir = os.path.join(log_dir, run_name)
    run_dir = interpolate_filename(run_dir)

    loggers = [
        # Log results to a csv file
        CSVLogger(save_dir="", name="logs", version=""),
        # Log data to tensorboard
        TensorBoardLogger(save_dir="", name="logs", version=""),
    ]

    callbacks = [
        # Callback to save recent + best validation checkpoint
        ModelCheckpoint(
            dirpath="checkpoints",
            monitor="val_ll",
            mode="max",
            save_last=True,
        ),
        # Callback to plot inferences on validation
        InferencePlotCallback(n_plots=3, dirpath="plots"),
        # Callback to plot comparison of the inference to the GP drawn from
        GPComparePlotCallback(
            n_plots=3,
            dirpath="plots",
            kernel=dataset.get_kernel(),
            obs_noise=dataset.obs_noise,
        ),
    ]

    trainer = pl.Trainer(
        max_epochs=args.epochs,
        # default_root_dir=args.paths.logdir,
        logger=loggers,
        callbacks=callbacks,
        log_every_n_steps=args.log_every_n_steps,
        flush_logs_every_n_steps=args.flush_logs_every_n_steps,
        val_check_interval=args.val_check_interval,
        gpus=int(torch.cuda.is_available()),
        log_gpu_memory=args.log_gpu_memory,
    )

    trainer.fit(steer_cnp, datamodule)
    trainer.test(steer_cnp, datamodule=datamodule)