Example #1
0
 def __init__(self, name=None, shape=None, dtype=None):
     super(TensorType, self).__init__(name, shape)
     if dtype is None:
         self.dtype = types.fp32
     elif is_builtin(dtype):
         self.dtype = dtype
     else:
         # Assume dtype is numpy type
         try:
             self.dtype = numpy_type_to_builtin_type(dtype)
         except TypeError:
             raise TypeError("dtype={} is unsupported".format(dtype))
Example #2
0
    def __init__(self, name=None, shape=None, dtype=None, default_value=None):
        """
        Specify a (dense) tensor input.

        Parameters
        ----------
        name: str
            Input name. Must match an input name in the model (usually the
            Placeholder name for TensorFlow or the input name for PyTorch).
            
            The ``name`` is required except for a TensorFlow model in which there is
            exactly one input Placeholder.
        
        shape: (1) list of positive int or RangeDim, or (2) EnumeratedShapes
            The shape of the input.
            
            For TensorFlow:
              * The ``shape`` is optional. If omitted, the shape is inferred from
                TensorFlow graph's Placeholder shape.
            
            For PyTorch:
              * The ``shape`` is required.
        
        dtype: np.generic or mil.type type
            Numpy ``dtype`` (for example, ``np.int32``). Default is ``np.float32``.
        
        default_value: np.ndarray
            If provided, the input is considered optional. At runtime, if the
            input is not provided, ``default_value`` is used.
            
            Limitations:
              *  If ``default_value`` is ``np.ndarray``, all
                 elements are required to have the same value.

              * The ``default_value`` may not be specified if ``shape`` is
                ``EnumeratedShapes``.

        Examples
        --------
        * ``ct.TensorType(name="input", shape=(1, 2, 3))` implies `dtype ==
          np.float32``
        
        * ``ct.TensorType(name="input", shape=(1, 2, 3), dtype=np.int32)``
    	
        * ``ct.TensorType(name="input", shape=(1, 2, 3),
          dtype=ct.converters.mil.types.fp32)``
        """
        super(TensorType, self).__init__(name, shape)
        if dtype is None:
            self.dtype = types.fp32
        elif is_builtin(dtype):
            self.dtype = dtype
        else:
            # Assume dtype is numpy type
            try:
                self.dtype = numpy_type_to_builtin_type(dtype)
            except TypeError:
                raise TypeError("dtype={} is unsupported".format(dtype))

        if default_value is not None:
            if isinstance(shape, EnumeratedShapes):
                msg = 'TensorType input {} has EnumeratedShapes and ' +\
                    'may not be optional'
                raise ValueError(msg.format(name))
            if not isinstance(default_value, np.ndarray):
                msg = 'TensorType {} default_value is not np.ndarray'
                raise ValueError(msg.format(name))
            default_fill_val = default_value.flatten()[0]
            if not np.all(default_value == default_fill_val):
                msg = 'TensorType {} default_value can only have ' +\
                    'same entries'
                raise ValueError(msg.format(name))
            if not self.shape.has_symbolic and \
                list(default_value.shape) != list(self.shape.symbolic_shape):
                msg = 'TensorType {} default_value shape {} != ' +\
                    'TensorType.shape {}'
                raise ValueError(
                    msg.format(name, default_value.shape,
                               self.shape.to_list()))
            if numpy_type_to_builtin_type(default_value.dtype) != self.dtype:
                msg = 'TensorType {} default_value dtype {} != ' +\
                    'TensorType.dtype {}'
                raise ValueError(
                    msg.format(name, default_value.dtype,
                               self.dtype.__type_info__()))

        self.default_value = default_value