def test_deserializer():
    array_data = [[1.0, 2.0, 3.0], [10.0, 20.0, 30.0]]
    s = RecordSerializer()
    buf = s.serialize(np.array(array_data))
    d = RecordDeserializer()
    for record, expected in zip(d.deserialize(buf, "who cares"), array_data):
        assert record.features["values"].float64_tensor.values == expected
    def __init__(
            self,
            endpoint_name,
            sagemaker_session=None,
            serializer=RecordSerializer(),
            deserializer=RecordDeserializer(),
    ):
        """Initialization for LinearLearnerPredictor.

        Args:
            endpoint_name (str): Name of the Amazon SageMaker endpoint to which
                requests are sent.
            sagemaker_session (sagemaker.session.Session): A SageMaker Session
                object, used for SageMaker interactions (default: None). If not
                specified, one is created using the default AWS configuration
                chain.
            serializer (sagemaker.serializers.BaseSerializer): Optional. Default
                serializes input data to x-recordio-protobuf format.
            deserializer (sagemaker.deserializers.BaseDeserializer): Optional.
                Default parses responses from x-recordio-protobuf format.
        """
        super(LinearLearnerPredictor, self).__init__(
            endpoint_name,
            sagemaker_session,
            serializer=serializer,
            deserializer=deserializer,
        )
Ejemplo n.º 3
0
 def __init__(self, endpoint_name, sagemaker_session=None):
     """
     Args:
         endpoint_name (str): Name of the Amazon SageMaker endpoint to which
             requests are sent.
         sagemaker_session (sagemaker.session.Session): A SageMaker Session
             object, used for SageMaker interactions (default: None). If not
             specified, one is created using the default AWS configuration
             chain.
     """
     super(FactorizationMachinesPredictor, self).__init__(
         endpoint_name,
         sagemaker_session,
         serializer=RecordSerializer(),
         deserializer=RecordDeserializer(),
     )