예제 #1
0
    def __init__(self, *args, **kwargs):
        super(BertModule, self).__init__(*args, **kwargs)

        self.save_hyperparameters()

        if self.hparams.use_pretrained:
            self.model = BertForSequenceClassification.from_pretrained(
                self.hparams.pretrained_model_name, num_labels=self.hparams.num_labels
            )

            if self.hparams.freeze:
                for param in self.model.bert.parameters():
                    param.requires_grad = False
        else:
            self.model = BertForSequenceClassification(
                BertConfig.from_pretrained(
                    self.hparams.pretrained_model_name,
                    num_labels=self.hparams.num_labels,
                )
            )

        self.metric = datasets.load_metric(
            "glue",
            self.hparams.task_name,
            experiment_id=datetime.now().strftime("%d-%m-%Y_%H-%M-%S"),
        )
    def __init__(self, config: Config, train_config: TrainerConfig):
        super(RetriverForIntentDetection, self).__init__()
        self.config = config
        self.train_config = train_config

        # 1. init the bert model
        bert_config = BertConfig.from_pretrained(config.pretrain_model)
        bert_config.num_labels = config.num_labels
        self.classifier = BertForSequenceClassification(bert_config)
예제 #3
0
            print(idx, " datasets loaded")
        mini_datasets.append(torch.load(
            os.path.join(args.mini_datasets_path, dataset_fname)
        ))
    
    # set up config
    if input_model == 'charbert':
        config = BertConfig.from_pretrained(
            os.path.join('pretrained_models', "general_character_bert"),
            num_labels=2
        )
        config.update({"return_dict": False})
            # set up backbone
        logging.info('Loading classification model with charbert backbone')

        model = BertForSequenceClassification(config=config)
        # this backbone is overwritten by fine-tuned backbone below
        model.bert = CharacterBertModel.from_pretrained(
            os.path.join('pretrained_models', "general_character_bert"),
            config=config
        )

        if load_checkpoint:
            #load pretrained weights
            print("Args.output_dir before state_dict = ", args.output_dir)
            state_dict = torch.load(
                os.path.join(args.output_dir, 'pytorch_model.bin'), map_location='cpu'
            )
            model.load_state_dict(state_dict, strict=True)
    else:
        # config = BertConfig.from_pretrained(
예제 #4
0
def classifier(model_name, num_labels) -> BertForSequenceClassification:
    bert_config = BertConfig.from_pretrained(model_name)
    bert_config.num_labels = num_labels
    bert_for_sequence_classification = BertForSequenceClassification(
        bert_config)
    return bert_for_sequence_classification
예제 #5
0
                                     tokenizer=tokenizer,
                                     debug=args.debug)

    num_train_steps_per_epoch = math.ceil(
        len(train_dataset) / args.train_batch_size)
    num_train_steps = num_train_steps_per_epoch * args.num_train_epochs
    num_warmup_steps = int(args.warmup_ratio * num_train_steps)

    # set up backbone
    if args.backbone == 'charbert':
        logging.info('Loading %s model', "general_character_bert")
        config = BertConfig.from_pretrained(os.path.join(
            'pretrained_models', "general_character_bert"),
                                            num_labels=2)
        config.update({"return_dict": False})
        model = BertForSequenceClassification(config=config)
        #model = CharacterBertForSequenceClassification(config=config)
        model.bert = CharacterBertModel.from_pretrained(os.path.join(
            'pretrained_models', "general_character_bert"),
                                                        config=config)
    else:
        logging.info('Loading %s model', "general_bert")
        config = BertConfig.from_pretrained(os.path.join(
            'pretrained_models', "bert-base-uncased"),
                                            num_labels=2)
        config.update({"return_dict": False})
        model = BertForSequenceClassification(config=config)
        model.bert = BertModel.from_pretrained(os.path.join(
            'pretrained_models', "bert-base-uncased"),
                                               config=config)
    model.to(args.device)