示例#1
0
    def load_referenced_model(self, experiment):
        if hasattr(experiment, "load") or hasattr(experiment, "overwrite"):
            exp_args = set([
                x[0] for x in tree_tools.name_children(experiment,
                                                       include_reserved=False)
            ])
            if exp_args not in [set(["load"]), set(["load", "overwrite"])]:
                raise ValueError(
                    f"When loading a model from an external YAML file, only 'load' and 'overwrite' are permitted ('load' is required) as arguments to the experiment. Found: {exp_args}"
                )
            try:
                with open(experiment.load) as stream:
                    saved_obj = yaml.load(stream)
            except IOError as e:
                raise RuntimeError(
                    f"Could not read configuration file {experiment.load}: {e}"
                )
            for saved_key, saved_val in tree_tools.name_children(
                    saved_obj, include_reserved=True):
                if not hasattr(experiment, saved_key):
                    setattr(experiment, saved_key, saved_val)

            if hasattr(experiment, "overwrite"):
                for d in experiment.overwrite:
                    path = tree_tools.Path(d["path"])
                    try:
                        tree_tools.set_descendant(experiment, path, d["val"])
                    except:
                        tree_tools.set_descendant(experiment, path, d["val"])
                delattr(experiment, "overwrite")
示例#2
0
 def init_component(self, path):
     """
 Args:
   path: path to uninitialized object
 Returns:
   initialized object; this method is cached, so multiple requests for the same path will return the exact same object
 """
     obj = tree_tools.get_descendant(self.deserialized_yaml, path)
     if not isinstance(obj, Serializable):
         return obj
     init_params = OrderedDict(
         tree_tools.name_children(obj, include_reserved=False))
     serialize_params = OrderedDict(init_params)
     init_args = tree_tools.get_init_args_defaults(obj)
     if "yaml_path" in init_args: init_params["yaml_path"] = path
     self.check_init_param_types(obj, init_params)
     try:
         initialized_obj = obj.__class__(**init_params)
         logger.debug(
             f"initialized {path}: {obj.__class__.__name__}@{id(obj)}({dict(init_params)})"[:
                                                                                            1000]
         )
     except TypeError as e:
         raise ComponentInitError(
             f"An error occurred trying to invoke {type(obj).__name__}.__init__()\n"
             f" The following arguments were passed: {init_params}\n"
             f" The following arguments were expected: {init_args.keys()}\n"
             f" Error message: {e}")
     serialize_params.update(
         getattr(initialized_obj, "serialize_params", {}))
     initialized_obj.serialize_params = serialize_params
     return initialized_obj
示例#3
0
 def format_strings(self, exp_values, format_dict):
     """
 - replaces strings containing {EXP} and other supported args
 - also checks if there are default arguments for which no arguments are set and instantiates them with replaced {EXP} if applicable
 """
     for path, node in tree_tools.traverse_tree(exp_values):
         if isinstance(node, str):
             try:
                 formatted = node.format(**format_dict)
             except (
                     ValueError, KeyError
             ):  # will occur e.g. if a vocab entry contains a curly bracket
                 formatted = node
             if node != formatted:
                 tree_tools.set_descendant(exp_values, path,
                                           FormatString(formatted, node))
         elif isinstance(node, Serializable):
             init_args_defaults = tree_tools.get_init_args_defaults(node)
             for expected_arg in init_args_defaults:
                 if not expected_arg in [
                         x[0] for x in tree_tools.name_children(
                             node, include_reserved=False)
                 ]:
                     arg_default = init_args_defaults[expected_arg].default
                     if isinstance(arg_default, str):
                         try:
                             formatted = arg_default.format(**format_dict)
                         except (
                                 ValueError, KeyError
                         ):  # will occur e.g. if a vocab entry contains a curly bracket
                             formatted = arg_default
                         if arg_default != formatted:
                             setattr(node, expected_arg,
                                     FormatString(formatted, arg_default))
示例#4
0
 def resolve_ref_default_args(self, root):
     for _, node in tree_tools.traverse_tree(root):
         if isinstance(node, Serializable):
             init_args_defaults = tree_tools.get_init_args_defaults(node)
             for expected_arg in init_args_defaults:
                 if not expected_arg in [
                         x[0] for x in tree_tools.name_children(
                             node, include_reserved=False)
                 ]:
                     arg_default = init_args_defaults[expected_arg].default
                     if isinstance(arg_default, tree_tools.Ref):
                         setattr(node, expected_arg, arg_default)
示例#5
0
 def get_named_paths(self, root):
     d = {}
     for path, node in tree_tools.traverse_tree(root):
         if "_xnmt_id" in [
                 name for (name, _) in tree_tools.name_children(
                     node, include_reserved=True)
         ]:
             xnmt_id = tree_tools.get_child(node, "_xnmt_id")
             if xnmt_id in d:
                 raise ValueError(
                     f"_xnmt_id {xnmt_id} was specified multiple times!")
             d[xnmt_id] = path
     return d
示例#6
0
 def resolve_bare_default_args(self, root):
     for path, node in tree_tools.traverse_tree(root):
         if isinstance(node, Serializable):
             init_args_defaults = tree_tools.get_init_args_defaults(node)
             for expected_arg in init_args_defaults:
                 if not expected_arg in [
                         x[0] for x in tree_tools.name_children(
                             node, include_reserved=False)
                 ]:
                     arg_default = init_args_defaults[expected_arg].default
                     if isinstance(arg_default,
                                   Serializable) and not isinstance(
                                       arg_default, tree_tools.Ref):
                         if not getattr(arg_default, "_is_bare", False):
                             raise ValueError(
                                 f"only Serializables created via bare(SerializableSubtype) are permitted as default arguments; found a fully initialized Serializable: {arg_default} at {path}"
                             )
                         self.resolve_bare_default_args(
                             arg_default)  # apply recursively
                         setattr(node, expected_arg, arg_default)