def parse_sweep_config(sweep_config) -> typ.List[typ.List[str]]: """ Takes in a sweep config and determines paths to the sweep tokens in the config. Uses the fact that all sweep tokens in the config are prefixed by the SWEEP_PREFIX (~) character. Each token path """ # To identify the swept parameters, # use the fact that any parameter sweep must use the special prefix (~ by default) token_paths = get_only_paths( sweep_config, pred=lambda p: any(lambda s: QuinSweep.SWEEP_PREFIX in str(s), p), stop_at=QuinSweep.SWEEP_PREFIX) # As output, we get a list of paths that point to locations of all the ~ tokens in the config token_paths = list(map(tuple, token_paths)) # Confirm that all the tokens followed by ~ are correctly specified all_tokens_ok = all( lambda s: s.startswith(QuinSweep.SWEEP_PREFIX) and s.strip( QuinSweep.SWEEP_PREFIX) in QuinSweep.SWEEP_TOKENS, map(last, token_paths)) assert all_tokens_ok, f'Unknown token: sweep config failed parsing. ' \ f'Only tokens {QuinSweep.SWEEP_TOKENS} are allowed.' return token_paths
def parse_fixed_parameters(sweep_config): # Extract the locations of the other non-swept, fixed parameters # use the fact that the other parameters must not be prefixed by the special prefix fixed_parameters = get_only_paths( sweep_config, pred=lambda p: all(lambda s: QuinSweep.SWEEP_PREFIX not in str(s), p)) # Make Parameter objects fixed_parameters = list( map( lambda tp: Parameter(tp, f'{".".join(map(str, tp))}.0', get_in(sweep_config, tp)), fixed_parameters)) return fixed_parameters
def replace_underscores(self, swept_parameter): """ Replace all the underscore references in sweep of swept_parameter. """ # Find all the references (i.e. dependencies) made by the swept_parameter references = [] for token in QuinSweep.SWEEP_TOKENS[: -1]: # omit default since it's value is never a dict if f"~{token}" in swept_parameter.sweep and is_mapping( swept_parameter.sweep[f"~{token}"]): references.extend( list(swept_parameter.sweep[f"~{token}"].keys())) # Find all the referred parameters parsed_references = list(map(QuinSweep.parse_ref_dotpath, references)) dotpaths = list(cat(parsed_references)) ref_dict = merge_with(compose(list, cat), *list(map(lambda e: dict([e]), dotpaths))) # TODO: there's a bug here potentially assert all(map(lambda l: len(l) == len(set(l)), list(itervalues(ref_dict)))), \ 'All conditions must be distinct.' ref_dict_no_underscores = walk_values( compose(set, autocurry(map)(int), autocurry(filter)(lambda e: e != '_')), ref_dict) if not references: return swept_parameter def compute_possibilities(full_dotpath, reference): # Look up the parameter using the dotpath parameter = self.swept_parameters_dict[full_dotpath] # Use the reference to figure out how many possiblities exist for the underscore if len(reference) > 0: # Merge all the sweeps performed for this parameter merged_sweep = merge( *list(filter(is_mapping, itervalues(parameter.sweep)))) # Look up the reference return len(merged_sweep[reference]) assert len( parameter.sweep ) == 1, 'If no reference, must be a single unconditional sweep.' # The number of possibilities is simply the number of values specified # in the (product/disjoint) unconditional sweep return len(list(parameter.sweep.values())[0]) # Update the sweep by replacing underscores updated_sweep = swept_parameter.sweep # Loop over all the parsed references for parsed_ref in parsed_references: # Expand all the partial dotpaths # TODO: remove? expanding all the partial dotpaths in the beginning? parsed_ref = list( map(lambda t: (self.expand_partial_dotpath(t[0]), t[1]), parsed_ref)) # For each parsed reference, there will be multiple (dotpath, idx) pairs for i, (full_dotpath, ref_idx) in enumerate(parsed_ref): # If the reference index is not an underscore, continue if not ref_idx == '_': continue # Compute the prefix reference prefix_reference = ".".join(list(cat(parsed_ref[:i]))) # Compute the number of possible ways to replace the underscore n_possibilities = compute_possibilities( full_dotpath, prefix_reference) replacements = set(range( n_possibilities)) - ref_dict_no_underscores[full_dotpath] # Find the path to the underscore condition path_to_condition = get_only_paths( updated_sweep, lambda p: any(lambda e: '_' in e, p), stop_at=full_dotpath)[0] # Find the value of the underscore condition value = tz.get_in(path_to_condition, updated_sweep) # Construct keys that are subtitutes for the underscore keys = list(map(lambda s: f'{full_dotpath}.{s}', replacements)) keys = list(map(lambda k: path_to_condition[:-1] + [k], keys)) # Update by adding those keys in for k in keys: updated_sweep = tz.assoc_in(updated_sweep, k, value) # Create a new swept parameter with the updated sweep swept_parameter = SweptParameter( swept_parameter.path, walk_values( iffy(is_mapping, autocurry(select_keys)(lambda k: '_' not in k)), updated_sweep)) return swept_parameter