Esempio n. 1
0
    def __init__(self,
                 op: str,
                 name: str = None,
                 attrs: Dict[str, object] = None,
                 inputs: List["Tensor"] = None,
                 outputs: List["Tensor"] = None):
        """
        A node represents an operation in a graph, and consumes zero or more Tensors, and produces zero or more Tensors.

        Args:
            op (str): The operation this node performs.

            name (str): The name of this node.
            attrs (Dict[str, object]): A dictionary that maps attribute names to their values.
            inputs (List[Tensor]): A list of zero or more input Tensors.
            outputs (List[Tensor]): A list of zero or more output Tensors.
        """
        self.op = op
        self.name = misc.default_value(name, "")
        self.attrs = misc.default_value(attrs, OrderedDict())
        self.inputs = misc.SynchronizedList(self,
                                            field_name="outputs",
                                            initial=misc.default_value(
                                                inputs, []))
        self.outputs = misc.SynchronizedList(self,
                                             field_name="inputs",
                                             initial=misc.default_value(
                                                 outputs, []))
Esempio n. 2
0
    def __init__(self,
                 name: str,
                 values: Union[np.ndarray, LazyValues],
                 data_location: int = None):
        """
        Represents a Tensor whose value is known.

        Args:
            name (str): The name of the tensor.
            values (numpy.ndarray): The values in this tensor, in the form of a NumPy array.

            data_location (int):
                    An enum value indicating the location where the tensor data is stored.
                    Generally, this will come from onnx.TensorProto.DataLocation.
        """
        self.name = name
        self.inputs = misc.SynchronizedList(self,
                                            field_name="outputs",
                                            initial=[])
        self.outputs = misc.SynchronizedList(self,
                                             field_name="inputs",
                                             initial=[])
        if not isinstance(values, np.ndarray) and not isinstance(
                values, LazyValues):
            G_LOGGER.critical(
                "Provided `values` argument is not a NumPy array or a LazyValues instance. "
                "Please provide a NumPy array or LazyValues instance to construct a Constant. "
                "Note: Provided `values` parameter was: {:}".format(values))
        self._values = values
        self.data_location = data_location
Esempio n. 3
0
    def __init__(self, name: str, dtype: np.dtype=None, shape: Sequence[Union[int, str]]=None):
        """
        Represents a Tensor whose value is not known until inference-time.

        Args:
            name (str): The name of the tensor.
            dtype (numpy.dtype): The data type of the tensor.
            shape (Sequence[Union[int, str]]): The shape of the tensor. This may contain strings if the model uses dimension parameters.
        """
        self.name = name
        self.inputs = misc.SynchronizedList(self, field_name="outputs", initial=[])
        self.outputs = misc.SynchronizedList(self, field_name="inputs", initial=[])
        self.dtype = dtype
        self.shape = misc.default_value(shape, None)
Esempio n. 4
0
    def __init__(self, name: str, values: np.ndarray):
        """
        Represents a Tensor whose value is known.

        Args:
            name (str): The name of the tensor.
            values (numpy.ndarray): The values in this tensor, in the form of a NumPy array.
            dtype (numpy.dtype): The data type of the tensor.
            shape (Sequence[Union[int, str]]): The shape of the tensor.
        """
        self.name = name
        self.inputs = misc.SynchronizedList(self, field_name="outputs", initial=[])
        self.outputs = misc.SynchronizedList(self, field_name="inputs", initial=[])
        if not isinstance(values, np.ndarray):
            G_LOGGER.critical("Provided `values` argument is not a NumPy array (please provide a NumPy array to construct a Constant): {:}".format(values))
        self.values = np.array(values)