Ejemplo n.º 1
0
  def testConditionallyEnqueueAndBatchTypes(self):
    tensor = tf.constant(1.0)
    keep_input = tf.constant(True)
    batch_size = 4

    # Check that output types are the same for 1 and 2-length input lists.
    output1 = sampling_ops._conditional_batch([tensor], keep_input, batch_size)  # pylint: disable=protected-access
    output2 = sampling_ops._conditional_batch(  # pylint: disable=protected-access
        [tensor, tensor], keep_input, batch_size)
    self.assertEqual(type(output1), type(output2))
Ejemplo n.º 2
0
  def testConditionallyEnqueueAndBatchTypes(self):
    tensor = tf.constant(1.0)
    accept_prob = tensor - 1
    batch_size = 4

    # Check that output types are the same for 1 and 2-length input lists.
    output1 = sampling_ops._conditional_batch([tensor], accept_prob, batch_size)  # pylint: disable=protected-access
    output2 = sampling_ops._conditional_batch(  # pylint: disable=protected-access
        [tensor, tensor], accept_prob, batch_size)
    self.assertEqual(type(output1), type(output2))
Ejemplo n.º 3
0
  def testConditionallyEnqueueAndBatch(self):
    tf.set_random_seed(1234)
    tensor = tf.cond(
        tf.greater(.5, tf.random_uniform([])),
        lambda: tf.constant(1.0),
        lambda: tf.constant(2.0))
    keep_input = tf.equal(tensor, 2.0)
    batch_size = 4

    # Set up the test graph.
    [batch] = sampling_ops._conditional_batch([tensor], keep_input, batch_size)  # pylint: disable=protected-access

    # Check conditional operation.
    with self.test_session():
      coord = tf.train.Coordinator()
      threads = tf.train.start_queue_runners(coord=coord)

      batch_np = batch.eval()

      coord.request_stop()
      coord.join(threads)

    # Check that all elements in batch come from tensors with acceptance prob
    # 1, so that none come from acceptance prob 0.
    self.assertListEqual(list(batch_np), [2.0] * batch_size)
Ejemplo n.º 4
0
  def testConditionallyEnqueueAndBatch(self):
    tf.set_random_seed(1234)
    tensor = tf.cond(
        tf.greater(.5, tf.random_uniform([])),
        lambda: tf.constant(1.0),
        lambda: tf.constant(2.0))
    accept_prob = tensor - 1
    batch_size = 4

    # Set up the test graph.
    [batch] = sampling_ops._conditional_batch([tensor], accept_prob, batch_size)  # pylint: disable=protected-access

    # Check conditional operation.
    with self.test_session():
      coord = tf.train.Coordinator()
      threads = tf.train.start_queue_runners(coord=coord)

      batch_np = batch.eval()

      coord.request_stop()
      coord.join(threads)

    # Check that all elements in batch come from tensors with acceptance prob
    # 1, so that none come from acceptance prob 0.
    self.assertListEqual(list(batch_np), [2.0] * batch_size)
Ejemplo n.º 5
0
  def testConditionallyEnqueueAndBatch(self):
    random_seed.set_random_seed(1234)
    tensor = control_flow_ops.cond(
        math_ops.greater(.5, random_ops.random_uniform([])),
        lambda: constant_op.constant(1.0), lambda: constant_op.constant(2.0))
    keep_input = math_ops.equal(tensor, 2.0)
    batch_size = 4

    # Set up the test graph.
    [batch] = sampling_ops._conditional_batch([tensor], keep_input, batch_size)  # pylint: disable=protected-access

    # Check conditional operation.
    with self.test_session():
      coord = coordinator.Coordinator()
      threads = queue_runner_impl.start_queue_runners(coord=coord)

      batch_np = batch.eval()

      coord.request_stop()
      coord.join(threads)

    # Check that all elements in batch come from tensors with acceptance prob
    # 1, so that none come from acceptance prob 0.
    self.assertListEqual(list(batch_np), [2.0] * batch_size)