def BuildTpuSubgraph(self): tf.logging.info('DecodeProgram BuildTpuSubGraph') py_utils.ResetStepSeed() # Instantiate input generator first. self._input = self._task_params.input.Instantiate() self._input.CreateTpuEnqueueOps() self._task_params.input.Define('skip_create_child', True, '') def _DecodeFn(): """Decode call to be compiled for TPU.""" with py_utils.OpportunisticVariableReuseScope(True): with cluster_factory.SetEval(True): self._model = self._task_params.Instantiate() self._model_task = self._model.GetTask() self._model_task.AddChild('input', self._input) input_batch = self._model_task.input_generator.TpuDequeueBatch() metrics_dict = self._model_task.Decode(input_batch) self.metrics_nm = py_utils.NestedMap(metrics_dict) return self.metrics_nm.Flatten() self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( _DecodeFn, num_shards=self.data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) self.metrics = py_utils.NestedMap(self.metrics_nm) self.metrics = self.metrics.Pack(batch_parallel_res) return None
def BuildTpuSubgraph(self): tf.logging.info('TrainProgram BuildTpuSubGraph') with py_utils.OpportunisticVariableReuseScope(True): self._eval_metrics = metrics.TpuEvalMetrics() data_parallelism = self.data_parallelism # Instantiate input generator first. self._input = self._task_params.input.Instantiate() self._input.CreateTpuEnqueueOps() self._task_params.input.Define('skip_create_child', True, '') def TpuTrainStep(*args): """Train a shard of a batch on a single TPU core. Args: *args: metrics values from previous steps. Returns: New summed metrics values and a train_op. """ self._model = self._task_params.Instantiate() self._task = self._model.GetTask() self._task.AddChild('input', self._input) self._model.ConstructFPropBPropGraph() per_step_eval_metrics = self._eval_metrics.SetMetrics( self._task.eval_metrics, args) outfeed_op = self._OutfeedEnqueue(self._task.per_example_tensors) summed_metrics = [] assert len(per_step_eval_metrics) == len(args) with tf.control_dependencies([outfeed_op]): for x, y in zip(per_step_eval_metrics, args): summed_metrics.append(x + y) return summed_metrics + [self._model.GetTask().train_op] @tpu_function.on_device_training_loop def TpuTrain(): loop_result = tpu_training_loop.repeat( self._steps_per_loop, TpuTrainStep, inputs=self._eval_metrics.initial_values, name='train_loop') # Final metrics are the avg across self._steps_per_loop steps. return self._eval_metrics.FinalizeMetrics(loop_result) self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( TpuTrain, num_shards=data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) outfeed_dequeue_op = self._OutfeedDequeueLoop( self._model.GetTask().per_example_tensors, self._steps_per_loop, self.num_splits_per_client) # Get metric result from a single replica; they are all same here. self.tpu_ops = [[t[0] for t in batch_parallel_res], outfeed_dequeue_op] return self.tpu_ops
def BuildTpuSubgraph(self): tf.logging.info('EvalProgram BuildTpuSubGraph') with py_utils.OpportunisticVariableReuseScope(True): self._eval_metrics = metrics.TpuEvalMetrics() data_parallelism = self.data_parallelism self._input = self._task_params.input.Instantiate() self._input.CreateTpuEnqueueOps() self._task_params.input.Define('skip_create_child', True, '') def TpuEvalStep(*args): """Eval a shard of a batch on a single TPU core. Args: *args: metrics values from previous steps. Returns: Summed eval metrics. """ with cluster_factory.SetEval(True): self._model = self._task_params.Instantiate() self._task = self._model.GetTask() self._task.AddChild('input', self._input) self._model.ConstructFPropGraph() per_step_eval_metrics = self._eval_metrics.SetMetrics( self._model.GetTask().eval_metrics, args) summed_metrics = [] for x, y in zip(per_step_eval_metrics, args): summed_metrics.append(x + y) return summed_metrics @tpu_function.on_device_training_loop def TpuEval(): loop_result = tpu_training_loop.repeat( self._steps_per_loop, TpuEvalStep, inputs=self._eval_metrics.initial_values, name='eval_loop') # Final metrics are the avg across self._steps_per_loop steps. return self._eval_metrics.FinalizeMetrics(loop_result) self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( TpuEval, num_shards=data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) # Get metric result from a single replica; they are all same here. self.tpu_ops = [[t[0] for t in batch_parallel_res]] return self.tpu_ops
def BuildTpuSubgraph(self): tf.logging.info('DecodeProgram BuildTpuSubGraph') py_utils.ResetStepSeed() device_assignment = py_utils.GetTpuDeviceAssignment() self.spmd = self._task_params.input.use_partitioned_infeed_queue with py_utils.OpportunisticVariableReuseScope(True): with cluster_factory.SetEval(True): self._model = self._task_params.Instantiate() self._model_task = self._model.GetTask() self._model_task.input.CreateTpuEnqueueOps() def _DecodeStep(): """Decode call to be compiled for TPU.""" input_batch = self._model_task.input_generator.TpuDequeueBatch() metrics_dict = self._model_task.Decode(input_batch) self.metrics_nm = py_utils.NestedMap(metrics_dict) device = tpu.core(0) if self.spmd else '' with tf.device(device): outfeed_enqueue = tpu_ops.outfeed_enqueue_tuple( self.metrics_nm.Flatten()) return [outfeed_enqueue] @tpu_function.on_device_training_loop def DecodeLoopFn(): return tpu_training_loop.repeat( self._steps_per_loop, _DecodeStep, inputs=[]) self._compile_op, self.decode_loop = tpu.split_compile_and_shard( DecodeLoopFn, num_shards=self.data_parallelism, device_assignment=device_assignment) # Get a list of outfeed ops. self.metrics = self._OutfeedDequeue() # Pack the list of outfeed ops with structure in self.metrics_nm. self.metrics = tf.nest.pack_sequence_as(self.metrics_nm, self.metrics) return
def build_model(self, model_fn, params): """Build the TPU model and infeed enqueue ops.""" tf.logging.info("TrainLowLevelRunner: build_model method") def tpu_train_step(loss): """Generate the TPU graph.""" del loss values = self.infeed_queue[0].generate_dequeue_op(tpu_device=0) unflattened_inputs = data_nest.pack_sequence_as( self.feature_structure, values) features = unflattened_inputs["features"] core_id = unflattened_inputs["core_id"] new_features = {} for k in features: s = features[k].shape.as_list() s = [self.hparams.num_shards, s[0] // self.hparams.num_shards ] + s[1:] new_features[k] = tf.squeeze( tf.gather( tf.reshape(tpu_ops.cross_replica_sum(features[k]), s), core_id), [0]) estimator_spec = model_fn(new_features, None, tf.estimator.ModeKeys.TRAIN, params) loss, train_op = estimator_spec.loss, estimator_spec.train_op with tf.control_dependencies([train_op]): return tf.identity(loss) @tpu_function.on_device_training_loop def train_loop(): return training_loop.repeat(self.iterations, tpu_train_step, [_INITIAL_LOSS]) def tpu_eval_step(): """Generate the TPU graph.""" values = self.eval_infeed_queue[0].generate_dequeue_op( tpu_device=0) unflattened_inputs = data_nest.pack_sequence_as( self.eval_feature_structure, values) features = unflattened_inputs["features"] estimator_spec = model_fn(features, None, tf.estimator.ModeKeys.PREDICT, params) for k, v in six.iteritems(estimator_spec.predictions): self.outfeed_names.append(k) self.outfeed_tensors.append(v) with tf.device( device_for_tpu_core(get_host(self.resolver, self.hparams))): outfeed_enqueue_ops = tpu_ops.outfeed_enqueue_tuple( self.outfeed_tensors) with tf.control_dependencies([outfeed_enqueue_ops]): return tf.no_op() @tpu_function.on_device_training_loop def eval_loop(): if self.eval_steps > 0: return training_loop.repeat(self.eval_steps, tpu_eval_step, []) else: return tf.no_op() def train_eval_step(): with tf.control_dependencies(train_loop()): return eval_loop() def train_eval_loop(): return training_loop.repeat(self.hparams.max_train_epochs, train_eval_step, []) def create_dequeue_ops(host_id): """Create outfeed dequeue ops.""" dequeue_ops = [] tensor_dtypes = [] tensor_shapes = [] for v in self.outfeed_tensors: dequeue_ops.append([]) tensor_dtypes.append(v.dtype) tensor_shapes.append(v.shape) for i in range(self.hparams.num_shards_per_host): with tf.device( device_for_host( get_host(self.resolver, self.hparams, host_id))): outfeed_tensors = tpu_ops.outfeed_dequeue_tuple( dtypes=tensor_dtypes, shapes=tensor_shapes, device_ordinal=i) for j, item in enumerate(outfeed_tensors): dequeue_ops[j].append(item) for j in range(len(outfeed_tensors)): dequeue_ops[j] = tf.concat(dequeue_ops[j], axis=0) return dequeue_ops with self.graph.as_default(): if self.eval_steps <= 0: (self.loss, ) = tpu.shard( train_loop, inputs=[], num_shards=self.hparams.num_shards, outputs_from_all_shards=False, ) else: ( self.compile_op, self.train_eval_op, ) = tpu.split_compile_and_shard( train_eval_loop, inputs=[], num_shards=self.hparams.num_shards, outputs_from_all_shards=False) if self.eval_steps > 0: for i in range(0, self.num_hosts): self.dequeue_ops.append({}) host_dequeue_ops = create_dequeue_ops(i) for j, dequeue_tenor in enumerate(host_dequeue_ops): self.dequeue_ops[i][ self.outfeed_names[j]] = dequeue_tenor global_initializer = tf.global_variables_initializer() local_initializer = tf.local_variables_initializer() self.sess.run(global_initializer) self.sess.run(local_initializer) graph_io.write_graph(self.graph.as_graph_def(add_shapes=True), self.hparams.out_dir, "graph.pbtxt") self.saver = tf.train.Saver()
def BuildTpuSubgraph(self): if self._ml_perf_log: mlp_log.mlperf_print('global_batch_size', self._ml_perf.global_batch_size) mlp_log.mlperf_print( 'max_sequence_length', self._ml_perf.max_sequence_length, metadata={'method': 'discard'}) mlp_log.mlperf_print('opt_name', self._ml_perf.optimizer_name) mlp_log.mlperf_print('opt_base_learning_rate', self._ml_perf.base_learning_rate) mlp_log.mlperf_print('opt_learning_rate_warmup_steps', self._ml_perf.warmup_steps) mlp_log.mlperf_print('opt_adam_beta_1', self._ml_perf.opt_adam_beta_1) mlp_log.mlperf_print('opt_adam_beta_2', self._ml_perf.opt_adam_beta_2) mlp_log.mlperf_print('opt_adam_epsilon', self._ml_perf.opt_adam_epsilon) mlp_log.mlperf_print('train_samples', self._ml_perf.train_samples) mlp_log.mlperf_print('eval_samples', self._ml_perf.eval_samples) with py_utils.OpportunisticVariableReuseScope(True): self._eval_metrics = metrics.TpuEvalMetrics() data_parallelism = self.data_parallelism self._train_task_params.input.Define('skip_create_child', True, '') self._train_input = self._train_task_params.input.Instantiate() self._train_input.CreateTpuEnqueueOps() self._decode_input = self._decode_task_params.input.Instantiate() self._decode_input.CreateTpuEnqueueOps() self._decode_task_params.input.Define('skip_create_child', True, '') def wrap_computation_in_while_loop(op_fn, n, host_device): """Wraps the ops generated by `op_fn` in tf.while_loop.""" def computation(i): ops = op_fn() if not isinstance(ops, list): ops = [ops] with tf.control_dependencies(ops): return tf.Print(i + 1, [i], 'while_loop:') with tf.device(host_device): return tf.while_loop( lambda i: tf.less(i, n), computation, [tf.constant(0)], parallel_iterations=1) def TrainAndDecodeEpoch(i, host_device): """Train and decode infeed for an epoch. Args: i: host index. host_device: host device string Returns: Decode with control deps on train node. """ train_infeed_fn = lambda: self._train_input.CreatePerHostEnqueueOp(i) decode_infeed_fn = lambda: self._decode_input.CreatePerHostEnqueueOp(i) tf.logging.info('self._train_steps_per_loop: %d', self._train_steps_per_loop) tf.logging.info('self._decode_steps_per_loop: %d', self._decode_steps_per_loop) train = wrap_computation_in_while_loop(train_infeed_fn, self._train_steps_per_loop, host_device) with tf.device(host_device): with tf.control_dependencies([train]): decode = wrap_computation_in_while_loop(decode_infeed_fn, self._decode_steps_per_loop, host_device) return decode def TrainAndDecodeEpochLoop(i, host_device): """Train and decode infeed for num_epochs_per_session_run. Args: i: host index. host_device: host device string Returns: tf.while_loop result. """ train_and_decode_epoch_fn = lambda: TrainAndDecodeEpoch(i, host_device) epoch = wrap_computation_in_while_loop(train_and_decode_epoch_fn, self.num_epochs_per_session_run, host_device) return epoch num_infeed_hosts = len(self._train_input.per_host_device) tf.logging.info('num_infeed_hosts: %d', num_infeed_hosts) self.infeed_ops = [] for i in range(num_infeed_hosts): host_device = self._train_input.per_host_device[i] self.infeed_ops.append(TrainAndDecodeEpochLoop(i, host_device)) def TpuTrainStep(): """Train a shard of a batch on a single TPU core. Do not calculate loss metrics. Returns: [train_op]. """ self._train_model = self._train_task_params.Instantiate() self._task = self._train_model.GetTask() self._task.AddChild('input', self._train_input) self._model = self._train_model self._train_model.ConstructFPropBPropGraph() return [self._train_model.GetTask().train_op] @tpu_function.on_device_training_loop def TpuTrain(): loop_result = tpu_training_loop.repeat( self._train_steps_per_loop, TpuTrainStep, inputs=[], name='train_loop') return loop_result py_utils.ResetStepSeed() def _DecodeStep(): """Decode call to be compiled for TPU.""" with py_utils.OpportunisticVariableReuseScope(True): with cluster_factory.SetEval(True): self._decode_model = self._decode_task_params.Instantiate() self._decode_model_task = self._decode_model.GetTask() self._decode_model_task.AddChild('input', self._decode_input) input_batch = self._decode_model_task.input_generator.TpuDequeueBatch( ) metrics_dict = self._decode_model_task.Decode(input_batch) self.metrics_nm = py_utils.NestedMap(metrics_dict) device = tpu.core(0) if self.spmd else '' with tf.device(device): outfeed_enqueue = tpu_ops.outfeed_enqueue_tuple( self.metrics_nm.Flatten()) return [outfeed_enqueue] @tpu_function.on_device_training_loop def DecodeLoopFn(): return tpu_training_loop.repeat( self._decode_steps_per_loop, _DecodeStep, inputs=[]) def TrainAndDecode(): with tf.control_dependencies([TpuTrain()]): return DecodeLoopFn() @OnDeviceTrainAndEvalLoops def TrainAndDecodeLoop(): tpu_training_loop.repeat( self.num_epochs_per_session_run, TrainAndDecode, inputs=[]) self._compile_op, self.train_and_decode_loop = tpu.split_compile_and_shard( TrainAndDecodeLoop, num_shards=data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) # Get a list of outfeed ops. self.metric_dicts = self._OutfeedDequeue() # Saves the graph def. tf.io.write_graph(tf.get_default_graph().as_graph_def(), self._logdir, 'train.pbtxt') return