def test_load_v2(self): read_count = metrics.GetRead(write_version="2") load_v2_count = metrics.GetReadApi(load._LOAD_V2_LABEL) save_dir = self._create_save_v2_model() load.load(save_dir) self.assertEqual(metrics.GetReadApi(load._LOAD_V2_LABEL), load_v2_count + 1) self.assertEqual(metrics.GetRead(write_version="2"), read_count + 1)
def test_loader_v1(self): read_count = metrics.GetRead(write_version="1") ops.disable_eager_execution() save_dir = self._create_save_v1_model() loader = loader_impl.SavedModelLoader(save_dir) with self.session(graph=ops.Graph()) as sess: loader.load(sess, ["foo"]) ops.enable_eager_execution() self.assertEqual(metrics.GetReadApi(loader_impl._LOADER_LABEL), 1) self.assertEqual(metrics.GetRead(write_version="1"), read_count + 1)
def test_load_v1_in_v2(self): read_v1_count = metrics.GetRead(write_version="1") read_v2_count = metrics.GetRead(write_version="2") load_v2_count = metrics.GetReadApi(load._LOAD_V2_LABEL) load_v1_v2_count = metrics.GetReadApi(load_v1_in_v2._LOAD_V1_V2_LABEL) save_dir = self._create_save_v1_model() load.load(save_dir) # Check that `load_v2` was *not* incremented. self.assertEqual(metrics.GetReadApi(load._LOAD_V2_LABEL), load_v2_count) self.assertEqual(metrics.GetRead(write_version="2"), read_v2_count) self.assertEqual(metrics.GetReadApi(load_v1_in_v2._LOAD_V1_V2_LABEL), load_v1_v2_count + 1) self.assertEqual(metrics.GetRead(write_version="1"), read_v1_count + 1)
def test_SM_increment_read(self): self.assertEqual(metrics.GetRead(write_version="2"), 0) metrics.IncrementReadApi("bar") self.assertEqual(metrics.GetReadApi("bar"), 1) metrics.IncrementRead(write_version="2") self.assertEqual(metrics.GetRead(write_version="2"), 1)