コード例 #1
0
def main():
    args = parse_args()

    if os.path.exists(args.output_path):
        print(f"Output path {args.output_path} exists!!!")
        return

    random.seed(args.seed)
    np.random.seed(args.seed)

    if args.limit:
        print(
            "WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
        )

    if args.tasks == "all_tasks":
        task_names = tasks.ALL_TASKS
    else:
        task_names = args.tasks.split(",")
    task_dict = tasks.get_task_dict(task_names)

    lm = models.get_model(args.model)

    train_args = simple_parse_args_string(args.train_args)
    model_args = simple_parse_args_string(args.model_args)

    if train_args:
        train_args.update(model_args)
        train_args["seed"] = args.seed

    results = evaluator.evaluate(lm, task_dict, args.provide_description,
                                 args.num_fewshot, args.limit, train_args,
                                 args.model_args, args.seed)

    results["args"] = args.__dict__
    dumped = json.dumps(results, indent=2)
    print(dumped)
    if args.output_path:
        with open(args.output_path, "w") as f:
            f.write(dumped)

    for task, task_res in results.items():
        if task not in task_names:
            continue
        if "train_args" not in task_res:
            experiment = comet_ml.Experiment(
                api_key=os.environ.get('COMET_API_KEY'),
                project_name=os.environ.get('COMET_PROJECT', "few-shot"),
                workspace=os.environ.get('COMET_WORKSPACE', "yuvalkirstain"),
            )
            experiment.log_asset(args.output_path)
        else:
            experiment = comet_ml.ExistingExperiment(
                api_key=os.environ.get('COMET_API_KEY'),
                previous_experiment=task_res["train_args"]
                ["previous_experiment"])
            experiment.log_asset(args.output_path)
コード例 #2
0
 def create_from_arg_string(cls, arg_string):
     args = utils.simple_parse_args_string(arg_string)
     return cls(device=args.get("device", "cpu"))
コード例 #3
0
 def create_from_arg_string(cls, arg_string, additional_config=None):
     additional_config = {} if additional_config is None else additional_config
     args = utils.simple_parse_args_string(arg_string)
     args2 = {k: v for k, v in additional_config.items() if v is not None}
     return cls(**args, **args2)
コード例 #4
0
 def create_from_arg_string(cls, arg_string):
     args = utils.simple_parse_args_string(arg_string)
     return cls(engine=args.get("engine", "davinci"))
コード例 #5
0
 def create_from_arg_string(cls, arg_string):
     args = utils.simple_parse_args_string(arg_string)
     return cls(device=args.get("device", None), pretrained=args.get("pretrained", "t5-small"))
コード例 #6
0
 def create_from_arg_string(cls, arg_string):
     args = utils.simple_parse_args_string(arg_string)
     return cls(**args)