def perform_frontend_training():
    print('ID: ' + exp_settings['exp_id'])
    # Instantiating data handler
    io_dealer = helpers.DataIO(exp_settings)

    # Number of file sets
    num_of_sets = exp_settings['num_of_multitracks'] // exp_settings['set_size']

    # Initialize modules
    if exp_settings['visualize']:
        win_viz, win_viz_b = visualize.init_visdom()  # Web loss plotting
    analysis, synthesis = nn_loaders.build_frontend_model(
        flag='training', exp_settings=exp_settings)

    # Expected shapes
    data_shape = (exp_settings['batch_size'],
                  exp_settings['d_p_length'] * exp_settings['fs'])
    noise_sampler = Normal(
        torch.zeros(data_shape),
        torch.ones(data_shape) * exp_settings['noise_scalar'])

    # Initialize optimizer and add the parameters that will be updated
    parameters_list = list(analysis.parameters()) + list(
        synthesis.parameters())
    optimizer = torch.optim.Adam(parameters_list,
                                 lr=exp_settings['learning_rate'])
    # Start of the training
    batch_indx = 0
    for epoch in range(1, exp_settings['epochs'] + 1):
        for file_set in range(1, num_of_sets + 1):
            # Load a sub-set of the recordings
            _, vox, bkg = io_dealer.get_data(file_set,
                                             exp_settings['set_size'],
                                             monaural=exp_settings['monaural'])

            # Create batches
            vox = io_dealer.gimme_batches(vox)
            bkg = io_dealer.gimme_batches(bkg)

            # Compute the total number of batches contained in this sub-set
            num_batches = vox.shape[0] // exp_settings['batch_size']

            # Compute permutations for random shuffling
            perm_in_vox = np.random.permutation(vox.shape[0])
            perm_in_bkg = np.random.permutation(bkg.shape[0])
            for batch in range(num_batches):
                shuf_ind_vox = perm_in_vox[batch *
                                           exp_settings['batch_size']:(batch +
                                                                       1) *
                                           exp_settings['batch_size']]
                shuf_ind_bkg = perm_in_bkg[batch *
                                           exp_settings['batch_size']:(batch +
                                                                       1) *
                                           exp_settings['batch_size']]
                vox_tr_batch = io_dealer.batches_from_numpy(
                    vox[shuf_ind_vox, :])
                bkg_tr_batch = io_dealer.batches_from_numpy(
                    bkg[shuf_ind_bkg, :])

                vox_var = torch.autograd.Variable(vox_tr_batch,
                                                  requires_grad=False)
                mix_var = torch.autograd.Variable(vox_tr_batch + bkg_tr_batch,
                                                  requires_grad=False)

                # Sample noise
                noise = torch.autograd.Variable(
                    noise_sampler.sample().cuda().float(), requires_grad=False)

                # 0 Mean
                vox_var -= vox_var.mean()
                mix_var -= mix_var.mean()

                # Target source forward pass
                vox_coeff = analysis.forward(vox_var + noise)
                waveform = synthesis.forward(
                    vox_coeff, use_sorting=exp_settings['dict_sorting'])

                # Mixture signal forward pass
                mix_coeff = analysis.forward(mix_var)

                # Loss functions
                rec_loss = losses.neg_snr(vox_var, waveform)
                smt_loss = exp_settings[
                    'lambda_reg'] * losses.tot_variation_2d(mix_coeff)

                loss = rec_loss + smt_loss

                # Optimize for reconstruction & smoothness
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                if exp_settings['visualize']:
                    # Visualization
                    win_viz = visualize.viz.line(
                        X=np.arange(batch_indx, batch_indx + 1),
                        Y=np.reshape(rec_loss.item(), (1, )),
                        win=win_viz,
                        update='append')
                    win_viz_b = visualize.viz.line(
                        X=np.arange(batch_indx, batch_indx + 1),
                        Y=np.reshape(smt_loss.item(), (1, )),
                        win=win_viz_b,
                        update='append')
                    batch_indx += 1

        print('--- Saving Model ---')
        # Check for numerical stability, had some trouble with sinc-net
        if not torch.isnan(loss):
            torch.save(
                analysis.state_dict(),
                'results/analysis_' + exp_settings['exp_id'] + '.pytorch')
            torch.save(
                synthesis.state_dict(),
                'results/synthesis_' + exp_settings['exp_id'] + '.pytorch')
        else:
            break

    return None
Beispiel #2
0
def perform_training():
    # Check if saving path exists
    if not (os.path.isdir(
            os.path.join("results/" + exp_settings['split_name']))):
        print(
            'Saving directory was not found... Creating a new folder to store the results!'
        )
        os.makedirs(os.path.join("results/" + exp_settings['split_name']))

    # Get data dictionary
    data_dict = helpers.csv_to_dict(training=True)
    training_keys = sorted(list(
        data_dict.keys()))[0:exp_settings['split_training_indx']]
    print('Training on: ' + " ".join(training_keys))

    # Get data
    x, y, _ = helpers.fetch_data(data_dict, training_keys)
    x *= 0.99 / np.max(np.abs(x))

    d_p_length_samples = exp_settings['d_p_length'] * exp_settings[
        'fs']  # Length in samples

    # Initialize NN modules
    dropout = torch.nn.Dropout(exp_settings['drp_rate']).cuda()
    win_viz, _ = visualize.init_visdom()  # Web loss plotting
    dft_analysis, mel_analysis, pcen, gru_enc, gru_dec, fc_layer, label_smoother = build_model(
        flag='training')

    # Criterion
    bce_func = torch.nn.BCEWithLogitsLoss(size_average=True)

    # Initialize optimizer and add the parameters that will be updated
    if exp_settings['end2end']:
        parameters_list = list(dft_analysis.parameters()) + list(mel_analysis.parameters()) + list(pcen.parameters())\
                          + list(gru_enc.parameters()) + list(gru_dec.parameters()) + list(fc_layer.parameters())
    else:
        parameters_list = list(pcen.parameters()) + list(gru_enc.parameters())\
                          + list(gru_dec.parameters()) + list(fc_layer.parameters())

    optimizer = torch.optim.Adam(parameters_list, lr=1e-4)
    scheduler_n = StepLR(optimizer,
                         1,
                         gamma=exp_settings['learning_rate_drop'])
    scheduler_p = StepLR(optimizer,
                         1,
                         gamma=exp_settings['learning_date_incr'])

    # Start of the training
    batch_indx = 0
    number_of_data_points = len(x) // d_p_length_samples
    prv_cls_error = 100.
    best_error = 50.
    for epoch in range(1, exp_settings['epochs'] + 1):
        # Validation
        if not epoch == 1:
            cls_err = perform_validation([
                dft_analysis, mel_analysis, pcen, gru_enc, gru_dec, fc_layer,
                label_smoother
            ])

            if prv_cls_error - cls_err > 0:
                # Increase learning rate
                scheduler_p.step()
                if cls_err < best_error:
                    # Update best error
                    best_error = cls_err
                    print('--- Saving Model ---')
                    torch.save(
                        pcen.state_dict(),
                        os.path.join('results', exp_settings['split_name'],
                                     'en_pcen_bs_drp.pytorch'))
                    torch.save(
                        gru_enc.state_dict(),
                        os.path.join('results', exp_settings['split_name'],
                                     'en_gru_enc_bs_drp.pytorch'))
                    torch.save(
                        gru_dec.state_dict(),
                        os.path.join('results', exp_settings['split_name'],
                                     'en_gru_dec_bs_drp.pytorch'))
                    torch.save(
                        fc_layer.state_dict(),
                        os.path.join('results', exp_settings['split_name'],
                                     'en_cls_bs_drp.pytorch'))
            else:
                # Decrease learning rate
                scheduler_n.step()

            # Update classification error
            prv_cls_error = cls_err

        # Shuffle between sequences
        shuffled_data_points = np.random.permutation(
            np.arange(0, number_of_data_points))

        # Constructing batches
        available_batches = len(
            shuffled_data_points) // exp_settings['batch_size']
        for batch in tqdm(range(available_batches)):
            x_d_p, y_d_p = helpers.gimme_batches(batch, shuffled_data_points,
                                                 x, y)
            x_cuda = torch.autograd.Variable(
                torch.from_numpy(x_d_p).cuda(),
                requires_grad=False).float().detach()
            y_cuda = torch.autograd.Variable(
                torch.from_numpy(y_d_p).cuda(),
                requires_grad=False).float().detach()

            # Forward analysis pass: Input data
            x_real, x_imag = dft_analysis.forward(x_cuda)
            # Magnitude computation
            mag = torch.sqrt(x_real.pow(2) + x_imag.pow(2))

            # Mel analysis
            mel_mag = torch.autograd.Variable(mel_analysis.forward(mag).data,
                                              requires_grad=True).cuda()

            # Learned normalization
            mel_mag_pr = pcen.forward(mel_mag)

            # GRUs
            dr_mel_p = dropout(mel_mag_pr)
            h_enc = gru_enc.forward(dr_mel_p)
            h_dec = gru_dec.forward(h_enc)
            # Classifier
            _, vad_prob = fc_layer.forward(h_dec, mel_mag_pr)

            # Target data preparation
            y_true = label_smoother.forward(y_cuda).detach()
            vad_true = torch.autograd.Variable(y_true.data,
                                               requires_grad=True).cuda()

            # Loss
            loss = bce_func(vad_prob, vad_true)

            # Optimization
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm(parameters_list,
                                          max_norm=60,
                                          norm_type=2)
            optimizer.step()

            # Update graph
            win_viz = visualize.viz.line(X=np.arange(batch_indx,
                                                     batch_indx + 1),
                                         Y=np.reshape(loss.data[0], (1, )),
                                         win=win_viz,
                                         update='append')
            batch_indx += 1

    return None
Beispiel #3
0
def perform_frontend_training():
    print('ID: ' + exp_settings['exp_id'])
    # Instantiating data handler
    io_dealer = helpers.DataIO(exp_settings=exp_settings)

    # Number of file sets
    num_of_sets = exp_settings['num_of_multitracks'] // exp_settings['set_size']

    # Initialize modules
    # Initialize modules
    if exp_settings['visualize']:
        win_viz, win_viz_b = visualize.init_visdom()  # Web loss plotting
    analysis, synthesis = nn_loaders.build_frontend_model(
        flag='training', exp_settings=exp_settings)
    disc = nn_loaders.build_discriminator(flag='training',
                                          exp_settings=exp_settings)
    sigmoid = torch.nn.Sigmoid()

    # Expected shapes
    data_shape = (exp_settings['batch_size'],
                  exp_settings['d_p_length'] * exp_settings['fs'])
    noise_sampler = Normal(
        torch.zeros(data_shape),
        torch.ones(data_shape) * exp_settings['noise_scalar'])

    # Initialize optimizer and add the parameters that will be updated
    parameters_list = list(analysis.parameters()) + list(
        synthesis.parameters()) + list(disc.parameters())

    optimizer = torch.optim.Adam(parameters_list,
                                 lr=exp_settings['learning_rate'])

    # Start of the training
    batch_indx = 0
    for epoch in range(1, exp_settings['epochs'] + 1):
        for file_set in range(1, num_of_sets + 1):
            # Load a sub-set of the recordings
            _, vox, bkg = io_dealer.get_data(file_set,
                                             exp_settings['set_size'],
                                             monaural=exp_settings['monaural'])

            # Create batches
            vox = io_dealer.gimme_batches(vox)
            bkg = io_dealer.gimme_batches(bkg)

            # Compute the total number of batches contained in this sub-set
            num_batches = vox.shape[0] // exp_settings['batch_size']

            # Compute permutations for random shuffling
            perm_in_vox = np.random.permutation(vox.shape[0])
            perm_in_bkg = np.random.permutation(bkg.shape[0])
            for batch in range(num_batches):
                shuf_ind_vox = perm_in_vox[batch *
                                           exp_settings['batch_size']:(batch +
                                                                       1) *
                                           exp_settings['batch_size']]
                shuf_ind_bkg = perm_in_bkg[batch *
                                           exp_settings['batch_size']:(batch +
                                                                       1) *
                                           exp_settings['batch_size']]
                vox_tr_batch = io_dealer.batches_from_numpy(
                    vox[shuf_ind_vox, :])
                bkg_tr_batch = io_dealer.batches_from_numpy(
                    bkg[shuf_ind_bkg, :])

                vox_var = torch.autograd.Variable(vox_tr_batch,
                                                  requires_grad=False)
                bkg_var = torch.autograd.Variable(bkg_tr_batch,
                                                  requires_grad=False)
                mix_var = torch.autograd.Variable(vox_tr_batch + bkg_tr_batch,
                                                  requires_grad=False)

                # Sample noise
                noise = torch.autograd.Variable(
                    noise_sampler.sample().cuda().float(), requires_grad=False)

                # 0 Mean
                vox_var -= vox_var.mean()
                bkg_var -= bkg_tr_batch.mean()
                mix_var -= mix_var.mean()

                # Target source forward pass
                vox_coeff = analysis.forward(vox_var + noise)
                waveform = synthesis.forward(
                    vox_coeff, use_sorting=exp_settings['dict_sorting'])

                # Mixture and Background signals forward pass
                mix_coeff = analysis.forward(mix_var)
                bkg_coeff = analysis.forward(bkg_var)

                # Loss functions
                rec_loss = losses.neg_snr(vox_var, waveform)
                smt_loss = exp_settings[
                    'lambda_reg'] * losses.tot_variation_2d(mix_coeff)

                loss = rec_loss + smt_loss

                # Optimize for reconstruction & smoothness
                optimizer.zero_grad()
                loss.backward(retain_graph=True)
                optimizer.step()

                # Optimize with discriminator
                # Remove silent frames
                c_loud_x = (10. * (vox_tr_batch.norm(
                    2., dim=1, keepdim=True).log10())).data.cpu().numpy()
                # Which segments are below the threshold?
                loud_locs = np.where(
                    c_loud_x > exp_settings['loudness_threshold'])[0]
                vox_coeff = vox_coeff[loud_locs]
                if vox_coeff.size(0) > 2:
                    # Make sure we are getting unmatched pairs
                    bkg_coeff = bkg_coeff[loud_locs]
                    vox_coeff_shf = vox_coeff[np.random.permutation(
                        vox_coeff.size(0))]

                    # Sample from discriminator
                    y_neg = sigmoid(disc.forward(vox_coeff, bkg_coeff))
                    y_pos = sigmoid(disc.forward(vox_coeff, vox_coeff_shf))

                    # Compute discriminator loss
                    disc_loss = losses.bce(y_pos, y_neg)

                    # Optimize the discriminator
                    optimizer.zero_grad()
                    disc_loss.backward()
                    optimizer.step()

                else:
                    pass

                if exp_settings['visualize']:
                    # Visualization
                    win_viz = visualize.viz.line(
                        X=np.arange(batch_indx, batch_indx + 1),
                        Y=np.reshape(rec_loss.item(), (1, )),
                        win=win_viz,
                        update='append')
                    win_viz_b = visualize.viz.line(
                        X=np.arange(batch_indx, batch_indx + 1),
                        Y=np.reshape(disc_loss.item(), (1, )),
                        win=win_viz_b,
                        update='append')
                    batch_indx += 1

        if not torch.isnan(loss) and not torch.isnan(disc_loss):
            print('--- Saving Model ---')
            torch.save(
                analysis.state_dict(),
                'results/analysis_' + exp_settings['exp_id'] + '.pytorch')
            torch.save(
                synthesis.state_dict(),
                'results/synthesis_' + exp_settings['exp_id'] + '.pytorch')
            torch.save(disc.state_dict(),
                       'results/disc_' + exp_settings['exp_id'] + '.pytorch')
        else:
            break

    return None