def _CompileDecodeFn(self): """Wrap the DecodeFn with split_compile_and_shard.""" with cluster_factory.SetImmediatelyInstantiateVariables(False): self._model = self._InstantiateTaskModel(self._task_params) self._task = self._model.GetTask() self._task.input.InstantiateVariables() self._task.input.CreateTpuEnqueueOps() self._task.input.CreateCpuPassthroughEnqueueOps() def _DecodeFn(): """Decode call to be compiled for TPU.""" with py_utils.OpportunisticVariableReuseScope(True): self._model.InstantiateVariables() input_batch = self._task.input.TpuDequeueBatch() decode_dict = self._task.Decode(input_batch) self.decode_nm = py_utils.NestedMap(decode_dict) return self.decode_nm.Flatten() self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( _DecodeFn, num_shards=self.data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) self.cpu_pt = self._task.input.DequeueCpuPassthrough() self.decode_tensors = py_utils.NestedMap(self.decode_nm) self.decode_tensors = self.decode_tensors.Pack(batch_parallel_res)
def BuildTpuSubgraph(self): tf.logging.info('EvalProgram BuildTpuSubGraph') with cluster_factory.SetEval(True): self._eval_metrics = metrics.TpuEvalMetrics() data_parallelism = self.data_parallelism with cluster_factory.SetImmediatelyInstantiateVariables(False): self._model = self._InstantiateTaskModel(self._task_params) self._task = self._model.GetTask() self._task.input.InstantiateVariables() self._task.input.CreateTpuEnqueueOps() self._init_input_ops = self._task.input.InitOps() # XLA thinks self.TpuEvalLoop() requires 1 argument due to self # Trick it with wrapper function def TpuEvalLoopWrapper(): return self.TpuEvalLoop() self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( TpuEvalLoopWrapper, num_shards=data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) self._task.input.CreateTpuEmbeddingEnqueueOps(mode_override='inference') # Get metric result from a single replica; they are all same here. self.tpu_ops = [[t[0] for t in batch_parallel_res]] return self.tpu_ops
def _CompileDecodeLoop(self): """Wrap the DecodeLoop with split_compile_and_shard.""" device_assignment = py_utils.GetTpuDeviceAssignment() with cluster_factory.SetImmediatelyInstantiateVariables(False): self._model = self._InstantiateTaskModel(self._task_params) self._task = self._model.GetTask() self._task.input.InstantiateVariables() self._task.input.CreateTpuEnqueueOps() self._task.input.CreateCpuPassthroughEnqueueOps() def _DecodeStep(): """Decode call to be compiled for TPU.""" with py_utils.OpportunisticVariableReuseScope(True): self._model.InstantiateVariables() input_batch = self._task.input.TpuDequeueBatch() decode_dict = self._task.Decode(input_batch) self.decode_nm = py_utils.NestedMap(decode_dict) return [self._OutfeedEnqueue(decode_dict)] @tpu_function.on_device_training_loop def DecodeLoopFn(): return tpu_training_loop.repeat( self._steps_per_loop, _DecodeStep, inputs=[]) self._compile_op, self.decode_loop = tpu.split_compile_and_shard( DecodeLoopFn, num_shards=self.data_parallelism, device_assignment=device_assignment) # Get a list of outfeed ops. self.decode_tensors = self._OutfeedDequeue() # Pack the list of outfeed ops with structure in self.decode_nm. self.decode_tensors = tf.nest.pack_sequence_as(self.decode_nm, self.decode_tensors) self.cpu_pt = self._task.input.DequeueCpuPassthrough()
def BuildTpuSubgraph(self): tf.logging.info('DecodeProgram BuildTpuSubGraph') py_utils.ResetStepSeed() with cluster_factory.SetEval(True): with cluster_factory.SetImmediatelyInstantiateVariables(False): self._model = self._task_params.Instantiate() self._task = self._model.GetTask() self._task.input.InstantiateVariables() self._task.input.CreateTpuEnqueueOps() def _DecodeFn(): """Decode call to be compiled for TPU.""" with py_utils.OpportunisticVariableReuseScope(True): self._model.InstantiateVariables() input_batch = self._task.input.TpuDequeueBatch() metrics_dict = self._task.Decode(input_batch) self.metrics_nm = py_utils.NestedMap(metrics_dict) return self.metrics_nm.Flatten() self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( _DecodeFn, num_shards=self.data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) self.metrics = py_utils.NestedMap(self.metrics_nm) self.metrics = self.metrics.Pack(batch_parallel_res) return None
def BuildTpuSubgraph(self): tf.logging.info('EvalProgram BuildTpuSubGraph') with cluster_factory.SetEval(True): self._eval_metrics = metrics.TpuEvalMetrics() data_parallelism = self.data_parallelism with cluster_factory.SetImmediatelyInstantiateVariables(False): self._model = self._InstantiateTaskModel(self._task_params) self._task = self._model.GetTask() self._task.input.InstantiateVariables() self._task.input.CreateTpuEnqueueOps() def TpuEvalStep(*args): """Eval a shard of a batch on a single TPU core. Args: *args: metrics values from previous steps. Returns: Summed eval metrics. """ with tf.name_scope('tpu_eval'): with py_utils.OpportunisticVariableReuseScope(True): self._model.InstantiateVariables() self._model.ConstructFPropGraph() per_step_eval_metrics = self._eval_metrics.SetMetrics( self._task.eval_metrics, args) summed_metrics = [] for x, y in zip(per_step_eval_metrics, args): summed_metrics.append(x + y) return summed_metrics @tpu_function.on_device_training_loop def TpuEval(): loop_result = tpu_training_loop.repeat( self._steps_per_loop, TpuEvalStep, inputs=self._eval_metrics.initial_values, name='eval_loop') # Final metrics are the avg across self._steps_per_loop steps. return self._eval_metrics.FinalizeMetrics(loop_result) self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( TpuEval, num_shards=data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) self._task.input.CreateTpuEmbeddingEnqueueOps( mode_override='inference') # Get metric result from a single replica; they are all same here. self.tpu_ops = [[t[0] for t in batch_parallel_res]] return self.tpu_ops
def BuildTpuSubgraph(self): tf.logging.info('DecodeProgram BuildTpuSubGraph') py_utils.ResetStepSeed() device_assignment = py_utils.GetTpuDeviceAssignment() self.spmd = self._task_params.input.use_partitioned_infeed_queue with cluster_factory.SetEval(True): with cluster_factory.SetImmediatelyInstantiateVariables(False): self._model = self._task_params.Instantiate() self._task = self._model.GetTask() self._task.input.InstantiateVariables() self._task.input.CreateTpuEnqueueOps() def _DecodeStep(): """Decode call to be compiled for TPU.""" with py_utils.OpportunisticVariableReuseScope(True): self._model.InstantiateVariables() input_batch = self._task.input.TpuDequeueBatch() metrics_dict = self._task.Decode(input_batch) self.metrics_nm = py_utils.NestedMap(metrics_dict) device = tpu.core(0) if self.spmd else '' with tf.device(device): outfeed_enqueue = tpu_ops.outfeed_enqueue_tuple( self.metrics_nm.Flatten()) return [outfeed_enqueue] @tpu_function.on_device_training_loop def DecodeLoopFn(): return tpu_training_loop.repeat(self._steps_per_loop, _DecodeStep, inputs=[]) self._compile_op, self.decode_loop = tpu.split_compile_and_shard( DecodeLoopFn, num_shards=self.data_parallelism, device_assignment=device_assignment) # Get a list of outfeed ops. self.metrics = self._OutfeedDequeue() # Pack the list of outfeed ops with structure in self.metrics_nm. self.metrics = tf.nest.pack_sequence_as(self.metrics_nm, self.metrics) return
def BuildTpuSubgraph(self): if self._ml_perf_log: mlp_log.mlperf_print('global_batch_size', self._ml_perf.global_batch_size) mlp_log.mlperf_print('max_sequence_length', self._ml_perf.max_sequence_length) mlp_log.mlperf_print('opt_name', self._ml_perf.optimizer_name) mlp_log.mlperf_print('opt_base_learning_rate', self._ml_perf.base_learning_rate) mlp_log.mlperf_print('opt_learning_rate_warmup_steps', self._ml_perf.warmup_steps) self._eval_metrics = metrics.TpuEvalMetrics() data_parallelism = self.data_parallelism with cluster_factory.SetImmediatelyInstantiateVariables(False): self._train_model = self._train_task_params.Instantiate() self._train_task = self._train_model.GetTask() self._train_task.input.InstantiateVariables() self._train_task.input.CreateTpuEnqueueOps() self._model = self._train_model def TpuTrainStep(): """Train a shard of a batch on a single TPU core. Do not calculate loss metrics. Returns: [train_op]. """ with py_utils.OpportunisticVariableReuseScope(True): self._train_model.InstantiateVariables() self._train_model.ConstructFPropBPropGraph() return [self._train_task.train_op] def TpuTrain(): loop_result = tpu_training_loop.repeat(self._train_steps_per_loop, TpuTrainStep, inputs=[], name='train_loop') return loop_result py_utils.ResetStepSeed() with cluster_factory.SetImmediatelyInstantiateVariables(False): self._decode_model = self._InstantiateTaskModel( self._decode_task_params) self._decode_task = self._decode_model.GetTask() self._decode_task.input.InstantiateVariables() self._decode_task.input.CreateTpuEnqueueOps() def _DecodeFn(): """Decode call to be compiled for TPU.""" with py_utils.OpportunisticVariableReuseScope(True): with cluster_factory.SetEval(True): self._decode_model.InstantiateVariables() input_batch = self._decode_task.input.TpuDequeueBatch() metrics_dict = self._decode_task.Decode(input_batch) self.metrics_nm = py_utils.NestedMap(metrics_dict) return self.metrics_nm.Flatten() @tpu_function.on_device_training_loop def TrainAndDecode(): with tf.control_dependencies([TpuTrain()]): return _DecodeFn() self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( TrainAndDecode, num_shards=data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) self.metrics = py_utils.NestedMap(self.metrics_nm) self.metrics = self.metrics.Pack(batch_parallel_res) return None
def BuildTpuSubgraph(self): tf.logging.info('TrainProgram BuildTpuSubGraph') self.spmd = (self.params.spmd or self._task_params.input.use_partitioned_infeed_queue) self._eval_metrics = metrics.TpuEvalMetrics() data_parallelism = self.data_parallelism with cluster_factory.SetImmediatelyInstantiateVariables(False): self._model = self._InstantiateTaskModel(self._task_params) self._task = self._model.GetTask() self._task.input.InstantiateVariables() self._task.input.CreateTpuEnqueueOps() def TpuTrainStep(*args): """Train a shard of a batch on a single TPU core. Args: *args: metrics values from previous steps. Returns: New summed metrics values and a train_op. """ with tf.name_scope('tpu_train'): with py_utils.OpportunisticVariableReuseScope(True): self._model.InstantiateVariables() self._model.ConstructFPropBPropGraph() per_step_eval_metrics = self._eval_metrics.SetMetrics( self._task.eval_metrics, args) outfeed_op = self._OutfeedEnqueue( self._task.per_example_tensors) summed_metrics = [] assert len(per_step_eval_metrics) == len(args) with tf.control_dependencies([outfeed_op]): for x, y in zip(per_step_eval_metrics, args): summed_metrics.append(x + y) return summed_metrics + [self._task.train_op] @tpu_function.on_device_training_loop def TpuTrain(): loop_result = tpu_training_loop.repeat( self._steps_per_loop, TpuTrainStep, inputs=self._eval_metrics.initial_values, name='train_loop') # Final metrics are the avg across self._steps_per_loop steps. return self._eval_metrics.FinalizeMetrics(loop_result) self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( TpuTrain, num_shards=data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) outfeed_dequeue_op = self._OutfeedDequeueLoop( self._task.per_example_tensors, self._steps_per_loop, self.num_splits_per_client) self._task.input.CreateTpuEmbeddingEnqueueOps() # Get metric result from a single replica; they are all same here. def _ConstructPostTrainingLoop(train_loop_op, outfeed_dequeue_op): """Returns the op for tpu training with tail cpu computation.""" # Adds a tail computation that is run after the tpu_training loop # step finishes. This allows us to run certain computation that # acts on the variable between tpu_train_loop iterations and # amortizing the cost of the operations. Alternative of running # tpu.outside_compilation & using tf.cond is expenseive. with tf.control_dependencies(train_loop_op): self._model.ConstructPostTrainingLoop() with tf.control_dependencies( [self._task.post_training_loop_op]): return ([[tf.identity(o) for o in train_loop_op], outfeed_dequeue_op]) # Get metric result from a single replica; they are all same here. all_tpu_ops = [t[0] for t in batch_parallel_res] self.tpu_ops = (_ConstructPostTrainingLoop(all_tpu_ops, outfeed_dequeue_op)) self._model_analysis, self._total_num_params = summary_utils.ModelAnalysis( self._model) try: with tf.io.gfile.GFile( os.path.join(self._program_dir, 'model_analysis.txt'), 'w') as f: f.write(self._model_analysis) except tf.errors.NotFoundError as e: tf.logging.info('Failed to write model analysis %s', e) return self.tpu_ops