def config(): debug = False timestamp = timeStamped('')[1:] + ('_debug' if debug else '') strong_label_crnn_hyper_params_dir = '' assert len( strong_label_crnn_hyper_params_dir ) > 0, 'Set strong_label_crnn_hyper_params_dir on the command line.' strong_label_crnn_tuning_config = load_json( Path(strong_label_crnn_hyper_params_dir) / '1' / 'config.json') strong_label_crnn_dirs = strong_label_crnn_tuning_config[ 'strong_label_crnn_dirs'] assert len(strong_label_crnn_dirs ) > 0, 'strong_label_crnn_dirs must not be empty.' strong_label_crnn_checkpoints = strong_label_crnn_tuning_config[ 'strong_label_crnn_checkpoints'] data_provider = strong_label_crnn_tuning_config['data_provider'] database_name = strong_label_crnn_tuning_config['database_name'] storage_dir = str(storage_root / 'strong_label_crnn' / database_name / 'inference' / timestamp) assert not Path(storage_dir).exists() weak_label_crnn_hyper_params_dir = strong_label_crnn_tuning_config[ 'weak_label_crnn_hyper_params_dir'] assert len( weak_label_crnn_hyper_params_dir ) > 0, 'Set weak_label_crnn_hyper_params_dir on the command line.' weak_label_crnn_tuning_config = load_json( Path(weak_label_crnn_hyper_params_dir) / '1' / 'config.json') weak_label_crnn_dirs = weak_label_crnn_tuning_config['crnn_dirs'] assert len( weak_label_crnn_dirs) > 0, 'weak_label_crnn_dirs must not be empty.' weak_label_crnn_checkpoints = weak_label_crnn_tuning_config[ 'crnn_checkpoints'] del strong_label_crnn_tuning_config del weak_label_crnn_tuning_config sed_hyper_params_name = ['f', 'psds1', 'psds2'] device = 0 dataset_name = 'eval_public' ground_truth_filepath = None max_segment_length = None if max_segment_length is None: segment_overlap = None else: segment_overlap = 100 save_scores = False save_detections = False weak_pseudo_labeling = False strong_pseudo_labeling = False pseudo_labelled_dataset_name = dataset_name pseudo_widening = .0 ex.observers.append(FileStorageObserver.create(storage_dir))
def config(): debug = False timestamp = timeStamped('')[1:] + ('_debug' if debug else '') weak_label_crnn_hyper_params_dir = '' assert len( weak_label_crnn_hyper_params_dir ) > 0, 'Set weak_label_crnn_hyper_params_dir on the command line.' weak_label_crnn_tuning_config = load_json( Path(weak_label_crnn_hyper_params_dir) / '1' / 'config.json') weak_label_crnn_dirs = weak_label_crnn_tuning_config['crnn_dirs'] assert len( weak_label_crnn_dirs) > 0, 'weak_label_crnn_dirs must not be empty.' weak_label_crnn_checkpoints = weak_label_crnn_tuning_config[ 'crnn_checkpoints'] del weak_label_crnn_tuning_config strong_label_crnn_group_dir = '' if isinstance(strong_label_crnn_group_dir, list): strong_label_crnn_dirs = sorted([ str(d) for g in strong_label_crnn_group_dir for d in Path(g).glob('202*') if d.is_dir() ]) else: strong_label_crnn_dirs = sorted([ str(d) for d in Path(strong_label_crnn_group_dir).glob('202*') if d.is_dir() ]) assert len(strong_label_crnn_dirs ) > 0, 'strong_label_crnn_dirs must not be empty.' strong_label_crnn_checkpoints = 'ckpt_best_macro_fscore_strong.pth' strong_crnn_config = load_json( Path(strong_label_crnn_dirs[0]) / '1' / 'config.json') data_provider = strong_crnn_config['data_provider'] database_name = strong_crnn_config.get('database_name', 'desed') storage_dir = str(storage_root / 'strong_label_crnn' / database_name / 'hyper_params' / timestamp) assert not Path(storage_dir).exists() del strong_crnn_config data_provider['min_audio_length'] = .01 data_provider['cached_datasets'] = None device = 0 validation_set_name = 'validation' validation_ground_truth_filepath = None eval_set_name = 'eval_public' eval_ground_truth_filepath = None medfilt_lengths = [31] if debug else [ 301, 251, 201, 151, 101, 81, 61, 51, 41, 31, 21, 11 ] ex.observers.append(FileStorageObserver.create(storage_dir))
def config(): debug = False timestamp = timeStamped('')[1:] + ('_debug' if debug else '') group_dir = '' if isinstance(group_dir, list): crnn_dirs = sorted([ str(d) for g in group_dir for d in Path(g).glob('202*') if d.is_dir() ]) else: crnn_dirs = sorted( [str(d) for d in Path(group_dir).glob('202*') if d.is_dir()]) assert len(crnn_dirs) > 0, 'crnn_dirs must not be empty.' crnn_checkpoints = 'ckpt_best_macro_fscore_weak.pth' crnn_config = load_json(Path(crnn_dirs[0]) / '1' / 'config.json') data_provider = crnn_config['data_provider'] database_name = crnn_config.get('database_name', 'desed') storage_dir = str(storage_root / 'weak_label_crnn' / database_name / 'hyper_params' / timestamp) assert not Path(storage_dir).exists() del crnn_config data_provider['min_audio_length'] = .01 data_provider['cached_datasets'] = None device = 0 validation_set_name = 'validation' validation_ground_truth_filepath = None eval_set_name = 'eval_public' eval_ground_truth_filepath = None boundaries_filter_lengths = [20] if debug else [ 100, 80, 60, 50, 40, 30, 20, 10, 0 ] tune_detection_scenario_1 = True detection_window_lengths_scenario_1 = [11] if debug else [ 51, 41, 31, 21, 11 ] detection_window_shift_scenario_1 = 1 detection_medfilt_lengths_scenario_1 = [11] if debug else [ 101, 81, 61, 51, 41, 31, 21, 11 ] tune_detection_scenario_2 = True detection_window_lengths_scenario_2 = [250] detection_window_shift_scenario_2 = 250 detection_medfilt_lengths_scenario_2 = [1] ex.observers.append(FileStorageObserver.create(storage_dir))
def config(): delay = 0 debug = False timestamp = timeStamped('')[1:] + ('_debug' if debug else '') group_name = timestamp database_name = 'desed' storage_dir = str(storage_root / 'strong_label_crnn' / database_name / 'training' / group_name / timestamp) init_ckpt_path = None frozen_cnn_2d_layers = 0 frozen_cnn_1d_layers = 0 # Data provider if database_name == 'desed': external_data = True batch_size = 32 data_provider = { 'factory': DESEDProvider, 'json_path': str(database_jsons_dir / 'desed_pseudo_labeled_with_external.json') if external_data else str(database_jsons_dir / 'desed_pseudo_labeled_without_external.json'), 'train_set': { 'train_weak': 10 if external_data else 20, 'train_strong': 10 if external_data else 0, 'train_synthetic20': 2, 'train_synthetic21': 1, 'train_unlabel_in_domain': 2, }, 'cached_datasets': None if debug else ['train_weak', 'train_synthetic20'], 'train_fetcher': { 'batch_size': batch_size, 'prefetch_workers': batch_size, 'min_dataset_examples_in_batch': { 'train_weak': int(3 * batch_size / 32), 'train_strong': int(6 * batch_size / 32) if external_data else 0, 'train_synthetic20': int(1 * batch_size / 32), 'train_synthetic21': int(2 * batch_size / 32), 'train_unlabel_in_domain': 0, }, }, 'storage_dir': storage_dir, } num_events = 10 DESEDProvider.get_config(data_provider) validation_set_name = 'validation' validation_ground_truth_filepath = None weak_label_crnn_hyper_params_dir = '' eval_set_name = 'eval_public' eval_ground_truth_filepath = None num_iterations = 45000 if init_ckpt_path is None else 20000 checkpoint_interval = 1000 summary_interval = 100 back_off_patience = None lr_decay_step = 30000 if back_off_patience is None else None lr_decay_factor = 1 / 5 lr_rampup_steps = 1000 if init_ckpt_path is None else None gradient_clipping = 1e10 if init_ckpt_path is None else 1 else: raise ValueError(f'Unknown database {database_name}.') # Trainer configuration net_config = 'shallow' if net_config == 'shallow': m = 1 cnn = { 'cnn_2d': { 'out_channels': [ 16 * m, 16 * m, 32 * m, 32 * m, 64 * m, 64 * m, 128 * m, 128 * m, min(256 * m, 512), ], 'pool_size': 4 * [1, (2, 1)] + [1], 'kernel_size': 3, 'norm': 'batch', 'norm_kwargs': { 'eps': 1e-3 }, 'activation_fn': 'relu', 'dropout': .0, 'output_layer': False, }, 'cnn_1d': { 'out_channels': 3 * [256 * m], 'kernel_size': 3, 'norm': 'batch', 'norm_kwargs': { 'eps': 1e-3 }, 'activation_fn': 'relu', 'dropout': .0, 'output_layer': False, }, } elif net_config == 'deep': m = 2 cnn = { 'cnn_2d': { 'out_channels': (4 * [16 * m] + 4 * [32 * m] + 4 * [64 * m] + 4 * [128 * m] + [256 * m, min(256 * m, 512)]), 'pool_size': 4 * [1, 1, 1, (2, 1)] + [1, 1], 'kernel_size': 9 * [3, 1], 'residual_connections': [ None, None, 4, None, 6, None, 8, None, 10, None, 12, None, 14, None, 16, None, None, None ], 'norm': 'batch', 'norm_kwargs': { 'eps': 1e-3 }, 'activation_fn': 'relu', 'pre_activation': True, 'dropout': .0, 'output_layer': False, }, 'cnn_1d': { 'out_channels': 8 * [256 * m], 'kernel_size': [1] + 3 * [3, 1] + [1], 'residual_connections': [None, 3, None, 5, None, 7, None, None], 'norm': 'batch', 'norm_kwargs': { 'eps': 1e-3 }, 'activation_fn': 'relu', 'pre_activation': True, 'dropout': .0, 'output_layer': False, }, } else: raise ValueError(f'Unknown net_config {net_config}') if init_ckpt_path is not None: cnn['conditional_dims'] = 0 trainer = { 'model': { 'factory': strong_label.CRNN, 'feature_extractor': { 'sample_rate': data_provider['audio_reader']['target_sample_rate'], 'stft_size': data_provider['train_transform']['stft']['size'], 'number_of_filters': 128, 'frequency_warping_fn': { 'factory': MelWarping, 'warp_factor_sampling_fn': { 'factory': LogTruncatedNormal, 'scale': .08, 'truncation': np.log(1.3), }, 'boundary_frequency_ratio_sampling_fn': { 'factory': TruncatedExponential, 'scale': .5, 'truncation': 5., }, 'highest_frequency': data_provider['audio_reader']['target_sample_rate'] / 2 }, # 'blur_sigma': .5, 'n_time_masks': 1, 'max_masked_time_steps': 70, 'max_masked_time_rate': .2, 'n_frequency_masks': 1, 'max_masked_frequency_bands': 20, 'max_masked_frequency_rate': .2, 'max_noise_scale': .2, }, 'cnn': cnn, 'rnn': { 'hidden_size': 256 * m, 'num_layers': 2, 'dropout': .0, 'output_net': { 'out_channels': [256 * m, num_events], 'kernel_size': 1, 'norm': 'batch', 'activation_fn': 'relu', 'dropout': .0, } }, 'labelwise_metrics': ('fscore_strong', ), }, 'optimizer': { 'factory': Adam, 'lr': 5e-4, 'gradient_clipping': gradient_clipping, # 'weight_decay': 1e-6, }, 'summary_trigger': (summary_interval, 'iteration'), 'checkpoint_trigger': (checkpoint_interval, 'iteration'), 'stop_trigger': (num_iterations, 'iteration'), 'storage_dir': storage_dir, } del cnn use_transformer = False if use_transformer: trainer['model']['rnn']['factory'] = TransformerStack trainer['model']['rnn']['hidden_size'] = 320 trainer['model']['rnn']['num_heads'] = 10 trainer['model']['rnn']['num_layers'] = 3 trainer['model']['rnn']['dropout'] = 0.1 Trainer.get_config(trainer) resume = False assert resume or not Path(trainer['storage_dir']).exists() ex.observers.append(FileStorageObserver.create(trainer['storage_dir']))
from paderbox.transform.module_stft import STFT from padercontrib.database.fearless import Fearless from padertorch import Trainer from padertorch.contrib.examples.voice_activity_detection.model import SAD_Classifier from padertorch.contrib.je.data.transforms import AudioReader, Normalizer, Collate from padertorch.contrib.je.modules.conv import CNN1d, CNN2d from padertorch.modules.fully_connected import fully_connected_stack from padertorch.train.optimizer import Adam import torch from torch.nn import MaxPool2d from torch.autograd import Variable from paderbox.array import segment_axis from einops import rearrange storage_dir = str( Path(os.environ['STORAGE_ROOT']) / 'voice_activity' / timeStamped('')[1:]) os.makedirs(storage_dir, exist_ok=True) DEBUG = False DATA_TEST = False def get_datasets(): db = Fearless() train = db.get_dataset_train(subset='stream') validate = db.get_dataset_validation(subset='stream') def prepare_example(example): example['audio_path'] = example['audio_path']['observation'] example['activity'] = db.get_activity(example) return example
fading='full') model = WaveNet(wavenet=wavenet, sample_rate=16000, fft_length=512, n_mels=64, fmin=50) return model def train(model, storage_dir): train_set, validate_set, _ = get_datasets() trainer = Trainer(model=model, optimizer=Adam(lr=5e-4), storage_dir=str(storage_dir), summary_trigger=(1000, 'iteration'), checkpoint_trigger=(10000, 'iteration'), stop_trigger=(100000, 'iteration')) trainer.test_run(train_set, validate_set) trainer.register_validation_hook(validate_set) trainer.train(train_set) if __name__ == '__main__': storage_dir = str( Path(os.environ['STORAGE_ROOT']) / 'wavenet' / timeStamped('')[1:]) os.makedirs(storage_dir, exist_ok=True) model = get_model() train(model, storage_dir)
import numpy as np from padercontrib.database.librispeech import LibriSpeech from paderbox.utils.timer import timeStamped from padertorch import Trainer from padertorch.contrib.examples.speaker_classification.model import SpeakerClf from padertorch.contrib.je.data.transforms import LabelEncoder, AudioReader, \ STFT, MelTransform, Normalizer, Collate from padertorch.contrib.je.data.utils import split_dataset from padertorch.contrib.je.modules.conv import CNN1d from padertorch.modules.fully_connected import fully_connected_stack from padertorch.train.optimizer import Adam from torch.nn import GRU storage_dir = str( Path(os.environ['STORAGE_ROOT']) / 'speaker_clf' / timeStamped('')[1:]) os.makedirs(storage_dir, exist_ok=True) def get_datasets(): db = LibriSpeech() train_clean_100 = db.get_dataset('train_clean_100') def prepare_example(example): example['audio_path'] = example['audio_path']['observation'] example['speaker_id'] = example['speaker_id'].split('-')[0] return example train_clean_100 = train_clean_100.map(prepare_example) train_set, validate_set = split_dataset(train_clean_100, fold=0)
from pathlib import Path import numpy as np import torch from padercontrib.database.audio_set import AudioSet from paderbox.utils.timer import timeStamped from padertorch import Model, Trainer, optimizer from padertorch.contrib.je.data.transforms import ( AudioReader, STFT, MelTransform, Normalizer, LabelEncoder, Collate ) from padertorch.contrib.je.modules.conv import CNN2d from torch import nn from padertorch.contrib.je.modules.norm import Norm storage_dir = str( Path(os.environ['STORAGE_ROOT']) / 'audio_tagging' / timeStamped('')[1:] ) os.makedirs(storage_dir, exist_ok=True) class MultiHotLabelEncoder(LabelEncoder): def __call__(self, example): labels = super().__call__(example)[self.label_key] nhot_encoding = np.zeros(527).astype(np.float32) nhot_encoding[labels] = 1 example[self.label_key] = nhot_encoding return example class WALNet(Model): """
from padertorch.contrib.examples.audio_tagging.data import get_datasets from padertorch.contrib.examples.audio_tagging.models import CRNN from paderbox.utils.timer import timeStamped from padertorch.contrib.je.modules.augment import (MelWarping, LogTruncNormalSampler, TruncExponentialSampler) from padertorch.train.optimizer import Adam from padertorch.train.trainer import Trainer from sacred import Experiment as Exp from sacred.commands import print_config from sacred.observers import FileStorageObserver nickname = 'audio_tagging' ex = Exp(nickname) storage_dir = str( Path(os.environ['STORAGE_ROOT']) / nickname / timeStamped('')[1:]) observer = FileStorageObserver.create(storage_dir) ex.observers.append(observer) @ex.config def config(): resume = False # Data configuration audio_reader = { 'source_sample_rate': None, 'target_sample_rate': 44100, } stft = { 'shift': 882,