def url_test(url, location=None, status_code=301, req_headers=None, req_kwargs=None, resp_headers=None, query=None): """ Function for producing a config dict for the redirect test. You can use simple bash style brace expansion in the `url` and `location` values. If you need the `location` to change with the `url` changes you must use the same number of expansions or the `location` will be treated as non-expandable. If you use brace expansion this function will return a list of dicts instead of a dict. You must use the `flatten` function provided to prepare your test fixture if you do this. example: url_test('/about/drivers{/,.html}', 'https://wiki.mozilla.org/Firefox/Drivers'), url_test('/projects/index.{de,fr,hr,sq}.html', '/{de,fr,hr,sq}/firefox/products/'), :param url: The URL in question (absolute or relative). :param location: If a redirect, the expected value of the "Location" header. :param status_code: Expected status code from the request. :param req_headers: Extra headers to send with the request. :param req_kwargs: Extra arguments to pass to requests.get() :param resp_headers: Dict of headers expected in the response. :param query: Dict of expected query params in `location` URL. :return: dict or list of dicts """ test_data = { 'url': url, 'location': location, 'status_code': status_code, 'req_headers': req_headers, 'req_kwargs': req_kwargs, 'resp_headers': resp_headers, 'query': query, } expanded_urls = list(braceexpand(url)) num_urls = len(expanded_urls) if num_urls == 1: return test_data new_urls = [] if location: expanded_locations = list(braceexpand(test_data['location'])) num_locations = len(expanded_locations) for i, url in enumerate(expanded_urls): data = test_data.copy() data['url'] = url if location and num_urls == num_locations: data['location'] = expanded_locations[i] new_urls.append(data) return new_urls
def main(*args): classes = [] for classref_spec in itertools.chain(*[ braceexpand(a) for a in args ]): classes.append(load_class_by_refstr(classref_spec)) print(classes) MROgraph(*classes, filename='py-MRO-graph.png')
def get_hosts(self, host_list): out = [] for host_line in host_list: for host_string in re.split("\s+", host_line.strip()): for host_addr in braceexpand(host_string): host_label = self.get_label(host_addr) out.append(OSSHHost(host_addr, host_label)) return out
def expand_file_expression(name): """ Expands Unix-style wildcards in nearly the same way the shell would Handles home directory shortcuts, brace expansion, environment variables, and wildcard characters. Args: name: A string representing a filename to potentially expand Returns: A list of file names matching the pattern. The list may be empty. All names represent existing files (though could potentially be broken symlinks). """ try: names = braceexpand.braceexpand(name) except braceexpand.UnbalancedBracesError: names =[name] names = [os.path.expanduser(os.path.expandvars(elem)) for elem in names] results = [] for elem in names: results.extend(glob.glob(elem)) return results
def expand_primers(primer): """ From an input of a degenerate oligo, returns a list of non-degenerate oligos. """ deg_dic = {'W': '{A,T}', 'S': '{C,G}', 'M': '{A,C}', 'K': '{G,T}', 'R': '{A,G}', 'Y': '{C,T}', 'B': '{C,G,T}', 'D': '{A,G,T}', 'H': '{A,C,G}', 'V': '{A,C,T}', 'N': '{A,C,G,T}'} expand_template = "" for letter in primer: if letter in {'A', 'T', 'C', 'G', '-'}: expand_template += letter else: expand_template += deg_dic[letter] expanded_primers = list(braceexpand(expand_template)) return expanded_primers
def parse_profiles(parser, action='start'): """Given an argparse parser and a cluster action, generate subparsers for each topology.""" topologies_directory = os.path.dirname(__file__) subparsers = parser.add_subparsers(help='The topology to use when starting the cluster', dest='topology') parsers = dict() for topology in os.listdir(topologies_directory): if os.path.isdir(os.path.join(topologies_directory, topology)): # Generate help and optional arguments based on the options under our topology's # profile.cfg file's node_groups section. config_filename = os.path.join(os.path.dirname(__file__), topology, TOPOLOGIES_CONFIG_NAME) config = ConfigParser.ConfigParser(allow_no_value=True) config.read(config_filename) parsers[topology] = subparsers.add_parser( topology, help=config.get('general', 'description'), formatter_class=argparse.ArgumentDefaultsHelpFormatter ) # Arguments in the [all] group should be available to all actions. parse_args_from_config(parsers[topology], config, 'all') if action == 'start': for option in config.options('node_groups'): # While we use our custom StoreBraceExpandedAction to process the given values, # we need to separately brace-expand the default to make it show up correctly # in help messages. default = list(braceexpand(config.get('node_groups', option))) parsers[topology].add_argument("--{0}".format(option), metavar='NODES', default=default, action=StoreBraceExpandedAction, help="Nodes of the {0} group".format(option)) parse_args_from_config(parsers[topology], config, 'start') elif action == 'build': parse_args_from_config(parsers[topology], config, 'build')
def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, list(braceexpand(values)))
create_device_type('iosv', cisco) create_device_type('ubuntu1804_cloud_init', linux) iosv = nb.dcim.device_types.get(model="iosv") ubuntu1804_cloud_init = nb.dcim.device_types.get(model="ubuntu1804_cloud_init") # create ios, ubuntu interface template ios_interfaces = ['GigabitEthernet0/{0..3}'] ubuntu_interfaces = ['ens2', 'ens3'] # create an empty set asserted_ios_interface_list = set() asserted_ubuntu_interface_list = set() # build set of ios interfaces for port in ios_interfaces: asserted_ios_interface_list.update(braceexpand(port)) for port in ubuntu_interfaces: asserted_ubuntu_interface_list.add(port) # convert set to dict and set port speed interface_data = {} # iterate over ios interfaces for port in asserted_ios_interface_list: data = {} intf_type = '1000base-t' mgmt_status = False # set gig0/0 as oob_mgmt port if port == 'GigabitEthernet0/0': mgmt_status = True
def brace_expand(pattern): """Bash-style brace expansion""" for path in braceexpand(pattern): yield Path(path)
'ScriptName': 'taxibench_ibis.py', 'CommitHash': args.commit_omnisci }) # Delete old table if not args.dnd: print("Deleting", database_name, "old database") try: conn.drop_database(database_name, force=True) time.sleep(2) conn = omnisci_server_worker.connect_to_server() except Exception as err: print("Failed to delete", database_name, "old database: ", err) args.dp = args.dp.replace("'", "") data_files_names = list(braceexpand(args.dp)) data_files_names = sorted( [x for f in data_files_names for x in glob.glob(f)]) data_files_number = len(data_files_names[:args.df]) try: print("Creating", database_name, "new database") conn.create_database( database_name) # Ibis list_databases method is not supported yet except Exception as err: print("Database creation is skipped, because of error:", err) if len(data_files_names) == 0: print("Could not find any data files matching", args.dp) sys.exit(2)
def renderfn(self, s: str, vals: dict) -> Generator[str, None, None]: return braceexpand.braceexpand(self.render(s, vals))
class EmoteCollector(Bot): def __init__(self, **kwargs): super().__init__(setup_db=True, **kwargs) self.jinja_env = jinja2.Environment( loader=jinja2.FileSystemLoader(str(BASE_DIR / 'sql')), line_statement_prefix='-- :') def process_config(self): super().process_config() self.config['backend_user_accounts'] = set(self.config['backend_user_accounts']) with contextlib.suppress(KeyError): self.config['copyright_license_file'] = BASE_DIR / self.config['copyright_license_file'] utils.SUCCESS_EMOJIS = self.config.get('success_or_failure_emojis', ('❌', '✅')) ### Events async def on_message(self, message): if self.should_reply(message): await self.set_locale(message) await self.process_commands(message) async def set_locale(self, message): locale = await self.cogs['Locales'].locale(message) utils.i18n.current_locale.set(locale) # https://github.com/Rapptz/RoboDanny/blob/ca75fae7de132e55270e53d89bc19dd2958c2ae0/bot.py#L77-L85 async def on_command_error(self, context, error): if isinstance(error, commands.NoPrivateMessage): await context.author.send(_('This command cannot be used in private messages.')) elif isinstance(error, commands.DisabledCommand): await context.send(_('Sorry. This command is disabled and cannot be used.')) elif isinstance(error, commands.NotOwner): logger.error('%s tried to run %s but is not the owner', context.author, context.command.name) with contextlib.suppress(discord.HTTPException): await context.try_add_reaction(utils.SUCCESS_EMOJIS[False]) elif isinstance(error, (commands.UserInputError, commands.CheckFailure)): await context.send(error) elif ( isinstance(error, commands.CommandInvokeError) # abort if it's overridden and getattr( type(context.cog), 'cog_command_error', # treat ones with no cog (e.g. eval'd ones) as being in a cog that did not override commands.Cog.cog_command_error) is commands.Cog.cog_command_error ): if not isinstance(error.original, discord.HTTPException): logger.error('"%s" caused an exception', context.message.content) logger.error(''.join(traceback.format_tb(error.original.__traceback__))) # pylint: disable=logging-format-interpolation logger.error('{0.__class__.__name__}: {0}'.format(error.original)) await context.send(_('An internal error occurred while trying to run that command.')) elif isinstance(error.original, discord.Forbidden): await context.send(_("I'm missing permissions to perform that action.")) ### Utility functions async def get_context(self, message, cls=None): return await super().get_context(message, cls=cls or utils.context.CustomContext) # https://github.com/Rapptz/discord.py/blob/814b03f5a8a6faa33d80495691f1e1cbdce40ce2/discord/ext/commands/core.py#L1338-L1346 def has_permissions(self, message, **perms): guild = message.guild me = guild.me if guild is not None else self.user permissions = message.channel.permissions_for(me) for perm, value in perms.items(): if getattr(permissions, perm, None) != value: return False return True def queries(self, template_name): return self.jinja_env.get_template(str(template_name)).module ### Init / Shutdown startup_extensions = list(braceexpand("""{ emote_collector.extensions.{ locale, file_upload_hook, logging, db, emote, api, gimme, meta, stats, meme, bingo.{ db, commands}}, jishaku, bot_bin.{ misc, debug, sql}} """.replace('\t', '').replace('\n', ''))) def load_extensions(self): utils.i18n.set_default_locale() super().load_extensions()
def url_test(url, location=None, status_code=requests.codes.moved_permanently, req_headers=None, req_kwargs=None, resp_headers=None, query=None, follow_redirects=False, final_status_code=requests.codes.ok): """ Function for producing a config dict for the redirect test. You can use simple bash style brace expansion in the `url` and `location` values. If you need the `location` to change with the `url` changes you must use the same number of expansions or the `location` will be treated as non-expandable. If you use brace expansion this function will return a list of dicts instead of a dict. You must use the `flatten` function provided to prepare your test fixture if you do this. If you combine brace expansion with a compiled regular expression pattern you must escape any backslashes as this is the escape character for brace expansion. example: url_test('/about/drivers{/,.html}', 'https://wiki.mozilla.org/Firefox/Drivers'), url_test('/projects/index.{de,fr,hr,sq}.html', '/{de,fr,hr,sq}/firefox/'), url_test('/firefox/notes/', re.compile(r'\/firefox\/[\d\.]+\/releasenotes\/'), url_test('/firefox/android/{,beta/}notes/', re.compile(r'\\/firefox\\/android\\/[\\d\\.]+{,beta}\\/releasenotes\\/' :param url: The URL in question (absolute or relative). :param location: If a redirect, either the expected value or a compiled regular expression to match the "Location" header. :param status_code: Expected status code from the request. :param req_headers: Extra headers to send with the request. :param req_kwargs: Extra arguments to pass to requests.get() :param resp_headers: Dict of headers expected in the response. :param query: Dict of expected query params in `location` URL. :param follow_redirects: Boolean indicating whether redirects should be followed. :param final_status_code: Expected status code after following any redirects. :return: dict or list of dicts """ test_data = { 'url': url, 'location': location, 'status_code': status_code, 'req_headers': req_headers, 'req_kwargs': req_kwargs, 'resp_headers': resp_headers, 'query': query, 'follow_redirects': follow_redirects, 'final_status_code': final_status_code, } expanded_urls = list(braceexpand(url)) num_urls = len(expanded_urls) if num_urls == 1: return test_data try: # location is a compiled regular expression pattern location_pattern = location.pattern test_data['location'] = location_pattern except AttributeError: location_pattern = None new_urls = [] if location: expanded_locations = list(braceexpand(test_data['location'])) num_locations = len(expanded_locations) for i, url in enumerate(expanded_urls): data = test_data.copy() data['url'] = url if location and num_urls == num_locations: if location_pattern is not None: # recompile the pattern after expansion data['location'] = re.compile(expanded_locations[i]) else: data['location'] = expanded_locations[i] new_urls.append(data) return new_urls
def __init__( self, text_tar_filepaths: str, metadata_path: str, encoder_tokenizer: str, decoder_tokenizer: str, shuffle_n: int = 1, shard_strategy: str = "scatter", global_rank: int = 0, world_size: int = 0, reverse_lang_direction: bool = False, ): super(TarredTranslationDataset, self).__init__() self.encoder_tokenizer = encoder_tokenizer self.decoder_tokenizer = decoder_tokenizer self.reverse_lang_direction = reverse_lang_direction self.src_pad_id = encoder_tokenizer.pad_id self.tgt_pad_id = decoder_tokenizer.pad_id valid_shard_strategies = ['scatter', 'replicate'] if shard_strategy not in valid_shard_strategies: raise ValueError(f"`shard_strategy` must be one of {valid_shard_strategies}") with open(metadata_path, 'r') as f: metadata = json.load(f) self.metadata = metadata if isinstance(text_tar_filepaths, str): # Replace '(', '[', '<' and '_OP_' with '{' brace_keys_open = ['(', '[', '<', '_OP_'] for bkey in brace_keys_open: if bkey in text_tar_filepaths: text_tar_filepaths = text_tar_filepaths.replace(bkey, "{") # Replace ')', ']', '>' and '_CL_' with '}' brace_keys_close = [')', ']', '>', '_CL_'] for bkey in brace_keys_close: if bkey in text_tar_filepaths: text_tar_filepaths = text_tar_filepaths.replace(bkey, "}") if isinstance(text_tar_filepaths, str): # Brace expand text_tar_filepaths = list(braceexpand.braceexpand(text_tar_filepaths)) if shard_strategy == 'scatter': logging.info("All tarred dataset shards will be scattered evenly across all nodes.") if len(text_tar_filepaths) % world_size != 0: logging.warning( f"Number of shards in tarred dataset ({len(text_tar_filepaths)}) is not divisible " f"by number of distributed workers ({world_size})." ) begin_idx = (len(text_tar_filepaths) // world_size) * global_rank end_idx = begin_idx + (len(text_tar_filepaths) // world_size) logging.info('Begin Index : %d' % (begin_idx)) logging.info('End Index : %d' % (end_idx)) text_tar_filepaths = text_tar_filepaths[begin_idx:end_idx] logging.info( "Partitioning tarred dataset: process (%d) taking shards [%d, %d)", global_rank, begin_idx, end_idx ) elif shard_strategy == 'replicate': logging.info("All tarred dataset shards will be replicated across all nodes.") else: raise ValueError(f"Invalid shard strategy ! Allowed values are : {valid_shard_strategies}") self.tarpath = text_tar_filepaths # Put together WebDataset self._dataset = ( wd.Dataset(text_tar_filepaths) .shuffle(shuffle_n) .rename(pkl='pkl', key='__key__') .to_tuple('pkl', 'key') .map(f=self._build_sample) )
def construct_dataset( fname, cache_dir=default_cache_dir, cache_size=default_cache_size, cache_name=default_cache_name, cache_verbose=default_cache_verbose, chunksize=10, handler=reraise_exception, repeat=False, ): """Construct a composite dataset from multiple sources using a YAML spec. This function gets invoked when you construct a WebDataset from a ".ds.yml" specification file. This is an experimental function that constructs composite datasets. You may want to opt for the simpler .shards.yml spec, which specifies combining datasets at the shard level; it interacts more simply with workers and distributed training. """ from .dataset import WebDataset, WebLoader with open(fname) as stream: spec = yaml.safe_load(stream) result = [] check_allowed(spec, "prefix datasets epoch", "top datasets spec") prefix = spec.get("prefix", "") for ds in spec["datasets"]: check_allowed( ds, """ name buckets shards resampled epoch_shuffle shuffle split_by_worker split_by_node cachedir cachesize cachename cacheverbose subsample shuffle epoch chunksize nworkers probability """, "dataset spec", ) buckets = ds.get("buckets", [""]) assert len( buckets) == 1, "FIXME support for multiple buckets unimplemented" bucket = buckets[0] urls = ds["shards"] urls = [u for url in urls for u in braceexpand.braceexpand(url)] urls = [prefix + bucket + u for u in urls] print( f"# input {ds.get('name', '')} {prefix+bucket+str(ds['shards'])} {len(urls)} " + f"{ds.get('epoch')} {ds.get('resampled')}", file=sys.stderr, ) if ds.get("resampled", False): urls = ResampledShards(urls) else: urls = PytorchShardList( urls, epoch_shuffle=ds.get("epoch_shuffle", False), shuffle=ds.get("shuffle", True), split_by_worker=ds.get("split_by_worker", True), split_by_node=ds.get("split_by_node", True), ) dataset = WebDataset( urls, ds.get("cachedir", cache_dir), ds.get("cachesize", cache_size), ds.get("cachename", cache_name), ds.get("cacheverbose", cache_verbose), ) if "subsample" in ds: dataset = dataset.rsample(ds["subsample"]) if "shuffle" in ds: dataset = dataset.shuffle(ds["shuffle"]) if "epoch" in ds: dataset = dataset.with_epoch(ds["epoch"]) bs = ds.get("chunksize", chunksize) if bs > 0: dataset = dataset.listed(bs) nworkers = ds.get("nworkers", 0) if nworkers >= 0: dataset = WebLoader(dataset, num_workers=nworkers, batch_size=None, collate_fn=list) p = ds.get("probability", 1.0) result.append(Source(dataset=dataset, probability=p)) if len(result) > 1: result = RoundRobin(result) else: result = result[0].dataset if bs > 0: result = result.unlisted() if "epoch" in spec: result = result.with_epoch(spec["epoch"]).with_length(spec["epoch"]) return result
def __init__( self, *, audio_tar_filepaths: Union[str, List[str]], manifest_filepath: str, labels: List[str], featurizer, shuffle_n: int = 0, min_duration: Optional[float] = 0.1, max_duration: Optional[float] = None, trim: bool = False, shard_strategy: str = "scatter", global_rank: int = 0, world_size: int = 0, is_regression_task: bool = False, ): self.collection = collections.ASRSpeechLabel( manifests_files=manifest_filepath.split(','), min_duration=min_duration, max_duration=max_duration, index_by_file_id= True, # Must set this so the manifest lines can be indexed by file ID ) self.file_occurence = count_occurence(self.collection.mapping) self.featurizer = featurizer self.trim = trim self.labels = labels if labels else self.collection.uniq_labels self.num_classes = len(self.labels) self.label2id, self.id2label = {}, {} for label_id, label in enumerate(self.labels): self.label2id[label] = label_id self.id2label[label_id] = label for idx in range(len(self.labels[:5])): logging.debug(" label id {} and its mapped label {}".format( idx, self.id2label[idx])) valid_shard_strategies = ['scatter', 'replicate'] if shard_strategy not in valid_shard_strategies: raise ValueError( f"`shard_strategy` must be one of {valid_shard_strategies}") if isinstance(audio_tar_filepaths, str): # Replace '(' and '[' with '{' brace_keys_open = ['(', '[', '<', '_OP_'] for bkey in brace_keys_open: if bkey in audio_tar_filepaths: audio_tar_filepaths = audio_tar_filepaths.replace( bkey, "{") # Replace ')' and ']' with '}' brace_keys_close = [')', ']', '>', '_CL_'] for bkey in brace_keys_close: if bkey in audio_tar_filepaths: audio_tar_filepaths = audio_tar_filepaths.replace( bkey, "}") # Check for distributed and partition shards accordingly if world_size > 1: if isinstance(audio_tar_filepaths, str): # Brace expand audio_tar_filepaths = list( braceexpand.braceexpand(audio_tar_filepaths)) if shard_strategy == 'scatter': logging.info( "All tarred dataset shards will be scattered evenly across all nodes." ) if len(audio_tar_filepaths) % world_size != 0: logging.warning( f"Number of shards in tarred dataset ({len(audio_tar_filepaths)}) is not divisible " f"by number of distributed workers ({world_size}).") begin_idx = (len(audio_tar_filepaths) // world_size) * global_rank end_idx = begin_idx + (len(audio_tar_filepaths) // world_size) audio_tar_filepaths = audio_tar_filepaths[begin_idx:end_idx] logging.info( "Partitioning tarred dataset: process (%d) taking shards [%d, %d)", global_rank, begin_idx, end_idx) elif shard_strategy == 'replicate': logging.info( "All tarred dataset shards will be replicated across all nodes." ) else: raise ValueError( f"Invalid shard strategy ! Allowed values are : {valid_shard_strategies}" ) # Put together WebDataset self._dataset = wd.WebDataset(urls=audio_tar_filepaths, nodesplitter=None) if shuffle_n > 0: self._dataset = self._dataset.shuffle(shuffle_n) else: logging.info( "WebDataset will not shuffle files within the tar files.") self._dataset = (self._dataset.rename( audio=VALID_FILE_FORMATS, key='__key__').to_tuple( 'audio', 'key').pipe(self._filter).map(f=self._build_sample))
def real_glob(rglob): glob_list = braceexpand(rglob) files = [] for g in glob_list: files = files + glob.glob(g) return sorted(files)
def brace_expand(pattern): """Bash-style brace expansion""" from braceexpand import braceexpand for path in braceexpand(pattern): yield to_path(path)
def files_names_from_pattern(filename): from braceexpand import braceexpand data_files_names = list(braceexpand(filename)) data_files_names = sorted( [x for f in data_files_names for x in glob.glob(f)]) return data_files_names
def transform_objects(self, transform_id, template): for obj_name in braceexpand(template): yield self.transform_object(transform_id, obj_name)
"synthetic_results", args.label, "CPU", "Benchmarks", args.synthetic_query + ".json", ) import_cmdline = None benchmark_cmdline = synthetic_benchmark_cmdline else: if args.import_file is None or args.table_schema_file is None or args.queries_dir is None: print( "For dataset type of benchmark the following parameters are mandatory: --import-file," " --table-schema-file and --queries-dir and --fragment-size is optional." ) sys.exit(3) datafiles_names = list(braceexpand(args.import_file)) datafiles_names = sorted( [x for f in datafiles_names for x in glob.glob(f)]) datafiles = len(datafiles_names) print("NUMBER OF DATAFILES FOUND:", datafiles) results_file_name = os.path.join(args.benchmarks_path, "benchmark.json") import_cmdline = dataset_import_cmdline benchmark_cmdline = dataset_benchmark_cmdline db_reporter = None if args.db_user is not "": if args.db_table is None: print( "--db-table parameter is mandatory to store results in MySQL database" ) sys.exit(4)
def __init__( self, audio_tar_filepaths: Union[str, List[str]], manifest_filepath: str, parser: Callable, sample_rate: int, int_values: bool = False, augmentor: Optional[ 'nemo.collections.asr.parts.perturb.AudioAugmentor'] = None, shuffle_n: int = 0, min_duration: Optional[float] = None, max_duration: Optional[float] = None, max_utts: int = 0, trim: bool = False, bos_id: Optional[int] = None, eos_id: Optional[int] = None, add_misc: bool = False, pad_id: int = 0, shard_strategy: str = "scatter", global_rank: int = 0, world_size: int = 0, ): self.collection = collections.ASRAudioText( manifests_files=manifest_filepath.split(','), parser=parser, min_duration=min_duration, max_duration=max_duration, max_number=max_utts, index_by_file_id= True, # Must set this so the manifest lines can be indexed by file ID ) self.featurizer = WaveformFeaturizer(sample_rate=sample_rate, int_values=int_values, augmentor=augmentor) self.trim = trim self.eos_id = eos_id self.bos_id = bos_id self.pad_id = pad_id self._add_misc = add_misc valid_shard_strategies = ['scatter', 'replicate'] if shard_strategy not in valid_shard_strategies: raise ValueError( f"`shard_strategy` must be one of {valid_shard_strategies}") if isinstance(audio_tar_filepaths, str): # Replace '(' and '[' with '{' brace_keys_open = ['(', '[', '<', '_OP_'] for bkey in brace_keys_open: if bkey in audio_tar_filepaths: audio_tar_filepaths = audio_tar_filepaths.replace( bkey, "{") # Replace ')' and ']' with '}' brace_keys_close = [')', ']', '>', '_CL_'] for bkey in brace_keys_close: if bkey in audio_tar_filepaths: audio_tar_filepaths = audio_tar_filepaths.replace( bkey, "}") # Check for distributed and partition shards accordingly if world_size > 1: if isinstance(audio_tar_filepaths, str): # Brace expand audio_tar_filepaths = list( braceexpand.braceexpand(audio_tar_filepaths)) if shard_strategy == 'scatter': logging.info( "All tarred dataset shards will be scattered evenly across all nodes." ) if len(audio_tar_filepaths) % world_size != 0: logging.warning( f"Number of shards in tarred dataset ({len(audio_tar_filepaths)}) is not divisible " f"by number of distributed workers ({world_size}).") begin_idx = (len(audio_tar_filepaths) // world_size) * global_rank end_idx = begin_idx + (len(audio_tar_filepaths) // world_size) audio_tar_filepaths = audio_tar_filepaths[begin_idx:end_idx] logging.info( "Partitioning tarred dataset: process (%d) taking shards [%d, %d)", global_rank, begin_idx, end_idx) elif shard_strategy == 'replicate': logging.info( "All tarred dataset shards will be replicated across all nodes." ) else: raise ValueError( f"Invalid shard strategy ! Allowed values are : {valid_shard_strategies}" ) # Put together WebDataset self._dataset = ( wd.Dataset(audio_tar_filepaths).shuffle(shuffle_n).rename( audio='wav', key='__key__').to_tuple('audio', 'key').pipe( self._filter).map(f=self._build_sample))
def __init__( self, audio_tar_filepaths, manifest_filepath, labels, batch_size, sample_rate=16000, int_values=False, bos_id=None, eos_id=None, pad_id=None, min_duration=0.1, max_duration=None, normalize_transcripts=True, trim_silence=False, shuffle_n=0, num_workers=0, augmentor: Optional[Union[AudioAugmentor, Dict[str, Dict[str, Any]]]] = None, ): super().__init__() self._sample_rate = sample_rate if augmentor is not None: augmentor = _process_augmentations(augmentor) self.collection = ASRAudioText( manifests_files=manifest_filepath.split(','), parser=make_parser(labels=labels, name='en', do_normalize=normalize_transcripts), min_duration=min_duration, max_duration=max_duration, index_by_file_id= True, # Must set this so the manifest lines can be indexed by file ID ) self.featurizer = WaveformFeaturizer(sample_rate=self._sample_rate, int_values=int_values, augmentor=augmentor) self.trim = trim_silence self.eos_id = eos_id self.bos_id = bos_id # Used in creating a sampler (in Actions). self._batch_size = batch_size self._num_workers = num_workers pad_id = 0 if pad_id is None else pad_id self.collate_fn = partial(seq_collate_fn, token_pad_value=pad_id) # Check for distributed and partition shards accordingly if torch.distributed.is_initialized(): global_rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() if isinstance(audio_tar_filepaths, str): audio_tar_filepaths = list( braceexpand.braceexpand(audio_tar_filepaths)) if len(audio_tar_filepaths) % world_size != 0: logging.warning( f"Number of shards in tarred dataset ({len(audio_tar_filepaths)}) is not divisible " f"by number of distributed workers ({world_size}).") begin_idx = (len(audio_tar_filepaths) // world_size) * global_rank end_idx = begin_idx + (len(audio_tar_filepaths) // world_size) audio_tar_filepaths = audio_tar_filepaths[begin_idx:end_idx] # Put together WebDataset self._dataset = ( wd.Dataset(audio_tar_filepaths).shuffle(shuffle_n).rename( audio='wav', key='__key__').to_tuple('audio', 'key').pipe( self._filter).map(f=self._build_sample))
def url_test(url, location=None, status_code=requests.codes.moved_permanently, req_headers=None, req_kwargs=None, resp_headers=None, query=None, follow_redirects=False, final_status_code=requests.codes.ok): r""" Function for producing a config dict for the redirect test. You can use simple bash style brace expansion in the `url` and `location` values. If you need the `location` to change with the `url` changes you must use the same number of expansions or the `location` will be treated as non-expandable. If you use brace expansion this function will return a list of dicts instead of a dict. You must use the `flatten` function provided to prepare your test fixture if you do this. If you combine brace expansion with a compiled regular expression pattern you must escape any backslashes as this is the escape character for brace expansion. example: url_test('/about/drivers{/,.html}', 'https://wiki.mozilla.org/Firefox/Drivers'), url_test('/projects/index.{de,fr,hr,sq}.html', '/{de,fr,hr,sq}/firefox/'), url_test('/firefox/notes/', re.compile(r'\/firefox\/[\d\.]+\/releasenotes\/'), url_test('/firefox/android/{,beta/}notes/', re.compile(r'\\/firefox\\/android\\/[\\d\\.]+{,beta}\\/releasenotes\\/' :param url: The URL in question (absolute or relative). :param location: If a redirect, either the expected value or a compiled regular expression to match the "Location" header. :param status_code: Expected status code from the request. :param req_headers: Extra headers to send with the request. :param req_kwargs: Extra arguments to pass to requests.get() :param resp_headers: Dict of headers expected in the response. :param query: Dict of expected query params in `location` URL. :param follow_redirects: Boolean indicating whether redirects should be followed. :param final_status_code: Expected status code after following any redirects. :return: dict or list of dicts """ test_data = { 'url': url, 'location': location, 'status_code': status_code, 'req_headers': req_headers, 'req_kwargs': req_kwargs, 'resp_headers': resp_headers, 'query': query, 'follow_redirects': follow_redirects, 'final_status_code': final_status_code, } expanded_urls = list(braceexpand(url)) num_urls = len(expanded_urls) if num_urls == 1: return test_data try: # location is a compiled regular expression pattern location_pattern = location.pattern test_data['location'] = location_pattern except AttributeError: location_pattern = None new_urls = [] if location: expanded_locations = list(braceexpand(test_data['location'])) num_locations = len(expanded_locations) for i, url in enumerate(expanded_urls): data = test_data.copy() data['url'] = url if location and num_urls == num_locations: if location_pattern is not None: # recompile the pattern after expansion data['location'] = re.compile(expanded_locations[i]) else: data['location'] = expanded_locations[i] new_urls.append(data) return new_urls
def _evaluate_glob(self, pattern): discovered = set() for variation in braceexpand.braceexpand(pattern): for path in glob.glob(variation): discovered.add(path) return discovered