def coerce_cached_input(index, name, dtype, shape): cached_feed_dict = self.cache[iteration] cached_name = misc.find_in_dict(name, cached_feed_dict, index) assert cached_name is not None if cached_name != name: G_LOGGER.warning( "Input tensor: {:24} | Cached buffer name ({:}) does not match input name ({:})." .format(name, cached_name, name)) buffer = cached_feed_dict[cached_name] if dtype != buffer.dtype: G_LOGGER.warning( "Input tensor: {:24} | Cached buffer dtype ({:}) does not match input dtype ({:}), attempting cast. " .format(name, buffer.dtype, np.dtype(dtype).name)) buffer = buffer.astype(dtype) if not misc.is_valid_shape_override(buffer.shape, shape): G_LOGGER.warning( "Input tensor: {:24} | Cached buffer shape ({:}) does not match input shape ({:}), attempting reshape. " .format(name, buffer.shape, shape)) buffer = misc.try_match_shape(buffer, shape) assert buffer.dtype == dtype and misc.is_valid_shape_override( buffer.shape, shape) return buffer
def get_static_shape(name, shape): static_shape = shape if misc.is_shape_dynamic(shape): static_shape = misc.override_dynamic_shape(shape) if static_shape != shape and name not in self.user_input_metadata: if not misc.is_valid_shape_override(static_shape, shape): G_LOGGER.critical( "Input tensor: {:24} | Cannot override original shape: {:} to {:}" .format(name, shape, static_shape)) G_LOGGER.warning( "Input tensor: {:24} | Adjusted shape: {:} to: {:}. If this is incorrect, please set input_metadata " "or provide a custom data loader.".format( name, shape, static_shape), mode=LogMode.ONCE) return static_shape
def __getitem__(self, index): """ Randomly generates input data. Args: index (int): Since this class behaves like an iterable, it takes an index parameter. Generated data is guaranteed to be the same for the same index. Returns: OrderedDict[str, numpy.ndarray]: A mapping of input names to input numpy buffers. """ if index >= self.iterations: raise IndexError() G_LOGGER.verbose( "Generating data using numpy seed: {:}".format(self.seed + index)) rng = np.random.RandomState(self.seed + index) def get_static_shape(name, shape): static_shape = shape if misc.is_shape_dynamic(shape): static_shape = misc.override_dynamic_shape(shape) if static_shape != shape and name not in self.user_input_metadata: if not misc.is_valid_shape_override(static_shape, shape): G_LOGGER.critical( "Input tensor: {:24} | Cannot override original shape: {:} to {:}" .format(name, shape, static_shape)) G_LOGGER.warning( "Input tensor: {:24} | Will generate data of shape: {:} (tensor shape is: {:}).\n" "If this is incorrect, please set input_metadata " "or provide a custom data loader.".format( name, static_shape, shape), mode=LogMode.ONCE) return static_shape # Whether the user provided the values for a shape tensor input, # rather than the shape of the input. # If the shape is 1D, and has a value equal to the rank of the provided default shape, it is # likely to be a shape tensor, and so its value, not shape, should be overriden. def is_shape_tensor(name, dtype): if name not in self.input_metadata or name not in self.user_input_metadata: return False _, shape = self.input_metadata[name] is_shape = np.issubdtype(dtype, np.integer) and ( not misc.is_shape_dynamic(shape)) and (len(shape) == 1) user_shape = self.user_input_metadata[name][1] is_shape &= len(user_shape) == shape[0] # Can't have negative values in shapes is_shape &= all([elem >= 0 for elem in user_shape]) return is_shape def generate_buffer(name, dtype, shape): if is_shape_tensor(name, dtype): buffer = np.array(shape, dtype=dtype) G_LOGGER.info( "Assuming {:} is a shape tensor. Setting input values to: {:}. If this is not correct, " "please set it correctly in 'input_metadata' or by providing --input-shapes" .format(name, buffer), mode=LogMode.ONCE) elif np.issubdtype(dtype, np.integer): # high is 1 greater than the max int drawn buffer = rng.randint(low=self.int_range[0], high=self.int_range[1] + 1, size=shape, dtype=dtype) elif np.issubdtype(dtype, np.bool_): buffer = rng.randint(low=0, high=2, size=shape).astype(dtype) else: buffer = (rng.random_sample(size=shape) * (self.float_range[1] - self.float_range[0]) + self.float_range[0]).astype(dtype) buffer = np.array( buffer ) # To handle scalars, since the above functions return a float if shape is (). return buffer if self.input_metadata is None and self.user_input_metadata is not None: self.input_metadata = self.user_input_metadata buffers = OrderedDict() for name, (dtype, shape) in self.input_metadata.items(): if name in self.user_input_metadata: user_dtype, user_shape = self.user_input_metadata[name] dtype = misc.default_value(user_dtype, dtype) is_valid_shape_override = user_shape is not None and misc.is_valid_shape_override( user_shape, shape) if not is_valid_shape_override and not is_shape_tensor( name, dtype): G_LOGGER.warning( "Input tensor: {:24} | Cannot use provided custom shape: {:}, since this input has " "a static shape: {:}".format(name, user_shape, shape), mode=LogMode.ONCE) else: shape = misc.default_value(user_shape, shape) static_shape = get_static_shape(name, shape) buffers[name] = generate_buffer(name, dtype, shape=static_shape) # Warn about unused metadata for name in self.user_input_metadata.keys(): if name not in self.input_metadata: msg = "Input tensor: {:24} | Metadata was provided, but the input does not exist in one or more runners.".format( name) close_match = misc.find_in_dict(name, self.input_metadata) if close_match: msg += "\nMaybe you meant to set: {:}".format(close_match) G_LOGGER.warning(msg) return buffers
def test_is_valid_shape_override(case): override, shape, expected = case assert misc.is_valid_shape_override(new_shape=override, original_shape=shape) == expected