Esempio n. 1
0
File: ops.py Progetto: xvdp/DALI
def _preprocess_inputs(inputs, op_name, device, schema=None):
    if isinstance(inputs, tuple):
        inputs = list(inputs)

    def is_input(x):
        if isinstance(x, (_DataNode, nvidia.dali.types.ScalarConstant)):
            return True
        return isinstance(x, (list)) and \
                any(isinstance(y, _DataNode) for y in x) and \
                all(isinstance(y, (_DataNode, nvidia.dali.types.ScalarConstant)) for y in x)

    default_input_device = "gpu" if device == "gpu" else "cpu"

    for idx, inp in enumerate(inputs):
        if not is_input(inp):
            input_device = schema.GetInputDevice(idx) or default_input_device if schema else default_input_device
            if not isinstance(inp, nvidia.dali.types.ScalarConstant):
                try:
                    inp = _Constant(inp, device=input_device)
                except Exception as ex:
                    raise TypeError("""when calling operator {0}:
Input {1} is neither a DALI `DataNode` nor a list of data nodes but `{2}`.
Attempt to convert it to a constant node failed."""
                    .format(op_name, idx, type(inp).__name__)) from ex

            if not isinstance(inp, _DataNode):
                inp = nvidia.dali.ops._instantiate_constant_node(input_device, inp)

        inputs[idx] = inp
    return inputs
Esempio n. 2
0
def _instantiate_constant_node(device, constant):
    return _Constant(device=device,
                     value=constant.value,
                     dtype=constant.dtype,
                     shape=constant.shape)
Esempio n. 3
0
    def __init__(self, inputs, op, **kwargs):
        self._counter = _OpCounter()
        self._outputs = []
        self._op = op
        self._default_call_args = op._call_args
        self._spec = op.spec.copy()
        self._relation_id = self._counter.id

        if inputs is not None:
            default_input_device = "gpu" if op.device == "gpu" else "cpu"
            inputs = list(inputs)
            for i in range(len(inputs)):
                inp = inputs[i]
                if isinstance(inp, _ScalarConstant):
                    inputs[i] = _instantiate_constant_node(default_input_device, inp)
            inputs = tuple(inputs)

        self._inputs = inputs

        spec_args, kwargs = _separate_kwargs(kwargs)
        _add_spec_args(op._schema, self._spec, spec_args)

        call_args = {**self._default_call_args}
        for k, v in kwargs.items():
            if v is None:
                continue  # if an argument was specified in __init__ and in __call__ it is None, ignore it
            if k in self._default_call_args:
                raise ValueError("The argument `{}` was already specified in __init__.".format(k))
            call_args[k] = v

        name = call_args.get("name", None)
        if name is not None:
            self._name = name
        else:
            self._name = '__' + type(op).__name__ + "_" + str(self._counter.id)
        # Add inputs
        if inputs:
            for inp in inputs:
                if not isinstance(inp, _DataNode):
                    raise TypeError(
                        "Expected inputs of type `DataNode`. Received input of type '{}'."
                        .format(type(inp).__name__))
                self._spec.AddInput(inp.name, inp.device)
        # Argument inputs
        for k in sorted(call_args.keys()):
            if k not in ["name"]:
                arg_inp = call_args[k]
                if arg_inp is None:
                    continue
                if isinstance(arg_inp, _ScalarConstant):
                    arg_inp = _instantiate_constant_node("cpu", arg_inp)
                if not isinstance(arg_inp, _DataNode):
                    try:
                        arg_inp = _Constant(arg_inp, device="cpu")
                    except Exception as e:
                        raise TypeError(
                                ("Expected inputs of type " +
                                "`DataNode` or convertible to constant nodes. Received " +
                                "input `{}` of type '{}'.")
                                .format(k, type(arg_inp).__name__)) from e

                _check_arg_input(op._schema, type(self._op).__name__, k)

                self._spec.AddArgumentInput(k, arg_inp.name)
                self._inputs = list(self._inputs) + [arg_inp]

        if self._op.schema.IsDeprecated():
            # TODO(klecki): how to know if this is fn or ops?
            msg = "WARNING: `{}` is now deprecated".format(_op_name(type(self._op).__name__, "fn"))
            use_instead = _op_name(self._op.schema.DeprecatedInFavorOf(), "fn")
            if use_instead:
                msg +=". Use `" + use_instead + "` instead."
            explanation = self._op.schema.DeprecationMessage()
            if explanation:
                msg += "\n" + explanation
            with warnings.catch_warnings():
                warnings.simplefilter("default")
                warnings.warn(msg, DeprecationWarning, stacklevel=2)