Exemplo n.º 1
0
 def resolve_serialize_refs(self, root):
     #     for _, node in tree_tools.traverse_serializable_breadth_first(root):
     for _, node in tree_tools.traverse_serializable(root):
         if isinstance(node, Serializable):
             node.resolved_serialize_params = node.serialize_params
     refs_inserted_at = set()
     refs_inserted_to = set()
     #     for path_to, node in tree_tools.traverse_serializable_breadth_first(root):
     for path_to, node in tree_tools.traverse_serializable(root):
         if not refs_inserted_at & path_to.ancestors(
         ) and not refs_inserted_at & path_to.ancestors():
             if isinstance(node, Serializable):
                 #           for path_from, matching_node in tree_tools.traverse_serializable_breadth_first(root):
                 for path_from, matching_node in tree_tools.traverse_serializable(
                         root):
                     if not path_from in refs_inserted_to:
                         if path_from != path_to and matching_node is node:
                             ref = tree_tools.Ref(path=path_to)
                             ref.resolved_serialize_params = ref.serialize_params
                             tree_tools.set_descendant(
                                 root,
                                 path_from.parent().append(
                                     "resolved_serialize_params").append(
                                         path_from[-1]), ref)
                             refs_inserted_at.add(path_from)
                             refs_inserted_to.add(path_from)
Exemplo n.º 2
0
 def init_components_bottom_up(self, root):
     for path, node in tree_tools.traverse_tree_deep_once(
             root,
             root,
             tree_tools.TraversalOrder.ROOT_LAST,
             named_paths=self.named_paths):
         if isinstance(node, Serializable):
             if isinstance(node, tree_tools.Ref):
                 try:
                     resolved_path = node.resolve_path(self.named_paths)
                     hits_before = self.init_component.cache_info().hits
                     initialized_component = self.init_component(
                         resolved_path)
                 except tree_tools.PathError:
                     initialized_component = None
                 if self.init_component.cache_info().hits > hits_before:
                     logger.debug(
                         f"for {path}: reusing previously initialized {initialized_component}"
                     )
             else:
                 initialized_component = self.init_component(path)
             if len(path) == 0:
                 root = initialized_component
             else:
                 tree_tools.set_descendant(root, path,
                                           initialized_component)
     return root
Exemplo n.º 3
0
 def format_strings(self, exp_values, format_dict):
     """
 - replaces strings containing {EXP} and other supported args
 - also checks if there are default arguments for which no arguments are set and instantiates them with replaced {EXP} if applicable
 """
     for path, node in tree_tools.traverse_tree(exp_values):
         if isinstance(node, str):
             try:
                 formatted = node.format(**format_dict)
             except (
                     ValueError, KeyError
             ):  # will occur e.g. if a vocab entry contains a curly bracket
                 formatted = node
             if node != formatted:
                 tree_tools.set_descendant(exp_values, path,
                                           FormatString(formatted, node))
         elif isinstance(node, Serializable):
             init_args_defaults = tree_tools.get_init_args_defaults(node)
             for expected_arg in init_args_defaults:
                 if not expected_arg in [
                         x[0] for x in tree_tools.name_children(
                             node, include_reserved=False)
                 ]:
                     arg_default = init_args_defaults[expected_arg].default
                     if isinstance(arg_default, str):
                         try:
                             formatted = arg_default.format(**format_dict)
                         except (
                                 ValueError, KeyError
                         ):  # will occur e.g. if a vocab entry contains a curly bracket
                             formatted = arg_default
                         if arg_default != formatted:
                             setattr(node, expected_arg,
                                     FormatString(formatted, arg_default))
Exemplo n.º 4
0
    def load_referenced_model(self, experiment):
        if hasattr(experiment, "load") or hasattr(experiment, "overwrite"):
            exp_args = set([
                x[0] for x in tree_tools.name_children(experiment,
                                                       include_reserved=False)
            ])
            if exp_args not in [set(["load"]), set(["load", "overwrite"])]:
                raise ValueError(
                    f"When loading a model from an external YAML file, only 'load' and 'overwrite' are permitted ('load' is required) as arguments to the experiment. Found: {exp_args}"
                )
            try:
                with open(experiment.load) as stream:
                    saved_obj = yaml.load(stream)
            except IOError as e:
                raise RuntimeError(
                    f"Could not read configuration file {experiment.load}: {e}"
                )
            for saved_key, saved_val in tree_tools.name_children(
                    saved_obj, include_reserved=True):
                if not hasattr(experiment, saved_key):
                    setattr(experiment, saved_key, saved_val)

            if hasattr(experiment, "overwrite"):
                for d in experiment.overwrite:
                    path = tree_tools.Path(d["path"])
                    try:
                        tree_tools.set_descendant(experiment, path, d["val"])
                    except:
                        tree_tools.set_descendant(experiment, path, d["val"])
                delattr(experiment, "overwrite")
Exemplo n.º 5
0
 def instantiate_random_search(self, exp_values):
     param_report = {}
     initialized_random_params = {}
     for path, v in tree_tools.traverse_tree(exp_values):
         if isinstance(v, RandomParam):
             if hasattr(v, "_xnmt_id"
                        ) and v._xnmt_id in initialized_random_params:
                 v = initialized_random_params[v._xnmt_id]
             v = v.draw_value()
             if hasattr(v, "_xnmt_id"):
                 initialized_random_params[v._xnmt_id] = v
             set_descendant(exp_values, path, v)
             param_report[path] = v
     return param_report
Exemplo n.º 6
0
 def share_init_params_top_down(self, root):
     abs_shared_param_sets = []
     for path, node in tree_tools.traverse_tree(root):
         if isinstance(node, Serializable):
             for shared_param_set in node.shared_params():
                 abs_shared_param_set = set(
                     p.get_absolute(path) for p in shared_param_set)
                 added = False
                 for prev_set in abs_shared_param_sets:
                     if prev_set & abs_shared_param_set:
                         prev_set |= abs_shared_param_set
                         added = True
                         break
                 if not added:
                     abs_shared_param_sets.append(abs_shared_param_set)
     for shared_param_set in abs_shared_param_sets:
         shared_val_choices = set()
         for shared_param_path in shared_param_set:
             try:
                 new_shared_val = tree_tools.get_descendant(
                     root, shared_param_path)
             except tree_tools.PathError:
                 continue
             for _, child_of_shared_param in tree_tools.traverse_tree(
                     new_shared_val, include_root=False):
                 if isinstance(child_of_shared_param, Serializable):
                     raise ValueError(
                         f"{path} shared params {shared_param_set} contains Serializable sub-object {child_of_shared_param} which is not permitted"
                     )
             shared_val_choices.add(new_shared_val)
         if len(shared_val_choices) > 1:
             logger.warning(
                 f"inconsistent shared params at {path} for {shared_param_set}: {shared_val_choices}; Ignoring these shared parameters."
             )
         elif len(shared_val_choices) == 1:
             for shared_param_path in shared_param_set:
                 if shared_param_path[
                         -1] in tree_tools.get_init_args_defaults(
                             tree_tools.get_descendant(
                                 root, shared_param_path.parent())):
                     tree_tools.set_descendant(root, shared_param_path,
                                               list(shared_val_choices)[0])
Exemplo n.º 7
0
 def create_referenced_default_args(self, root):
     for _, node in tree_tools.traverse_tree(root):
         if isinstance(node, tree_tools.Ref):
             referenced_path = node.get_path()
             if not referenced_path:
                 continue  # skip named paths
             if isinstance(referenced_path, str):
                 referenced_path = tree_tools.Path(referenced_path)
             give_up = False
             for ancestor in sorted(referenced_path.ancestors(),
                                    key=lambda x: len(x)):
                 try:
                     tree_tools.get_descendant(root, ancestor)
                 except tree_tools.PathError:
                     ancestor_parent = tree_tools.get_descendant(
                         root, ancestor.parent())
                     if isinstance(ancestor_parent, Serializable):
                         init_args_defaults = tree_tools.get_init_args_defaults(
                             ancestor_parent)
                         if ancestor[-1] in init_args_defaults:
                             referenced_arg_default = init_args_defaults[
                                 ancestor[-1]].default
                         else:
                             referenced_arg_default = inspect.Parameter.empty
                         if referenced_arg_default == inspect.Parameter.empty:
                             if node.is_required():
                                 raise ValueError(
                                     f"Reference '{node}' is required but does not exist and has no default arguments"
                                 )
                         else:
                             tree_tools.set_descendant(
                                 root, ancestor, referenced_arg_default)
                     else:
                         if node.is_required():
                             raise ValueError(
                                 f"Reference '{node}' is required but does not exist"
                             )
                         give_up = True
                 if give_up: break