def split_key(keychain, reference): key = ''.join(keychain[-1:]) ref_value = get_value(reference, safe(keychain)) dp.delete(reference, safe(keychain)) keychain.pop() keychain.extend(key.split(".", maxsplit=1)) dp.new(reference, safe(keychain), ref_value)
def map(self, from_, to, separator='/', skip_missing=ZATO_NOT_GIVEN, default=ZATO_NOT_GIVEN): """ Maps 'from_' into 'to', splitting from using the 'separator' and applying transformation functions along the way. """ if skip_missing == ZATO_NOT_GIVEN: skip_missing = self.skip_missing if default == ZATO_NOT_GIVEN: default = self.default # Store for later use, such as in log entries. orig_from = from_ force_func = None force_func_name = None needs_time_reformat = False from_format, to_format = None, None # Perform any string substitutions first. if self.subs: from_.format(**self.subs) to.format(**self.subs) # Pick at most one processing functions. for key in self.func_keys: if from_.startswith(key): from_ = from_.replace('{}:'.format(key), '', 1) force_func = self.funcs[key] force_func_name = key break # Perhaps it's a date value that needs to be converted. if from_.startswith('time:'): needs_time_reformat = True from_format, from_ = self._get_time_format(from_) to_format, to_ = self._get_time_format(to) # Obtain the value. value = self.source.get(from_.split(separator)[1:]) if needs_time_reformat: value = self.time_util.reformat(value, from_format, to_format) # Don't return anything if we are to skip missing values # or, we aren't, return a default value. if not value: if skip_missing: return else: value = default if default != ZATO_NOT_GIVEN else value # We have some value, let's process it using the function found above. if force_func: try: value = force_func(value) except Exception: logger.warn('Error in force_func:`%s` `%s` over `%s` in `%s` -> `%s` e:`%s`', force_func_name, force_func, value, orig_from, to, format_exc()) raise dpath_util.new(self.target, to, value)
def _set_if_absent(d, path, value): if '*' in path: [pre, post] = path.split('*') elem_count = len(du.values(d, f'{pre}*')) for i in range(elem_count): _set_if_absent(d, f'{pre}{i}{post}', value) elif du.search(d, path) == {}: du.new(d, path, value())
def create(self, user_id, challenge_id, stack, service, flag=True): result = {'id': self.ids.encode(user_id, challenge_id)} name = stack_name(user_id, challenge_id) if name in self.stacks.ls(): raise InstanceExistsError( f'An instance of challenge ID {challenge_id} already exists for user ID {user_id}' ) self.ensure_gateway_up(user_id) stack_context = { 'chad_id': result['id'], 'chad_docker_registry': self.docker_registry } stack_template = Template(json.dumps(stack)) stack = json.loads(stack_template.safe_substitute(**stack_context)) # Docker Swarm overlay networks don't F*****G SUPPORT MULTICAST dpath.new( stack, f'networks/challenge', { 'driver': self.network_plugin, 'external': True, 'name': f'chad_{user_id}' }) if flag: if isinstance(flag, str): result['flag'] = flag elif flag is True: result['flag'] = self.flags.next_flag() elif isinstance(flag, int): result['flag'] = self.flags.next_flag(flag) secret_tmp = tempfile.NamedTemporaryFile('w', prefix='flag', suffix='.txt', encoding='ascii') secret_tmp.write(f'{result["flag"]}\n') secret_tmp.flush() dpath.new(stack, 'secrets/flag/file', secret_tmp.name) dpath.merge( stack, { 'services': { service: { 'secrets': [{ 'source': 'flag', 'target': 'flag.txt', 'mode': 0o440 }] } } }) self.redis.set(f'{name}_last_ping', int(time.time())) self.stacks.deploy(name, stack, registry_auth=True) if flag: secret_tmp.close() return result
def unpack(keychain, reference, response): res_keys = get_value(response, keychain).keys() table = get_value(reference, safe(keychain)).split(".")[0] dp.delete(reference, safe(keychain)) for res_key in res_keys: keychain.append(res_key) value = "{0}.{1}".format(table, res_key) dp.new(reference, safe(keychain), value) keychain.pop()
def _dict_from_object(self, keys, obj, auction_index): to_patch = {} for to_key, from_key in keys.items(): try: value = util.get(obj, from_key.format(auction_index)) except KeyError: continue util.new(to_patch, to_key, value) return to_patch
def add_submission_with(request, username, id_string): import uuid import requests from django.conf import settings from django.template import loader, Context from dpath import util as dpath_util from dict2xml import dict2xml def geopoint_xpaths(username, id_string): d = DataDictionary.objects.get(user__username=username, id_string=id_string) return [ e.get_abbreviated_xpath() for e in d.get_survey_elements() if e.bind.get(u'type') == u'geopoint' ] value = request.GET.get('coordinates') xpaths = geopoint_xpaths(username, id_string) xml_dict = {} for path in xpaths: dpath_util.new(xml_dict, path, value) context = { 'username': username, 'id_string': id_string, 'xml_content': dict2xml(xml_dict) } instance_xml = loader.get_template("instance_add.xml").render( Context(context)) url = settings.ENKETO_API_INSTANCE_IFRAME_URL return_url = reverse('thank_you_submission', kwargs={ "username": username, "id_string": id_string }) if settings.DEBUG: openrosa_url = "https://dev.formhub.org/{}".format(username) else: openrosa_url = request.build_absolute_uri("/{}".format(username)) payload = { 'return_url': return_url, 'form_id': id_string, 'server_url': openrosa_url, 'instance': instance_xml, 'instance_id': uuid.uuid4().hex } r = requests.post(url, data=payload, auth=(settings.ENKETO_API_TOKEN, ''), verify=False) return HttpResponse(r.text, mimetype='application/json')
def add_submission_with(request, username, id_string): """ Returns JSON response with Enketo form url preloaded with coordinates. """ def geopoint_xpaths(username, id_string): """ Returns xpaths with elements of type 'geopoint'. """ data_dictionary = DataDictionary.objects.get( user__username__iexact=username, id_string__iexact=id_string) return [ e.get_abbreviated_xpath() for e in data_dictionary.get_survey_elements() if e.bind.get(u'type') == u'geopoint' ] value = request.GET.get('coordinates') xpaths = geopoint_xpaths(username, id_string) xml_dict = {} for path in xpaths: dpath_util.new(xml_dict, path, value) context = { 'username': username, 'id_string': id_string, 'xml_content': dict2xml(xml_dict) } instance_xml = loader.get_template("instance_add.xml")\ .render(context) url = settings.ENKETO_API_INSTANCE_IFRAME_URL return_url = reverse( 'thank_you_submission', kwargs={"username": username, "id_string": id_string}) if settings.DEBUG: openrosa_url = "https://dev.formhub.org/{}".format(username) else: openrosa_url = request.build_absolute_uri("/{}".format(username)) payload = { 'return_url': return_url, 'form_id': id_string, 'server_url': openrosa_url, 'instance': instance_xml, 'instance_id': uuid.uuid4().hex } response = requests.post( url, data=payload, auth=(settings.ENKETO_API_TOKEN, ''), verify=getattr(settings, 'VERIFY_SSL', True)) return HttpResponse(response.text, content_type='application/json')
def store(keychain, reference, response, records): #93.45 log.debug("Store record") #0 ref_values = get_ref(keychain, reference, response, records) #20.26 #print(ref_values) for ref in ref_values: #105.57 if type(keychain[-1]) is bool: res = keychain[-1] #0 else: res = get_value_store(response, keychain) if type(res) is bool: res = str(res) #0 dp.new(records, ref, res) #0 add_primary(list(ref), reference, response, records) #39.84 return records
def set(self, name, doc, value, return_on_missing=False, in_place=True): if return_on_missing: if not self.get(name, doc): return doc pointer = self.data[name] try: return pointer.set(doc, value, in_place) except PathNotFoundException: dpath_util.new(doc, '/' + '/'.join(pointer.parts), value) return doc
def add_primary(ref, reference, response, records): ref_value = get_value(reference, ["uuid"]) res_value = get_value(response, ["uuid"]) if not ref_value: ref_value = get_value(reference, ["permalink"]) res_value = get_value(response, ["permalink"]) for key in ref_value: table, attribute = tuple(key.split(".")) if get_value(records, ref[:-1]): if ref[0] in table: ref[2] = attribute dp.new(records, ref, res_value)
def setval(self, path, val): # Sets/creates key in data to/with value if self.data is None: log.debug("Data is empty! Ignoring changes.") return elif dpath.set(self.data, path, val): log.debug("Changed: '%s'.", path) if self.alog: log.debug("New value is " + str(val)) else: dpath.new(self.data, path, val) log.debug("Created '%s'.", path) if self.alog: log.debug("New value is " + str(val))
def create_message(xstream_function: str, refno='', ptype='', polref='', prospect=None, risk=None) -> str: print('Info: creating xstream message') try: template = { 'xmlexecute': { 'job': { 'queue': '1' }, 'parameters': { 'yzt': { 'Char20.1': xstream_function } }, 'apmdata': { 'prospect': { 'p.cm': prospect or { 'refno': refno } } }, 'apmpolicy': { 'p.py': { 'polref': polref, 'ptype': ptype } } } } if risk: for risk_frame, risk_fields in risk.items(): dict_xpath.new( template, 'xmlexecute/apmpolicy/{}'.format(risk_frame), risk_fields) parsed_template = xmltodict.unparse(template, full_document=False) except Exception as e: print('Error: error in create_message - {}'.format(e)) raise Exception(e) return parsed_template
def add_submission_with(request, username, id_string): import uuid import requests from django.template import loader, Context from dpath import util as dpath_util from dict2xml import dict2xml def geopoint_xpaths(username, id_string): d = DataDictionary.objects.get( user__username__iexact=username, id_string__exact=id_string) return [e.get_abbreviated_xpath() for e in d.get_survey_elements() if e.bind.get(u'type') == u'geopoint'] value = request.GET.get('coordinates') xpaths = geopoint_xpaths(username, id_string) xml_dict = {} for path in xpaths: dpath_util.new(xml_dict, path, value) context = {'username': username, 'id_string': id_string, 'xml_content': dict2xml(xml_dict)} instance_xml = loader.get_template("instance_add.xml")\ .render(Context(context)) url = settings.ENKETO_API_INSTANCE_IFRAME_URL return_url = reverse('thank_you_submission', kwargs={"username": username, "id_string": id_string}) if settings.DEBUG: openrosa_url = "https://dev.formhub.org/{}".format(username) else: openrosa_url = request.build_absolute_uri("/{}".format(username)) payload = {'return_url': return_url, 'form_id': id_string, 'server_url': openrosa_url, 'instance': instance_xml, 'instance_id': uuid.uuid4().hex} r = requests.post(url, data=payload, auth=(settings.ENKETO_API_TOKEN, ''), verify=False) return HttpResponse(r.text, content_type='application/json')
def save_snp_500_tickers(tickers: list) -> None: """update the YAML market coordinates config with the SNP500 tickers""" mkt_class = "equity".upper() mkt_type = "single stock".upper() market_coordinates = mkt_classes.mkt_data_cfg() # lets load the defaults and then see if there is tsdb yaml to overwrite base defaults defaults = mkt_coord_defaults.defaults.copy() mkt_default_cfg_load = mkt_classes.mkt_defaults_cfg() dp.merge(defaults, mkt_default_cfg_load) equity_defaults = [ i for i in dp.search( defaults, '{0}/{1}'.format(mkt_class, mkt_type), yielded=True) ].pop()[1] for ticker in tickers: mkt_asset = ticker points_default = [ i for i in dp.search(market_coordinates, '{0}/{1}/{2}/points'.format( mkt_class, mkt_type, mkt_asset), yielded=True) ] points_default = points_default.pop()[1] if len(points_default) else [] points = list(set(points_default)) value = {'points': points} value.update(equity_defaults) xpath = '{0}/{1}/{2}'.format(mkt_class, mkt_type, mkt_asset) dp.new(market_coordinates, xpath, value) mkt_data_cfg = { 'market_coordinates': market_coordinates, 'defaults': defaults } with open(mkt_classes.tsdb_path() + 'market_coord_cfg.YAML', "w") as f: yaml.dump(mkt_data_cfg, f) "added snp500 tickers to the config"
async def get_gm_borders(gm_data: dict): gm_table_info = {} for index, region in enumerate(REGIONS): # No GM data could be retrieved if not gm_data[index]: continue ladder_teams = gm_data[index]["ladderTeams"] rank190: list = deepcopy(ladder_teams) # Sometimes, some players don't seem to have mmr rank190 = [i for i in rank190 if "mmr" in i] # Sort descending rank190.sort(key=lambda x: x["mmr"], reverse=True) rank190 = rank190[:190] # GM somehow doesnt have at least one member if not rank190: continue max_mmr = rank190[0]["mmr"] min_mmr = rank190[-1]["mmr"] new(gm_table_info, ["201", "6", "0", region, "min_rating"], min_mmr) new(gm_table_info, ["201", "6", "0", region, "max_rating"], max_mmr) return gm_table_info
def convert_dataframe_to_dict(df): temp_dict = df.to_dict(orient='index', into=OrderedDict) out = [] for county_precinct, cols in temp_dict.items(): row = OrderedDict() row[df.index.names[0]] = county_precinct[0] row[df.index.names[1]] = county_precinct[1] top_headers = [ _ for _ in list(set(header[0] for header in cols.keys())) ] sorted_top_header = get_top_headers_sorted_by_std_order(top_headers) for top_header in sorted_top_header: row[top_header] = OrderedDict() for header, val in cols.items(): if header[1] == '': header = [header[0]] du.new(row, list(header), val) out.append(row) return out
def get_config(ctx, self): if 'specification' not in self.config: self.config['specification'] = SPECIFCATION_LATEST if 'source' not in self.config: dpath_util.new(self.config, 'unikraft/source', self.core.source) if 'version' not in self.config: dpath_util.new(self.config, 'unikraft/version', self.core.version) for arch in self.architectures.all(): dpath_util.new(self.config, 'architectures/%s' % arch.name, True) for plat in self.platforms.all(): dpath_util.new(self.config, 'platforms/%s' % plat.name, True) # for lib in self.libraries.all(): # print(lib) return self.config
def add_log(self, ct, tag, val): try: du.get(self.memory, f'/log/{ct}/{tag}').append(val) except KeyError as e: du.new(self.memory, f'/log/{ct}/{tag}', [val]) return
def __init__(self, var_names): self.memory = dict() du.new(self.memory, '/memory/var_names', sorted(var_names)) return
def set_val(self, ct, tag, variable, value): if variable not in du.get(self.memory, '/memory/var_names'): print(f'ValProcessor: {variable} is not in my memory') du.new(self.memory, f'/data/{ct}/{tag}/{variable}', value) return
def new(self, file, val): #To easily edit files or folders filer.new(self.fs, self.formatpath(file), val)
def set(self, to, value): """ Sets 'to' to a static 'value'. """ dpath_util.new(self.target, to, value)
class Mapper(object): def __init__(self, source, target=None, time_util=None, skip_missing=True, default=None, *args, **kwargs): self.target = target if target is not None else {} self.map_type = kwargs.get('msg_type', MSG_MAPPER.DICT_TO_DICT) self.skip_ns = kwargs.get('skip_ns', True) self.time_util = time_util self.skip_missing = skip_missing self.default = default self.subs = {} self.funcs = { 'int': int, 'long': long, 'bool': asbool, 'dec': Decimal, 'arrow': arrow_get } self.func_keys = self.funcs.keys() self.times = {} self.cache = {} if isinstance(source, DictNav): self.source = source else: if self.map_type.startswith('dict-to-'): self.source = DictNav(source) def set_time(self, name, format): self.times[name] = format def set_substitution(self, name, value): self.subs[name] = value def set_func(self, name, func): self.funcs[name] = func self.func_keys = self.funcs.keys() def map(self, from_, to, separator='/', skip_missing=ZATO_NOT_GIVEN, default=ZATO_NOT_GIVEN): """ Maps 'from_' into 'to', splitting from using the 'separator' and applying transformation functions along the way. """ if skip_missing == ZATO_NOT_GIVEN: skip_missing = self.skip_missing if default == ZATO_NOT_GIVEN: default = self.default # Store for later use, such as in log entries. orig_from = from_ force_func = None force_func_name = None needs_time_reformat = False from_format, to_format = None, None # Perform any string substitutions first. if self.subs: from_.format(**self.subs) to.format(**self.subs) # Pick at most one processing functions. for key in self.func_keys: if from_.startswith(key): from_ = from_.replace('{}:'.format(key), '', 1) force_func = self.funcs[key] force_func_name = key break # Perhaps it's a date value that needs to be converted. if from_.startswith('time:'): needs_time_reformat = True from_format, from_ = self._get_time_format(from_) to_format, to_ = self._get_time_format(to) # Obtain the value. value = self.source.get(from_.split(separator)[1:]) if needs_time_reformat: value = self.time_util.reformat(value, from_format, to_format) # Don't return anything if we are to skip missing values # or, we aren't, return a default value. if not value: if skip_missing: return else: value = default if default != ZATO_NOT_GIVEN else value # We have some value, let's process it using the function found above. if force_func: try: value = force_func(value) except Exception, e: logger.warn( 'Error in force_func:`%s` `%s` over `%s` in `%s` -> `%s` e:`%s`', force_func_name, force_func, value, orig_from, to, format_exc(e)) raise dpath_util.new(self.target, to, value)
async def get_sc2_legacy_ladder_api_data(client: aiohttp.ClientSession, access_token: str, fetch_delay: float, prepared_data: dict, gm_data: dict): url = "https://{}.api.blizzard.com/sc2/legacy/ladder/{}/{}" # url = f"https://{region}.api.blizzard.com/sc2/legacy/ladder/{region_id}/{ladder_id}" # Debugging info profiles_with_no_favorite_race_p1 = 0 total_profiles = 0 # Each table has keys: us, eu, kr # Table with average games per placement-account avg_games_table = {} # Table with averaage winrate per placement-account avg_winrate_table = {} # Table with total games total_games_table = {} for region_id, region_name in enumerate(REGIONS, start=1): new_table_avg_games = [] new_table_avg_winrate = [] new_table_total_games = [] for mode in MODES[:1]: row_number = 0 for league_id, league in enumerate(LEAGUES[:6]): for tier_id in reversed(range(3)): # Skip if it doesnt exist, e.g. for GM when GM is locked if not get(prepared_data, f"{mode}/{league_id}/{tier_id}", default={}): continue new_row_avg_games = [ROW_DESCRIPTIONS[row_number]] new_row_avg_winrate = [ROW_DESCRIPTIONS[row_number]] new_row_total_games = [ROW_DESCRIPTIONS[row_number]] row_number += 1 # Get normal non-ladder stats urls = [ url.format(region_name, region_id, ladder_id) for ladder_id in get( prepared_data, f"{mode}/{league_id}/{tier_id}/{region_name}/ladder_ids", default=[]) ] responses = await fetch_multiple(client, access_token, urls, fetch_delay) # Collect games per race, keys are: P, T, Z, R league_tier_wins = {} league_tier_losses = {} league_tier_profiles = {} for response in responses: if "ladderMembers" not in response: logger.error( f"Error with response, no key found with 'ladderMembers'" ) continue for profile in response["ladderMembers"]: # Ignore profile if buggy (race not shown?) total_profiles += 1 if "favoriteRaceP1" not in profile: logger.error( f"Error with profile in region '{region_name}' - has no 'favoriteRaceP1' entry." ) profiles_with_no_favorite_race_p1 += 1 continue wins = profile["wins"] losses = profile["losses"] race = profile["favoriteRaceP1"][0] # Load old data from dict total_race_wins = get(league_tier_wins, f"{race}", default=0) total_race_losses = get(league_tier_losses, f"{race}", default=0) total_race_profiles = get(league_tier_profiles, f"{race}", default=0) # Store new sum of data new(league_tier_wins, f"{race}", total_race_wins + wins) new(league_tier_losses, f"{race}", total_race_losses + losses) new(league_tier_profiles, f"{race}", total_race_profiles + 1) # Calculate average games per profile new_row_avg_games += [ get_avg_games_entry(league_tier_wins, league_tier_losses, league_tier_profiles, race) for race in RACES + ["TOTAL"] ] new_table_avg_games.append(new_row_avg_games) # Calculate average winrate per race new_row_avg_winrate += [ get_avg_winrate_entry(league_tier_wins, league_tier_losses, race) for race in RACES ] new_table_avg_winrate.append(new_row_avg_winrate) # Calculate total games per race new_row_total_games += [ get_total_games_entry(league_tier_wins, league_tier_losses, race) for race in RACES + ["TOTAL"] ] new_table_total_games.append(new_row_total_games) # Add region data new_table_avg_games.append(STATISTICS_HEADER_WITH_TOTAL) new_table_avg_games.reverse() avg_games_table[f"{region_name}"] = new_table_avg_games new_table_avg_winrate.append(STATISTICS_HEADER) new_table_avg_winrate.reverse() avg_winrate_table[f"{region_name}"] = new_table_avg_winrate new_table_total_games.append(STATISTICS_HEADER_WITH_TOTAL) new_table_total_games.reverse() total_games_table[f"{region_name}"] = new_table_total_games # Add GM stats await add_gm_stats(gm_data, avg_games_table, total_games_table, avg_winrate_table) # Debugging blizzard API if total_profiles > 0: _fraction = profiles_with_no_favorite_race_p1 / total_profiles if profiles_with_no_favorite_race_p1 > 0: logger.warning( f"Found {profiles_with_no_favorite_race_p1} / {total_profiles} ({get_percentage(_fraction)}) profiles which are incompletely returned by the legacy API." ) logger.info(f"Outputting info to 'avg_games_table.json'") with open("avg_games_table.json", "w") as f: json.dump(avg_games_table, f, indent=4, sort_keys=True) logger.info(f"Outputting info to 'avg_winrate_table.json'") with open("avg_winrate_table.json", "w") as f: json.dump(avg_winrate_table, f, indent=4, sort_keys=True) logger.info(f"Outputting info to 'total_games_table.json'") with open("total_games_table.json", "w") as f: json.dump(total_games_table, f, indent=4, sort_keys=True) return { "avg_games": avg_games_table, "avg_winrate": avg_winrate_table, "total_games": total_games_table, }
def create(self, keypath: str, value: Any) -> None: '''Create new key/value pair located at path.''' dpath.new(self, keypath, value, DpathMixin.separator)
def add(self, keypath: str, value: Any) -> None: '''Add key/value pair located at keypath.''' dpath.new(self, keypath, value, DpathMixin.separator)
def modify_values_yaml(experiment_folder: str, script_location: str, script_parameters: Tuple[str, ...], experiment_name: str, run_name: str, pack_type: str, cluster_registry_port: int, pack_params: List[Tuple[str, str]], env_variables: List[str]): log.debug("Modify values.yaml - start") values_yaml_filename = os.path.join(experiment_folder, f"charts/{pack_type}/values.yaml") values_yaml_temp_filename = os.path.join( experiment_folder, f"charts/{pack_type}/values_temp.yaml") with open(values_yaml_filename, "r") as values_yaml_file: template = jinja2.Template(values_yaml_file.read()) rendered_values = template.render( NAUTA={ 'ExperimentName': experiment_name, 'CommandLine': common.prepare_script_paramaters(script_parameters, script_location), 'RegistryPort': str(cluster_registry_port), 'ExperimentImage': f'127.0.0.1:{cluster_registry_port}/{run_name}', 'ImageRepository': f'127.0.0.1:{cluster_registry_port}' }) v = yaml.load(rendered_values) workersCount = None pServersCount = None regex = re.compile( "^\[.*|^\{.*" ) # Regex used for detecting dicts/arrays in pack params for key, value in pack_params: if re.match(regex, value): try: value = ast.literal_eval(value) except Exception as e: raise AttributeError( Texts.CANT_PARSE_VALUE.format(value=value, error=e)) # Handle boolean params elif value in {"true", "false"}: value = _parse_yaml_boolean(value) if key == WORK_CNT_PARAM: workersCount = value if key == P_SERV_CNT_PARAM: pServersCount = value dutil.new(v, key, value, '.') # setting sum of replicas involved in multinode training if both pServersCount and workersCount are present in # the pack or given in the cli if (WORK_CNT_PARAM in v or workersCount) and (P_SERV_CNT_PARAM in v or pServersCount): number_of_replicas = int(v.get( WORK_CNT_PARAM)) if not workersCount else int(workersCount) number_of_replicas += int(v.get( P_SERV_CNT_PARAM)) if not pServersCount else int(pServersCount) v[POD_COUNT_PARAM] = number_of_replicas if env_variables: env_list = [] for variable in env_variables: key, value = variable.split("=") one_env_map = {"name": key, "value": value} env_list.append(one_env_map) if v.get("env"): v["env"].extend(env_list) else: v["env"] = env_list with open(values_yaml_temp_filename, "w") as values_yaml_file: yaml.dump(v, values_yaml_file) shutil.move(values_yaml_temp_filename, values_yaml_filename) log.debug("Modify values.yaml - end")
def set_service(self, service, value): du.new(self.memory, f'/service/{service}', value) return
def make(self, file): #To easily edit files or folders if "." in file: filer.new(self.fs, self.formatpath(file), "") else: filer.new(self.fs, self.formatpath(file), {})
def spec(self, path, value): du.new(self._spec, path, value, '.')
def modify_values_yaml(experiment_folder: str, script_location: str, script_parameters: Tuple[str, ...], experiment_name: str, pack_type: str, username: str, cluster_registry_port: int, pack_params: List[Tuple[str, str]] = None, env_variables: List[str] = None): log.debug("Modify values.yaml - start") pack_params = pack_params if pack_params else [] values_yaml_filename = os.path.join(experiment_folder, f"charts/{pack_type}/values.yaml") values_yaml_temp_filename = os.path.join(experiment_folder, f"charts/{pack_type}/values_temp.yaml") with open(values_yaml_filename, "r") as values_yaml_file: template = jinja2.Template(values_yaml_file.read()) rendered_values = template.render(NAUTA = { 'ExperimentName': experiment_name, 'CommandLine': common.prepare_script_paramaters(script_parameters, script_location), 'RegistryPort': str(cluster_registry_port), 'ExperimentImage': f'127.0.0.1:{cluster_registry_port}/{username}/{experiment_name}:latest', 'ImageRepository': f'127.0.0.1:{cluster_registry_port}/{username}/{experiment_name}:latest' }) v = yaml.safe_load(rendered_values) workersCount = None pServersCount = None regex = re.compile(r"^\[.*|^\{.*") # Regex used for detecting dicts/arrays in pack params for key, value in pack_params: if re.match(regex, value): try: value = ast.literal_eval(value) except Exception as e: raise AttributeError(Texts.CANT_PARSE_VALUE.format(value=value, error=e)) # Handle boolean params elif value in {"true", "false"}: value = str(_parse_yaml_boolean(value)) if key == WORK_CNT_PARAM: workersCount = value if key == P_SERV_CNT_PARAM: pServersCount = value dutil.new(v, key, value, '.') # setting sum of replicas involved in multinode training if both pServersCount and workersCount are present in # the pack or given in the cli if (WORK_CNT_PARAM in v or workersCount) and (P_SERV_CNT_PARAM in v or pServersCount): number_of_replicas = int(v.get(WORK_CNT_PARAM)) if not workersCount else int(workersCount) number_of_replicas += int(v.get(P_SERV_CNT_PARAM)) if not pServersCount else int(pServersCount) v[POD_COUNT_PARAM] = number_of_replicas elif (WORK_CNT_PARAM in v or workersCount) and (POD_COUNT_PARAM not in v): number_of_replicas = int(v.get(WORK_CNT_PARAM)) if not workersCount else int(workersCount) v[POD_COUNT_PARAM] = number_of_replicas + 1 env_variables = env_variables if env_variables else [] parsed_envs = [] for variable in env_variables: key, value = variable.split("=") one_env_map = {"name": key, "value": value} parsed_envs.append(one_env_map) # Set OMP_NUM_THREADS to be equal to cpu limit if it was not explicitly passed if "OMP_NUM_THREADS" not in (env["name"] for env in parsed_envs): try: cpu_limit = calculate_omp_num_threads(v) if cpu_limit: parsed_envs.append({"name": "OMP_NUM_THREADS", "value": str(cpu_limit)}) except (ValueError, TypeError, KeyError): log.exception("Failed to infer OMP_NUM_THREADS value.") envs_to_set = {'env', 'worker.env', 'master.env'} # Env placeholders in values.yaml that we expect for env in envs_to_set: if dutil.search(v, env, separator='.'): dutil.get(v, env, separator='.').extend(parsed_envs) with open(values_yaml_temp_filename, "w") as values_yaml_file: yaml.safe_dump(v, values_yaml_file) shutil.move(values_yaml_temp_filename, values_yaml_filename) log.debug("Modify values.yaml - end")