def _CompileDecodeFn(self): """Wrap the DecodeFn with split_compile_and_shard.""" with cluster_factory.SetImmediatelyInstantiateVariables(False): self._model = self._InstantiateTaskModel(self._task_params) self._task = self._model.GetTask() self._task.input.InstantiateVariables() self._task.input.CreateTpuEnqueueOps() self._task.input.CreateCpuPassthroughEnqueueOps() def _DecodeFn(): """Decode call to be compiled for TPU.""" with py_utils.OpportunisticVariableReuseScope(True): self._model.InstantiateVariables() input_batch = self._task.input.TpuDequeueBatch() decode_dict = self._task.Decode(input_batch) self.decode_nm = py_utils.NestedMap(decode_dict) return self.decode_nm.Flatten() self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( _DecodeFn, num_shards=self.data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) self.cpu_pt = self._task.input.DequeueCpuPassthrough() self.decode_tensors = py_utils.NestedMap(self.decode_nm) self.decode_tensors = self.decode_tensors.Pack(batch_parallel_res)
def BuildTpuSubgraph(self): tf.logging.info('DecodeProgram BuildTpuSubGraph') py_utils.ResetStepSeed() # Instantiate input generator first. self._input = self._task_params.input.Instantiate() self._input.CreateTpuEnqueueOps() self.SkipCreateChild(self._task_params) def _DecodeFn(): """Decode call to be compiled for TPU.""" with py_utils.OpportunisticVariableReuseScope(True): with cluster_factory.SetEval(True): self._model = self._task_params.Instantiate() self._task = self._model.GetTask() self._task.AddChild('input', self._input) input_batch = self._task.input.TpuDequeueBatch() metrics_dict = self._task.Decode(input_batch) self.metrics_nm = py_utils.NestedMap(metrics_dict) return self.metrics_nm.Flatten() self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( _DecodeFn, num_shards=self.data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) self.metrics = py_utils.NestedMap(self.metrics_nm) self.metrics = self.metrics.Pack(batch_parallel_res) return None
def LoopBody(i, *input_arrays): """Process outfeed data for a single TpuTrainStep. Args: i: current loop index. *input_arrays: One tf.TensorArray per outfeed tensor. Returns: i+1 (new index) plus post-write tf.TensorArray handles. """ # Outfeed ops execute on each JF node, so they must be located on the # nodes. outfeed_devices = [] device_assignment = py_utils.GetTpuDeviceAssignment() assert device_assignment for replica in range(device_assignment.num_replicas): for core in range(device_assignment.num_cores_per_replica): with tf.device(device_assignment.host_device( replica, core)): outfeed_devices.append( tpu_ops.outfeed_dequeue_tuple( tensor_types, tensor_shapes, device_ordinal=device_assignment.tpu_ordinal( replica, core))) offset = i * num_devices output_arrays = list(input_arrays) # Each output_array holds a different per-example tensor. We get results # for each tensor from each TPU for each TpuTrainStep call. for j in range(len(output_arrays)): for k in range(len(outfeed_devices)): output_arrays[j] = output_arrays[j].write( offset + k, outfeed_devices[k][j]) return tuple([i + 1] + output_arrays)
def BuildTpuSubgraph(self): tf.logging.info('DecodeProgram BuildTpuSubGraph') py_utils.ResetStepSeed() def _DecodeFn(): """Decode call to be compiled for TPU.""" with py_utils.OpportunisticVariableReuseScope(True): with cluster_factory.SetEval(True): self._model = self._task_params.Instantiate() self._model_task = self._model.GetTask() if py_utils.use_tpu(): input_batch = self._model_task.input_generator.CreateTpuFeeds( ) else: input_batch = self._model_task.input_generator.SplitInputBatch( self.cluster.num_splits_per_client) metrics_dict = self._model_task.Decode(input_batch) self.metrics_nm = py_utils.NestedMap(metrics_dict) return self.metrics_nm.Flatten() self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( _DecodeFn, num_shards=self.data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) self.metrics = py_utils.NestedMap(self.metrics_nm) self.metrics = self.metrics.Pack(batch_parallel_res) return None
def _CompileDecodeLoop(self): """Wrap the DecodeLoop with split_compile_and_shard.""" device_assignment = py_utils.GetTpuDeviceAssignment() with cluster_factory.SetImmediatelyInstantiateVariables(False): self._model = self._InstantiateTaskModel(self._task_params) self._task = self._model.GetTask() self._task.input.InstantiateVariables() self._task.input.CreateTpuEnqueueOps() self._task.input.CreateCpuPassthroughEnqueueOps() def _DecodeStep(): """Decode call to be compiled for TPU.""" with py_utils.OpportunisticVariableReuseScope(True): self._model.InstantiateVariables() input_batch = self._task.input.TpuDequeueBatch() decode_dict = self._task.Decode(input_batch) self.decode_nm = py_utils.NestedMap(decode_dict) return [self._OutfeedEnqueue(decode_dict)] @tpu_function.on_device_training_loop def DecodeLoopFn(): return tpu_training_loop.repeat( self._steps_per_loop, _DecodeStep, inputs=[]) self._compile_op, self.decode_loop = tpu.split_compile_and_shard( DecodeLoopFn, num_shards=self.data_parallelism, device_assignment=device_assignment) # Get a list of outfeed ops. self.decode_tensors = self._OutfeedDequeue() # Pack the list of outfeed ops with structure in self.decode_nm. self.decode_tensors = tf.nest.pack_sequence_as(self.decode_nm, self.decode_tensors) self.cpu_pt = self._task.input.DequeueCpuPassthrough()
def BuildTpuSubgraph(self): tf.logging.info('EvalProgram BuildTpuSubGraph') with cluster_factory.SetEval(True): self._eval_metrics = metrics.TpuEvalMetrics() data_parallelism = self.data_parallelism with cluster_factory.SetImmediatelyInstantiateVariables(False): self._model = self._InstantiateTaskModel(self._task_params) self._task = self._model.GetTask() self._task.input.InstantiateVariables() self._task.input.CreateTpuEnqueueOps() self._init_input_ops = self._task.input.InitOps() # XLA thinks self.TpuEvalLoop() requires 1 argument due to self # Trick it with wrapper function def TpuEvalLoopWrapper(): return self.TpuEvalLoop() self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( TpuEvalLoopWrapper, num_shards=data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) self._task.input.CreateTpuEmbeddingEnqueueOps(mode_override='inference') # Get metric result from a single replica; they are all same here. self.tpu_ops = [[t[0] for t in batch_parallel_res]] return self.tpu_ops
def BuildTpuSubgraph(self): tf.logging.info('TrainProgram BuildTpuSubGraph') with py_utils.OpportunisticVariableReuseScope(True): self._eval_metrics = metrics.TpuEvalMetrics() data_parallelism = self.data_parallelism # Instantiate input generator first. self._input = self._task_params.input.Instantiate() self._input.CreateTpuEnqueueOps() self.SkipCreateChild(self._task_params) def TpuTrainStep(*args): """Train a shard of a batch on a single TPU core. Args: *args: metrics values from previous steps. Returns: New summed metrics values and a train_op. """ self._model = self._task_params.Instantiate() self._task = self._model.GetTask() self._task.AddChild('input', self._input) self._model.ConstructFPropBPropGraph() per_step_eval_metrics = self._eval_metrics.SetMetrics( self._task.eval_metrics, args) outfeed_op = self._OutfeedEnqueue( self._task.per_example_tensors) summed_metrics = [] assert len(per_step_eval_metrics) == len(args) with tf.control_dependencies([outfeed_op]): for x, y in zip(per_step_eval_metrics, args): summed_metrics.append(x + y) return summed_metrics + [self._task.train_op] @tpu_function.on_device_training_loop def TpuTrain(): loop_result = tpu_training_loop.repeat( self._steps_per_loop, TpuTrainStep, inputs=self._eval_metrics.initial_values, name='train_loop') # Final metrics are the avg across self._steps_per_loop steps. return self._eval_metrics.FinalizeMetrics(loop_result) self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( TpuTrain, num_shards=data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) outfeed_dequeue_op = self._OutfeedDequeueLoop( self._task.per_example_tensors, self._steps_per_loop, self.num_splits_per_client) # Get metric result from a single replica; they are all same here. self.tpu_ops = [[t[0] for t in batch_parallel_res], outfeed_dequeue_op] return self.tpu_ops
def TPUOrdinalFunction(shard_index_in_host): device_assignment = py_utils.GetTpuDeviceAssignment() if device_assignment: # We put both enqueue/dequeue ops at core 0 in each replica. replica = device_assignment.lookup_replicas( task_id, 0)[shard_index_in_host] # pylint: disable=cell-var-from-loop return device_assignment.tpu_ordinal(replica=replica) else: return shard_index_in_host
def BuildTpuSubgraph(self): with py_utils.OpportunisticVariableReuseScope(True): self._eval_metrics = metrics.TpuEvalMetrics() data_parallelism = self.data_parallelism def TpuTrainStep(*args): """Train a shard of a batch on a single TPU core. Args: *args: metrics values from previous steps. Returns: New summed metrics values and a train_op. """ self._model = self._task_params.Instantiate() self._model.ConstructFPropBPropGraph() per_step_eval_metrics = self._eval_metrics.SetMetrics( self._model.GetTask().eval_metrics, args) outfeed_op = self._OutfeedEnqueue( self._model.GetTask().per_example_tensors) summed_metrics = [] assert len(per_step_eval_metrics) == len(args) with tf.control_dependencies([outfeed_op]): for x, y in zip(per_step_eval_metrics, args): summed_metrics.append(x + y) return summed_metrics + [self._model.GetTask().train_op] @tpu_function.on_device_training_loop def TpuTrain(): loop_result = tpu_training_loop.repeat( self._steps_per_loop, TpuTrainStep, inputs=self._eval_metrics.initial_values, name='train_loop') # Final metrics are the avg across self._steps_per_loop steps. return self._eval_metrics.FinalizeMetrics(loop_result) batch_parallel_res = tf.tpu.batch_parallel( TpuTrain, num_shards=data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) outfeed_dequeue_op = self._OutfeedDequeueLoop( self._model.GetTask().per_example_tensors, self._steps_per_loop, self.num_splits_per_client) # Get metric result from a single replica; they are all same here. self.tpu_ops = [[t[0] for t in batch_parallel_res], outfeed_dequeue_op] # TODO(blee): This is going to need to be fixed for multiple-model # execution. Need to get only the vars associated with the model. self._checkpointer = self._CreateCheckpointer( self._checkpoint_dir, self._model) return self.tpu_ops
def BuildTpuSubgraph(self): tf.logging.info('EvalProgram BuildTpuSubGraph') with cluster_factory.SetEval(True): self._eval_metrics = metrics.TpuEvalMetrics() data_parallelism = self.data_parallelism with cluster_factory.SetImmediatelyInstantiateVariables(False): self._model = self._InstantiateTaskModel(self._task_params) self._task = self._model.GetTask() self._task.input.InstantiateVariables() self._task.input.CreateTpuEnqueueOps() def TpuEvalStep(*args): """Eval a shard of a batch on a single TPU core. Args: *args: metrics values from previous steps. Returns: Summed eval metrics. """ with tf.name_scope('tpu_eval'): with py_utils.OpportunisticVariableReuseScope(True): self._model.InstantiateVariables() self._model.ConstructFPropGraph() per_step_eval_metrics = self._eval_metrics.SetMetrics( self._task.eval_metrics, args) summed_metrics = [] for x, y in zip(per_step_eval_metrics, args): summed_metrics.append(x + y) return summed_metrics @tpu_function.on_device_training_loop def TpuEval(): loop_result = tpu_training_loop.repeat( self._steps_per_loop, TpuEvalStep, inputs=self._eval_metrics.initial_values, name='eval_loop') # Final metrics are the avg across self._steps_per_loop steps. return self._eval_metrics.FinalizeMetrics(loop_result) self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( TpuEval, num_shards=data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) self._task.input.CreateTpuEmbeddingEnqueueOps( mode_override='inference') # Get metric result from a single replica; they are all same here. self.tpu_ops = [[t[0] for t in batch_parallel_res]] return self.tpu_ops
def BuildTpuSubgraph(self): tf.logging.info('EvalProgram BuildTpuSubGraph') with py_utils.OpportunisticVariableReuseScope(True): self._eval_metrics = metrics.TpuEvalMetrics() data_parallelism = self.data_parallelism def TpuEvalStep(*args): """Eval a shard of a batch on a single TPU core. Args: *args: metrics values from previous steps. Returns: Per-step eval metrics. """ self._model = self._task_params.Instantiate() self._model.ConstructFPropGraph() per_step_eval_metrics = self._eval_metrics.SetMetrics( self._model.GetTask().eval_metrics, args) return per_step_eval_metrics @tpu_function.on_device_training_loop def TpuEval(): loop_result = tpu_training_loop.repeat( self._steps_per_loop, TpuEvalStep, inputs=self._eval_metrics.initial_values, name='eval_loop') # Final metrics are the avg across self._steps_per_loop steps. return self._eval_metrics.FinalizeMetrics(loop_result) batch_parallel_res = tf.tpu.batch_parallel( TpuEval, num_shards=data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) # Get metric result from a single replica; they are all same here. self.tpu_ops = [[t[0] for t in batch_parallel_res]] self._checkpointer = checkpointer.Checkpointer( self._checkpoint_dir, self._model) return self.tpu_ops
def _OutfeedDequeue(self): """Collect outfeed dequeue from all devices.""" num_outfeeds = len(self.metrics_nm.Flatten()) outfeed_ops = [[]] * num_outfeeds device_assignment = py_utils.GetTpuDeviceAssignment() assert device_assignment for replica in range(device_assignment.num_replicas): num_cores_per_replica = 1 if self.spmd else ( device_assignment.num_cores_per_replica) for core in range(num_cores_per_replica): with tf.device(device_assignment.host_device(replica, core)): outfeeds_per_core = tpu_ops.outfeed_dequeue_tuple( dtypes=[x.dtype for x in self.metrics_nm.Flatten()], shapes=[x.shape for x in self.metrics_nm.Flatten()], device_ordinal=device_assignment.tpu_ordinal( replica, core)) for idx_outfeed, out_feed in enumerate(outfeeds_per_core): outfeed_ops[idx_outfeed] = outfeed_ops[idx_outfeed] + [ out_feed ] return [tf.concat(per_outfeed, 0) for per_outfeed in outfeed_ops]
def BuildTpuSubgraph(self): tf.logging.info('DecodeProgram BuildTpuSubGraph') py_utils.ResetStepSeed() device_assignment = py_utils.GetTpuDeviceAssignment() self.spmd = self._task_params.input.use_partitioned_infeed_queue with cluster_factory.SetEval(True): with cluster_factory.SetImmediatelyInstantiateVariables(False): self._model = self._task_params.Instantiate() self._task = self._model.GetTask() self._task.input.InstantiateVariables() self._task.input.CreateTpuEnqueueOps() def _DecodeStep(): """Decode call to be compiled for TPU.""" with py_utils.OpportunisticVariableReuseScope(True): self._model.InstantiateVariables() input_batch = self._task.input.TpuDequeueBatch() metrics_dict = self._task.Decode(input_batch) self.metrics_nm = py_utils.NestedMap(metrics_dict) device = tpu.core(0) if self.spmd else '' with tf.device(device): outfeed_enqueue = tpu_ops.outfeed_enqueue_tuple( self.metrics_nm.Flatten()) return [outfeed_enqueue] @tpu_function.on_device_training_loop def DecodeLoopFn(): return tpu_training_loop.repeat(self._steps_per_loop, _DecodeStep, inputs=[]) self._compile_op, self.decode_loop = tpu.split_compile_and_shard( DecodeLoopFn, num_shards=self.data_parallelism, device_assignment=device_assignment) # Get a list of outfeed ops. self.metrics = self._OutfeedDequeue() # Pack the list of outfeed ops with structure in self.metrics_nm. self.metrics = tf.nest.pack_sequence_as(self.metrics_nm, self.metrics) return
def BuildTpuSubgraph(self): tf.logging.info('DecodeProgram BuildTpuSubGraph') py_utils.ResetStepSeed() def _DecodeFn(): with py_utils.OpportunisticVariableReuseScope(True): with cluster_factory.SetEval(True): self._model = self._task_params.Instantiate() self._model_task = self._model.GetTask() input_batch = self._model_task.GetInputBatch() metrics_dict = self._model_task.Decode(input_batch) self.metrics_nm = py_utils.NestedMap(metrics_dict) return self.metrics_nm.Flatten() batch_parallel_res = tf.tpu.batch_parallel( _DecodeFn, num_shards=self.data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) self.metrics = py_utils.NestedMap(self.metrics_nm) self.metrics = self.metrics.Pack(batch_parallel_res) return None
def BuildTpuSubgraph(self): py_utils.ResetStepSeed() def _DecodeFn(): with py_utils.OpportunisticVariableReuseScope(True): self._model = self._task_params.Instantiate() self._model_task = self._model.GetTask() input_batch = self._model_task.GetInputBatch() metrics_dict = self._model_task.Decode(input_batch) self.metrics_nm = py_utils.NestedMap(metrics_dict) return self.metrics_nm.Flatten() batch_parallel_res = tf.tpu.batch_parallel( _DecodeFn, num_shards=self.data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) self._checkpointer = checkpointer.Checkpointer(self._checkpoint_dir, self._model) self.metrics = py_utils.NestedMap(self.metrics_nm) self.metrics = self.metrics.Pack(batch_parallel_res) return None
def _OutfeedDequeue(self): """Collect outfeed dequeue from all devices. Returns: A list of tensors corresponding to stacked decoded outputs. The decoder outputs are stacked on the first dimension (usually corresponds to batch size). """ num_decode_tensors = len(self.decode_nm.Flatten()) outfeed_ops = [[]] * num_decode_tensors device_assignment = py_utils.GetTpuDeviceAssignment() assert device_assignment num_cores_per_replica = (1 if self.spmd else (device_assignment.num_cores_per_replica)) for replica in range(device_assignment.num_replicas): for core in range(num_cores_per_replica): with tf.device(device_assignment.host_device(replica, core)): outfeeds_per_core = tpu_ops.outfeed_dequeue_tuple( dtypes=[x.dtype for x in self.decode_nm.Flatten()], shapes=[x.shape for x in self.decode_nm.Flatten()], device_ordinal=device_assignment.tpu_ordinal(replica, core)) for idx_outfeed, out_feed in enumerate(outfeeds_per_core): outfeed_ops[idx_outfeed] = outfeed_ops[idx_outfeed] + [out_feed] return [tf.concat(per_outfeed, axis=0) for per_outfeed in outfeed_ops]
def CreateTpuFeeds(self): """Creates the TPU infeed queue from preprocessed batch.""" p = self.params cluster = self.cluster num_tpu_hosts = cluster.num_tpu_hosts num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts tf.logging.info( 'CreateTPUFeeds num_splits_per_client={} ' 'num_devices_per_split={} num_tpu_hosts={} use_per_host_infeed={}'. format(cluster.num_splits_per_client, cluster.num_devices_per_split, num_tpu_hosts, p.use_per_host_infeed)) assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts) if (cluster.num_devices_per_split > num_cores_per_host and p.use_per_host_infeed): tf.logging.fatal( 'Doesn\'t support per host infeed mode when ' 'num_devices_per_split({}) > num_cores_per_host({})'.format( cluster.num_devices_per_split, num_cores_per_host)) num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 with py_utils.outside_all_rewrites(): assert py_utils.use_tpu() assert not self._made_tpu_infeed shards = tpu_function.get_tpu_context( ).number_of_shards // num_infeed_hosts tf.logging.info('shards {}'.format(shards)) input_ops_list = [] queues = [] tpu_embedding_collection = tf.get_collection( py_utils.TPU_EMBEDDING) tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) if num_tpu_hosts > 1 and tpu_embedding is not None: if not p.use_per_host_infeed: tf.logging.fatal( 'TPU Embedding must be used with per_host_infeed with multiple ' 'TPU host topologies.') tpu_emb_input_keys = (list( tpu_embedding.feature_to_config_dict.keys()) if tpu_embedding is not None else []) tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys) batch = None for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): batch = self.GetPreprocessedInputBatch() if isinstance(batch, py_utils.NestedMap): # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU. # Note that when MultiTaskData is used, bucket_keys will be at the # second level of the dictionary. batch = batch.FilterKeyVal( lambda k, _: not k.endswith('bucket_keys')) tf.logging.info('host_device: %s, batch: %r', host_device, batch) if tpu_embedding is not None: enqueue_dict_per_core = [ {} for _ in range(tpu_embedding.num_cores_per_host) ] num_cores_per_host = tpu_embedding.num_cores_per_host for key in tpu_emb_input_keys: feat = batch[key] tpu_emb_feat_splitted = tf.split( feat, num_cores_per_host) for core, split in enumerate( tpu_emb_feat_splitted): # Dense to sparse. Note the assumption of a padding id. sample_indices = tf.where( tf.not_equal(split, -1)) embedding_indices = tf.gather_nd( split, sample_indices) enqueue_data = tpu_embedding_lib.EnqueueData( embedding_indices, sample_indices) enqueue_dict_per_core[core][key] = enqueue_data input_ops_list += tpu_embedding.generate_enqueue_ops( enqueue_dict_per_core) for k, x in batch.FlattenItems(): assert x.shape.is_fully_defined(), ( 'Shape must be fully defined: %s: %s' % (k, x)) # TODO(cwhipkey): if it's a string (or other type not supported on # TPU), drop it from feeding and on the other end add in an op that # fails if used. shapes = batch.Transform(lambda x: x.shape).Flatten() dtypes = batch.Transform(lambda x: x.dtype).Flatten() tf.logging.info('host_device: %s infeed shapes: %r', host_device, shapes) tf.logging.info('host_device: %s infeed dtypes: %r', host_device, dtypes) if p.use_partitioned_infeed_queue: device_assignment = py_utils.GetTpuDeviceAssignment() host_device = device_assignment.host_device( replica=0, job=tf.flags.FLAGS.tf_master) host_id = int( host_device.split('/task:')[1].split('/device:') [0]) tf.logging.info('host_id: {} host_device: {}'.format( host_id, host_device)) q = tpu_feed._PartitionedInfeedQueue( # pylint: disable=protected-access number_of_tuple_elements=len(dtypes), device_assignment=device_assignment, host_id=host_id, input_partition_dims=[[p.num_partitions, 1] for _ in dtypes], tuple_types=dtypes, tuple_shapes=shapes) else: q = tpu_feed.InfeedQueue(tuple_types=dtypes, tuple_shapes=shapes) assert shards is not None q.set_number_of_shards(shards) queues.append(q) tf.logging.info('q=%r', q) if p.use_partitioned_infeed_queue: input_ops = q.generate_enqueue_ops([batch.Flatten()]) elif p.use_per_host_infeed: # TODO(ylc/zhifengc): Add this to a policy module and test it. def TPUOrdinalFunction(shard_index_in_host): device_assignment = py_utils.GetTpuDeviceAssignment( ) if device_assignment: # We put both enqueue/dequeue ops at core 0 in each replica. replica = device_assignment.lookup_replicas( task_id, 0)[shard_index_in_host] # pylint: disable=cell-var-from-loop return device_assignment.tpu_ordinal( replica=replica) else: return shard_index_in_host input_ops = q.split_inputs_and_generate_enqueue_ops( batch.Flatten(), placement_function=lambda x: host_device, # pylint: disable=cell-var-from-loop tpu_ordinal_function=TPUOrdinalFunction) else: input_ops = q.split_inputs_and_generate_enqueue_ops( batch.Flatten(), device_assignment=py_utils.GetTpuDeviceAssignment( )) input_ops_list += input_ops tf.logging.info('input_ops_list %s', input_ops_list) tpu_infeed_op = tf.group(*input_ops_list) self._made_tpu_infeed = True # Let trainer.py use multiple threads to drive the infeed op. for _ in range(p.tpu_infeed_parallelism): tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op) self._tpu_infeed_op = tpu_infeed_op with tf.device(tf.tpu.core(0)): tensors = queues[0].generate_dequeue_op() return batch.Pack(tensors)
def BuildTpuSubgraph(self): tf.logging.info('TrainProgram BuildTpuSubGraph') self.spmd = (self.params.spmd or self._task_params.input.use_partitioned_infeed_queue) self._eval_metrics = metrics.TpuEvalMetrics() data_parallelism = self.data_parallelism with cluster_factory.SetImmediatelyInstantiateVariables(False): self._model = self._InstantiateTaskModel(self._task_params) self._task = self._model.GetTask() self._task.input.InstantiateVariables() self._task.input.CreateTpuEnqueueOps() def TpuTrainStep(*args): """Train a shard of a batch on a single TPU core. Args: *args: metrics values from previous steps. Returns: New summed metrics values and a train_op. """ with tf.name_scope('tpu_train'): with py_utils.OpportunisticVariableReuseScope(True): self._model.InstantiateVariables() self._model.ConstructFPropBPropGraph() per_step_eval_metrics = self._eval_metrics.SetMetrics( self._task.eval_metrics, args) outfeed_op = self._OutfeedEnqueue( self._task.per_example_tensors) summed_metrics = [] assert len(per_step_eval_metrics) == len(args) with tf.control_dependencies([outfeed_op]): for x, y in zip(per_step_eval_metrics, args): summed_metrics.append(x + y) return summed_metrics + [self._task.train_op] @tpu_function.on_device_training_loop def TpuTrain(): loop_result = tpu_training_loop.repeat( self._steps_per_loop, TpuTrainStep, inputs=self._eval_metrics.initial_values, name='train_loop') # Final metrics are the avg across self._steps_per_loop steps. return self._eval_metrics.FinalizeMetrics(loop_result) self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( TpuTrain, num_shards=data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) outfeed_dequeue_op = self._OutfeedDequeueLoop( self._task.per_example_tensors, self._steps_per_loop, self.num_splits_per_client) self._task.input.CreateTpuEmbeddingEnqueueOps() # Get metric result from a single replica; they are all same here. def _ConstructPostTrainingLoop(train_loop_op, outfeed_dequeue_op): """Returns the op for tpu training with tail cpu computation.""" # Adds a tail computation that is run after the tpu_training loop # step finishes. This allows us to run certain computation that # acts on the variable between tpu_train_loop iterations and # amortizing the cost of the operations. Alternative of running # tpu.outside_compilation & using tf.cond is expenseive. with tf.control_dependencies(train_loop_op): self._model.ConstructPostTrainingLoop() with tf.control_dependencies( [self._task.post_training_loop_op]): return ([[tf.identity(o) for o in train_loop_op], outfeed_dequeue_op]) # Get metric result from a single replica; they are all same here. all_tpu_ops = [t[0] for t in batch_parallel_res] self.tpu_ops = (_ConstructPostTrainingLoop(all_tpu_ops, outfeed_dequeue_op)) self._model_analysis, self._total_num_params = summary_utils.ModelAnalysis( self._model) try: with tf.io.gfile.GFile( os.path.join(self._program_dir, 'model_analysis.txt'), 'w') as f: f.write(self._model_analysis) except tf.errors.NotFoundError as e: tf.logging.info('Failed to write model analysis %s', e) return self.tpu_ops
def CreateTpuFeeds(self): """Creates the TPU infeed queue from preprocessed batch.""" p = self.params cluster = self.cluster num_tpu_hosts = cluster.num_tpu_hosts num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts tf.logging.info('num_cores_per_host {}'.format(num_cores_per_host)) tf.logging.info('num_devices_per_split {}'.format( cluster.num_devices_per_split)) assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts) if (cluster.num_devices_per_split > num_cores_per_host and p.use_per_host_infeed): tf.logging.fatal( 'Doesn\'t support per host infeed mode when ' 'num_devices_per_split({}) > num_cores_per_host({})'.format( cluster.num_devices_per_split, num_cores_per_host)) num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 with py_utils.outside_all_rewrites(): assert py_utils.use_tpu() assert not self._made_tpu_infeed shards = tpu_function.get_tpu_context( ).number_of_shards // num_infeed_hosts input_ops_list = [] queues = [] first_batch = None tpu_embedding_collection = tf.get_collection( py_utils.TPU_EMBEDDING) tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) tpu_embedding_input_keys = ( tpu_embedding.feature_to_config_dict.keys() if tpu_embedding is not None else []) for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): batch = self.GetPreprocessedInputBatch() tpu_embedding_features = [] for tpu_embedding_input_key in tpu_embedding_input_keys: tpu_embedding_feature = batch.pop( tpu_embedding_input_key) tpu_embedding_features.append( (tpu_embedding_input_key, tpu_embedding_feature)) if first_batch is None: first_batch = batch flat_batch = batch.FlattenItems() if tpu_embedding is not None: enqueue_dict_per_core = [ {} ] * tpu_embedding.num_cores_per_host num_cores_per_host = tpu_embedding.num_cores_per_host for tpu_embedding_input_key, tpu_embedding_feature in tpu_embedding_features: tpu_embedding_feature_splitted = tf.split( tpu_embedding_feature, num_cores_per_host) for core, split in enumerate( tpu_embedding_feature_splitted): enqueue_data = tpu_embedding_lib.EnqueueData( tf.squeeze(split, axis=[1])) enqueue_dict_per_core[core][ tpu_embedding_input_key] = enqueue_data input_ops_list += tpu_embedding.generate_enqueue_ops( enqueue_dict_per_core) shapes, types = [], [] for k, x in flat_batch: assert x.shape.is_fully_defined(), ( 'Shape must be fully defined: %s: %s' % (k, x)) # TODO(cwhipkey): if it's a string (or other type not supported on # TPU), drop it from feeding and on the other end add in an op that # fails if used. shapes.append(x.shape) types.append(x.dtype) q = tf.contrib.tpu.InfeedQueue(tuple_types=types, tuple_shapes=shapes) queues.append(q) assert shards is not None q.set_number_of_shards(shards) if p.use_per_host_infeed: # TODO(ylc/zhifengc): Add this to a policy module and test it. def _tpu_ordinal_function(shard_index_in_host): device_assignment = py_utils.GetTpuDeviceAssignment( ) if device_assignment: # We put both enqueue/dequeue ops at core 0 in each replica. replica = device_assignment.lookup_replicas( task_id, 0)[shard_index_in_host] # pylint: disable=cell-var-from-loop return device_assignment.tpu_ordinal( replica=replica) else: return shard_index_in_host input_ops = q.split_inputs_and_generate_enqueue_ops( [v for _, v in flat_batch], placement_function=lambda x: host_device, # pylint: disable=cell-var-from-loop tpu_ordinal_function=_tpu_ordinal_function) else: input_ops = q.split_inputs_and_generate_enqueue_ops( [v for _, v in flat_batch], device_assignment=py_utils.GetTpuDeviceAssignment( )) input_ops_list += input_ops tf.logging.info('input_ops_list %s', input_ops_list) tpu_infeed_op = tf.group(*input_ops_list) self._made_tpu_infeed = True # Let trainer.py use multiple threads to drive the infeed op. for _ in range(p.tpu_infeed_parallism): tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op) with tf.device(tf.compat.v1.tpu.core(0)): tensors = queues[0].generate_dequeue_op() return first_batch.Pack(tensors)
def CreateTpuFeeds(self): """Creates the TPU infeed queue from preprocessed batch.""" p = self.params cluster = cluster_factory.Current() num_tpu_hosts = cluster.num_tpu_hosts assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts) num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 with py_utils.outside_all_rewrites(): assert py_utils.use_tpu() assert not self._made_tpu_infeed shards = tpu_function.get_tpu_context( ).number_of_shards // num_infeed_hosts input_ops_list = [] queues = [] first_batch = None for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): batch = self.GetPreprocessedInputBatch() if first_batch is None: first_batch = batch flat_batch = batch.FlattenItems() shapes, types = [], [] for k, x in flat_batch: assert x.shape.is_fully_defined(), ( 'Shape must be fully defined: %s: %s' % (k, x)) # TODO(cwhipkey): if it's a string (or other type not supported on # TPU), drop it from feeding and on the other end add in an op that # fails if used. shapes.append(x.shape) types.append(x.dtype) q = tf.contrib.tpu.InfeedQueue(tuple_types=types, tuple_shapes=shapes) queues.append(q) assert shards is not None q.set_number_of_shards(shards) if p.use_per_host_infeed: # TODO(ylc/zhifengc): Add this to a policy module and test it. def _tpu_ordinal_function(shard_index_in_host): device_assignment = py_utils.GetTpuDeviceAssignment() if device_assignment: # We put both enqueue/dequeue ops at core 0 in each replica. replica = device_assignment.lookup_replicas( task_id, 0)[shard_index_in_host] # pylint: disable=cell-var-from-loop return device_assignment.tpu_ordinal(replica=replica) else: return shard_index_in_host input_ops = q.split_inputs_and_generate_enqueue_ops( [v for _, v in flat_batch], placement_function=lambda x: host_device, # pylint: disable=cell-var-from-loop tpu_ordinal_function=_tpu_ordinal_function) else: input_ops = q.split_inputs_and_generate_enqueue_ops( [v for _, v in flat_batch], device_assignment=py_utils.GetTpuDeviceAssignment()) input_ops_list += input_ops tf.logging.info('input_ops_list %s', input_ops_list) tpu_infeed_op = tf.group(*input_ops_list) self._made_tpu_infeed = True # Let trainer.py use multiple threads to drive the infeed op. for _ in range(p.tpu_infeed_parallism): tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op) with tf.device(tf.contrib.tpu.core(0)): tensors = queues[0].generate_dequeue_op() return first_batch.Pack(tensors)
def CreateTpuFeeds(self): """Creates the TPU infeed queue from preprocessed batch.""" p = self.params cluster = self.cluster num_tpu_hosts = cluster.num_tpu_hosts num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts tf.logging.info('num_cores_per_host {}'.format(num_cores_per_host)) tf.logging.info('num_devices_per_split {}'.format( cluster.num_devices_per_split)) assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts) if (cluster.num_devices_per_split > num_cores_per_host and p.use_per_host_infeed): tf.logging.fatal( 'Doesn\'t support per host infeed mode when ' 'num_devices_per_split({}) > num_cores_per_host({})'.format( cluster.num_devices_per_split, num_cores_per_host)) num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 with py_utils.outside_all_rewrites(): assert py_utils.use_tpu() assert not self._made_tpu_infeed shards = tpu_function.get_tpu_context( ).number_of_shards // num_infeed_hosts input_ops_list = [] queues = [] tpu_embedding_collection = tf.get_collection( py_utils.TPU_EMBEDDING) tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) tpu_emb_input_keys = (list( tpu_embedding.feature_to_config_dict.keys()) if tpu_embedding is not None else []) tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys) batch = None for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): batch = self.GetPreprocessedInputBatch() if 'bucket_keys' in batch: # Hack: bucket_keys are not needed on TPU. del batch['bucket_keys'] tf.logging.info('host_device: %s, batch: %r', host_device, batch) if tpu_embedding is not None: enqueue_dict_per_core = [ {} for _ in range(tpu_embedding.num_cores_per_host) ] num_cores_per_host = tpu_embedding.num_cores_per_host for key in tpu_emb_input_keys: feat = batch[key] tpu_emb_feat_splitted = tf.split( feat, num_cores_per_host) for core, split in enumerate( tpu_emb_feat_splitted): # Dense to sparse. Note the assumption of a padding id. sample_indices = tf.where( tf.not_equal(split, -1)) embedding_indices = tf.gather_nd( split, sample_indices) enqueue_data = tpu_embedding_lib.EnqueueData( embedding_indices, sample_indices) enqueue_dict_per_core[core][key] = enqueue_data input_ops_list += tpu_embedding.generate_enqueue_ops( enqueue_dict_per_core) for k, x in batch.FlattenItems(): assert x.shape.is_fully_defined(), ( 'Shape must be fully defined: %s: %s' % (k, x)) # TODO(cwhipkey): if it's a string (or other type not supported on # TPU), drop it from feeding and on the other end add in an op that # fails if used. shapes = batch.Transform(lambda x: x.shape).Flatten() dtypes = batch.Transform(lambda x: x.dtype).Flatten() tf.logging.info('host_device: %s infeed shapes: %r', host_device, shapes) tf.logging.info('host_device: %s infeed dtypes: %r', host_device, dtypes) q = tpu_feed.InfeedQueue(tuple_types=dtypes, tuple_shapes=shapes) queues.append(q) assert shards is not None q.set_number_of_shards(shards) if p.use_per_host_infeed: # TODO(ylc/zhifengc): Add this to a policy module and test it. def TPUOrdinalFunction(shard_index_in_host): device_assignment = py_utils.GetTpuDeviceAssignment( ) if device_assignment: # We put both enqueue/dequeue ops at core 0 in each replica. replica = device_assignment.lookup_replicas( task_id, 0)[shard_index_in_host] # pylint: disable=cell-var-from-loop return device_assignment.tpu_ordinal( replica=replica) else: return shard_index_in_host input_ops = q.split_inputs_and_generate_enqueue_ops( batch.Flatten(), placement_function=lambda x: host_device, # pylint: disable=cell-var-from-loop tpu_ordinal_function=TPUOrdinalFunction) else: input_ops = q.split_inputs_and_generate_enqueue_ops( batch.Flatten(), device_assignment=py_utils.GetTpuDeviceAssignment( )) input_ops_list += input_ops tf.logging.info('input_ops_list %s', input_ops_list) tpu_infeed_op = tf.group(*input_ops_list) self._made_tpu_infeed = True # Let trainer.py use multiple threads to drive the infeed op. for _ in range(p.tpu_infeed_parallelism): tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op) # For executor-driven multiple programs, we need more fine-grained # access rather than using a single global graph collection. self.tpu_infeed_op = tpu_infeed_op with tf.device(tf.tpu.core(0)): tensors = queues[0].generate_dequeue_op() return batch.Pack(tensors)
def BuildTpuSubgraph(self): tf.logging.info('TrainProgram BuildTpuSubGraph') with py_utils.OpportunisticVariableReuseScope(True): self._eval_metrics = metrics.TpuEvalMetrics() data_parallelism = self.data_parallelism # Instantiate input generator first. self._input = self._task_params.input.Instantiate() self._input.CreateTpuEnqueueOps() self.SkipCreateChild(self._task_params) def TpuTrainStep(*args): """Train a shard of a batch on a single TPU core. Args: *args: metrics values from previous steps. Returns: New summed metrics values and a train_op. """ self._model = self._task_params.Instantiate() self._task = self._model.GetTask() self._task.AddChild('input', self._input) self._model.ConstructFPropBPropGraph() per_step_eval_metrics = self._eval_metrics.SetMetrics( self._task.eval_metrics, args) outfeed_op = self._OutfeedEnqueue( self._task.per_example_tensors) summed_metrics = [] assert len(per_step_eval_metrics) == len(args) with tf.control_dependencies([outfeed_op]): for x, y in zip(per_step_eval_metrics, args): summed_metrics.append(x + y) return summed_metrics + [self._task.train_op] @tpu_function.on_device_training_loop def TpuTrain(): loop_result = tpu_training_loop.repeat( self._steps_per_loop, TpuTrainStep, inputs=self._eval_metrics.initial_values, name='train_loop') # Final metrics are the avg across self._steps_per_loop steps. return self._eval_metrics.FinalizeMetrics(loop_result) self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( TpuTrain, num_shards=data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) outfeed_dequeue_op = self._OutfeedDequeueLoop( self._task.per_example_tensors, self._steps_per_loop, self.num_splits_per_client) # Get metric result from a single replica; they are all same here. def _ConstructPostTrainingLoop(train_loop_op, outfeed_dequeue_op): """Returns the op for tpu training with tail cpu computation.""" # Adds a tail computation that is run after the tpu_training loop # step finishes. This allows us to run certain computation that # acts on the variable between tpu_train_loop iterations and # amortizing the cost of the operations. Alternative of running # tpu.outside_compilation & using tf.cond is expenseive. with tf.control_dependencies(train_loop_op): self._model.ConstructPostTrainingLoop() with tf.control_dependencies( [self._task.post_training_loop_op]): return ([[tf.identity(o) for o in train_loop_op], outfeed_dequeue_op]) # Get metric result from a single replica; they are all same here. all_tpu_ops = [t[0] for t in batch_parallel_res] self.tpu_ops = (_ConstructPostTrainingLoop(all_tpu_ops, outfeed_dequeue_op)) return self.tpu_ops
def __init__(self, *args, **kwargs): super(TrainerTpu, self).__init__(*args, **kwargs) # Multiple TPU trainer tasks not tested/implemented. assert self._cluster.num_replicas == 1 data_parallelism = self._cluster.num_splits_per_client assert data_parallelism num_devices_per_split = self._cluster.num_devices_per_split tf.logging.info('data_parallelism: %d, num_devices_per_split: %d', data_parallelism, num_devices_per_split) def ComputationShape(split_size): """Decides the computation shape based on the split_size.""" computation_shape = None if split_size == 1: computation_shape = [1, 1, 1] elif split_size == 2: computation_shape = [1, 1, 2] elif split_size == 4: computation_shape = [1, 2, 2] elif split_size == 8: computation_shape = [2, 2, 2] elif split_size == 16: computation_shape = [4, 2, 2] else: assert False, ('Model parallelism with %d devices is currently not' ' supported.' % split_size) assert computation_shape is not None return computation_shape self._steps_per_loop = min(self.params.train.tpu_steps_per_loop, self.params.train.max_steps) tf.logging.info( 'Creating TrainerTpu using data parallelism %s ' 'and %s steps_per_loop', data_parallelism, self._steps_per_loop) @py_utils.RetryOnTransientTfError() def _WaitTillInit(): """Wait until the model is ready.""" try: with self._GetSession() as sess: topology = sess.run( tf.contrib.tpu.initialize_system(embedding_config=None, job=None)) device_assignment = tf.contrib.tpu.device_assignment( topology, computation_shape=ComputationShape(num_devices_per_split), num_replicas=data_parallelism) py_utils.SetTpuDeviceAssignment(device_assignment) tf.logging.info('device_assignment.core_assignment: %s', str(device_assignment.core_assignment)) tf.logging.info('device_assignment.topology.device_coordinates: %s', str(device_assignment.topology.device_coordinates)) except py_utils.transient_tf_errors as e: tf.logging.info('TPU initialization failed: %s', e) raise _WaitTillInit() with self._graph.as_default(), tf.container(self._container_id): with self._cluster, tf.device(self._cluster.job_spec.name): self._eval_metrics = metrics.TpuEvalMetrics() def TpuTrainStep(*args): self._model = self.params.cls(self.params) self._model.ConstructFPropBPropGraph() per_step_eval_metrics = self._eval_metrics.SetMetrics( self._model.GetTask().eval_metrics, args) summed_metrics = [] assert len(per_step_eval_metrics) == len(args) for x, y in zip(per_step_eval_metrics, args): summed_metrics.append(x + y) return summed_metrics + [self._model.GetTask().train_op] def TpuTrain(): loop_result = tf.contrib.tpu.repeat( self._steps_per_loop, TpuTrainStep, inputs=self._eval_metrics.initial_values, name='train_loop') # Final metrics are the avg across self._steps_per_loop steps. return self._eval_metrics.FinalizeMetrics(loop_result) batch_parallel_res = tf.contrib.tpu.batch_parallel( TpuTrain, num_shards=data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) # Get metric result from a single replica; they are all same here. self._tpu_train_ops = [t[0] for t in batch_parallel_res] self.initialize_tables = tf.tables_initializer() self.enqueue_ops = tf.get_collection(py_utils.ENQUEUE_OPS) assert not tf.get_collection(py_utils.CLOSE_QUEUE_OPS) tf.logging.info('Trainer number of enqueue ops: %d', len(self.enqueue_ops)) self._summary_writer = self._CreateSummaryWriter(self._train_dir) # Saves the graph def. tf.train.write_graph(self._graph.as_graph_def(), self._train_dir, 'train.pbtxt')
def CreateTpuEnqueueOps(self): """Create the host-side enqueue ops. This should be called in an outer non-TPU context. """ assert not self._tpu_queues, ( 'CreateTpuEnqueueOps should only be called ' 'once.') self._tpu_queues = [] p = self.params cluster = self.cluster num_tpu_hosts = cluster.num_tpu_hosts num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts tf.logging.info( 'CreateTpuEnqueueOps num_splits_per_client={} ' 'num_devices_per_split={} num_tpu_hosts={} use_per_host_infeed={}'. format(cluster.num_splits_per_client, cluster.num_devices_per_split, num_tpu_hosts, p.use_per_host_infeed)) assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts) if (cluster.num_devices_per_split > num_cores_per_host and p.use_per_host_infeed): tf.logging.fatal( 'Doesn\'t support per host infeed mode when ' 'num_devices_per_split({}) > num_cores_per_host({})'.format( cluster.num_devices_per_split, num_cores_per_host)) num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 shards = (cluster.total_worker_devices // num_infeed_hosts) // cluster.num_devices_per_split tf.logging.info('shards {}'.format(shards)) input_ops_list = [] tpu_embedding_collection = tf.get_collection(py_utils.TPU_EMBEDDING) tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) if num_tpu_hosts > 1 and tpu_embedding is not None: if not p.use_per_host_infeed: tf.logging.fatal( 'TPU Embedding must be used with per_host_infeed with multiple ' 'TPU host topologies.') tpu_emb_input_keys = (list(tpu_embedding.feature_to_config_dict.keys()) if tpu_embedding is not None else []) tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys) tf.logging.info('num_infeed_hosts: %d', num_infeed_hosts) for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): self._batch = self.GetPreprocessedInputBatch() if isinstance(self._batch, py_utils.NestedMap): # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU. # Note that when MultiTaskData is used, bucket_keys will be at the # second level of the dictionary. self._batch = self._batch.FilterKeyVal( lambda k, _: not k.endswith('bucket_keys')) tf.logging.info('host_device: %s, batch: %r', host_device, self._batch) for k, x in self._batch.FlattenItems(): assert x.shape.is_fully_defined(), ( 'Shape must be fully defined: %s: %s' % (k, x)) # TODO(cwhipkey): if it's a string (or other type not supported on # TPU), drop it from feeding and on the other end add in an op that # fails if used. shapes = self._batch.Transform(lambda x: x.shape).Flatten() dtypes = self._batch.Transform(lambda x: x.dtype).Flatten() tf.logging.info('host_device: %s infeed shapes: %r', host_device, shapes) tf.logging.info('host_device: %s infeed dtypes: %r', host_device, dtypes) if p.use_partitioned_infeed_queue: device_assignment = py_utils.GetTpuDeviceAssignment() host_device = device_assignment.host_device( replica=0, job=tf.flags.FLAGS.tf_master) host_id = int( host_device.split('/task:')[1].split('/device:')[0]) tf.logging.info('host_id: {} host_device: {}'.format( host_id, host_device)) q = tpu_feed._PartitionedInfeedQueue( # pylint: disable=protected-access number_of_tuple_elements=len(dtypes), device_assignment=device_assignment, host_id=host_id, input_partition_dims=[[p.num_partitions] + [1] * (len(s) - 1) for s in shapes], tuple_types=dtypes, tuple_shapes=shapes) else: q = tpu_feed.InfeedQueue(tuple_types=dtypes, tuple_shapes=shapes) assert shards is not None q.set_number_of_shards(shards) self._tpu_queues.append(q) if p.use_partitioned_infeed_queue: input_ops = q.generate_enqueue_ops([self._batch.Flatten()]) elif p.use_per_host_infeed: # TODO(ylc/zhifengc): Add this to a policy module and test it. def TPUOrdinalFunction(shard_index_in_host): device_assignment = py_utils.GetTpuDeviceAssignment() if device_assignment: # We put both enqueue/dequeue ops at core 0 in each replica. replica = device_assignment.lookup_replicas( task_id, 0)[shard_index_in_host] # pylint: disable=cell-var-from-loop return device_assignment.tpu_ordinal( replica=replica) else: return shard_index_in_host input_ops = q.split_inputs_and_generate_enqueue_ops( self._batch.Flatten(), placement_function=lambda x: host_device, # pylint: disable=cell-var-from-loop tpu_ordinal_function=TPUOrdinalFunction) else: input_ops = q.split_inputs_and_generate_enqueue_ops( self._batch.Flatten(), device_assignment=py_utils.GetTpuDeviceAssignment()) input_ops_list += input_ops tf.logging.info('input_ops_list %s', input_ops_list) grouped_infeed_op = tf.group(*input_ops_list) self._tpu_infeed_op = [] for _ in range(p.tpu_infeed_parallelism): self._tpu_infeed_op.append(grouped_infeed_op)
def BuildTpuSubgraph(self): if self._ml_perf_log: mlp_log.mlperf_print('global_batch_size', self._ml_perf.global_batch_size) mlp_log.mlperf_print('max_sequence_length', self._ml_perf.max_sequence_length) mlp_log.mlperf_print('opt_name', self._ml_perf.optimizer_name) mlp_log.mlperf_print('opt_base_learning_rate', self._ml_perf.base_learning_rate) mlp_log.mlperf_print('opt_learning_rate_warmup_steps', self._ml_perf.warmup_steps) with py_utils.OpportunisticVariableReuseScope(True): self._eval_metrics = metrics.TpuEvalMetrics() data_parallelism = self.data_parallelism def TpuTrainStep(): """Train a shard of a batch on a single TPU core. Do not calculate loss metrics. Returns: [train_op]. """ self._train_model = self._train_task_params.Instantiate() self._model = self._train_model self._train_model.ConstructFPropBPropGraph() return [self._train_model.GetTask().train_op] def TpuTrain(): loop_result = tpu_training_loop.repeat( self._train_steps_per_loop, TpuTrainStep, inputs=[], name='train_loop') return loop_result py_utils.ResetStepSeed() def _DecodeFn(): """Decode call to be compiled for TPU.""" with py_utils.OpportunisticVariableReuseScope(True): with cluster_factory.SetEval(True): self._decode_model = self._decode_task_params.Instantiate() self._decode_model_task = self._decode_model.GetTask() if py_utils.use_tpu(): input_batch = self._decode_model_task.input_generator.CreateTpuFeeds( ) else: input_batch = self._decode_model_task.input_generator.SplitInputBatch( self.cluster.num_splits_per_client) metrics_dict = self._decode_model_task.Decode(input_batch) self.metrics_nm = py_utils.NestedMap(metrics_dict) return self.metrics_nm.Flatten() @tpu_function.on_device_training_loop def TrainAndDecode(): with tf.control_dependencies([TpuTrain()]): return _DecodeFn() self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( TrainAndDecode, num_shards=data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) self.metrics = py_utils.NestedMap(self.metrics_nm) self.metrics = self.metrics.Pack(batch_parallel_res) return None
def init_graph(self, model_params): """Builds moe decode graph. Args: model_params: the hyperparams of the specified model. """ assert self.graph self.model_params = model_params batch_size = model_params.task.batch_size if (hasattr(model_params.task.builder, 'device_mesh_shape') and model_params.task.builder.device_mesh_shape): num_partitions = np.prod( model_params.task.builder.device_mesh_shape) else: num_partitions = model_params.task.builder.num_devices device_order_mode = (model_params.task.train.tpu_device_order_mode or tpu_device_assignment.DeviceOrderMode.AUTO) self._init_tpu(num_partitions, device_order_mode) assert self.cluster_params # configured by init_tpu self.cluster = self.cluster_params.Instantiate() with self.graph.as_default(), self.cluster, tf.device( self.cluster.GetPlacer()): _ = py_utils.GetOrCreateGlobalStepVar() self.heartbeat = tf.constant(np.pi) device_assignment = py_utils.GetTpuDeviceAssignment() tf.logging.info('Instantiating model') model = model_params.Instantiate() xformer = model.GetTask() self.task = xformer self.init_vars_op = tf.global_variables_initializer() self.saver = tf.train.Saver(sharded=True, reshape=self._saver_reshape) infeed = self._config_infeed(num_partitions=num_partitions, device_assignment=device_assignment, batch_size=batch_size) self.outfeed = [] def decode_fn(*infeed_batch): # pylint: disable=missing-docstring # Length 6 is passed when there is no tgt_mask (e.g. decoding) and # length 7 is passed when there is a tgt_mask (e.g. fprop). self.outfeed = self._config_outfeed(xformer, infeed_batch) with tf.device(tf.tpu.core(0)): outfeed_op = tpu_ops.outfeed_enqueue_tuple( tf.nest.flatten(self.outfeed)) return [outfeed_op] @tpu_function.on_device_training_loop def decode_loop_fn(): if not self.num_batches: infinite_repeat(decode_fn, infeed) else: training_loop.repeat(self.num_batches, decode_fn, infeed_queue=infeed) self.compile_op, self.decode_loop = tpu_lib.split_compile_and_shard( decode_loop_fn, num_shards=1, device_assignment=device_assignment) assert self.outfeed with tf.device(device_assignment.tpu_device(0, 0)): self.outfeed_op = tpu_ops.outfeed_dequeue_tuple( dtypes=[x.dtype for x in tf.nest.flatten(self.outfeed)], shapes=[x.shape for x in tf.nest.flatten(self.outfeed)])