def call(*args, **kwargs): # Pop manually specified tolerances from the kwargs (if any). tolerances = {} tolerances["rtol"] = kwargs.pop("rtol", None) tolerances["atol"] = kwargs.pop("atol", None) # Only pass these to ModuleCall if they were specified by the user. tolerances = {k: v for k, v in tolerances.items() if v is not None} # Ensure the inputs are numpy inputs. args = tf_utils.convert_to_numpy(args) kwargs = tf_utils.convert_to_numpy(kwargs) # Run the method and record the details of the call. outputs = method(*args, **kwargs) serialized_inputs, serialized_outputs = method.get_serialized_values() self._trace.calls.append( ModuleCall(method_name, args, outputs, serialized_inputs, serialized_outputs, **tolerances)) return outputs
def __call__(self, *args, **kwargs) -> Union[Dict[str, Any], Tuple[Any], np.ndarray]: if len(args) and len(kwargs): raise ValueError( "Passing both args and kwargs is not supported by " "_TfLiteFunctionWrapper") if len(args) == 1 and isinstance(args[0], list): # Specifically to get TFLite to work with keras models that take a list of # inputs instead of a sequence of args as their inputs, because it decides # to change the input signature but it still technically works if you # ignore that it does that. if len(args) == 1 and isinstance(args[0], list): args = args[0] # Tell TFLite what the shapes of the input tensors are before allocation. if args: for arg, detail in zip(args, self._interpreter.get_input_details()): self._interpreter.resize_tensor_input(detail["index"], arg.shape) else: for detail in self._interpreter.get_input_details(): self._interpreter.resize_tensor_input( detail["index"], kwargs[detail["name"]].shape) # Allocate the (potentially dynamic) tensors. self._interpreter.allocate_tensors() # Copy the input data into the allocated tensors. if args: for arg, detail in zip(args, self._interpreter.get_input_details()): self._interpreter.set_tensor(detail["index"], arg) else: for detail in self._interpreter.get_input_details(): self._interpreter.set_tensor(detail["index"], kwargs[detail["name"]]) # Execute the function. self._interpreter.invoke() # Extract the outputs from the TFLite interpreter. outputs = [] for detail in self._interpreter.get_output_details(): # Normalize for comparison with IREE. value = tf_utils.convert_to_numpy( self._interpreter.get_tensor(detail["index"])) if self._output_names is not None: name = detail["name"] if name not in self._output_names: raise ValueError( f"Expected '{name}' to be in {self._output_names}") outputs.append([detail["name"], value]) else: outputs.append(value) # Process them to match the output of the tf.Module. if self._output_names is not None: return dict(outputs) else: if len(outputs) == 1: return outputs[0] return tuple(outputs)
def __call__(self, *args, **kwargs): # TensorFlow will auto-convert all inbound args. results = self._f(*args, **kwargs) return tf_utils.convert_to_numpy(results)