Exemplo n.º 1
0
  def testDtypes(self, count=10):
    """Test that we can define the ops with float64 weights."""

    vals = self.get_values(count)
    weights = math_ops.cast(self.get_weights(count), dtypes.float64)

    # should not error:
    resample.resample_at_rate([vals], weights)
    resample.weighted_resample(
        [vals], weights, overall_rate=math_ops.cast(1.0, dtypes.float64))
Exemplo n.º 2
0
  def testCorrectRates(self, rate=0.25, count=10, n=500, rtol=0.1):
    """Tests that the rates returned by weighted_resample are correct."""

    # The approach here is to verify that:
    #  - sum(1/rate) approximates the size of the original collection
    #  - sum(1/rate * value) approximates the sum of the original inputs,
    #  - sum(1/rate * value)/sum(1/rate) approximates the mean.
    vals = self.get_values(count)
    weights = self.get_weights(count)

    resampled, rates = resample.weighted_resample([vals],
                                                  constant_op.constant(weights),
                                                  rate)

    invrates = 1.0 / rates

    init = control_flow_ops.group(variables.local_variables_initializer(),
                                  variables.global_variables_initializer())
    expected_sum_op = math_ops.reduce_sum(vals)
    with self.test_session() as s:
      s.run(init)
      expected_sum = n * s.run(expected_sum_op)

      weight_sum = 0.0
      weighted_value_sum = 0.0
      for _ in range(n):
        val, inv_rate = s.run([resampled[0], invrates])
        weight_sum += sum(inv_rate)
        weighted_value_sum += sum(val * inv_rate)

    # sum(inv_rate) ~= N*count:
    expected_count = count * n
    self.assertAlmostEqual(
        expected_count, weight_sum, delta=(rtol * expected_count))

    # sum(vals) * n ~= weighted_sum(resampled, 1.0/weights)
    self.assertAlmostEqual(
        expected_sum, weighted_value_sum, delta=(rtol * expected_sum))

    # Mean ~= weighted mean:
    expected_mean = expected_sum / float(n * count)
    self.assertAlmostEqual(
        expected_mean,
        weighted_value_sum / weight_sum,
        delta=(rtol * expected_mean))
Exemplo n.º 3
0
    def testRoundtrip(self, rate=0.25, count=5, n=500):
        """Tests `resample(x, weights)` and resample(resample(x, rate), 1/rate)`."""

        foo = self.get_values(count)
        bar = self.get_values(count)
        weights = self.get_weights(count)

        resampled_in, rates = resample.weighted_resample(
            [foo, bar], constant_op.constant(weights), rate, seed=123)

        resampled_back_out = resample.resample_at_rate(resampled_in,
                                                       1.0 / rates,
                                                       seed=456)

        init = control_flow_ops.group(variables.local_variables_initializer(),
                                      variables.global_variables_initializer())
        with self.cached_session() as s:
            s.run(init)  # initialize

            # outputs
            counts_resampled = collections.Counter()
            counts_reresampled = collections.Counter()
            for _ in range(n):
                resampled_vs, reresampled_vs = s.run(
                    [resampled_in, resampled_back_out])

                self.assertAllEqual(resampled_vs[0], resampled_vs[1])
                self.assertAllEqual(reresampled_vs[0], reresampled_vs[1])

                for v in resampled_vs[0]:
                    counts_resampled[v] += 1
                for v in reresampled_vs[0]:
                    counts_reresampled[v] += 1

            # assert that resampling worked as expected
            self.assert_expected(weights, rate, counts_resampled, n)

            # and that re-resampling gives the approx identity.
            self.assert_expected([1.0 for _ in weights],
                                 1.0,
                                 counts_reresampled,
                                 n,
                                 abs_delta=0.1 * n * count)
Exemplo n.º 4
0
  def testRoundtrip(self, rate=0.25, count=5, n=500):
    """Tests `resample(x, weights)` and resample(resample(x, rate), 1/rate)`."""

    foo = self.get_values(count)
    bar = self.get_values(count)
    weights = self.get_weights(count)

    resampled_in, rates = resample.weighted_resample(
        [foo, bar], constant_op.constant(weights), rate, seed=123)

    resampled_back_out = resample.resample_at_rate(
        resampled_in, 1.0 / rates, seed=456)

    init = control_flow_ops.group(variables.local_variables_initializer(),
                                  variables.global_variables_initializer())
    with self.test_session() as s:
      s.run(init)  # initialize

      # outputs
      counts_resampled = collections.Counter()
      counts_reresampled = collections.Counter()
      for _ in range(n):
        resampled_vs, reresampled_vs = s.run([resampled_in, resampled_back_out])

        self.assertAllEqual(resampled_vs[0], resampled_vs[1])
        self.assertAllEqual(reresampled_vs[0], reresampled_vs[1])

        for v in resampled_vs[0]:
          counts_resampled[v] += 1
        for v in reresampled_vs[0]:
          counts_reresampled[v] += 1

      # assert that resampling worked as expected
      self.assert_expected(weights, rate, counts_resampled, n)

      # and that re-resampling gives the approx identity.
      self.assert_expected(
          [1.0 for _ in weights],
          1.0,
          counts_reresampled,
          n,
          abs_delta=0.1 * n * count)