def resolve_serialize_refs(self, root): # for _, node in tree_tools.traverse_serializable_breadth_first(root): for _, node in tree_tools.traverse_serializable(root): if isinstance(node, Serializable): node.resolved_serialize_params = node.serialize_params refs_inserted_at = set() refs_inserted_to = set() # for path_to, node in tree_tools.traverse_serializable_breadth_first(root): for path_to, node in tree_tools.traverse_serializable(root): if not refs_inserted_at & path_to.ancestors( ) and not refs_inserted_at & path_to.ancestors(): if isinstance(node, Serializable): # for path_from, matching_node in tree_tools.traverse_serializable_breadth_first(root): for path_from, matching_node in tree_tools.traverse_serializable( root): if not path_from in refs_inserted_to: if path_from != path_to and matching_node is node: ref = tree_tools.Ref(path=path_to) ref.resolved_serialize_params = ref.serialize_params tree_tools.set_descendant( root, path_from.parent().append( "resolved_serialize_params").append( path_from[-1]), ref) refs_inserted_at.add(path_from) refs_inserted_to.add(path_from)
def init_components_bottom_up(self, root): for path, node in tree_tools.traverse_tree_deep_once( root, root, tree_tools.TraversalOrder.ROOT_LAST, named_paths=self.named_paths): if isinstance(node, Serializable): if isinstance(node, tree_tools.Ref): try: resolved_path = node.resolve_path(self.named_paths) hits_before = self.init_component.cache_info().hits initialized_component = self.init_component( resolved_path) except tree_tools.PathError: initialized_component = None if self.init_component.cache_info().hits > hits_before: logger.debug( f"for {path}: reusing previously initialized {initialized_component}" ) else: initialized_component = self.init_component(path) if len(path) == 0: root = initialized_component else: tree_tools.set_descendant(root, path, initialized_component) return root
def format_strings(self, exp_values, format_dict): """ - replaces strings containing {EXP} and other supported args - also checks if there are default arguments for which no arguments are set and instantiates them with replaced {EXP} if applicable """ for path, node in tree_tools.traverse_tree(exp_values): if isinstance(node, str): try: formatted = node.format(**format_dict) except ( ValueError, KeyError ): # will occur e.g. if a vocab entry contains a curly bracket formatted = node if node != formatted: tree_tools.set_descendant(exp_values, path, FormatString(formatted, node)) elif isinstance(node, Serializable): init_args_defaults = tree_tools.get_init_args_defaults(node) for expected_arg in init_args_defaults: if not expected_arg in [ x[0] for x in tree_tools.name_children( node, include_reserved=False) ]: arg_default = init_args_defaults[expected_arg].default if isinstance(arg_default, str): try: formatted = arg_default.format(**format_dict) except ( ValueError, KeyError ): # will occur e.g. if a vocab entry contains a curly bracket formatted = arg_default if arg_default != formatted: setattr(node, expected_arg, FormatString(formatted, arg_default))
def load_referenced_model(self, experiment): if hasattr(experiment, "load") or hasattr(experiment, "overwrite"): exp_args = set([ x[0] for x in tree_tools.name_children(experiment, include_reserved=False) ]) if exp_args not in [set(["load"]), set(["load", "overwrite"])]: raise ValueError( f"When loading a model from an external YAML file, only 'load' and 'overwrite' are permitted ('load' is required) as arguments to the experiment. Found: {exp_args}" ) try: with open(experiment.load) as stream: saved_obj = yaml.load(stream) except IOError as e: raise RuntimeError( f"Could not read configuration file {experiment.load}: {e}" ) for saved_key, saved_val in tree_tools.name_children( saved_obj, include_reserved=True): if not hasattr(experiment, saved_key): setattr(experiment, saved_key, saved_val) if hasattr(experiment, "overwrite"): for d in experiment.overwrite: path = tree_tools.Path(d["path"]) try: tree_tools.set_descendant(experiment, path, d["val"]) except: tree_tools.set_descendant(experiment, path, d["val"]) delattr(experiment, "overwrite")
def instantiate_random_search(self, exp_values): param_report = {} initialized_random_params = {} for path, v in tree_tools.traverse_tree(exp_values): if isinstance(v, RandomParam): if hasattr(v, "_xnmt_id" ) and v._xnmt_id in initialized_random_params: v = initialized_random_params[v._xnmt_id] v = v.draw_value() if hasattr(v, "_xnmt_id"): initialized_random_params[v._xnmt_id] = v set_descendant(exp_values, path, v) param_report[path] = v return param_report
def share_init_params_top_down(self, root): abs_shared_param_sets = [] for path, node in tree_tools.traverse_tree(root): if isinstance(node, Serializable): for shared_param_set in node.shared_params(): abs_shared_param_set = set( p.get_absolute(path) for p in shared_param_set) added = False for prev_set in abs_shared_param_sets: if prev_set & abs_shared_param_set: prev_set |= abs_shared_param_set added = True break if not added: abs_shared_param_sets.append(abs_shared_param_set) for shared_param_set in abs_shared_param_sets: shared_val_choices = set() for shared_param_path in shared_param_set: try: new_shared_val = tree_tools.get_descendant( root, shared_param_path) except tree_tools.PathError: continue for _, child_of_shared_param in tree_tools.traverse_tree( new_shared_val, include_root=False): if isinstance(child_of_shared_param, Serializable): raise ValueError( f"{path} shared params {shared_param_set} contains Serializable sub-object {child_of_shared_param} which is not permitted" ) shared_val_choices.add(new_shared_val) if len(shared_val_choices) > 1: logger.warning( f"inconsistent shared params at {path} for {shared_param_set}: {shared_val_choices}; Ignoring these shared parameters." ) elif len(shared_val_choices) == 1: for shared_param_path in shared_param_set: if shared_param_path[ -1] in tree_tools.get_init_args_defaults( tree_tools.get_descendant( root, shared_param_path.parent())): tree_tools.set_descendant(root, shared_param_path, list(shared_val_choices)[0])
def create_referenced_default_args(self, root): for _, node in tree_tools.traverse_tree(root): if isinstance(node, tree_tools.Ref): referenced_path = node.get_path() if not referenced_path: continue # skip named paths if isinstance(referenced_path, str): referenced_path = tree_tools.Path(referenced_path) give_up = False for ancestor in sorted(referenced_path.ancestors(), key=lambda x: len(x)): try: tree_tools.get_descendant(root, ancestor) except tree_tools.PathError: ancestor_parent = tree_tools.get_descendant( root, ancestor.parent()) if isinstance(ancestor_parent, Serializable): init_args_defaults = tree_tools.get_init_args_defaults( ancestor_parent) if ancestor[-1] in init_args_defaults: referenced_arg_default = init_args_defaults[ ancestor[-1]].default else: referenced_arg_default = inspect.Parameter.empty if referenced_arg_default == inspect.Parameter.empty: if node.is_required(): raise ValueError( f"Reference '{node}' is required but does not exist and has no default arguments" ) else: tree_tools.set_descendant( root, ancestor, referenced_arg_default) else: if node.is_required(): raise ValueError( f"Reference '{node}' is required but does not exist" ) give_up = True if give_up: break