def initialize(self, input_fn, params): """Initialize all the things required for evaluation.""" tf.logging.info("EvalLowLevelRunner: initialize method") self.tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) def get_enqueue_ops_fn(): """Generate the enqueue ops graph function.""" dataset = input_fn(params) with tf.device(utils.device_for_host(self._get_host(0))): iterator = dataset.make_initializable_iterator() self.dataset_initializer.append(iterator.initializer) def enqueue_ops_fn(): """Enqueue ops function for one host.""" per_host_sharded_inputs = [] control_deps = [] for _ in range(FLAGS.num_shards_per_host): with tf.control_dependencies(control_deps): features = iterator.get_next() self.feature_structure["features"] = features flattened_inputs = data_nest.flatten( self.feature_structure) control_deps.extend(flattened_inputs) per_host_sharded_inputs.append(flattened_inputs) infeed = tpu.InfeedQueue(number_of_tuple_elements=len( per_host_sharded_inputs[0])) self.infeed_queue.append(infeed) return infeed.generate_enqueue_ops( per_host_sharded_inputs, tpu_ordinal_function=utils.tpu_ordinal_fn) return enqueue_ops_fn with self.graph.as_default(): self.enqueue_ops.append( utils.wrap_computation_in_while_loop( get_enqueue_ops_fn(), n=self.eval_steps, host_name=self._get_host(0))) session_config = tf.ConfigProto(allow_soft_placement=True, isolate_session_state=True, operation_timeout_in_ms=600 * 60 * 1000) # 10 hours cluster_spec = self.tpu_cluster_resolver.cluster_spec() if cluster_spec: session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) self.sess = tf.Session(self.tpu_cluster_resolver.get_master(), graph=self.graph, config=session_config) if FLAGS.mode != "eval_once": self.sess.run(self.tpu_init)
def build_enqueue_ops(self, input_fn, params, host_id): """Build enqueue ops.""" tf.logging.info("TrainLowLevelRunner: build_enqueue_ops") def get_enqueue_ops_fn(host_id): """Generate the enqueue ops graph function.""" params["dataset_num_shards"] = self.num_hosts params["dataset_index"] = host_id dataset = input_fn(params) with tf.device(utils.device_for_host(self._get_host(host_id))): iterator = dataset.make_initializable_iterator() self.dataset_initializer.append(iterator.initializer) def enqueue_ops_fn(): """Enqueue ops function for one host.""" per_host_sharded_inputs = [] control_deps = [] for _ in range(FLAGS.num_shards_per_host): with tf.control_dependencies(control_deps): features, labels = iterator.get_next() self.feature_structure["features"] = features self.feature_structure["labels"] = labels flattened_inputs = data_nest.flatten( self.feature_structure) control_deps.extend(flattened_inputs) per_host_sharded_inputs.append(flattened_inputs) infeed = tpu.InfeedQueue(number_of_tuple_elements=len( per_host_sharded_inputs[0])) self.infeed_queue.append(infeed) return infeed.generate_enqueue_ops( per_host_sharded_inputs, tpu_ordinal_function=utils.tpu_ordinal_fn) return enqueue_ops_fn with self.input_graph.as_default(): self.enqueue_ops.append( utils.wrap_computation_in_while_loop( get_enqueue_ops_fn(host_id), n=self.iterations, host_name=self._get_host(host_id)))
def build_enqueue_ops(self, input_fn, params, host_id): """Build enqueue ops.""" tf.logging.info("TrainLowLevelRunner: build_enqueue_ops") def get_enqueue_ops_fn(host_id): """Generate the enqueue ops graph function.""" params["dataset_num_shards"] = self.num_hosts params["dataset_index"] = host_id with tf.device(utils.device_for_host(self._get_host(host_id))): dataset = input_fn(params) iterator = dataset.make_initializable_iterator() self.dataset_initializer.append(iterator.initializer) def enqueue_ops_fn(): """Enqueue ops function for one host.""" per_host_sharded_inputs = [] control_deps = [] for _ in range(self.replicas_per_worker): with tf.control_dependencies(control_deps): features, labels = iterator.get_next() if self.use_spatial_partition: num_elements = [] for i, d in enumerate(ssd_constants.FEATURE_SIZES): num_elements.append( d * d * ssd_constants.NUM_DEFAULTS[i]) gt_boxes = tf.split(labels[ssd_constants.BOXES], num_elements, 1) gt_classes = tf.split( labels[ssd_constants.CLASSES], num_elements, 1) def transpose_gt_box(gt_box, i): return tf.transpose( tf.reshape(gt_box, [ -1, ssd_constants.NUM_DEFAULTS[i], ssd_constants.FEATURE_SIZES[i], ssd_constants.FEATURE_SIZES[i], 4 ]), [0, 2, 3, 1, 4]) def transpose_gt_class(gt_class, i): return tf.transpose( tf.reshape(gt_class, [ -1, ssd_constants.NUM_DEFAULTS[i], ssd_constants.FEATURE_SIZES[i], ssd_constants.FEATURE_SIZES[i] ]), [0, 2, 3, 1]) labels[ssd_constants.BOXES] = { i: transpose_gt_box(gt_boxes[i], i) for i in range(len(ssd_constants.NUM_DEFAULTS)) } labels[ssd_constants.CLASSES] = { i: transpose_gt_class(gt_classes[i], i) for i in range(len(ssd_constants.NUM_DEFAULTS)) } self.feature_structure["features"] = features self.feature_structure["labels"] = labels flattened_inputs = data_nest.flatten( self.feature_structure) control_deps.extend(flattened_inputs) per_host_sharded_inputs.append(flattened_inputs) if self.use_spatial_partition: flattened_input_dims = [] for i in per_host_sharded_inputs[0]: if i.shape.ndims >= len( self.input_partition_dims[0]): if i.shape.as_list() == self.feature_structure[ "features"].shape.as_list(): flattened_input_dims.append( self.input_partition_dims[0]) else: flattened_input_dims.append( FLAGS.input_partition_dims + [1] * (i.shape.ndims - len(self.input_partition_dims[0]))) else: flattened_input_dims.append([1] * i.shape.ndims) # pylint: disable=protected-access infeed = tpu_feed._PartitionedInfeedQueue( number_of_tuple_elements=len( per_host_sharded_inputs[0]), host_id=host_id, input_partition_dims=flattened_input_dims, device_assignment=self.device_assignment) self.infeed_queue.append(infeed) return infeed.generate_enqueue_ops( per_host_sharded_inputs) infeed = tpu.InfeedQueue(number_of_tuple_elements=len( per_host_sharded_inputs[0])) self.infeed_queue.append(infeed) return infeed.generate_enqueue_ops( per_host_sharded_inputs, tpu_ordinal_function=utils.tpu_ordinal_fn) return enqueue_ops_fn with self.input_graph.as_default(): self.enqueue_ops.append( utils.wrap_computation_in_while_loop( get_enqueue_ops_fn(host_id), n=self.iterations, host_name=self._get_host(host_id)))
def multiple_steps_fn(): """function for multiple TPU steps in a host training loop.""" return utils.wrap_computation_in_while_loop( single_step_fn, n=iterations_per_loop, parallel_iterations=1)