Ejemplo n.º 1
0
def train(opt):
    '''
    data could be loaded to a dictionary with "train"/"val"/"test" pointers (Need to improve the below part)
    '''

    print_every = opt.print_every
    showatt_every = opt.print_every
    plot_every = opt.print_every

    full_model_name = opt.model_name \
    + '_hs1' + str(opt.hs1) + '_hs2' + str(opt.hs2) + '_pfnet_hs1' + str(opt.pfnet_hs1) + '_r' + str(opt.r) \
    + '_lr' + str(opt.lr) + '_b1' + str(opt.b1) + '_b2' + str(opt.b2) \
    + '_dp' + str(opt.dp) \
    + '_gc' + str(opt.gcth) \
    + '_wtinit' + str(opt.wtinit_meth) \
    + '_lw' + str(int(opt.load_wts)) \
    + '_ef' + str(int(opt.embedding_flag)) \
    + '_rf' + str(int(opt.residual_flag))

    print(full_model_name)

    logging.basicConfig(filename=opt.log_folder + full_model_name + '.log',
                        filemode='w',
                        level=logging.DEBUG,
                        format='%(asctime)s - %(levelname)s - %(message)s')

    r = opt.r

    fid = open(opt.feats_dir + 'train_list.txt')
    train_list = fid.read().splitlines()
    fid.close()

    fid = open(opt.feats_dir + 'val_list.txt')
    val_list = fid.read().splitlines()
    fid.close()

    all_prompts = make_prompts_dict_v2(opt.feats_dir + 'txt.done.data',
                                       train_list + val_list)
    vocab = make_vocab(all_prompts)
    print(vocab)

    # Load training data
    train_prompts = make_prompts_dict_v2(opt.feats_dir + 'txt.done.data',
                                         train_list)
    phn2id, id2phn = phn2id2phn(vocab)
    file_list = train_prompts.keys()
    print(len(file_list), len(train_list))

    # Load stats of mfcc
    mo1 = np.load(opt.stats_dir + 'mo.npy')
    so1 = np.load(opt.stats_dir + 'so.npy')
    mo1 = mo1.astype('float32')
    so1 = so1.astype('float32')
    nml_vec1 = np.arange(0, mo1.shape[1])

    # Load stats of spectrum
    mo2 = np.load(opt.pfnet_stats_dir + 'mo.npy')
    so2 = np.load(opt.pfnet_stats_dir + 'so.npy')
    mo2 = mo2.astype('float32')
    so2 = so2.astype('float32')
    nml_vec2 = np.arange(0, mo2.shape[1])

    # Load validation data
    val_prompts = make_prompts_dict_v2(opt.feats_dir + 'txt.done.data',
                                       val_list)

    # Initialize model
    vocab_size = len(vocab)
    op_dim = 60
    encoder = encoders.EncoderCBL(vocab_size, opt.hs2, opt.hs1)
    if opt.residual_flag:
        if opt.r == 2:
            decoder = decoders.AttnDecoderLSTM3L_R2_Rescon(
                op_dim, opt.hs2, op_dim, 1, opt.dp)
        if opt.r == 3:
            decoder = decoders.AttnDecoderLSTM3L_R3_Rescon(
                op_dim, opt.hs2, op_dim, 1, opt.dp)
        if opt.r == 4:
            decoder = decoders.AttnDecoderLSTM3L_R4_Rescon(
                op_dim, opt.hs2, op_dim, 1, opt.dp)
        if opt.r == 5:
            decoder = decoders.AttnDecoderLSTM3L_R5_Rescon(
                op_dim, opt.hs2, op_dim, 1, opt.dp)
    else:
        decoder = decoders.AttnDecoderLSTM3L_R2(op_dim, opt.hs2, op_dim, 1,
                                                opt.dp)

    op_dim1 = 513
    pfnet = encoders.EncoderBLSTM_WOE_1L(op_dim, opt.pfnet_hs1, op_dim1)

    encoder = encoder.cuda() if use_cuda else encoder
    decoder = decoder.cuda() if use_cuda else decoder
    pfnet = pfnet.cuda() if use_cuda else pfnet
    criterion = torch.nn.L1Loss(size_average=False)

    if opt.load_wts:
        load_model_name_pfx = '../../wt/s2s_enc_blstm_dec_lstm3l_pfnet_blstm1L_nopfnetloss__hs1250_hs2500_pfnet_hs1250_r' + str(
            opt.r
        ) + '_lr0.0003_b10.9_b20.99_dp0.5_gc0.0_wtinitdefault_init_lw0_ef1_rf1_'
        load_model_name_sfx = '.pth'

        # load model
        enc_state_dict = torch.load(load_model_name_pfx + 'enc' +
                                    load_model_name_sfx,
                                    map_location=lambda storage, loc: storage)
        encoder.load_state_dict(enc_state_dict)

        dec_state_dict = torch.load(load_model_name_pfx + 'dec' +
                                    load_model_name_sfx,
                                    map_location=lambda storage, loc: storage)
        decoder.load_state_dict(dec_state_dict)

        pfnet_state_dict = torch.load(
            load_model_name_pfx + 'pfnet' + load_model_name_sfx,
            map_location=lambda storage, loc: storage)
        pfnet.load_state_dict(pfnet_state_dict)

    encoder_optimizer = optim.Adam(encoder.parameters(),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))
    decoder_optimizer = optim.Adam(decoder.parameters(),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))
    pfnet_optimizer = optim.Adam(pfnet.parameters(),
                                 lr=opt.lr,
                                 betas=(opt.b1, opt.b2))

    start = time.time()
    print_loss_total = 0  # Reset every print_every
    best_val_loss = sys.maxsize

    for iter in range(1, opt.niter + 1):

        if iter == 3:
            opt.lr = opt.lr / 10

            encoder_optimizer = optim.Adam(encoder.parameters(),
                                           lr=opt.lr,
                                           betas=(opt.b1, opt.b2))
            decoder_optimizer = optim.Adam(decoder.parameters(),
                                           lr=opt.lr,
                                           betas=(opt.b1, opt.b2))
            pfnet_optimizer = optim.Adam(pfnet.parameters(),
                                         lr=opt.lr,
                                         betas=(opt.b1, opt.b2))

        for j, k in enumerate(train_prompts):

            [input_variable, input_length] = get_x(train_prompts, k, phn2id,
                                                   use_cuda)

            train_targets, train_seq_len = load_targets(
                opt.feats_dir + '/fb/', [k], '.npy', dtype, mo1, so1, nml_vec1)

            [target_variable, target_variable2,
             target_length] = get_y(train_seq_len, 0, train_targets, use_cuda,
                                    r)

            loss = 0
            encoder_optimizer.zero_grad()
            decoder_optimizer.zero_grad()

            encoder_h0 = encoder.initHidden()
            encoder_c0 = encoder.initHidden()
            encoder_outputs = Variable(
                torch.zeros(input_length, encoder.hidden_size2))
            encoder_outputs = encoder_outputs.cuda(
            ) if use_cuda else encoder_outputs

            encoder_output, (encoder_hn,
                             encoder_cn) = encoder(input_variable,
                                                   (encoder_h0, encoder_c0))
            encoder_outputs = encoder_output.squeeze(1)

            decoder_input = Variable(torch.zeros(1,
                                                 op_dim))  # all - zero frame
            decoder_input = decoder_input.cuda() if use_cuda else decoder_input
            decoder_h1 = decoder.initHidden()
            decoder_c1 = decoder.initHidden()
            decoder_h2 = decoder.initHidden()
            decoder_c2 = decoder.initHidden()
            decoder_h3 = decoder.initHidden()
            decoder_c3 = decoder.initHidden()
            decoder_output_half = Variable(
                torch.zeros(target_length, r *
                            op_dim)).cuda() if use_cuda else Variable(
                                torch.zeros(target_length, r * op_dim))
            decoder_output_full = Variable(
                torch.zeros(r * target_length,
                            op_dim)).cuda() if use_cuda else Variable(
                                torch.zeros(r * target_length, op_dim))

            # Teacher forcing: Feed the target as the next input
            for di in range(target_length):
                decoder_output1, decoder_output2, decoder_h1, decoder_c1, decoder_h2, decoder_c2, decoder_h3, decoder_c3, decoder_attention = decoder(
                    decoder_input, decoder_h1, decoder_c1, decoder_h2,
                    decoder_c2, decoder_h3, decoder_c3, encoder_outputs)
                loss += criterion(decoder_output1, target_variable[di])
                decoder_input = target_variable2[di].unsqueeze(
                    0)  # Teacher forcing
                decoder_output_half[di] = decoder_output1

            loss.backward(retain_graph=True)
            encoder_optimizer.step()
            decoder_optimizer.step()

            encoder_optimizer.zero_grad()
            decoder_optimizer.zero_grad()

            # Start Post-Filtering Net
            pfnet_optimizer.zero_grad()

            for ix in range(r):
                decoder_output_full[
                    ix::r, :] = decoder_output_half[:, ix * op_dim:(ix + 1) *
                                                    op_dim]

            s1 = r * target_length

            train_targets_pfnet, train_seq_len_pfnet = load_targets(
                opt.feats_dir + '/sp/', [k], '.npy', dtype, mo2, so2, nml_vec2)
            targets_pfnet = Variable(train_targets_pfnet).cuda(
            ) if use_cuda else Variable(train_targets_pfnet)
            s2 = targets_pfnet.size()[0]
            if (s2 % r) > 0:
                targets_pfnet = targets_pfnet[:-(s2 % r), :]

            pfnet_h0 = pfnet.initHidden()
            pfnet_c0 = pfnet.initHidden()
            pfnet_outputs = Variable(
                torch.zeros(targets_pfnet.size()[0], pfnet.output_size))
            pfnet_outputs = pfnet_outputs.cuda() if use_cuda else pfnet_outputs

            pfnet_output = pfnet(decoder_output_full, (pfnet_h0, pfnet_c0))
            pfnet_outputs = pfnet_output
            loss_pfnet = criterion(pfnet_outputs, targets_pfnet)

            loss_pfnet.backward()
            pfnet_optimizer.step()
            loss_total = loss + loss_pfnet

            print_loss_total += (loss_total.data[0] / r * target_length)

            if (j + 1) % print_every == 0:
                print_loss_avg = print_loss_total / print_every
                print_loss_total = 0

                print('%s (%d %d%%) %.4f' % (timeSince(
                    start,
                    (iter * len(train_prompts) - len(train_prompts) + j) /
                    ((opt.niter + 1) * len(train_prompts))), iter, iter /
                                             opt.niter * 100, print_loss_avg))

                tf = True  # teacher forcing
                avg_total_val_loss_tf, avg_dec_val_loss_tf, decoder_attentions_tf = evaluate(
                    encoder.eval(), decoder.eval(), pfnet.eval(), val_prompts,
                    phn2id, id2phn, vocab_size, use_cuda, criterion, op_dim,
                    tf, opt, mo1, so1, nml_vec1, mo2, so2, nml_vec2)
                print('%d %0.4f %0.4f' %
                      (iter, avg_total_val_loss_tf, avg_dec_val_loss_tf))

                tf = False  # always sampling
                avg_total_val_loss_as, avg_dec_val_loss_as, decoder_attentions_pf = evaluate(
                    encoder.eval(), decoder.eval(), pfnet.eval(), val_prompts,
                    phn2id, id2phn, vocab_size, use_cuda, criterion, op_dim,
                    tf, opt, mo1, so1, nml_vec1, mo2, so2, nml_vec2)
                print('%d %0.4f %0.4f' %
                      (iter, avg_total_val_loss_as, avg_dec_val_loss_as))
                logging.debug(
                    'Epoch: ' + str(iter) + ' Update: ' +
                    str(iter * len(train_prompts) - len(train_prompts) + j) +
                    ' Avg Total Val Loss TF: ' + str(avg_total_val_loss_tf) +
                    ' Avg Total Val Loss AS: ' + str(avg_total_val_loss_as) +
                    ' Avg Dec Val Loss TF: ' + str(avg_dec_val_loss_tf) +
                    ' Avg Dec Val Loss AS: ' + str(avg_dec_val_loss_as))

                if avg_total_val_loss_tf < best_val_loss:
                    best_val_loss = avg_total_val_loss_tf
                    torch.save(
                        encoder.state_dict(),
                        '%s/%s_enc.pth' % (opt.model_folder, full_model_name))
                    torch.save(
                        decoder.state_dict(),
                        '%s/%s_dec.pth' % (opt.model_folder, full_model_name))
                    torch.save(
                        pfnet.state_dict(), '%s/%s_pfnet.pth' %
                        (opt.model_folder, full_model_name))

                encoder.train()
                decoder.train()
                pfnet.train()

            # if (j+1) % showatt_every == 0:

            #    plt.figure(1, figsize=(12, 12))
            #    plt.imshow(decoder_attentions_tf.numpy())
            #    plt.colorbar()
            #    pylab.savefig(opt.plot_folder + full_model_name + '_' + str(j) + '_' + str(iter) + '.png', bbox_inches='tight')
            #    plt.close()

            #    plt.figure(1, figsize=(12, 12))
            #    plt.imshow(decoder_attentions_pf.numpy())
            #    plt.colorbar()
            #    pylab.savefig(opt.plot_folder + full_model_name + '_' + str(j) + '_' + str(iter) + '.png', bbox_inches='tight')
            #    plt.close()

            # if (j+1) % plot_every == 0:
            #    plot_loss_avg = plot_loss_total / plot_every
            #    plot_losses.append(plot_loss_avg)
            #    plot_loss_total = 0

        gc.collect()
Ejemplo n.º 2
0
def test(opt):


    full_model_name = opt.model_name \
    + '_hs1' + str(opt.hs1) + '_hs2' + str(opt.hs2) + '_pfnet_hs1' + str(opt.pfnet_hs1) + '_r' + str(opt.r) \
    + '_lr' + str(opt.lr) + '_b1' + str(opt.b1) + '_b2' + str(opt.b2) \
    + '_dp' + str(opt.dp) \
    + '_gc' + str(opt.gcth) \
    + '_wtinit' + str(opt.wtinit_meth) \
    + '_lw' + str(int(opt.load_wts)) \
    + '_ef' + str(int(opt.embedding_flag)) \
    + '_rf' + str(int(opt.residual_flag))

    print(full_model_name)
    opt.full_model_name = full_model_name

    try:
        os.makedirs(opt.synth_folder + opt.full_model_name)
        os.makedirs(opt.plot_folder + opt.full_model_name)
    except OSError:
        pass

    fid = open(opt.feats_dir + 'train_list.txt')
    train_list = fid.read().splitlines()
    fid.close()

    fid = open(opt.feats_dir + 'val_list.txt')
    val_list = fid.read().splitlines()
    fid.close()

    all_prompts = make_prompts_dict_v2(opt.feats_dir + 'txt.done.data',
                                       train_list + val_list)
    vocab = make_vocab(all_prompts)
    phn2id, id2phn = phn2id2phn(vocab)
    print(vocab)

    fid = open(opt.feats_dir + 'test_list.txt')
    val_list = fid.read().splitlines()
    val_list = val_list[:10]
    fid.close()

    # Load stats of mfcc
    mo1 = np.load(opt.stats_dir + 'mo.npy')
    so1 = np.load(opt.stats_dir + 'so.npy')
    mo1 = mo1.astype('float32')
    so1 = so1.astype('float32')
    nml_vec1 = np.arange(0, mo1.shape[1])

    # Load stats of spectrum
    mo2 = np.load(opt.pfnet_stats_dir + 'mo.npy')
    so2 = np.load(opt.pfnet_stats_dir + 'so.npy')
    mo2 = mo2.astype('float32')
    so2 = so2.astype('float32')
    nml_vec2 = np.arange(0, mo2.shape[1])

    # Load validation data
    val_prompts = make_prompts_dict_v2(opt.feats_dir + 'txt.done.data',
                                       val_list)

    # Initialize model
    vocab_size = len(vocab)
    op_dim = 60
    encoder = encoders.EncoderCBL(vocab_size, opt.hs2, opt.hs1)
    if opt.residual_flag:
        if opt.r == 2:
            decoder = decoders.AttnDecoderLSTM3L_R2_Rescon(
                op_dim, opt.hs2, op_dim, 1, opt.dp)
        if opt.r == 3:
            decoder = decoders.AttnDecoderLSTM3L_R3_Rescon(
                op_dim, opt.hs2, op_dim, 1, opt.dp)
        if opt.r == 4:
            decoder = decoders.AttnDecoderLSTM3L_R4_Rescon(
                op_dim, opt.hs2, op_dim, 1, opt.dp)
        if opt.r == 5:
            decoder = decoders.AttnDecoderLSTM3L_R5_Rescon(
                op_dim, opt.hs2, op_dim, 1, opt.dp)
    else:
        decoder = decoders.AttnDecoderLSTM3L_R2(op_dim, opt.hs2, op_dim, 1,
                                                opt.dp)

    op_dim1 = 513
    pfnet = encoders.EncoderBLSTM_WOE_1L(op_dim, opt.pfnet_hs1, op_dim1)

    encoder = encoder.cuda() if use_cuda else encoder
    decoder = decoder.cuda() if use_cuda else decoder
    pfnet = pfnet.cuda() if use_cuda else pfnet
    criterion = torch.nn.L1Loss(size_average=False)

    load_model_name_pfx = '../../wt/' + opt.full_model_name + '_'
    load_model_name_sfx = '.pth'

    # load model
    enc_state_dict = torch.load(load_model_name_pfx + 'enc' +
                                load_model_name_sfx,
                                map_location=lambda storage, loc: storage)
    encoder.load_state_dict(enc_state_dict)

    dec_state_dict = torch.load(load_model_name_pfx + 'dec' +
                                load_model_name_sfx,
                                map_location=lambda storage, loc: storage)
    decoder.load_state_dict(dec_state_dict)

    pfnet_state_dict = torch.load(load_model_name_pfx + 'pfnet' +
                                  load_model_name_sfx,
                                  map_location=lambda storage, loc: storage)
    pfnet.load_state_dict(pfnet_state_dict)

    tf = True  # teacher forcing
    avg_val_loss_tf1, avg_val_loss_tf2, decoder_attentions_tf = evaluate(
        encoder.eval(), decoder.eval(), pfnet.eval(), val_prompts, phn2id,
        id2phn, vocab_size, use_cuda, criterion, op_dim, tf, opt, mo1, so1,
        nml_vec1, mo2, so2, nml_vec2)
    print('%0.4f %0.4f' % (avg_val_loss_tf1, avg_val_loss_tf2))

    tf = False  # professor forcing
    avg_val_loss_pf1, avg_val_loss_pf2, decoder_attentions_pf = evaluate(
        encoder.eval(), decoder.eval(), pfnet.eval(), val_prompts, phn2id,
        id2phn, vocab_size, use_cuda, criterion, op_dim, tf, opt, mo1, so1,
        nml_vec1, mo2, so2, nml_vec2)
    print('%0.4f %0.4f' % (avg_val_loss_pf1, avg_val_loss_pf2))
Ejemplo n.º 3
0
def train(opt):
    '''
    data could be loaded to a dictionary with "train"/"val"/"test" pointers (Need to improve the below part)
    '''

    print_every = opt.print_every
    showatt_every = opt.print_every
    plot_every = opt.print_every

    full_model_name = opt.model_name \
    + '_hs1' + str(opt.hs1) + '_hs2' + str(opt.hs2) + '_pfnet_hs1' + str(opt.pfnet_hs1)\
    + '_lr' + str(opt.lr) + '_b1' + str(opt.b1) + '_b2' + str(opt.b2) \
    + '_dp' + str(opt.dp) \
    + '_gc' + str(opt.gcth) \
    + '_wtinit' + str(opt.wtinit_meth) \
    + '_ef' + str(int(opt.embedding_flag)) \
    + '_rf' + str(int(opt.residual_flag))

    print(full_model_name)

    logging.basicConfig(filename=opt.log_folder + full_model_name + '.log',
                        filemode='w',
                        level=logging.DEBUG,
                        format='%(asctime)s - %(levelname)s - %(message)s')

    # Load training data
    phase = 'train'
    train_prompts = make_prompts_dict(opt.etc_dir + phase + '.done.data')
    vocab = make_vocab(train_prompts)
    phn2id, id2phn = phn2id2phn(vocab)
    file_list = train_prompts.keys()
    save_stats(opt.feats_dir + phase + '/audio_feats/', file_list,
               opt.audio_feats_ext, dtype, opt.stats_dir)
    # save_stats(opt.feats_dir + phase + '/log_mag_spec/',
    # file_list, opt.pfnet_audio_feats_ext, dtype, opt.pfnet_stats_dir)
    train_targets, train_seq_len = load_targets(
        opt.feats_dir + phase + '/audio_feats/', file_list,
        opt.audio_feats_ext, dtype, opt.stats_dir)

    # Load validation data
    phase = 'val'
    val_prompts = make_prompts_dict(opt.etc_dir + phase + '.done.data')
    file_list = val_prompts.keys()
    val_targets, val_seq_len = load_targets(
        opt.feats_dir + phase + '/audio_feats/', file_list,
        opt.audio_feats_ext, dtype, opt.stats_dir)

    # Initialize model
    vocab_size = len(vocab)
    op_dim = 60
    encoder = encoders.EncoderBLSTM_WOE(vocab_size, opt.hs1)
    if opt.residual_flag:
        decoder = decoders.AttnDecoderLSTM3L_R2_Rescon(op_dim, opt.hs2, op_dim,
                                                       1, opt.dp)
    else:
        decoder = decoders.AttnDecoderLSTM3L_R2(op_dim, opt.hs2, op_dim, 1,
                                                opt.dp)

    op_dim1 = 513
    pfnet = encoders.EncoderBLSTM_WOE_1L(op_dim, opt.pfnet_hs1, op_dim1)

    encoder = encoder.cuda() if use_cuda else encoder
    decoder = decoder.cuda() if use_cuda else decoder
    pfnet = pfnet.cuda() if use_cuda else pfnet
    criterion = torch.nn.L1Loss(size_average=False)

    encoder_optimizer = optim.Adam(encoder.parameters(),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))
    decoder_optimizer = optim.Adam(decoder.parameters(),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))
    pfnet_optimizer = optim.Adam(pfnet.parameters(),
                                 lr=opt.lr,
                                 betas=(opt.b1, opt.b2))

    start = time.time()
    plot_losses = []
    print_loss_total = 0  # Reset every print_every
    plot_loss_total = 0  # Reset every plot_every
    print_loss_total2 = 0  # Reset every print_every
    plot_loss_total2 = 0  # Reset every plot_every
    best_val_loss = 1000000

    for iter in range(1, opt.niter + 1):

        for j, k in enumerate(train_prompts):

            [input_variable,
             input_length] = get_x_1hot(train_prompts, k, phn2id, vocab_size,
                                        use_cuda)
            [target_variable, target_variable2,
             target_length] = get_y(train_seq_len, j, train_targets, use_cuda)

            loss = 0
            encoder_optimizer.zero_grad()
            decoder_optimizer.zero_grad()

            encoder_h0 = encoder.initHidden()
            encoder_c0 = encoder.initHidden()
            encoder_outputs = Variable(
                torch.zeros(input_length, encoder.hidden_size))
            encoder_outputs = encoder_outputs.cuda(
            ) if use_cuda else encoder_outputs

            encoder_output, (encoder_hn,
                             encoder_cn) = encoder(input_variable,
                                                   (encoder_h0, encoder_c0))
            encoder_outputs = encoder_output.squeeze(1)

            decoder_input = Variable(torch.zeros(1,
                                                 op_dim))  # all - zero frame
            decoder_input = decoder_input.cuda() if use_cuda else decoder_input
            decoder_h1 = decoder.initHidden()
            decoder_c1 = decoder.initHidden()
            decoder_h2 = decoder.initHidden()
            decoder_c2 = decoder.initHidden()
            decoder_h3 = decoder.initHidden()
            decoder_c3 = decoder.initHidden()
            decoder_output_half = Variable(
                torch.zeros(target_length, 2 *
                            op_dim)).cuda() if use_cuda else Variable(
                                torch.zeros(target_length, 2 * op_dim))
            decoder_output_full = Variable(
                torch.zeros(2 * target_length,
                            op_dim)).cuda() if use_cuda else Variable(
                                torch.zeros(2 * target_length, op_dim))

            # Teacher forcing: Feed the target as the next input
            for di in range(target_length):
                decoder_output1, decoder_output2, decoder_h1, decoder_c1, decoder_h2, decoder_c2, decoder_h3, decoder_c3, decoder_attention = decoder(
                    decoder_input, decoder_h1, decoder_c1, decoder_h2,
                    decoder_c2, decoder_h3, decoder_c3, encoder_outputs)
                loss += criterion(decoder_output1, target_variable[di])
                decoder_input = target_variable2[di].unsqueeze(
                    0)  # Teacher forcing
                decoder_output_half[di] = decoder_output1

            loss.backward(retain_graph=True)
            encoder_optimizer.step()
            decoder_optimizer.step()

            encoder_optimizer.zero_grad()
            decoder_optimizer.zero_grad()

            # Start Post-Filtering Net
            pfnet_optimizer.zero_grad()
            decoder_output_full[0::2, :] = decoder_output_half[:, :60]
            decoder_output_full[1::2, :] = decoder_output_half[:, 60:]
            s1 = 2 * target_length

            train_targets_pfnet, train_seq_len_pfnet = load_targets(
                opt.feats_dir + 'train' + '/log_mag_spec/', [k],
                opt.pfnet_audio_feats_ext, dtype, opt.pfnet_stats_dir)
            targets_pfnet = Variable(train_targets_pfnet).cuda(
            ) if use_cuda else train_targets_pfnet
            s2 = targets_pfnet.size()[0]
            if s2 > s1:
                targets_pfnet = targets_pfnet[:-1, :]

            pfnet_h0 = pfnet.initHidden()
            pfnet_c0 = pfnet.initHidden()
            pfnet_outputs = Variable(
                torch.zeros(targets_pfnet.size()[0], pfnet.output_size))
            pfnet_outputs = pfnet_outputs.cuda() if use_cuda else pfnet_outputs

            pfnet_output = pfnet(decoder_output_full, (pfnet_h0, pfnet_c0))
            pfnet_outputs = pfnet_output
            loss1 = criterion(pfnet_outputs, targets_pfnet)

            loss2 = loss + loss1
            loss1.backward()
            #torch.nn.utils.clip_grad_norm(encoder.parameters(), 1)
            #torch.nn.utils.clip_grad_norm(decoder.parameters(), 1)
            #encoder_optimizer.step()
            #decoder_optimizer.step()
            pfnet_optimizer.step()

            print_loss_total += (loss.data[0] / target_length)
            plot_loss_total += (loss.data[0] / target_length)
            print_loss_total2 += (loss2.data[0] / 2 * target_length)
            plot_loss_total2 += (loss2.data[0] / 2 * target_length)

            if (j + 1) % print_every == 0:
                print_loss_avg = print_loss_total / print_every
                print_loss_avg2 = print_loss_total2 / print_every

                print_loss_total = 0
                print_loss_total2 = 0
                print(
                    '%s (%d %d%%) %.4f %.4f' %
                    (timeSince(
                        start,
                        (iter * len(train_prompts) - len(train_prompts) + j) /
                        ((opt.niter + 1) * len(train_prompts))), iter,
                     iter / opt.niter * 100, print_loss_avg, print_loss_avg2))

                tf = True  # teacher forcing
                avg_val_loss_tf1, avg_val_loss_tf2, decoder_attentions_tf = evaluate(
                    encoder.eval(), decoder.eval(), pfnet.eval(), val_prompts,
                    val_targets, val_seq_len, phn2id, id2phn, vocab_size,
                    use_cuda, criterion, op_dim, tf, opt)
                print('%d %0.4f %0.4f' %
                      (iter, avg_val_loss_tf1, avg_val_loss_tf2))

                tf = False  # professor forcing
                avg_val_loss_pf1, avg_val_loss_pf2, decoder_attentions_pf = evaluate(
                    encoder.eval(), decoder.eval(), pfnet.eval(), val_prompts,
                    val_targets, val_seq_len, phn2id, id2phn, vocab_size,
                    use_cuda, criterion, op_dim, tf, opt)
                print('%d %0.4f %0.4f' %
                      (iter, avg_val_loss_pf1, avg_val_loss_pf2))
                logging.debug('Epoch: ' + str(iter) + ' Update: ' +
                              str(iter * len(train_prompts) -
                                  len(train_prompts) + j) +
                              ' Avg Val Loss TF1: ' + str(avg_val_loss_tf1) +
                              ' Avg Val Loss PF1: ' + str(avg_val_loss_pf1))

                if avg_val_loss_tf1 < best_val_loss:
                    best_val_loss = avg_val_loss_tf1
                    torch.save(
                        encoder.state_dict(), '%s/%s_enc_epoch_%d_%d.pth' %
                        (opt.model_folder, full_model_name, j, iter))
                    torch.save(
                        decoder.state_dict(), '%s/%s_dec_epoch_%d_%d.pth' %
                        (opt.model_folder, full_model_name, j, iter))
                    torch.save(
                        pfnet.state_dict(), '%s/%s_pfnet_epoch_%d_%d.pth' %
                        (opt.model_folder, full_model_name, j, iter))

                encoder.train()
                decoder.train()
                pfnet.train()

            if (j + 1) % showatt_every == 0:

                plt.figure(1, figsize=(12, 12))
                plt.imshow(decoder_attentions_tf.numpy())
                plt.colorbar()
                pylab.savefig(opt.plot_folder + full_model_name + '_' +
                              str(j) + '_' + str(iter) + '.png',
                              bbox_inches='tight')
                plt.close()

                plt.figure(1, figsize=(12, 12))
                plt.imshow(decoder_attentions_pf.numpy())
                plt.colorbar()
                pylab.savefig(opt.plot_folder + full_model_name + '_' +
                              str(j) + '_' + str(iter) + '.png',
                              bbox_inches='tight')
                plt.close()

            if (j + 1) % plot_every == 0:
                plot_loss_avg = plot_loss_total / plot_every
                plot_losses.append(plot_loss_avg)
                plot_loss_total = 0

        gc.collect()
    showPlot(plot_losses)
Ejemplo n.º 4
0
def test(opt):
    
    print_every = opt.print_every
    showatt_every = opt.print_every
    plot_every = opt.print_every

    full_model_name = opt.model_name \
    + '_hs1' + str(opt.hs1) + '_hs2' + str(opt.hs2) + '_pfnet_hs1' + str(opt.pfnet_hs1) + '_r' + str(opt.r) \
    + '_lr' + str(opt.lr) + '_b1' + str(opt.b1) + '_b2' + str(opt.b2) \
    + '_dp' + str(opt.dp) \
    + '_gc' + str(opt.gcth) \
    + '_wtinit' + str(opt.wtinit_meth) \
    + '_lw' + str(int(opt.load_wts)) \
    + '_ef' + str(int(opt.embedding_flag)) \
    + '_rf' + str(int(opt.residual_flag))

    print(full_model_name)
    opt.full_model_name = full_model_name
    
    r = opt.r
 
    fid = open(opt.feats_dir + 'train_list.txt')
    train_list = fid.read().splitlines()
    fid.close()

    fid = open(opt.feats_dir + 'val_list.txt')
    val_list = fid.read().splitlines()
    val_list = val_list
    fid.close()

    all_prompts = make_prompts_dict_v2(opt.feats_dir + 'txt.done.data', train_list + val_list)
    vocab = make_vocab(all_prompts)
    print(vocab)

    fid = open(opt.feats_dir + 'test_list.txt')
    val_list = fid.read().splitlines()
    val_list = val_list[:10]
    fid.close()



    # Load training data
    train_prompts = make_prompts_dict_v2(opt.feats_dir + 'txt.done.data', train_list)
    #vocab = make_vocab(train_prompts)
    #print(vocab)
    phn2id, id2phn = phn2id2phn(vocab)
    file_list = train_prompts.keys()
    print(len(file_list), len(train_list))
    #save_stats_suffstats(opt.feats_dir + '/fb/', file_list, '.npy', dtype, opt.stats_dir)
    #save_stats_suffstats(opt.feats_dir + '/sp/', file_list, '.npy', dtype, opt.pfnet_stats_dir)

    # save_stats(opt.feats_dir + '/fb/', file_list, '.npy', dtype, opt.stats_dir)
    # save_stats(opt.feats_dir + phase + '/log_mag_spec/',
    # file_list, opt.pfnet_audio_feats_ext, dtype, opt.pfnet_stats_dir)
    # exit()

    # Load stats of mfcc
    mo1 = np.load(opt.stats_dir + 'mo.npy')
    so1 = np.load(opt.stats_dir + 'so.npy')
    mo1 = mo1.astype('float32')
    so1 = so1.astype('float32')
    nml_vec1 = np.arange(0, mo1.shape[1])

    # Load stats of spectrum
    mo2 = np.load(opt.pfnet_stats_dir + 'mo.npy')
    so2 = np.load(opt.pfnet_stats_dir + 'so.npy')
    mo2 = mo2.astype('float32')
    so2 = so2.astype('float32')
    nml_vec2 = np.arange(0, mo2.shape[1])

    #train_targets, train_seq_len = load_targets(opt.feats_dir + phase
    #                                            + '/audio_feats/', file_list,
    #                                            opt.audio_feats_ext,
    #                                            dtype, mo1, so1, nml_vec1)

    # Load validation data
    val_prompts = make_prompts_dict_v2(opt.feats_dir + 'txt.done.data', val_list)
    #file_list = val_prompts.keys()
    #val_targets, val_seq_len = load_targets(opt.feats_dir + '/fb/', val_list, '.npy', dtype, mo1, so1, nml_vec1)
    #print(val_seq_len)
    #print(val_targets.shape)
    #print(val_list)

    # Initialize model
    vocab_size = len(vocab)
    op_dim = 60
    encoder = encoders.EncoderBLSTM_WOE(vocab_size, opt.hs1)
    if opt.residual_flag:
        if opt.r == 2:
            decoder = decoders.AttnDecoderLSTM3L_R2_Rescon(op_dim, opt.hs2, op_dim, 1, opt.dp)
        if opt.r == 3:
            decoder = decoders.AttnDecoderLSTM3L_R3_Rescon(op_dim, opt.hs2, op_dim, 1, opt.dp)
        if opt.r == 4:
            decoder = decoders.AttnDecoderLSTM3L_R4_Rescon(op_dim, opt.hs2, op_dim, 1, opt.dp)
        if opt.r == 5:
            decoder = decoders.AttnDecoderLSTM3L_R5_Rescon(op_dim, opt.hs2, op_dim, 1, opt.dp)
    else:
        decoder = decoders.AttnDecoderLSTM3L_R2(op_dim, opt.hs2, op_dim, 1, opt.dp)

    op_dim1 = 513
    pfnet = encoders.EncoderBLSTM_WOE_1L(op_dim, opt.pfnet_hs1, op_dim1)

    encoder = encoder.cuda() if use_cuda else encoder
    decoder = decoder.cuda() if use_cuda else decoder
    pfnet = pfnet.cuda() if use_cuda else pfnet
    criterion = torch.nn.L1Loss(size_average=False)

    if opt.load_wts:
        load_model_name_pfx = '../../wt/s2s_enc_blstm_dec_lstm3l_pfnet_blstm1L_nopfnetloss__hs1250_hs2500_pfnet_hs1250_r3_lr3e-05_b10.9_b20.99_dp0.5_gc0.0_wtinitdefault_init_lw1_ef0_rf1_'
        load_model_name_sfx = '_epoch_2999_5.pth'

        # load model
        enc_state_dict = torch.load(load_model_name_pfx + 'enc' + load_model_name_sfx, map_location=lambda storage, loc: storage)
        encoder.load_state_dict(enc_state_dict)

        dec_state_dict = torch.load(load_model_name_pfx + 'dec' + load_model_name_sfx, map_location=lambda storage, loc: storage)
        decoder.load_state_dict(dec_state_dict)

        pfnet_state_dict = torch.load(load_model_name_pfx + 'pfnet' + load_model_name_sfx, map_location=lambda storage, loc: storage)
        pfnet.load_state_dict(pfnet_state_dict)
    
    
    tf = True # teacher forcing
    avg_val_loss_tf1, avg_val_loss_tf2, decoder_attentions_tf = evaluate(encoder.eval(), decoder.eval(), pfnet.eval(), val_prompts, phn2id, id2phn, vocab_size, use_cuda, criterion, op_dim, tf, opt, mo1, so1, nml_vec1, mo2, so2, nml_vec2)
    print('%0.4f %0.4f' % (avg_val_loss_tf1, avg_val_loss_tf2))
Ejemplo n.º 5
0
def test(opt):

    full_model_name = opt.model_name \
    + '_hs1' + str(opt.hs1) + '_hs2' + str(opt.hs2) + '_pfnet_hs1' + str(opt.pfnet_hs1)\
    + '_lr' + str(opt.lr) + '_b1' + str(opt.b1) + '_b2' + str(opt.b2) \
    + '_dp' + str(opt.dp) \
    + '_gc' + str(opt.gcth) \
    + '_wtinit' + str(opt.wtinit_meth) \
    + '_ef' + str(int(opt.embedding_flag)) \
    + '_rf' + str(int(opt.residual_flag))

    print(full_model_name)
    opt.full_model_name = full_model_name

    # Load training data
    phase = 'train'
    train_prompts = make_prompts_dict(opt.etc_dir + phase + '.done.data')
    vocab = make_vocab(train_prompts)
    phn2id, id2phn = phn2id2phn(vocab)

    # Load validation data
    phase = 'test'
    val_prompts = make_prompts_dict(opt.etc_dir + phase + '.done.data')
    file_list = val_prompts.keys()
    file_list = file_list[:5]
    val_targets, val_seq_len = load_targets(
        opt.feats_dir + phase + '/audio_feats/', file_list,
        opt.audio_feats_ext, dtype, opt.stats_dir)

    # Initialize model
    vocab_size = len(vocab)
    op_dim = 60
    encoder = encoders.EncoderBLSTM_WOE(vocab_size, opt.hs1)
    if opt.residual_flag:
        decoder = decoders.AttnDecoderLSTM3L_R2_Rescon(op_dim, opt.hs2, op_dim,
                                                       1, opt.dp)
    else:
        decoder = decoders.AttnDecoderLSTM3L_R2(op_dim, opt.hs2, op_dim, 1,
                                                opt.dp)

    op_dim1 = 513
    pfnet = encoders.EncoderBLSTM_WOE_1L(op_dim, opt.pfnet_hs1, op_dim1)

    encoder = encoder.cuda() if use_cuda else encoder
    decoder = decoder.cuda() if use_cuda else decoder
    pfnet = pfnet.cuda() if use_cuda else pfnet
    criterion = torch.nn.L1Loss(size_average=False)

    # load model
    enc_state_dict = torch.load(
        '../../wt/s2s_enc_blstm_dec_lstm3l_pfnet_blstm1L_nopfnetloss__hs1250_hs2500_pfnet_hs1250_lr0.0003_b10.9_b20.99_dp0.5_gc0.0_wtinitdefault_init_ef0_rf1_enc_epoch_999_18.pth',
        map_location=lambda storage, loc: storage)
    encoder.load_state_dict(enc_state_dict)

    dec_state_dict = torch.load(
        '../../wt/s2s_enc_blstm_dec_lstm3l_pfnet_blstm1L_nopfnetloss__hs1250_hs2500_pfnet_hs1250_lr0.0003_b10.9_b20.99_dp0.5_gc0.0_wtinitdefault_init_ef0_rf1_dec_epoch_999_18.pth',
        map_location=lambda storage, loc: storage)
    decoder.load_state_dict(dec_state_dict)

    pfnet_state_dict = torch.load(
        '../../wt/s2s_enc_blstm_dec_lstm3l_pfnet_blstm1L_nopfnetloss__hs1250_hs2500_pfnet_hs1250_lr0.0003_b10.9_b20.99_dp0.5_gc0.0_wtinitdefault_init_ef0_rf1_pfnet_epoch_999_18.pth',
        map_location=lambda storage, loc: storage)
    pfnet.load_state_dict(pfnet_state_dict)

    #tf = True # teacher forcing
    #avg_val_loss_tf1, avg_val_loss_tf2, decoder_attentions_tf = evaluate(encoder.eval(), decoder.eval(), pfnet.eval(), val_prompts, val_targets, val_seq_len, phn2id, id2phn, vocab_size, use_cuda, criterion, op_dim, tf, opt)
    #print('%d %0.4f %0.4f' % (iter, avg_val_loss_tf1, avg_val_loss_tf2))

    tf = False  # professor forcing
    avg_val_loss_pf1, avg_val_loss_pf2, decoder_attentions_pf = evaluate(
        encoder.eval(), decoder.eval(), pfnet.eval(), val_prompts, val_targets,
        val_seq_len, phn2id, id2phn, vocab_size, use_cuda, criterion, op_dim,
        tf, opt)
    print('%d %0.4f %0.4f' % (iter, avg_val_loss_pf1, avg_val_loss_pf2))