Exemplo n.º 1
0
def run_test(config):
    """ Define our model and test it """

    generator = AECNN(
        channel_counts=config.gchan,
        kernel_size=config.gkernel,
        block_size=config.gblocksize,
        dropout=config.gdrop,
    ).cuda()

    generator.load_state_dict(torch.load(config.gcheckpoints))

    # Initialize datasets
    #ev_dataset = wav_dataset(config, 'et', 4)
    ev_dataset = wav_dataset(config, 'et')

    #count = 0
    #score = {'stoi': 0, 'estoi':0, 'sdr':0}
    example = ev_dataset[361]
    print(example['id'])
    data = np.squeeze(
        generator(example['noisy'].cuda()).cpu().detach().numpy())
    #clean = np.squeeze(example['clean'].numpy())
    noisy = np.squeeze(example['noisy'].numpy())
    #with sf.SoundFile('clean.wav', 'w', 16000, 1) as w:
    #    w.write(clean)
    with sf.SoundFile('noisy.wav', 'w', 16000, 1) as w:
        w.write(noisy)
    with sf.SoundFile('test.wav', 'w', 16000, 1) as w:
        w.write(data)
Exemplo n.º 2
0
def run_test(config):
    """ Define our model and test it """

    generator = AECNN(
        channel_counts=config.gchan,
        kernel_size=config.gkernel,
        block_size=config.gblocksize,
        dropout=config.gdrop,
    ).cuda().eval()

    generator.load_state_dict(torch.load(config.gcheckpoints))

    # Initialize datasets
    #for phase in ['tr', 'dt', 'et']:

    max_ch = 6 if config.phase == 'tr' else 1

    count = 0
    for ch in range(max_ch):
        dataset = wav_dataset(config, config.phase, ch)

        with torch.no_grad():
            for example in dataset:
                data = np.squeeze(
                    generator(example['noisy'].cuda()).cpu().detach().numpy())
                fname = make_filename(config, ch, example['id'])
                with sf.SoundFile(fname, 'w', 16000, 1) as w:
                    w.write(data)

                if count % 1000 == 0:
                    print("finished #%d" % count)
                count += 1
Exemplo n.º 3
0
def run_test(config):
    """ Define our model and test it """

    generator = AECNN(
        channel_counts=config.gchan,
        kernel_size=config.gkernel,
        block_size=config.gblocksize,
        dropout=config.gdrop,
    ).cuda()

    generator.load_state_dict(torch.load(config.gcheckpoints))

    # Initialize datasets
    ev_dataset = wav_dataset(config, 'et', 4)

    count = 0
    score = {'stoi': 0, 'estoi': 0, 'sdr': 0}
    for example in ev_dataset:
        data = np.squeeze(
            generator(example['noisy'].cuda()).cpu().detach().numpy())
        clean = np.squeeze(example['clean'].numpy())
        noisy = np.squeeze(example['noisy'].numpy())
        score['stoi'] += stoi(clean, data, 16000, extended=False)
        score['estoi'] += stoi(clean, data, 16000, extended=True)
        score['sdr'] += si_sdr(data, clean)
        count += 1
        #if count == 1:
        #    with sf.SoundFile('clean.wav', 'w', 16000, 1) as w:
        #        w.write(clean)
        #    with sf.SoundFile('noisy.wav', 'w', 16000, 1) as w:
        #        w.write(noisy)
        #    with sf.SoundFile('test.wav', 'w', 16000, 1) as w:
        #        w.write(data)
        #    break

    print('stoi: %f' % (score['stoi'] / count))
    print('estoi: %f' % (score['estoi'] / count))
    print('sdr: %f' % (score['sdr'] / count))
Exemplo n.º 4
0
def run_training(config):
    """ Define our model and train it """

    load_generator = config.gpretrain is not None
    train_generator = config.gcheckpoints is not None

    load_mimic = config.mpretrain is not None
    train_mimic = config.mcheckpoints is not None

    if torch.cuda.is_available():
        config.device = torch.device('cuda')
    else:
        config.device = torch.device('cpu')

    models = {}

    # Build enhancement model
    if load_generator or train_generator:
        models['generator'] = AECNN(
            channel_counts=config.gchan,
            kernel_size=config.gkernel,
            block_size=config.gblocksize,
            dropout=config.gdrop,
            training=train_generator,
        ).to(config.device)

        models['generator'].requires_grad = train_generator

        if load_generator:
            models['generator'].load_state_dict(
                torch.load(config.gpretrain, map_location=config.device))

    # Build acoustic model
    if load_mimic or train_mimic:

        if config.mact == 'rrelu':
            activation = lambda x: torch.nn.functional.rrelu(
                x, training=train_mimic)
        else:
            activation = lambda x: torch.nn.functional.leaky_relu(
                x, negative_slope=0.3)

        models['mimic'] = ResNet(
            input_dim=256,
            output_dim=config.moutdim,
            channel_counts=config.mchan,
            dropout=config.mdrop,
            training=train_mimic,
            activation=activation,
        ).to(config.device)

        models['mimic'].requires_grad = train_mimic

        if load_mimic:
            models['mimic'].load_state_dict(
                torch.load(config.mpretrain, map_location=config.device))

            if config.mimic_weight > 0 or any(
                    config.texture_weights) and train_mimic:
                models['teacher'] = ResNet(
                    input_dim=256,
                    output_dim=config.moutdim,
                    channel_counts=config.mchan,
                    dropout=0,
                    training=False,
                ).to(config.device)

                models['teacher'].requires_grad = False
                models['teacher'].load_state_dict(
                    torch.load(config.mpretrain, map_location=config.device))

    if config.gan_weight > 0:
        models['discriminator'] = Discriminator(
            channel_counts=config.gchan,
            kernel_size=config.gkernel,
            block_size=config.gblocksize,
            dropout=config.gdrop,
            training=True,
        ).to(config.device)

    # Initialize datasets
    tr_dataset = wav_dataset(config, 'tr')
    dt_dataset = wav_dataset(config, 'dt', 4)

    if config.soft_senone_weight > 0:
        print("Pretraining senone embeddings")
        models['embedding'] = Embedding(config.moutdim).to(config.device)
        models['embedding'].pretrain(tr_dataset, models['mimic'],
                                     config.device)
        print("Completed embedding pretraining")

    if config.real_senone_file:
        real_config = config
        real_config.senone_file = config.real_senone_file
        real_config.noisy_flist = config.real_flist
        real_config.noise_flist = None
        real_config.clean_flist = None
        tr_real_dataset = wav_dataset(real_config, 'tr')

    trainer = Trainer(config, models)

    # Run the training
    best_dev_loss = float('inf')
    for epoch in range(config.epochs):
        print("Starting epoch %d" % epoch)

        # Train for one epoch
        start_time = time.time()
        trainer.run_epoch(tr_dataset, training=True)
        total_time = time.time() - start_time

        print("Completed epoch %d in %d seconds" % (epoch, int(total_time)))

        dev_loss, dev_losses = trainer.run_epoch(dt_dataset, training=False)

        print("Dev loss: %f" % dev_loss)
        for key in dev_losses:
            print("%s loss: %f" % (key, dev_losses[key]))

        # Save our model
        if dev_loss < best_dev_loss:
            best_dev_loss = dev_loss
            if train_mimic:
                mfile = os.path.join(config.mcheckpoints, config.mfile)
                torch.save(models['mimic'].state_dict(), mfile)
            if train_generator:
                gfile = os.path.join(config.gcheckpoints, config.gfile)
                torch.save(models['generator'].state_dict(), gfile)
Exemplo n.º 5
0
def run_training(config):
    """ Define our model and train it """

    load_generator = config.gpretrain is not None
    train_generator = config.gcheckpoints is not None

    load_mimic = config.mpretrain is not None
    train_mimic = config.mcheckpoints is not None

    models = {}

    # Build enhancement model
    if load_generator or train_generator:
        models['generator'] = AECNN(
            channel_counts=config.gchan,
            kernel_size=config.gkernel,
            block_size=config.gblocksize,
            dropout=config.gdrop,
            training=train_generator,
        ).cuda()

        models['generator'].requires_grad = train_generator

        if load_generator:
            models['generator'].load_state_dict(torch.load(config.gpretrain))

    # Build acoustic model
    if load_mimic or train_mimic:

        models['mimic'] = ResNet(
            input_dim=256,
            output_dim=config.moutdim,
            channel_counts=config.mchan,
            dropout=config.mdrop,
            training=train_mimic,
        ).cuda()

        models['mimic'].requires_grad = train_mimic

        if load_mimic:
            models['mimic'].load_state_dict(torch.load(config.mpretrain))

            if config.mimic_weight > 0 or any(
                    config.texture_weights) and train_mimic:
                models['teacher'] = ResNet(
                    input_dim=256,
                    output_dim=config.moutdim,
                    channel_counts=config.mchan,
                    dropout=0,
                    training=False,
                ).cuda()

                models['teacher'].requires_grad = False
                models['teacher'].load_state_dict(torch.load(config.mpretrain))

    # Initialize datasets
    tr_dataset = wav_dataset(config, 'tr')
    dt_dataset = wav_dataset(config, 'dt', 4)

    trainer = Trainer(config, models)

    # Run the training
    best_dev_loss = float('inf')
    for epoch in range(config.epochs):
        print("Starting epoch %d" % epoch)

        # Train for one epoch
        start_time = time.time()
        trainer.run_epoch(tr_dataset, training=True)
        total_time = time.time() - start_time

        print("Completed epoch %d in %d seconds" % (epoch, int(total_time)))

        dev_loss = trainer.run_epoch(dt_dataset, training=False)

        print("Dev loss: %f" % dev_loss)

        # Save our model
        if dev_loss < best_dev_loss:
            best_dev_loss = dev_loss
            if train_mimic:
                mfile = os.path.join(config.mcheckpoints, config.mfile)
                torch.save(models['mimic'].state_dict(), mfile)
            if train_generator:
                gfile = os.path.join(config.gcheckpoints, config.gfile)
                torch.save(models['generator'].state_dict(), gfile)