def _preprocess_inputs(inputs, op_name, device, schema=None): if isinstance(inputs, tuple): inputs = list(inputs) def is_input(x): if isinstance(x, (_DataNode, nvidia.dali.types.ScalarConstant)): return True return isinstance(x, (list)) and \ any(isinstance(y, _DataNode) for y in x) and \ all(isinstance(y, (_DataNode, nvidia.dali.types.ScalarConstant)) for y in x) default_input_device = "gpu" if device == "gpu" else "cpu" for idx, inp in enumerate(inputs): if not is_input(inp): input_device = schema.GetInputDevice(idx) or default_input_device if schema else default_input_device if not isinstance(inp, nvidia.dali.types.ScalarConstant): try: inp = _Constant(inp, device=input_device) except Exception as ex: raise TypeError("""when calling operator {0}: Input {1} is neither a DALI `DataNode` nor a list of data nodes but `{2}`. Attempt to convert it to a constant node failed.""" .format(op_name, idx, type(inp).__name__)) from ex if not isinstance(inp, _DataNode): inp = nvidia.dali.ops._instantiate_constant_node(input_device, inp) inputs[idx] = inp return inputs
def _instantiate_constant_node(device, constant): return _Constant(device=device, value=constant.value, dtype=constant.dtype, shape=constant.shape)
def __init__(self, inputs, op, **kwargs): self._counter = _OpCounter() self._outputs = [] self._op = op self._default_call_args = op._call_args self._spec = op.spec.copy() self._relation_id = self._counter.id if inputs is not None: default_input_device = "gpu" if op.device == "gpu" else "cpu" inputs = list(inputs) for i in range(len(inputs)): inp = inputs[i] if isinstance(inp, _ScalarConstant): inputs[i] = _instantiate_constant_node(default_input_device, inp) inputs = tuple(inputs) self._inputs = inputs spec_args, kwargs = _separate_kwargs(kwargs) _add_spec_args(op._schema, self._spec, spec_args) call_args = {**self._default_call_args} for k, v in kwargs.items(): if v is None: continue # if an argument was specified in __init__ and in __call__ it is None, ignore it if k in self._default_call_args: raise ValueError("The argument `{}` was already specified in __init__.".format(k)) call_args[k] = v name = call_args.get("name", None) if name is not None: self._name = name else: self._name = '__' + type(op).__name__ + "_" + str(self._counter.id) # Add inputs if inputs: for inp in inputs: if not isinstance(inp, _DataNode): raise TypeError( "Expected inputs of type `DataNode`. Received input of type '{}'." .format(type(inp).__name__)) self._spec.AddInput(inp.name, inp.device) # Argument inputs for k in sorted(call_args.keys()): if k not in ["name"]: arg_inp = call_args[k] if arg_inp is None: continue if isinstance(arg_inp, _ScalarConstant): arg_inp = _instantiate_constant_node("cpu", arg_inp) if not isinstance(arg_inp, _DataNode): try: arg_inp = _Constant(arg_inp, device="cpu") except Exception as e: raise TypeError( ("Expected inputs of type " + "`DataNode` or convertible to constant nodes. Received " + "input `{}` of type '{}'.") .format(k, type(arg_inp).__name__)) from e _check_arg_input(op._schema, type(self._op).__name__, k) self._spec.AddArgumentInput(k, arg_inp.name) self._inputs = list(self._inputs) + [arg_inp] if self._op.schema.IsDeprecated(): # TODO(klecki): how to know if this is fn or ops? msg = "WARNING: `{}` is now deprecated".format(_op_name(type(self._op).__name__, "fn")) use_instead = _op_name(self._op.schema.DeprecatedInFavorOf(), "fn") if use_instead: msg +=". Use `" + use_instead + "` instead." explanation = self._op.schema.DeprecationMessage() if explanation: msg += "\n" + explanation with warnings.catch_warnings(): warnings.simplefilter("default") warnings.warn(msg, DeprecationWarning, stacklevel=2)