def testGraphConstructionFailures(self): accept_prob_fn = lambda _: constant_op.constant(1.0) batch_size = 32 # Data must have batch dimension if `enqueue_many` is `True`. with self.assertRaises(ValueError): sampling_ops.rejection_sample( [array_ops.zeros([])], accept_prob_fn, batch_size, enqueue_many=True) # Batch dimensions should be equal if `enqueue_many` is `True`. with self.assertRaises(ValueError): sampling_ops.rejection_sample( [array_ops.zeros([5, 1]), array_ops.zeros([4, 1])], accept_prob_fn, batch_size, enqueue_many=True)
def testRuntimeFailures(self): prob_ph = array_ops.placeholder(dtypes.float32, []) accept_prob_fn = lambda _: prob_ph batch_size = 32 # Set up graph. random_seed.set_random_seed(1234) sampling_ops.rejection_sample([array_ops.zeros([])], accept_prob_fn, batch_size, runtime_checks=True, name='rejection_sample') prob_tensor = ops.get_default_graph().get_tensor_by_name( 'rejection_sample/prob_with_checks:0') # Run session that should fail. with self.cached_session() as sess: for illegal_prob in [-0.1, 1.1]: with self.assertRaises(errors_impl.InvalidArgumentError): sess.run(prob_tensor, feed_dict={prob_ph: illegal_prob})
def testRuntimeFailures(self): prob_ph = array_ops.placeholder(dtypes.float32, []) accept_prob_fn = lambda _: prob_ph batch_size = 32 # Set up graph. random_seed.set_random_seed(1234) sampling_ops.rejection_sample( [array_ops.zeros([])], accept_prob_fn, batch_size, runtime_checks=True, name='rejection_sample') prob_tensor = ops.get_default_graph().get_tensor_by_name( 'rejection_sample/prob_with_checks:0') # Run session that should fail. with self.test_session() as sess: for illegal_prob in [-0.1, 1.1]: with self.assertRaises(errors_impl.InvalidArgumentError): sess.run(prob_tensor, feed_dict={prob_ph: illegal_prob})
def testNormalBehavior(self): tensor_list = [ control_flow_ops.cond( math_ops.greater(.5, random_ops.random_uniform([])), lambda: constant_op.constant(1.0), lambda: constant_op.constant(2.0)) ] accept_prob_fn = lambda x: x[0] - 1.0 batch_size = 10 # Set up graph. sample = sampling_ops.rejection_sample(tensor_list, accept_prob_fn, batch_size) with self.cached_session() as sess: coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(coord=coord) for _ in range(5): sample_np = sess.run(sample)[0] self.assertListEqual([2.0] * batch_size, list(sample_np)) coord.request_stop() coord.join(threads)
def testNormalBehavior(self): tensor_list = [ control_flow_ops.cond( math_ops.greater(.5, random_ops.random_uniform([])), lambda: constant_op.constant(1.0), lambda: constant_op.constant(2.0)) ] accept_prob_fn = lambda x: x[0] - 1.0 batch_size = 10 # Set up graph. sample = sampling_ops.rejection_sample(tensor_list, accept_prob_fn, batch_size) with self.test_session() as sess: coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(coord=coord) for _ in range(5): sample_np = sess.run(sample)[0] self.assertListEqual([2.0] * batch_size, list(sample_np)) coord.request_stop() coord.join(threads)