Esempio n. 1
0
def cachingToRedis():
    ## caching the data into redis meomory and delet it after 60 sec
    context = pa.default_serialization_context()
    redis.set('key',
              context.serialize(scrapes_BTC_Data()).to_buffer().to_pybytes())
    redis.expire('key', 60)
    return context
Esempio n. 2
0
 def _df_from_redis(self, key):
     buffer = self._config.redis.get(key)
     if buffer is not None:
         context = pa.default_serialization_context()
         return context.deserialize(buffer)
     else:
         return None
Esempio n. 3
0
def test_serialize_subclasses():

    # This test shows how subclasses can be handled in an idiomatic way
    # by having only a serializer for the base class

    # This technique should however be used with care, since pickling
    # type(obj) with couldpickle will include the full class definition
    # in the serialized representation.
    # This means the class definition is part of every instance of the
    # object, which in general is not desirable; registering all subclasses
    # with register_type will result in faster and more memory
    # efficient serialization.

    context = pa.default_serialization_context()
    context.register_type(Serializable,
                          "Serializable",
                          custom_serializer=serialize_serializable,
                          custom_deserializer=deserialize_serializable)

    a = SerializableClass()
    serialized = pa.serialize(a, context=context)

    deserialized = serialized.deserialize(context=context)
    assert type(deserialized).__name__ == SerializableClass.__name__
    assert deserialized.value == 3
Esempio n. 4
0
def test_serialize_subclasses():

    # This test shows how subclasses can be handled in an idiomatic way
    # by having only a serializer for the base class

    # This technique should however be used with care, since pickling
    # type(obj) with couldpickle will include the full class definition
    # in the serialized representation.
    # This means the class definition is part of every instance of the
    # object, which in general is not desirable; registering all subclasses
    # with register_type will result in faster and more memory
    # efficient serialization.

    context = pa.default_serialization_context()
    context.register_type(
        Serializable, "Serializable",
        custom_serializer=serialize_serializable,
        custom_deserializer=deserialize_serializable)

    a = SerializableClass()
    serialized = pa.serialize(a, context=context)

    deserialized = serialized.deserialize(context=context)
    assert type(deserialized).__name__ == SerializableClass.__name__
    assert deserialized.value == 3
Esempio n. 5
0
def fetchHospital(rconn, key):
    context = pyarrow.default_serialization_context()

    #
    # See: https://dev.socrata.com/foundry/healthdata.gov/g62h-syeh
    #

    #
    # Check date of main dataframe
    #
    expires = rconn.hget("statehos" + key, "expires")
    if expires and time.time() < float(expires):
        return context.deserialize(rconn.hget("statehos" + key, "dataframe"))

    #
    # Fetch
    # Make sure we include a user agent. We are limited to 50,000 records per query,
    # but that should be plenty for this table (which has rows per day)
    #
    # The HHS sure loves long column names...
    #
    columns = [
        'date',
        'state',
        'inpatient_beds',
        'inpatient_beds_used',
        'inpatient_beds_used_covid',
        'staffed_icu_adult_patients_confirmed_and_suspected_covid',
        'total_staffed_adult_icu_beds',
    ]

    req = requests.get(
        "https://healthdata.gov/resource/g62h-syeh.csv",
        params={
            'state': key,
            '$limit': 5000,
            '$select': ",".join(columns),
            "$$app_token": app.config['SOCRATA_TOKEN']
        },
        headers={
            'User-Agent':
            'Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:77.0) Gecko/20100101 Firefox/77.0'
        })

    if req.status_code != 200:
        raise Exception("Request failure: {}".format(req.status_code))

    answer = pd.read_csv(StringIO(req.text),
                         parse_dates=["date"]).rename(columns={'date': 'dt'})

    answer = answer.sort_values('dt')

    #
    # Save
    #
    rconn.hset("statehos" + key, "dataframe",
               context.serialize(answer).to_buffer().to_pybytes())
    rconn.hset("statehos" + key, "expires", str(time.time() + 600.0))

    return answer
Esempio n. 6
0
def _serialize_and_expand_data(
    result_set: SupersetResultSet,
    db_engine_spec: BaseEngineSpec,
    use_msgpack: Optional[bool] = False,
    expand_data: bool = False,
) -> Tuple[Union[bytes, str], List[Any], List[Any], List[Any]]:
    selected_columns = result_set.columns
    all_columns: List[Any]
    expanded_columns: List[Any]

    if use_msgpack:
        with stats_timing("sqllab.query.results_backend_pa_serialization",
                          stats_logger):
            data = (pa.default_serialization_context().serialize(
                result_set.pa_table).to_buffer().to_pybytes())

        # expand when loading data from results backend
        all_columns, expanded_columns = (selected_columns, [])
    else:
        df = result_set.to_pandas_df()
        data = df_to_records(df) or []

        if expand_data:
            all_columns, data, expanded_columns = db_engine_spec.expand_data(
                selected_columns, data)
        else:
            all_columns = selected_columns
            expanded_columns = []

    return (data, selected_columns, all_columns, expanded_columns)
Esempio n. 7
0
def mars_serialize_context():
    global _serialize_context
    if _serialize_context is None:
        ctx = pyarrow.default_serialization_context()
        ctx.register_type(SparseNDArray,
                          'mars.SparseNDArray',
                          custom_serializer=_serialize_sparse_nd_array,
                          custom_deserializer=_deserialize_sparse_nd_array)
        ctx.register_type(GroupByWrapper,
                          'pandas.GroupByWrapper',
                          custom_serializer=_serialize_groupby_wrapper,
                          custom_deserializer=_deserialize_groupby_wrapper)
        ctx.register_type(pd.Interval,
                          'pandas.Interval',
                          custom_serializer=_serialize_pandas_interval,
                          custom_deserializer=_deserialize_pandas_interval)
        ctx.register_type(pd.Categorical,
                          'pandas.Categorical',
                          custom_serializer=_serialze_pandas_categorical,
                          custom_deserializer=_deserialize_pandas_categorical)
        ctx.register_type(
            pd.CategoricalDtype,
            'pandas.CategoricalDtype',
            custom_serializer=_serialize_pandas_categorical_dtype,
            custom_deserializer=_deserialize_pandas_categorical_dtype)
        _apply_pyarrow_serialization_patch(ctx)
        if vineyard is not None:  # pragma: no cover
            vineyard.register_vineyard_serialize_context(ctx)
        _serialize_context = ctx
    return _serialize_context
Esempio n. 8
0
def fetchData(rconn):
    context = pyarrow.default_serialization_context()

    #
    # Check date of main dataframe
    #
    expires = rconn.hget("county", "expires")
    if expires and time.time() < float(expires):
        return context.deserialize(rconn.hget("county", "dataframe"))

    #
    # Fetch new copy
    #
    dt = pd.read_csv(
        "https://github.com/nytimes/covid-19-data/blob/master/us-counties.csv?raw=true"
    )
    dt['dt'] = pd.to_datetime(dt.date, format="%Y-%m-%d")

    #
    # Save
    #
    rconn.hset("county", "dataframe",
               context.serialize(dt).to_buffer().to_pybytes())
    rconn.hset("county", "expires", str(time.time() + 600.0))
    return dt
Esempio n. 9
0
    def getOGCdataframe(cls, cache_key, csv_url, process):
        serializer = pa.default_serialization_context()
        data = cache.get(cache_key)

        app = current_app._get_current_object()

        #if empty dataset refresh data synchronously, otherwise refresh in the background and continue
        if not data:
            df = refreshOGCdata(app, cache_key, csv_url, process)
        else:
            thread = Thread(target=refreshOGCdata,
                            args=(
                                app,
                                cache_key,
                                csv_url,
                                process,
                            ))
            thread.daemon = True
            thread.start()

        #update data and return
        data = cache.get(cache_key)
        if data:
            df = serializer.deserialize(data)

        return df
def test_torch_serialization(large_buffer):
    pytest.importorskip("torch")

    serialization_context = pa.default_serialization_context()
    pa.register_torch_serialization_handlers(serialization_context)

    # Dense tensors:

    # These are the only types that are supported for the
    # PyTorch to NumPy conversion
    for t in ["float32", "float64",
              "uint8", "int16", "int32", "int64"]:
        obj = torch.from_numpy(np.random.randn(1000).astype(t))
        serialization_roundtrip(obj, large_buffer,
                                context=serialization_context)

    tensor_requiring_grad = torch.randn(10, 10, requires_grad=True)
    serialization_roundtrip(tensor_requiring_grad, large_buffer,
                            context=serialization_context)

    # Sparse tensors:

    # These are the only types that are supported for the
    # PyTorch to NumPy conversion
    for t in ["float32", "float64",
              "uint8", "int16", "int32", "int64"]:
        i = torch.LongTensor([[0, 2], [1, 0], [1, 2]])
        v = torch.from_numpy(np.array([3, 4, 5]).astype(t))
        obj = torch.sparse_coo_tensor(i.t(), v, torch.Size([2, 3]))
        serialization_roundtrip(obj, large_buffer,
                                context=serialization_context)
Esempio n. 11
0
    def connect(self):
        logger.info("Connecting to redis cache")
        import redis
        import pyarrow as pa

        self.redis_client = redis.Redis(**self.redis_args)
        self.pyarrow_context = pa.default_serialization_context()
Esempio n. 12
0
def update_histogram(trigger_idx, left_hide_trigger, histogram_sw, x_histogram,
                     y_histogram, keys_dict, num_keys, cat_keys,
                     categorical_key_values, numerical_key_values):
    x_key = keys_dict[x_histogram]['key']
    x_label = keys_dict[x_histogram]['description']
    y_key = y_histogram

    if histogram_sw:
        context = pa.default_serialization_context()
        data = context.deserialize(redis_instance.get("DATASET"))
        filtered_table = filter_all(data, num_keys, numerical_key_values,
                                    cat_keys, categorical_key_values)

        histogram_fig = get_histogram(filtered_table, x_key, x_label, y_key)
        histogram_x_disabled = False
        histogram_y_disabled = False
    else:
        histogram_fig = {
            'data': [{
                'type': 'histogram',
                'x': []
            }],
            'layout': {}
        }
        histogram_x_disabled = True
        histogram_y_disabled = True

    return [
        histogram_fig,
        histogram_x_disabled,
        histogram_y_disabled,
    ]
Esempio n. 13
0
def left_hide_button(
    btn,
    selectedData,
    trigger_idx,
):
    if btn > 0 and selectedData is not None:
        context = pa.default_serialization_context()
        data = context.deserialize(redis_instance.get("DATASET"))

        s_data = pd.DataFrame(selectedData['points'])
        idx = s_data['id']
        idx.index = idx

        vis_idx = idx[data['Visibility'][idx] == 'visible']
        hid_idx = idx[data['Visibility'][idx] == 'hidden']

        data.loc[vis_idx, 'Visibility'] = 'hidden'
        data.loc[hid_idx, 'Visibility'] = 'visible'

        redis_instance.set(REDIS_KEYS["DATASET"],
                           context.serialize(data).to_buffer().to_pybytes())

        return trigger_idx + 1

    else:
        raise PreventUpdate
Esempio n. 14
0
    def __init__(self, redis_instance, key=None, hash_keys=False):
        # Checks that a redis object has been passed
        # https://stackoverflow.com/questions/57949871/how-to-set-get-pandas-dataframes-into-redis-using-pyarrow/57986261#57986261
        if isinstance(redis_instance, redis.client.Redis) != True:
            raise AttributeError(
                "Did not recieve an Redis Object, instead recieved {}".format(
                    type(redis_instance)))

        # Sets the redis object as an attribute of this class
        self.cache_container = redis_instance

        # pyarrow serializer, uses the default serialization method
        self.context = pa.default_serialization_context()

        # leading for reference in db, can be used to identify the caching object in redis
        self.leading_key = 'RandasCache'

        # Key to identify object/state of object,
        # if None, key will try to be generated from inputs of function
        self.key = key

        assert isinstance(hash_keys,
                          bool), "[ERROR]: hash_keys should be a bool"
        self.hash_keys = hash_keys

        # storing of keys in the object so as to know how to retrieve the value
        self.keys = {}
Esempio n. 15
0
def test_sparse_csf_tensor_serialization(index_type, tensor_type):
    tensor_dtype = np.dtype(tensor_type)
    index_dtype = np.dtype(index_type)
    data = np.array([[1, 2, 3, 4, 5, 6, 7, 8]]).T.astype(tensor_dtype)
    indptr = [
        np.array([0, 2, 3]),
        np.array([0, 1, 3, 4]),
        np.array([0, 2, 4, 5, 8]),
    ]
    indices = [
        np.array([0, 1]),
        np.array([0, 1, 1]),
        np.array([0, 0, 1, 1]),
        np.array([1, 2, 0, 2, 0, 0, 1, 2]),
    ]
    indptr = [x.astype(index_dtype) for x in indptr]
    indices = [x.astype(index_dtype) for x in indices]
    shape = (2, 3, 4, 5)
    axis_order = (0, 1, 2, 3)
    dim_names = ("a", "b", "c", "d")

    for ndim in [2, 3, 4]:
        sparse_tensor = pa.SparseCSFTensor.from_numpy(data, indptr[:ndim - 1],
                                                      indices[:ndim],
                                                      shape[:ndim],
                                                      axis_order[:ndim],
                                                      dim_names[:ndim])

        context = pa.default_serialization_context()
        serialized = pa.serialize(sparse_tensor, context=context).to_buffer()
        result = pa.deserialize(serialized)
        assert_equal(result, sparse_tensor)
        assert isinstance(result, pa.SparseCSFTensor)
Esempio n. 16
0
 def save_data_frame(self, key: str, df: pd.DataFrame):
     """
     Save a Pandas dataframe to Redis store
     """
     context = pa.default_serialization_context()
     self.redis_client.setex(key, EXPIRATION_SECONDS,
                             context.serialize(df).to_buffer().to_pybytes())
Esempio n. 17
0
def test_numpy_subclass_serialization():
    # Check that we can properly serialize subclasses of np.ndarray.
    class CustomNDArray(np.ndarray):
        def __new__(cls, input_array):
            array = np.asarray(input_array).view(cls)
            return array

    def serializer(obj):
        return {'numpy': obj.view(np.ndarray)}

    def deserializer(data):
        array = data['numpy'].view(CustomNDArray)
        return array

    context = pa.default_serialization_context()

    context.register_type(CustomNDArray,
                          'CustomNDArray',
                          custom_serializer=serializer,
                          custom_deserializer=deserializer)

    x = CustomNDArray(np.zeros(3))
    serialized = pa.serialize(x, context=context).to_buffer()
    new_x = pa.deserialize(serialized, context=context)
    assert type(new_x) == CustomNDArray
    assert np.alltrue(new_x.view(np.ndarray) == np.zeros(3))
Esempio n. 18
0
def _serialize_and_expand_data(
    cdf: SupersetDataFrame,
    db_engine_spec: BaseEngineSpec,
    use_msgpack: Optional[bool] = False,
    expand_data: bool = False,
) -> Tuple[Union[bytes, str], list, list, list]:
    selected_columns: list = cdf.columns or []
    expanded_columns: list

    if use_msgpack:
        with stats_timing(
            "sqllab.query.results_backend_pa_serialization", stats_logger
        ):
            data = (
                pa.default_serialization_context()
                .serialize(cdf.raw_df)
                .to_buffer()
                .to_pybytes()
            )
        # expand when loading data from results backend
        all_columns, expanded_columns = (selected_columns, [])
    else:
        data = cdf.data or []
        if expand_data:
            all_columns, data, expanded_columns = db_engine_spec.expand_data(
                selected_columns, data
            )
        else:
            all_columns = selected_columns
            expanded_columns = []

    return (data, selected_columns, all_columns, expanded_columns)
Esempio n. 19
0
def test_numpy_subclass_serialization():
    # Check that we can properly serialize subclasses of np.ndarray.
    class CustomNDArray(np.ndarray):
        def __new__(cls, input_array):
            array = np.asarray(input_array).view(cls)
            return array

    def serializer(obj):
        return {'numpy': obj.view(np.ndarray)}

    def deserializer(data):
        array = data['numpy'].view(CustomNDArray)
        return array

    context = pa.default_serialization_context()

    context.register_type(CustomNDArray, 'CustomNDArray',
                          custom_serializer=serializer,
                          custom_deserializer=deserializer)

    x = CustomNDArray(np.zeros(3))
    serialized = pa.serialize(x, context=context).to_buffer()
    new_x = pa.deserialize(serialized, context=context)
    assert type(new_x) == CustomNDArray
    assert np.alltrue(new_x.view(np.ndarray) == np.zeros(3))
def save_signal_database():
    logger.debug("Saving signal database to redis")
    r = redis_handle()
    alerts = Alerts()
    context = pa.default_serialization_context()
    for signal, df in alerts.data.items():
        r.sadd(signal, context.serialize(df).to_buffer().to_pybytes())
Esempio n. 21
0
    def cache_df(self, alias, df):
        cur = redis.Redis(connection_pool=self.redis_pool)
        context = pa.default_serialization_context()
        df_compressed = context.serialize(df).to_buffer().to_pybytes()

        res = cur.set(alias, df_compressed)
        if res == True:
            print('df cached')
 def _df_from_redis(self, key):
     """Loads a pd.DataFrame from redis. """
     buffer = self._r.get(key)
     if buffer is not None:
         context = pa.default_serialization_context()
         return context.deserialize(buffer)
     else:
         return None
Esempio n. 23
0
def get_api_dataframe(level, month):
    """ Gets API data for a given area hierachy level and month (yyyy-mm) and saves it as a dataframe to Redis """     

    query_filters = [
        "areaType="+level,
        "date>="+month+"-01",
        "date<="+month+"-31"
    ]
    
    # Query API data for this hierachy level  (API can be flaky so try up to 3 times)
    attempts = 1
    json_data = list()
    while attempts < 4 and len(json_data) == 0:
        print(f"Querying API..... Attempt {attempts} - Level {level} - Month {month}")
        json_data = get_api_paginated_dataset(query_filters, get_structure(level))
        attempts += 1 
        if len(json_data) == 0:
            # Wait 30 seconds before trying again
            time.sleep(30)

    checktotal = None
    if len(json_data) > 0:

        # Convert json data to dataframe
        df = pd.DataFrame.from_dict(json_data, orient='columns')
       
        # Rename api level with our application level text (e.g. utla to Upper tier local authority)
        df['type'] = df['type'].replace(level,LEVELS_DICT[level])

        # Tidy up dataframe, rename columns, replace some data, fill nas, sort, convert types etc
        df.rename(columns=COLUMN_RENAMES,inplace=True)
        df['Date'] = pd.to_datetime(df['Date'],format="%Y-%m-%d")
        df.sort_values(by=['Date'], inplace=True)
        df.set_index('Date',inplace=True)
        df.fillna(0,inplace=True)
        df = df.astype({'Cases': int, 'Tests': int, 'Hospital Cases': int, 'Deaths within 28 Days of Positive Test': int})

        # Save dataframe to redis - note we create a new connection here so it can be serialised by Spark
        r = redis.Redis(host="localhost", port=6379, db=0)
        ac = pa.default_serialization_context()

        # Save dataframe to redis - first backup previous data to "Old.xxx" key
        oldkey = "Old.Cases."+level+"."+month
        currentkey = "Cases."+level+"."+month

        if r.exists(oldkey):
            r.delete(oldkey)
        
        if r.exists(currentkey):
            r.rename(currentkey,oldkey)

        # Save new data
        r.set(currentkey, ac.serialize(df).to_buffer().to_pybytes())

        checktotal = df['Cases'].sum()

    print(f"Total Cases : {checktotal}")
    return checktotal
Esempio n. 24
0
def mars_serialize_context():
    global _serialize_context
    if _serialize_context is None:
        ctx = pyarrow.default_serialization_context()
        ctx.register_type(SparseNDArray, 'mars.SparseNDArray',
                          custom_serializer=_serialize_sparse_csr_list,
                          custom_deserializer=_deserialize_sparse_csr_list)
        _serialize_context = ctx
    return _serialize_context
def load_signal_database():
    logger.debug("Loading signal database")
    alerts = Alerts()

    r = redis_handle()
    context = pa.default_serialization_context()
    signals = get_signals()
    for signal in signals:
        alerts.data[signal] = context.deserialize(r.get(signal))
Esempio n. 26
0
 def __init__(self, raylet_socket_name, object_store_socket_name,
              is_worker):
     # Connect to the Raylet and object store.
     self.node_manager_client = ray.local_scheduler.LocalSchedulerClient(
         raylet_socket_name, random_string(), is_worker)
     self.plasma_client = plasma.connect(object_store_socket_name, "", 0)
     self.serialization_context = pyarrow.default_serialization_context()
     self.raylet_socket_name = raylet_socket_name
     self.object_store_socket_name = object_store_socket_name
Esempio n. 27
0
 def get_df_cache(self,ns='twitter_cache:df:',key=''):
     context = pa.default_serialization_context()
     if type(key) is bytes:
         get_key=key
     else:
         get_key=ns+key 
     if self.r.get(get_key):
         return context.deserialize(self.r.get(get_key))
     else:
         return pd.DataFrame()
Esempio n. 28
0
    def setup(self):
        # Transpose to make column-major
        values = np.random.randn(10, 100000)

        df = pd.DataFrame(values.T)
        ctx = pa.default_serialization_context()

        self.serialized = ctx.serialize(df)
        self.as_buffer = self.serialized.to_buffer()
        self.as_components = self.serialized.to_components()
Esempio n. 29
0
    def get_data_frame(self, key: str) -> Optional:
        """
        Retrieve a Pandas dataframe from Redis store
        """
        value = self.redis_client.get(key)

        if value is None:
            return

        context = pa.default_serialization_context()
        return context.deserialize(value)
Esempio n. 30
0
def mars_serialize_context():
    global _serialize_context
    if _serialize_context is None:
        from ..compat import apply_pyarrow_serialization_patch
        ctx = pyarrow.default_serialization_context()
        ctx.register_type(SparseNDArray, 'mars.SparseNDArray',
                          custom_serializer=_serialize_sparse_csr_list,
                          custom_deserializer=_deserialize_sparse_csr_list)
        apply_pyarrow_serialization_patch(ctx)
        _serialize_context = ctx
    return _serialize_context
Esempio n. 31
0
def mars_serialize_context():
    from ..dataframe.arrays import ArrowArray, ArrowDtype

    global _serialize_context
    if _serialize_context is None:
        ctx = pyarrow.default_serialization_context()
        ctx.register_type(_PickleWrapper,
                          'mars.PickleWrapper',
                          custom_serializer=_PickleWrapper.serialize,
                          custom_deserializer=_PickleWrapper.deserialize)
        ctx.register_type(_ComplexWrapper,
                          'mars.ComplexWrapper',
                          custom_serializer=_ComplexWrapper.serialize,
                          custom_deserializer=_ComplexWrapper.deserialize)
        ctx.register_type(SparseNDArray,
                          'mars.SparseNDArray',
                          custom_serializer=_serialize_sparse_nd_array,
                          custom_deserializer=_deserialize_sparse_nd_array)
        ctx.register_type(GroupByWrapper,
                          'pandas.GroupByWrapper',
                          custom_serializer=_serialize_groupby_wrapper,
                          custom_deserializer=_deserialize_groupby_wrapper)
        ctx.register_type(pd.Interval,
                          'pandas.Interval',
                          custom_serializer=_serialize_pandas_interval,
                          custom_deserializer=_deserialize_pandas_interval)
        ctx.register_type(pd.Categorical,
                          'pandas.Categorical',
                          custom_serializer=_serialze_pandas_categorical,
                          custom_deserializer=_deserialize_pandas_categorical)
        ctx.register_type(
            pd.CategoricalDtype,
            'pandas.CategoricalDtype',
            custom_serializer=_serialize_pandas_categorical_dtype,
            custom_deserializer=_deserialize_pandas_categorical_dtype)
        ctx.register_type(ArrowDtype,
                          'mars.dataframe.ArrowDtype',
                          custom_serializer=_serialize_arrow_dtype,
                          custom_deserializer=_deserialize_arrow_dtype)
        ctx.register_type(ArrowArray,
                          'mars.dataframe.ArrowArray',
                          custom_serializer=_serialize_arrow_array,
                          custom_deserializer=_deserialize_arrow_array)
        ctx.register_type(pd.arrays.SparseArray,
                          'pandas.arrays.SparseArray',
                          custom_serializer=_serialize_sparse_array,
                          custom_deserializer=_deserialzie_sparse_array)
        ctx.register_type(pd.SparseDtype,
                          'pandas.SparseDtype',
                          custom_serializer=_serialize_sparse_dtype,
                          custom_deserializer=_deserialize_sparse_dtype)
        _apply_pyarrow_serialization_patch(ctx)
        _serialize_context = ctx
    return _serialize_context
Esempio n. 32
0
def test_torch_serialization(large_buffer):
    pytest.importorskip("torch")

    serialization_context = pa.default_serialization_context()
    pa.register_torch_serialization_handlers(serialization_context)
    # These are the only types that are supported for the
    # PyTorch to NumPy conversion
    for t in ["float32", "float64", "uint8", "int16", "int32", "int64"]:
        obj = torch.from_numpy(np.random.randn(1000).astype(t))
        serialization_roundtrip(obj,
                                large_buffer,
                                context=serialization_context)
Esempio n. 33
0
def test_torch_serialization(large_buffer):
    pytest.importorskip("torch")

    serialization_context = pa.default_serialization_context()
    pa.register_torch_serialization_handlers(serialization_context)
    # These are the only types that are supported for the
    # PyTorch to NumPy conversion
    for t in ["float32", "float64",
              "uint8", "int16", "int32", "int64"]:
        obj = torch.from_numpy(np.random.randn(1000).astype(t))
        serialization_roundtrip(obj, large_buffer,
                                context=serialization_context)
Esempio n. 34
0
def make_serialization_context():
    context = pa.default_serialization_context()

    context.register_type(Foo, "Foo")
    context.register_type(Bar, "Bar")
    context.register_type(Baz, "Baz")
    context.register_type(Qux, "Quz")
    context.register_type(SubQux, "SubQux")
    context.register_type(SubQuxPickle, "SubQuxPickle", pickle=True)
    context.register_type(Exception, "Exception")
    context.register_type(CustomError, "CustomError")
    context.register_type(Point, "Point")
    context.register_type(NamedTupleExample, "NamedTupleExample")

    return context
Esempio n. 35
0
def test_serialization_callback_numpy():

    class DummyClass(object):
        pass

    def serialize_dummy_class(obj):
        x = np.zeros(4)
        return x

    def deserialize_dummy_class(serialized_obj):
        return serialized_obj

    context = pa.default_serialization_context()
    context.register_type(DummyClass, "DummyClass",
                          custom_serializer=serialize_dummy_class,
                          custom_deserializer=deserialize_dummy_class)

    pa.serialize(DummyClass(), context=context)
Esempio n. 36
0
def test_buffer_serialization():

    class BufferClass(object):
        pass

    def serialize_buffer_class(obj):
        return pa.py_buffer(b"hello")

    def deserialize_buffer_class(serialized_obj):
        return serialized_obj

    context = pa.default_serialization_context()
    context.register_type(
        BufferClass, "BufferClass",
        custom_serializer=serialize_buffer_class,
        custom_deserializer=deserialize_buffer_class)

    b = pa.serialize(BufferClass(), context=context).to_buffer()
    assert pa.deserialize(b, context=context).to_pybytes() == b"hello"