def testSerialization(self): original = Resources(1, 0, 0, 1, custom_resources={"a": 1, "b": 2}) jsoned = resources_to_json(original) new_resource = json_to_resources(jsoned) self.assertEquals(original, new_resource)
def create_trial_from_spec(spec: dict, output_path: str, parser: argparse.ArgumentParser, **trial_kwargs): """Creates a Trial object from parsing the spec. Args: spec: A resolved experiment specification. Arguments should The args here should correspond to the command line flags in ray.tune.experiment.config_parser. output_path: A specific output path within the local_dir. Typically the name of the experiment. parser: An argument parser object from make_parser. trial_kwargs: Extra keyword arguments used in instantiating the Trial. Returns: A trial object with corresponding parameters to the specification. """ global _cached_pgf spec = spec.copy() resources = spec.pop("resources_per_trial", None) try: args, _ = parser.parse_known_args(to_argv(spec)) except SystemExit: raise TuneError("Error parsing args, see above message", spec) if resources: if isinstance(resources, PlacementGroupFactory): trial_kwargs["placement_group_factory"] = resources else: # This will be converted to a placement group factory in the # Trial object constructor try: trial_kwargs["resources"] = json_to_resources(resources) except (TuneError, ValueError) as exc: raise TuneError("Error parsing resources_per_trial", resources) from exc remote_checkpoint_dir = spec.get("remote_checkpoint_dir") sync_config = spec.get("sync_config", SyncConfig()) if (sync_config.syncer is None or sync_config.syncer == "auto" or isinstance(sync_config.syncer, Syncer)): custom_syncer = sync_config.syncer else: raise ValueError( f"Unknown syncer type passed in SyncConfig: {type(sync_config.syncer)}. " f"Note that custom sync functions and templates have been deprecated. " f"Instead you can implement you own `Syncer` class. " f"Please leave a comment on GitHub if you run into any issues with this: " f"https://github.com/ray-project/ray/issues") return Trial( # Submitting trial via server in py2.7 creates Unicode, which does not # convert to string in a straightforward manner. trainable_name=spec["run"], # json.load leads to str -> unicode in py2.7 config=spec.get("config", {}), local_dir=os.path.join(spec["local_dir"], output_path), # json.load leads to str -> unicode in py2.7 stopping_criterion=spec.get("stop", {}), remote_checkpoint_dir=remote_checkpoint_dir, custom_syncer=custom_syncer, checkpoint_freq=args.checkpoint_freq, checkpoint_at_end=args.checkpoint_at_end, sync_on_checkpoint=sync_config.sync_on_checkpoint, keep_checkpoints_num=args.keep_checkpoints_num, checkpoint_score_attr=args.checkpoint_score_attr, export_formats=spec.get("export_formats", []), # str(None) doesn't create None restore_path=spec.get("restore"), trial_name_creator=spec.get("trial_name_creator"), trial_dirname_creator=spec.get("trial_dirname_creator"), log_to_file=spec.get("log_to_file"), # str(None) doesn't create None max_failures=args.max_failures, **trial_kwargs, )
def create_trial_from_spec(spec: dict, output_path: str, parser: argparse.ArgumentParser, **trial_kwargs): """Creates a Trial object from parsing the spec. Args: spec: A resolved experiment specification. Arguments should The args here should correspond to the command line flags in ray.tune.config_parser. output_path: A specific output path within the local_dir. Typically the name of the experiment. parser: An argument parser object from make_parser. trial_kwargs: Extra keyword arguments used in instantiating the Trial. Returns: A trial object with corresponding parameters to the specification. """ global _cached_pgf spec = spec.copy() resources = spec.pop("resources_per_trial", None) try: args, _ = parser.parse_known_args(to_argv(spec)) except SystemExit: raise TuneError("Error parsing args, see above message", spec) if resources: if isinstance(resources, PlacementGroupFactory): trial_kwargs["placement_group_factory"] = resources else: # This will be converted to a placement group factory in the # Trial object constructor try: trial_kwargs["resources"] = json_to_resources(resources) except (TuneError, ValueError) as exc: raise TuneError("Error parsing resources_per_trial", resources) from exc remote_checkpoint_dir = spec.get("remote_checkpoint_dir") sync_config = spec.get("sync_config", SyncConfig()) if sync_config.syncer is None or isinstance(sync_config.syncer, str): sync_function_tpl = sync_config.syncer elif not isinstance(sync_config.syncer, str): # If a syncer was specified, but not a template, it is a function. # Functions cannot be used for trial checkpointing on remote nodes, # so we set the remote checkpoint dir to None to disable this. sync_function_tpl = None remote_checkpoint_dir = None else: sync_function_tpl = None # Auto-detect return Trial( # Submitting trial via server in py2.7 creates Unicode, which does not # convert to string in a straightforward manner. trainable_name=spec["run"], # json.load leads to str -> unicode in py2.7 config=spec.get("config", {}), local_dir=os.path.join(spec["local_dir"], output_path), # json.load leads to str -> unicode in py2.7 stopping_criterion=spec.get("stop", {}), remote_checkpoint_dir=remote_checkpoint_dir, sync_function_tpl=sync_function_tpl, checkpoint_freq=args.checkpoint_freq, checkpoint_at_end=args.checkpoint_at_end, sync_on_checkpoint=sync_config.sync_on_checkpoint, keep_checkpoints_num=args.keep_checkpoints_num, checkpoint_score_attr=args.checkpoint_score_attr, export_formats=spec.get("export_formats", []), # str(None) doesn't create None restore_path=spec.get("restore"), trial_name_creator=spec.get("trial_name_creator"), trial_dirname_creator=spec.get("trial_dirname_creator"), log_to_file=spec.get("log_to_file"), # str(None) doesn't create None max_failures=args.max_failures, **trial_kwargs)
def create_trial_from_spec(spec, output_path, parser, **trial_kwargs): """Creates a Trial object from parsing the spec. Arguments: spec (dict): A resolved experiment specification. Arguments should The args here should correspond to the command line flags in ray.tune.config_parser. output_path (str); A specific output path within the local_dir. Typically the name of the experiment. parser (ArgumentParser): An argument parser object from make_parser. trial_kwargs: Extra keyword arguments used in instantiating the Trial. Returns: A trial object with corresponding parameters to the specification. """ global _cached_pgf spec = spec.copy() resources = spec.pop("resources_per_trial", None) try: args, _ = parser.parse_known_args(to_argv(spec)) except SystemExit: raise TuneError("Error parsing args, see above message", spec) if resources: if isinstance(resources, PlacementGroupFactory): trial_kwargs["placement_group_factory"] = resources elif callable(resources): if resources in _cached_pgf: trial_kwargs["placement_group_factory"] = _cached_pgf[ resources] else: pgf = PlacementGroupFactory(resources) _cached_pgf[resources] = pgf trial_kwargs["placement_group_factory"] = pgf else: try: trial_kwargs["resources"] = json_to_resources(resources) except (TuneError, ValueError) as exc: raise TuneError("Error parsing resources_per_trial", resources) from exc return Trial( # Submitting trial via server in py2.7 creates Unicode, which does not # convert to string in a straightforward manner. trainable_name=spec["run"], # json.load leads to str -> unicode in py2.7 config=spec.get("config", {}), local_dir=os.path.join(spec["local_dir"], output_path), # json.load leads to str -> unicode in py2.7 stopping_criterion=spec.get("stop", {}), remote_checkpoint_dir=spec.get("remote_checkpoint_dir"), checkpoint_freq=args.checkpoint_freq, checkpoint_at_end=args.checkpoint_at_end, sync_on_checkpoint=args.sync_on_checkpoint, keep_checkpoints_num=args.keep_checkpoints_num, checkpoint_score_attr=args.checkpoint_score_attr, export_formats=spec.get("export_formats", []), # str(None) doesn't create None restore_path=spec.get("restore"), trial_name_creator=spec.get("trial_name_creator"), trial_dirname_creator=spec.get("trial_dirname_creator"), log_to_file=spec.get("log_to_file"), # str(None) doesn't create None max_failures=args.max_failures, **trial_kwargs)