예제 #1
0
    def initialize(self, input_fn, params):
        """Initialize all the things required for evaluation."""
        tf.logging.info("EvalLowLevelRunner: initialize method")

        self.tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

        def get_enqueue_ops_fn():
            """Generate the enqueue ops graph function."""

            dataset = input_fn(params)
            with tf.device(utils.device_for_host(self._get_host(0))):
                iterator = dataset.make_initializable_iterator()
                self.dataset_initializer.append(iterator.initializer)

                def enqueue_ops_fn():
                    """Enqueue ops function for one host."""
                    per_host_sharded_inputs = []
                    control_deps = []
                    for _ in range(FLAGS.num_shards_per_host):
                        with tf.control_dependencies(control_deps):
                            features = iterator.get_next()
                        self.feature_structure["features"] = features
                        flattened_inputs = data_nest.flatten(
                            self.feature_structure)
                        control_deps.extend(flattened_inputs)
                        per_host_sharded_inputs.append(flattened_inputs)

                    infeed = tpu.InfeedQueue(number_of_tuple_elements=len(
                        per_host_sharded_inputs[0]))
                    self.infeed_queue.append(infeed)
                    return infeed.generate_enqueue_ops(
                        per_host_sharded_inputs,
                        tpu_ordinal_function=utils.tpu_ordinal_fn)

                return enqueue_ops_fn

        with self.graph.as_default():
            self.enqueue_ops.append(
                utils.wrap_computation_in_while_loop(
                    get_enqueue_ops_fn(),
                    n=self.eval_steps,
                    host_name=self._get_host(0)))

        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        isolate_session_state=True,
                                        operation_timeout_in_ms=600 * 60 *
                                        1000)  # 10 hours
        cluster_spec = self.tpu_cluster_resolver.cluster_spec()
        if cluster_spec:
            session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
        self.sess = tf.Session(self.tpu_cluster_resolver.get_master(),
                               graph=self.graph,
                               config=session_config)
        if FLAGS.mode != "eval_once":
            self.sess.run(self.tpu_init)
예제 #2
0
    def build_enqueue_ops(self, input_fn, params, host_id):
        """Build enqueue ops."""
        tf.logging.info("TrainLowLevelRunner: build_enqueue_ops")

        def get_enqueue_ops_fn(host_id):
            """Generate the enqueue ops graph function."""

            params["dataset_num_shards"] = self.num_hosts
            params["dataset_index"] = host_id
            dataset = input_fn(params)
            with tf.device(utils.device_for_host(self._get_host(host_id))):
                iterator = dataset.make_initializable_iterator()
                self.dataset_initializer.append(iterator.initializer)

                def enqueue_ops_fn():
                    """Enqueue ops function for one host."""
                    per_host_sharded_inputs = []
                    control_deps = []
                    for _ in range(FLAGS.num_shards_per_host):
                        with tf.control_dependencies(control_deps):
                            features, labels = iterator.get_next()
                        self.feature_structure["features"] = features
                        self.feature_structure["labels"] = labels
                        flattened_inputs = data_nest.flatten(
                            self.feature_structure)
                        control_deps.extend(flattened_inputs)
                        per_host_sharded_inputs.append(flattened_inputs)

                    infeed = tpu.InfeedQueue(number_of_tuple_elements=len(
                        per_host_sharded_inputs[0]))
                    self.infeed_queue.append(infeed)
                    return infeed.generate_enqueue_ops(
                        per_host_sharded_inputs,
                        tpu_ordinal_function=utils.tpu_ordinal_fn)

                return enqueue_ops_fn

        with self.input_graph.as_default():
            self.enqueue_ops.append(
                utils.wrap_computation_in_while_loop(
                    get_enqueue_ops_fn(host_id),
                    n=self.iterations,
                    host_name=self._get_host(host_id)))
예제 #3
0
    def build_enqueue_ops(self, input_fn, params, host_id):
        """Build enqueue ops."""
        tf.logging.info("TrainLowLevelRunner: build_enqueue_ops")

        def get_enqueue_ops_fn(host_id):
            """Generate the enqueue ops graph function."""

            params["dataset_num_shards"] = self.num_hosts
            params["dataset_index"] = host_id
            with tf.device(utils.device_for_host(self._get_host(host_id))):
                dataset = input_fn(params)
                iterator = dataset.make_initializable_iterator()
                self.dataset_initializer.append(iterator.initializer)

                def enqueue_ops_fn():
                    """Enqueue ops function for one host."""
                    per_host_sharded_inputs = []
                    control_deps = []
                    for _ in range(self.replicas_per_worker):
                        with tf.control_dependencies(control_deps):
                            features, labels = iterator.get_next()
                        if self.use_spatial_partition:
                            num_elements = []
                            for i, d in enumerate(ssd_constants.FEATURE_SIZES):
                                num_elements.append(
                                    d * d * ssd_constants.NUM_DEFAULTS[i])
                            gt_boxes = tf.split(labels[ssd_constants.BOXES],
                                                num_elements, 1)
                            gt_classes = tf.split(
                                labels[ssd_constants.CLASSES], num_elements, 1)

                            def transpose_gt_box(gt_box, i):
                                return tf.transpose(
                                    tf.reshape(gt_box, [
                                        -1, ssd_constants.NUM_DEFAULTS[i],
                                        ssd_constants.FEATURE_SIZES[i],
                                        ssd_constants.FEATURE_SIZES[i], 4
                                    ]), [0, 2, 3, 1, 4])

                            def transpose_gt_class(gt_class, i):
                                return tf.transpose(
                                    tf.reshape(gt_class, [
                                        -1, ssd_constants.NUM_DEFAULTS[i],
                                        ssd_constants.FEATURE_SIZES[i],
                                        ssd_constants.FEATURE_SIZES[i]
                                    ]), [0, 2, 3, 1])

                            labels[ssd_constants.BOXES] = {
                                i: transpose_gt_box(gt_boxes[i], i)
                                for i in range(len(ssd_constants.NUM_DEFAULTS))
                            }
                            labels[ssd_constants.CLASSES] = {
                                i: transpose_gt_class(gt_classes[i], i)
                                for i in range(len(ssd_constants.NUM_DEFAULTS))
                            }
                        self.feature_structure["features"] = features
                        self.feature_structure["labels"] = labels
                        flattened_inputs = data_nest.flatten(
                            self.feature_structure)
                        control_deps.extend(flattened_inputs)
                        per_host_sharded_inputs.append(flattened_inputs)

                    if self.use_spatial_partition:
                        flattened_input_dims = []
                        for i in per_host_sharded_inputs[0]:
                            if i.shape.ndims >= len(
                                    self.input_partition_dims[0]):
                                if i.shape.as_list() == self.feature_structure[
                                        "features"].shape.as_list():
                                    flattened_input_dims.append(
                                        self.input_partition_dims[0])
                                else:
                                    flattened_input_dims.append(
                                        FLAGS.input_partition_dims + [1] *
                                        (i.shape.ndims -
                                         len(self.input_partition_dims[0])))
                            else:
                                flattened_input_dims.append([1] *
                                                            i.shape.ndims)
                        # pylint: disable=protected-access
                        infeed = tpu_feed._PartitionedInfeedQueue(
                            number_of_tuple_elements=len(
                                per_host_sharded_inputs[0]),
                            host_id=host_id,
                            input_partition_dims=flattened_input_dims,
                            device_assignment=self.device_assignment)
                        self.infeed_queue.append(infeed)
                        return infeed.generate_enqueue_ops(
                            per_host_sharded_inputs)

                    infeed = tpu.InfeedQueue(number_of_tuple_elements=len(
                        per_host_sharded_inputs[0]))
                    self.infeed_queue.append(infeed)
                    return infeed.generate_enqueue_ops(
                        per_host_sharded_inputs,
                        tpu_ordinal_function=utils.tpu_ordinal_fn)

                return enqueue_ops_fn

        with self.input_graph.as_default():
            self.enqueue_ops.append(
                utils.wrap_computation_in_while_loop(
                    get_enqueue_ops_fn(host_id),
                    n=self.iterations,
                    host_name=self._get_host(host_id)))
 def multiple_steps_fn():
   """function for multiple TPU steps in a host training loop."""
   return utils.wrap_computation_in_while_loop(
       single_step_fn, n=iterations_per_loop, parallel_iterations=1)