Esempio n. 1
0
def main():
    args = parse_args()
    np.random.seed(args.seed)

    if args.tasks == "all_tasks":
        task_names = tasks.ALL_TASKS
    else:
        task_names = args.tasks.split(",")
    task_dict = tasks.get_task_dict(task_names)

    description_dict = {}
    if args.description_dict_path:
        with open(args.description_dict_path, "r") as f:
            description_dict = json.load(f)

    os.makedirs(args.output_base_path, exist_ok=True)
    for task_name, task in task_dict.items():
        rnd = random.Random()
        rnd.seed(args.seed)

        iters = []

        for set in args.sets.split(","):
            if set == "train" and task.has_training_docs():
                docs = task.training_docs()
            if set == "val" and task.has_validation_docs():
                docs = task.validation_docs()
            if set == "test" and task.has_test_docs():
                docs = task.test_docs()
            iters.append(docs)

        docs = join_iters(iters)

        description = (
            description_dict[task_name]
            if description_dict and task_name in description_dict
            else ""
        )

        with open(os.path.join(args.output_base_path, task_name), "w") as f:
            for i, doc in (
                zip(range(args.num_examples), docs)
                if args.num_examples > 0
                else enumerate(docs)
            ):
                f.write(EXAMPLE_DIVIDER.format(i=i))
                ctx = task.fewshot_context(
                    doc=doc,
                    num_fewshot=args.num_fewshot,
                    rnd=rnd,
                    description=description,
                )
                f.write(ctx + "\n")
Esempio n. 2
0
def main():
    args = parse_args()
    np.random.seed(args.seed)

    if args.tasks == "all_tasks":
        task_names = tasks.ALL_TASKS
    else:
        task_names = args.tasks.split(",")
    task_dict = tasks.get_task_dict(task_names)
    os.makedirs(args.output_base_path, exist_ok=True)
    for task_name, task in task_dict.items():
        rnd = random.Random()
        rnd.seed(args.seed)

        iters = []

        for set in args.sets.split(","):
            if set == 'train' and task.has_training_docs():
                docs = task.training_docs()
            if set == 'val' and task.has_validation_docs():
                docs = task.validation_docs()
            if set == 'test' and task.has_test_docs():
                docs = task.test_docs()
            iters.append(docs)

        docs = join_iters(iters)

        with open(os.path.join(args.output_base_path, task_name), "w") as f:
            for i, doc in zip(
                    range(args.num_examples),
                    docs) if args.num_examples > 0 else enumerate(docs):
                f.write(EXAMPLE_DIVIDER.format(i=i))
                ctx = task.fewshot_context(
                    doc=doc,
                    provide_description=args.provide_description,
                    num_fewshot=args.num_fewshot,
                    rnd=rnd)
                f.write(ctx + "\n")