def main(): opt = {'train': 'train_wiki', 'val': 'val_wiki', 'test': 'val_semeval', 'adv': None, 'trainN': 10, 'N': 5, 'K': 1, 'Q': 1, 'batch_size': 1, 'train_iter': 5, 'val_iter': 1000, 'test_iter': 10, 'val_step': 2000, 'model': 'pair', 'encoder': 'bert', 'max_length': 64, 'lr': -1, 'weight_decay': 1e-5, 'dropout': 0.0, 'na_rate':0, 'optim': 'adam', 'load_ckpt': './src/FewRel-master/Checkpoints/ckpt_5-Way-1-Shot_FewRel.pth', 'save_ckpt': './src/FewRel-master/Checkpoints/post_ckpt_5-Way-1-Shot_FewRel.pth', 'fp16':False, 'only_test': False, 'ckpt_name': 'Checkpoints/ckpt_5-Way-1-Shot_FewRel.pth', 'pair': True, 'pretrain_ckpt': '', 'cat_entity_rep': False, 'dot': False, 'no_dropout': False, 'mask_entity': False, 'use_sgd_for_bert': False } opt = DotMap(opt) trainN = opt.trainN N = opt.N K = opt.K Q = opt.Q batch_size = opt.batch_size model_name = opt.model encoder_name = opt.encoder max_length = opt.max_length print("{}-way-{}-shot Few-Shot Relation Classification".format(N, K)) print("model: {}".format(model_name)) print("encoder: {}".format(encoder_name)) print("max_length: {}".format(max_length)) pretrain_ckpt = opt.pretrain_ckpt or 'bert-base-uncased' sentence_encoder = BERTPAIRSentenceEncoder( pretrain_ckpt, max_length) train_data_loader = get_loader_pair(opt.train, sentence_encoder, N=trainN, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size, encoder_name=encoder_name) val_data_loader = get_loader_pair(opt.val, sentence_encoder, N=N, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size, encoder_name=encoder_name) test_data_loader = get_loader_pair(opt.test, sentence_encoder, N=N, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size, encoder_name=encoder_name) if opt.optim == 'sgd': pytorch_optim = optim.SGD elif opt.optim == 'adam': pytorch_optim = optim.Adam elif opt.optim == 'adamw': pytorch_optim = AdamW else: raise NotImplementedError framework = FewShotREFramework(train_data_loader, val_data_loader, test_data_loader) prefix = '-'.join([model_name, encoder_name, opt.train, opt.val, str(N), str(K)]) if opt.na_rate != 0: prefix += '-na{}'.format(opt.na_rate) if opt.dot: prefix += '-dot' if opt.cat_entity_rep: prefix += '-catentity' if len(opt.ckpt_name) > 0: prefix += '-' + opt.ckpt_name model = Pair(sentence_encoder, hidden_size=opt.hidden_size) if not os.path.exists('checkpoint'): os.mkdir('checkpoint') ckpt = 'checkpoint/{}.pth.tar'.format(prefix) if opt.save_ckpt: ckpt = opt.save_ckpt if torch.cuda.is_available(): model.cuda() if not opt.only_test: if encoder_name in ['bert', 'roberta']: bert_optim = True else: bert_optim = False if opt.lr == -1: if bert_optim: opt.lr = 2e-5 else: opt.lr = 1e-1 framework.train(model, prefix, batch_size, trainN, N, K, Q, pytorch_optim=pytorch_optim, load_ckpt=opt.load_ckpt, save_ckpt=ckpt, na_rate=opt.na_rate, val_step=opt.val_step, fp16=opt.fp16, pair=opt.pair, train_iter=opt.train_iter, val_iter=opt.val_iter, bert_optim=bert_optim, learning_rate=opt.lr, use_sgd_for_bert=opt.use_sgd_for_bert) else: ckpt = opt.load_ckpt if ckpt is None: print("Warning: --load_ckpt is not specified. Will load Huggingface pre-trained checkpoint.") ckpt = 'none' acc = framework.eval(model, batch_size, N, K, Q, opt.test_iter, na_rate=opt.na_rate, ckpt=ckpt, pair=opt.pair) print("RESULT: %.2f" % (acc * 100))
def main(): parser = argparse.ArgumentParser() parser.add_argument('--train', default='train_wiki', help='train file') parser.add_argument('--val', default='val_wiki', help='val file') parser.add_argument('--test', default='test_wiki', help='test file') parser.add_argument('--adv', default=None, help='adv file') parser.add_argument('--trainN', default=10, type=int, help='N in train') parser.add_argument('--N', default=5, type=int, help='N way') parser.add_argument('--K', default=5, type=int, help='K shot') parser.add_argument('--Q', default=5, type=int, help='Num of query per class') parser.add_argument('--batch_size', default=4, type=int, help='batch size') parser.add_argument('--train_iter', default=30000, type=int, help='num of iters in training') parser.add_argument('--val_iter', default=1000, type=int, help='num of iters in validation') parser.add_argument('--test_iter', default=10000, type=int, help='num of iters in testing') parser.add_argument('--val_step', default=2000, type=int, help='val after training how many iters') parser.add_argument('--model', default='proto', help='model name') parser.add_argument('--encoder', default='cnn', help='encoder: cnn or bert or roberta') parser.add_argument('--max_length', default=128, type=int, help='max length') parser.add_argument('--lr', default=1e-1, type=float, help='learning rate') parser.add_argument('--weight_decay', default=1e-5, type=float, help='weight decay') parser.add_argument('--dropout', default=0.0, type=float, help='dropout rate') parser.add_argument('--na_rate', default=0, type=int, help='NA rate (NA = Q * na_rate)') parser.add_argument('--grad_iter', default=1, type=int, help='accumulate gradient every x iterations') parser.add_argument('--optim', default='sgd', help='sgd / adam / adamw') parser.add_argument('--hidden_size', default=230, type=int, help='hidden size') parser.add_argument('--load_ckpt', default=None, help='load ckpt') parser.add_argument('--save_ckpt', default=None, help='save ckpt') parser.add_argument('--fp16', action='store_true', help='use nvidia apex fp16') parser.add_argument('--only_test', action='store_true', help='only test') # only for bert / roberta parser.add_argument('--pair', action='store_true', help='use pair model') parser.add_argument('--pretrain_ckpt', default=None, help='bert / roberta pre-trained checkpoint') parser.add_argument( '--cat_entity_rep', action='store_true', help='concatenate entity representation as sentence rep') # only for prototypical networks parser.add_argument('--dot', action='store_true', help='use dot instead of L2 distance for proto') opt = parser.parse_args() trainN = opt.trainN N = opt.N K = opt.K Q = opt.Q batch_size = opt.batch_size model_name = opt.model encoder_name = opt.encoder max_length = opt.max_length print("{}-way-{}-shot Few-Shot Relation Classification".format(N, K)) print("model: {}".format(model_name)) print("encoder: {}".format(encoder_name)) print("max_length: {}".format(max_length)) if encoder_name == 'cnn': try: glove_mat = np.load('./pretrain/glove/glove_mat.npy') glove_word2id = json.load( open('./pretrain/glove/glove_word2id.json')) except: raise Exception( "Cannot find glove files. Run glove/download_glove.sh to download glove files." ) sentence_encoder = CNNSentenceEncoder(glove_mat, glove_word2id, max_length) elif encoder_name == 'bert': pretrain_ckpt = opt.pretrain_ckpt or 'bert-base-uncased' if opt.pair: sentence_encoder = BERTPAIRSentenceEncoder(pretrain_ckpt, max_length) else: sentence_encoder = BERTSentenceEncoder( pretrain_ckpt, max_length, cat_entity_rep=opt.cat_entity_rep) elif encoder_name == 'roberta': pretrain_ckpt = opt.pretrain_ckpt or 'roberta-base' if opt.pair: sentence_encoder = RobertaPAIRSentenceEncoder( pretrain_ckpt, max_length) else: sentence_encoder = RobertaSentenceEncoder( pretrain_ckpt, max_length, cat_entity_rep=opt.cat_entity_rep) else: raise NotImplementedError if opt.pair: train_data_loader = get_loader_pair(opt.train, sentence_encoder, N=trainN, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size, encoder_name=encoder_name) val_data_loader = get_loader_pair(opt.val, sentence_encoder, N=N, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size, encoder_name=encoder_name) test_data_loader = get_loader_pair(opt.test, sentence_encoder, N=N, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size, encoder_name=encoder_name) else: train_data_loader = get_loader(opt.train, sentence_encoder, N=trainN, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size) val_data_loader = get_loader(opt.val, sentence_encoder, N=N, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size) test_data_loader = get_loader(opt.test, sentence_encoder, N=N, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size) if opt.adv: adv_data_loader = get_loader_unsupervised(opt.adv, sentence_encoder, N=trainN, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size) if opt.optim == 'sgd': pytorch_optim = optim.SGD elif opt.optim == 'adam': pytorch_optim = optim.Adam elif opt.optim == 'adamw': from transformers import AdamW pytorch_optim = AdamW else: raise NotImplementedError if opt.adv: d = Discriminator(opt.hidden_size) framework = FewShotREFramework(train_data_loader, val_data_loader, test_data_loader, adv_data_loader, adv=opt.adv, d=d) else: framework = FewShotREFramework(train_data_loader, val_data_loader, test_data_loader) prefix = '-'.join( [model_name, encoder_name, opt.train, opt.val, str(N), str(K)]) if opt.adv is not None: prefix += '-adv_' + opt.adv if opt.na_rate != 0: prefix += '-na{}'.format(opt.na_rate) if opt.dot: prefix += '-dot' if opt.cat_entity_rep: prefix += '-catentity' if model_name == 'proto': model = Proto(sentence_encoder, dot=opt.dot) elif model_name == 'gnn': model = GNN(sentence_encoder, N, hidden_size=opt.hidden_size) elif model_name == 'snail': model = SNAIL(sentence_encoder, N, K, hidden_size=opt.hidden_size) elif model_name == 'metanet': model = MetaNet(N, K, sentence_encoder.embedding, max_length) elif model_name == 'siamese': model = Siamese(sentence_encoder, hidden_size=opt.hidden_size, dropout=opt.dropout) elif model_name == 'pair': model = Pair(sentence_encoder, hidden_size=opt.hidden_size) else: raise NotImplementedError if not os.path.exists('checkpoint'): os.mkdir('checkpoint') ckpt = 'checkpoint/{}.pth.tar'.format(prefix) if opt.save_ckpt: ckpt = opt.save_ckpt if torch.cuda.is_available(): model.cuda() if not opt.only_test: if encoder_name in ['bert', 'roberta']: bert_optim = True else: bert_optim = False framework.train(model, prefix, batch_size, trainN, N, K, Q, pytorch_optim=pytorch_optim, load_ckpt=opt.load_ckpt, save_ckpt=ckpt, na_rate=opt.na_rate, val_step=opt.val_step, fp16=opt.fp16, pair=opt.pair, train_iter=opt.train_iter, val_iter=opt.val_iter, bert_optim=bert_optim) else: ckpt = opt.load_ckpt acc = framework.eval(model, batch_size, N, K, Q, opt.test_iter, na_rate=opt.na_rate, ckpt=ckpt, pair=opt.pair) print("RESULT: %.2f" % (acc * 100))
def main(): parser = argparse.ArgumentParser() # 模型相关 parser.add_argument('--do_train', default=True, type=bool, help='do train') parser.add_argument('--do_eval', default=True, type=bool, help='do eval') parser.add_argument('--do_predict', default=False, type=bool, help='do predict') parser.add_argument('--do_cn_eval', default=False, type=bool, help='do CN eval') parser.add_argument( '--proto_emb', default=False, help='Get root cause proto emb or sentence emb. Require do_predict=True' ) parser.add_argument('--train_file', default='./data/source_add_CN_V2.xlsx', help='source file') parser.add_argument('--trainN', default=5, type=int, help='N in train') parser.add_argument('--N', default=5, type=int, help='N way') parser.add_argument('--K', default=3, type=int, help='K shot') parser.add_argument('--Q', default=2, type=int, help='Num of query per class') parser.add_argument('--batch_size', default=8, type=int, help='batch size') parser.add_argument('--train_iter', default=10000, type=int, help='num of iters in training') parser.add_argument('--warmup_rate', default=0.1, type=float) parser.add_argument('--max_length', default=128, type=int, help='max length') parser.add_argument('--lr', default=1e-5, type=float, help='learning rate') parser.add_argument('--dropout', default=0.0, type=float, help='dropout rate') parser.add_argument('--seed', default=100, type=int) # 100 # 保存与加载 parser.add_argument('--load_ckpt', default='./check_points/model_54000.bin', help='load ckpt') parser.add_argument('--save_ckpt', default='./check_points/', help='save ckpt') parser.add_argument('--save_emb', default='./data/emb.json', help='save embedding') parser.add_argument('--save_root_emb', default='./data/root_emb.json', help='save embedding') parser.add_argument('--use_cuda', default=True, help='whether to use cuda') parser.add_argument('--eval_step', default=100) parser.add_argument('--save_step', default=500) parser.add_argument('--threshold', default=5) # bert pretrain parser.add_argument("--vocab_file", default="./pretrain/vocab.txt", type=str, help="Init vocab to resume training from.") parser.add_argument("--config_path", default="./pretrain/bert_config.json", type=str, help="Init config to resume training from.") parser.add_argument("--init_checkpoint", default="./pretrain/pytorch_model.bin", type=str, help="Init checkpoint to resume training from.") opt = parser.parse_args() trainN = opt.trainN K = opt.K Q = opt.Q batch_size = opt.batch_size max_length = opt.max_length logger.info("{}-way-{}-shot Few-Shot Dignose".format(trainN, K)) logger.info("max_length: {}".format(max_length)) random.seed(opt.seed) np.random.seed(opt.seed) torch.manual_seed(opt.seed) if not os.path.exists(opt.save_ckpt): os.mkdir(opt.save_ckpt) bert_tokenizer = BertTokenizer.from_pretrained(opt.vocab_file) bert_config = BertConfig.from_pretrained(opt.config_path) bert_model = BertModel.from_pretrained(opt.init_checkpoint, config=bert_config) model = Proto(bert_model, opt) if opt.use_cuda: model.cuda() if opt.do_train: train_data, eval_data = read_data(opt.train_file, opt.threshold) train_data_loader = get_loader(train_data, bert_tokenizer, max_length, N=trainN, K=K, Q=Q, batch_size=batch_size) framework = FewShotREFramework(tokenizer=bert_tokenizer, train_data_loader=train_data_loader, train_data=train_data, eval_data=eval_data) framework.train(model, batch_size, trainN, K, Q, opt) if opt.do_eval: train_data, eval_data = read_data(opt.train_file, opt.threshold) state_dict = torch.load(opt.load_ckpt) own_state = bert_model.state_dict() for name, param in state_dict.items(): name = name.replace('sentence_encoder.module.', '') if name not in own_state: continue own_state[name].copy_(param) step = opt.load_ckpt.split('/')[-1].replace('model_', '').split('.')[0] bert_model.eval() train_data_emb, train_rc_emb = util.get_emb(bert_model, bert_tokenizer, train_data, opt) eval_data_emb, _ = util.get_emb(bert_model, bert_tokenizer, eval_data, opt) acc1 = util.single_acc(train_data_emb, eval_data_emb) acc2 = util.proto_acc(train_rc_emb, eval_data_emb) acc3 = util.policy_acc(train_data_emb, eval_data_emb) logger.info( "single eval accuracy: [top1: %.4f] [top3: %.4f] [top5: %.4f]" % (acc1[0], acc1[1], acc1[2])) logger.info( "proto eval accuracy: [top1: %.4f] [top3: %.4f] [top5: %.4f]" % (acc2[0], acc2[1], acc2[2])) logger.info( "policy eval accuracy: [top1: %.4f] [top3: %.4f] [top5: %.4f]" % (acc3[0], acc3[1], acc3[2])) with open('./data/train_emb_%s.json' % step, 'w', encoding='utf8') as f: json.dump(train_data_emb, f, ensure_ascii=False) with open('./data/train_rc_emb_%s.json' % step, 'w', encoding='utf8') as f: json.dump(train_rc_emb, f, ensure_ascii=False) with open('./data/eval_emb_%s.json' % step, 'w', encoding='utf8') as f: json.dump(eval_data_emb, f, ensure_ascii=False) if opt.do_predict: test_data = read_data(opt.train_file, opt.threshold, False) # predict proto emb or sentence emb state_dict = torch.load(opt.load_ckpt) own_state = bert_model.state_dict() for name, param in state_dict.items(): name = name.replace('sentence_encoder.module.', '') if name not in own_state: continue own_state[name].copy_(param) bert_model.eval() id_to_emd, root_cause_emb = util.get_emb(bert_model, bert_tokenizer, test_data, opt) if opt.save_emb and opt.save_root_emb: with open(opt.save_emb, 'w', encoding='utf8') as f: json.dump(id_to_emd, f, ensure_ascii=False) with open(opt.save_root_emb, 'w', encoding='utf8') as f: json.dump(root_cause_emb, f, ensure_ascii=False)
if len(sys.argv) > 3: K = int(sys.argv[3]) print("{}-way-{}-shot Few-Shot Relation Classification".format(N, K)) print("Model: {}".format(model_name)) max_length = 40 train_data_loader = JSONFileDataLoader('./data/train.json', './data/glove.6B.50d.json', max_length=max_length) val_data_loader = JSONFileDataLoader('./data/val.json', './data/glove.6B.50d.json', max_length=max_length) test_data_loader = JSONFileDataLoader('./data/test.json', './data/glove.6B.50d.json', max_length=max_length) framework = FewShotREFramework(train_data_loader, val_data_loader, test_data_loader) sentence_encoder = CNNSentenceEncoder(train_data_loader.word_vec_mat, max_length) if model_name == 'proto': model = Proto(sentence_encoder) framework.train(model, model_name, 4, 20, N, K, 5) elif model_name == 'gnn': model = GNN(sentence_encoder, N) framework.train(model, model_name, 2, N, N, K, 1, learning_rate=1e-3, weight_decay=0, optimizer=optim.Adam) elif model_name == 'snail': print("HINT: SNAIL works only in PyTorch 0.3.1") model = SNAIL(sentence_encoder, N, K) framework.train(model, model_name, 25, N, N, K, 1, learning_rate=1e-2, weight_decay=0, optimizer=optim.SGD) elif model_name == 'metanet': model = MetaNet(N, K, train_data_loader.word_vec_mat, max_length) framework.train(model, model_name, 1, N, N, K, 1, learning_rate=5e-3, weight_decay=0, optimizer=optim.Adam, train_iter=300000) else: raise NotImplementedError
def main(): parser = argparse.ArgumentParser() parser.add_argument('--train', default='train_wiki', help='train file') parser.add_argument('--val', default='val_wiki', help='val file') parser.add_argument('--test', default='test_wiki', help='test file') parser.add_argument('--adv', default=None, help='adv file') parser.add_argument('--trainN', default=10, type=int, help='N in train') parser.add_argument('--N', default=5, type=int, help='N way') parser.add_argument('--K', default=5, type=int, help='K shot') parser.add_argument('--Q', default=5, type=int, help='Num of query per class') parser.add_argument('--batch_size', default=4, type=int, help='batch size') parser.add_argument('--train_iter', default=20000, type=int, help='num of iters in training') parser.add_argument('--val_iter', default=1000, type=int, help='num of iters in validation') parser.add_argument('--test_iter', default=2000, type=int, help='num of iters in testing') parser.add_argument('--val_step', default=2000, type=int, help='val after training how many iters') parser.add_argument('--model', default='proto', help='model name') parser.add_argument('--encoder', default='cnn', help='encoder: cnn or bert') parser.add_argument('--max_length', default=128, type=int, help='max length') parser.add_argument('--lr', default=1e-1, type=float, help='learning rate') parser.add_argument('--weight_decay', default=1e-5, type=float, help='weight decay') parser.add_argument('--dropout', default=0.0, type=float, help='dropout rate') parser.add_argument('--na_rate', default=0, type=int, help='NA rate (NA = Q * na_rate)') parser.add_argument('--grad_iter', default=1, type=int, help='accumulate gradient every x iterations') parser.add_argument('--optim', default='sgd', help='sgd / adam / bert_adam') parser.add_argument('--hidden_size', default=230, type=int, help='hidden size') parser.add_argument('--load_ckpt', default=None, help='load ckpt') parser.add_argument('--save_ckpt', default=None, help='save ckpt') parser.add_argument('--fp16', action='store_true', help='use nvidia apex fp16') parser.add_argument('--only_test', action='store_true', help='only test') parser.add_argument('--pair', action='store_true', help='use pair model') parser.add_argument('--language', type=str, default='eng', help='language') parser.add_argument('--sup_cost', type=int, default=0, help='use sup classifier') opt = parser.parse_args() trainN = opt.trainN N = opt.N K = opt.K Q = opt.Q batch_size = opt.batch_size model_name = opt.model encoder_name = opt.encoder max_length = opt.max_length sup_cost = bool(opt.sup_cost) print(sup_cost) print("{}-way-{}-shot Few-Shot Relation Classification".format(N, K)) print("model: {}".format(model_name)) print("encoder: {}".format(encoder_name)) print("max_length: {}".format(max_length)) embsize = 50 if opt.language == 'chn': embsize = 100 if encoder_name == 'cnn': try: if opt.language == 'chn': glove_mat = np.load('./pretrain/chinese_emb/emb.npy') glove_word2id = json.load( open('./pretrain/chinese_emb/word2id.json')) else: glove_mat = np.load('./pretrain/glove/glove_mat.npy') glove_word2id = json.load( open('./pretrain/glove/glove_word2id.json')) except: raise Exception( "Cannot find glove files. Run glove/download_glove.sh to download glove files." ) sentence_encoder = CNNSentenceEncoder(glove_mat, glove_word2id, max_length, word_embedding_dim=embsize) elif encoder_name == 'bert': if opt.pair: if opt.language == 'chn': sentence_encoder = BERTPAIRSentenceEncoder( 'bert-base-chinese', #'./pretrain/bert-base-uncased', max_length) else: sentence_encoder = BERTPAIRSentenceEncoder( 'bert-base-uncased', max_length) else: if opt.language == 'chn': sentence_encoder = BERTSentenceEncoder( 'bert-base-chinese', #'./pretrain/bert-base-uncased', max_length) else: sentence_encoder = BERTSentenceEncoder('bert-base-uncased', max_length) else: raise NotImplementedError if opt.pair: train_data_loader = get_loader_pair(opt.train, sentence_encoder, N=trainN, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size) val_data_loader = get_loader_pair(opt.val, sentence_encoder, N=N, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size) test_data_loader = get_loader_pair(opt.test, sentence_encoder, N=N, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size) else: train_data_loader = get_loader(opt.train, sentence_encoder, N=trainN, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size) val_data_loader = get_loader(opt.val, sentence_encoder, N=N, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size) test_data_loader = get_loader(opt.test, sentence_encoder, N=N, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size) if opt.adv: adv_data_loader = get_loader_unsupervised(opt.adv, sentence_encoder, N=trainN, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size) if opt.optim == 'sgd': pytorch_optim = optim.SGD elif opt.optim == 'adam': pytorch_optim = optim.Adam elif opt.optim == 'bert_adam': from transformers import AdamW pytorch_optim = AdamW else: raise NotImplementedError if opt.adv: d = Discriminator(opt.hidden_size) framework = FewShotREFramework(train_data_loader, val_data_loader, test_data_loader, adv_data_loader, adv=opt.adv, d=d) else: framework = FewShotREFramework(train_data_loader, val_data_loader, test_data_loader) prefix = '-'.join( [model_name, encoder_name, opt.train, opt.val, str(N), str(K)]) if opt.adv is not None: prefix += '-adv_' + opt.adv if opt.na_rate != 0: prefix += '-na{}'.format(opt.na_rate) if model_name == 'proto': model = Proto(sentence_encoder, hidden_size=opt.hidden_size) elif model_name == 'gnn': model = GNN(sentence_encoder, N, use_sup_cost=sup_cost) elif model_name == 'snail': print("HINT: SNAIL works only in PyTorch 0.3.1") model = SNAIL(sentence_encoder, N, K) elif model_name == 'metanet': model = MetaNet(N, K, sentence_encoder.embedding, max_length, use_sup_cost=sup_cost) elif model_name == 'siamese': model = Siamese(sentence_encoder, hidden_size=opt.hidden_size, dropout=opt.dropout) elif model_name == 'pair': model = Pair(sentence_encoder, hidden_size=opt.hidden_size) else: raise NotImplementedError if not os.path.exists('checkpoint'): os.mkdir('checkpoint') ckpt = 'checkpoint/{}.pth.tar'.format(prefix) if opt.save_ckpt: ckpt = opt.save_ckpt if torch.cuda.is_available(): model.cuda() if not opt.only_test: if encoder_name == 'bert': bert_optim = True else: bert_optim = False framework.train(model, prefix, batch_size, trainN, N, K, Q, pytorch_optim=pytorch_optim, load_ckpt=opt.load_ckpt, save_ckpt=ckpt, na_rate=opt.na_rate, val_step=opt.val_step, fp16=opt.fp16, pair=opt.pair, train_iter=opt.train_iter, val_iter=opt.val_iter, bert_optim=bert_optim, sup_cls=sup_cost) else: ckpt = opt.load_ckpt acc = framework.eval(model, batch_size, N, K, Q, opt.test_iter, na_rate=opt.na_rate, ckpt=ckpt, pair=opt.pair) wfile = open('logs/' + ckpt.replace('checkpoint/', '') + '.txt', 'a') wfile.write(str(N) + '\t' + str(K) + '\t' + str(acc * 100) + '\n') wfile.close() print("RESULT: %.2f" % (acc * 100))
def main(): parser = argparse.ArgumentParser() parser.add_argument('--train', default='train_wiki', help='train file') parser.add_argument('--val', default='val_wiki', help='val file') parser.add_argument('--test', default='test_wiki', help='test file') parser.add_argument('--adv', default=None, help='adv file') parser.add_argument('--trainN', default=10, type=int, help='N in train') parser.add_argument('--N', default=5, type=int, help='N way') parser.add_argument('--K', default=5, type=int, help='K shot') parser.add_argument('--Q', default=5, type=int, help='Num of query per class') parser.add_argument('--batch_size', default=4, type=int, help='batch size') parser.add_argument('--train_iter', default=30000, type=int, help='num of iters in training') parser.add_argument('--val_iter', default=1000, type=int, help='num of iters in validation') parser.add_argument('--test_iter', default=10000, type=int, help='num of iters in testing') parser.add_argument('--val_step', default=2000, type=int, help='val after training how many iters') parser.add_argument('--model', default='regrab', help='model name') parser.add_argument('--encoder', default='bert', help='encoder: cnn or bert') parser.add_argument('--max_length', default=128, type=int, help='max length') parser.add_argument('--lr', default=1e-1, type=float, help='learning rate') parser.add_argument('--weight_decay', default=1e-5, type=float, help='weight decay') parser.add_argument('--dropout', default=0.0, type=float, help='dropout rate') parser.add_argument('--na_rate', default=0, type=int, help='NA rate (NA = Q * na_rate)') parser.add_argument('--grad_iter', default=1, type=int, help='accumulate gradient every x iterations') parser.add_argument('--optim', default='sgd', help='sgd / adam / bert_adam') parser.add_argument('--hidden_size', default=230, type=int, help='hidden size') parser.add_argument('--load_ckpt', default=None, help='load ckpt') parser.add_argument('--save_ckpt', default=None, help='save ckpt') parser.add_argument('--fp16', action='store_true', help='use nvidia apex fp16') parser.add_argument('--only_test', action='store_true', help='only test') parser.add_argument('--pair', action='store_true', help='use pair model') parser.add_argument('--eps', default=0.1, type=float, help='step size for SG-MCMC') parser.add_argument('--temp', default=10.0, type=float, help='temperature for softmax') parser.add_argument('--step', default=5, type=int, help='steps for SG-MCMC') parser.add_argument('--smp', default=10, type=int, help='samples for SG-MCMC') parser.add_argument('--ratio', default=0.01, type=float, help='decay ratio of step size for SG-MCMC') parser.add_argument('--wtp', default=0.1, type=float, help='weight of the prior term') parser.add_argument('--wtn', default=1.0, type=float, help='weight of the noise term') parser.add_argument('--wtb', default=0.0, type=float, help='weight of the background term') parser.add_argument('--metric', default='dot', help='similarity metric (dot or l2)') parser.add_argument('--seed', default=1234, type=int, help='random seed') opt = parser.parse_args() trainN = opt.trainN N = opt.N K = opt.K Q = opt.Q batch_size = opt.batch_size model_name = opt.model encoder_name = opt.encoder max_length = opt.max_length torch.manual_seed(opt.seed) np.random.seed(opt.seed) random.seed(opt.seed) torch.cuda.manual_seed(opt.seed) print("{}-way-{}-shot Few-Shot Relation Classification".format(N, K)) print("model: {}".format(model_name)) print("encoder: {}".format(encoder_name)) print("max_length: {}".format(max_length)) if encoder_name == 'cnn': try: glove_mat = np.load('./pretrain/glove/glove_mat.npy') glove_word2id = json.load( open('./pretrain/glove/glove_word2id.json')) except: raise Exception( "Cannot find glove files. Run glove/download_glove.sh to download glove files." ) sentence_encoder = CNNSentenceEncoder(glove_mat, glove_word2id, max_length) elif encoder_name == 'bert': if opt.pair: sentence_encoder = BERTPAIRSentenceEncoder( './pretrain/bert-base-uncased', max_length) else: sentence_encoder = BERTSentenceEncoder( './pretrain/bert-base-uncased', max_length) else: raise NotImplementedError if opt.pair: train_data_loader = get_loader_pair(opt.train, sentence_encoder, N=trainN, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size) val_data_loader = get_loader_pair(opt.val, sentence_encoder, N=N, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size) test_data_loader = get_loader_pair(opt.test, sentence_encoder, N=N, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size) else: train_data_loader = get_loader(opt.train, sentence_encoder, N=trainN, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size) val_data_loader = get_loader(opt.val, sentence_encoder, N=N, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size) test_data_loader = get_loader(opt.test, sentence_encoder, N=N, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size) if opt.adv: adv_data_loader = get_loader_unsupervised(opt.adv, sentence_encoder, N=trainN, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size) if opt.optim == 'sgd': pytorch_optim = optim.SGD elif opt.optim == 'adam': pytorch_optim = optim.Adam elif opt.optim == 'bert_adam': from transformers import AdamW pytorch_optim = AdamW else: raise NotImplementedError if opt.adv: d = Discriminator(opt.hidden_size) framework = FewShotREFramework(train_data_loader, val_data_loader, test_data_loader, adv_data_loader, adv=opt.adv, d=d) else: framework = FewShotREFramework(train_data_loader, val_data_loader, test_data_loader) prefix = '-'.join( [model_name, encoder_name, opt.train, opt.val, str(N), str(K)]) if opt.adv is not None: prefix += '-adv_' + opt.adv if opt.na_rate != 0: prefix += '-na{}'.format(opt.na_rate) # Loading relation embeddings. data = pickle.load(open('./data/embeddings.pkl', 'rb')) rellist = data['relations'] relemb = data['embeddings'] array0 = np.zeros((1, relemb.shape[1]), dtype=relemb.dtype) relemb = np.concatenate([array0, relemb], axis=0) rel2id = dict([(rel, k + 1) for k, rel in enumerate(rellist)]) # Loading relation graphs. with open('./data/graph.txt', 'r') as fi: us, vs, ws = [], [], [] for line in fi: items = line.strip().split('\t') us += [rel2id[items[0]]] vs += [rel2id[items[1]]] ws += [float(items[2])] index = torch.LongTensor([us, vs]) value = torch.Tensor(ws) shape = torch.Size([len(rel2id) + 1, len(rel2id) + 1]) reladj = torch.sparse.FloatTensor(index, value, shape).cuda() # End model = REGRAB(sentence_encoder, hidden_size=opt.hidden_size, eps=opt.eps, temp=opt.temp, step=opt.step, smp=opt.smp, ratio=opt.ratio, wtp=opt.wtp, wtn=opt.wtn, wtb=opt.wtb, metric=opt.metric) model.set_relemb(rel2id, relemb) model.set_reladj(reladj) if not os.path.exists('checkpoint'): os.mkdir('checkpoint') ckpt = 'checkpoint/{}.pth.tar'.format(prefix) if opt.save_ckpt: ckpt = opt.save_ckpt if torch.cuda.is_available(): model.cuda() if not opt.only_test: if encoder_name == 'bert': bert_optim = True else: bert_optim = False framework.train(model, prefix, batch_size, trainN, N, K, Q, pytorch_optim=pytorch_optim, load_ckpt=opt.load_ckpt, save_ckpt=ckpt, na_rate=opt.na_rate, val_step=opt.val_step, fp16=opt.fp16, pair=opt.pair, train_iter=opt.train_iter, val_iter=opt.val_iter, bert_optim=bert_optim) else: ckpt = opt.load_ckpt acc = framework.eval(model, batch_size, N, K, Q, opt.test_iter, na_rate=opt.na_rate, ckpt=ckpt, pair=opt.pair) print("RESULT: %.2f" % (acc * 100))
max_length=max_length) val_data_loader = JSONFileDataLoader('./data/val.json', './data/glove.6B.50d.json', max_length=max_length) test_data_loader = JSONFileDataLoader('./data/test.json', './data/glove.6B.50d.json', max_length=max_length) framework = FewShotREFramework(train_data_loader, val_data_loader, test_data_loader) sentence_encoder = CNNSentenceEncoder(train_data_loader.word_vec_mat, max_length) if model_name == 'proto': model = Proto(sentence_encoder) framework.train(model, model_name, 4, 20, N, K, 5, noise_rate=noise_rate) elif model_name == 'proto_hatt': model = ProtoHATT(sentence_encoder, K) framework.train(model, model_name, 4, 20, N, K, 5, lr_step_size=5000, train_iter=15000, noise_rate=noise_rate) else: raise NotImplementedError
'./data/w2v.json', max_length=max_length) test_data_loader = myJsonFileDataLoader('./data/test_data.json', './data/w2v.json', max_length=max_length) framework = FewShotREFramework(train_data_loader, val_data_loader, test_data_loader) sentence_encoder = CNNSentenceEncoder(train_data_loader.word_vec_mat, max_length) #,model) if model_name == 'proto': model = Proto(sentence_encoder) framework.train(model, model_name, 4, 5, N, K, 5, train_iter=15000, noise_rate=noise_rate) elif model_name == 'proto_hatt': model = ProtoHATT(sentence_encoder, K) framework.train(model, model_name, 4, 5, N, K, 5, noise_rate=noise_rate) elif model_name == 'proto_idatt': model = ProtoIDATT(sentence_encoder, K) framework.train(model, model_name, 4, 5, N, K, 5, noise_rate=noise_rate) else: raise NotImplementedError
d = Discriminator(hidden_size) framework = FewShotREFramework(train_data_loader, val_data_loader, None, adv_data_loader, adv, d) prefix += '-adv_' + adv else: framework = FewShotREFramework(train_data_loader, val_data_loader, None) framework.train(model, prefix, batch_size, trainN, N, K, Q, pytorch_optim=optim, load_ckpt=None, na_rate=na_rate, val_step=val_step, fp16=None, pair=None, train_iter=train_iter, val_iter=val_iter, bert_optim=bert_optim) acc = framework.eval(model, batch_size, N, K, Q, test_iter, na_rate=na_rate,
def main(): parser = argparse.ArgumentParser() parser.add_argument('--train', default='train_data', help='train file') parser.add_argument('--val', default='val_data', help='val file') parser.add_argument('--test', default='test_data', help='test file') parser.add_argument('--adv', default=None, help='adv file') parser.add_argument('--trainN', default=10, type=int, help='N in train') parser.add_argument('--N', default=5, type=int, help='N way') parser.add_argument('--K', default=5, type=int, help='K shot') parser.add_argument('--Q', default=5, type=int, help='Num of query per class') parser.add_argument('--batch_size', default=4, type=int, help='batch size') parser.add_argument('--train_iter', default=30000, type=int, help='num of iters in training') parser.add_argument('--val_iter', default=1000, type=int, help='num of iters in validation') parser.add_argument('--test_iter', default=1000, type=int, help='num of iters in testing') parser.add_argument('--val_step', default=2000, type=int, help='val after training how many iters') parser.add_argument('--repeat_test', default=8, type=int, help='repeat test stage') parser.add_argument('--model', default='proto', help='model name') parser.add_argument('--encoder', default='bert', help='encoder: cnn or bert or roberta') parser.add_argument('--cross_modality', default='concate', help='cross-modality module') parser.add_argument('--max_length', default=128, type=int, help='max length') parser.add_argument('--lr', default=1e-1, type=float, help='learning rate') parser.add_argument('--weight_decay', default=1e-5, type=float, help='weight decay') parser.add_argument('--dropout', default=0.0, type=float, help='dropout rate') parser.add_argument('--na_rate', default=0, type=int, help='NA rate (NA = Q * na_rate)') parser.add_argument('--grad_iter', default=1, type=int, help='accumulate gradient every x iterations') parser.add_argument('--optim', default='sgd', help='sgd / adam / adamw') parser.add_argument('--hidden_size', default=230, type=int, help='hidden size') parser.add_argument('--multi_choose', default=1, type=int, help='choose multi random faces') parser.add_argument('--load_ckpt', default=None, help='load ckpt') parser.add_argument('--root_data', default='./data', help='the root path stores data') parser.add_argument('--save_ckpt', default=None, help='save ckpt') parser.add_argument('--fp16', action='store_true', help='use nvidia apex fp16') parser.add_argument('--only_test', action='store_true', help='only test') # only for bert / roberta parser.add_argument('--pair', action='store_true', help='use pair model') parser.add_argument('--pretrain_ckpt', default=None, help='bert / roberta pre-trained checkpoint') parser.add_argument('--cat_entity_rep', action='store_true', help='concatenate entity representation as sentence rep') parser.add_argument('--use_img', action='store_true', help='use img info') # only for prototypical networks parser.add_argument('--dot', action='store_true', help='use dot instead of L2 distance for proto') parser.add_argument('--differ_scene', action='store_true', help='use face image in different scenes') opt = parser.parse_args() trainN = opt.trainN N = opt.N K = opt.K Q = opt.Q batch_size = opt.batch_size model_name = opt.model encoder_name = opt.encoder max_length = opt.max_length print("{}-way-{}-shot Multimodal Social Relation Classification".format(N, K)) print("model: {}".format(model_name)) print("encoder: {}".format(encoder_name)) print("max_length: {}".format(max_length)) if encoder_name == 'cnn': try: glove_mat = np.load('./pretrain/glove/glove_mat.npy') glove_word2id = json.load(open('./pretrain/glove/glove_word2id.json')) except: raise Exception("Cannot find glove files. Run glove/download_glove.sh to download glove files.") sentence_encoder = CNNSentenceEncoder( glove_mat, glove_word2id, max_length) elif encoder_name == 'bert': pretrain_ckpt = opt.pretrain_ckpt or 'bert-base-chinese' sentence_encoder = BERTSentenceEncoder(pretrain_ckpt, max_length) elif encoder_name == 'roberta': pretrain_ckpt = opt.pretrain_ckpt or 'roberta-base' sentence_encoder = RobertaSentenceEncoder( pretrain_ckpt, max_length, cat_entity_rep=opt.cat_entity_rep) else: raise NotImplementedError train_data_loader = get_loader(opt.train, sentence_encoder, N=trainN, K=K, Q=Q, root=opt.root_data, batch_size=batch_size, use_img=opt.use_img, differ_scene=opt.differ_scene, multi_choose=opt.multi_choose) val_data_loader = get_loader(opt.val, sentence_encoder, N=N, K=K, Q=Q, root=opt.root_data, batch_size=batch_size, use_img=opt.use_img, differ_scene=opt.differ_scene, multi_choose=opt.multi_choose) test_data_loader = get_loader(opt.test, sentence_encoder, N=N, K=K, Q=Q, root=opt.root_data, batch_size=batch_size, use_img=opt.use_img, differ_scene=opt.differ_scene, multi_choose=opt.multi_choose) if opt.optim == 'sgd': pytorch_optim = optim.SGD elif opt.optim == 'adam': pytorch_optim = optim.Adam elif opt.optim == 'adamw': from transformers import AdamW pytorch_optim = AdamW else: raise NotImplementedError framework = FewShotREFramework(train_data_loader, val_data_loader, test_data_loader) if opt.differ_scene: prefix = '-'.join(['differ', model_name, encoder_name, opt.train, opt.val, str(N), str(K)]) else: prefix = '-'.join(['same', model_name, encoder_name, opt.train, opt.val, str(N), str(K)]) if opt.use_img: prefix = 'img_' + prefix if opt.cross_modality == 'transformer': prefix = 'transformer' + '-' + prefix if opt.multi_choose > 1: prefix = 'choose-' + str(opt.multi_choose) + '-' + prefix prefix = '0227' + prefix if model_name == 'proto': if opt.use_img: face_encoder = FacenetEncoder() multimodal_encoder = MultimodalEncoder(sentence_encoder=sentence_encoder, face_encoder=face_encoder, cross_modality=opt.cross_modality) model = Proto(multimodal_encoder, dot=opt.dot) else: model = Proto(sentence_encoder, dot=opt.dot) else: raise NotImplementedError # if not os.path.exists('checkpoint'): # os.mkdir('checkpoint') # ckpt = 'checkpoint/{}.pth.tar'.format(prefix) checkpoint_dir = os.path.join(opt.root_data, 'checkpoint') if not os.path.exists(checkpoint_dir): os.mkdir(checkpoint_dir) ckpt = os.path.join(checkpoint_dir, '{}.pth.tar'.format(prefix)) if opt.save_ckpt: ckpt = opt.save_ckpt if torch.cuda.is_available(): model.cuda() if not opt.only_test: if encoder_name in ['bert', 'roberta']: bert_optim = True else: bert_optim = False framework.train(model, prefix, batch_size, trainN, N, K, Q, pytorch_optim=pytorch_optim, load_ckpt=opt.load_ckpt, save_ckpt=ckpt, na_rate=opt.na_rate, val_step=opt.val_step, grad_iter=opt.grad_iter, fp16=opt.fp16, pair=opt.pair, train_iter=opt.train_iter, val_iter=opt.val_iter, bert_optim=bert_optim, multi_choose=opt.multi_choose) else: ckpt = opt.load_ckpt result_txt = os.path.join(checkpoint_dir, '{}.txt'.format(opt.root_data.split('/')[1] + '_' + prefix)) with open(result_txt, 'a', encoding='utf-8') as f: for _ in range (opt.repeat_test): acc = framework.eval(model, batch_size, N, K, Q, opt.test_iter, na_rate=opt.na_rate, ckpt=ckpt, pair=opt.pair, multi_choose=opt.multi_choose) print("RESULT: %.2f" % (acc * 100)) f.write(str(acc * 100)) f.write('\n')
N = int(sys.argv[2]) if len(sys.argv) > 3: K = int(sys.argv[3]) if len(sys.argv) > 4: noise_rate = float(sys.argv[4]) print("{}-way-{}-shot Few-Shot Relation Classification".format(N, K)) print("Model: {}".format(model_name)) max_length = 100 train_data_loader = JSONFileDataLoader('./data/train.json', './data/glove.6B.50d.json', max_length=max_length) val_data_loader = JSONFileDataLoader('./data/val.json', './data/glove.6B.50d.json', max_length=max_length) test_data_loader = JSONFileDataLoader('./data/val.json', './data/glove.6B.50d.json', max_length=max_length) framework = FewShotREFramework(train_data_loader, val_data_loader, test_data_loader) sentence_encoder = TransformerSentenceEncoder(train_data_loader.word_vec_mat, max_length) if model_name == 'proto': model = Proto(sentence_encoder) framework.train(model, model_name, 1, 10, N, K, 5, noise_rate=noise_rate) else: raise NotImplementedError
def main(): parser = argparse.ArgumentParser() parser.add_argument('--train', default='train_wiki', help='train file') parser.add_argument('--val', default='val_wiki', help='val file') parser.add_argument('--test', default='test_wiki', help='test file') parser.add_argument('--adv', default=None, help='adv file') parser.add_argument('--trainN', default=10, type=int, help='N in train') parser.add_argument('--N', default=5, type=int, help='N way') parser.add_argument('--K', default=5, type=int, help='K shot') parser.add_argument('--Q', default=5, type=int, help='Num of query per class') parser.add_argument('--batch_size', default=4, type=int, help='batch size') parser.add_argument('--train_iter', default=30000, type=int, help='num of iters in training') parser.add_argument('--val_iter', default=1000, type=int, help='num of iters in validation') parser.add_argument('--test_iter', default=3000, type=int, help='num of iters in testing') parser.add_argument('--val_step', default=2000, type=int, help='val after training how many iters') parser.add_argument('--model', default='proto', help='model name') parser.add_argument('--encoder', default='cnn', help='encoder: cnn or bert') parser.add_argument('--max_length', default=128, type=int, help='max length') parser.add_argument('--lr', default=1e-1, type=float, help='learning rate') parser.add_argument('--weight_decay', default=1e-5, type=float, help='weight decay') parser.add_argument('--dropout', default=0.0, type=float, help='dropout rate') parser.add_argument('--na_rate', default=0, type=int, help='NA rate (NA = Q * na_rate)') parser.add_argument('--grad_iter', default=1, type=int, help='accumulate gradient every x iterations') parser.add_argument('--optim', default='sgd', help='sgd / adam / bert_adam') parser.add_argument('--hidden_size', default=230, type=int, help='hidden size') parser.add_argument('--load_ckpt', default=None, help='load ckpt') parser.add_argument('--save_ckpt', default=None, help='save ckpt') parser.add_argument('--fp16', action='store_true', help='use nvidia apex fp16') parser.add_argument('--only_test', action='store_true', help='only test') parser.add_argument('--pair', action='store_true', help='use pair model') opt = parser.parse_args() trainN = opt.trainN N = opt.N K = opt.K Q = opt.Q batch_size = opt.batch_size model_name = opt.model encoder_name = opt.encoder max_length = opt.max_length print("{}-way-{}-shot Few-Shot Relation Classification".format(N, K)) print("model: {}".format(model_name)) print("encoder: {}".format(encoder_name)) print("max_length: {}".format(max_length)) if encoder_name == 'cnn': try: glove_mat = np.load('./pretrain/glove/glove_mat.npy') glove_word2id = json.load( open('./pretrain/glove/glove_word2id.json')) except: raise Exception( "Cannot find glove files. Run glove/download_glove.sh to download glove files." ) sentence_encoder = CNNSentenceEncoder(glove_mat, glove_word2id, max_length) elif encoder_name == 'bert': if opt.pair: sentence_encoder = BERTPAIRSentenceEncoder( './pretrain/bert-base-uncased', max_length) else: sentence_encoder = BERTSentenceEncoder( './pretrain/bert-base-uncased', max_length) else: raise NotImplementedError if opt.pair: train_data_loader = get_loader_pair(opt.train, sentence_encoder, N=trainN, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size) val_data_loader = get_loader_pair(opt.val, sentence_encoder, N=N, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size) test_data_loader = get_loader_pair(opt.test, sentence_encoder, N=N, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size) else: train_data_loader = get_loader(opt.train, sentence_encoder, N=trainN, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size) val_data_loader = get_loader(opt.val, sentence_encoder, N=N, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size) test_data_loader = get_loader(opt.test, sentence_encoder, N=N, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size) if opt.adv: adv_data_loader = get_loader_unsupervised(opt.adv, sentence_encoder, N=trainN, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size) if opt.optim == 'sgd': pytorch_optim = optim.SGD elif opt.optim == 'adam': pytorch_optim = optim.Adam elif opt.optim == 'bert_adam': from pytorch_transformers import AdamW pytorch_optim = AdamW else: raise NotImplementedError if opt.adv: d = Discriminator(opt.hidden_size) framework = FewShotREFramework(train_data_loader, val_data_loader, test_data_loader, adv_data_loader, adv=opt.adv, d=d) else: framework = FewShotREFramework(train_data_loader, val_data_loader, test_data_loader) prefix = '-'.join( [model_name, encoder_name, opt.train, opt.val, str(N), str(K)]) if opt.adv is not None: prefix += '-adv_' + opt.adv if opt.na_rate != 0: prefix += '-na{}'.format(opt.na_rate) if model_name == 'proto': model = Proto(sentence_encoder, hidden_size=opt.hidden_size) elif model_name == 'gnn': model = GNN(sentence_encoder, N) elif model_name == 'snail': print("HINT: SNAIL works only in PyTorch 0.3.1") model = SNAIL(sentence_encoder, N, K) elif model_name == 'metanet': model = MetaNet(N, K, sentence_encoder.embedding, max_length) elif model_name == 'siamese': model = Siamese(sentence_encoder, hidden_size=opt.hidden_size, dropout=opt.dropout) elif model_name == 'pair': model = Pair(sentence_encoder, hidden_size=opt.hidden_size) else: raise NotImplementedError if not os.path.exists('checkpoint'): os.mkdir('checkpoint') ckpt = 'checkpoint/{}.pth.tar'.format(prefix) if opt.save_ckpt: ckpt = opt.save_ckpt if torch.cuda.is_available(): model.cuda() if not opt.only_test: if encoder_name == 'bert': bert_optim = True else: bert_optim = False framework.train(model, prefix, batch_size, trainN, N, K, Q, pytorch_optim=pytorch_optim, load_ckpt=opt.load_ckpt, save_ckpt=ckpt, na_rate=opt.na_rate, val_step=opt.val_step, fp16=opt.fp16, pair=opt.pair, train_iter=opt.train_iter, val_iter=opt.val_iter, bert_optim=bert_optim) else: ckpt = opt.load_ckpt acc = 0 his_acc = [] total_test_round = 5 for i in range(total_test_round): cur_acc = framework.eval(model, batch_size, N, K, Q, opt.test_iter, na_rate=opt.na_rate, ckpt=ckpt, pair=opt.pair) his_acc.append(cur_acc) acc += cur_acc acc /= total_test_round nhis_acc = np.array(his_acc) error = nhis_acc.std() * 1.96 / (nhis_acc.shape[0]**0.5) print("RESULT: %.2f\\pm%.2f" % (acc * 100, error * 100)) result_file = open('./result.txt', 'a+') result_file.write( "test data: %12s | model: %45s | acc: %.6f\n | error: %.6f\n" % (opt.test, prefix, acc, error)) result_file = open('./result_detail.txt', 'a+') result_detail = { 'test': opt.test, 'model': prefix, 'acc': acc, 'his': his_acc } result_file.write("%s\n" % (json.dumps(result_detail)))