def BuildTpuSubgraph(self): tf.logging.info('DecodeProgram BuildTpuSubGraph') py_utils.ResetStepSeed() def _DecodeFn(): """Decode call to be compiled for TPU.""" with py_utils.OpportunisticVariableReuseScope(True): with cluster_factory.SetEval(True): self._model = self._task_params.Instantiate() self._model_task = self._model.GetTask() if py_utils.use_tpu(): input_batch = self._model_task.input_generator.CreateTpuFeeds( ) else: input_batch = self._model_task.input_generator.SplitInputBatch( self.cluster.num_splits_per_client) metrics_dict = self._model_task.Decode(input_batch) self.metrics_nm = py_utils.NestedMap(metrics_dict) return self.metrics_nm.Flatten() self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( _DecodeFn, num_shards=self.data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) self.metrics = py_utils.NestedMap(self.metrics_nm) self.metrics = self.metrics.Pack(batch_parallel_res) return None
def BuildTpuSubgraph(self): tf.logging.info('EvalProgram BuildTpuSubGraph') with cluster_factory.SetEval(True): self._eval_metrics = metrics.TpuEvalMetrics() data_parallelism = self.data_parallelism with cluster_factory.SetImmediatelyInstantiateVariables(False): self._model = self._InstantiateTaskModel(self._task_params) self._task = self._model.GetTask() self._task.input.InstantiateVariables() self._task.input.CreateTpuEnqueueOps() self._init_input_ops = self._task.input.InitOps() # XLA thinks self.TpuEvalLoop() requires 1 argument due to self # Trick it with wrapper function def TpuEvalLoopWrapper(): return self.TpuEvalLoop() self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( TpuEvalLoopWrapper, num_shards=data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) self._task.input.CreateTpuEmbeddingEnqueueOps(mode_override='inference') # Get metric result from a single replica; they are all same here. self.tpu_ops = [[t[0] for t in batch_parallel_res]] return self.tpu_ops
def _CompileDecodeFn(self): """Wrap the DecodeFn with split_compile_and_shard.""" with cluster_factory.SetImmediatelyInstantiateVariables(False): self._model = self._InstantiateTaskModel(self._task_params) self._task = self._model.GetTask() self._task.input.InstantiateVariables() self._task.input.CreateTpuEnqueueOps() self._task.input.CreateCpuPassthroughEnqueueOps() def _DecodeFn(): """Decode call to be compiled for TPU.""" with py_utils.OpportunisticVariableReuseScope(True): self._model.InstantiateVariables() input_batch = self._task.input.TpuDequeueBatch() decode_dict = self._task.Decode(input_batch) self.decode_nm = py_utils.NestedMap(decode_dict) return self.decode_nm.Flatten() self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( _DecodeFn, num_shards=self.data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) self.cpu_pt = self._task.input.DequeueCpuPassthrough() self.decode_tensors = py_utils.NestedMap(self.decode_nm) self.decode_tensors = self.decode_tensors.Pack(batch_parallel_res)
def _CompileDecodeLoop(self): """Wrap the DecodeLoop with split_compile_and_shard.""" device_assignment = py_utils.GetTpuDeviceAssignment() with cluster_factory.SetImmediatelyInstantiateVariables(False): self._model = self._InstantiateTaskModel(self._task_params) self._task = self._model.GetTask() self._task.input.InstantiateVariables() self._task.input.CreateTpuEnqueueOps() self._task.input.CreateCpuPassthroughEnqueueOps() def _DecodeStep(): """Decode call to be compiled for TPU.""" with py_utils.OpportunisticVariableReuseScope(True): self._model.InstantiateVariables() input_batch = self._task.input.TpuDequeueBatch() decode_dict = self._task.Decode(input_batch) self.decode_nm = py_utils.NestedMap(decode_dict) return [self._OutfeedEnqueue(decode_dict)] @tpu_function.on_device_training_loop def DecodeLoopFn(): return tpu_training_loop.repeat( self._steps_per_loop, _DecodeStep, inputs=[]) self._compile_op, self.decode_loop = tpu.split_compile_and_shard( DecodeLoopFn, num_shards=self.data_parallelism, device_assignment=device_assignment) # Get a list of outfeed ops. self.decode_tensors = self._OutfeedDequeue() # Pack the list of outfeed ops with structure in self.decode_nm. self.decode_tensors = tf.nest.pack_sequence_as(self.decode_nm, self.decode_tensors) self.cpu_pt = self._task.input.DequeueCpuPassthrough()
def BuildTpuSubgraph(self): tf.logging.info('DecodeProgram BuildTpuSubGraph') py_utils.ResetStepSeed() # Instantiate input generator first. self._input = self._task_params.input.Instantiate() self._input.CreateTpuEnqueueOps() self.SkipCreateChild(self._task_params) def _DecodeFn(): """Decode call to be compiled for TPU.""" with py_utils.OpportunisticVariableReuseScope(True): with cluster_factory.SetEval(True): self._model = self._task_params.Instantiate() self._task = self._model.GetTask() self._task.AddChild('input', self._input) input_batch = self._task.input.TpuDequeueBatch() metrics_dict = self._task.Decode(input_batch) self.metrics_nm = py_utils.NestedMap(metrics_dict) return self.metrics_nm.Flatten() self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( _DecodeFn, num_shards=self.data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) self.metrics = py_utils.NestedMap(self.metrics_nm) self.metrics = self.metrics.Pack(batch_parallel_res) return None
def BuildTpuSubgraph(self): tf.logging.info('TrainProgram BuildTpuSubGraph') with py_utils.OpportunisticVariableReuseScope(True): self._eval_metrics = metrics.TpuEvalMetrics() data_parallelism = self.data_parallelism # Instantiate input generator first. self._input = self._task_params.input.Instantiate() self._input.CreateTpuEnqueueOps() self.SkipCreateChild(self._task_params) def TpuTrainStep(*args): """Train a shard of a batch on a single TPU core. Args: *args: metrics values from previous steps. Returns: New summed metrics values and a train_op. """ self._model = self._task_params.Instantiate() self._task = self._model.GetTask() self._task.AddChild('input', self._input) self._model.ConstructFPropBPropGraph() per_step_eval_metrics = self._eval_metrics.SetMetrics( self._task.eval_metrics, args) outfeed_op = self._OutfeedEnqueue( self._task.per_example_tensors) summed_metrics = [] assert len(per_step_eval_metrics) == len(args) with tf.control_dependencies([outfeed_op]): for x, y in zip(per_step_eval_metrics, args): summed_metrics.append(x + y) return summed_metrics + [self._task.train_op] @tpu_function.on_device_training_loop def TpuTrain(): loop_result = tpu_training_loop.repeat( self._steps_per_loop, TpuTrainStep, inputs=self._eval_metrics.initial_values, name='train_loop') # Final metrics are the avg across self._steps_per_loop steps. return self._eval_metrics.FinalizeMetrics(loop_result) self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( TpuTrain, num_shards=data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) outfeed_dequeue_op = self._OutfeedDequeueLoop( self._task.per_example_tensors, self._steps_per_loop, self.num_splits_per_client) # Get metric result from a single replica; they are all same here. self.tpu_ops = [[t[0] for t in batch_parallel_res], outfeed_dequeue_op] return self.tpu_ops
def BuildTpuSubgraph(self): tf.logging.info('EvalProgram BuildTpuSubGraph') with cluster_factory.SetEval(True): self._eval_metrics = metrics.TpuEvalMetrics() data_parallelism = self.data_parallelism with cluster_factory.SetImmediatelyInstantiateVariables(False): self._model = self._InstantiateTaskModel(self._task_params) self._task = self._model.GetTask() self._task.input.InstantiateVariables() self._task.input.CreateTpuEnqueueOps() def TpuEvalStep(*args): """Eval a shard of a batch on a single TPU core. Args: *args: metrics values from previous steps. Returns: Summed eval metrics. """ with tf.name_scope('tpu_eval'): with py_utils.OpportunisticVariableReuseScope(True): self._model.InstantiateVariables() self._model.ConstructFPropGraph() per_step_eval_metrics = self._eval_metrics.SetMetrics( self._task.eval_metrics, args) summed_metrics = [] for x, y in zip(per_step_eval_metrics, args): summed_metrics.append(x + y) return summed_metrics @tpu_function.on_device_training_loop def TpuEval(): loop_result = tpu_training_loop.repeat( self._steps_per_loop, TpuEvalStep, inputs=self._eval_metrics.initial_values, name='eval_loop') # Final metrics are the avg across self._steps_per_loop steps. return self._eval_metrics.FinalizeMetrics(loop_result) self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( TpuEval, num_shards=data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) self._task.input.CreateTpuEmbeddingEnqueueOps( mode_override='inference') # Get metric result from a single replica; they are all same here. self.tpu_ops = [[t[0] for t in batch_parallel_res]] return self.tpu_ops
def BuildTpuSubgraph(self): tf.logging.info('DecodeProgram BuildTpuSubGraph') py_utils.ResetStepSeed() device_assignment = py_utils.GetTpuDeviceAssignment() self.spmd = self._task_params.input.use_partitioned_infeed_queue with cluster_factory.SetEval(True): with cluster_factory.SetImmediatelyInstantiateVariables(False): self._model = self._task_params.Instantiate() self._task = self._model.GetTask() self._task.input.InstantiateVariables() self._task.input.CreateTpuEnqueueOps() def _DecodeStep(): """Decode call to be compiled for TPU.""" with py_utils.OpportunisticVariableReuseScope(True): self._model.InstantiateVariables() input_batch = self._task.input.TpuDequeueBatch() metrics_dict = self._task.Decode(input_batch) self.metrics_nm = py_utils.NestedMap(metrics_dict) device = tpu.core(0) if self.spmd else '' with tf.device(device): outfeed_enqueue = tpu_ops.outfeed_enqueue_tuple( self.metrics_nm.Flatten()) return [outfeed_enqueue] @tpu_function.on_device_training_loop def DecodeLoopFn(): return tpu_training_loop.repeat(self._steps_per_loop, _DecodeStep, inputs=[]) self._compile_op, self.decode_loop = tpu.split_compile_and_shard( DecodeLoopFn, num_shards=self.data_parallelism, device_assignment=device_assignment) # Get a list of outfeed ops. self.metrics = self._OutfeedDequeue() # Pack the list of outfeed ops with structure in self.metrics_nm. self.metrics = tf.nest.pack_sequence_as(self.metrics_nm, self.metrics) return
def BuildTpuSubgraph(self): if self._ml_perf_log: mlp_log.mlperf_print('global_batch_size', self._ml_perf.global_batch_size) mlp_log.mlperf_print('max_sequence_length', self._ml_perf.max_sequence_length) mlp_log.mlperf_print('opt_name', self._ml_perf.optimizer_name) mlp_log.mlperf_print('opt_base_learning_rate', self._ml_perf.base_learning_rate) mlp_log.mlperf_print('opt_learning_rate_warmup_steps', self._ml_perf.warmup_steps) with py_utils.OpportunisticVariableReuseScope(True): self._eval_metrics = metrics.TpuEvalMetrics() data_parallelism = self.data_parallelism def TpuTrainStep(): """Train a shard of a batch on a single TPU core. Do not calculate loss metrics. Returns: [train_op]. """ self._train_model = self._train_task_params.Instantiate() self._model = self._train_model self._train_model.ConstructFPropBPropGraph() return [self._train_model.GetTask().train_op] def TpuTrain(): loop_result = tpu_training_loop.repeat( self._train_steps_per_loop, TpuTrainStep, inputs=[], name='train_loop') return loop_result py_utils.ResetStepSeed() def _DecodeFn(): """Decode call to be compiled for TPU.""" with py_utils.OpportunisticVariableReuseScope(True): with cluster_factory.SetEval(True): self._decode_model = self._decode_task_params.Instantiate() self._decode_model_task = self._decode_model.GetTask() if py_utils.use_tpu(): input_batch = self._decode_model_task.input_generator.CreateTpuFeeds( ) else: input_batch = self._decode_model_task.input_generator.SplitInputBatch( self.cluster.num_splits_per_client) metrics_dict = self._decode_model_task.Decode(input_batch) self.metrics_nm = py_utils.NestedMap(metrics_dict) return self.metrics_nm.Flatten() @tpu_function.on_device_training_loop def TrainAndDecode(): with tf.control_dependencies([TpuTrain()]): return _DecodeFn() self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( TrainAndDecode, num_shards=data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) self.metrics = py_utils.NestedMap(self.metrics_nm) self.metrics = self.metrics.Pack(batch_parallel_res) return None
def BuildTpuSubgraph(self): tf.logging.info('TrainProgram BuildTpuSubGraph') self.spmd = (self.params.spmd or self._task_params.input.use_partitioned_infeed_queue) self._eval_metrics = metrics.TpuEvalMetrics() data_parallelism = self.data_parallelism with cluster_factory.SetImmediatelyInstantiateVariables(False): self._model = self._InstantiateTaskModel(self._task_params) self._task = self._model.GetTask() self._task.input.InstantiateVariables() self._task.input.CreateTpuEnqueueOps() def TpuTrainStep(*args): """Train a shard of a batch on a single TPU core. Args: *args: metrics values from previous steps. Returns: New summed metrics values and a train_op. """ with tf.name_scope('tpu_train'): with py_utils.OpportunisticVariableReuseScope(True): self._model.InstantiateVariables() self._model.ConstructFPropBPropGraph() per_step_eval_metrics = self._eval_metrics.SetMetrics( self._task.eval_metrics, args) outfeed_op = self._OutfeedEnqueue( self._task.per_example_tensors) summed_metrics = [] assert len(per_step_eval_metrics) == len(args) with tf.control_dependencies([outfeed_op]): for x, y in zip(per_step_eval_metrics, args): summed_metrics.append(x + y) return summed_metrics + [self._task.train_op] @tpu_function.on_device_training_loop def TpuTrain(): loop_result = tpu_training_loop.repeat( self._steps_per_loop, TpuTrainStep, inputs=self._eval_metrics.initial_values, name='train_loop') # Final metrics are the avg across self._steps_per_loop steps. return self._eval_metrics.FinalizeMetrics(loop_result) self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( TpuTrain, num_shards=data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) outfeed_dequeue_op = self._OutfeedDequeueLoop( self._task.per_example_tensors, self._steps_per_loop, self.num_splits_per_client) self._task.input.CreateTpuEmbeddingEnqueueOps() # Get metric result from a single replica; they are all same here. def _ConstructPostTrainingLoop(train_loop_op, outfeed_dequeue_op): """Returns the op for tpu training with tail cpu computation.""" # Adds a tail computation that is run after the tpu_training loop # step finishes. This allows us to run certain computation that # acts on the variable between tpu_train_loop iterations and # amortizing the cost of the operations. Alternative of running # tpu.outside_compilation & using tf.cond is expenseive. with tf.control_dependencies(train_loop_op): self._model.ConstructPostTrainingLoop() with tf.control_dependencies( [self._task.post_training_loop_op]): return ([[tf.identity(o) for o in train_loop_op], outfeed_dequeue_op]) # Get metric result from a single replica; they are all same here. all_tpu_ops = [t[0] for t in batch_parallel_res] self.tpu_ops = (_ConstructPostTrainingLoop(all_tpu_ops, outfeed_dequeue_op)) self._model_analysis, self._total_num_params = summary_utils.ModelAnalysis( self._model) try: with tf.io.gfile.GFile( os.path.join(self._program_dir, 'model_analysis.txt'), 'w') as f: f.write(self._model_analysis) except tf.errors.NotFoundError as e: tf.logging.info('Failed to write model analysis %s', e) return self.tpu_ops
def init_graph(self, model_params): """Builds moe decode graph. Args: model_params: the hyperparams of the specified model. """ assert self.graph self.model_params = model_params batch_size = model_params.task.batch_size if (hasattr(model_params.task.builder, 'device_mesh_shape') and model_params.task.builder.device_mesh_shape): num_partitions = np.prod( model_params.task.builder.device_mesh_shape) else: num_partitions = model_params.task.builder.num_devices device_order_mode = (model_params.task.train.tpu_device_order_mode or tpu_device_assignment.DeviceOrderMode.AUTO) self._init_tpu(num_partitions, device_order_mode) assert self.cluster_params # configured by init_tpu self.cluster = self.cluster_params.Instantiate() with self.graph.as_default(), self.cluster, tf.device( self.cluster.GetPlacer()): _ = py_utils.GetOrCreateGlobalStepVar() self.heartbeat = tf.constant(np.pi) device_assignment = py_utils.GetTpuDeviceAssignment() tf.logging.info('Instantiating model') model = model_params.Instantiate() xformer = model.GetTask() self.task = xformer self.init_vars_op = tf.global_variables_initializer() self.saver = tf.train.Saver(sharded=True, reshape=self._saver_reshape) infeed = self._config_infeed(num_partitions=num_partitions, device_assignment=device_assignment, batch_size=batch_size) self.outfeed = [] def decode_fn(*infeed_batch): # pylint: disable=missing-docstring # Length 6 is passed when there is no tgt_mask (e.g. decoding) and # length 7 is passed when there is a tgt_mask (e.g. fprop). self.outfeed = self._config_outfeed(xformer, infeed_batch) with tf.device(tf.tpu.core(0)): outfeed_op = tpu_ops.outfeed_enqueue_tuple( tf.nest.flatten(self.outfeed)) return [outfeed_op] @tpu_function.on_device_training_loop def decode_loop_fn(): if not self.num_batches: infinite_repeat(decode_fn, infeed) else: training_loop.repeat(self.num_batches, decode_fn, infeed_queue=infeed) self.compile_op, self.decode_loop = tpu_lib.split_compile_and_shard( decode_loop_fn, num_shards=1, device_assignment=device_assignment) assert self.outfeed with tf.device(device_assignment.tpu_device(0, 0)): self.outfeed_op = tpu_ops.outfeed_dequeue_tuple( dtypes=[x.dtype for x in tf.nest.flatten(self.outfeed)], shapes=[x.shape for x in tf.nest.flatten(self.outfeed)])
def build_model(self, model_fn, eval_model_fn, params, hparams, config): """Build the TPU model for training and eval.""" tf.logging.info( "LowLevelRunner: build_model method for training and eval.") def tpu_train_step(loss): """Generate the TPU graph.""" del loss values = self.infeed_queue[0].generate_dequeue_op(tpu_device=0) unflattened_inputs = data_nest.pack_sequence_as( self.feature_structure, values) features = unflattened_inputs["features"] labels = unflattened_inputs["labels"] estimator_spec = model_fn(features, labels, tf.estimator.ModeKeys.TRAIN, params=params, config=config) loss, train_op = estimator_spec.loss, estimator_spec.train_op self.scaffold_fn = estimator_spec.scaffold_fn with tf.control_dependencies([train_op]): return tf.identity(loss) @tpu_function.on_device_training_loop def train_loop(): return training_loop.repeat(self.train_steps_tensor, tpu_train_step, [_INITIAL_LOSS]) def tpu_eval_step(): """Generate the TPU graph.""" values = self.eval_infeed_queue[0].generate_dequeue_op( tpu_device=0) unflattened_inputs = data_nest.pack_sequence_as( self.eval_feature_structure, values) features = unflattened_inputs["features"] estimator_spec = eval_model_fn(features, None, tf.estimator.ModeKeys.PREDICT, params=params, config=config) for k, v in six.iteritems(estimator_spec.predictions): self.outfeed_names.append(k) self.outfeed_tensors.append(v) with tf.device( low_level_utils.device_for_tpu_core(self._get_host(0))): outfeed_enqueue_ops = tpu_ops.outfeed_enqueue_tuple( self.outfeed_tensors) with tf.control_dependencies([outfeed_enqueue_ops]): return tf.no_op() @tpu_function.on_device_training_loop def eval_loop(): return training_loop.repeat(self.eval_steps_tensor, tpu_eval_step, []) def train_eval_step(): with tf.control_dependencies(train_loop()): return eval_loop() @tpu_function.on_device_training_loop def train_eval_loop(): return training_loop.repeat(self.num_epochs_tensor, train_eval_step, []) with self.graph.as_default(): ( self.compile_op, self.train_eval_op, ) = tpu.split_compile_and_shard( train_eval_loop, inputs=[], num_shards=FLAGS.tpu_num_shards, outputs_from_all_shards=False, ) if self.scaffold_fn: self.scaffold_fn() self.sess.run(tf.global_variables_initializer()) self.sess.run(tf.local_variables_initializer()) graph_io.write_graph(self.graph.as_graph_def(add_shapes=True), FLAGS.output_dir, "graph.pbtxt") def create_dequeue_ops(host_id): """Create outfeed dequeue ops.""" dequeue_ops = [] tensor_dtypes = [] tensor_shapes = [] for v in self.outfeed_tensors: dequeue_ops.append([]) tensor_dtypes.append(v.dtype) tensor_shapes.append(v.shape) for i in range(FLAGS.tpu_num_shards_per_host): with tf.device( low_level_utils.device_for_host( self._get_host(host_id))): outfeed = tpu_ops.outfeed_dequeue_tuple( dtypes=tensor_dtypes, shapes=tensor_shapes, device_ordinal=i) for j, item in enumerate(outfeed): dequeue_ops[j].append(item) for j in range(len(outfeed)): dequeue_ops[j] = tf.concat(dequeue_ops[j], axis=0) return dequeue_ops with self.output_graph.as_default(): # Get dequeue ops from each hosts. for i in range(self.num_hosts): tf.logging.info( "LowLevelRunner: get dequeue ops for host: %d.", i) local_batch_size = hparams.batch_size // self.num_hosts local_dequeue_ops = [] for n in range(local_batch_size): local_dequeue_ops.append({}) for j, dequeue_tensor in enumerate(create_dequeue_ops(i)): if self.outfeed_names[j] in ("inputs", "targets", "outputs"): dequeue_tensors = tf.split(dequeue_tensor, local_batch_size, axis=0) for n in range(local_batch_size): local_dequeue_ops[n][ self.outfeed_names[j]] = dequeue_tensors[n] for j, dequeue_dict in enumerate(local_dequeue_ops): self.dequeue_ops.append(dequeue_dict)
def build_eval_model(self, model_fn, params): """Build the Eval TPU model and infeed enqueue ops.""" tf.logging.info("TrainAndEvalLowLevelRunner: build_model method") # TODO(wangtao): refactor to extract common logic with tpu_train_step. def tpu_eval_step(): """Generate the TPU graph.""" values = self.eval_infeed_queue[0].generate_dequeue_op( tpu_device=0) unflattened_inputs = data_nest.pack_sequence_as( self.eval_feature_structure, values) features = unflattened_inputs["features"] estimator_spec = model_fn(features, None, tf.estimator.ModeKeys.PREDICT, params) for k, v in six.iteritems(estimator_spec.predictions): self.outfeed_names.append(k) self.outfeed_tensors.append(v) with tf.device(utils.device_for_tpu_core(self._get_host(0))): outfeed_enqueue_ops = tpu_ops.outfeed_enqueue_tuple( self.outfeed_tensors) with tf.control_dependencies([outfeed_enqueue_ops]): return tf.no_op() def eval_loop(): return training_loop.repeat(self.eval_steps, tpu_eval_step, []) def train_eval_step(iteration): with tf.control_dependencies(self.train_loop()): should_eval = tf.reduce_any( tf.equal(tf.constant(self.eval_iterations), iteration)) should_eval = tf.logical_or( should_eval, tf.constant(self.params["eval_every_checkpoint"])) ops = tf.cond(should_eval, lambda: eval_loop(), lambda: tf.no_op()) # pylint: disable=unnecessary-lambda with tf.control_dependencies([ops]): return iteration + 1 @on_device_train_and_eval_loops def train_eval_loop(): return training_loop.repeat(self.max_train_iterations, train_eval_step, [0]) self.eval_epochs = [ steps * ssd_constants.DEFAULT_BATCH_SIZE / FLAGS.train_batch_size // params["steps_per_epoch"] for steps in self.eval_at_steps ] self.log_epochs = dict( zip(self.eval_epochs, [False for _ in self.eval_epochs])) self.epoch_count = dict( zip(self.eval_epochs, [self.eval_epochs[0]] + np.diff(self.eval_epochs).tolist())) # TODO(wangtao): refactor to extract common logic # with train create_dequeu_ops. def create_dequeue_ops(host_id): """Create outfeed dequeue ops.""" dequeue_ops = [] tensor_dtypes = [] tensor_shapes = [] for v in self.outfeed_tensors: tensor_dtypes.append(v.dtype) tensor_shapes.append(v.shape) with tf.device(utils.device_for_host(self._get_host(host_id))): for i in range(self.replicas_per_worker): if self.use_spatial_partition: replica_id = self.device_assignment.lookup_replicas( host_id, 0)[i] ordinal = self.device_assignment.tpu_ordinal( replica=replica_id, logical_core=0) else: ordinal = i outfeed = tpu_ops.outfeed_dequeue_tuple( dtypes=tensor_dtypes, shapes=tensor_shapes, device_ordinal=ordinal) if len(outfeed) == 2: # 2 outfeed tensors # is_pad: [batch] # detections: [batch, 200, 7] if outfeed[0].shape.ndims == 3: detections, is_pad = outfeed else: is_pad, detections = outfeed num_non_pad = tf.shape(is_pad)[0] - tf.reduce_sum( tf.cast(is_pad, tf.int32)) dequeue_ops.append( tf.slice(detections, [0, 0, 0], [num_non_pad, -1, -1])) else: # no padding, only detections are in the outfeed dequeue_ops.append(outfeed) dequeue_ops = tf.concat(dequeue_ops, axis=0) return dequeue_ops with self.graph.as_default(): ( self.train_eval_compile_op, self.train_eval_op, ) = tpu.split_compile_and_shard( train_eval_loop, inputs=[], num_shards=self.num_shards, outputs_from_all_shards=False, device_assignment=self.device_assignment, ) # Get dequeue ops from each hosts. for i in range(self.num_hosts): self.dequeue_ops.append(create_dequeue_ops(i))
def BuildTpuSubgraph(self): tf.logging.info('TrainProgram BuildTpuSubGraph') with py_utils.OpportunisticVariableReuseScope(True): self._eval_metrics = metrics.TpuEvalMetrics() data_parallelism = self.data_parallelism # Instantiate input generator first. self._input = self._task_params.input.Instantiate() self._input.CreateTpuEnqueueOps() self.SkipCreateChild(self._task_params) def TpuTrainStep(*args): """Train a shard of a batch on a single TPU core. Args: *args: metrics values from previous steps. Returns: New summed metrics values and a train_op. """ self._model = self._task_params.Instantiate() self._task = self._model.GetTask() self._task.AddChild('input', self._input) self._model.ConstructFPropBPropGraph() per_step_eval_metrics = self._eval_metrics.SetMetrics( self._task.eval_metrics, args) outfeed_op = self._OutfeedEnqueue( self._task.per_example_tensors) summed_metrics = [] assert len(per_step_eval_metrics) == len(args) with tf.control_dependencies([outfeed_op]): for x, y in zip(per_step_eval_metrics, args): summed_metrics.append(x + y) return summed_metrics + [self._task.train_op] @tpu_function.on_device_training_loop def TpuTrain(): loop_result = tpu_training_loop.repeat( self._steps_per_loop, TpuTrainStep, inputs=self._eval_metrics.initial_values, name='train_loop') # Final metrics are the avg across self._steps_per_loop steps. return self._eval_metrics.FinalizeMetrics(loop_result) self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( TpuTrain, num_shards=data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) outfeed_dequeue_op = self._OutfeedDequeueLoop( self._task.per_example_tensors, self._steps_per_loop, self.num_splits_per_client) # Get metric result from a single replica; they are all same here. def _ConstructPostTrainingLoop(train_loop_op, outfeed_dequeue_op): """Returns the op for tpu training with tail cpu computation.""" # Adds a tail computation that is run after the tpu_training loop # step finishes. This allows us to run certain computation that # acts on the variable between tpu_train_loop iterations and # amortizing the cost of the operations. Alternative of running # tpu.outside_compilation & using tf.cond is expenseive. with tf.control_dependencies(train_loop_op): self._model.ConstructPostTrainingLoop() with tf.control_dependencies( [self._task.post_training_loop_op]): return ([[tf.identity(o) for o in train_loop_op], outfeed_dequeue_op]) # Get metric result from a single replica; they are all same here. all_tpu_ops = [t[0] for t in batch_parallel_res] self.tpu_ops = (_ConstructPostTrainingLoop(all_tpu_ops, outfeed_dequeue_op)) return self.tpu_ops
def build_model(self, model_fn, params): """Build the TPU model and infeed enqueue ops.""" tf.logging.info("TrainLowLevelRunner: build_model method") def tpu_train_step(loss): """Generate the TPU graph.""" del loss values = self.infeed_queue[0].generate_dequeue_op(tpu_device=0) unflattened_inputs = data_nest.pack_sequence_as(self.feature_structure, values) features = unflattened_inputs["features"] core_id = unflattened_inputs["core_id"] new_features = {} for k in features: s = features[k].shape.as_list() s = [self.hparams.num_shards, s[0] // self.hparams.num_shards] + s[1:] new_features[k] = tf.squeeze( tf.gather( tf.reshape(tpu_ops.cross_replica_sum(features[k]), s), core_id), [0]) estimator_spec = model_fn(new_features, None, tf.estimator.ModeKeys.TRAIN, params) loss, train_op = estimator_spec.loss, estimator_spec.train_op with tf.control_dependencies([train_op]): return tf.identity(loss) @tpu_function.on_device_training_loop def train_loop(): return training_loop.repeat(self.iterations, tpu_train_step, [_INITIAL_LOSS]) def tpu_eval_step(): """Generate the TPU graph.""" values = self.eval_infeed_queue[0].generate_dequeue_op(tpu_device=0) unflattened_inputs = data_nest.pack_sequence_as( self.eval_feature_structure, values) features = unflattened_inputs["features"] estimator_spec = model_fn(features, None, tf.estimator.ModeKeys.PREDICT, params) for k, v in six.iteritems(estimator_spec.predictions): self.outfeed_names.append(k) self.outfeed_tensors.append(v) with tf.device( device_for_tpu_core(get_host(self.resolver, self.hparams))): outfeed_enqueue_ops = tpu_ops.outfeed_enqueue_tuple( self.outfeed_tensors) with tf.control_dependencies([outfeed_enqueue_ops]): return tf.no_op() @tpu_function.on_device_training_loop def eval_loop(): if self.eval_steps > 0: return training_loop.repeat(self.eval_steps, tpu_eval_step, []) else: return tf.no_op() def train_eval_step(): with tf.control_dependencies(train_loop()): return eval_loop() def train_eval_loop(): return training_loop.repeat(self.hparams.max_train_epochs, train_eval_step, []) def create_dequeue_ops(host_id): """Create outfeed dequeue ops.""" dequeue_ops = [] tensor_dtypes = [] tensor_shapes = [] for v in self.outfeed_tensors: dequeue_ops.append([]) tensor_dtypes.append(v.dtype) tensor_shapes.append(v.shape) for i in range(self.hparams.num_shards_per_host): with tf.device( device_for_host(get_host(self.resolver, self.hparams, host_id))): outfeed_tensors = tpu_ops.outfeed_dequeue_tuple( dtypes=tensor_dtypes, shapes=tensor_shapes, device_ordinal=i) for j, item in enumerate(outfeed_tensors): dequeue_ops[j].append(item) for j in range(len(outfeed_tensors)): dequeue_ops[j] = tf.concat(dequeue_ops[j], axis=0) return dequeue_ops with self.graph.as_default(): if self.eval_steps <= 0: (self.loss,) = tpu.shard( train_loop, inputs=[], num_shards=self.hparams.num_shards, outputs_from_all_shards=False, ) else: ( self.compile_op, self.train_eval_op, ) = tpu.split_compile_and_shard( train_eval_loop, inputs=[], num_shards=self.hparams.num_shards, outputs_from_all_shards=False) if self.eval_steps > 0: for i in range(0, self.num_hosts): self.dequeue_ops.append({}) host_dequeue_ops = create_dequeue_ops(i) for j, dequeue_tenor in enumerate(host_dequeue_ops): self.dequeue_ops[i][self.outfeed_names[j]] = dequeue_tenor global_initializer = tf.global_variables_initializer() local_initializer = tf.local_variables_initializer() self.sess.run(global_initializer) self.sess.run(local_initializer) graph_io.write_graph( self.graph.as_graph_def(add_shapes=True), self.hparams.out_dir, "graph.pbtxt") self.saver = tf.train.Saver()