def tearDown(self): from dxl.learn.backend import current_backend, TensorFlow from dxl.learn.core.config import clear_config from dxl.learn.core import SubgraphMakerFactory, SubgraphMakerTable if isinstance(current_backend(), TensorFlow): current_backend().unbox().reset_default_graph() clear_config() SubgraphMakerFactory.reset() SubgraphMakerTable.reset()
def kernel(self, inputs=None): KT, KC, KS = self.KEYS.TENSOR, self.KEYS.CONFIG, self.KEYS.GRAPH self.tensors[KT.LEARNING_RATE] = NotTrainableVariable( self.info.child_tensor(KT.LEARNING_RATE), [], current_backend().float32, self.config(KT.LEARNING_RATE)) self.tensors[KT.DECAY_LEARNING_RATE] = self.tensors[ KT.LEARNING_RATE].assign(self.tensors[KT.LEARNING_RATE] * self.config(KC.DECAY_RATIO)) self.tensors[KS.OPTIMIZER] = self._get_optimizer( self.config(KC.OPTIMIZER_NAME))
def sandbox(): from dxl.learn.backend import current_backend return current_backend().sandbox
def sandbox(): with current_backend().sandbox(): yield
class TestCase(current_backend().TestCase()): def setUp(self): pass def tearDown(self): from dxl.learn.backend import current_backend, TensorFlow from dxl.learn.core.config import clear_config from dxl.learn.core import SubgraphMakerFactory, SubgraphMakerTable if isinstance(current_backend(), TensorFlow): current_backend().unbox().reset_default_graph() clear_config() SubgraphMakerFactory.reset() SubgraphMakerTable.reset() def make_dummy_tensor(self, info=None): from dxl.learn.core import Constant if info is None: info = str(uuid.uuid4()) return Constant(0.0, info) def make_dummy_variable(self, info=None): from dxl.learn.core import Variable if info is None: info == str(uuid.uuid4()) return Variable(info, []) @property def resource_path(self): from .resource import test_resource_path # return Path(os.getenv('DEV_DXLEARN_TEST_RESOURCE_PATH')) return test_resource_path def assertFloatArrayEqual(self, first, second, msg=None): if msg is None: msg = '' return np.testing.assert_array_almost_equal( np.array(first), np.array(second), err_msg=msg) def assertNameEqual(self, first, second, with_strip_colon_and_index=True): names = map(get_object_name_str, [first, second]) if with_strip_colon_and_index: names = map(name_str_without_colon_and_index, names) names = list(names) self.assertEqual(names[0], names[1], 'Name not equal.') @pytest.mark.skip # @unittest.skip @contextmanager def test_session(self): from dxl.learn.core.session import TestSession with super().test_session() as sess: yield TestSession(sess) @contextmanager def variables_initialized_test_session(self): with self.test_session() as sess: sess.run(tf.global_variables_initializer()) yield sess @contextmanager def graph_on_cpu(self): with tf.device('/cpu:0'): with tf.Graph().as_default() as g: yield
def test_default_current_backend(self): self.assertIsInstance(current_backend(), TensorFlow)