def BuildTpuSubgraph(self): tf.logging.info('EvalProgram BuildTpuSubGraph') with cluster_factory.SetEval(True): self._eval_metrics = metrics.TpuEvalMetrics() data_parallelism = self.data_parallelism with cluster_factory.SetImmediatelyInstantiateVariables(False): self._model = self._InstantiateTaskModel(self._task_params) self._task = self._model.GetTask() self._task.input.InstantiateVariables() self._task.input.CreateTpuEnqueueOps() self._init_input_ops = self._task.input.InitOps() # XLA thinks self.TpuEvalLoop() requires 1 argument due to self # Trick it with wrapper function def TpuEvalLoopWrapper(): return self.TpuEvalLoop() self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( TpuEvalLoopWrapper, num_shards=data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) self._task.input.CreateTpuEmbeddingEnqueueOps(mode_override='inference') # Get metric result from a single replica; they are all same here. self.tpu_ops = [[t[0] for t in batch_parallel_res]] return self.tpu_ops
def BuildTpuSubgraph(self): tf.logging.info('TrainProgram BuildTpuSubGraph') with py_utils.OpportunisticVariableReuseScope(True): self._eval_metrics = metrics.TpuEvalMetrics() data_parallelism = self.data_parallelism # Instantiate input generator first. self._input = self._task_params.input.Instantiate() self._input.CreateTpuEnqueueOps() self.SkipCreateChild(self._task_params) def TpuTrainStep(*args): """Train a shard of a batch on a single TPU core. Args: *args: metrics values from previous steps. Returns: New summed metrics values and a train_op. """ self._model = self._task_params.Instantiate() self._task = self._model.GetTask() self._task.AddChild('input', self._input) self._model.ConstructFPropBPropGraph() per_step_eval_metrics = self._eval_metrics.SetMetrics( self._task.eval_metrics, args) outfeed_op = self._OutfeedEnqueue( self._task.per_example_tensors) summed_metrics = [] assert len(per_step_eval_metrics) == len(args) with tf.control_dependencies([outfeed_op]): for x, y in zip(per_step_eval_metrics, args): summed_metrics.append(x + y) return summed_metrics + [self._task.train_op] @tpu_function.on_device_training_loop def TpuTrain(): loop_result = tpu_training_loop.repeat( self._steps_per_loop, TpuTrainStep, inputs=self._eval_metrics.initial_values, name='train_loop') # Final metrics are the avg across self._steps_per_loop steps. return self._eval_metrics.FinalizeMetrics(loop_result) self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( TpuTrain, num_shards=data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) outfeed_dequeue_op = self._OutfeedDequeueLoop( self._task.per_example_tensors, self._steps_per_loop, self.num_splits_per_client) # Get metric result from a single replica; they are all same here. self.tpu_ops = [[t[0] for t in batch_parallel_res], outfeed_dequeue_op] return self.tpu_ops
def BuildTpuSubgraph(self): tf.logging.info('EvalProgram BuildTpuSubGraph') with cluster_factory.SetEval(True): self._eval_metrics = metrics.TpuEvalMetrics() data_parallelism = self.data_parallelism with cluster_factory.SetImmediatelyInstantiateVariables(False): self._model = self._InstantiateTaskModel(self._task_params) self._task = self._model.GetTask() self._task.input.InstantiateVariables() self._task.input.CreateTpuEnqueueOps() def TpuEvalStep(*args): """Eval a shard of a batch on a single TPU core. Args: *args: metrics values from previous steps. Returns: Summed eval metrics. """ with tf.name_scope('tpu_eval'): with py_utils.OpportunisticVariableReuseScope(True): self._model.InstantiateVariables() self._model.ConstructFPropGraph() per_step_eval_metrics = self._eval_metrics.SetMetrics( self._task.eval_metrics, args) summed_metrics = [] for x, y in zip(per_step_eval_metrics, args): summed_metrics.append(x + y) return summed_metrics @tpu_function.on_device_training_loop def TpuEval(): loop_result = tpu_training_loop.repeat( self._steps_per_loop, TpuEvalStep, inputs=self._eval_metrics.initial_values, name='eval_loop') # Final metrics are the avg across self._steps_per_loop steps. return self._eval_metrics.FinalizeMetrics(loop_result) self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( TpuEval, num_shards=data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) self._task.input.CreateTpuEmbeddingEnqueueOps( mode_override='inference') # Get metric result from a single replica; they are all same here. self.tpu_ops = [[t[0] for t in batch_parallel_res]] return self.tpu_ops
def BuildTpuSubgraph(self): with py_utils.OpportunisticVariableReuseScope(True): self._eval_metrics = metrics.TpuEvalMetrics() data_parallelism = self.data_parallelism def TpuTrainStep(*args): """Train a shard of a batch on a single TPU core. Args: *args: metrics values from previous steps. Returns: New summed metrics values and a train_op. """ self._model = self._task_params.Instantiate() self._model.ConstructFPropBPropGraph() per_step_eval_metrics = self._eval_metrics.SetMetrics( self._model.GetTask().eval_metrics, args) outfeed_op = self._OutfeedEnqueue( self._model.GetTask().per_example_tensors) summed_metrics = [] assert len(per_step_eval_metrics) == len(args) with tf.control_dependencies([outfeed_op]): for x, y in zip(per_step_eval_metrics, args): summed_metrics.append(x + y) return summed_metrics + [self._model.GetTask().train_op] @tpu_function.on_device_training_loop def TpuTrain(): loop_result = tpu_training_loop.repeat( self._steps_per_loop, TpuTrainStep, inputs=self._eval_metrics.initial_values, name='train_loop') # Final metrics are the avg across self._steps_per_loop steps. return self._eval_metrics.FinalizeMetrics(loop_result) batch_parallel_res = tf.tpu.batch_parallel( TpuTrain, num_shards=data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) outfeed_dequeue_op = self._OutfeedDequeueLoop( self._model.GetTask().per_example_tensors, self._steps_per_loop, self.num_splits_per_client) # Get metric result from a single replica; they are all same here. self.tpu_ops = [[t[0] for t in batch_parallel_res], outfeed_dequeue_op] # TODO(blee): This is going to need to be fixed for multiple-model # execution. Need to get only the vars associated with the model. self._checkpointer = self._CreateCheckpointer( self._checkpoint_dir, self._model) return self.tpu_ops
def BuildTpuSubgraph(self): tf.logging.info('EvalProgram BuildTpuSubGraph') with py_utils.OpportunisticVariableReuseScope(True): self._eval_metrics = metrics.TpuEvalMetrics() data_parallelism = self.data_parallelism def TpuEvalStep(*args): """Eval a shard of a batch on a single TPU core. Args: *args: metrics values from previous steps. Returns: Per-step eval metrics. """ self._model = self._task_params.Instantiate() self._model.ConstructFPropGraph() per_step_eval_metrics = self._eval_metrics.SetMetrics( self._model.GetTask().eval_metrics, args) return per_step_eval_metrics @tpu_function.on_device_training_loop def TpuEval(): loop_result = tpu_training_loop.repeat( self._steps_per_loop, TpuEvalStep, inputs=self._eval_metrics.initial_values, name='eval_loop') # Final metrics are the avg across self._steps_per_loop steps. return self._eval_metrics.FinalizeMetrics(loop_result) batch_parallel_res = tf.tpu.batch_parallel( TpuEval, num_shards=data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) # Get metric result from a single replica; they are all same here. self.tpu_ops = [[t[0] for t in batch_parallel_res]] self._checkpointer = checkpointer.Checkpointer( self._checkpoint_dir, self._model) return self.tpu_ops
def BuildTpuSubgraph(self): if self._ml_perf_log: mlp_log.mlperf_print('global_batch_size', self._ml_perf.global_batch_size) mlp_log.mlperf_print('max_sequence_length', self._ml_perf.max_sequence_length) mlp_log.mlperf_print('opt_name', self._ml_perf.optimizer_name) mlp_log.mlperf_print('opt_base_learning_rate', self._ml_perf.base_learning_rate) mlp_log.mlperf_print('opt_learning_rate_warmup_steps', self._ml_perf.warmup_steps) with py_utils.OpportunisticVariableReuseScope(True): self._eval_metrics = metrics.TpuEvalMetrics() data_parallelism = self.data_parallelism def TpuTrainStep(): """Train a shard of a batch on a single TPU core. Do not calculate loss metrics. Returns: [train_op]. """ self._train_model = self._train_task_params.Instantiate() self._model = self._train_model self._train_model.ConstructFPropBPropGraph() return [self._train_model.GetTask().train_op] def TpuTrain(): loop_result = tpu_training_loop.repeat( self._train_steps_per_loop, TpuTrainStep, inputs=[], name='train_loop') return loop_result py_utils.ResetStepSeed() def _DecodeFn(): """Decode call to be compiled for TPU.""" with py_utils.OpportunisticVariableReuseScope(True): with cluster_factory.SetEval(True): self._decode_model = self._decode_task_params.Instantiate() self._decode_model_task = self._decode_model.GetTask() if py_utils.use_tpu(): input_batch = self._decode_model_task.input_generator.CreateTpuFeeds( ) else: input_batch = self._decode_model_task.input_generator.SplitInputBatch( self.cluster.num_splits_per_client) metrics_dict = self._decode_model_task.Decode(input_batch) self.metrics_nm = py_utils.NestedMap(metrics_dict) return self.metrics_nm.Flatten() @tpu_function.on_device_training_loop def TrainAndDecode(): with tf.control_dependencies([TpuTrain()]): return _DecodeFn() self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( TrainAndDecode, num_shards=data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) self.metrics = py_utils.NestedMap(self.metrics_nm) self.metrics = self.metrics.Pack(batch_parallel_res) return None
def BuildTpuSubgraph(self): tf.logging.info('TrainProgram BuildTpuSubGraph') self.spmd = (self.params.spmd or self._task_params.input.use_partitioned_infeed_queue) self._eval_metrics = metrics.TpuEvalMetrics() data_parallelism = self.data_parallelism with cluster_factory.SetImmediatelyInstantiateVariables(False): self._model = self._InstantiateTaskModel(self._task_params) self._task = self._model.GetTask() self._task.input.InstantiateVariables() self._task.input.CreateTpuEnqueueOps() def TpuTrainStep(*args): """Train a shard of a batch on a single TPU core. Args: *args: metrics values from previous steps. Returns: New summed metrics values and a train_op. """ with tf.name_scope('tpu_train'): with py_utils.OpportunisticVariableReuseScope(True): self._model.InstantiateVariables() self._model.ConstructFPropBPropGraph() per_step_eval_metrics = self._eval_metrics.SetMetrics( self._task.eval_metrics, args) outfeed_op = self._OutfeedEnqueue( self._task.per_example_tensors) summed_metrics = [] assert len(per_step_eval_metrics) == len(args) with tf.control_dependencies([outfeed_op]): for x, y in zip(per_step_eval_metrics, args): summed_metrics.append(x + y) return summed_metrics + [self._task.train_op] @tpu_function.on_device_training_loop def TpuTrain(): loop_result = tpu_training_loop.repeat( self._steps_per_loop, TpuTrainStep, inputs=self._eval_metrics.initial_values, name='train_loop') # Final metrics are the avg across self._steps_per_loop steps. return self._eval_metrics.FinalizeMetrics(loop_result) self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( TpuTrain, num_shards=data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) outfeed_dequeue_op = self._OutfeedDequeueLoop( self._task.per_example_tensors, self._steps_per_loop, self.num_splits_per_client) self._task.input.CreateTpuEmbeddingEnqueueOps() # Get metric result from a single replica; they are all same here. def _ConstructPostTrainingLoop(train_loop_op, outfeed_dequeue_op): """Returns the op for tpu training with tail cpu computation.""" # Adds a tail computation that is run after the tpu_training loop # step finishes. This allows us to run certain computation that # acts on the variable between tpu_train_loop iterations and # amortizing the cost of the operations. Alternative of running # tpu.outside_compilation & using tf.cond is expenseive. with tf.control_dependencies(train_loop_op): self._model.ConstructPostTrainingLoop() with tf.control_dependencies( [self._task.post_training_loop_op]): return ([[tf.identity(o) for o in train_loop_op], outfeed_dequeue_op]) # Get metric result from a single replica; they are all same here. all_tpu_ops = [t[0] for t in batch_parallel_res] self.tpu_ops = (_ConstructPostTrainingLoop(all_tpu_ops, outfeed_dequeue_op)) self._model_analysis, self._total_num_params = summary_utils.ModelAnalysis( self._model) try: with tf.io.gfile.GFile( os.path.join(self._program_dir, 'model_analysis.txt'), 'w') as f: f.write(self._model_analysis) except tf.errors.NotFoundError as e: tf.logging.info('Failed to write model analysis %s', e) return self.tpu_ops
def __init__(self, *args, **kwargs): super(TrainerTpu, self).__init__(*args, **kwargs) # Multiple TPU trainer tasks not tested/implemented. assert self._cluster.num_replicas == 1 data_parallelism = self._cluster.num_splits_per_client assert data_parallelism num_devices_per_split = self._cluster.num_devices_per_split tf.logging.info('data_parallelism: %d, num_devices_per_split: %d', data_parallelism, num_devices_per_split) def ComputationShape(split_size): """Decides the computation shape based on the split_size.""" computation_shape = None if split_size == 1: computation_shape = [1, 1, 1] elif split_size == 2: computation_shape = [1, 1, 2] elif split_size == 4: computation_shape = [1, 2, 2] elif split_size == 8: computation_shape = [2, 2, 2] elif split_size == 16: computation_shape = [4, 2, 2] else: assert False, ('Model parallelism with %d devices is currently not' ' supported.' % split_size) assert computation_shape is not None return computation_shape self._steps_per_loop = min(self.params.train.tpu_steps_per_loop, self.params.train.max_steps) tf.logging.info( 'Creating TrainerTpu using data parallelism %s ' 'and %s steps_per_loop', data_parallelism, self._steps_per_loop) @py_utils.RetryOnTransientTfError() def _WaitTillInit(): """Wait until the model is ready.""" try: with self._GetSession() as sess: topology = sess.run( tf.contrib.tpu.initialize_system(embedding_config=None, job=None)) device_assignment = tf.contrib.tpu.device_assignment( topology, computation_shape=ComputationShape(num_devices_per_split), num_replicas=data_parallelism) py_utils.SetTpuDeviceAssignment(device_assignment) tf.logging.info('device_assignment.core_assignment: %s', str(device_assignment.core_assignment)) tf.logging.info('device_assignment.topology.device_coordinates: %s', str(device_assignment.topology.device_coordinates)) except py_utils.transient_tf_errors as e: tf.logging.info('TPU initialization failed: %s', e) raise _WaitTillInit() with self._graph.as_default(), tf.container(self._container_id): with self._cluster, tf.device(self._cluster.job_spec.name): self._eval_metrics = metrics.TpuEvalMetrics() def TpuTrainStep(*args): self._model = self.params.cls(self.params) self._model.ConstructFPropBPropGraph() per_step_eval_metrics = self._eval_metrics.SetMetrics( self._model.GetTask().eval_metrics, args) summed_metrics = [] assert len(per_step_eval_metrics) == len(args) for x, y in zip(per_step_eval_metrics, args): summed_metrics.append(x + y) return summed_metrics + [self._model.GetTask().train_op] def TpuTrain(): loop_result = tf.contrib.tpu.repeat( self._steps_per_loop, TpuTrainStep, inputs=self._eval_metrics.initial_values, name='train_loop') # Final metrics are the avg across self._steps_per_loop steps. return self._eval_metrics.FinalizeMetrics(loop_result) batch_parallel_res = tf.contrib.tpu.batch_parallel( TpuTrain, num_shards=data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) # Get metric result from a single replica; they are all same here. self._tpu_train_ops = [t[0] for t in batch_parallel_res] self.initialize_tables = tf.tables_initializer() self.enqueue_ops = tf.get_collection(py_utils.ENQUEUE_OPS) assert not tf.get_collection(py_utils.CLOSE_QUEUE_OPS) tf.logging.info('Trainer number of enqueue ops: %d', len(self.enqueue_ops)) self._summary_writer = self._CreateSummaryWriter(self._train_dir) # Saves the graph def. tf.train.write_graph(self._graph.as_graph_def(), self._train_dir, 'train.pbtxt')
def BuildTpuSubgraph(self): tf.logging.info('TrainProgram BuildTpuSubGraph') with py_utils.OpportunisticVariableReuseScope(True): self._eval_metrics = metrics.TpuEvalMetrics() data_parallelism = self.data_parallelism # Instantiate input generator first. self._input = self._task_params.input.Instantiate() self._input.CreateTpuEnqueueOps() self.SkipCreateChild(self._task_params) def TpuTrainStep(*args): """Train a shard of a batch on a single TPU core. Args: *args: metrics values from previous steps. Returns: New summed metrics values and a train_op. """ self._model = self._task_params.Instantiate() self._task = self._model.GetTask() self._task.AddChild('input', self._input) self._model.ConstructFPropBPropGraph() per_step_eval_metrics = self._eval_metrics.SetMetrics( self._task.eval_metrics, args) outfeed_op = self._OutfeedEnqueue( self._task.per_example_tensors) summed_metrics = [] assert len(per_step_eval_metrics) == len(args) with tf.control_dependencies([outfeed_op]): for x, y in zip(per_step_eval_metrics, args): summed_metrics.append(x + y) return summed_metrics + [self._task.train_op] @tpu_function.on_device_training_loop def TpuTrain(): loop_result = tpu_training_loop.repeat( self._steps_per_loop, TpuTrainStep, inputs=self._eval_metrics.initial_values, name='train_loop') # Final metrics are the avg across self._steps_per_loop steps. return self._eval_metrics.FinalizeMetrics(loop_result) self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( TpuTrain, num_shards=data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) outfeed_dequeue_op = self._OutfeedDequeueLoop( self._task.per_example_tensors, self._steps_per_loop, self.num_splits_per_client) # Get metric result from a single replica; they are all same here. def _ConstructPostTrainingLoop(train_loop_op, outfeed_dequeue_op): """Returns the op for tpu training with tail cpu computation.""" # Adds a tail computation that is run after the tpu_training loop # step finishes. This allows us to run certain computation that # acts on the variable between tpu_train_loop iterations and # amortizing the cost of the operations. Alternative of running # tpu.outside_compilation & using tf.cond is expenseive. with tf.control_dependencies(train_loop_op): self._model.ConstructPostTrainingLoop() with tf.control_dependencies( [self._task.post_training_loop_op]): return ([[tf.identity(o) for o in train_loop_op], outfeed_dequeue_op]) # Get metric result from a single replica; they are all same here. all_tpu_ops = [t[0] for t in batch_parallel_res] self.tpu_ops = (_ConstructPostTrainingLoop(all_tpu_ops, outfeed_dequeue_op)) return self.tpu_ops