Ejemplo n.º 1
0
def testing_process():
    """The testing process.
    """
    # Check what device we'll be using
    device = 'cuda' if not debug and cuda.is_available() else 'cpu'
    torch.backends.cudnn.benchmark = True
    # Inform about the device and time and date
    printing.print_intro_messages(device)
    printing.print_msg('Starting training process. '
                       'Debug mode: {}'.format(debug))
    latent_n =args.layers
    lats = args.layers
    do = args.do
    channels = args.channels

    print("Using " + str(lats) + " latent layers between encoder and decoder.", flush=True)
    print("Using " + str(args.channels) + " cnn channels.", flush=True)
    print("Using " + str(args.do/100) + " dropout.", flush=True)
    print(("Using residual connections" if args.residual else "Not using residual connections"))

    outputPath = str(lats)+str(args.channels)+str(args.do)+("res" if args.residual else "")+ _dataset_parent_dir[7:]
    print("Output path: " + outputPath)
    pt_path = ("./outputs/states/" + str(lats) + str(args.channels)+str(int(args.do))+("res" if args.residual else "")+".pt"+ _dataset_parent_dir[7:])
    pt_path = "./outputs/states/mad.pt" + str(lats) +  "latents" + str(args.channels) +"features"+str(args.do/100) + \
              "dropout" + ("res" if args.residual else "") + str(90) + "epochs"
    # print("USING " + str(latent_n) + " conv layers between enc/dec")
    print("Weights path: " + pt_path, flush=True)
    # Set up MaD TwinNet
    with printing.InformAboutProcess('Setting up MaD TwinNet'):

        mad = MaDConv(
            cnn_channels=args.channels,
            inner_kernel_size=1,
            inner_padding=0,
            cnn_dropout=do/100,
            original_input_dim=hyper_parameters['original_input_dim'],
            context_length=hyper_parameters['context_length'],
            latent_n=latent_n,
            residual=args.residual
        )
    with printing.InformAboutProcess('Loading states'):
        mad.load_state_dict(load(pt_path))
        mad = mad.to(device).eval()

    with printing.InformAboutProcess('Initializing data feeder'):
        testing_it = data_feeder.data_feeder_testing(
            window_size=hyper_parameters['window_size'],
            fft_size=hyper_parameters['fft_size'],
            hop_size=hyper_parameters['hop_size'],
            seq_length=hyper_parameters['seq_length'],
            context_length=hyper_parameters['context_length'],
            batch_size=1, debug=debug)

    p_testing = partial(
        _testing_process, mad=mad, device=device,
        seq_length=hyper_parameters['seq_length'],
        context_length=hyper_parameters['context_length'],
        window_size=hyper_parameters['window_size'],
        batch_size=4,#training_constants['batch_size'],
        hop_size=hyper_parameters['hop_size'], outputPath=outputPath)

    printing.print_msg('Testing starts', end='\n\n')

    sdr, sir, sar, total_time = [e for e in zip(*[
        i for index, data in enumerate(testing_it())
        for i in [p_testing(data, index)]])]

    total_time = sum(total_time)

    printing.print_msg('Testing finished', start='\n-- ', end='\n\n')
    printing.print_msg(testing_output_string_all.format(
        sdr=np.median([ii for i in sdr for ii in i[0] if not np.isnan(ii)]),
        sir=np.median([ii for i in sir for ii in i[0] if not np.isnan(ii)]),
        sar=np.median([ii for i in sar for ii in i[0] if not np.isnan(ii)]),
        t=total_time), end='\n\n')

    with printing.InformAboutProcess('Saving results... '):
        with open(metrics_paths['sdr'], 'wb') as f:
            pickle.dump(sdr, f, protocol=2)
        with open(metrics_paths['sir'], 'wb') as f:
            pickle.dump(sir, f, protocol=2)
        with open(metrics_paths['sar'], 'wb') as f:
            pickle.dump(sar, f, protocol=2)
    print("Using " + str(lats) + " latent layers between encoder and decoder.", flush=True)
    print("Using " + str(args.channels) + " cnn channels.", flush=True)
    print("Using " + str(args.do / 100) + " dropout.", flush=True)


    printing.print_msg('That\'s all folks!')
Ejemplo n.º 2
0
def testing_process():
    """The testing process.
    """

    device = 'cuda' if not debug and torch.cuda.is_available() else 'cpu'

    print('\n-- Starting testing process. Debug mode: {}'.format(debug))
    print('-- Process on: {}'.format(device), end='\n\n')
    print('-- Setting up modules... ', end='')

    # Masker modules
    rnn_enc = RNNEnc(hyper_parameters['reduced_dim'],
                     hyper_parameters['context_length'], debug)
    rnn_dec = RNNDec(hyper_parameters['rnn_enc_output_dim'], debug)
    fnn = FNNMasker(hyper_parameters['rnn_enc_output_dim'],
                    hyper_parameters['original_input_dim'],
                    hyper_parameters['context_length'])

    # Denoiser modules
    denoiser = FNNDenoiser(hyper_parameters['original_input_dim'])

    rnn_enc.load_state_dict(torch.load(
        output_states_path['rnn_enc'])).to(device)
    rnn_dec.load_state_dict(torch.load(
        output_states_path['rnn_dec'])).to(device)
    fnn.load_state_dict(torch.load(output_states_path['fnn'])).to(device)
    denoiser.load_state_dict(torch.load(
        output_states_path['denoiser'])).to(device)

    print('done.')

    testing_it = data_feeder_testing(
        window_size=hyper_parameters['window_size'],
        fft_size=hyper_parameters['fft_size'],
        hop_size=hyper_parameters['hop_size'],
        seq_length=hyper_parameters['seq_length'],
        context_length=hyper_parameters['context_length'],
        batch_size=1,
        debug=debug)

    print('-- Testing starts\n')

    sdr = []
    sir = []
    total_time = 0

    for index, data in enumerate(testing_it()):

        s_time = time.time()

        mix, mix_magnitude, mix_phase, voice_true, bg_true = data

        voice_predicted = np.zeros(
            (mix_magnitude.shape[0], hyper_parameters['seq_length'] -
             hyper_parameters['context_length'] * 2,
             hyper_parameters['window_size']),
            dtype=np.float32)

        for batch in range(
                int(mix_magnitude.shape[0] /
                    training_constants['batch_size'])):
            b_start = batch * training_constants['batch_size']
            b_end = (batch + 1) * training_constants['batch_size']

            v_in = torch.from_numpy(
                mix_magnitude[b_start:b_end, :, :]).to(device)

            tmp_voice_predicted = rnn_enc(v_in)
            tmp_voice_predicted = rnn_dec(tmp_voice_predicted)
            tmp_voice_predicted = fnn(tmp_voice_predicted, v_in)
            tmp_voice_predicted = denoiser(tmp_voice_predicted)

            voice_predicted[
                b_start:b_end, :, :] = tmp_voice_predicted.data.cpu().numpy()

        tmp_sdr, tmp_sir = data_process_results_testing(
            index=index,
            voice_true=voice_true,
            bg_true=bg_true,
            voice_predicted=voice_predicted,
            window_size=hyper_parameters['window_size'],
            mix=mix,
            mix_magnitude=mix_magnitude,
            mix_phase=mix_phase,
            hop=hyper_parameters['hop_size'],
            context_length=hyper_parameters['context_length'])

        e_time = time.time()

        print(
            testing_output_string_per_example.format(
                e=index,
                sdr=np.median([i for i in tmp_sdr[0] if not np.isnan(i)]),
                sir=np.median([i for i in tmp_sir[0] if not np.isnan(i)]),
                t=e_time - s_time))

        total_time += e_time - s_time

        sdr.append(tmp_sdr)
        sir.append(tmp_sir)

    print('\n-- Testing finished\n')
    print(
        testing_output_string_all.format(
            sdr=np.median([ii for i in sdr for ii in i[0]
                           if not np.isnan(ii)]),
            sir=np.median([ii for i in sir for ii in i[0]
                           if not np.isnan(ii)]),
            t=total_time))

    print('\n-- Saving results... ', end='')

    with open(metrics_paths['sdr'], 'wb') as f:
        pickle.dump(sdr, f, protocol=2)

    with open(metrics_paths['sir'], 'wb') as f:
        pickle.dump(sir, f, protocol=2)

    print('done!')
    print('-- That\'s all folks!')
Ejemplo n.º 3
0
def testing_process():
    """The testing process.
    """
    # Check what device we'll be using
    device = 'cuda' if not debug and cuda.is_available() else 'cpu'

    # Inform about the device and time and date
    printing.print_intro_messages(device)
    printing.print_msg('Starting training process. '
                       'Debug mode: {}'.format(debug))

    # Set up MaD TwinNet
    with printing.InformAboutProcess('Setting up MaD TwinNet'):
        mad = MaD(rnn_enc_input_dim=hyper_parameters['reduced_dim'],
                  rnn_dec_input_dim=hyper_parameters['rnn_enc_output_dim'],
                  original_input_dim=hyper_parameters['original_input_dim'],
                  context_length=hyper_parameters['context_length'])

    with printing.InformAboutProcess('Loading states'):
        mad.load_state_dict(load(output_states_path['mad']))
        mad = mad.to(device).eval()

    with printing.InformAboutProcess('Initializing data feeder'):
        testing_it = data_feeder.data_feeder_testing(
            window_size=hyper_parameters['window_size'],
            fft_size=hyper_parameters['fft_size'],
            hop_size=hyper_parameters['hop_size'],
            seq_length=hyper_parameters['seq_length'],
            context_length=hyper_parameters['context_length'],
            batch_size=1,
            debug=debug)

    p_testing = partial(_testing_process,
                        mad=mad,
                        device=device,
                        seq_length=hyper_parameters['seq_length'],
                        context_length=hyper_parameters['context_length'],
                        window_size=hyper_parameters['window_size'],
                        batch_size=training_constants['batch_size'],
                        hop_size=hyper_parameters['hop_size'])

    printing.print_msg('Testing starts', end='\n\n')

    sdr, sir, total_time = [
        e for e in zip(*[
            i for index, data in enumerate(testing_it())
            for i in [p_testing(data, index)]
        ])
    ]

    total_time = sum(total_time)

    printing.print_msg('Testing finished', start='\n-- ', end='\n\n')
    printing.print_msg(testing_output_string_all.format(
        sdr=np.median([ii for i in sdr for ii in i[0] if not np.isnan(ii)]),
        sir=np.median([ii for i in sir for ii in i[0] if not np.isnan(ii)]),
        t=total_time),
                       end='\n\n')

    with printing.InformAboutProcess('Saving results... '):
        with open(metrics_paths['sdr'], 'wb') as f:
            pickle.dump(sdr, f, protocol=2)
        with open(metrics_paths['sir'], 'wb') as f:
            pickle.dump(sir, f, protocol=2)

    printing.print_msg('That\'s all folks!')