Ejemplo n.º 1
0
def config():
    debug = False
    timestamp = timeStamped('')[1:] + ('_debug' if debug else '')

    strong_label_crnn_hyper_params_dir = ''
    assert len(
        strong_label_crnn_hyper_params_dir
    ) > 0, 'Set strong_label_crnn_hyper_params_dir on the command line.'
    strong_label_crnn_tuning_config = load_json(
        Path(strong_label_crnn_hyper_params_dir) / '1' / 'config.json')
    strong_label_crnn_dirs = strong_label_crnn_tuning_config[
        'strong_label_crnn_dirs']
    assert len(strong_label_crnn_dirs
               ) > 0, 'strong_label_crnn_dirs must not be empty.'
    strong_label_crnn_checkpoints = strong_label_crnn_tuning_config[
        'strong_label_crnn_checkpoints']
    data_provider = strong_label_crnn_tuning_config['data_provider']
    database_name = strong_label_crnn_tuning_config['database_name']
    storage_dir = str(storage_root / 'strong_label_crnn' / database_name /
                      'inference' / timestamp)
    assert not Path(storage_dir).exists()

    weak_label_crnn_hyper_params_dir = strong_label_crnn_tuning_config[
        'weak_label_crnn_hyper_params_dir']
    assert len(
        weak_label_crnn_hyper_params_dir
    ) > 0, 'Set weak_label_crnn_hyper_params_dir on the command line.'
    weak_label_crnn_tuning_config = load_json(
        Path(weak_label_crnn_hyper_params_dir) / '1' / 'config.json')
    weak_label_crnn_dirs = weak_label_crnn_tuning_config['crnn_dirs']
    assert len(
        weak_label_crnn_dirs) > 0, 'weak_label_crnn_dirs must not be empty.'
    weak_label_crnn_checkpoints = weak_label_crnn_tuning_config[
        'crnn_checkpoints']

    del strong_label_crnn_tuning_config
    del weak_label_crnn_tuning_config

    sed_hyper_params_name = ['f', 'psds1', 'psds2']

    device = 0

    dataset_name = 'eval_public'
    ground_truth_filepath = None

    max_segment_length = None
    if max_segment_length is None:
        segment_overlap = None
    else:
        segment_overlap = 100
    save_scores = False
    save_detections = False

    weak_pseudo_labeling = False
    strong_pseudo_labeling = False
    pseudo_labelled_dataset_name = dataset_name

    pseudo_widening = .0

    ex.observers.append(FileStorageObserver.create(storage_dir))
Ejemplo n.º 2
0
def config():
    debug = False
    timestamp = timeStamped('')[1:] + ('_debug' if debug else '')

    weak_label_crnn_hyper_params_dir = ''
    assert len(
        weak_label_crnn_hyper_params_dir
    ) > 0, 'Set weak_label_crnn_hyper_params_dir on the command line.'
    weak_label_crnn_tuning_config = load_json(
        Path(weak_label_crnn_hyper_params_dir) / '1' / 'config.json')
    weak_label_crnn_dirs = weak_label_crnn_tuning_config['crnn_dirs']
    assert len(
        weak_label_crnn_dirs) > 0, 'weak_label_crnn_dirs must not be empty.'
    weak_label_crnn_checkpoints = weak_label_crnn_tuning_config[
        'crnn_checkpoints']
    del weak_label_crnn_tuning_config

    strong_label_crnn_group_dir = ''
    if isinstance(strong_label_crnn_group_dir, list):
        strong_label_crnn_dirs = sorted([
            str(d) for g in strong_label_crnn_group_dir
            for d in Path(g).glob('202*') if d.is_dir()
        ])
    else:
        strong_label_crnn_dirs = sorted([
            str(d) for d in Path(strong_label_crnn_group_dir).glob('202*')
            if d.is_dir()
        ])
    assert len(strong_label_crnn_dirs
               ) > 0, 'strong_label_crnn_dirs must not be empty.'
    strong_label_crnn_checkpoints = 'ckpt_best_macro_fscore_strong.pth'
    strong_crnn_config = load_json(
        Path(strong_label_crnn_dirs[0]) / '1' / 'config.json')
    data_provider = strong_crnn_config['data_provider']
    database_name = strong_crnn_config.get('database_name', 'desed')
    storage_dir = str(storage_root / 'strong_label_crnn' / database_name /
                      'hyper_params' / timestamp)
    assert not Path(storage_dir).exists()
    del strong_crnn_config
    data_provider['min_audio_length'] = .01
    data_provider['cached_datasets'] = None

    device = 0

    validation_set_name = 'validation'
    validation_ground_truth_filepath = None
    eval_set_name = 'eval_public'
    eval_ground_truth_filepath = None

    medfilt_lengths = [31] if debug else [
        301, 251, 201, 151, 101, 81, 61, 51, 41, 31, 21, 11
    ]

    ex.observers.append(FileStorageObserver.create(storage_dir))
Ejemplo n.º 3
0
def config():
    debug = False
    timestamp = timeStamped('')[1:] + ('_debug' if debug else '')

    group_dir = ''
    if isinstance(group_dir, list):
        crnn_dirs = sorted([
            str(d) for g in group_dir for d in Path(g).glob('202*')
            if d.is_dir()
        ])
    else:
        crnn_dirs = sorted(
            [str(d) for d in Path(group_dir).glob('202*') if d.is_dir()])
    assert len(crnn_dirs) > 0, 'crnn_dirs must not be empty.'
    crnn_checkpoints = 'ckpt_best_macro_fscore_weak.pth'
    crnn_config = load_json(Path(crnn_dirs[0]) / '1' / 'config.json')
    data_provider = crnn_config['data_provider']
    database_name = crnn_config.get('database_name', 'desed')
    storage_dir = str(storage_root / 'weak_label_crnn' / database_name /
                      'hyper_params' / timestamp)
    assert not Path(storage_dir).exists()
    del crnn_config
    data_provider['min_audio_length'] = .01
    data_provider['cached_datasets'] = None

    device = 0

    validation_set_name = 'validation'
    validation_ground_truth_filepath = None
    eval_set_name = 'eval_public'
    eval_ground_truth_filepath = None

    boundaries_filter_lengths = [20] if debug else [
        100, 80, 60, 50, 40, 30, 20, 10, 0
    ]

    tune_detection_scenario_1 = True
    detection_window_lengths_scenario_1 = [11] if debug else [
        51, 41, 31, 21, 11
    ]
    detection_window_shift_scenario_1 = 1
    detection_medfilt_lengths_scenario_1 = [11] if debug else [
        101, 81, 61, 51, 41, 31, 21, 11
    ]

    tune_detection_scenario_2 = True
    detection_window_lengths_scenario_2 = [250]
    detection_window_shift_scenario_2 = 250
    detection_medfilt_lengths_scenario_2 = [1]

    ex.observers.append(FileStorageObserver.create(storage_dir))
Ejemplo n.º 4
0
def config():
    delay = 0
    debug = False
    timestamp = timeStamped('')[1:] + ('_debug' if debug else '')
    group_name = timestamp
    database_name = 'desed'
    storage_dir = str(storage_root / 'strong_label_crnn' / database_name /
                      'training' / group_name / timestamp)

    init_ckpt_path = None
    frozen_cnn_2d_layers = 0
    frozen_cnn_1d_layers = 0

    # Data provider
    if database_name == 'desed':
        external_data = True
        batch_size = 32
        data_provider = {
            'factory':
            DESEDProvider,
            'json_path':
            str(database_jsons_dir /
                'desed_pseudo_labeled_with_external.json') if external_data
            else str(database_jsons_dir /
                     'desed_pseudo_labeled_without_external.json'),
            'train_set': {
                'train_weak': 10 if external_data else 20,
                'train_strong': 10 if external_data else 0,
                'train_synthetic20': 2,
                'train_synthetic21': 1,
                'train_unlabel_in_domain': 2,
            },
            'cached_datasets':
            None if debug else ['train_weak', 'train_synthetic20'],
            'train_fetcher': {
                'batch_size': batch_size,
                'prefetch_workers': batch_size,
                'min_dataset_examples_in_batch': {
                    'train_weak': int(3 * batch_size / 32),
                    'train_strong':
                    int(6 * batch_size / 32) if external_data else 0,
                    'train_synthetic20': int(1 * batch_size / 32),
                    'train_synthetic21': int(2 * batch_size / 32),
                    'train_unlabel_in_domain': 0,
                },
            },
            'storage_dir':
            storage_dir,
        }
        num_events = 10
        DESEDProvider.get_config(data_provider)

        validation_set_name = 'validation'
        validation_ground_truth_filepath = None
        weak_label_crnn_hyper_params_dir = ''
        eval_set_name = 'eval_public'
        eval_ground_truth_filepath = None

        num_iterations = 45000 if init_ckpt_path is None else 20000
        checkpoint_interval = 1000
        summary_interval = 100
        back_off_patience = None
        lr_decay_step = 30000 if back_off_patience is None else None
        lr_decay_factor = 1 / 5
        lr_rampup_steps = 1000 if init_ckpt_path is None else None
        gradient_clipping = 1e10 if init_ckpt_path is None else 1
    else:
        raise ValueError(f'Unknown database {database_name}.')

    # Trainer configuration
    net_config = 'shallow'
    if net_config == 'shallow':
        m = 1
        cnn = {
            'cnn_2d': {
                'out_channels': [
                    16 * m,
                    16 * m,
                    32 * m,
                    32 * m,
                    64 * m,
                    64 * m,
                    128 * m,
                    128 * m,
                    min(256 * m, 512),
                ],
                'pool_size':
                4 * [1, (2, 1)] + [1],
                'kernel_size':
                3,
                'norm':
                'batch',
                'norm_kwargs': {
                    'eps': 1e-3
                },
                'activation_fn':
                'relu',
                'dropout':
                .0,
                'output_layer':
                False,
            },
            'cnn_1d': {
                'out_channels': 3 * [256 * m],
                'kernel_size': 3,
                'norm': 'batch',
                'norm_kwargs': {
                    'eps': 1e-3
                },
                'activation_fn': 'relu',
                'dropout': .0,
                'output_layer': False,
            },
        }
    elif net_config == 'deep':
        m = 2
        cnn = {
            'cnn_2d': {
                'out_channels':
                (4 * [16 * m] + 4 * [32 * m] + 4 * [64 * m] + 4 * [128 * m] +
                 [256 * m, min(256 * m, 512)]),
                'pool_size':
                4 * [1, 1, 1, (2, 1)] + [1, 1],
                'kernel_size':
                9 * [3, 1],
                'residual_connections': [
                    None, None, 4, None, 6, None, 8, None, 10, None, 12, None,
                    14, None, 16, None, None, None
                ],
                'norm':
                'batch',
                'norm_kwargs': {
                    'eps': 1e-3
                },
                'activation_fn':
                'relu',
                'pre_activation':
                True,
                'dropout':
                .0,
                'output_layer':
                False,
            },
            'cnn_1d': {
                'out_channels': 8 * [256 * m],
                'kernel_size': [1] + 3 * [3, 1] + [1],
                'residual_connections':
                [None, 3, None, 5, None, 7, None, None],
                'norm': 'batch',
                'norm_kwargs': {
                    'eps': 1e-3
                },
                'activation_fn': 'relu',
                'pre_activation': True,
                'dropout': .0,
                'output_layer': False,
            },
        }
    else:
        raise ValueError(f'Unknown net_config {net_config}')

    if init_ckpt_path is not None:
        cnn['conditional_dims'] = 0

    trainer = {
        'model': {
            'factory': strong_label.CRNN,
            'feature_extractor': {
                'sample_rate':
                data_provider['audio_reader']['target_sample_rate'],
                'stft_size': data_provider['train_transform']['stft']['size'],
                'number_of_filters': 128,
                'frequency_warping_fn': {
                    'factory':
                    MelWarping,
                    'warp_factor_sampling_fn': {
                        'factory': LogTruncatedNormal,
                        'scale': .08,
                        'truncation': np.log(1.3),
                    },
                    'boundary_frequency_ratio_sampling_fn': {
                        'factory': TruncatedExponential,
                        'scale': .5,
                        'truncation': 5.,
                    },
                    'highest_frequency':
                    data_provider['audio_reader']['target_sample_rate'] / 2
                },
                # 'blur_sigma': .5,
                'n_time_masks': 1,
                'max_masked_time_steps': 70,
                'max_masked_time_rate': .2,
                'n_frequency_masks': 1,
                'max_masked_frequency_bands': 20,
                'max_masked_frequency_rate': .2,
                'max_noise_scale': .2,
            },
            'cnn': cnn,
            'rnn': {
                'hidden_size': 256 * m,
                'num_layers': 2,
                'dropout': .0,
                'output_net': {
                    'out_channels': [256 * m, num_events],
                    'kernel_size': 1,
                    'norm': 'batch',
                    'activation_fn': 'relu',
                    'dropout': .0,
                }
            },
            'labelwise_metrics': ('fscore_strong', ),
        },
        'optimizer': {
            'factory': Adam,
            'lr': 5e-4,
            'gradient_clipping': gradient_clipping,
            # 'weight_decay': 1e-6,
        },
        'summary_trigger': (summary_interval, 'iteration'),
        'checkpoint_trigger': (checkpoint_interval, 'iteration'),
        'stop_trigger': (num_iterations, 'iteration'),
        'storage_dir': storage_dir,
    }
    del cnn
    use_transformer = False
    if use_transformer:
        trainer['model']['rnn']['factory'] = TransformerStack
        trainer['model']['rnn']['hidden_size'] = 320
        trainer['model']['rnn']['num_heads'] = 10
        trainer['model']['rnn']['num_layers'] = 3
        trainer['model']['rnn']['dropout'] = 0.1
    Trainer.get_config(trainer)

    resume = False
    assert resume or not Path(trainer['storage_dir']).exists()
    ex.observers.append(FileStorageObserver.create(trainer['storage_dir']))
Ejemplo n.º 5
0
from paderbox.transform.module_stft import STFT
from padercontrib.database.fearless import Fearless
from padertorch import Trainer
from padertorch.contrib.examples.voice_activity_detection.model import SAD_Classifier
from padertorch.contrib.je.data.transforms import AudioReader, Normalizer, Collate
from padertorch.contrib.je.modules.conv import CNN1d, CNN2d
from padertorch.modules.fully_connected import fully_connected_stack
from padertorch.train.optimizer import Adam
import torch
from torch.nn import MaxPool2d
from torch.autograd import Variable
from paderbox.array import segment_axis
from einops import rearrange

storage_dir = str(
    Path(os.environ['STORAGE_ROOT']) / 'voice_activity' / timeStamped('')[1:])
os.makedirs(storage_dir, exist_ok=True)

DEBUG = False
DATA_TEST = False


def get_datasets():
    db = Fearless()
    train = db.get_dataset_train(subset='stream')
    validate = db.get_dataset_validation(subset='stream')

    def prepare_example(example):
        example['audio_path'] = example['audio_path']['observation']
        example['activity'] = db.get_activity(example)
        return example
Ejemplo n.º 6
0
                                      fading='full')
    model = WaveNet(wavenet=wavenet,
                    sample_rate=16000,
                    fft_length=512,
                    n_mels=64,
                    fmin=50)
    return model


def train(model, storage_dir):
    train_set, validate_set, _ = get_datasets()

    trainer = Trainer(model=model,
                      optimizer=Adam(lr=5e-4),
                      storage_dir=str(storage_dir),
                      summary_trigger=(1000, 'iteration'),
                      checkpoint_trigger=(10000, 'iteration'),
                      stop_trigger=(100000, 'iteration'))

    trainer.test_run(train_set, validate_set)
    trainer.register_validation_hook(validate_set)
    trainer.train(train_set)


if __name__ == '__main__':
    storage_dir = str(
        Path(os.environ['STORAGE_ROOT']) / 'wavenet' / timeStamped('')[1:])
    os.makedirs(storage_dir, exist_ok=True)
    model = get_model()
    train(model, storage_dir)
Ejemplo n.º 7
0
import numpy as np
from padercontrib.database.librispeech import LibriSpeech
from paderbox.utils.timer import timeStamped
from padertorch import Trainer
from padertorch.contrib.examples.speaker_classification.model import SpeakerClf
from padertorch.contrib.je.data.transforms import LabelEncoder, AudioReader, \
    STFT, MelTransform, Normalizer, Collate
from padertorch.contrib.je.data.utils import split_dataset
from padertorch.contrib.je.modules.conv import CNN1d
from padertorch.modules.fully_connected import fully_connected_stack
from padertorch.train.optimizer import Adam
from torch.nn import GRU

storage_dir = str(
    Path(os.environ['STORAGE_ROOT']) / 'speaker_clf' / timeStamped('')[1:])
os.makedirs(storage_dir, exist_ok=True)


def get_datasets():
    db = LibriSpeech()
    train_clean_100 = db.get_dataset('train_clean_100')

    def prepare_example(example):
        example['audio_path'] = example['audio_path']['observation']
        example['speaker_id'] = example['speaker_id'].split('-')[0]
        return example

    train_clean_100 = train_clean_100.map(prepare_example)

    train_set, validate_set = split_dataset(train_clean_100, fold=0)
Ejemplo n.º 8
0
from pathlib import Path

import numpy as np
import torch
from padercontrib.database.audio_set import AudioSet
from paderbox.utils.timer import timeStamped
from padertorch import Model, Trainer, optimizer
from padertorch.contrib.je.data.transforms import (
    AudioReader, STFT, MelTransform, Normalizer, LabelEncoder, Collate
)
from padertorch.contrib.je.modules.conv import CNN2d
from torch import nn
from padertorch.contrib.je.modules.norm import Norm

storage_dir = str(
    Path(os.environ['STORAGE_ROOT']) / 'audio_tagging' / timeStamped('')[1:]
)
os.makedirs(storage_dir, exist_ok=True)


class MultiHotLabelEncoder(LabelEncoder):
    def __call__(self, example):
        labels = super().__call__(example)[self.label_key]
        nhot_encoding = np.zeros(527).astype(np.float32)
        nhot_encoding[labels] = 1
        example[self.label_key] = nhot_encoding
        return example


class WALNet(Model):
    """
Ejemplo n.º 9
0
from padertorch.contrib.examples.audio_tagging.data import get_datasets
from padertorch.contrib.examples.audio_tagging.models import CRNN
from paderbox.utils.timer import timeStamped
from padertorch.contrib.je.modules.augment import (MelWarping,
                                                   LogTruncNormalSampler,
                                                   TruncExponentialSampler)
from padertorch.train.optimizer import Adam
from padertorch.train.trainer import Trainer
from sacred import Experiment as Exp
from sacred.commands import print_config
from sacred.observers import FileStorageObserver

nickname = 'audio_tagging'
ex = Exp(nickname)
storage_dir = str(
    Path(os.environ['STORAGE_ROOT']) / nickname / timeStamped('')[1:])
observer = FileStorageObserver.create(storage_dir)
ex.observers.append(observer)


@ex.config
def config():
    resume = False

    # Data configuration
    audio_reader = {
        'source_sample_rate': None,
        'target_sample_rate': 44100,
    }
    stft = {
        'shift': 882,