def __init__(self, cls_num_labels=2, tok_num_labels=2, tok2id=None): super(TaggerFromDebiaser, self).__init__() global ARGS global CUDA if ARGS.pointer_generator: self.debias_model = seq2seq_model.PointerSeq2Seq( vocab_size=len(tok2id), hidden_size=ARGS.hidden_size, emb_dim=768, dropout=0.2, tok2id=tok2id) else: self.debias_model = seq2seq_model.Seq2Seq( vocab_size=len(tok2id), hidden_size=ARGS.hidden_size, emb_dim=768, dropout=0.2, tok2id=tok2id) assert ARGS.debias_checkpoint print('LOADING DEBIASER FROM ' + ARGS.debias_checkpoint) self.debias_model.load_state_dict(torch.load(ARGS.debias_checkpoint)) print('...DONE') self.cls_classifier = nn.Sequential( nn.Linear(ARGS.hidden_size, ARGS.hidden_size), nn.Dropout(0.1), nn.ReLU(), nn.Linear(ARGS.hidden_size, cls_num_labels), nn.Dropout(0.1)) self.tok_classifier = nn.Sequential( nn.Linear(ARGS.hidden_size, ARGS.hidden_size), nn.Dropout(0.1), nn.ReLU(), nn.Linear(ARGS.hidden_size, tok_num_labels), nn.Dropout(0.1))
tok2id['<del>'] = len(tok2id) eval_dataloader, num_eval_examples = get_dataloader( ARGS.test, tok2id, ARGS.test_batch_size, ARGS.working_dir + '/test_data.pkl', test=True, add_del_tok=ARGS.add_del_tok) # # # # # # # # ## # # # ## # # MODEL # # # # # # # # ## # # # ## # # if ARGS.pointer_generator: debias_model = seq2seq_model.PointerSeq2Seq( vocab_size=len(tok2id), hidden_size=ARGS.hidden_size, emb_dim=768, dropout=0.2, tok2id=tok2id) # 768 = bert hidden size else: debias_model = seq2seq_model.Seq2Seq(vocab_size=len(tok2id), hidden_size=ARGS.hidden_size, emb_dim=768, dropout=0.2, tok2id=tok2id) if ARGS.extra_features_top: tagging_model = tagging_model.BertForMultitaskWithFeaturesOnTop.from_pretrained( ARGS.bert_model, cls_num_labels=ARGS.num_categories, tok_num_labels=ARGS.num_tok_labels, cache_dir=ARGS.working_dir + '/cache',