def test_full_pytorch_example(large_mock_mnist_data, tmpdir):
    # First, generate mock dataset
    dataset_url = 'file://{}'.format(tmpdir)
    mnist_data_to_petastorm_dataset(tmpdir,
                                    dataset_url,
                                    mnist_data=large_mock_mnist_data,
                                    spark_master='local[1]',
                                    parquet_files_count=1)

    # Next, run a round of training using the pytorce adapting data loader
    from petastorm.pytorch import DataLoader

    torch.manual_seed(1)
    device = torch.device('cpu')
    model = pytorch_example.Net().to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

    with DataLoader(Reader('{}/train'.format(dataset_url),
                           reader_pool=DummyPool(),
                           num_epochs=1),
                    batch_size=32,
                    transform=pytorch_example._transform_row) as train_loader:
        pytorch_example.train(model, device, train_loader, 10, optimizer, 1)
    with DataLoader(Reader('{}/test'.format(dataset_url),
                           reader_pool=DummyPool(),
                           num_epochs=1),
                    batch_size=100,
                    transform=pytorch_example._transform_row) as test_loader:
        pytorch_example.test(model, device, test_loader)
Exemplo n.º 2
0
    def test_empty_ventilation(self):
        pool = DummyPool()
        ventilator = ConcurrentVentilator(pool.ventilate, [])
        pool.start(IdentityWorker, ventilator=ventilator)
        with self.assertRaises(EmptyResultError):
            pool.get_results()

        pool.stop()
        pool.join()
Exemplo n.º 3
0
def test_num_epochs_value_error(synthetic_dataset):
    """Tests that the reader raises value errors when appropriate"""

    with pytest.raises(ValueError):
        Reader(synthetic_dataset.url, reader_pool=DummyPool(), num_epochs=0)

    with pytest.raises(ValueError):
        Reader(synthetic_dataset.url, reader_pool=DummyPool(), num_epochs=-10)

    with pytest.raises(ValueError):
        Reader(synthetic_dataset.url, reader_pool=DummyPool(), num_epochs='abc')
Exemplo n.º 4
0
 def readout_all_ids(shuffle, drop_ratio):
     with Reader(dataset_url=synthetic_dataset.url,
                 reader_pool=DummyPool(),
                 shuffle_options=ShuffleOptions(shuffle,
                                                drop_ratio)) as reader:
         ids = [row.id for row in reader]
     return ids
Exemplo n.º 5
0
def test_pytorch_dataloader_batched(synthetic_dataset):
    batch_size = 10
    loader = DataLoader(Reader(synthetic_dataset.url, reader_pool=DummyPool()),
                        batch_size=batch_size,
                        collate_fn=_noop_collate)
    for item in loader:
        assert len(item) == batch_size
Exemplo n.º 6
0
def test_ngram_basic_longer_no_overlap(synthetic_dataset):
    """Tests basic ngram with no delta threshold with no overlaps of timestamps."""
    fields = {
        -5: [TestSchema.id, TestSchema.id2, TestSchema.matrix],
        -4: [TestSchema.id, TestSchema.id2, TestSchema.image_png],
        -3: [TestSchema.id, TestSchema.id2, TestSchema.decimal],
        -2: [TestSchema.id, TestSchema.id2, TestSchema.sensor_name],
        -1: [TestSchema.id, TestSchema.id2]
    }

    dataset_dicts = synthetic_dataset.data
    ngram = NGram(fields=fields,
                  delta_threshold=10,
                  timestamp_field=TestSchema.id,
                  timestamp_overlap=False)
    with Reader(schema_fields=ngram,
                dataset_url=synthetic_dataset.url,
                reader_pool=DummyPool(),
                shuffle_options=ShuffleOptions(False)) as reader:

        timestamps_seen = set()
        for actual in reader:
            expected_ngram = _get_named_tuple_from_ngram(
                ngram, dataset_dicts, actual[min(actual.keys())].id)
            np.testing.assert_equal(actual, expected_ngram)
            for step in actual.values():
                timestamp = step.id
                assert timestamp not in timestamps_seen
                timestamps_seen.add(timestamp)
Exemplo n.º 7
0
def test_ngram_delta_small_threshold_tf():
    """Test to verify that a small threshold work in ngrams."""

    with temporary_directory() as tmp_dir:
        tmp_url = 'file://{}'.format(tmp_dir)
        ids = range(0, 99, 5)
        create_test_dataset(tmp_url, ids)

        fields = {
            0: [
                TestSchema.id, TestSchema.id2, TestSchema.image_png,
                TestSchema.matrix
            ],
            1: [TestSchema.id, TestSchema.id2, TestSchema.sensor_name],
        }
        ngram = NGram(fields=fields,
                      delta_threshold=1,
                      timestamp_field=TestSchema.id)
        reader = Reader(
            schema_fields=ngram,
            dataset_url=tmp_url,
            reader_pool=DummyPool(),
        )

        with tf.Session() as sess:
            with pytest.raises(OutOfRangeError):
                sess.run(tf_tensors(reader))

        reader.stop()
        reader.join()
Exemplo n.º 8
0
 def test_no_metadata(self):
     self.vanish_metadata()
     with self.assertRaises(PetastormMetadataError) as e:
         Reader(self._dataset_url, reader_pool=DummyPool())
     self.assertTrue(
         'Could not find _common_metadata file' in str(e.exception))
     self.restore_metadata()
def compute_correlation_distribution(dataset_url,
                                     id_column,
                                     shuffle_options,
                                     num_corr_samples=100):
    """
    Compute the correlation distribution of a given shuffle_options on an existing dataset.
    Use this to compare 2 different shuffling options compare.
    It is encouraged to use a dataset generated by generate_shuffle_analysis_dataset for this analysis.

    :param dataset_url: Dataset url to compute correlation distribution of
    :param id_column: Column where an integer or string id can be found
    :param shuffle_options: shuffle options to test correlation against
    :param num_corr_samples: How many samples of the correlation to take to compute distribution
    :return: (mean, standard deviation) of computed distribution
    """

    # Read the dataset without any shuffling in order (need to use a dummy pool for this).
    with Reader(dataset_url,
                shuffle_options=ShuffleOptions(False),
                reader_pool=DummyPool()) as reader:
        unshuffled = [row[id_column] for row in reader]

    correlations = []
    for _ in range(num_corr_samples):
        with Reader(dataset_url, shuffle_options=shuffle_options) as reader:
            shuffled = [row[id_column] for row in reader]
            correlations.append(abs(np.corrcoef(unshuffled, shuffled)[0, 1]))

    mean = np.mean(correlations)
    std_dev = np.std(correlations)

    return mean, std_dev
Exemplo n.º 10
0
    def test_max_ventilation_size(self):
        """Tests that we dont surpass a max ventilation size in each pool type
        (since it relies on accurate ventilation size reporting)"""
        max_ventilation_size = 10

        for pool in [DummyPool(), ProcessPool(10), ThreadPool(10)]:
            ventilator = ConcurrentVentilator(ventilate_fn=pool.ventilate,
                                              items_to_ventilate=[{'item': i} for i in range(100)],
                                              max_ventilation_queue_size=max_ventilation_size)
            pool.start(IdentityWorker, ventilator=ventilator)

            # Give time for the thread to fill the ventilation queue
            while ventilator._ventilated_items_count - ventilator._processed_items_count < max_ventilation_size:
                time.sleep(.1)

            # After stopping the ventilator queue, we should only get 10 results
            ventilator.stop()
            for _ in range(max_ventilation_size):
                pool.get_results()

            with self.assertRaises(EmptyResultError):
                pool.get_results()

            pool.stop()
            pool.join()
Exemplo n.º 11
0
def test_partition_value_error(synthetic_dataset):
    """Tests that the reader raises value errors when appropriate"""

    with pytest.raises(ValueError):
        Reader(synthetic_dataset.url, reader_pool=DummyPool(), training_partition=0)

    with pytest.raises(ValueError):
        Reader(synthetic_dataset.url, reader_pool=DummyPool(), num_training_partitions=5)

    with pytest.raises(ValueError):
        Reader(synthetic_dataset.url, reader_pool=DummyPool(), training_partition='0',
               num_training_partitions=5)

    with pytest.raises(ValueError):
        Reader(synthetic_dataset.url, reader_pool=DummyPool(), training_partition=0,
               num_training_partitions='5')
Exemplo n.º 12
0
    def test_reset_ventilator(self):
        """Resetting ventilator after all items were ventilated will make it re-ventilate the same items"""
        items_count = 100
        for pool in [DummyPool(), ThreadPool(10)]:
            ventilator = ConcurrentVentilator(ventilate_fn=pool.ventilate,
                                              items_to_ventilate=[{
                                                  'item': i
                                              } for i in range(items_count)],
                                              iterations=1)
            pool.start(IdentityWorker, ventilator=ventilator)

            # Readout all ventilated items
            for _ in range(items_count):
                pool.get_results()

            # Should fail reading the next, as all items were read by now
            with self.assertRaises(EmptyResultError):
                pool.get_results()

            # Resetting, hence will be read out the items all over again
            ventilator.reset()

            for _ in range(items_count):
                pool.get_results()

            with self.assertRaises(EmptyResultError):
                pool.get_results()

            pool.stop()
            pool.join()
Exemplo n.º 13
0
def test_simple_read_moved_dataset(synthetic_dataset, tmpdir):
    """Tests that a dataset may be opened after being moved to a new location"""
    a_moved_path = tmpdir.join('moved').strpath
    copytree(synthetic_dataset.path, a_moved_path)

    with Reader('file://{}'.format(a_moved_path), reader_pool=DummyPool()) as reader:
        _check_simple_reader(reader, synthetic_dataset.data)
Exemplo n.º 14
0
def test_unlimited_epochs(synthetic_dataset):
    """Tests that unlimited epochs works as expected"""
    with Reader(synthetic_dataset.url, reader_pool=DummyPool(), num_epochs=None) as reader:
        # Read many expected entries from the dataset and compare the data to reference
        for _ in range(len(synthetic_dataset.data) * random.randint(10, 30) + random.randint(25, 50)):
            actual = dict(next(reader)._asdict())
            expected = next(d for d in synthetic_dataset.data if d['id'] == actual['id'])
            np.testing.assert_equal(expected, actual)
Exemplo n.º 15
0
def test_partition_multi_node(synthetic_dataset):
    """Tests that the reader only returns half of the expected data consistently"""
    reader = Reader(synthetic_dataset.url,
                    reader_pool=DummyPool(),
                    training_partition=0,
                    num_training_partitions=5)
    reader_2 = Reader(synthetic_dataset.url,
                      reader_pool=DummyPool(),
                      training_partition=0,
                      num_training_partitions=5)

    results_1 = []
    expected = []
    for row in reader:
        actual = dict(row._asdict())
        results_1.append(actual)
        expected.append(
            next(d for d in synthetic_dataset.data if d['id'] == actual['id']))

    results_2 = [dict(row._asdict()) for row in reader_2]

    # Since order is non deterministic, we need to sort results by id
    results_1.sort(key=lambda x: x['id'])
    results_2.sort(key=lambda x: x['id'])
    expected.sort(key=lambda x: x['id'])

    np.testing.assert_equal(expected, results_1)
    np.testing.assert_equal(results_1, results_2)

    assert len(results_1) < len(synthetic_dataset.data)

    # Test that separate partitions also have no overlap by checking ids
    id_set = set([item['id'] for item in results_1])
    for partition in range(1, 5):
        with Reader(synthetic_dataset.url,
                    reader_pool=DummyPool(),
                    training_partition=partition,
                    num_training_partitions=5) as reader_other:

            for row in reader_other:
                assert dict(row._asdict())['id'] not in id_set

    reader.stop()
    reader.join()
    reader_2.stop()
    reader_2.join()
Exemplo n.º 16
0
def test_rowgroup_selector_wrong_index_name(synthetic_dataset):
    """ Attempt to select row groups to based on wrong dataset index,
        Reader should raise exception
    """
    with pytest.raises(ValueError):
        Reader(synthetic_dataset.url,
               rowgroup_selector=SingleIndexSelector('WrongIndexName',
                                                     ['some_value']),
               reader_pool=DummyPool())
Exemplo n.º 17
0
def test_reading_subset_of_columns(synthetic_dataset):
    """Just a bunch of read and compares of all values to the expected values"""
    with Reader(synthetic_dataset.url, schema_fields=[TestSchema.id2, TestSchema.id],
                reader_pool=DummyPool()) as reader:
        # Read a bunch of entries from the dataset and compare the data to reference
        for row in reader:
            actual = dict(row._asdict())
            expected = next(d for d in synthetic_dataset.data if d['id'] == actual['id'])
            np.testing.assert_equal(expected['id2'], actual['id2'])
Exemplo n.º 18
0
def test_rowgroup_selector_string_field(synthetic_dataset):
    """ Select row groups to read based on dataset index for string field"""
    with Reader(synthetic_dataset.url,
                rowgroup_selector=SingleIndexSelector(TestSchema.sensor_name.name, ['test_sensor']),
                reader_pool=DummyPool()) as reader:
        count = 0
        for _ in reader:
            count += 1
        # Since we use artificial dataset all sensors have the same name,
        # so all row groups should be selected and all 100 generated rows should be returned
        assert 100 == count
Exemplo n.º 19
0
def test_predicate_on_single_column(synthetic_dataset):
    reader = Reader(synthetic_dataset.url,
                    schema_fields=[TestSchema.id2],
                    predicate=in_lambda(['id2'], lambda id2: True),
                    reader_pool=DummyPool())
    counter = 0
    for row in reader:
        counter += 1
        actual = dict(row._asdict())
        assert actual['id2'] < 2
    assert counter == len(synthetic_dataset.data)
Exemplo n.º 20
0
    def test_metadata_missing_unischema(self):
        """ Produce a BAD _metadata that is missing the unischema pickling first, then load dataset. """

        # Remove the common metadata file with unischema information
        self.vanish_metadata('_common_metadata')

        # Reader will now just get the metadata file which will not have the unischema information
        with self.assertRaises(ValueError) as e:
            Reader(self._dataset_url, reader_pool=DummyPool())
        self.assertTrue('Could not find the unischema' in str(e.exception))
        self.restore_metadata('_common_metadata')
Exemplo n.º 21
0
def test_rowgroup_selector_nullable_array_field(synthetic_dataset):
    """ Select row groups to read based on dataset index for array field"""
    with Reader(synthetic_dataset.url,
                rowgroup_selector=SingleIndexSelector(TestSchema.string_array_nullable.name, ['100']),
                reader_pool=DummyPool()) as reader:
        count = sum(1 for _ in reader)
        # This field contain id string, generated like this
        #   None if id % 5 == 0 else np.asarray([], dtype=np.string_) if id % 4 == 0 else
        #   np.asarray([str(i+id) for i in xrange(2)], dtype=np.string_)
        # hence '100' could be present in row id 99 as 99+1 and row id 100 as 100+0
        # but row 100 will be skipped by ' None if id % 5 == 0' condition, so only one row group should be selected
        assert 10 == count
Exemplo n.º 22
0
def _create_worker_pool(pool_type, workers_count, profiling_enabled, pyarrow_serialize):
    """Different worker pool implementation (in process none or thread-pool, out of process pool)"""
    if pool_type == WorkerPoolType.THREAD:
        worker_pool = ThreadPool(workers_count, profiling_enabled=profiling_enabled)
    elif pool_type == WorkerPoolType.PROCESS:
        worker_pool = ProcessPool(workers_count,
                                  serializer=PyArrowSerializer() if pyarrow_serialize else PickleSerializer())
    elif pool_type == WorkerPoolType.NONE:
        worker_pool = DummyPool()
    else:
        raise ValueError('Supported pool types are thread, process or dummy. Got {}.'.format(pool_type))
    return worker_pool
Exemplo n.º 23
0
def test_ngram_delta_threshold_tf(dataset_0_3_8_10_11_20_23):
    """Test to verify that delta threshold work as expected in one partition in the same ngram
    and between consecutive ngrams. delta threshold here refers that each ngram must not be
    more than delta threshold apart for the field specified by timestamp_field."""

    fields = {
        0: [
            TestSchema.id, TestSchema.id2, TestSchema.image_png,
            TestSchema.matrix
        ],
        1: [TestSchema.id, TestSchema.id2, TestSchema.sensor_name],
    }
    ngram = NGram(fields=fields,
                  delta_threshold=4,
                  timestamp_field=TestSchema.id)
    with Reader(schema_fields=ngram,
                dataset_url=dataset_0_3_8_10_11_20_23.url,
                reader_pool=DummyPool(),
                shuffle_options=ShuffleOptions(False)) as reader:

        # Ngrams expected: (0, 3), (8, 10), (10, 11)

        with tf.Session() as sess:
            readout = tf_tensors(reader)
            for timestep in readout:
                for field in readout[timestep]:
                    assert field.get_shape().dims is not None
            first_item = sess.run(readout)
            expected_item = _get_named_tuple_from_ngram(
                ngram, dataset_0_3_8_10_11_20_23.data, 0)
            _assert_equal_ngram(first_item, expected_item)

            readout = tf_tensors(reader)
            for timestep in readout:
                for field in readout[timestep]:
                    assert field.get_shape().dims is not None
            second_item = sess.run(readout)
            expected_item = _get_named_tuple_from_ngram(
                ngram, dataset_0_3_8_10_11_20_23.data, 3)
            _assert_equal_ngram(second_item, expected_item)

            readout = tf_tensors(reader)
            for timestep in readout:
                for field in readout[timestep]:
                    assert field.get_shape().dims is not None
            third_item = sess.run(readout)
            expected_item = _get_named_tuple_from_ngram(
                ngram, dataset_0_3_8_10_11_20_23.data, 5)
            _assert_equal_ngram(third_item, expected_item)

            with pytest.raises(OutOfRangeError):
                sess.run(tf_tensors(reader))
Exemplo n.º 24
0
def test_stable_pieces_order(synthetic_dataset):
    """Tests that the reader raises value errors when appropriate"""

    RERUN_THE_TEST_COUNT = 20
    baseline_run = None
    for _ in range(RERUN_THE_TEST_COUNT):
        with Reader(synthetic_dataset.url, schema_fields=[TestSchema.id], shuffle_options=ShuffleOptions(False),
                    reader_pool=DummyPool()) as reader:
            this_run = [row.id for row in reader]
        if baseline_run:
            assert this_run == baseline_run

        baseline_run = this_run
Exemplo n.º 25
0
    def test_worker_produces_no_results(self):
        """Check edge case, when workers consistently does not produce results"""
        # 10000 is an interesting case as in the original implementation it caused stack overflow
        for ventilate_count in [10, 10000]:
            for pool in [DummyPool(), ThreadPool(2)]:
                pool.start(PreprogrammedReturnValueWorker, ventilate_count * [[]])
                for _ in range(ventilate_count):
                    pool.ventilate('not_important')

                with self.assertRaises(EmptyResultError):
                    pool.get_results()

                pool.stop()
                pool.join()
Exemplo n.º 26
0
def test_multiple_epochs(synthetic_dataset):
    """Tests that multiple epochs works as expected"""
    num_epochs = 5
    with Reader(synthetic_dataset.url, reader_pool=DummyPool(), num_epochs=num_epochs) as reader:
        # Read all expected entries from the dataset and compare the data to reference
        id_set = set([d['id'] for d in synthetic_dataset.data])

        for _ in range(num_epochs):
            current_epoch_set = set()
            for _ in range(len(id_set)):
                actual = dict(next(reader)._asdict())
                expected = next(d for d in synthetic_dataset.data if d['id'] == actual['id'])
                np.testing.assert_equal(expected, actual)
                current_epoch_set.add(actual['id'])
            np.testing.assert_equal(id_set, current_epoch_set)
def test_real_reader(synthetic_dataset):
    readers = [
        Reader(synthetic_dataset.url,
               predicate=in_lambda(['id'], lambda id: id % 2 == 0),
               num_epochs=None,
               reader_pool=DummyPool()),
        Reader(synthetic_dataset.url,
               predicate=in_lambda(['id'], lambda id: id % 2 == 1),
               num_epochs=None,
               reader_pool=DummyPool())
    ]
    results = [0, 0]
    num_of_reads = 300
    with WeightedSamplingReader(readers, [0.5, 0.5]) as mixer:
        # Piggyback on this test to verify container interface of the WeightedSamplingReader
        for i, sample in enumerate(mixer):
            next_id = sample.id % 2
            results[next_id] += 1
            if i >= num_of_reads:
                break

    np.testing.assert_allclose(results,
                               [num_of_reads * 0.5, num_of_reads * 0.5],
                               atol=num_of_reads / 10)
Exemplo n.º 28
0
def test_ngram_shuffle_drop_ratio(synthetic_dataset):
    """Test to verify the shuffle drop ratio work as expected."""
    fields = {
        -2: [TestSchema.id, TestSchema.id2, TestSchema.matrix],
        -1: [TestSchema.id, TestSchema.id2, TestSchema.image_png],
        0: [TestSchema.id, TestSchema.id2, TestSchema.decimal],
        1: [TestSchema.id, TestSchema.id2, TestSchema.sensor_name],
        2: [TestSchema.id, TestSchema.id2]
    }
    ngram = NGram(fields=fields,
                  delta_threshold=10,
                  timestamp_field=TestSchema.id)
    with Reader(synthetic_dataset.url,
                schema_fields=ngram,
                shuffle_options=ShuffleOptions(False),
                reader_pool=DummyPool()) as reader:
        unshuffled = [row[0].id for row in reader]
    with Reader(synthetic_dataset.url,
                schema_fields=ngram,
                shuffle_options=ShuffleOptions(True, 6),
                reader_pool=DummyPool()) as reader:
        shuffled = [row[0].id for row in reader]
    assert len(unshuffled) == len(shuffled)
    assert unshuffled != shuffled
Exemplo n.º 29
0
    def test_worker_produces_some_results(self):
        """Check edge case, when workers consistently does not produce results"""
        # 10000 is an interesting case as in the original implementation it caused stack overflow
        VENTILATE_COUNT = 4
        for pool in [DummyPool(), ThreadPool(1)]:
            pool.start(PreprogrammedReturnValueWorker, [[], [], [42], []])
            for _ in range(VENTILATE_COUNT):
                pool.ventilate('not_important')

            self.assertEqual(42, pool.get_results())
            with self.assertRaises(EmptyResultError):
                pool.get_results()

            pool.stop()
            pool.join()
Exemplo n.º 30
0
def test_rowgroup_selector_integer_field(synthetic_dataset):
    """ Select row groups to read based on dataset index for integer field"""
    with Reader(synthetic_dataset.url, rowgroup_selector=SingleIndexSelector(TestSchema.id.name, [2, 18]),
                reader_pool=DummyPool()) as reader:
        status = [False, False]
        count = 0
        for row in reader:
            if row.id == 2:
                status[0] = True
            if row.id == 18:
                status[1] = True
            count += 1
        # both id values in reader result
        assert all(status)
        # read only 2 row groups, 10 rows per row group
        assert 20 == count