Esempio n. 1
0
def TransformReturnedLocalBlob(local_blob, annotation):
    if oft.OriginFrom(annotation, typing.Tuple):
        assert type(local_blob) is tuple
        assert len(local_blob) == len(annotation.__args__)
        pairs = zip(local_blob, annotation.__args__)
        return tuple(TransformReturnedLocalBlob(*pair) for pair in pairs)
    elif oft.OriginFrom(annotation, typing.List):
        assert type(local_blob) is list
        assert len(annotation.__args__) == 1
        return [
            TransformReturnedLocalBlob(elem, annotation.__args__[0])
            for elem in local_blob
        ]
    elif oft.OriginFrom(annotation, typing.Dict):
        assert type(local_blob) is dict
        assert len(annotation.__args__) == 2
        vals = [
            TransformReturnedLocalBlob(val, annotation.__args__[1])
            for val in local_blob.values()
        ]
        return dict(zip(local_blob.keys(), vals))
    elif oft.OriginFrom(annotation, oft.PyStructCompatibleToBlob):
        return TransformLocalBlob(local_blob, annotation)
    else:
        raise NotImplementedError(
            "invalid watch callback parameter annotation %s found" %
            annotation)
Esempio n. 2
0
def TransformLocalBlob(future_blob, annotation):
    if oft.OriginFrom(annotation, oft.Numpy):
        return future_blob.numpy()
    elif oft.OriginFrom(annotation, oft.ListNumpy):
        return future_blob.numpy_list()
    else:
        raise NotImplementedError(
            "invalid watch callback parameter annotation %s found" %
            annotation)
Esempio n. 3
0
def CheckGlobalFunctionParamAnnotation(cls):
    if oft.OriginFrom(cls, typing.Tuple):
        assert cls.__args__ is not None, "T in typing.Tuple[T, ...] cannot be omitted"
        assert len(cls.__args__) > 0
        for cls_arg in cls.__args__:
            CheckGlobalFunctionParamAnnotation(cls_arg)
    elif oft.OriginFrom(cls, oft.OneflowNumpyDef):
        pass
    else:
        raise NotImplementedError("invalid parameter annotation %s found" % cls)
Esempio n. 4
0
def _RecusiveMakeInputBlobDef(cls):
    if oft.OriginFrom(cls, oft.OneflowNumpyDef):
        return cls.NewInputBlobDef()
    elif oft.OriginFrom(cls, typing.Tuple):
        return tuple(_RecusiveMakeInputBlobDef(a) for a in cls.__args__)
    else:
        raise NotImplementedError(
            ("\nannotation %s" % cls) + "not supported" +
            "\nonly support oneflow.typing.Numpy.Placeholder, "
            "oneflow.typing.ListNumpy.Placeholder")
Esempio n. 5
0
def _CheckReturnByAnnotation(function_name, ret, annotation):
    error_str = "%s does not matched return annotation %s of global_function %s." % (
        ret,
        annotation,
        function_name,
    )
    if oft.OriginFrom(annotation, typing.Tuple):
        assert type(ret) is tuple, error_str
        assert len(ret) == len(
            annotation.__args__), "%s length compare: %s v.s. %s" % (
                error_str,
                len(ret),
                len(annotation.__args__),
            )
        for ret_i, annotation_i in zip(ret, annotation.__args__):
            _CheckReturnByAnnotation(function_name, ret_i, annotation_i)
    elif oft.OriginFrom(annotation, typing.List):
        assert type(ret) is list, error_str
        assert len(annotation.__args__) == 1, (
            "%s element type in list must be unique" % error_str)
        for ret_i in ret:
            _CheckReturnByAnnotation(function_name, ret_i,
                                     annotation.__args__[0])
    elif oft.OriginFrom(annotation, typing.Dict):
        assert len(annotation.__args__) == 2
        assert type(ret) is dict, error_str
        for key, val in ret.items():
            assert type(key) is annotation.__args__[0], (
                "type of %s:%s and %s:%s do not matched return annotation (%s, %s) of global_function %s."
                % (
                    key,
                    type(key),
                    val,
                    type(val),
                    annotation.__args__[0],
                    annotation.__args__[1],
                    function_name,
                ))
            _CheckReturnByAnnotation(function_name, val,
                                     annotation.__args__[1])
    elif oft.OriginFrom(annotation, oft.Numpy):
        assert isinstance(
            ret,
            oneflow._oneflow_internal.BlobDesc), "type(ret): %s" % type(ret)
        # TODO(chengcheng) oft.Numpy support dynamic.
        assert not ret.is_dynamic, (
            "only fixed shaped blob compatible to oneflow.typing.Numpy. "
            "you can change annotation to oneflow.typing.ListNumpy ")
    elif oft.OriginFrom(annotation, oft.ListNumpy):
        assert isinstance(
            ret,
            oneflow._oneflow_internal.BlobDesc), "type(ret): %s" % type(ret)
    else:
        raise NotImplementedError("invalid return annotation %s found" %
                                  annotation)
Esempio n. 6
0
def _CheckGlobalFunctionReturnAnnotation(cls):
    if oft.OriginFrom(cls, typing.Tuple):
        assert cls.__args__ is not None, "T in typing.Tuple[T, ...] cannot be omitted"
        assert len(cls.__args__) > 0
        for cls_arg in cls.__args__:
            _CheckGlobalFunctionReturnAnnotation(cls_arg)
    elif oft.OriginFrom(cls, typing.Dict):
        assert cls.__args__ is not None, "(K, V) in typing.Dict[K,V] cannot be omitted"
        assert len(cls.__args__) == 2
        _CheckGlobalFunctionReturnAnnotation(cls.__args__[1])
    elif oft.OriginFrom(cls, oft.PyStructCompatibleToBlob):
        pass
    else:
        raise NotImplementedError("invalid return annotation %s found" % cls)
Esempio n. 7
0
def CheckWatchedBlobByAnnotation(blob, annotation):
    if annotation is inspect._empty:
        return
    if oft.OriginFrom(annotation, oft.Numpy):
        # TODO(chengcheng) oft.Numpy support dynamic.
        assert not blob.is_dynamic, (
            "only fixed shaped blob compatible to oneflow.typing.Numpy. "
            "you can change annotation to oneflow.typing.ListNumpy ")
    elif oft.OriginFrom(annotation, oft.ListNumpy):
        pass
    else:
        raise NotImplementedError(
            "invalid watch callback parameter annotation %s found" %
            annotation)
Esempio n. 8
0
def TransformGlobalFunctionResult(future_blob, annotation):
    if annotation is inspect._empty:
        return future_blob
    elif annotation is None:
        assert future_blob is None
        return None
    elif oft.OriginFrom(annotation, oft.Callback):
        annotation = annotation.__args__[0]

        def Transform(f):
            return lambda x: f(TransformReturnedLocalBlob(x, annotation))

        return lambda f: future_blob.async_get(Transform(f))
    elif oft.OriginFrom(annotation, oft.Bundle):
        return TransformReturnedBundle(future_blob.get(), annotation)
    else:
        return TransformReturnedLocalBlob(future_blob.get(), annotation)
Esempio n. 9
0
def CheckGlobalFunctionReturnAnnotation(cls):
    if cls is None:
        pass
    elif oft.OriginFrom(cls, oft.Callback):
        assert (
            cls.__args__
            is not None), "T in oneflow.typing.Callback[T] cannot be omitted"
        assert len(cls.__args__) == 1
        _CheckGlobalFunctionReturnAnnotation(cls.__args__[0])
    elif oft.OriginFrom(cls, oft.Bundle):
        assert cls.__args__[0] in (
            oft.Numpy,
            oft.ListNumpy,
        ), "T in oneflow.typing.Bundle[T] must be one of (oneflow.typing.Numpy, oneflow.typing.ListNumpy)"
        assert len(cls.__args__) == 1
        _CheckGlobalFunctionReturnAnnotation(cls.__args__[0])
    else:
        _CheckGlobalFunctionReturnAnnotation(cls)
Esempio n. 10
0
def CheckGlobalFunctionReturnAnnotation(cls):
    if cls is None:
        pass
    elif oft.OriginFrom(cls, oft.Callback):
        assert (
            cls.__args__
            is not None), "T in oneflow.typing.Callback[T] cannot be omitted"
        assert len(cls.__args__) == 1
        _CheckGlobalFunctionReturnAnnotation(cls.__args__[0])
    else:
        _CheckGlobalFunctionReturnAnnotation(cls)
Esempio n. 11
0
def CheckReturnByAnnotation(function_name, ret, annotation):
    if annotation is inspect._empty:
        return
    if annotation is None:
        error_str = (
            "%s does not matched return annotation %s of global_function %s." %
            (ret, annotation, function_name))
        assert ret is None, error_str
    elif oft.OriginFrom(annotation, oft.Callback):
        _CheckReturnByAnnotation(function_name, ret, annotation.__args__[0])
    else:
        _CheckReturnByAnnotation(function_name, ret, annotation)
Esempio n. 12
0
def CheckWatchCallbackParameterAnnotation(parameters):
    assert len(parameters) == 1, "watch callback should accept only one parameter"
    annotation = parameters[list(parameters.keys())[0]].annotation
    if annotation is inspect._empty:
        if enable_typing_check.typing_check_enabled:
            raise NotImplementedError("the watch callback's parameter is not annotated")
        return
    if not oft.OriginFrom(annotation, oft.PyStructCompatibleToBlob):
        raise NotImplementedError(
            ("invalid watch callback paremeter annotation %s found. " % annotation)
            + "candidate annotations: oneflow.typing.Numpy, oneflow.typing.ListNumpy, "
            "oneflow.typing.ListListNumpy"
        )
Esempio n. 13
0
def CheckReturnByAnnotation(function_name, ret, annotation):
    if annotation is inspect._empty:
        return
    if annotation is None:
        error_str = (
            "%s does not matched return annotation %s of global_function %s." %
            (ret, annotation, function_name))
        assert ret is None, error_str
    elif oft.OriginFrom(annotation, oft.Callback):
        _CheckReturnByAnnotation(function_name, ret, annotation.__args__[0])
    elif oft.OriginFrom(annotation, oft.Bundle):
        if isinstance(ret, oneflow._oneflow_internal.BlobDesc):
            _CheckReturnByAnnotation(function_name, ret,
                                     annotation.__args__[0])
        elif isinstance(ret, (list, tuple)):
            for elem in ret:
                CheckReturnByAnnotation(function_name, elem, annotation)
        elif type(ret) is dict:
            for val in ret.values():
                CheckReturnByAnnotation(function_name, val, annotation)
        else:
            raise NotImplementedError("invalid return  %s found" % (type(ret)))
    else:
        _CheckReturnByAnnotation(function_name, ret, annotation)