コード例 #1
0
ファイル: noam_schedule.py プロジェクト: xiaming9880/neurst
    def __init__(self, args):
        """Initialize configuration of the learning rate schedule.

        Args:
          args: A dict of full parameters.
        """
        super(NoamSchedule, self).__init__()

        self._dmodel = args["dmodel"]
        self._warmup_steps = tf.cast(args["warmup_steps"], tf.float32)
        self._initial_step = compat.get_registered_initial_step()
        logging.info("Initialize NoamSchedule from global step={}. "
                     "The result learning rate will be scaled by {}"
                     "".format(int(self._initial_step), args["initial_factor"]))
        self._initial_step = tf.convert_to_tensor(self._initial_step, dtype=tf.float32)
        _initial_learning_rate = args["initial_factor"]
        self._initial_learning_rate = tf.convert_to_tensor(_initial_learning_rate, tf.float32)
        _end_learning_rate = args["end_factor"]
        if (_end_learning_rate is not None and args["start_decay_at"] is not None
            and args["decay_steps"] is not None):
            start_decay_at = args["start_decay_at"]
            decay_steps = args["decay_steps"]
            logging.info("\tThe scaling factor will start to decay from {} to {} at step {} "
                         "and finish at step {}.".format(_initial_learning_rate, _end_learning_rate,
                                                         start_decay_at, start_decay_at + decay_steps))
        else:
            _end_learning_rate = _initial_learning_rate
            start_decay_at = 0
            decay_steps = 1
        self._end_learning_rate = tf.convert_to_tensor(_end_learning_rate, tf.float32)
        self._start_decay_at = tf.convert_to_tensor(start_decay_at, tf.float32)
        self._decay_steps = tf.convert_to_tensor(decay_steps, tf.float32)
コード例 #2
0
    def __init__(self, args):
        """Initialize configuration of the learning rate schedule.

        Args:
          args: A dict of full parameters.
        """
        super(InverseSquareRootSchedule, self).__init__()
        self._initial_step = compat.get_registered_initial_step()
        logging.info(
            f"Initialize InverseSquareRootSchedule from global step={self._initial_step}. "
        )
        self._initial_step = tf.convert_to_tensor(self._initial_step,
                                                  dtype=tf.float32)
        self._lr = tf.cast(args["peak_lr"], tf.float32)
        self._init_lr = tf.cast(args["init_lr"], tf.float32)
        self._warmup_steps = tf.cast(args["warmup_steps"], tf.float32)
        self._lr_step = (self._lr - self._init_lr) / self._warmup_steps
        self._decay_factor = self._lr * self._warmup_steps**0.5
コード例 #3
0
    def __init__(self, args):
        """Initialize configuration of the learning rate schedule.

        Args:
          args: A dict of full parameters.
        """
        super(PiecewiseSchedule, self).__init__()
        self._schedule_steps = args["schedule_steps"]
        self._schedule_lrs = args["schedule_lrs"]
        assert len(self._schedule_steps) + 1 == len(self._schedule_lrs)
        self._initial_step = compat.get_registered_initial_step()
        logging.info("Initialize PiecewiseSchedule from global step={}. "
                     "The learning rate will be:".format(self._initial_step))
        for idx, (step, lr) in enumerate(
                zip(self._schedule_steps, self._schedule_lrs)):
            if idx == 0:
                logging.info("    linear warmup from 0~{} for {} steps".format(
                    lr, step))
            else:
                logging.info("    {} from step={} to step={}".format(
                    lr, self._schedule_steps[idx - 1], step))
        logging.info("    {} for step>{}".format(self._schedule_lrs[-1],
                                                 self._schedule_steps[-1]))
コード例 #4
0
 def on_train_begin(self, logs=None):
     super(CentralizedCallback, self).on_train_begin(logs)
     self.__global_step = compat.get_registered_initial_step()
コード例 #5
0
 def __init__(self):
     super(CentralizedCallback, self).__init__()
     self.__global_step = compat.get_registered_initial_step()