예제 #1
0
파일: dataset.py 프로젝트: zyfra/ebonite
 def serialize(self, instance: dict):
     self._check_type(instance, dict, SerializationError)
     try:
         items = {self.get_key(k): v for k, v in instance.items()}
     except ValueError as e:
         raise SerializationError(e)
     return super().serialize(self, items)
예제 #2
0
파일: dataset.py 프로젝트: zyfra/ebonite
 def serialize(self, instance: np.ndarray):
     self._check_type(instance, np.ndarray, SerializationError)
     exp_type = np_type_from_string(self.dtype)
     if instance.dtype != exp_type:
         raise SerializationError(f'given array is of type: {instance.dtype}, expected: {exp_type}')
     self._check_shape(instance, SerializationError)
     return instance.tolist()
예제 #3
0
파일: dataset.py 프로젝트: zyfra/ebonite
 def serialize(self, instance: torch.Tensor):
     self._check_type(instance, torch.Tensor, SerializationError)
     if instance.dtype is not getattr(torch, self.dtype):
         raise SerializationError(
             f'given tensor is of dtype: {instance.dtype}, '
             f'expected: {getattr(torch, self.dtype)}')
     self._check_shape(instance, SerializationError)
     return instance.tolist()
예제 #4
0
def _serialize_with_serializer(obj, serializer: Serializer):
    if issubclass_safe(serializer, StaticSerializer):
        return serializer.serialize(obj)

    if not serializer._is_dynamic:
        raise SerializationError(
            'Cannot use uninitialized serializer. '
            'Initialize it or replace with StaticSerializer')
    return serializer.serialize(obj)
예제 #5
0
def _serialize_union(obj, class_union):
    for as_class in union_args(class_union):
        try:
            return serialize(obj, as_class)
        except SerializationError:
            pass
    else:
        raise SerializationError(
            'None of the possible types matched for obj {} and type {}'.format(
                obj, class_union))
예제 #6
0
def _serialize_to_dict(cls, obj):
    result = {}
    fields = get_class_fields(cls)
    for f in fields:
        name = f.name
        field = getattr(obj, name)
        if field is not None:
            if hasattr(cls, FIELD_MAPPING_NAME_FIELD):
                name = getattr(cls, FIELD_MAPPING_NAME_FIELD).get(name, name)
            result[name] = serialize(field, f.type)

    if type_field_position_is(cls, Position.INSIDE):
        type_field_name = get_type_field_name(cls)
        if type_field_name in result:
            raise SerializationError(
                'Type field name {} conflicts with field name in {}'.format(
                    type_field_name, cls))
        result[type_field_name] = getattr(cls, type_field_name)
    return result
예제 #7
0
파일: dataset.py 프로젝트: zyfra/ebonite
 def serialize(self, instance: xgboost.DMatrix) -> list:
     """
     Raises an error because there is no way to extract original data from DMatrix
     """
     raise SerializationError('xgboost matrix does not support serialization')