예제 #1
0
                def eval_enqueue_ops_fn():
                    """Enqueue ops function for one host."""
                    per_host_sharded_inputs = []
                    control_deps = []
                    for _ in range(self.train_params["replicas_per_worker"]):
                        with tf.control_dependencies(control_deps):
                            features = iterator.get_next()
                        if self.use_spatial_partition:
                            self.eval_input_dims_flattener.validate_and_flatten_input_dims(
                                features, None)
                        flattened_inputs = (self.eval_input_flattener.
                                            flatten_features_and_labels(
                                                features, None))
                        control_deps.extend(flattened_inputs)
                        per_host_sharded_inputs.append(flattened_inputs)

                    if self.use_spatial_partition:
                        flattened_input_dims = (self.eval_input_dims_flattener.
                                                flattened_input_dims)
                        # pylint: disable=protected-access
                        infeed = tpu_feed._PartitionedInfeedQueue(
                            number_of_tuple_elements=len(
                                per_host_sharded_inputs[0]),
                            host_id=host_id,
                            input_partition_dims=flattened_input_dims,
                            device_assignment=self.device_assignment)
                        self.eval_infeed_queue.append(infeed)
                        return infeed.generate_enqueue_ops(
                            per_host_sharded_inputs)

                    infeed = tf.contrib.tpu.InfeedQueue(
                        number_of_tuple_elements=len(
                            per_host_sharded_inputs[0]))
                    self.eval_infeed_queue.append(infeed)
                    return infeed.generate_enqueue_ops(
                        per_host_sharded_inputs,
                        tpu_ordinal_function=functools.partial(
                            runner_utils.tpu_ordinal_fn,
                            replicas_per_worker=self.
                            train_params["replicas_per_worker"]))
예제 #2
0
                def enqueue_ops_fn():
                    """Enqueue ops function for one host."""
                    per_host_sharded_inputs = []
                    control_deps = []
                    for _ in range(self.replicas_per_worker):
                        with tf.control_dependencies(control_deps):
                            features, labels = iterator.get_next()
                        if self.use_spatial_partition:
                            num_elements = []
                            for i, d in enumerate(ssd_constants.FEATURE_SIZES):
                                num_elements.append(
                                    d * d * ssd_constants.NUM_DEFAULTS[i])
                            gt_boxes = tf.split(labels[ssd_constants.BOXES],
                                                num_elements, 1)
                            gt_classes = tf.split(
                                labels[ssd_constants.CLASSES], num_elements, 1)

                            def transpose_gt_box(gt_box, i):
                                return tf.transpose(
                                    tf.reshape(gt_box, [
                                        -1, ssd_constants.NUM_DEFAULTS[i],
                                        ssd_constants.FEATURE_SIZES[i],
                                        ssd_constants.FEATURE_SIZES[i], 4
                                    ]), [0, 2, 3, 1, 4])

                            def transpose_gt_class(gt_class, i):
                                return tf.transpose(
                                    tf.reshape(gt_class, [
                                        -1, ssd_constants.NUM_DEFAULTS[i],
                                        ssd_constants.FEATURE_SIZES[i],
                                        ssd_constants.FEATURE_SIZES[i]
                                    ]), [0, 2, 3, 1])

                            labels[ssd_constants.BOXES] = {
                                i: transpose_gt_box(gt_boxes[i], i)
                                for i in range(len(ssd_constants.NUM_DEFAULTS))
                            }
                            labels[ssd_constants.CLASSES] = {
                                i: transpose_gt_class(gt_classes[i], i)
                                for i in range(len(ssd_constants.NUM_DEFAULTS))
                            }
                        self.feature_structure["features"] = features
                        self.feature_structure["labels"] = labels
                        flattened_inputs = data_nest.flatten(
                            self.feature_structure)
                        control_deps.extend(flattened_inputs)
                        per_host_sharded_inputs.append(flattened_inputs)

                    if self.use_spatial_partition:
                        flattened_input_dims = []
                        for i in per_host_sharded_inputs[0]:
                            if i.shape.ndims >= len(
                                    self.input_partition_dims[0]):
                                if i.shape.as_list() == self.feature_structure[
                                        "features"].shape.as_list():
                                    flattened_input_dims.append(
                                        self.input_partition_dims[0])
                                else:
                                    flattened_input_dims.append(
                                        FLAGS.input_partition_dims + [1] *
                                        (i.shape.ndims -
                                         len(self.input_partition_dims[0])))
                            else:
                                flattened_input_dims.append([1] *
                                                            i.shape.ndims)
                        # pylint: disable=protected-access
                        infeed = tpu_feed._PartitionedInfeedQueue(
                            number_of_tuple_elements=len(
                                per_host_sharded_inputs[0]),
                            host_id=host_id,
                            input_partition_dims=flattened_input_dims,
                            device_assignment=self.device_assignment)
                        self.infeed_queue.append(infeed)
                        return infeed.generate_enqueue_ops(
                            per_host_sharded_inputs)

                    infeed = tpu.InfeedQueue(number_of_tuple_elements=len(
                        per_host_sharded_inputs[0]))
                    self.infeed_queue.append(infeed)
                    return infeed.generate_enqueue_ops(
                        per_host_sharded_inputs,
                        tpu_ordinal_function=utils.tpu_ordinal_fn)