def testGraphConstructionFailures(self):
    accept_prob_fn = lambda _: constant_op.constant(1.0)
    batch_size = 32
    # Data must have batch dimension if `enqueue_many` is `True`.
    with self.assertRaises(ValueError):
      sampling_ops.rejection_sample(
          [array_ops.zeros([])], accept_prob_fn, batch_size, enqueue_many=True)

    # Batch dimensions should be equal if `enqueue_many` is `True`.
    with self.assertRaises(ValueError):
      sampling_ops.rejection_sample(
          [array_ops.zeros([5, 1]), array_ops.zeros([4, 1])],
          accept_prob_fn,
          batch_size,
          enqueue_many=True)
  def testGraphConstructionFailures(self):
    accept_prob_fn = lambda _: constant_op.constant(1.0)
    batch_size = 32
    # Data must have batch dimension if `enqueue_many` is `True`.
    with self.assertRaises(ValueError):
      sampling_ops.rejection_sample(
          [array_ops.zeros([])], accept_prob_fn, batch_size, enqueue_many=True)

    # Batch dimensions should be equal if `enqueue_many` is `True`.
    with self.assertRaises(ValueError):
      sampling_ops.rejection_sample(
          [array_ops.zeros([5, 1]), array_ops.zeros([4, 1])],
          accept_prob_fn,
          batch_size,
          enqueue_many=True)
    def testRuntimeFailures(self):
        prob_ph = array_ops.placeholder(dtypes.float32, [])
        accept_prob_fn = lambda _: prob_ph
        batch_size = 32

        # Set up graph.
        random_seed.set_random_seed(1234)
        sampling_ops.rejection_sample([array_ops.zeros([])],
                                      accept_prob_fn,
                                      batch_size,
                                      runtime_checks=True,
                                      name='rejection_sample')
        prob_tensor = ops.get_default_graph().get_tensor_by_name(
            'rejection_sample/prob_with_checks:0')

        # Run session that should fail.
        with self.cached_session() as sess:
            for illegal_prob in [-0.1, 1.1]:
                with self.assertRaises(errors_impl.InvalidArgumentError):
                    sess.run(prob_tensor, feed_dict={prob_ph: illegal_prob})
  def testRuntimeFailures(self):
    prob_ph = array_ops.placeholder(dtypes.float32, [])
    accept_prob_fn = lambda _: prob_ph
    batch_size = 32

    # Set up graph.
    random_seed.set_random_seed(1234)
    sampling_ops.rejection_sample(
        [array_ops.zeros([])],
        accept_prob_fn,
        batch_size,
        runtime_checks=True,
        name='rejection_sample')
    prob_tensor = ops.get_default_graph().get_tensor_by_name(
        'rejection_sample/prob_with_checks:0')

    # Run session that should fail.
    with self.test_session() as sess:
      for illegal_prob in [-0.1, 1.1]:
        with self.assertRaises(errors_impl.InvalidArgumentError):
          sess.run(prob_tensor, feed_dict={prob_ph: illegal_prob})
    def testNormalBehavior(self):
        tensor_list = [
            control_flow_ops.cond(
                math_ops.greater(.5, random_ops.random_uniform([])),
                lambda: constant_op.constant(1.0),
                lambda: constant_op.constant(2.0))
        ]
        accept_prob_fn = lambda x: x[0] - 1.0
        batch_size = 10

        # Set up graph.
        sample = sampling_ops.rejection_sample(tensor_list, accept_prob_fn,
                                               batch_size)

        with self.cached_session() as sess:
            coord = coordinator.Coordinator()
            threads = queue_runner_impl.start_queue_runners(coord=coord)

            for _ in range(5):
                sample_np = sess.run(sample)[0]
                self.assertListEqual([2.0] * batch_size, list(sample_np))

            coord.request_stop()
            coord.join(threads)
  def testNormalBehavior(self):
    tensor_list = [
        control_flow_ops.cond(
            math_ops.greater(.5, random_ops.random_uniform([])),
            lambda: constant_op.constant(1.0),
            lambda: constant_op.constant(2.0))
    ]
    accept_prob_fn = lambda x: x[0] - 1.0
    batch_size = 10

    # Set up graph.
    sample = sampling_ops.rejection_sample(tensor_list, accept_prob_fn,
                                           batch_size)

    with self.test_session() as sess:
      coord = coordinator.Coordinator()
      threads = queue_runner_impl.start_queue_runners(coord=coord)

      for _ in range(5):
        sample_np = sess.run(sample)[0]
        self.assertListEqual([2.0] * batch_size, list(sample_np))

      coord.request_stop()
      coord.join(threads)