コード例 #1
0
ファイル: test_spark.py プロジェクト: scpwais/oarphpy
def _check_serialization(spark, rows, testname, schema=None):
    from oarphpy import util
    from oarphpy.spark import RowAdapter

    TEST_TEMPDIR = testutil.test_tempdir('spark_row_adapter_test')

    adapted_rows = [RowAdapter.to_row(r) for r in rows]
    if schema:
        df = spark.createDataFrame(adapted_rows,
                                   schema=schema,
                                   verifySchema=False)
        # verifySchema is expensive and improperly erros on mostly empty rows
    else:
        df = spark.createDataFrame(adapted_rows)
        # Automatically samples both rows to get schema
    outpath = os.path.join(TEST_TEMPDIR, 'rowdata_%s' % testname)
    df.write.parquet(outpath)

    df2 = spark.read.parquet(outpath)
    decoded_wrapped_rows = df2.collect()

    decoded_rows = [RowAdapter.from_row(row) for row in decoded_wrapped_rows]

    # We can't do assert sorted(rows) == sorted(decoded_rows)
    # because numpy syntatic sugar breaks ==
    import pprint

    def sorted_row_str(rowz):
        return pprint.pformat(sorted(rowz, key=lambda row: row['id']))

    assert sorted_row_str(rows) == sorted_row_str(decoded_rows)
コード例 #2
0
def test_archive_flyweight_tar():
    TEST_TEMPDIR = testutil.test_tempdir('test_archive_flyweight_tar')
    fixture_path = os.path.join(TEST_TEMPDIR, 'test.tar')

    # Create the fixture:
    # test.tar
    #  |- foo: "foo"
    #  |- bar: "bar"
    # ... an archive with a few files, where each file contains just a string
    #   that matches the name of the file in the archive
    ss = [b'foo', b'bar', b'bazzzz']

    import tarfile
    with tarfile.open(fixture_path, mode='w') as t:
        for s in ss:
            from io import BytesIO
            buf = BytesIO()
            buf.write(s)
            buf.seek(0)

            buf.seek(0, os.SEEK_END)
            buf_len = buf.tell()
            buf.seek(0)

            info = tarfile.TarInfo(name=s.decode('utf-8'))
            info.size = buf_len

            t.addfile(tarinfo=info, fileobj=buf)

    # Test reading!
    fws = util.ArchiveFileFlyweight.fws_from(fixture_path)
    _check_fws(fws, ss)
コード例 #3
0
ファイル: test_spark.py プロジェクト: scpwais/oarphpy
    def test_archive_rdd_tar(self):
        TEST_TEMPDIR = testutil.test_tempdir('test_archive_rdd_tar')
        fixture_path = os.path.join(TEST_TEMPDIR, 'test.tar')

        # Create the fixture:
        # test.tar
        #  |- foo: "foo"
        #  |- bar: "bar"
        # ... an archive with a few files, where each file contains just a string
        #   that matches the name of the file in the archive
        ss = [b'foo', b'bar', b'baz']

        import tarfile
        with tarfile.open(fixture_path, mode='w') as t:
            for s in ss:
                from io import BytesIO
                buf = BytesIO()
                buf.write(s)
                buf.seek(0)

                buf.seek(0, os.SEEK_END)
                buf_len = buf.tell()
                buf.seek(0)

                info = tarfile.TarInfo(name=s.decode('utf-8'))
                info.size = buf_len

                t.addfile(tarinfo=info, fileobj=buf)

        # Test Reading!
        with testutil.LocalSpark.sess() as spark:
            from oarphpy import spark as S
            fw_rdd = S.archive_rdd(spark, fixture_path)
            self._check_rdd(fw_rdd, ss)
コード例 #4
0
    def _check_serialization(self, rows, schema=None, do_adaption=True):
        import inspect
        from oarphpy import util
        from oarphpy.spark import RowAdapter

        test_name = inspect.stack()[1][3]

        TEST_TEMPDIR = testutil.test_tempdir('TestRowAdapter.' + test_name)

        if do_adaption:
            adapted_rows = [RowAdapter.to_row(r) for r in rows]
        else:
            adapted_rows = rows

        with testutil.LocalSpark.sess() as spark:
            if schema:
                df = spark.createDataFrame(adapted_rows,
                                           schema=schema,
                                           verifySchema=False)
                # verifySchema is expensive and improperly errors on mostly
                # empty rows
            else:
                df = spark.createDataFrame(adapted_rows)
                # Automatically samples rows to get schema
            outpath = os.path.join(TEST_TEMPDIR, 'rowdata_%s' % test_name)
            df.write.parquet(outpath)

            df2 = spark.read.parquet(outpath)
            decoded_wrapped_rows = df2.collect()

            if do_adaption:
                decoded_rows = [
                    RowAdapter.from_row(row) for row in decoded_wrapped_rows
                ]
                # We can't do assert sorted(rows) == sorted(decoded_rows)
                # because numpy syntatic sugar breaks __eq__, so use pprint,
                # which is safe for our tests
                import pprint

                def sorted_row_str(rowz):
                    if self._is_spark_2x():
                        # Spark 2.x has non-stable sorting semantics for Row
                        if len(rowz) > 1:
                            rowz = sorted(rowz, key=lambda r: r.id)
                        return pprint.pformat(rowz)
                    else:
                        return pprint.pformat(sorted(rowz))

                assert sorted_row_str(rows) == sorted_row_str(decoded_rows)

            return df
コード例 #5
0
def test_archive_flyweight_zip():
    TEST_TEMPDIR = testutil.test_tempdir('test_archive_flyweight_zip')
    fixture_path = os.path.join(TEST_TEMPDIR, 'test.zip')

    # Create the fixture:
    # test.zip
    #  |- foo: "foo"
    #  |- bar: "bar"
    # ... an archive with a few files, where each file contains just a string
    #   that matches the name of the file in the archive
    ss = [b'foo', b'bar', b'baz']

    import zipfile
    with zipfile.ZipFile(fixture_path, mode='w') as z:
        for s in ss:
            z.writestr(s.decode('utf-8'), s)

    # Test Reading!
    fws = util.ArchiveFileFlyweight.fws_from(fixture_path)
    _check_fws(fws, ss)
コード例 #6
0
ファイル: test_tfutil.py プロジェクト: scpwais/oarphpy
def test_tf_records_file_as_list_of_str():
  TEST_TEMPDIR = testutil.test_tempdir(
                      'test_tf_records_file_as_list_of_str')
  util.cleandir(TEST_TEMPDIR)
  
  # Create the fixture: simply three strings in the file.  A TFRecords file
  # is just a size-delimited concatenation of string records.
  ss = [b'foo', b'bar', b'bazzzz']
  fixture_path = os.path.join(TEST_TEMPDIR, 'test.tfrecord')

  with tf.io.TFRecordWriter(fixture_path) as writer:
    for s in ss:
      writer.write(s)
  
  # Test reading!
  tf_lst = util.TFRecordsFileAsListOfStrings(open(fixture_path, 'rb'))
  assert len(tf_lst) == len(ss)
  assert sorted(tf_lst) == sorted(ss)
  for i in range(len(ss)):
    assert tf_lst[i] == ss[i]
コード例 #7
0
ファイル: test_spark.py プロジェクト: scpwais/oarphpy
    def test_archive_rdd_zip(self):
        TEST_TEMPDIR = testutil.test_tempdir('test_archive_rdd_zip')
        fixture_path = os.path.join(TEST_TEMPDIR, 'test.zip')

        # Create the fixture:
        # test.zip
        #  |- foo: "foo"
        #  |- bar: "bar"
        # ... an archive with a few files, where each file contains just a string
        #   that matches the name of the file in the archive
        ss = [b'foo', b'bar', b'baz']

        import zipfile
        with zipfile.ZipFile(fixture_path, mode='w') as z:
            for s in ss:
                z.writestr(s.decode('utf-8'), s)

        # Test Reading!
        with testutil.LocalSpark.sess() as spark:
            from oarphpy import spark as S
            fw_rdd = S.archive_rdd(spark, fixture_path)
            self._check_rdd(fw_rdd, ss)