def test_trace_serialize_and_load(self): def trace_function(module): module.increment() module.increment_by(np.array([81.], dtype=np.float32)) module.increment_by_max(np.array([81], dtype=np.float32), np.array([92], dtype=np.float32)) module.get_count() module = tf_utils.IreeCompiledModule.create_from_class( StatefulCountingModule, tf_utils.BackendInfo('iree_vmla')) trace = tf_test_utils.Trace(module, trace_function) trace_function(tf_test_utils.TracedModule(module, trace)) with tempfile.TemporaryDirectory() as artifacts_dir: trace_function_dir = tf_test_utils._get_trace_dir(artifacts_dir, trace) trace.serialize(trace_function_dir) self.assertTrue( os.path.exists(os.path.join(trace_function_dir, 'metadata.pkl'))) loaded_trace = tf_test_utils.Trace.load(trace_function_dir) # Check all calls match. self.assertTrue(tf_test_utils.Trace.compare_traces(trace, loaded_trace)) # Check all other metadata match. self.assertAllEqual(trace.__dict__.keys(), loaded_trace.__dict__.keys()) for key in trace.__dict__.keys(): if key != 'calls': self.assertEqual(trace.__dict__[key], loaded_trace.__dict__[key])
def test_nonmatching_inputs(self): def tf_function(module): module.increment_by(np.array([42.], dtype=np.float32)) def vmla_function(module): module.increment_by(np.array([22.], dtype=np.float32)) tf_module = tf_utils.TfCompiledModule.create_from_class( StatefulCountingModule, tf_utils.BackendInfo('tf')) tf_trace = tf_test_utils.Trace(tf_module, tf_function) tf_function(tf_test_utils.TracedModule(tf_module, tf_trace)) vmla_module = tf_utils.IreeCompiledModule.create_from_class( StatefulCountingModule, tf_utils.BackendInfo('iree_vmla')) vmla_trace = tf_test_utils.Trace(vmla_module, vmla_function) vmla_function(tf_test_utils.TracedModule(vmla_module, vmla_trace)) self.assertFalse(tf_test_utils.Trace.compare_traces(tf_trace, vmla_trace))
def test_nonmatching_methods(self): def tf_function(module): module.increment() module.increment() def vmla_function(module): module.increment() module.decrement() tf_module = tf_utils.TfCompiledModule(StatefulCountingModule, tf_utils.BackendInfo('tf')) tf_trace = tf_test_utils.Trace(tf_module, tf_function) tf_function(tf_test_utils.TracedModule(tf_module, tf_trace)) vmla_module = tf_utils.IreeCompiledModule( StatefulCountingModule, tf_utils.BackendInfo('iree_vmla')) vmla_trace = tf_test_utils.Trace(vmla_module, vmla_function) vmla_function(tf_test_utils.TracedModule(vmla_module, vmla_trace)) with self.assertRaises(ValueError): tf_test_utils.Trace.compare_traces(tf_trace, vmla_trace)
def test_trace_inputs_and_outputs(self): def trace_function(module): # No inputs or outpus module.increment() # Only inputs module.increment_by(np.array([81.], dtype=np.float32)) # Only outputs module.get_count() module = tf_utils.TfCompiledModule(StatefulCountingModule, tf_utils.BackendInfo('tf')) trace = tf_test_utils.Trace(module, trace_function) trace_function(tf_test_utils.TracedModule(module, trace)) self.assertIsInstance(trace.calls[0].inputs, tuple) self.assertEmpty(trace.calls[0].inputs) self.assertIsInstance(trace.calls[0].outputs, tuple) self.assertEmpty(trace.calls[0].outputs) self.assertAllClose(trace.calls[1].inputs[0], [81.]) self.assertAllClose(trace.calls[2].outputs[0], [82.])