def testDtypes(self, count=10): """Test that we can define the ops with float64 weights.""" vals = self.get_values(count) weights = math_ops.cast(self.get_weights(count), dtypes.float64) # should not error: resample.resample_at_rate([vals], weights) resample.weighted_resample( [vals], weights, overall_rate=math_ops.cast(1.0, dtypes.float64))
def testCorrectRates(self, rate=0.25, count=10, n=500, rtol=0.1): """Tests that the rates returned by weighted_resample are correct.""" # The approach here is to verify that: # - sum(1/rate) approximates the size of the original collection # - sum(1/rate * value) approximates the sum of the original inputs, # - sum(1/rate * value)/sum(1/rate) approximates the mean. vals = self.get_values(count) weights = self.get_weights(count) resampled, rates = resample.weighted_resample([vals], constant_op.constant(weights), rate) invrates = 1.0 / rates init = control_flow_ops.group(variables.local_variables_initializer(), variables.global_variables_initializer()) expected_sum_op = math_ops.reduce_sum(vals) with self.test_session() as s: s.run(init) expected_sum = n * s.run(expected_sum_op) weight_sum = 0.0 weighted_value_sum = 0.0 for _ in range(n): val, inv_rate = s.run([resampled[0], invrates]) weight_sum += sum(inv_rate) weighted_value_sum += sum(val * inv_rate) # sum(inv_rate) ~= N*count: expected_count = count * n self.assertAlmostEqual( expected_count, weight_sum, delta=(rtol * expected_count)) # sum(vals) * n ~= weighted_sum(resampled, 1.0/weights) self.assertAlmostEqual( expected_sum, weighted_value_sum, delta=(rtol * expected_sum)) # Mean ~= weighted mean: expected_mean = expected_sum / float(n * count) self.assertAlmostEqual( expected_mean, weighted_value_sum / weight_sum, delta=(rtol * expected_mean))
def testRoundtrip(self, rate=0.25, count=5, n=500): """Tests `resample(x, weights)` and resample(resample(x, rate), 1/rate)`.""" foo = self.get_values(count) bar = self.get_values(count) weights = self.get_weights(count) resampled_in, rates = resample.weighted_resample( [foo, bar], constant_op.constant(weights), rate, seed=123) resampled_back_out = resample.resample_at_rate(resampled_in, 1.0 / rates, seed=456) init = control_flow_ops.group(variables.local_variables_initializer(), variables.global_variables_initializer()) with self.cached_session() as s: s.run(init) # initialize # outputs counts_resampled = collections.Counter() counts_reresampled = collections.Counter() for _ in range(n): resampled_vs, reresampled_vs = s.run( [resampled_in, resampled_back_out]) self.assertAllEqual(resampled_vs[0], resampled_vs[1]) self.assertAllEqual(reresampled_vs[0], reresampled_vs[1]) for v in resampled_vs[0]: counts_resampled[v] += 1 for v in reresampled_vs[0]: counts_reresampled[v] += 1 # assert that resampling worked as expected self.assert_expected(weights, rate, counts_resampled, n) # and that re-resampling gives the approx identity. self.assert_expected([1.0 for _ in weights], 1.0, counts_reresampled, n, abs_delta=0.1 * n * count)
def testRoundtrip(self, rate=0.25, count=5, n=500): """Tests `resample(x, weights)` and resample(resample(x, rate), 1/rate)`.""" foo = self.get_values(count) bar = self.get_values(count) weights = self.get_weights(count) resampled_in, rates = resample.weighted_resample( [foo, bar], constant_op.constant(weights), rate, seed=123) resampled_back_out = resample.resample_at_rate( resampled_in, 1.0 / rates, seed=456) init = control_flow_ops.group(variables.local_variables_initializer(), variables.global_variables_initializer()) with self.test_session() as s: s.run(init) # initialize # outputs counts_resampled = collections.Counter() counts_reresampled = collections.Counter() for _ in range(n): resampled_vs, reresampled_vs = s.run([resampled_in, resampled_back_out]) self.assertAllEqual(resampled_vs[0], resampled_vs[1]) self.assertAllEqual(reresampled_vs[0], reresampled_vs[1]) for v in resampled_vs[0]: counts_resampled[v] += 1 for v in reresampled_vs[0]: counts_reresampled[v] += 1 # assert that resampling worked as expected self.assert_expected(weights, rate, counts_resampled, n) # and that re-resampling gives the approx identity. self.assert_expected( [1.0 for _ in weights], 1.0, counts_reresampled, n, abs_delta=0.1 * n * count)