コード例 #1
0
    def __init__(self,
                 hidden_1_dim=256,
                 hidden_2_dim=256,
                 input_dim=28 * 28,
                 classes=10):
        super().__init__()
        tf_utils.set_random_seed()
        self.hidden_1_dim = hidden_1_dim
        self.hidden_2_dim = hidden_2_dim
        self.input_dim = input_dim
        self.classes = classes
        self.h1_weights = tf.Variable(
            tf.random.normal([input_dim, hidden_1_dim]))
        self.h2_weights = tf.Variable(
            tf.random.normal([hidden_1_dim, hidden_2_dim]))
        self.out_weights = tf.Variable(
            tf.random.normal([hidden_2_dim, classes]))
        self.h1_bias = tf.Variable(tf.random.normal([hidden_1_dim]))
        self.h2_bias = tf.Variable(tf.random.normal([hidden_2_dim]))
        self.out_bias = tf.Variable(tf.random.normal([classes]))

        # Compile with dynamic batch dim.
        self.predict = tf.function(
            input_signature=[tf.TensorSpec([None, self.input_dim])])(
                self.predict)
コード例 #2
0
 def __init__(self):
     super().__init__()
     tf_utils.set_random_seed()
     model_path = posixpath.join(FLAGS.tf_hub_url, FLAGS.model, MODE)
     hub_layer = hub.KerasLayer(model_path)
     self.m = tf.keras.Sequential([hub_layer])
     input_shape = get_input_shape()
     self.m.build(input_shape)
     self.predict = tf.function(
         input_signature=[tf.TensorSpec(input_shape)])(self.m.call)
コード例 #3
0
    def create_from_class(cls,
                          module_class: Type[tf.Module],
                          backend_info: "BackendInfo",
                          exported_names: Sequence[str] = (),
                          artifacts_dir: str = None):
        """Compile a tf.Module subclass to the target backend in backend_info.

    Args:
      module_class: The tf.Module subclass to compile.
      backend_info: BackendInfo with the details for compiling module to IREE.
      exported_names: Optional sequence representing the exported names to keep.
      artifacts_dir: An optional string pointing to where compilation artifacts
        should be saved. No compilation artifacts will be saved if this is not
        provided.
    """
        tf_utils.set_random_seed()
        module_instance = module_class()
        return cls.create_from_instance(module_instance, backend_info,
                                        exported_names, artifacts_dir)
コード例 #4
0
    def create_from_class(cls,
                          module_class: Type[tf.Module],
                          backend_info: "BackendInfo",
                          exported_names: Sequence[str] = (),
                          artifacts_dir: str = None):
        """Compile a tf.Module subclass to the target backend in backend_info.

    Args:
      module_class: The tf.Module subclass to compile.
      backend_info: BackendInfo with the details for compiling this module.
      exported_names: Optional sequence representing the exported names to keep.
      artifacts_dir: An optional string pointing to where compilation artifacts
        should be saved. No compilation artifacts will be saved if this is not
        provided.
    """
        tf_utils.set_random_seed()
        tflite_module_bytes = tf_module_to_tflite_module_bytes(
            module_class, exported_names)
        interpreters, compiled_paths = tflite_module_bytes_to_tflite_interpreters(
            tflite_module_bytes, artifacts_dir)
        module_name = module_class.__name__
        return cls(module_name, backend_info, compiled_paths, interpreters)
コード例 #5
0
 def reinitialize(self):
     """Reinitializes all stateful variables."""
     tf_utils.set_random_seed()
     self._tf_module = self._constructor()
コード例 #6
0
  def compare_backends(self,
                       trace_function: Callable[[trace_utils.TracedModule],
                                                None],
                       modules: Modules) -> None:
    """Run the reference and target backends on trace_function and compare them.

    Random seeds for tensorflow, numpy and python are set before each invocation
    of trace_function.

    Args:
      trace_function: a function accepting a TracedModule as its argument.
    """
    # Create Traces for each backend.
    ref_trace = trace_utils.Trace(modules.ref_module, trace_function)
    tar_traces = [
        trace_utils.Trace(module, trace_function)
        for module in modules.tar_modules
    ]

    # Run the traces through trace_function with their associated modules.
    tf_utils.set_random_seed()
    trace_function(trace_utils.TracedModule(modules.ref_module, ref_trace))
    if FLAGS.log_all_traces:
      logging.info(ref_trace)
    for module, trace in zip(modules.tar_modules, tar_traces):
      tf_utils.set_random_seed()
      trace_function(trace_utils.TracedModule(module, trace))
      if FLAGS.log_all_traces:
        logging.info(trace)

    # Compare each target trace of trace_function with the reference trace.
    failed_backend_indices = []
    error_messages = []
    for i, tar_trace in enumerate(tar_traces):
      logging.info("Comparing the reference backend '%s' with '%s'",
                   ref_trace.backend_id, tar_trace.backend_id)
      traces_match, errors = trace_utils.compare_traces(ref_trace, tar_trace)
      if not traces_match:
        failed_backend_indices.append(i)
        error_messages.extend(errors)

    # Save the results to disk before validating.
    ref_trace_dir = trace_utils.get_trace_dir(modules.artifacts_dir, ref_trace)
    ref_trace.save_plaintext(ref_trace_dir, FLAGS.summarize)
    ref_trace.serialize(ref_trace_dir)
    for tar_trace in tar_traces:
      tar_trace_dir = trace_utils.get_trace_dir(modules.artifacts_dir,
                                                tar_trace)
      tar_trace.save_plaintext(tar_trace_dir, FLAGS.summarize)
      tar_trace.serialize(tar_trace_dir)

    # Validate results.
    if failed_backend_indices:
      # Extract info for logging.
      failed_backends = [
          tar_traces[i].backend_id for i in failed_backend_indices
      ]
      error_list = ''.join([f'\n  - {message}' for message in error_messages])
      self.fail(
          "Comparison between the reference backend and the following targets "
          f"failed: {failed_backends}. Errors: {error_list}\n"
          "See the logs above for more details about the non-matching calls.")