示例#1
0
def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
    """Creates a Trial object from parsing the spec.

    Arguments:
        spec (dict): A resolved experiment specification. Arguments should
            The args here should correspond to the command line flags
            in ray.tune.config_parser.
        output_path (str); A specific output path within the local_dir.
            Typically the name of the experiment.
        parser (ArgumentParser): An argument parser object from
            make_parser.
        trial_kwargs: Extra keyword arguments used in instantiating the Trial.

    Returns:
        A trial object with corresponding parameters to the specification.
    """
    try:
        args = parser.parse_args(to_argv(spec))
    except SystemExit:
        raise TuneError("Error parsing args, see above message", spec)
    if "resources_per_trial" in spec:
        trial_kwargs["resources"] = json_to_resources(
            spec["resources_per_trial"])
    return Trial(
        # Submitting trial via server in py2.7 creates Unicode, which does not
        # convert to string in a straightforward manner.
        trainable_name=spec["run"],
        # json.load leads to str -> unicode in py2.7
        config=spec.get("config", {}),
        local_dir=os.path.join(args.local_dir, output_path),
        # json.load leads to str -> unicode in py2.7
        stopping_criterion=spec.get("stop", {}),
        checkpoint_freq=args.checkpoint_freq,
        checkpoint_at_end=args.checkpoint_at_end,
        keep_checkpoints_num=args.keep_checkpoints_num,
        checkpoint_score_attr=args.checkpoint_score_attr,
        export_formats=spec.get("export_formats", []),
        # str(None) doesn't create None
        restore_path=spec.get("restore"),
        upload_dir=args.upload_dir,
        trial_name_creator=spec.get("trial_name_creator"),
        loggers=spec.get("loggers"),
        # str(None) doesn't create None
        sync_function=spec.get("sync_function"),
        max_failures=args.max_failures,
        **trial_kwargs)
示例#2
0
def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
    """Creates a Trial object from parsing the spec.

    Arguments:
        spec (dict): A resolved experiment specification. Arguments should
            The args here should correspond to the command line flags
            in ray.tune.config_parser.
        output_path (str); A specific output path within the local_dir.
            Typically the name of the experiment.
        parser (ArgumentParser): An argument parser object from
            make_parser.
        trial_kwargs: Extra keyword arguments used in instantiating the Trial.

    Returns:
        A trial object with corresponding parameters to the specification.
    """
    try:
        args = parser.parse_args(to_argv(spec))
    except SystemExit:
        raise TuneError("Error parsing args, see above message", spec)
    if "resources_per_trial" in spec:
        trial_kwargs["resources"] = json_to_resources(
            spec["resources_per_trial"])
    return Trial(
        # Submitting trial via server in py2.7 creates Unicode, which does not
        # convert to string in a straightforward manner.
        trainable_name=spec["run"],
        # json.load leads to str -> unicode in py2.7
        config=spec.get("config", {}),
        local_dir=os.path.join(args.local_dir, output_path),
        # json.load leads to str -> unicode in py2.7
        stopping_criterion=spec.get("stop", {}),
        checkpoint_freq=args.checkpoint_freq,
        checkpoint_at_end=args.checkpoint_at_end,
        # str(None) doesn't create None
        restore_path=spec.get("restore"),
        upload_dir=args.upload_dir,
        trial_name_creator=spec.get("trial_name_creator"),
        custom_loggers=spec.get("custom_loggers"),
        # str(None) doesn't create None
        sync_function=spec.get("sync_function"),
        max_failures=args.max_failures,
        **trial_kwargs)