예제 #1
0
def test_weighted_random_sampler_exception():
    """
    Test error cases for WeightedRandomSampler
    """
    logger.info("Test error cases for WeightedRandomSampler")
    error_msg_1 = "type of weights element should be number"
    with pytest.raises(TypeError, match=error_msg_1):
        weights = ""
        ds.WeightedRandomSampler(weights)

    error_msg_2 = "type of weights element should be number"
    with pytest.raises(TypeError, match=error_msg_2):
        weights = (0.9, 0.8, 1.1)
        ds.WeightedRandomSampler(weights)

    error_msg_3 = "WeightedRandomSampler: weights vector must not be empty"
    with pytest.raises(RuntimeError, match=error_msg_3):
        weights = []
        sampler = ds.WeightedRandomSampler(weights)
        sampler.parse()

    error_msg_4 = "WeightedRandomSampler: weights vector must not contain negative number, got: "
    with pytest.raises(RuntimeError, match=error_msg_4):
        weights = [1.0, 0.1, 0.02, 0.3, -0.4]
        sampler = ds.WeightedRandomSampler(weights)
        sampler.parse()

    error_msg_5 = "WeightedRandomSampler: elements of weights vector must not be all zero"
    with pytest.raises(RuntimeError, match=error_msg_5):
        weights = [0, 0, 0, 0, 0]
        sampler = ds.WeightedRandomSampler(weights)
        sampler.parse()
def test_weighted_random_sampler_exception():
    """
    Test error cases for WeightedRandomSampler
    """
    logger.info("Test error cases for WeightedRandomSampler")
    error_msg_1 = "type of weights element should be number"
    with pytest.raises(TypeError, match=error_msg_1):
        weights = ""
        ds.WeightedRandomSampler(weights)

    error_msg_2 = "type of weights element should be number"
    with pytest.raises(TypeError, match=error_msg_2):
        weights = (0.9, 0.8, 1.1)
        ds.WeightedRandomSampler(weights)

    error_msg_3 = "weights size should not be 0"
    with pytest.raises(ValueError, match=error_msg_3):
        weights = []
        ds.WeightedRandomSampler(weights)

    error_msg_4 = "weights should not contain negative numbers"
    with pytest.raises(ValueError, match=error_msg_4):
        weights = [1.0, 0.1, 0.02, 0.3, -0.4]
        ds.WeightedRandomSampler(weights)

    error_msg_5 = "elements of weights should not be all zero"
    with pytest.raises(ValueError, match=error_msg_5):
        weights = [0, 0, 0, 0, 0]
        ds.WeightedRandomSampler(weights)
예제 #3
0
def test_chained_sampler_06():
    logger.info("Test Case Chained Sampler - WeightedRandom and PKSampler")

    # Create chained sampler, WeightedRandom and PKSampler
    weights = [1.0, 0.1, 0.02, 0.3, 0.4, 0.05, 1.2, 0.13, 0.14, 0.015, 0.16, 0.5]
    sampler = ds.WeightedRandomSampler(weights=weights, num_samples=12)
    child_sampler = ds.PKSampler(num_val=3)  # Number of elements per class is 3 (and there are 4 classes)
    sampler.add_child(child_sampler)
    # Create ImageFolderDataset with sampler
    data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)

    # Verify dataset size
    data1_size = data1.get_dataset_size()
    logger.info("dataset size is: {}".format(data1_size))
    assert data1_size == 12

    # Verify number of iterations
    num_iter = 0
    for item in data1.create_dict_iterator(num_epochs=1):  # each data is a dictionary
        # in this example, each dictionary has keys "image" and "label"
        logger.info("image is {}".format(item["image"]))
        logger.info("label is {}".format(item["label"]))
        num_iter += 1

    logger.info("Number of data in data1: {}".format(num_iter))
    # Note: WeightedRandomSampler produces 12 samples
    # Note: Child PKSampler produces 12 samples
    assert num_iter == 12
예제 #4
0
def test_weighted_random_sampler():
    logger.info("Test Case WeightedRandomSampler")
    # define parameters
    repeat_count = 1

    # apply dataset operations
    weights = [1.0, 0.1, 0.02, 0.3, 0.4, 0.05, 1.2, 0.13, 0.14, 0.015, 0.16, 1.1]
    sampler = ds.WeightedRandomSampler(weights, 11)
    data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
    data1 = data1.repeat(repeat_count)

    num_iter = 0
    for item in data1.create_dict_iterator(num_epochs=1):  # each data is a dictionary
        # in this example, each dictionary has keys "image" and "label"
        logger.info("image is {}".format(item["image"]))
        logger.info("label is {}".format(item["label"]))
        num_iter += 1

    logger.info("Number of data in data1: {}".format(num_iter))
    assert num_iter == 11
예제 #5
0
def test_serdes_imagefolder_dataset(remove_json_files=True):
    """
    Test simulating resnet50 dataset pipeline.
    """
    data_dir = "../data/dataset/testPK/data"
    ds.config.set_seed(1)

    # define data augmentation parameters
    rescale = 1.0 / 255.0
    shift = 0.0
    resize_height, resize_width = 224, 224
    weights = [
        1.0, 0.1, 0.02, 0.3, 0.4, 0.05, 1.2, 0.13, 0.14, 0.015, 0.16, 1.1
    ]

    # Constructing DE pipeline
    sampler = ds.WeightedRandomSampler(weights, 11)
    child_sampler = ds.SequentialSampler()
    sampler.add_child(child_sampler)
    data1 = ds.ImageFolderDataset(data_dir, sampler=sampler)
    data1 = data1.repeat(1)
    data1 = data1.map(operations=[vision.Decode(True)],
                      input_columns=["image"])
    rescale_op = vision.Rescale(rescale, shift)

    resize_op = vision.Resize((resize_height, resize_width), Inter.LINEAR)
    data1 = data1.map(operations=[rescale_op, resize_op],
                      input_columns=["image"])
    data1 = data1.batch(2)

    # Serialize the dataset pre-processing pipeline.
    # data1 should still work after saving.
    ds.serialize(data1, "imagenet_dataset_pipeline.json")
    ds1_dict = ds.serialize(data1)
    assert validate_jsonfile("imagenet_dataset_pipeline.json") is True

    # Print the serialized pipeline to stdout
    ds.show(data1)

    # Deserialize the serialized json file
    data2 = ds.deserialize(json_filepath="imagenet_dataset_pipeline.json")

    # Serialize the pipeline we just deserialized.
    # The content of the json file should be the same to the previous serialize.
    ds.serialize(data2, "imagenet_dataset_pipeline_1.json")
    assert validate_jsonfile("imagenet_dataset_pipeline_1.json") is True
    assert filecmp.cmp('imagenet_dataset_pipeline.json',
                       'imagenet_dataset_pipeline_1.json')

    # Deserialize the latest json file again
    data3 = ds.deserialize(json_filepath="imagenet_dataset_pipeline_1.json")
    data4 = ds.deserialize(input_dict=ds1_dict)
    num_samples = 0
    # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2)
    for item1, item2, item3, item4 in zip(
            data1.create_dict_iterator(num_epochs=1, output_numpy=True),
            data2.create_dict_iterator(num_epochs=1, output_numpy=True),
            data3.create_dict_iterator(num_epochs=1, output_numpy=True),
            data4.create_dict_iterator(num_epochs=1, output_numpy=True)):
        np.testing.assert_array_equal(item1['image'], item2['image'])
        np.testing.assert_array_equal(item1['image'], item3['image'])
        np.testing.assert_array_equal(item1['label'], item2['label'])
        np.testing.assert_array_equal(item1['label'], item3['label'])
        np.testing.assert_array_equal(item3['image'], item4['image'])
        np.testing.assert_array_equal(item3['label'], item4['label'])
        num_samples += 1

    logger.info("Number of data in data1: {}".format(num_samples))
    assert num_samples == 6

    # Remove the generated json file
    if remove_json_files:
        delete_json_files()