Exemplo n.º 1
0
 def __init__(self, cache_key, device, **kwargs):
     self._def = None
     self._cache_key = cache_key
     self._device = device
     self._arg_device = proto_util.get_device_option('cpu')
     self._arg_device = self._arg_device.SerializeToString()
     self._seed = kwargs.get('seed', config.config().random_seed)
Exemplo n.º 2
0
 def _add_device(graph_def):
     """Add device."""
     cfg = config.config()
     spec = context.get_device()
     graph_def.device_option.CopyFrom(
         proto_util.get_device_option(
             spec.type, spec.index, cfg.random_seed))
Exemplo n.º 3
0
 def to_proto(self, serialized=True):
     """Return the device proto."""
     if self._proto is None:
         self._proto = proto_util.get_device_option(self.type, self.index)
     if serialized:
         if self._serialized_proto is None:
             self._serialized_proto = self._proto.SerializeToString()
         return self._serialized_proto
     return self._proto
Exemplo n.º 4
0
    def feed_tensor(self, tensor, value, dtype=None, enforce_cpu=False):
        """Copy the value to tensor.

        Examples:

        ```python
        # Define a named tensor to feed
        x = dragon.Tensor(name='x')
        dragon.get_workspace().feed_tensor(x, 0)

        # Feed by specifying a tensor name
        # Note that it will create the implementation whatever
        dragon.get_workspace().feed_tensor('y', 1)
        print(dragon.get_workspace().has_tensor('y'))  # True
        ```

        Parameters
        ----------
        tensor : Union[dragon.Tensor, str]
            The tensor to feed.
        value : array_like
            The value to copy.
        dtype : str, optional
            The optional data type.
        enforce_cpu : bool, optional, default=False
            **True** to copy using cpu context.

        """
        if types.is_tensor(value):
            # Steal the data if value is a tensor
            value = getattr(value, 'get_value')()
        # Determine the data type from argument or value
        if not isinstance(value, numpy.ndarray):
            dtype = 'float32' if dtype is None else dtype
        else:
            dtype = value.dtype if dtype is None else dtype
        if hasattr(tensor, 'dtype') and tensor.dtype is not None:
            if tensor.dtype not in mapping.TENSOR_TYPE_TO_NP_TYPE:
                raise TypeError('Unsupported data type: ' + tensor.dtype)
            dtype = mapping.TENSOR_TYPE_TO_NP_TYPE[tensor.dtype]
        # Determine the copying device option
        if enforce_cpu is True:
            device_option = proto_util.get_device_option('cpu')
        else:
            device_option = proto_util.get_default_device_option()
            if device_option is None:
                device_option = proto_util.get_global_device_option()
        # Copy data to the backend
        self.FeedTensor(
            _stringify_object(tensor),
            numpy.array(value, dtype=dtype, copy=False),
            serialization.serialize_proto(device_option),
        )
Exemplo n.º 5
0
 def _gen_def(self):
     """Generate the OpDef from attributes."""
     attributes = self.attributes()
     self._def = proto_util.make_operator_def_cpp(
         name=attributes.get('name', 'Op'),
         cache_key=self._cache_key,
         op_type=attributes['op_type'],
         device_option=proto_util.get_device_option(
             self._device.type,
             self._device.index,
             self._seed,
         ),
         **attributes['arguments'])
Exemplo n.º 6
0
    def __init__(self, key, dev, **kwargs):
        """Create a ``Function``.

        Parameters
        ----------
        key : str
            The cache key.
        device : dragon.vm.torch.device
            The device spec.

        """
        super(Function, self).__init__()
        self._def = None
        self._cache_key = key
        self._device = dev
        self._arg_device = proto_util.get_device_option('cpu')
        self._arg_device = self._arg_device.SerializeToString()
        self._seed = kwargs.get('seed', config.config().random_seed)