コード例 #1
0
ファイル: io.py プロジェクト: zkkxu/tf-quant-finance
"""

from typing import Dict, Callable, Optional

import numpy as np
import tensorflow.compat.v2 as tf

__all__ = [
    'encode_array',
    'decode_array',
    'ArrayDictReader',
    'ArrayDictWriter',
]

# Needed for decoding serialized arrays.
_CLS = type(tf.make_tensor_proto([0]))

ArrayEncoderFn = Callable[[np.ndarray], bytes]
ArrayDecoderFn = Callable[[bytes], np.ndarray]


def encode_array(x: np.ndarray) -> bytes:
    """Encodes a numpy array using `TensorProto` protocol buffer."""
    return tf.make_tensor_proto(x).SerializeToString()


def decode_array(bytestring: bytes) -> np.ndarray:
    """Decodes a bytestring into a numpy array.

  The bytestring should be a serialized `TensorProto` instance. For more details
  see `tf.make_tensor_proto`.
コード例 #2
0
ファイル: io.py プロジェクト: zkkxu/tf-quant-finance
def encode_array(x: np.ndarray) -> bytes:
    """Encodes a numpy array using `TensorProto` protocol buffer."""
    return tf.make_tensor_proto(x).SerializeToString()
コード例 #3
0
def _hash(tensor):
    content = tf.make_tensor_proto(tensor).SerializeToString()
    return hashlib.md5(content).hexdigest()