コード例 #1
0
ファイル: ops.py プロジェクト: zhangqiang880/DALI
def _get_kwargs(schema):
    """
    Get the keywords arguments from the schema.

    `schema`
        the schema in which to lookup arguments
    """
    ret = ""
    for arg in schema.GetArgumentNames():
        skip_full_doc = False
        type_name = ""
        dtype = None
        doc = ""
        deprecation_warning = None
        if schema.IsDeprecatedArg(arg):
            meta = schema.DeprecatedArgMeta(arg)
            msg = meta['msg']
            assert msg is not None
            deprecation_warning = ".. warning::\n\n    " + msg.replace(
                "\n", "\n    ")
            renamed_arg = meta['renamed_to']
            # Renamed and removed arguments won't show full documentation (only warning box)
            skip_full_doc = renamed_arg or meta['removed']
            # Renamed aliases are not fully registered to the schema, that's why we query for the
            # info on the renamed_arg name.
            if renamed_arg:
                dtype = schema.GetArgumentType(renamed_arg)
                type_name = _type_name_convert_to_string(
                    dtype, allow_tensors=schema.IsTensorArgument(renamed_arg))
        # Try to get dtype only if not set already (renamed args go through a different path, see above)
        if not dtype:
            dtype = schema.GetArgumentType(arg)
            type_name = _type_name_convert_to_string(
                dtype, allow_tensors=schema.IsTensorArgument(arg))
        # Add argument documentation if necessary
        if not skip_full_doc:
            if schema.IsArgumentOptional(arg):
                type_name += ", optional"
                if schema.HasArgumentDefaultValue(arg):
                    default_value_string = schema.GetArgumentDefaultValueString(
                        arg)
                    default_value = ast.literal_eval(default_value_string)
                    type_name += ", default = {}".format(
                        _default_converter(dtype, default_value))
            doc += schema.GetArgumentDox(arg)
            if deprecation_warning:
                doc += "\n\n" + deprecation_warning
        elif deprecation_warning:
            doc += deprecation_warning
        ret += _numpydoc_formatter(arg, type_name, doc)
        ret += '\n'
    return ret
コード例 #2
0
def _get_kwargs(schema, only_tensor=False):
    """
    Get the keywords arguments from the schema.

    `schema`
        the schema in which to lookup arguments
    `only_tensor`: bool
        If True list only keyword arguments that can be passed as TensorLists (argument inputs)
        If False list all the arguments. False indicates that we list arguments to the
        constructor of the operator which does not accept TensorLists (argument inputs) - that
        fact will be reflected in specified type
    """
    ret = ""
    for arg in schema.GetArgumentNames():
        if not only_tensor or schema.IsTensorArgument(arg):
            arg_name_doc = arg
            dtype = schema.GetArgumentType(arg)
            type_name = _type_name_convert_to_string(dtype,
                                                     is_tensor=only_tensor)
            if schema.IsArgumentOptional(arg):
                type_name += ", optional"
                if schema.HasArgumentDefaultValue(arg):
                    default_value_string = schema.GetArgumentDefaultValueString(
                        arg)
                    default_value = eval(default_value_string)
                    type_name += ", default = {}".format(
                        repr(_type_convert_value(dtype, default_value)))
            doc = schema.GetArgumentDox(arg)
            ret += _numpydoc_formatter(arg, type_name, doc)
            ret += '\n'
    return ret
コード例 #3
0
ファイル: ops.py プロジェクト: davidtranno1/DALI
def _docstring_generator(cls):
    op_name = cls.__name__
    op_dev = []
    if op_name in _cpu_ops:
        op_dev.append("'CPU'")
    if op_name in _gpu_ops:
        op_dev.append("'GPU'")
    if op_name in _mixed_ops:
        op_dev.append("'mixed'")
    if op_name in _support_ops:
        op_dev.append("'support'")
    pre_doc = "This is a " + ", ".join(op_dev) + " operator\n\n"

    schema = b.GetSchema(op_name)
    # insert tag to easily link to the operator
    ret = '.. _' + op_name + ':\n\n'
    ret += pre_doc
    ret += schema.Dox()
    ret += '\n'
    if schema.IsSequenceOperator():
        ret += "\nThis operator expects sequence inputs\n"
    elif schema.AllowsSequences():
        ret += "\nThis operator allows sequence inputs\n"

    if schema.IsDeprecated():
        use_instead = schema.DeprecatedInFavorOf()
        ret += "\n.. warning::\n\n   This operator is now deprecated"
        if use_instead:
            ret +=". Use `" + use_instead + "` instead"
        ret += "\n"

    if schema.IsNoPrune():
        ret += "\nThis operator will **not** be optimized out of the graph.\n"

    ret += """
Parameters
----------
"""
    for arg in schema.GetArgumentNames():
        dtype = schema.GetArgumentType(arg)
        arg_name_doc = "`" + arg + "` : "
        ret += (arg_name_doc +
                _type_name_convert_to_string(dtype, schema.IsTensorArgument(arg)))
        if schema.IsArgumentOptional(arg):
            default_value_string = schema.GetArgumentDefaultValueString(arg)
            # Evaluating empty string results in an error
            # so we need to prevent that
            if default_value_string:
                default_value = eval(default_value_string)
            else:
                default_value = default_value_string
            if dtype == DALIDataType.STRING:
                default_value = "\'" + str(default_value) + "\'"
            ret += (", optional, default = " +
                    str(_type_convert_value(dtype, default_value)))
        indent = '\n' + " " * len(arg_name_doc)
        ret += indent
        ret += schema.GetArgumentDox(arg).replace("\n", indent)
        ret += '\n'
    return ret
コード例 #4
0
ファイル: ops.py プロジェクト: zzm422/DALI
def _docstring_generator(cls):
    schema = b.GetSchema(cls.__name__)
    ret = schema.Dox()
    ret += '\n'
    ret += """
Parameters
----------
"""
    for arg in schema.GetArgumentNames():
        dtype = schema.GetArgumentType(arg)
        arg_name_doc = "`" + arg + "` : "
        ret += (
            arg_name_doc +
            _type_name_convert_to_string(dtype, schema.IsTensorArgument(arg)))
        if schema.IsArgumentOptional(arg):
            default_value_string = schema.GetArgumentDefaultValueString(arg)
            # Evaluating empty string results in an error
            # so we need to prevent that
            if default_value_string:
                default_value = eval(default_value_string)
            else:
                default_value = default_value_string
            if dtype == DALIDataType.STRING:
                default_value = "\'" + str(default_value) + "\'"
            ret += (", optional, default = " +
                    str(_type_convert_value(dtype, default_value)))
        indent = '\n' + " " * len(arg_name_doc)
        ret += indent
        ret += schema.GetArgumentDox(arg).replace("\n", indent)
        ret += '\n'
    return ret
コード例 #5
0
ファイル: ops.py プロジェクト: bariarviv/DALI
def _get_kwargs(schema):
    """
    Get the keywords arguments from the schema.

    `schema`
        the schema in which to lookup arguments
    """
    ret = ""
    for arg in schema.GetArgumentNames():
        allow_tensors = schema.IsTensorArgument(arg)
        arg_name_doc = arg
        dtype = schema.GetArgumentType(arg)
        type_name = _type_name_convert_to_string(dtype,
                                                 allow_tensors=allow_tensors)
        if schema.IsArgumentOptional(arg):
            type_name += ", optional"
            if schema.HasArgumentDefaultValue(arg):
                default_value_string = schema.GetArgumentDefaultValueString(
                    arg)
                default_value = ast.literal_eval(default_value_string)
                type_name += ", default = {}".format(
                    _default_converter(dtype, default_value))
        doc = schema.GetArgumentDox(arg)
        ret += _numpydoc_formatter(arg, type_name, doc)
        ret += '\n'
    return ret
コード例 #6
0
ファイル: ops.py プロジェクト: zhangqiang880/DALI
def _check_arg_input(schema, op_name, name):
    if name == "name":
        return
    if not schema.IsTensorArgument(name):
        raise TypeError(
            "The argument `{}` for operator `{}` should not be a `DataNode` but a {}"
            .format(
                name, op_name,
                _type_name_convert_to_string(schema.GetArgumentType(name),
                                             False)))
コード例 #7
0
ファイル: ops.py プロジェクト: xeransis/DALI
def _docstring_generator(cls):
    __cpu_ops = set(b.RegisteredCPUOps())
    __cpu_ops.add("TFRecordReader")
    __gpu_ops = set(b.RegisteredGPUOps())
    __mix_ops = set(b.RegisteredMixedOps())
    __support_ops = set(b.RegisteredSupportOps())
    op_name = cls.__name__
    op_dev = []
    if op_name in __cpu_ops:
        op_dev.append("'CPU'")
    if op_name in __gpu_ops:
        op_dev.append("'GPU'")
    if op_name in __mix_ops:
        op_dev.append("'mixed'")
    if op_name in __support_ops:
        op_dev.append("'support'")
    pre_doc = "This is a " + ", ".join(op_dev) + " operator\n\n"

    schema = b.GetSchema(op_name)
    # insert tag to easily link to the operator
    ret = '.. _' + op_name + ':\n\n'
    ret += pre_doc
    ret += schema.Dox()
    ret += '\n'
    if schema.IsSequenceOperator():
        ret += "\nThis operator expects sequence inputs\n"
    elif schema.AllowsSequences():
        ret += "\nThis operator allows sequence inputs\n"
    ret += """
Parameters
----------
"""
    for arg in schema.GetArgumentNames():
        dtype = schema.GetArgumentType(arg)
        arg_name_doc = "`" + arg + "` : "
        ret += (
            arg_name_doc +
            _type_name_convert_to_string(dtype, schema.IsTensorArgument(arg)))
        if schema.IsArgumentOptional(arg):
            default_value_string = schema.GetArgumentDefaultValueString(arg)
            # Evaluating empty string results in an error
            # so we need to prevent that
            if default_value_string:
                default_value = eval(default_value_string)
            else:
                default_value = default_value_string
            if dtype == DALIDataType.STRING:
                default_value = "\'" + str(default_value) + "\'"
            ret += (", optional, default = " +
                    str(_type_convert_value(dtype, default_value)))
        indent = '\n' + " " * len(arg_name_doc)
        ret += indent
        ret += schema.GetArgumentDox(arg).replace("\n", indent)
        ret += '\n'
    return ret