Пример #1
0
    def test_moments(self):
        input_shape = (10, 10, 10, 10)
        x_0 = np.zeros(input_shape)
        x_1 = np.ones(input_shape)
        x_random = np.random.random(input_shape)

        th_axes = [0, 2, 3]
        tf_axes = [0, 1, 2]

        for ip in [x_0, x_1, x_random]:
            for axes in [th_axes, tf_axes]:
                for keep_dims in [True, False]:
                    ip_th = KTH.variable(ip)
                    th_mean, th_var = KCTH.moments(ip_th, axes, keep_dims=keep_dims)

                    ip_tf = KTF.variable(ip)
                    tf_mean, tf_var = KCTF.moments(ip_tf, axes, keep_dims=keep_dims)

                    ip_cntk = KCTK.variable(ip)
                    cntk_mean, cntk_var = KCNTK.moments(ip_cntk, axes, keep_dims=keep_dims)

                    th_mean_val = KTH.eval(th_mean)
                    tf_mean_val = KTF.eval(tf_mean)
                    cntk_mean_val = KCTK.eval(cntk_mean)
                    th_var_val = KTH.eval(th_var)
                    tf_var_val = KTF.eval(tf_var)
                    cntk_var_val = KCTK.eval(cntk_var)

                    # absolute tolerance needed when working with zeros
                    assert_allclose(th_mean_val, tf_mean_val, rtol=1e-4, atol=1e-10)
                    assert_allclose(th_var_val, tf_var_val, rtol=1e-4, atol=1e-10)
                    assert_allclose(th_mean_val, cntk_mean_val, rtol=1e-4, atol=1e-10)
                    assert_allclose(th_var_val, cntk_var_val, rtol=1e-4, atol=1e-10)
Пример #2
0
    def test_moments(self):
        input_shape = (10, 10, 10, 10)
        x_0 = np.zeros(input_shape)
        x_1 = np.ones(input_shape)
        x_random = np.random.random(input_shape)

        th_axes = [0, 2, 3]
        tf_axes = [0, 1, 2]

        for ip in [x_0, x_1, x_random]:
            for axes in [th_axes, tf_axes]:
                for keep_dims in [True, False]:
                    ip_th = KTH.variable(ip)
                    th_mean, th_var = KCTH.moments(ip_th,
                                                   axes,
                                                   keep_dims=keep_dims)

                    ip_tf = KTF.variable(ip)
                    tf_mean, tf_var = KCTF.moments(ip_tf,
                                                   axes,
                                                   keep_dims=keep_dims)

                    th_mean_val = KTH.eval(th_mean)
                    tf_mean_val = KTF.eval(tf_mean)
                    th_var_val = KTH.eval(th_var)
                    tf_var_val = KTF.eval(tf_var)

                    assert_allclose(th_mean_val, tf_mean_val, rtol=1e-4)
                    assert_allclose(th_var_val, tf_var_val, rtol=1e-4)