コード例 #1
0
  def create_optimizer(self):
    """Create the optimizer used for training.

    This function optionally wraps the base optimizer with SyncReplicasOptimizer
    (aggregrate gradients across devices).

    Returns:
      An instance of `tf.train.Optimizer`.
    """
    config = self.get_run_config()
    optimizer = self._create_optimizer_fn()
    if self._use_avg_model_params:
      optimizer = optimizers.create_moving_average_optimizer(optimizer)

      def create_swapping_saver_scaffold(saver=None):
        saver = optimizers.create_swapping_saver(optimizer)
        tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
        return tf.train.Scaffold(saver=saver)

      self._scaffold_fn = create_swapping_saver_scaffold
    if (self._use_sync_replicas_optimizer and (not self.is_device_tpu) and
        config is not None and config.num_worker_replicas > 1):
      optimizer = tf.train.SyncReplicasOptimizer(
          optimizer,
          replicas_to_aggregate=config.num_worker_replicas - 1,
          total_num_replicas=config.num_worker_replicas)
      self._sync_replicas_optimizer = optimizer
    return optimizer
コード例 #2
0
ファイル: abstract_model.py プロジェクト: HK2-D/tensor2robot
  def create_optimizer(self, params):
    """Create the optimizer used for training.

    This function optionally wraps the base optimizer with SyncReplicasOptimizer
    (aggregrate gradients across devices).

    Args:
      params: An optional dict of hyper parameters that will be passed into
        input_fn and model_fn. Keys are names of parameters, values are basic
        python types. There are reserved keys for TPUEstimator, including
        'batch_size'.

    Returns:
      An instance of `tf.train.Optimizer`.
    """

    config = self.get_run_config()
    optimizer = self._create_optimizer_fn(self.use_summaries(params))
    if self._use_avg_model_params:
      optimizer = optimizers.create_moving_average_optimizer(optimizer)

      def create_swapping_saver_scaffold(saver=None):
        saver = optimizers.create_swapping_saver(optimizer)
        tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
        return tf.train.Scaffold(saver=saver)

      self._scaffold_fn = create_swapping_saver_scaffold
    if (self._use_sync_replicas_optimizer and (not self.is_device_tpu) and
        config is not None and config.num_worker_replicas > 1):
      optimizer = tf.train.SyncReplicasOptimizer(
          optimizer,
          replicas_to_aggregate=config.num_worker_replicas - 1,
          total_num_replicas=config.num_worker_replicas)
      self._sync_replicas_optimizer = optimizer
    return optimizer