def _create_infeed_enqueue_ops_and_dequeue_fn(inputs_holder, run_config): """Utility to convert input_fn to enqueue and dequeue fns for TPU. Args: inputs_holder: An `_InputsHolder` holding features and labels. run_config: A `RunConfig` instance. Returns: A tuple of (dequeue_fn, enqueue_fn) """ if inputs_holder.sharded: sharded_inputs = inputs_holder.as_sharded_flattened_inputs() infeed_queue = tpu_feed.InfeedQueue( number_of_tuple_elements=len(sharded_inputs[0])) infeed_queue.set_configuration_from_sharded_input_tensors(sharded_inputs) else: unsharded_inputs = inputs_holder.as_flattened_inputs() infeed_queue = tpu_feed.InfeedQueue( tuple_types=[t.dtype for t in unsharded_inputs], tuple_shapes=[t.shape for t in unsharded_inputs]) infeed_queue.set_number_of_shards(inputs_holder.num_shards) def dequeue_fn(): """dequeue_fn is used by the train_step in TPU to retrieve the tensors.""" values = infeed_queue.generate_dequeue_op() return inputs_holder.unflatten_features_and_labels(values) def tpu_ordinal_function(index): """Return the TPU ordinal associated with a shard. Required because the enqueue ops are placed on CPU. Args: index: the shard index Returns: The ordinal of the TPU device the shard's infeed should be placed on. """ return index % 8 def enqueue_fn(): """enqueue_fn is used to add ops to the graph to send tensors.""" if inputs_holder.sharded: return infeed_queue.generate_enqueue_ops( sharded_inputs, tpu_ordinal_function=tpu_ordinal_function) else: job = _tpu_job(run_config) def placement_function(index): if job is None: return '/replica:0/task:0/device:CPU:0' else: return '/job:%s/replica:0/task:%d/device:CPU:0' % (job, index / 8) return infeed_queue.split_inputs_and_generate_enqueue_ops( unsharded_inputs, placement_function=placement_function) return (dequeue_fn, enqueue_fn)
def testUsingInfeedQueueWithRegularizer(self): """Test that Layer regularizers can reference data created in loops.""" def make_regularizer(scale): return lambda inputs: scale * math_ops.reduce_sum( math_ops.square(inputs)) def training_step(inputs, scale): outputs = convolutional.conv2d( inputs, filters=16, kernel_size=(3, 3), data_format="channels_first", kernel_regularizer=make_regularizer(scale)) loss = math_ops.reduce_mean(math_ops.square(outputs)) return loss.op inputs = array_ops.zeros(shape=(128, 32, 32, 16)) scale = array_ops.ones(shape=()) infeed = tpu_feed.InfeedQueue( tuple_types=[dtypes.float32, dtypes.float32], tuple_shapes=[inputs.shape, scale.shape]) def loop(): return training_loop.repeat(5, training_step, infeed_queue=infeed) # This should not throw an error. tpu.rewrite(loop)
def testVarArgsAndDefaults(self): """Tests that arg checker works for a function with varargs and defaults.""" def func(x, y, z=17, *q): # pylint: disable=keyword-arg-before-vararg return x + y + z + len(q) self.assertEqual(None, xla.check_function_argument_count(func, 2, None)) self.assertEqual(None, xla.check_function_argument_count(func, 3, None)) self.assertEqual(None, xla.check_function_argument_count(func, 4, None)) self.assertEqual(None, xla.check_function_argument_count(func, 5, None)) self.assertEqual('at least 2 arguments', xla.check_function_argument_count(func, 1, None)) queue = tpu_feed.InfeedQueue(1) self.assertEqual(None, xla.check_function_argument_count(func, 1, queue)) self.assertEqual(None, xla.check_function_argument_count(func, 2, queue)) self.assertEqual(None, xla.check_function_argument_count(func, 3, queue)) self.assertEqual(None, xla.check_function_argument_count(func, 4, queue)) self.assertEqual('at least 2 arguments', xla.check_function_argument_count(func, 0, queue))
def testVarArgsAndDefaults(self): """Tests that arg checker works for a function with varargs and defaults.""" def func(x, y, z=17, *q): return x + y + z + len(q) self.assertEqual( None, tpu_function.check_function_argument_count(func, 2, None)) self.assertEqual( None, tpu_function.check_function_argument_count(func, 3, None)) self.assertEqual( None, tpu_function.check_function_argument_count(func, 4, None)) self.assertEqual( None, tpu_function.check_function_argument_count(func, 5, None)) self.assertEqual( "at least 2 arguments", tpu_function.check_function_argument_count(func, 1, None)) queue = tpu_feed.InfeedQueue(1) self.assertEqual( None, tpu_function.check_function_argument_count(func, 1, queue)) self.assertEqual( None, tpu_function.check_function_argument_count(func, 2, queue)) self.assertEqual( None, tpu_function.check_function_argument_count(func, 3, queue)) self.assertEqual( None, tpu_function.check_function_argument_count(func, 4, queue)) self.assertEqual( "at least 2 arguments", tpu_function.check_function_argument_count(func, 0, queue))
def testDefaultArgs(self): """Tests that arg checker works for a function with no varargs.""" def func(x, y, z=17): return x + y + z self.assertEqual( None, tpu_function.check_function_argument_count(func, 3, None)) self.assertEqual( None, tpu_function.check_function_argument_count(func, 2, None)) self.assertEqual( "at least 2 arguments", tpu_function.check_function_argument_count(func, 1, None)) self.assertEqual( "at most 3 arguments", tpu_function.check_function_argument_count(func, 4, None)) queue = tpu_feed.InfeedQueue(1) self.assertEqual( None, tpu_function.check_function_argument_count(func, 2, queue)) self.assertEqual( None, tpu_function.check_function_argument_count(func, 1, queue)) self.assertEqual( "at least 2 arguments", tpu_function.check_function_argument_count(func, 0, queue)) self.assertEqual( "at most 3 arguments", tpu_function.check_function_argument_count(func, 4, queue))
def _create_infeed_enqueue_ops_and_dequeue_fn(run_config, features, labels): """Utility to convert input_fn to enqueue and dequeue fns for TPU. Mainly, three things need to be done here. 1. Calls the input_fn many times (`num_shards`) to infeed the data into TPU 2. Create a dequeue_fn used by the train_step inside TPU execution to dequeue the tensors. 3. Sets up the input thread to infeed. Args: run_config: run_config features: features labels: labels Returns: A tuple of (dequeue_fn, and thread main function) """ infeed_names = None infeed_tuple = [] if isinstance(features, dict): # We need a fixed ordering for enqueueing and dequeueing. infeed_names = [name for name in features] infeed_tuple.extend([features[name] for name in infeed_names]) else: infeed_tuple.append(features) # TODO(jhseu): Handle multi-head and None labels infeed_tuple.append(labels) # TODO(jhseu): Update when b/36470756 is settled. infeed_queue = tpu_feed.InfeedQueue( tuple_types=[t.dtype for t in infeed_tuple], tuple_shapes=[t.shape for t in infeed_tuple]) infeed_queue.set_number_of_shards(run_config.tpu_config.num_shards) def dequeue_fn(): """dequeue_fn is used by the train_step in TPU to retrieve the tensors.""" values = infeed_queue.generate_dequeue_op() if infeed_names is None: return values # Restore the feature dictionary and label. dequeued_features = {} for i in range(len(values) - 1): dequeued_features[infeed_names[i]] = values[i] label = values[-1] return dequeued_features, label def enqueue_fn(): """enqueue_fn is used to add ops to the graph to send tensors.""" job = _tpu_job(run_config) def placement_function(index): if job is None: return '/replica:0/task:0/device:CPU:0' else: return '/job:%s/replica:0/task:%d/device:CPU:0' % (job, index / 8) return infeed_queue.split_inputs_and_generate_enqueue_ops( infeed_tuple, placement_function=placement_function) return (dequeue_fn, enqueue_fn)
def testModification(self): """Tests modification of the queue post-construction.""" i = tpu_feed.InfeedQueue(number_of_tuple_elements=2) i.set_tuple_types([dtypes.float32, dtypes.int32]) self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32]) i.set_tuple_types([dtypes.float32, dtypes.float32]) self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.float32]) with self.assertRaises(ValueError): i.set_tuple_types([dtypes.float32]) i.set_tuple_shapes([[1], [2, 3]]) self.assertEqual(i.tuple_shapes, [[1], [2, 3]]) i.set_tuple_shapes([[1, 2], [3, 4]]) self.assertEqual(i.tuple_shapes, [[1, 2], [3, 4]]) with self.assertRaises(ValueError): i.set_tuple_shapes([[1, 2]]) i.set_number_of_shards(2) self.assertEqual(i.number_of_shards, 2) i.set_number_of_shards(3) self.assertEqual(i.number_of_shards, 3) t1 = constant_op.constant(1, dtypes.int32, shape=[6]) t2 = constant_op.constant(2.0, dtypes.float32, shape=[3, 18]) i.set_configuration_from_input_tensors([t1, t2]) self.assertEqual(i.tuple_shapes, [[6], [3, 18]]) self.assertEqual(i.tuple_types, [dtypes.int32, dtypes.float32]) i.set_configuration_from_sharded_input_tensors([[t2, t1], [t2, t1]]) self.assertEqual(i.number_of_shards, 2) self.assertEqual(i.tuple_shapes, [[6, 18], [12]]) self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32]) i.set_shard_dimensions([1, 0]) i.set_number_of_shards(3) with self.assertRaises(ValueError): i.set_number_of_shards(4)
def testFreezing(self): """Tests freezing the queue.""" i = tpu_feed.InfeedQueue(number_of_tuple_elements=2) t1 = constant_op.constant(1, dtypes.int32, shape=[2]) t2 = constant_op.constant(2.0, dtypes.float32, shape=[2, 4]) i.set_configuration_from_sharded_input_tensors([[t2, t1], [t2, t1]]) self.assertEqual(i.number_of_shards, 2) self.assertEqual(i.tuple_shapes, [[4, 4], [4]]) self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32]) self.assertEqual(i.shard_dimensions, [0, 0]) i.freeze() i.set_number_of_shards(2) i.set_tuple_shapes([[4, 4], [4]]) i.set_tuple_types([dtypes.float32, dtypes.int32]) i.set_shard_dimensions([0, 0]) with self.assertRaises(ValueError): i.set_number_of_shards(1) with self.assertRaises(ValueError): i.set_tuple_shapes([[8, 8], [8]]) with self.assertRaises(ValueError): i.set_tuple_types([dtypes.int32, dtypes.float32]) with self.assertRaises(ValueError): i.set_shard_dimensions([1, 0]) self.assertEqual(i.number_of_shards, 2) self.assertEqual(i.tuple_shapes, [[4, 4], [4]]) self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32]) self.assertEqual(i.shard_dimensions, [0, 0])
def testSimple(self): """Tests that arg checker works for functions with no varargs or defaults. """ def func(x, y, z): return x + y + z self.assertEqual(None, xla.check_function_argument_count(func, 3, None)) self.assertEqual('exactly 3 arguments', xla.check_function_argument_count(func, 2, None)) queue = tpu_feed.InfeedQueue(2) self.assertEqual(None, xla.check_function_argument_count(func, 1, queue)) self.assertEqual('exactly 3 arguments', xla.check_function_argument_count(func, 2, queue))
def testConstructor(self): """Tests that the constructor can be called with different arguments.""" i = tpu_feed.InfeedQueue(number_of_tuple_elements=2) self.assertEqual(i.number_of_tuple_elements, 2) self.assertEqual(i.tuple_types, None) self.assertEqual(i.tuple_shapes, None) self.assertEqual(i.number_of_shards, None) i = tpu_feed.InfeedQueue( tuple_types=[dtypes.float32, dtypes.int32, dtypes.int32]) self.assertEqual(i.number_of_tuple_elements, 3) self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32, dtypes.int32]) self.assertEqual(i.tuple_shapes, None) self.assertEqual(i.number_of_shards, None) i = tpu_feed.InfeedQueue(tuple_shapes=[[1], [2, 3]]) self.assertEqual(i.number_of_tuple_elements, 2) self.assertEqual(i.tuple_types, None) self.assertEqual(i.tuple_shapes, [[1], [2, 3]]) self.assertEqual(i.number_of_shards, None) i = tpu_feed.InfeedQueue(shard_dimensions=[1, 0, 7]) self.assertEqual(i.number_of_tuple_elements, 3) self.assertEqual(i.tuple_types, None) self.assertEqual(i.tuple_shapes, None) self.assertEqual([p.shard_dimension for p in i.sharding_policies], [1, 0, 7]) with self.assertRaises(ValueError): i = tpu_feed.InfeedQueue() with self.assertRaises(ValueError): i = tpu_feed.InfeedQueue(number_of_tuple_elements=2, tuple_types=[dtypes.float32]) with self.assertRaises(ValueError): i = tpu_feed.InfeedQueue(number_of_tuple_elements=2, tuple_shapes=[[1]]) with self.assertRaises(ValueError): i = tpu_feed.InfeedQueue(number_of_tuple_elements=2, shard_dimensions=[1]) with self.assertRaises(ValueError): i = tpu_feed.InfeedQueue(tuple_shapes=[[1], [2, 3]], shard_dimensions=[1])
def testVarArgs(self): """Tests that arg checker works for a function with varargs.""" def func(x, y, *z): return x + y + len(z) self.assertEqual(None, xla.check_function_argument_count(func, 2, None)) self.assertEqual(None, xla.check_function_argument_count(func, 3, None)) self.assertEqual(None, xla.check_function_argument_count(func, 4, None)) self.assertEqual('at least 2 arguments', xla.check_function_argument_count(func, 1, None)) queue = tpu_feed.InfeedQueue(1) self.assertEqual(None, xla.check_function_argument_count(func, 1, queue)) self.assertEqual(None, xla.check_function_argument_count(func, 2, queue)) self.assertEqual(None, xla.check_function_argument_count(func, 3, queue)) self.assertEqual('at least 2 arguments', xla.check_function_argument_count(func, 0, queue))
def _create_infeed_enqueue_ops_and_dequeue_fn(inputs_holder): """Utility to convert input_fn to enqueue and dequeue fns for TPU. Args: inputs_holder: An `_InputsHolder` holding features and labels. Returns: A tuple of (dequeue_fn, enqueue_fn) """ sharded_inputs = inputs_holder.as_sharded_flattened_inputs() infeed_queue = tpu_feed.InfeedQueue( number_of_tuple_elements=len(sharded_inputs[0])) infeed_queue.set_configuration_from_sharded_input_tensors(sharded_inputs) def dequeue_fn(): """dequeue_fn is used by the train_step in TPU to retrieve the tensors.""" values = infeed_queue.generate_dequeue_op() return inputs_holder.unflatten_features_and_labels(values) def tpu_ordinal_function(index): """Return the TPU ordinal associated with a shard. Required because the enqueue ops are placed on CPU. Args: index: the shard index Returns: The ordinal of the TPU device the shard's infeed should be placed on. """ return index % 8 def enqueue_fn(): """enqueue_fn is used to add ops to the graph to send tensors.""" return infeed_queue.generate_enqueue_ops( sharded_inputs, tpu_ordinal_function=tpu_ordinal_function) return (dequeue_fn, enqueue_fn)
def eval_enqueue_ops_fn(): """Enqueue ops function for one host.""" per_host_sharded_inputs = [] control_deps = [] for _ in range(self.replicas_per_worker): with tf.control_dependencies(control_deps): features = iterator.get_next() if self.use_spatial_partition: self.eval_input_dims_flattener.validate_and_flatten_input_dims( features, None) self.eval_feature_structure["features"] = features flattened_inputs = data_nest.flatten( self.eval_feature_structure) control_deps.extend(flattened_inputs) per_host_sharded_inputs.append(flattened_inputs) if self.use_spatial_partition: flattened_input_dims = (self.eval_input_dims_flattener. flattened_input_dims) # pylint: disable=protected-access infeed = tpu_feed._PartitionedInfeedQueue( number_of_tuple_elements=len( per_host_sharded_inputs[0]), host_id=host_id, input_partition_dims=flattened_input_dims, device_assignment=self.device_assignment) self.eval_infeed_queue.append(infeed) return infeed.generate_enqueue_ops( per_host_sharded_inputs) infeed = tpu_feed.InfeedQueue(number_of_tuple_elements=len( per_host_sharded_inputs[0])) self.eval_infeed_queue.append(infeed) return infeed.generate_enqueue_ops( per_host_sharded_inputs, tpu_ordinal_function=utils.tpu_ordinal_fn)
def _create_infeed_enqueue_ops_and_dequeue_fn(run_config, features, labels): """Utility to convert input_fn to enqueue and dequeue fns for TPU. Mainly, three things need to be done here. 1. Calls the input_fn many times (`num_shards`) to infeed the data into TPU 2. Create a dequeue_fn used by the train_step inside TPU execution to dequeue the tensors. 3. Sets up the input thread to infeed. Args: run_config: run_config features: features labels: labels Returns: A tuple of (dequeue_fn, enqueue_fn) """ infeed_names = None sharded_inputs = [] if isinstance(features[0], dict): # We need a fixed ordering for enqueueing and dequeueing. infeed_names = [name for name in features[0]] for shard in range(run_config.tpu_config.num_shards): inputs = [] if infeed_names is None: inputs.append(features[shard]) else: for name in infeed_names: inputs.append(features[shard][name]) if labels is not None: inputs.append(labels[shard]) sharded_inputs.append(inputs) infeed_queue = tpu_feed.InfeedQueue( number_of_tuple_elements=len(sharded_inputs[0])) infeed_queue.set_configuration_from_sharded_input_tensors(sharded_inputs) def dequeue_fn(): """dequeue_fn is used by the train_step in TPU to retrieve the tensors.""" values = infeed_queue.generate_dequeue_op() expected_num_tensors = 0 if labels is not None: expected_num_tensors += 1 if infeed_names is None: expected_num_tensors += 1 else: expected_num_tensors += len(infeed_names) assert len(values) == expected_num_tensors dequeue_label = None if labels is not None: dequeue_label = values[-1] if infeed_names is None: return values[0], dequeue_label # Restore the feature dictionary and label. dequeued_features = {} for i in range(len(infeed_names)): dequeued_features[infeed_names[i]] = values[i] return dequeued_features, dequeue_label def tpu_ordinal_function(index): """Return the TPU ordinal associated with a shard. Required because the enqueue ops are placed on CPU. Args: index: the shard index Returns: The ordinal of the TPU device the shard's infeed should be placed on. """ return index % 8 def enqueue_fn(): """enqueue_fn is used to add ops to the graph to send tensors.""" return infeed_queue.generate_enqueue_ops( sharded_inputs, tpu_ordinal_function=tpu_ordinal_function) return (dequeue_fn, enqueue_fn)
def enqueue_ops_fn(): """Enqueue ops function for one host.""" per_host_sharded_inputs = [] control_deps = [] for _ in range(self.replicas_per_worker): with tf.control_dependencies(control_deps): features, labels = iterator.get_next() if self.use_spatial_partition: num_elements = [] for i, d in enumerate(ssd_constants.FEATURE_SIZES): num_elements.append( d * d * ssd_constants.NUM_DEFAULTS[i]) gt_boxes = tf.split(labels[ssd_constants.BOXES], num_elements, 1) gt_classes = tf.split( labels[ssd_constants.CLASSES], num_elements, 1) def transpose_gt_box(gt_box, i): return tf.transpose( tf.reshape(gt_box, [ -1, ssd_constants.NUM_DEFAULTS[i], ssd_constants.FEATURE_SIZES[i], ssd_constants.FEATURE_SIZES[i], 4 ]), [0, 2, 3, 1, 4]) def transpose_gt_class(gt_class, i): return tf.transpose( tf.reshape(gt_class, [ -1, ssd_constants.NUM_DEFAULTS[i], ssd_constants.FEATURE_SIZES[i], ssd_constants.FEATURE_SIZES[i] ]), [0, 2, 3, 1]) # TODO(dehao): This causes 3s overhead in startup, fix it. labels[ssd_constants.BOXES] = { i: transpose_gt_box(gt_boxes[i], i) for i in range(len(ssd_constants.NUM_DEFAULTS)) } labels[ssd_constants.CLASSES] = { i: transpose_gt_class(gt_classes[i], i) for i in range(len(ssd_constants.NUM_DEFAULTS)) } self.feature_structure["features"] = features self.feature_structure["labels"] = labels flattened_inputs = data_nest.flatten( self.feature_structure) control_deps.extend(flattened_inputs) per_host_sharded_inputs.append(flattened_inputs) if self.use_spatial_partition: flattened_input_dims = [] for i in per_host_sharded_inputs[0]: if i.shape.ndims >= len( self.input_partition_dims[0]): if i.shape.as_list() == self.feature_structure[ "features"].shape.as_list(): flattened_input_dims.append( self.input_partition_dims[0]) else: flattened_input_dims.append( FLAGS.input_partition_dims + [1] * (i.shape.ndims - len(self.input_partition_dims[0]))) else: flattened_input_dims.append([1] * i.shape.ndims) # pylint: disable=protected-access infeed = tpu_feed._PartitionedInfeedQueue( number_of_tuple_elements=len( per_host_sharded_inputs[0]), host_id=host_id, input_partition_dims=flattened_input_dims, device_assignment=self.device_assignment) self.infeed_queue.append(infeed) return infeed.generate_enqueue_ops( per_host_sharded_inputs) infeed = tpu_feed.InfeedQueue(number_of_tuple_elements=len( per_host_sharded_inputs[0])) self.infeed_queue.append(infeed) return infeed.generate_enqueue_ops( per_host_sharded_inputs, tpu_ordinal_function=utils.tpu_ordinal_fn)