def testRefSelect(self):
    index = tf.placeholder(tf.int32)

    # All inputs unknown.
    p1 = tf.placeholder(tf.float32_ref)
    p2 = tf.placeholder(tf.float32_ref)
    p3 = tf.placeholder(tf.float32_ref)
    s = control_flow_ops.ref_select(index, [p1, p2, p3])
    self.assertIs(None, s.get_shape().ndims)

    # All inputs known but different.
    p1 = tf.placeholder(tf.float32_ref, shape=[1, 2])
    p2 = tf.placeholder(tf.float32_ref, shape=[2, 1])
    s = control_flow_ops.ref_select(index, [p1, p2])
    self.assertIs(None, s.get_shape().ndims)

    # All inputs known but same.
    p1 = tf.placeholder(tf.float32_ref, shape=[1, 2])
    p2 = tf.placeholder(tf.float32_ref, shape=[1, 2])
    s = control_flow_ops.ref_select(index, [p1, p2])
    self.assertEqual([1, 2], s.get_shape())

    # Possibly the same but not guaranteed.
    p1 = tf.placeholder(tf.float32_ref, shape=[1, 2])
    p2 = tf.placeholder(tf.float32_ref)
    p2.set_shape([None, 2])
    s = control_flow_ops.ref_select(index, [p1, p2])
    self.assertEqual(None, s.get_shape())
  def testRefSelect(self):
    index = tf.placeholder(tf.int32)

    # All inputs unknown.
    p1 = tf.placeholder(tf.float32_ref)
    p2 = tf.placeholder(tf.float32_ref)
    p3 = tf.placeholder(tf.float32_ref)
    s = control_flow_ops.ref_select(index, [p1, p2, p3])
    self.assertIs(None, s.get_shape().ndims)

    # All inputs known but different.
    p1 = tf.placeholder(tf.float32_ref, shape=[1, 2])
    p2 = tf.placeholder(tf.float32_ref, shape=[2, 1])
    s = control_flow_ops.ref_select(index, [p1, p2])
    self.assertIs(None, s.get_shape().ndims)

    # All inputs known but same.
    p1 = tf.placeholder(tf.float32_ref, shape=[1, 2])
    p2 = tf.placeholder(tf.float32_ref, shape=[1, 2])
    s = control_flow_ops.ref_select(index, [p1, p2])
    self.assertEqual([1, 2], s.get_shape())

    # Possibly the same but not guaranteed.
    p1 = tf.placeholder(tf.float32_ref, shape=[1, 2])
    p2 = tf.placeholder(tf.float32_ref)
    p2.set_shape([None, 2])
    s = control_flow_ops.ref_select(index, [p1, p2])
    self.assertEqual(None, s.get_shape())
Exemple #3
0
    def from_list(index, queues):
        """Create a queue using the queue reference from `queues[index]`.

    Args:
      index: An integer scalar tensor that determines the input that gets
        selected.
      queues: A list of `QueueBase` objects.

    Returns:
      A `QueueBase` object.

    Raises:
      TypeError: When `queues` is not a list of `QueueBase` objects,
        or when the data types of `queues` are not all the same.
    """
        if ((not queues) or (not isinstance(queues, list))
                or (not all(isinstance(x, QueueBase) for x in queues))):
            raise TypeError("A list of queues expected")

        dtypes = queues[0].dtypes
        if not all([dtypes == q.dtypes for q in queues[1:]]):
            raise TypeError("Queues do not have matching component dtypes.")

        queue_refs = [x.queue_ref for x in queues]
        selected_queue = control_flow_ops.ref_select(index, queue_refs)
        # TODO(josh11b): Unify the shapes of the queues too?
        return QueueBase(dtypes=dtypes, shapes=None, queue_ref=selected_queue)
  def from_list(index, queues):
    """Create a queue using the queue reference from `queues[index]`.

    Args:
      index: An integer scalar tensor that determines the input that gets
        selected.
      queues: A list of `QueueBase` objects.

    Returns:
      A `QueueBase` object.

    Raises:
      TypeError: When `queues` is not a list of `QueueBase` objects,
        or when the data types of `queues` are not all the same.
    """
    if ((not queues) or
        (not isinstance(queues, list)) or
        (not all(isinstance(x, QueueBase) for x in queues))):
      raise TypeError("A list of queues expected")

    dtypes = queues[0].dtypes
    if not all([dtypes == q.dtypes for q in queues[1:]]):
      raise TypeError("Queues do not have matching component dtypes.")

    queue_refs = [x.queue_ref for x in queues]
    selected_queue = control_flow_ops.ref_select(index, queue_refs)
    # TODO(josh11b): Unify the shapes of the queues too?
    return QueueBase(dtypes=dtypes, shapes=None, queue_ref=selected_queue)