コード例 #1
0
  def test_trace_serialize_and_load(self):

    def trace_function(module):
      module.increment()
      module.increment_by(np.array([81.], dtype=np.float32))
      module.increment_by_max(np.array([81], dtype=np.float32),
                              np.array([92], dtype=np.float32))
      module.get_count()

    module = tf_utils.IreeCompiledModule.create_from_class(
        StatefulCountingModule, tf_utils.BackendInfo('iree_vmla'))
    trace = tf_test_utils.Trace(module, trace_function)
    trace_function(tf_test_utils.TracedModule(module, trace))

    with tempfile.TemporaryDirectory() as artifacts_dir:
      trace_function_dir = tf_test_utils._get_trace_dir(artifacts_dir, trace)
      trace.serialize(trace_function_dir)
      self.assertTrue(
          os.path.exists(os.path.join(trace_function_dir, 'metadata.pkl')))
      loaded_trace = tf_test_utils.Trace.load(trace_function_dir)

      # Check all calls match.
      self.assertTrue(tf_test_utils.Trace.compare_traces(trace, loaded_trace))

      # Check all other metadata match.
      self.assertAllEqual(trace.__dict__.keys(), loaded_trace.__dict__.keys())
      for key in trace.__dict__.keys():
        if key != 'calls':
          self.assertEqual(trace.__dict__[key], loaded_trace.__dict__[key])
コード例 #2
0
  def test_nonmatching_inputs(self):

    def tf_function(module):
      module.increment_by(np.array([42.], dtype=np.float32))

    def vmla_function(module):
      module.increment_by(np.array([22.], dtype=np.float32))

    tf_module = tf_utils.TfCompiledModule.create_from_class(
        StatefulCountingModule, tf_utils.BackendInfo('tf'))
    tf_trace = tf_test_utils.Trace(tf_module, tf_function)
    tf_function(tf_test_utils.TracedModule(tf_module, tf_trace))

    vmla_module = tf_utils.IreeCompiledModule.create_from_class(
        StatefulCountingModule, tf_utils.BackendInfo('iree_vmla'))
    vmla_trace = tf_test_utils.Trace(vmla_module, vmla_function)
    vmla_function(tf_test_utils.TracedModule(vmla_module, vmla_trace))

    self.assertFalse(tf_test_utils.Trace.compare_traces(tf_trace, vmla_trace))
コード例 #3
0
    def test_nonmatching_methods(self):
        def tf_function(module):
            module.increment()
            module.increment()

        def vmla_function(module):
            module.increment()
            module.decrement()

        tf_module = tf_utils.TfCompiledModule(StatefulCountingModule,
                                              tf_utils.BackendInfo('tf'))
        tf_trace = tf_test_utils.Trace(tf_module, tf_function)
        tf_function(tf_test_utils.TracedModule(tf_module, tf_trace))

        vmla_module = tf_utils.IreeCompiledModule(
            StatefulCountingModule, tf_utils.BackendInfo('iree_vmla'))
        vmla_trace = tf_test_utils.Trace(vmla_module, vmla_function)
        vmla_function(tf_test_utils.TracedModule(vmla_module, vmla_trace))

        with self.assertRaises(ValueError):
            tf_test_utils.Trace.compare_traces(tf_trace, vmla_trace)
コード例 #4
0
    def test_trace_inputs_and_outputs(self):
        def trace_function(module):
            # No inputs or outpus
            module.increment()
            # Only inputs
            module.increment_by(np.array([81.], dtype=np.float32))
            # Only outputs
            module.get_count()

        module = tf_utils.TfCompiledModule(StatefulCountingModule,
                                           tf_utils.BackendInfo('tf'))
        trace = tf_test_utils.Trace(module, trace_function)
        trace_function(tf_test_utils.TracedModule(module, trace))

        self.assertIsInstance(trace.calls[0].inputs, tuple)
        self.assertEmpty(trace.calls[0].inputs)
        self.assertIsInstance(trace.calls[0].outputs, tuple)
        self.assertEmpty(trace.calls[0].outputs)

        self.assertAllClose(trace.calls[1].inputs[0], [81.])
        self.assertAllClose(trace.calls[2].outputs[0], [82.])