コード例 #1
0
def trained_model():
    args = [
        "--num_users",
        str(100),
        "--num_items",
        str(100),
        "--max_epochs",
        str(100),
        "--num_workers",
        str(4),
    ]

    if th.cuda.is_available():
        args = args + ["--gpus", str(1)]

    parser = ArgumentParser(add_help=False)
    parser = Trainer.add_argparse_args(parser)
    parser = ImplicitMatrixFactorization.add_model_specific_args(parser)
    parser = TestDataModule.add_dataset_specific_args(parser)
    parsed_args = parser.parse_args(args)
    model = ImplicitMatrixFactorization(parsed_args)

    test_dataset = TestDataModule(parsed_args.batch_size, parsed_args.num_workers)

    trainer = Trainer.from_argparse_args(parsed_args)
    trainer.fit(model, test_dataset)

    return model
コード例 #2
0
def test_precision_parsed_correctly(arg, expected):
    """Test to ensure that the precision flag is passed correctly when adding argparse args."""
    parser = ArgumentParser()
    parser = Trainer.add_argparse_args(parser)
    fake_argv = [arg]
    args = parser.parse_args(fake_argv)
    assert args.precision == expected
コード例 #3
0
def main():
    seed_everything(4321)

    parser = ArgumentParser(add_help=False)
    parser = Trainer.add_argparse_args(parser)
    parser.add_argument("--trainer_method", default="fit")
    parser.add_argument("--tmpdir")
    parser.add_argument("--workdir")
    parser.set_defaults(gpus=2)
    parser.set_defaults(strategy="ddp")
    args = parser.parse_args()

    dm = ClassifDataModule()
    model = ClassificationModel()
    trainer = Trainer.from_argparse_args(args)

    if args.trainer_method == "fit":
        trainer.fit(model, datamodule=dm)
        result = None
    elif args.trainer_method == "test":
        result = trainer.test(model, datamodule=dm)
    elif args.trainer_method == "fit_test":
        trainer.fit(model, datamodule=dm)
        result = trainer.test(model, datamodule=dm)
    else:
        raise ValueError(f"Unsupported: {args.trainer_method}")

    result_ext = {
        "status": "complete",
        "method": args.trainer_method,
        "result": result
    }
    file_path = os.path.join(args.tmpdir, "ddp.result")
    torch.save(result_ext, file_path)
コード例 #4
0
def run_fit():
    import pprint
    parser = ArgumentParser()
    parser = Trainer.add_argparse_args(parser)
    parser = WellnessClassifier.add_model_specific_args(parser)
    args = parser.parse_args()

    args.gpus = 2
    args.max_epochs = 25
    args.train_batch_size = 8
    args.batch_size = 8

    pprint.pprint(vars(args))

    wellness_dm = WellnessDataModule(args)
    wellness_dm.prepare_data()

    print("num labels=", wellness_dm.num_labels)
    model = WellnessClassifier(args, wellness_dm.num_labels)

    wellness_dm = WellnessDataModule(args)
    wellness_dm.prepare_data()

    trainer = Trainer.from_argparse_args(
        args, early_stop_callback=early_stop_callback)
    trainer.fit(model, wellness_dm)
コード例 #5
0
def main():
    seed_everything(1234)

    parser = ArgumentParser(add_help=False)
    parser = Trainer.add_argparse_args(parser)
    parser.add_argument('--trainer_method', default='fit')
    parser.add_argument('--tmpdir')
    parser.add_argument('--workdir')
    parser.set_defaults(gpus=2)
    parser.set_defaults(accelerator="ddp")
    args = parser.parse_args()

    model = EvalModelTemplate()
    trainer = Trainer.from_argparse_args(args)

    result = {}
    if args.trainer_method == 'fit':
        trainer.fit(model)
        result = {'status': 'complete', 'method': args.trainer_method, 'result': None}
    if args.trainer_method == 'test':
        result = trainer.test(model)
        result = {'status': 'complete', 'method': args.trainer_method, 'result': result}
    if args.trainer_method == 'fit_test':
        trainer.fit(model)
        result = trainer.test(model)
        result = {'status': 'complete', 'method': args.trainer_method, 'result': result}

    if len(result) > 0:
        file_path = os.path.join(args.tmpdir, 'ddp.result')
        torch.save(result, file_path)
コード例 #6
0
ファイル: main.py プロジェクト: zzalshcv1/learnopencv
def main():
    parser = get_program_level_args()
    parser = LitFood101.add_model_specific_args(parser)
    parser = Trainer.add_argparse_args(parser)

    args = parser.parse_args()
    seed_everything(args.seed)

    checkpoint_callback = ModelCheckpoint(monitor="avg_val_acc", mode="max")
    trainer = Trainer.from_argparse_args(
        args,
        deterministic=True,
        benchmark=False,
        checkpoint_callback=checkpoint_callback,
        precision=16 if args.amp_level != "O0" else 32,
    )

    # create model
    model = resnet18(pretrained=True)
    if args.use_knowledge_distillation:
        teacher_model = resnet50(pretrained=False)
        model = LitFood101KD(model, teacher_model, args)
    else:
        model = LitFood101(model, args)

    if args.evaluate:
        checkpoint = torch.load(args.checkpoint)
        model.load_state_dict(checkpoint["state_dict"])
        trainer.test(model, test_dataloaders=model.test_dataloader())
        return 0

    trainer.fit(model)

    trainer.test()
コード例 #7
0
ファイル: train.py プロジェクト: mpeven/pytorch-slurm
def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('-s',
                        '--sweep',
                        action='store_true',
                        help='Run a hyperparameter sweep over all options')

    # DataModule args
    parser = MNISTDataModule.add_argparse_args(parser)

    # Trainer args (https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags)
    parser = Trainer.add_argparse_args(parser)
    # Set some sane defaults
    for x in parser._actions:
        if x.dest == 'gpus':
            x.default = 1
        if x.dest == 'max_epochs':
            x.default = 100

    # TestTube args - hyperparam parser & slurm info
    parser = HyperOptArgumentParser(strategy='grid_search',
                                    add_help=False,
                                    parents=[parser])
    parser.add_argument('--test_tube_exp_name', default='sweep_test')
    parser.add_argument('--log_path', default='./pytorch-slurm')

    # LightningModule args (hyperparameters)
    parser = MNISTClassifier.add_model_specific_args(parser)

    args = parser.parse_args()
    return args
コード例 #8
0
ファイル: task_segmention.py プロジェクト: xyz9911/ltp
def main():
    parser = ArgumentParser()

    # add task level args
    parser = add_common_specific_args(parser)
    parser = add_tune_specific_args(parser)
    parser = add_task_specific_args(parser)

    # add model specific args
    parser = Model.add_model_specific_args(parser)
    parser = optimization.add_optimizer_specific_args(parser)
    parser = Trainer.add_argparse_args(parser)

    # set task specific args
    parser.set_defaults(gradient_clip_val=1.0, min_epochs=1, max_epochs=10)
    parser.set_defaults(num_labels=2)

    args = parser.parse_args()

    if args.build_dataset:
        build_distill_dataset(args)
    elif args.tune:
        tune_train(args, model_class=Model, task_info=task_info)
    else:
        common_train(args, model_class=Model, task_info=task_info)
コード例 #9
0
def cli_main() -> None:
    from pl_bolts.datamodules.sklearn_datamodule import SklearnDataModule
    from pl_bolts.utils import _SKLEARN_AVAILABLE

    seed_everything(1234)

    # create dataset
    if _SKLEARN_AVAILABLE:
        from sklearn.datasets import load_diabetes
    else:  # pragma: no cover
        raise ModuleNotFoundError(
            'You want to use `sklearn` which is not installed yet, install it with `pip install sklearn`.'
        )

    # args
    parser = ArgumentParser()
    parser = LinearRegression.add_model_specific_args(parser)
    parser = Trainer.add_argparse_args(parser)
    args = parser.parse_args()

    # model
    model = LinearRegression(input_dim=10, l1_strength=1, l2_strength=1)
    # model = LinearRegression(**vars(args))

    # data
    X, y = load_diabetes(return_X_y=True)  # these are numpy arrays
    loaders = SklearnDataModule(X, y, batch_size=args.batch_size)

    # train
    trainer = Trainer.from_argparse_args(args)
    trainer.fit(model,
                train_dataloader=loaders.train_dataloader(),
                val_dataloaders=loaders.val_dataloader())
コード例 #10
0
def cli_main(args=None):
    from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule

    seed_everything()

    parser = ArgumentParser()
    parser.add_argument("--dataset", default="cifar10", type=str, choices=["cifar10", "stl10", "imagenet"])
    script_args, _ = parser.parse_known_args(args)

    if script_args.dataset == "cifar10":
        dm_cls = CIFAR10DataModule
    elif script_args.dataset == "stl10":
        dm_cls = STL10DataModule
    elif script_args.dataset == "imagenet":
        dm_cls = ImagenetDataModule
    else:
        raise ValueError(f"undefined dataset {script_args.dataset}")

    parser = VAE.add_model_specific_args(parser)
    parser = Trainer.add_argparse_args(parser)
    args = parser.parse_args(args)

    dm = dm_cls.from_argparse_args(args)
    args.input_height = dm.size()[-1]

    if args.max_steps == -1:
        args.max_steps = None

    model = VAE(**vars(args))

    trainer = Trainer.from_argparse_args(args)
    trainer.fit(model, datamodule=dm)
    return dm, model, trainer
コード例 #11
0
def main():
    """main"""
    parser = get_parser()
    parser = Trainer.add_argparse_args(parser)
    args = parser.parse_args()

    model = BertClassificationTask(args)

    checkpoint_callback = ModelCheckpoint(
        filepath=os.path.join(args.save_path, 'checkpoint',
                              '{epoch}-{val_loss:.4f}-{val_acc:.4f}'),
        save_top_k=1,
        save_last=False,
        monitor="val_acc",
        mode="max",
    )
    logger = TensorBoardLogger(save_dir=args.save_path, name='log')

    # save args
    with open(os.path.join(args.save_path, 'checkpoint', "args.json"),
              'w') as f:
        args_dict = args.__dict__
        del args_dict['tpu_cores']
        json.dump(args_dict, f, indent=4)

    trainer = Trainer.from_argparse_args(
        args,
        checkpoint_callback=checkpoint_callback,
        distributed_backend="ddp",
        logger=logger)

    trainer.fit(model)
コード例 #12
0
def test_wandb_sanitize_callable_params(tmpdir):
    """
    Callback function are not serializiable. Therefore, we get them a chance to return
    something and if the returned type is not accepted, return None.
    """
    opt = "--max_epochs 1".split(" ")
    parser = ArgumentParser()
    parser = Trainer.add_argparse_args(parent_parser=parser)
    params = parser.parse_args(opt)

    def return_something():
        return "something"

    params.something = return_something

    def wrapper_something():
        return return_something

    params.wrapper_something_wo_name = lambda: lambda: '1'
    params.wrapper_something = wrapper_something

    assert isinstance(params.gpus, types.FunctionType)
    params = WandbLogger._convert_params(params)
    params = WandbLogger._flatten_dict(params)
    params = WandbLogger._sanitize_callable_params(params)
    assert params["gpus"] == '_gpus_arg_default'
    assert params["something"] == "something"
    assert params["wrapper_something"] == "wrapper_something"
    assert params["wrapper_something_wo_name"] == "<lambda>"
コード例 #13
0
def cli_main(args=None):
    from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler
    from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, MNISTDataModule, STL10DataModule

    seed_everything(1234)

    parser = ArgumentParser()
    parser.add_argument("--dataset", default="mnist", type=str, help="mnist, cifar10, stl10, imagenet")
    script_args, _ = parser.parse_known_args(args)

    if script_args.dataset == "mnist":
        dm_cls = MNISTDataModule
    elif script_args.dataset == "cifar10":
        dm_cls = CIFAR10DataModule
    elif script_args.dataset == "stl10":
        dm_cls = STL10DataModule
    elif script_args.dataset == "imagenet":
        dm_cls = ImagenetDataModule

    parser = dm_cls.add_argparse_args(parser)
    parser = Trainer.add_argparse_args(parser)
    parser = GAN.add_model_specific_args(parser)
    args = parser.parse_args(args)

    dm = dm_cls.from_argparse_args(args)
    model = GAN(*dm.size(), **vars(args))
    callbacks = [TensorboardGenerativeModelImageSampler(), LatentDimInterpolator(interpolate_epoch_interval=5)]
    trainer = Trainer.from_argparse_args(args, callbacks=callbacks, progress_bar_refresh_rate=20)
    trainer.fit(model, datamodule=dm)
    return dm, model, trainer
コード例 #14
0
ファイル: fdkconvnet.py プロジェクト: phernst/sparse_dbp
def main():
    parser = ArgumentParser()
    parser = FDKConvNet.add_model_specific_args(parser)
    parser = Trainer.add_argparse_args(parser)
    hparams = parser.parse_args()
    hparams.lr = 5e-2
    hparams.end_lr = 1e-2
    hparams.max_epochs = 300
    hparams.batch_size = CONFIGURATION['batch_size']
    hparams.data_dir = DATA_DIRS['datasets']
    hparams.valid_dir = 'valid_fdkconvnet' \
        + ('_pre' if hparams.pretrained else '')
    with open('train_valid.json') as json_file:
        json_dict = json.load(json_file)
        hparams.train_files = json_dict['train_files']
        hparams.valid_files = json_dict['valid_files']

    model = FDKConvNet(**vars(hparams))

    checkpoint_callback = ModelCheckpoint(
        dirpath=hparams.valid_dir,
        monitor='val_loss',
        save_last=True,
    )
    trainer = Trainer(
        precision=CONFIGURATION['precision'],
        progress_bar_refresh_rate=CONFIGURATION['progress_bar_refresh_rate'],
        gpus=1,
        checkpoint_callback=checkpoint_callback,
        max_epochs=hparams.max_epochs,
        terminate_on_nan=True,
    )
    trainer.fit(model)
コード例 #15
0
def main():
    parser = ArgumentParser()

    # add task level args
    parser = add_task_specific_args(parser)

    # add model specific args
    parser = Model.add_model_specific_args(parser)
    parser = optimization.add_optimizer_specific_args(parser)
    parser = Trainer.add_argparse_args(parser)

    # set default args
    parser.set_defaults(num_labels=27)

    args = parser.parse_args()

    if args.build_dataset:
        build_distill_dataset(args)
    else:
        common_train(
            args,
            metric=f'val_{task_info.metric_name}',
            model_class=Model,
            build_method=build_method,
            task=task_info.task_name
        )
コード例 #16
0
def main():
    parser = get_parser()
    parser = BertForGLUETask.add_model_specific_args(parser)
    parser = Trainer.add_argparse_args(parser)
    args = parser.parse_args()

    task_model = BertForGLUETask(args)

    if len(args.pretrained_checkpoint) > 1:
        task_model.load_state_dict(
            torch.load(args.pretrained_checkpoint,
                       map_location=torch.device("cpu"))["state_dict"])

    checkpoint_callback = ModelCheckpoint(filepath=args.output_dir,
                                          save_top_k=args.max_keep_ckpt,
                                          save_last=False,
                                          monitor="val_f1",
                                          verbose=True,
                                          mode='max',
                                          period=-1)

    task_trainer = Trainer.from_argparse_args(
        args, checkpoint_callback=checkpoint_callback, deterministic=True)

    task_trainer.fit(task_model)

    # after training, use the model checkpoint which achieves the best f1 score on dev set to compute the f1 on test set.
    best_f1_on_dev, path_to_best_checkpoint = find_best_checkpoint_on_dev(
        args.output_dir,
        only_keep_the_best_ckpt=args.only_keep_the_best_ckpt_after_training)
    task_model.result_logger.info("=&" * 20)
    task_model.result_logger.info(f"Best F1 on DEV is {best_f1_on_dev}")
    task_model.result_logger.info(
        f"Best checkpoint on DEV set is {path_to_best_checkpoint}")
    task_model.result_logger.info("=&" * 20)
コード例 #17
0
ファイル: trainer.py プロジェクト: Xiang-Pan/CovidQA
def main():
    """main"""
    parser = get_parser()

    # add model specific args
    parser = BertLabeling.add_model_specific_args(parser)

    # add all the available trainer options to argparse
    # ie: now --gpus --num_nodes ... --fast_dev_run all work in the cli
    parser = Trainer.add_argparse_args(parser)

    args = parser.parse_args()

    model = BertLabeling(args)
    if args.pretrained_checkpoint:
        model.load_state_dict(
            torch.load(args.pretrained_checkpoint,
                       map_location=torch.device('cpu'))["state_dict"])

    checkpoint_callback = ModelCheckpoint(
        filepath=args.default_root_dir,
        save_top_k=10,
        verbose=True,
        monitor="span_f1",
        period=-1,
        mode="max",
    )
    trainer = Trainer.from_argparse_args(
        args, checkpoint_callback=checkpoint_callback)

    trainer.fit(model)
コード例 #18
0
ファイル: trainer.py プロジェクト: Xiang-Pan/CovidQA
def run_dataloader():
    """test dataloader"""
    parser = get_parser()

    # add model specific args
    parser = BertLabeling.add_model_specific_args(parser)

    # add all the available trainer options to argparse
    # ie: now --gpus --num_nodes ... --fast_dev_run all work in the cli
    parser = Trainer.add_argparse_args(parser)

    args = parser.parse_args()
    args.workers = 0
    args.default_root_dir = "/mnt/data/mrc/train_logs/debug"

    model = BertLabeling(args)
    from tokenizers import BertWordPieceTokenizer
    tokenizer = BertWordPieceTokenizer(
        os.path.join(args.bert_config_dir, "vocab.txt"))

    loader = model.get_dataloader("dev", limit=1000)
    for d in loader:
        input_ids = d[0][0].tolist()
        match_labels = d[-1][0]
        start_positions, end_positions = torch.where(match_labels > 0)
        start_positions = start_positions.tolist()
        end_positions = end_positions.tolist()
        if not start_positions:
            continue
        print("=" * 20)
        print(tokenizer.decode(input_ids, skip_special_tokens=False))
        for start, end in zip(start_positions, end_positions):
            print(tokenizer.decode(input_ids[start:end + 1]))
コード例 #19
0
def main():
    parser = ArgumentParser()

    # add task level args
    parser = add_task_specific_args(parser)
    # add model specific args
    parser = Model.add_model_specific_args(parser)
    parser = optimization.add_optimizer_specific_args(parser)
    parser = Trainer.add_argparse_args(parser)

    # task specific default args
    parser.set_defaults(gradient_clip_val=1.0, min_epochs=1, max_epochs=10)
    parser.set_defaults(num_labels=56,
                        arc_hidden_size=600,
                        rel_hidden_size=600)

    args = parser.parse_args()

    if args.build_dataset:
        build_distill_dataset(args)
    elif args.tune:
        tune_train(args,
                   model_class=Model,
                   task_info=task_info,
                   model_kwargs={'loss_func': sdp_loss})
    else:
        common_train(args,
                     model_class=Model,
                     task_info=task_info,
                     model_kwargs={'loss_func': sdp_loss})
コード例 #20
0
def main():
    parser = ArgumentParser()

    # add task level args
    parser = add_task_specific_args(parser)
    # add model specific args
    parser = Model.add_model_specific_args(parser)
    parser = optimization.add_optimizer_specific_args(parser)
    parser = Trainer.add_argparse_args(parser)

    # task specific default args
    parser.set_defaults(num_labels=56)
    parser.set_defaults(arc_hidden_size=600, rel_hidden_size=600)

    args = parser.parse_args()

    if args.build_dataset:
        build_distill_dataset(args)
    else:
        common_train(args,
                     metric=f'val_{task_info.metric_name}',
                     model_class=Model,
                     build_method=build_method,
                     task='sdp',
                     loss_func=sdp_loss)
コード例 #21
0
def test_sac():
    """Smoke test that the SAC model runs."""

    parent_parser = argparse.ArgumentParser(add_help=False)
    parent_parser = Trainer.add_argparse_args(parent_parser)
    parent_parser = SAC.add_model_specific_args(parent_parser)
    args_list = [
        "--warm_start_size",
        "100",
        "--gpus",
        "0",
        "--env",
        "Pendulum-v0",
        "--batch_size",
        "10",
    ]
    hparams = parent_parser.parse_args(args_list)

    trainer = Trainer(
        gpus=hparams.gpus,
        max_steps=100,
        max_epochs=
        100,  # Set this as the same as max steps to ensure that it doesn't stop early
        val_check_interval=
        1,  # This just needs 'some' value, does not effect training right now
        fast_dev_run=True,
    )
    model = SAC(**hparams.__dict__)
    trainer.fit(model)
コード例 #22
0
def main():
    """main"""
    parser = get_parser()
    # add model specific arguments.
    parser = BertForQA.add_model_specific_args(parser)
    # add all the available trainer options to argparse
    parser = Trainer.add_argparse_args(parser)
    args = parser.parse_args()

    model = BertForQA(args)

    if len(args.pretrained_checkpoint) > 1:
        model.load_state_dict(torch.load(args.pretrained_checkpoint,
                                         map_location=torch.device('cpu'))["state_dict"])

    checkpoint_callback = ModelCheckpoint(
        filepath=args.output_dir,
        save_top_k=args.max_keep_ckpt,
        verbose=True,
        period=-1,
        mode="auto"
    )

    trainer = Trainer.from_argparse_args(
        args,
        checkpoint_callback=checkpoint_callback,
        deterministic=True
    )

    trainer.fit(model)
コード例 #23
0
def test_argparse_args_parsing_fast_dev_run(cli_args, expected):
    """Test multi type argument with bool."""
    cli_args = cli_args.split(" ") if cli_args else []
    with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
        parser = ArgumentParser(add_help=False)
        parser = Trainer.add_argparse_args(parent_parser=parser)
        args = Trainer.parse_argparser(parser)
    assert args.fast_dev_run is expected
コード例 #24
0
def main():
    parser = add_model_specific_args()
    parser = Trainer.add_argparse_args(parser)
    args = parser.parse_args()
    if args.mode == 'label':
        label_unlabeled_data(args)
    else:
        train_model(args)
コード例 #25
0
def parse_args(argv=None):
    argv = argv or []
    parser = ArgumentParser()
    parser = ConvVAE.add_model_specific_args(parser)
    parser = Trainer.add_argparse_args(parser)
    args = parser.parse_args(argv)

    return args
コード例 #26
0
ファイル: ixi_train_t2net.py プロジェクト: chunmeifeng/T2Net
def build_args():
    # ------------------------
    # TRAINING ARGUMENTS
    # ------------------------
    path_config = "ixi_config.yaml"

    with open(path_config) as f:
        data = yaml.load(f, Loader=yaml.FullLoader)
        ixi_args = SimpleNamespace(**data)
        ixi_args.mask_path = ('./masks_mei/1D-Cartesian_6X_256.mat')

    data_path = data['data_dir']
    logdir = data['output_dir'] + "/dense_edsr/ixi/edsr_transformer"  #

    parent_parser = ArgumentParser(add_help=False)

    parser = UnetModule.add_model_specific_args(parent_parser)
    parser = Trainer.add_argparse_args(parser)

    num_gpus = 1
    backend = "ddp"
    batch_size = 8 if backend == "ddp" else num_gpus

    # module config
    config = dict(
        n_channels_in=1,
        n_channels_out=1,
        n_resgroups=5,  # 10
        n_resblocks=8,  # 20
        n_feats=64,  # 64
        lr=0.00005,
        lr_step_size=40,
        lr_gamma=0.1,
        weight_decay=0.0,
        data_path=data_path,
        exp_dir=logdir,
        exp_name="unet_demo",
        test_split="test",
        batch_size=batch_size,
        ixi_args=ixi_args,
    )
    parser.set_defaults(**config)

    # trainer config
    parser.set_defaults(
        gpus=num_gpus,
        max_epochs=50,
        default_root_dir=logdir,
        replace_sampler_ddp=(backend != "ddp"),
        distributed_backend=backend,
        seed=42,
        deterministic=True,
    )

    parser.add_argument("--mode", default="train", type=str)
    args = parser.parse_args()

    return args
コード例 #27
0
def cli_main():
    parser = ArgumentParser()
    parser = Trainer.add_argparse_args(parser)
    parser = AE.add_model_specific_args(parser)
    args = parser.parse_args()

    ae = AE(**vars(args))
    trainer = Trainer.from_argparse_args(args)
    trainer.fit(ae)
コード例 #28
0
    def _parse_args(self) -> Dict:
        parser = argparse.ArgumentParser()
        parser = Trainer.add_argparse_args(parser)
        parser = Workspace.add_argparse_args(parser)
        parser = self.pl_module_cls.add_argparse_args(parser)

        args = parser.parse_args()

        return args.__dict__
コード例 #29
0
ファイル: main.py プロジェクト: elouayas/JFR2020
def init_trainer():
    """ Init a Lightning Trainer using from_argparse_args
    Thus every CLI command (--gpus, distributed_backend, ...) become available.
    """
    parser = ArgumentParser()
    parser = Trainer.add_argparse_args(parser)
    args = parser.parse_args()
    lr_logger = LearningRateLogger()
    return Trainer.from_argparse_args(args, callbacks=[lr_logger])
コード例 #30
0
def get_args():
    parser = ArgumentParser()
    parser = Trainer.add_argparse_args(parser)
    temp_args, _ = parser.parse_known_args()
    file = open('project/ecgresnet_config.json', 'r')
    ECGResNet_params = json.load(file)
    file.close()

    return parser, ECGResNet_params