def load_referenced_model(self, experiment): if hasattr(experiment, "load") or hasattr(experiment, "overwrite"): exp_args = set([ x[0] for x in tree_tools.name_children(experiment, include_reserved=False) ]) if exp_args not in [set(["load"]), set(["load", "overwrite"])]: raise ValueError( f"When loading a model from an external YAML file, only 'load' and 'overwrite' are permitted ('load' is required) as arguments to the experiment. Found: {exp_args}" ) try: with open(experiment.load) as stream: saved_obj = yaml.load(stream) except IOError as e: raise RuntimeError( f"Could not read configuration file {experiment.load}: {e}" ) for saved_key, saved_val in tree_tools.name_children( saved_obj, include_reserved=True): if not hasattr(experiment, saved_key): setattr(experiment, saved_key, saved_val) if hasattr(experiment, "overwrite"): for d in experiment.overwrite: path = tree_tools.Path(d["path"]) try: tree_tools.set_descendant(experiment, path, d["val"]) except: tree_tools.set_descendant(experiment, path, d["val"]) delattr(experiment, "overwrite")
def init_component(self, path): """ Args: path: path to uninitialized object Returns: initialized object; this method is cached, so multiple requests for the same path will return the exact same object """ obj = tree_tools.get_descendant(self.deserialized_yaml, path) if not isinstance(obj, Serializable): return obj init_params = OrderedDict( tree_tools.name_children(obj, include_reserved=False)) serialize_params = OrderedDict(init_params) init_args = tree_tools.get_init_args_defaults(obj) if "yaml_path" in init_args: init_params["yaml_path"] = path self.check_init_param_types(obj, init_params) try: initialized_obj = obj.__class__(**init_params) logger.debug( f"initialized {path}: {obj.__class__.__name__}@{id(obj)}({dict(init_params)})"[: 1000] ) except TypeError as e: raise ComponentInitError( f"An error occurred trying to invoke {type(obj).__name__}.__init__()\n" f" The following arguments were passed: {init_params}\n" f" The following arguments were expected: {init_args.keys()}\n" f" Error message: {e}") serialize_params.update( getattr(initialized_obj, "serialize_params", {})) initialized_obj.serialize_params = serialize_params return initialized_obj
def format_strings(self, exp_values, format_dict): """ - replaces strings containing {EXP} and other supported args - also checks if there are default arguments for which no arguments are set and instantiates them with replaced {EXP} if applicable """ for path, node in tree_tools.traverse_tree(exp_values): if isinstance(node, str): try: formatted = node.format(**format_dict) except ( ValueError, KeyError ): # will occur e.g. if a vocab entry contains a curly bracket formatted = node if node != formatted: tree_tools.set_descendant(exp_values, path, FormatString(formatted, node)) elif isinstance(node, Serializable): init_args_defaults = tree_tools.get_init_args_defaults(node) for expected_arg in init_args_defaults: if not expected_arg in [ x[0] for x in tree_tools.name_children( node, include_reserved=False) ]: arg_default = init_args_defaults[expected_arg].default if isinstance(arg_default, str): try: formatted = arg_default.format(**format_dict) except ( ValueError, KeyError ): # will occur e.g. if a vocab entry contains a curly bracket formatted = arg_default if arg_default != formatted: setattr(node, expected_arg, FormatString(formatted, arg_default))
def resolve_ref_default_args(self, root): for _, node in tree_tools.traverse_tree(root): if isinstance(node, Serializable): init_args_defaults = tree_tools.get_init_args_defaults(node) for expected_arg in init_args_defaults: if not expected_arg in [ x[0] for x in tree_tools.name_children( node, include_reserved=False) ]: arg_default = init_args_defaults[expected_arg].default if isinstance(arg_default, tree_tools.Ref): setattr(node, expected_arg, arg_default)
def get_named_paths(self, root): d = {} for path, node in tree_tools.traverse_tree(root): if "_xnmt_id" in [ name for (name, _) in tree_tools.name_children( node, include_reserved=True) ]: xnmt_id = tree_tools.get_child(node, "_xnmt_id") if xnmt_id in d: raise ValueError( f"_xnmt_id {xnmt_id} was specified multiple times!") d[xnmt_id] = path return d
def resolve_bare_default_args(self, root): for path, node in tree_tools.traverse_tree(root): if isinstance(node, Serializable): init_args_defaults = tree_tools.get_init_args_defaults(node) for expected_arg in init_args_defaults: if not expected_arg in [ x[0] for x in tree_tools.name_children( node, include_reserved=False) ]: arg_default = init_args_defaults[expected_arg].default if isinstance(arg_default, Serializable) and not isinstance( arg_default, tree_tools.Ref): if not getattr(arg_default, "_is_bare", False): raise ValueError( f"only Serializables created via bare(SerializableSubtype) are permitted as default arguments; found a fully initialized Serializable: {arg_default} at {path}" ) self.resolve_bare_default_args( arg_default) # apply recursively setattr(node, expected_arg, arg_default)