def create_optimizer(self): """Create the optimizer used for training. This function optionally wraps the base optimizer with SyncReplicasOptimizer (aggregrate gradients across devices). Returns: An instance of `tf.train.Optimizer`. """ config = self.get_run_config() optimizer = self._create_optimizer_fn() if self._use_avg_model_params: optimizer = optimizers.create_moving_average_optimizer(optimizer) def create_swapping_saver_scaffold(saver=None): saver = optimizers.create_swapping_saver(optimizer) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) return tf.train.Scaffold(saver=saver) self._scaffold_fn = create_swapping_saver_scaffold if (self._use_sync_replicas_optimizer and (not self.is_device_tpu) and config is not None and config.num_worker_replicas > 1): optimizer = tf.train.SyncReplicasOptimizer( optimizer, replicas_to_aggregate=config.num_worker_replicas - 1, total_num_replicas=config.num_worker_replicas) self._sync_replicas_optimizer = optimizer return optimizer
def create_optimizer(self, params): """Create the optimizer used for training. This function optionally wraps the base optimizer with SyncReplicasOptimizer (aggregrate gradients across devices). Args: params: An optional dict of hyper parameters that will be passed into input_fn and model_fn. Keys are names of parameters, values are basic python types. There are reserved keys for TPUEstimator, including 'batch_size'. Returns: An instance of `tf.train.Optimizer`. """ config = self.get_run_config() optimizer = self._create_optimizer_fn(self.use_summaries(params)) if self._use_avg_model_params: optimizer = optimizers.create_moving_average_optimizer(optimizer) def create_swapping_saver_scaffold(saver=None): saver = optimizers.create_swapping_saver(optimizer) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) return tf.train.Scaffold(saver=saver) self._scaffold_fn = create_swapping_saver_scaffold if (self._use_sync_replicas_optimizer and (not self.is_device_tpu) and config is not None and config.num_worker_replicas > 1): optimizer = tf.train.SyncReplicasOptimizer( optimizer, replicas_to_aggregate=config.num_worker_replicas - 1, total_num_replicas=config.num_worker_replicas) self._sync_replicas_optimizer = optimizer return optimizer