예제 #1
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """
        Initializes model to use BERT model for GLUE tasks.
        """
        self.data_dir = cfg.dataset.data_dir
        if not os.path.exists(self.data_dir):
            raise FileNotFoundError(
                "GLUE datasets not found. For more details on how to get the data, see: "
                "https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e"
            )

        if cfg.task_name not in cfg.supported_tasks:
            raise ValueError(
                f'{cfg.task_name} not in supported task. Choose from {cfg.supported_tasks}'
            )
        self.task_name = cfg.task_name

        # MNLI task has two separate dev sets: matched and mismatched
        cfg.train_ds.file_name = os.path.join(self.data_dir,
                                              cfg.train_ds.file_name)
        if self.task_name == "mnli":
            cfg.validation_ds.file_name = [
                os.path.join(self.data_dir, 'dev_matched.tsv'),
                os.path.join(self.data_dir, 'dev_mismatched.tsv'),
            ]
        else:
            cfg.validation_ds.file_name = os.path.join(
                self.data_dir, cfg.validation_ds.file_name)
        logging.info(
            f'Using {cfg.validation_ds.file_name} for model evaluation.')
        self._setup_tokenizer(cfg.tokenizer)

        super().__init__(cfg=cfg, trainer=trainer)

        num_labels = GLUE_TASKS_NUM_LABELS[self.task_name]

        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=cfg.language_model.config_file,
            config_dict=OmegaConf.to_container(cfg.language_model.config)
            if cfg.language_model.config else None,
            checkpoint_file=cfg.language_model.lm_checkpoint,
        )

        # uses [CLS] token for classification (the first token)
        if self.task_name == "sts-b":
            self.pooler = SequenceRegression(
                hidden_size=self.bert_model.config.hidden_size)
            self.loss = MSELoss()
        else:
            self.pooler = SequenceClassifier(
                hidden_size=self.bert_model.config.hidden_size,
                num_classes=num_labels,
                log_softmax=False)
            self.loss = CrossEntropyLoss()

        # Optimizer setup needs to happen after all model weights are ready
        self.setup_optimization(cfg.optim)
예제 #2
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """
        Initializes model to use BERT model for GLUE tasks.
        """

        if cfg.task_name not in cfg.supported_tasks:
            raise ValueError(
                f'{cfg.task_name} not in supported task. Choose from {cfg.supported_tasks}'
            )
        self.task_name = cfg.task_name

        # needed to setup validation on multiple datasets
        # MNLI task has two separate dev sets: matched and mismatched
        if not self._is_model_being_restored():
            if self.task_name == "mnli":
                cfg.validation_ds.ds_item = [
                    os.path.join(cfg.dataset.data_dir, 'dev_matched.tsv'),
                    os.path.join(cfg.dataset.data_dir, 'dev_mismatched.tsv'),
                ]
            else:
                cfg.validation_ds.ds_item = os.path.join(
                    cfg.dataset.data_dir, cfg.validation_ds.ds_item)
            cfg.train_ds.ds_item = os.path.join(cfg.dataset.data_dir,
                                                cfg.train_ds.ds_item)
            logging.info(
                f'Using {cfg.validation_ds.ds_item} for model evaluation.')

        self.setup_tokenizer(cfg.tokenizer)
        super().__init__(cfg=cfg, trainer=trainer)

        num_labels = GLUE_TASKS_NUM_LABELS[self.task_name]

        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=self.register_artifact('language_model.config_file',
                                               cfg.language_model.config_file),
            config_dict=OmegaConf.to_container(cfg.language_model.config)
            if cfg.language_model.config else None,
            checkpoint_file=cfg.language_model.lm_checkpoint,
            vocab_file=self.register_artifact('tokenizer.vocab_file',
                                              cfg.tokenizer.vocab_file),
        )

        # uses [CLS] token for classification (the first token)
        if self.task_name == "sts-b":
            self.pooler = SequenceRegression(
                hidden_size=self.bert_model.config.hidden_size)
            self.loss = MSELoss()
        else:
            self.pooler = SequenceClassifier(
                hidden_size=self.bert_model.config.hidden_size,
                num_classes=num_labels,
                log_softmax=False)
            self.loss = CrossEntropyLoss()
예제 #3
0
 def _setup_loss(self):
     return MSELoss()