def main():
    config = json.load(open(args.training_config, 'r'))
    indices.set_relation_classes(args.relation_config)
    pred2idx, idx2pred, _ = indices.load_predicates(
        config['predicate_indices'])
    argw2idx, idx2argw, _ = indices.load_argw(config['argw_indices'])
    n_preds = len(pred2idx)
    argw_vocabs = argw2idx.keys()
    argw_encoder = create_argw_encoder(config, args.device)
    if args.encoder_file:
        argw_encoder.load(args.encoder_file)

    logging.info("model class: " + config['model_type'])
    ModelClass = eval(config['model_type'])
    model = ModelClass(config, argw_encoder, n_preds,
                       args.device).to(args.device)
    model.load_state_dict(
        torch.load(args.model_file,
                   map_location=lambda storage, location: storage))

    questions = pkl.load(open(args.question_file, 'r'))
    logging.info("#questions={}".format(len(questions)))

    n_correct, n_incorrect = 0, 0
    rtype = indices.REL2IDX[
        indices.REL_CONTEXT] if args.context_rel else indices.REL2IDX[
            indices.REL_COREF]
    rtype = torch.LongTensor([rtype]).to(args.device)

    widgets = [progressbar.Percentage(), progressbar.Bar(), progressbar.ETA()]
    bar = progressbar.ProgressBar(widgets=widgets,
                                  maxval=len(questions)).start()

    logging.info("batch_size = {}".format(args.batch_size))
    batch_size = args.batch_size
    n_batches = len(questions) // batch_size + 1
    logging.info("n_batches = {}".format(n_batches))
    i_q = 0
    for i_batch in range(n_batches):
        batch_questions = questions[i_batch * batch_size:(i_batch + 1) *
                                    batch_size]
        e2idx, embeddings = build_embeddings(model, batch_questions, config,
                                             pred2idx, argw2idx, rtype)
        for q in batch_questions:
            pred = intrinsic.predict_mcnc(model, q, e2idx, embeddings, rtype,
                                          args.device)
            if pred == q.ans_idx:
                n_correct += 1
            else:
                n_incorrect += 1
            i_q += 1
            bar.update(i_q)
    bar.finish()
    print("n_correct={}, n_incorrect={}".format(n_correct, n_incorrect))
    print("accuracy={}".format(float(n_correct) / (n_correct + n_incorrect)))
def main():
    # DNEE
    t1 = time.time()
    indices.set_relation_classes(args.relation_config)
    config = json.load(open(args.training_config, 'r'))
    pred2idx, idx2pred, _  = indices.load_predicates(config['predicate_indices'])
    argw2idx, idx2argw, _  = indices.load_argw(config['argw_indices'])
    n_preds = len(pred2idx)
    argw_vocabs = argw2idx.keys()
    argw_encoder = create_argw_encoder(config, args.device)
    argw_encoder.load(args.encoder_file)
    
    logging.info("model class: " + config['model_type'])
    ModelClass = eval(config['model_type'])
    dnee_model = ModelClass(config, argw_encoder, n_preds, args.device).to(args.device)
    dnee_model.load_state_dict(torch.load(args.model_file,
                                    map_location=lambda storage, location: storage))
    logging.info('Loading DNEE: {} s'.format(time.time()-t1))
    
    elmo = Elmo(args.elmo_option_file, args.elmo_weight_file, 1, dropout=0)
    
    # DS_TRAIN_FLD = "ds_train_elmo"
    # DNEE_TRAIN_FLD = "ds_train_transe" if config['model_type'] == 'EventTransE' else 'ds_train_transr_tmp'
    # train_data = ds.DsDataset(DS_TRAIN_FLD, DNEE_TRAIN_FLD)
    # logging.info("DNEE_TRAIN_FLD={}".format(DNEE_TRAIN_FLD))
    # seq_len = train_data.seq_len
    # event_seq_len = train_data.dnee_seq_len
    
    # These are the max seq length from training data (above code)
    # We hardcode them to avoid loading training data
    seq_len, event_seq_len = 392, 14
    logging.info("seq_len={}, event_seq_len={}".format(seq_len, event_seq_len))
    
    t1 = time.time()
    logging.info('dev...')
    fw_path = os.path.join(args.output_folder, 'dev_res.json')
    eval_relations(args.ds_dev_file, fw_path, elmo, dnee_model, config, pred2idx, argw2idx, seq_len, event_seq_len)
    logging.info('Eval DEV: {} s'.format(time.time()-t1))
    
    t1 = time.time()
    logging.info('test...')
    fw_path = os.path.join(args.output_folder, 'test_res.json')
    eval_relations(args.ds_test_file, fw_path, elmo, dnee_model, config, pred2idx, argw2idx, seq_len, event_seq_len)
    logging.info('Eval TEST: {} s'.format(time.time()-t1))
    
    t1 = time.time()
    logging.info('blind...')
    fw_path = os.path.join(args.output_folder, 'blind_res.json')
    eval_relations(args.ds_blind_file, fw_path, elmo, dnee_model, config, pred2idx, argw2idx, seq_len, event_seq_len)
    logging.info('Eval BLIND: {} s'.format(time.time()-t1))
示例#3
0
def main():
    config = json.load(open(args.training_config, 'r'))
    relation_config = json.load(open(args.relation_config, 'r'))
    pred2idx, idx2pred, _ = indices.load_predicates(
        config['predicate_indices'])
    argw2idx, idx2argw, _ = indices.load_argw(config['argw_indices'])
    n_preds = len(pred2idx)
    argw_vocabs = argw2idx.keys()
    argw_encoder = create_argw_encoder(config, args.device)
    if args.encoder_file:
        argw_encoder.load(args.encoder_file)

    logging.info("model class: " + config['model_type'])
    ModelClass = eval(config['model_type'])
    model = ModelClass(config, argw_encoder, n_preds,
                       args.device).to(args.device)
    model.load_state_dict(
        torch.load(args.model_file,
                   map_location=lambda storage, location: storage))

    questions = pkl.load(open(args.question_file, 'r'))
    logging.info("#questions={}".format(len(questions)))

    logging.info('predict relation')
    y, y_pred = eval_by_events(questions, model, config, pred2idx, argw2idx,
                               relation_config)
    logging.info("predict relation, accuracy={}".format(
        accuracy_score(y, y_pred)))
    logging.info("predict relation, accuracy={}".format(acc(y, y_pred)))

    logging.info('predict next event')
    y, y_pred = eval_by_rel(questions, model, config, pred2idx, argw2idx,
                            relation_config)
    logging.info("predict next event, accuracy={}".format(
        accuracy_score(y, y_pred)))
    logging.info("predict next event, accuracy={}".format(acc(y, y_pred)))

    logging.info('predict next event by random rel')
    y, y_pred = eval_by_random_rel(questions, model, config, pred2idx,
                                   argw2idx, relation_config)
    logging.info("by random rel, accuracy={}".format(accuracy_score(y,
                                                                    y_pred)))
    logging.info("by random rel, accuracy={}".format(acc(y, y_pred)))

    logging.info('predict next event by next rel')
    y, y_pred = eval_by_next_rel(questions, model, config, pred2idx, argw2idx,
                                 relation_config)
    logging.info("by next rel, accuracy={}".format(accuracy_score(y, y_pred)))
    logging.info("by next rel, accuracy={}".format(acc(y, y_pred)))
def main():
    config = json.load(open(args.training_config, 'r'))
    indices.set_relation_classes(args.relation_config)

    if args.use_elmo:
        logging.info("using ELMo")
        elmo = Elmo(args.elmo_option_file, args.elmo_weight_file, 1,
                    dropout=0).to(args.device)
    else:
        pred2idx, idx2pred, _ = indices.load_predicates(
            config['predicate_indices'])
        argw2idx, idx2argw, _ = indices.load_argw(config['argw_indices'])
        n_preds = len(pred2idx)
        argw_vocabs = argw2idx.keys()
        argw_encoder = create_argw_encoder(config, args.device)
        if args.encoder_file:
            argw_encoder.load(args.encoder_file)

        logging.info("model class: " + config['model_type'])
        ModelClass = eval(config['model_type'])
        dnee_model = ModelClass(config, argw_encoder, n_preds,
                                args.device).to(args.device)
        dnee_model.load_state_dict(
            torch.load(args.model_file,
                       map_location=lambda storage, location: storage))
    model = elmo if args.use_elmo else dnee_model
    results = torch.zeros((args.n_rounds, len(indices.REL2IDX)),
                          dtype=torch.float32)
    precisions = torch.zeros((args.n_rounds, len(indices.REL2IDX)),
                             dtype=torch.float32)
    recalls = torch.zeros((args.n_rounds, len(indices.REL2IDX)),
                          dtype=torch.float32)
    f1s = torch.zeros((args.n_rounds, len(indices.REL2IDX)),
                      dtype=torch.float32)
    for i_round in range(args.n_rounds):
        logging.info("ROUND {}".format(i_round))

        # dev
        dev_questions = sample_questions(args.dev_question_file,
                                         n_cat_questions=500)
        if args.use_elmo:
            dev_e2idx, dev_ev_embeddings = build_elmo(dev_questions, elmo)
        else:
            dev_e2idx, dev_ev_embeddings = build_ev_embeddings(
                dev_questions, config, pred2idx, argw2idx, dnee_model)
        thresholds = dev_thresholds(dev_questions, model, dev_e2idx,
                                    dev_ev_embeddings, args.step_size)

        # test results
        test_questions = sample_questions(args.test_question_file,
                                          n_cat_questions=500)
        if args.use_elmo:
            test_e2idx, test_ev_embeddings = build_elmo(test_questions, elmo)
        else:
            test_e2idx, test_ev_embeddings = build_ev_embeddings(
                test_questions, config, pred2idx, argw2idx, dnee_model)
        test_scores = score_questions(test_questions, model, test_e2idx,
                                      test_ev_embeddings)
        for i_cat in test_questions.keys():
            y = torch.LongTensor([label for label, q in test_questions[i_cat]
                                  ]).to(args.device)
            acc = pred_acc(y, test_scores[i_cat], thresholds[i_cat])

            if args.use_elmo:
                y_preds = (test_scores[i_cat] > thresholds[i_cat]).type(
                    torch.int64)
            else:
                y_preds = (test_scores[i_cat] < thresholds[i_cat]).type(
                    torch.int64)

            y = y.detach().cpu().numpy()
            y_preds = y_preds.detach().cpu().numpy()

            prec = metrics.precision_score(y, y_preds)
            rec = metrics.recall_score(y, y_preds)
            f1 = metrics.f1_score(y, y_preds)
            logging.info("i_cat={} ({}), test_acc={}".format(
                i_cat, indices.IDX2REL[i_cat], acc))
            logging.info("i_cat={} ({}), test_prec={}".format(
                i_cat, indices.IDX2REL[i_cat], prec))
            logging.info("i_cat={} ({}), test_rec={}".format(
                i_cat, indices.IDX2REL[i_cat], rec))
            results[i_round][i_cat] = acc
            precisions[i_round][i_cat] = prec
            recalls[i_round][i_cat] = rec
            f1s[i_round][i_cat] = f1

    avg = torch.mean(results, dim=0)
    avg_precisions = torch.mean(precisions, dim=0)
    avg_recalls = torch.mean(recalls, dim=0)
    avg_f1s = torch.mean(f1s, dim=0)
    for i_cat in test_questions.keys():
        logging.info("i_cat={} ({}), avg_test_acc={} over {} rounds".format(
            i_cat, indices.IDX2REL[i_cat], avg[i_cat], args.n_rounds))
        logging.info("i_cat={} ({}), avg_test_prec={} over {} rounds".format(
            i_cat, indices.IDX2REL[i_cat], avg_precisions[i_cat],
            args.n_rounds))
        logging.info("i_cat={} ({}), avg_test_rec={} over {} rounds".format(
            i_cat, indices.IDX2REL[i_cat], avg_recalls[i_cat], args.n_rounds))
        logging.info("i_cat={} ({}), avg_test_f1={} over {} rounds".format(
            i_cat, indices.IDX2REL[i_cat], avg_f1s[i_cat], args.n_rounds))
示例#5
0
def main():
    logging.info('using {} for computation.'.format(args.device))
    config = json.load(open(args.training_config, 'r'))

    indices.set_relation_classes(args.relation_config)
    pred2idx, idx2pred, _ = indices.load_predicates(
        config['predicate_indices'])
    argw2idx, idx2argw, _ = indices.load_argw(config['argw_indices'])

    n_preds = len(pred2idx)
    argw_encoder = create_argw_encoder(config, args.device)
    argw_encoder.load(args.encoder_file)

    logging.info("model class: " + config['model_type'])
    ModelClass = eval(config['model_type'])
    dnee_model = ModelClass(config, argw_encoder, n_preds,
                            args.device).to(args.device)
    dnee_model.load_state_dict(
        torch.load(args.model_file,
                   map_location=lambda storage, location: storage))
    dnee_model.eval()

    elmo = Elmo(args.elmo_option_file, args.elmo_weight_file, 1, dropout=0)
    train_data = ds.DsDataset(args.ds_train_fld, args.dnee_train_fld)

    dev_rels = [json.loads(line) for line in open(args.ds_dev_rel_file)]
    dev_rels = [
        rel for rel in dev_rels if rel['Type'] != 'Explicit'
        and rel['Sense'][0] in indices.DISCOURSE_REL2IDX
    ]
    dev_cm, dev_valid_senses = ds.create_cm(dev_rels,
                                            indices.DISCOURSE_REL2IDX)

    dnee_seq_len = train_data.dnee_seq_len if args.dnee_train_fld else None
    x_dev, y_dev = ds.get_features(dev_rels,
                                   elmo,
                                   train_data.seq_len,
                                   dnee_model,
                                   dnee_seq_len,
                                   config,
                                   pred2idx,
                                   argw2idx,
                                   indices.DISCOURSE_REL2IDX,
                                   device=args.device,
                                   use_dnee=(args.dnee_train_fld is not None))
    x0_dev, x1_dev, x0_dnee_dev, x1_dnee_dev, x_dnee_dev = x_dev
    x0_dev, x1_dev = x0_dev.to(args.device), x1_dev.to(args.device)

    model = ds.AttentionNN(len(indices.DISCOURSE_REL2IDX),
                           event_dim=config['event_dim'],
                           dropout=args.dropout,
                           use_event=(args.dnee_train_fld is not None),
                           use_dnee_scores=not args.no_dnee_scores).to(
                               args.device)
    optimizer = optim.Adagrad(filter(lambda p: p.requires_grad,
                                     model.parameters()),
                              lr=args.learning_rate)
    logging.info("initial learning rate = {}".format(args.learning_rate))
    logging.info("dropout rate = {}".format(args.dropout))

    # arg_lens = [config['arg0_max_len'], config['arg1_max_len']]
    losses = []
    dev_f1s = []
    best_dev_f1, best_epoch, best_batch = -1, -1, -1
    logging.info("batch_size = {}".format(args.batch_size))
    for i_epoch in range(args.n_epoches):
        train_loader = DataLoader(train_data,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=1)

        epoch_start = time.time()
        for i_batch, (x, y) in enumerate(train_loader):
            if y.shape[0] != args.batch_size:
                # skip the last batch
                continue
            if args.dnee_train_fld:
                x0, x1, x0_dnee, x1_dnee, x_dnee = x
                x0 = x0.to(args.device)
                x1 = x1.to(args.device)
                x0_dnee = x0_dnee.to(args.device)
                x1_dnee = x1_dnee.to(args.device)
                x_dnee = x_dnee.to(args.device)
            else:
                x0, x1 = x
                x0 = x0.to(args.device)
                x1 = x1.to(args.device)
                x0_dnee, x1_dnee, x_dnee = None, None, None
            y = y.squeeze().to(args.device)

            model.train()
            optimizer.zero_grad()
            out = model(x0, x1, x0_dnee, x1_dnee, x_dnee)
            loss = model.loss_func(out, y)

            # step
            loss.backward()
            optimizer.step()

            losses.append(loss.item())

            model.eval()
            y_pred = model.predict(x0_dev, x1_dev, x0_dnee_dev, x1_dnee_dev,
                                   x_dnee_dev)
            dev_prec, dev_recall, dev_f1 = ds.scoring_cm(
                y_dev, y_pred.cpu(), dev_cm, dev_valid_senses,
                indices.DISCOURSE_IDX2REL)
            dev_f1s.append(dev_f1)

            ## if i_batch % config['n_batches_per_record'] == 0:
            logging.info("{}, {}: loss={}, time={}".format(
                i_epoch, i_batch, loss.item(),
                time.time() - epoch_start))
            logging.info("dev: prec={}, recall={}, f1={}".format(
                dev_prec, dev_recall, dev_f1))
            if dev_f1 > best_dev_f1:
                logging.info("best dev: prec={}, recall={}, f1={}".format(
                    dev_prec, dev_recall, dev_f1))
                best_dev_f1 = dev_f1
                best_epoch = i_epoch
                best_batch = i_batch
                fpath = os.path.join(args.output_folder, 'best_model.pt')
                torch.save(model.state_dict(), fpath)

    logging.info("{}-{}: best dev f1 = {}".format(best_epoch, best_batch,
                                                  best_dev_f1))
    fpath = os.path.join(args.output_folder, "losses.pkl")
    pkl.dump(losses, open(fpath, 'wb'))
    fpath = os.path.join(args.output_folder, "dev_f1s.pkl")
    pkl.dump(dev_f1s, open(fpath, 'wb'))
def main():
    config = json.load(open(args.training_config, 'r'))
    indices.set_relation_classes(args.relation_config)
    pred2idx, idx2pred, _ = indices.load_predicates(
        config['predicate_indices'])
    argw2idx, idx2argw, _ = indices.load_argw(config['argw_indices'])
    n_preds = len(pred2idx)
    argw_vocabs = argw2idx.keys()
    argw_encoder = create_argw_encoder(config, args.device)
    if args.encoder_file:
        argw_encoder.load(args.encoder_file)

    logging.info("model class: " + config['model_type'])
    ModelClass = eval(config['model_type'])
    model = ModelClass(config, argw_encoder, n_preds,
                       args.device).to(args.device)
    model.load_state_dict(
        torch.load(args.model_file,
                   map_location=lambda storage, location: storage))

    we = utils.load_word_embeddings(args.word_embedding_file, use_torch=True)

    questions = pkl.load(open(args.question_file, 'r'))
    if not args.no_subsample:
        # ridxs = list(range(len(questions)))
        # random.shuffle(ridxs)
        # ridxs = [ridxs[i] for i in range(1000)]
        # questions = [questions[i] for i in ridxs]
        questions = questions[:10000]
    logging.info("#questions={}".format(len(questions)))

    rtype = indices.REL2IDX[
        indices.REL_CONTEXT] if args.context_rel else indices.REL2IDX[
            indices.REL_COREF]
    rtype = torch.LongTensor([rtype]).to(args.device)

    widgets = [progressbar.Percentage(), progressbar.Bar(), progressbar.ETA()]
    bar = progressbar.ProgressBar(widgets=widgets,
                                  maxval=len(questions)).start()

    logging.info("batch_size = {}".format(args.batch_size))
    batch_size = args.batch_size
    n_batches = len(questions) // batch_size + 1 if len(
        questions) % batch_size != 0 else len(questions) // batch_size
    logging.info("#questions = {}".format(len(questions)))
    logging.info("n_batches = {}".format(n_batches))
    i_q = 0

    # ((we, ev, mix), (mcns, mcne),(incorrect, correct))
    WE_IDX, EV_IDX, MIX_IDX = 0, 1, 2
    MCNS_IDX, MCNE_IDX = 0, 1
    INCORRECT_IDX, CORRECT_IDX = 0, 1
    results = torch.zeros((3, 2, 2), dtype=torch.int64)
    for i_batch in range(n_batches):
        batch_questions = questions[i_batch * batch_size:(i_batch + 1) *
                                    batch_size]
        e2idx, ev_embeddings, w_embeddings = build_embeddings(
            model, batch_questions, config, pred2idx, argw2idx, rtype, we)
        for q in batch_questions:
            # when calculating the accuracy, we only consider the questions in the middle
            # so that MCNS and MCNE can have a fair comparison
            n_q = len(q.ans_idxs) - 1

            we_preds, ev_preds = intrinsic.predict_mcns(
                model, q, e2idx, ev_embeddings, w_embeddings, rtype,
                args.device, args.inference_model)
            for i in range(n_q):
                for emb_idx, preds in [(WE_IDX, we_preds), (EV_IDX, ev_preds)]:
                    if preds[i] == q.ans_idxs[i]:
                        results[emb_idx][MCNS_IDX][CORRECT_IDX] += 1
                    else:
                        results[emb_idx][MCNS_IDX][INCORRECT_IDX] += 1

            we_preds, ev_preds = intrinsic.predict_mcne(
                model, q, e2idx, ev_embeddings, w_embeddings, rtype,
                args.device, args.inference_model)
            for i in range(n_q):
                for emb_idx, preds in [(WE_IDX, we_preds), (EV_IDX, ev_preds)]:
                    if preds[i] == q.ans_idxs[i]:
                        results[emb_idx][MCNE_IDX][CORRECT_IDX] += 1
                    else:
                        results[emb_idx][MCNE_IDX][INCORRECT_IDX] += 1
            i_q += 1
            bar.update(i_q)
    bar.finish()

    results = results.type(torch.float32)
    print("MCNS:")
    print("\tWE:")
    print("\t\taccuracy={}".format(results[WE_IDX][MCNS_IDX][CORRECT_IDX] /
                                   (results[WE_IDX][MCNS_IDX][CORRECT_IDX] +
                                    results[WE_IDX][MCNS_IDX][INCORRECT_IDX])))
    print("\tEV:")
    print("\t\taccuracy={}".format(results[EV_IDX][MCNS_IDX][CORRECT_IDX] /
                                   (results[EV_IDX][MCNS_IDX][CORRECT_IDX] +
                                    results[EV_IDX][MCNS_IDX][INCORRECT_IDX])))
    # print ("\tMIX:")
    # print("\t\taccuracy={}".format(results[MIX_IDX][MCNS_IDX][CORRECT_IDX]/(results[MIX_IDX][MCNS_IDX][CORRECT_IDX]+results[MIX_IDX][MCNS_IDX][INCORRECT_IDX])))

    print("MCNE:")
    print("\tWE:")
    print("\t\taccuracy={}".format(results[WE_IDX][MCNE_IDX][CORRECT_IDX] /
                                   (results[WE_IDX][MCNE_IDX][CORRECT_IDX] +
                                    results[WE_IDX][MCNE_IDX][INCORRECT_IDX])))
    print("\tEV:")
    print("\t\taccuracy={}".format(results[EV_IDX][MCNE_IDX][CORRECT_IDX] /
                                   (results[EV_IDX][MCNE_IDX][CORRECT_IDX] +
                                    results[EV_IDX][MCNE_IDX][INCORRECT_IDX])))
    # print ("\tMIX:")
    # print("\t\taccuracy={}".format(results[MIX_IDX][MCNE_IDX][CORRECT_IDX]/(results[MIX_IDX][MCNE_IDX][CORRECT_IDX]+results[MIX_IDX][MCNE_IDX][INCORRECT_IDX])))

    logging.info("MCNS:")
    logging.info("\tWE:")
    logging.info(
        "\t\taccuracy={}".format(results[WE_IDX][MCNS_IDX][CORRECT_IDX] /
                                 (results[WE_IDX][MCNS_IDX][CORRECT_IDX] +
                                  results[WE_IDX][MCNS_IDX][INCORRECT_IDX])))
    logging.info("\tEV:")
    logging.info(
        "\t\taccuracy={}".format(results[EV_IDX][MCNS_IDX][CORRECT_IDX] /
                                 (results[EV_IDX][MCNS_IDX][CORRECT_IDX] +
                                  results[EV_IDX][MCNS_IDX][INCORRECT_IDX])))
    # logging.info("\tMIX:")
    # logging.info("\t\taccuracy={}".format(results[MIX_IDX][MCNS_IDX][CORRECT_IDX]/(results[MIX_IDX][MCNS_IDX][CORRECT_IDX]+results[MIX_IDX][MCNS_IDX][INCORRECT_IDX])))

    logging.info("MCNE:")
    logging.info("\tWE:")
    logging.info(
        "\t\taccuracy={}".format(results[WE_IDX][MCNE_IDX][CORRECT_IDX] /
                                 (results[WE_IDX][MCNE_IDX][CORRECT_IDX] +
                                  results[WE_IDX][MCNE_IDX][INCORRECT_IDX])))
    logging.info("\tEV:")
    logging.info(
        "\t\taccuracy={}".format(results[EV_IDX][MCNE_IDX][CORRECT_IDX] /
                                 (results[EV_IDX][MCNE_IDX][CORRECT_IDX] +
                                  results[EV_IDX][MCNE_IDX][INCORRECT_IDX])))