def training_process():
    """The training process.
    """

    print('\n-- Starting training process. Debug mode: {}'.format(debug))
    print('-- Setting up modules... ', end='')
    # Masker modules
    rnn_enc = RNNEnc(hyper_parameters['reduced_dim'],
                     hyper_parameters['context_length'], debug)
    rnn_dec = RNNDec(hyper_parameters['rnn_enc_output_dim'], debug)
    fnn = FNNMasker(hyper_parameters['rnn_enc_output_dim'],
                    hyper_parameters['original_input_dim'],
                    hyper_parameters['context_length'])

    # Denoiser modules
    denoiser = CNNDenoiser(hyper_parameters['original_input_dim'])
    # TwinNet regularization modules
    twin_net_rnn_dec = TwinRNNDec(hyper_parameters['rnn_enc_output_dim'],
                                  debug)
    twin_net_fnn_masker = FNNMasker(hyper_parameters['rnn_enc_output_dim'],
                                    hyper_parameters['original_input_dim'],
                                    hyper_parameters['context_length'])
    affine_transform = AffineTransform(hyper_parameters['rnn_enc_output_dim'])

    if not debug and torch.has_cudnn:
        rnn_enc = rnn_enc.cuda()
        rnn_dec = rnn_dec.cuda()
        fnn = fnn.cuda()
        denoiser = denoiser.cuda()
        twin_net_rnn_dec = twin_net_rnn_dec.cuda()
        twin_net_fnn_masker = twin_net_fnn_masker.cuda()
        affine_transform = affine_transform.cuda()

    print('done.')
    print('-- Setting up optimizes and losses... ', end='')

    # Objectives and penalties
    loss_masker = kl
    loss_denoiser = kl
    loss_twin = kl
    reg_twin = l2_loss
    reg_fnn_masker = sparsity_penalty
    reg_fnn_dec = l2_reg_squared

    # Optimizer
    optimizer = optim.Adam(
        list(rnn_enc.parameters()) + list(rnn_dec.parameters()) +
        list(fnn.parameters()) + list(denoiser.parameters()) +
        list(twin_net_rnn_dec.parameters()) +
        list(twin_net_fnn_masker.parameters()) +
        list(affine_transform.parameters()),
        lr=hyper_parameters['learning_rate'])

    print('done.')

    # Initializing data feeder
    epoch_it = data_feeder_training(
        window_size=hyper_parameters['window_size'],
        fft_size=hyper_parameters['fft_size'],
        hop_size=hyper_parameters['hop_size'],
        seq_length=hyper_parameters['seq_length'],
        context_length=hyper_parameters['context_length'],
        batch_size=training_constants['batch_size'],
        files_per_pass=training_constants['files_per_pass'],
        debug=debug)

    print('-- Training starts\n')

    # Training loop starts
    for epoch in range(training_constants['epochs']):
        epoch_l_m = []
        epoch_l_d = []
        epoch_l_tw = []
        epoch_l_twin = []

        time_start = time.time()

        # Epoch loop
        for data in epoch_it():
            v_in = Variable(torch.from_numpy(data[0]))
            v_j = Variable(torch.from_numpy(data[1]))

            if not debug and torch.has_cudnn:
                v_in = v_in.cuda()
                v_j = v_j.cuda()

            # Masker pass
            h_enc = rnn_enc(v_in)
            h_dec = rnn_dec(h_enc)
            v_j_filt_prime = fnn(h_dec, v_in)

            # TwinNet pass
            h_dec_twin = twin_net_rnn_dec(h_enc)
            v_j_filt_prime_twin = twin_net_fnn_masker(h_dec_twin, v_in)

            # Twin net regularization
            affine_output = affine_transform(h_dec)

            # Denoiser pass
            v_j_filt = denoiser(v_j_filt_prime)

            optimizer.zero_grad()

            # Calculate losses
            l_m = loss_masker(v_j_filt_prime, v_j)
            l_d = loss_denoiser(v_j_filt, v_j)
            l_tw = loss_twin(v_j_filt_prime_twin, v_j)
            l_twin = reg_twin(affine_output, h_dec_twin.detach())

            # Make MaD TwinNet objective
            loss = l_m + l_d + l_tw + (hyper_parameters['lambda_l_twin'] * l_twin) + \
                   (hyper_parameters['lambda_1'] * reg_fnn_masker(fnn.linear_layer.weight)) + \
                   (hyper_parameters['lambda_2'] * reg_fnn_dec(denoiser.fnn_dec.weight))

            # Backward pass
            loss.backward()

            # Gradient norm clipping
            torch.nn.utils.clip_grad_norm(
                list(rnn_enc.parameters()) + list(rnn_dec.parameters()) +
                list(fnn.parameters()) + list(denoiser.parameters()) +
                list(twin_net_rnn_dec.parameters()) +
                list(twin_net_fnn_masker.parameters()) +
                list(affine_transform.parameters()),
                max_norm=hyper_parameters['max_grad_norm'],
                norm_type=2)

            # Optimize
            optimizer.step()

            # Log losses
            epoch_l_m.append(l_m.data[0])
            epoch_l_d.append(l_d.data[0])
            epoch_l_tw.append(l_tw.data[0])
            epoch_l_twin.append(l_twin.data[0])

        time_end = time.time()

        # Tell us what happened
        print(
            training_output_string.format(
                ep=epoch,
                l_m=torch.mean(torch.FloatTensor(epoch_l_m)),
                l_d=torch.mean(torch.FloatTensor(epoch_l_d)),
                l_tw=torch.mean(torch.FloatTensor(epoch_l_tw)),
                l_twin=torch.mean(torch.FloatTensor(epoch_l_twin)),
                t=time_end - time_start))

    # Kindly end and save the model
    print('\n-- Training done.')
    print('-- Saving model.. ', end='')
    torch.save(rnn_enc.state_dict(), output_states_path['rnn_enc'])
    torch.save(rnn_dec.state_dict(), output_states_path['rnn_dec'])
    torch.save(fnn.state_dict(), output_states_path['fnn'])
    torch.save(denoiser.state_dict(), output_states_path['denoiser'])
    print('done.')
    print('-- That\'s all folks!')
示例#2
0
def use_me_process(sources_list, output_file_names):
    """The usage process.

    :param sources_list: The file names to be used.
    :type sources_list: list[str]
    :param output_file_names: The output file names to be used.
    :type output_file_names: list[list[str]]
    """

    print('\n-- Welcome to MaD TwinNet.')
    if debug:
        print(
            '\n-- Cannot proceed in debug mode. Please set debug=False at the settings file.'
        )
        print('-- Exiting.')
        exit(-1)
    print(
        '-- Now I will extract the voice and the background music from the provided files'
    )

    # Masker modules
    rnn_enc = RNNEnc(hyper_parameters['reduced_dim'],
                     hyper_parameters['context_length'], debug)
    rnn_dec = RNNDec(hyper_parameters['rnn_enc_output_dim'], debug)
    fnn = FNNMasker(hyper_parameters['rnn_enc_output_dim'],
                    hyper_parameters['original_input_dim'],
                    hyper_parameters['context_length'])

    # Denoiser modules
    denoiser = FNNDenoiser(hyper_parameters['original_input_dim'])

    rnn_enc.load_state_dict(torch.load(output_states_path['rnn_enc']))
    rnn_dec.load_state_dict(torch.load(output_states_path['rnn_dec']))
    fnn.load_state_dict(torch.load(output_states_path['fnn']))
    denoiser.load_state_dict(torch.load(output_states_path['denoiser']))

    if not debug and torch.has_cudnn:
        rnn_enc = rnn_enc.cuda()
        rnn_dec = rnn_dec.cuda()
        fnn = fnn.cuda()
        denoiser = denoiser.cuda()

    testing_it = data_feeder_testing(
        window_size=hyper_parameters['window_size'],
        fft_size=hyper_parameters['fft_size'],
        hop_size=hyper_parameters['hop_size'],
        seq_length=hyper_parameters['seq_length'],
        context_length=hyper_parameters['context_length'],
        batch_size=1,
        debug=debug,
        sources_list=sources_list)

    print('-- Let\'s go!\n')
    total_time = 0

    for index, data in enumerate(testing_it()):

        s_time = time.time()

        mix, mix_magnitude, mix_phase, voice_true, bg_true = data

        voice_predicted = np.zeros(
            (mix_magnitude.shape[0], hyper_parameters['seq_length'] -
             hyper_parameters['context_length'] * 2,
             hyper_parameters['window_size']),
            dtype=np.float32)

        for batch in range(
                int(mix_magnitude.shape[0] /
                    training_constants['batch_size'])):
            b_start = batch * training_constants['batch_size']
            b_end = (batch + 1) * training_constants['batch_size']

            v_in = Variable(
                torch.from_numpy(mix_magnitude[b_start:b_end, :, :]))

            if not debug and torch.has_cudnn:
                v_in = v_in.cuda()

            tmp_voice_predicted = rnn_enc(v_in)
            tmp_voice_predicted = rnn_dec(tmp_voice_predicted)
            tmp_voice_predicted = fnn(tmp_voice_predicted, v_in)
            tmp_voice_predicted = denoiser(tmp_voice_predicted)

            voice_predicted[
                b_start:b_end, :, :] = tmp_voice_predicted.data.cpu().numpy()

        data_process_results_testing(
            index=index,
            voice_true=voice_true,
            bg_true=bg_true,
            voice_predicted=voice_predicted,
            window_size=hyper_parameters['window_size'],
            mix=mix,
            mix_magnitude=mix_magnitude,
            mix_phase=mix_phase,
            hop=hyper_parameters['hop_size'],
            context_length=hyper_parameters['context_length'],
            output_file_name=output_file_names[index])

        e_time = time.time()

        print(
            usage_output_string_per_example.format(f=sources_list[index],
                                                   t=e_time - s_time))

        total_time += e_time - s_time

    print('\n-- Testing finished\n')
    print(usage_output_string_total.format(t=total_time))
    print('-- That\'s all folks!')
示例#3
0
def testing_process():
    """The testing process.
    """

    device = 'cuda' if not debug and torch.cuda.is_available() else 'cpu'

    print('\n-- Starting testing process. Debug mode: {}'.format(debug))
    print('-- Process on: {}'.format(device), end='\n\n')
    print('-- Setting up modules... ', end='')

    # Masker modules
    rnn_enc = RNNEnc(hyper_parameters['reduced_dim'],
                     hyper_parameters['context_length'], debug)
    rnn_dec = RNNDec(hyper_parameters['rnn_enc_output_dim'], debug)
    fnn = FNNMasker(hyper_parameters['rnn_enc_output_dim'],
                    hyper_parameters['original_input_dim'],
                    hyper_parameters['context_length'])

    # Denoiser modules
    denoiser = FNNDenoiser(hyper_parameters['original_input_dim'])

    rnn_enc.load_state_dict(torch.load(
        output_states_path['rnn_enc'])).to(device)
    rnn_dec.load_state_dict(torch.load(
        output_states_path['rnn_dec'])).to(device)
    fnn.load_state_dict(torch.load(output_states_path['fnn'])).to(device)
    denoiser.load_state_dict(torch.load(
        output_states_path['denoiser'])).to(device)

    print('done.')

    testing_it = data_feeder_testing(
        window_size=hyper_parameters['window_size'],
        fft_size=hyper_parameters['fft_size'],
        hop_size=hyper_parameters['hop_size'],
        seq_length=hyper_parameters['seq_length'],
        context_length=hyper_parameters['context_length'],
        batch_size=1,
        debug=debug)

    print('-- Testing starts\n')

    sdr = []
    sir = []
    total_time = 0

    for index, data in enumerate(testing_it()):

        s_time = time.time()

        mix, mix_magnitude, mix_phase, voice_true, bg_true = data

        voice_predicted = np.zeros(
            (mix_magnitude.shape[0], hyper_parameters['seq_length'] -
             hyper_parameters['context_length'] * 2,
             hyper_parameters['window_size']),
            dtype=np.float32)

        for batch in range(
                int(mix_magnitude.shape[0] /
                    training_constants['batch_size'])):
            b_start = batch * training_constants['batch_size']
            b_end = (batch + 1) * training_constants['batch_size']

            v_in = torch.from_numpy(
                mix_magnitude[b_start:b_end, :, :]).to(device)

            tmp_voice_predicted = rnn_enc(v_in)
            tmp_voice_predicted = rnn_dec(tmp_voice_predicted)
            tmp_voice_predicted = fnn(tmp_voice_predicted, v_in)
            tmp_voice_predicted = denoiser(tmp_voice_predicted)

            voice_predicted[
                b_start:b_end, :, :] = tmp_voice_predicted.data.cpu().numpy()

        tmp_sdr, tmp_sir = data_process_results_testing(
            index=index,
            voice_true=voice_true,
            bg_true=bg_true,
            voice_predicted=voice_predicted,
            window_size=hyper_parameters['window_size'],
            mix=mix,
            mix_magnitude=mix_magnitude,
            mix_phase=mix_phase,
            hop=hyper_parameters['hop_size'],
            context_length=hyper_parameters['context_length'])

        e_time = time.time()

        print(
            testing_output_string_per_example.format(
                e=index,
                sdr=np.median([i for i in tmp_sdr[0] if not np.isnan(i)]),
                sir=np.median([i for i in tmp_sir[0] if not np.isnan(i)]),
                t=e_time - s_time))

        total_time += e_time - s_time

        sdr.append(tmp_sdr)
        sir.append(tmp_sir)

    print('\n-- Testing finished\n')
    print(
        testing_output_string_all.format(
            sdr=np.median([ii for i in sdr for ii in i[0]
                           if not np.isnan(ii)]),
            sir=np.median([ii for i in sir for ii in i[0]
                           if not np.isnan(ii)]),
            t=total_time))

    print('\n-- Saving results... ', end='')

    with open(metrics_paths['sdr'], 'wb') as f:
        pickle.dump(sdr, f, protocol=2)

    with open(metrics_paths['sir'], 'wb') as f:
        pickle.dump(sir, f, protocol=2)

    print('done!')
    print('-- That\'s all folks!')