Ejemplo n.º 1
0
def dual(args):
    vocabs = {}
    opts = {}
    state_dicts = {}
    train_srcs = {}
    lms = {}

    # load model params & training data
    for i in range(2):
        model_id = (['A', 'B'])[i]
        print('loading pieces, part {:s}'.format(model_id))

        print('  load model{:s}     from [{:s}]'.format(model_id, args.nmt[i]), file=sys.stderr)
        params = torch.load(args.nmt[i], map_location=lambda storage, loc: storage)  # load model onto CPU
        vocabs[model_id] = params['vocab']
        opts[model_id] = params['args']
        state_dicts[model_id] = params['state_dict']

        print('  load train_src{:s} from [{:s}]'.format(model_id, args.src[i]), file=sys.stderr)
        train_srcs[model_id] = read_corpus(args.src[i], source='src')

        print('  load lm{:s}        from [{:s}]'.format(model_id, args.lm[i]), file=sys.stderr)
        lms[model_id] = LMProb(args.lm[i], args.dict[i])

    models = {}
    optimizers = {}

    for m in ['A', 'B']:
        # build model
        opts[m].cuda = args.cuda

        models[m] = NMT(opts[m], vocabs[m])
        models[m].load_state_dict(state_dicts[m])
        models[m].train()

        if args.cuda:
            models[m] = models[m].cuda()

        random.shuffle(train_srcs[m])

        # optimizer
        # optimizers[m] = torch.optim.Adam(models[m].parameters())
        optimizers[m] = torch.optim.SGD(models[m].parameters(), lr=1e-3, momentum=0.9)

    # loss function
    loss_nll = torch.nn.NLLLoss()
    loss_ce = torch.nn.CrossEntropyLoss()

    epoch = 0
    start = args.start_iter

    while True:
        epoch += 1
        print('\nstart of epoch {:d}'.format(epoch))

        data = {}
        data['A'] = iter(train_srcs['A'])
        data['B'] = iter(train_srcs['B'])

        start += (epoch - 1) * len(train_srcs['A']) + 1

        for t in range(start, start + len(train_srcs['A'])):
            show_log = False
            if t % args.log_every == 0:
                show_log = True

            if show_log:
                print('\nstep', t)

            for m in ['A', 'B']:
                lm_probs = []

                NLL_losses = []
                CE_losses = []

                modelA = models[m]
                modelB = models[change(m)]
                lmB = lms[change(m)]
                optimizerA = optimizers[m]
                optimizerB = optimizers[change(m)]
                vocabB = vocabs[change(m)]
                s = next(data[m])

                if show_log:
                    print('\n{:s} -> {:s}'.format(m, change(m)))
                    print('[s]', ' '.join(s))

                hyps = modelA.beam(s, beam_size=5)

                for ids, smid, dist in hyps:
                    if show_log:
                        print('[smid]', ' '.join(smid))

                    var_ids = Variable(torch.LongTensor(ids[1:]), requires_grad=False)
                    NLL_losses.append(loss_nll(dist, var_ids).cpu())

                    lm_probs.append(lmB.get_prob(smid))

                    src_sent_var = to_input_variable([smid], vocabB.src, cuda=args.cuda)
                    tgt_sent_var = to_input_variable([['<s>'] + s + ['</s>']], vocabB.tgt, cuda=args.cuda)
                    src_sent_len = [len(smid)]

                    score = modelB(src_sent_var, src_sent_len, tgt_sent_var[:-1]).squeeze(1)

                    CE_losses.append(loss_ce(score, tgt_sent_var[1:].view(-1)).cpu())

                # losses on target language
                fw_losses = torch.cat(NLL_losses)

                # losses on reconstruction
                bw_losses = torch.cat(CE_losses)

                # r1, language model reward
                r1s = Variable(torch.FloatTensor(lm_probs), requires_grad=False)
                r1s = (r1s - torch.mean(r1s)) / torch.std(r1s)

                # r2, communication reward
                r2s = Variable(bw_losses.data, requires_grad=False)
                r2s = (torch.mean(r2s) - r2s) / torch.std(r2s)

                # rk = alpha * r1 + (1 - alpha) * r2
                rks = r1s * args.alpha + r2s * (1 - args.alpha)

                # averaging loss over samples
                A_loss = torch.mean(fw_losses * rks)
                B_loss = torch.mean(bw_losses * (1 - args.alpha))

                if show_log:
                    for r1, r2, rk, fw_loss, bw_loss in zip(r1s.data.numpy(), r2s.data.numpy(), rks.data.numpy(), fw_losses.data.numpy(), bw_losses.data.numpy()):
                        print('r1={:7.4f}\t r2={:7.4f}\t rk={:7.4f}\t fw_loss={:7.4f}\t bw_loss={:7.4f}'.format(r1, r2, rk, fw_loss, bw_loss))
                    print('A loss = {:.7f} \t B loss = {:.7f}'.format(A_loss.data.numpy().item(), B_loss.data.numpy().item()))

                optimizerA.zero_grad()
                optimizerB.zero_grad()

                A_loss.backward()
                B_loss.backward()

                optimizerA.step()
                optimizerB.step()

            if t % args.save_n_iter == 0:
                print('\nsaving model')
                models['A'].save('{}.iter{}.bin'.format(args.model[0], t))
                models['B'].save('{}.iter{}.bin'.format(args.model[1], t))
Ejemplo n.º 2
0
def dual(args):
    vocabs = {}
    opts = {}
    state_dicts = {}
    train_srcs = {}
    train_tgt = {}
    lm_scores = {}
    dev_data = {}

    # load model params & training data
    for i in range(2):
        model_id = (['A', 'B'])[i]
        print('loading pieces, part {:s}'.format(model_id))
        print('  load model{:s}     from [{:s}]'.format(model_id, args.nmt[i]))
        params = torch.load(
            args.nmt[i],
            map_location=lambda storage, loc: storage)  # load model onto CPU
        vocabs[model_id] = params['vocab']
        print('==' * 10)
        print(vocabs[model_id])
        opts[model_id] = params['args']
        state_dicts[model_id] = params['state_dict']
        print('done')

    for i in range(2):
        model_id = (['A', 'B'])[i]
        print('  load train_src{:s} from [{:s}]'.format(model_id, args.src[i]))
        train_srcs[model_id], lm_scores[model_id] = read_corpus_for_dsl(
            args.src[i], source='src')
        train_tgt[model_id], _ = read_corpus_for_dsl(args.src[(i + 1) % 2],
                                                     source='tgt')

    dev_data_src1 = read_corpus(args.val[0], source='src')
    dev_data_tgt1 = read_corpus(args.val[1], source='tgt')
    dev_data['A'] = list(zip(dev_data_src1, dev_data_tgt1))
    dev_data_src2 = read_corpus(args.val[1], source='src')
    dev_data_tgt2 = read_corpus(args.val[0], source='tgt')
    dev_data['B'] = list(zip(dev_data_src2, dev_data_tgt2))

    models = {}
    optimizers = {}
    nll_loss = {}
    cross_entropy_loss = {}

    for m in ['A', 'B']:
        # build model
        opts[m].cuda = args.cuda

        models[m] = NMT(opts[m], vocabs[m])
        models[m].load_state_dict(state_dicts[m])
        models[m].train()

        if args.cuda:
            if m == 'A':
                models[m] = models[m].cuda()
            else:
                models[m] = models[m].cuda()

        optimizers[m] = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                               models[m].parameters()),
                                        lr=args.lr)
    for m in ['A', 'B']:
        vocab_mask = torch.ones(len(vocabs[m].tgt))
        vocab_mask[vocabs[m].tgt['<pad>']] = 0
        nll_loss[m] = torch.nn.NLLLoss(weight=vocab_mask, size_average=False)
        cross_entropy_loss[m] = torch.nn.CrossEntropyLoss(weight=vocab_mask,
                                                          reduce=False,
                                                          size_average=False)
        models[m].eval()
        if args.cuda:
            nll_loss[m] = nll_loss[m].cuda()
            cross_entropy_loss[m] = cross_entropy_loss[m].cuda()
    epoch = 0

    train_data = list(
        zip(train_srcs['A'], train_tgt['A'], lm_scores['A'], lm_scores['B']))
    cum_lossA = cum_lossB = 0
    att_loss = 0
    ce_lossA_log = 0
    ce_lossB_log = 0
    t = 0
    hist_valid_scores = {}
    hist_valid_scores['A'] = []
    hist_valid_scores['B'] = []

    patience = {}
    patience['A'] = patience['B'] = 0
    decay = {}
    decay['A'] = 0
    decay['B'] = 0
    while True:
        epoch += 1
        print('\nstart of epoch {:d}'.format(epoch))

        data = {}
        data['A'] = data_iter_for_dual(train_data,
                                       batch_size=args.batch_size,
                                       shuffle=False)

        for batchA in data['A']:
            src_sentsA, tgt_sentsA, src_scoresA, src_scoresB = batchA[
                0], batchA[1], batchA[2], batchA[3]
            tgt_sents_forA = [['<s>'] + sent + ['</s>'] for sent in tgt_sentsA]

            src_sents_varA, masksA = to_input_variable(src_sentsA,
                                                       vocabs['A'].src,
                                                       cuda=args.cuda)
            tgt_sents_varA, _ = to_input_variable(tgt_sents_forA,
                                                  vocabs['A'].tgt,
                                                  cuda=args.cuda)
            src_scores_varA = Variable(torch.FloatTensor(src_scoresA),
                                       requires_grad=False)

            src_sents_len_A = [len(s) for s in src_sentsA]
            # print(src_sents_varA, src_sents_len_A, tgt_sents_varA[:-1], masksA)
            scoresA, feature_A, att_sim_A = models['A'](src_sents_varA,
                                                        src_sents_len_A,
                                                        tgt_sents_varA[:-1],
                                                        masksA)

            ce_lossA = cross_entropy_loss['A'](scoresA.view(
                -1, scoresA.size(2)), tgt_sents_varA[1:].view(-1)).cpu()

            batch_data = (src_sentsA, tgt_sentsA, src_scoresA, src_scoresB)
            src_sentsA, tgt_sentsA, src_scoresA, src_scoresB = get_new_batch(
                batch_data)
            tgt_sents_forB = [['<s>'] + sent + ['</s>'] for sent in src_sentsA]

            src_sents_varB, masksB = to_input_variable(tgt_sentsA,
                                                       vocabs['B'].src,
                                                       cuda=args.cuda)
            tgt_sents_varB, _ = to_input_variable(tgt_sents_forB,
                                                  vocabs['B'].tgt,
                                                  cuda=args.cuda)
            src_scores_varB = Variable(torch.FloatTensor(src_scoresB),
                                       requires_grad=False)

            src_sents_len = [len(s) for s in tgt_sentsA]
            scoresB, feature_B, att_sim_B = models['B'](src_sents_varB,
                                                        src_sents_len,
                                                        tgt_sents_varB[:-1],
                                                        masksB)

            ce_lossB = cross_entropy_loss['B'](scoresB.view(
                -1, scoresB.size(2)), tgt_sents_varB[1:].view(-1)).cpu()

            optimizerA = optimizers['A']
            optimizerB = optimizers['B']

            optimizerA.zero_grad()
            optimizerB.zero_grad()
            # print (ce_lossA.size(), src_scores_varA.size(), tgt_sents_varA[1:].size(0))
            ce_lossA = ce_lossA.view(tgt_sents_varA[1:].size(0),
                                     tgt_sents_varA[1:].size(1)).mean(0)
            ce_lossB = ce_lossB.view(tgt_sents_varB[1:].size(0),
                                     tgt_sents_varB[1:].size(1)).mean(0)

            att_sim_A = torch.cat(att_sim_A, 1)

            masksA = masksA.transpose(1, 0).unsqueeze(1)
            masksA = masksA.expand(masksA.size(0), att_sim_A.size(1),
                                   masksA.size(2))
            assert att_sim_A.size() == masksA.size(), '{} {}'.format(
                att_sim_A.size(), masksA.size())
            att_sim_B = torch.cat(att_sim_B, 1)
            masksB = masksB.transpose(1, 0).unsqueeze(1)
            masksB = masksB.expand(masksB.size(0), att_sim_B.size(1),
                                   masksB.size(2))
            assert att_sim_B.size() == masksB.size(), '{} {}'.format(
                att_sim_B.size(), masksB.size())
            att_sim_B = att_sim_B.transpose(2, 1)
            loss_att_A = loss_att(att_sim_A, att_sim_B, masksB.transpose(1, 0),
                                  src_sents_len)
            loss_att_B = loss_att(att_sim_A.transpose(2, 1),
                                  att_sim_B.transpose(2, 1), masksB,
                                  src_sents_len_A)

            dual_loss = (src_scores_varA - ce_lossA - src_scores_varB +
                         ce_lossB)**2
            att_loss_ = (loss_att_A + loss_att_B)

            lossA = ce_lossA + args.beta1 * dual_loss + args.beta3 * att_loss_
            lossB = ce_lossB + args.beta2 * dual_loss + args.beta4 * att_loss_

            lossA = torch.mean(lossA)
            lossB = torch.mean(lossB)

            cum_lossA += lossA.data[0]
            cum_lossB += lossB.data[0]

            ce_lossA_log += torch.mean(ce_lossA).data[0]
            ce_lossB_log += torch.mean(ce_lossB).data[0]
            att_loss += (torch.mean(loss_att_A).data[0] +
                         torch.mean(loss_att_B).data[0])

            optimizerA.zero_grad()
            lossA.backward(retain_graph=True)
            grad_normA = torch.nn.utils.clip_grad_norm(
                models['A'].parameters(), args.clip_grad)
            optimizerA.step()
            optimizerB.zero_grad()
            lossB.backward()
            grad_normB = torch.nn.utils.clip_grad_norm(
                models['B'].parameters(), args.clip_grad)
            optimizerB.step()
            if t % args.log_n_iter == 0 and t != 0:
                print(
                    'epoch %d, avg. loss A %.3f, avg. word loss A %.3f, avg, loss B %.3f, avg. word loss B %.3f, avg att loss %.3f'
                    % (epoch, cum_lossA / args.log_n_iter, ce_lossA_log /
                       args.log_n_iter, cum_lossB / args.log_n_iter,
                       ce_lossB_log / args.log_n_iter,
                       att_loss / args.log_n_iter))
                cum_lossA = 0
                cum_lossB = 0
                att_loss = 0
                ce_lossA_log = 0
                ce_lossB_log = 0
            if t % args.val_n_iter == 0 and t != 0:
                print('Validation begins ...')
                for i, model_id in enumerate(['A', 'B']):
                    models[model_id].eval()

                    tmp_dev_data = dev_data[model_id]
                    dev_hyps = decode(models[model_id], tmp_dev_data)
                    dev_hyps = [hyps[0] for hyps in dev_hyps]
                    valid_metric = get_bleu([tgt for src, tgt in tmp_dev_data],
                                            dev_hyps, 'test')
                    models[model_id].train()
                    hist_scores = hist_valid_scores[model_id]
                    print('Model_id {} Sentence bleu : {}'.format(
                        model_id, valid_metric))

                    is_better = len(
                        hist_scores) == 0 or valid_metric > max(hist_scores)
                    hist_scores.append(valid_metric)

                    if not is_better:
                        patience[model_id] += 1
                        print('hit patience %d' % patience[model_id])
                        if patience[model_id] > 0:
                            if abs(optimizers[model_id].param_groups[0]
                                   ['lr']) < 1e-8:
                                exit(0)
                            if decay[model_id] < 1:
                                lr = optimizers[model_id].param_groups[0][
                                    'lr'] * 0.5
                                print('Decay learning rate to %f' % lr)
                                optimizers[model_id].param_groups[0]['lr'] = lr
                                patience[model_id] = 0
                                decay[model_id] += 1
                            else:
                                for param in models[model_id].parameters():
                                    if param.size()[0] == 50000 or param.size(
                                    )[0] == 27202:
                                        param.requires_grad = False

                                lr = optimizers[model_id].param_groups[0][
                                    'lr'] * 0.95
                                print('Decay learning rate to %f' % lr)
                                optimizers[model_id].param_groups[0]['lr'] = lr
                                decay[model_id] += 1

                    else:
                        patience[model_id] = 0
                        if model_id == 'A':
                            np.save('{}.iter{}'.format(args.model[i], t),
                                    att_sim_A[0].cpu().data.numpy())
                        if model_id == 'B':
                            np.save('{}.iter{}'.format(args.model[i], t),
                                    att_sim_B[0].cpu().data.numpy())
                        models[model_id].save('{}.iter{}.bin'.format(
                            args.model[i], t))

            t += 1
Ejemplo n.º 3
0
def dual(args):
    vocabs = {}
    opts = {}
    state_dicts = {}
    train_srcs = {}
    lms = {}

    # load model params & training data
    for i in range(2):
        model_id = (['A', 'B'])[i]
        print('loading pieces, part {:s}'.format(model_id))

        print('  load model{:s}     from [{:s}]'.format(model_id, args.nmt[i]),
              file=sys.stderr)
        params = torch.load(
            args.nmt[i],
            map_location=lambda storage, loc: storage)  # load model onto CPU
        vocabs[model_id] = params['vocab']
        opts[model_id] = params['args']
        state_dicts[model_id] = params['state_dict']

        print('  load train_src{:s} from [{:s}]'.format(model_id, args.src[i]),
              file=sys.stderr)
        train_srcs[model_id] = read_corpus(args.src[i], source='src')

        print('  load lm{:s}        from [{:s}]'.format(model_id, args.lm[i]),
              file=sys.stderr)
        lms[model_id] = LMProb(args.lm[i], args.dict[i])

    models = {}
    optimizers = {}

    for m in ['A', 'B']:
        # build model
        opts[m].cuda = args.cuda

        models[m] = NMT(opts[m], vocabs[m])
        models[m].load_state_dict(state_dicts[m])
        models[m].train()

        if args.cuda:
            models[m] = models[m].cuda()

        random.shuffle(train_srcs[m])

        # optimizer
        # optimizers[m] = torch.optim.Adam(models[m].parameters())
        optimizers[m] = torch.optim.SGD(models[m].parameters(),
                                        lr=1e-3,
                                        momentum=0.9)

    # loss function
    loss_nll = torch.nn.NLLLoss()
    loss_ce = torch.nn.CrossEntropyLoss()
    f_lossA = open(args.model[0] + ".losses", "w")
    f_lossB = open(args.model[1] + ".losses", "w")

    epoch = 0
    start = args.start_iter

    while True:
        epoch += 1
        print('\nstart of epoch {:d}'.format(epoch))

        data = {}
        data['A'] = iter(train_srcs['A'])
        data['B'] = iter(train_srcs['B'])

        start += (epoch - 1) * len(train_srcs['A']) + 1

        for t in range(start, start + len(train_srcs['A'])):
            show_log = False
            if t % args.log_every == 0:
                show_log = True

            if show_log:
                print('\nstep', t)

            for m in ['A', 'B']:
                lm_probsA = []
                lm_probsB = []

                NLL_lossesA = []
                NLL_lossesB = []

                modelA = models[m]
                modelB = models[change(m)]
                lmA = lms[m]
                lmB = lms[change(m)]
                optimizerA = optimizers[m]
                optimizerB = optimizers[change(m)]
                vocabA = vocabs[m]
                vocabB = vocabs[change(m)]
                s = next(data[m])

                if show_log:
                    print('\n{:s} -> {:s}'.format(m, change(m)))
                    print('[s]', ' '.join(s))

                hyps = modelA.beam(s, beam_size=5)

                src_sents_var = to_input_variable([s],
                                                  modelA.vocab.src,
                                                  cuda=args.cuda,
                                                  is_test=True)
                src_encoding, _ = modelA.encode(src_sents_var, [len(s)])
                src_encoding = src_encoding.squeeze(1)
                src_encoding = torch.mean(src_encoding, dim=0)

                tb_encodings = []

                for ids, smid, dist in hyps:
                    if show_log:
                        print('[smid]', ' '.join(smid))

                    var_ids = torch.LongTensor(ids[1:]).detach()
                    NLL_lossesB.append(
                        loss_nll(dist, var_ids).unsqueeze(0).cpu())
                    lm_probsB.append(lmB.get_prob(smid))

                    idback, sback, distback = modelB.beam(smid, beam_size=1)[0]
                    var_idback = torch.LongTensor(idback[1:]).detach()
                    NLL_lossesA.append(
                        loss_nll(distback, var_idback).unsqueeze(0).cpu())
                    lm_probsA.append(lmA.get_prob(sback))

                    tb_sents_var = to_input_variable([sback],
                                                     modelA.vocab.src,
                                                     cuda=args.cuda,
                                                     is_test=True)
                    tb_encoding, _ = modelA.encode(tb_sents_var, [len(sback)])
                    tb_encoding = tb_encoding.squeeze(1)
                    tb_encoding = torch.mean(tb_encoding, dim=0, keepdim=True)
                    tb_encodings.append(tb_encoding)

                # losses on target language
                fw_losses = torch.cat(NLL_lossesB)

                # losses on reconstruction
                bw_losses = torch.cat(NLL_lossesA)

                # r1, language model reward
                r1s = torch.FloatTensor(lm_probsB).detach()
                r1s = (r1s - torch.mean(r1s)) / torch.std(r1s)

                # r2, communication reward
                r2s = torch.FloatTensor(lm_probsA).detach()
                r2s = (r2s - torch.mean(r2s)) / torch.std(r2s)

                tb_encodings = torch.cat(tb_encodings).detach()
                cossim = torch.matmul(tb_encodings, src_encoding)
                cossim = 1 - torch.nn.Sigmoid()(torch.mean(cossim)).item()

                # rab = alpha * cossim + (1 - alpha) * r1
                # rba = beta  * cossim + (1 - beta ) * r2
                rkab = cossim * args.alpha + r1s * (1 - args.alpha)
                rkba = cossim * args.beta + r2s * (1 - args.beta)

                # averaging loss over samples
                A_loss = torch.mean(fw_losses * rkab)
                B_loss = torch.mean(bw_losses * rkba)

                if show_log:
                    for r1, r2, rab, rba, fw_loss, bw_loss in zip(
                            r1s.data.numpy(), r2s.data.numpy(),
                            rkab.data.numpy(), rkba.data.numpy(),
                            fw_losses.data.numpy(), bw_losses.data.numpy()):
                        print(
                            'r1={:7.4f}\t r2={:7.4f}\t rab={:7.4f}\t rba={:7.4f}\t fw_loss={:7.4f}\t bw_loss={:7.4f}'
                            .format(r1, r2, rab, rba, fw_loss, bw_loss))
                    print('A loss = {:.7f} \t B loss = {:.7f}'.format(
                        A_loss.data.numpy().item(),
                        B_loss.data.numpy().item()))
                    f_lossA.write(
                        str(t) +
                        ' {:.7f}\n'.format(A_loss.data.numpy().item()))
                    f_lossB.write(
                        str(t) +
                        ' {:.7f}\n'.format(B_loss.data.numpy().item()))

                optimizerA.zero_grad()
                optimizerB.zero_grad()

                A_loss.backward()
                B_loss.backward()

                optimizerA.step()
                optimizerB.step()

            if t % args.save_n_iter == 0:
                print('\nsaving model')
                models['A'].save('{}.iter{}.bin'.format(args.model[0], t))
                models['B'].save('{}.iter{}.bin'.format(args.model[1], t))
    f_lossA.close()
    f_lossB.close()
Ejemplo n.º 4
0
    alpha_m = np.stack(alpha, axis=1)
    plt.imshow(alpha_m, cmap='gray', interpolation='nearest')
    print 'type(src_sent) = ', type(src_sent), type(src_sent[0])
    print 'Process {} sents'.format(idx)
    print 'Src = ', ' '.join(src_sent)
    print 'Tgt = ', ' '.join(tgt_sent)
    src_sent = [unicode(s, errors='ignore') for s in src_sent]
    tgt_sent = [unicode(s, errors='ignore') for s in tgt_sent]
    plt.xticks(range(0, len(tgt_sent) - 1), tgt_sent[1:], rotation='vertical')
    plt.yticks(range(0, len(src_sent) - 1), src_sent[1:])
    plt.savefig(filename, bbox_inches='tight')
    plt.close()


filename = 'model.de-en.de3w.en2w.h_to_embed_space.affine_trans.dropout0.5.bin_alpha.npz'
d = np.load(open(filename, 'r'))
alphas = d['alpha']

src_sents = read_corpus('en-de/test.en-de.low.de')
tgt_sents = read_corpus(
    'model.de-en.de3w.en2w.h_to_embed_space.affine_trans.dropout0.5.decode')

assert len(src_sents) == len(tgt_sents), 'src={}, tgt={}'.format(
    len(src_sents), len(tgt_sents))
assert len(alphas) == len(src_sents), 'a={}, src={}'.format(
    len(alphas), len(src_sents))

for idx, a in enumerate(alphas):
    visualize(idx, src_sents[idx], tgt_sents[idx], a,
              'alignment/' + str(idx) + '.png')