def __init__(self, name=None, shape=None, dtype=None): super(TensorType, self).__init__(name, shape) if dtype is None: self.dtype = types.fp32 elif is_builtin(dtype): self.dtype = dtype else: # Assume dtype is numpy type try: self.dtype = numpy_type_to_builtin_type(dtype) except TypeError: raise TypeError("dtype={} is unsupported".format(dtype))
def __init__(self, name=None, shape=None, dtype=None, default_value=None): """ Specify a (dense) tensor input. Parameters ---------- name: str Input name. Must match an input name in the model (usually the Placeholder name for TensorFlow or the input name for PyTorch). The ``name`` is required except for a TensorFlow model in which there is exactly one input Placeholder. shape: (1) list of positive int or RangeDim, or (2) EnumeratedShapes The shape of the input. For TensorFlow: * The ``shape`` is optional. If omitted, the shape is inferred from TensorFlow graph's Placeholder shape. For PyTorch: * The ``shape`` is required. dtype: np.generic or mil.type type Numpy ``dtype`` (for example, ``np.int32``). Default is ``np.float32``. default_value: np.ndarray If provided, the input is considered optional. At runtime, if the input is not provided, ``default_value`` is used. Limitations: * If ``default_value`` is ``np.ndarray``, all elements are required to have the same value. * The ``default_value`` may not be specified if ``shape`` is ``EnumeratedShapes``. Examples -------- * ``ct.TensorType(name="input", shape=(1, 2, 3))` implies `dtype == np.float32`` * ``ct.TensorType(name="input", shape=(1, 2, 3), dtype=np.int32)`` * ``ct.TensorType(name="input", shape=(1, 2, 3), dtype=ct.converters.mil.types.fp32)`` """ super(TensorType, self).__init__(name, shape) if dtype is None: self.dtype = types.fp32 elif is_builtin(dtype): self.dtype = dtype else: # Assume dtype is numpy type try: self.dtype = numpy_type_to_builtin_type(dtype) except TypeError: raise TypeError("dtype={} is unsupported".format(dtype)) if default_value is not None: if isinstance(shape, EnumeratedShapes): msg = 'TensorType input {} has EnumeratedShapes and ' +\ 'may not be optional' raise ValueError(msg.format(name)) if not isinstance(default_value, np.ndarray): msg = 'TensorType {} default_value is not np.ndarray' raise ValueError(msg.format(name)) default_fill_val = default_value.flatten()[0] if not np.all(default_value == default_fill_val): msg = 'TensorType {} default_value can only have ' +\ 'same entries' raise ValueError(msg.format(name)) if not self.shape.has_symbolic and \ list(default_value.shape) != list(self.shape.symbolic_shape): msg = 'TensorType {} default_value shape {} != ' +\ 'TensorType.shape {}' raise ValueError( msg.format(name, default_value.shape, self.shape.to_list())) if numpy_type_to_builtin_type(default_value.dtype) != self.dtype: msg = 'TensorType {} default_value dtype {} != ' +\ 'TensorType.dtype {}' raise ValueError( msg.format(name, default_value.dtype, self.dtype.__type_info__())) self.default_value = default_value