コード例 #1
0
ファイル: mnist.py プロジェクト: Akhilez/vision_lab
def train():
    hp = {
        "epochs": 10,
        "lr_initial": 0.001,
        "lr_decay_every": 30,
        "lr_decay_by": 0.3,
    }

    config = {
        "data_path": "../data",
        "val_split": 0.05,
        "batch_size": 64,
        "manual_seed": 2,
        "output_path": "./output",
        "model_save_frequency": 5,
        "dataloader_num_workers": 0,
    }

    dataset = MnistDataset(**config)
    model = MnistModel(**hp, **config)
    wandb_logger = WandbLogger(project="classification_test", log_model=True)
    trainer = pl.Trainer(
        gpus=0,
        max_epochs=hp["epochs"],
        default_root_dir=config["output_path"],
        logger=wandb_logger,
    )
    wandb_logger.watch(model)

    trainer.fit(model, datamodule=dataset)
コード例 #2
0
ファイル: main.py プロジェクト: undertherain/pl_stopping_test
def main():
    data_module = CFDataModule(batch_size=2)
    wandb_logger = WandbLogger(project="cosmoflow")
    early_stop_callback = EarlyStopping(
        monitor='val_loss',
        min_delta=0.0001,
        patience=2,
        verbose=True,
        mode="min",
    )
    print("create tainer")
    trainer = pl.Trainer(
        gpus=1,
        num_sanity_val_steps=0,
        max_epochs=20,
        distributed_backend="horovod",
        replace_sampler_ddp=False,
        early_stop_callback=early_stop_callback,
        logger=wandb_logger,
        progress_bar_refresh_rate=0,
    )
    # print("tainer created")

    model = Cosmoflow()
    trainer.fit(model, data_module)
コード例 #3
0
def main():
    # Use concurrency experiment
    wandb.require(experiment="service")
    print("PIDPID", os.getpid())

    # Set up data
    num_samples = 100000
    train = DataLoader(RandomDataset(32, num_samples), batch_size=32)
    val = DataLoader(RandomDataset(32, num_samples), batch_size=32)
    test = DataLoader(RandomDataset(32, num_samples), batch_size=32)
    # init model
    model = BoringModel()

    # set up wandb
    config = dict(some_hparam="Logged Before Trainer starts DDP")
    wandb_logger = WandbLogger(log_model=True, config=config, save_code=True)

    # Initialize a trainer
    trainer = Trainer(
        max_epochs=1,
        gpus=2,
        strategy="ddp_spawn",
        logger=wandb_logger,
    )

    # Train the model
    trainer.fit(model, train, val)
    trainer.test(test_dataloaders=test)
コード例 #4
0
def main(hparams):
    # ------------------------
    # 1 INIT LIGHTNING MODEL
    # ------------------------
    model = SegModel(hparams)

    # ------------------------
    # 2 SET WANDB LOGGER
    # ------------------------
    wandb_logger = WandbLogger()

    # optional: log model topology
    wandb_logger.watch(model.net)

    # ------------------------
    # 3 INIT TRAINER
    # ------------------------
    trainer = pl.Trainer(gpus=hparams.gpus,
                         logger=wandb_logger,
                         max_epochs=hparams.epochs,
                         accumulate_grad_batches=hparams.grad_batches,
                         checkpoint_callback=False)

    # ------------------------
    # 5 START TRAINING
    # ------------------------
    trainer.fit(model)
コード例 #5
0
def main(hparams: Namespace):
    # ------------------------
    # 1 INIT LIGHTNING MODEL
    # ------------------------
    model = SegModel(**vars(hparams))

    # ------------------------
    # 2 SET LOGGER
    # ------------------------
    logger = False
    if hparams.log_wandb:
        logger = WandbLogger()

        # optional: log model topology
        logger.watch(model.net)

    # ------------------------
    # 3 INIT TRAINER
    # ------------------------
    trainer = pl.Trainer(
        gpus=hparams.gpus,
        logger=logger,
        max_epochs=hparams.epochs,
        accumulate_grad_batches=hparams.grad_batches,
        accelerator=hparams.accelerator,
        precision=16 if hparams.use_amp else 32,
    )

    # ------------------------
    # 5 START TRAINING
    # ------------------------
    trainer.fit(model)
コード例 #6
0
def test_wandb_pickle(wandb):
    """Verify that pickling trainer with wandb logger works.
    Wandb doesn't work well with pytest so we have to mock it out here."""
    tutils.reset_seed()

    class Experiment:
        id = 'the_id'

    wandb.init.return_value = Experiment()

    logger = WandbLogger(id='the_id', offline=True)

    trainer_options = dict(max_epochs=1, logger=logger)

    trainer = Trainer(**trainer_options)
    pkl_bytes = pickle.dumps(trainer)
    trainer2 = pickle.loads(pkl_bytes)

    assert os.environ['WANDB_MODE'] == 'dryrun'
    assert trainer2.logger.__class__.__name__ == WandbLogger.__name__
    _ = trainer2.logger.experiment

    wandb.init.assert_called()
    assert 'id' in wandb.init.call_args[1]
    assert wandb.init.call_args[1]['id'] == 'the_id'

    del os.environ['WANDB_MODE']
コード例 #7
0
def main():
    print("Running main")
    print(time.ctime())

    args = parse_args()

    with open(args.config) as file:
        default_configs = yaml.load(file, Loader=yaml.FullLoader)

    print("Initialising model")
    print(time.ctime())
    model = CheckpointedPyramid(default_configs)
    # model.setup(stage="fit")

    logger = WandbLogger(
        project=default_configs["project"],
        group="InitialTest",
        save_dir=default_configs["artifacts"],
    )

    trainer = Trainer(
        gpus=4,
        num_nodes=8,
        strategy=CustomDDPPlugin(find_unused_parameters=False),
        max_epochs=default_configs["max_epochs"],
        logger=logger,
    )
    trainer.fit(model)
コード例 #8
0
def test_wandb_logger_offline_log_model(wandb, tmpdir):
    """ Test that log_model=True raises an error in offline mode """
    with pytest.raises(MisconfigurationException,
                       match='checkpoints cannot be uploaded in offline mode'):
        logger = WandbLogger(save_dir=str(tmpdir),
                             offline=True,
                             log_model=True)
コード例 #9
0
ファイル: train.py プロジェクト: sankovalev/goznak
def main(args) -> None:
    """
    Функция запуска обучения.
    """
    config = load_cfg(args.config)
    pretty_printer = pprint.PrettyPrinter(indent=2)
    pretty_printer.pprint(config)

    model = BaselineLearner(config)

    logger = False
    if args.use_logger:
        logger = WandbLogger(name=config.name)
        logger.watch(model.net)

    trainer = pl.Trainer(
        gpus=args.gpus,
        logger=logger,
        callbacks=[
            ModelCheckpoint(monitor='valid_loss',
                            dirpath=config.sources.ckpt_path,
                            filename=config.name)
        ],
        max_epochs=config.training.epochs,
        distributed_backend=args.distributed_backend,
        precision=16 if args.use_amp else 32,
    )

    trainer.fit(model)
    print('Model training completed!')
コード例 #10
0
def cli_main():
    parser = argparse.ArgumentParser()

    # trainer args
    parser = pl.Trainer.add_argparse_args(parser)

    # model args
    parser = SLExperiment.add_model_specific_args(parser)
    args = parser.parse_args()

    model = SLExperiment(**args.__dict__)

    if args.resume:
        model.resume(args.resume)

    logger = None
    callbacks = []
    if not args.fast_dev_run:
        logger = WandbLogger(project="argumentation",
                             save_dir=str(config.root_dir),
                             tags=[args.tag])
        logger.log_hyperparams(args)

        # save checkpoints based on avg_reward
        checkpoint_callback = ModelCheckpoint(
            dirpath=logger.experiment.dir,
            save_top_k=1,
            monitor="validation/loss",
            mode="min",
            save_weights_only=True,
            verbose=True,
        )
        callbacks.append(checkpoint_callback)

        if args.tag:
            tag_checkpoint_callback = copy.deepcopy(checkpoint_callback)
            tag_checkpoint_callback.dirpath = model.model_dir
            tag_checkpoint_callback.filename = model.model_name
            callbacks.append(tag_checkpoint_callback)

    # early stopping
    if args.patience:
        early_stop_callback = EarlyStopping(monitor="validation/loss",
                                            patience=args.patience,
                                            mode="min",
                                            verbose=True)
        callbacks.append(early_stop_callback)

    pl.seed_everything(123)
    trainer = pl.Trainer.from_argparse_args(
        args,
        logger=logger,
        callbacks=callbacks,
        track_grad_norm=2,
    )

    if args.train_ds and args.val_ds:
        trainer.fit(model)
    if args.test_ds:
        trainer.test(model)
コード例 #11
0
def get_logger(model_config):  # M

    logger_choice = model_config["logger"]
    if "project" not in model_config.keys():
        model_config["project"] = "my_project"

    if logger_choice == "wandb":
        logger = WandbLogger(
            project=model_config["project"],
            save_dir=model_config["artifact_library"],
            id=model_config["resume_id"],
        )

    elif logger_choice == "tb":
        logger = TensorBoardLogger(
            name=model_config["project"],
            save_dir=model_config["artifact_library"],
            version=model_config["resume_id"],
        )

    elif logger_choice == None:
        logger = None

    logging.info("Logger retrieved")
    return logger
コード例 #12
0
ファイル: train.py プロジェクト: oleges1/kws
def train(config):
    fix_seeds(seed=config.train.seed)

    crnn = CRNNEncoder(
        in_channels=config.model.get('in_channels', 42),
        hidden_size=config.model.get('hidden_size', 16),
        dropout=config.model.get('dropout', 0.1),
        cnn_layers=config.model.get('cnn_layers', 2),
        rnn_layers=config.model.get('rnn_layers', 2),
        kernel_size=config.model.get('kernel_size', 9)
    )
    model = AttentionNet(
        crnn,
        hidden_size=config.model.get('hidden_size', 16),
        num_classes=config.model.get('num_classes', 3)
    )
    pl_model = KWSModel(
        model, lr=config.train.get('lr', 4e-5),
         in_channels=config.model.get('in_channels', 42),
         batch_size=config.train.get('batch_size', 32)
    )
    wandb_logger = WandbLogger(name=config.train.get('experiment_name', 'final_run'), project='kws-attention', log_model=True)
    wandb_logger.log_hyperparams(config)
    wandb_logger.watch(model, log='all', log_freq=100)
    trainer = pl.Trainer(max_epochs=config.train.get('max_epochs', 15), logger=wandb_logger, gpus=config.train.get('gpus', 1))
    trainer.fit(pl_model)
コード例 #13
0
def train(hparams):
    NUM_GPUS = hparams.num_gpus
    USE_AMP = False  # True if NUM_GPUS > 1 else False
    MAX_EPOCHS = 50

    dataset = load_link_dataset(hparams.dataset, hparams=hparams)
    hparams.n_classes = dataset.n_classes

    model = LATTELinkPredictor(hparams,
                               dataset,
                               collate_fn="triples_batch",
                               metrics=[hparams.dataset])
    wandb_logger = WandbLogger(name=model.name(),
                               tags=[dataset.name()],
                               project="multiplex-comparison")

    trainer = Trainer(
        gpus=NUM_GPUS,
        distributed_backend='ddp' if NUM_GPUS > 1 else None,
        auto_lr_find=False,
        max_epochs=MAX_EPOCHS,
        early_stop_callback=EarlyStopping(monitor='val_loss',
                                          patience=10,
                                          min_delta=0.01,
                                          strict=False),
        logger=wandb_logger,
        # regularizers=regularizers,
        weights_summary='top',
        amp_level='O1' if USE_AMP else None,
        precision=16 if USE_AMP else 32)

    trainer.fit(model)
    trainer.test(model)
コード例 #14
0
def main(cfg: DictConfig):
    datamodule = instantiate(cfg.data)
    task = instantiate(cfg.task)
    logger = WandbLogger(**cfg.logger)
    # logger = CSVLogger(save_dir='logs')
    trainer = Trainer(**cfg.trainer, logger=logger)
    trainer.fit(model=task, datamodule=datamodule)
コード例 #15
0
ファイル: train.py プロジェクト: borisdayma/lightning-kitti
def main(config):
    # ------------------------
    # 1 LIGHTNING MODEL
    # ------------------------
    model = SegModel(config)

    # ------------------------
    # 2 DATA PIPELINES
    # ------------------------
    kittiData = KittiDataModule(config)

    # ------------------------
    # 3 WANDB LOGGER
    # ------------------------
    wandb_logger = WandbLogger()

    # optional: log model topology
    wandb_logger.watch(model.net)

    # ------------------------
    # 4 TRAINER
    # ------------------------
    trainer = pl.Trainer(
        gpus=-1,
        logger=wandb_logger,
        max_epochs=config.epochs,
        accumulate_grad_batches=config.grad_batches,
    )

    # ------------------------
    # 5 START TRAINING
    # ------------------------
    trainer.fit(model, kittiData)
コード例 #16
0
def main(hparams: Namespace):
    # ------------------------
    # 1 INIT LIGHTNING MODEL
    # ------------------------
    model = SegModel(**vars(hparams))

    # ------------------------
    # 2 SET LOGGER
    # ------------------------
    logger = False
    if hparams.log_wandb:
        logger = WandbLogger()

        # optional: log model topology
        logger.watch(model.net)

    # ------------------------
    # 3 INIT TRAINER
    # ------------------------
    trainer = pl.Trainer.from_argparse_args(hparams)

    # ------------------------
    # 5 START TRAINING
    # ------------------------
    trainer.fit(model)
コード例 #17
0
def test_wandb_pickle(wandb, tmpdir):
    """
    Verify that pickling trainer with wandb logger works.
    Wandb doesn't work well with pytest so we have to mock it out here.
    """
    class Experiment:
        """ """
        id = 'the_id'

        def project_name(self):
            return 'the_project_name'

    wandb.init.return_value = Experiment()
    logger = WandbLogger(id='the_id', offline=True)

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        logger=logger,
    )
    # Access the experiment to ensure it's created
    assert trainer.logger.experiment, 'missing experiment'
    pkl_bytes = pickle.dumps(trainer)
    trainer2 = pickle.loads(pkl_bytes)

    assert os.environ['WANDB_MODE'] == 'dryrun'
    assert trainer2.logger.__class__.__name__ == WandbLogger.__name__
    assert trainer2.logger.experiment, 'missing experiment'

    wandb.init.assert_called()
    assert 'id' in wandb.init.call_args[1]
    assert wandb.init.call_args[1]['id'] == 'the_id'

    del os.environ['WANDB_MODE']
コード例 #18
0
def test_wandb_logger_dirs_creation(wandb, tmpdir):
    """ Test that the logger creates the folders and files in the right place. """
    logger = WandbLogger(save_dir=str(tmpdir), offline=True)
    assert logger.version is None
    assert logger.name is None

    # mock return values of experiment
    logger.experiment.id = '1'
    logger.experiment.project_name.return_value = 'project'

    for _ in range(2):
        _ = logger.experiment

    assert logger.version == '1'
    assert logger.name == 'project'
    assert str(tmpdir) == logger.save_dir
    assert not os.listdir(tmpdir)

    version = logger.version
    model = EvalModelTemplate()
    trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3)
    trainer.fit(model)

    assert trainer.checkpoint_callback.dirpath == str(tmpdir / 'project' / version / 'checkpoints')
    assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=9.ckpt'}
コード例 #19
0
def objective(trial):
    pl.seed_everything(42, workers=True)
    trans = {'Resize': {'width': 224, 'height': 224}}
    dm = DataModule(batch_size=64,
                    num_workers=24,
                    pin_memory=True,
                    train_trans=trans,
                    val_trans=trans,
                    shuffle_train=False)
    loss = trial.suggest_categorical(
        "loss", ["bce", "dice", "jaccard", "focal", "log_cosh_dice"])
    model = SMP({
        'optimizer': 'Adam',
        'lr': 0.0003,
        'loss': loss,
        'model': 'Unet',
        'backbone': 'resnet18',
        'pretrained': 'imagenet'
    })
    wandb_logger = WandbLogger(project="MnMs2-opt", name=loss)
    trainer = pl.Trainer(
        gpus=1,
        precision=16,
        logger=wandb_logger,
        max_epochs=10,
        callbacks=[PyTorchLightningPruningCallback(trial, monitor="val_iou")],
        checkpoint_callback=False,
        limit_train_batches=1.,
        limit_val_batches=1.,
        deterministic=True)
    trainer.fit(model, dm)
    score = trainer.test(model, dm.val_dataloader())
    wandb_logger.experiment.finish()
    return score[0]['test_iou']
コード例 #20
0
def test_wandb_logger(wandb):
    """Verify that basic functionality of wandb logger works.
    Wandb doesn't work well with pytest so we have to mock it out here."""
    logger = WandbLogger(anonymous=True, offline=True)

    logger.log_metrics({'acc': 1.0})
    wandb.init().log.assert_called_once_with({'acc': 1.0}, step=None)

    wandb.init().log.reset_mock()
    logger.log_metrics({'acc': 1.0}, step=3)
    wandb.init().log.assert_called_once_with({'acc': 1.0}, step=3)

    # continue training on same W&B run
    wandb.init().step = 3
    logger.finalize('success')
    logger.log_metrics({'acc': 1.0}, step=3)
    wandb.init().log.assert_called_with({'acc': 1.0}, step=6)

    logger.log_hyperparams({'test': None, 'nested': {'a': 1}, 'b': [2, 3, 4]})
    wandb.init().config.update.assert_called_once_with(
        {'test': 'None', 'nested/a': 1, 'b': [2, 3, 4]},
        allow_val_change=True,
    )

    logger.watch('model', 'log', 10)
    wandb.init().watch.assert_called_once_with('model', log='log', log_freq=10)

    assert logger.name == wandb.init().project_name()
    assert logger.version == wandb.init().id
コード例 #21
0
def main():
    # Use concurrency experiment
    wandb.require(experiment="service")
    print("PIDPID", os.getpid())

    # Set up data
    num_samples = 100000
    train = RandomDataset(32, num_samples)
    train = DataLoader(train, batch_size=32)
    val = RandomDataset(32, num_samples)
    val = DataLoader(val, batch_size=32)
    test = RandomDataset(32, num_samples)
    test = DataLoader(test, batch_size=32)
    # init model
    model = BoringModel()

    # set up wandb
    config = dict(some_hparam="Logged Before Trainer starts DDP")
    wandb_logger = WandbLogger(log_model=True, config=config, save_code=True)

    # Initialize a trainer
    trainer = pl.Trainer(
        max_epochs=1,
        progress_bar_refresh_rate=20,
        num_processes=2,
        accelerator="ddp_cpu",
        logger=wandb_logger,
    )

    # Train the model
    trainer.fit(model, train, val)
    trainer.test(dataloaders=test)
コード例 #22
0
def test_multi_gpu_wandb_ddp_spawn(tmpdir):
    """Make sure DP/DDP + AMP work."""
    from pytorch_lightning.loggers import WandbLogger
    tutils.set_random_master_port()

    model = EvalModelTemplate()

    wandb.run = MagicMock()
    wandb.init(name='name', project='project')

    logger = WandbLogger(name='name', offline=True)
    trainer_options = dict(
        default_root_dir=tmpdir,
        max_epochs=1,
        gpus=2,
        distributed_backend='ddp_spawn',
        precision=16,
        logger=logger,

    )
    # tutils.run_model_test(trainer_options, model)
    trainer = Trainer(**trainer_options)
    result = trainer.fit(model)
    assert result
    trainer.test(model)
コード例 #23
0
def test_wandb_pickle(tmpdir):
    """Verify that pickling trainer with wandb logger works."""
    tutils.reset_seed()

    wandb_dir = str(tmpdir)
    logger = WandbLogger(save_dir=wandb_dir, anonymous=True)
    assert logger is not None
コード例 #24
0
ファイル: vqvae_sweep.py プロジェクト: imatge-upc/PiCoEDL
def main():
    import sys
    import wandb

    from config import setSeed, getConfig
    from main.vqvae import VQVAE
    from pytorch_lightning.loggers import WandbLogger
    import pytorch_lightning as pl

    from IPython import embed

    run = wandb.init()
    conf = getConfig(sys.argv[1])

    conf = update_custom(conf, run.config)
    wandb_logger = WandbLogger(project='mineRL',
                               name=conf['experiment'],
                               tags=[alg, 'sweep'])

    wandb_logger.log_hyperparams(conf)

    vqvae = VQVAE(conf)

    trainer = pl.Trainer(gpus=1,
                         max_epochs=conf['epochs'],
                         progress_bar_refresh_rate=20,
                         weights_summary='full',
                         logger=wandb_logger,
                         default_root_dir=f"./results/{conf['experiment']}")

    trainer.fit(vqvae)
コード例 #25
0
ファイル: finetune.py プロジェクト: azraar/nlp_summarization
def main(args, model=None) -> SummarizationModule:
    Path(args.output_dir).mkdir(exist_ok=True)
    if len(os.listdir(args.output_dir)) > 3 and args.do_train:
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    if model is None:
        model: BaseTransformer = SummarizationModule(args)
    if (args.logger == "default" or args.fast_dev_run
            or str(args.output_dir).startswith("/tmp")
            or str(args.output_dir).startswith("/var")):
        logger = True  # don't pollute wandb logs unnecessarily
    elif args.logger == "wandb":
        from pytorch_lightning.loggers import WandbLogger

        logger = WandbLogger(name=model.output_dir.name)
    elif args.logger == "wandb_shared":
        from pytorch_lightning.loggers import WandbLogger

        # TODO: separate LB for CNN, we should use Path(args.data_dir).name to determine the correct LB.
        logger = WandbLogger(name=model.output_dir.name,
                             project="hf_summarization")
    trainer: pl.Trainer = generic_train(
        model,
        args,
        logging_callback=Seq2SeqLoggingCallback(),
        checkpoint_callback=get_rouge2_checkpoint_callback(args.output_dir),
        logger=logger,
        # TODO: early stopping callback seems messed up
    )
    pickle_save(model.hparams, model.output_dir / "hparams.pkl")
    if not args.do_predict:
        return model

    model.hparams.test_checkpoint = ""
    checkpoints = list(
        sorted(
            glob.glob(os.path.join(args.output_dir, "*.ckpt"),
                      recursive=True)))
    if checkpoints:
        model.hparams.test_checkpoint = checkpoints[-1]
        trainer.resume_from_checkpoint = checkpoints[-1]
    trainer.logger.log_hyperparams(model.hparams)
    trainer.test(
        model
    )  # this breaks in DDP, known lightning issue. See evaluate_checkpoint to recover metrics.
    return model
コード例 #26
0
def main(
    checkpoint: str,
    test: bool = False,
    overfit: float = 0,
    max_epochs: int = 1000,
):
    config: VisualElectraConfig = VisualElectraConfig()
    # Base BERT model
    config.tokenizer = AutoTokenizer.from_pretrained(
        "google/bert_uncased_L-4_H-512_A-8"
        # "bert-base-uncased"
    )
    gen_model_name = "google/bert_uncased_L-2_H-512_A-8"
    disc_model_name = "google/bert_uncased_L-8_H-768_A-12"
    config.hidden_size = 512

    gen_conf = AutoConfig.from_pretrained(gen_model_name)
    config.generator_model = AutoModelForMaskedLM.from_config(gen_conf)
    config.generator_hidden_size = 512

    disc_conf = AutoConfig.from_pretrained(disc_model_name)
    disc_conf.is_decoder = True
    config.discriminator_model = AutoModel.from_config(disc_conf)
    config.discriminator_hidden_size = 768

    full_model = VisualElectra.load_from_checkpoint(checkpoint, config=config)
    model = full_model.discriminator
    model.training_objective = TrainingObjective.Captioning
    model.add_lm_head()

    data = CocoCaptions()
    data.prepare_data()
    data.setup()

    logger = None

    fast_dev_run = test & (overfit == 0)

    if test is not True:
        logger = WandbLogger(project="final-year-project",
                             offline=False,
                             log_model=True,
                             save_dir=work_dir,
                             config={'checkpoint': checkpoint},
                             tags=['electra-finetune'])

    callbacks = [CheckpointEveryNSteps(50000)]

    trainer = pl.Trainer(gpus=1,
                         fast_dev_run=fast_dev_run,
                         default_root_dir=work_dir,
                         log_every_n_steps=10,
                         logger=logger,
                         max_epochs=max_epochs,
                         overfit_batches=overfit,
                         callbacks=callbacks
                         # check_val_every_n_epoch=1000 if overfit > 0 else 1,
                         )
    trainer.fit(model, data)
コード例 #27
0
def train_classifier(logging=False, train=True):

    hparams = {
        'gpus': [1],
        'max_epochs': 25,
        'num_classes': 700,
        'feature_dimension': 512,
        'model_dimension': 1024,
        'pretrained_text': False,
        'num_modalities': 1,
        'batch_size': 32,
        'learning_rate': 1e-3,
        'model_path':
        "/home/sgurram/Projects/aai/aai/experimental/sgurram/lava/src/wandb/run-20210626_215155-yqwe58z7/files/lava/yqwe58z7/checkpoints/epoch=6-step=12529.ckpt",
        'model_descriptor': 'lava timesformer 1/3 kinetics data, unshuffled',
        'accumulate_grad_batches': 2,
        'overfit_batches': 0,
        'type_modalities': 'av',
        'modality_fusion': 'concat',
        'loss_funtions': ['cross_entropy'],
        'metrics': None,
        'optimizer': 'adam',
        'scheduler': 'n/a',
        'profiler': 'simple',
        'default_root_dir': '/home/sgurram/Desktop/video_lava_classifer',
    }

    model = EvalLightning(
        num_classes=hparams['num_classes'],
        feature_dimension=hparams['feature_dimension'],
        model_dimension=hparams['model_dimension'],
        num_modalities=hparams['num_modalities'],
        batch_size=hparams['batch_size'],
        learning_rate=hparams['learning_rate'],
        model_path=hparams['model_path'],
        model=LAVALightning,
        pretrained_text=hparams['pretrained_text'],
    )

    if logging:
        wandb_logger = WandbLogger(name='run', project='lava')
        wandb_logger.log_hyperparams(hparams)
        wandb_logger.watch(model, log='gradients', log_freq=10)
    else:
        wandb_logger = None

    if not train:
        return model

    trainer = pl.Trainer(
        default_root_dir=hparams['default_root_dir'],
        gpus=hparams['gpus'],
        max_epochs=hparams['max_epochs'],
        accumulate_grad_batches=hparams['accumulate_grad_batches'],
        overfit_batches=hparams['overfit_batches'],
        logger=wandb_logger,
        profiler=hparams['profiler'])

    trainer.fit(model)
コード例 #28
0
def train(dataset_name: str,
          model_name: str,
          expt_dir: str,
          data_folder: str,
          num_workers: int = 0,
          is_test: bool = False,
          resume_from_checkpoint: str = None):
    seed_everything(SEED)
    dataset_main_folder = data_folder
    vocab = Vocabulary.load(join(dataset_main_folder, "vocabulary.pkl"))

    if model_name == "code2seq":
        config_function = get_code2seq_test_config if is_test else get_code2seq_default_config
        config = config_function(dataset_main_folder)
        model = Code2Seq(config, vocab, num_workers)
        model.half()
    #elif model_name == "code2class":
    #	config_function = get_code2class_test_config if is_test else get_code2class_default_config
    #	config = config_function(dataset_main_folder)
    #	model = Code2Class(config, vocab, num_workers)
    else:
        raise ValueError(f"Model {model_name} is not supported")

    # define logger
    wandb_logger = WandbLogger(project=f"{model_name}-{dataset_name}",
                               log_model=True,
                               offline=True)
    wandb_logger.watch(model)
    # define model checkpoint callback
    model_checkpoint_callback = ModelCheckpoint(
        filepath=join(expt_dir, "{epoch:02d}-{val_loss:.4f}"),
        period=config.hyperparams.save_every_epoch,
        save_top_k=3,
    )
    # define early stopping callback
    early_stopping_callback = EarlyStopping(
        patience=config.hyperparams.patience, verbose=True, mode="min")
    # use gpu if it exists
    gpu = 1 if torch.cuda.is_available() else None
    # define learning rate logger
    lr_logger = LearningRateLogger()
    trainer = Trainer(
        max_epochs=20,
        gradient_clip_val=config.hyperparams.clip_norm,
        deterministic=True,
        check_val_every_n_epoch=config.hyperparams.val_every_epoch,
        row_log_interval=config.hyperparams.log_every_epoch,
        logger=wandb_logger,
        checkpoint_callback=model_checkpoint_callback,
        early_stop_callback=early_stopping_callback,
        resume_from_checkpoint=resume_from_checkpoint,
        gpus=gpu,
        callbacks=[lr_logger],
        reload_dataloaders_every_epoch=True,
    )
    trainer.fit(model)
    trainer.save_checkpoint(join(expt_dir, 'Latest.ckpt'))

    trainer.test()
コード例 #29
0
def main(args, model_name: str, reproducible: bool, comet: bool, wandb: bool):
    if reproducible:
        seed_everything(42)
        args.deterministic = True
        args.benchmark = True

    if comet:
        from pytorch_lightning.loggers import CometLogger
        comet_logger = CometLogger(
            api_key=os.environ.get('COMET_API_KEY'),
            workspace=os.environ.get('COMET_WORKSPACE'),  # Optional
            project_name=os.environ.get('COMET_PROJECT_NAME'),  # Optional
            experiment_name=model_name  # Optional
        )
        args.logger = comet_logger
    if wandb:
        from pytorch_lightning.loggers import WandbLogger
        wandb_logger = WandbLogger(
            project=os.environ.get('WANDB_PROJECT_NAME'),
            log_model=True,
            sync_step=True)
        args.logger = wandb_logger

    if args.default_root_dir is None:
        args.default_root_dir = 'results'

    # Save best model
    model_checkpoint = ModelCheckpoint(
        filename=model_name + '_{epoch}',
        save_top_k=1,
        monitor='val_iou',
        mode='max',
    )
    args.checkpoint_callback = model_checkpoint

    data = SimulatorDataModule(dataPath=args.dataPath,
                               augment=args.augment,
                               batch_size=args.batch_size,
                               num_workers=8)
    model = RightLaneModule(lr=args.learningRate,
                            lrRatio=args.lrRatio,
                            decay=args.decay,
                            num_cls=4)

    # Parse all trainer options available from the command line
    trainer = Trainer.from_argparse_args(args)
    trainer.fit(model, datamodule=data)

    # Reload best model
    model = RightLaneModule.load_from_checkpoint(
        model_checkpoint.best_model_path, dataPath=args.dataPath, num_cls=4)

    # Upload weights
    if comet:
        comet_logger.experiment.log_model(model_name + '_weights',
                                          model_checkpoint.best_model_path)

    # Perform testing
    trainer.test(model, datamodule=data)
コード例 #30
0
def get_logger(model_config):

    wandb_logger = WandbLogger(project=model_config["project"],
                               save_dir=model_config["wandb_save_dir"],
                               id=model_config["resume_id"])

    logging.info("Logger retrieved")
    return wandb_logger