def enqueue_ops_fn():
                    """Generate the infeed enqueue ops graph."""

                    per_host_sharded_inputs = []
                    control_deps = []
                    for _ in range(FLAGS.tpu_cores_per_host):
                        with tf.control_dependencies(control_deps):
                            features, labels = iterator.get_next()
                        if is_training:
                            self.feature_structure["features"] = features
                            self.feature_structure["labels"] = labels
                            flattened_inputs = data_nest.flatten(
                                self.feature_structure)
                        else:
                            self.eval_feature_structure["features"] = features
                            self.eval_feature_structure["labels"] = labels
                            flattened_inputs = data_nest.flatten(
                                self.eval_feature_structure)

                        control_deps.extend(flattened_inputs)
                        per_host_sharded_inputs.append(flattened_inputs)

                    infeed = tpu.InfeedQueue(number_of_tuple_elements=len(
                        per_host_sharded_inputs[0]))
                    if is_training:
                        self.infeed_queue.append(infeed)
                    else:
                        self.eval_infeed_queue.append(infeed)
                    return infeed.generate_enqueue_ops(
                        per_host_sharded_inputs,
                        tpu_ordinal_function=tpu_ordinal_fn)
示例#2
0
            def enqueue_ops_fn():
                """Generate the infeed enqueue ops graph."""

                per_host_sharded_inputs = []
                control_deps = []
                with tf.device(self.device_for_host(task=host_id)):
                    for _ in range(FLAGS.tpu_cores_per_host):
                        with tf.control_dependencies(control_deps):
                            self.feature_structure = iterator.get_next()
                            if not isinstance(self.feature_structure, dict):
                                features, labels = self.feature_structure
                                self.feature_structure = {}
                                self.feature_structure["features"] = features
                                self.feature_structure["labels"] = labels
                        flattened_inputs = data_nest.flatten(
                            self.feature_structure)
                        control_deps.extend(flattened_inputs)
                        per_host_sharded_inputs.append(flattened_inputs)

                    infeed = tpu.InfeedQueue(number_of_tuple_elements=len(
                        per_host_sharded_inputs[0]))
                    self.infeed_queue.append(infeed)
                    return infeed.generate_enqueue_ops(
                        per_host_sharded_inputs,
                        tpu_ordinal_function=tpu_ordinal_fn)
                def enqueue_ops_fn_v1():
                    """Enqueue ops function for one host.."""
                    features = output
                    self.feature_structure["features"] = features
                    self.feature_structure["labels"] = {}
                    flattened_inputs = data_nest.flatten(
                        self.feature_structure)
                    infeed = tpu.InfeedQueue(
                        tuple_types=[t.dtype for t in flattened_inputs],
                        tuple_shapes=[t.shape for t in flattened_inputs],
                        shard_dimensions=None)

                    infeed.set_number_of_shards(
                        self.hparams.num_shards_per_host)
                    self.infeed_queue.append(infeed)

                    def tpu_ordinal_fn(shard_index_in_host):
                        return shard_index_in_host % self.hparams.num_shards_per_host

                    per_host_enqueue_ops = (
                        infeed.split_inputs_and_generate_enqueue_ops(
                            flattened_inputs,
                            placement_function=lambda x: device,
                            tpu_ordinal_function=tpu_ordinal_fn))
                    return per_host_enqueue_ops
                def enqueue_ops_fn():
                    """Enqueue ops function for one host."""
                    per_host_sharded_inputs = []
                    control_deps = []

                    if FLAGS.broadcast_input_all_replicas:
                        features, labels = iterator.get_next()
                        self.feature_structure["features"] = features
                        self.feature_structure["labels"] = labels
                        flattened_inputs = data_nest.flatten(
                            self.feature_structure)
                        for _ in range(FLAGS.tpu_num_shards_per_host):
                            per_host_sharded_inputs.append(flattened_inputs)
                    else:
                        for _ in range(FLAGS.tpu_num_shards_per_host):
                            with tf.control_dependencies(control_deps):
                                features, labels = iterator.get_next()
                            self.feature_structure["features"] = features
                            self.feature_structure["labels"] = labels
                            flattened_inputs = data_nest.flatten(
                                self.feature_structure)
                            control_deps.extend(flattened_inputs)
                            per_host_sharded_inputs.append(flattened_inputs)

                    infeed = tpu.InfeedQueue(number_of_tuple_elements=len(
                        per_host_sharded_inputs[0]))
                    self.infeed_queue.append(infeed)
                    return infeed.generate_enqueue_ops(
                        per_host_sharded_inputs,
                        tpu_ordinal_function=low_level_utils.tpu_ordinal_fn)
        def enqueue_ops_fn():
          """Enqueue ops function for one host."""
          per_host_sharded_inputs = []
          control_deps = []
          # Currently working only on a donut, change this later to support
          # distibuted eval.
          for _ in range(FLAGS.tpu_num_shards_per_host):
            with tf.control_dependencies(control_deps):
              features = iterator.get_next()
            self.feature_structure["features"] = features
            flattened_inputs = data_nest.flatten(self.feature_structure)
            control_deps.extend(flattened_inputs)
            per_host_sharded_inputs.append(flattened_inputs)

          infeed = tpu.InfeedQueue(
              number_of_tuple_elements=len(per_host_sharded_inputs[0]))
          self.infeed_queue.append(infeed)
          return infeed.generate_enqueue_ops(
              per_host_sharded_inputs,
              tpu_ordinal_function=low_level_utils.tpu_ordinal_fn)
                def enqueue_ops_fn():
                    """Enqueue ops function for one host."""
                    per_host_sharded_inputs = []
                    control_deps = []
                    for _ in range(self.hparams.num_shards_per_host):
                        with tf.control_dependencies(control_deps):
                            features = iterator.get_next()
                        self.feature_structure["features"] = features
                        flattened_inputs = data_nest.flatten(
                            self.feature_structure)
                        control_deps.extend(flattened_inputs)
                        per_host_sharded_inputs.append(flattened_inputs)

                    infeed = tpu.InfeedQueue(number_of_tuple_elements=len(
                        per_host_sharded_inputs[0]))
                    self.infeed_queue.append(infeed)

                    def tpu_ordinal_fn(shard_index_in_host):
                        return shard_index_in_host % self.hparams.num_shards_per_host

                    return infeed.generate_enqueue_ops(
                        per_host_sharded_inputs,
                        tpu_ordinal_function=tpu_ordinal_fn)
示例#7
0
def build_infeed(input_fn, params, batch_count, embedding, mode):
    """Build infeed."""
    if mode == tpu_embedding.TRAINING:
        infeed_queue = tpu.InfeedQueue(tuple_types=[tf.int32],
                                       tuple_shapes=[[params["batch_size"],
                                                      1]])
        infeed_queue.set_number_of_shards(embedding.num_cores)
    else:
        infeed_queue = tpu.InfeedQueue(tuple_types=[tf.float32],
                                       tuple_shapes=[[params["batch_size"],
                                                      1]])
        infeed_queue.set_number_of_shards(embedding.num_cores)

    def enqueue_ops_fn():
        """Create enqueue ops."""
        ds = input_fn(params)
        iterator = ds.make_one_shot_iterator()
        if mode == tpu_embedding.TRAINING:
            features, labels = iterator.get_next()
        else:
            features = iterator.get_next()

        # TODO(shizhiw): speed up input pipeline by avoiding splitting and
        # sparse tensor.
        # TPU embedding enqueue.
        users = features[movielens.USER_COLUMN]
        items = features[movielens.ITEM_COLUMN]

        sparse_features_list = []
        users_per_core_list = tf.split(users, embedding.num_cores_per_host)
        items_per_core_list = tf.split(items, embedding.num_cores_per_host)
        for j in range(embedding.num_cores_per_host):
            users_sparse = tf.SparseTensor(
                indices=[[i, 0] for i in range(embedding.batch_size_per_core)],
                values=users_per_core_list[j],
                dense_shape=[embedding.batch_size_per_core, 1])
            items_sparse = tf.SparseTensor(
                indices=[[i, 0] for i in range(embedding.batch_size_per_core)],
                values=items_per_core_list[j],
                dense_shape=[embedding.batch_size_per_core, 1])
            sparse_features = {
                "mf_user": users_sparse,
                "mlp_user": users_sparse,
                "mf_item": items_sparse,
                "mlp_item": items_sparse,
            }
            sparse_features_list.append(sparse_features)
        enqueue_ops = embedding.generate_enqueue_ops(sparse_features_list)

        # TPU dense enqueue.
        if mode == tpu_embedding.TRAINING:
            # Infeed does not support bool.
            labels = tf.cast(labels, tf.int32)
            enqueue_ops.extend(
                infeed_queue.split_inputs_and_generate_enqueue_ops([labels]))
        else:
            duplicate_mask = tf.cast(features[rconst.DUPLICATE_MASK],
                                     tf.float32)
            enqueue_ops.extend(
                infeed_queue.split_inputs_and_generate_enqueue_ops(
                    [duplicate_mask]))

        return enqueue_ops

    if len(embedding.hosts) != 1:
        raise ValueError(
            "len(embedding.hosts) should be 1, but got {}.".format(
                embedding.hosts))
    # TODO(shizhiw): check enqueue op location in tpu_embedding.py as user
    # might fail to specify device for enqueue ops.
    with tf.device(embedding.hosts[0]):
        wrapped_enqueue_ops = wrap_computation_in_while_loop(
            enqueue_ops_fn, n=batch_count, parallel_iterations=1)

    def get_infeed_thread_fn(sess):
        def infeed_thread_fn():
            tf.logging.info("Enqueueing...")
            sess.run(wrapped_enqueue_ops)

        return infeed_thread_fn

    return get_infeed_thread_fn, infeed_queue
示例#8
0
                def enqueue_ops_fn():
                    """Enqueue ops function for one host."""
                    per_host_sharded_inputs = []
                    control_deps = []
                    for _ in range(self.replicas_per_worker):
                        with tf.control_dependencies(control_deps):
                            features, labels = iterator.get_next()
                        if self.use_spatial_partition:
                            num_elements = []
                            for i, d in enumerate(ssd_constants.FEATURE_SIZES):
                                num_elements.append(
                                    d * d * ssd_constants.NUM_DEFAULTS[i])
                            gt_boxes = tf.split(labels[ssd_constants.BOXES],
                                                num_elements, 1)
                            gt_classes = tf.split(
                                labels[ssd_constants.CLASSES], num_elements, 1)

                            def transpose_gt_box(gt_box, i):
                                return tf.transpose(
                                    tf.reshape(gt_box, [
                                        -1, ssd_constants.NUM_DEFAULTS[i],
                                        ssd_constants.FEATURE_SIZES[i],
                                        ssd_constants.FEATURE_SIZES[i], 4
                                    ]), [0, 2, 3, 1, 4])

                            def transpose_gt_class(gt_class, i):
                                return tf.transpose(
                                    tf.reshape(gt_class, [
                                        -1, ssd_constants.NUM_DEFAULTS[i],
                                        ssd_constants.FEATURE_SIZES[i],
                                        ssd_constants.FEATURE_SIZES[i]
                                    ]), [0, 2, 3, 1])

                            labels[ssd_constants.BOXES] = {
                                i: transpose_gt_box(gt_boxes[i], i)
                                for i in range(len(ssd_constants.NUM_DEFAULTS))
                            }
                            labels[ssd_constants.CLASSES] = {
                                i: transpose_gt_class(gt_classes[i], i)
                                for i in range(len(ssd_constants.NUM_DEFAULTS))
                            }
                        self.feature_structure["features"] = features
                        self.feature_structure["labels"] = labels
                        flattened_inputs = data_nest.flatten(
                            self.feature_structure)
                        control_deps.extend(flattened_inputs)
                        per_host_sharded_inputs.append(flattened_inputs)

                    if self.use_spatial_partition:
                        flattened_input_dims = []
                        for i in per_host_sharded_inputs[0]:
                            if i.shape.ndims >= len(
                                    self.input_partition_dims[0]):
                                if i.shape.as_list() == self.feature_structure[
                                        "features"].shape.as_list():
                                    flattened_input_dims.append(
                                        self.input_partition_dims[0])
                                else:
                                    flattened_input_dims.append(
                                        FLAGS.input_partition_dims + [1] *
                                        (i.shape.ndims -
                                         len(self.input_partition_dims[0])))
                            else:
                                flattened_input_dims.append([1] *
                                                            i.shape.ndims)
                        # pylint: disable=protected-access
                        infeed = tpu_feed._PartitionedInfeedQueue(
                            number_of_tuple_elements=len(
                                per_host_sharded_inputs[0]),
                            host_id=host_id,
                            input_partition_dims=flattened_input_dims,
                            device_assignment=self.device_assignment)
                        self.infeed_queue.append(infeed)
                        return infeed.generate_enqueue_ops(
                            per_host_sharded_inputs)

                    infeed = tpu.InfeedQueue(number_of_tuple_elements=len(
                        per_host_sharded_inputs[0]))
                    self.infeed_queue.append(infeed)
                    return infeed.generate_enqueue_ops(
                        per_host_sharded_inputs,
                        tpu_ordinal_function=utils.tpu_ordinal_fn)