def __init__(self, args, **kwargs): """ Initializes a util class for training neural models. """ super(Trainer, self).__init__(**kwargs) self._tb_log_dir = args["tb_log_dir"] self._train_steps = args["train_steps"] self._summary_steps = args["summary_steps"] self._save_checkpoint_steps = args["save_checkpoint_steps"] self._checkpoints_max_to_keep = args["checkpoints_max_to_keep"] self._initial_global_step = args["initial_global_step"] self._pretrain_model = flatten_string_list(args["pretrain_model"]) self._pretrain_variable_pattern = args["pretrain_variable_pattern"] if self._pretrain_model and self._pretrain_variable_pattern is None: self._pretrain_variable_pattern = [None] * len( self._pretrain_model) assert ( (self._pretrain_model is None and self._pretrain_variable_pattern is None) or len( self._pretrain_model) == len(self._pretrain_variable_pattern) or len(self._pretrain_model) == 1 ), ("`pretrain_variable_pattern` must match with `pretrain_model`.") if self._pretrain_model is not None and self._pretrain_variable_pattern is None: self._pretrain_variable_pattern = [None] * len( self._pretrain_model) self._update_cycle = args["update_cycle"] self._clip_value = args["clip_value"] self._clip_norm = args["clip_norm"] self._hvd_backend = self.strategy if self.strategy in [ "byteps", "horovod" ] else None with training_utils.get_strategy_scope(self.strategy): self._criterion = build_criterion(args) self._criterion.set_model(self.model) self._lr_schedule = build_lr_schedule(args) optimizer = build_optimizer(args) assert optimizer is not None, "optimizer parameters must be provided for training." self._optimizer = _handle_fp16_and_distributed_optimizer( optimizer, self._lr_schedule, self._hvd_backend) self._validator = build_validator(args) self._experimental_count_batch_num = args[ "experimental_count_batch_num"]
def run(self): """ Training a neural model. Step 1: Create training model Step 2: Restore checkpoint/pretrain model/global_step if exists. Step 3: Fetch training data. Step 5: Fetch training training. Step 6: TRAIN!!! """ if self._hvd_backend == "horovod": import horovod.tensorflow.keras as hvd elif self._hvd_backend == "byteps": import byteps.tensorflow.keras as hvd tfds = training_utils.build_datasets(compat.ModeKeys.TRAIN, self.strategy, self.custom_dataset, self.task) if isinstance(self.custom_dataset, MultipleDataset): _tfds = None for _, ds in tfds.items(): if _tfds is None: _tfds = ds else: _tfds = _tfds.concatenate(ds) tfds = _tfds tfds = tfds.prefetch(tf.data.experimental.AUTOTUNE) # Step 1: create a model with training_utils.get_strategy_scope(self.strategy): inps = self.task.create_inputs(compat.ModeKeys.TRAIN) formatted_inps = self.task.example_to_input( inps, compat.ModeKeys.TRAIN) model_out = self.model(formatted_inps, is_training=True) for metric_layer in self.task.build_metric_layer(): model_out = metric_layer([formatted_inps, model_out]) if (LooseVersion(tf.__version__) < LooseVersion("2.3") or LooseVersion(tf.__version__) >= LooseVersion("2.5")): logging.info( f"Warning: Need further check on AccumgradKerasModel when TF version={tf.__version__}. " f"Here we ignore update_cycle={self._update_cycle}, " f"clip_value={self._clip_value}, clip_norm={self._clip_norm}." ) keras_model = tf.keras.Model(inps, model_out) elif compat.IS_PREV_TF_2_4_0: from neurst.training.gradaccum_keras_model import TF23GradAccumKerasModel keras_model = TF23GradAccumKerasModel( inps, model_out, update_cycle=self._update_cycle, clip_value=self._clip_value, clip_norm=self._clip_norm, freeze_variables=self._freeze_variables) else: keras_model = GradAccumKerasModel( inps, model_out, update_cycle=self._update_cycle, clip_value=self._clip_value, clip_norm=self._clip_norm, freeze_variables=self._freeze_variables) loss = self._criterion.reduce_loss(formatted_inps, model_out) if compat.is_tf_tensor(loss) or isinstance(loss, (list, tuple)): keras_model.add_loss(loss) elif isinstance(loss, dict): for _name, _loss in loss.items(): keras_model.add_loss(_loss) keras_model.add_metric(_loss, name=_name + "_mean", aggregation="mean") else: raise ValueError("criterion.reduce_loss returns " "unsupported value of type: {}".format( type(loss))) self._restore_ckpt_or_pretrain() self._lr_schedule = build_lr_schedule(self._lr_schedule_args) if self._pruning_schedule is not None: self._optimizer = create_pruning_optimizer( self._optimizer, self.model, self._pruning_schedule, pruning_variable_pattern=self._pruning_variable_pattern, nopruning_variable_pattern=self. _nopruning_variable_pattern, keep_prune_property=True) self._optimizer = training_utils.handle_fp16_and_distributed_optimizer( self._optimizer, self._lr_schedule, self._hvd_backend) if self._hvd_backend is None: keras_model.compile(self._optimizer) else: # NOTE: we already add Horovod DistributedOptimizer in `_handle_fp16_and_distributed_optimizer`. # Horovod: Specify `experimental_run_tf_function=False` to ensure TensorFlow # uses hvd.DistributedOptimizer() to compute gradients. keras_model.compile(self._optimizer, experimental_run_tf_function=False) keras_model.summary() summary_model_variables(self.model, self._freeze_variables) # initialize the checkpoint manager _ = compat.get_saver_or_default( self.model, self.model_dir, max_to_keep=self._checkpoints_max_to_keep) # build training training if not self._tb_log_dir: self._tb_log_dir = os.path.join(self.model_dir, "train") training_callbacks = [ MetricReductionCallback(self.strategy, self._summary_steps, self._tb_log_dir, device="GPU:0", lr_schedule=self._lr_schedule) ] if self._hvd_backend is None or hvd.rank() == 0: training_callbacks.append( CustomCheckpointCallback( self.task.model_configs(self.model), save_checkpoint_steps=self._save_checkpoint_steps)) if self._validator is not None: training_callbacks.append( self._validator.build(self.strategy, self.task, self.model)) if self._hvd_backend is not None: # Horovod: average metrics among workers at the end of every epoch. # # Note: This callback must be in the list before the ReduceLROnPlateau, # TensorBoard or other metrics-based training. # NOTE!!! HERE we already integrate the metric averaging behaviour into the MetricReductionCallback. # training_callbacks.insert(0, hvd.callbacks.MetricAverageCallback(device="GPU:0")) # Horovod: broadcast initial variable states from rank 0 to all other processes. # This is necessary to ensure consistent initialization of all workers when # training is started with random weights or restored from a checkpoint. training_callbacks.insert( 0, hvd.callbacks.BroadcastGlobalVariablesCallback(0, device="GPU:0")) if self._lr_schedule is not None: training_callbacks.append( LearningRateScheduler(self._lr_schedule)) if self._experimental_count_batch_num: logging.info("Scanning the dataset......") iterator = iter( training_utils.maybe_distribution_dataset(self.strategy, tfds)) cnt = 0 for _ in iterator: cnt += 1 logging.info(f"Total {cnt} batches per EPOCH.") history = keras_model.fit( map_data_for_keras(tfds.repeat()), initial_epoch=0, epochs=1, steps_per_epoch=self._train_steps, # * args["update_cycle"], verbose=2, callbacks=training_callbacks) logging.info(history.history)