예제 #1
0
  def testDtypes(self, count=10):
    """Test that we can define the ops with float64 weights."""

    vals = self.get_values(count)
    weights = math_ops.cast(self.get_weights(count), dtypes.float64)

    # should not error:
    resample.resample_at_rate([vals], weights)
    resample.weighted_resample(
        [vals], weights, overall_rate=math_ops.cast(1.0, dtypes.float64))
예제 #2
0
  def testZeroRateUnknownShapes(self, count=10):
    """Tests that resampling runs with completely runtime shapes."""
    # Use placeholcers without shape set:
    vals = array_ops.placeholder(dtype=dtypes.int32)
    rates = array_ops.placeholder(dtype=dtypes.float32)

    resampled = resample.resample_at_rate([vals], rates)

    with self.test_session() as s:
      rs = s.run(resampled, {
          vals: list(range(count)),
          rates: numpy.zeros(
              shape=[count], dtype=numpy.float32)
      })
      self.assertEqual(0, len(rs))
예제 #3
0
    def testRoundtrip(self, rate=0.25, count=5, n=500):
        """Tests `resample(x, weights)` and resample(resample(x, rate), 1/rate)`."""

        foo = self.get_values(count)
        bar = self.get_values(count)
        weights = self.get_weights(count)

        resampled_in, rates = resample.weighted_resample(
            [foo, bar], constant_op.constant(weights), rate, seed=123)

        resampled_back_out = resample.resample_at_rate(resampled_in,
                                                       1.0 / rates,
                                                       seed=456)

        init = control_flow_ops.group(variables.local_variables_initializer(),
                                      variables.global_variables_initializer())
        with self.cached_session() as s:
            s.run(init)  # initialize

            # outputs
            counts_resampled = collections.Counter()
            counts_reresampled = collections.Counter()
            for _ in range(n):
                resampled_vs, reresampled_vs = s.run(
                    [resampled_in, resampled_back_out])

                self.assertAllEqual(resampled_vs[0], resampled_vs[1])
                self.assertAllEqual(reresampled_vs[0], reresampled_vs[1])

                for v in resampled_vs[0]:
                    counts_resampled[v] += 1
                for v in reresampled_vs[0]:
                    counts_reresampled[v] += 1

            # assert that resampling worked as expected
            self.assert_expected(weights, rate, counts_resampled, n)

            # and that re-resampling gives the approx identity.
            self.assert_expected([1.0 for _ in weights],
                                 1.0,
                                 counts_reresampled,
                                 n,
                                 abs_delta=0.1 * n * count)
예제 #4
0
  def testRoundtrip(self, rate=0.25, count=5, n=500):
    """Tests `resample(x, weights)` and resample(resample(x, rate), 1/rate)`."""

    foo = self.get_values(count)
    bar = self.get_values(count)
    weights = self.get_weights(count)

    resampled_in, rates = resample.weighted_resample(
        [foo, bar], constant_op.constant(weights), rate, seed=123)

    resampled_back_out = resample.resample_at_rate(
        resampled_in, 1.0 / rates, seed=456)

    init = control_flow_ops.group(variables.local_variables_initializer(),
                                  variables.global_variables_initializer())
    with self.test_session() as s:
      s.run(init)  # initialize

      # outputs
      counts_resampled = collections.Counter()
      counts_reresampled = collections.Counter()
      for _ in range(n):
        resampled_vs, reresampled_vs = s.run([resampled_in, resampled_back_out])

        self.assertAllEqual(resampled_vs[0], resampled_vs[1])
        self.assertAllEqual(reresampled_vs[0], reresampled_vs[1])

        for v in resampled_vs[0]:
          counts_resampled[v] += 1
        for v in reresampled_vs[0]:
          counts_reresampled[v] += 1

      # assert that resampling worked as expected
      self.assert_expected(weights, rate, counts_resampled, n)

      # and that re-resampling gives the approx identity.
      self.assert_expected(
          [1.0 for _ in weights],
          1.0,
          counts_reresampled,
          n,
          abs_delta=0.1 * n * count)