コード例 #1
0
    def get_model(self, stage_id, old_model=None):
        """Build model for each stage."""
        stage_config: StackingStageConfig = self._progressive_config.stage_list[
            stage_id]
        if stage_config.small_encoder_config is not None:
            encoder_cfg: ecfg.TransformerEncoderConfig = ecfg.from_bert_encoder_config(
                self._model_config.encoder.bert,
                stage_config.small_encoder_config)
            model_cfg = copy.deepcopy(self._model_config)
            model_cfg.encoder = encoders.EncoderConfig(bert=encoder_cfg)
            model = self.build_small_model(model_cfg.as_dict())
        else:
            model_config = copy.deepcopy(self._model_config)
            if stage_config.override_num_layers is not None:
                model_config.encoder.bert.num_layers = stage_config.override_num_layers
            model = self.build_model(model_config)
            _ = model(model.inputs)

        if stage_id == 0:
            self.initialize(model)

        if stage_id > 0 and old_model is not None:
            logging.info('Stage %d copying weights.', stage_id)
            self.transform_model(small_model=old_model, model=model)
        return model
コード例 #2
0
    def initialize(self, model):
        init_dir_or_path = self.task_config.init_checkpoint
        logging.info('init dir_or_path: %s', init_dir_or_path)
        if not init_dir_or_path:
            return

        if tf.io.gfile.isdir(init_dir_or_path):
            init_dir = init_dir_or_path
            init_path = tf.train.latest_checkpoint(init_dir_or_path)
        else:
            init_path = init_dir_or_path
            init_dir = os.path.dirname(init_path)

        logging.info('init dir: %s', init_dir)
        logging.info('init path: %s', init_path)

        # restore from small model
        init_yaml_path = os.path.join(init_dir, 'params.yaml')
        if not tf.io.gfile.exists(init_yaml_path):
            init_yaml_path = os.path.join(os.path.dirname(init_dir),
                                          'params.yaml')
        with tf.io.gfile.GFile(init_yaml_path, 'r') as rf:
            init_yaml_config = yaml.safe_load(rf)
        init_model_config = init_yaml_config['task']['model']
        if 'progressive' in init_yaml_config['trainer']:
            stage_list = init_yaml_config['trainer']['progressive'][
                'stage_list']
            if stage_list:
                small_encoder_config = stage_list[-1]['small_encoder_config']
                if small_encoder_config is not None:
                    small_encoder_config = ecfg.from_bert_encoder_config(
                        init_model_config['encoder']['bert'],
                        small_encoder_config)
                    init_model_config['encoder'][
                        'bert'] = small_encoder_config.as_dict()

        # check if model size matches
        assert init_model_config['encoder']['bert'][
            'hidden_size'] == model.encoder_network.get_config()['hidden_size']

        # build small model
        small_model = self.build_small_model(init_model_config)
        ckpt = tf.train.Checkpoint(model=small_model)
        ckpt.restore(init_path).assert_existing_objects_matched()

        self.transform_model(small_model, model)