def _build_and_restore_model(self): """ Build a single model or ensemble model. """ model_dirs = flatten_string_list(self.model_dir) if len(model_dirs) == 1: model = self.model stat = restore_checkpoint_if_possible(model, model_dirs[0]) if not stat: logging.info("WARNING: Fail to restore checkpoint from {}. " "We assume this was done on purpose. ".format( model_dirs[0])) else: logging.info( "We assume models for ensemble are all based on the same task." ) multiple_models = [] for idx, one_model_dir in enumerate(model_dirs): name_prefix = "ensemble_{}".format(idx) logging.info("Create model for {} from {}".format( name_prefix, one_model_dir)) cfg = ModelConfigs.load(one_model_dir) this_model = self.task.build_model(cfg, name=name_prefix) stat = restore_checkpoint_if_possible(this_model, one_model_dir) if not stat: logging.info( "WARNING: Fail to restore checkpoint from {}. " "We assume this was done on purpose. ".format( one_model_dir)) multiple_models.append(this_model) model = EncoderDecoderEnsembleModel.new(multiple_models) return model
def _restore_ckpt_or_pretrain(self): """ restoring checkpoint from model_dir or pretrain_model dir. """ stat = restore_checkpoint_if_possible(self.model, self.model_dir) continue_training = False if stat: logging.info( f"Successfully restoring checkpoint from model_dir={self.model_dir}" ) continue_training = True else: logging.info( f"No checkpoint restored from model_dir={self.model_dir}") if self._pretrain_model: for pt, pt_varname in zip(self._pretrain_model, self._pretrain_variable_pattern): logging.info(f"Trying to restore from pretrain_model={pt}") logging.info( "NOTE THAT, one must first check the variable names in this checkpoint, " "otherwise no variables will be restored.") restore_checkpoint_if_possible(self.model, pt, var_name_pattern=pt_varname) if self._initial_global_step is None and continue_training: _step = compat.hack_global_step(self.model_dir) if _step: compat.register_initial_step( _step or 0) # must do this before creating optimizer and training logging.info(f"Restored initial global step={_step}") else: compat.register_initial_step(self._initial_global_step or 0)
def build_task_and_model(model_dir, wait_k): model_dirs = flatten_string_list(model_dir) cfgs = ModelConfigs.load(model_dirs[0]) cfgs["task.params"]["wait_k"] = wait_k task = build_task(cfgs) models = [] for md in model_dirs: models.append(task.build_model(ModelConfigs.load(md))) restore_checkpoint_if_possible(models[-1], md) return task, models
def run(self): """ Evaluation on a existing model. Step 1: Build model. Step 2: Builds evaluation dataset. Step 3: Restore checkpoints. Step 4: Evaluate and reduce metric. """ assert not isinstance(self.custom_dataset, MultipleDataset), ( "SequenceEvaluator only supports single dataset.") with training_utils.get_strategy_scope(self.strategy): tfds = training_utils.build_datasets(compat.ModeKeys.EVAL, self.strategy, self.custom_dataset, self.task) keras_model = self.build_evaluation_model(self.task, self.model, self._criterion) keras_model.summary() summary_model_variables(keras_model) # Step 4: Restore checkpoints. stat = restore_checkpoint_if_possible(self.model, self.model_dir) if not stat: logging.info(f"WARNING: Fail to restore checkpoint from {self.model_dir}. " "We assume this was done on purpose. ") # Step 5: Evaluate and reduce metric. predict_fn = keras_model.make_predict_function() iterator = iter(training_utils.maybe_distribution_dataset( self.strategy, tfds.prefetch(tf.data.experimental.AUTOTUNE))) with tf.io.gfile.GFile(self._output_file, "w") as fw: while True: try: preds = predict_fn(iterator) for pred in self._criterion.reduce_sample_metrics(preds): fw.write(str(pred) + "\n") except (StopIteration, tf.errors.OutOfRangeError): break
def run(self): """ Evaluation on a existing model. Step 1: Build model. Step 2: Builds evaluation dataset. Step 3: Restore checkpoints. Step 4: Evaluate and reduce metric. """ with training_utils.get_strategy_scope(self.strategy): tfds = training_utils.build_datasets(compat.ModeKeys.EVAL, self.strategy, self.custom_dataset, self.task, cache=True) keras_model = self.build_evaluation_model(self.task, self.model, self._criterion) keras_model.summary() summary_model_variables(keras_model) # Step 4: Restore checkpoints. stat = restore_checkpoint_if_possible(self.model, self.model_dir) if not stat: logging.info( f"WARNING: Fail to restore checkpoint from {self.model_dir}. " "We assume this was done on purpose. ") # Step 5: Evaluate and reduce metric. start_time = time.time() results, avg_res, whole_res = training_utils.reduce_eval_results( self._criterion, self.custom_dataset, training_utils.make_predictions(self.strategy, keras_model, tfds, self.custom_dataset)) logging.info("Evaluation elapsed: %.2fs", time.time() - start_time) def _display(res, name=None): if name: logging.info(f"Evaluation Results ({name}):") for k, v in res.items(): logging.info(" %s: %.2f", k, v) if not isinstance(self.custom_dataset, MultipleDataset): _display(results) else: for name, res in results.items(): _display(res, name) _display( avg_res, f"on average by weights {self.custom_dataset.sample_weights}") _display(whole_res, "mixed")
import tensorflow as tf from neurst.layers.quantization import QuantLayer from neurst.models.transformer import Transformer from neurst.tasks import build_task from neurst.utils.checkpoints import restore_checkpoint_if_possible from neurst.utils.configurable import ModelConfigs model_dir = sys.argv[1] model_configs = ModelConfigs.load(model_dir) QuantLayer.global_init(model_configs["enable_quant"], **model_configs["quant_params"]) task = build_task(model_configs) model: Transformer = task.build_model(model_configs) restore_checkpoint_if_possible(model, model_dir) clip_max = model._encoder._stacking_layers[0][1]._layer._conv1.traced[ "kernel"].clip_max weight_clip_max = tf.maximum(clip_max, 0.0) weight_clip_max = tf.cast(weight_clip_max, tf.float32) bits_tmp = float(2**(QuantLayer.quant_bits - 1)) weight_clip_min = -weight_clip_max * bits_tmp / (bits_tmp - 1) print("The quantized weight of encoder layer0's first ffn") print( tf.quantization.quantize( model._encoder._stacking_layers[0][1]._layer._conv1.kernel, weight_clip_min, clip_max, tf.qint8))