コード例 #1
    def _verify_input_spec(self, input_spec):
        Verifies the `input_spec` and its element type is valid.
        if not isinstance(input_spec, (tuple, list)):
            raise TypeError(
                "The type(input_spec) should be one of (tuple, list), but received {}."
        input_spec = tuple(input_spec)
        for spec in flatten(input_spec):
            if not isinstance(spec, paddle.static.InputSpec):
                raise ValueError(
                    "The type(elem) from input_spec should be `InputSpec`, but received {}."

        return input_spec
コード例 #2
def convert_to_input_spec(inputs, input_spec):
    Replaces tensor in structured `inputs` by InputSpec in `input_spec`.
        inputs(list|dict): nested structure list or dict.
        input_spec(list|dict): same nested structure list or dict as inputs. 

        Same structure with inputs by replacing the element with specified InputSpec.
    def check_type_and_len(input, spec, check_length=False):
        if type(input) is not type(spec):
            raise TypeError(
                'type(input) should be {}, but received {}.'.format(
                    type(spec), type(input)))
        if check_length and len(input) < len(spec):
            raise ValueError(
                'Requires len(inputs) >= len(input_spec), but received len(inputs):{} < len(input_spec):{}'
                .format(len(inputs), len(input_spec)))

    if isinstance(input_spec, (tuple, list)):
        input_with_spec = []
        check_type_and_len(inputs, input_spec, True)

        for i, spec in enumerate(input_spec):
            out_spec = convert_to_input_spec(inputs[i], spec)

        # Note: If the rest inputs contain tensor or numpy.ndarray
        # without specific InputSpec, raise warning.
        if len(inputs) > len(input_spec):
            for rest_input in inputs[len(input_spec):]:
                if isinstance(rest_input, (core.VarBase, np.ndarray)):
                        "The inputs constain `{}` without specificing InputSpec, its shape and dtype will be treated immutable. "
                        "Please specific InputSpec information in `@declarative` if you expect them as mutable inputs."

        return input_with_spec
    elif isinstance(input_spec, dict):
        input_with_spec = {}
        check_type_and_len(inputs, input_spec, True)
        for name, input in six.iteritems(inputs):
            if name in input_spec:
                input_with_spec[name] = convert_to_input_spec(
                    input, input_spec[name])
                input_with_spec[name] = input
        return input_with_spec
    elif isinstance(input_spec, paddle.static.InputSpec):
        return input_spec
        raise TypeError(
            "The type(input_spec) should be a `InputSpec` or dict/list/tuple of it, but received {}."
コード例 #3
    def _verify_input_spec(self, input_spec):
        Verifies the `input_spec` and its element type is valid.
        if not isinstance(input_spec, (tuple, list)):
            raise TypeError(
                "The type(input_spec) should be one of (tuple, list), but received {}."

        return tuple(input_spec)
コード例 #4
 def get_program(self, item):
     if not isinstance(item, CacheKey):
         raise ValueError(
             "Input item's type should be FunctionSpec, but received %s" %
     if item not in self._caches:
         raise RuntimeError(
             "Failed to find program for input item, please decorate input function by `@paddle.jit.to_static`."
     return self._caches[item]
コード例 #5
    def __getitem__(self, item):
        if not isinstance(item, CacheKey):
            raise ValueError('type(item) should be CacheKey, but received %s' %

        if item not in self._caches:
            self._caches[item] = self._build_once(item)
            # Note: raise warnings if number of traced program is more than `max_tracing_count`
            current_tracing_count = len(self._caches)
            if current_tracing_count > MAX_TRACED_PROGRAM_COUNT:
                    "Current traced program number: {} > `max_tracing_count`:{}. Too much cached programs will bring expensive overhead. "
                    "The reason may be: (1) passing tensors with different shapes, (2) passing python objects instead of tensors.".
                    format(current_tracing_count, MAX_TRACED_PROGRAM_COUNT))

        return self._caches[item]
コード例 #6
def get_buffers(layer_instance, include_sublayer=True):
    Returns Variable buffers of decorated layers. If set `include_sublayer` True,
    the Variable buffers created in sub layers will be added.
    buffers = collections.OrderedDict()
    if layer_instance is not None:
        if isinstance(layer_instance, layers.Layer):
            if include_sublayer:
                buffers = layer_instance.buffers()
                names = [buffer.name for buffer in buffers]
                buffers = collections.OrderedDict(zip(names, buffers))
                buffers = layer_instance._buffers
            raise TypeError(
                "Type of `layer_instance` should be nn.Layer, but received {}".
    return buffers
コード例 #7
def get_parameters(layer_instance, include_sublayer=True):
    Returns parameters of decorated layers. If set `include_sublayer` True,
    the parameters created in sub layers will be added.
    params = collections.OrderedDict()
    if layer_instance is not None:
        if isinstance(layer_instance, layers.Layer):
            if include_sublayer:
                params = layer_instance.parameters()
                names = [p.name for p in params]
                params = collections.OrderedDict(zip(names, params))
                params = layer_instance._parameters
            raise TypeError(
                "Type of `layer_instance` should be nn.Layer, but received {}".

    return params