Beispiel #1
0
    def test_var_args_and_defaults(self):
        """Tests that arg checker works for a function with varargs and defaults."""
        def func(x, y, z=17, *q):  # pylint: disable=keyword-arg-before-vararg
            return x + y + z + len(q)

        self.assertEqual(None,
                         xla.check_function_argument_count(func, 2, None))
        self.assertEqual(None,
                         xla.check_function_argument_count(func, 3, None))
        self.assertEqual(None,
                         xla.check_function_argument_count(func, 4, None))
        self.assertEqual(None,
                         xla.check_function_argument_count(func, 5, None))
        self.assertEqual('at least 2 arguments',
                         xla.check_function_argument_count(func, 1, None))
        queue = tpu_feed.InfeedQueue(1)
        self.assertEqual(None,
                         xla.check_function_argument_count(func, 1, queue))
        self.assertEqual(None,
                         xla.check_function_argument_count(func, 2, queue))
        self.assertEqual(None,
                         xla.check_function_argument_count(func, 3, queue))
        self.assertEqual(None,
                         xla.check_function_argument_count(func, 4, queue))
        self.assertEqual('at least 2 arguments',
                         xla.check_function_argument_count(func, 0, queue))
Beispiel #2
0
 def testModification(self):
     """Tests modification of the queue post-construction."""
     i = tpu_feed.InfeedQueue(number_of_tuple_elements=2)
     i.set_tuple_types([dtypes.float32, dtypes.int32])
     self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32])
     i.set_tuple_types([dtypes.float32, dtypes.float32])
     self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.float32])
     with self.assertRaises(ValueError):
         i.set_tuple_types([dtypes.float32])
     i.set_tuple_shapes([[1], [2, 3]])
     self.assertEqual(i.tuple_shapes, [[1], [2, 3]])
     i.set_tuple_shapes([[1, 2], [3, 4]])
     self.assertEqual(i.tuple_shapes, [[1, 2], [3, 4]])
     with self.assertRaises(ValueError):
         i.set_tuple_shapes([[1, 2]])
     i.set_number_of_shards(2)
     self.assertEqual(i.number_of_shards, 2)
     i.set_number_of_shards(3)
     self.assertEqual(i.number_of_shards, 3)
     t1 = constant_op.constant(1, dtypes.int32, shape=[6])
     t2 = constant_op.constant(2.0, dtypes.float32, shape=[3, 18])
     i.set_configuration_from_input_tensors([t1, t2])
     self.assertEqual(i.tuple_shapes, [[6], [3, 18]])
     self.assertEqual(i.tuple_types, [dtypes.int32, dtypes.float32])
     i.set_configuration_from_sharded_input_tensors([[t2, t1], [t2, t1]])
     self.assertEqual(i.number_of_shards, 2)
     self.assertEqual(i.tuple_shapes, [[6, 18], [12]])
     self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32])
     i.set_shard_dimensions([1, 0])
     i.set_number_of_shards(3)
     with self.assertRaises(ValueError):
         i.set_number_of_shards(4)
Beispiel #3
0
    def testUsingInfeedQueueWithRegularizer(self):
        """Test that Layer regularizers can reference data created in loops."""

        with ops.Graph().as_default():

            def make_regularizer(scale):
                def regularizer(inputs):
                    return scale * math_ops.reduce_sum(math_ops.square(inputs))

                return regularizer

            def training_step(inputs, scale):
                outputs = convolutional.conv2d(
                    inputs,
                    filters=16,
                    kernel_size=(3, 3),
                    data_format="channels_first",
                    kernel_regularizer=make_regularizer(scale))
                loss = math_ops.reduce_mean(math_ops.square(outputs))
                return loss.op

            inputs = array_ops.zeros(shape=(128, 32, 32, 16))
            scale = array_ops.ones(shape=())
            infeed = tpu_feed.InfeedQueue(
                tuple_types=[dtypes.float32, dtypes.float32],
                tuple_shapes=[inputs.shape, scale.shape])

            def loop():
                return training_loop.repeat(5,
                                            training_step,
                                            infeed_queue=infeed)

            # This should not throw an error.
            tpu.rewrite(loop)
Beispiel #4
0
 def testFreezing(self):
     """Tests freezing the queue."""
     i = tpu_feed.InfeedQueue(number_of_tuple_elements=2)
     t1 = constant_op.constant(1, dtypes.int32, shape=[2])
     t2 = constant_op.constant(2.0, dtypes.float32, shape=[2, 4])
     i.set_configuration_from_sharded_input_tensors([[t2, t1], [t2, t1]])
     self.assertEqual(i.number_of_shards, 2)
     self.assertEqual(i.tuple_shapes, [[4, 4], [4]])
     self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32])
     self.assertEqual(i.shard_dimensions, [0, 0])
     i.freeze()
     i.set_number_of_shards(2)
     i.set_tuple_shapes([[4, 4], [4]])
     i.set_tuple_types([dtypes.float32, dtypes.int32])
     i.set_shard_dimensions([0, 0])
     with self.assertRaises(ValueError):
         i.set_number_of_shards(1)
     with self.assertRaises(ValueError):
         i.set_tuple_shapes([[8, 8], [8]])
     with self.assertRaises(ValueError):
         i.set_tuple_types([dtypes.int32, dtypes.float32])
     with self.assertRaises(ValueError):
         i.set_shard_dimensions([1, 0])
     self.assertEqual(i.number_of_shards, 2)
     self.assertEqual(i.tuple_shapes, [[4, 4], [4]])
     self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32])
     self.assertEqual(i.shard_dimensions, [0, 0])
      def enqueue_ops_fn(idx):
        """Enqueue ops function for one host.."""
        with tf.device(device):
          sharded_inputs = []
          start_idx = 0
          if host_id in range(0, self.hparams.num_infeed_workers * 2, 2):
            core_id = tf.constant(
                host_id * self.hparams.num_shards_per_host,
                shape=[1],
                dtype=tf.int32)
            if self.hparams.use_synthetic_data:
              features = output
            else:

              def true_fn():
                return iterator.get_next()

              def false_fn():
                return {
                    k: tf.zeros_like(self.feature_structure["features"][k])
                    for k in self.feature_structure["features"]
                }

              features = tf.cond(
                  tf.equal(idx % self.hparams.num_infeed_workers, host_id // 2),
                  true_fn, false_fn)
            sharded_inputs.append(
                data_nest.flatten({
                    "features": features,
                    "core_id": core_id
                }))
            start_idx = 1
          for i in range(start_idx, self.hparams.num_shards_per_host):
            sharded_inputs.append(
                data_nest.flatten({
                    "features": {
                        k: tf.zeros_like(self.feature_structure["features"][k])
                        for k in self.feature_structure["features"]
                    },
                    "core_id":
                        tf.constant(
                            host_id * self.hparams.num_shards_per_host + i,
                            shape=[1],
                            dtype=tf.int32)
                }))
        infeed = tpu_feed.InfeedQueue(
            number_of_tuple_elements=len(sharded_inputs[0]))
        self.infeed_queue.append(infeed)

        def tpu_ordinal_fn(shard_index_in_host):
          return shard_index_in_host % self.hparams.num_shards_per_host

        return infeed.generate_enqueue_ops(
            sharded_inputs, tpu_ordinal_function=tpu_ordinal_fn)
Beispiel #6
0
    def test_simple(self):
        """Tests that arg checker works for functions with no varargs or defaults.
    """
        def func(x, y, z):
            return x + y + z

        self.assertEqual(None,
                         xla.check_function_argument_count(func, 3, None))
        self.assertEqual('exactly 3 arguments',
                         xla.check_function_argument_count(func, 2, None))
        queue = tpu_feed.InfeedQueue(2)
        self.assertEqual(None,
                         xla.check_function_argument_count(func, 1, queue))
        self.assertEqual('exactly 3 arguments',
                         xla.check_function_argument_count(func, 2, queue))
Beispiel #7
0
 def testConstructor(self):
     """Tests that the constructor can be called with different arguments."""
     i = tpu_feed.InfeedQueue(number_of_tuple_elements=2)
     self.assertEqual(i.number_of_tuple_elements, 2)
     self.assertEqual(i.tuple_types, None)
     self.assertEqual(i.tuple_shapes, None)
     self.assertEqual(i.number_of_shards, None)
     i = tpu_feed.InfeedQueue(
         tuple_types=[dtypes.float32, dtypes.int32, dtypes.int32])
     self.assertEqual(i.number_of_tuple_elements, 3)
     self.assertEqual(i.tuple_types,
                      [dtypes.float32, dtypes.int32, dtypes.int32])
     self.assertEqual(i.tuple_shapes, None)
     self.assertEqual(i.number_of_shards, None)
     i = tpu_feed.InfeedQueue(tuple_shapes=[[1], [2, 3]])
     self.assertEqual(i.number_of_tuple_elements, 2)
     self.assertEqual(i.tuple_types, None)
     self.assertEqual(i.tuple_shapes, [[1], [2, 3]])
     self.assertEqual(i.number_of_shards, None)
     i = tpu_feed.InfeedQueue(shard_dimensions=[1, 0, 7])
     self.assertEqual(i.number_of_tuple_elements, 3)
     self.assertEqual(i.tuple_types, None)
     self.assertEqual(i.tuple_shapes, None)
     self.assertEqual([p.shard_dimension for p in i.sharding_policies],
                      [1, 0, 7])
     with self.assertRaises(ValueError):
         i = tpu_feed.InfeedQueue()
     with self.assertRaises(ValueError):
         i = tpu_feed.InfeedQueue(number_of_tuple_elements=2,
                                  tuple_types=[dtypes.float32])
     with self.assertRaises(ValueError):
         i = tpu_feed.InfeedQueue(number_of_tuple_elements=2,
                                  tuple_shapes=[[1]])
     with self.assertRaises(ValueError):
         i = tpu_feed.InfeedQueue(number_of_tuple_elements=2,
                                  shard_dimensions=[1])
     with self.assertRaises(ValueError):
         i = tpu_feed.InfeedQueue(tuple_shapes=[[1], [2, 3]],
                                  shard_dimensions=[1])
Beispiel #8
0
    def _enqueue_laidout_tensors(self, all_laidout_tensors):
        """Generate enqueue ops to enqueue all_laidout_tensors."""
        def _tpu_ordinal_function_impl(pnum):
            return self._p_dev.ordered_ordinals[pnum]

        def _placement_function_impl(pnum):
            return self._p_dev.ordered_hosts[pnum]

        laidout_tensors0 = all_laidout_tensors[0]
        infeed_queue = tpu_feed.InfeedQueue(
            number_of_tuple_elements=len(laidout_tensors0),
            tuple_types=[x.dtype for x in laidout_tensors0],
            tuple_shapes=[x.shape for x in laidout_tensors0])
        enqueue_ops = infeed_queue.generate_enqueue_ops(
            all_laidout_tensors,
            tpu_ordinal_function=_tpu_ordinal_function_impl,
            placement_function=_placement_function_impl)

        return infeed_queue, enqueue_ops
        def enqueue_ops_fn():
          """Enqueue ops function for one host."""
          per_host_sharded_inputs = []
          control_deps = []
          for _ in range(self.hparams.num_shards_per_host):
            with tf.control_dependencies(control_deps):
              features = iterator.get_next()
            self.eval_feature_structure["features"] = features
            flattened_inputs = data_nest.flatten(self.eval_feature_structure)
            control_deps.extend(flattened_inputs)
            per_host_sharded_inputs.append(flattened_inputs)

          infeed = tpu_feed.InfeedQueue(
              number_of_tuple_elements=len(per_host_sharded_inputs[0]))
          self.eval_infeed_queue.append(infeed)

          def tpu_ordinal_fn(shard_index_in_host):
            return shard_index_in_host % self.hparams.num_shards_per_host

          return infeed.generate_enqueue_ops(
              per_host_sharded_inputs, tpu_ordinal_function=tpu_ordinal_fn)
Beispiel #10
0
    def test_var_args(self):
        """Tests that arg checker works for a function with varargs."""
        def func(x, y, *z):
            return x + y + len(z)

        self.assertEqual(None,
                         xla.check_function_argument_count(func, 2, None))
        self.assertEqual(None,
                         xla.check_function_argument_count(func, 3, None))
        self.assertEqual(None,
                         xla.check_function_argument_count(func, 4, None))
        self.assertEqual('at least 2 arguments',
                         xla.check_function_argument_count(func, 1, None))
        queue = tpu_feed.InfeedQueue(1)
        self.assertEqual(None,
                         xla.check_function_argument_count(func, 1, queue))
        self.assertEqual(None,
                         xla.check_function_argument_count(func, 2, queue))
        self.assertEqual(None,
                         xla.check_function_argument_count(func, 3, queue))
        self.assertEqual('at least 2 arguments',
                         xla.check_function_argument_count(func, 0, queue))
Beispiel #11
0
    def _config_infeed(self,
                       num_partitions,
                       device_assignment,
                       batch_size,
                       key_size=2,
                       return_tgt_mask=False,
                       use_partitioned_infeed_queue=False):
        """Config the infeed ops and args."""
        zero_batch = get_zero_batch(batch_size=batch_size,
                                    max_len=self._prefix_max_len,
                                    key_size=key_size,
                                    return_tgt_mask=return_tgt_mask)

        host_device = device_assignment.host_device(replica=0, job=self._tpu)
        host_id = int(host_device.split('/task:')[1].split('/device:')[0])
        input_partition_dims = [[num_partitions] + [1] * (len(x.shape) - 1)
                                for x in zero_batch]

        if use_partitioned_infeed_queue:
            infeed = tpu_feed._PartitionedInfeedQueue(  # pylint: disable=protected-access
                number_of_tuple_elements=len(zero_batch),
                host_id=host_id,
                input_partition_dims=input_partition_dims,
                device_assignment=device_assignment)
        else:
            infeed = tpu_feed.InfeedQueue(
                number_of_tuple_elements=len(zero_batch))

        self.infeed_args = []
        for x in zero_batch:
            p = tf.placeholder(tf.as_dtype(x.dtype), shape=x.shape)
            self.infeed_args += [p]
        if use_partitioned_infeed_queue:
            self.infeed_op = infeed.generate_enqueue_ops([self.infeed_args])
        else:
            self.infeed_op = infeed.split_inputs_and_generate_enqueue_ops(
                self.infeed_args, device_assignment=device_assignment)
        return infeed
                def enqueue_ops_fn():
                    """Enqueue ops function for one host."""
                    per_host_sharded_inputs = []
                    control_deps = []
                    for _ in range(FLAGS.tpu_num_shards_per_host):
                        if "eval" in task:
                            with tf.control_dependencies(control_deps):
                                features = iterator.get_next()
                            feature_structure["features"] = features
                        else:
                            with tf.control_dependencies(control_deps):
                                features, labels = iterator.get_next()
                            feature_structure["features"] = features
                            feature_structure["labels"] = labels
                        flattened_inputs = data_nest.flatten(feature_structure)
                        control_deps.extend(flattened_inputs)
                        per_host_sharded_inputs.append(flattened_inputs)

                    infeed = tpu_feed.InfeedQueue(number_of_tuple_elements=len(
                        per_host_sharded_inputs[0]))
                    infeed_queue.append(infeed)
                    return infeed.generate_enqueue_ops(
                        per_host_sharded_inputs,
                        tpu_ordinal_function=low_level_utils.tpu_ordinal_fn)
Beispiel #13
0
    def CreateTpuFeeds(self):
        """Creates the TPU infeed queue from preprocessed batch."""
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts
        tf.logging.info(
            'CreateTPUFeeds num_splits_per_client={} '
            'num_devices_per_split={} num_tpu_hosts={} use_per_host_infeed={}'.
            format(cluster.num_splits_per_client,
                   cluster.num_devices_per_split, num_tpu_hosts,
                   p.use_per_host_infeed))

        assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts)
        if (cluster.num_devices_per_split > num_cores_per_host
                and p.use_per_host_infeed):
            tf.logging.fatal(
                'Doesn\'t support per host infeed mode when '
                'num_devices_per_split({}) > num_cores_per_host({})'.format(
                    cluster.num_devices_per_split, num_cores_per_host))
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        with py_utils.outside_all_rewrites():
            assert py_utils.use_tpu()
            assert not self._made_tpu_infeed

            shards = tpu_function.get_tpu_context(
            ).number_of_shards // num_infeed_hosts
            tf.logging.info('shards {}'.format(shards))

            input_ops_list = []
            queues = []
            tpu_embedding_collection = tf.get_collection(
                py_utils.TPU_EMBEDDING)
            tpu_embedding = (tpu_embedding_collection[0]
                             if tpu_embedding_collection else None)

            if num_tpu_hosts > 1 and tpu_embedding is not None:
                if not p.use_per_host_infeed:
                    tf.logging.fatal(
                        'TPU Embedding must be used with per_host_infeed with multiple '
                        'TPU host topologies.')
            tpu_emb_input_keys = (list(
                tpu_embedding.feature_to_config_dict.keys())
                                  if tpu_embedding is not None else [])
            tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys)

            batch = None
            for task_id in range(num_infeed_hosts):
                host_device = '/task:{}/device:CPU:0'.format(task_id)
                with tf.device(host_device):
                    batch = self.GetPreprocessedInputBatch()
                    if isinstance(batch, py_utils.NestedMap):
                        # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU.
                        # Note that when MultiTaskData is used, bucket_keys will be at the
                        # second level of the dictionary.
                        batch = batch.FilterKeyVal(
                            lambda k, _: not k.endswith('bucket_keys'))
                    tf.logging.info('host_device: %s, batch: %r', host_device,
                                    batch)

                    if tpu_embedding is not None:
                        enqueue_dict_per_core = [
                            {} for _ in range(tpu_embedding.num_cores_per_host)
                        ]
                        num_cores_per_host = tpu_embedding.num_cores_per_host
                        for key in tpu_emb_input_keys:
                            feat = batch[key]
                            tpu_emb_feat_splitted = tf.split(
                                feat, num_cores_per_host)
                            for core, split in enumerate(
                                    tpu_emb_feat_splitted):
                                # Dense to sparse. Note the assumption of a padding id.
                                sample_indices = tf.where(
                                    tf.not_equal(split, -1))
                                embedding_indices = tf.gather_nd(
                                    split, sample_indices)
                                enqueue_data = tpu_embedding_lib.EnqueueData(
                                    embedding_indices, sample_indices)
                                enqueue_dict_per_core[core][key] = enqueue_data
                        input_ops_list += tpu_embedding.generate_enqueue_ops(
                            enqueue_dict_per_core)

                    for k, x in batch.FlattenItems():
                        assert x.shape.is_fully_defined(), (
                            'Shape must be fully defined: %s: %s' % (k, x))
                        # TODO(cwhipkey): if it's a string (or other type not supported on
                        # TPU), drop it from feeding and on the other end add in an op that
                        # fails if used.
                    shapes = batch.Transform(lambda x: x.shape).Flatten()
                    dtypes = batch.Transform(lambda x: x.dtype).Flatten()
                    tf.logging.info('host_device: %s infeed shapes: %r',
                                    host_device, shapes)
                    tf.logging.info('host_device: %s infeed dtypes: %r',
                                    host_device, dtypes)
                    if p.use_partitioned_infeed_queue:
                        device_assignment = py_utils.GetTpuDeviceAssignment()

                        host_device = device_assignment.host_device(
                            replica=0, job=tf.flags.FLAGS.tf_master)
                        host_id = int(
                            host_device.split('/task:')[1].split('/device:')
                            [0])
                        tf.logging.info('host_id: {} host_device: {}'.format(
                            host_id, host_device))
                        q = tpu_feed._PartitionedInfeedQueue(  # pylint: disable=protected-access
                            number_of_tuple_elements=len(dtypes),
                            device_assignment=device_assignment,
                            host_id=host_id,
                            input_partition_dims=[[p.num_partitions, 1]
                                                  for _ in dtypes],
                            tuple_types=dtypes,
                            tuple_shapes=shapes)
                    else:
                        q = tpu_feed.InfeedQueue(tuple_types=dtypes,
                                                 tuple_shapes=shapes)
                        assert shards is not None
                        q.set_number_of_shards(shards)

                    queues.append(q)
                    tf.logging.info('q=%r', q)

                    if p.use_partitioned_infeed_queue:
                        input_ops = q.generate_enqueue_ops([batch.Flatten()])
                    elif p.use_per_host_infeed:
                        # TODO(ylc/zhifengc): Add this to a policy module and test it.
                        def TPUOrdinalFunction(shard_index_in_host):
                            device_assignment = py_utils.GetTpuDeviceAssignment(
                            )
                            if device_assignment:
                                # We put both enqueue/dequeue ops at core 0 in each replica.
                                replica = device_assignment.lookup_replicas(
                                    task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
                                return device_assignment.tpu_ordinal(
                                    replica=replica)
                            else:
                                return shard_index_in_host

                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            placement_function=lambda x: host_device,  # pylint: disable=cell-var-from-loop
                            tpu_ordinal_function=TPUOrdinalFunction)
                    else:
                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            device_assignment=py_utils.GetTpuDeviceAssignment(
                            ))

                    input_ops_list += input_ops
            tf.logging.info('input_ops_list %s', input_ops_list)
            tpu_infeed_op = tf.group(*input_ops_list)
        self._made_tpu_infeed = True
        # Let trainer.py use multiple threads to drive the infeed op.
        for _ in range(p.tpu_infeed_parallelism):
            tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op)

        self._tpu_infeed_op = tpu_infeed_op

        with tf.device(tf.tpu.core(0)):
            tensors = queues[0].generate_dequeue_op()
        return batch.Pack(tensors)
    def CreateTpuEnqueueOps(self):
        """Create the host-side enqueue ops.

    This should be called in an outer non-TPU context.
    """
        assert not self._tpu_queues, (
            'CreateTpuEnqueueOps should only be called '
            'once.')
        self._tpu_queues = []
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts
        tf.logging.info(
            'CreateTpuEnqueueOps num_splits_per_client={} '
            'num_devices_per_split={} num_tpu_hosts={} use_per_host_infeed={}'.
            format(cluster.num_splits_per_client,
                   cluster.num_devices_per_split, num_tpu_hosts,
                   p.use_per_host_infeed))

        assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts)
        if (cluster.num_devices_per_split > num_cores_per_host
                and p.use_per_host_infeed):
            tf.logging.fatal(
                'Doesn\'t support per host infeed mode when '
                'num_devices_per_split({}) > num_cores_per_host({})'.format(
                    cluster.num_devices_per_split, num_cores_per_host))
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        shards = (cluster.total_worker_devices //
                  num_infeed_hosts) // cluster.num_devices_per_split
        tf.logging.info('shards {}'.format(shards))

        input_ops_list = []
        tpu_embedding_collection = tf.get_collection(py_utils.TPU_EMBEDDING)
        tpu_embedding = (tpu_embedding_collection[0]
                         if tpu_embedding_collection else None)

        if num_tpu_hosts > 1 and tpu_embedding is not None:
            if not p.use_per_host_infeed:
                tf.logging.fatal(
                    'TPU Embedding must be used with per_host_infeed with multiple '
                    'TPU host topologies.')

        tpu_emb_input_keys = (list(tpu_embedding.feature_to_config_dict.keys())
                              if tpu_embedding is not None else [])
        tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys)
        tf.logging.info('num_infeed_hosts: %d', num_infeed_hosts)

        for task_id in range(num_infeed_hosts):
            host_device = '/task:{}/device:CPU:0'.format(task_id)
            with tf.device(host_device):
                self._batch = self.GetPreprocessedInputBatch()
                if isinstance(self._batch, py_utils.NestedMap):
                    # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU.
                    # Note that when MultiTaskData is used, bucket_keys will be at the
                    # second level of the dictionary.
                    self._batch = self._batch.FilterKeyVal(
                        lambda k, _: not k.endswith('bucket_keys'))
                tf.logging.info('host_device: %s, batch: %r', host_device,
                                self._batch)

                for k, x in self._batch.FlattenItems():
                    assert x.shape.is_fully_defined(), (
                        'Shape must be fully defined: %s: %s' % (k, x))
                    # TODO(cwhipkey): if it's a string (or other type not supported on
                    # TPU), drop it from feeding and on the other end add in an op that
                    # fails if used.
                shapes = self._batch.Transform(lambda x: x.shape).Flatten()
                dtypes = self._batch.Transform(lambda x: x.dtype).Flatten()

                tf.logging.info('host_device: %s infeed shapes: %r',
                                host_device, shapes)
                tf.logging.info('host_device: %s infeed dtypes: %r',
                                host_device, dtypes)

                if p.use_partitioned_infeed_queue:
                    device_assignment = py_utils.GetTpuDeviceAssignment()

                    host_device = device_assignment.host_device(
                        replica=0, job=tf.flags.FLAGS.tf_master)
                    host_id = int(
                        host_device.split('/task:')[1].split('/device:')[0])
                    tf.logging.info('host_id: {} host_device: {}'.format(
                        host_id, host_device))
                    q = tpu_feed._PartitionedInfeedQueue(  # pylint: disable=protected-access
                        number_of_tuple_elements=len(dtypes),
                        device_assignment=device_assignment,
                        host_id=host_id,
                        input_partition_dims=[[p.num_partitions] + [1] *
                                              (len(s) - 1) for s in shapes],
                        tuple_types=dtypes,
                        tuple_shapes=shapes)
                else:
                    q = tpu_feed.InfeedQueue(tuple_types=dtypes,
                                             tuple_shapes=shapes)
                    assert shards is not None
                    q.set_number_of_shards(shards)

                self._tpu_queues.append(q)

                if p.use_partitioned_infeed_queue:
                    input_ops = q.generate_enqueue_ops([self._batch.Flatten()])
                elif p.use_per_host_infeed:
                    # TODO(ylc/zhifengc): Add this to a policy module and test it.
                    def TPUOrdinalFunction(shard_index_in_host):
                        device_assignment = py_utils.GetTpuDeviceAssignment()
                        if device_assignment:
                            # We put both enqueue/dequeue ops at core 0 in each replica.
                            replica = device_assignment.lookup_replicas(
                                task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
                            return device_assignment.tpu_ordinal(
                                replica=replica)
                        else:
                            return shard_index_in_host

                    input_ops = q.split_inputs_and_generate_enqueue_ops(
                        self._batch.Flatten(),
                        placement_function=lambda x: host_device,  # pylint: disable=cell-var-from-loop
                        tpu_ordinal_function=TPUOrdinalFunction)
                else:
                    input_ops = q.split_inputs_and_generate_enqueue_ops(
                        self._batch.Flatten(),
                        device_assignment=py_utils.GetTpuDeviceAssignment())
                input_ops_list += input_ops

        tf.logging.info('input_ops_list %s', input_ops_list)
        grouped_infeed_op = tf.group(*input_ops_list)
        self._tpu_infeed_op = []
        for _ in range(p.tpu_infeed_parallelism):
            self._tpu_infeed_op.append(grouped_infeed_op)
Beispiel #15
0
    def CreateTpuFeeds(self):
        """Creates the TPU infeed queue from preprocessed batch."""
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts
        tf.logging.info('num_cores_per_host {}'.format(num_cores_per_host))
        tf.logging.info('num_devices_per_split {}'.format(
            cluster.num_devices_per_split))

        assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts)
        if (cluster.num_devices_per_split > num_cores_per_host
                and p.use_per_host_infeed):
            tf.logging.fatal(
                'Doesn\'t support per host infeed mode when '
                'num_devices_per_split({}) > num_cores_per_host({})'.format(
                    cluster.num_devices_per_split, num_cores_per_host))
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        with py_utils.outside_all_rewrites():
            assert py_utils.use_tpu()
            assert not self._made_tpu_infeed

            shards = tpu_function.get_tpu_context(
            ).number_of_shards // num_infeed_hosts
            input_ops_list = []
            queues = []
            tpu_embedding_collection = tf.get_collection(
                py_utils.TPU_EMBEDDING)
            tpu_embedding = (tpu_embedding_collection[0]
                             if tpu_embedding_collection else None)

            tpu_emb_input_keys = (list(
                tpu_embedding.feature_to_config_dict.keys())
                                  if tpu_embedding is not None else [])
            tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys)

            batch = None
            for task_id in range(num_infeed_hosts):
                host_device = '/task:{}/device:CPU:0'.format(task_id)
                with tf.device(host_device):
                    batch = self.GetPreprocessedInputBatch()
                    if 'bucket_keys' in batch:
                        # Hack: bucket_keys are not needed on TPU.
                        del batch['bucket_keys']
                    tf.logging.info('host_device: %s, batch: %r', host_device,
                                    batch)

                    if tpu_embedding is not None:
                        enqueue_dict_per_core = [
                            {} for _ in range(tpu_embedding.num_cores_per_host)
                        ]
                        num_cores_per_host = tpu_embedding.num_cores_per_host
                        for key in tpu_emb_input_keys:
                            feat = batch[key]
                            tpu_emb_feat_splitted = tf.split(
                                feat, num_cores_per_host)
                            for core, split in enumerate(
                                    tpu_emb_feat_splitted):
                                # Dense to sparse. Note the assumption of a padding id.
                                sample_indices = tf.where(
                                    tf.not_equal(split, -1))
                                embedding_indices = tf.gather_nd(
                                    split, sample_indices)
                                enqueue_data = tpu_embedding_lib.EnqueueData(
                                    embedding_indices, sample_indices)
                                enqueue_dict_per_core[core][key] = enqueue_data
                        input_ops_list += tpu_embedding.generate_enqueue_ops(
                            enqueue_dict_per_core)

                    for k, x in batch.FlattenItems():
                        assert x.shape.is_fully_defined(), (
                            'Shape must be fully defined: %s: %s' % (k, x))
                        # TODO(cwhipkey): if it's a string (or other type not supported on
                        # TPU), drop it from feeding and on the other end add in an op that
                        # fails if used.
                    shapes = batch.Transform(lambda x: x.shape).Flatten()
                    dtypes = batch.Transform(lambda x: x.dtype).Flatten()
                    tf.logging.info('host_device: %s infeed shapes: %r',
                                    host_device, shapes)
                    tf.logging.info('host_device: %s infeed dtypes: %r',
                                    host_device, dtypes)
                    q = tpu_feed.InfeedQueue(tuple_types=dtypes,
                                             tuple_shapes=shapes)
                    queues.append(q)
                    assert shards is not None
                    q.set_number_of_shards(shards)

                    if p.use_per_host_infeed:

                        # TODO(ylc/zhifengc): Add this to a policy module and test it.
                        def TPUOrdinalFunction(shard_index_in_host):
                            device_assignment = py_utils.GetTpuDeviceAssignment(
                            )
                            if device_assignment:
                                # We put both enqueue/dequeue ops at core 0 in each replica.
                                replica = device_assignment.lookup_replicas(
                                    task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
                                return device_assignment.tpu_ordinal(
                                    replica=replica)
                            else:
                                return shard_index_in_host

                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            placement_function=lambda x: host_device,  # pylint: disable=cell-var-from-loop
                            tpu_ordinal_function=TPUOrdinalFunction)
                    else:
                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            device_assignment=py_utils.GetTpuDeviceAssignment(
                            ))

                    input_ops_list += input_ops
            tf.logging.info('input_ops_list %s', input_ops_list)
            tpu_infeed_op = tf.group(*input_ops_list)
        self._made_tpu_infeed = True
        # Let trainer.py use multiple threads to drive the infeed op.
        for _ in range(p.tpu_infeed_parallelism):
            tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op)

        # For executor-driven multiple programs, we need more fine-grained
        # access rather than using a single global graph collection.
        self.tpu_infeed_op = tpu_infeed_op

        with tf.device(tf.tpu.core(0)):
            tensors = queues[0].generate_dequeue_op()
        return batch.Pack(tensors)
Beispiel #16
0
                def enqueue_ops_fn(idx):
                    """Generate the infeed enqueue ops graph."""

                    per_host_sharded_inputs = []
                    control_deps = []
                    for _ in range(FLAGS.replicas_per_host):
                        with tf.control_dependencies(control_deps):
                            self.feature_structure[
                                is_training] = iterator.get_next()
                        self.maybe_capture_embedding_inputs(
                            self.feature_structure[is_training], is_training)
                        flattened_inputs = tf.nest.flatten(
                            self.feature_structure[is_training])
                        control_deps.extend(flattened_inputs)
                        if input_partition_dims:
                            padded_inputs = []
                            for inp in flattened_inputs:
                                if inp.shape.ndims < len(input_partition_dims):
                                    padded_inputs.append(inp)
                                    continue
                                paddings = []
                                for i, j in enumerate(input_partition_dims):
                                    r = inp.shape.as_list()[i] % j
                                    if r > 0:
                                        paddings.append([0, j - r])
                                    else:
                                        paddings.append([0, 0])
                                for i in range(inp.shape.ndims -
                                               len(input_partition_dims)):
                                    paddings.append([0, 0])
                                padded_inputs.append(tf.pad(inp, paddings))
                            per_host_sharded_inputs.append(padded_inputs)
                        else:
                            per_host_sharded_inputs.append(flattened_inputs)

                    if input_partition_dims:
                        flattened_input_dims = []
                        for i in per_host_sharded_inputs[0]:
                            if i.shape.ndims == len(input_partition_dims):
                                flattened_input_dims.append(
                                    input_partition_dims)
                            elif i.shape.ndims > len(input_partition_dims):
                                flattened_input_dims.append(
                                    input_partition_dims + [1] *
                                    (i.shape.ndims - len(input_partition_dims))
                                )
                            else:
                                flattened_input_dims.append([1] *
                                                            i.shape.ndims)
                        # pylint: disable=protected-access
                        self.infeed_op[
                            is_training] = tpu_feed._PartitionedInfeedQueue(
                                number_of_tuple_elements=len(
                                    per_host_sharded_inputs[0]),
                                host_id=host_id,
                                input_partition_dims=flattened_input_dims,
                                device_assignment=self.device_assignment)
                        with tf.control_dependencies(
                                self.infeed_op[is_training].
                                generate_enqueue_ops(per_host_sharded_inputs)):
                            return idx + 1
                    else:
                        self.infeed_op[is_training] = tpu_feed.InfeedQueue(
                            number_of_tuple_elements=len(
                                per_host_sharded_inputs[0]))
                        per_host_enqueue_ops = (
                            self.infeed_op[is_training].generate_enqueue_ops(
                                per_host_sharded_inputs,
                                tpu_ordinal_function=_tpu_ordinal_fn))

                    self.maybe_add_embedding_enqueue_ops_int(
                        is_training, per_host_enqueue_ops)
                    with tf.control_dependencies(per_host_enqueue_ops):
                        return idx + 1