예제 #1
0
    def __init__(self,
                 model,
                 directory,
                 metric,
                 max_to_keep=8,
                 checkpoint_name="ckpt"):
        """ Initializes a custom checkpoint manager.

        Args:
            model: A keras model.
            directory: The path to a directory in which to write checkpoints.
            metric: A metric object.
            max_to_keep: The maximum checkpoint numbers to keep.
            checkpoint_name: The name of each checkpoint.
        """
        if directory is None:
            directory = compat.get_saver_or_default().directory
            if not directory.endswith("/"):
                directory += "/"
            directory += "best"
        super(KeepBestCheckpointSaver, self).__init__(
            checkpoint=tf.train.Checkpoint(**dict([(x.name.split(":")[0], x)
                                                   for x in model.weights])),
            directory=directory,
            max_to_keep=max_to_keep)
        self._metric = metric
        self._checkpoint_name = checkpoint_name
        logging.info(
            "Creates custom keep-best checkpoint manager for directory: {}".
            format(directory))
예제 #2
0
    def __init__(self,
                 model_configs,
                 save_checkpoint_steps=1000,
                 checkpoint_manager=None):
        """ Initializes custom checkpoint callback.

        Args:
            model_configs: A dict of configurations for restoring.
            save_checkpoint_steps: An int scalar, saving checkpoint this every steps.
            checkpoint_manager: A CheckpointManager instance.
        """
        super(CustomCheckpointCallback, self).__init__()
        self._checkpoint_manager = checkpoint_manager
        if self._checkpoint_manager is None:
            self._checkpoint_manager = compat.get_saver_or_default()
        self._model_configs = model_configs
        self._save_checkpoint_steps = save_checkpoint_steps
예제 #3
0
    def __init__(self,
                 model,
                 directory,
                 metric,
                 max_to_keep=8,
                 checkpoint_name="ckpt"):
        """ Initializes a custom checkpoint manager.

        Args:
            model: A keras model.
            directory: The path to a directory in which to write checkpoints.
            metric: A metric object.
            max_to_keep: The maximum checkpoint numbers to keep.
            checkpoint_name: The name of each checkpoint.
        """
        if directory is None:
            directory = compat.get_saver_or_default().directory
            if not directory.endswith("/"):
                directory += "/"
            directory += "best_avg"
        self._checkpoint_name = checkpoint_name
        self._traced_vars = dict([(x.name.split(":")[0], x)
                                  for x in model.weights])
        self._traced_var_names = list(self._traced_vars.keys())
        self._traced_var_numpys = []
        self._metric = metric

        v_numpys = dict([(n, v.numpy()) for n, v in self._traced_vars.items()])
        with tf.distribute.OneDeviceStrategy(device="/cpu:0").scope():
            self._avg_traced_vars = dict([(n,
                                           tf.Variable(v,
                                                       dtype=v.dtype,
                                                       name=n + "_avg"))
                                          for n, v in v_numpys.items()])
        super(AverageCheckpointSaver, self).__init__(
            directory=directory,
            max_to_keep=max_to_keep,
            checkpoint=tf.train.Checkpoint(**self._avg_traced_vars))
        logging.info("Create checkpoint manager for averaged checkpoint "
                     "of the latest {} checkpoints to dir: {}".format(
                         self._max_to_keep, self.directory))
예제 #4
0
파일: trainer.py 프로젝트: lileicc/neurst
    def run(self):
        """ Training a neural model.

        Step 1: Create training model
        Step 2: Restore checkpoint/pretrain model/global_step if exists.
        Step 3: Fetch training data.
        Step 5: Fetch training training.
        Step 6: TRAIN!!!
        """
        if self._hvd_backend == "horovod":
            import horovod.tensorflow.keras as hvd
        elif self._hvd_backend == "byteps":
            import byteps.tensorflow.keras as hvd

        tfds = training_utils.build_datasets(compat.ModeKeys.TRAIN,
                                             self.strategy,
                                             self.custom_dataset, self.task)
        if isinstance(self.custom_dataset, MultipleDataset):
            _tfds = None
            for _, ds in tfds.items():
                if _tfds is None:
                    _tfds = ds
                else:
                    _tfds = _tfds.concatenate(ds)
            tfds = _tfds
        tfds = tfds.prefetch(tf.data.experimental.AUTOTUNE)
        # Step 1: create a model
        with training_utils.get_strategy_scope(self.strategy):
            inps = self.task.create_inputs(compat.ModeKeys.TRAIN)
            formatted_inps = self.task.example_to_input(
                inps, compat.ModeKeys.TRAIN)
            model_out = self.model(formatted_inps, is_training=True)
            for metric_layer in self.task.build_metric_layer():
                model_out = metric_layer([formatted_inps, model_out])
            if (LooseVersion(tf.__version__) < LooseVersion("2.3")
                    or LooseVersion(tf.__version__) >= LooseVersion("2.5")):
                logging.info(
                    f"Warning: Need further check on AccumgradKerasModel when TF version={tf.__version__}. "
                    f"Here we ignore update_cycle={self._update_cycle}, "
                    f"clip_value={self._clip_value}, clip_norm={self._clip_norm}."
                )
                keras_model = tf.keras.Model(inps, model_out)
            elif compat.IS_PREV_TF_2_4_0:
                from neurst.training.gradaccum_keras_model import TF23GradAccumKerasModel
                keras_model = TF23GradAccumKerasModel(
                    inps,
                    model_out,
                    update_cycle=self._update_cycle,
                    clip_value=self._clip_value,
                    clip_norm=self._clip_norm,
                    freeze_variables=self._freeze_variables)
            else:
                keras_model = GradAccumKerasModel(
                    inps,
                    model_out,
                    update_cycle=self._update_cycle,
                    clip_value=self._clip_value,
                    clip_norm=self._clip_norm,
                    freeze_variables=self._freeze_variables)

            loss = self._criterion.reduce_loss(formatted_inps, model_out)
            if compat.is_tf_tensor(loss) or isinstance(loss, (list, tuple)):
                keras_model.add_loss(loss)
            elif isinstance(loss, dict):
                for _name, _loss in loss.items():
                    keras_model.add_loss(_loss)
                    keras_model.add_metric(_loss,
                                           name=_name + "_mean",
                                           aggregation="mean")
            else:
                raise ValueError("criterion.reduce_loss returns "
                                 "unsupported value of type: {}".format(
                                     type(loss)))
            self._restore_ckpt_or_pretrain()
            self._lr_schedule = build_lr_schedule(self._lr_schedule_args)
            if self._pruning_schedule is not None:
                self._optimizer = create_pruning_optimizer(
                    self._optimizer,
                    self.model,
                    self._pruning_schedule,
                    pruning_variable_pattern=self._pruning_variable_pattern,
                    nopruning_variable_pattern=self.
                    _nopruning_variable_pattern,
                    keep_prune_property=True)
            self._optimizer = training_utils.handle_fp16_and_distributed_optimizer(
                self._optimizer, self._lr_schedule, self._hvd_backend)
            if self._hvd_backend is None:
                keras_model.compile(self._optimizer)
            else:
                # NOTE: we already add Horovod DistributedOptimizer in `_handle_fp16_and_distributed_optimizer`.
                # Horovod: Specify `experimental_run_tf_function=False` to ensure TensorFlow
                # uses hvd.DistributedOptimizer() to compute gradients.
                keras_model.compile(self._optimizer,
                                    experimental_run_tf_function=False)
            keras_model.summary()
            summary_model_variables(self.model, self._freeze_variables)
        # initialize the checkpoint manager
        _ = compat.get_saver_or_default(
            self.model,
            self.model_dir,
            max_to_keep=self._checkpoints_max_to_keep)
        # build training training
        if not self._tb_log_dir:
            self._tb_log_dir = os.path.join(self.model_dir, "train")

        training_callbacks = [
            MetricReductionCallback(self.strategy,
                                    self._summary_steps,
                                    self._tb_log_dir,
                                    device="GPU:0",
                                    lr_schedule=self._lr_schedule)
        ]
        if self._hvd_backend is None or hvd.rank() == 0:
            training_callbacks.append(
                CustomCheckpointCallback(
                    self.task.model_configs(self.model),
                    save_checkpoint_steps=self._save_checkpoint_steps))
            if self._validator is not None:
                training_callbacks.append(
                    self._validator.build(self.strategy, self.task,
                                          self.model))
        if self._hvd_backend is not None:
            # Horovod: average metrics among workers at the end of every epoch.
            #
            # Note: This callback must be in the list before the ReduceLROnPlateau,
            # TensorBoard or other metrics-based training.
            # NOTE!!! HERE we already integrate the metric averaging behaviour into the MetricReductionCallback.
            # training_callbacks.insert(0, hvd.callbacks.MetricAverageCallback(device="GPU:0"))

            # Horovod: broadcast initial variable states from rank 0 to all other processes.
            # This is necessary to ensure consistent initialization of all workers when
            # training is started with random weights or restored from a checkpoint.
            training_callbacks.insert(
                0,
                hvd.callbacks.BroadcastGlobalVariablesCallback(0,
                                                               device="GPU:0"))
            if self._lr_schedule is not None:
                training_callbacks.append(
                    LearningRateScheduler(self._lr_schedule))

        if self._experimental_count_batch_num:
            logging.info("Scanning the dataset......")
            iterator = iter(
                training_utils.maybe_distribution_dataset(self.strategy, tfds))
            cnt = 0
            for _ in iterator:
                cnt += 1
            logging.info(f"Total {cnt} batches per EPOCH.")

        history = keras_model.fit(
            map_data_for_keras(tfds.repeat()),
            initial_epoch=0,
            epochs=1,
            steps_per_epoch=self._train_steps,  # * args["update_cycle"],
            verbose=2,
            callbacks=training_callbacks)
        logging.info(history.history)
예제 #5
0
    def run(self):
        """ Repeats to call validator's validate function if new checkponts are observed.

        Step 1: Build model.
        Step 2: Fetch training status.
        while True:
            Step 3: Restore checkpoints.
            Step 4: Validate.
        """
        if self.task is None or self.model is None:
            model_cfg_waiting_rounds = self._maximum_waiting_time // self._waiting_interval
            for i in range(model_cfg_waiting_rounds):
                try:
                    args = ModelConfigs.load(self._model_dir)
                    break
                except FileNotFoundError:
                    logging.info(
                        f"Fail to load model configs from directory: {self.model_dir}. "
                        f"Wait for another {self._waiting_interval}s, "
                        f"patience={model_cfg_waiting_rounds - 1 - i}.")
                    time.sleep(self._waiting_interval)
            self._task = build_task(args)
            self._model = self.task.build_model(args)
        # initialize the checkpoint manager
        saver = compat.get_saver_or_default(self.model, self.model_dir)
        # enable tensorboard
        if self._tb_log_dir is None:
            self._tb_log_dir = os.path.join(
                self.model_dir, "validation_{}".format(int(time.time())))
        file_writer = tf.summary.create_file_writer(self._tb_log_dir)
        file_writer.set_as_default()
        # create training
        self._validator.build(self.strategy, self.task, self.model)
        last_triggered_step = None
        accumulated_waiting_time = 0
        this_waiting_interval = next_waiting_interval = self._waiting_interval
        while True:
            bad_cnt = 0
            while bad_cnt < 5:
                try:
                    ckpt_state = tf.train.get_checkpoint_state(self.model_dir)
                    break
                except ValueError:
                    bad_cnt += 1
                    time.sleep(5)
                    logging.info(traceback.format_exc())
                    if bad_cnt >= 5:
                        ckpt_state = tf.train.get_checkpoint_state(
                            self.model_dir)

            ckpts_to_be_restore = None
            if ckpt_state is None:
                logging.info(
                    f"No checkpoint in directory: {self.model_dir}. Please wait."
                )
            else:
                all_ckpts = [
                    (t, x)
                    for t, x in zip(ckpt_state.all_model_checkpoint_timestamps,
                                    ckpt_state.all_model_checkpoint_paths)
                ]
                global_steps_to_be_restore = []
                ckpts_to_be_restore = []
                for ckpt in all_ckpts[::-1]:
                    step = compat.hack_global_step(ckpt[1])
                    if last_triggered_step is None or step > last_triggered_step:
                        ckpts_to_be_restore.insert(0, ckpt)
                        global_steps_to_be_restore.insert(0, step)
                if len(ckpts_to_be_restore) > 0:
                    accumulated_waiting_time = 0
                _start_time = time.time()
                for step, (timestamp, ckpt) in zip(global_steps_to_be_restore,
                                                   ckpts_to_be_restore):
                    stat = saver.restore(ckpt)
                    if not stat:
                        logging.info(
                            f"Fail to restore checkpoint from {ckpt}. Skip...")
                        continue
                    logging.info(
                        f"Checkpoint with global_step={step} triggered on {timestamp}"
                    )
                    self._validator.validate(step)
                    last_triggered_step = step
                this_waiting_interval = max(
                    this_waiting_interval - int(time.time() - _start_time), 10)
                tf.summary.flush(file_writer)
            if ckpts_to_be_restore is None:
                pass
            elif len(ckpts_to_be_restore) > 1:
                this_waiting_interval = int(this_waiting_interval * 1. *
                                            (len(ckpts_to_be_restore) // 2) /
                                            len(ckpts_to_be_restore))
                next_waiting_interval = this_waiting_interval
            elif len(ckpts_to_be_restore) == 0:
                next_waiting_interval = min(
                    int(this_waiting_interval * 4. / 3.),
                    self._waiting_interval)
                this_waiting_interval = this_waiting_interval // 2
            accumulated_waiting_time += this_waiting_interval
            if accumulated_waiting_time > self._maximum_waiting_time:
                logging.info(
                    f"Waited for maximum patience: {self._maximum_waiting_time}s"
                )
                break
            time.sleep(this_waiting_interval)
            this_waiting_interval = next_waiting_interval