Exemple #1
0
def test_cpu_template(cli_args):
    """Test running CLI for an example with default params."""
    from pl_examples.basic_examples.cpu_template import run_cli

    cli_args = cli_args.split(' ') if cli_args else []
    with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
        run_cli()
def test_sync_bn(cli_args):
    """Test running CLI for an example with sync bn."""
    from pl_examples.basic_examples.sync_bn import run_cli

    cli_args = cli_args.split(' ') if cli_args else []
    with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
        run_cli()
def test_imagenet(tmpdir, cli_args):
    """Test running CLI for the ImageNet example with default params."""

    from pl_examples.domain_templates.imagenet import run_cli

    # https://github.com/pytorch/vision/blob/master/test/fakedata_generation.py#L105
    def _make_image(file_path):
        Image.fromarray(np.zeros((32, 32, 3), dtype=np.uint8)).save(file_path)

    for split in ['train', 'val']:
        for class_id in ['a', 'b']:
            os.makedirs(os.path.join(tmpdir, split, class_id))
            # Generate 5 black images
            for image_id in range(5):
                _make_image(
                    os.path.join(tmpdir, split, class_id,
                                 str(image_id) + '.JPEG'))

    cli_args = cli_args.split(' ') if cli_args else []
    cli_args += ['--data-path', str(tmpdir)]
    cli_args += ['--default_root_dir', str(tmpdir)]

    with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
        run_cli()