예제 #1
0
def main(args):
    data_dir = "dl/" + args.data_dir
    data_dir = os.path.abspath(data_dir)
    #if not os.path.isdir(data_dir):
    #    os.mkdir(data_dir)

    #train_path = os.path.join(data_dir, 'trainingPosit_chem')
    #dev_path = os.path.join(data_dir, 'developPosit_chem')
    #test_path = os.path.join(data_dir)
    test_path = data_dir
    #chemprot_train_data = load_chemprot(train_path)
    #chemprot_dev_data = load_chemprot(dev_path)
    print(test_path)
    chemprot_test_data = load_chemprot(test_path, header=False)
    #logger.info('Loaded {} ChemProt train samples'.format(len(chemprot_train_data)))
    #logger.info('Loaded {} ChemProt dev samples'.format(len(chemprot_dev_data)))
    #logger.info('Loaded {} ChemProt test samples'.format(len(chemprot_test_data)))

    canonical_data_root = "dl/" + args.output_dir
    if not os.path.isdir(canonical_data_root):
        os.mkdir(canonical_data_root)
    #chemprot_train_fout = os.path.join(canonical_data_root, 'chemprot_new_train.tsv')
    #chemprot_dev_fout = os.path.join(canonical_data_root, 'chemprot_new_dev.tsv')
    chemprot_test_fout = os.path.join(canonical_data_root,
                                      'chemprotnew_test.tsv')

    #dump_rows(chemprot_train_data, chemprot_train_fout, DataFormat.PremiseOnly)
    #dump_rows(chemprot_dev_data, chemprot_dev_fout, DataFormat.PremiseOnly)
    dump_rows(chemprot_test_data, chemprot_test_fout, DataFormat.PremiseOnly)
예제 #2
0
def main(args):
    root = args.root_dir
    assert os.path.exists(root)

    ######################################
    # XNLI/SciTail Tasks
    ######################################

    xnli_dev_path = os.path.join(root, 'XNLI/xnli.dev.tsv')
    xnli_test_path = os.path.join(root, 'XNLI/xnli.test.tsv')

    ######################################
    # Loading DATA
    ######################################

    xnli_dev_data = load_xnli(xnli_dev_path)
    xnli_test_data = load_xnli(xnli_test_path)
    logger.info('Loaded {} XNLI train samples'.format(len(xnli_dev_data)))
    logger.info('Loaded {} XNLI test samples'.format(len(xnli_test_data)))

    canonical_data_suffix = "canonical_data"
    canonical_data_root = os.path.join(root, canonical_data_suffix)
    if not os.path.isdir(canonical_data_root):
        os.mkdir(canonical_data_root)

    # BUILD XNLI
    xnli_dev_fout = os.path.join(canonical_data_root, 'xnli_dev.tsv')
    xnli_test_fout = os.path.join(canonical_data_root, 'xnli_test.tsv')
    dump_rows(xnli_dev_data, xnli_dev_fout, DataFormat.PremiseAndOneHypothesis)
    dump_rows(xnli_test_data, xnli_test_fout,
              DataFormat.PremiseAndOneHypothesis)
    logger.info('done with XNLI')
예제 #3
0
def main(args):
    data_dir = args.data_dir
    #data_dir = os.path.abspath(data_dir)
    #if not os.path.isdir(data_dir):
    #    os.mkdir(data_dir)
    print(data_dir)
    #ddi_train_path = os.path.join(data_dir, 'ddi2013-type/train.tsv')
    #ddi_dev_path = os.path.join(data_dir, 'ddi2013-type/dev.tsv')
    #ddi_test_path = os.path.join(data_dir)
    ddi_test_path = "dl/" + data_dir
    #ddi_train_data = load_ddi(ddi_train_path)
    #ddi_dev_data = load_ddi(ddi_dev_path)
    ddi_test_data = load_ddi(ddi_test_path, header=False)
    #logger.info('Loaded {} ddi2013-type train samples'.format(len(ddi_train_data)))
    #logger.info('Loaded {} ddi2013-type dev samples'.format(len(ddi_dev_data)))
    logger.info('Loaded {} ddi2013-type test samples'.format(
        len(ddi_test_data)))

    bert_root = "dl/" + args.output_dir
    #if not os.path.isdir(bert_root):
    #    os.mkdir(bert_root)

    #ddi_train_fout = os.path.join(bert_root, 'ddi_train.tsv')
    #ddi_dev_fout = os.path.join(bert_root, 'ddi_dev.tsv')
    ddi_test_fout = os.path.join(bert_root, 'ddi_test.tsv')

    #dump_rows(ddi_train_data, ddi_train_fout, DataFormat.PremiseOnly)
    #dump_rows(ddi_dev_data, ddi_dev_fout, DataFormat.PremiseOnly)
    dump_rows(ddi_test_data, ddi_test_fout, DataFormat.PremiseOnly)
    logger.info('done with ddi2013-type')
예제 #4
0
def main(args):
    root = args.root_dir
    assert os.path.exists(root)

    squad_train_path = os.path.join(root, 'squad/train.json')
    squad_dev_path = os.path.join(root, 'squad/dev.json')
    squad_v2_train_path = os.path.join(root, 'squad_v2/train.json')
    squad_v2_dev_path = os.path.join(root, 'squad_v2/dev.json')

    squad_train_data = load_data(squad_train_path)
    squad_dev_data = load_data(squad_dev_path, is_train=False)
    logger.info('Loaded {} squad train samples'.format(len(squad_train_data)))
    logger.info('Loaded {} squad dev samples'.format(len(squad_dev_data)))

    squad_v2_train_data = load_data(squad_v2_train_path, v2_on=True)
    squad_v2_dev_data = load_data(squad_v2_dev_path, is_train=False, v2_on=True)
    logger.info('Loaded {} squad_v2 train samples'.format(len(squad_v2_train_data)))
    logger.info('Loaded {} squad_v2 dev samples'.format(len(squad_v2_dev_data)))

    canonical_data_suffix = "canonical_data"
    canonical_data_root = os.path.join(root, canonical_data_suffix)
    if not os.path.isdir(canonical_data_root):
        os.mkdir(canonical_data_root)

    squad_train_fout = os.path.join(canonical_data_root, 'squad_train.tsv')
    squad_dev_fout = os.path.join(canonical_data_root, 'squad_dev.tsv')
    dump_rows(squad_train_data, squad_train_fout, DataFormat.PremiseAndOneHypothesis)
    dump_rows(squad_dev_data, squad_dev_fout, DataFormat.PremiseAndOneHypothesis)
    logger.info('done with squad')

    squad_v2_train_fout = os.path.join(canonical_data_root, 'squad-v2_train.tsv')
    squad_v2_dev_fout = os.path.join(canonical_data_root, 'squad-v2_dev.tsv')
    dump_rows(squad_v2_train_data, squad_v2_train_fout, DataFormat.PremiseAndOneHypothesis)
    dump_rows(squad_v2_dev_data, squad_v2_dev_fout, DataFormat.PremiseAndOneHypothesis)
    logger.info('done with squad_v2')
예제 #5
0
def main(args):
    data_dir = args.data_dir
    #data_dir = os.path.abspath(data_dir)
    #if not os.path.isdir(data_dir):
    #    os.mkdir(data_dir)
    test_path = os.path.join("dl/" + data_dir)

    #gad_train_data = load_gad(train_path, header=False, train=True)
    gad_test_data = load_gad(test_path, header=False, train=False)
    #print(gad_test_data)
    canonical_data_root = "dl/" + args.output_dir
    if not os.path.isdir(canonical_data_root):
        os.mkdir(canonical_data_root)
    #gad_train_fout = os.path.join(canonical_data_root, 'gad'+number+'_train.tsv')
    gad_test_fout = os.path.join(canonical_data_root, 'gad_test.tsv')

    #dump_rows(gad_train_data, gad_train_fout, DataFormat.PremiseOnly)
    dump_rows(gad_test_data, gad_test_fout, DataFormat.PremiseOnly)
예제 #6
0
def main(args):
    is_old_glue = args.old_glue
    root = args.root_dir
    assert os.path.exists(root)

    ######################################
    # SNLI/SciTail Tasks
    ######################################
    scitail_train_path = os.path.join(root, 'SciTail/tsv_format/scitail_1.0_train.tsv')
    scitail_dev_path = os.path.join(root, 'SciTail/tsv_format/scitail_1.0_dev.tsv')
    scitail_test_path = os.path.join(root, 'SciTail/tsv_format/scitail_1.0_test.tsv')

    snli_train_path = os.path.join(root, 'SNLI/train.tsv')
    snli_dev_path = os.path.join(root, 'SNLI/dev.tsv')
    snli_test_path = os.path.join(root, 'SNLI/test.tsv')

    ######################################
    # GLUE tasks
    ######################################
    multi_train_path = os.path.join(root, 'MNLI/train.tsv')
    multi_dev_matched_path = os.path.join(root, 'MNLI/dev_matched.tsv')
    multi_dev_mismatched_path = os.path.join(root, 'MNLI/dev_mismatched.tsv')
    multi_test_matched_path = os.path.join(root, 'MNLI/test_matched.tsv')
    multi_test_mismatched_path = os.path.join(root, 'MNLI/test_mismatched.tsv')

    mrpc_train_path = os.path.join(root, 'MRPC/train.tsv')
    mrpc_dev_path = os.path.join(root, 'MRPC/dev.tsv')
    mrpc_test_path = os.path.join(root, 'MRPC/test.tsv')

    qnli_train_path = os.path.join(root, 'QNLI/train.tsv')
    qnli_dev_path = os.path.join(root, 'QNLI/dev.tsv')
    qnli_test_path = os.path.join(root, 'QNLI/test.tsv')

    qqp_train_path = os.path.join(root, 'QQP/train.tsv')
    qqp_dev_path = os.path.join(root, 'QQP/dev.tsv')
    qqp_test_path = os.path.join(root, 'QQP/test.tsv')

    rte_train_path = os.path.join(root, 'RTE/train.tsv')
    rte_dev_path = os.path.join(root, 'RTE/dev.tsv')
    rte_test_path = os.path.join(root, 'RTE/test.tsv')

    wnli_train_path = os.path.join(root, 'WNLI/train.tsv')
    wnli_dev_path = os.path.join(root, 'WNLI/dev.tsv')
    wnli_test_path = os.path.join(root, 'WNLI/test.tsv')

    stsb_train_path = os.path.join(root, 'STS-B/train.tsv')
    stsb_dev_path = os.path.join(root, 'STS-B/dev.tsv')
    stsb_test_path = os.path.join(root, 'STS-B/test.tsv')

    sst_train_path = os.path.join(root, 'SST-2/train.tsv')
    sst_dev_path = os.path.join(root, 'SST-2/dev.tsv')
    sst_test_path = os.path.join(root, 'SST-2/test.tsv')

    cola_train_path = os.path.join(root, 'CoLA/train.tsv')
    cola_dev_path = os.path.join(root, 'CoLA/dev.tsv')
    cola_test_path = os.path.join(root, 'CoLA/test.tsv')

    ######################################
    # Loading DATA
    ######################################
    scitail_train_data = load_scitail(scitail_train_path)
    scitail_dev_data = load_scitail(scitail_dev_path)
    scitail_test_data = load_scitail(scitail_test_path)
    logger.info('Loaded {} SciTail train samples'.format(len(scitail_train_data)))
    logger.info('Loaded {} SciTail dev samples'.format(len(scitail_dev_data)))
    logger.info('Loaded {} SciTail test samples'.format(len(scitail_test_data)))

    snli_train_data = load_snli(snli_train_path)
    snli_dev_data = load_snli(snli_dev_path)
    snli_test_data = load_snli(snli_test_path)
    logger.info('Loaded {} SNLI train samples'.format(len(snli_train_data)))
    logger.info('Loaded {} SNLI dev samples'.format(len(snli_dev_data)))
    logger.info('Loaded {} SNLI test samples'.format(len(snli_test_data)))

    multinli_train_data = load_mnli(multi_train_path)
    multinli_matched_dev_data = load_mnli(multi_dev_matched_path)
    multinli_mismatched_dev_data = load_mnli(multi_dev_mismatched_path)
    multinli_matched_test_data = load_mnli(multi_test_matched_path, is_train=False)
    multinli_mismatched_test_data = load_mnli(multi_test_mismatched_path, is_train=False)

    logger.info('Loaded {} MNLI train samples'.format(len(multinli_train_data)))
    logger.info('Loaded {} MNLI matched dev samples'.format(len(multinli_matched_dev_data)))
    logger.info('Loaded {} MNLI mismatched dev samples'.format(len(multinli_mismatched_dev_data)))
    logger.info('Loaded {} MNLI matched test samples'.format(len(multinli_matched_test_data)))
    logger.info('Loaded {} MNLI mismatched test samples'.format(len(multinli_mismatched_test_data)))

    mrpc_train_data = load_mrpc(mrpc_train_path)
    mrpc_dev_data = load_mrpc(mrpc_dev_path)
    mrpc_test_data = load_mrpc(mrpc_test_path, is_train=False)
    logger.info('Loaded {} MRPC train samples'.format(len(mrpc_train_data)))
    logger.info('Loaded {} MRPC dev samples'.format(len(mrpc_dev_data)))
    logger.info('Loaded {} MRPC test samples'.format(len(mrpc_test_data)))

    qnli_train_data = load_qnli(qnli_train_path)
    qnli_dev_data = load_qnli(qnli_dev_path)
    qnli_test_data = load_qnli(qnli_test_path, is_train=False)
    logger.info('Loaded {} QNLI train samples'.format(len(qnli_train_data)))
    logger.info('Loaded {} QNLI dev samples'.format(len(qnli_dev_data)))
    logger.info('Loaded {} QNLI test samples'.format(len(qnli_test_data)))

    if is_old_glue:
        random.seed(args.seed)
        qnnli_train_data = load_qnnli(qnli_train_path)
        qnnli_dev_data = load_qnnli(qnli_dev_path)
        qnnli_test_data = load_qnnli(qnli_test_path, is_train=False)
        logger.info('Loaded {} QNLI train samples'.format(len(qnnli_train_data)))
        logger.info('Loaded {} QNLI dev samples'.format(len(qnnli_dev_data)))
        logger.info('Loaded {} QNLI test samples'.format(len(qnnli_test_data)))

    qqp_train_data = load_qqp(qqp_train_path)
    qqp_dev_data = load_qqp(qqp_dev_path)
    qqp_test_data = load_qqp(qqp_test_path, is_train=False)
    logger.info('Loaded {} QQP train samples'.format(len(qqp_train_data)))
    logger.info('Loaded {} QQP dev samples'.format(len(qqp_dev_data)))
    logger.info('Loaded {} QQP test samples'.format(len(qqp_test_data)))

    rte_train_data = load_rte(rte_train_path)
    rte_dev_data = load_rte(rte_dev_path)
    rte_test_data = load_rte(rte_test_path, is_train=False)
    logger.info('Loaded {} RTE train samples'.format(len(rte_train_data)))
    logger.info('Loaded {} RTE dev samples'.format(len(rte_dev_data)))
    logger.info('Loaded {} RTE test samples'.format(len(rte_test_data)))

    wnli_train_data = load_wnli(wnli_train_path)
    wnli_dev_data = load_wnli(wnli_dev_path)
    wnli_test_data = load_wnli(wnli_test_path, is_train=False)
    logger.info('Loaded {} WNLI train samples'.format(len(wnli_train_data)))
    logger.info('Loaded {} WNLI dev samples'.format(len(wnli_dev_data)))
    logger.info('Loaded {} WNLI test samples'.format(len(wnli_test_data)))

    sst_train_data = load_sst(sst_train_path)
    sst_dev_data = load_sst(sst_dev_path)
    sst_test_data = load_sst(sst_test_path, is_train=False)
    logger.info('Loaded {} SST train samples'.format(len(sst_train_data)))
    logger.info('Loaded {} SST dev samples'.format(len(sst_dev_data)))
    logger.info('Loaded {} SST test samples'.format(len(sst_test_data)))

    cola_train_data = load_cola(cola_train_path, header=False)
    cola_dev_data = load_cola(cola_dev_path, header=False)
    cola_test_data = load_cola(cola_test_path, is_train=False)
    logger.info('Loaded {} COLA train samples'.format(len(cola_train_data)))
    logger.info('Loaded {} COLA dev samples'.format(len(cola_dev_data)))
    logger.info('Loaded {} COLA test samples'.format(len(cola_test_data)))

    stsb_train_data = load_sts(stsb_train_path)
    stsb_dev_data = load_sts(stsb_dev_path)
    stsb_test_data = load_sts(stsb_test_path, is_train=False)
    logger.info('Loaded {} STS-B train samples'.format(len(stsb_train_data)))
    logger.info('Loaded {} STS-B dev samples'.format(len(stsb_dev_data)))
    logger.info('Loaded {} STS-B test samples'.format(len(stsb_test_data)))

    canonical_data_suffix = "canonical_data"
    canonical_data_root = os.path.join(root, canonical_data_suffix)
    if not os.path.isdir(canonical_data_root):
        os.mkdir(canonical_data_root)

    # BUILD SciTail
    scitail_train_fout = os.path.join(canonical_data_root, 'scitail_train.tsv')
    scitail_dev_fout = os.path.join(canonical_data_root, 'scitail_dev.tsv')
    scitail_test_fout = os.path.join(canonical_data_root, 'scitail_test.tsv')
    dump_rows(scitail_train_data, scitail_train_fout)
    dump_rows(scitail_dev_data, scitail_dev_fout)
    dump_rows(scitail_test_data, scitail_test_fout)
    logger.info('done with scitail')

    # BUILD SNLI
    snli_train_fout = os.path.join(canonical_data_root, 'snli_train.tsv')
    snli_dev_fout = os.path.join(canonical_data_root, 'snli_dev.tsv')
    snli_test_fout = os.path.join(canonical_data_root, 'snli_test.tsv')
    dump_rows(snli_train_data, snli_train_fout)
    dump_rows(snli_dev_data, snli_dev_fout)
    dump_rows(snli_test_data, snli_test_fout)
    logger.info('done with snli')

    # BUILD MNLI
    multinli_train_fout = os.path.join(canonical_data_root, 'mnli_train.tsv')
    multinli_matched_dev_fout = os.path.join(canonical_data_root, 'mnli_matched_dev.tsv')
    multinli_mismatched_dev_fout = os.path.join(canonical_data_root, 'mnli_mismatched_dev.tsv')
    multinli_matched_test_fout = os.path.join(canonical_data_root, 'mnli_matched_test.tsv')
    multinli_mismatched_test_fout = os.path.join(canonical_data_root, 'mnli_mismatched_test.tsv')
    dump_rows(multinli_train_data, multinli_train_fout)
    dump_rows(multinli_matched_dev_data, multinli_matched_dev_fout)
    dump_rows(multinli_mismatched_dev_data, multinli_mismatched_dev_fout)
    dump_rows(multinli_matched_test_data, multinli_matched_test_fout)
    dump_rows(multinli_mismatched_test_data, multinli_mismatched_test_fout)
    logger.info('done with mnli')

    mrpc_train_fout = os.path.join(canonical_data_root, 'mrpc_train.tsv')
    mrpc_dev_fout = os.path.join(canonical_data_root, 'mrpc_dev.tsv')
    mrpc_test_fout = os.path.join(canonical_data_root, 'mrpc_test.tsv')
    dump_rows(mrpc_train_data, mrpc_train_fout)
    dump_rows(mrpc_dev_data, mrpc_dev_fout)
    dump_rows(mrpc_test_data, mrpc_test_fout)
    logger.info('done with mrpc')

    qnli_train_fout = os.path.join(canonical_data_root, 'qnli_train.tsv')
    qnli_dev_fout = os.path.join(canonical_data_root, 'qnli_dev.tsv')
    qnli_test_fout = os.path.join(canonical_data_root, 'qnli_test.tsv')
    dump_rows(qnli_train_data, qnli_train_fout)
    dump_rows(qnli_dev_data, qnli_dev_fout)
    dump_rows(qnli_test_data, qnli_test_fout)
    logger.info('done with qnli')

    if is_old_glue:
        qnli_train_fout = os.path.join(canonical_data_root, 'qnnli_train.tsv')
        qnli_dev_fout = os.path.join(canonical_data_root, 'qnnli_dev.tsv')
        qnli_test_fout = os.path.join(canonical_data_root, 'qnnli_test.tsv')
        dump_rows(qnnli_train_data, qnli_train_fout)
        dump_rows(qnnli_dev_data, qnli_dev_fout)
        dump_rows(qnnli_train_data, qnli_test_fout)
        logger.info('done with qnli')

    qqp_train_fout = os.path.join(canonical_data_root, 'qqp_train.tsv')
    qqp_dev_fout = os.path.join(canonical_data_root, 'qqp_dev.tsv')
    qqp_test_fout = os.path.join(canonical_data_root, 'qqp_test.tsv')
    dump_rows(qqp_train_data, qqp_train_fout)
    dump_rows(qqp_dev_data, qqp_dev_fout)
    dump_rows(qqp_test_data, qqp_test_fout)
    logger.info('done with qqp')

    rte_train_fout = os.path.join(canonical_data_root, 'rte_train.tsv')
    rte_dev_fout = os.path.join(canonical_data_root, 'rte_dev.tsv')
    rte_test_fout = os.path.join(canonical_data_root, 'rte_test.tsv')
    dump_rows(rte_train_data, rte_train_fout)
    dump_rows(rte_dev_data, rte_dev_fout)
    dump_rows(rte_test_data, rte_test_fout)
    logger.info('done with rte')

    wnli_train_fout = os.path.join(canonical_data_root, 'wnli_train.tsv')
    wnli_dev_fout = os.path.join(canonical_data_root, 'wnli_dev.tsv')
    wnli_test_fout = os.path.join(canonical_data_root, 'wnli_test.tsv')
    dump_rows(wnli_train_data, wnli_train_fout)
    dump_rows(wnli_dev_data, wnli_dev_fout)
    dump_rows(wnli_test_data, wnli_test_fout)
    logger.info('done with wnli')

    sst_train_fout = os.path.join(canonical_data_root, 'sst_train.tsv')
    sst_dev_fout = os.path.join(canonical_data_root, 'sst_dev.tsv')
    sst_test_fout = os.path.join(canonical_data_root, 'sst_test.tsv')
    dump_rows(sst_train_data, sst_train_fout)
    dump_rows(sst_dev_data, sst_dev_fout)
    dump_rows(sst_test_data, sst_test_fout)
    logger.info('done with sst')

    cola_train_fout = os.path.join(canonical_data_root, 'cola_train.tsv')
    cola_dev_fout = os.path.join(canonical_data_root, 'cola_dev.tsv')
    cola_test_fout = os.path.join(canonical_data_root, 'cola_test.tsv')
    dump_rows(cola_train_data, cola_train_fout)
    dump_rows(cola_dev_data, cola_dev_fout)
    dump_rows(cola_test_data, cola_test_fout)
    logger.info('done with cola')

    stsb_train_fout = os.path.join(canonical_data_root, 'stsb_train.tsv')
    stsb_dev_fout = os.path.join(canonical_data_root, 'stsb_dev.tsv')
    stsb_test_fout = os.path.join(canonical_data_root, 'stsb_test.tsv')
    dump_rows(stsb_train_data, stsb_train_fout)
    dump_rows(stsb_dev_data, stsb_dev_fout)
    dump_rows(stsb_test_data, stsb_test_fout)
    logger.info('done with stsb')
예제 #7
0
def main(args):
    is_old_glue = args.old_glue
    root = args.root_dir
    assert os.path.exists(root)

    ######################################
    # GLUE tasks
    ######################################
    multi_train_path = os.path.join(root, "MNLI/train.tsv")
    multi_dev_matched_path = os.path.join(root, "MNLI/dev_matched.tsv")
    multi_dev_mismatched_path = os.path.join(root, "MNLI/dev_mismatched.tsv")
    multi_test_matched_path = os.path.join(root, "MNLI/test_matched.tsv")
    multi_test_mismatched_path = os.path.join(root, "MNLI/test_mismatched.tsv")

    mrpc_train_path = os.path.join(root, "MRPC/train.tsv")
    mrpc_dev_path = os.path.join(root, "MRPC/dev.tsv")
    mrpc_test_path = os.path.join(root, "MRPC/test.tsv")

    qnli_train_path = os.path.join(root, "QNLI/train.tsv")
    qnli_dev_path = os.path.join(root, "QNLI/dev.tsv")
    qnli_test_path = os.path.join(root, "QNLI/test.tsv")

    qqp_train_path = os.path.join(root, "QQP/train.tsv")
    qqp_dev_path = os.path.join(root, "QQP/dev.tsv")
    qqp_test_path = os.path.join(root, "QQP/test.tsv")

    rte_train_path = os.path.join(root, "RTE/train.tsv")
    rte_dev_path = os.path.join(root, "RTE/dev.tsv")
    rte_test_path = os.path.join(root, "RTE/test.tsv")

    wnli_train_path = os.path.join(root, "WNLI/train.tsv")
    wnli_dev_path = os.path.join(root, "WNLI/dev.tsv")
    wnli_test_path = os.path.join(root, "WNLI/test.tsv")

    stsb_train_path = os.path.join(root, "STS-B/train.tsv")
    stsb_dev_path = os.path.join(root, "STS-B/dev.tsv")
    stsb_test_path = os.path.join(root, "STS-B/test.tsv")

    sst_train_path = os.path.join(root, "SST-2/train.tsv")
    sst_dev_path = os.path.join(root, "SST-2/dev.tsv")
    sst_test_path = os.path.join(root, "SST-2/test.tsv")

    cola_train_path = os.path.join(root, "CoLA/train.tsv")
    cola_dev_path = os.path.join(root, "CoLA/dev.tsv")
    cola_test_path = os.path.join(root, "CoLA/test.tsv")

    ######################################
    # Loading DATA
    ######################################

    multinli_train_data = load_mnli(multi_train_path)
    multinli_matched_dev_data = load_mnli(multi_dev_matched_path)
    multinli_mismatched_dev_data = load_mnli(multi_dev_mismatched_path)
    multinli_matched_test_data = load_mnli(multi_test_matched_path,
                                           is_train=False)
    multinli_mismatched_test_data = load_mnli(multi_test_mismatched_path,
                                              is_train=False)

    logger.info("Loaded {} MNLI train samples".format(
        len(multinli_train_data)))
    logger.info("Loaded {} MNLI matched dev samples".format(
        len(multinli_matched_dev_data)))
    logger.info("Loaded {} MNLI mismatched dev samples".format(
        len(multinli_mismatched_dev_data)))
    logger.info("Loaded {} MNLI matched test samples".format(
        len(multinli_matched_test_data)))
    logger.info("Loaded {} MNLI mismatched test samples".format(
        len(multinli_mismatched_test_data)))

    mrpc_train_data = load_mrpc(mrpc_train_path)
    mrpc_dev_data = load_mrpc(mrpc_dev_path)
    mrpc_test_data = load_mrpc(mrpc_test_path, is_train=False)
    logger.info("Loaded {} MRPC train samples".format(len(mrpc_train_data)))
    logger.info("Loaded {} MRPC dev samples".format(len(mrpc_dev_data)))
    logger.info("Loaded {} MRPC test samples".format(len(mrpc_test_data)))

    qnli_train_data = load_qnli(qnli_train_path)
    qnli_dev_data = load_qnli(qnli_dev_path)
    qnli_test_data = load_qnli(qnli_test_path, is_train=False)
    logger.info("Loaded {} QNLI train samples".format(len(qnli_train_data)))
    logger.info("Loaded {} QNLI dev samples".format(len(qnli_dev_data)))
    logger.info("Loaded {} QNLI test samples".format(len(qnli_test_data)))

    if is_old_glue:
        random.seed(args.seed)
        qnnli_train_data = load_qnnli(qnli_train_path)
        qnnli_dev_data = load_qnnli(qnli_dev_path)
        qnnli_test_data = load_qnnli(qnli_test_path, is_train=False)
        logger.info("Loaded {} QNLI train samples".format(
            len(qnnli_train_data)))
        logger.info("Loaded {} QNLI dev samples".format(len(qnnli_dev_data)))
        logger.info("Loaded {} QNLI test samples".format(len(qnnli_test_data)))

    qqp_train_data = load_qqp(qqp_train_path)
    qqp_dev_data = load_qqp(qqp_dev_path)
    qqp_test_data = load_qqp(qqp_test_path, is_train=False)
    logger.info("Loaded {} QQP train samples".format(len(qqp_train_data)))
    logger.info("Loaded {} QQP dev samples".format(len(qqp_dev_data)))
    logger.info("Loaded {} QQP test samples".format(len(qqp_test_data)))

    rte_train_data = load_rte(rte_train_path)
    rte_dev_data = load_rte(rte_dev_path)
    rte_test_data = load_rte(rte_test_path, is_train=False)
    logger.info("Loaded {} RTE train samples".format(len(rte_train_data)))
    logger.info("Loaded {} RTE dev samples".format(len(rte_dev_data)))
    logger.info("Loaded {} RTE test samples".format(len(rte_test_data)))

    wnli_train_data = load_wnli(wnli_train_path)
    wnli_dev_data = load_wnli(wnli_dev_path)
    wnli_test_data = load_wnli(wnli_test_path, is_train=False)
    logger.info("Loaded {} WNLI train samples".format(len(wnli_train_data)))
    logger.info("Loaded {} WNLI dev samples".format(len(wnli_dev_data)))
    logger.info("Loaded {} WNLI test samples".format(len(wnli_test_data)))

    sst_train_data = load_sst(sst_train_path)
    sst_dev_data = load_sst(sst_dev_path)
    sst_test_data = load_sst(sst_test_path, is_train=False)
    logger.info("Loaded {} SST train samples".format(len(sst_train_data)))
    logger.info("Loaded {} SST dev samples".format(len(sst_dev_data)))
    logger.info("Loaded {} SST test samples".format(len(sst_test_data)))

    cola_train_data = load_cola(cola_train_path, header=False)
    cola_dev_data = load_cola(cola_dev_path, header=False)
    cola_test_data = load_cola(cola_test_path, is_train=False)
    logger.info("Loaded {} COLA train samples".format(len(cola_train_data)))
    logger.info("Loaded {} COLA dev samples".format(len(cola_dev_data)))
    logger.info("Loaded {} COLA test samples".format(len(cola_test_data)))

    stsb_train_data = load_stsb(stsb_train_path)
    stsb_dev_data = load_stsb(stsb_dev_path)
    stsb_test_data = load_stsb(stsb_test_path, is_train=False)
    logger.info("Loaded {} STS-B train samples".format(len(stsb_train_data)))
    logger.info("Loaded {} STS-B dev samples".format(len(stsb_dev_data)))
    logger.info("Loaded {} STS-B test samples".format(len(stsb_test_data)))

    canonical_data_suffix = "canonical_data"
    canonical_data_root = os.path.join(root, canonical_data_suffix)
    if not os.path.isdir(canonical_data_root):
        os.mkdir(canonical_data_root)

    # BUILD MNLI
    multinli_train_fout = os.path.join(canonical_data_root, "mnli_train.tsv")
    multinli_matched_dev_fout = os.path.join(canonical_data_root,
                                             "mnli_matched_dev.tsv")
    multinli_mismatched_dev_fout = os.path.join(canonical_data_root,
                                                "mnli_mismatched_dev.tsv")
    multinli_matched_test_fout = os.path.join(canonical_data_root,
                                              "mnli_matched_test.tsv")
    multinli_mismatched_test_fout = os.path.join(canonical_data_root,
                                                 "mnli_mismatched_test.tsv")
    dump_rows(multinli_train_data, multinli_train_fout,
              DataFormat.PremiseAndOneHypothesis)
    dump_rows(
        multinli_matched_dev_data,
        multinli_matched_dev_fout,
        DataFormat.PremiseAndOneHypothesis,
    )
    dump_rows(
        multinli_mismatched_dev_data,
        multinli_mismatched_dev_fout,
        DataFormat.PremiseAndOneHypothesis,
    )
    dump_rows(
        multinli_matched_test_data,
        multinli_matched_test_fout,
        DataFormat.PremiseAndOneHypothesis,
    )
    dump_rows(
        multinli_mismatched_test_data,
        multinli_mismatched_test_fout,
        DataFormat.PremiseAndOneHypothesis,
    )
    logger.info("done with mnli")

    mrpc_train_fout = os.path.join(canonical_data_root, "mrpc_train.tsv")
    mrpc_dev_fout = os.path.join(canonical_data_root, "mrpc_dev.tsv")
    mrpc_test_fout = os.path.join(canonical_data_root, "mrpc_test.tsv")
    dump_rows(mrpc_train_data, mrpc_train_fout,
              DataFormat.PremiseAndOneHypothesis)
    dump_rows(mrpc_dev_data, mrpc_dev_fout, DataFormat.PremiseAndOneHypothesis)
    dump_rows(mrpc_test_data, mrpc_test_fout,
              DataFormat.PremiseAndOneHypothesis)
    logger.info("done with mrpc")

    qnli_train_fout = os.path.join(canonical_data_root, "qnli_train.tsv")
    qnli_dev_fout = os.path.join(canonical_data_root, "qnli_dev.tsv")
    qnli_test_fout = os.path.join(canonical_data_root, "qnli_test.tsv")
    dump_rows(qnli_train_data, qnli_train_fout,
              DataFormat.PremiseAndOneHypothesis)
    dump_rows(qnli_dev_data, qnli_dev_fout, DataFormat.PremiseAndOneHypothesis)
    dump_rows(qnli_test_data, qnli_test_fout,
              DataFormat.PremiseAndOneHypothesis)
    logger.info("done with qnli")

    if is_old_glue:
        qnli_train_fout = os.path.join(canonical_data_root, "qnnli_train.tsv")
        qnli_dev_fout = os.path.join(canonical_data_root, "qnnli_dev.tsv")
        qnli_test_fout = os.path.join(canonical_data_root, "qnnli_test.tsv")
        dump_rows(qnnli_train_data, qnli_train_fout,
                  DataFormat.PremiseAndMultiHypothesis)
        dump_rows(qnnli_dev_data, qnli_dev_fout,
                  DataFormat.PremiseAndMultiHypothesis)
        dump_rows(qnnli_train_data, qnli_test_fout,
                  DataFormat.PremiseAndMultiHypothesis)
        logger.info("done with qnli")

    qqp_train_fout = os.path.join(canonical_data_root, "qqp_train.tsv")
    qqp_dev_fout = os.path.join(canonical_data_root, "qqp_dev.tsv")
    qqp_test_fout = os.path.join(canonical_data_root, "qqp_test.tsv")
    dump_rows(qqp_train_data, qqp_train_fout,
              DataFormat.PremiseAndOneHypothesis)
    dump_rows(qqp_dev_data, qqp_dev_fout, DataFormat.PremiseAndOneHypothesis)
    dump_rows(qqp_test_data, qqp_test_fout, DataFormat.PremiseAndOneHypothesis)
    logger.info("done with qqp")

    rte_train_fout = os.path.join(canonical_data_root, "rte_train.tsv")
    rte_dev_fout = os.path.join(canonical_data_root, "rte_dev.tsv")
    rte_test_fout = os.path.join(canonical_data_root, "rte_test.tsv")
    dump_rows(rte_train_data, rte_train_fout,
              DataFormat.PremiseAndOneHypothesis)
    dump_rows(rte_dev_data, rte_dev_fout, DataFormat.PremiseAndOneHypothesis)
    dump_rows(rte_test_data, rte_test_fout, DataFormat.PremiseAndOneHypothesis)
    logger.info("done with rte")

    wnli_train_fout = os.path.join(canonical_data_root, "wnli_train.tsv")
    wnli_dev_fout = os.path.join(canonical_data_root, "wnli_dev.tsv")
    wnli_test_fout = os.path.join(canonical_data_root, "wnli_test.tsv")
    dump_rows(wnli_train_data, wnli_train_fout,
              DataFormat.PremiseAndOneHypothesis)
    dump_rows(wnli_dev_data, wnli_dev_fout, DataFormat.PremiseAndOneHypothesis)
    dump_rows(wnli_test_data, wnli_test_fout,
              DataFormat.PremiseAndOneHypothesis)
    logger.info("done with wnli")

    sst_train_fout = os.path.join(canonical_data_root, "sst_train.tsv")
    sst_dev_fout = os.path.join(canonical_data_root, "sst_dev.tsv")
    sst_test_fout = os.path.join(canonical_data_root, "sst_test.tsv")
    dump_rows(sst_train_data, sst_train_fout, DataFormat.PremiseOnly)
    dump_rows(sst_dev_data, sst_dev_fout, DataFormat.PremiseOnly)
    dump_rows(sst_test_data, sst_test_fout, DataFormat.PremiseOnly)
    logger.info("done with sst")

    cola_train_fout = os.path.join(canonical_data_root, "cola_train.tsv")
    cola_dev_fout = os.path.join(canonical_data_root, "cola_dev.tsv")
    cola_test_fout = os.path.join(canonical_data_root, "cola_test.tsv")
    dump_rows(cola_train_data, cola_train_fout, DataFormat.PremiseOnly)
    dump_rows(cola_dev_data, cola_dev_fout, DataFormat.PremiseOnly)
    dump_rows(cola_test_data, cola_test_fout, DataFormat.PremiseOnly)
    logger.info("done with cola")

    stsb_train_fout = os.path.join(canonical_data_root, "stsb_train.tsv")
    stsb_dev_fout = os.path.join(canonical_data_root, "stsb_dev.tsv")
    stsb_test_fout = os.path.join(canonical_data_root, "stsb_test.tsv")
    dump_rows(stsb_train_data, stsb_train_fout,
              DataFormat.PremiseAndOneHypothesis)
    dump_rows(stsb_dev_data, stsb_dev_fout, DataFormat.PremiseAndOneHypothesis)
    dump_rows(stsb_test_data, stsb_test_fout,
              DataFormat.PremiseAndOneHypothesis)
    logger.info("done with stsb")
예제 #8
0
def main(args):
    data_dir = args.data_dir
    data_dir = os.path.abspath(data_dir)
    if not os.path.isdir(data_dir):
        os.mkdir(data_dir)

    train_path = os.path.join(data_dir, 'train.txt')
    dev_path = os.path.join(data_dir, 'valid.txt')
    test_path = os.path.join(data_dir, 'test.txt')

    train_data = load_conll_ner(train_path)
    dev_data = load_conll_ner(dev_path)
    test_data = load_conll_ner(test_path)
    logger.info('Loaded {} NER train samples'.format(len(train_data)))
    logger.info('Loaded {} NER dev samples'.format(len(dev_data)))
    logger.info('Loaded {} NER test samples'.format(len(test_data)))

    pos_train_data = load_conll_pos(train_path)
    pos_dev_data = load_conll_pos(dev_path)
    pos_test_data = load_conll_pos(test_path)
    logger.info('Loaded {} POS train samples'.format(len(pos_train_data)))
    logger.info('Loaded {} POS dev samples'.format(len(pos_train_data)))
    logger.info('Loaded {} POS test samples'.format(len(pos_train_data)))

    chunk_train_data = load_conll_chunk(train_path)
    chunk_dev_data = load_conll_chunk(dev_path)
    chunk_test_data = load_conll_chunk(test_path)
    logger.info('Loaded {} POS train samples'.format(len(chunk_train_data)))
    logger.info('Loaded {} POS dev samples'.format(len(chunk_dev_data)))
    logger.info('Loaded {} POS test samples'.format(len(chunk_test_data)))

    bert_root = args.output_dir
    if not os.path.isdir(bert_root):
        os.mkdir(bert_root)
    train_fout = os.path.join(bert_root, 'ner_train.tsv')
    dev_fout = os.path.join(bert_root, 'ner_dev.tsv')
    test_fout = os.path.join(bert_root, 'ner_test.tsv')

    dump_rows(train_data, train_fout, DataFormat.Seqence)
    dump_rows(dev_data, dev_fout, DataFormat.Seqence)
    dump_rows(test_data, test_fout, DataFormat.Seqence)
    logger.info('done with NER')

    train_fout = os.path.join(bert_root, 'pos_train.tsv')
    dev_fout = os.path.join(bert_root, 'pos_dev.tsv')
    test_fout = os.path.join(bert_root, 'pos_test.tsv')
    dump_rows(pos_train_data, train_fout, DataFormat.Seqence)
    dump_rows(pos_dev_data, dev_fout, DataFormat.Seqence)
    dump_rows(pos_test_data, test_fout, DataFormat.Seqence)
    logger.info('done with POS')

    train_fout = os.path.join(bert_root, 'chunk_train.tsv')
    dev_fout = os.path.join(bert_root, 'chunk_dev.tsv')
    test_fout = os.path.join(bert_root, 'chunk_test.tsv')
    dump_rows(chunk_train_data, train_fout, DataFormat.Seqence)
    dump_rows(chunk_dev_data, dev_fout, DataFormat.Seqence)
    dump_rows(chunk_test_data, test_fout, DataFormat.Seqence)
    logger.info('done with chunk')
예제 #9
0
def main(args):
    is_old_glue = args.old_glue
    root = args.root_dir
    assert os.path.exists(root)

    ######################################
    # GLUE tasks
    ######################################
    lcqmc_train_path = os.path.join(root, 'LCQMC/train.tsv')
    lcqmc_dev_path = os.path.join(root, 'LCQMC/dev.tsv')
    lcqmc_test_path = os.path.join(root, 'LCQMC/test.tsv')

    bq_train_path = os.path.join(root, 'BQ/train.tsv')
    bq_dev_path = os.path.join(root, 'BQ/dev.tsv')
    bq_test_path = os.path.join(root, 'BQ/test.tsv')

    pawsx_train_path = os.path.join(root, 'PAWSX/train.tsv')
    pawsx_dev_path = os.path.join(root, 'PAWSX/dev.tsv')
    pawsx_test_path = os.path.join(root, 'PAWSX/test.tsv')

    ######################################
    # Loading DATA
    ######################################

    lcqmc_train_data = load_test(lcqmc_train_path, is_train=True)
    lcqmc_dev_data = load_test(lcqmc_dev_path, is_train=True)
    lcqmc_test_data = load_test(lcqmc_test_path, is_train=False)
    logger.info('Loaded {} LCQMC train samples'.format(len(lcqmc_train_data)))
    logger.info('Loaded {} LCQMC dev samples'.format(len(lcqmc_dev_data)))
    logger.info('Loaded {} LCQMC test samples'.format(len(lcqmc_test_data)))

    bq_train_data = load_test(bq_train_path, is_train=True)
    bq_dev_data = load_test(bq_dev_path, is_train=True)
    bq_test_data = load_test(bq_test_path, is_train=False)
    logger.info('Loaded {} BQ train samples'.format(len(bq_train_data)))
    logger.info('Loaded {} BQ dev samples'.format(len(bq_dev_data)))
    logger.info('Loaded {} BQ test samples'.format(len(bq_test_data)))

    pawsx_train_data = load_test(pawsx_train_path, is_train=True)
    pawsx_dev_data = load_test(pawsx_dev_path, is_train=True)
    pawsx_test_data = load_test(pawsx_test_path, is_train=False)
    logger.info('Loaded {} PAWSX train samples'.format(len(pawsx_train_data)))
    logger.info('Loaded {} PAWSX dev samples'.format(len(pawsx_dev_data)))
    logger.info('Loaded {} PAWSX test samples'.format(len(pawsx_test_data)))
    canonical_data_suffix = "canonical_data"
    canonical_data_root = os.path.join(root, canonical_data_suffix)
    if not os.path.isdir(canonical_data_root):
        os.mkdir(canonical_data_root)

    lcqmc_train_fout = os.path.join(canonical_data_root, 'lcqmc_train.tsv')
    lcqmc_dev_fout = os.path.join(canonical_data_root, 'lcqmc_dev.tsv')
    lcqmc_test_fout = os.path.join(canonical_data_root, 'lcqmc_test.tsv')
    dump_rows(lcqmc_train_data, lcqmc_train_fout, DataFormat.SimPair)
    dump_rows(lcqmc_dev_data, lcqmc_dev_fout, DataFormat.SimPair)
    dump_rows(lcqmc_test_data, lcqmc_test_fout, DataFormat.SimPairTest)
    logger.info('done with lcqmc')

    bq_train_fout = os.path.join(canonical_data_root, 'bq_train.tsv')
    bq_dev_fout = os.path.join(canonical_data_root, 'bq_dev.tsv')
    bq_test_fout = os.path.join(canonical_data_root, 'bq_test.tsv')
    dump_rows(bq_train_data, bq_train_fout, DataFormat.SimPair)
    dump_rows(bq_dev_data, bq_dev_fout, DataFormat.SimPair)
    dump_rows(bq_test_data, bq_test_fout, DataFormat.SimPairTest)
    logger.info('done with bq')

    pawsx_train_fout = os.path.join(canonical_data_root, 'pawsx_train.tsv')
    pawsx_dev_fout = os.path.join(canonical_data_root, 'pawsx_dev.tsv')
    pawsx_test_fout = os.path.join(canonical_data_root, 'pawsx_test.tsv')
    dump_rows(pawsx_train_data, pawsx_train_fout, DataFormat.SimPair)
    dump_rows(pawsx_dev_data, pawsx_dev_fout, DataFormat.SimPair)
    dump_rows(pawsx_test_data, pawsx_test_fout, DataFormat.SimPairTest)
    logger.info('done with pawsx')
예제 #10
0
def main(args):
    root = args.root_dir
    assert os.path.exists(root)

    ######################################
    # SuperGLUE tasks
    ######################################

    cb_train_path = os.path.join(root, "CB/train.jsonl")
    cb_dev_path = os.path.join(root, "CB/val.jsonl")
    cb_test_path = os.path.join(root, "CB/test.jsonl")

    boolq_train_path = os.path.join(root, "BoolQ/train.jsonl")
    boolq_dev_path = os.path.join(root, "BoolQ/val.jsonl")
    boolq_test_path = os.path.join(root, "BoolQ/test.jsonl")

    copa_train_path = os.path.join(root, "COPA/train.jsonl")
    copa_dev_path = os.path.join(root, "COPA/val.jsonl")
    copa_test_path = os.path.join(root, "COPA/test.jsonl")

    record_train_path = os.path.join(root, "ReCoRD/train.jsonl")
    record_dev_path = os.path.join(root, "ReCoRD/val.jsonl")
    record_test_path = os.path.join(root, "ReCoRD/test.jsonl")

    wic_train_path = os.path.join(root, "WiC/train.jsonl")
    wic_dev_path = os.path.join(root, "WiC/val.jsonl")
    wic_test_path = os.path.join(root, "WiC/test.jsonl")

    multirc_train_path = os.path.join(root, "MultiRC/train.jsonl")
    multirc_dev_path = os.path.join(root, "MultiRC/val.jsonl")
    multirc_test_path = os.path.join(root, "MultiRC/test.jsonl")

    ######################################
    # Loading DATA
    ######################################

    cb_train_data = load_cb(cb_train_path)
    cb_dev_data = load_cb(cb_dev_path)
    cb_test_data = load_cb(cb_test_path)
    logger.info("Loaded {} CB train samples".format(len(cb_train_data)))
    logger.info("Loaded {} CB dev samples".format(len(cb_dev_data)))
    logger.info("Loaded {} CB test samples".format(len(cb_test_data)))

    boolq_train_data = load_boolq(boolq_train_path)
    boolq_dev_data = load_boolq(boolq_dev_path)
    boolq_test_data = load_boolq(boolq_test_path)
    logger.info("Loaded {} BoolQ train samples".format(len(boolq_train_data)))
    logger.info("Loaded {} BoolQ dev samples".format(len(boolq_dev_data)))
    logger.info("Loaded {} BoolQ test samples".format(len(boolq_test_data)))

    copa_train_data = load_copa_mtdnn(copa_train_path)
    copa_dev_data = load_copa_mtdnn(copa_dev_path)
    copa_test_data = load_copa_mtdnn(copa_test_path)
    logger.info("Loaded {} COPA train samples".format(len(copa_train_data)))
    logger.info("Loaded {} COPA dev samples".format(len(copa_dev_data)))
    logger.info("Loaded {} COPA test samples".format(len(copa_test_data)))

    record_train_data = load_record_mtdnn(record_train_path)
    record_dev_data = load_record_mtdnn(record_dev_path)
    record_test_data = load_record_mtdnn(record_test_path)
    logger.info("Loaded {} Record train samples".format(
        len(record_train_data)))
    logger.info("Loaded {} Record dev samples".format(len(record_dev_data)))
    logger.info("Loaded {} Record test samples".format(len(record_test_data)))

    wic_train_data = load_wic_mtdnn(wic_train_path)
    wic_dev_data = load_wic_mtdnn(wic_dev_path)
    wic_test_data = load_wic_mtdnn(wic_test_path)
    logger.info("Loaded {} WiC train samples".format(len(wic_train_data)))
    logger.info("Loaded {} WiC dev samples".format(len(wic_dev_data)))
    logger.info("Loaded {} WiC test samples".format(len(wic_test_data)))

    multirc_train_data = load_multirc_mtdnn(multirc_train_path)
    multirc_dev_data = load_multirc_mtdnn(multirc_dev_path)
    multirc_test_data = load_multirc_mtdnn(multirc_test_path)
    logger.info("Loaded {} MultiRC train samples".format(
        len(multirc_train_data)))
    logger.info("Loaded {} MultiRC dev samples".format(len(multirc_dev_data)))
    logger.info("Loaded {} MultiRC test samples".format(
        len(multirc_test_data)))

    canonical_data_suffix = "canonical_data"
    canonical_data_root = os.path.join(root, canonical_data_suffix)
    if not os.path.isdir(canonical_data_root):
        os.mkdir(canonical_data_root)

    cb_train_fout = os.path.join(canonical_data_root, "cb_train.tsv")
    cb_dev_fout = os.path.join(canonical_data_root, "cb_dev.tsv")
    cb_test_fout = os.path.join(canonical_data_root, "cb_test.tsv")
    dump_rows(cb_train_data, cb_train_fout, DataFormat.PremiseAndOneHypothesis)
    dump_rows(cb_dev_data, cb_dev_fout, DataFormat.PremiseAndOneHypothesis)
    dump_rows(cb_test_data, cb_test_fout, DataFormat.PremiseAndOneHypothesis)
    logger.info("done with CB")

    boolq_train_fout = os.path.join(canonical_data_root, "boolq_train.tsv")
    boolq_dev_fout = os.path.join(canonical_data_root, "boolq_dev.tsv")
    boolq_test_fout = os.path.join(canonical_data_root, "boolq_test.tsv")
    dump_rows(boolq_train_data, boolq_train_fout,
              DataFormat.PremiseAndOneHypothesis)
    dump_rows(boolq_dev_data, boolq_dev_fout,
              DataFormat.PremiseAndOneHypothesis)
    dump_rows(boolq_test_data, boolq_test_fout,
              DataFormat.PremiseAndOneHypothesis)
    logger.info("done with boolq")

    copa_train_fout = os.path.join(canonical_data_root, "copa_train.tsv")
    copa_dev_fout = os.path.join(canonical_data_root, "copa_dev.tsv")
    copa_test_fout = os.path.join(canonical_data_root, "copa_test.tsv")
    dump_rows(copa_train_data, copa_train_fout,
              DataFormat.PremiseAndMultiHypothesis)
    dump_rows(copa_dev_data, copa_dev_fout,
              DataFormat.PremiseAndMultiHypothesis)
    dump_rows(copa_test_data, copa_test_fout,
              DataFormat.PremiseAndMultiHypothesis)
    logger.info("done with record")

    record_train_fout = os.path.join(canonical_data_root, "record_train.tsv")
    record_dev_fout = os.path.join(canonical_data_root, "record_dev.tsv")
    record_test_fout = os.path.join(canonical_data_root, "record_test.tsv")
    dump_rows(record_train_data, record_train_fout, DataFormat.ClozeChoice)
    dump_rows(record_dev_data, record_dev_fout, DataFormat.ClozeChoice)
    dump_rows(record_test_data, record_test_fout, DataFormat.ClozeChoice)
    logger.info("done with record")

    wic_train_fout = os.path.join(canonical_data_root, "wic_train.tsv")
    wic_dev_fout = os.path.join(canonical_data_root, "wic_dev.tsv")
    wic_test_fout = os.path.join(canonical_data_root, "wic_test.tsv")
    dump_rows(wic_train_data, wic_train_fout,
              DataFormat.PremiseAndOneHypothesis)
    dump_rows(wic_dev_data, wic_dev_fout, DataFormat.PremiseAndOneHypothesis)
    dump_rows(wic_test_data, wic_test_fout, DataFormat.PremiseAndOneHypothesis)
    logger.info("done with WiC")

    multirc_train_fout = os.path.join(canonical_data_root, "multirc_train.tsv")
    multirc_dev_fout = os.path.join(canonical_data_root, "multirc_dev.tsv")
    multirc_test_fout = os.path.join(canonical_data_root, "multirc_test.tsv")
    dump_rows(multirc_train_data, multirc_train_fout,
              DataFormat.PremiseAndOneHypothesis)
    dump_rows(multirc_dev_data, multirc_dev_fout,
              DataFormat.PremiseAndOneHypothesis)
    dump_rows(multirc_test_data, multirc_test_fout,
              DataFormat.PremiseAndOneHypothesis)
    logger.info("done with MultiRC")
예제 #11
0
def main(args):
    is_old_glue = args.old_glue
    root = args.root_dir
    assert os.path.exists(root)

    ######################################
    # SNLI/SciTail Tasks
    ######################################
    scitail_train_path = os.path.join(root, "SciTail/tsv_format/scitail_1.0_train.tsv")
    scitail_dev_path = os.path.join(root, "SciTail/tsv_format/scitail_1.0_dev.tsv")
    scitail_test_path = os.path.join(root, "SciTail/tsv_format/scitail_1.0_test.tsv")

    snli_train_path = os.path.join(root, "SNLI/train.tsv")
    snli_dev_path = os.path.join(root, "SNLI/dev.tsv")
    snli_test_path = os.path.join(root, "SNLI/test.tsv")

    ######################################
    # Loading DATA
    ######################################
    scitail_train_data = load_scitail(scitail_train_path)
    scitail_dev_data = load_scitail(scitail_dev_path)
    scitail_test_data = load_scitail(scitail_test_path)
    logger.info("Loaded {} SciTail train samples".format(len(scitail_train_data)))
    logger.info("Loaded {} SciTail dev samples".format(len(scitail_dev_data)))
    logger.info("Loaded {} SciTail test samples".format(len(scitail_test_data)))

    snli_train_data = load_snli(snli_train_path)
    snli_dev_data = load_snli(snli_dev_path)
    snli_test_data = load_snli(snli_test_path)
    logger.info("Loaded {} SNLI train samples".format(len(snli_train_data)))
    logger.info("Loaded {} SNLI dev samples".format(len(snli_dev_data)))
    logger.info("Loaded {} SNLI test samples".format(len(snli_test_data)))

    canonical_data_suffix = "canonical_data"
    canonical_data_root = os.path.join(root, canonical_data_suffix)
    if not os.path.isdir(canonical_data_root):
        os.mkdir(canonical_data_root)

    # BUILD SciTail
    scitail_train_fout = os.path.join(canonical_data_root, "scitail_train.tsv")
    scitail_dev_fout = os.path.join(canonical_data_root, "scitail_dev.tsv")
    scitail_test_fout = os.path.join(canonical_data_root, "scitail_test.tsv")
    dump_rows(
        scitail_train_data, scitail_train_fout, DataFormat.PremiseAndOneHypothesis
    )
    dump_rows(scitail_dev_data, scitail_dev_fout, DataFormat.PremiseAndOneHypothesis)
    dump_rows(scitail_test_data, scitail_test_fout, DataFormat.PremiseAndOneHypothesis)
    logger.info("done with scitail")

    # BUILD SNLI
    snli_train_fout = os.path.join(canonical_data_root, "snli_train.tsv")
    snli_dev_fout = os.path.join(canonical_data_root, "snli_dev.tsv")
    snli_test_fout = os.path.join(canonical_data_root, "snli_test.tsv")
    dump_rows(snli_train_data, snli_train_fout, DataFormat.PremiseAndOneHypothesis)
    dump_rows(snli_dev_data, snli_dev_fout, DataFormat.PremiseAndOneHypothesis)
    dump_rows(snli_test_data, snli_test_fout, DataFormat.PremiseAndOneHypothesis)
    logger.info("done with snli")