def random_pred(model_list=['PMTL'],
                n_samp=2,
                min_length=1.0,
                fld=test_fld,
                psongs=test_psongs):
    # Predicts output of specified list of systems on some random samples from the dataset
    # It also takes as arguments the number of samples for evaluation, minimum length of each sample,

    nus_train_data = DL.NUS_48E(data_key, [sr, nfft, wlen, hop])
    sampler = DL.nus_samp(data_dir,
                          1,
                          n_samp,
                          fld,
                          psongs,
                          use_word=True,
                          randomize=True,
                          print_elem=True,
                          min_len=min_length)
    dataload = DataLoader(dataset=nus_train_data,
                          batch_sampler=sampler,
                          collate_fn=my_collate_e8)
    samp_idx = -1
    lsd = []
    for data in dataload:
        # Initialize, Load the networks and their weights properly taking into account the exceptions
        samp_idx += 1
        print 'Processing sample', samp_idx
        for idx in range(len(model_list)):
            cur_model = model_list[idx]
            suffix = suffix_dict[cur_model]
            network2 = defModel.exp_net(512, 512, freq=513).to(device)
            if cur_model == 'B2' or cur_model == 'b2':
                network1 = defModel.net_base(512, 512, freq=513).to(device)
            else:
                network1 = defModel.net_in_v2(512, 512, freq=513).to(device)

            if not (cur_model == 'B1' or cur_model == 'b1'):
                network2.load_state_dict(
                    torch.load('output/models/net2_' + suffix + '.pt',
                               map_location=device))  # Complete
            network1.load_state_dict(
                torch.load('output/models/net1_' + suffix + '.pt',
                           map_location=device))
            network1, network2 = network1.eval(), network2.eval()

            # Make predictions
            encode2 = int(not cur_model == 'B1') * network2(
                Variable(data[3].to(device)))

            pred, encode1 = network1(Variable(data[0].to(device)), encode2)
            pred = pred.cpu().data.numpy()
            pred[pred < 0] = 0

            #Save log-STFTs of input, target and prediction
            saving_dir = 'output/random_predictions/'
            logstft_inp = data[0].numpy()
            logstft_out = data[1].numpy()
            logstft_pred = 1.0 * pred
            np.save(saving_dir + 'inp_lgstft' + str(samp_idx), logstft_inp)
            np.save(saving_dir + 'out_lgstft' + str(samp_idx), logstft_out)
            np.save(saving_dir + 'pred_lgstft' + str(samp_idx), logstft_pred)

            # Get time domain signals
            stft_pred = np.zeros([513, pred.shape[2]])
            stft_pred[:pred.shape[1]] = np.exp(pred[0]) - 1

            time_pred = utils.gl_rec(stft_pred, hop, wlen,
                                     core.istft(stft_pred**1.0, hop, wlen))
            time_inp_orig = core.istft(data[4][0], hop, wlen)
            time_inp_phase = core.istft(data[5][0], hop, wlen)
            time_target_phase = core.istft(data[6][0], hop, wlen)

            # Save predictions
            librosa.output.write_wav(
                saving_dir + 'original_speech_' + str(samp_idx) + '.wav',
                time_inp_orig, sr)
            librosa.output.write_wav(
                saving_dir + 'stretched_speech_' + str(samp_idx) + '.wav',
                time_inp_phase, sr)
            librosa.output.write_wav(
                saving_dir + 'true_singing_' + str(samp_idx) + '.wav',
                time_target_phase, sr)
            librosa.output.write_wav(
                saving_dir + 'predicted_singing_' + str(samp_idx) + cur_model +
                '.wav', time_pred, sr)

    return
def eval_sys(model_list=['PMTL', 'PMSE', 'B1', 'B2'],
             n_samp=30,
             min_length=1.0,
             random=True,
             fld=test_fld,
             psongs=test_psongs):
    # Currently evaluates the specified models on the NUS dataset for the given songs. Default songs comprise of our test set
    # It also takes as arguments the number of samples for evaluation (n_samp), minimum length of speech in each sample (min_length),
    # Returns array of all computed LSD's and prints the mean LSD for each model

    nus_train_data = DL.NUS_48E(data_key, [sr, nfft, wlen, hop])
    sampler = DL.nus_samp(data_dir,
                          1,
                          n_samp,
                          fld,
                          psongs,
                          use_word=True,
                          randomize=random,
                          print_elem=False,
                          min_len=min_length)
    dataload = DataLoader(dataset=nus_train_data,
                          batch_sampler=sampler,
                          collate_fn=my_collate_e8)
    samp_idx = -1
    lsd = []
    for data in dataload:
        # Initialize, Load the networks and their weights properly taking into account the exceptions
        samp_idx += 1
        print 'Processing sample ', samp_idx
        for idx in range(len(model_list)):
            cur_model = model_list[idx]
            suffix = suffix_dict[cur_model]
            network2 = defModel.exp_net(512, 512, freq=513).to(device)
            if cur_model == 'B2':
                network1 = defModel.net_base(512, 512, freq=513).to(device)
            else:
                network1 = defModel.net_in_v2(512, 512, freq=513).to(device)

            if not cur_model == 'B1':
                network2.load_state_dict(
                    torch.load('output/models/net2_' + suffix + '.pt',
                               map_location=device))  # Complete
            network1.load_state_dict(
                torch.load('output/models/net1_' + suffix + '.pt',
                           map_location=device))
            network1, network2 = network1.eval(), network2.eval()

            # Make predictions
            encode2 = int(not cur_model == 'B1') * network2(
                Variable(data[3].to(device)))
            pred, encode1 = network1(Variable(data[0].to(device)), encode2)
            pred = pred.cpu().data.numpy()
            pred[pred < 0] = 0

            #Temporarily save log-STFTs of input target and prediction
            logstft_inp = data[0].numpy()
            logstft_out = data[1].numpy()
            logstft_pred = 1.0 * pred
            np.save('runtime_folder/inp_stft', logstft_inp)
            np.save('runtime_folder/out_stft', logstft_out)
            np.save('runtime_folder/pred_stft', logstft_pred)

            # Get time domain signals
            stft_inp = np.zeros([513, pred.shape[2]])
            stft_pred = np.zeros([513, pred.shape[2]])
            stft_target = np.zeros([513, pred.shape[2]])

            stft_pred[:pred.shape[1]] = np.exp(pred[0]) - 1
            time_pred = utils.gl_rec(stft_pred, hop, wlen,
                                     core.istft(stft_pred**1.0, hop, wlen))
            time_target_phase = core.istft(data[6][0], hop, wlen)

            # Save predictions in the runtime folder
            true_file = 'runtime_folder/runtime_true.wav'
            pred_file = 'runtime_folder/runtime_pred.wav'
            librosa.output.write_wav(true_file, time_target_phase, sr)
            librosa.output.write_wav(pred_file, time_pred, sr)
            calc_lsd = utils.comp_lsd(true_file, pred_file)
            #print cur_model, calc_lsd
            lsd.append(calc_lsd)

    # Print the results
    arr = np.zeros([len(model_list), n_samp])
    for i in range(len(model_list) * n_samp):
        arr[i % len(model_list), i // len(model_list)] = lsd[i]
    for i in range(len(model_list)):
        print model_list[i] + ' (mean LSD):', np.mean(arr[i])

    return lsd
    pc = np.concatenate(
        [pc, np.zeros([pc.shape[0], n_frames - pc.shape[1]])], axis=1)
    stft_inp = np.log(1 + np.abs(stft_inp))
    stft_inp, pc = torch.from_numpy(stft_inp).float().unsqueeze(
        0), torch.from_numpy(pc).float().unsqueeze(0)  # Make tensors

    # Extract output
    encode2 = network2(Variable(pc.to(device)))
    pred, encode1 = network1(Variable(stft_inp.to(device)), encode2)
    pred = pred[0].cpu().data.numpy()
    pred[pred < 0] = 0
    pred = np.exp(pred) - 1
    time_pred = 3.0 * utils.gl_rec(pred, hop, wlen, core.istft(
        pred, hop, wlen))  # Adding a multiplier to increase loudness
    return time_pred


if __name__ == '__main__':
    args = sys.argv[1:]
    suffix = suffix_dict['PMTL']  # Get the suffix of PMTL model
    net1 = defModel.net_in_v2(512, 512, freq=513).to(device)
    net2 = defModel.exp_net(512, 512, freq=513).to(device)
    net2.load_state_dict(
        torch.load('output/models/net2_' + suffix + '.pt',
                   map_location=device))
    net1.load_state_dict(
        torch.load('output/models/net1_' + suffix + '.pt',
                   map_location=device))
    random_pred(['pmtl', 'b2'])
    stats = eval_sys()