def main(): config = json.load(open(args.training_config, 'r')) indices.set_relation_classes(args.relation_config) pred2idx, idx2pred, _ = indices.load_predicates( config['predicate_indices']) argw2idx, idx2argw, _ = indices.load_argw(config['argw_indices']) n_preds = len(pred2idx) argw_vocabs = argw2idx.keys() argw_encoder = create_argw_encoder(config, args.device) if args.encoder_file: argw_encoder.load(args.encoder_file) logging.info("model class: " + config['model_type']) ModelClass = eval(config['model_type']) model = ModelClass(config, argw_encoder, n_preds, args.device).to(args.device) model.load_state_dict( torch.load(args.model_file, map_location=lambda storage, location: storage)) questions = pkl.load(open(args.question_file, 'r')) logging.info("#questions={}".format(len(questions))) n_correct, n_incorrect = 0, 0 rtype = indices.REL2IDX[ indices.REL_CONTEXT] if args.context_rel else indices.REL2IDX[ indices.REL_COREF] rtype = torch.LongTensor([rtype]).to(args.device) widgets = [progressbar.Percentage(), progressbar.Bar(), progressbar.ETA()] bar = progressbar.ProgressBar(widgets=widgets, maxval=len(questions)).start() logging.info("batch_size = {}".format(args.batch_size)) batch_size = args.batch_size n_batches = len(questions) // batch_size + 1 logging.info("n_batches = {}".format(n_batches)) i_q = 0 for i_batch in range(n_batches): batch_questions = questions[i_batch * batch_size:(i_batch + 1) * batch_size] e2idx, embeddings = build_embeddings(model, batch_questions, config, pred2idx, argw2idx, rtype) for q in batch_questions: pred = intrinsic.predict_mcnc(model, q, e2idx, embeddings, rtype, args.device) if pred == q.ans_idx: n_correct += 1 else: n_incorrect += 1 i_q += 1 bar.update(i_q) bar.finish() print("n_correct={}, n_incorrect={}".format(n_correct, n_incorrect)) print("accuracy={}".format(float(n_correct) / (n_correct + n_incorrect)))
def main(): # DNEE t1 = time.time() indices.set_relation_classes(args.relation_config) config = json.load(open(args.training_config, 'r')) pred2idx, idx2pred, _ = indices.load_predicates(config['predicate_indices']) argw2idx, idx2argw, _ = indices.load_argw(config['argw_indices']) n_preds = len(pred2idx) argw_vocabs = argw2idx.keys() argw_encoder = create_argw_encoder(config, args.device) argw_encoder.load(args.encoder_file) logging.info("model class: " + config['model_type']) ModelClass = eval(config['model_type']) dnee_model = ModelClass(config, argw_encoder, n_preds, args.device).to(args.device) dnee_model.load_state_dict(torch.load(args.model_file, map_location=lambda storage, location: storage)) logging.info('Loading DNEE: {} s'.format(time.time()-t1)) elmo = Elmo(args.elmo_option_file, args.elmo_weight_file, 1, dropout=0) # DS_TRAIN_FLD = "ds_train_elmo" # DNEE_TRAIN_FLD = "ds_train_transe" if config['model_type'] == 'EventTransE' else 'ds_train_transr_tmp' # train_data = ds.DsDataset(DS_TRAIN_FLD, DNEE_TRAIN_FLD) # logging.info("DNEE_TRAIN_FLD={}".format(DNEE_TRAIN_FLD)) # seq_len = train_data.seq_len # event_seq_len = train_data.dnee_seq_len # These are the max seq length from training data (above code) # We hardcode them to avoid loading training data seq_len, event_seq_len = 392, 14 logging.info("seq_len={}, event_seq_len={}".format(seq_len, event_seq_len)) t1 = time.time() logging.info('dev...') fw_path = os.path.join(args.output_folder, 'dev_res.json') eval_relations(args.ds_dev_file, fw_path, elmo, dnee_model, config, pred2idx, argw2idx, seq_len, event_seq_len) logging.info('Eval DEV: {} s'.format(time.time()-t1)) t1 = time.time() logging.info('test...') fw_path = os.path.join(args.output_folder, 'test_res.json') eval_relations(args.ds_test_file, fw_path, elmo, dnee_model, config, pred2idx, argw2idx, seq_len, event_seq_len) logging.info('Eval TEST: {} s'.format(time.time()-t1)) t1 = time.time() logging.info('blind...') fw_path = os.path.join(args.output_folder, 'blind_res.json') eval_relations(args.ds_blind_file, fw_path, elmo, dnee_model, config, pred2idx, argw2idx, seq_len, event_seq_len) logging.info('Eval BLIND: {} s'.format(time.time()-t1))
def main(): config = json.load(open(args.training_config, 'r')) relation_config = json.load(open(args.relation_config, 'r')) pred2idx, idx2pred, _ = indices.load_predicates( config['predicate_indices']) argw2idx, idx2argw, _ = indices.load_argw(config['argw_indices']) n_preds = len(pred2idx) argw_vocabs = argw2idx.keys() argw_encoder = create_argw_encoder(config, args.device) if args.encoder_file: argw_encoder.load(args.encoder_file) logging.info("model class: " + config['model_type']) ModelClass = eval(config['model_type']) model = ModelClass(config, argw_encoder, n_preds, args.device).to(args.device) model.load_state_dict( torch.load(args.model_file, map_location=lambda storage, location: storage)) questions = pkl.load(open(args.question_file, 'r')) logging.info("#questions={}".format(len(questions))) logging.info('predict relation') y, y_pred = eval_by_events(questions, model, config, pred2idx, argw2idx, relation_config) logging.info("predict relation, accuracy={}".format( accuracy_score(y, y_pred))) logging.info("predict relation, accuracy={}".format(acc(y, y_pred))) logging.info('predict next event') y, y_pred = eval_by_rel(questions, model, config, pred2idx, argw2idx, relation_config) logging.info("predict next event, accuracy={}".format( accuracy_score(y, y_pred))) logging.info("predict next event, accuracy={}".format(acc(y, y_pred))) logging.info('predict next event by random rel') y, y_pred = eval_by_random_rel(questions, model, config, pred2idx, argw2idx, relation_config) logging.info("by random rel, accuracy={}".format(accuracy_score(y, y_pred))) logging.info("by random rel, accuracy={}".format(acc(y, y_pred))) logging.info('predict next event by next rel') y, y_pred = eval_by_next_rel(questions, model, config, pred2idx, argw2idx, relation_config) logging.info("by next rel, accuracy={}".format(accuracy_score(y, y_pred))) logging.info("by next rel, accuracy={}".format(acc(y, y_pred)))
def main(): config = json.load(open(args.training_config, 'r')) indices.set_relation_classes(args.relation_config) if args.use_elmo: logging.info("using ELMo") elmo = Elmo(args.elmo_option_file, args.elmo_weight_file, 1, dropout=0).to(args.device) else: pred2idx, idx2pred, _ = indices.load_predicates( config['predicate_indices']) argw2idx, idx2argw, _ = indices.load_argw(config['argw_indices']) n_preds = len(pred2idx) argw_vocabs = argw2idx.keys() argw_encoder = create_argw_encoder(config, args.device) if args.encoder_file: argw_encoder.load(args.encoder_file) logging.info("model class: " + config['model_type']) ModelClass = eval(config['model_type']) dnee_model = ModelClass(config, argw_encoder, n_preds, args.device).to(args.device) dnee_model.load_state_dict( torch.load(args.model_file, map_location=lambda storage, location: storage)) model = elmo if args.use_elmo else dnee_model results = torch.zeros((args.n_rounds, len(indices.REL2IDX)), dtype=torch.float32) precisions = torch.zeros((args.n_rounds, len(indices.REL2IDX)), dtype=torch.float32) recalls = torch.zeros((args.n_rounds, len(indices.REL2IDX)), dtype=torch.float32) f1s = torch.zeros((args.n_rounds, len(indices.REL2IDX)), dtype=torch.float32) for i_round in range(args.n_rounds): logging.info("ROUND {}".format(i_round)) # dev dev_questions = sample_questions(args.dev_question_file, n_cat_questions=500) if args.use_elmo: dev_e2idx, dev_ev_embeddings = build_elmo(dev_questions, elmo) else: dev_e2idx, dev_ev_embeddings = build_ev_embeddings( dev_questions, config, pred2idx, argw2idx, dnee_model) thresholds = dev_thresholds(dev_questions, model, dev_e2idx, dev_ev_embeddings, args.step_size) # test results test_questions = sample_questions(args.test_question_file, n_cat_questions=500) if args.use_elmo: test_e2idx, test_ev_embeddings = build_elmo(test_questions, elmo) else: test_e2idx, test_ev_embeddings = build_ev_embeddings( test_questions, config, pred2idx, argw2idx, dnee_model) test_scores = score_questions(test_questions, model, test_e2idx, test_ev_embeddings) for i_cat in test_questions.keys(): y = torch.LongTensor([label for label, q in test_questions[i_cat] ]).to(args.device) acc = pred_acc(y, test_scores[i_cat], thresholds[i_cat]) if args.use_elmo: y_preds = (test_scores[i_cat] > thresholds[i_cat]).type( torch.int64) else: y_preds = (test_scores[i_cat] < thresholds[i_cat]).type( torch.int64) y = y.detach().cpu().numpy() y_preds = y_preds.detach().cpu().numpy() prec = metrics.precision_score(y, y_preds) rec = metrics.recall_score(y, y_preds) f1 = metrics.f1_score(y, y_preds) logging.info("i_cat={} ({}), test_acc={}".format( i_cat, indices.IDX2REL[i_cat], acc)) logging.info("i_cat={} ({}), test_prec={}".format( i_cat, indices.IDX2REL[i_cat], prec)) logging.info("i_cat={} ({}), test_rec={}".format( i_cat, indices.IDX2REL[i_cat], rec)) results[i_round][i_cat] = acc precisions[i_round][i_cat] = prec recalls[i_round][i_cat] = rec f1s[i_round][i_cat] = f1 avg = torch.mean(results, dim=0) avg_precisions = torch.mean(precisions, dim=0) avg_recalls = torch.mean(recalls, dim=0) avg_f1s = torch.mean(f1s, dim=0) for i_cat in test_questions.keys(): logging.info("i_cat={} ({}), avg_test_acc={} over {} rounds".format( i_cat, indices.IDX2REL[i_cat], avg[i_cat], args.n_rounds)) logging.info("i_cat={} ({}), avg_test_prec={} over {} rounds".format( i_cat, indices.IDX2REL[i_cat], avg_precisions[i_cat], args.n_rounds)) logging.info("i_cat={} ({}), avg_test_rec={} over {} rounds".format( i_cat, indices.IDX2REL[i_cat], avg_recalls[i_cat], args.n_rounds)) logging.info("i_cat={} ({}), avg_test_f1={} over {} rounds".format( i_cat, indices.IDX2REL[i_cat], avg_f1s[i_cat], args.n_rounds))
def main(): logging.info('using {} for computation.'.format(args.device)) config = json.load(open(args.training_config, 'r')) indices.set_relation_classes(args.relation_config) pred2idx, idx2pred, _ = indices.load_predicates( config['predicate_indices']) argw2idx, idx2argw, _ = indices.load_argw(config['argw_indices']) n_preds = len(pred2idx) argw_encoder = create_argw_encoder(config, args.device) argw_encoder.load(args.encoder_file) logging.info("model class: " + config['model_type']) ModelClass = eval(config['model_type']) dnee_model = ModelClass(config, argw_encoder, n_preds, args.device).to(args.device) dnee_model.load_state_dict( torch.load(args.model_file, map_location=lambda storage, location: storage)) dnee_model.eval() elmo = Elmo(args.elmo_option_file, args.elmo_weight_file, 1, dropout=0) train_data = ds.DsDataset(args.ds_train_fld, args.dnee_train_fld) dev_rels = [json.loads(line) for line in open(args.ds_dev_rel_file)] dev_rels = [ rel for rel in dev_rels if rel['Type'] != 'Explicit' and rel['Sense'][0] in indices.DISCOURSE_REL2IDX ] dev_cm, dev_valid_senses = ds.create_cm(dev_rels, indices.DISCOURSE_REL2IDX) dnee_seq_len = train_data.dnee_seq_len if args.dnee_train_fld else None x_dev, y_dev = ds.get_features(dev_rels, elmo, train_data.seq_len, dnee_model, dnee_seq_len, config, pred2idx, argw2idx, indices.DISCOURSE_REL2IDX, device=args.device, use_dnee=(args.dnee_train_fld is not None)) x0_dev, x1_dev, x0_dnee_dev, x1_dnee_dev, x_dnee_dev = x_dev x0_dev, x1_dev = x0_dev.to(args.device), x1_dev.to(args.device) model = ds.AttentionNN(len(indices.DISCOURSE_REL2IDX), event_dim=config['event_dim'], dropout=args.dropout, use_event=(args.dnee_train_fld is not None), use_dnee_scores=not args.no_dnee_scores).to( args.device) optimizer = optim.Adagrad(filter(lambda p: p.requires_grad, model.parameters()), lr=args.learning_rate) logging.info("initial learning rate = {}".format(args.learning_rate)) logging.info("dropout rate = {}".format(args.dropout)) # arg_lens = [config['arg0_max_len'], config['arg1_max_len']] losses = [] dev_f1s = [] best_dev_f1, best_epoch, best_batch = -1, -1, -1 logging.info("batch_size = {}".format(args.batch_size)) for i_epoch in range(args.n_epoches): train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=1) epoch_start = time.time() for i_batch, (x, y) in enumerate(train_loader): if y.shape[0] != args.batch_size: # skip the last batch continue if args.dnee_train_fld: x0, x1, x0_dnee, x1_dnee, x_dnee = x x0 = x0.to(args.device) x1 = x1.to(args.device) x0_dnee = x0_dnee.to(args.device) x1_dnee = x1_dnee.to(args.device) x_dnee = x_dnee.to(args.device) else: x0, x1 = x x0 = x0.to(args.device) x1 = x1.to(args.device) x0_dnee, x1_dnee, x_dnee = None, None, None y = y.squeeze().to(args.device) model.train() optimizer.zero_grad() out = model(x0, x1, x0_dnee, x1_dnee, x_dnee) loss = model.loss_func(out, y) # step loss.backward() optimizer.step() losses.append(loss.item()) model.eval() y_pred = model.predict(x0_dev, x1_dev, x0_dnee_dev, x1_dnee_dev, x_dnee_dev) dev_prec, dev_recall, dev_f1 = ds.scoring_cm( y_dev, y_pred.cpu(), dev_cm, dev_valid_senses, indices.DISCOURSE_IDX2REL) dev_f1s.append(dev_f1) ## if i_batch % config['n_batches_per_record'] == 0: logging.info("{}, {}: loss={}, time={}".format( i_epoch, i_batch, loss.item(), time.time() - epoch_start)) logging.info("dev: prec={}, recall={}, f1={}".format( dev_prec, dev_recall, dev_f1)) if dev_f1 > best_dev_f1: logging.info("best dev: prec={}, recall={}, f1={}".format( dev_prec, dev_recall, dev_f1)) best_dev_f1 = dev_f1 best_epoch = i_epoch best_batch = i_batch fpath = os.path.join(args.output_folder, 'best_model.pt') torch.save(model.state_dict(), fpath) logging.info("{}-{}: best dev f1 = {}".format(best_epoch, best_batch, best_dev_f1)) fpath = os.path.join(args.output_folder, "losses.pkl") pkl.dump(losses, open(fpath, 'wb')) fpath = os.path.join(args.output_folder, "dev_f1s.pkl") pkl.dump(dev_f1s, open(fpath, 'wb'))
def main(): config = json.load(open(args.training_config, 'r')) indices.set_relation_classes(args.relation_config) pred2idx, idx2pred, _ = indices.load_predicates( config['predicate_indices']) argw2idx, idx2argw, _ = indices.load_argw(config['argw_indices']) n_preds = len(pred2idx) argw_vocabs = argw2idx.keys() argw_encoder = create_argw_encoder(config, args.device) if args.encoder_file: argw_encoder.load(args.encoder_file) logging.info("model class: " + config['model_type']) ModelClass = eval(config['model_type']) model = ModelClass(config, argw_encoder, n_preds, args.device).to(args.device) model.load_state_dict( torch.load(args.model_file, map_location=lambda storage, location: storage)) we = utils.load_word_embeddings(args.word_embedding_file, use_torch=True) questions = pkl.load(open(args.question_file, 'r')) if not args.no_subsample: # ridxs = list(range(len(questions))) # random.shuffle(ridxs) # ridxs = [ridxs[i] for i in range(1000)] # questions = [questions[i] for i in ridxs] questions = questions[:10000] logging.info("#questions={}".format(len(questions))) rtype = indices.REL2IDX[ indices.REL_CONTEXT] if args.context_rel else indices.REL2IDX[ indices.REL_COREF] rtype = torch.LongTensor([rtype]).to(args.device) widgets = [progressbar.Percentage(), progressbar.Bar(), progressbar.ETA()] bar = progressbar.ProgressBar(widgets=widgets, maxval=len(questions)).start() logging.info("batch_size = {}".format(args.batch_size)) batch_size = args.batch_size n_batches = len(questions) // batch_size + 1 if len( questions) % batch_size != 0 else len(questions) // batch_size logging.info("#questions = {}".format(len(questions))) logging.info("n_batches = {}".format(n_batches)) i_q = 0 # ((we, ev, mix), (mcns, mcne),(incorrect, correct)) WE_IDX, EV_IDX, MIX_IDX = 0, 1, 2 MCNS_IDX, MCNE_IDX = 0, 1 INCORRECT_IDX, CORRECT_IDX = 0, 1 results = torch.zeros((3, 2, 2), dtype=torch.int64) for i_batch in range(n_batches): batch_questions = questions[i_batch * batch_size:(i_batch + 1) * batch_size] e2idx, ev_embeddings, w_embeddings = build_embeddings( model, batch_questions, config, pred2idx, argw2idx, rtype, we) for q in batch_questions: # when calculating the accuracy, we only consider the questions in the middle # so that MCNS and MCNE can have a fair comparison n_q = len(q.ans_idxs) - 1 we_preds, ev_preds = intrinsic.predict_mcns( model, q, e2idx, ev_embeddings, w_embeddings, rtype, args.device, args.inference_model) for i in range(n_q): for emb_idx, preds in [(WE_IDX, we_preds), (EV_IDX, ev_preds)]: if preds[i] == q.ans_idxs[i]: results[emb_idx][MCNS_IDX][CORRECT_IDX] += 1 else: results[emb_idx][MCNS_IDX][INCORRECT_IDX] += 1 we_preds, ev_preds = intrinsic.predict_mcne( model, q, e2idx, ev_embeddings, w_embeddings, rtype, args.device, args.inference_model) for i in range(n_q): for emb_idx, preds in [(WE_IDX, we_preds), (EV_IDX, ev_preds)]: if preds[i] == q.ans_idxs[i]: results[emb_idx][MCNE_IDX][CORRECT_IDX] += 1 else: results[emb_idx][MCNE_IDX][INCORRECT_IDX] += 1 i_q += 1 bar.update(i_q) bar.finish() results = results.type(torch.float32) print("MCNS:") print("\tWE:") print("\t\taccuracy={}".format(results[WE_IDX][MCNS_IDX][CORRECT_IDX] / (results[WE_IDX][MCNS_IDX][CORRECT_IDX] + results[WE_IDX][MCNS_IDX][INCORRECT_IDX]))) print("\tEV:") print("\t\taccuracy={}".format(results[EV_IDX][MCNS_IDX][CORRECT_IDX] / (results[EV_IDX][MCNS_IDX][CORRECT_IDX] + results[EV_IDX][MCNS_IDX][INCORRECT_IDX]))) # print ("\tMIX:") # print("\t\taccuracy={}".format(results[MIX_IDX][MCNS_IDX][CORRECT_IDX]/(results[MIX_IDX][MCNS_IDX][CORRECT_IDX]+results[MIX_IDX][MCNS_IDX][INCORRECT_IDX]))) print("MCNE:") print("\tWE:") print("\t\taccuracy={}".format(results[WE_IDX][MCNE_IDX][CORRECT_IDX] / (results[WE_IDX][MCNE_IDX][CORRECT_IDX] + results[WE_IDX][MCNE_IDX][INCORRECT_IDX]))) print("\tEV:") print("\t\taccuracy={}".format(results[EV_IDX][MCNE_IDX][CORRECT_IDX] / (results[EV_IDX][MCNE_IDX][CORRECT_IDX] + results[EV_IDX][MCNE_IDX][INCORRECT_IDX]))) # print ("\tMIX:") # print("\t\taccuracy={}".format(results[MIX_IDX][MCNE_IDX][CORRECT_IDX]/(results[MIX_IDX][MCNE_IDX][CORRECT_IDX]+results[MIX_IDX][MCNE_IDX][INCORRECT_IDX]))) logging.info("MCNS:") logging.info("\tWE:") logging.info( "\t\taccuracy={}".format(results[WE_IDX][MCNS_IDX][CORRECT_IDX] / (results[WE_IDX][MCNS_IDX][CORRECT_IDX] + results[WE_IDX][MCNS_IDX][INCORRECT_IDX]))) logging.info("\tEV:") logging.info( "\t\taccuracy={}".format(results[EV_IDX][MCNS_IDX][CORRECT_IDX] / (results[EV_IDX][MCNS_IDX][CORRECT_IDX] + results[EV_IDX][MCNS_IDX][INCORRECT_IDX]))) # logging.info("\tMIX:") # logging.info("\t\taccuracy={}".format(results[MIX_IDX][MCNS_IDX][CORRECT_IDX]/(results[MIX_IDX][MCNS_IDX][CORRECT_IDX]+results[MIX_IDX][MCNS_IDX][INCORRECT_IDX]))) logging.info("MCNE:") logging.info("\tWE:") logging.info( "\t\taccuracy={}".format(results[WE_IDX][MCNE_IDX][CORRECT_IDX] / (results[WE_IDX][MCNE_IDX][CORRECT_IDX] + results[WE_IDX][MCNE_IDX][INCORRECT_IDX]))) logging.info("\tEV:") logging.info( "\t\taccuracy={}".format(results[EV_IDX][MCNE_IDX][CORRECT_IDX] / (results[EV_IDX][MCNE_IDX][CORRECT_IDX] + results[EV_IDX][MCNE_IDX][INCORRECT_IDX])))