Exemple #1
0
 def test_numpy_codec(self):
     SHAPE = (10, 20, 30)
     expected = np.random.rand(*SHAPE).astype(dtype=np.int32)
     codec = NdarrayCodec()
     field = UnischemaField(name='test_name', numpy_dtype=np.int32, shape=SHAPE, codec=NdarrayCodec(),
                            nullable=False)
     np.testing.assert_equal(codec.decode(field, codec.encode(field, expected)), expected)
Exemple #2
0
    def get_petastorm_column(df_column):

        column_type = df_column.type
        column_name = df_column.name
        column_is_nullable = df_column.is_nullable
        column_array_dimensions = df_column.array_dimensions

        # Reference:
        # https://github.com/uber/petastorm/blob/master/petastorm/
        # tests/test_common.py

        petastorm_column = None
        if column_type == ColumnType.INTEGER:
            petastorm_column = UnischemaField(column_name, np.int32, (),
                                              ScalarCodec(IntegerType()),
                                              column_is_nullable)
        elif column_type == ColumnType.FLOAT:
            petastorm_column = UnischemaField(column_name, np.float64, (),
                                              ScalarCodec(FloatType()),
                                              column_is_nullable)
        elif column_type == ColumnType.TEXT:
            petastorm_column = UnischemaField(column_name, np.str_, (),
                                              ScalarCodec(StringType()),
                                              column_is_nullable)
        elif column_type == ColumnType.NDARRAY:
            petastorm_column = UnischemaField(column_name, np.uint8,
                                              column_array_dimensions,
                                              NdarrayCodec(),
                                              column_is_nullable)
        else:
            LoggingManager().log("Invalid column type: " + str(column_type),
                                 LoggingLevel.ERROR)

        return petastorm_column
Exemple #3
0
def test_make_named_tuple():
    TestSchema = Unischema('TestSchema', [
        UnischemaField('string_scalar', np.string_,
                       (), ScalarCodec(StringType()), True),
        UnischemaField('int32_scalar', np.int32,
                       (), ScalarCodec(ShortType()), False),
        UnischemaField('uint8_scalar', np.uint8,
                       (), ScalarCodec(ShortType()), False),
        UnischemaField('int32_matrix', np.float32,
                       (10, 20, 3), NdarrayCodec(), True),
        UnischemaField('decimal_scalar', Decimal,
                       (10, 20, 3), ScalarCodec(DecimalType(10, 9)), False),
    ])

    TestSchema.make_namedtuple(string_scalar='abc',
                               int32_scalar=10,
                               uint8_scalar=20,
                               int32_matrix=np.int32((10, 20, 3)),
                               decimal_scalar=Decimal(123) / Decimal(10))

    TestSchema.make_namedtuple(string_scalar=None,
                               int32_scalar=10,
                               uint8_scalar=20,
                               int32_matrix=None,
                               decimal_scalar=Decimal(123) / Decimal(10))
Exemple #4
0
    def test_dict_to_spark_row_field_validation_ndarrays(self):
        """Test various validations done on data types when converting a dictionary to a spark row"""
        TestSchema = Unischema('TestSchema', [
            UnischemaField('tensor3d', np.float32,
                           (10, 20, 30), NdarrayCodec(), False),
        ])

        self.assertTrue(
            isinstance(
                dict_to_spark_row(
                    TestSchema,
                    {'tensor3d': np.zeros((10, 20, 30), dtype=np.float32)}),
                Row))

        # Null value into not nullable field
        with self.assertRaises(ValueError):
            isinstance(dict_to_spark_row(TestSchema, {'string_field': None}),
                       Row)

        # Wrong dimensions
        with self.assertRaises(ValueError):
            isinstance(
                dict_to_spark_row(
                    TestSchema,
                    {'string_field': np.zeros((1, 2, 3), dtype=np.float32)}),
                Row)
Exemple #5
0
def test_nominal_case():
    """Nominal flow: can decode field successfully"""
    expected = np.random.rand(10, 10)
    row = {'matrix': NdarrayCodec().encode(MatrixField, expected)}

    actual = decode_row(row, MatrixSchema)['matrix']

    np.testing.assert_equal(actual, expected)
Exemple #6
0
 def test_get_petastorm_column_ndarray(self):
     expected_type = [np.int8, np.uint8, np.int16, np.int32, np.int64,
                      np.unicode_, np.bool_, np.float32, np.float64,
                      Decimal, np.str_, np.datetime64]
     col_name = 'frame_id'
     for array_type, np_type in zip(NdArrayType, expected_type):
         col = DataFrameColumn(col_name, ColumnType.NDARRAY, True,
                               array_type, [10, 10])
         petastorm_col = UnischemaField(col_name, np_type, [10, 10],
                                        NdarrayCodec(), True)
         self.assertEqual(SchemaUtils.get_petastorm_column(col),
                          petastorm_col)
Exemple #7
0
    def test_get_petastorm_column(self):
        col_name = 'frame_id'
        col = DataFrameColumn(col_name, ColumnType.INTEGER, False)
        petastorm_col = UnischemaField(col_name, np.int32, (),
                                       ScalarCodec(IntegerType()), False)
        self.assertEqual(SchemaUtils.get_petastorm_column(col), petastorm_col)

        col = DataFrameColumn(col_name, ColumnType.FLOAT, True)
        petastorm_col = UnischemaField(col_name, np.float64, (),
                                       ScalarCodec(FloatType()), True)
        self.assertEqual(SchemaUtils.get_petastorm_column(col), petastorm_col)

        col = DataFrameColumn(col_name, ColumnType.TEXT, False)
        petastorm_col = UnischemaField(col_name, np.string_, (),
                                       ScalarCodec(StringType()), False)
        self.assertEqual(SchemaUtils.get_petastorm_column(col), petastorm_col)

        col = DataFrameColumn(col_name, ColumnType.NDARRAY, True, [10, 10])
        petastorm_col = UnischemaField(col_name, np.uint8, [10, 10],
                                       NdarrayCodec(), True)
        self.assertEqual(SchemaUtils.get_petastorm_column(col), petastorm_col)

        col = DataFrameColumn(col_name, None, True, [10, 10])
        self.assertEqual(SchemaUtils.get_petastorm_column(col), None)
Exemple #8
0
def decode_image(tensor):
    codec = NdarrayCodec()
    image_np_array = codec.decode(TrainSchema[0], tensor[0])
    print(image_np_array.shape)
    return image_np_array
#  Copyright (c) 2017-2018 Uber Technologies, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
from pyspark.sql.types import IntegerType

from petastorm.codecs import ScalarCodec, NdarrayCodec
from petastorm.unischema import Unischema, UnischemaField

MnistSchema = Unischema('MnistSchema', [
    UnischemaField('idx', np.int_, (), ScalarCodec(IntegerType()), False),
    UnischemaField('digit', np.int_, (), ScalarCodec(IntegerType()), False),
    UnischemaField('image', np.uint8, (28, 28), NdarrayCodec(), False),
])
Exemple #10
0
from pyspark.sql.types import IntegerType, StringType

import numpy as np

L1bSchema = Unischema('L1bSchema', [
    UnischemaField('year', np.int32, (), ScalarCodec(IntegerType()), False),
    UnischemaField('dayofyear', np.int32,
                   (), ScalarCodec(IntegerType()), False),
    UnischemaField('hour', np.int32, (), ScalarCodec(IntegerType()), False),
    UnischemaField('minute', np.int32, (), ScalarCodec(IntegerType()), False),
    UnischemaField('file', np.string_, (), ScalarCodec(StringType()), False),
    UnischemaField('h', np.string_, (), ScalarCodec(IntegerType()), False),
    UnischemaField('v', np.string_, (), ScalarCodec(IntegerType()), False),
    UnischemaField('sample_id', np.int32,
                   (), ScalarCodec(IntegerType()), False),
    UnischemaField('data', np.float32, (64, 64, 16), NdarrayCodec(), False),
])

MAIACSchema = Unischema('MAIAC', [
    UnischemaField('year', np.int32, (), ScalarCodec(IntegerType()), False),
    UnischemaField('dayofyear', np.int32,
                   (), ScalarCodec(IntegerType()), False),
    UnischemaField('hour', np.int32, (), ScalarCodec(IntegerType()), False),
    UnischemaField('minute', np.int32, (), ScalarCodec(IntegerType()), False),
    UnischemaField('fileahi05', np.string_,
                   (), ScalarCodec(StringType()), False),
    UnischemaField('fileahi12', np.string_,
                   (), ScalarCodec(StringType()), False),
    UnischemaField('h', np.string_, (), ScalarCodec(IntegerType()), False),
    UnischemaField('v', np.string_, (), ScalarCodec(IntegerType()), False),
    UnischemaField('sample_id', np.int32,
import numpy as np
from petastorm.codecs import ScalarCodec, CompressedImageCodec, NdarrayCodec
from petastorm.etl.dataset_metadata import materialize_dataset
from petastorm.unischema import dict_to_spark_row, Unischema, UnischemaField
from pyspark.sql.types import IntegerType, StringType
from PIL import Image
from io import BytesIO

DEFAULT_IMAGE_SIZE = (128, 128)

# The schema defines how the dataset schema looks like
FeatureSchema = Unischema('FeatureSchema', [
    UnischemaField('features', np.uint8, (DEFAULT_IMAGE_SIZE[0], DEFAULT_IMAGE_SIZE[1], 3), NdarrayCodec() , False),
    UnischemaField('img_name', np.unicode, (), ScalarCodec(StringType()), False),
])

MaskSchema = Unischema('MaskSchema', [
    UnischemaField('masks', np.uint8, (DEFAULT_IMAGE_SIZE[0], DEFAULT_IMAGE_SIZE[1]), NdarrayCodec(), False),
    UnischemaField('img_name', np.unicode, (), ScalarCodec(StringType()), False),
])

TrainSchema = Unischema('TrainSchema', [
    UnischemaField('features', np.uint8, (DEFAULT_IMAGE_SIZE[0], DEFAULT_IMAGE_SIZE[1], 3), NdarrayCodec(), False),
    UnischemaField('masks', np.uint8, (DEFAULT_IMAGE_SIZE[0], DEFAULT_IMAGE_SIZE[1]), NdarrayCodec(), False)
])

# def resize_image(raw_image_data, image_size = (128, 128)):
#     img = Image.open(BytesIO(raw_image_data))
#     img = img.resize((image_size[0], image_size[1]), Image.ANTIALIAS)
#     return img
def test_str_special_method():
    codec = NdarrayCodec()
    assert str(codec) == 'NdarrayCodec()'

    codec = CompressedNdarrayCodec()
    assert str(codec) == 'CompressedNdarrayCodec()'
Exemple #13
0
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pytest

from petastorm.codecs import NdarrayCodec
from petastorm.unischema import UnischemaField, Unischema
from petastorm.utils import decode_row, DecodeFieldError

MatrixField = UnischemaField('matrix', np.float64, (10, 10), NdarrayCodec(), False)
MatrixSchema = Unischema('TestSchema', [MatrixField])


def test_nominal_case():
    """Nominal flow: can decode field successfully"""
    expected = np.random.rand(10, 10)
    row = {'matrix': NdarrayCodec().encode(MatrixField, expected)}

    actual = decode_row(row, MatrixSchema)['matrix']

    np.testing.assert_equal(actual, expected)


def test_can_not_decode():
    """Make sure field name is part of the error message"""
Exemple #14
0
_DEFAULT_IMAGE_SIZE = (32, 16, 3)

TestSchema = Unischema('TestSchema', [
    UnischemaField('partition_key', np.unicode_,
                   (), ScalarCodec(StringType()), False),
    UnischemaField('id', np.int64, (), ScalarCodec(LongType()), False),
    UnischemaField('id2', np.int32, (), ScalarCodec(ShortType()), False),
    UnischemaField('id_float', np.float64,
                   (), ScalarCodec(DoubleType()), False),
    UnischemaField('id_odd', np.bool_, (), ScalarCodec(BooleanType()), False),
    UnischemaField('python_primitive_uint8', np.uint8,
                   (), ScalarCodec(ShortType()), False),
    UnischemaField('image_png', np.uint8, _DEFAULT_IMAGE_SIZE,
                   CompressedImageCodec('png'), False),
    UnischemaField('matrix', np.float32, _DEFAULT_IMAGE_SIZE, NdarrayCodec(),
                   False),
    UnischemaField('decimal', Decimal,
                   (), ScalarCodec(DecimalType(10, 9)), False),
    UnischemaField('matrix_uint16', np.uint16, _DEFAULT_IMAGE_SIZE,
                   NdarrayCodec(), False),
    UnischemaField('matrix_string', np.string_, (
        None,
        None,
    ), NdarrayCodec(), False),
    UnischemaField('empty_matrix_string', np.string_,
                   (None, ), NdarrayCodec(), False),
    UnischemaField('matrix_nullable', np.uint16, _DEFAULT_IMAGE_SIZE,
                   NdarrayCodec(), True),
    UnischemaField('sensor_name', np.unicode_, (1, ), NdarrayCodec(), False),
    UnischemaField('string_array_nullable', np.unicode_,
Exemple #15
0
from petastorm.etl.dataset_metadata import materialize_dataset
from petastorm.etl.rowgroup_indexers import SingleFieldIndexer
from petastorm.etl.rowgroup_indexing import build_rowgroup_index
from petastorm.unischema import Unischema, UnischemaField, dict_to_spark_row

_DEFAULT_IMAGE_SIZE = (32, 16, 3)

TestSchema = Unischema('TestSchema', [
    UnischemaField('partition_key', np.unicode_, (), ScalarCodec(StringType()), False),
    UnischemaField('id', np.int64, (), ScalarCodec(LongType()), False),
    UnischemaField('id2', np.int32, (), ScalarCodec(ShortType()), False),
    UnischemaField('id_float', np.float64, (), ScalarCodec(DoubleType()), False),
    UnischemaField('id_odd', np.bool_, (), ScalarCodec(BooleanType()), False),
    UnischemaField('python_primitive_uint8', np.uint8, (), ScalarCodec(ShortType()), False),
    UnischemaField('image_png', np.uint8, _DEFAULT_IMAGE_SIZE, CompressedImageCodec('png'), False),
    UnischemaField('matrix', np.float32, _DEFAULT_IMAGE_SIZE, NdarrayCodec(), False),
    UnischemaField('decimal', Decimal, (), ScalarCodec(DecimalType(10, 9)), False),
    UnischemaField('matrix_uint16', np.uint16, _DEFAULT_IMAGE_SIZE, NdarrayCodec(), False),
    UnischemaField('matrix_string', np.string_, (None, None,), NdarrayCodec(), False),
    UnischemaField('empty_matrix_string', np.string_, (None,), NdarrayCodec(), False),
    UnischemaField('matrix_nullable', np.uint16, _DEFAULT_IMAGE_SIZE, NdarrayCodec(), True),
    UnischemaField('sensor_name', np.unicode_, (1,), NdarrayCodec(), False),
    UnischemaField('string_array_nullable', np.unicode_, (None,), NdarrayCodec(), True),
])


def _random_binary_string_gen(max_length):
    """Returns a single random string up to max_length specified length that may include \x00 character anywhere in the
    string"""
    size = random.randint(0, max_length)
    return ''.join(random.choice(('\x00', 'A', 'B')) for _ in range(size))
Exemple #16
0
from petastorm import make_batch_reader
from petastorm import make_reader
from petastorm.tf_utils import make_petastorm_dataset
from petastorm.unischema import dict_to_spark_row, Unischema, UnischemaField
import numpy as np
import tensorflow.keras.backend as K
import shutil
import horovod.tensorflow.keras as hvd
from petastorm.codecs import ScalarCodec, CompressedImageCodec, NdarrayCodec
tf.enable_eager_execution()
DEFAULT_IMAGE_SIZE = (128, 128)

TrainSchema = [
    UnischemaField('features', np.uint8,
                   (DEFAULT_IMAGE_SIZE[0], DEFAULT_IMAGE_SIZE[1], 3),
                   NdarrayCodec(), False),
    UnischemaField('masks', np.uint8,
                   (DEFAULT_IMAGE_SIZE[0], DEFAULT_IMAGE_SIZE[1]),
                   NdarrayCodec(), False)
]

# Horovod: initialize Horovod inside the trainer.
hvd.init()

# Horovod: pin GPU to be used to process local rank (one GPU per process), if GPUs are available.
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.visible_device_list = str(hvd.local_rank())
K.set_session(tf.Session(config=config))
from petastorm.transform import TransformSpec
Exemple #17
0
def decode_image(tensor):
        codec = NdarrayCodec()
        image_np_array = codec.decode(TrainSchema[0], tensor[0]).reshape((128, 128, 3))
        #print(image_np_array)
        return image_np_array
from pyspark.sql.types import IntegerType

from petastorm.codecs import ScalarCodec, CompressedImageCodec, NdarrayCodec
from petastorm.unischema import dict_to_spark_row, Unischema, UnischemaField

from pycarbon.core.carbon_dataset_metadata import materialize_dataset_carbon

from pycarbon.tests import DEFAULT_CARBONSDK_PATH

# The schema defines how the dataset schema looks like
HelloWorldSchema = Unischema('HelloWorldSchema', [
    UnischemaField('id', np.int_, (), ScalarCodec(IntegerType()), False),
    UnischemaField('image1', np.uint8,
                   (128, 256, 3), CompressedImageCodec('png'), False),
    UnischemaField('array_4d', np.uint8,
                   (None, 128, 30, None), NdarrayCodec(), False),
])


def row_generator(x):
    """Returns a single entry in the generated dataset. Return a bunch of random values as an example."""
    return {
        'id': x,
        'image1': np.random.randint(0, 255, dtype=np.uint8,
                                    size=(128, 256, 3)),
        'array_4d': np.random.randint(0,
                                      255,
                                      dtype=np.uint8,
                                      size=(4, 128, 30, 3))
    }
Exemple #19
0
def decode_mask(tensor):
    codec = NdarrayCodec()
    mask_np_array = codec.decode(TrainSchema[1], tensor[1])
    return mask_np_array
Exemple #20
0
_DEFAULT_IMAGE_SIZE = (32, 16, 3)

TestSchema = Unischema(
    'TestSchema',
    [
        UnischemaField('partition_key', np.unicode_, ()),
        UnischemaField('id', np.int64, ()),
        UnischemaField('id2', np.int32, (), ScalarCodec(ShortType()),
                       False),  # Explicit scalar codec in some scalar fields
        UnischemaField('id_float', np.float64, ()),
        UnischemaField('id_odd', np.bool_, ()),
        UnischemaField('python_primitive_uint8', np.uint8, ()),
        UnischemaField('image_png', np.uint8, _DEFAULT_IMAGE_SIZE,
                       CompressedImageCodec('png'), False),
        UnischemaField('matrix', np.float32, _DEFAULT_IMAGE_SIZE,
                       NdarrayCodec(), False),
        UnischemaField('decimal', Decimal,
                       (), ScalarCodec(DecimalType(10, 9)), False),
        UnischemaField('matrix_uint16', np.uint16, _DEFAULT_IMAGE_SIZE,
                       NdarrayCodec(), False),
        UnischemaField('matrix_uint32', np.uint32, _DEFAULT_IMAGE_SIZE,
                       NdarrayCodec(), False),
        UnischemaField('matrix_string', np.string_, (
            None,
            None,
        ), NdarrayCodec(), False),
        UnischemaField('empty_matrix_string', np.string_,
                       (None, ), NdarrayCodec(), False),
        UnischemaField('matrix_nullable', np.uint16, _DEFAULT_IMAGE_SIZE,
                       NdarrayCodec(), True),
        UnischemaField('sensor_name', np.unicode_,
Exemple #21
0
def petastorm_unischema_codec(shape, type):
    if shape == 1:
        return ScalarCodec(type())
    else:
        return NdarrayCodec()
def main():
    SPARK_MASTER_URL = 'spark://...' # Change the Spark master URL.
    H5_PRE_PROCESSED_DATA_DIR = 'file://...' # Change pre-processed data input path. Should be accessible from all Spark workers.
    OUTPUT_PATH = 'file:///...' # Change Petastorm output path. Should be accessible from all Spark workers.
    TRAIN_FRACTION = 0.7 # Fraction of train data. Remaining is validation data.
    
    ROW_GROUP_SIZE_MB = 512 # Size of Parquet row group size.
    NUM_PARTITIONS = 100 # Number of Parquet partitions for train and val data each.
    
    spark = SparkSession \
            .builder \
            .master(SPARK_MASTER_URL) \
            .appName("Deep Postures Example - Petastorm Data Generation") \
            .getOrCreate()

    input_data = []
    if H5_PRE_PROCESSED_DATA_DIR.startswith('hdfs://'):
        args = "hdfs dfs -ls "+dir_in+" | awk '{print $8}'"
        proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)

        s_output, s_err = proc.communicate()
        input_data = ['hdfs://'+ path for path in s_output.split()]
    elif H5_PRE_PROCESSED_DATA_DIR.startswith('file://'):
        for dirname in os.listdir(H5_PRE_PROCESSED_DATA_DIR):
            if not os.path.join(H5_PRE_PROCESSED_DATA_DIR, dirname).startswith('.')
            input_data.append(str(os.path.join(H5_PRE_PROCESSED_DATA_DIR, dirname)))
    else:
        raise Exception('Unsupported file system in: {}'.format(H5_PRE_PROCESSED_DATA_DIR))

    random.shuffle(input_data)
    n_train = int(len(input_data) * TRAIN_FRACTION)
    train_data = input_data[:n_train]
    val_data = input_data[n_train:]

    backend = SparkBackend(spark_context=spark.sparkContext)
    store = LocalStore(OUTPUT_PATH, train_path=os.path.join(OUTPUT_PATH, 'train_data'), val_path=os.path.join(OUTPUT_PATH, 'val_data'))
    
    schema = Unischema('schema', [
        UnischemaField('id', np.string_, (), ScalarCodec(StringType()), False),
        UnischemaField('time', np.int64, (), ScalarCodec(LongType()), False),
        UnischemaField('data', np.float32, (100, 3), NdarrayCodec(), False),
        UnischemaField('non_wear', np.int32, (), ScalarCodec(IntegerType()), False),
        UnischemaField('sleeping', np.int32, (), ScalarCodec(IntegerType()), False),
        UnischemaField('label', np.int32, (), ScalarCodec(IntegerType()), False)
    ])

    with materialize_dataset(spark, os.path.join(output_url, 'train_data'), schema, ROW_GROUP_SIZE_MB):
        rdd=spark.sparkContext.parallelize(train_data)
        rdd = rdd.flatMap(lambda x: load_h5(x)).map(lambda item: {'id': item[0], 'time':item[1], 'data':item[2], 'non_wear':item[3], 'sleeping':item[4], 'label':item[5]})
        rdd =  rdd.map(lambda x: dict_to_spark_row(schema, x)) 
        
        df = spark.createDataFrame(rdd, schema=schema.as_spark_schema())
        df.orderBy("id","time").coalesce(NUM_PARTITIONS).write.mode('overwrite').parquet(os.path.join(output_url, 'train_data'))


    with materialize_dataset(spark, os.path.join(output_url, 'val_data'), schema, ROW_GROUP_SIZE_MB):
        rdd=spark.sparkContext.parallelize(val_data)
        rdd = rdd.flatMap(lambda x: load_h5(x)).map(lambda item: {'id': item[0], 'time':item[1], 'data':item[2], 'non_wear':item[3], 'sleeping':item[4], 'label':item[5]})
        rdd =  rdd.map(lambda x: dict_to_spark_row(schema, x)) 
        
        df = spark.createDataFrame(rdd, schema=schema.as_spark_schema())
        df.orderBy("id","time").coalesce(NUM_PARTITIONS).write.mode('overwrite').parquet(os.path.join(output_url, 'val_data'))

if __name__ == "__main__":
    main()
Exemple #23
0
# limitations under the License.

import numpy as np
from pyspark.sql.types import IntegerType

from petastorm.codecs import ScalarCodec, NdarrayCodec
from petastorm.ngram import NGram
from petastorm.reader_impl.row_group_decoder import RowDecoder
from petastorm.unischema import Unischema, UnischemaField, dict_to_spark_row

_matrix_shape = (10, 20)

TestSchema = Unischema('TestSchema', [
    UnischemaField('some_number', np.uint8,
                   (), ScalarCodec(IntegerType()), False),
    UnischemaField('some_matrix', np.float32, _matrix_shape, NdarrayCodec(),
                   False),
])


def _rand_row(some_number=np.random.randint(0, 255)):
    row_as_dict = {
        'some_number': some_number,
        'some_matrix': np.random.random(size=_matrix_shape).astype(np.float32),
    }

    return row_as_dict


def test_row_decoding():
    expected_row = _rand_row()