Example #1
        def __init__(self, **kwargs):
            self._spec = b.OpSpec(type(self).__name__)
            self._schema = b.GetSchema(type(self).__name__)

            # Get the device argument. We will need this to determine
            # the device that our outputs will be stored on
            if "device" in kwargs.keys():
                self._device = kwargs["device"]
                del kwargs["device"]
                self._device = op_device
            self._spec.AddArg("device", self._device)

            if "preserve" in kwargs.keys():
                self._preserve = kwargs["preserve"]
                self._preserve = False
            self._spec.AddArg("preserve", self._preserve)
            self._preserve = self._preserve or self._schema.IsNoPrune()

            # Store the specified arguments
            for key, value in kwargs.items():
                if isinstance(value, list):
                    if not value:
                        raise RuntimeError(
                            "List arguments need to have at least 1 element.")
                dtype = self._schema.GetArgumentType(key)
                converted_value = _type_convert_value(dtype, value)
                self._spec.AddArg(key, converted_value)
Example #2
        def __init__(self, **kwargs):
            self._spec = b.OpSpec(type(self).__name__)
            self._schema = b.GetSchema(type(self).__name__)

            # Get the device argument. We will need this to determine
            # the device that our outputs will be stored on
            if "device" in kwargs.keys():
                self._device = kwargs["device"]
                del kwargs["device"]
                self._device = op_device
            self._spec.AddArg("device", self._device)

            if "preserve" in kwargs.keys():
                self._preserve = kwargs["preserve"]
                self._preserve = False
            self._spec.AddArg("preserve", self._preserve)
            self._preserve = self._preserve or self._schema.IsNoPrune()

            # Store the specified arguments
            for key, value in kwargs.items():
                if value is None:
                    # None is not a valid value for any argument type, so treat it
                    # as if the argument was not supplied at all

                dtype = self._schema.GetArgumentType(key)
                if isinstance(value, (list, tuple)):
                    if len(value) == 0:
                converted_value = _type_convert_value(dtype, value)
                self._spec.AddArg(key, converted_value)
Example #3
def _get_kwargs(schema, only_tensor=False):
    Get the keywords arguments from the schema.

        the schema in which to lookup arguments
    `only_tensor`: bool
        If True list only keyword arguments that can be passed as TensorLists (argument inputs)
        If False list all the arguments. False indicates that we list arguments to the
        constructor of the operator which does not accept TensorLists (argument inputs) - that
        fact will be reflected in specified type
    ret = ""
    for arg in schema.GetArgumentNames():
        if not only_tensor or schema.IsTensorArgument(arg):
            arg_name_doc = arg
            dtype = schema.GetArgumentType(arg)
            type_name = _type_name_convert_to_string(dtype,
            if schema.IsArgumentOptional(arg):
                type_name += ", optional"
                if schema.HasArgumentDefaultValue(arg):
                    default_value_string = schema.GetArgumentDefaultValueString(
                    default_value = eval(default_value_string)
                    type_name += ", default = {}".format(
                        repr(_type_convert_value(dtype, default_value)))
            doc = schema.GetArgumentDox(arg)
            ret += _numpydoc_formatter(arg, type_name, doc)
            ret += '\n'
    return ret
Example #4
def _docstring_generator(cls):
    op_name = cls.__name__
    op_dev = []
    if op_name in _cpu_ops:
    if op_name in _gpu_ops:
    if op_name in _mixed_ops:
    if op_name in _support_ops:
    pre_doc = "This is a " + ", ".join(op_dev) + " operator\n\n"

    schema = b.GetSchema(op_name)
    # insert tag to easily link to the operator
    ret = '.. _' + op_name + ':\n\n'
    ret += pre_doc
    ret += schema.Dox()
    ret += '\n'
    if schema.IsSequenceOperator():
        ret += "\nThis operator expects sequence inputs\n"
    elif schema.AllowsSequences():
        ret += "\nThis operator allows sequence inputs\n"

    if schema.IsDeprecated():
        use_instead = schema.DeprecatedInFavorOf()
        ret += "\n.. warning::\n\n   This operator is now deprecated"
        if use_instead:
            ret +=". Use `" + use_instead + "` instead"
        ret += "\n"

    if schema.IsNoPrune():
        ret += "\nThis operator will **not** be optimized out of the graph.\n"

    ret += """
    for arg in schema.GetArgumentNames():
        dtype = schema.GetArgumentType(arg)
        arg_name_doc = "`" + arg + "` : "
        ret += (arg_name_doc +
                _type_name_convert_to_string(dtype, schema.IsTensorArgument(arg)))
        if schema.IsArgumentOptional(arg):
            default_value_string = schema.GetArgumentDefaultValueString(arg)
            # Evaluating empty string results in an error
            # so we need to prevent that
            if default_value_string:
                default_value = eval(default_value_string)
                default_value = default_value_string
            if dtype == DALIDataType.STRING:
                default_value = "\'" + str(default_value) + "\'"
            ret += (", optional, default = " +
                    str(_type_convert_value(dtype, default_value)))
        indent = '\n' + " " * len(arg_name_doc)
        ret += indent
        ret += schema.GetArgumentDox(arg).replace("\n", indent)
        ret += '\n'
    return ret
Example #5
File: ops.py Project: zzm422/DALI
def _docstring_generator(cls):
    schema = b.GetSchema(cls.__name__)
    ret = schema.Dox()
    ret += '\n'
    ret += """
    for arg in schema.GetArgumentNames():
        dtype = schema.GetArgumentType(arg)
        arg_name_doc = "`" + arg + "` : "
        ret += (
            arg_name_doc +
            _type_name_convert_to_string(dtype, schema.IsTensorArgument(arg)))
        if schema.IsArgumentOptional(arg):
            default_value_string = schema.GetArgumentDefaultValueString(arg)
            # Evaluating empty string results in an error
            # so we need to prevent that
            if default_value_string:
                default_value = eval(default_value_string)
                default_value = default_value_string
            if dtype == DALIDataType.STRING:
                default_value = "\'" + str(default_value) + "\'"
            ret += (", optional, default = " +
                    str(_type_convert_value(dtype, default_value)))
        indent = '\n' + " " * len(arg_name_doc)
        ret += indent
        ret += schema.GetArgumentDox(arg).replace("\n", indent)
        ret += '\n'
    return ret
Example #6
        def __init__(self, **kwargs):
            schema_name = _schema_name(type(self))
            self._spec = _b.OpSpec(schema_name)
            self._schema = _b.GetSchema(schema_name)

            # Get the device argument. We will need this to determine
            # the device that our outputs will be stored on
            if "device" in kwargs.keys():
                self._device = kwargs["device"]
                del kwargs["device"]
                self._device = op_device
            self._spec.AddArg("device", self._device)

            if "preserve" in kwargs.keys():
                self._preserve = kwargs["preserve"]
                self._preserve = False
            self._spec.AddArg("preserve", self._preserve)
            self._preserve = self._preserve or self._schema.IsNoPrune()

            # Check for any deprecated arguments that should be replaced or removed
            arg_names = list(kwargs.keys())
            for arg_name in arg_names:
                if not self._schema.IsDeprecatedArg(arg_name):
                meta = self._schema.DeprecatedArgMeta(arg_name)
                new_name = meta['renamed_to']
                removed = meta['removed']
                msg = meta['msg']
                if new_name:
                    if new_name in kwargs:
                        raise TypeError(
                            "Operator {} got an unexpected '{}' deprecated argument when '{}' was already provided"
                            .format(type(self).__name__, arg_name, new_name))
                    kwargs[new_name] = kwargs[arg_name]
                    del kwargs[arg_name]
                elif removed:
                    del kwargs[arg_name]

                with warnings.catch_warnings():
                    warnings.warn(msg, DeprecationWarning, stacklevel=2)

            # Store the specified arguments
            for key, value in kwargs.items():
                if value is None:
                    # None is not a valid value for any argument type, so treat it
                    # as if the argument was not supplied at all

                dtype = self._schema.GetArgumentType(key)
                if isinstance(value, (list, tuple)):
                    if len(value) == 0:
                converted_value = _type_convert_value(dtype, value)
                self._spec.AddArg(key, converted_value)
Example #7
File: ops.py Project: xeransis/DALI
def _docstring_generator(cls):
    __cpu_ops = set(b.RegisteredCPUOps())
    __gpu_ops = set(b.RegisteredGPUOps())
    __mix_ops = set(b.RegisteredMixedOps())
    __support_ops = set(b.RegisteredSupportOps())
    op_name = cls.__name__
    op_dev = []
    if op_name in __cpu_ops:
    if op_name in __gpu_ops:
    if op_name in __mix_ops:
    if op_name in __support_ops:
    pre_doc = "This is a " + ", ".join(op_dev) + " operator\n\n"

    schema = b.GetSchema(op_name)
    # insert tag to easily link to the operator
    ret = '.. _' + op_name + ':\n\n'
    ret += pre_doc
    ret += schema.Dox()
    ret += '\n'
    if schema.IsSequenceOperator():
        ret += "\nThis operator expects sequence inputs\n"
    elif schema.AllowsSequences():
        ret += "\nThis operator allows sequence inputs\n"
    ret += """
    for arg in schema.GetArgumentNames():
        dtype = schema.GetArgumentType(arg)
        arg_name_doc = "`" + arg + "` : "
        ret += (
            arg_name_doc +
            _type_name_convert_to_string(dtype, schema.IsTensorArgument(arg)))
        if schema.IsArgumentOptional(arg):
            default_value_string = schema.GetArgumentDefaultValueString(arg)
            # Evaluating empty string results in an error
            # so we need to prevent that
            if default_value_string:
                default_value = eval(default_value_string)
                default_value = default_value_string
            if dtype == DALIDataType.STRING:
                default_value = "\'" + str(default_value) + "\'"
            ret += (", optional, default = " +
                    str(_type_convert_value(dtype, default_value)))
        indent = '\n' + " " * len(arg_name_doc)
        ret += indent
        ret += schema.GetArgumentDox(arg).replace("\n", indent)
        ret += '\n'
    return ret
Example #8
def _add_spec_args(schema, spec, kwargs):
    for key, value in kwargs.items():
        if value is None:
            # None is not a valid value for any argument type, so treat it
            # as if the argument was not supplied at all

        dtype = schema.GetArgumentType(key)
        if isinstance(value, (list, tuple)):
            if len(value) == 0:
                spec.AddArgEmptyList(key, _vector_element_type(dtype))
        converted_value = _type_convert_value(dtype, value)
        spec.AddArg(key, converted_value)