示例#1
0
def test_scene_parameters_loading():
    """Tests if SceneParameters are still equal after loading from json."""
    sampler = two4two.Sampler()
    sampled_param = sampler.sample()
    json_buf = json.dumps(sampled_param.state_dict())
    loaded_param = two4two.SceneParameters.load(json.loads(json_buf))
    assert sampled_param == loaded_param
示例#2
0
def test_blender_fliplr(tmp_path: Path):
    """Tests if fliplr produces the exact same image but only flipped."""
    np.random.seed(200002)
    sampler = two4two.Sampler()
    param_original = sampler.sample()
    param_flip = param_original.clone()
    param_flip.fliplr = True

    print(tmp_path)
    original_path = tmp_path / 'original'
    fliplr_path = tmp_path / 'fliplr'
    original_path.mkdir()
    fliplr_path.mkdir()

    for (img_original, mask_original, _) in two4two.render(
        [param_original],
            output_dir=str(original_path),
    ):
        pass

    for (img_fliplr, mask_fliplr, _) in two4two.render(
        [param_flip],
            output_dir=str(fliplr_path),
    ):
        pass

    assert (img_fliplr == img_original[:, ::-1]).all()
    assert (mask_fliplr == mask_original[:, ::-1]).all()
示例#3
0
def test_samplers_valid():
    """Test if the custom samplers run."""

    samplers = [
        two4two.Sampler(),
        two4two.ColorBiasedSampler(),
        two4two.HighVariationSampler(),
        two4two.HighVariationColorBiasedSampler()
    ]

    for sampler in samplers:
        for _ in range(40):
            sampler.sample()
示例#4
0
def test_generic_sampler():
    """Tests if generic sample can handle all its intended types."""
    sampler = two4two.Sampler()
    scipy_trunc_normal = utils.truncated_normal(0, 0.5, 0, 1)
    py_uniform = random.random
    test_dict = {
        'peaky': scipy_trunc_normal,
        'stretchy': py_uniform,
        'ignore': None
    }

    assert isinstance(sampler._sample('peaky', scipy_trunc_normal),
                      numbers.Number)
    assert isinstance(sampler._sample('peaky', test_dict), numbers.Number)
    assert isinstance(sampler._sample('stretchy', test_dict), numbers.Number)
    assert isinstance(sampler._sample('stretchy', test_dict, size=5), tuple)

    with pytest.raises(KeyError):
        two4two.Sampler._sample('ronny', test_dict)
示例#5
0
def test_sample_scene_parameters():
    """Test sampling of SceneParameters."""
    sampler = two4two.Sampler()
    for i in range(1000):
        param = sampler.sample()
        param.check_values()
示例#6
0
def test_printing_scene_parameters():
    """Dummy test of __str__ function."""
    params = two4two.Sampler().sample()
    print(str(params))
示例#7
0
def test_pytorch_dataloader(tmp_path: Path):
    """Tests if dataset can load the rendered images."""
    print("test temp dir: ", tmp_path)
    np.random.seed(242)

    sampler = two4two.Sampler()
    sampled_params = [sampler.sample() for _ in range(2)]

    (tmp_path / 'train').mkdir()

    for _ in two4two.render(
        sampled_params,
        n_processes=1,
        output_dir=str(tmp_path / 'train'),
    ):
        pass

    dataset = two4two.pytorch.Two4Two(str(tmp_path), split='train')

    df = dataset.get_dataframe()
    assert df.obj_name[0] == sampled_params[0].obj_name
    assert df.obj_name[1] == sampled_params[1].obj_name
    assert "resolution" not in set(df.keys())

    df = dataset.get_dataframe(to_dict=two4two.pytorch.all_attributes)
    assert df.attribute_status_obj_name[0] == "sampled"
    assert df.attribute_status_obj_name[1] == "sampled"

    # check dataset shapes
    assert len(dataset) == 2
    img, mask, labels = dataset[0]
    assert img.shape == (3, 128, 128)
    assert mask.shape == (1, 128, 128)
    assert labels.shape == (1,)

    assert type(img) == torch.Tensor
    assert type(mask) == torch.Tensor
    assert type(labels) == torch.Tensor

    # check dataset loader
    dataloader = DataLoader(dataset, batch_size=2)
    imgs, masks, labels = next(iter(dataloader))

    assert type(imgs) == torch.Tensor
    assert type(masks) == torch.Tensor
    assert type(labels) == torch.Tensor

    assert imgs.shape == (2, 3, 128, 128)
    assert masks.shape == (2, 1, 128, 128)
    assert labels.shape == (2, 1,)

    dataset.set_return_attributes([
        'obj_name', 'bending', 'bg_color', 'spherical'])

    label_names = dataset.get_label_names()
    expected_label_names = [
        'obj_name',
        'bending',
        'bg_color',
        'spherical',
    ]
    assert label_names == expected_label_names

    img, mask, labels = dataset[0]
    assert labels.shape == (len(expected_label_names),)