def __init__(self, hidden_1_dim=256, hidden_2_dim=256, input_dim=28 * 28, classes=10): super().__init__() tf_utils.set_random_seed() self.hidden_1_dim = hidden_1_dim self.hidden_2_dim = hidden_2_dim self.input_dim = input_dim self.classes = classes self.h1_weights = tf.Variable( tf.random.normal([input_dim, hidden_1_dim])) self.h2_weights = tf.Variable( tf.random.normal([hidden_1_dim, hidden_2_dim])) self.out_weights = tf.Variable( tf.random.normal([hidden_2_dim, classes])) self.h1_bias = tf.Variable(tf.random.normal([hidden_1_dim])) self.h2_bias = tf.Variable(tf.random.normal([hidden_2_dim])) self.out_bias = tf.Variable(tf.random.normal([classes])) # Compile with dynamic batch dim. self.predict = tf.function( input_signature=[tf.TensorSpec([None, self.input_dim])])( self.predict)
def __init__(self): super().__init__() tf_utils.set_random_seed() model_path = posixpath.join(FLAGS.tf_hub_url, FLAGS.model, MODE) hub_layer = hub.KerasLayer(model_path) self.m = tf.keras.Sequential([hub_layer]) input_shape = get_input_shape() self.m.build(input_shape) self.predict = tf.function( input_signature=[tf.TensorSpec(input_shape)])(self.m.call)
def create_from_class(cls, module_class: Type[tf.Module], backend_info: "BackendInfo", exported_names: Sequence[str] = (), artifacts_dir: str = None): """Compile a tf.Module subclass to the target backend in backend_info. Args: module_class: The tf.Module subclass to compile. backend_info: BackendInfo with the details for compiling module to IREE. exported_names: Optional sequence representing the exported names to keep. artifacts_dir: An optional string pointing to where compilation artifacts should be saved. No compilation artifacts will be saved if this is not provided. """ tf_utils.set_random_seed() module_instance = module_class() return cls.create_from_instance(module_instance, backend_info, exported_names, artifacts_dir)
def create_from_class(cls, module_class: Type[tf.Module], backend_info: "BackendInfo", exported_names: Sequence[str] = (), artifacts_dir: str = None): """Compile a tf.Module subclass to the target backend in backend_info. Args: module_class: The tf.Module subclass to compile. backend_info: BackendInfo with the details for compiling this module. exported_names: Optional sequence representing the exported names to keep. artifacts_dir: An optional string pointing to where compilation artifacts should be saved. No compilation artifacts will be saved if this is not provided. """ tf_utils.set_random_seed() tflite_module_bytes = tf_module_to_tflite_module_bytes( module_class, exported_names) interpreters, compiled_paths = tflite_module_bytes_to_tflite_interpreters( tflite_module_bytes, artifacts_dir) module_name = module_class.__name__ return cls(module_name, backend_info, compiled_paths, interpreters)
def reinitialize(self): """Reinitializes all stateful variables.""" tf_utils.set_random_seed() self._tf_module = self._constructor()
def compare_backends(self, trace_function: Callable[[trace_utils.TracedModule], None], modules: Modules) -> None: """Run the reference and target backends on trace_function and compare them. Random seeds for tensorflow, numpy and python are set before each invocation of trace_function. Args: trace_function: a function accepting a TracedModule as its argument. """ # Create Traces for each backend. ref_trace = trace_utils.Trace(modules.ref_module, trace_function) tar_traces = [ trace_utils.Trace(module, trace_function) for module in modules.tar_modules ] # Run the traces through trace_function with their associated modules. tf_utils.set_random_seed() trace_function(trace_utils.TracedModule(modules.ref_module, ref_trace)) if FLAGS.log_all_traces: logging.info(ref_trace) for module, trace in zip(modules.tar_modules, tar_traces): tf_utils.set_random_seed() trace_function(trace_utils.TracedModule(module, trace)) if FLAGS.log_all_traces: logging.info(trace) # Compare each target trace of trace_function with the reference trace. failed_backend_indices = [] error_messages = [] for i, tar_trace in enumerate(tar_traces): logging.info("Comparing the reference backend '%s' with '%s'", ref_trace.backend_id, tar_trace.backend_id) traces_match, errors = trace_utils.compare_traces(ref_trace, tar_trace) if not traces_match: failed_backend_indices.append(i) error_messages.extend(errors) # Save the results to disk before validating. ref_trace_dir = trace_utils.get_trace_dir(modules.artifacts_dir, ref_trace) ref_trace.save_plaintext(ref_trace_dir, FLAGS.summarize) ref_trace.serialize(ref_trace_dir) for tar_trace in tar_traces: tar_trace_dir = trace_utils.get_trace_dir(modules.artifacts_dir, tar_trace) tar_trace.save_plaintext(tar_trace_dir, FLAGS.summarize) tar_trace.serialize(tar_trace_dir) # Validate results. if failed_backend_indices: # Extract info for logging. failed_backends = [ tar_traces[i].backend_id for i in failed_backend_indices ] error_list = ''.join([f'\n - {message}' for message in error_messages]) self.fail( "Comparison between the reference backend and the following targets " f"failed: {failed_backends}. Errors: {error_list}\n" "See the logs above for more details about the non-matching calls.")