Пример #1
0
def init_checkpoint_callback(logger, path_dir=None):
    exp_path = get_data_path(logger, path_dir=path_dir)
    ckpt_dir = os.path.join(exp_path, 'checkpoints')
    checkpoint = ModelCheckpoint(ckpt_dir)
    return checkpoint
Пример #2
0
def single_run(
    config,
    log_dir,
    gpus,
    checkpoint_resume=None,
    test_state_dict=None,
    fast_dev_run=False,
):
    """
    Running sound event detection baselin

    Args:
        config (dict): the dictionary of configuration params
        log_dir (str): path to log directory
        gpus (int): number of gpus to use
        checkpoint_resume (str, optional): path to checkpoint to resume from. Defaults to "".
        test_state_dict (dict, optional): if not None, no training is involved. This dictionary is the state_dict
            to be loaded to test the model.
        fast_dev_run (bool, optional): whether to use a run with only one batch at train and validation, useful
            for development purposes.
    """
    config.update({"log_dir": log_dir})

    ##### data prep test ##########
    encoder = ManyHotEncoder(
        list(classes_labels.keys()),
        audio_len=config["data"]["audio_max_len"],
        frame_len=config["feats"]["n_filters"],
        frame_hop=config["feats"]["hop_length"],
        net_pooling=config["data"]["net_subsample"],
        fs=config["data"]["fs"],
    )

    devtest_df = pd.read_csv(config["data"]["test_tsv"], sep="\t")
    devtest_dataset = StronglyAnnotatedSet(
        config["data"]["test_folder"],
        devtest_df,
        encoder,
        return_filename=True,
        pad_to=config["data"]["audio_max_len"],
    )

    test_dataset = devtest_dataset

    ##### model definition  ############
    sed_student = CRNN(**config["net"])

    if test_state_dict is None:
        ##### data prep train valid ##########
        synth_df = pd.read_csv(config["data"]["synth_tsv"], sep="\t")
        synth_set = StronglyAnnotatedSet(
            config["data"]["synth_folder"],
            synth_df,
            encoder,
            pad_to=config["data"]["audio_max_len"],
        )

        weak_df = pd.read_csv(config["data"]["weak_tsv"], sep="\t")
        train_weak_df = weak_df.sample(frac=config["training"]["weak_split"],
                                       random_state=config["training"]["seed"])
        valid_weak_df = weak_df.drop(
            train_weak_df.index).reset_index(drop=True)
        train_weak_df = train_weak_df.reset_index(drop=True)
        weak_set = WeakSet(
            config["data"]["weak_folder"],
            train_weak_df,
            encoder,
            pad_to=config["data"]["audio_max_len"],
        )

        unlabeled_set = UnlabelledSet(
            config["data"]["unlabeled_folder"],
            encoder,
            pad_to=config["data"]["audio_max_len"],
        )

        synth_df_val = pd.read_csv(config["data"]["synth_val_tsv"], sep="\t")
        synth_val = StronglyAnnotatedSet(
            config["data"]["synth_val_folder"],
            synth_df_val,
            encoder,
            return_filename=True,
            pad_to=config["data"]["audio_max_len"],
        )

        weak_val = WeakSet(
            config["data"]["weak_folder"],
            valid_weak_df,
            encoder,
            pad_to=config["data"]["audio_max_len"],
            return_filename=True,
        )

        tot_train_data = [synth_set, weak_set, unlabeled_set]
        train_dataset = torch.utils.data.ConcatDataset(tot_train_data)

        batch_sizes = config["training"]["batch_size"]
        samplers = [torch.utils.data.RandomSampler(x) for x in tot_train_data]
        batch_sampler = ConcatDatasetBatchSampler(samplers, batch_sizes)

        valid_dataset = torch.utils.data.ConcatDataset([synth_val, weak_val])

        ##### training params and optimizers ############
        epoch_len = min([
            len(tot_train_data[indx]) //
            (config["training"]["batch_size"][indx] *
             config["training"]["accumulate_batches"])
            for indx in range(len(tot_train_data))
        ])

        opt = torch.optim.Adam(sed_student.parameters(),
                               1e-3,
                               betas=(0.9, 0.999))
        exp_steps = config["training"]["n_epochs_warmup"] * epoch_len
        exp_scheduler = {
            "scheduler": ExponentialWarmup(opt, config["opt"]["lr"],
                                           exp_steps),
            "interval": "step",
        }
        logger = TensorBoardLogger(
            os.path.dirname(config["log_dir"]),
            config["log_dir"].split("/")[-1],
        )
        print(f"experiment dir: {logger.log_dir}")

        callbacks = [
            EarlyStopping(monitor="val/obj_metric",
                          patience=config["training"]["early_stop_patience"],
                          verbose=True,
                          mode="max"),
            ModelCheckpoint(logger.log_dir,
                            monitor="val/obj_metric",
                            save_top_k=1,
                            mode="max",
                            save_last=True),
        ]
    else:
        train_dataset = None
        valid_dataset = None
        batch_sampler = None
        opt = None
        exp_scheduler = None
        logger = True
        callbacks = None

    desed_training = SEDTask4_2021(
        config,
        encoder=encoder,
        sed_student=sed_student,
        opt=opt,
        train_data=train_dataset,
        valid_data=valid_dataset,
        test_data=test_dataset,
        train_sampler=batch_sampler,
        scheduler=exp_scheduler,
        fast_dev_run=fast_dev_run,
    )

    # Not using the fast_dev_run of Trainer because creates a DummyLogger so cannot check problems with the Logger
    if fast_dev_run:
        flush_logs_every_n_steps = 1
        log_every_n_steps = 1
        limit_train_batches = 2
        limit_val_batches = 2
        limit_test_batches = 2
        n_epochs = 3
    else:
        flush_logs_every_n_steps = 100
        log_every_n_steps = 40
        limit_train_batches = 1.
        limit_val_batches = 1.
        limit_test_batches = 1.
        n_epochs = config["training"]["n_epochs"]

    trainer = pl.Trainer(
        max_epochs=n_epochs,
        callbacks=callbacks,
        gpus=gpus,
        distributed_backend=config["training"].get("backend"),
        accumulate_grad_batches=config["training"]["accumulate_batches"],
        logger=logger,
        resume_from_checkpoint=checkpoint_resume,
        gradient_clip_val=config["training"]["gradient_clip"],
        check_val_every_n_epoch=config["training"]["validation_interval"],
        num_sanity_val_steps=0,
        log_every_n_steps=log_every_n_steps,
        flush_logs_every_n_steps=flush_logs_every_n_steps,
        limit_train_batches=limit_train_batches,
        limit_val_batches=limit_val_batches,
        limit_test_batches=limit_test_batches,
    )

    if test_state_dict is None:
        trainer.fit(desed_training)
        best_path = trainer.checkpoint_callback.best_model_path
        print(f"best model: {best_path}")
        test_state_dict = torch.load(best_path)["state_dict"]

    desed_training.load_state_dict(test_state_dict)
    trainer.test(desed_training)
Пример #3
0
def train_task(init,
               close,
               exp_cfg_path,
               env_cfg_path,
               task_nr,
               logger_pass=None):
    seed_everything(42)
    local_rank = int(os.environ.get('LOCAL_RANK', 0))
    if local_rank != 0 or not init:
        print(init, local_rank)
        rm = exp_cfg_path.find('cfg/exp/') + len('cfg/exp/')
        exp_cfg_path = os.path.join(exp_cfg_path[:rm], 'tmp/',
                                    exp_cfg_path[rm:])

    exp = load_yaml(exp_cfg_path)
    env = load_yaml(env_cfg_path)

    if local_rank == 0 and init:
        # Set in name the correct model path
        if exp.get('timestamp', True):
            timestamp = datetime.datetime.now().replace(
                microsecond=0).isoformat()

            model_path = os.path.join(env['base'], exp['name'])
            p = model_path.split('/')
            model_path = os.path.join('/', *p[:-1],
                                      str(timestamp) + '_' + p[-1])
        else:
            model_path = os.path.join(env['base'], exp['name'])
            try:
                shutil.rmtree(model_path)
            except:
                pass
        # Create the directory
        if not os.path.exists(model_path):
            try:
                os.makedirs(model_path)
            except:
                print("Failed generating network run folder")
        else:
            print("Network run folder already exits")

        # Only copy config files for the main ddp-task
        exp_cfg_fn = os.path.split(exp_cfg_path)[-1]
        env_cfg_fn = os.path.split(env_cfg_path)[-1]
        print(f'Copy {env_cfg_path} to {model_path}/{exp_cfg_fn}')
        shutil.copy(exp_cfg_path, f'{model_path}/{exp_cfg_fn}')
        shutil.copy(env_cfg_path, f'{model_path}/{env_cfg_fn}')
        exp['name'] = model_path
    else:
        # the correct model path has already been written to the yaml file.
        model_path = os.path.join(exp['name'], f'rank_{local_rank}_{task_nr}')
        # Create the directory
        if not os.path.exists(model_path):
            try:
                os.makedirs(model_path)
            except:
                pass

    # if local_rank == 0 and env['workstation'] == False:
    #     cm = open(os.path.join(model_path, f'info{local_rank}_{task_nr}.log'), 'w')
    # else:
    #     cm = nullcontext()
    # with cm as f:
    #   if local_rank == 0 and env['workstation'] == False:
    #     cm2 = redirect_stdout(f)
    #   else:
    #     cm2 = nullcontext()
    #   with cm2:
    # # Setup logger for each ddp-task
    # logging.getLogger("lightning").setLevel(logging.DEBUG)
    # logger = logging.getLogger("lightning")
    # fh = logging.FileHandler( , 'a')
    # logger.addHandler(fh)

    # Copy Dataset from Scratch to Nodes SSD

    if env['workstation'] == False:
        # use proxy hack for neptunai !!!
        NeptuneLogger._create_or_get_experiment = _create_or_get_experiment2

        # move data to ssd
        if exp['move_datasets'][0]['env_var'] != 'none':
            for dataset in exp['move_datasets']:
                scratchdir = os.getenv('TMPDIR')
                env_var = dataset['env_var']
                tar = os.path.join(env[env_var], f'{env_var}.tar')
                name = (tar.split('/')[-1]).split('.')[0]

                if not os.path.exists(
                        os.path.join(scratchdir, dataset['env_var'])):

                    try:
                        cmd = f"tar -xvf {tar} -C $TMPDIR >/dev/null 2>&1"
                        st = time.time()
                        print(f'Start moveing dataset-{env_var}: {cmd}')
                        os.system(cmd)
                        env[env_var] = str(os.path.join(scratchdir, name))
                        print(
                            f'Finished moveing dataset-{env_var} in {time.time()-st}s'
                        )

                    except:
                        rank_zero_warn('ENV Var' + env_var)
                        env[env_var] = str(os.path.join(scratchdir, name))
                        rank_zero_warn('Copying data failed')
                else:
                    env[env_var] = str(os.path.join(scratchdir, name))
        else:
            env['mlhypersim'] = str(
                os.path.join(env['mlhypersim'], 'mlhypersim'))

    if (exp['trainer']).get('gpus', -1):
        nr = torch.cuda.device_count()
        exp['trainer']['gpus'] = nr
        print(f'Set GPU Count for Trainer to {nr}!')

    model = Network(exp=exp, env=env)

    lr_monitor = LearningRateMonitor(**exp['lr_monitor']['cfg'])

    if exp['cb_early_stopping']['active']:
        early_stop_callback = EarlyStopping(**exp['cb_early_stopping']['cfg'])
        cb_ls = [early_stop_callback, lr_monitor]
    else:
        cb_ls = [lr_monitor]

    tses = TaskSpecificEarlyStopping(
        nr_tasks=exp['task_generator']['total_tasks'],
        **exp['task_specific_early_stopping'])
    cb_ls.append(tses)
    if local_rank == 0:
        for i in range(exp['task_generator']['total_tasks']):
            if i == task_nr:
                m = '/'.join(
                    [a for a in model_path.split('/') if a.find('rank') == -1])

                dic = copy.deepcopy(exp['cb_checkpoint']['cfg'])
                # try:
                #   if len(exp['cb_checkpoint'].get('nameing',[])) > 0:
                #     #filepath += '-{task_name:10s}'
                #     for m in exp['cb_checkpoint']['nameing']:
                #       filepath += '-{'+ m + ':.2f}'
                # except:
                #   pass
                # dic['monitor'] += str(i)
                checkpoint_callback = ModelCheckpoint(
                    dirpath=m,
                    filename='task' + str(i) + '-{epoch:02d}--{step:06d}',
                    **dic)

                cb_ls.append(checkpoint_callback)

    params = log_important_params(exp)

    if env['workstation']:
        t1 = 'workstation'
    else:
        t1 = 'leonhard'

    # if local_rank == 0:
    cwd = os.getcwd()
    files = [
        str(p).replace(cwd + '/', '') for p in Path(cwd).rglob('*.py')
        if str(p).find('vscode') == -1
    ]
    files.append(exp_cfg_path)
    files.append(env_cfg_path)

    if not exp.get('offline_mode', False):
        # if exp.get('experiment_id',-1) == -1:
        #create new experiment_id and write back
        if logger_pass is None:
            logger = NeptuneLogger(
                api_key=os.environ["NEPTUNE_API_TOKEN"],
                project_name="jonasfrey96/asl",
                experiment_name=exp['name'].split('/')[-2] + "_" +
                exp['name'].split('/')[-1],  # Optional,
                params=params,  # Optional,
                tags=[
                    t1, exp['name'].split('/')[-2], exp['name'].split('/')[-1]
                ] + exp["tag_list"],  # Optional,
                close_after_fit=False,
                offline_mode=exp.get('offline_mode', False),
                upload_source_files=files,
                upload_stdout=False,
                upload_stderr=False)
            exp['experiment_id'] = logger.experiment.id
            print('created experiment id' + str(exp['experiment_id']))
        else:
            logger = logger_pass

        # else:
        # print('loaded experiment id' +  str( exp['experiment_id']))
        # TODO
        # logger = NeptuneLogger(
        #   api_key=os.environ["NEPTUNE_API_TOKEN"],
        #   project_name="jonasfrey96/asl",
        #   experiment_name= exp['name'].split('/')[-2] +"_"+ exp['name'].split('/')[-1], # Optional,
        #   params=params, # Optional,
        #   tags=[t1, exp['name'].split('/')[-2], exp['name'].split('/')[-1]] + exp["tag_list"], # Optional,
        #   close_after_fit = False,
        #   offline_mode = exp.get('offline_mode', False),
        #   upload_source_files=files,
        #   upload_stdout=False,
        #   upload_stderr=False
        # )

        # logger = NeptuneLogger(
        #   api_key=os.environ["NEPTUNE_API_TOKEN"],
        #   project_name="jonasfrey96/asl",
        #   experiment_id=exp.get('experiment_id',-1),
        #   close_after_fit = False,
        # )
        print('Neptune Experiment ID: ' + str(logger.experiment.id) +
              " TASK NR " + str(task_nr))
    else:
        logger = TensorBoardLogger(
            save_dir=model_path,
            name='tensorboard',  # Optional,
            default_hp_metric=params,  # Optional,
        )
    # else:
    #   logger = TensorBoardLogger(
    #       save_dir=model_path+'/rank/'+str(local_rank),
    #       name= exp['name'].split('/')[-2] +"_"+ exp['name'].split('/')[-1], # Optional,
    #   )

    weight_restore = exp.get('weights_restore', False)
    checkpoint_load = exp['checkpoint_load']

    if local_rank == 0 and init:
        # write back the exp file with the correct name set to the model_path!
        # other ddp-task dont need to care about timestamps
        # also storeing the path to the latest.ckpt that downstream tasks can restore the model state
        exp['weights_restore_2'] = False
        exp['checkpoint_restore_2'] = True
        exp['checkpoint_load_2'] = os.path.join(model_path, 'last.ckpt')

        rm = exp_cfg_path.find('cfg/exp/') + len('cfg/exp/')
        exp_cfg_path = os.path.join(exp_cfg_path[:rm], 'tmp/',
                                    exp_cfg_path[rm:])
        Path(exp_cfg_path).parent.mkdir(parents=True, exist_ok=True)
        with open(exp_cfg_path, 'w+') as f:
            yaml.dump(exp, f, default_flow_style=False, sort_keys=False)

    if not init:
        # restore model state from previous task.
        exp['checkpoint_restore'] = exp['checkpoint_restore_2']
        exp['checkpoint_load'] = exp['checkpoint_load_2']
        exp['weights_restore'] = exp['weights_restore_2']

    # Always use advanced profiler
    if exp['trainer'].get('profiler', False):
        exp['trainer']['profiler'] = AdvancedProfiler(
            output_filename=os.path.join(model_path, 'profile.out'))
    else:
        exp['trainer']['profiler'] = False

    # print( exp['trainer'] )
    # print(os.environ.get('GLOBAL_RANK'))
    if exp.get('checkpoint_restore', False):
        p = os.path.join(env['base'], exp['checkpoint_load'])
        trainer = Trainer(**exp['trainer'],
                          default_root_dir=model_path,
                          callbacks=cb_ls,
                          resume_from_checkpoint=p,
                          logger=logger)
    else:
        trainer = Trainer(**exp['trainer'],
                          default_root_dir=model_path,
                          callbacks=cb_ls,
                          logger=logger)

    if exp['weights_restore']:
        # it is not strict since the latent replay buffer is not always available
        p = os.path.join(env['base'], exp['checkpoint_load'])
        if os.path.isfile(p):
            res = model.load_state_dict(torch.load(
                p, map_location=lambda storage, loc: storage)['state_dict'],
                                        strict=False)
            print('Restoring weights: ' + str(res))
        else:
            raise Exception('Checkpoint not a file')

    main_visu = MainVisualizer(p_visu=os.path.join(model_path, 'main_visu'),
                               logger=logger,
                               epoch=0,
                               store=True,
                               num_classes=22)

    tc = TaskCreator(**exp['task_generator'],
                     output_size=exp['model']['input_size'])
    print(tc)
    _task_start_training = time.time()
    _task_start_time = time.time()

    for idx, out in enumerate(tc):
        if idx == task_nr:
            break

    if True:
        #for idx, out in enumerate(tc):
        task, eval_lists = out
        main_visu.epoch = idx
        # New Logger
        print(f'<<<<<<<<<<<< TASK IDX {idx} TASK NAME : ' + task.name +
              ' >>>>>>>>>>>>>')

        model._task_name = task.name
        model._task_count = idx
        dataloader_train, dataloader_buffer = get_dataloader_train(
            d_train=task.dataset_train_cfg, env=env, exp=exp)
        print(str(dataloader_train.dataset))
        print(str(dataloader_buffer.dataset))
        dataloader_list_test = eval_lists_into_dataloaders(eval_lists,
                                                           env=env,
                                                           exp=exp)
        print(f'<<<<<<<<<<<< All Datasets are loaded and set up >>>>>>>>>>>>>')
        #Training the model
        trainer.should_stop = False
        # print("GLOBAL STEP ", model.global_step)
        for d in dataloader_list_test:
            print(str(d.dataset))

        if idx < exp['start_at_task']:
            # trainer.limit_val_batches = 1.0
            trainer.limit_train_batches = 1
            trainer.max_epochs = 1
            trainer.check_val_every_n_epoch = 1
            train_res = trainer.fit(model=model,
                                    train_dataloader=dataloader_train,
                                    val_dataloaders=dataloader_list_test)

            trainer.max_epochs = exp['trainer']['max_epochs']
            trainer.check_val_every_n_epoch = exp['trainer'][
                'check_val_every_n_epoch']
            trainer.limit_val_batches = exp['trainer']['limit_val_batches']
            trainer.limit_train_batches = exp['trainer']['limit_train_batches']
        else:
            print('Train', dataloader_train)
            print('Val', dataloader_list_test)
            train_res = trainer.fit(model=model,
                                    train_dataloader=dataloader_train,
                                    val_dataloaders=dataloader_list_test)
        res = trainer.logger_connector.callback_metrics
        res_store = {}
        for k in res.keys():
            try:
                res_store[k] = float(res[k])
            except:
                pass
        base_path = '/'.join(
            [a for a in model_path.split('/') if a.find('rank') == -1])
        with open(f"{base_path}/res{task_nr}.pkl", "wb") as f:
            pickle.dump(res_store, f)

        print(f'<<<<<<<<<<<< TASK IDX {idx} TASK NAME : ' + task.name +
              ' Trained >>>>>>>>>>>>>')

        if exp.get('buffer', {}).get('fill_after_fit', False):
            print(f'<<<<<<<<<<<< Performance Test to Get Buffer >>>>>>>>>>>>>')

            trainer.test(model=model, test_dataloaders=dataloader_buffer)

            if local_rank == 0:
                checkpoint_callback.save_checkpoint(trainer, model)
            print(f'<<<<<<<<<<<< Performance Test DONE >>>>>>>>>>>>>')

        number_validation_dataloaders = len(dataloader_list_test)

        if model._rssb_active:
            # visualize rssb
            bins, valids = model._rssb.get()
            fill_status = (bins != 0).sum(axis=1)
            main_visu.plot_bar(fill_status,
                               x_label='Bin',
                               y_label='Filled',
                               title='Fill Status per Bin',
                               sort=False,
                               reverse=False,
                               tag='Buffer_Fill_Status')

        plot_from_pkl(main_visu, base_path, task_nr)

    try:
        if close:
            logger.experiment.stop()
    except:
        pass
Пример #4
0
    dm = MultiLabelImageClassificationDatamodule(16, train_transform,
                                                 val_transform)
    model = MultiLabelImageClassifier(pos_weight)

    # Fast run first
    trainer = Trainer(gpus=1,
                      fast_dev_run=True,
                      checkpoint_callback=False,
                      logger=False)
    trainer.fit(model, dm)

    checkpoint_callback = ModelCheckpoint(
        filepath=os.getcwd(),
        save_top_k=2,
        verbose=True,
        monitor="val/loss",
        mode="min",
    )

    experiment_name = ...
    PROJECT_NAME = ...

    logger = WandbLogger(name=experiment_name, project=PROJECT_NAME)

    # And then actual training
    pl.seed_everything(42)
    trainer = Trainer(
        max_epochs=40,
        logger=logger,
        gpus=1,
Пример #5
0
def main(conf):
    exp_dir = conf["main_args"]["exp_dir"]
    # Define Dataloader
    train_loader, val_loader = make_dataloaders(**conf["data"],
                                                **conf["training"])
    conf["masknet"].update({"n_src": conf["data"]["n_src"]})
    # Define model, optimizer + scheduler
    model, optimizer = make_model_and_optimizer(conf)
    scheduler = None
    if conf["training"]["half_lr"]:
        scheduler = ReduceLROnPlateau(optimizer=optimizer,
                                      factor=0.5,
                                      patience=5)

    # Save config
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, "conf.yml")
    with open(conf_path, "w") as outfile:
        yaml.safe_dump(conf, outfile)

    # Define loss function
    loss_func = ChimeraLoss(alpha=conf["training"]["loss_alpha"])
    # Put together in System
    system = ChimeraSystem(
        model=model,
        loss_func=loss_func,
        optimizer=optimizer,
        train_loader=train_loader,
        val_loader=val_loader,
        scheduler=scheduler,
        config=conf,
    )

    # Callbacks
    checkpoint_dir = os.path.join(exp_dir, "checkpoints/")
    checkpoint = ModelCheckpoint(checkpoint_dir,
                                 monitor="val_loss",
                                 mode="min",
                                 save_top_k=5,
                                 verbose=1)
    early_stopping = False
    if conf["training"]["early_stop"]:
        early_stopping = EarlyStopping(monitor="val_loss",
                                       patience=30,
                                       verbose=1)
    gpus = -1
    # Don't ask GPU if they are not available.
    if not torch.cuda.is_available():
        print("No available GPU were found, set gpus to None")
        gpus = None

    # Train model
    trainer = pl.Trainer(
        max_nb_epochs=conf["training"]["epochs"],
        checkpoint_callback=checkpoint,
        early_stop_callback=early_stopping,
        default_save_path=exp_dir,
        gpus=gpus,
        distributed_backend="dp",
        train_percent_check=1.0,  # Useful for fast experiment
        gradient_clip_val=200,
    )
    trainer.fit(system)

    best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
    with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
        json.dump(best_k, f, indent=0)
    # Save last model for convenience
    torch.save(system.model.state_dict(),
               os.path.join(exp_dir, "checkpoints/final.pth"))
def test_dp_resume(tmpdir):
    """Make sure DP continues training correctly."""
    if not tutils.can_run_gpu_test():
        return

    tutils.reset_seed()

    hparams = tutils.get_hparams()
    model = LightningTestModel(hparams)

    trainer_options = dict(
        show_progress_bar=True,
        max_epochs=2,
        gpus=2,
        distributed_backend='dp',
    )

    # get logger
    logger = tutils.get_test_tube_logger(tmpdir, debug=False)

    # exp file to get weights
    # logger file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    # add these to the trainer options
    trainer_options['logger'] = logger
    trainer_options['checkpoint_callback'] = checkpoint

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.is_slurm_managing_tasks = True
    result = trainer.fit(model)

    # track epoch before saving
    real_global_epoch = trainer.current_epoch

    # correct result and ok accuracy
    assert result == 1, 'amp + dp model failed to complete'

    # ---------------------------
    # HPC LOAD/SAVE
    # ---------------------------
    # save
    trainer.hpc_save(tmpdir, logger)

    # init new trainer
    new_logger = tutils.get_test_tube_logger(tmpdir, version=logger.version)
    trainer_options['logger'] = new_logger
    trainer_options['checkpoint_callback'] = ModelCheckpoint(tmpdir)
    trainer_options['train_percent_check'] = 0.2
    trainer_options['val_percent_check'] = 0.2
    trainer_options['max_epochs'] = 1
    new_trainer = Trainer(**trainer_options)

    # set the epoch start hook so we can predict before the model does the full training
    def assert_good_acc():
        assert new_trainer.current_epoch == real_global_epoch and new_trainer.current_epoch > 0

        # if model and state loaded correctly, predictions will be good even though we
        # haven't trained with the new loaded model
        dp_model = new_trainer.model
        dp_model.eval()

        dataloader = trainer.get_train_dataloader()
        tutils.run_prediction(dataloader, dp_model, dp=True)

    # new model
    model = LightningTestModel(hparams)
    model.on_train_start = assert_good_acc

    # fit new model which should load hpc weights
    new_trainer.fit(model)

    # test freeze on gpu
    model.freeze()
    model.unfreeze()
Пример #7
0
        return self.transform(img), self.transform(img)


if __name__ == "__main__":
    ds = ImagesDataset(args.image_folder, IMAGE_SIZE)

    train_loader = DataLoader(ds,
                              batch_size=BATCH_SIZE,
                              num_workers=NUM_WORKERS,
                              shuffle=True,
                              drop_last=True)

    cl = ContrastiveLearning()

    logger = TensorBoardLogger(save_dir='lightning_logs', name='logs')

    checkpoint_callback = ModelCheckpoint(period=10, save_top_k=-1)

    trainer = pl.Trainer(gpus=NUM_GPUS,
                         distributed_backend='ddp',
                         max_epochs=EPOCHS,
                         accumulate_grad_batches=1,
                         callbacks=[checkpoint_callback],
                         logger=logger
                         # resume_from_checkpoint = '*.ckpt'
                         )

    trainer.sync_batchnorm = True
    trainer.fit(cl, train_loader)
def test_none_every_n_train_steps_val_epochs(tmpdir):
    checkpoint_callback = ModelCheckpoint(dirpath=tmpdir)
    assert checkpoint_callback.every_n_epochs == 1
    assert checkpoint_callback._every_n_train_steps == 0
def test_model_checkpoint_score_and_ckpt(
    tmpdir, validation_step_none: bool, val_dataloaders_none: bool, monitor: str, reduce_lr_on_plateau: bool
):
    """Test that when a model checkpoint is saved, it saves with the correct score appended to ckpt_path and
    checkpoint data."""
    max_epochs = 3
    limit_train_batches = 5
    limit_val_batches = 7
    lr, gamma = 1e-1, 2

    class CustomBoringModel(BoringModel):
        def __init__(self):
            super().__init__()
            self.train_log_epochs = torch.randn(max_epochs, limit_train_batches)
            self.val_logs = torch.randn(max_epochs, limit_val_batches)
            self.scores = []

        def training_step(self, batch, batch_idx):
            log_value = self.train_log_epochs[self.current_epoch, batch_idx]
            self.log("train_log", log_value, on_epoch=True)
            return super().training_step(batch, batch_idx)

        def validation_step(self, batch, batch_idx):
            log_value = self.val_logs[self.current_epoch, batch_idx]
            self.log("val_log", log_value)
            self.log("epoch", self.current_epoch, on_epoch=True)
            return super().validation_step(batch, batch_idx)

        def configure_optimizers(self):
            optimizer = optim.SGD(self.parameters(), lr=lr)

            if reduce_lr_on_plateau:
                lr_scheduler = {
                    "scheduler": optim.lr_scheduler.ReduceLROnPlateau(optimizer),
                    "monitor": monitor,
                    "strict": True,
                }
            else:
                lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma)

            return [optimizer], [lr_scheduler]

        def on_train_epoch_end(self):
            if "train" in monitor:
                self.scores.append(self.trainer.logged_metrics[monitor])

        def on_validation_epoch_end(self):
            if not self.trainer.sanity_checking and "val" in monitor:
                self.scores.append(self.trainer.logged_metrics[monitor])

    filename = "{" + f"{monitor}" + ":.4f}-{epoch}"
    checkpoint = ModelCheckpoint(dirpath=tmpdir, filename=filename, monitor=monitor, save_top_k=-1)

    model = CustomBoringModel()

    if validation_step_none:
        model.validation_step = None
    if val_dataloaders_none:
        model.val_dataloaders = None

    trainer = Trainer(
        default_root_dir=tmpdir,
        callbacks=[checkpoint],
        limit_train_batches=limit_train_batches,
        limit_val_batches=limit_val_batches,
        max_epochs=max_epochs,
        enable_progress_bar=False,
    )
    calls = mock_training_epoch_loop(trainer)
    trainer.fit(model)

    ckpt_files = list(Path(tmpdir).glob("*.ckpt"))
    assert len(ckpt_files) == len(model.scores) == max_epochs

    for epoch in range(max_epochs):
        score = model.scores[epoch]
        expected_score = getattr(model, f"{monitor}s")[epoch].mean().item()
        expected_filename = f"{monitor}={score:.4f}-epoch={epoch}.ckpt"
        assert math.isclose(score, expected_score, rel_tol=1e-4)

        chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename))
        assert chk["epoch"] == epoch + 1
        assert chk["global_step"] == limit_train_batches * (epoch + 1)

        mc_specific_data = chk["callbacks"][
            f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
            " 'train_time_interval': None, 'save_on_train_epoch_end': True}"
        ]
        assert mc_specific_data["dirpath"] == checkpoint.dirpath
        assert mc_specific_data["monitor"] == monitor
        assert mc_specific_data["current_score"] == score

        if not reduce_lr_on_plateau:
            actual_step_count = chk["lr_schedulers"][0]["_step_count"]
            actual_lr = chk["lr_schedulers"][0]["_last_lr"][0]
            # checkpoint is saved after updating lr_scheduler states
            assert actual_step_count == epoch + 2  # step_count starts at 1
            assert actual_lr == lr * gamma ** (epoch + 1)
        else:
            assert calls[epoch] == {monitor: score}
def test_model_checkpoint_score_and_ckpt_val_check_interval(
    tmpdir, val_check_interval, reduce_lr_on_plateau, epoch_aligned
):
    """Test that when a model checkpoint is saved, it saves with the correct score appended to ckpt_path and
    checkpoint data with val_check_interval."""
    seed_everything(0)
    max_epochs = 3
    limit_train_batches = 12
    limit_val_batches = 7
    lr, gamma = 1e-1, 2
    monitor = "val_log"
    per_val_train_batches = int(limit_train_batches * val_check_interval)
    per_epoch_val_checks, leftover_train_batches = divmod(limit_train_batches, per_val_train_batches)

    class CustomBoringModel(BoringModel):
        def __init__(self):
            super().__init__()
            self.val_logs = torch.randn(per_epoch_val_checks * max_epochs, limit_val_batches)
            self.val_loop_count = 0
            self.scores = []

        def validation_step(self, batch, batch_idx):
            log_value = self.val_logs[self.val_loop_count, batch_idx]
            self.log("val_log", log_value)
            return super().validation_step(batch, batch_idx)

        def validation_epoch_end(self, outputs):
            self.val_loop_count += 1
            super().validation_epoch_end(outputs)
            self.scores.append(self.trainer.logged_metrics[monitor])

        def configure_optimizers(self):
            optimizer = optim.SGD(self.parameters(), lr=lr)

            if reduce_lr_on_plateau:
                lr_scheduler = {
                    "scheduler": optim.lr_scheduler.ReduceLROnPlateau(optimizer),
                    "monitor": monitor,
                    "strict": True,
                }
            else:
                lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma)

            return [optimizer], [lr_scheduler]

    filename = "{" + f"{monitor}" + ":.4f}-{epoch}"
    checkpoint = ModelCheckpoint(dirpath=tmpdir, filename=filename, monitor=monitor, save_top_k=-1)

    model = CustomBoringModel()

    trainer = Trainer(
        default_root_dir=tmpdir,
        callbacks=[checkpoint],
        limit_train_batches=limit_train_batches,
        limit_val_batches=limit_val_batches,
        max_epochs=max_epochs,
        val_check_interval=val_check_interval,
        enable_progress_bar=False,
        num_sanity_val_steps=0,
    )
    calls = mock_training_epoch_loop(trainer)
    trainer.fit(model)

    def _make_assertions(epoch, ix):
        global_ix = ix + per_epoch_val_checks * epoch

        # checkpoint saved at the end of training epoch will have updated lr_scheduler states
        epoch_end_checkpoint = epoch_aligned and ix == (per_epoch_val_checks - 1)

        score = model.scores[global_ix]
        expected_score = getattr(model, f"{monitor}s")[global_ix].mean().item()
        expected_filename = f"{monitor}={score:.4f}-epoch={epoch}.ckpt"
        assert math.isclose(score, expected_score, rel_tol=1e-4)

        chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename))
        assert chk["epoch"] == epoch + 1
        expected_global_step = per_val_train_batches * (global_ix + 1) + (leftover_train_batches * epoch)
        assert chk["global_step"] == expected_global_step

        mc_specific_data = chk["callbacks"][
            f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
            " 'train_time_interval': None, 'save_on_train_epoch_end': False}"
        ]
        assert mc_specific_data["dirpath"] == checkpoint.dirpath
        assert mc_specific_data["monitor"] == monitor
        assert mc_specific_data["current_score"] == score

        if not reduce_lr_on_plateau:
            actual_step_count = chk["lr_schedulers"][0]["_step_count"]
            actual_lr = chk["lr_schedulers"][0]["_last_lr"][0]
            assert actual_step_count == epoch + 1 + epoch_end_checkpoint
            assert actual_lr == lr * gamma ** (epoch + epoch_end_checkpoint)

        return score

    ckpt_files = list(Path(tmpdir).glob("*.ckpt"))
    assert len(ckpt_files) == len(model.scores) == per_epoch_val_checks * max_epochs

    for epoch in range(max_epochs):
        for i in range(per_epoch_val_checks):
            score = _make_assertions(epoch, i)

        if reduce_lr_on_plateau:
            assert calls[epoch] == {monitor: score}
def test_invalid_top_k(tmpdir):
    """Make sure that a MisconfigurationException is raised for a negative save_top_k argument."""
    with pytest.raises(MisconfigurationException, match=r".*Must be >= -1"):
        ModelCheckpoint(dirpath=tmpdir, save_top_k=-3)
def test_model_checkpoint_mode_options():
    with pytest.raises(MisconfigurationException, match="`mode` can be .* but got unknown_option"):
        ModelCheckpoint(mode="unknown_option")
Пример #13
0
    vocab_path = args.vocab_path
    max_length = args.max_length
    batch_size = args.batch_size
    epochs = args.epochs
    output_path = args.output_dir
    eval_interval = args.eval_interval
    lr = args.lr
    warmup_steps = args.warmup_steps
    data_path = args.data_path
    config_path = args.config_path
    t_total = args.t_total

    checkpoint_callback = ModelCheckpoint(
        dirpath=output_path,
        verbose=True,
#         period=1,
        save_top_k=1,
        monitor="val_loss",
        mode="min",
    )
    learning_rate_callback = LearningRateMonitor()
    trainer = pl.Trainer(
        default_root_dir=output_path,
        gradient_clip_val=1,
        max_epochs=epochs,
        gpus=args.device,
#         distributed_backend="dp",
        val_check_interval=eval_interval,
        callbacks=[learning_rate_callback, checkpoint_callback],
        precision=32,
    )
    net = Net(
Пример #14
0
def test_model_checkpoint_options(tmp_path):
    """Test ModelCheckpoint options."""
    def mock_save_function(filepath):
        open(filepath, 'a').close()

    hparams = tutils.get_hparams()
    _ = LightningTestModel(hparams)

    # simulated losses
    save_dir = tmp_path / "1"
    save_dir.mkdir()
    losses = [10, 9, 2.8, 5, 2.5]

    # -----------------
    # CASE K=-1  (all)
    w = ModelCheckpoint(save_dir, save_top_k=-1, verbose=1)
    w.save_function = mock_save_function
    for i, loss in enumerate(losses):
        w.on_epoch_end(i, logs={'val_loss': loss})

    file_lists = set(os.listdir(save_dir))

    assert len(file_lists) == len(
        losses), "Should save all models when save_top_k=-1"

    # verify correct naming
    for i in range(0, len(losses)):
        assert f'_ckpt_epoch_{i}.ckpt' in file_lists

    save_dir = tmp_path / "2"
    save_dir.mkdir()

    # -----------------
    # CASE K=0 (none)
    w = ModelCheckpoint(save_dir, save_top_k=0, verbose=1)
    w.save_function = mock_save_function
    for i, loss in enumerate(losses):
        w.on_epoch_end(i, logs={'val_loss': loss})

    file_lists = os.listdir(save_dir)

    assert len(file_lists) == 0, "Should save 0 models when save_top_k=0"

    save_dir = tmp_path / "3"
    save_dir.mkdir()

    # -----------------
    # CASE K=1 (2.5, epoch 4)
    w = ModelCheckpoint(save_dir,
                        save_top_k=1,
                        verbose=1,
                        prefix='test_prefix')
    w.save_function = mock_save_function
    for i, loss in enumerate(losses):
        w.on_epoch_end(i, logs={'val_loss': loss})

    file_lists = set(os.listdir(save_dir))

    assert len(file_lists) == 1, "Should save 1 model when save_top_k=1"
    assert 'test_prefix_ckpt_epoch_4.ckpt' in file_lists

    save_dir = tmp_path / "4"
    save_dir.mkdir()

    # -----------------
    # CASE K=2 (2.5 epoch 4, 2.8 epoch 2)
    # make sure other files don't get deleted

    w = ModelCheckpoint(save_dir, save_top_k=2, verbose=1)
    open(f'{save_dir}/other_file.ckpt', 'a').close()
    w.save_function = mock_save_function
    for i, loss in enumerate(losses):
        w.on_epoch_end(i, logs={'val_loss': loss})

    file_lists = set(os.listdir(save_dir))

    assert len(file_lists) == 3, 'Should save 2 model when save_top_k=2'
    assert '_ckpt_epoch_4.ckpt' in file_lists
    assert '_ckpt_epoch_2.ckpt' in file_lists
    assert 'other_file.ckpt' in file_lists

    save_dir = tmp_path / "5"
    save_dir.mkdir()

    # -----------------
    # CASE K=4 (save all 4 models)
    # multiple checkpoints within same epoch

    w = ModelCheckpoint(save_dir, save_top_k=4, verbose=1)
    w.save_function = mock_save_function
    for loss in losses:
        w.on_epoch_end(0, logs={'val_loss': loss})

    file_lists = set(os.listdir(save_dir))

    assert len(
        file_lists
    ) == 4, 'Should save all 4 models when save_top_k=4 within same epoch'

    save_dir = tmp_path / "6"
    save_dir.mkdir()

    # -----------------
    # CASE K=3 (save the 2nd, 3rd, 4th model)
    # multiple checkpoints within same epoch

    w = ModelCheckpoint(save_dir, save_top_k=3, verbose=1)
    w.save_function = mock_save_function
    for loss in losses:
        w.on_epoch_end(0, logs={'val_loss': loss})

    file_lists = set(os.listdir(save_dir))

    assert len(file_lists) == 3, 'Should save 3 models when save_top_k=3'
    assert '_ckpt_epoch_0_v2.ckpt' in file_lists
    assert '_ckpt_epoch_0_v1.ckpt' in file_lists
    assert '_ckpt_epoch_0.ckpt' in file_lists
Пример #15
0
def init_model(hyper_params, num_workers, output, validation_interval, gpu,
               no_model_restoring, debug, print_sentence_stats):

    check_and_log_hp(
        [
            "train_file", "dev_files", "test_file", "batch_size",
            "tokenizer_name", "model", "max_question_len", "max_paragraph_len",
            "patience", "gradient_clipping", "max_epochs", "loss_type",
            "optimizer", "precision", "accumulate_grad_batches", "seed",
            "logging", "keep_ood"
        ],
        hyper_params,
    )

    if hyper_params["seed"] is not None:
        # fix the seed
        torch.manual_seed(hyper_params["seed"])
        np.random.seed(hyper_params["seed"])
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    tokenizer_name = hyper_params["tokenizer_name"]
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    ret = load_model(hyper_params, tokenizer, debug)

    os.makedirs(output, exist_ok=True)
    checkpoint_callback = ModelCheckpoint(
        filepath=os.path.join(output,
                              "{epoch}-{val_acc_0:.2f}-{val_loss_0:.2f}"),
        save_top_k=1,
        verbose=True,
        monitor="val_acc_0",
        mode="max",
        period=0,
    )
    early_stopping = EarlyStopping("val_acc_0",
                                   mode="max",
                                   patience=hyper_params["patience"])

    if (hyper_params["model"].get("name") == "bert_encoder"
            and hyper_params["model"].get("cache_size", 0) > 0):
        cbs = [CacheManagerCallback(ret, output)]
    else:
        cbs = []

    if hyper_params["precision"] not in {16, 32}:
        raise ValueError("precision should be either 16 or 32")
    if not no_model_restoring:
        ckpt_to_resume = try_to_restore_model_weights(output)
    else:
        ckpt_to_resume = None
        logger.info(
            "will not try to restore previous models because --no-model-restoring"
        )
    if hyper_params["logging"]["logger"] == "tensorboard":
        pl_logger = loggers.TensorBoardLogger("experiment_logs")
        for hparam in list(hyper_params):
            pl_logger.experiment.add_text(hparam, str(hyper_params[hparam]))
    elif hyper_params["logging"]["logger"] == "wandb":
        orion_trial_id = os.environ.get('ORION_TRIAL_ID')
        name = orion_trial_id if orion_trial_id else hyper_params["logging"][
            "name"]
        pl_logger = WandbLogger(
            name=name,
            project=hyper_params["logging"]["project"],
            group=hyper_params["logging"]["group"],
        )
        pl_logger.log_hyperparams(hyper_params)
    else:
        raise ValueError(
            logger.info("logger {} is not implemnted".format(
                hyper_params["logging"]["logger"])))

    trainer = pl.Trainer(
        logger=pl_logger,
        gpus=gpu,
        distributed_backend="dp",
        val_check_interval=validation_interval,
        min_epochs=1,
        gradient_clip_val=hyper_params["gradient_clipping"],
        checkpoint_callback=checkpoint_callback,
        early_stop_callback=early_stopping,
        callbacks=cbs,
        precision=hyper_params["precision"],
        resume_from_checkpoint=ckpt_to_resume,
        accumulate_grad_batches=hyper_params["accumulate_grad_batches"],
        max_epochs=hyper_params["max_epochs"],
    )

    dev_dataloaders, test_dataloader, train_dataloader = get_data_loaders(
        hyper_params, num_workers, tokenizer)

    if print_sentence_stats:
        evaluate_tokenizer_cutoff(
            hyper_params["train_file"],
            tokenizer,
            hyper_params["max_question_len"],
            hyper_params["max_paragraph_len"],
        )

    ret_trainee = RetrieverTrainer(
        ret,
        train_dataloader,
        dev_dataloaders,
        test_dataloader,
        hyper_params["loss_type"],
        hyper_params["optimizer"],
    )
    return ckpt_to_resume, ret_trainee, trainer
def test_checkpoint_repeated_strategy_extended(tmpdir):
    """This test validates checkpoint can be called several times without increasing internally its global step if
    nothing run."""

    class ExtendedBoringModel(BoringModel):
        def validation_step(self, batch, batch_idx):
            output = self.layer(batch)
            loss = self.loss(batch, output)
            return {"val_loss": loss}

        def validation_epoch_end(self, *_):
            ...

    def assert_trainer_init(trainer):
        assert trainer.global_step == 0
        assert trainer.current_epoch == 0

    def get_last_checkpoint(ckpt_dir):
        last = ckpt_dir.listdir(sort=True)[-1]
        return str(last)

    def assert_checkpoint_content(ckpt_dir):
        chk = pl_load(get_last_checkpoint(ckpt_dir))
        assert chk["epoch"] == epochs
        assert chk["global_step"] == 4

    def assert_checkpoint_log_dir(idx):
        lightning_logs = tmpdir / "lightning_logs"
        actual = [d.basename for d in lightning_logs.listdir(sort=True)]
        assert actual == [f"version_{i}" for i in range(idx + 1)]
        assert len(ckpt_dir.listdir()) == epochs

    ckpt_dir = tmpdir / "checkpoints"
    checkpoint_cb = ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1)
    epochs = 2
    limit_train_batches = 2
    trainer_config = dict(
        default_root_dir=tmpdir,
        max_epochs=epochs,
        limit_train_batches=limit_train_batches,
        limit_val_batches=3,
        limit_test_batches=4,
        callbacks=[checkpoint_cb],
    )
    trainer = pl.Trainer(**trainer_config)
    assert_trainer_init(trainer)

    model = ExtendedBoringModel()
    trainer.fit(model)
    assert trainer.global_step == epochs * limit_train_batches
    assert trainer.current_epoch == epochs - 1
    assert_checkpoint_log_dir(0)
    assert_checkpoint_content(ckpt_dir)

    trainer.validate(model)
    assert trainer.current_epoch == epochs - 1

    trainer.test(model)
    assert trainer.current_epoch == epochs - 1

    for idx in range(1, 5):
        chk = get_last_checkpoint(ckpt_dir)
        assert_checkpoint_content(ckpt_dir)

        # load from checkpoint
        trainer_config["callbacks"] = [ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1)]
        trainer = pl.Trainer(**trainer_config)
        assert_trainer_init(trainer)

        model = ExtendedBoringModel()

        trainer.test(model)
        assert trainer.global_step == 0
        assert trainer.current_epoch == 0

        trainer.fit(model, ckpt_path=chk)
        assert trainer.global_step == epochs * limit_train_batches
        assert trainer.current_epoch == epochs

        trainer.validate(model)
        assert trainer.global_step == epochs * limit_train_batches
        assert trainer.current_epoch == epochs

        trainer.fit(model)
        assert trainer.global_step == epochs * limit_train_batches
        assert trainer.current_epoch == epochs
        assert_checkpoint_log_dir(idx)
Пример #17
0
    model = DeepHsAblationStudyModule(hparams)
    logger = WandbLogger(hparams['git_id'], offline=not opt.online_logging,
                         save_dir=opt.log_path, project='deephs')

    early_stop_callback = EarlyStopping(
        monitor='val_loss',
        min_delta=0.00,
        verbose=True,
        mode='min',
        patience=20
    )

    checkpoint_callback = ModelCheckpoint(
        filepath='best.ckpt',
        save_top_k=1,
        verbose=True,
        monitor='val_loss',
        mode='min'
    )

    trainer = lightning.Trainer(max_epochs=opt.num_epochs, gpus=-1, logger=logger,
                                early_stop_callback=early_stop_callback, min_epochs=50,
                                checkpoint_callback=checkpoint_callback, callbacks=[LRLoggingCallback()])

    trainer.fit(model)
    best_model = DeepHsAblationStudyModule.load_from_checkpoint(checkpoint_callback.best_model_path)

    print("Best model..")
    trainer.test(best_model)
Пример #18
0
def main(config):
    pl.seed_everything(config.seed)
    gpus = [0] if torch.cuda.is_available() else None

    filepath_list_train=generate_pathlist.make_datapath_list(
        config.project_dir+config.train_dir,
    )
    dm = image_datamodule.ImageDataModule(
        filepath_list_train=filepath_list_train,
        filepath_list_test=filepath_list_train,
        )

    discriminator=dcgan.Discriminator()
    generator=dcgan.Generator()
    criterion=nn.BCEWithLogitsLoss(reduction="mean")
    model = GAN(
        discriminator=discriminator,
        generator=generator,
        criterion=criterion,
        **dc.asdict(config),
        )

    mlflow_tags={}
    mlflow_tags["mlflow.user"]=config.user
    mlflow_tags["mlflow.source.name"]=str(os.path.abspath(__file__)).replace("/",'\\')
    mlf_logger = MLFlowLogger(
        experiment_name=config.experiment_name,
        tracking_uri=config.tracking_uri,
        tags=mlflow_tags
        )

    now=datetime.datetime.now(pytz.timezone('Asia/Tokyo')).strftime('%Y-%m-%d-%H-%M-%S')
    ckpt_path=f"{config.checkpoint_dir}{now}_{mlf_logger.run_id}"
    checkpoint_callback = ModelCheckpoint(
        filepath=ckpt_path,
        save_top_k=None,
        monitor=None,
    )

    trainer = pl.Trainer(
        max_epochs=config.max_epochs,
        logger=mlf_logger,
        gpus=gpus,
        checkpoint_callback=checkpoint_callback,
        resume_from_checkpoint=None,
        )
    trainer.fit(model, datamodule=dm)

    # save log, model, and config to mlflow
    mlf_logger.experiment.log_artifact(mlf_logger.run_id,
                                config.log_dir+"/"+config.log_normal)
    mlf_logger.experiment.log_artifact(mlf_logger.run_id,
                                config.log_dir+"/"+config.log_error)

    with tempfile.TemporaryDirectory() as dname:
        for ckptfile in glob.glob(f"{ckpt_path}*"):
            model=model.load_from_checkpoint(checkpoint_path=ckptfile)
            with tempfile.TemporaryDirectory() as dname:
                filepath = pathlib.Path(dname).joinpath(f"{pathlib.Path(ckptfile).stem}.pth")
                torch.save(model.state_dict(),filepath)
                mlf_logger.experiment.log_artifact(mlf_logger.run_id,filepath)

    with tempfile.TemporaryDirectory() as dname:
        filepath = pathlib.Path(dname).joinpath("config.yml")
        config.save(filepath)
        mlf_logger.experiment.log_artifact(mlf_logger.run_id,filepath)
def test_cpu_restore_training(tmpdir):
    """Verify continue training session on CPU."""
    tutils.reset_seed()

    hparams = tutils.get_hparams()
    model = LightningTestModel(hparams)

    # logger file to get meta
    test_logger_version = 10
    logger = tutils.get_test_tube_logger(tmpdir,
                                         False,
                                         version=test_logger_version)

    trainer_options = dict(max_epochs=8,
                           val_check_interval=0.50,
                           val_percent_check=0.2,
                           train_percent_check=0.2,
                           logger=logger,
                           checkpoint_callback=ModelCheckpoint(tmpdir,
                                                               save_top_k=-1))

    # fit model
    trainer = Trainer(**trainer_options)
    result = trainer.fit(model)
    real_global_epoch = trainer.current_epoch

    # traning complete
    assert result == 1, 'amp + ddp model failed to complete'

    # wipe-out trainer and model
    # retrain with not much data... this simulates picking training back up after slurm
    # we want to see if the weights come back correctly
    new_logger = tutils.get_test_tube_logger(tmpdir,
                                             False,
                                             version=test_logger_version)
    trainer_options = dict(
        max_epochs=2,
        val_check_interval=0.50,
        val_percent_check=0.2,
        train_percent_check=0.2,
        logger=new_logger,
        checkpoint_callback=ModelCheckpoint(tmpdir),
    )
    trainer = Trainer(**trainer_options)
    model = LightningTestModel(hparams)

    # set the epoch start hook so we can predict before the model does the full training
    def assert_good_acc():
        assert trainer.current_epoch == real_global_epoch
        assert trainer.current_epoch >= 0

        # if model and state loaded correctly, predictions will be good even though we
        # haven't trained with the new loaded model
        trainer.model.eval()
        for dataloader in trainer.get_val_dataloaders():
            tutils.run_prediction(dataloader, trainer.model)

    model.on_train_start = assert_good_acc

    # by calling fit again, we trigger training, loading weights from the cluster
    # and our hook to predict using current model before any more weight updates
    trainer.fit(model)
Пример #20
0
            args.logname if args.logname is not None else "perf.json")
        callbacks = [
            LoggingCallback(
                log_dir=log_dir,
                global_batch_size=batch_size * args.gpus,
                mode=args.exec_mode,
                warmup=args.warmup,
                dim=args.dim,
                profile=args.profile,
            )
        ]
    elif args.exec_mode == "train":
        model = NNUnet(args)
        if args.save_ckpt:
            model_ckpt = ModelCheckpoint(monitor="dice_sum",
                                         mode="max",
                                         save_last=True)
        callbacks = [
            EarlyStopping(monitor="dice_sum",
                          patience=args.patience,
                          verbose=True,
                          mode="max")
        ]
    else:  # Evaluation or inference
        if ckpt_path is not None:
            model = NNUnet.load_from_checkpoint(ckpt_path)
        else:
            model = NNUnet(args)

    trainer = Trainer(
        logger=False,
Пример #21
0
def main():
    config_filepath = str(sys.argv[1])
    cfg = load_config(filepath=config_filepath)
    pprint.pprint(cfg)
    cfg = munchify(cfg)
    seed(cfg)
    seed_everything(cfg.seed)

    log_dir = '_'.join([
        cfg.log_dir,
        str(cfg.if_sound),
        str(cfg.if_vision),
        str(cfg.if_depth), cfg.depth_representation, cfg.model_name,
        str(cfg.if_all_input_data), cfg.output_representation,
        str(cfg.seed)
    ])

    model = SoundBoxModel(lr=cfg.lr,
                          seed=cfg.seed,
                          if_cuda=cfg.if_cuda,
                          if_test=False,
                          gamma=cfg.gamma,
                          log_dir=log_dir,
                          train_batch=cfg.train_batch,
                          val_batch=cfg.val_batch,
                          test_batch=cfg.test_batch,
                          num_workers=cfg.num_workers,
                          in_channels=cfg.in_channels,
                          model_name=cfg.model_name,
                          num_branches=cfg.num_branches,
                          branches_in_channels=cfg.branches_in_channels,
                          data_filepath=cfg.data_filepath,
                          shapes=cfg.shapes,
                          if_sound=cfg.if_sound,
                          if_vision=cfg.if_vision,
                          if_depth=cfg.if_depth,
                          if_all_input_data=cfg.if_all_input_data,
                          depth_representation=cfg.depth_representation,
                          output_representation=cfg.output_representation,
                          lr_schedule=cfg.schedule,
                          test_hsv_threshold_lst=cfg.test_hsv_threshold_lst)

    # define callback for selecting checkpoints during training
    checkpoint_callback = ModelCheckpoint(
        filepath=log_dir + "/lightning_logs/checkpoints/{epoch}_{iou_score}",
        verbose=True,
        monitor='iou_score',
        mode='max',
        prefix='')

    # define trainer
    trainer = Trainer(gpus=cfg.num_gpus,
                      max_epochs=cfg.epochs,
                      deterministic=True,
                      accelerator='ddp',
                      amp_backend='native',
                      default_root_dir=log_dir,
                      val_check_interval=1.0,
                      checkpoint_callback=checkpoint_callback)

    trainer.fit(model)
def test__training_step__log(tmpdir):
    """Tests that only training_step can be used."""
    class TestModel(BoringModel):
        def training_step(self, batch, batch_idx):
            out = super().training_step(batch, batch_idx)
            loss = out["loss"]

            # -----------
            # default
            # -----------
            self.log("default", loss)

            # -----------
            # logger
            # -----------
            # on_step T on_epoch F
            self.log("l_s",
                     loss,
                     on_step=True,
                     on_epoch=False,
                     prog_bar=False,
                     logger=True)

            # on_step F on_epoch T
            self.log("l_e",
                     loss,
                     on_step=False,
                     on_epoch=True,
                     prog_bar=False,
                     logger=True)

            # on_step T on_epoch T
            self.log("l_se",
                     loss,
                     on_step=True,
                     on_epoch=True,
                     prog_bar=False,
                     logger=True)

            # -----------
            # pbar
            # -----------
            # on_step T on_epoch F
            self.log("p_s",
                     loss,
                     on_step=True,
                     on_epoch=False,
                     prog_bar=True,
                     logger=False)

            # on_step F on_epoch T
            self.log("p_e",
                     loss,
                     on_step=False,
                     on_epoch=True,
                     prog_bar=True,
                     logger=False)

            # on_step T on_epoch T
            self.log("p_se",
                     loss,
                     on_step=True,
                     on_epoch=True,
                     prog_bar=True,
                     logger=False)

            return loss

    model = TestModel()
    model.val_dataloader = None

    trainer = Trainer(
        default_root_dir=tmpdir,
        limit_train_batches=2,
        limit_val_batches=2,
        max_epochs=2,
        log_every_n_steps=1,
        enable_model_summary=False,
        callbacks=[ModelCheckpoint(monitor="l_se")],
    )
    trainer.fit(model)

    logged_metrics = set(trainer.logged_metrics)
    assert logged_metrics == {
        "default", "l_e", "l_s", "l_se_step", "l_se_epoch"
    }

    pbar_metrics = set(trainer.progress_bar_metrics)
    assert pbar_metrics == {"p_e", "p_s", "p_se_step", "p_se_epoch"}

    assert set(trainer.callback_metrics) == (logged_metrics | pbar_metrics
                                             | {"p_se", "l_se"})
    assert all(
        isinstance(v, torch.Tensor) for v in trainer.callback_metrics.values())
    assert all(
        isinstance(v, torch.Tensor) for v in trainer.logged_metrics.values())
    assert all(
        isinstance(v, float) for v in trainer.progress_bar_metrics.values())
Пример #23
0
def test_v1_5_0_model_checkpoint_period(tmpdir):
    with no_warning_call(DeprecationWarning):
        ModelCheckpoint(dirpath=tmpdir)
    with pytest.deprecated_call(
            match="is deprecated in v1.3 and will be removed in v1.5"):
        ModelCheckpoint(dirpath=tmpdir, period=1)
Пример #24
0
def test_model_checkpoint_options(tmp_path):
    """Test ModelCheckpoint options."""
    def mock_save_function(filepath):
        open(filepath, 'a').close()

    hparams = tutils.get_hparams()
    _ = LightningTestModel(hparams)

    # simulated losses
    save_dir = tmp_path / "1"
    save_dir.mkdir()
    losses = [10, 9, 2.8, 5, 2.5]

    # -----------------
    # CASE K=-1  (all)
    checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=-1, verbose=1)
    checkpoint_callback.save_function = mock_save_function
    trainer = Trainer()

    # emulate callback's calls during the training
    for i, loss in enumerate(losses):
        trainer.current_epoch = i
        trainer.callback_metrics = {'val_loss': loss}
        checkpoint_callback.on_validation_end(trainer, trainer.get_model())

    file_lists = set(os.listdir(save_dir))

    assert len(file_lists) == len(
        losses), "Should save all models when save_top_k=-1"

    # verify correct naming
    for fname in {
            '_epoch=4_val_loss=2.50.ckpt', '_epoch=3_val_loss=5.00.ckpt',
            '_epoch=2_val_loss=2.80.ckpt', '_epoch=1_val_loss=9.00.ckpt',
            '_epoch=0_val_loss=10.00.ckpt'
    }:
        assert fname in file_lists

    save_dir = tmp_path / "2"
    save_dir.mkdir()

    # -----------------
    # CASE K=0 (none)
    checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=0, verbose=1)
    checkpoint_callback.save_function = mock_save_function
    trainer = Trainer()

    # emulate callback's calls during the training
    for i, loss in enumerate(losses):
        trainer.current_epoch = i
        trainer.callback_metrics = {'val_loss': loss}
        checkpoint_callback.on_validation_end(trainer, trainer.get_model())

    file_lists = os.listdir(save_dir)

    assert len(file_lists) == 0, "Should save 0 models when save_top_k=0"

    save_dir = tmp_path / "3"
    save_dir.mkdir()

    # -----------------
    # CASE K=1 (2.5, epoch 4)
    checkpoint_callback = ModelCheckpoint(save_dir,
                                          save_top_k=1,
                                          verbose=1,
                                          prefix='test_prefix')
    checkpoint_callback.save_function = mock_save_function
    trainer = Trainer()

    # emulate callback's calls during the training
    for i, loss in enumerate(losses):
        trainer.current_epoch = i
        trainer.callback_metrics = {'val_loss': loss}
        checkpoint_callback.on_validation_end(trainer, trainer.get_model())

    file_lists = set(os.listdir(save_dir))

    assert len(file_lists) == 1, "Should save 1 model when save_top_k=1"
    assert 'test_prefix_epoch=4_val_loss=2.50.ckpt' in file_lists

    save_dir = tmp_path / "4"
    save_dir.mkdir()

    # -----------------
    # CASE K=2 (2.5 epoch 4, 2.8 epoch 2)
    # make sure other files don't get deleted

    checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=2, verbose=1)
    open(f"{save_dir}/other_file.ckpt", 'a').close()
    checkpoint_callback.save_function = mock_save_function
    trainer = Trainer()

    # emulate callback's calls during the training
    for i, loss in enumerate(losses):
        trainer.current_epoch = i
        trainer.callback_metrics = {'val_loss': loss}
        checkpoint_callback.on_validation_end(trainer, trainer.get_model())

    file_lists = set(os.listdir(save_dir))

    assert len(file_lists) == 3, 'Should save 2 model when save_top_k=2'
    for fname in {
            '_epoch=4_val_loss=2.50.ckpt', '_epoch=2_val_loss=2.80.ckpt',
            'other_file.ckpt'
    }:
        assert fname in file_lists

    save_dir = tmp_path / "5"
    save_dir.mkdir()

    # -----------------
    # CASE K=4 (save all 4 models)
    # multiple checkpoints within same epoch

    checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=4, verbose=1)
    checkpoint_callback.save_function = mock_save_function
    trainer = Trainer()

    # emulate callback's calls during the training
    for loss in losses:
        trainer.current_epoch = 0
        trainer.callback_metrics = {'val_loss': loss}
        checkpoint_callback.on_validation_end(trainer, trainer.get_model())

    file_lists = set(os.listdir(save_dir))

    assert len(
        file_lists
    ) == 4, 'Should save all 4 models when save_top_k=4 within same epoch'

    save_dir = tmp_path / "6"
    save_dir.mkdir()

    # -----------------
    # CASE K=3 (save the 2nd, 3rd, 4th model)
    # multiple checkpoints within same epoch

    checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=3, verbose=1)
    checkpoint_callback.save_function = mock_save_function
    trainer = Trainer()

    # emulate callback's calls during the training
    for loss in losses:
        trainer.current_epoch = 0
        trainer.callback_metrics = {'val_loss': loss}
        checkpoint_callback.on_validation_end(trainer, trainer.get_model())

    file_lists = set(os.listdir(save_dir))

    assert len(file_lists) == 3, 'Should save 3 models when save_top_k=3'
    for fname in {
            '_epoch=0_val_loss=2.80.ckpt', '_epoch=0_val_loss=2.50.ckpt',
            '_epoch=0_val_loss=5.00.ckpt'
    }:
        assert fname in file_lists
Пример #25
0
# models = {}
# for method_path in parent_path.glob("*"):
#     for ckpt_path in method_path.glob("*"):
#         ckpt_path = str(ckpt_path)
#         models[method_path] = CellTyper.load_from_checkpoint(ckpt_path)
# new_trainer = pl.Trainer(resume_from_checkpoint='checkpoints/0_standard/epoch=9-avg_val_loss=0.80.ckpt')
seed_everything(0)

model_test = CellTyper.load_from_checkpoint(
    'checkpoints/0_variational/epoch=29-avg_val_loss=-161639.77.ckpt')
data = CellTyperDataModule(model_test.hparams)
data.setup('fit')
data.setup('test')

checkpoint_callback = ModelCheckpoint(monitor=model_test.hparams.monitor,
                                      dirpath=model_test.hparams.dirpath,
                                      filename=model_test.hparams.filename,
                                      save_top_k=model_test.hparams.save_top_k)

wandb_logger = WandbLogger(name=model_test.hparams.wandb_name,
                           project=model_test.hparams.wandb_project)

trainer = Trainer.from_argparse_args(model_test.hparams,
                                     logger=wandb_logger,
                                     callbacks=[checkpoint_callback])

trainer.test(model_test, data.test_dataloader())
# print(data)

# test_loader = data.test_dataloader()
# for batch in test_loader:
#     print(batch)
Пример #26
0
def test_resume_from_checkpoint_epoch_restored(tmpdir):
    """Verify resuming from checkpoint runs the right number of epochs"""
    import types

    tutils.reset_seed()

    hparams = tutils.get_hparams()

    def _new_model():
        # Create a model that tracks epochs and batches seen
        model = LightningTestModel(hparams)
        model.num_epochs_seen = 0
        model.num_batches_seen = 0

        def increment_epoch(self):
            self.num_epochs_seen += 1

        def increment_batch(self, _):
            self.num_batches_seen += 1

        # Bind the increment_epoch function on_epoch_end so that the
        # model keeps track of the number of epochs it has seen.
        model.on_epoch_end = types.MethodType(increment_epoch, model)
        model.on_batch_start = types.MethodType(increment_batch, model)
        return model

    model = _new_model()

    trainer_options = dict(
        show_progress_bar=False,
        max_epochs=2,
        train_percent_check=0.65,
        val_percent_check=1,
        checkpoint_callback=ModelCheckpoint(tmpdir, save_top_k=-1),
        logger=False,
        default_save_path=tmpdir,
        early_stop_callback=False,
        val_check_interval=1.,
    )

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.fit(model)

    training_batches = trainer.num_training_batches

    assert model.num_epochs_seen == 2
    assert model.num_batches_seen == training_batches * 2

    # Other checkpoints can be uncommented if/when resuming mid-epoch is supported
    checkpoints = sorted(
        glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, '*.ckpt')))

    for check in checkpoints:
        next_model = _new_model()
        state = torch.load(check)

        # Resume training
        trainer_options['max_epochs'] = 4
        new_trainer = Trainer(**trainer_options, resume_from_checkpoint=check)
        new_trainer.fit(next_model)
        assert state[
            'global_step'] + next_model.num_batches_seen == training_batches * 4
Пример #27
0
def train_model(train_glob: str, tensorboard_root: str, max_epochs: int,
                num_samples: int, batch_size: int, num_workers: int,
                learning_rate: int, accelerator: str, model_save_path: str):
    """
    method to train and validate the model

    :param train_glob: Input sentences from the batch
    :param tensorboard_root: Path to save the tensorboard logs
    :param max_epochs: Maximum number of epochs
    :param num_samples: Maximum number of samples to train the model
    :param batch_size: Number of samples ImportError: sys.meta_path is None, Python is likely shutting downper batch
    :param num_workers: Number of cores to train the model
    :param learning_rate: Learning rate used to train the model
    :param accelerator: single or multi GPU
    :param model_save_path: Path for the model to be saved
    :param bucket_name: Name of the S3 bucket
    :param folder_name: Name of the folder to write in S3
    :param webapp_path: Path to save the web content
    """

    if accelerator == "None":
        accelerator = None

    dict_args = {
        "train_glob": train_glob,
        "max_epochs": max_epochs,
        "num_samples": num_samples,
        "batch_size": batch_size,
        "num_workers": num_workers,
        "lr": learning_rate,
        "accelerator": accelerator,
    }

    dm = BertDataModule(**dict_args)
    dm.prepare_data()
    dm.setup(stage="fit")

    model = BertNewsClassifier(**dict_args)
    early_stopping = EarlyStopping(monitor="val_loss",
                                   mode="min",
                                   verbose=True)

    if os.path.exists(os.path.join(tensorboard_root,
                                   "bert_lightning_kubeflow")):
        shutil.rmtree(os.path.join(tensorboard_root,
                                   "bert_lightning_kubeflow"))

    Path(tensorboard_root).mkdir(parents=True, exist_ok=True)

    # Tensorboard root name of the logging directory
    tboard = TensorBoardLogger(tensorboard_root, "bert_lightning_kubeflow")

    Path(model_save_path).mkdir(parents=True, exist_ok=True)

    checkpoint_callback = ModelCheckpoint(
        dirpath=model_save_path,
        filename="bert_news_classification_{epoch:02d}",
        save_top_k=1,
        verbose=True,
        monitor="val_loss",
        mode="min",
        prefix="",
    )
    lr_logger = LearningRateMonitor()

    trainer = pl.Trainer(
        logger=tboard,
        accelerator=accelerator,
        callbacks=[lr_logger, early_stopping],
        checkpoint_callback=checkpoint_callback,
        max_epochs=max_epochs,
    )
    trainer.fit(model, dm)
    trainer.test()
Пример #28
0
        warmup_steps=100,
        gradient_accumulation_steps=16,
        train_batch_size=16,
        eval_batch_size=12,
        num_train_epochs=2000,
        n_gpu=gpu,
        fp_16=
        False,  # fp_16 true will end up shorter trainning time. 32 is default
        opt_level='O1',  # pure or mixed precision
        seed=42)
    args = argparse.Namespace(**args_dict)
    print(args_dict)

    checkpoint_callback = ModelCheckpoint(filepath=args.output_dir,
                                          prefix=str(lr) +
                                          '_checkpoint_self_attn-{epoch:02d}',
                                          monitor="val_loss",
                                          mode="min",
                                          save_top_k=5)

    early_stop_callback = EarlyStopping(monitor='val_loss',
                                        min_delta=0.00,
                                        patience=3,
                                        verbose=True,
                                        mode='min')
    train_params = dict(
        accumulate_grad_batches=args.gradient_accumulation_steps,
        gpus=args.n_gpu,
        max_epochs=args.num_train_epochs,
        amp_level=args.opt_level,
        gradient_clip_val=args.gradient_clip_val,
        auto_lr_find=True,
Пример #29
0
    train_fold = train.iloc[train_index]
    val_fold = train.iloc[val_index]
    train_fold.to_csv(path / f'fold_{f+1}_train.csv')
    val_fold.to_csv(path / f'fold_{f+1}_val.csv')

    dm = DataModule(file=f'fold_{f+1}', **config)

    model = Model(config)

    #wandb_logger = WandbLogger(project="cassava", config=config)
    #es = MyEarlyStopping(monitor='val_acc', mode='max', patience=config['patience'])
    checkpoint = ModelCheckpoint(
        dirpath='./',
        filename=
        f'{config["backbone"]}-{config["size"]}-fold_{f+1}-{{val_acc:.5f}}',
        save_top_k=1,
        monitor='val_acc',
        mode='max')
    #lr_monitor = LearningRateMonitor(logging_interval='step')

    trainer = pl.Trainer(
        gpus=config['gpus'],
        precision=config['precision'],
        #logger= wandb_logger,
        max_epochs=config['max_epochs'],
        #callbacks=[es, checkpoint, lr_monitor],
        callbacks=[checkpoint],
        limit_val_batches=config['val_batches'])

    trainer.fit(model, dm)
Пример #30
0
def test_dp_resume(tmpdir):
    """Make sure DP continues training correctly."""
    model = CustomClassificationModelDP(lr=0.1)
    dm = ClassifDataModule()

    trainer_options = dict(max_epochs=1,
                           gpus=2,
                           accelerator='dp',
                           default_root_dir=tmpdir)

    # get logger
    logger = tutils.get_default_logger(tmpdir)

    # exp file to get weights
    # logger file to get weights
    checkpoint = tutils.init_checkpoint_callback(logger)

    # add these to the trainer options
    trainer_options['logger'] = logger
    trainer_options['callbacks'] = [checkpoint]

    # fit model
    trainer = Trainer(**trainer_options)
    trainer.is_slurm_managing_tasks = True
    trainer.fit(model, datamodule=dm)

    # track epoch before saving. Increment since we finished the current epoch, don't want to rerun
    real_global_epoch = trainer.current_epoch + 1

    # correct result and ok accuracy
    assert trainer.state.finished, f"Training failed with {trainer.state}"

    # ---------------------------
    # HPC LOAD/SAVE
    # ---------------------------
    # save
    trainer.checkpoint_connector.hpc_save(tmpdir, logger)

    # init new trainer
    new_logger = tutils.get_default_logger(tmpdir, version=logger.version)
    trainer_options['logger'] = new_logger
    trainer_options['callbacks'] = [ModelCheckpoint(dirpath=tmpdir)]
    trainer_options['limit_train_batches'] = 0.5
    trainer_options['limit_val_batches'] = 0.2
    trainer_options['max_epochs'] = 1
    new_trainer = Trainer(**trainer_options)

    class CustomModel(CustomClassificationModelDP):
        def __init__(self):
            super().__init__()
            self.on_pretrain_routine_end_called = False

        # set the epoch start hook so we can predict before the model does the full training
        def on_pretrain_routine_end(self):
            assert self.trainer.current_epoch == real_global_epoch and self.trainer.current_epoch > 0

            # if model and state loaded correctly, predictions will be good even though we
            # haven't trained with the new loaded model
            new_trainer.state.stage = RunningStage.VALIDATING

            dataloader = self.train_dataloader()
            tpipes.run_prediction_eval_model_template(
                self.trainer.lightning_module, dataloader=dataloader)
            self.on_pretrain_routine_end_called = True

    # new model
    model = CustomModel()

    # fit new model which should load hpc weights
    new_trainer.fit(model, datamodule=dm)
    assert model.on_pretrain_routine_end_called

    # test freeze on gpu
    model.freeze()
    model.unfreeze()