示例#1
0
def training_loop(train,
                  valid,
                  save_path,
                  pl_module,
                  callbacks,
                  n_epochs,
                  checkpoint_callback,
                  use_neptune=False,
                  resume=True,
                  limit_train_batches=2,
                  neptune_tags="",
                  neptune_name=""):
    """
    Largely model/application agnostic training code.
    """
    # Train with proper resuming
    # Copy gin configs used, for reference, to the save folder
    os.system("rm " + os.path.join(save_path, "*gin"))
    for gin_config in sys.argv[2].split(";"):
        os.system("cp {} {}/base_config.gin".format(gin_config, save_path))
    with open(os.path.join(save_path, "config.gin"), "w") as f:
        f.write(gin.operative_config_str())
    hparams = parse_gin_config(os.path.join(save_path, 'config.gin'))
    if 'train.callbacks' in hparams:
        del hparams['train.callbacks']
    # TODO: What is a less messy way to pass hparams? This is only that logging is aware of hyperparameters
    pl_module._set_hparams(hparams)
    pl_module._hparams_initial = copy.deepcopy(hparams)
    loggers = []
    loggers.append(pl_loggers.CSVLogger(save_path))
    if use_neptune:
        from pytorch_lightning.loggers import NeptuneLogger
        loggers.append(
            NeptuneLogger(
                api_key=NEPTUNE_TOKEN,
                project_name=NEPTUNE_USER + "/" + NEPTUNE_PROJECT,
                experiment_name=neptune_name
                if len(neptune_name) else os.path.basename(save_path),
                tags=neptune_tags.split(',') if len(neptune_tags) else None,
            ))
        callbacks += [MetaSaver(), Heartbeat(), LearningRateMonitor()]
    trainer = pl.Trainer(
        default_root_dir=save_path,
        limit_train_batches=limit_train_batches,
        max_epochs=n_epochs,
        logger=loggers,
        callbacks=callbacks,
        log_every_n_steps=1,
        checkpoint_callback=checkpoint_callback,
        resume_from_checkpoint=os.path.join(save_path, 'last.ckpt') if resume
        and os.path.exists(os.path.join(save_path, 'last.ckpt')) else None)
    trainer.fit(pl_module, train, valid)
    return trainer
示例#2
0
    def tune_train_once(config,
                        checkpoint_dir=None,
                        args: argparse.Namespace = None,
                        model_class: type = None,
                        build_method=None,
                        task_info: TaskInfo = None,
                        model_kwargs: dict = None,
                        resume: str = None,
                        **kwargs):
        if resume is None:
            resume = 'all'
        args_vars = vars(args)
        args_vars.update(config)

        pl.seed_everything(args.seed)
        logger = [
            loggers.CSVLogger(save_dir=tune.get_trial_dir(),
                              name="",
                              version="."),
            loggers.TensorBoardLogger(save_dir=tune.get_trial_dir(),
                                      name="",
                                      version=".",
                                      default_hp_metric=False)
        ]
        trainer_args = dict(logger=logger,
                            progress_bar_refresh_rate=0,
                            callbacks=[
                                TuneReportCheckpointCallback(
                                    metrics={
                                        f'tune_{task_info.metric_name}':
                                        f'val_{task_info.metric_name}'
                                    },
                                    filename="tune.ckpt",
                                    on="validation_end")
                            ])
        if checkpoint_dir and resume == 'all':
            trainer_args['resume_from_checkpoint'] = os.path.join(
                checkpoint_dir, "tune.ckpt")

        # fix slurm trainer
        os.environ["SLURM_JOB_NAME"] = "bash"
        model = model_class(args, **model_kwargs)
        build_method(model, task_info)
        trainer: Trainer = Trainer.from_argparse_args(args, **trainer_args)
        if checkpoint_dir and resume == 'model':
            ckpt = pl_load(os.path.join(checkpoint_dir, "tune.ckpt"),
                           map_location=lambda storage, loc: storage)
            model = model._load_model_state(ckpt)
            trainer.current_epoch = ckpt["epoch"]
        trainer.fit(model)
示例#3
0
    def train_model(self, experiment_savedir, transform, batch_size,
                    classifier):
        train_set = ImageFolder(self.train_windows, transform=transform)
        valid_set = ImageFolder(self.valid_windows, transform=transform)
        print(len(train_set), len(valid_set))

        # create dataloaders
        train_loader = DataLoader(train_set,
                                  batch_size=batch_size,
                                  num_workers=8,
                                  shuffle=True)
        valid_loader = DataLoader(train_set,
                                  batch_size=batch_size,
                                  num_workers=8)

        # configure logging and checkpoints
        checkpoint_callback = ModelCheckpoint(
            monitor="val_accuracy",
            dirpath=experiment_savedir + "patch_model/",
            filename=f"checkpoint.ckpt",
            save_top_k=1,
            mode="max",
        )

        early_stop_callback = EarlyStopping(monitor='val_accuracy',
                                            min_delta=0.00,
                                            patience=10,
                                            verbose=False,
                                            mode='max')

        # create a logger
        csv_logger = pl_loggers.CSVLogger(experiment_savedir + 'logs/',
                                          name='patch_classifier',
                                          version=0)

        # train our model
        trainer = pl.Trainer(
            callbacks=[checkpoint_callback, early_stop_callback],
            gpus=2,
            accelerator="dp",
            max_epochs=100,
            logger=csv_logger,
            log_every_n_steps=11)
        trainer.fit(classifier,
                    train_dataloader=train_loader,
                    val_dataloaders=valid_loader)
示例#4
0
def get_loggers_callbacks(args, model=None):

    try:
        # Setup logger(s) params
        csv_logger_params = dict(
            save_dir="./experiments",
            name=os.path.join(*args.save_dir.split("/")[1:-1]),
            version=args.save_dir.split("/")[-1],
        )
        wandb_logger_params = dict(
            log_model=False,
            name=os.path.join(*args.save_dir.split("/")[1:]),
            offline=args.debug,
            project="utime",
            save_dir=args.save_dir,
        )
        loggers = [
            pl_loggers.CSVLogger(**csv_logger_params),
            pl_loggers.WandbLogger(**wandb_logger_params),
        ]
        if model:
            loggers[-1].watch(model)

        # Setup callback(s) params
        checkpoint_monitor_params = dict(
            filepath=os.path.join(args.save_dir,
                                  "{epoch:03d}-{eval_loss:.2f}"),
            monitor=args.checkpoint_monitor,
            save_last=True,
            save_top_k=1,
        )
        earlystopping_parameters = dict(
            monitor=args.earlystopping_monitor,
            patience=args.earlystopping_patience,
        )
        callbacks = [
            pl_callbacks.ModelCheckpoint(**checkpoint_monitor_params),
            pl_callbacks.EarlyStopping(**earlystopping_parameters),
            pl_callbacks.LearningRateMonitor(),
        ]

        return loggers, callbacks
    except AttributeError:
        return None, None
def build_trainer_and_model(conf, dirpath='.', progress=False) -> Tuple[pl.Trainer, AbstractLightningModule]:
    cons = model_name_to_cons(conf.model)
    if conf.irm is not None:
        assert conf.irm > 0
        if conf.dataset_name not in ["SYNTHETIC2", "NCI1", "NCI109", "PROTEINS", "DD"]:
            assert conf.cutoff is None
        model = cons(conf, IRMLoss(conf.irm, dataset_name=conf.dataset_name, cutoff=conf.cutoff))
    elif conf.reg_const is not None:
        assert conf.reg_const > 0
        assert conf.dataset_name in ["SYNTHETIC3", "NCI1", "NCI109", "PROTEINS", "DD"] and conf.model in [
            ModelName.KaryGNN, ModelName.KaryRPGNN]
        model = cons(conf, SubgraphRegularizedLoss(conf.reg_const))
    elif conf.label_smooth is not None:
        assert conf.label_smooth > 0
        assert conf.dataset_name in ["NCI1", "NCI109", "PROTEINS", "DD"] and conf.model in [ModelName.KaryGNN,
                                                                                            ModelName.KaryRPGNN]
        model = cons(conf, LabelSmoothingLoss(conf.num_out, conf.label_smooth))
    else:
        model = cons(conf, CELoss(conf.dataset_name))

    kwgs = {}

    if len(model.val_ds_list) > 0:
        kwgs["monitor"] = 'val/accuracy'
        kwgs["mode"] = 'max'
    else:
        print("No validation present")

    chkp = pl.callbacks.ModelCheckpoint(dirpath=dirpath, filename="model.ckpt", **kwgs)

    csv_logger = pl_loggers.CSVLogger(dirpath, name='csv_logs')
    tb_logger = pl_loggers.TensorBoardLogger(dirpath, name="tb_logs")
    trainer = pl.Trainer(weights_summary='full', max_epochs=conf.num_epochs, callbacks=[chkp],
                         reload_dataloaders_every_epoch=(
                                 conf.dataset_name in ["SYNTHETIC3", "NCI1", "NCI109", "PROTEINS"]
                                 and conf.model not in [ModelName.GNN, ModelName.RPGNN]),
                         logger=[tb_logger, csv_logger], gpus=1, progress_bar_refresh_rate=0 if not progress else 1)

    return trainer, model
示例#6
0
def get_default_loggers(
        save_path: str) -> List[pl_loggers.LightningLoggerBase]:
    return [pl_loggers.CSVLogger(save_path, name='csv_logs')]
示例#7
0
文件: common_train.py 项目: zmjm4/ltp
    def tune_train(args,
                   model_class,
                   task_info: TaskInfo,
                   build_method=default_build_method,
                   model_kwargs: dict = None,
                   tune_config=None):
        if model_kwargs is None:
            model_kwargs = {}
        this_time = time.strftime("%m-%d_%H:%M:%S", time.localtime())
        experiment_name = f'{task_info.task_name}_{this_time}'

        if tune_config is None:
            config = {
                # 3e-4 for Small, 1e-4 for Base, 5e-5 for Large
                "lr":
                tune.loguniform(args.tune_min_lr, args.tune_max_lr),

                # -1 for disable, 0.8 for Base/Small, 0.9 for Large
                "layerwise_lr_decay_power":
                tune.choice([0.8, 0.9]),

                # lr scheduler
                "lr_scheduler":
                tune.choice([
                    'linear_schedule_with_warmup',
                    'polynomial_decay_schedule_with_warmup'
                ]),
            }
        else:
            config = tune_config
        if torch.cuda.is_available():
            resources_per_trial = {
                "cpu": args.tune_cpus_per_trial,
                "gpu": args.tune_gpus_per_trial
            }
        else:
            resources_per_trial = {"cpu": args.tune_cpus_per_trial}
        print("resources_per_trial", resources_per_trial)

        tune_dir = os.path.abspath('tune_lightning_logs')

        analysis = tune.run(
            tune.with_parameters(
                tune_train_once,
                args=args,
                task_info=task_info,
                model_class=model_class,
                build_method=build_method,
                model_kwargs=model_kwargs,
                resume=args.tune_resume,
                group=experiment_name,
                log_dir=tune_dir,
            ),
            mode="max",
            config=config,
            num_samples=args.tune_num_samples,
            metric=f'tune_{task_info.metric_name}',
            name=experiment_name,
            progress_reporter=CLIReporter(
                parameter_columns=list(config.keys()),
                metric_columns=[
                    "loss", f'tune_{task_info.metric_name}',
                    "training_iteration"
                ]),
            callbacks=[TBXLoggerCallback(),
                       CSVLoggerCallback()],
            resources_per_trial=resources_per_trial,
            scheduler=ASHAScheduler(
                max_t=args.max_epochs + 1,  # for test
                grace_period=args.min_epochs),
            queue_trials=True,
            keep_checkpoints_num=args.tune_keep_checkpoints_num,
            checkpoint_score_attr=f'tune_{task_info.metric_name}',
            local_dir=tune_dir,
        )
        print("Best hyperparameters found were: ", analysis.best_config)
        print("Best checkpoint: ", analysis.best_checkpoint)

        args_vars = vars(args)
        args_vars.update(analysis.best_config)
        model = model_class.load_from_checkpoint(os.path.join(
            analysis.best_checkpoint, "tune.ckpt"),
                                                 hparams=args,
                                                 **model_kwargs)

        pl_loggers = [
            loggers.CSVLogger(save_dir=tune.get_trial_dir(),
                              name="",
                              version="."),
            loggers.TensorBoardLogger(save_dir=tune.get_trial_dir(),
                                      name="",
                                      version=".",
                                      default_hp_metric=False),
        ]

        try:
            import wandb
            pl_loggers.append(
                loggers.WandbLogger(save_dir=tune_dir,
                                    project=args.project,
                                    name=tune.get_trial_name(),
                                    id=tune.get_trial_id(),
                                    offline=args.offline,
                                    group=experiment_name))
        except Exception:
            pass

        trainer: Trainer = Trainer.from_argparse_args(args, logger=pl_loggers)
        build_method(model, task_info)
        trainer.test(model)
示例#8
0
        self.sched.step()


# # Model

EPOCHS = 1
batch_size = 20
num_classes = 6
n_train_steps = int(len(df) / batch_size * EPOCHS)

dm = CSVDataModule(batch_size=batch_size, data_dir="/media/hdd/Datasets/jigsaw/")
dm.setup()

model = LitModel(num_classes=num_classes, num_train_steps=n_train_steps)

logger = pl_loggers.CSVLogger("logs", name="eff-b5")

trainer = pl.Trainer(
    auto_select_gpus=True,
    gpus=1,
    precision=16,
    profiler=False,
    max_epochs=EPOCHS,
    callbacks=[pl.callbacks.ProgressBar()],
    automatic_optimization=True,
    enable_pl_optimizer=True,
    accelerator="ddp",
    plugins="ddp_sharded",
    logger=logger,
)