Ejemplo n.º 1
0
def run(_run, storage_dir, job_id, number_of_jobs, session_id, test_run=False):
    print_config(_run)

    assert job_id >= 1 and job_id <= number_of_jobs, (job_id, number_of_jobs)

    enhancer = get_enhancer()

    if test_run:
        print('Database', enhancer.db)

    if test_run is False:
        dataset_slice = slice(job_id - 1, None, number_of_jobs)
    else:
        dataset_slice = test_run

    if mpi.IS_MASTER:
        print('Enhancer:', enhancer)
        print(session_id)

    enhancer.enhance_session(
        session_id,
        Path(storage_dir) / 'audio',
        dataset_slice=dataset_slice,
        audio_dir_exist_ok=True,
    )
    if mpi.IS_MASTER:
        print('Finished experiment dir:', storage_dir)
Ejemplo n.º 2
0
def run(_run, test_run=False):
    if mpi.IS_MASTER:
        print_config(_run)
        _dir = get_dir()
        print('Experiment dir:', _dir)
    else:
        _dir = None

    _dir = mpi.bcast(_dir, mpi.MASTER)

    enhancer = get_enhancer()

    if test_run:
        print('Database', enhancer.db)

    session_ids = get_session_ids()
    if mpi.IS_MASTER:
        print('Enhancer:', enhancer)
        print(session_ids)

    enhancer.enhance_session(
        session_ids,
        _dir / 'audio',
        test_run=test_run,
    )
    if mpi.IS_MASTER:
        print('Finished experiment dir:', _dir)
Ejemplo n.º 3
0
def train(
    _run,
    audio_reader,
    stft,
    num_workers,
    batch_size,
    max_padding_rate,
    trainer,
    resume,
):

    print_config(_run)
    trainer = Trainer.from_config(trainer)

    train_iter, validation_iter = get_datasets(
        audio_reader=audio_reader,
        stft=stft,
        num_workers=num_workers,
        batch_size=batch_size,
        max_padding_rate=max_padding_rate,
        storage_dir=trainer.storage_dir)
    trainer.test_run(train_iter, validation_iter)

    trainer.register_validation_hook(validation_iter,
                                     metric='macro_fscore',
                                     maximize=True)

    trainer.train(train_iter, resume=resume)
Ejemplo n.º 4
0
def main(_run, _log, trainer, database_json, training_sets, validation_sets,
         audio_reader, stft, max_length_in_sec, batch_size, resume):
    commands.print_config(_run)
    trainer = Trainer.from_config(trainer)
    storage_dir = Path(trainer.storage_dir)
    storage_dir.mkdir(parents=True, exist_ok=True)
    commands.save_config(_run.config,
                         _log,
                         config_filename=str(storage_dir / 'config.json'))

    db = JsonDatabase(database_json)
    training_data = db.get_dataset(training_sets)
    validation_data = db.get_dataset(validation_sets)
    training_data = prepare_dataset(training_data,
                                    audio_reader=audio_reader,
                                    stft=stft,
                                    max_length_in_sec=max_length_in_sec,
                                    batch_size=batch_size,
                                    shuffle=True)
    validation_data = prepare_dataset(validation_data,
                                      audio_reader=audio_reader,
                                      stft=stft,
                                      max_length_in_sec=max_length_in_sec,
                                      batch_size=batch_size,
                                      shuffle=False)

    trainer.test_run(training_data, validation_data)
    trainer.register_validation_hook(validation_data)
    trainer.train(training_data, resume=resume)
Ejemplo n.º 5
0
def run(_run, chime6, test_run=False):
    if dlp_mpi.IS_MASTER:
        print_config(_run)
        _dir = get_dir()
        print('Experiment dir:', _dir)
    else:
        _dir = None

    _dir = dlp_mpi.bcast(_dir, dlp_mpi.MASTER)

    if chime6:
        enhancer = get_enhancer_chime6()
    else:
        enhancer = get_enhancer()

    if test_run:
        print('Database', enhancer.db)

    session_ids = get_session_ids()
    if dlp_mpi.IS_MASTER:
        print('Enhancer:', enhancer)
        print(session_ids)

    enhancer.enhance_session(session_ids,
                             _dir / 'audio',
                             dataset_slice=test_run,
                             audio_dir_exist_ok=True)
    if dlp_mpi.IS_MASTER:
        print('Finished experiment dir:', _dir)
Ejemplo n.º 6
0
def main(_run, _log, trainer, database_json, training_set, validation_metric,
         maximize_metric, audio_reader, stft, num_workers, batch_size,
         max_padding_rate, resume):
    commands.print_config(_run)
    trainer = Trainer.from_config(trainer)
    storage_dir = Path(trainer.storage_dir)
    storage_dir.mkdir(parents=True, exist_ok=True)
    commands.save_config(_run.config,
                         _log,
                         config_filename=str(storage_dir / 'config.json'))

    training_data, validation_data, _ = get_datasets(
        database_json=database_json,
        min_signal_length=1.5,
        audio_reader=audio_reader,
        stft=stft,
        num_workers=num_workers,
        batch_size=batch_size,
        max_padding_rate=max_padding_rate,
        training_set=training_set,
        storage_dir=storage_dir,
        stft_stretch_factor_sampling_fn=Uniform(low=0.5, high=1.5),
        stft_segment_length=audio_reader['target_sample_rate'],
        stft_segment_shuffle_prob=0.,
        mixup_probs=(1 / 2, 1 / 2),
        max_mixup_length=15.,
        min_mixup_overlap=.8,
    )

    trainer.test_run(training_data, validation_data)
    trainer.register_validation_hook(validation_data,
                                     metric=validation_metric,
                                     maximize=maximize_metric)
    trainer.train(training_data, resume=resume)
Ejemplo n.º 7
0
def train(logdir, device, iterations, resume_iteration, checkpoint_interval, batch_size, sequence_length,
          model_complexity, learning_rate, learning_rate_decay_steps, learning_rate_decay_rate, leave_one_out,
          clip_gradient_norm, validation_length, validation_interval):
    print_config(ex.current_run)

    os.makedirs(logdir, exist_ok=True)
    writer = SummaryWriter(logdir)

    train_groups, validation_groups = ['train'], ['validation']

    if leave_one_out is not None:
        all_years = {'2004', '2006', '2008', '2009', '2011', '2013', '2014', '2015', '2017'}
        train_groups = list(all_years - {str(leave_one_out)})
        validation_groups = [str(leave_one_out)]

    dataset = MAESTRO(groups=train_groups, sequence_length=sequence_length)
    loader = DataLoader(dataset, batch_size, shuffle=True)

    validation_dataset = MAESTRO(groups=validation_groups, sequence_length=validation_length)

    if resume_iteration is None:
        model = OnsetsAndFrames(N_MELS, MAX_MIDI - MIN_MIDI + 1, model_complexity).to(device)
        optimizer = torch.optim.Adam(model.parameters(), learning_rate)
        resume_iteration = 0
    else:
        model_path = os.path.join(logdir, f'model-{resume_iteration}.pt')
        model = torch.load(model_path)
        optimizer = torch.optim.Adam(model.parameters(), learning_rate)
        optimizer.load_state_dict(torch.load(os.path.join(logdir, 'last-optimizer-state.pt')))

    summary(model)
    scheduler = StepLR(optimizer, step_size=learning_rate_decay_steps, gamma=learning_rate_decay_rate)

    loop = tqdm(range(resume_iteration + 1, iterations + 1))
    for i, batch in zip(loop, cycle(loader)):
        scheduler.step()
        predictions, losses = model.run_on_batch(batch)

        loss = sum(losses.values())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if clip_gradient_norm:
            clip_grad_norm_(model.parameters(), clip_gradient_norm)

        for key, value in {'loss': loss, **losses}.items():
            writer.add_scalar(key, value.item(), global_step=i)

        if i % validation_interval == 0:
            model.eval()
            with torch.no_grad():
                for key, value in evaluate(validation_dataset, model).items():
                    writer.add_scalar('validation/' + key.replace(' ', '_'), np.mean(value), global_step=i)
            model.train()

        if i % checkpoint_interval == 0:
            torch.save(model, os.path.join(logdir, f'model-{i}.pt'))
            torch.save(optimizer.state_dict(), os.path.join(logdir, 'last-optimizer-state.pt'))
Ejemplo n.º 8
0
def main(_run, seed, save_folder, config_filename):
    set_global_seeds(seed)
    logger.info("Run id: {}".format(_run._id))

    print_config(ex.current_run)

    # saving config
    save_config(ex.current_run.config, ex.logger,
                config_filename=save_folder + config_filename)
    train()
Ejemplo n.º 9
0
def main(_run, seed, save_folder, config_filename):
    set_global_seeds(seed)
    logger.info("Run id: {}".format(_run._id))

    print_config(ex.current_run)

    # saving config
    save_config(ex.current_run.config,
                ex.logger,
                config_filename=save_folder + config_filename)
    train()
Ejemplo n.º 10
0
def main(_run, out):
    if dlp_mpi.IS_MASTER:
        from sacred.commands import print_config
        print_config(_run)

    ds = get_dataset()

    data = []

    for ex in dlp_mpi.split_managed(ds.sort(), allow_single_worker=True):
        for prediction in [
                'source',
                'early_0',
                'early_1',
                'image_0',
                'image_1',
                'image_0_noise',
                'image_1_noise',
        ]:
            for source in [
                    'source',
                    'early_0',
                    'early_1',
                    'image_0',
                    'image_1',
                    'image_0_noise',
                    'image_1_noise',
            ]:
                scores = get_scores(ex, prediction=prediction, source=source)
                for score_name, score_value in scores.items():
                    data.append(
                        dict(
                            score_name=score_name,
                            prediction=prediction,
                            source=source,
                            example_id=ex['example_id'],
                            value=score_value,
                        ))

    data = dlp_mpi.gather(data)

    if dlp_mpi.IS_MASTER:
        data = [entry for worker_data in data for entry in worker_data]

        if out is not None:
            assert isinstance(out, str), out
            assert out.endswith('.json'), out
            print(f'Write details to {out}.')
            dump_json(data, out)

        summary(data)
Ejemplo n.º 11
0
def main(_run, _log, trainer, database_json, dataset, batch_size):
    commands.print_config(_run)
    trainer = Trainer.from_config(trainer)
    storage_dir = Path(trainer.storage_dir)
    storage_dir.mkdir(parents=True, exist_ok=True)
    commands.save_config(_run.config,
                         _log,
                         config_filename=str(storage_dir / 'config.json'))

    train_set, validate_set, _ = get_datasets(storage_dir, database_json,
                                              dataset, batch_size)

    # Early stopping if loss is not decreasing after three consecutive validation
    # runs. Typically around 20k iterations (13 epochs) with an accuracy >98%
    # on the test set.
    trainer.register_validation_hook(validate_set, early_stopping_patience=3)
    trainer.test_run(train_set, validate_set)
    trainer.train(train_set)
Ejemplo n.º 12
0
def train(
        _run, trainer, device,
):
    print_config(_run)
    trainer = Trainer.from_config(trainer)
    train_iter, validate_iter, batch_norm_tuning_iter = get_datasets()
    if validate_iter is not None:
        trainer.register_validation_hook(validate_iter)
    trainer.train(train_iter, device=device)

    # finalize
    if trainer.optimizer.swa_start is not None:
        trainer.optimizer.swap_swa_sgd()
    batch_norm_update(
        trainer.model, batch_norm_tuning_iter,
        feature_key='features', device=device
    )
    torch.save(
        trainer.model.state_dict(),
        storage_dir / 'checkpoints' / 'ckpt_final.pth'
    )
Ejemplo n.º 13
0
def run(_run, storage_dir, job_id, number_of_jobs, session_id, test_run=False):
    if dlp_mpi.IS_MASTER:
        print_config(_run)

    assert job_id >= 1 and job_id <= number_of_jobs, (job_id, number_of_jobs)

    enhancer = get_enhancer()

    if test_run:
        print('Database', enhancer.db)

    if test_run is False:
        dataset_slice = slice(job_id - 1, None, number_of_jobs)
    else:
        dataset_slice = test_run

    if dlp_mpi.IS_MASTER:
        print('Enhancer:', enhancer)
        print(session_id)

    if session_id is None:
        session_ids = sorted(get_sessions())
    elif isinstance(session_id, str):
        session_ids = [session_id]
    elif isinstance(session_id, (tuple, list)):
        session_ids = session_id
    else:
        raise TypeError(type(session_id), session_id)

    for session_id in session_ids:
        enhancer.enhance_session(
            session_id,
            Path(storage_dir) / 'audio',
            dataset_slice=dataset_slice,
            audio_dir_exist_ok=True,
            is_chime=False,
        )

    if dlp_mpi.IS_MASTER:
        print('Finished experiment dir:', storage_dir)
Ejemplo n.º 14
0
def main(_run, exp_dir, storage_dir, database_json, test_set, max_examples,
         device):
    if IS_MASTER:
        commands.print_config(_run)

    exp_dir = Path(exp_dir)
    storage_dir = Path(storage_dir)
    audio_dir = storage_dir / 'audio'
    audio_dir.mkdir(parents=True)

    config = load_json(exp_dir / 'config.json')

    model = Model.from_storage_dir(exp_dir, consider_mpi=True)
    model.to(device)
    model.eval()

    db = JsonDatabase(database_json)
    test_data = db.get_dataset(test_set)
    if max_examples is not None:
        test_data = test_data.shuffle(
            rng=np.random.RandomState(0))[:max_examples]
    test_data = prepare_dataset(test_data,
                                audio_reader=config['audio_reader'],
                                stft=config['stft'],
                                max_length=None,
                                batch_size=1,
                                shuffle=True)
    squared_err = list()
    with torch.no_grad():
        for example in split_managed(test_data,
                                     is_indexable=False,
                                     progress_bar=True,
                                     allow_single_worker=True):
            example = model.example_to_device(example, device)
            target = example['audio_data'].squeeze(1)
            x = model.feature_extraction(example['stft'], example['seq_len'])
            x = model.wavenet.infer(
                x.squeeze(1),
                chunk_length=80_000,
                chunk_overlap=16_000,
            )
            assert target.shape == x.shape, (target.shape, x.shape)
            squared_err.extend([(ex_id, mse.cpu().detach().numpy(), x.shape[1])
                                for ex_id, mse in zip(example['example_id'], ((
                                    x - target)**2).sum(1))])

    squared_err_list = COMM.gather(squared_err, root=MASTER)

    if IS_MASTER:
        print(f'\nlen(squared_err_list): {len(squared_err_list)}')
        squared_err = []
        for i in range(len(squared_err_list)):
            squared_err.extend(squared_err_list[i])
        _, err, t = list(zip(*squared_err))
        print('rmse:', np.sqrt(np.sum(err) / np.sum(t)))
        rmse = sorted([(ex_id, np.sqrt(err / t))
                       for ex_id, err, t in squared_err],
                      key=lambda x: x[1])
        dump_json(rmse, storage_dir / 'rmse.json', indent=4, sort_keys=False)
        ex_ids_ordered = [x[0] for x in rmse]
        test_data = db.get_dataset('test_clean').shuffle(
            rng=np.random.RandomState(0))[:max_examples].filter(lambda x: x[
                'example_id'] in ex_ids_ordered[:10] + ex_ids_ordered[-10:],
                                                                lazy=False)
        test_data = prepare_dataset(test_data,
                                    audio_reader=config['audio_reader'],
                                    stft=config['stft'],
                                    max_length=10.,
                                    batch_size=1,
                                    shuffle=True)
        with torch.no_grad():
            for example in test_data:
                example = model.example_to_device(example, device)
                x = model.feature_extraction(example['stft'],
                                             example['seq_len'])
                x = model.wavenet.infer(
                    x.squeeze(1),
                    chunk_length=80_000,
                    chunk_overlap=16_000,
                )
                for i, audio in enumerate(x.cpu().detach().numpy()):
                    wavfile.write(
                        str(audio_dir / f'{example["example_id"][i]}.wav'),
                        model.sample_rate, audio)
Ejemplo n.º 15
0
def main(run_dir, data_dir, nb_epoch, early_stopping_patience, desired_sample_rate, fragment_length, batch_size,
         fragment_stride, nb_output_bins, keras_verbose, _log, seed, _config, debug, learn_all_outputs,
         train_only_in_receptive_field, _run, use_ulaw, train_with_soft_target_stdev):
    if run_dir is None:
        if not os.path.exists("models"):
            os.mkdir("models")
        run_dir = os.path.join('models', datetime.datetime.now().strftime('run_%Y%m%d_%H%M%S'))
        _config['run_dir'] = run_dir

    print_config(_run)

    _log.info('Running with seed %d' % seed)

    checkpoint_dir = os.path.join(run_dir, 'checkpoints')

    if not debug:
        if not os.path.exists(run_dir):
            os.mkdir(run_dir)
            os.mkdir(checkpoint_dir)
            json.dump(_config, open(os.path.join(run_dir, 'config.json'), 'w'))

    _log.info('Loading data...')
    data_generators, nb_examples = get_generators()

    _log.info('Building model...')
    model = build_model(fragment_length)
    _log.info(model.summary())

    optim = make_optimizer()
    _log.info('Compiling Model...')

    loss = objectives.categorical_crossentropy
    all_metrics = [
        metrics.categorical_accuracy,
        categorical_mean_squared_error
    ]
    if train_with_soft_target_stdev:
        loss = make_targets_soft(loss)
    if train_only_in_receptive_field:
        loss = skip_out_of_receptive_field(loss)
        all_metrics = [skip_out_of_receptive_field(m) for m in all_metrics]

    model.compile(optimizer=optim, loss=loss, metrics=all_metrics)
    # TODO: Consider gradient weighting making last outputs more important.

    callbacks = [
        ReduceLROnPlateau(patience=early_stopping_patience / 2, cooldown=early_stopping_patience / 4, verbose=1),
        EarlyStopping(patience=early_stopping_patience, verbose=1),
    ]
    if not debug:
        callbacks.extend([
            ModelCheckpoint(os.path.join(checkpoint_dir, 'checkpoint.{epoch:05d}-{val_loss:.3f}.hdf5'),
                            save_best_only=True),
            CSVLogger(os.path.join(run_dir, 'history.csv')),
        ])

    checkpoints = sorted(os.listdir(checkpoint_dir))
    if checkpoints:
       last_checkpoint = checkpoints[-1]
       _log.info('Loading existing weights ', os.path.join(checkpoint_dir, last_checkpoint))
       model.load_weights(os.path.join(checkpoint_dir, last_checkpoint))

    model.fit_generator(data_generators['train'],
                        len(data_generators['train']),
                        epochs=1,
                        validation_data=data_generators['test'],
                        validation_steps=5,
                        callbacks=callbacks,
                        use_multiprocessing=False,
                        verbose=keras_verbose)
Ejemplo n.º 16
0
def main(run_dir, data_dir, nb_epoch, early_stopping_patience,
         desired_sample_rate, fragment_length, batch_size, fragment_stride,
         nb_output_bins, keras_verbose, _log, seed, _config, debug,
         learn_all_outputs, train_only_in_receptive_field, _run, use_ulaw):
    if run_dir is None:
        run_dir = os.path.join(
            'models',
            datetime.datetime.now().strftime('run_%Y-%m-%d_%H:%M:%S'))
        _config['run_dir'] = run_dir

    print_config(_run)

    _log.info('Running with seed %d' % seed)

    if not debug:
        if os.path.exists(run_dir):
            raise EnvironmentError('Run with seed %d already exists' % seed)
        os.mkdir(run_dir)
        checkpoint_dir = os.path.join(run_dir, 'checkpoints')
        json.dump(_config, open(os.path.join(run_dir, 'config.json'), 'w'))

    _log.info('Loading data...')
    data_generators, nb_examples = dataset.generators(
        data_dir, desired_sample_rate, fragment_length, batch_size,
        fragment_stride, nb_output_bins, learn_all_outputs, use_ulaw)

    _log.info('Building model...')
    model = build_model(fragment_length)
    _log.info(model.summary())

    optim = make_optimizer()
    _log.info('Compiling Model...')

    loss = objectives.categorical_crossentropy
    all_metrics = [
        metrics.categorical_accuracy, metrics.categorical_mean_squared_error
    ]
    if train_only_in_receptive_field:
        loss = skip_out_of_receptive_field(loss)
        all_metrics = [skip_out_of_receptive_field(m) for m in all_metrics]

    model.compile(optimizer=optim, loss=loss, metrics=all_metrics)
    # TODO: Consider gradient weighting making last outputs more important.

    callbacks = [
        ReduceLROnPlateau(patience=early_stopping_patience / 2,
                          cooldown=early_stopping_patience / 4,
                          verbose=1),
        EarlyStopping(patience=early_stopping_patience, verbose=1),
    ]
    if not debug:
        callbacks.extend([
            ModelCheckpoint(os.path.join(
                checkpoint_dir, 'checkpoint.{epoch:05d}-{val_loss:.3f}.hdf5'),
                            save_best_only=True),
            CSVLogger(os.path.join(run_dir, 'history.csv')),
        ])

    if not debug:
        os.mkdir(checkpoint_dir)
        _log.info('Starting Training...')

    model.fit_generator(data_generators['train'],
                        nb_examples['train'],
                        nb_epoch=nb_epoch,
                        validation_data=data_generators['test'],
                        nb_val_samples=nb_examples['test'],
                        callbacks=callbacks,
                        verbose=keras_verbose)
Ejemplo n.º 17
0
def train(
    _run,
    debug,
    data_provider,
    filter_desed_test_clips,
    trainer,
    lr_rampup_steps,
    back_off_patience,
    lr_decay_step,
    lr_decay_factor,
    init_ckpt_path,
    frozen_cnn_2d_layers,
    frozen_cnn_1d_layers,
    track_emissions,
    resume,
    delay,
    validation_set_name,
    validation_ground_truth_filepath,
    eval_set_name,
    eval_ground_truth_filepath,
):
    print()
    print('##### Training #####')
    print()
    print_config(_run)
    assert (back_off_patience is None) or (lr_decay_step is None), (
        back_off_patience, lr_decay_step)
    if delay > 0:
        print(f'Sleep for {delay} seconds.')
        time.sleep(delay)

    data_provider = DataProvider.from_config(data_provider)
    data_provider.train_transform.label_encoder.initialize_labels(
        dataset=data_provider.db.get_dataset(data_provider.validate_set),
        verbose=True)
    data_provider.test_transform.label_encoder.initialize_labels()
    trainer = Trainer.from_config(trainer)
    trainer.model.label_mapping = []
    for idx, label in sorted(data_provider.train_transform.label_encoder.
                             inverse_label_mapping.items()):
        assert idx == len(
            trainer.model.label_mapping), (idx, label,
                                           len(trainer.model.label_mapping))
        trainer.model.label_mapping.append(
            label.replace(', ', '__').replace(' ',
                                              '').replace('(', '_').replace(
                                                  ')', '_').replace("'", ''))
    print('Params', sum(p.numel() for p in trainer.model.parameters()))
    print('CNN Params', sum(p.numel() for p in trainer.model.cnn.parameters()))

    if init_ckpt_path is not None:
        print('Load init params')
        state_dict = deflatten(torch.load(init_ckpt_path,
                                          map_location='cpu')['model'],
                               maxdepth=2)
        trainer.model.cnn.load_state_dict(flatten(state_dict['cnn']))
        trainer.model.rnn_fwd.rnn.load_state_dict(state_dict['rnn_fwd']['rnn'])
        trainer.model.rnn_bwd.rnn.load_state_dict(state_dict['rnn_bwd']['rnn'])
        # pop output layer from checkpoint
        param_keys = sorted(state_dict['rnn_fwd']['output_net'].keys())
        layer_idx = [key.split('.')[1] for key in param_keys]
        last_layer_idx = layer_idx[-1]
        for key, layer_idx in zip(param_keys, layer_idx):
            if layer_idx == last_layer_idx:
                state_dict['rnn_fwd']['output_net'].pop(key)
                state_dict['rnn_bwd']['output_net'].pop(key)
        trainer.model.rnn_fwd.output_net.load_state_dict(
            state_dict['rnn_fwd']['output_net'], strict=False)
        trainer.model.rnn_bwd.output_net.load_state_dict(
            state_dict['rnn_bwd']['output_net'], strict=False)
    if frozen_cnn_2d_layers:
        print(f'Freeze {frozen_cnn_2d_layers} cnn_2d layers')
        trainer.model.cnn.cnn_2d.freeze(frozen_cnn_2d_layers)
    if frozen_cnn_1d_layers:
        print(f'Freeze {frozen_cnn_1d_layers} cnn_1d layers')
        trainer.model.cnn.cnn_1d.freeze(frozen_cnn_1d_layers)

    if filter_desed_test_clips:
        with (database_jsons_dir / 'desed.json').open() as fid:
            desed_json = json.load(fid)
        filter_example_ids = {
            clip_id.rsplit('_', maxsplit=2)[0][1:]
            for clip_id in (list(desed_json['datasets']['validation'].keys()) +
                            list(desed_json['datasets']['eval_public'].keys()))
        }
    else:
        filter_example_ids = None
    train_set = data_provider.get_train_set(
        filter_example_ids=filter_example_ids)
    validate_set = data_provider.get_validate_set()

    if validate_set is not None:
        trainer.test_run(train_set, validate_set)
        trainer.register_validation_hook(
            validate_set,
            metric='macro_fscore_weak',
            maximize=True,
            back_off_patience=back_off_patience,
            n_back_off=0 if back_off_patience is None else 1,
            lr_update_factor=lr_decay_factor,
            early_stopping_patience=back_off_patience,
        )

    breakpoints = []
    if lr_rampup_steps is not None:
        breakpoints += [(0, 0.), (lr_rampup_steps, 1.)]
    if lr_decay_step is not None:
        breakpoints += [(lr_decay_step, 1.), (lr_decay_step, lr_decay_factor)]
    if len(breakpoints) > 0:
        if isinstance(trainer.optimizer, dict):
            names = sorted(trainer.optimizer.keys())
        else:
            names = [None]
        for name in names:
            trainer.register_hook(
                LRAnnealingHook(
                    trigger=AllTrigger(
                        (100, 'iteration'),
                        NotTrigger(
                            EndTrigger(breakpoints[-1][0] + 100, 'iteration')),
                    ),
                    breakpoints=breakpoints,
                    unit='iteration',
                    name=name,
                ))
    trainer.train(train_set, resume=resume, track_emissions=track_emissions)

    if validation_set_name is not None:
        tuning.run(
            config_updates={
                'debug': debug,
                'crnn_dirs': [str(trainer.storage_dir)],
                'validation_set_name': validation_set_name,
                'validation_ground_truth_filepath':
                validation_ground_truth_filepath,
                'eval_set_name': eval_set_name,
                'eval_ground_truth_filepath': eval_ground_truth_filepath,
                'data_provider': {
                    'test_fetcher': {
                        'batch_size': data_provider.train_fetcher.batch_size,
                    }
                },
            })
Ejemplo n.º 18
0
def main(_run, storage_dir, debug, weak_label_crnn_hyper_params_dir,
         weak_label_crnn_dirs, weak_label_crnn_checkpoints,
         strong_label_crnn_dirs, strong_label_crnn_checkpoints, data_provider,
         validation_set_name, validation_ground_truth_filepath, eval_set_name,
         eval_ground_truth_filepath, medfilt_lengths, device):
    print()
    print('##### Tuning #####')
    print()
    print_config(_run)
    print(storage_dir)
    storage_dir = Path(storage_dir)

    if not isinstance(weak_label_crnn_checkpoints, list):
        assert isinstance(weak_label_crnn_checkpoints,
                          str), weak_label_crnn_checkpoints
        weak_label_crnn_checkpoints = len(weak_label_crnn_dirs) * [
            weak_label_crnn_checkpoints
        ]
    weak_label_crnns = [
        weak_label.CRNN.from_storage_dir(storage_dir=crnn_dir,
                                         config_name='1/config.json',
                                         checkpoint_name=crnn_checkpoint)
        for crnn_dir, crnn_checkpoint in zip(weak_label_crnn_dirs,
                                             weak_label_crnn_checkpoints)
    ]
    data_provider = DESEDProvider.from_config(data_provider)
    data_provider.test_transform.label_encoder.initialize_labels()
    event_classes = data_provider.test_transform.label_encoder.inverse_label_mapping
    event_classes = [event_classes[i] for i in range(len(event_classes))]
    frame_shift = data_provider.test_transform.stft.shift
    frame_shift /= data_provider.audio_reader.target_sample_rate

    if validation_set_name == 'validation' and not validation_ground_truth_filepath:
        database_root = Path(
            data_provider.get_raw('validation')[0]
            ['audio_path']).parent.parent.parent.parent
        validation_ground_truth_filepath = database_root / 'metadata' / 'validation' / 'validation.tsv'
    elif validation_set_name == 'eval_public' and not validation_ground_truth_filepath:
        database_root = Path(
            data_provider.get_raw('eval_public')[0]
            ['audio_path']).parent.parent.parent.parent
        validation_ground_truth_filepath = database_root / 'metadata' / 'eval' / 'public.tsv'
    assert isinstance(
        validation_ground_truth_filepath,
        (str, Path)) and Path(validation_ground_truth_filepath).exists(
        ), validation_ground_truth_filepath

    dataset = data_provider.get_dataset(validation_set_name)
    audio_durations = {
        example['example_id']: example['audio_length']
        for example in data_provider.db.get_dataset(validation_set_name)
    }

    timestamps = {
        audio_id: np.array([0., audio_durations[audio_id]])
        for audio_id in audio_durations
    }
    tags, tagging_scores, _ = tagging(
        weak_label_crnns,
        dataset,
        device,
        timestamps,
        event_classes,
        weak_label_crnn_hyper_params_dir,
        None,
        None,
    )

    collar_based_params = {
        'onset_collar': .2,
        'offset_collar': .2,
        'offset_collar_rate': .2,
    }
    psds_scenario_1 = {
        'dtc_threshold': 0.7,
        'gtc_threshold': 0.7,
        'cttc_threshold': None,
        'alpha_ct': .0,
        'alpha_st': 1.,
    }
    psds_scenario_2 = {
        'dtc_threshold': 0.1,
        'gtc_threshold': 0.1,
        'cttc_threshold': 0.3,
        'alpha_ct': .5,
        'alpha_st': 1.,
    }
    metrics = {
        'f':
        partial(
            base.f_collar,
            ground_truth=validation_ground_truth_filepath,
            return_onset_offset_bias=True,
            num_jobs=8,
            **collar_based_params,
        ),
        'auc1':
        partial(
            base.psd_auc,
            ground_truth=validation_ground_truth_filepath,
            audio_durations=audio_durations,
            num_jobs=8,
            **psds_scenario_1,
        ),
        'auc2':
        partial(
            base.psd_auc,
            ground_truth=validation_ground_truth_filepath,
            audio_durations=audio_durations,
            num_jobs=8,
            **psds_scenario_2,
        )
    }

    if not isinstance(strong_label_crnn_checkpoints, list):
        assert isinstance(strong_label_crnn_checkpoints,
                          str), strong_label_crnn_checkpoints
        strong_label_crnn_checkpoints = len(strong_label_crnn_dirs) * [
            strong_label_crnn_checkpoints
        ]
    strong_label_crnns = [
        strong_label.CRNN.from_storage_dir(storage_dir=crnn_dir,
                                           config_name='1/config.json',
                                           checkpoint_name=crnn_checkpoint)
        for crnn_dir, crnn_checkpoint in zip(strong_label_crnn_dirs,
                                             strong_label_crnn_checkpoints)
    ]

    def add_tag_condition(example):
        example["tag_condition"] = np.array(
            [tags[example_id] for example_id in example["example_id"]])
        return example

    timestamps = np.arange(0, 10000) * frame_shift
    leaderboard = strong_label.crnn.tune_sound_event_detection(
        strong_label_crnns,
        dataset.map(add_tag_condition),
        device,
        timestamps,
        event_classes,
        tags,
        metrics,
        tag_masking={
            'f': True,
            'auc1': '?',
            'auc2': '?'
        },
        medfilt_lengths=medfilt_lengths,
    )
    dump_json(leaderboard['f'][1], storage_dir / f'sed_hyper_params_f.json')
    f, p, r, thresholds, _ = collar_based.best_fscore(
        scores=leaderboard['auc1'][2],
        ground_truth=validation_ground_truth_filepath,
        **collar_based_params,
        num_jobs=8)
    for event_class in thresholds:
        leaderboard['auc1'][1][event_class]['threshold'] = thresholds[
            event_class]
    dump_json(leaderboard['auc1'][1],
              storage_dir / 'sed_hyper_params_psds1.json')
    f, p, r, thresholds, _ = collar_based.best_fscore(
        scores=leaderboard['auc2'][2],
        ground_truth=validation_ground_truth_filepath,
        **collar_based_params,
        num_jobs=8)
    for event_class in thresholds:
        leaderboard['auc2'][1][event_class]['threshold'] = thresholds[
            event_class]
    dump_json(leaderboard['auc2'][1],
              storage_dir / 'sed_hyper_params_psds2.json')
    for crnn_dir in strong_label_crnn_dirs:
        tuning_dir = Path(crnn_dir) / 'hyper_params'
        os.makedirs(str(tuning_dir), exist_ok=True)
        (tuning_dir / storage_dir.name).symlink_to(storage_dir)
    print(storage_dir)

    if eval_set_name:
        evaluation.run(config_updates={
            'debug': debug,
            'strong_label_crnn_hyper_params_dir': str(storage_dir),
            'dataset_name': eval_set_name,
            'ground_truth_filepath': eval_ground_truth_filepath,
        }, )
Ejemplo n.º 19
0
def main(
    _run,
    storage_dir,
    strong_label_crnn_hyper_params_dir,
    sed_hyper_params_name,
    strong_label_crnn_dirs,
    strong_label_crnn_checkpoints,
    weak_label_crnn_hyper_params_dir,
    weak_label_crnn_dirs,
    weak_label_crnn_checkpoints,
    device,
    data_provider,
    dataset_name,
    ground_truth_filepath,
    save_scores,
    save_detections,
    max_segment_length,
    segment_overlap,
    strong_pseudo_labeling,
    pseudo_widening,
    pseudo_labelled_dataset_name,
):
    print()
    print('##### Inference #####')
    print()
    print_config(_run)
    print(storage_dir)
    emissions_tracker = EmissionsTracker(output_dir=storage_dir,
                                         on_csv_write="update",
                                         log_level='error')
    emissions_tracker.start()
    storage_dir = Path(storage_dir)

    collar_based_params = {
        'onset_collar': .2,
        'offset_collar': .2,
        'offset_collar_rate': .2,
    }
    psds_scenario_1 = {
        'dtc_threshold': 0.7,
        'gtc_threshold': 0.7,
        'cttc_threshold': None,
        'alpha_ct': .0,
        'alpha_st': 1.,
    }
    psds_scenario_2 = {
        'dtc_threshold': 0.1,
        'gtc_threshold': 0.1,
        'cttc_threshold': 0.3,
        'alpha_ct': .5,
        'alpha_st': 1.,
    }

    if not isinstance(weak_label_crnn_checkpoints, list):
        assert isinstance(weak_label_crnn_checkpoints,
                          str), weak_label_crnn_checkpoints
        weak_label_crnn_checkpoints = len(weak_label_crnn_dirs) * [
            weak_label_crnn_checkpoints
        ]
    weak_label_crnns = [
        weak_label.CRNN.from_storage_dir(storage_dir=crnn_dir,
                                         config_name='1/config.json',
                                         checkpoint_name=crnn_checkpoint)
        for crnn_dir, crnn_checkpoint in zip(weak_label_crnn_dirs,
                                             weak_label_crnn_checkpoints)
    ]
    print(
        'Weak Label CRNN Params',
        sum([
            p.numel() for crnn in weak_label_crnns for p in crnn.parameters()
        ]))
    print(
        'Weak Label CNN2d Params',
        sum([
            p.numel() for crnn in weak_label_crnns
            for p in crnn.cnn.cnn_2d.parameters()
        ]))
    if not isinstance(strong_label_crnn_checkpoints, list):
        assert isinstance(strong_label_crnn_checkpoints,
                          str), strong_label_crnn_checkpoints
        strong_label_crnn_checkpoints = len(strong_label_crnn_dirs) * [
            strong_label_crnn_checkpoints
        ]
    strong_label_crnns = [
        strong_label.CRNN.from_storage_dir(storage_dir=crnn_dir,
                                           config_name='1/config.json',
                                           checkpoint_name=crnn_checkpoint)
        for crnn_dir, crnn_checkpoint in zip(strong_label_crnn_dirs,
                                             strong_label_crnn_checkpoints)
    ]
    print(
        'Strong Label CRNN Params',
        sum([
            p.numel() for crnn in strong_label_crnns
            for p in crnn.parameters()
        ]))
    print(
        'Strong Label CNN2d Params',
        sum([
            p.numel() for crnn in strong_label_crnns
            for p in crnn.cnn.cnn_2d.parameters()
        ]))
    data_provider = DESEDProvider.from_config(data_provider)
    data_provider.test_transform.label_encoder.initialize_labels()
    event_classes = data_provider.test_transform.label_encoder.inverse_label_mapping
    event_classes = [event_classes[i] for i in range(len(event_classes))]
    frame_shift = data_provider.test_transform.stft.shift
    frame_shift /= data_provider.audio_reader.target_sample_rate

    if not isinstance(dataset_name, list):
        dataset_name = [dataset_name]
    if ground_truth_filepath is None:
        ground_truth_filepath = len(dataset_name) * [ground_truth_filepath]
    elif isinstance(ground_truth_filepath, (str, Path)):
        ground_truth_filepath = [ground_truth_filepath]
    assert len(ground_truth_filepath) == len(dataset_name)
    if not isinstance(strong_pseudo_labeling, list):
        strong_pseudo_labeling = len(dataset_name) * [strong_pseudo_labeling]
    assert len(strong_pseudo_labeling) == len(dataset_name)
    if not isinstance(pseudo_labelled_dataset_name, list):
        pseudo_labelled_dataset_name = [pseudo_labelled_dataset_name]
    assert len(pseudo_labelled_dataset_name) == len(dataset_name)

    database = deepcopy(data_provider.db.data)
    for i in range(len(dataset_name)):
        print()
        print(dataset_name[i])
        if dataset_name[i] == 'eval_public' and not ground_truth_filepath[i]:
            database_root = Path(
                data_provider.get_raw('eval_public')[0]
                ['audio_path']).parent.parent.parent.parent
            ground_truth_filepath[
                i] = database_root / 'metadata' / 'eval' / 'public.tsv'
        elif dataset_name[i] == 'validation' and not ground_truth_filepath[i]:
            database_root = Path(
                data_provider.get_raw('validation')[0]
                ['audio_path']).parent.parent.parent.parent
            ground_truth_filepath[
                i] = database_root / 'metadata' / 'validation' / 'validation.tsv'

        dataset = data_provider.get_dataset(dataset_name[i])
        audio_durations = {
            example['example_id']: example['audio_length']
            for example in data_provider.db.get_dataset(dataset_name[i])
        }

        score_storage_dir = storage_dir / 'scores' / dataset_name[i]
        detection_storage_dir = storage_dir / 'detections' / dataset_name[i]

        if max_segment_length is None:
            timestamps = {
                audio_id: np.array([0., audio_durations[audio_id]])
                for audio_id in audio_durations
            }
        else:
            timestamps = {}
            for audio_id in audio_durations:
                ts = np.arange(
                    (2 + max_segment_length) * frame_shift,
                    audio_durations[audio_id],
                    (max_segment_length - segment_overlap) * frame_shift)
                timestamps[audio_id] = np.concatenate(
                    ([0.], ts - segment_overlap / 2 * frame_shift,
                     [audio_durations[audio_id]]))
        if max_segment_length is not None:
            dataset = dataset.map(
                partial(segment_batch,
                        max_length=max_segment_length,
                        overlap=segment_overlap)).unbatch()
        tags, tagging_scores, _ = tagging(
            weak_label_crnns,
            dataset,
            device,
            timestamps,
            event_classes,
            weak_label_crnn_hyper_params_dir,
            None,
            None,
        )

        def add_tag_condition(example):
            example["tag_condition"] = np.array(
                [tags[example_id] for example_id in example["example_id"]])
            return example

        dataset = dataset.map(add_tag_condition)

        timestamps = np.round(np.arange(0, 100000) * frame_shift, decimals=6)
        if not isinstance(sed_hyper_params_name, (list, tuple)):
            sed_hyper_params_name = [sed_hyper_params_name]
        events, sed_results = sound_event_detection(
            strong_label_crnns,
            dataset,
            device,
            timestamps,
            event_classes,
            tags,
            strong_label_crnn_hyper_params_dir,
            sed_hyper_params_name,
            ground_truth_filepath[i],
            audio_durations,
            collar_based_params,
            [psds_scenario_1, psds_scenario_2],
            max_segment_length=max_segment_length,
            segment_overlap=segment_overlap,
            pseudo_widening=pseudo_widening,
            score_storage_dir=[
                score_storage_dir / name for name in sed_hyper_params_name
            ] if save_scores else None,
            detection_storage_dir=[
                detection_storage_dir / name for name in sed_hyper_params_name
            ] if save_detections else None,
        )
        for j, sed_results_j in enumerate(sed_results):
            if sed_results_j:
                dump_json(
                    sed_results_j, storage_dir /
                    f'sed_{sed_hyper_params_name[j]}_results_{dataset_name[i]}.json'
                )
        if strong_pseudo_labeling[i]:
            database['datasets'][
                pseudo_labelled_dataset_name[i]] = base.pseudo_label(
                    database['datasets'][dataset_name[i]],
                    event_classes,
                    False,
                    False,
                    strong_pseudo_labeling[i],
                    None,
                    None,
                    events[0],
                )
            with (storage_dir /
                  f'{dataset_name[i]}_pseudo_labeled.tsv').open('w') as fid:
                fid.write('filename\tonset\toffset\tevent_label\n')
                for key, event_list in events[0].items():
                    if len(event_list) == 0:
                        fid.write(f'{key}.wav\t\t\t\n')
                    for t_on, t_off, event_label in event_list:
                        fid.write(
                            f'{key}.wav\t{t_on}\t{t_off}\t{event_label}\n')

    if any(strong_pseudo_labeling):
        dump_json(
            database,
            storage_dir / Path(data_provider.json_path).name,
            create_path=True,
            indent=4,
            ensure_ascii=False,
        )
    inference_dir = Path(strong_label_crnn_hyper_params_dir) / 'inference'
    os.makedirs(str(inference_dir), exist_ok=True)
    (inference_dir / storage_dir.name).symlink_to(storage_dir)
    emissions_tracker.stop()
    print(storage_dir)
Ejemplo n.º 20
0
def main(_run, model_path, load_ckpt, batch_size, device, store_misclassified):
    if IS_MASTER:
        commands.print_config(_run)

    model_path = Path(model_path)
    eval_dir = get_new_subdir(model_path / 'eval',
                              id_naming='time',
                              consider_mpi=True)
    # perform evaluation on a sub-set (10%) of the dataset used for training
    config = load_json(model_path / 'config.json')
    database_json = config['database_json']
    dataset = config['dataset']

    model = pt.Model.from_storage_dir(model_path,
                                      checkpoint_name=load_ckpt,
                                      consider_mpi=True)
    model.to(device)
    # Turn on evaluation mode for, e.g., BatchNorm and Dropout modules
    model.eval()

    _, _, test_set = get_datasets(model_path,
                                  database_json,
                                  dataset,
                                  batch_size,
                                  return_indexable=device == 'cpu')
    with torch.no_grad():
        summary = dict(misclassified_examples=dict(),
                       correct_classified_examples=dict(),
                       hits=list())
        for batch in split_managed(test_set,
                                   is_indexable=device == 'cpu',
                                   progress_bar=True,
                                   allow_single_worker=True):
            output = model(pt.data.example_to_device(batch, device))
            prediction = torch.argmax(output, dim=-1).cpu().numpy()
            confidence = torch.softmax(output, dim=-1).max(dim=-1).values.cpu()\
                .numpy()
            label = np.array(batch['speaker_id'])
            hits = (label == prediction).astype('bool')
            summary['hits'].extend(hits.tolist())
            summary['misclassified_examples'].update({
                k: {
                    'true_label': v1,
                    'predicted_label': v2,
                    'audio_path': v3,
                    'confidence': f'{v4:.2%}',
                }
                for k, v1, v2, v3, v4 in zip(
                    np.array(batch['example_id'])[~hits], label[~hits],
                    prediction[~hits],
                    np.array(batch['audio_path'])[~hits], confidence[~hits])
            })
            # for each correct predicted label, collect the audio paths
            correct_classified = summary['correct_classified_examples']
            summary['correct_classified_examples'].update({
                k: correct_classified[k] +
                [v] if k in correct_classified.keys() else [v]
                for k, v in zip(prediction[hits],
                                np.array(batch['audio_path'])[hits])
            })

    summary_list = COMM.gather(summary, root=MASTER)

    if IS_MASTER:
        print(f'\nlen(summary_list): {len(summary_list)}')
        if len(summary_list) > 1:
            summary = dict(
                misclassified_examples=dict(),
                correct_classified_examples=dict(),
                hits=list(),
            )
            for partial_summary in summary_list:
                summary['hits'].extend(partial_summary['hits'])
                summary['misclassified_examples'].update(
                    partial_summary['misclassified_examples'])
                for label, audio_path_list in \
                        partial_summary['correct_classified_examples'].items():
                    summary['correct_classified_examples'].update({
                        label:
                        summary['correct_classified_examples'][label] +
                        audio_path_list if label
                        in summary['correct_classified_examples'].keys() else
                        audio_path_list
                    })
        hits = summary['hits']
        misclassified_examples = summary['misclassified_examples']
        correct_classified_examples = summary['correct_classified_examples']
        accuracy = np.array(hits).astype('float').mean()
        if store_misclassified:
            misclassified_dir = eval_dir / 'misclassified_examples'
            for example_id, v in misclassified_examples.items():
                label, prediction_label, audio_path, _ = v.values()
                try:
                    predicted_speaker_audio_path = \
                        correct_classified_examples[prediction_label][0]
                    example_dir = \
                        misclassified_dir / f'{example_id}_{label}_{prediction_label}'
                    example_dir.mkdir(parents=True)
                    os.symlink(audio_path, example_dir / 'example.wav')
                    os.symlink(predicted_speaker_audio_path,
                               example_dir / 'predicted_speaker_example.wav')
                except KeyError:
                    warnings.warn(
                        'There were no correctly predicted inputs from speaker '
                        f'with speaker label {prediction_label}')
        outputs = dict(
            accuracy=f'{accuracy:.2%} ({np.sum(hits)}/{len(hits)})',
            misclassifications=misclassified_examples,
        )
        print(f'Speaker classification accuracy on test set: {accuracy:.2%}')
        print(f'Wrote results to {eval_dir / "results.json"}')
        dump_json(outputs, eval_dir / 'results.json')
Ejemplo n.º 21
0
def main(_run, storage_dir, debug, crnn_dirs, crnn_checkpoints, data_provider,
         validation_set_name, validation_ground_truth_filepath, eval_set_name,
         eval_ground_truth_filepath, boundaries_filter_lengths,
         tune_detection_scenario_1, detection_window_lengths_scenario_1,
         detection_window_shift_scenario_1,
         detection_medfilt_lengths_scenario_1, tune_detection_scenario_2,
         detection_window_lengths_scenario_2,
         detection_window_shift_scenario_2,
         detection_medfilt_lengths_scenario_2, device):
    print()
    print('##### Tuning #####')
    print()
    print_config(_run)
    print(storage_dir)
    emissions_tracker = EmissionsTracker(output_dir=storage_dir,
                                         on_csv_write="update",
                                         log_level='error')
    emissions_tracker.start()
    storage_dir = Path(storage_dir)

    boundaries_collar_based_params = {
        'onset_collar': .5,
        'offset_collar': .5,
        'offset_collar_rate': .0,
        'min_precision': .8,
    }
    collar_based_params = {
        'onset_collar': .2,
        'offset_collar': .2,
        'offset_collar_rate': .2,
    }
    psds_scenario_1 = {
        'dtc_threshold': 0.7,
        'gtc_threshold': 0.7,
        'cttc_threshold': None,
        'alpha_ct': .0,
        'alpha_st': 1.,
    }
    psds_scenario_2 = {
        'dtc_threshold': 0.1,
        'gtc_threshold': 0.1,
        'cttc_threshold': 0.3,
        'alpha_ct': .5,
        'alpha_st': 1.,
    }

    if not isinstance(crnn_checkpoints, list):
        assert isinstance(crnn_checkpoints, str), crnn_checkpoints
        crnn_checkpoints = len(crnn_dirs) * [crnn_checkpoints]
    crnns = [
        weak_label.CRNN.from_storage_dir(storage_dir=crnn_dir,
                                         config_name='1/config.json',
                                         checkpoint_name=crnn_checkpoint)
        for crnn_dir, crnn_checkpoint in zip(crnn_dirs, crnn_checkpoints)
    ]
    data_provider = DataProvider.from_config(data_provider)
    data_provider.test_transform.label_encoder.initialize_labels()
    event_classes = data_provider.test_transform.label_encoder.inverse_label_mapping
    event_classes = [event_classes[i] for i in range(len(event_classes))]
    frame_shift = data_provider.test_transform.stft.shift
    frame_shift /= data_provider.audio_reader.target_sample_rate

    if validation_set_name == 'validation' and not validation_ground_truth_filepath:
        database_root = Path(
            data_provider.get_raw('validation')[0]
            ['audio_path']).parent.parent.parent.parent
        validation_ground_truth_filepath = database_root / 'metadata' / 'validation' / 'validation.tsv'
    elif validation_set_name == 'eval_public' and not validation_ground_truth_filepath:
        database_root = Path(
            data_provider.get_raw('eval_public')[0]
            ['audio_path']).parent.parent.parent.parent
        validation_ground_truth_filepath = database_root / 'metadata' / 'eval' / 'public.tsv'
    assert isinstance(
        validation_ground_truth_filepath,
        (str, Path)) and Path(validation_ground_truth_filepath).exists(
        ), validation_ground_truth_filepath

    dataset = data_provider.get_dataset(validation_set_name)
    audio_durations = {
        example['example_id']: example['audio_length']
        for example in data_provider.db.get_dataset(validation_set_name)
    }

    timestamps = {
        audio_id: np.array([0., audio_durations[audio_id]])
        for audio_id in audio_durations
    }
    metrics = {
        'f':
        partial(base.f_tag,
                ground_truth=validation_ground_truth_filepath,
                num_jobs=8)
    }
    leaderboard = weak_label.crnn.tune_tagging(crnns,
                                               dataset,
                                               device,
                                               timestamps,
                                               event_classes,
                                               metrics,
                                               storage_dir=storage_dir)
    _, hyper_params, tagging_scores = leaderboard['f']
    tagging_thresholds = np.array([
        hyper_params[event_class]['threshold'] for event_class in event_classes
    ])
    tags = {
        audio_id:
        tagging_scores[audio_id][event_classes].to_numpy() > tagging_thresholds
        for audio_id in tagging_scores
    }

    boundaries_ground_truth = base.boundaries_from_events(
        validation_ground_truth_filepath)
    timestamps = np.arange(0, 10000) * frame_shift
    metrics = {
        'f':
        partial(
            base.f_collar,
            ground_truth=boundaries_ground_truth,
            return_onset_offset_bias=True,
            num_jobs=8,
            **boundaries_collar_based_params,
        ),
    }
    weak_label.crnn.tune_boundary_detection(
        crnns,
        dataset,
        device,
        timestamps,
        event_classes,
        tags,
        metrics,
        tag_masking=True,
        stepfilt_lengths=boundaries_filter_lengths,
        storage_dir=storage_dir)

    if tune_detection_scenario_1:
        metrics = {
            'f':
            partial(
                base.f_collar,
                ground_truth=validation_ground_truth_filepath,
                return_onset_offset_bias=True,
                num_jobs=8,
                **collar_based_params,
            ),
            'auc':
            partial(
                base.psd_auc,
                ground_truth=validation_ground_truth_filepath,
                audio_durations=audio_durations,
                num_jobs=8,
                **psds_scenario_1,
            ),
        }
        leaderboard = weak_label.crnn.tune_sound_event_detection(
            crnns,
            dataset,
            device,
            timestamps,
            event_classes,
            tags,
            metrics,
            tag_masking={
                'f': True,
                'auc': '?'
            },
            window_lengths=detection_window_lengths_scenario_1,
            window_shift=detection_window_shift_scenario_1,
            medfilt_lengths=detection_medfilt_lengths_scenario_1,
        )
        dump_json(leaderboard['f'][1],
                  storage_dir / f'sed_hyper_params_f.json')
        f, p, r, thresholds, _ = collar_based.best_fscore(
            scores=leaderboard['auc'][2],
            ground_truth=validation_ground_truth_filepath,
            **collar_based_params,
            num_jobs=8)
        for event_class in thresholds:
            leaderboard['auc'][1][event_class]['threshold'] = thresholds[
                event_class]
        dump_json(leaderboard['auc'][1],
                  storage_dir / 'sed_hyper_params_psds1.json')
    if tune_detection_scenario_2:
        metrics = {
            'auc':
            partial(
                base.psd_auc,
                ground_truth=validation_ground_truth_filepath,
                audio_durations=audio_durations,
                num_jobs=8,
                **psds_scenario_2,
            )
        }
        leaderboard = weak_label.crnn.tune_sound_event_detection(
            crnns,
            dataset,
            device,
            timestamps,
            event_classes,
            tags,
            metrics,
            tag_masking=False,
            window_lengths=detection_window_lengths_scenario_2,
            window_shift=detection_window_shift_scenario_2,
            medfilt_lengths=detection_medfilt_lengths_scenario_2,
        )
        f, p, r, thresholds, _ = collar_based.best_fscore(
            scores=leaderboard['auc'][2],
            ground_truth=validation_ground_truth_filepath,
            **collar_based_params,
            num_jobs=8)
        for event_class in thresholds:
            leaderboard['auc'][1][event_class]['threshold'] = thresholds[
                event_class]
        dump_json(leaderboard['auc'][1],
                  storage_dir / 'sed_hyper_params_psds2.json')
    for crnn_dir in crnn_dirs:
        tuning_dir = Path(crnn_dir) / 'hyper_params'
        os.makedirs(str(tuning_dir), exist_ok=True)
        (tuning_dir / storage_dir.name).symlink_to(storage_dir)
    emissions_tracker.stop()
    print(storage_dir)

    if eval_set_name:
        if tune_detection_scenario_1:
            evaluation.run(config_updates={
                'debug': debug,
                'hyper_params_dir': str(storage_dir),
                'dataset_name': eval_set_name,
                'ground_truth_filepath': eval_ground_truth_filepath,
            }, )
        if tune_detection_scenario_2:
            evaluation.run(config_updates={
                'debug': debug,
                'hyper_params_dir': str(storage_dir),
                'dataset_name': eval_set_name,
                'ground_truth_filepath': eval_ground_truth_filepath,
                'sed_hyper_params_name': 'psds2',
            }, )
Ejemplo n.º 22
0
def train(logdir, device, iterations, resume_iteration, checkpoint_interval,
          train_on, batch_size, sequence_length, model_complexity,
          learning_rate, learning_rate_decay_steps, learning_rate_decay_rate,
          leave_one_out, clip_gradient_norm, validation_length,
          validation_interval):
    print_config(ex.current_run)

    os.makedirs(logdir, exist_ok=True)
    writer = SummaryWriter(logdir)

    train_groups, validation_groups = ['train'], ['validation']

    if leave_one_out is not None:
        all_years = {
            '2004', '2006', '2008', '2009', '2011', '2013', '2014', '2015',
            '2017'
        }
        train_groups = list(all_years - {str(leave_one_out)})
        validation_groups = [str(leave_one_out)]

    if train_on == 'MAESTRO':
        dataset = MAESTRO(groups=train_groups, sequence_length=sequence_length)
        validation_dataset = JAZZ_EVALUATE(groups=validation_groups,
                                           sequence_length=sequence_length)
    elif train_on == 'MAPS':
        dataset = MAPS(groups=[
            'AkPnBcht', 'AkPnBsdf', 'AkPnCGdD', 'AkPnStgb', 'SptkBGAm',
            'SptkBGCl', 'StbgTGd2'
        ],
                       sequence_length=sequence_length)
        validation_dataset = JAZZ_EVALUATE(groups=['ENSTDkAm', 'ENSTDkCl'],
                                           sequence_length=validation_length)
    elif train_on == 'WJD':
        dataset = WJD(groups=[''], sequence_length=sequence_length)
        validation_dataset = JAZZ_EVALUATE(groups=[''],
                                           sequence_length=validation_length)
    elif train_on == 'DMJ':
        dataset = DMJ(groups=[''], sequence_length=sequence_length)
        validation_dataset = JAZZ_EVALUATE(groups=[''],
                                           sequence_length=validation_length)
    elif train_on == 'MAPS_DMJ':
        dataset = MAPS_DMJ(groups=[''], sequence_length=sequence_length)
        validation_dataset = JAZZ_EVALUATE(groups=[''],
                                           sequence_length=validation_length)
    elif train_on == 'MILLS':
        dataset = MILLS(groups=[''], sequence_length=sequence_length)
        validation_dataset = JAZZ_EVALUATE(groups=[''],
                                           sequence_length=validation_length)
    elif train_on == 'MAESTRO_JAZZ':
        dataset = MAESTRO_JAZZ(groups=[''], sequence_length=sequence_length)
        validation_dataset = JAZZ_EVALUATE(groups=validation_groups,
                                           sequence_length=validation_length)
    elif train_on == 'MAESTRO_JAZZ_AUGMENTED':
        dataset = MAESTRO_JAZZ_AUGMENTED(groups=[''],
                                         sequence_length=sequence_length)
        validation_dataset = JAZZ_EVALUATE(groups=validation_groups,
                                           sequence_length=validation_length)
    elif train_on == 'DMJ_DMJOFFSHIFT':
        dataset = DMJ_DMJOFFSHIFT(groups=[''], sequence_length=sequence_length)
        validation_dataset = JAZZ_EVALUATE(groups=validation_groups,
                                           sequence_length=validation_length)
    elif train_on == 'MAPS_DMJ_DMJOFFSHIFT':
        dataset = MAPS_DMJ_DMJOFFSHIFT(groups=[''],
                                       sequence_length=sequence_length)
        validation_dataset = JAZZ_EVALUATE(groups=validation_groups,
                                           sequence_length=validation_length)
    elif train_on == 'JAZZ_TRAIN':
        dataset = JAZZ_TRAIN(groups=[''], sequence_length=sequence_length)
        validation_dataset = JAZZ_EVALUATE(groups=validation_groups,
                                           sequence_length=validation_length)
    elif train_on == 'MAESTRO_JAZZ_TRAIN':
        dataset = MAESTRO_JAZZ_TRAIN(groups=[''],
                                     sequence_length=sequence_length)
        validation_dataset = JAZZ_EVALUATE(groups=validation_groups,
                                           sequence_length=validation_length)
    elif train_on == 'MAPS_JAZZ_TRAIN':
        dataset = MAPS_JAZZ_TRAIN(groups=[''], sequence_length=sequence_length)
        validation_dataset = JAZZ_EVALUATE(groups=validation_groups,
                                           sequence_length=validation_length)
    else:
        pass

    loader = DataLoader(dataset, batch_size, shuffle=True, drop_last=True)

    if resume_iteration is None:
        model = OnsetsAndFrames(N_MELS, MAX_MIDI - MIN_MIDI + 1,
                                model_complexity).to(device)
        optimizer = torch.optim.Adam(model.parameters(), learning_rate)
        resume_iteration = 0
    else:
        model_path = os.path.join(logdir, f'model-{resume_iteration}.pt')
        model = torch.load(model_path)
        optimizer = torch.optim.Adam(model.parameters(), learning_rate)
        optimizer.load_state_dict(
            torch.load(os.path.join(logdir, 'last-optimizer-state.pt')))

    # i = 0
    # for child in model.children():
    #     print("Child {}:".format(i))
    #     print(child)
    #     if i == 1:
    #         print("offset model")
    #         count_parameters(child)
    #     i += 1

    # Comment this out to run the model train loop from scratch. Else, this will be finetuning the model
    model.freeze_stacks()
    count_parameters(model)

    summary(model)
    scheduler = StepLR(optimizer,
                       step_size=learning_rate_decay_steps,
                       gamma=learning_rate_decay_rate)

    loop = tqdm(range(resume_iteration + 1, iterations + 1))
    for i, batch in zip(loop, cycle(loader)):
        predictions, losses = model.run_on_batch(batch)

        loss = sum(losses.values())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        if clip_gradient_norm:
            clip_grad_norm_(model.parameters(), clip_gradient_norm)

        for key, value in {'loss': loss, **losses}.items():
            writer.add_scalar(key, value.item(), global_step=i)

        if i % validation_interval == 0:
            model.eval()
            with torch.no_grad():
                for key, value in evaluate_old(validation_dataset,
                                               model).items():
                    writer.add_scalar('validation/' + key.replace(' ', '_'),
                                      np.mean(value),
                                      global_step=i)
            model.train()

        if i % checkpoint_interval == 0:
            torch.save(model, os.path.join(logdir, f'model-{i}.pt'))
            torch.save(optimizer.state_dict(),
                       os.path.join(logdir, 'last-optimizer-state.pt'))
Ejemplo n.º 23
0
def train(
    _run,
    debug,
    data_provider,
    trainer,
    lr_rampup_steps,
    back_off_patience,
    lr_decay_step,
    lr_decay_factor,
    init_ckpt_path,
    frozen_cnn_2d_layers,
    frozen_cnn_1d_layers,
    resume,
    delay,
    validation_set_name,
    validation_ground_truth_filepath,
    weak_label_crnn_hyper_params_dir,
    eval_set_name,
    eval_ground_truth_filepath,
):
    print()
    print('##### Training #####')
    print()
    print_config(_run)
    assert (back_off_patience is None) or (lr_decay_step is None), (
        back_off_patience, lr_decay_step)
    if delay > 0:
        print(f'Sleep for {delay} seconds.')
        time.sleep(delay)

    data_provider = DESEDProvider.from_config(data_provider)
    data_provider.train_transform.label_encoder.initialize_labels(
        dataset=data_provider.db.get_dataset(data_provider.validate_set),
        verbose=True)
    data_provider.test_transform.label_encoder.initialize_labels()
    trainer = Trainer.from_config(trainer)
    trainer.model.label_mapping = []
    for idx, label in sorted(data_provider.train_transform.label_encoder.
                             inverse_label_mapping.items()):
        assert idx == len(
            trainer.model.label_mapping), (idx, label,
                                           len(trainer.model.label_mapping))
        trainer.model.label_mapping.append(
            label.replace(', ', '__').replace(' ',
                                              '').replace('(', '_').replace(
                                                  ')', '_').replace("'", ''))
    print('Params', sum(p.numel() for p in trainer.model.parameters()))

    if init_ckpt_path is not None:
        print('Load init params')
        state_dict = deflatten(torch.load(init_ckpt_path,
                                          map_location='cpu')['model'],
                               maxdepth=1)
        trainer.model.cnn.load_state_dict(state_dict['cnn'])
    if frozen_cnn_2d_layers:
        print(f'Freeze {frozen_cnn_2d_layers} cnn_2d layers')
        trainer.model.cnn.cnn_2d.freeze(frozen_cnn_2d_layers)
    if frozen_cnn_1d_layers:
        print(f'Freeze {frozen_cnn_1d_layers} cnn_1d layers')
        trainer.model.cnn.cnn_1d.freeze(frozen_cnn_1d_layers)

    def add_tag_condition(example):
        example["tag_condition"] = example["weak_targets"]
        return example

    train_set = data_provider.get_train_set().map(add_tag_condition)
    validate_set = data_provider.get_validate_set().map(add_tag_condition)

    if validate_set is not None:
        trainer.test_run(train_set, validate_set)
        trainer.register_validation_hook(
            validate_set,
            metric='macro_fscore_strong',
            maximize=True,
        )

    breakpoints = []
    if lr_rampup_steps is not None:
        breakpoints += [(0, 0.), (lr_rampup_steps, 1.)]
    if lr_decay_step is not None:
        breakpoints += [(lr_decay_step, 1.), (lr_decay_step, lr_decay_factor)]
    if len(breakpoints) > 0:
        if isinstance(trainer.optimizer, dict):
            names = sorted(trainer.optimizer.keys())
        else:
            names = [None]
        for name in names:
            trainer.register_hook(
                LRAnnealingHook(
                    trigger=AllTrigger(
                        (100, 'iteration'),
                        NotTrigger(
                            EndTrigger(breakpoints[-1][0] + 100, 'iteration')),
                    ),
                    breakpoints=breakpoints,
                    unit='iteration',
                    name=name,
                ))
    trainer.train(train_set, resume=resume)

    if validation_set_name:
        tuning.run(
            config_updates={
                'debug': debug,
                'weak_label_crnn_hyper_params_dir':
                weak_label_crnn_hyper_params_dir,
                'strong_label_crnn_dirs': [str(trainer.storage_dir)],
                'validation_set_name': validation_set_name,
                'validation_ground_truth_filepath':
                validation_ground_truth_filepath,
                'eval_set_name': eval_set_name,
                'eval_ground_truth_filepath': eval_ground_truth_filepath,
            })
Ejemplo n.º 24
0
def print_config_option(args, run):
    """Always print the configuration first."""
    print_config(run)
    print("-" * 79)
Ejemplo n.º 25
0
def main(run_dir, data_dir, nb_epoch, early_stopping_patience, desired_sample_rate, fragment_length, batch_size,
         fragment_stride, nb_output_bins, keras_verbose, _log, seed, _config, debug, learn_all_outputs,
         train_only_in_receptive_field, _run, use_ulaw, train_with_soft_target_stdev):
    if run_dir is None:
        if not os.path.exists("models"):
            os.mkdir("models")
        run_dir = os.path.join('models', datetime.datetime.now().strftime('run_%Y%m%d_%H%M%S'))
        _config['run_dir'] = run_dir

    print_config(_run)

    _log.info('Running with seed %d' % seed)

    if not debug:
        if os.path.exists(run_dir):
            raise EnvironmentError('Run with seed %d already exists' % seed)
        os.mkdir(run_dir)
        checkpoint_dir = os.path.join(run_dir, 'checkpoints')
        json.dump(_config, open(os.path.join(run_dir, 'config.json'), 'w'))

    _log.info('Loading data...')
    data_generators, nb_examples = get_generators()

    _log.info('Building model...')
    model = build_model(fragment_length)
    _log.info(model.summary())

    optim = make_optimizer()
    _log.info('Compiling Model...')

    loss = objectives.categorical_crossentropy
    all_metrics = [
        metrics.categorical_accuracy,
        categorical_mean_squared_error
    ]
    if train_with_soft_target_stdev:
        loss = make_targets_soft(loss)
    if train_only_in_receptive_field:
        loss = skip_out_of_receptive_field(loss)
        all_metrics = [skip_out_of_receptive_field(m) for m in all_metrics]

    model.compile(optimizer=optim, loss=loss, metrics=all_metrics)
    # TODO: Consider gradient weighting making last outputs more important.

    tictoc = strftime("%a_%d_%b_%Y_%H_%M_%S", gmtime())
    directory_name = tictoc
    log_dir = 'wavenet_' + directory_name
    os.mkdir(log_dir)
    tensorboard = TensorBoard(log_dir=log_dir)

    callbacks = [
        tensorboard,
        ReduceLROnPlateau(patience=early_stopping_patience / 2, cooldown=early_stopping_patience / 4, verbose=1),
        EarlyStopping(patience=early_stopping_patience, verbose=1),
    ]
    if not debug:
        callbacks.extend([
            ModelCheckpoint(os.path.join(checkpoint_dir, 'checkpoint.{epoch:05d}-{val_loss:.3f}.hdf5'),
                            save_best_only=True),
            CSVLogger(os.path.join(run_dir, 'history.csv')),
        ])

    if not debug:
        os.mkdir(checkpoint_dir)
        _log.info('Starting Training...')

    print("nb_examples['train'] {0}".format(nb_examples['train']))
    print("nb_examples['test'] {0}".format(nb_examples['test']))

    model.fit_generator(data_generators['train'],
                        steps_per_epoch=nb_examples['train'] // batch_size,
                        epochs=nb_epoch,
                        validation_data=data_generators['test'],
                        validation_steps=nb_examples['test'] // batch_size,
                        callbacks=callbacks,
                        verbose=keras_verbose)
Ejemplo n.º 26
0
def main(
    _run,
    out,
    mask_estimator,
    Observation,
    beamformer,
    postfilter,
    normalize_audio=True,
):
    if dlp_mpi.IS_MASTER:
        from sacred.commands import print_config
        print_config(_run)

    ds = get_dataset()

    data = []

    out = Path(out)

    for ex in dlp_mpi.split_managed(ds.sort(), allow_single_worker=True):

        if mask_estimator is None:
            mask = None
        elif mask_estimator == 'cacgmm':
            mask = get_mask_from_cacgmm(ex)
        else:
            mask = get_mask_from_oracle(ex, mask_estimator)

        metric, score = get_scores(
            ex,
            mask,
            Observation=Observation,
            beamformer=beamformer,
            postfilter=postfilter,
        )

        est0, est1 = metric.speech_prediction_selection
        dump_audio(est0,
                   out / ex['dataset'] / f"{ex['example_id']}_0.wav",
                   normalize=normalize_audio)
        dump_audio(est1,
                   out / ex['dataset'] / f"{ex['example_id']}_1.wav",
                   normalize=normalize_audio)

        data.append(
            dict(
                example_id=ex['example_id'],
                value=score,
                dataset=ex['dataset'],
            ))

        # print(score, repr(score))

    data = dlp_mpi.gather(data)

    if dlp_mpi.IS_MASTER:
        data = [entry for worker_data in data for entry in worker_data]

        data = {  # itertools.groupby expect an order
            dataset: list(subset)
            for dataset, subset in from_list(data).groupby(
                lambda ex: ex['dataset']).items()
        }

        for dataset, sub_data in data.items():
            print(f'Write details to {out}.')
            dump_json(sub_data, out / f'{dataset}_scores.json')

        for dataset, sub_data in data.items():
            summary = {}
            for k in sub_data[0]['value'].keys():
                m = np.mean([d['value'][k] for d in sub_data])
                print(dataset, k, m)
                summary[k] = m
            dump_json(summary, out / f'{dataset}_summary.json')
Ejemplo n.º 27
0
def train(_run, model, device, lr, gradient_clipping, weight_decay, swa_start,
          swa_freq, swa_lr, summary_interval, validation_interval, max_steps):
    print_config(_run)
    os.makedirs(storage_dir / 'checkpoints', exist_ok=True)
    train_iter, validate_iter, batch_norm_tuning_iter = get_datasets()
    model = CRNN(
        cnn_2d=CNN2d(**model['cnn_2d']),
        cnn_1d=CNN1d(**model['cnn_1d']),
        enc=GRU(**model['enc']),
        fcn=fully_connected_stack(**model['fcn']),
        fcn_noisy=None if model['fcn_noisy'] is None else
        fully_connected_stack(**model['fcn_noisy']),
    )
    print(sum(p.numel() for p in model.parameters() if p.requires_grad))
    model = model.to(device)
    model.train()
    optimizer = Adam(tuple(model.parameters()),
                     lr=lr,
                     weight_decay=weight_decay)
    if swa_start is not None:
        optimizer = SWA(optimizer,
                        swa_start=swa_start,
                        swa_freq=swa_freq,
                        swa_lr=swa_lr)

    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False

    # Summary
    summary_writer = tensorboardX.SummaryWriter(str(storage_dir))

    def get_empty_summary():
        return dict(
            scalars=defaultdict(list),
            histograms=defaultdict(list),
            images=dict(),
        )

    def update_summary(review, summary):
        review['scalars']['loss'] = review['loss'].detach()
        for key, value in review['scalars'].items():
            if torch.is_tensor(value):
                value = value.cpu().data.numpy()
            summary['scalars'][key].extend(np.array(value).flatten().tolist())
        for key, value in review['histograms'].items():
            if torch.is_tensor(value):
                value = value.cpu().data.numpy()
            summary['histograms'][key].extend(
                np.array(value).flatten().tolist())
        summary['images'] = review['images']

    def dump_summary(summary, prefix, iteration):
        summary = model.modify_summary(summary)

        # write summary
        for key, value in summary['scalars'].items():
            summary_writer.add_scalar(f'{prefix}/{key}', np.mean(value),
                                      iteration)
        for key, values in summary['histograms'].items():
            summary_writer.add_histogram(f'{prefix}/{key}', np.array(values),
                                         iteration)
        for key, image in summary['images'].items():
            summary_writer.add_image(f'{prefix}/{key}', image, iteration)
        return defaultdict(list)

    # Training loop
    train_summary = get_empty_summary()
    i = 0
    while i < max_steps:
        for batch in train_iter:
            optimizer.zero_grad()
            # forward
            batch = batch_to_device(batch, device=device)
            model_out = model(batch)

            # backward
            review = model.review(batch, model_out)
            review['loss'].backward()
            review['histograms']['grad_norm'] = torch.nn.utils.clip_grad_norm_(
                tuple(model.parameters()), gradient_clipping)
            optimizer.step()

            # update summary
            update_summary(review, train_summary)

            i += 1
            if i % summary_interval == 0:
                dump_summary(train_summary, 'training', i)
                train_summary = get_empty_summary()
            if i % validation_interval == 0 and validate_iter is not None:
                print('Starting Validation')
                model.eval()
                validate_summary = get_empty_summary()
                with torch.no_grad():
                    for batch in validate_iter:
                        batch = batch_to_device(batch, device=device)
                        model_out = model(batch)
                        review = model.review(batch, model_out)
                        update_summary(review, validate_summary)
                dump_summary(validate_summary, 'validation', i)
                print('Finished Validation')
                model.train()
            if i >= max_steps:
                break

    # finalize
    if swa_start is not None:
        optimizer.swap_swa_sgd()
    batch_norm_update(model,
                      batch_norm_tuning_iter,
                      feature_key='features',
                      device=device)
    torch.save(model.state_dict(),
               storage_dir / 'checkpoints' / 'ckpt_final.pth')
Ejemplo n.º 28
0
def main(run_dir, data_dir, nb_epoch, early_stopping_patience, desired_sample_rate, fragment_length, batch_size,
         fragment_stride, nb_output_bins, keras_verbose, _log, seed, _config, debug, learn_all_outputs,
         train_only_in_receptive_field, _run, use_ulaw, train_with_soft_target_stdev):
    if run_dir is None:
        if not os.path.exists("models"):
            os.makedirs("models", exist_ok=True)
        run_dir = os.path.join('models', datetime.datetime.now().strftime('run_%Y%m%d_%H%M%S'))
        _config['run_dir'] = run_dir

    print_config(_run)

    _log.info('Running with seed %d' % seed)

    if not debug:
        if os.path.exists(run_dir):
            raise EnvironmentError('Run with seed %d already exists' % seed)
        os.makedirs(run_dir, exist_ok=True)
        checkpoint_dir = os.path.join(run_dir, 'checkpoints')
        json.dump(_config, open(os.path.join(run_dir, 'config.json'), 'w'))

    _log.info('Loading data...')
    data_generators, nb_examples = get_generators()

    _log.info('Building model...')
    model = build_model(fragment_length)
    _log.info(model.summary())

    optim = make_optimizer()
    _log.info('Compiling Model...')

    loss = losses.categorical_crossentropy
    all_metrics = [
        metrics.categorical_accuracy,
        categorical_mean_squared_error
    ]
    if train_with_soft_target_stdev:
        loss = make_targets_soft(loss)
    if train_only_in_receptive_field:
        loss = skip_out_of_receptive_field(loss)
        all_metrics = [skip_out_of_receptive_field(m) for m in all_metrics]

    model.compile(optimizer=optim, loss=loss, metrics=all_metrics)
    # TODO: Consider gradient weighting making last outputs more important.

    tictoc = strftime("%a_%d_%b_%Y_%H_%M_%S", gmtime())
    directory_name = tictoc
    log_dir = 'tensorboard\\wavenet_' + directory_name
    os.makedirs(log_dir, exist_ok=True)
    tensorboard = TensorBoard(log_dir=log_dir)

    callbacks = [
        tensorboard,
        ReduceLROnPlateau(patience=early_stopping_patience / 2, cooldown=early_stopping_patience / 4, verbose=1),
        EarlyStopping(patience=early_stopping_patience, verbose=1),
    ]
    if not debug:
        os.makedirs(checkpoint_dir, exist_ok=True)
        callbacks.extend([
            ModelCheckpoint(os.path.join(checkpoint_dir, 'checkpoint.{epoch:05d}-{val_loss:.3f}.hdf5'),
                            save_best_only=True),
            CSVLogger(os.path.join(run_dir, 'history.csv')),
        ])
        _log.info('Starting Training...')

    print("nb_examples['train'] {0}".format(nb_examples['train']))
    print("nb_examples['test'] {0}".format(nb_examples['test']))

    model.fit_generator(data_generators['train'],
                        steps_per_epoch=nb_examples['train'] // batch_size,
                        epochs=nb_epoch,
                        validation_data=data_generators['test'],
                        validation_steps=nb_examples['test'] // batch_size,
                        callbacks=callbacks,
                        verbose=keras_verbose)
Ejemplo n.º 29
0
def main(_run, exp_dir, storage_dir, database_json, ckpt_name, num_workers,
         batch_size, max_padding_rate, device):
    commands.print_config(_run)

    exp_dir = Path(exp_dir)
    storage_dir = Path(storage_dir)

    config = load_json(exp_dir / 'config.json')

    model = Model.from_storage_dir(exp_dir,
                                   consider_mpi=True,
                                   checkpoint_name=ckpt_name)
    model.to(device)
    model.eval()

    _, validation_data, test_data = get_datasets(
        database_json=database_json,
        min_signal_length=1.5,
        audio_reader=config['audio_reader'],
        stft=config['stft'],
        num_workers=num_workers,
        batch_size=batch_size,
        max_padding_rate=max_padding_rate,
        storage_dir=exp_dir,
    )

    outputs = []
    with torch.no_grad():
        for example in tqdm(validation_data):
            example = model.example_to_device(example, device)
            (y, seq_len), _ = model(example)
            y = Mean(axis=-1)(y, seq_len)
            outputs.append((
                y.cpu().detach().numpy(),
                example['events'].cpu().detach().numpy(),
            ))

    scores, targets = list(zip(*outputs))
    scores = np.concatenate(scores)
    targets = np.concatenate(targets)
    thresholds, f1 = instance_based.get_optimal_thresholds(targets,
                                                           scores,
                                                           metric='f1')
    decisions = scores > thresholds
    f1, p, r = instance_based.fscore(targets, decisions, event_wise=True)
    ap = metrics.average_precision_score(targets, scores, None)
    auc = metrics.roc_auc_score(targets, scores, None)
    pos_class_indices, precision_at_hits = instance_based.positive_class_precisions(
        targets, scores)
    lwlrap, per_class_lwlrap, weight_per_class = instance_based.lwlrap_from_precisions(
        precision_at_hits, pos_class_indices, num_classes=targets.shape[1])
    overall_results = {
        'validation': {
            'mF1': np.mean(f1),
            'mP': np.mean(p),
            'mR': np.mean(r),
            'mAP': np.mean(ap),
            'mAUC': np.mean(auc),
            'lwlrap': lwlrap,
        }
    }
    event_validation_results = {}
    labels = load_json(exp_dir / 'events.json')
    for i, label in enumerate(labels):
        event_validation_results[label] = {
            'F1': f1[i],
            'P': p[i],
            'R': r[i],
            'AP': ap[i],
            'AUC': auc[i],
            'lwlrap': per_class_lwlrap[i],
        }

    outputs = []
    with torch.no_grad():
        for example in tqdm(test_data):
            example = model.example_to_device(example, device)
            (y, seq_len), _ = model(example)
            y = Mean(axis=-1)(y, seq_len)
            outputs.append((
                example['example_id'],
                y.cpu().detach().numpy(),
                example['events'].cpu().detach().numpy(),
            ))

    example_ids, scores, targets = list(zip(*outputs))
    example_ids = np.concatenate(example_ids).tolist()
    scores = np.concatenate(scores)
    targets = np.concatenate(targets)
    decisions = scores > thresholds
    f1, p, r = instance_based.fscore(targets, decisions, event_wise=True)
    ap = metrics.average_precision_score(targets, scores, None)
    auc = metrics.roc_auc_score(targets, scores, None)
    pos_class_indices, precision_at_hits = instance_based.positive_class_precisions(
        targets, scores)
    lwlrap, per_class_lwlrap, weight_per_class = instance_based.lwlrap_from_precisions(
        precision_at_hits, pos_class_indices, num_classes=targets.shape[1])
    overall_results['test'] = {
        'mF1': np.mean(f1),
        'mP': np.mean(p),
        'mR': np.mean(r),
        'mAP': np.mean(ap),
        'mAUC': np.mean(auc),
        'lwlrap': lwlrap,
    }
    dump_json(overall_results,
              storage_dir / 'overall.json',
              indent=4,
              sort_keys=False)
    event_results = {}
    for i, label in sorted(enumerate(labels),
                           key=lambda x: ap[x[0]],
                           reverse=True):
        event_results[label] = {
            'validation': event_validation_results[label],
            'test': {
                'F1': f1[i],
                'P': p[i],
                'R': r[i],
                'AP': ap[i],
                'AUC': auc[i],
                'lwlrap': per_class_lwlrap[i],
            },
        }
    dump_json(event_results,
              storage_dir / 'event_wise.json',
              indent=4,
              sort_keys=False)
    fp = np.argwhere(decisions * (1 - targets))
    dump_json(sorted([(example_ids[n], labels[i]) for n, i in fp]),
              storage_dir / 'fp.json',
              indent=4,
              sort_keys=False)
    fn = np.argwhere((1 - decisions) * targets)
    dump_json(sorted([(example_ids[n], labels[i]) for n, i in fn]),
              storage_dir / 'fn.json',
              indent=4,
              sort_keys=False)
    pprint(overall_results)
Ejemplo n.º 30
0
 def apply(cls, args, run):
     print_config(run)
     print('-' * 79)
Ejemplo n.º 31
0
def main(
    _run,
    storage_dir,
    hyper_params_dir,
    sed_hyper_params_name,
    crnn_dirs,
    crnn_checkpoints,
    device,
    data_provider,
    dataset_name,
    ground_truth_filepath,
    save_scores,
    save_detections,
    max_segment_length,
    segment_overlap,
    weak_pseudo_labeling,
    boundary_pseudo_labeling,
    strong_pseudo_labeling,
    pseudo_widening,
    pseudo_labeled_dataset_name,
):
    print()
    print('##### Inference #####')
    print()
    print_config(_run)
    print(storage_dir)
    emissions_tracker = EmissionsTracker(output_dir=storage_dir,
                                         on_csv_write="update",
                                         log_level='error')
    emissions_tracker.start()
    storage_dir = Path(storage_dir)

    boundary_collar_based_params = {
        'onset_collar': .5,
        'offset_collar': .5,
        'offset_collar_rate': .0,
    }
    collar_based_params = {
        'onset_collar': .2,
        'offset_collar': .2,
        'offset_collar_rate': .2,
    }
    psds_scenario_1 = {
        'dtc_threshold': 0.7,
        'gtc_threshold': 0.7,
        'cttc_threshold': None,
        'alpha_ct': .0,
        'alpha_st': 1.,
    }
    psds_scenario_2 = {
        'dtc_threshold': 0.1,
        'gtc_threshold': 0.1,
        'cttc_threshold': 0.3,
        'alpha_ct': .5,
        'alpha_st': 1.,
    }

    if not isinstance(crnn_checkpoints, list):
        assert isinstance(crnn_checkpoints, str), crnn_checkpoints
        crnn_checkpoints = len(crnn_dirs) * [crnn_checkpoints]
    crnns = [
        CRNN.from_storage_dir(storage_dir=crnn_dir,
                              config_name='1/config.json',
                              checkpoint_name=crnn_checkpoint)
        for crnn_dir, crnn_checkpoint in zip(crnn_dirs, crnn_checkpoints)
    ]
    print('Params',
          sum([p.numel() for crnn in crnns for p in crnn.parameters()]))
    print(
        'CNN2d Params',
        sum([
            p.numel() for crnn in crnns for p in crnn.cnn.cnn_2d.parameters()
        ]))
    data_provider = DataProvider.from_config(data_provider)
    data_provider.test_transform.label_encoder.initialize_labels()
    event_classes = data_provider.test_transform.label_encoder.inverse_label_mapping
    event_classes = [event_classes[i] for i in range(len(event_classes))]
    frame_shift = data_provider.test_transform.stft.shift
    frame_shift /= data_provider.audio_reader.target_sample_rate

    if not isinstance(dataset_name, list):
        dataset_name = [dataset_name]
    if ground_truth_filepath is None:
        ground_truth_filepath = len(dataset_name) * [ground_truth_filepath]
    elif isinstance(ground_truth_filepath, (str, Path)):
        ground_truth_filepath = [ground_truth_filepath]
    assert len(ground_truth_filepath) == len(dataset_name)
    if not isinstance(weak_pseudo_labeling, list):
        weak_pseudo_labeling = len(dataset_name) * [weak_pseudo_labeling]
    assert len(weak_pseudo_labeling) == len(dataset_name)
    if not isinstance(boundary_pseudo_labeling, list):
        boundary_pseudo_labeling = len(dataset_name) * [
            boundary_pseudo_labeling
        ]
    assert len(boundary_pseudo_labeling) == len(dataset_name)
    if not isinstance(strong_pseudo_labeling, list):
        strong_pseudo_labeling = len(dataset_name) * [strong_pseudo_labeling]
    assert len(strong_pseudo_labeling) == len(dataset_name)
    if not isinstance(pseudo_labeled_dataset_name, list):
        pseudo_labeled_dataset_name = [pseudo_labeled_dataset_name]
    assert len(pseudo_labeled_dataset_name) == len(dataset_name)

    database = deepcopy(data_provider.db.data)
    for i in range(len(dataset_name)):
        print()
        print(dataset_name[i])

        if dataset_name[i] == 'eval_public' and not ground_truth_filepath[i]:
            database_root = Path(
                data_provider.get_raw('eval_public')[0]
                ['audio_path']).parent.parent.parent.parent
            ground_truth_filepath[
                i] = database_root / 'metadata' / 'eval' / 'public.tsv'
        elif dataset_name[i] == 'validation' and not ground_truth_filepath[i]:
            database_root = Path(
                data_provider.get_raw('validation')[0]
                ['audio_path']).parent.parent.parent.parent
            ground_truth_filepath[
                i] = database_root / 'metadata' / 'validation' / 'validation.tsv'

        dataset = data_provider.get_dataset(dataset_name[i])
        audio_durations = {
            example['example_id']: example['audio_length']
            for example in data_provider.db.get_dataset(dataset_name[i])
        }

        score_storage_dir = storage_dir / 'scores' / dataset_name[i]
        detection_storage_dir = storage_dir / 'detections' / dataset_name[i]

        if max_segment_length is None:
            timestamps = {
                audio_id: np.array([0., audio_durations[audio_id]])
                for audio_id in audio_durations
            }
        else:
            timestamps = {}
            for audio_id in audio_durations:
                ts = np.arange(0, audio_durations[audio_id],
                               (max_segment_length - segment_overlap) *
                               frame_shift)
                timestamps[audio_id] = np.concatenate(
                    (ts, [audio_durations[audio_id]]))
        tags, tagging_scores, tagging_results = tagging(
            crnns,
            dataset,
            device,
            timestamps,
            event_classes,
            hyper_params_dir,
            ground_truth_filepath[i],
            audio_durations,
            [psds_scenario_1, psds_scenario_2],
            max_segment_length=max_segment_length,
            segment_overlap=segment_overlap,
        )
        if tagging_results:
            dump_json(tagging_results,
                      storage_dir / f'tagging_results_{dataset_name[i]}.json')

        timestamps = np.round(np.arange(0, 100000) * frame_shift, decimals=6)
        if ground_truth_filepath[i] is not None or boundary_pseudo_labeling[i]:
            boundaries, boundaries_detection_results = boundaries_detection(
                crnns,
                dataset,
                device,
                timestamps,
                event_classes,
                tags,
                hyper_params_dir,
                ground_truth_filepath[i],
                boundary_collar_based_params,
                max_segment_length=max_segment_length,
                segment_overlap=segment_overlap,
                pseudo_widening=pseudo_widening,
            )
            if boundaries_detection_results:
                dump_json(
                    boundaries_detection_results, storage_dir /
                    f'boundaries_detection_results_{dataset_name[i]}.json')
        else:
            boundaries = {}
        if not isinstance(sed_hyper_params_name, (list, tuple)):
            sed_hyper_params_name = [sed_hyper_params_name]
        if (ground_truth_filepath[i] is not None
            ) or strong_pseudo_labeling[i] or save_scores or save_detections:
            events, sed_results = sound_event_detection(
                crnns,
                dataset,
                device,
                timestamps,
                event_classes,
                tags,
                hyper_params_dir,
                sed_hyper_params_name,
                ground_truth_filepath[i],
                audio_durations,
                collar_based_params,
                [psds_scenario_1, psds_scenario_2],
                max_segment_length=max_segment_length,
                segment_overlap=segment_overlap,
                pseudo_widening=pseudo_widening,
                score_storage_dir=[
                    score_storage_dir / name for name in sed_hyper_params_name
                ] if save_scores else None,
                detection_storage_dir=[
                    detection_storage_dir / name
                    for name in sed_hyper_params_name
                ] if save_detections else None,
            )
            for j, sed_results_j in enumerate(sed_results):
                if sed_results_j:
                    dump_json(
                        sed_results_j, storage_dir /
                        f'sed_{sed_hyper_params_name[j]}_results_{dataset_name[i]}.json'
                    )
        else:
            events = [{}]
        database['datasets'][
            pseudo_labeled_dataset_name[i]] = base.pseudo_label(
                database['datasets'][dataset_name[i]],
                event_classes,
                weak_pseudo_labeling[i],
                boundary_pseudo_labeling[i],
                strong_pseudo_labeling[i],
                tags,
                boundaries,
                events[0],
            )

    if any(weak_pseudo_labeling) or any(boundary_pseudo_labeling) or any(
            strong_pseudo_labeling):
        dump_json(
            database,
            storage_dir / Path(data_provider.json_path).name,
            create_path=True,
            indent=4,
            ensure_ascii=False,
        )
    inference_dir = Path(hyper_params_dir) / 'inference'
    os.makedirs(str(inference_dir), exist_ok=True)
    (inference_dir / storage_dir.name).symlink_to(storage_dir)
    emissions_tracker.stop()
    print(storage_dir)
Ejemplo n.º 32
0
 def apply(cls, args, run):
     print_config(run)
     print('-' * 79)
Ejemplo n.º 33
0
def main(run_dir, data_dir, nb_epoch, early_stopping_patience,
         desired_sample_rate, fragment_length, batch_size, fragment_stride,
         nb_output_bins, keras_verbose, _log, seed, _config, debug,
         learn_all_outputs, train_only_in_receptive_field, _run, use_ulaw,
         train_with_soft_target_stdev):

    dataset.set_multi_gpu(True)

    if run_dir is None:
        if not os.path.exists("models"):
            os.mkdir("models")
        run_dir = os.path.join(
            'models',
            datetime.datetime.now().strftime('run_%Y%m%d_%H%M%S'))
        _config['run_dir'] = run_dir

    print_config(_run)

    _log.info('Running with seed %d' % seed)

    if not debug:
        if not os.path.exists(run_dir):
            os.mkdir(run_dir)
        checkpoint_dir = os.path.join(run_dir, 'checkpoints')
        json.dump(_config, open(os.path.join(run_dir, 'config.json'), 'w'))

    _log.info('Loading data...')
    data_generators, nb_examples = get_generators()

    _log.info('Building model...')
    model = build_model(fragment_length)
    _log.info(model.summary())

    optim = make_optimizer()
    optim = hvd.DistributedOptimizer(optim)
    _log.info('Compiling Model...')

    loss = objectives.categorical_crossentropy
    all_metrics = [
        metrics.categorical_accuracy, categorical_mean_squared_error
    ]
    if train_with_soft_target_stdev:
        loss = make_targets_soft(loss)
    if train_only_in_receptive_field:
        loss = skip_out_of_receptive_field(loss)
        all_metrics = [skip_out_of_receptive_field(m) for m in all_metrics]

    model.compile(optimizer=optim, loss=loss, metrics=all_metrics)
    # TODO: Consider gradient weighting making last outputs more important.

    callbacks = [
        # Broadcast initial variable states from rank 0 to all other processes.
        # This is necessary to ensure consistent initialization of all workers when
        # training is started with random weights or restored from a checkpoint.
        hvd.callbacks.BroadcastGlobalVariablesCallback(0),

        # Average metrics among workers at the end of every epoch.
        #
        # Note: This callback must be in the list before the ReduceLROnPlateau,
        # TensorBoard or other metrics-based callbacks.
        hvd.callbacks.MetricAverageCallback(),

        # Using `lr = 1.0 * hvd.size()` from the very beginning leads to worse final
        # accuracy. Scale the learning rate `lr = 1.0` ---> `lr = 1.0 * hvd.size()` during
        # the first five epochs. See https://arxiv.org/abs/1706.02677 for details.
        hvd.callbacks.LearningRateWarmupCallback(warmup_epochs=5, verbose=1),
        ReduceLROnPlateau(patience=early_stopping_patience / 2,
                          cooldown=early_stopping_patience / 4,
                          verbose=1),
        EarlyStopping(patience=early_stopping_patience, verbose=1),
    ]
    if not debug and hvd.rank() == 0:
        callbacks.extend([
            ModelCheckpoint(os.path.join(
                checkpoint_dir, 'checkpoint.{epoch:05d}-{val_loss:.3f}.hdf5'),
                            save_best_only=True),
            CSVLogger(os.path.join(run_dir, 'history.csv')),
        ])

    if not debug:
        if not os.path.exists(checkpoint_dir):
            os.mkdir(checkpoint_dir)
        _log.info('Starting Training...')

    model.fit_generator(data_generators['train'],
                        nb_examples['train'] // hvd.size(),
                        epochs=nb_epoch,
                        validation_data=data_generators['test'],
                        validation_steps=nb_examples['test'] // hvd.size(),
                        callbacks=callbacks,
                        verbose=keras_verbose if hvd.rank() == 0 else 0)