def get_model(self, stage_id, old_model=None): """Build model for each stage.""" stage_config: StackingStageConfig = self._progressive_config.stage_list[ stage_id] if stage_config.small_encoder_config is not None: encoder_cfg: ecfg.TransformerEncoderConfig = ecfg.from_bert_encoder_config( self._model_config.encoder.bert, stage_config.small_encoder_config) model_cfg = copy.deepcopy(self._model_config) model_cfg.encoder = encoders.EncoderConfig(bert=encoder_cfg) model = self.build_small_model(model_cfg.as_dict()) else: model_config = copy.deepcopy(self._model_config) if stage_config.override_num_layers is not None: model_config.encoder.bert.num_layers = stage_config.override_num_layers model = self.build_model(model_config) _ = model(model.inputs) if stage_id == 0: self.initialize(model) if stage_id > 0 and old_model is not None: logging.info('Stage %d copying weights.', stage_id) self.transform_model(small_model=old_model, model=model) return model
def initialize(self, model): init_dir_or_path = self.task_config.init_checkpoint logging.info('init dir_or_path: %s', init_dir_or_path) if not init_dir_or_path: return if tf.io.gfile.isdir(init_dir_or_path): init_dir = init_dir_or_path init_path = tf.train.latest_checkpoint(init_dir_or_path) else: init_path = init_dir_or_path init_dir = os.path.dirname(init_path) logging.info('init dir: %s', init_dir) logging.info('init path: %s', init_path) # restore from small model init_yaml_path = os.path.join(init_dir, 'params.yaml') if not tf.io.gfile.exists(init_yaml_path): init_yaml_path = os.path.join(os.path.dirname(init_dir), 'params.yaml') with tf.io.gfile.GFile(init_yaml_path, 'r') as rf: init_yaml_config = yaml.safe_load(rf) init_model_config = init_yaml_config['task']['model'] if 'progressive' in init_yaml_config['trainer']: stage_list = init_yaml_config['trainer']['progressive'][ 'stage_list'] if stage_list: small_encoder_config = stage_list[-1]['small_encoder_config'] if small_encoder_config is not None: small_encoder_config = ecfg.from_bert_encoder_config( init_model_config['encoder']['bert'], small_encoder_config) init_model_config['encoder'][ 'bert'] = small_encoder_config.as_dict() # check if model size matches assert init_model_config['encoder']['bert'][ 'hidden_size'] == model.encoder_network.get_config()['hidden_size'] # build small model small_model = self.build_small_model(init_model_config) ckpt = tf.train.Checkpoint(model=small_model) ckpt.restore(init_path).assert_existing_objects_matched() self.transform_model(small_model, model)