Esempio n. 1
0
def test_parse_args_parsing_complex_types(cli_args, expected, instantiate):
    """Test parsing complex types."""
    parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
    parser.add_lightning_class_args(Trainer, None)
    with mock.patch("sys.argv", ["any.py"] + cli_args):
        args = parser.parse_args()

    for k, v in expected.items():
        assert getattr(args, k) == v
    if instantiate:
        assert Trainer.from_argparse_args(args)
Esempio n. 2
0
def main():
    """main"""
    parser = argparse.ArgumentParser(description="Training")
    parser.add_argument("--bert_path",
                        required=True,
                        type=str,
                        help="bert config file")
    parser.add_argument("--mode",
                        default="bert",
                        type=str,
                        help="bert config file")
    parser.add_argument("--batch_size",
                        type=int,
                        default=20,
                        help="batch size")
    parser.add_argument("--lr", type=float, default=2e-5, help="learning rate")
    parser.add_argument("--workers",
                        type=int,
                        default=0,
                        help="num workers for dataloader")
    parser.add_argument("--weight_decay",
                        default=0.0,
                        type=float,
                        help="Weight decay if we apply some.")
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--warmup_steps",
                        default=0,
                        type=int,
                        help="warmup steps")
    parser.add_argument("--use_memory",
                        action="store_true",
                        help="load dataset to memory to accelerate.")
    parser.add_argument("--max_length",
                        default=512,
                        type=int,
                        help="max length of dataset")
    parser.add_argument("--data_dir",
                        required=True,
                        type=str,
                        help="train data path")
    parser.add_argument("--save_topk",
                        default=0,
                        type=int,
                        help="save topk checkpoint")
    parser = BertPPL.add_model_specific_args(parser)
    parser = Trainer.add_argparse_args(parser)
    args = parser.parse_args()

    model = BertPPL(args)
    trainer = Trainer.from_argparse_args(args, distributed_backend="ddp")
    trainer.test(model)
def test_argparse_args_parsing(cli_args, expected):
    """Test multi type argument with bool."""
    cli_args = cli_args.split(" ") if cli_args else []
    with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
        parser = ArgumentParser(add_help=False)
        parser = Trainer.add_argparse_args(parent_parser=parser)
        args = Trainer.parse_argparser(parser)

    for k, v in expected.items():
        assert getattr(args, k) == v
    assert Trainer.from_argparse_args(args)
def inference():
    parser = add_model_specific_args()
    parser = Trainer.add_argparse_args(parser)
    args = parser.parse_args()
    model = RobertaClassificationModel(args)
    checkpoint = torch.load(args.checkpoint_path,
                            map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['state_dict'])
    trainer = Trainer.from_argparse_args(args, distributed_backend="ddp")

    trainer.test(model)
Esempio n. 5
0
def test_init_from_argparse_args(cli_args, extra_args):
    unknown_args = dict(unknown_arg=0)

    # unkown args in the argparser/namespace should be ignored
    with mock.patch("pytorch_lightning.Trainer.__init__",
                    autospec=True,
                    return_value=None) as init:
        trainer = Trainer.from_argparse_args(
            Namespace(**cli_args, **unknown_args), **extra_args)
        expected = dict(cli_args)
        expected.update(extra_args)  # extra args should override any cli arg
        init.assert_called_with(trainer, **expected)

    # passing in unknown manual args should throw an error
    with pytest.raises(
            TypeError,
            match=
            r"__init__\(\) got an unexpected keyword argument 'unknown_arg'"):
        Trainer.from_argparse_args(Namespace(**cli_args), **extra_args,
                                   **unknown_args)
Esempio n. 6
0
def test_tpu_cores_with_argparse(cli_args, expected):
    """Test passing tpu_cores in command line"""
    cli_args = cli_args.split(' ') if cli_args else []
    with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
        parser = ArgumentParser(add_help=False)
        parser = Trainer.add_argparse_args(parent_parser=parser)
        args = Trainer.parse_argparser(parser)

    for k, v in expected.items():
        assert getattr(args, k) == v
    assert Trainer.from_argparse_args(args)
def run_noisy_clip():
    args = grab_config()

    seed_everything(args.seed)

    dataset = ImageNetCLIPDataset(args)
    dataset.setup()
    model = NoisyCLIP(args)

    logger = TensorBoardLogger(
        save_dir=args.logdir,
        version=args.experiment_name,
        name='NoisyCLIP_Logs'
    )
    if not hasattr(args, 'increasing') or not args.increasing:
        trainer = Trainer.from_argparse_args(args, logger=logger, callbacks=[ModelCheckpoint(save_top_k=-1, period=25)])
        trainer.fit(model, datamodule=dataset)
    else:
        trainer = Trainer.from_argparse_args(args, logger=logger, reload_dataloaders_every_epoch=True, callbacks=[ModelCheckpoint(save_top_k=-1, period=25)])
        trainer.fit(model)
Esempio n. 8
0
def test_parse_args_parsing_gpus(monkeypatch, cli_args, expected_gpu):
    """Test parsing of gpus and instantiation of Trainer."""
    monkeypatch.setattr("torch.cuda.device_count", lambda: 2)
    cli_args = cli_args.split(" ") if cli_args else []
    parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
    parser.add_lightning_class_args(Trainer, None)
    with mock.patch("sys.argv", ["any.py"] + cli_args):
        args = parser.parse_args()

    trainer = Trainer.from_argparse_args(args)
    assert trainer.data_parallel_device_ids == expected_gpu
def test_trainer_cli_profiler_remove_in_v1_3_0(cli_args, expected_parsed_arg,
                                               expected_profiler):
    cli_args = cli_args.split(' ')
    with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
        parser = ArgumentParser(add_help=False)
        parser = Trainer.add_argparse_args(parent_parser=parser)
        args = Trainer.parse_argparser(parser)

    assert getattr(args, "profiler") == expected_parsed_arg
    trainer = Trainer.from_argparse_args(args)
    assert isinstance(trainer.profiler, expected_profiler)
Esempio n. 10
0
def main(args):
    model = Model(hparams=args, data_path=os.path.join(PARENT_DIR, 'datasets', 'pendulum-gym-image-dataset-train.pkl'))
    checkpoint_callback = ModelCheckpoint(monitor='loss', 
                                          prefix=args.name+f'-T_p={args.T_pred}-', 
                                          save_top_k=1, 
                                          save_last=True)
    trainer = Trainer.from_argparse_args(args, 
                                         deterministic=True,
                                         default_root_dir=os.path.join(PARENT_DIR, 'logs', args.name),
                                         checkpoint_callback=checkpoint_callback) 
    trainer.fit(model)
Esempio n. 11
0
def test_parse_args_parsing(cli_args, expected):
    """Test parsing simple types and None optionals not modified."""
    cli_args = cli_args.split(" ") if cli_args else []
    parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
    parser.add_lightning_class_args(Trainer, None)
    with mock.patch("sys.argv", ["any.py"] + cli_args):
        args = parser.parse_args()

    for k, v in expected.items():
        assert getattr(args, k) == v
    assert Trainer.from_argparse_args(args)
Esempio n. 12
0
def main():
    """main"""
    parser = get_parser()

    # add model specific args
    parser = BertLabeling.add_model_specific_args(parser)

    # add all the available trainer options to argparse
    # ie: now --gpus --num_nodes ... --fast_dev_run all work in the cli
    parser = Trainer.add_argparse_args(parser)

    args = parser.parse_args()

    if not args.test_only:
        model = BertLabeling(args)
        if args.pretrained_checkpoint:
            model.load_state_dict(
                torch.load(args.pretrained_checkpoint,
                           map_location=torch.device('cpu'))["state_dict"])

        checkpoint_callback = ModelCheckpoint(
            filepath=args.default_root_dir,
            save_top_k=3,
            verbose=True,
            monitor="span_f1",
            period=-1,
            mode="max",
        )
        trainer = Trainer.from_argparse_args(
            args, checkpoint_callback=checkpoint_callback)

        trainer.fit(model)
        trainer.test()
    else:
        assert args.test_checkpoint_path, 'test_checkpoint_path is required in test_mode'
        model = BertLabeling.load_from_checkpoint(
            checkpoint_path=args.test_checkpoint_path,
            on_gpu=True,
        )
        trainer = Trainer.from_argparse_args(args, )
        trainer.test(model)
Esempio n. 13
0
def main(args):
    config = create_config()
    model = create_model(args.mode)
    dm = create_data_module(config)

    print(f'Training for {args.max_epochs} epochs')

    # Output training parameters
    config.display()

    trainer = Trainer.from_argparse_args(args)
    trainer.fit(model, dm)
Esempio n. 14
0
def main(args):
    logger = pl_loggers.WandbLogger(experiment="example", save_dir=None)
    early_stop = EarlyStopping(monitor="val_loss")
    checkpoint_callback = ModelCheckpoint(dirpath="ckpts/", monitor="val_loss")
    model = ExampleModel(args)
    lr_logger = LearningRateLogger()
    trainer = Trainer.from_argparse_args(
        args,
        logger=logger,
        callbacks=[early_stop, lr_logger],
        checkpoint_callback=checkpoint_callback)
    trainer.fit(model)
Esempio n. 15
0
def main(hparams):
    if hparams.checkpoint_path is None:
        model = vnet.VNet(**vars(hparams))
    else:
        # If any arguments were explicitly given, then force those
        seen_params = { a : getattr(hparams, a) for a in hparams.seen_args_ }
        checkpoint_path = seen_params.pop('checkpoint_path')
        model = vnet.VNet.load_from_checkpoint(checkpoint_path, **seen_params)

    trainer = Trainer.from_argparse_args(hparams, auto_lr_find=True)

    trainer.tune(model)
Esempio n. 16
0
def main(hparams):

    # Prepare folders
    if not os.path.isdir(hparams.generated_images_folder):
        os.mkdir(hparams.generated_images_folder)

    data = CelebaDataModule(hparams)
    model = VaeGanModule(hparams)
    trainer = Trainer.from_argparse_args(hparams)

    # train
    trainer.fit(model, data)
Esempio n. 17
0
def test_varnet_trainer(backend, skip_module_test):
    if skip_module_test:
        pytest.skip("config set to skip")

    args = build_varnet_args()
    args.fast_dev_run = True
    args.backend = backend

    model = VarNetModule(**vars(args))
    trainer = Trainer.from_argparse_args(args)

    trainer.fit(model)
Esempio n. 18
0
def main(args):
    logger = pl_loggers.WandbLogger(experiment=None, save_dir=None)
    checkpoint_callback = ModelCheckpoint(
        dirpath=f"ckpts/{args.reduction}_reduction/", monitor="val_loss")
    model = AutoencoderModel(args)
    lr_logger = LearningRateMonitor()
    trainer = Trainer.from_argparse_args(
        args,
        logger=logger,
        callbacks=[lr_logger],
        checkpoint_callback=checkpoint_callback)
    trainer.fit(model)
Esempio n. 19
0
def cli_main() -> None:
    parent_parser = argparse.ArgumentParser(add_help=False)
    parent_parser = Trainer.add_argparse_args(parent_parser)

    parser = PPO.add_model_specific_args(parent_parser)
    args = parser.parse_args()

    model = PPO(**vars(args))

    seed_everything(0)
    trainer = Trainer.from_argparse_args(args)
    trainer.fit(model)
Esempio n. 20
0
def inference():
    parser = add_model_specific_args()
    parser = Trainer.add_argparse_args(parser)
    args = parser.parse_args()
    model = ConvolutionClassificationModel(args)
    trainer = Trainer.from_argparse_args(args)
    # 选取最优的val做测试
    checkpoint_path = os.path.join(args.save_path, "checkpoints")
    best_checkpoint_path = find_best_checkpoint(checkpoint_path)
    checkpoint = torch.load(best_checkpoint_path)
    model.load_state_dict(checkpoint['state_dict'])
    trainer.test(model)
Esempio n. 21
0
def evaluate():
    parser = get_parser()
    parser = Trainer.add_argparse_args(parser)
    args = parser.parse_args()

    model = ChnSentiClassificationTask(args)
    checkpoint = torch.load(args.checkpoint_path,
                            map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['state_dict'])
    trainer = Trainer.from_argparse_args(args, distributed_backend="ddp")

    trainer.test(model)
Esempio n. 22
0
def cli_main(args=None):
    seed_everything(1234)

    parser = ArgumentParser()
    parser.add_argument("--batch_size", default=64, type=int)
    parser.add_argument("--dataset",
                        default="mnist",
                        type=str,
                        choices=["lsun", "mnist"])
    parser.add_argument("--data_dir", default="./", type=str)
    parser.add_argument("--image_size", default=64, type=int)
    parser.add_argument("--num_workers", default=8, type=int)

    script_args, _ = parser.parse_known_args(args)

    if script_args.dataset == "lsun":
        transforms = transform_lib.Compose([
            transform_lib.Resize(script_args.image_size),
            transform_lib.CenterCrop(script_args.image_size),
            transform_lib.ToTensor(),
            transform_lib.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        dataset = LSUN(root=script_args.data_dir,
                       classes=["bedroom_train"],
                       transform=transforms)
        image_channels = 3
    elif script_args.dataset == "mnist":
        transforms = transform_lib.Compose([
            transform_lib.Resize(script_args.image_size),
            transform_lib.ToTensor(),
            transform_lib.Normalize((0.5, ), (0.5, )),
        ])
        dataset = MNIST(root=script_args.data_dir,
                        download=True,
                        transform=transforms)
        image_channels = 1

    dataloader = DataLoader(dataset,
                            batch_size=script_args.batch_size,
                            shuffle=True,
                            num_workers=script_args.num_workers)

    parser = DCGAN.add_model_specific_args(parser)
    parser = Trainer.add_argparse_args(parser)
    args = parser.parse_args(args)

    model = DCGAN(**vars(args), image_channels=image_channels)
    callbacks = [
        TensorboardGenerativeModelImageSampler(num_samples=5),
        LatentDimInterpolator(interpolate_epoch_interval=5),
    ]
    trainer = Trainer.from_argparse_args(args, callbacks=callbacks)
    trainer.fit(model, dataloader)
def test_argparse_args_parsing_gpus(cli_args, expected_parsed,
                                    expected_device_ids):
    """Test multi type argument with bool."""
    cli_args = cli_args.split(" ") if cli_args else []
    with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
        parser = ArgumentParser(add_help=False)
        parser = Trainer.add_argparse_args(parent_parser=parser)
        args = Trainer.parse_argparser(parser)

    assert args.gpus == expected_parsed
    trainer = Trainer.from_argparse_args(args)
    assert trainer.data_parallel_device_ids == expected_device_ids
Esempio n. 24
0
def main(args):
    # logger = WandbLogger(project=args.project_name, save_dir=None, log_model=True)
    # model_checkpointer = ModelCheckpoint(dirpath=logger.save_dir, monitor=args.monitor, save_weights_only=True)
    trainer = Trainer.from_argparse_args(
        args,
        # logger=logger,
        # callbacks=[model_checkpointer],
        plugins=DDPPlugin(find_unused_parameters=True)
    )
    celeba = CelebaDataModule.from_argparse_args(args)
    model = VanillaStarGAN.from_argparse_args(args, image_shape=celeba.image_shape, label_names=celeba.attributes)
    trainer.fit(model, datamodule=celeba)
Esempio n. 25
0
def test_default_args(mock_argparse, tmpdir):
    """Tests default argument parser for Trainer"""
    mock_argparse.return_value = Namespace(**Trainer.default_attributes())

    parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
    args = parser.parse_args([])

    args.max_epochs = 5
    trainer = Trainer.from_argparse_args(args)

    assert isinstance(trainer, Trainer)
    assert trainer.max_epochs == 5
def main(args):
    # Sanity checks
    assert args.classification_task or args.regression_task

    seed_everything(42)

    # Build datasets
    train_ds, val_ds, test_ds = build_datasets(args, LABEL_COLUMNS)
    print("Size of train/val/test:",
          len(train_ds),
          len(val_ds),
          len(test_ds),
          end="\n\n")

    # Build dataloaders
    train_dl = DataLoader(
        train_ds,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.dataloader_workers,
    )
    val_dl = DataLoader(
        val_ds,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.dataloader_workers,
    )
    test_dl = DataLoader(
        test_ds,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.dataloader_workers,
    )

    # Comet.ml logging
    if args.wandb_logging:
        wandb_logger = WandbLogger(name=args.wandb_name, project="mining")

    # Instantiate model, train and test
    dict_args = vars(args)
    model = MultiTaskLearner(
        classifier_loss_weights=train_ds.classifier_weights, **dict_args)
    trainer = Trainer.from_argparse_args(
        args,
        default_root_dir=f"{args.root_dir}/{args.wandb_name}",
        early_stop_callback=False,
        min_epochs=args.epochs,
        max_epochs=args.epochs,
        logger=wandb_logger if args.wandb_logging else None,
    )
    trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl)
    trainer.test(test_dataloaders=[test_dl])
Esempio n. 27
0
def train_byol(model,
               loader,
               byol_parameters,
               log_training=True,
               logger_name='byol'):
    only_train_layers = [
        lambda trunk: trunk.blocks[-1],
        lambda trunk: trunk.conv_head,
        lambda trunk: trunk.bn2,
        lambda trunk: trunk.global_pool,
        lambda trunk: trunk.act2,
        lambda trunk: trunk.classifier,
    ]
    new_model = LeafDoctorModel(only_train_layers=only_train_layers)
    new_model.load_state_dict(model.state_dict())
    model = new_model

    hparams = Namespace(**byol_parameters)

    logger = TensorBoardLogger("lightning_logs",
                               name=logger_name) if log_training else None
    byol = BYOL(model.trunk, hparams=hparams)
    early_stopping = EarlyStopping('train_loss',
                                   mode='min',
                                   patience=hparams.early_stop_patience,
                                   verbose=True)
    callbacks = [early_stopping]
    if log_training:
        lr_monitor = LearningRateMonitor(logging_interval='step')
        callbacks.append(lr_monitor)
    trainer = Trainer.from_argparse_args(
        hparams,
        reload_dataloaders_every_epoch=True,
        terminate_on_nan=True,
        callbacks=callbacks,
        precision=hparams.precision,
        amp_level=hparams.amp_level,
        log_every_n_steps=hparams.log_every_n_steps,
        flush_logs_every_n_steps=hparams.flush_logs_every_n_steps,
        logger=logger,
    )

    if hparams.auto_lr_find:
        new_lr = lr_find(trainer, byol, loader)
        hparams.lr = new_lr
        byol.hparams.lr = new_lr

    trainer.fit(byol, loader, loader)

    pretrained_model = LeafDoctorModel(None)
    pretrained_model.trunk.load_state_dict(byol.encoder.model.state_dict())
    return pretrained_model
Esempio n. 28
0
def test_sfomm_texp_syn(): 
    seed_everything(0)

    parser = ArgumentParser()
    parser.add_argument('--model_name', type=str, default='sfomm', help='fomm, ssm, or gru')
    parser.add_argument('--lr', type=float, default=8e-3, help='learning rate')
    parser.add_argument('--anneal', type=float, default=-1., help='annealing rate')
    parser.add_argument('--fname', type=str, help='name of save file')
    parser.add_argument('--imp_sampling', type=bool, default=False, help='importance sampling to estimate marginal likelihood')
    parser.add_argument('--nsamples', default=1, type=int)
    parser.add_argument('--nsamples_syn', default=50, type=int, help='number of training samples for synthetic data')
    parser.add_argument('--optimizer_name', type=str, default='adam')
    parser.add_argument('--dataset', default='synthetic', type=str)
    parser.add_argument('--loss_type', type=str, default='unsup')
    parser.add_argument('--eval_type', type=str, default='nelbo')
    parser.add_argument('--bs', default=600, type=int, help='batch size')
    parser.add_argument('--fold', default=1, type=int)

    # THIS LINE IS KEY TO PULL THE MODEL NAME
    temp_args, _ = parser.parse_known_args()

    # add rest of args from SSM and base trainer
    parser = SFOMM.add_model_specific_args(parser)
    parser = Trainer.add_argparse_args(parser)

    # parse args and convert to dict
    args = parser.parse_args()
    args.max_epochs = 5000
    args.mtype      = 'treatment_exp'
    args.alpha1_type = 'linear'
    args.inftype     = 'rnn'
    args.reg_type    = 'l2'
    args.C           = 0.
    args.dim_stochastic= 16
    args.reg_all     = True
    args.add_stochastic = False
    dict_args = vars(args)

    # initialize FOMM w/ args and train 
    model = SFOMM(**dict_args)
    trainer = Trainer.from_argparse_args(args, deterministic=True, logger=False, checkpoint_callback=False, gpus=[2])
    trainer.fit(model)

    # evaluate on validation set; this should match what we were getting with the old codebase (after 100 epochs)
    if torch.cuda.is_available():
        device = torch.device('cuda:2')
    else:
        device  = torch.device('cpu')
    _, valid_loader = model.load_helper('valid', device=device)
    preds, _ = model.predict(*valid_loader.dataset.tensors)
    mse, r2, ci = calc_stats(preds, valid_loader.dataset.tensors)
    assert abs(mse - 4.57) < 1e-1
def train():
    # Parse args
    parser = ArgumentParser()
    parser = Trainer.add_argparse_args(parser)
    parser.add_argument('--job_name',
                        type=str,
                        help='Name of the job',
                        required=True)
    parser.add_argument('--batch_size',
                        type=int,
                        help='Batch size',
                        default=16)
    parser.add_argument('--num_workers',
                        type=int,
                        help='Num workers',
                        default=4)

    parser = MultiDepthDistilBertModel.add_model_specific_args(parser)

    args = parser.parse_args()

    # Load data and models
    data_path = 'data'
    toxic_data_module = ToxicDataModule(data_path, args.batch_size,
                                        args.num_workers)
    model = MultiDepthDistilBertModel(args=args)

    # Logger
    tb_logger = pl_loggers.TensorBoardLogger(save_dir='logs',
                                             name=args.job_name)

    # Callbacks
    callbacks = []

    # Save best model checkpoints callback
    checkpoint_callback = ModelCheckpoint(
        monitor='val_f1',
        dirpath=os.path.join('logs', args.job_name,
                             'version_' + str(tb_logger.version)),
        filename='{epoch:02d}-{val_f1:.2f}',
        save_top_k=3,
        mode='max',
        save_weights_only=True,
        save_last=False)

    callbacks.append(checkpoint_callback)

    # Train
    trainer = Trainer.from_argparse_args(args,
                                         logger=tb_logger,
                                         callbacks=callbacks)
    trainer.fit(model, datamodule=toxic_data_module)
def cli_main():
    # args
    parser = ArgumentParser()
    parser = Trainer.add_argparse_args(parser)
    parser = LitMNIST.add_model_specific_args(parser)
    args = parser.parse_args()

    # model
    model = LitMNIST(**vars(args))

    # training
    trainer = Trainer.from_argparse_args(args)
    trainer.fit(model)