def __init__(self, model, directory, metric, max_to_keep=8, checkpoint_name="ckpt"): """ Initializes a custom checkpoint manager. Args: model: A keras model. directory: The path to a directory in which to write checkpoints. metric: A metric object. max_to_keep: The maximum checkpoint numbers to keep. checkpoint_name: The name of each checkpoint. """ if directory is None: directory = compat.get_saver_or_default().directory if not directory.endswith("/"): directory += "/" directory += "best" super(KeepBestCheckpointSaver, self).__init__( checkpoint=tf.train.Checkpoint(**dict([(x.name.split(":")[0], x) for x in model.weights])), directory=directory, max_to_keep=max_to_keep) self._metric = metric self._checkpoint_name = checkpoint_name logging.info( "Creates custom keep-best checkpoint manager for directory: {}". format(directory))
def __init__(self, model_configs, save_checkpoint_steps=1000, checkpoint_manager=None): """ Initializes custom checkpoint callback. Args: model_configs: A dict of configurations for restoring. save_checkpoint_steps: An int scalar, saving checkpoint this every steps. checkpoint_manager: A CheckpointManager instance. """ super(CustomCheckpointCallback, self).__init__() self._checkpoint_manager = checkpoint_manager if self._checkpoint_manager is None: self._checkpoint_manager = compat.get_saver_or_default() self._model_configs = model_configs self._save_checkpoint_steps = save_checkpoint_steps
def __init__(self, model, directory, metric, max_to_keep=8, checkpoint_name="ckpt"): """ Initializes a custom checkpoint manager. Args: model: A keras model. directory: The path to a directory in which to write checkpoints. metric: A metric object. max_to_keep: The maximum checkpoint numbers to keep. checkpoint_name: The name of each checkpoint. """ if directory is None: directory = compat.get_saver_or_default().directory if not directory.endswith("/"): directory += "/" directory += "best_avg" self._checkpoint_name = checkpoint_name self._traced_vars = dict([(x.name.split(":")[0], x) for x in model.weights]) self._traced_var_names = list(self._traced_vars.keys()) self._traced_var_numpys = [] self._metric = metric v_numpys = dict([(n, v.numpy()) for n, v in self._traced_vars.items()]) with tf.distribute.OneDeviceStrategy(device="/cpu:0").scope(): self._avg_traced_vars = dict([(n, tf.Variable(v, dtype=v.dtype, name=n + "_avg")) for n, v in v_numpys.items()]) super(AverageCheckpointSaver, self).__init__( directory=directory, max_to_keep=max_to_keep, checkpoint=tf.train.Checkpoint(**self._avg_traced_vars)) logging.info("Create checkpoint manager for averaged checkpoint " "of the latest {} checkpoints to dir: {}".format( self._max_to_keep, self.directory))
def run(self): """ Training a neural model. Step 1: Create training model Step 2: Restore checkpoint/pretrain model/global_step if exists. Step 3: Fetch training data. Step 5: Fetch training training. Step 6: TRAIN!!! """ if self._hvd_backend == "horovod": import horovod.tensorflow.keras as hvd elif self._hvd_backend == "byteps": import byteps.tensorflow.keras as hvd tfds = training_utils.build_datasets(compat.ModeKeys.TRAIN, self.strategy, self.custom_dataset, self.task) if isinstance(self.custom_dataset, MultipleDataset): _tfds = None for _, ds in tfds.items(): if _tfds is None: _tfds = ds else: _tfds = _tfds.concatenate(ds) tfds = _tfds tfds = tfds.prefetch(tf.data.experimental.AUTOTUNE) # Step 1: create a model with training_utils.get_strategy_scope(self.strategy): inps = self.task.create_inputs(compat.ModeKeys.TRAIN) formatted_inps = self.task.example_to_input( inps, compat.ModeKeys.TRAIN) model_out = self.model(formatted_inps, is_training=True) for metric_layer in self.task.build_metric_layer(): model_out = metric_layer([formatted_inps, model_out]) if (LooseVersion(tf.__version__) < LooseVersion("2.3") or LooseVersion(tf.__version__) >= LooseVersion("2.5")): logging.info( f"Warning: Need further check on AccumgradKerasModel when TF version={tf.__version__}. " f"Here we ignore update_cycle={self._update_cycle}, " f"clip_value={self._clip_value}, clip_norm={self._clip_norm}." ) keras_model = tf.keras.Model(inps, model_out) elif compat.IS_PREV_TF_2_4_0: from neurst.training.gradaccum_keras_model import TF23GradAccumKerasModel keras_model = TF23GradAccumKerasModel( inps, model_out, update_cycle=self._update_cycle, clip_value=self._clip_value, clip_norm=self._clip_norm, freeze_variables=self._freeze_variables) else: keras_model = GradAccumKerasModel( inps, model_out, update_cycle=self._update_cycle, clip_value=self._clip_value, clip_norm=self._clip_norm, freeze_variables=self._freeze_variables) loss = self._criterion.reduce_loss(formatted_inps, model_out) if compat.is_tf_tensor(loss) or isinstance(loss, (list, tuple)): keras_model.add_loss(loss) elif isinstance(loss, dict): for _name, _loss in loss.items(): keras_model.add_loss(_loss) keras_model.add_metric(_loss, name=_name + "_mean", aggregation="mean") else: raise ValueError("criterion.reduce_loss returns " "unsupported value of type: {}".format( type(loss))) self._restore_ckpt_or_pretrain() self._lr_schedule = build_lr_schedule(self._lr_schedule_args) if self._pruning_schedule is not None: self._optimizer = create_pruning_optimizer( self._optimizer, self.model, self._pruning_schedule, pruning_variable_pattern=self._pruning_variable_pattern, nopruning_variable_pattern=self. _nopruning_variable_pattern, keep_prune_property=True) self._optimizer = training_utils.handle_fp16_and_distributed_optimizer( self._optimizer, self._lr_schedule, self._hvd_backend) if self._hvd_backend is None: keras_model.compile(self._optimizer) else: # NOTE: we already add Horovod DistributedOptimizer in `_handle_fp16_and_distributed_optimizer`. # Horovod: Specify `experimental_run_tf_function=False` to ensure TensorFlow # uses hvd.DistributedOptimizer() to compute gradients. keras_model.compile(self._optimizer, experimental_run_tf_function=False) keras_model.summary() summary_model_variables(self.model, self._freeze_variables) # initialize the checkpoint manager _ = compat.get_saver_or_default( self.model, self.model_dir, max_to_keep=self._checkpoints_max_to_keep) # build training training if not self._tb_log_dir: self._tb_log_dir = os.path.join(self.model_dir, "train") training_callbacks = [ MetricReductionCallback(self.strategy, self._summary_steps, self._tb_log_dir, device="GPU:0", lr_schedule=self._lr_schedule) ] if self._hvd_backend is None or hvd.rank() == 0: training_callbacks.append( CustomCheckpointCallback( self.task.model_configs(self.model), save_checkpoint_steps=self._save_checkpoint_steps)) if self._validator is not None: training_callbacks.append( self._validator.build(self.strategy, self.task, self.model)) if self._hvd_backend is not None: # Horovod: average metrics among workers at the end of every epoch. # # Note: This callback must be in the list before the ReduceLROnPlateau, # TensorBoard or other metrics-based training. # NOTE!!! HERE we already integrate the metric averaging behaviour into the MetricReductionCallback. # training_callbacks.insert(0, hvd.callbacks.MetricAverageCallback(device="GPU:0")) # Horovod: broadcast initial variable states from rank 0 to all other processes. # This is necessary to ensure consistent initialization of all workers when # training is started with random weights or restored from a checkpoint. training_callbacks.insert( 0, hvd.callbacks.BroadcastGlobalVariablesCallback(0, device="GPU:0")) if self._lr_schedule is not None: training_callbacks.append( LearningRateScheduler(self._lr_schedule)) if self._experimental_count_batch_num: logging.info("Scanning the dataset......") iterator = iter( training_utils.maybe_distribution_dataset(self.strategy, tfds)) cnt = 0 for _ in iterator: cnt += 1 logging.info(f"Total {cnt} batches per EPOCH.") history = keras_model.fit( map_data_for_keras(tfds.repeat()), initial_epoch=0, epochs=1, steps_per_epoch=self._train_steps, # * args["update_cycle"], verbose=2, callbacks=training_callbacks) logging.info(history.history)
def run(self): """ Repeats to call validator's validate function if new checkponts are observed. Step 1: Build model. Step 2: Fetch training status. while True: Step 3: Restore checkpoints. Step 4: Validate. """ if self.task is None or self.model is None: model_cfg_waiting_rounds = self._maximum_waiting_time // self._waiting_interval for i in range(model_cfg_waiting_rounds): try: args = ModelConfigs.load(self._model_dir) break except FileNotFoundError: logging.info( f"Fail to load model configs from directory: {self.model_dir}. " f"Wait for another {self._waiting_interval}s, " f"patience={model_cfg_waiting_rounds - 1 - i}.") time.sleep(self._waiting_interval) self._task = build_task(args) self._model = self.task.build_model(args) # initialize the checkpoint manager saver = compat.get_saver_or_default(self.model, self.model_dir) # enable tensorboard if self._tb_log_dir is None: self._tb_log_dir = os.path.join( self.model_dir, "validation_{}".format(int(time.time()))) file_writer = tf.summary.create_file_writer(self._tb_log_dir) file_writer.set_as_default() # create training self._validator.build(self.strategy, self.task, self.model) last_triggered_step = None accumulated_waiting_time = 0 this_waiting_interval = next_waiting_interval = self._waiting_interval while True: bad_cnt = 0 while bad_cnt < 5: try: ckpt_state = tf.train.get_checkpoint_state(self.model_dir) break except ValueError: bad_cnt += 1 time.sleep(5) logging.info(traceback.format_exc()) if bad_cnt >= 5: ckpt_state = tf.train.get_checkpoint_state( self.model_dir) ckpts_to_be_restore = None if ckpt_state is None: logging.info( f"No checkpoint in directory: {self.model_dir}. Please wait." ) else: all_ckpts = [ (t, x) for t, x in zip(ckpt_state.all_model_checkpoint_timestamps, ckpt_state.all_model_checkpoint_paths) ] global_steps_to_be_restore = [] ckpts_to_be_restore = [] for ckpt in all_ckpts[::-1]: step = compat.hack_global_step(ckpt[1]) if last_triggered_step is None or step > last_triggered_step: ckpts_to_be_restore.insert(0, ckpt) global_steps_to_be_restore.insert(0, step) if len(ckpts_to_be_restore) > 0: accumulated_waiting_time = 0 _start_time = time.time() for step, (timestamp, ckpt) in zip(global_steps_to_be_restore, ckpts_to_be_restore): stat = saver.restore(ckpt) if not stat: logging.info( f"Fail to restore checkpoint from {ckpt}. Skip...") continue logging.info( f"Checkpoint with global_step={step} triggered on {timestamp}" ) self._validator.validate(step) last_triggered_step = step this_waiting_interval = max( this_waiting_interval - int(time.time() - _start_time), 10) tf.summary.flush(file_writer) if ckpts_to_be_restore is None: pass elif len(ckpts_to_be_restore) > 1: this_waiting_interval = int(this_waiting_interval * 1. * (len(ckpts_to_be_restore) // 2) / len(ckpts_to_be_restore)) next_waiting_interval = this_waiting_interval elif len(ckpts_to_be_restore) == 0: next_waiting_interval = min( int(this_waiting_interval * 4. / 3.), self._waiting_interval) this_waiting_interval = this_waiting_interval // 2 accumulated_waiting_time += this_waiting_interval if accumulated_waiting_time > self._maximum_waiting_time: logging.info( f"Waited for maximum patience: {self._maximum_waiting_time}s" ) break time.sleep(this_waiting_interval) this_waiting_interval = next_waiting_interval