def testAInvUpdateEmptyObservations(self, batch_size, context_dim): a_array = 2 * np.eye(context_dim) + np.array( range(context_dim * context_dim)).reshape( (context_dim, context_dim)) a_array = a_array + a_array.T a_inv_array = np.linalg.inv(a_array) expected_a_inv_update_array = np.zeros([context_dim, context_dim], dtype=np.float32) a_inv = tf.constant(a_inv_array, dtype=tf.float32, shape=[context_dim, context_dim]) x = tf.constant([], dtype=tf.float32, shape=[0, context_dim]) a_inv_update = linalg.update_inverse(a_inv, x) self.assertAllClose(expected_a_inv_update_array, self.evaluate(a_inv_update))
def testAInvUpdate(self, batch_size, context_dim): a_array = 2 * np.eye(context_dim) + np.array( range(context_dim * context_dim)).reshape( (context_dim, context_dim)) a_array = a_array + a_array.T a_inv_array = np.linalg.inv(a_array) x_array = np.array(range(batch_size * context_dim)).reshape( (batch_size, context_dim)) expected_a_inv_updated_array = np.linalg.inv( a_array + np.matmul(np.transpose(x_array), x_array)) a_inv = tf.constant(a_inv_array, dtype=tf.float32, shape=[context_dim, context_dim]) x = tf.constant(x_array, dtype=tf.float32, shape=[batch_size, context_dim]) a_inv_update = linalg.update_inverse(a_inv, x) self.assertAllClose(expected_a_inv_updated_array, self.evaluate(a_inv + a_inv_update))