Exemplo n.º 1
0
        def coerce_cached_input(index, name, dtype, shape):
            cached_feed_dict = self.cache[iteration]
            cached_name = misc.find_in_dict(name, cached_feed_dict, index)
            assert cached_name is not None

            if cached_name != name:
                G_LOGGER.warning(
                    "Input tensor: {:24} | Cached buffer name ({:}) does not match input name ({:})."
                    .format(name, cached_name, name))

            buffer = cached_feed_dict[cached_name]

            if dtype != buffer.dtype:
                G_LOGGER.warning(
                    "Input tensor: {:24} | Cached buffer dtype ({:}) does not match input dtype ({:}), attempting cast. "
                    .format(name, buffer.dtype,
                            np.dtype(dtype).name))
                buffer = buffer.astype(dtype)

            if not misc.is_valid_shape_override(buffer.shape, shape):
                G_LOGGER.warning(
                    "Input tensor: {:24} | Cached buffer shape ({:}) does not match input shape ({:}), attempting reshape. "
                    .format(name, buffer.shape, shape))
                buffer = misc.try_match_shape(buffer, shape)

            assert buffer.dtype == dtype and misc.is_valid_shape_override(
                buffer.shape, shape)
            return buffer
Exemplo n.º 2
0
 def get_static_shape(name, shape):
     static_shape = shape
     if misc.is_shape_dynamic(shape):
         static_shape = misc.override_dynamic_shape(shape)
         if static_shape != shape and name not in self.user_input_metadata:
             if not misc.is_valid_shape_override(static_shape, shape):
                 G_LOGGER.critical(
                     "Input tensor: {:24} | Cannot override original shape: {:} to {:}"
                     .format(name, shape, static_shape))
             G_LOGGER.warning(
                 "Input tensor: {:24} | Adjusted shape: {:} to: {:}. If this is incorrect, please set input_metadata "
                 "or provide a custom data loader.".format(
                     name, shape, static_shape),
                 mode=LogMode.ONCE)
     return static_shape
Exemplo n.º 3
0
    def __getitem__(self, index):
        """
        Randomly generates input data.

        Args:
            index (int):
                    Since this class behaves like an iterable, it takes an index parameter.
                    Generated data is guaranteed to be the same for the same index.

            Returns:
                OrderedDict[str, numpy.ndarray]: A mapping of input names to input numpy buffers.
        """
        if index >= self.iterations:
            raise IndexError()

        G_LOGGER.verbose(
            "Generating data using numpy seed: {:}".format(self.seed + index))
        rng = np.random.RandomState(self.seed + index)

        def get_static_shape(name, shape):
            static_shape = shape
            if misc.is_shape_dynamic(shape):
                static_shape = misc.override_dynamic_shape(shape)
                if static_shape != shape and name not in self.user_input_metadata:
                    if not misc.is_valid_shape_override(static_shape, shape):
                        G_LOGGER.critical(
                            "Input tensor: {:24} | Cannot override original shape: {:} to {:}"
                            .format(name, shape, static_shape))
                    G_LOGGER.warning(
                        "Input tensor: {:24} | Will generate data of shape: {:} (tensor shape is: {:}).\n"
                        "If this is incorrect, please set input_metadata "
                        "or provide a custom data loader.".format(
                            name, static_shape, shape),
                        mode=LogMode.ONCE)
            return static_shape

        # Whether the user provided the values for a shape tensor input,
        # rather than the shape of the input.
        # If the shape is 1D, and has a value equal to the rank of the provided default shape, it is
        # likely to be a shape tensor, and so its value, not shape, should be overriden.
        def is_shape_tensor(name, dtype):
            if name not in self.input_metadata or name not in self.user_input_metadata:
                return False

            _, shape = self.input_metadata[name]
            is_shape = np.issubdtype(dtype, np.integer) and (
                not misc.is_shape_dynamic(shape)) and (len(shape) == 1)

            user_shape = self.user_input_metadata[name][1]
            is_shape &= len(user_shape) == shape[0]
            # Can't have negative values in shapes
            is_shape &= all([elem >= 0 for elem in user_shape])
            return is_shape

        def generate_buffer(name, dtype, shape):
            if is_shape_tensor(name, dtype):
                buffer = np.array(shape, dtype=dtype)
                G_LOGGER.info(
                    "Assuming {:} is a shape tensor. Setting input values to: {:}. If this is not correct, "
                    "please set it correctly in 'input_metadata' or by providing --input-shapes"
                    .format(name, buffer),
                    mode=LogMode.ONCE)
            elif np.issubdtype(dtype, np.integer):
                # high is 1 greater than the max int drawn
                buffer = rng.randint(low=self.int_range[0],
                                     high=self.int_range[1] + 1,
                                     size=shape,
                                     dtype=dtype)
            elif np.issubdtype(dtype, np.bool_):
                buffer = rng.randint(low=0, high=2, size=shape).astype(dtype)
            else:
                buffer = (rng.random_sample(size=shape) *
                          (self.float_range[1] - self.float_range[0]) +
                          self.float_range[0]).astype(dtype)

            buffer = np.array(
                buffer
            )  # To handle scalars, since the above functions return a float if shape is ().
            return buffer

        if self.input_metadata is None and self.user_input_metadata is not None:
            self.input_metadata = self.user_input_metadata

        buffers = OrderedDict()
        for name, (dtype, shape) in self.input_metadata.items():
            if name in self.user_input_metadata:
                user_dtype, user_shape = self.user_input_metadata[name]

                dtype = misc.default_value(user_dtype, dtype)

                is_valid_shape_override = user_shape is not None and misc.is_valid_shape_override(
                    user_shape, shape)
                if not is_valid_shape_override and not is_shape_tensor(
                        name, dtype):
                    G_LOGGER.warning(
                        "Input tensor: {:24} | Cannot use provided custom shape: {:}, since this input has "
                        "a static shape: {:}".format(name, user_shape, shape),
                        mode=LogMode.ONCE)
                else:
                    shape = misc.default_value(user_shape, shape)

            static_shape = get_static_shape(name, shape)
            buffers[name] = generate_buffer(name, dtype, shape=static_shape)

        # Warn about unused metadata
        for name in self.user_input_metadata.keys():
            if name not in self.input_metadata:
                msg = "Input tensor: {:24} | Metadata was provided, but the input does not exist in one or more runners.".format(
                    name)
                close_match = misc.find_in_dict(name, self.input_metadata)
                if close_match:
                    msg += "\nMaybe you meant to set: {:}".format(close_match)
                G_LOGGER.warning(msg)

        return buffers
Exemplo n.º 4
0
def test_is_valid_shape_override(case):
    override, shape, expected = case
    assert misc.is_valid_shape_override(new_shape=override,
                                        original_shape=shape) == expected