def format_mrpc(data_dir):
    mrpc_dir = os.path.join(data_dir, "mrpc")
    os.makedirs(mrpc_dir, exist_ok=True)
    mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt")
    mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt")
    download(GLUE_TASK2PATH["mrpc"]['train'],
             mrpc_train_file,
             sha1_hash=_URL_FILE_STATS[GLUE_TASK2PATH["mrpc"]['train']])
    download(GLUE_TASK2PATH["mrpc"]['test'],
             mrpc_test_file,
             sha1_hash=_URL_FILE_STATS[GLUE_TASK2PATH["mrpc"]['test']])
    assert os.path.isfile(
        mrpc_train_file), "Train data not found at %s" % mrpc_train_file
    assert os.path.isfile(
        mrpc_test_file), "Test data not found at %s" % mrpc_test_file
    download(GLUE_TASK2PATH["mrpc"]['dev'],
             os.path.join(mrpc_dir, "dev_ids.tsv"),
             sha1_hash=_URL_FILE_STATS[GLUE_TASK2PATH["mrpc"]['dev']])

    dev_ids = []
    with open(os.path.join(mrpc_dir, "dev_ids.tsv"),
              encoding="utf8") as ids_fh:
        for row in ids_fh:
            dev_ids.append(row.strip().split("\t"))

    with open(mrpc_train_file, encoding="utf8") as data_fh, open(
            os.path.join(mrpc_dir, "train.tsv"), "w",
            encoding="utf8") as train_fh, open(os.path.join(
                mrpc_dir, "dev.tsv"),
                                               "w",
                                               encoding="utf8") as dev_fh:
        header = data_fh.readline()
        train_fh.write(header)
        dev_fh.write(header)
        for row in data_fh:
            label, id1, id2, s1, s2 = row.strip().split("\t")
            if [id1, id2] in dev_ids:
                dev_fh.write("%s\t%s\t%s\t%s\t%s\n" %
                             (label, id1, id2, s1, s2))
            else:
                train_fh.write("%s\t%s\t%s\t%s\t%s\n" %
                               (label, id1, id2, s1, s2))

    with open(mrpc_test_file,
              encoding="utf8") as data_fh, open(os.path.join(
                  mrpc_dir, "test.tsv"),
                                                "w",
                                                encoding="utf8") as test_fh:
        header = data_fh.readline()
        test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n")
        for idx, row in enumerate(data_fh):
            label, id1, id2, s1, s2 = row.strip().split("\t")
            test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2))
def main(args):
    if args.data_dir is None:
        args.data_dir = args.benchmark
    args.cache_path = os.path.join(args.cache_path, args.benchmark)
    print('Downloading {} to {}. Selected tasks = {}'.format(
        args.benchmark, args.data_dir, args.tasks))
    os.makedirs(args.cache_path, exist_ok=True)
    os.makedirs(args.data_dir, exist_ok=True)
    tasks = get_tasks(args.benchmark, args.tasks)
    if args.benchmark == 'glue':
        TASK2PATH = GLUE_TASK2PATH
        TASK2READER = GLUE_READERS
    elif args.benchmark == 'superglue':
        TASK2PATH = SUPERGLUE_TASK2PATH
        TASK2READER = SUPERGLUE_READER
    else:
        raise NotImplementedError
    for task in tasks:
        print('Processing {}...'.format(task))
        if task == 'diagnostic' or 'diagnostic' in task:
            if args.benchmark == 'glue':
                reader = TASK2READER[task]
                base_dir = os.path.join(args.data_dir, 'rte_diagnostic')
                os.makedirs(base_dir, exist_ok=True)
                download(TASK2PATH['diagnostic'][0],
                         path=os.path.join(base_dir, 'diagnostic.tsv'),
                         sha1_hash=_URL_FILE_STATS[TASK2PATH['diagnostic'][0]])
                download(TASK2PATH['diagnostic'][1],
                         path=os.path.join(base_dir, 'diagnostic-full.tsv'),
                         sha1_hash=_URL_FILE_STATS[TASK2PATH['diagnostic'][1]])
                df = reader(base_dir)
                df.to_parquet(os.path.join(base_dir,
                                           'diagnostic-full.parquet'))
            else:
                for key, name in [('broadcoverage-diagnostic', 'AX-b'),
                                  ('winogender-diagnostic', 'AX-g')]:
                    data_file = os.path.join(args.cache_path,
                                             "{}.zip".format(key))
                    url = TASK2PATH[key]
                    reader = TASK2READER[key]
                    download(url, data_file, sha1_hash=_URL_FILE_STATS[url])
                    with zipfile.ZipFile(data_file) as zipdata:
                        zipdata.extractall(args.data_dir)
                    df = reader(os.path.join(args.data_dir, name))
                    df.to_parquet(
                        os.path.join(args.data_dir, name,
                                     '{}.parquet'.format(name)))
        elif task == 'mrpc':
            reader = TASK2READER[task]
            format_mrpc(args.data_dir)
            df_dict, meta_data = reader(os.path.join(args.data_dir, 'mrpc'))
            for key, df in df_dict.items():
                if key == 'val':
                    key = 'dev'
                df.to_parquet(
                    os.path.join(args.data_dir, 'mrpc',
                                 '{}.parquet'.format(key)))
            with open(os.path.join(args.data_dir, 'mrpc', 'metadata.json'),
                      'w') as f:
                json.dump(meta_data, f)
        else:
            # Download data
            data_file = os.path.join(args.cache_path, "{}.zip".format(task))
            url = TASK2PATH[task]
            reader = TASK2READER[task]
            download(url, data_file, sha1_hash=_URL_FILE_STATS[url])
            base_dir = os.path.join(args.data_dir, task)
            if os.path.exists(base_dir):
                print('Found!')
                continue
            zip_dir_name = None
            with zipfile.ZipFile(data_file) as zipdata:
                if zip_dir_name is None:
                    zip_dir_name = os.path.dirname(
                        zipdata.infolist()[0].filename)
                zipdata.extractall(args.data_dir)
            shutil.move(os.path.join(args.data_dir, zip_dir_name), base_dir)
            df_dict, meta_data = reader(base_dir)
            for key, df in df_dict.items():
                if key == 'val':
                    key = 'dev'
                df.to_parquet(os.path.join(base_dir, '{}.parquet'.format(key)))
            if meta_data is not None:
                with open(os.path.join(base_dir, 'metadata.json'), 'w') as f:
                    json.dump(meta_data, f)
        print("\tCompleted!")