Exemplo n.º 1
0
    def test_function(self):
        val = np.random.random((4, 2))
        input_val = np.random.random((4, 2))

        xth = KTH.variable(val)
        xtf = KTF.variable(val)
        yth = KTH.placeholder(ndim=2)
        ytf = KTF.placeholder(ndim=2)

        exp_th = KTH.square(xth) + yth
        exp_tf = KTF.square(xtf) + ytf

        update_th = xth * 2
        update_tf = xtf * 2
        fth = KTH.function([yth], [exp_th], updates=[(xth, update_th)])
        ftf = KTF.function([ytf], [exp_tf], updates=[(xtf, update_tf)])

        function_outputs_th = fth([input_val])[0]
        function_outputs_tf = ftf([input_val])[0]
        assert function_outputs_th.shape == function_outputs_tf.shape
        assert_allclose(function_outputs_th, function_outputs_tf, atol=1e-05)

        new_val_th = KTH.get_value(xth)
        new_val_tf = KTF.get_value(xtf)
        assert new_val_th.shape == new_val_tf.shape
        assert_allclose(new_val_th, new_val_tf, atol=1e-05)
Exemplo n.º 2
0
    def test_function(self):
        val = np.random.random((4, 2))
        input_val = np.random.random((4, 2))

        xth = KTH.variable(val)
        xtf = KTF.variable(val)
        yth = KTH.placeholder(ndim=2)
        ytf = KTF.placeholder(ndim=2)

        exp_th = KTH.square(xth) + yth
        exp_tf = KTF.square(xtf) + ytf

        update_th = xth * 2
        update_tf = xtf * 2
        fth = KTH.function([yth], [exp_th], updates=[(xth, update_th)])
        ftf = KTF.function([ytf], [exp_tf], updates=[(xtf, update_tf)])

        function_outputs_th = fth([input_val])[0]
        function_outputs_tf = ftf([input_val])[0]
        assert function_outputs_th.shape == function_outputs_tf.shape
        assert_allclose(function_outputs_th, function_outputs_tf, atol=1e-05)

        new_val_th = KTH.get_value(xth)
        new_val_tf = KTF.get_value(xtf)
        assert new_val_th.shape == new_val_tf.shape
        assert_allclose(new_val_th, new_val_tf, atol=1e-05)