示例#1
0
    def __init__(self, cls_num_labels=2, tok_num_labels=2, tok2id=None):
        super(TaggerFromDebiaser, self).__init__()

        global ARGS
        global CUDA

        if ARGS.pointer_generator:
            self.debias_model = seq2seq_model.PointerSeq2Seq(
                vocab_size=len(tok2id),
                hidden_size=ARGS.hidden_size,
                emb_dim=768,
                dropout=0.2,
                tok2id=tok2id)
        else:
            self.debias_model = seq2seq_model.Seq2Seq(
                vocab_size=len(tok2id),
                hidden_size=ARGS.hidden_size,
                emb_dim=768,
                dropout=0.2,
                tok2id=tok2id)

        assert ARGS.debias_checkpoint
        print('LOADING DEBIASER FROM ' + ARGS.debias_checkpoint)
        self.debias_model.load_state_dict(torch.load(ARGS.debias_checkpoint))
        print('...DONE')

        self.cls_classifier = nn.Sequential(
            nn.Linear(ARGS.hidden_size, ARGS.hidden_size), nn.Dropout(0.1),
            nn.ReLU(), nn.Linear(ARGS.hidden_size, cls_num_labels),
            nn.Dropout(0.1))

        self.tok_classifier = nn.Sequential(
            nn.Linear(ARGS.hidden_size, ARGS.hidden_size), nn.Dropout(0.1),
            nn.ReLU(), nn.Linear(ARGS.hidden_size, tok_num_labels),
            nn.Dropout(0.1))
示例#2
0
tok2id['<del>'] = len(tok2id)

eval_dataloader, num_eval_examples = get_dataloader(
    ARGS.test,
    tok2id,
    ARGS.test_batch_size,
    ARGS.working_dir + '/test_data.pkl',
    test=True,
    add_del_tok=ARGS.add_del_tok)

# # # # # # # # ## # # # ## # # MODEL # # # # # # # # ## # # # ## # #

if ARGS.pointer_generator:
    debias_model = seq2seq_model.PointerSeq2Seq(
        vocab_size=len(tok2id),
        hidden_size=ARGS.hidden_size,
        emb_dim=768,
        dropout=0.2,
        tok2id=tok2id)  # 768 = bert hidden size
else:
    debias_model = seq2seq_model.Seq2Seq(vocab_size=len(tok2id),
                                         hidden_size=ARGS.hidden_size,
                                         emb_dim=768,
                                         dropout=0.2,
                                         tok2id=tok2id)

if ARGS.extra_features_top:
    tagging_model = tagging_model.BertForMultitaskWithFeaturesOnTop.from_pretrained(
        ARGS.bert_model,
        cls_num_labels=ARGS.num_categories,
        tok_num_labels=ARGS.num_tok_labels,
        cache_dir=ARGS.working_dir + '/cache',