コード例 #1
0
ファイル: base.py プロジェクト: demoranews/bedrock
def url_test(url, location=None, status_code=301, req_headers=None, req_kwargs=None,
             resp_headers=None, query=None):
    """
    Function for producing a config dict for the redirect test.

    You can use simple bash style brace expansion in the `url` and `location`
    values. If you need the `location` to change with the `url` changes you must
    use the same number of expansions or the `location` will be treated as non-expandable.

    If you use brace expansion this function will return a list of dicts instead of a dict.
    You must use the `flatten` function provided to prepare your test fixture if you do this.

    example:

        url_test('/about/drivers{/,.html}', 'https://wiki.mozilla.org/Firefox/Drivers'),
        url_test('/projects/index.{de,fr,hr,sq}.html', '/{de,fr,hr,sq}/firefox/products/'),

    :param url: The URL in question (absolute or relative).
    :param location: If a redirect, the expected value of the "Location" header.
    :param status_code: Expected status code from the request.
    :param req_headers: Extra headers to send with the request.
    :param req_kwargs: Extra arguments to pass to requests.get()
    :param resp_headers: Dict of headers expected in the response.
    :param query: Dict of expected query params in `location` URL.
    :return: dict or list of dicts
    """
    test_data = {
        'url': url,
        'location': location,
        'status_code': status_code,
        'req_headers': req_headers,
        'req_kwargs': req_kwargs,
        'resp_headers': resp_headers,
        'query': query,
    }
    expanded_urls = list(braceexpand(url))
    num_urls = len(expanded_urls)
    if num_urls == 1:
        return test_data

    new_urls = []
    if location:
        expanded_locations = list(braceexpand(test_data['location']))
        num_locations = len(expanded_locations)

    for i, url in enumerate(expanded_urls):
        data = test_data.copy()
        data['url'] = url
        if location and num_urls == num_locations:
            data['location'] = expanded_locations[i]
        new_urls.append(data)

    return new_urls
コード例 #2
0
ファイル: py-MRO-graph.py プロジェクト: dotmpe/script-mpe
def main(*args):
    classes = []
    for classref_spec in itertools.chain(*[ braceexpand(a) for a in  args ]):
        classes.append(load_class_by_refstr(classref_spec))

    print(classes)
    MROgraph(*classes, filename='py-MRO-graph.png')
コード例 #3
0
ファイル: ossh.py プロジェクト: kt97679/one-ssh
 def get_hosts(self, host_list):
     out = []
     for host_line in host_list:
         for host_string in re.split("\s+", host_line.strip()):
             for host_addr in braceexpand(host_string):
                 host_label = self.get_label(host_addr)
                 out.append(OSSHHost(host_addr, host_label))
     return out
コード例 #4
0
ファイル: read_files.py プロジェクト: cbh66/ngrams
def expand_file_expression(name):
    """ Expands Unix-style wildcards in nearly the same way the shell would

    Handles home directory shortcuts, brace expansion, environment variables,
    and wildcard characters.
    Args:
        name: A string representing a filename to potentially expand
    Returns:
        A list of file names matching the pattern.  The list may be empty.
        All names represent existing files (though could potentially be
        broken symlinks).
    """
    try:
        names = braceexpand.braceexpand(name)
    except braceexpand.UnbalancedBracesError:
        names =[name]
    names = [os.path.expanduser(os.path.expandvars(elem)) for elem in names]
    results = []
    for elem in names:
        results.extend(glob.glob(elem))
    return results
コード例 #5
0
ファイル: utilities.py プロジェクト: manutamminen/epride
def expand_primers(primer):
    """ From an input of a degenerate oligo, returns a list of non-degenerate oligos.
    """
    deg_dic = {'W': '{A,T}',
               'S': '{C,G}',
               'M': '{A,C}',
               'K': '{G,T}',
               'R': '{A,G}',
               'Y': '{C,T}',
               'B': '{C,G,T}',
               'D': '{A,G,T}',
               'H': '{A,C,G}',
               'V': '{A,C,T}',
               'N': '{A,C,G,T}'}
    expand_template = ""
    for letter in primer:
        if letter in {'A', 'T', 'C', 'G', '-'}:
            expand_template += letter
        else:
            expand_template += deg_dic[letter]
    expanded_primers = list(braceexpand(expand_template))
    return expanded_primers
コード例 #6
0
ファイル: parsing.py プロジェクト: cloudera/clusterdock
def parse_profiles(parser, action='start'):
    """Given an argparse parser and a cluster action, generate subparsers for each topology."""
    topologies_directory = os.path.dirname(__file__)

    subparsers = parser.add_subparsers(help='The topology to use when starting the cluster',
                                       dest='topology')

    parsers = dict()
    for topology in os.listdir(topologies_directory):
        if os.path.isdir(os.path.join(topologies_directory, topology)):
            # Generate help and optional arguments based on the options under our topology's
            # profile.cfg file's node_groups section.
            config_filename = os.path.join(os.path.dirname(__file__), topology,
                                           TOPOLOGIES_CONFIG_NAME)
            config = ConfigParser.ConfigParser(allow_no_value=True)
            config.read(config_filename)

            parsers[topology] = subparsers.add_parser(
                topology, help=config.get('general', 'description'),
                formatter_class=argparse.ArgumentDefaultsHelpFormatter
            )

            # Arguments in the [all] group should be available to all actions.
            parse_args_from_config(parsers[topology], config, 'all')

            if action == 'start':
                for option in config.options('node_groups'):
                    # While we use our custom StoreBraceExpandedAction to process the given values,
                    # we need to separately brace-expand the default to make it show up correctly
                    # in help messages.
                    default = list(braceexpand(config.get('node_groups', option)))
                    parsers[topology].add_argument("--{0}".format(option), metavar='NODES',
                                                   default=default, action=StoreBraceExpandedAction,
                                                   help="Nodes of the {0} group".format(option))
                parse_args_from_config(parsers[topology], config, 'start')
            elif action == 'build':
                parse_args_from_config(parsers[topology], config, 'build')
コード例 #7
0
ファイル: parsing.py プロジェクト: cloudera/clusterdock
 def __call__(self, parser, namespace, values, option_string=None):
     setattr(namespace, self.dest, list(braceexpand(values)))
コード例 #8
0
create_device_type('iosv', cisco)
create_device_type('ubuntu1804_cloud_init', linux)
iosv = nb.dcim.device_types.get(model="iosv")
ubuntu1804_cloud_init = nb.dcim.device_types.get(model="ubuntu1804_cloud_init")

# create ios, ubuntu interface template
ios_interfaces = ['GigabitEthernet0/{0..3}']
ubuntu_interfaces = ['ens2', 'ens3']

# create an empty set
asserted_ios_interface_list = set()
asserted_ubuntu_interface_list = set()

# build set of ios interfaces
for port in ios_interfaces:
    asserted_ios_interface_list.update(braceexpand(port))

for port in ubuntu_interfaces:
    asserted_ubuntu_interface_list.add(port)

# convert set to dict and set port speed
interface_data = {}

# iterate over ios interfaces
for port in asserted_ios_interface_list:
    data = {}
    intf_type = '1000base-t'
    mgmt_status = False
    # set gig0/0 as oob_mgmt port
    if port == 'GigabitEthernet0/0':
        mgmt_status = True
コード例 #9
0
ファイル: shlib.py プロジェクト: KenKundert/shlib
 def brace_expand(pattern):
     """Bash-style brace expansion"""
     for path in braceexpand(pattern):
         yield Path(path)
コード例 #10
0
                'ScriptName': 'taxibench_ibis.py',
                'CommitHash': args.commit_omnisci
            })

    # Delete old table
    if not args.dnd:
        print("Deleting", database_name, "old database")
        try:
            conn.drop_database(database_name, force=True)
            time.sleep(2)
            conn = omnisci_server_worker.connect_to_server()
        except Exception as err:
            print("Failed to delete", database_name, "old database: ", err)

    args.dp = args.dp.replace("'", "")
    data_files_names = list(braceexpand(args.dp))
    data_files_names = sorted(
        [x for f in data_files_names for x in glob.glob(f)])

    data_files_number = len(data_files_names[:args.df])

    try:
        print("Creating", database_name, "new database")
        conn.create_database(
            database_name)  # Ibis list_databases method is not supported yet
    except Exception as err:
        print("Database creation is skipped, because of error:", err)

    if len(data_files_names) == 0:
        print("Could not find any data files matching", args.dp)
        sys.exit(2)
コード例 #11
0
ファイル: iface.py プロジェクト: wtnb75/tarjinja
 def renderfn(self, s: str, vals: dict) -> Generator[str, None, None]:
     return braceexpand.braceexpand(self.render(s, vals))
コード例 #12
0
class EmoteCollector(Bot):
	def __init__(self, **kwargs):
		super().__init__(setup_db=True, **kwargs)
		self.jinja_env = jinja2.Environment(
			loader=jinja2.FileSystemLoader(str(BASE_DIR / 'sql')),
			line_statement_prefix='-- :')

	def process_config(self):
		super().process_config()
		self.config['backend_user_accounts'] = set(self.config['backend_user_accounts'])
		with contextlib.suppress(KeyError):
			self.config['copyright_license_file'] = BASE_DIR / self.config['copyright_license_file']
		utils.SUCCESS_EMOJIS = self.config.get('success_or_failure_emojis', ('❌', '✅'))

	### Events

	async def on_message(self, message):
		if self.should_reply(message):
			await self.set_locale(message)
			await self.process_commands(message)

	async def set_locale(self, message):
		locale = await self.cogs['Locales'].locale(message)
		utils.i18n.current_locale.set(locale)

	# https://github.com/Rapptz/RoboDanny/blob/ca75fae7de132e55270e53d89bc19dd2958c2ae0/bot.py#L77-L85
	async def on_command_error(self, context, error):
		if isinstance(error, commands.NoPrivateMessage):
			await context.author.send(_('This command cannot be used in private messages.'))
		elif isinstance(error, commands.DisabledCommand):
			await context.send(_('Sorry. This command is disabled and cannot be used.'))
		elif isinstance(error, commands.NotOwner):
			logger.error('%s tried to run %s but is not the owner', context.author, context.command.name)
			with contextlib.suppress(discord.HTTPException):
				await context.try_add_reaction(utils.SUCCESS_EMOJIS[False])
		elif isinstance(error, (commands.UserInputError, commands.CheckFailure)):
			await context.send(error)
		elif (
			isinstance(error, commands.CommandInvokeError)
			# abort if it's overridden
			and
				getattr(
					type(context.cog),
					'cog_command_error',
					# treat ones with no cog (e.g. eval'd ones) as being in a cog that did not override
					commands.Cog.cog_command_error)
				is commands.Cog.cog_command_error
		):
			if not isinstance(error.original, discord.HTTPException):
				logger.error('"%s" caused an exception', context.message.content)
				logger.error(''.join(traceback.format_tb(error.original.__traceback__)))
				# pylint: disable=logging-format-interpolation
				logger.error('{0.__class__.__name__}: {0}'.format(error.original))

				await context.send(_('An internal error occurred while trying to run that command.'))
			elif isinstance(error.original, discord.Forbidden):
				await context.send(_("I'm missing permissions to perform that action."))

	### Utility functions

	async def get_context(self, message, cls=None):
		return await super().get_context(message, cls=cls or utils.context.CustomContext)

	# https://github.com/Rapptz/discord.py/blob/814b03f5a8a6faa33d80495691f1e1cbdce40ce2/discord/ext/commands/core.py#L1338-L1346
	def has_permissions(self, message, **perms):
		guild = message.guild
		me = guild.me if guild is not None else self.user
		permissions = message.channel.permissions_for(me)

		for perm, value in perms.items():
			if getattr(permissions, perm, None) != value:
				return False

		return True

	def queries(self, template_name):
		return self.jinja_env.get_template(str(template_name)).module

	### Init / Shutdown

	startup_extensions = list(braceexpand("""{
		emote_collector.extensions.{
			locale,
			file_upload_hook,
			logging,
			db,
			emote,
			api,
			gimme,
			meta,
			stats,
			meme,
			bingo.{
				db,
				commands}},
		jishaku,
		bot_bin.{
			misc,
			debug,
			sql}}
	""".replace('\t', '').replace('\n', '')))

	def load_extensions(self):
		utils.i18n.set_default_locale()
		super().load_extensions()
コード例 #13
0
ファイル: base.py プロジェクト: Delphine/bedrock
def url_test(url, location=None, status_code=requests.codes.moved_permanently,
             req_headers=None, req_kwargs=None, resp_headers=None, query=None,
             follow_redirects=False, final_status_code=requests.codes.ok):
    """
    Function for producing a config dict for the redirect test.

    You can use simple bash style brace expansion in the `url` and `location`
    values. If you need the `location` to change with the `url` changes you must
    use the same number of expansions or the `location` will be treated as non-expandable.

    If you use brace expansion this function will return a list of dicts instead of a dict.
    You must use the `flatten` function provided to prepare your test fixture if you do this.

    If you combine brace expansion with a compiled regular expression pattern you must
    escape any backslashes as this is the escape character for brace expansion.

    example:

        url_test('/about/drivers{/,.html}', 'https://wiki.mozilla.org/Firefox/Drivers'),
        url_test('/projects/index.{de,fr,hr,sq}.html', '/{de,fr,hr,sq}/firefox/'),
        url_test('/firefox/notes/', re.compile(r'\/firefox\/[\d\.]+\/releasenotes\/'),
        url_test('/firefox/android/{,beta/}notes/', re.compile(r'\\/firefox\\/android\\/[\\d\\.]+{,beta}\\/releasenotes\\/'

    :param url: The URL in question (absolute or relative).
    :param location: If a redirect, either the expected value or a compiled regular expression to match the "Location" header.
    :param status_code: Expected status code from the request.
    :param req_headers: Extra headers to send with the request.
    :param req_kwargs: Extra arguments to pass to requests.get()
    :param resp_headers: Dict of headers expected in the response.
    :param query: Dict of expected query params in `location` URL.
    :param follow_redirects: Boolean indicating whether redirects should be followed.
    :param final_status_code: Expected status code after following any redirects.
    :return: dict or list of dicts
    """
    test_data = {
        'url': url,
        'location': location,
        'status_code': status_code,
        'req_headers': req_headers,
        'req_kwargs': req_kwargs,
        'resp_headers': resp_headers,
        'query': query,
        'follow_redirects': follow_redirects,
        'final_status_code': final_status_code,
    }
    expanded_urls = list(braceexpand(url))
    num_urls = len(expanded_urls)
    if num_urls == 1:
        return test_data

    try:
        # location is a compiled regular expression pattern
        location_pattern = location.pattern
        test_data['location'] = location_pattern
    except AttributeError:
        location_pattern = None

    new_urls = []
    if location:
        expanded_locations = list(braceexpand(test_data['location']))
        num_locations = len(expanded_locations)

    for i, url in enumerate(expanded_urls):
        data = test_data.copy()
        data['url'] = url
        if location and num_urls == num_locations:
            if location_pattern is not None:
                # recompile the pattern after expansion
                data['location'] = re.compile(expanded_locations[i])
            else:
                data['location'] = expanded_locations[i]
        new_urls.append(data)

    return new_urls
コード例 #14
0
    def __init__(
        self,
        text_tar_filepaths: str,
        metadata_path: str,
        encoder_tokenizer: str,
        decoder_tokenizer: str,
        shuffle_n: int = 1,
        shard_strategy: str = "scatter",
        global_rank: int = 0,
        world_size: int = 0,
        reverse_lang_direction: bool = False,
    ):
        super(TarredTranslationDataset, self).__init__()

        self.encoder_tokenizer = encoder_tokenizer
        self.decoder_tokenizer = decoder_tokenizer
        self.reverse_lang_direction = reverse_lang_direction
        self.src_pad_id = encoder_tokenizer.pad_id
        self.tgt_pad_id = decoder_tokenizer.pad_id

        valid_shard_strategies = ['scatter', 'replicate']
        if shard_strategy not in valid_shard_strategies:
            raise ValueError(f"`shard_strategy` must be one of {valid_shard_strategies}")

        with open(metadata_path, 'r') as f:
            metadata = json.load(f)

        self.metadata = metadata

        if isinstance(text_tar_filepaths, str):
            # Replace '(', '[', '<' and '_OP_' with '{'
            brace_keys_open = ['(', '[', '<', '_OP_']
            for bkey in brace_keys_open:
                if bkey in text_tar_filepaths:
                    text_tar_filepaths = text_tar_filepaths.replace(bkey, "{")

            # Replace ')', ']', '>' and '_CL_' with '}'
            brace_keys_close = [')', ']', '>', '_CL_']
            for bkey in brace_keys_close:
                if bkey in text_tar_filepaths:
                    text_tar_filepaths = text_tar_filepaths.replace(bkey, "}")

        if isinstance(text_tar_filepaths, str):
            # Brace expand
            text_tar_filepaths = list(braceexpand.braceexpand(text_tar_filepaths))

        if shard_strategy == 'scatter':
            logging.info("All tarred dataset shards will be scattered evenly across all nodes.")
            if len(text_tar_filepaths) % world_size != 0:
                logging.warning(
                    f"Number of shards in tarred dataset ({len(text_tar_filepaths)}) is not divisible "
                    f"by number of distributed workers ({world_size})."
                )
            begin_idx = (len(text_tar_filepaths) // world_size) * global_rank
            end_idx = begin_idx + (len(text_tar_filepaths) // world_size)
            logging.info('Begin Index : %d' % (begin_idx))
            logging.info('End Index : %d' % (end_idx))
            text_tar_filepaths = text_tar_filepaths[begin_idx:end_idx]
            logging.info(
                "Partitioning tarred dataset: process (%d) taking shards [%d, %d)", global_rank, begin_idx, end_idx
            )

        elif shard_strategy == 'replicate':
            logging.info("All tarred dataset shards will be replicated across all nodes.")

        else:
            raise ValueError(f"Invalid shard strategy ! Allowed values are : {valid_shard_strategies}")

        self.tarpath = text_tar_filepaths

        # Put together WebDataset
        self._dataset = (
            wd.Dataset(text_tar_filepaths)
            .shuffle(shuffle_n)
            .rename(pkl='pkl', key='__key__')
            .to_tuple('pkl', 'key')
            .map(f=self._build_sample)
        )
コード例 #15
0
ファイル: dsspecs.py プロジェクト: OCRoArchive/webdataset
def construct_dataset(
    fname,
    cache_dir=default_cache_dir,
    cache_size=default_cache_size,
    cache_name=default_cache_name,
    cache_verbose=default_cache_verbose,
    chunksize=10,
    handler=reraise_exception,
    repeat=False,
):
    """Construct a composite dataset from multiple sources using a YAML spec.

    This function gets invoked when you construct a WebDataset from a
    ".ds.yml" specification file.

    This is an experimental function that constructs composite datasets.
    You may want to opt for the simpler .shards.yml spec, which specifies
    combining datasets at the shard level; it interacts more simply with
    workers and distributed training.
    """
    from .dataset import WebDataset, WebLoader
    with open(fname) as stream:
        spec = yaml.safe_load(stream)
    result = []
    check_allowed(spec, "prefix datasets epoch", "top datasets spec")
    prefix = spec.get("prefix", "")
    for ds in spec["datasets"]:
        check_allowed(
            ds,
            """
            name buckets shards resampled epoch_shuffle shuffle split_by_worker
            split_by_node cachedir cachesize cachename cacheverbose subsample
            shuffle epoch chunksize nworkers probability
        """,
            "dataset spec",
        )
        buckets = ds.get("buckets", [""])
        assert len(
            buckets) == 1, "FIXME support for multiple buckets unimplemented"
        bucket = buckets[0]
        urls = ds["shards"]
        urls = [u for url in urls for u in braceexpand.braceexpand(url)]
        urls = [prefix + bucket + u for u in urls]
        print(
            f"# input {ds.get('name', '')} {prefix+bucket+str(ds['shards'])} {len(urls)} "
            + f"{ds.get('epoch')} {ds.get('resampled')}",
            file=sys.stderr,
        )
        if ds.get("resampled", False):
            urls = ResampledShards(urls)
        else:
            urls = PytorchShardList(
                urls,
                epoch_shuffle=ds.get("epoch_shuffle", False),
                shuffle=ds.get("shuffle", True),
                split_by_worker=ds.get("split_by_worker", True),
                split_by_node=ds.get("split_by_node", True),
            )
        dataset = WebDataset(
            urls,
            ds.get("cachedir", cache_dir),
            ds.get("cachesize", cache_size),
            ds.get("cachename", cache_name),
            ds.get("cacheverbose", cache_verbose),
        )
        if "subsample" in ds:
            dataset = dataset.rsample(ds["subsample"])
        if "shuffle" in ds:
            dataset = dataset.shuffle(ds["shuffle"])
        if "epoch" in ds:
            dataset = dataset.with_epoch(ds["epoch"])
        bs = ds.get("chunksize", chunksize)
        if bs > 0:
            dataset = dataset.listed(bs)
        nworkers = ds.get("nworkers", 0)
        if nworkers >= 0:
            dataset = WebLoader(dataset,
                                num_workers=nworkers,
                                batch_size=None,
                                collate_fn=list)
        p = ds.get("probability", 1.0)
        result.append(Source(dataset=dataset, probability=p))
    if len(result) > 1:
        result = RoundRobin(result)
    else:
        result = result[0].dataset
    if bs > 0:
        result = result.unlisted()
    if "epoch" in spec:
        result = result.with_epoch(spec["epoch"]).with_length(spec["epoch"])
    return result
コード例 #16
0
ファイル: audio_to_label.py プロジェクト: piraka9011/NeMo
    def __init__(
        self,
        *,
        audio_tar_filepaths: Union[str, List[str]],
        manifest_filepath: str,
        labels: List[str],
        featurizer,
        shuffle_n: int = 0,
        min_duration: Optional[float] = 0.1,
        max_duration: Optional[float] = None,
        trim: bool = False,
        shard_strategy: str = "scatter",
        global_rank: int = 0,
        world_size: int = 0,
        is_regression_task: bool = False,
    ):
        self.collection = collections.ASRSpeechLabel(
            manifests_files=manifest_filepath.split(','),
            min_duration=min_duration,
            max_duration=max_duration,
            index_by_file_id=
            True,  # Must set this so the manifest lines can be indexed by file ID
        )

        self.file_occurence = count_occurence(self.collection.mapping)

        self.featurizer = featurizer
        self.trim = trim

        self.labels = labels if labels else self.collection.uniq_labels
        self.num_classes = len(self.labels)

        self.label2id, self.id2label = {}, {}
        for label_id, label in enumerate(self.labels):
            self.label2id[label] = label_id
            self.id2label[label_id] = label

        for idx in range(len(self.labels[:5])):
            logging.debug(" label id {} and its mapped label {}".format(
                idx, self.id2label[idx]))

        valid_shard_strategies = ['scatter', 'replicate']
        if shard_strategy not in valid_shard_strategies:
            raise ValueError(
                f"`shard_strategy` must be one of {valid_shard_strategies}")

        if isinstance(audio_tar_filepaths, str):
            # Replace '(' and '[' with '{'
            brace_keys_open = ['(', '[', '<', '_OP_']
            for bkey in brace_keys_open:
                if bkey in audio_tar_filepaths:
                    audio_tar_filepaths = audio_tar_filepaths.replace(
                        bkey, "{")

            # Replace ')' and ']' with '}'
            brace_keys_close = [')', ']', '>', '_CL_']
            for bkey in brace_keys_close:
                if bkey in audio_tar_filepaths:
                    audio_tar_filepaths = audio_tar_filepaths.replace(
                        bkey, "}")

        # Check for distributed and partition shards accordingly
        if world_size > 1:
            if isinstance(audio_tar_filepaths, str):
                # Brace expand
                audio_tar_filepaths = list(
                    braceexpand.braceexpand(audio_tar_filepaths))

            if shard_strategy == 'scatter':
                logging.info(
                    "All tarred dataset shards will be scattered evenly across all nodes."
                )

                if len(audio_tar_filepaths) % world_size != 0:
                    logging.warning(
                        f"Number of shards in tarred dataset ({len(audio_tar_filepaths)}) is not divisible "
                        f"by number of distributed workers ({world_size}).")

                begin_idx = (len(audio_tar_filepaths) //
                             world_size) * global_rank
                end_idx = begin_idx + (len(audio_tar_filepaths) // world_size)
                audio_tar_filepaths = audio_tar_filepaths[begin_idx:end_idx]
                logging.info(
                    "Partitioning tarred dataset: process (%d) taking shards [%d, %d)",
                    global_rank, begin_idx, end_idx)

            elif shard_strategy == 'replicate':
                logging.info(
                    "All tarred dataset shards will be replicated across all nodes."
                )

            else:
                raise ValueError(
                    f"Invalid shard strategy ! Allowed values are : {valid_shard_strategies}"
                )

        # Put together WebDataset
        self._dataset = wd.WebDataset(urls=audio_tar_filepaths,
                                      nodesplitter=None)

        if shuffle_n > 0:
            self._dataset = self._dataset.shuffle(shuffle_n)
        else:
            logging.info(
                "WebDataset will not shuffle files within the tar files.")

        self._dataset = (self._dataset.rename(
            audio=VALID_FILE_FORMATS, key='__key__').to_tuple(
                'audio', 'key').pipe(self._filter).map(f=self._build_sample))
コード例 #17
0
def real_glob(rglob):
    glob_list = braceexpand(rglob)
    files = []
    for g in glob_list:
        files = files + glob.glob(g)
    return sorted(files)
コード例 #18
0
def brace_expand(pattern):
    """Bash-style brace expansion"""
    from braceexpand import braceexpand

    for path in braceexpand(pattern):
        yield to_path(path)
コード例 #19
0
def files_names_from_pattern(filename):
    from braceexpand import braceexpand
    data_files_names = list(braceexpand(filename))
    data_files_names = sorted(
        [x for f in data_files_names for x in glob.glob(f)])
    return data_files_names
コード例 #20
0
 def transform_objects(self, transform_id, template):
     for obj_name in braceexpand(template):
         yield self.transform_object(transform_id, obj_name)
コード例 #21
0
        "synthetic_results",
        args.label,
        "CPU",
        "Benchmarks",
        args.synthetic_query + ".json",
    )
    import_cmdline = None
    benchmark_cmdline = synthetic_benchmark_cmdline
else:
    if args.import_file is None or args.table_schema_file is None or args.queries_dir is None:
        print(
            "For dataset type of benchmark the following parameters are mandatory: --import-file,"
            " --table-schema-file and --queries-dir and --fragment-size is optional."
        )
        sys.exit(3)
    datafiles_names = list(braceexpand(args.import_file))
    datafiles_names = sorted(
        [x for f in datafiles_names for x in glob.glob(f)])
    datafiles = len(datafiles_names)
    print("NUMBER OF DATAFILES FOUND:", datafiles)
    results_file_name = os.path.join(args.benchmarks_path, "benchmark.json")
    import_cmdline = dataset_import_cmdline
    benchmark_cmdline = dataset_benchmark_cmdline

db_reporter = None
if args.db_user is not "":
    if args.db_table is None:
        print(
            "--db-table parameter is mandatory to store results in MySQL database"
        )
        sys.exit(4)
コード例 #22
0
 def brace_expand(pattern):
     """Bash-style brace expansion"""
     for path in braceexpand(pattern):
         yield Path(path)
コード例 #23
0
    def __init__(
        self,
        audio_tar_filepaths: Union[str, List[str]],
        manifest_filepath: str,
        parser: Callable,
        sample_rate: int,
        int_values: bool = False,
        augmentor: Optional[
            'nemo.collections.asr.parts.perturb.AudioAugmentor'] = None,
        shuffle_n: int = 0,
        min_duration: Optional[float] = None,
        max_duration: Optional[float] = None,
        max_utts: int = 0,
        trim: bool = False,
        bos_id: Optional[int] = None,
        eos_id: Optional[int] = None,
        add_misc: bool = False,
        pad_id: int = 0,
        shard_strategy: str = "scatter",
        global_rank: int = 0,
        world_size: int = 0,
    ):
        self.collection = collections.ASRAudioText(
            manifests_files=manifest_filepath.split(','),
            parser=parser,
            min_duration=min_duration,
            max_duration=max_duration,
            max_number=max_utts,
            index_by_file_id=
            True,  # Must set this so the manifest lines can be indexed by file ID
        )

        self.featurizer = WaveformFeaturizer(sample_rate=sample_rate,
                                             int_values=int_values,
                                             augmentor=augmentor)
        self.trim = trim
        self.eos_id = eos_id
        self.bos_id = bos_id
        self.pad_id = pad_id
        self._add_misc = add_misc

        valid_shard_strategies = ['scatter', 'replicate']
        if shard_strategy not in valid_shard_strategies:
            raise ValueError(
                f"`shard_strategy` must be one of {valid_shard_strategies}")

        if isinstance(audio_tar_filepaths, str):
            # Replace '(' and '[' with '{'
            brace_keys_open = ['(', '[', '<', '_OP_']
            for bkey in brace_keys_open:
                if bkey in audio_tar_filepaths:
                    audio_tar_filepaths = audio_tar_filepaths.replace(
                        bkey, "{")

            # Replace ')' and ']' with '}'
            brace_keys_close = [')', ']', '>', '_CL_']
            for bkey in brace_keys_close:
                if bkey in audio_tar_filepaths:
                    audio_tar_filepaths = audio_tar_filepaths.replace(
                        bkey, "}")

        # Check for distributed and partition shards accordingly
        if world_size > 1:
            if isinstance(audio_tar_filepaths, str):
                # Brace expand
                audio_tar_filepaths = list(
                    braceexpand.braceexpand(audio_tar_filepaths))

            if shard_strategy == 'scatter':
                logging.info(
                    "All tarred dataset shards will be scattered evenly across all nodes."
                )

                if len(audio_tar_filepaths) % world_size != 0:
                    logging.warning(
                        f"Number of shards in tarred dataset ({len(audio_tar_filepaths)}) is not divisible "
                        f"by number of distributed workers ({world_size}).")

                begin_idx = (len(audio_tar_filepaths) //
                             world_size) * global_rank
                end_idx = begin_idx + (len(audio_tar_filepaths) // world_size)
                audio_tar_filepaths = audio_tar_filepaths[begin_idx:end_idx]
                logging.info(
                    "Partitioning tarred dataset: process (%d) taking shards [%d, %d)",
                    global_rank, begin_idx, end_idx)

            elif shard_strategy == 'replicate':
                logging.info(
                    "All tarred dataset shards will be replicated across all nodes."
                )

            else:
                raise ValueError(
                    f"Invalid shard strategy ! Allowed values are : {valid_shard_strategies}"
                )

        # Put together WebDataset
        self._dataset = (
            wd.Dataset(audio_tar_filepaths).shuffle(shuffle_n).rename(
                audio='wav', key='__key__').to_tuple('audio', 'key').pipe(
                    self._filter).map(f=self._build_sample))
コード例 #24
0
ファイル: data_layer.py プロジェクト: paulhendricks/NeMo
    def __init__(
        self,
        audio_tar_filepaths,
        manifest_filepath,
        labels,
        batch_size,
        sample_rate=16000,
        int_values=False,
        bos_id=None,
        eos_id=None,
        pad_id=None,
        min_duration=0.1,
        max_duration=None,
        normalize_transcripts=True,
        trim_silence=False,
        shuffle_n=0,
        num_workers=0,
        augmentor: Optional[Union[AudioAugmentor,
                                  Dict[str, Dict[str, Any]]]] = None,
    ):
        super().__init__()
        self._sample_rate = sample_rate

        if augmentor is not None:
            augmentor = _process_augmentations(augmentor)

        self.collection = ASRAudioText(
            manifests_files=manifest_filepath.split(','),
            parser=make_parser(labels=labels,
                               name='en',
                               do_normalize=normalize_transcripts),
            min_duration=min_duration,
            max_duration=max_duration,
            index_by_file_id=
            True,  # Must set this so the manifest lines can be indexed by file ID
        )

        self.featurizer = WaveformFeaturizer(sample_rate=self._sample_rate,
                                             int_values=int_values,
                                             augmentor=augmentor)

        self.trim = trim_silence
        self.eos_id = eos_id
        self.bos_id = bos_id

        # Used in creating a sampler (in Actions).
        self._batch_size = batch_size
        self._num_workers = num_workers
        pad_id = 0 if pad_id is None else pad_id
        self.collate_fn = partial(seq_collate_fn, token_pad_value=pad_id)

        # Check for distributed and partition shards accordingly
        if torch.distributed.is_initialized():
            global_rank = torch.distributed.get_rank()
            world_size = torch.distributed.get_world_size()

            if isinstance(audio_tar_filepaths, str):
                audio_tar_filepaths = list(
                    braceexpand.braceexpand(audio_tar_filepaths))

            if len(audio_tar_filepaths) % world_size != 0:
                logging.warning(
                    f"Number of shards in tarred dataset ({len(audio_tar_filepaths)}) is not divisible "
                    f"by number of distributed workers ({world_size}).")

            begin_idx = (len(audio_tar_filepaths) // world_size) * global_rank
            end_idx = begin_idx + (len(audio_tar_filepaths) // world_size)
            audio_tar_filepaths = audio_tar_filepaths[begin_idx:end_idx]

        # Put together WebDataset
        self._dataset = (
            wd.Dataset(audio_tar_filepaths).shuffle(shuffle_n).rename(
                audio='wav', key='__key__').to_tuple('audio', 'key').pipe(
                    self._filter).map(f=self._build_sample))
コード例 #25
0
ファイル: base.py プロジェクト: theaswanson/bedrock
def url_test(url, location=None, status_code=requests.codes.moved_permanently,
             req_headers=None, req_kwargs=None, resp_headers=None, query=None,
             follow_redirects=False, final_status_code=requests.codes.ok):
    r"""
    Function for producing a config dict for the redirect test.

    You can use simple bash style brace expansion in the `url` and `location`
    values. If you need the `location` to change with the `url` changes you must
    use the same number of expansions or the `location` will be treated as non-expandable.

    If you use brace expansion this function will return a list of dicts instead of a dict.
    You must use the `flatten` function provided to prepare your test fixture if you do this.

    If you combine brace expansion with a compiled regular expression pattern you must
    escape any backslashes as this is the escape character for brace expansion.

    example:

        url_test('/about/drivers{/,.html}', 'https://wiki.mozilla.org/Firefox/Drivers'),
        url_test('/projects/index.{de,fr,hr,sq}.html', '/{de,fr,hr,sq}/firefox/'),
        url_test('/firefox/notes/', re.compile(r'\/firefox\/[\d\.]+\/releasenotes\/'),
        url_test('/firefox/android/{,beta/}notes/', re.compile(r'\\/firefox\\/android\\/[\\d\\.]+{,beta}\\/releasenotes\\/'

    :param url: The URL in question (absolute or relative).
    :param location: If a redirect, either the expected value or a compiled regular expression to match the "Location" header.
    :param status_code: Expected status code from the request.
    :param req_headers: Extra headers to send with the request.
    :param req_kwargs: Extra arguments to pass to requests.get()
    :param resp_headers: Dict of headers expected in the response.
    :param query: Dict of expected query params in `location` URL.
    :param follow_redirects: Boolean indicating whether redirects should be followed.
    :param final_status_code: Expected status code after following any redirects.
    :return: dict or list of dicts
    """
    test_data = {
        'url': url,
        'location': location,
        'status_code': status_code,
        'req_headers': req_headers,
        'req_kwargs': req_kwargs,
        'resp_headers': resp_headers,
        'query': query,
        'follow_redirects': follow_redirects,
        'final_status_code': final_status_code,
    }
    expanded_urls = list(braceexpand(url))
    num_urls = len(expanded_urls)
    if num_urls == 1:
        return test_data

    try:
        # location is a compiled regular expression pattern
        location_pattern = location.pattern
        test_data['location'] = location_pattern
    except AttributeError:
        location_pattern = None

    new_urls = []
    if location:
        expanded_locations = list(braceexpand(test_data['location']))
        num_locations = len(expanded_locations)

    for i, url in enumerate(expanded_urls):
        data = test_data.copy()
        data['url'] = url
        if location and num_urls == num_locations:
            if location_pattern is not None:
                # recompile the pattern after expansion
                data['location'] = re.compile(expanded_locations[i])
            else:
                data['location'] = expanded_locations[i]
        new_urls.append(data)

    return new_urls
コード例 #26
0
ファイル: _common.py プロジェクト: tbroadley/sandpaper
 def _evaluate_glob(self, pattern):
     discovered = set()
     for variation in braceexpand.braceexpand(pattern):
         for path in glob.glob(variation):
             discovered.add(path)
     return discovered