def _train_on_tpu_shards(run_config, train_step): """Executes the `train_step` on all shards.""" def train_shard(): return training_loop.repeat(run_config.tpu_config.iterations_per_loop, train_step, [1e7], # initial_loss name='loop') (loss,) = tpu.shard(train_shard, inputs=[], num_shards=run_config.tpu_config.num_shards, outputs_from_all_shards=False) return loss
def _train_on_tpu_system(model_fn_wrapper, dequeue_fn): """Executes `model_fn_wrapper` multiple times on all TPU shards.""" config = model_fn_wrapper.config.tpu_config iterations_per_loop = config.iterations_per_loop num_shards = config.num_shards single_tpu_train_step = model_fn_wrapper.convert_to_single_tpu_train_step( dequeue_fn) multi_tpu_train_steps_on_single_shard = (lambda: training_loop.repeat( # pylint: disable=g-long-lambda iterations_per_loop, single_tpu_train_step, [_INITIAL_LOSS], name='loop')) (loss,) = tpu.shard(multi_tpu_train_steps_on_single_shard, inputs=[], num_shards=num_shards, outputs_from_all_shards=False) return loss