Ejemplo n.º 1
0
def _train_on_tpu_shards(run_config, train_step):
  """Executes the `train_step` on all shards."""
  def train_shard():
    return training_loop.repeat(run_config.tpu_config.iterations_per_loop,
                                train_step,
                                [1e7],  # initial_loss
                                name='loop')

  (loss,) = tpu.shard(train_shard,
                      inputs=[],
                      num_shards=run_config.tpu_config.num_shards,
                      outputs_from_all_shards=False)
  return loss
Ejemplo n.º 2
0
def _train_on_tpu_shards(run_config, train_step):
  """Executes the `train_step` on all shards."""
  def train_shard():
    return training_loop.repeat(run_config.tpu_config.iterations_per_loop,
                                train_step,
                                [1e7],  # initial_loss
                                name='loop')

  (loss,) = tpu.shard(train_shard,
                      inputs=[],
                      num_shards=run_config.tpu_config.num_shards,
                      outputs_from_all_shards=False)
  return loss
Ejemplo n.º 3
0
def _train_on_tpu_system(model_fn_wrapper, dequeue_fn):
  """Executes `model_fn_wrapper` multiple times on all TPU shards."""
  config = model_fn_wrapper.config.tpu_config
  iterations_per_loop = config.iterations_per_loop
  num_shards = config.num_shards

  single_tpu_train_step = model_fn_wrapper.convert_to_single_tpu_train_step(
      dequeue_fn)

  multi_tpu_train_steps_on_single_shard = (lambda: training_loop.repeat(  # pylint: disable=g-long-lambda
      iterations_per_loop, single_tpu_train_step, [_INITIAL_LOSS], name='loop'))

  (loss,) = tpu.shard(multi_tpu_train_steps_on_single_shard,
                      inputs=[],
                      num_shards=num_shards,
                      outputs_from_all_shards=False)
  return loss
Ejemplo n.º 4
0
def _train_on_tpu_system(model_fn_wrapper, dequeue_fn):
  """Executes `model_fn_wrapper` multiple times on all TPU shards."""
  config = model_fn_wrapper.config.tpu_config
  iterations_per_loop = config.iterations_per_loop
  num_shards = config.num_shards

  single_tpu_train_step = model_fn_wrapper.convert_to_single_tpu_train_step(
      dequeue_fn)

  multi_tpu_train_steps_on_single_shard = (lambda: training_loop.repeat(  # pylint: disable=g-long-lambda
      iterations_per_loop, single_tpu_train_step, [_INITIAL_LOSS], name='loop'))

  (loss,) = tpu.shard(multi_tpu_train_steps_on_single_shard,
                      inputs=[],
                      num_shards=num_shards,
                      outputs_from_all_shards=False)
  return loss