def coerce_cached_input(index, name, dtype, shape): cached_feed_dict = self.cache[iteration] cached_name = util.find_in_dict(name, cached_feed_dict, index) assert cached_name is not None if cached_name != name: G_LOGGER.warning("Input tensor: {:} | Cached buffer name ({:}) does not match input name ({:}).".format( name, cached_name, name)) buffer = cached_feed_dict[cached_name] if dtype != buffer.dtype: G_LOGGER.warning("Input tensor: {:} | Cached buffer dtype ({:}) does not match input dtype ({:}), attempting cast. ".format( name, buffer.dtype, np.dtype(dtype).name)) type_info = None if np.issubdtype(dtype, np.integer): type_info = np.iinfo(np.dtype(dtype)) elif np.issubdtype(dtype, np.floating): type_info = np.finfo(np.dtype(dtype)) if type_info is not None and np.any((buffer < type_info.min) | (buffer > type_info.max)): G_LOGGER.warning("Some values in this input arre out of range of {:}. Unexpected behavior may ensue!".format(dtype)) buffer = buffer.astype(dtype) if not util.is_valid_shape_override(buffer.shape, shape): G_LOGGER.warning("Input tensor: {:} | Cached buffer shape ({:}) does not match input shape ({:}), attempting reshape. ".format( name, buffer.shape, shape)) buffer = util.try_match_shape(buffer, shape) assert buffer.dtype == dtype and util.is_valid_shape_override(buffer.shape, shape) return buffer
def infer(self, feed_dict, check_inputs=True, *args, **kwargs): """ Runs inference using the provided feed_dict. NOTE: Some runners may accept additional parameters in infer(). For details on these, see the documentation for their `infer_impl()` methods. Args: feed_dict (OrderedDict[str, numpy.ndarray]): A mapping of input tensor names to corresponding input NumPy arrays. check_inputs (bool): Whether to check that the provided ``feed_dict`` includes the expected inputs with the expected data types and shapes. Disabling this may improve performance. Defaults to True. Returns: OrderedDict[str, numpy.ndarray]: A mapping of output tensor names to their corresponding NumPy arrays. IMPORTANT: Runners may reuse these output buffers. Thus, if you need to save outputs from multiple inferences, you should make a copy with ``copy.deepcopy(outputs)``. """ if not self.is_active: G_LOGGER.critical( "{:35} | Must be activated prior to calling infer()".format( self.name)) if check_inputs: input_metadata = self.get_input_metadata() G_LOGGER.verbose( "Runner input metadata is: {:}".format(input_metadata)) util.check_dict_contains(feed_dict, input_metadata.keys(), dict_name="feed_dict", log_func=G_LOGGER.critical) for name, inp in feed_dict.items(): meta = input_metadata[name] if not np.issubdtype(inp.dtype, meta.dtype): G_LOGGER.critical( "Input tensor: {:} | Received unexpected dtype: {:}.\n" "Note: Expected type: {:}".format( name, inp.dtype, meta.dtype)) if not util.is_valid_shape_override(inp.shape, meta.shape): G_LOGGER.critical( "Input tensor: {:} | Received incompatible shape: {:}.\n" "Note: Expected a shape compatible with: {:}".format( name, inp.shape, meta.shape)) return self.infer_impl(feed_dict, *args, **kwargs)
def get_static_shape(name, shape): static_shape = shape if util.is_shape_dynamic(shape): static_shape = util.override_dynamic_shape(shape) if static_shape != shape and name not in self.user_input_metadata: if not util.is_valid_shape_override(static_shape, shape): G_LOGGER.critical("Input tensor: {:} | Cannot override original shape: {:} to {:}".format(name, shape, static_shape)) G_LOGGER.warning("Input tensor: {:} | Will generate data of shape: {:}.\n" "If this is incorrect, please set input_metadata " "or provide a custom data loader.".format(name, static_shape), mode=LogMode.ONCE) return static_shape
def __getitem__(self, index): """ Generates random input data. May update the DataLoader's `input_metadata` attribute. Args: index (int): Since this class behaves like an iterable, it takes an index parameter. Generated data is guaranteed to be the same for the same index. Returns: OrderedDict[str, numpy.ndarray]: A mapping of input names to input numpy buffers. """ if index >= self.iterations: raise IndexError() G_LOGGER.verbose( "Generating data using numpy seed: {:}".format(self.seed + index)) rng = np.random.RandomState(self.seed + index) def get_static_shape(name, shape): static_shape = shape if util.is_shape_dynamic(shape): static_shape = util.override_dynamic_shape(shape) if static_shape != shape: if not util.is_valid_shape_override(static_shape, shape): G_LOGGER.critical( "Input tensor: {:} | Cannot override original shape: {:} to {:}" .format(name, shape, static_shape)) G_LOGGER.warning( "Input tensor: {:} | Will generate data of shape: {:}.\n" "If this is incorrect, please set input_metadata " "or provide a custom data loader.".format( name, static_shape), mode=LogMode.ONCE, ) return static_shape # Whether the user provided the values for a shape tensor input, # rather than the shape of the input. # If the shape is 1D, and has a value equal to the rank of the provided default shape, it is # likely to be a shape tensor, and so its value, not shape, should be overriden. def is_shape_tensor(name, dtype): if name not in self.input_metadata or name not in self.user_input_metadata: return False _, shape = self.input_metadata[name] is_shape = np.issubdtype(dtype, np.integer) and ( not util.is_shape_dynamic(shape)) and (len(shape) == 1) user_shape = self.user_input_metadata[name].shape is_shape &= len(user_shape) == shape[0] is_shape &= not util.is_shape_dynamic( user_shape) # Shape of shape cannot be dynamic. return is_shape def generate_buffer(name, dtype, shape): if is_shape_tensor(name, dtype): buffer = np.array(shape, dtype=dtype) G_LOGGER.info( "Assuming {:} is a shape tensor. Setting input values to: {:}. If this is not correct, " "please set it correctly in 'input_metadata' or by providing --input-shapes" .format(name, buffer), mode=LogMode.ONCE, ) elif np.issubdtype(dtype, np.integer) or np.issubdtype( dtype, np.bool_): imin, imax = self._get_range(name, cast_type=int if np.issubdtype( dtype, np.integer) else bool) G_LOGGER.verbose( "Input tensor: {:} | Generating input data in range: [{:}, {:}]" .format(name, imin, imax), mode=LogMode.ONCE, ) # high is 1 greater than the max int drawn. buffer = rng.randint(low=imin, high=imax + 1, size=shape, dtype=dtype) else: fmin, fmax = self._get_range(name, cast_type=float) G_LOGGER.verbose( "Input tensor: {:} | Generating input data in range: [{:}, {:}]" .format(name, fmin, fmax), mode=LogMode.ONCE, ) buffer = (rng.random_sample(size=shape) * (fmax - fmin) + fmin).astype(dtype) buffer = np.array( buffer ) # To handle scalars, since the above functions return a float if shape is (). return buffer if self.input_metadata is None and self.user_input_metadata is not None: self.input_metadata = self.user_input_metadata buffers = OrderedDict() for name, (dtype, shape) in self.input_metadata.items(): if name in self.user_input_metadata: user_dtype, user_shape = self.user_input_metadata[name] dtype = util.default(user_dtype, dtype) is_valid_shape_override = user_shape is not None and util.is_valid_shape_override( user_shape, shape) if util.is_shape_dynamic(user_shape): G_LOGGER.warning( "Input tensor: {:} | Provided input shape: {:} is dynamic.\n" "Dynamic shapes cannot be used to generate inference data. " "Will use default shape instead.\n" "To avoid this, please provide a fixed shape to the data loader. " .format(name, user_shape)) elif not is_valid_shape_override and not is_shape_tensor( name, dtype): G_LOGGER.warning( "Input tensor: {:} | Cannot use provided custom shape: {:} " "to override: {:}. Will use default shape instead.". format(name, user_shape, shape), mode=LogMode.ONCE, ) else: shape = util.default(user_shape, shape) static_shape = get_static_shape(name, shape) buffers[name] = generate_buffer(name, dtype, shape=static_shape) # Warn about unused metadata for name in self.user_input_metadata.keys(): if name not in self.input_metadata: msg = "Input tensor: {:} | Metadata was provided, but the input does not exist in one or more runners.".format( name) close_match = util.find_in_dict(name, self.input_metadata) if close_match: msg += "\nMaybe you meant to set: {:}".format(close_match) G_LOGGER.warning(msg) # Warn about unused val_range if not isinstance(self.val_range, tuple): util.check_dict_contains(self.val_range, list(self.input_metadata.keys()) + [""], check_missing=False, dict_name="val_range") return buffers
def test_is_valid_shape_override(case): override, shape, expected = case assert util.is_valid_shape_override(new_shape=override, original_shape=shape) == expected