def testRefSelect(self): index = tf.placeholder(tf.int32) # All inputs unknown. p1 = tf.placeholder(tf.float32_ref) p2 = tf.placeholder(tf.float32_ref) p3 = tf.placeholder(tf.float32_ref) s = control_flow_ops.ref_select(index, [p1, p2, p3]) self.assertIs(None, s.get_shape().ndims) # All inputs known but different. p1 = tf.placeholder(tf.float32_ref, shape=[1, 2]) p2 = tf.placeholder(tf.float32_ref, shape=[2, 1]) s = control_flow_ops.ref_select(index, [p1, p2]) self.assertIs(None, s.get_shape().ndims) # All inputs known but same. p1 = tf.placeholder(tf.float32_ref, shape=[1, 2]) p2 = tf.placeholder(tf.float32_ref, shape=[1, 2]) s = control_flow_ops.ref_select(index, [p1, p2]) self.assertEqual([1, 2], s.get_shape()) # Possibly the same but not guaranteed. p1 = tf.placeholder(tf.float32_ref, shape=[1, 2]) p2 = tf.placeholder(tf.float32_ref) p2.set_shape([None, 2]) s = control_flow_ops.ref_select(index, [p1, p2]) self.assertEqual(None, s.get_shape())
def from_list(index, queues): """Create a queue using the queue reference from `queues[index]`. Args: index: An integer scalar tensor that determines the input that gets selected. queues: A list of `QueueBase` objects. Returns: A `QueueBase` object. Raises: TypeError: When `queues` is not a list of `QueueBase` objects, or when the data types of `queues` are not all the same. """ if ((not queues) or (not isinstance(queues, list)) or (not all(isinstance(x, QueueBase) for x in queues))): raise TypeError("A list of queues expected") dtypes = queues[0].dtypes if not all([dtypes == q.dtypes for q in queues[1:]]): raise TypeError("Queues do not have matching component dtypes.") queue_refs = [x.queue_ref for x in queues] selected_queue = control_flow_ops.ref_select(index, queue_refs) # TODO(josh11b): Unify the shapes of the queues too? return QueueBase(dtypes=dtypes, shapes=None, queue_ref=selected_queue)