def test_no_privacy_sum(self):
        with self.cached_session() as sess:
            record1 = tf.constant([2.0, 0.0])
            record2 = tf.constant([-1.0, 1.0])

            query = no_privacy_query.NoPrivacySumQuery()
            query_result = _run_query(query, [record1, record2])
            result = sess.run(query_result)
            expected = [1.0, 1.0]
            self.assertAllClose(result, expected)
    def test_no_privacy_weighted_sum(self):
        with self.cached_session() as sess:
            record1 = tf.constant([2.0, 0.0])
            record2 = tf.constant([-1.0, 1.0])

            weights = [1, 2]

            query = no_privacy_query.NoPrivacySumQuery()
            query_result, _ = test_utils.run_query(query, [record1, record2],
                                                   weights=weights)
            result = sess.run(query_result)
            expected = [0.0, 2.0]
            self.assertAllClose(result, expected)
 def test_incompatible_records(self, record1, record2, error_type):
     query = no_privacy_query.NoPrivacySumQuery()
     with self.assertRaises(error_type):
         _run_query(query, [record1, record2])