Esempio n. 1
0
 def init_model(self, params):
     config = BertConfig()
     config.dropout_prob = params.dropout_prob
     config.num_labels = self.num_labels
     config.identifier = params.identifier
     self.config = config
     self.model = Classification(config)
Esempio n. 2
0
    wikidata = WikiDataset()
    pretraindata = PretrainDataset(wikidata)
    dataloader = torch.utils.data.DataLoader(
        pretraindata, BATCH_SIZE, collate_fn=pretrain_collate_fn
    )

    # Set the config of the bert
    config = BertConfig(
        num_hidden_layers=4,
        hidden_size=312,
        intermediate_size=1200,
        max_position_embeddings=1024,
    )

    if args.target == "mobert":
        config.num_labels = pretraindata.token_num + 1
        model = MoBert(config)
    elif args.target == "bert":
        model = BertForPreTraining(config)
    model = model.to(device)

    # Pre-train the MoBERT model
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    model.train()

    step = 1
    total_loss = 0
    total_loss_pre = 0
    total_loss_cl = 0
    start = time.time()
    for src, mlm, mask, nsp, mt, token_type_ids in dataloader:
def crt_model(F):

    name = F.model_name.lower()
    if name == "gru":
        model = GRU_model(emb_size=F.emb_size,
                          wv_model_file=F.wv_model_file,
                          h_size=F.rnn_dim,
                          num_label=F.num_label,
                          fix_emb=F.fix_emb,
                          dropout=F.dropout,
                          use_pretrain=F.emb_pretrain,
                          bidirectional=F.rnn_bid,
                          num_layer=F.rnn_num_layer)
    elif name == "rcnn":
        model = RCNN(
            emb_size=F.emb_size,
            wv_model_file=F.wv_model_file,
            h_size=F.rnn_dim,
            num_label=F.num_label,
            fix_emb=F.fix_emb,
            dropout=F.dropout,
            use_pretrain=F.emb_pretrain,
            num_layer=F.rnn_num_layer,
        )
    elif name == "bigru_atten2":
        model = BiGRU_Atten2(
            emb_size=F.emb_size,
            wv_model_file=F.wv_model_file,
            h_size=F.rnn_dim,
            num_label=F.num_label,
            fix_emb=F.fix_emb,
            dropout=F.dropout,
            use_pretrain=F.emb_pretrain,
            num_layer=F.rnn_num_layer,
        )
    elif name == "seq2seqatten":
        model = Seq2SeqAtten(
            emb_size=F.emb_size,
            wv_model_file=F.wv_model_file,
            h_size=F.rnn_dim,
            num_label=F.num_label,
            fix_emb=F.fix_emb,
            use_pretrain=F.emb_pretrain,
            dropout=F.dropout,
            num_layer=F.rnn_num_layer,
            bidirectional=F.rnn_bid,
        )
    elif name == "seq2seqatten2":
        model = Seq2SeqAtten2(
            emb_size=F.emb_size,
            wv_model_file=F.wv_model_file,
            h_size=F.rnn_dim,
            num_label=F.num_label,
            fix_emb=F.fix_emb,
            use_pretrain=F.emb_pretrain,
            dropout=F.dropout,
            num_layer=F.rnn_num_layer,
            bidirectional=F.rnn_bid,
        )
    elif name == "maskgru":
        model = MaskGRU_model(
            emb_size=F.emb_size,
            wv_model_file=F.wv_model_file,
            h_size=F.rnn_dim,
            num_label=F.num_label,
            fix_emb=F.fix_emb,
            use_pretrain=F.emb_pretrain,
            dropout=F.dropout,
            bidirectional=F.rnn_bid,
        )
    elif name == "bert":
        from transformers import BertConfig, BertForSequenceClassification
        vocab = read_pkl(F.vocab_file)
        F.vocab_size = len(vocab)
        conf = BertConfig(**F.__dict__)
        conf.num_labels = F.num_label
        model = BertForSequenceClassification(conf)
    elif name == "nezha":
        from model.nezha.modeling_nezha import BertConfig, BertForSequenceClassification
        vocab = read_pkl(F.vocab_file)
        F.vocab_size = len(vocab)
        conf = BertConfig.from_dict(F.__dict__)
        model = BertForSequenceClassification(conf, F.num_label)
    else:
        print("model not found!")
        sys.exit(-1)

    return model