Exemplo n.º 1
0
def test_cli_run_log_regression(cli_args):
    """Test running CLI for an example with default params."""
    from pl_bolts.models.regression.logistic_regression import cli_main

    cli_args = cli_args.strip().split(" ") if cli_args else []
    with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
        cli_main()
Exemplo n.º 2
0
def test_cli_run_retinanet(cli_args):
    """Test running CLI for an example with default params."""
    from pl_bolts.models.detection.retinanet.retinanet_module import cli_main

    cli_args = cli_args.strip().split(" ") if cli_args else []
    with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
        cli_main()
Exemplo n.º 3
0
def test_cli_run_basic_vae(cli_args):
    """Test running CLI for an example with default params."""
    from pl_bolts.models.autoencoders.basic_vae.basic_vae_module import cli_main

    cli_args = cli_args.strip().split(" ") if cli_args else []
    with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
        cli_main()
Exemplo n.º 4
0
def test_cli_run_vision_image_gpt(cli_args):
    """Test running CLI for an example with default params."""
    from pl_bolts.models.vision.image_gpt.igpt_module import cli_main

    cli_args = cli_args.strip().split(" ") if cli_args else []
    with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
        cli_main()
def test_cli_run_basic_gan(cli_args, dataset_name):
    from pl_bolts.models.gans.basic.basic_gan_module import cli_main

    cli_args = cli_args % {'dataset_name': dataset_name}
    with mock.patch("argparse._sys.argv",
                    ["any.py"] + cli_args.strip().split()):
        cli_main()
def test_cli_run_mnist(cli_args):
    """Test running CLI for an example with default params."""
    from pl_bolts.models.mnist_module import cli_main

    cli_args = cli_args.split(' ') if cli_args else []
    with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
        cli_main()
def test_cli_basic_ae(dataset_name):
    from pl_bolts.models.autoencoders.basic_ae.basic_ae_module import cli_main

    cli_args = f"""
        --dataset {dataset_name}
        --max_epochs 1
        --limit_train_batches 3
        --limit_val_batches 3
        --batch_size 3
    """.strip().split()

    with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
        cli_main()
def test_cli_cpc(cli_args):
    from pl_bolts.models.self_supervised.cpc.cpc_module import cli_main

    cli_args = cli_args.split(' ') if cli_args else []
    with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
        cli_main()
def test_cli_basic_gan(cli_args):
    from pl_bolts.models.gans.basic.basic_gan_module import cli_main

    cli_args = cli_args.split(' ') if cli_args else []
    with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
        cli_main()
Exemplo n.º 10
0
def test_cli_run_srresnet(cli_args):
    from pl_bolts.models.gans.srgan.srresnet_module import cli_main

    with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()):
        cli_main()
def test_cli_run_basic_gan(cli_args):
    from pl_bolts.models.gans.basic.basic_gan_module import cli_main

    with mock.patch("argparse._sys.argv",
                    ["any.py"] + cli_args.strip().split()):
        cli_main()