def testLinearization(self, shape):
        key = random.PRNGKey(0)
        key, s1, s2, s3, = random.split(key, 4)
        w1 = random.normal(s1, shape)
        w1 = 0.5 * (w1 + w1.T)
        w2 = random.normal(s2, shape)
        b = random.normal(s3, (shape[-1], ))
        params = (w1, w2, b)

        key, split = random.split(key)
        x0 = random.normal(split, (shape[-1], ))

        f_lin = empirical.linearize(EmpiricalTest.f, x0)

        for _ in range(TAYLOR_RANDOM_SAMPLES):
            for do_alter in [True, False]:
                for do_shift_x in [True, False]:
                    key, split = random.split(key)
                    x = random.normal(split, (shape[-1], ))
                    self.assertAllClose(
                        EmpiricalTest.f_lin_exact(x0,
                                                  x,
                                                  params,
                                                  do_alter,
                                                  do_shift_x=do_shift_x),
                        f_lin(x, params, do_alter, do_shift_x=do_shift_x),
                        True)
Beispiel #2
0
    def testLinearization(self, shape):
        # We use a three layer deep linear network for testing.
        def f(x, params):
            w1, w2, b = params
            return 0.5 * np.dot(np.dot(x.T, w1), x) + np.dot(w2, x) + b

        def f_lin_exact(x0, x, params):
            w1, w2, b = params
            f0 = f(x0, params)
            dx = x - x0
            return f0 + np.dot(np.dot(x0.T, w1) + w2, dx)

        key = random.PRNGKey(0)
        key, s1, s2, s3, = random.split(key, 4)
        w1 = random.normal(s1, shape)
        w1 = 0.5 * (w1 + w1.T)
        w2 = random.normal(s2, shape)
        b = random.normal(s3, (shape[-1], ))
        params = (w1, w2, b)

        key, split = random.split(key)
        x0 = random.normal(split, (shape[-1], ))

        f_lin = empirical.linearize(f, x0)

        for _ in range(TAYLOR_RANDOM_SAMPLES):
            key, split = random.split(key)
            x = random.normal(split, (shape[-1], ))
            self.assertAllClose(f_lin_exact(x0, x, params), f_lin(x, params),
                                True)
    def testLinearization(self, shape):
        key = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        splits = tf_random_split(seed=tf.convert_to_tensor(key,
                                                           dtype=tf.int32),
                                 num=4)
        key = splits[0]
        s1 = splits[1]
        s2 = splits[2]
        s3 = splits[3]
        w1 = np.asarray(normal(shape, seed=s1))
        w1 = 0.5 * (w1 + w1.T)
        w2 = np.asarray(normal(shape, seed=s2))
        b = np.asarray(normal((shape[-1], ), seed=s3))
        params = (w1, w2, b)

        splits = tf_random_split(seed=tf.convert_to_tensor(key,
                                                           dtype=tf.int32),
                                 num=2)
        key = splits[0]
        split = splits[1]
        x0 = np.asarray(normal((shape[-1], ), seed=split))

        f_lin = empirical.linearize(EmpiricalTest.f, x0)

        for _ in range(TAYLOR_RANDOM_SAMPLES):
            for do_alter in [True, False]:
                for do_shift_x in [True, False]:
                    splits = tf_random_split(seed=tf.convert_to_tensor(
                        key, dtype=tf.int32),
                                             num=2)
                    key = splits[0]
                    split = splits[1]
                    x = np.asarray(normal((shape[-1], ), seed=split))
                    self.assertAllClose(
                        EmpiricalTest.f_lin_exact(x0,
                                                  x,
                                                  params,
                                                  do_alter,
                                                  do_shift_x=do_shift_x),
                        f_lin(x, params, do_alter, do_shift_x=do_shift_x))