def testDtypes(self, count=10): """Test that we can define the ops with float64 weights.""" vals = self.get_values(count) weights = math_ops.cast(self.get_weights(count), dtypes.float64) # should not error: resample.resample_at_rate([vals], weights) resample.weighted_resample( [vals], weights, overall_rate=math_ops.cast(1.0, dtypes.float64))
def testZeroRateUnknownShapes(self, count=10): """Tests that resampling runs with completely runtime shapes.""" # Use placeholcers without shape set: vals = array_ops.placeholder(dtype=dtypes.int32) rates = array_ops.placeholder(dtype=dtypes.float32) resampled = resample.resample_at_rate([vals], rates) with self.test_session() as s: rs = s.run(resampled, { vals: list(range(count)), rates: numpy.zeros( shape=[count], dtype=numpy.float32) }) self.assertEqual(0, len(rs))
def testRoundtrip(self, rate=0.25, count=5, n=500): """Tests `resample(x, weights)` and resample(resample(x, rate), 1/rate)`.""" foo = self.get_values(count) bar = self.get_values(count) weights = self.get_weights(count) resampled_in, rates = resample.weighted_resample( [foo, bar], constant_op.constant(weights), rate, seed=123) resampled_back_out = resample.resample_at_rate(resampled_in, 1.0 / rates, seed=456) init = control_flow_ops.group(variables.local_variables_initializer(), variables.global_variables_initializer()) with self.cached_session() as s: s.run(init) # initialize # outputs counts_resampled = collections.Counter() counts_reresampled = collections.Counter() for _ in range(n): resampled_vs, reresampled_vs = s.run( [resampled_in, resampled_back_out]) self.assertAllEqual(resampled_vs[0], resampled_vs[1]) self.assertAllEqual(reresampled_vs[0], reresampled_vs[1]) for v in resampled_vs[0]: counts_resampled[v] += 1 for v in reresampled_vs[0]: counts_reresampled[v] += 1 # assert that resampling worked as expected self.assert_expected(weights, rate, counts_resampled, n) # and that re-resampling gives the approx identity. self.assert_expected([1.0 for _ in weights], 1.0, counts_reresampled, n, abs_delta=0.1 * n * count)
def testRoundtrip(self, rate=0.25, count=5, n=500): """Tests `resample(x, weights)` and resample(resample(x, rate), 1/rate)`.""" foo = self.get_values(count) bar = self.get_values(count) weights = self.get_weights(count) resampled_in, rates = resample.weighted_resample( [foo, bar], constant_op.constant(weights), rate, seed=123) resampled_back_out = resample.resample_at_rate( resampled_in, 1.0 / rates, seed=456) init = control_flow_ops.group(variables.local_variables_initializer(), variables.global_variables_initializer()) with self.test_session() as s: s.run(init) # initialize # outputs counts_resampled = collections.Counter() counts_reresampled = collections.Counter() for _ in range(n): resampled_vs, reresampled_vs = s.run([resampled_in, resampled_back_out]) self.assertAllEqual(resampled_vs[0], resampled_vs[1]) self.assertAllEqual(reresampled_vs[0], reresampled_vs[1]) for v in resampled_vs[0]: counts_resampled[v] += 1 for v in reresampled_vs[0]: counts_reresampled[v] += 1 # assert that resampling worked as expected self.assert_expected(weights, rate, counts_resampled, n) # and that re-resampling gives the approx identity. self.assert_expected( [1.0 for _ in weights], 1.0, counts_reresampled, n, abs_delta=0.1 * n * count)