def LoopBody(i, *input_arrays): """Process outfeed data for a single TpuTrainStep. Args: i: current loop index. *input_arrays: One tf.TensorArray per outfeed tensor. Returns: i+1 (new index) plus post-write tf.TensorArray handles. """ # Outfeed ops execute on each JF node, so they must be located on the # nodes. outfeed_devices = [] device_assignment = py_utils.GetTpuDeviceAssignment() assert device_assignment for replica in range(device_assignment.num_replicas): for core in range(device_assignment.num_cores_per_replica): with tf.device(device_assignment.host_device( replica, core)): outfeed_devices.append( tpu_ops.outfeed_dequeue_tuple( tensor_types, tensor_shapes, device_ordinal=device_assignment.tpu_ordinal( replica, core))) offset = i * num_devices output_arrays = list(input_arrays) # Each output_array holds a different per-example tensor. We get results # for each tensor from each TPU for each TpuTrainStep call. for j in range(len(output_arrays)): for k in range(len(outfeed_devices)): output_arrays[j] = output_arrays[j].write( offset + k, outfeed_devices[k][j]) return tuple([i + 1] + output_arrays)
def create_dequeue_ops(host_id): """Create outfeed dequeue ops.""" dequeue_ops = [] tensor_dtypes = [] tensor_shapes = [] for v in self.outfeed_tensors: dequeue_ops.append([]) tensor_dtypes.append(v.dtype) tensor_shapes.append(v.shape) for i in range(FLAGS.tpu_num_shards_per_host): with tf.device( low_level_utils.device_for_host(self._get_host(host_id))): outfeed_tensors = tpu_ops.outfeed_dequeue_tuple( dtypes=tensor_dtypes, shapes=tensor_shapes, device_ordinal=i) for j, item in enumerate(outfeed_tensors): dequeue_ops[j].append(item) for j in range(len(outfeed_tensors)): dequeue_ops[j] = tf.concat(dequeue_ops[j], axis=0) return dequeue_ops
def _OutfeedDequeue(self): """Collect outfeed dequeue from all devices.""" num_outfeeds = len(self.metrics_nm.Flatten()) outfeed_ops = [[]] * num_outfeeds device_assignment = py_utils.GetTpuDeviceAssignment() assert device_assignment for replica in range(device_assignment.num_replicas): num_cores_per_replica = 1 if self.spmd else ( device_assignment.num_cores_per_replica) for core in range(num_cores_per_replica): with tf.device(device_assignment.host_device(replica, core)): outfeeds_per_core = tpu_ops.outfeed_dequeue_tuple( dtypes=[x.dtype for x in self.metrics_nm.Flatten()], shapes=[x.shape for x in self.metrics_nm.Flatten()], device_ordinal=device_assignment.tpu_ordinal( replica, core)) for idx_outfeed, out_feed in enumerate(outfeeds_per_core): outfeed_ops[idx_outfeed] = outfeed_ops[idx_outfeed] + [ out_feed ] return [tf.concat(per_outfeed, 0) for per_outfeed in outfeed_ops]
def create_dequeue_ops(host_id): """Create outfeed dequeue ops.""" dequeue_ops = [] tensor_dtypes = [] tensor_shapes = [] for v in self.outfeed_tensors: tensor_dtypes.append(v.dtype) tensor_shapes.append(v.shape) with tf.device(utils.device_for_host(self._get_host(host_id))): for i in range(self.replicas_per_worker): if self.use_spatial_partition: replica_id = self.device_assignment.lookup_replicas( host_id, 0)[i] ordinal = self.device_assignment.tpu_ordinal( replica=replica_id, logical_core=0) else: ordinal = i outfeed = tpu_ops.outfeed_dequeue_tuple( dtypes=tensor_dtypes, shapes=tensor_shapes, device_ordinal=ordinal) if len(outfeed) == 2: # 2 outfeed tensors # is_pad: [batch] # detections: [batch, 200, 7] if outfeed[0].shape.ndims == 3: detections, is_pad = outfeed else: is_pad, detections = outfeed num_non_pad = tf.shape(is_pad)[0] - tf.reduce_sum( tf.cast(is_pad, tf.int32)) dequeue_ops.append( tf.slice(detections, [0, 0, 0], [num_non_pad, -1, -1])) else: # no padding, only detections are in the outfeed dequeue_ops.append(outfeed) dequeue_ops = tf.concat(dequeue_ops, axis=0) return dequeue_ops
def _OutfeedDequeue(self): """Collect outfeed dequeue from all devices. Returns: A list of tensors corresponding to stacked decoded outputs. The decoder outputs are stacked on the first dimension (usually corresponds to batch size). """ num_decode_tensors = len(self.decode_nm.Flatten()) outfeed_ops = [[]] * num_decode_tensors device_assignment = py_utils.GetTpuDeviceAssignment() assert device_assignment num_cores_per_replica = (1 if self.spmd else (device_assignment.num_cores_per_replica)) for replica in range(device_assignment.num_replicas): for core in range(num_cores_per_replica): with tf.device(device_assignment.host_device(replica, core)): outfeeds_per_core = tpu_ops.outfeed_dequeue_tuple( dtypes=[x.dtype for x in self.decode_nm.Flatten()], shapes=[x.shape for x in self.decode_nm.Flatten()], device_ordinal=device_assignment.tpu_ordinal(replica, core)) for idx_outfeed, out_feed in enumerate(outfeeds_per_core): outfeed_ops[idx_outfeed] = outfeed_ops[idx_outfeed] + [out_feed] return [tf.concat(per_outfeed, axis=0) for per_outfeed in outfeed_ops]
def _run_eval_phase(): """The real function that runs the evaluation phase.""" # Setup input pipeline. ds_creator = unet.get_dataset_creator('eval') mtf_shapes = unet.get_input_mtf_shapes('eval') model_eval_fn, eval_hooks, output_dtypes_shapes = _get_model_fn( 'eval', mesh_context) if FLAGS.use_tpu: assert mesh_context.device_assignment assert mesh_context.num_cores simd_input_reader = input_reader.SimdMeshImplInputReader( mesh_context.mesh_impl, ds_creator, mtf_shapes, is_eval_mode=True) eval_computation = tpu.replicate( computation=model_eval_fn, inputs=[[]] * mesh_context.num_cores, infeed_queue=simd_input_reader.infeed_queue, device_assignment=mesh_context.device_assignment) output_dtypes, output_shapes = output_dtypes_shapes.get() outfeed_dequeue_ops = [] # Create outfeed_dequeue_ops. for host_id in range(mesh_context.num_hosts): # pylint: disable=protected-access with ops.device(input_reader._host_id_to_tf_device( host_id, external_worker=True)): for device_ordinal in range(mesh_context.num_cores_per_host): outfeed_dequeue_op = tpu_ops.outfeed_dequeue_tuple( dtypes=output_dtypes, shapes=output_shapes, device_ordinal=device_ordinal) # We don't need output other than from core 0. if outfeed_dequeue_ops: outfeed_dequeue_ops.append( [tf.reduce_mean(x) for x in outfeed_dequeue_op]) else: outfeed_dequeue_ops.append(outfeed_dequeue_op) else: placement_input_reader = input_reader.PlacementMeshImplInputReader( mesh_context.mesh_impl, ds_creator, mtf_shapes, is_eval_mode=False) eval_computation = placement_input_reader.gpu_placement(model_eval_fn) ########################################################### # Evaluation. master_to_slice_hook, _ = eval_hooks.get() ckpt_loader_hook = _CkptLoaderHook() all_hooks = [ckpt_loader_hook, master_to_slice_hook] if FLAGS.write_summary: flush_summary = contrib_summary.flush() with tf.train.MonitoredSession( session_creator=tf.train.ChiefSessionCreator( master=FLAGS.master, config=tf.ConfigProto(allow_soft_placement=True)), hooks=all_hooks) as sess: if FLAGS.write_summary: contrib_summary.initialize(session=sess) if FLAGS.use_tpu: simd_input_reader.start_infeed_thread( sess, FLAGS.num_eval_iterations_per_loop) else: placement_input_reader.initialize(sess) pprocessor = unet.PostProcessor() for step in range(FLAGS.num_eval_iterations_per_loop): # Only get results from the 0-th core. if FLAGS.use_tpu: sess.run(eval_computation) results = sess.run(outfeed_dequeue_ops)[0] else: results = sess.run(eval_computation) pprocessor.record(results, FLAGS.pred_output_dir) if FLAGS.write_summary: sess.run(flush_summary) tf.logging.info('eval steps: {}'.format(step)) pprocessor.finish()
def init_graph(self, model_params): """Builds moe decode graph. Args: model_params: the hyperparams of the specified model. """ assert self.graph self.model_params = model_params batch_size = model_params.task.batch_size if (hasattr(model_params.task.builder, 'device_mesh_shape') and model_params.task.builder.device_mesh_shape): num_partitions = np.prod( model_params.task.builder.device_mesh_shape) else: num_partitions = model_params.task.builder.num_devices device_order_mode = (model_params.task.train.tpu_device_order_mode or tpu_device_assignment.DeviceOrderMode.AUTO) self._init_tpu(num_partitions, device_order_mode) assert self.cluster_params # configured by init_tpu self.cluster = self.cluster_params.Instantiate() with self.graph.as_default(), self.cluster, tf.device( self.cluster.GetPlacer()): _ = py_utils.GetOrCreateGlobalStepVar() self.heartbeat = tf.constant(np.pi) device_assignment = py_utils.GetTpuDeviceAssignment() tf.logging.info('Instantiating model') model = model_params.Instantiate() xformer = model.GetTask() self.task = xformer self.init_vars_op = tf.global_variables_initializer() self.saver = tf.train.Saver(sharded=True, reshape=self._saver_reshape) infeed = self._config_infeed(num_partitions=num_partitions, device_assignment=device_assignment, batch_size=batch_size) self.outfeed = [] def decode_fn(*infeed_batch): # pylint: disable=missing-docstring # Length 6 is passed when there is no tgt_mask (e.g. decoding) and # length 7 is passed when there is a tgt_mask (e.g. fprop). self.outfeed = self._config_outfeed(xformer, infeed_batch) with tf.device(tf.tpu.core(0)): outfeed_op = tpu_ops.outfeed_enqueue_tuple( tf.nest.flatten(self.outfeed)) return [outfeed_op] @tpu_function.on_device_training_loop def decode_loop_fn(): if not self.num_batches: infinite_repeat(decode_fn, infeed) else: training_loop.repeat(self.num_batches, decode_fn, infeed_queue=infeed) self.compile_op, self.decode_loop = tpu_lib.split_compile_and_shard( decode_loop_fn, num_shards=1, device_assignment=device_assignment) assert self.outfeed with tf.device(device_assignment.tpu_device(0, 0)): self.outfeed_op = tpu_ops.outfeed_dequeue_tuple( dtypes=[x.dtype for x in tf.nest.flatten(self.outfeed)], shapes=[x.shape for x in tf.nest.flatten(self.outfeed)])
def initialize(self, train_input_fn, eval_input_fn, model_fn, train_batch_size, eval_batch_size, input_partition_dims=None, init_fn=None, train_has_labels=True, eval_has_labels=True, params=None, num_partitions=None): """Build graphs for the TPU device and the input pipelines.""" num_cores_per_replica = 1 num_cores_per_replica = functools.reduce( operator.mul, input_partition_dims ) if input_partition_dims else num_partitions if num_partitions else 1 self.device_assignment = device_assignment.device_assignment( topology=self.device_topology, computation_shape=_NUM_CORES_TO_COMPUTATION_SHAPE[ num_cores_per_replica], num_replicas=self.num_replicas) self.train_batch_size = train_batch_size self.eval_batch_size = eval_batch_size self.eval_has_labels = eval_has_labels self.model_fn = model_fn if params is None: params = {} params[ "dataset_num_shards"] = self.num_replicas // FLAGS.replicas_per_host per_replica_train_batch_size = train_batch_size // self.num_replicas per_replica_eval_batch_size = eval_batch_size // self.num_replicas for i in range(self.num_replicas // FLAGS.replicas_per_host): params["dataset_index"] = i params["batch_size"] = per_replica_train_batch_size self.build_enqueue_ops(train_input_fn, True, input_partition_dims, params) if self.eval_steps > 0: params["batch_size"] = per_replica_eval_batch_size self.build_enqueue_ops(eval_input_fn, False, input_partition_dims, params) def train_step(_): """One train step.""" inp = self.infeed_op[True].generate_dequeue_op() flatten_structure = tf.nest.flatten(self.feature_structure[True]) inp = [ tf.slice(i, [0] * i.shape.ndims, j.shape) for i, j in zip(inp, flatten_structure) ] if train_has_labels: features, labels = tf.nest.pack_sequence_as( self.feature_structure[True], inp) else: features = tf.nest.pack_sequence_as( self.feature_structure[True], inp) labels = None self.maybe_add_embedding_features(features, True) train_op, _ = model_fn(features, labels, True) embedding_train_op = self.maybe_get_embedding_train_op() with tf.device(device_for_tpu_core(self.get_host(0))): with tf.control_dependencies([train_op, embedding_train_op]): return tf.constant(0) @tpu_function.on_device_training_loop def train_loop(): return training_loop.repeat(self.iterations_per_loop, train_step, tf.constant(0)) def train_eval_step(): with tf.control_dependencies(train_loop()): if self.eval_steps > 0: return self.eval_loop() else: return tf.no_op() @on_device_train_and_eval_loops def train_eval_loop(): return training_loop.repeat(self.max_train_iterations, train_eval_step) with self.graph.as_default(): (self.train_eval_op, ) = tpu.shard( train_eval_loop, inputs=[], num_shards=self.num_replicas, outputs_from_all_shards=False, device_assignment=self.device_assignment) if FLAGS.model_dir: tf.io.write_graph(self.graph, FLAGS.model_dir, "graph.pbtxt") output_graph = tf.Graph() if self.eval_steps > 0: with output_graph.as_default(): flatten_output = tf.nest.flatten(self.predict_output) self.dequeue_ops = [[] for _ in flatten_output] tensor_dtypes = [v.dtype for v in flatten_output] tensor_shapes = [v.shape for v in flatten_output] is_padded_index = flatten_output.index( self.predict_output[_IS_PADDED] ) if _IS_PADDED in self.predict_output else -1 for i in range(self.num_replicas // FLAGS.replicas_per_host): with tf.device(device_for_host(self.get_host(i))): host_dequeue_ops = [[] for _ in flatten_output] for j in range(FLAGS.replicas_per_host): replica_id = self.device_assignment.lookup_replicas( i, 0)[j] ordinal = self.device_assignment.tpu_ordinal( replica=replica_id, logical_core=0) dequeue_ops = tpu_ops.outfeed_dequeue_tuple( dtypes=tensor_dtypes, shapes=tensor_shapes, device_ordinal=ordinal) if is_padded_index >= 0: num_non_pad = tf.shape( dequeue_ops[is_padded_index] )[0] - tf.reduce_sum( tf.cast(dequeue_ops[is_padded_index], tf.int32)) dequeue_ops = [ tf.slice(k, [0] * k.shape.ndims, [num_non_pad] + [-1] * (k.shape.ndims - 1)) for k in dequeue_ops ] for k, item in enumerate(dequeue_ops): host_dequeue_ops[k].append(item) for k in range(len(self.predict_output)): self.dequeue_ops[k].append( tf.concat(host_dequeue_ops[k], axis=0)) self.sess = tf.Session(self.master, graph=self.graph, config=self.config) for is_training in [True, False]: if is_training or self.eval_steps > 0: for i in range(self.num_input_graphs): with self.input_graph[is_training][i].as_default(): self.input_sess[is_training][i] = tf.Session( self.master, graph=self.input_graph[is_training][i], config=self.config) self.input_sess[is_training][i].run( self.dataset_initializer[is_training][i]) self.output_sess = tf.Session(self.master, graph=output_graph, config=self.config) with self.graph.as_default(): _ = tf.train.get_or_create_global_step() if init_fn: init_fn() checkpoint_path = tf.train.latest_checkpoint( FLAGS.model_dir) if FLAGS.model_dir else None if FLAGS.restore_checkpoint and checkpoint_path: tf.train.Saver().restore(self.sess, checkpoint_path) else: self.sess.run(tf.global_variables_initializer()) self.sess.run(tf.local_variables_initializer()) self.maybe_load_embedding_vars() self.global_step = self.sess.run( tf.train.get_global_step(self.graph)) def train_eval_thread_fn(sess, train_eval_op): sess.run([train_eval_op]) # Start the just in time compilation of the model function self.train_eval_thread = threading.Thread(target=train_eval_thread_fn, args=(self.sess, self.train_eval_op)) self.train_eval_thread.start() # Sleep for JTC to finish time.sleep(FLAGS.sleep_after_init)