Пример #1
0
 def tearDown(self):
     from dxl.learn.backend import current_backend, TensorFlow
     from dxl.learn.core.config import clear_config
     from dxl.learn.core import SubgraphMakerFactory, SubgraphMakerTable
     if isinstance(current_backend(), TensorFlow):
         current_backend().unbox().reset_default_graph()
     clear_config()
     SubgraphMakerFactory.reset()
     SubgraphMakerTable.reset()
Пример #2
0
    def kernel(self, inputs=None):
        KT, KC, KS = self.KEYS.TENSOR, self.KEYS.CONFIG, self.KEYS.GRAPH
        self.tensors[KT.LEARNING_RATE] = NotTrainableVariable(
            self.info.child_tensor(KT.LEARNING_RATE), [],
            current_backend().float32, self.config(KT.LEARNING_RATE))
        self.tensors[KT.DECAY_LEARNING_RATE] = self.tensors[
            KT.LEARNING_RATE].assign(self.tensors[KT.LEARNING_RATE] *
                                     self.config(KC.DECAY_RATIO))

        self.tensors[KS.OPTIMIZER] = self._get_optimizer(
            self.config(KC.OPTIMIZER_NAME))
Пример #3
0
def sandbox():
    from dxl.learn.backend import current_backend
    return current_backend().sandbox
Пример #4
0
def sandbox():
    with current_backend().sandbox():
        yield
Пример #5
0
class TestCase(current_backend().TestCase()):
    def setUp(self):
        pass

    def tearDown(self):
        from dxl.learn.backend import current_backend, TensorFlow
        from dxl.learn.core.config import clear_config
        from dxl.learn.core import SubgraphMakerFactory, SubgraphMakerTable
        if isinstance(current_backend(), TensorFlow):
            current_backend().unbox().reset_default_graph()
        clear_config()
        SubgraphMakerFactory.reset()
        SubgraphMakerTable.reset()

    def make_dummy_tensor(self, info=None):
        from dxl.learn.core import Constant
        if info is None:
            info = str(uuid.uuid4())
        return Constant(0.0, info)

    def make_dummy_variable(self, info=None):
        from dxl.learn.core import Variable
        if info is None:
            info == str(uuid.uuid4())
        return Variable(info, [])

    @property
    def resource_path(self):
        from .resource import test_resource_path
        # return Path(os.getenv('DEV_DXLEARN_TEST_RESOURCE_PATH'))
        return test_resource_path

    def assertFloatArrayEqual(self, first, second, msg=None):
        if msg is None:
            msg = ''
        return np.testing.assert_array_almost_equal(
            np.array(first), np.array(second), err_msg=msg)

    def assertNameEqual(self, first, second, with_strip_colon_and_index=True):
        names = map(get_object_name_str, [first, second])
        if with_strip_colon_and_index:
            names = map(name_str_without_colon_and_index, names)
        names = list(names)
        self.assertEqual(names[0], names[1], 'Name not equal.')

    @pytest.mark.skip
    # @unittest.skip
    @contextmanager
    def test_session(self):
        from dxl.learn.core.session import TestSession
        with super().test_session() as sess:
            yield TestSession(sess)

    @contextmanager
    def variables_initialized_test_session(self):
        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            yield sess

    @contextmanager
    def graph_on_cpu(self):
        with tf.device('/cpu:0'):
            with tf.Graph().as_default() as g:
                yield
Пример #6
0
 def test_default_current_backend(self):
     self.assertIsInstance(current_backend(), TensorFlow)