示例#1
0
    def testAInvUpdateEmptyObservations(self, batch_size, context_dim):
        a_array = 2 * np.eye(context_dim) + np.array(
            range(context_dim * context_dim)).reshape(
                (context_dim, context_dim))
        a_array = a_array + a_array.T
        a_inv_array = np.linalg.inv(a_array)
        expected_a_inv_update_array = np.zeros([context_dim, context_dim],
                                               dtype=np.float32)

        a_inv = tf.constant(a_inv_array,
                            dtype=tf.float32,
                            shape=[context_dim, context_dim])
        x = tf.constant([], dtype=tf.float32, shape=[0, context_dim])
        a_inv_update = linalg.update_inverse(a_inv, x)
        self.assertAllClose(expected_a_inv_update_array,
                            self.evaluate(a_inv_update))
示例#2
0
    def testAInvUpdate(self, batch_size, context_dim):
        a_array = 2 * np.eye(context_dim) + np.array(
            range(context_dim * context_dim)).reshape(
                (context_dim, context_dim))
        a_array = a_array + a_array.T
        a_inv_array = np.linalg.inv(a_array)
        x_array = np.array(range(batch_size * context_dim)).reshape(
            (batch_size, context_dim))
        expected_a_inv_updated_array = np.linalg.inv(
            a_array + np.matmul(np.transpose(x_array), x_array))

        a_inv = tf.constant(a_inv_array,
                            dtype=tf.float32,
                            shape=[context_dim, context_dim])
        x = tf.constant(x_array,
                        dtype=tf.float32,
                        shape=[batch_size, context_dim])
        a_inv_update = linalg.update_inverse(a_inv, x)
        self.assertAllClose(expected_a_inv_updated_array,
                            self.evaluate(a_inv + a_inv_update))