Пример #1
0
    def test_trace_serialize_and_load(self):
        def trace_function(module):
            module.increment()
            module.increment_by(np.array([81.], dtype=np.float32))
            module.increment_by_max(np.array([81], dtype=np.float32),
                                    np.array([92], dtype=np.float32))
            module.get_count()

        module = tf_utils.IreeCompiledModule(StatefulCountingModule,
                                             tf_utils.BackendInfo('iree_vmla'))
        trace = tf_test_utils.Trace(module, trace_function)
        trace_function(tf_test_utils.TracedModule(module, trace))

        with tempfile.TemporaryDirectory() as artifacts_dir:
            trace_function_dir = tf_test_utils._get_trace_dir(
                artifacts_dir, trace)
            trace.serialize(trace_function_dir)
            self.assertTrue(
                os.path.exists(os.path.join(trace_function_dir, 'flagfile')))
            loaded_trace = tf_test_utils.Trace.load(trace_function_dir)

            # Check all calls match.
            self.assertTrue(
                tf_test_utils.Trace.compare_traces(trace, loaded_trace))

            # Check all other metadata match.
            self.assertAllEqual(trace.__dict__.keys(),
                                loaded_trace.__dict__.keys())
            for key in trace.__dict__.keys():
                if key != 'calls':
                    self.assertEqual(trace.__dict__[key],
                                     loaded_trace.__dict__[key])
Пример #2
0
    def test_nonmatching_inputs(self):
        def tf_function(module):
            module.increment_by(np.array([42.], dtype=np.float32))

        def vmla_function(module):
            module.increment_by(np.array([22.], dtype=np.float32))

        tf_module = tf_utils.TfCompiledModule(StatefulCountingModule,
                                              tf_utils.BackendInfo('tf'))
        tf_trace = tf_test_utils.Trace(tf_module, tf_function)
        tf_function(tf_test_utils.TracedModule(tf_module, tf_trace))

        vmla_module = tf_utils.IreeCompiledModule(
            StatefulCountingModule, tf_utils.BackendInfo('iree_vmla'))
        vmla_trace = tf_test_utils.Trace(vmla_module, vmla_function)
        vmla_function(tf_test_utils.TracedModule(vmla_module, vmla_trace))

        self.assertFalse(
            tf_test_utils.Trace.compare_traces(tf_trace, vmla_trace))
Пример #3
0
    def test_nonmatching_methods(self):
        def tf_function(module):
            module.increment()
            module.increment()

        def vmla_function(module):
            module.increment()
            module.decrement()

        tf_module = tf_utils.TfCompiledModule(StatefulCountingModule,
                                              tf_utils.BackendInfo('tf'))
        tf_trace = tf_test_utils.Trace(tf_module, tf_function)
        tf_function(tf_test_utils.TracedModule(tf_module, tf_trace))

        vmla_module = tf_utils.IreeCompiledModule(
            StatefulCountingModule, tf_utils.BackendInfo('iree_vmla'))
        vmla_trace = tf_test_utils.Trace(vmla_module, vmla_function)
        vmla_function(tf_test_utils.TracedModule(vmla_module, vmla_trace))

        with self.assertRaises(ValueError):
            tf_test_utils.Trace.compare_traces(tf_trace, vmla_trace)