Esempio n. 1
0
 def format_strings(self, exp_values, format_dict):
     """
 - replaces strings containing {EXP} and other supported args
 - also checks if there are default arguments for which no arguments are set and instantiates them with replaced {EXP} if applicable
 """
     for path, node in tree_tools.traverse_tree(exp_values):
         if isinstance(node, str):
             try:
                 formatted = node.format(**format_dict)
             except (
                     ValueError, KeyError
             ):  # will occur e.g. if a vocab entry contains a curly bracket
                 formatted = node
             if node != formatted:
                 tree_tools.set_descendant(exp_values, path,
                                           FormatString(formatted, node))
         elif isinstance(node, Serializable):
             init_args_defaults = tree_tools.get_init_args_defaults(node)
             for expected_arg in init_args_defaults:
                 if not expected_arg in [
                         x[0] for x in tree_tools.name_children(
                             node, include_reserved=False)
                 ]:
                     arg_default = init_args_defaults[expected_arg].default
                     if isinstance(arg_default, str):
                         try:
                             formatted = arg_default.format(**format_dict)
                         except (
                                 ValueError, KeyError
                         ):  # will occur e.g. if a vocab entry contains a curly bracket
                             formatted = arg_default
                         if arg_default != formatted:
                             setattr(node, expected_arg,
                                     FormatString(formatted, arg_default))
Esempio n. 2
0
    def parse_experiment(self, filename, exp_name):
        """
    Returns a dictionary of experiments => {task => {arguments object}}
    """
        try:
            with open(filename) as stream:
                config = yaml.load(stream)
        except IOError as e:
            raise RuntimeError(
                f"Could not read configuration file {filename}: {e}")

        experiment = config[exp_name]

        for _, node in tree_tools.traverse_tree(experiment):
            if isinstance(node, Serializable):
                self.resolve_kwargs(node)

        self.load_referenced_model(experiment)

        random_search_report = self.instantiate_random_search(experiment)
        if random_search_report:
            setattr(experiment, 'random_search_report', random_search_report)

        # if arguments were not given in the YAML file and are set to a bare(Serializable) by default, copy the bare object into the object hierarchy so it can used w/ param sharing etc.
        self.resolve_bare_default_args(experiment)

        self.format_strings(
            experiment, {
                "EXP": exp_name,
                "PID": os.getpid(),
                "EXP_DIR": os.path.dirname(filename)
            })

        return UninitializedYamlObject(experiment)
Esempio n. 3
0
 def share_init_params_top_down(self, root):
     abs_shared_param_sets = []
     for path, node in tree_tools.traverse_tree(root):
         if isinstance(node, Serializable):
             for shared_param_set in node.shared_params():
                 abs_shared_param_set = set(
                     p.get_absolute(path) for p in shared_param_set)
                 added = False
                 for prev_set in abs_shared_param_sets:
                     if prev_set & abs_shared_param_set:
                         prev_set |= abs_shared_param_set
                         added = True
                         break
                 if not added:
                     abs_shared_param_sets.append(abs_shared_param_set)
     for shared_param_set in abs_shared_param_sets:
         shared_val_choices = set()
         for shared_param_path in shared_param_set:
             try:
                 new_shared_val = tree_tools.get_descendant(
                     root, shared_param_path)
             except tree_tools.PathError:
                 continue
             for _, child_of_shared_param in tree_tools.traverse_tree(
                     new_shared_val, include_root=False):
                 if isinstance(child_of_shared_param, Serializable):
                     raise ValueError(
                         f"{path} shared params {shared_param_set} contains Serializable sub-object {child_of_shared_param} which is not permitted"
                     )
             shared_val_choices.add(new_shared_val)
         if len(shared_val_choices) > 1:
             logger.warning(
                 f"inconsistent shared params at {path} for {shared_param_set}: {shared_val_choices}; Ignoring these shared parameters."
             )
         elif len(shared_val_choices) == 1:
             for shared_param_path in shared_param_set:
                 if shared_param_path[
                         -1] in tree_tools.get_init_args_defaults(
                             tree_tools.get_descendant(
                                 root, shared_param_path.parent())):
                     tree_tools.set_descendant(root, shared_param_path,
                                               list(shared_val_choices)[0])
Esempio n. 4
0
 def resolve_ref_default_args(self, root):
     for _, node in tree_tools.traverse_tree(root):
         if isinstance(node, Serializable):
             init_args_defaults = tree_tools.get_init_args_defaults(node)
             for expected_arg in init_args_defaults:
                 if not expected_arg in [
                         x[0] for x in tree_tools.name_children(
                             node, include_reserved=False)
                 ]:
                     arg_default = init_args_defaults[expected_arg].default
                     if isinstance(arg_default, tree_tools.Ref):
                         setattr(node, expected_arg, arg_default)
Esempio n. 5
0
 def get_named_paths(self, root):
     d = {}
     for path, node in tree_tools.traverse_tree(root):
         if "_xnmt_id" in [
                 name for (name, _) in tree_tools.name_children(
                     node, include_reserved=True)
         ]:
             xnmt_id = tree_tools.get_child(node, "_xnmt_id")
             if xnmt_id in d:
                 raise ValueError(
                     f"_xnmt_id {xnmt_id} was specified multiple times!")
             d[xnmt_id] = path
     return d
Esempio n. 6
0
 def instantiate_random_search(self, exp_values):
     param_report = {}
     initialized_random_params = {}
     for path, v in tree_tools.traverse_tree(exp_values):
         if isinstance(v, RandomParam):
             if hasattr(v, "_xnmt_id"
                        ) and v._xnmt_id in initialized_random_params:
                 v = initialized_random_params[v._xnmt_id]
             v = v.draw_value()
             if hasattr(v, "_xnmt_id"):
                 initialized_random_params[v._xnmt_id] = v
             set_descendant(exp_values, path, v)
             param_report[path] = v
     return param_report
Esempio n. 7
0
 def resolve_bare_default_args(self, root):
     for path, node in tree_tools.traverse_tree(root):
         if isinstance(node, Serializable):
             init_args_defaults = tree_tools.get_init_args_defaults(node)
             for expected_arg in init_args_defaults:
                 if not expected_arg in [
                         x[0] for x in tree_tools.name_children(
                             node, include_reserved=False)
                 ]:
                     arg_default = init_args_defaults[expected_arg].default
                     if isinstance(arg_default,
                                   Serializable) and not isinstance(
                                       arg_default, tree_tools.Ref):
                         if not getattr(arg_default, "_is_bare", False):
                             raise ValueError(
                                 f"only Serializables created via bare(SerializableSubtype) are permitted as default arguments; found a fully initialized Serializable: {arg_default} at {path}"
                             )
                         self.resolve_bare_default_args(
                             arg_default)  # apply recursively
                         setattr(node, expected_arg, arg_default)
Esempio n. 8
0
 def create_referenced_default_args(self, root):
     for _, node in tree_tools.traverse_tree(root):
         if isinstance(node, tree_tools.Ref):
             referenced_path = node.get_path()
             if not referenced_path:
                 continue  # skip named paths
             if isinstance(referenced_path, str):
                 referenced_path = tree_tools.Path(referenced_path)
             give_up = False
             for ancestor in sorted(referenced_path.ancestors(),
                                    key=lambda x: len(x)):
                 try:
                     tree_tools.get_descendant(root, ancestor)
                 except tree_tools.PathError:
                     ancestor_parent = tree_tools.get_descendant(
                         root, ancestor.parent())
                     if isinstance(ancestor_parent, Serializable):
                         init_args_defaults = tree_tools.get_init_args_defaults(
                             ancestor_parent)
                         if ancestor[-1] in init_args_defaults:
                             referenced_arg_default = init_args_defaults[
                                 ancestor[-1]].default
                         else:
                             referenced_arg_default = inspect.Parameter.empty
                         if referenced_arg_default == inspect.Parameter.empty:
                             if node.is_required():
                                 raise ValueError(
                                     f"Reference '{node}' is required but does not exist and has no default arguments"
                                 )
                         else:
                             tree_tools.set_descendant(
                                 root, ancestor, referenced_arg_default)
                     else:
                         if node.is_required():
                             raise ValueError(
                                 f"Reference '{node}' is required but does not exist"
                             )
                         give_up = True
                 if give_up: break
Esempio n. 9
0
 def check_args(self, root):
     for _, node in tree_tools.traverse_tree(root):
         if isinstance(node, Serializable):
             tree_tools.check_serializable_args_valid(node)