コード例 #1
0
 def __init__(self,
              action,
              name='Pipeline',
              aug_min=1,
              aug_p=1,
              flow=None,
              verbose=0):
     Augmenter.__init__(self,
                        name=name,
                        method=Method.FLOW,
                        action=action,
                        aug_min=aug_min,
                        verbose=verbose)
     self.aug_p = aug_p
     if flow is None:
         list.__init__(self, [])
     elif isinstance(flow, (Augmenter, CharAugmenter)):
         list.__init__(self, [flow])
     elif isinstance(flow, list):
         for subflow in flow:
             if not isinstance(subflow, Augmenter):
                 raise ValueError(
                     'At least one of the flow does not belongs to Augmenter'
                 )
         list.__init__(self, flow)
     else:
         raise Exception(
             'Expected None, Augmenter or list of Augmenter while {} is passed'
             .format(type(flow)))
コード例 #2
0
    def setUpClass(cls):
        env_config_path = os.path.abspath(os.path.join(
            os.path.dirname(__file__), '..', '..', '.env'))
        load_dotenv(env_config_path)

        cls.aug = Augmenter(name='base', method='flow', action='insert',
            aug_min=1, aug_max=10, aug_p=0.5)
コード例 #3
0
def store_augments(origin_path: Path, augmenter: Augmenter, num_augments: int,
                   ident_generator: Iterator, output_dir: Path) -> List[Path]:
    """
    Stores augmented versions of a WAV file in the filesystem

    :param origin_path: the path to the original WAV file
    :param augmenter: An augmenter object
    :param num_augments: The number of augmented files to create
    :param ident_generator: A generator object that creates unique identifiers
                            for the file name
    :param output_dir: The directory to store the augmented files
    :return: A list of file paths for the augments that were created

    :raise ParameterError: `MAX_ATTEMPTS` ParameterErrors are raised in a row
                           while attempting to augment
    """
    logger.info('Loading %s for augmentation', origin_path)

    audio_data, _ = librosa.load(origin_path)

    logger.info('Augmenting %s', origin_path)

    attempts = 0
    augments = None

    while augments is None:
        try:
            attempts += 1
            augments = augmenter.augment(audio_data, n=num_augments)
        except ParameterError:
            logger.info(
                'Error encountered while augmenting "%s"; trying again.',
                origin_path)
            if attempts >= MAX_ATTEMPTS:
                raise

    if num_augments == 1:
        # nlpaug.Augmenter doesn't return a list if n=1
        augments = [augments]

    _, label = origin_path.stem.split('__')
    output_paths = []

    for augment in augments:
        identifier = next(ident_generator)
        output_path = output_dir / f'{identifier}__{label}.wav'
        output_paths.append(output_path)
        sf.write(output_path, augment, SAMPLING_RATE)
        logger.info('"%s" written to disk', output_path)

    return output_paths
コード例 #4
0
ファイル: backtranslate.py プロジェクト: Shikhar-S/nlc2cmd
def augment_dataset(
        augmenter: Augmenter,
        data: List[Dict],
        text_key: str,
        out_file: Path,
        batch_size: int,
        is_original_key: str = 'original') -> Generator[Dict, None, None]:
    with find_free_file(out_file).open("x") as ostream:
        for i in tqdm(range(0, len(data), batch_size)):
            batch = data[i:i + batch_size]
            examples = [e[text_key] for e in batch]
            variants = augmenter.augment(examples)
            for entry, var in zip(batch, variants):
                new_entry = copy(entry)
                new_entry.update({text_key: var, is_original_key: False})
                ostream.write(f"{json.dumps(new_entry)}\n")
コード例 #5
0
ファイル: augment.py プロジェクト: Shikhar-S/nlc2cmd
def augment_dataset(augmenter: Augmenter,
                    data: List[Dict],
                    text_key: str,
                    out_file: Path,
                    aug_config: Dict,
                    original_key: str = 'original'):
    with find_free_file(out_file).open("x") as ostream:
        # Dump config so it will be easy to know how data was augmented
        ostream.write(f"{json.dumps(aug_config)}\n")
        for entry in tqdm(data):
            # Save original example:
            ostream.write(f"{json.dumps(entry)}\n")
            example = entry[text_key]
            variants = augmenter.augment(example)
            if isinstance(variants, str):
                variants = [variants]
            for var in variants:
                new_entry = copy(entry)
                new_entry.update({text_key: var, original_key: False})
                ostream.write(f"{json.dumps(new_entry)}\n")