示例#1
0
def decode(sess, model, query, isTest):
    '''    
    :param sess: 
    :param model: 
    :param query: 
    :param isTest: if False, query need to have tags on conversation mode.  
    :return: 
    '''
    # Load vocabularies.
    enc_vocab, rev_dec_vocab = load_vocab(
        "./final/Encoder_Dict_key2val_file_forGG",
        "./final/Decoder_Dict_val2key_file_forGG")

    if isTest:
        # Get token-ids for the input sentence.
        tokens = utils.query_preserver(query)
        token_ids = utils.sent2idx(tokens, enc_vocab)
    else:
        #print("Debug on starting conv", query)
        token_ids = utils.sent2idx(query, enc_vocab)

    # Which bucket does it belong to?
    bucket_id = len(buckets) - 1
    for i, bucket in enumerate(buckets):
        if bucket[0] >= len(token_ids):
            bucket_id = i
            break

    # Get a 1-element batch to feed the sentence to the model.
    encoder_inputs, decoder_inputs, target_weights = model.get_batch(
        {bucket_id: [(token_ids, [])]},
        bucket_id)  #getbatch 적용하기위한 load_data수정필요

    # Get output logits for the sentence.
    _, _, output_logits = model.step(
        sess, encoder_inputs, decoder_inputs, target_weights, bucket_id,
        True)  # output_logits, (decoder_length*batch_size*dec_vocab_size)

    # This is a greedy decoder - outputs are just argmaxes of output_logits.
    #outputs = [int(np.argmax(logits, axis=1)) for logits in output_logits] # because of batch-major
    # outputs = [np.argmax(np.sum(logits, axis=0)) for logits in output_logits]
    outputs = [np.argmax(logits, axis=-1) for logits in output_logits]
    outputs = np.transpose(outputs)
    output = list(outputs[0])
    print("Debug", output)

    # If there is an EOS symbol in outputs, cut them at that point.
    EOS_ID = '2'  # dec_vocab["<EOS>"]
    PAD_ID = '0'
    if EOS_ID in output:
        output = output[:output.index(EOS_ID)]  # EOS 까지만 outputs 가져온다.
    elif PAD_ID in output:
        output = output[:output.index(PAD_ID)]

    # print out decoder sentence corresponding to outputs
    answer = [rev_dec_vocab[str(token)] for token in output]

    return answer
示例#2
0
def decode(sess, model, query):
    # Load vocabularies.
    enc_vocab, rev_dec_vocab = load_vocab(
        "./final/Encoder_Dict_key2val_file_forBE",
        "./final/Decoder_Dict_val2key_file_forBE")

    # Get token-ids for the input sentence.
    tokens = utils.query_disintegrator(query)
    token_ids = utils.sent2idx(tokens, enc_vocab)

    # Which bucket does it belong to?
    bucket_id = len(buckets) - 1
    for i, bucket in enumerate(buckets):
        if bucket[0] >= len(token_ids):
            bucket_id = i
            break

    # Get a 1-element batch to feed the sentence to the model.
    encoder_inputs, decoder_inputs, target_weights = model.get_batch(
        {bucket_id: [(token_ids, [])]},
        bucket_id)  #getbatch 적용하기위한 load_data수정필요

    # Get output logits for the sentence.
    _, _, output_logits = model.step(
        sess, encoder_inputs, decoder_inputs, target_weights, bucket_id,
        True)  # output_logits, (decoder_length*batch_size*dec_vocab_size)

    #output_logits = np.array(output_logits)
    #print("debug", output_logits.shape)
    # This is a greedy decoder - outputs are just argmaxes of output_logits.

    #outputs = [int(np.argmax(logits, axis=1)) for logits in output_logits]
    outputs = [np.argmax(logits, axis=-1) for logits in output_logits]
    outputs = np.transpose(outputs)
    output = list(outputs[0])

    # If there is an EOS symbol in outputs, cut them at that point.
    EOS_ID = 2  # dec_vocab["<EOS>"]
    if EOS_ID in output:
        output = output[:output.index(EOS_ID)]
    else:
        sent = utils.tokenize("할 말이 없다")
        output = utils.sent2idx(sent, enc_vocab)

    # print out decoder sentence corresponding to outputs
    #answer = " ".join([rev_dec_vocab[str(token)] for token in output])
    answer = [rev_dec_vocab[str(token)] for token in output]

    return answer
示例#3
0
 def test_input(text):
     x_input = sent2idx(text, wordtoix, opt)
     res = sess.run(res_,
                    feed_dict={
                        x_: x_input,
                        x_org_: x_batch_org
                    })
     print "Reconstructed:" + " ".join(
         [ixtoword[x] for x in res['rec_sents'][0] if x != 0])
示例#4
0
def collate_fn(batch):
    # Add ending token at the end
    text_idx = [sent2idx(b['text']) + [hps.vocab.find('E')] for b in batch]
    #    # In order to use torch.nn.utils.rnn.pack_padded_sequence, we sort by text length in  a decreasing order
    #    text_len = [len(x) for x in text_idx]
    #    idx = sorted(range(len(text_idx)), key=lambda i: -text_len[i])
    #    # Sort
    #    text_idx = [text_idx[i] for i in idx]
    mel = [b['mel'] for b in batch]
    mag = [b['mag'] for b in batch]

    max_text_len = max([len(x) for x in text_idx])
    max_time_step = max([x.shape[0] for x in mel])
    # for reduction factor
    remain = max_time_step % hps.reduction_factor
    max_time_step += (hps.reduction_factor - remain)

    text_len = []
    mel_len = []

    # Padding
    for i, x in enumerate(text_idx):
        L = len(x)
        diff = max_text_len - L
        pad = [hps.vocab.find('P') for _ in range(diff)]
        text_idx[i] += pad
        text_len.append(L)

    for i, x in enumerate(mel):
        L = x.shape[0]
        diff = max_time_step - L
        pad = np.zeros([diff, x.shape[1]])
        mel[i] = np.concatenate([x, pad], axis=0)
        mel_len.append(L)

    for i, x in enumerate(mag):
        L = x.shape[0]
        diff = max_time_step - L
        pad = np.zeros([diff, x.shape[1]])
        mag[i] = np.concatenate([x, pad], axis=0)

    return {
        'text': torch.LongTensor(text_idx),
        'mel': torch.Tensor(mel),
        'mag': torch.Tensor(mag),
        'text_length': torch.LongTensor(text_len),
        'frame_length': torch.LongTensor(mel_len)
    }
示例#5
0
def evaluation(model, step, device, args):
    # Evaluation
    model.eval()
    with torch.no_grad():
        # Preprocessing eval texts
        print('Start generating evaluation speeches...')
        n_eval = len(hps.eval_texts)
        for i in range(n_eval):
            sys.stdout.write('\rProgress: {}/{}'.format(i + 1, n_eval))
            sys.stdout.flush()
            text = hps.eval_texts[i]
            text = text_normalize(text)

            txt_id = sent2idx(text) + [hps.vocab.find('E')]
            txt_len = len(txt_id)
            GO_frame = torch.zeros(1, 1, hps.n_mels)

            # Shape: (1, seq_length)
            txt = torch.LongTensor([txt_id])
            txt_len = torch.LongTensor([txt_len])
            if args.cuda:
                GO_frame = GO_frame.cuda()
                txt = txt.cuda()
                txt_len.cuda()
            _batch = model(text=txt, frames=GO_frame, text_length=txt_len)
            mel = _batch['mel'][0]
            mag = _batch['mag'][0]
            attn = _batch['attn'][0]
            if args.cuda:
                mel = mel.cpu()
                mag = mag.cpu()
                attn = attn.cpu()
            mel = mel.numpy()
            mag = mag.numpy()
            attn = attn.numpy()

            wav = mag2wav(mag)
            save_alignment(attn, step, 'eval/plots/attn_{}.png'.format(text))
            save_spectrogram(mag,
                             'eval/plots/spectrogram_[{}].png'.format(text))
            save_wav(wav, 'eval/results/wav_{}.wav'.format(text))
        sys.stdout.write('\n')
示例#6
0
def collate_fn(batch):
    #GO_frame = np.zeros([1, hps.n_mels])
    # Add ending token at the end
    idx = [sent2idx(b['text']) + [hps.char_set.find('E')] for b in batch]
    # Add GO frame at the beginning
    #mel = [np.concatenate([GO_frame, b['mel']], axis=0) for b in batch]
    mel = [b['mel'] for b in batch]
    mag = [b['mag'] for b in batch]

    max_text_len = max([len(x) for x in idx])
    max_time_step = max([x.shape[0] for x in mel])
    # for reduction factor
    remain = max_time_step % hps.reduction_factor
    max_time_step += (hps.reduction_factor - remain)

    # Padding
    for i, x in enumerate(idx):
        L = len(x)
        diff = max_text_len - L
        pad = [hps.char_set.find('P') for _ in range(diff)]
        idx[i] += pad

    for i, x in enumerate(mel):
        L = x.shape[0]
        diff = max_time_step - L
        pad = np.zeros([diff, x.shape[1]])
        mel[i] = np.concatenate([x, pad], axis=0)

    for i, x in enumerate(mag):
        L = x.shape[0]
        diff = max_time_step - L
        pad = np.zeros([diff, x.shape[1]])
        mag[i] = np.concatenate([x, pad], axis=0)

    return {
        'text': torch.LongTensor(idx),
        'mel': torch.Tensor(mel),
        'mag': torch.Tensor(mag)
    }
示例#7
0
def run(args):
    # Check cuda device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # Data
    if hps.bucket:
        dataset = LJSpeech_Dataset(meta_file=hps.meta_path, wav_dir=hps.wav_dir, batch_size=hps.batch_size, do_bucket=True, bucket_size=20)
        loader = DataLoader(
            dataset, 
            batch_size=1,
            shuffle=True,
            num_workers=4)
    else:
        dataset = LJSpeech_Dataset(meta_file=hps.meta_path, wav_dir=hps.wav_dir)
        loader = DataLoader(
            dataset,
            batch_size=hps.batch_size,
            shuffle=True,
            num_workers=4,
            drop_last=True,
            collate_fn=collate_fn)

    # Network
    model = Tacotron()
    criterion = nn.L1Loss()
    if args.cuda:
        model = nn.DataParallel(model.to(device))
        criterion = criterion.to(device)
    # The learning rate scheduling mechanism in "Attention is all you need" 
    lr_lambda = lambda step: hps.warmup_step ** 0.5 * min((step+1) * (hps.warmup_step ** -1.5), (step+1) ** -0.5)
    optimizer = optim.Adam(model.parameters(), lr=hps.lr)
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
        
    step = 1
    epoch = 1
    # Load model
    if args.ckpt:
        ckpt = load(args.ckpt)
        step = ckpt['step']
        epoch = ckpt['epoch']
        model.load_state_dict(ckpt['model'])
        optimizer.load_state_dict(ckpt['optimizer'])
        scheduler = optim.lr_scheduler.LambdaLR(
            optimizer, 
            lr_lambda, 
            last_epoch=step)

    if args.eval:
        # Evaluation
        model.eval()
        with torch.no_grad():
			# Preprocessing eval texts
            print('Start generating evaluation speeches...')
            n_eval = len(hps.eval_texts)
            for i in range(n_eval):
                sys.stdout.write('\rProgress: {}/{}'.format(i+1, n_eval))
                sys.stdout.flush()
                text = hps.eval_texts[i]
                text = text_normalize(text)
                txt_id = sent2idx(text) + [hps.char_set.find('E')]
                GO_frame = torch.zeros(1, 1, hps.n_mels)

                # Shape: (1, seq_length)
                txt = torch.LongTensor(txt_id).unsqueeze(0)
                if args.cuda:
                    GO_frame = GO_frame.cuda()
                    txt = txt.cuda()
                _batch = model(text=txt, frames=GO_frame)
                mel = _batch['mel'][0]
                mag = _batch['mag'][0]
                attn = _batch['attn'][0]
                if args.cuda:
               	    mel = mel.cpu()
                    mag = mag.cpu()
                    attn = attn.cpu()
                mel = mel.numpy()
                mag = mag.numpy()
                attn = attn.numpy()

                wav = mag2wav(mag)
                save_alignment(attn, step, 'eval/plots/attn_{}.png'.format(text))
                save_spectrogram(mag, 'eval/plots/spectrogram_[{}].png'.format(text))
                save_wav(wav, 'eval/results/wav_{}.wav'.format(text))
            sys.stdout.write('\n')

    if args.train:
        before_load = time.time()
        # Start training
        model.train()
        while True:
            for batch in loader:
                # torch.LongTensor, (batch_size, seq_length)
                txt = batch['text']
                # torch.Tensor, (batch_size, max_time, hps.n_mels)
                mel = batch['mel']
                # torch.Tensor, (batch_size, max_time, hps.n_fft)
                mag = batch['mag']
                if hps.bucket:
                    # If bucketing, the shape will be (1, batch_size, ...)
                    txt = txt.squeeze(0)
                    mel = mel.squeeze(0)
                    mag = mag.squeeze(0)
                # GO frame
                GO_frame = torch.zeros(mel[:, :1, :].size())
                if args.cuda:
                    txt = txt.to(device)
                    mel = mel.to(device)
                    mag = mag.to(device)
                    GO_frame = GO_frame.to(device)

                # Model prediction
                decoder_input = torch.cat([GO_frame, mel[:, hps.reduction_factor::hps.reduction_factor, :]], dim=1)

                load_time = time.time() - before_load
                before_step = time.time()

                _batch = model(text=txt, frames=decoder_input)
                _mel = _batch['mel']
                _mag = _batch['mag']
                _attn = _batch['attn']

                # Optimization
                optimizer.zero_grad()
                loss_mel = criterion(_mel, mel)
                loss_mag = criterion(_mag, mag)
                loss = loss_mel + loss_mag
                loss.backward()
                # Gradient clipping
                total_norm = clip_grad_norm_(model.parameters(), max_norm=hps.clip_norm)
                # Apply gradient
                optimizer.step()
                # Adjust learning rate
                scheduler.step()
                process_time = time.time() - before_step 
                if step % hps.log_every_step == 0:
                    lr_curr = optimizer.param_groups[0]['lr']
                    log = '[{}-{}] loss: {:.3f}, grad: {:.3f}, lr: {:.3e}, time: {:.2f} + {:.2f} sec'.format(epoch, step, loss.item(), total_norm, lr_curr, load_time, process_time)
                    print(log)
                if step % hps.save_model_every_step == 0:
                    save(filepath='tmp/ckpt/ckpt_{}.pth.tar'.format(step),
                         model=model.state_dict(),
                         optimizer=optimizer.state_dict(),
                         step=step, 
                         epoch=epoch)

                if step % hps.save_result_every_step == 0:
                    sample_idx = random.randint(0, hps.batch_size-1)
                    attn_sample = _attn[sample_idx].detach().cpu().numpy()
                    mag_sample = _mag[sample_idx].detach().cpu().numpy()
                    wav_sample = mag2wav(mag_sample)
                    # Save results
                    save_alignment(attn_sample, step, 'tmp/plots/attn_{}.png'.format(step))
                    save_spectrogram(mag_sample, 'tmp/plots/spectrogram_{}.png'.format(step))
                    save_wav(wav_sample, 'tmp/results/wav_{}.wav'.format(step))
                before_load = time.time()
                step += 1
            epoch += 1