def __init__(self, cfg: DictConfig, trainer: Trainer = None): """ Initializes model to use BERT model for GLUE tasks. """ self.data_dir = cfg.dataset.data_dir if not os.path.exists(self.data_dir): raise FileNotFoundError( "GLUE datasets not found. For more details on how to get the data, see: " "https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e" ) if cfg.task_name not in cfg.supported_tasks: raise ValueError( f'{cfg.task_name} not in supported task. Choose from {cfg.supported_tasks}' ) self.task_name = cfg.task_name # MNLI task has two separate dev sets: matched and mismatched cfg.train_ds.file_name = os.path.join(self.data_dir, cfg.train_ds.file_name) if self.task_name == "mnli": cfg.validation_ds.file_name = [ os.path.join(self.data_dir, 'dev_matched.tsv'), os.path.join(self.data_dir, 'dev_mismatched.tsv'), ] else: cfg.validation_ds.file_name = os.path.join( self.data_dir, cfg.validation_ds.file_name) logging.info( f'Using {cfg.validation_ds.file_name} for model evaluation.') self._setup_tokenizer(cfg.tokenizer) super().__init__(cfg=cfg, trainer=trainer) num_labels = GLUE_TASKS_NUM_LABELS[self.task_name] self.bert_model = get_lm_model( pretrained_model_name=cfg.language_model.pretrained_model_name, config_file=cfg.language_model.config_file, config_dict=OmegaConf.to_container(cfg.language_model.config) if cfg.language_model.config else None, checkpoint_file=cfg.language_model.lm_checkpoint, ) # uses [CLS] token for classification (the first token) if self.task_name == "sts-b": self.pooler = SequenceRegression( hidden_size=self.bert_model.config.hidden_size) self.loss = MSELoss() else: self.pooler = SequenceClassifier( hidden_size=self.bert_model.config.hidden_size, num_classes=num_labels, log_softmax=False) self.loss = CrossEntropyLoss() # Optimizer setup needs to happen after all model weights are ready self.setup_optimization(cfg.optim)
def __init__(self, cfg: DictConfig, trainer: Trainer = None): """ Initializes model to use BERT model for GLUE tasks. """ if cfg.task_name not in cfg.supported_tasks: raise ValueError( f'{cfg.task_name} not in supported task. Choose from {cfg.supported_tasks}' ) self.task_name = cfg.task_name # needed to setup validation on multiple datasets # MNLI task has two separate dev sets: matched and mismatched if not self._is_model_being_restored(): if self.task_name == "mnli": cfg.validation_ds.ds_item = [ os.path.join(cfg.dataset.data_dir, 'dev_matched.tsv'), os.path.join(cfg.dataset.data_dir, 'dev_mismatched.tsv'), ] else: cfg.validation_ds.ds_item = os.path.join( cfg.dataset.data_dir, cfg.validation_ds.ds_item) cfg.train_ds.ds_item = os.path.join(cfg.dataset.data_dir, cfg.train_ds.ds_item) logging.info( f'Using {cfg.validation_ds.ds_item} for model evaluation.') self.setup_tokenizer(cfg.tokenizer) super().__init__(cfg=cfg, trainer=trainer) num_labels = GLUE_TASKS_NUM_LABELS[self.task_name] self.bert_model = get_lm_model( pretrained_model_name=cfg.language_model.pretrained_model_name, config_file=self.register_artifact('language_model.config_file', cfg.language_model.config_file), config_dict=OmegaConf.to_container(cfg.language_model.config) if cfg.language_model.config else None, checkpoint_file=cfg.language_model.lm_checkpoint, vocab_file=self.register_artifact('tokenizer.vocab_file', cfg.tokenizer.vocab_file), ) # uses [CLS] token for classification (the first token) if self.task_name == "sts-b": self.pooler = SequenceRegression( hidden_size=self.bert_model.config.hidden_size) self.loss = MSELoss() else: self.pooler = SequenceClassifier( hidden_size=self.bert_model.config.hidden_size, num_classes=num_labels, log_softmax=False) self.loss = CrossEntropyLoss()
def _setup_loss(self): return MSELoss()