Exemplo n.º 1
0
def evaluate_loss(model, data, crit):
    model.eval()
    cum_loss = 0.
    cum_tgt_words = 0.
    for src_sents, tgt_sents in data_iter(data,
                                          batch_size=args.batch_size,
                                          shuffle=False):
        pred_tgt_word_num = sum(len(s[1:])
                                for s in tgt_sents)  # omitting leading `<s>`
        src_sents_len = [len(s) for s in src_sents]

        src_sents_var = to_input_variable(src_sents,
                                          model.vocab.src,
                                          cuda=args.cuda,
                                          is_test=True)
        tgt_sents_var = to_input_variable(tgt_sents,
                                          model.vocab.tgt,
                                          cuda=args.cuda,
                                          is_test=True)

        # (tgt_sent_len, batch_size, tgt_vocab_size)
        scores = model(src_sents_var, src_sents_len, tgt_sents_var[:-1])
        loss = crit(scores.view(-1, scores.size(2)),
                    tgt_sents_var[1:].view(-1))

        cum_loss += loss.item()
        cum_tgt_words += pred_tgt_word_num

    loss = cum_loss / cum_tgt_words
    return loss
Exemplo n.º 2
0
def train_raml(args):
    tau = args.temp

    train_data_src = read_corpus(args.train_src, source='src')
    train_data_tgt = read_corpus(args.train_tgt, source='tgt')
    train_data = list(zip(train_data_src, train_data_tgt))

    dev_data_src = read_corpus(args.dev_src, source='src')
    dev_data_tgt = read_corpus(args.dev_tgt, source='tgt')
    dev_data = list(zip(dev_data_src, dev_data_tgt))

    dev_data = dev_data[:args.dev_limit]

    vocab, model, optimizer, nll_loss, cross_entropy_loss = init_training(args)

    if args.raml_sample_mode == 'pre_sample':
        # dict of (src, [tgt: (sent, prob)])
        print('read in raml training data...', file=sys.stderr, end='')
        begin_time = time.time()
        raml_samples = read_raml_train_data(args.raml_sample_file, temp=tau)
        print('done[%d s].' % (time.time() - begin_time))
    elif args.raml_sample_mode.startswith('hamming_distance'):
        print('sample from hamming distance payoff distribution')
        payoff_prob, Z_qs = generate_hamming_distance_payoff_distribution(
            max(len(sent) for sent in train_data_tgt),
            vocab_size=len(vocab.tgt) - 3,
            tau=tau)

    train_iter = patience = cum_loss = report_loss = cum_tgt_words = report_tgt_words = 0
    report_weighted_loss = cum_weighted_loss = 0
    cum_examples = cum_batches = report_examples = epoch = valid_num = best_model_iter = 0
    hist_valid_scores = []
    train_time = begin_time = time.time()
    _info = f"""
        begin RAML training
        ・学習:{len(train_data)}ペア, {args.train_log_file}
        ・テスト:{len(dev_data)}ペア, {args.valid_niter}iter毎 {args.validation_log_file}
        ・バッチサイズ:{args.batch_size}
        ・1epoch = {len(train_data)}ペア = {int(len(train_data)/args.batch_size)}iter
        """
    print(_info)

    log_data = {'args': args}

    if args.notify_slack:
        slack.post(f"""
        {_info}
        {args}
        """)

    # smoothing function for BLEU
    sm_func = None
    if args.smooth_bleu:
        sm_func = SmoothingFunction().method3

    with open(args.train_log_file,
              "w") as train_output, open(args.validation_log_file,
                                         "w") as validation_output:

        while True:
            epoch += 1
            for src_sents, tgt_sents in data_iter(train_data,
                                                  batch_size=args.batch_size):
                train_iter += 1

                raml_src_sents = []
                raml_tgt_sents = []
                raml_tgt_weights = []

                if args.raml_sample_mode == 'pre_sample':
                    for src_sent in src_sents:
                        sent = ' '.join(src_sent)
                        tgt_samples_all = raml_samples[sent]
                        # print(f'src_sent: "{sent}", target_samples_all: {len(list(tgt_samples_all))}')
                        if args.sample_size >= len(list(tgt_samples_all)):
                            tgt_samples = tgt_samples_all
                        else:
                            tgt_samples_id = np.random.choice(
                                range(1, len(list(tgt_samples_all))),
                                size=args.sample_size - 1,
                                replace=False)
                            tgt_samples = [tgt_samples_all[0]] + [
                                tgt_samples_all[i] for i in tgt_samples_id
                            ]  # make sure the ground truth y* is in the samples

                        raml_src_sents.extend([src_sent] *
                                              len(list(tgt_samples)))
                        raml_tgt_sents.extend(
                            [['<s>'] + sent.split(' ') + ['</s>']
                             for sent, weight in tgt_samples])
                        raml_tgt_weights.extend(
                            [weight for sent, weight in tgt_samples])
                elif args.raml_sample_mode in [
                        'hamming_distance', 'hamming_distance_impt_sample'
                ]:
                    for src_sent, tgt_sent in zip(src_sents, tgt_sents):
                        tgt_samples = [
                        ]  # make sure the ground truth y* is in the samples
                        tgt_sent_len = len(
                            tgt_sent
                        ) - 3  # remove <s> and </s> and ending period .
                        tgt_ref_tokens = tgt_sent[1:-1]
                        bleu_scores = []
                        # print('y*: %s' % ' '.join(tgt_sent))
                        # sample an edit distances
                        e_samples = np.random.choice(
                            range(tgt_sent_len + 1),
                            p=payoff_prob[tgt_sent_len],
                            size=args.sample_size,
                            replace=True)

                        # make sure the ground truth y* is in the samples
                        if args.raml_bias_groundtruth and (not 0 in e_samples):
                            e_samples[0] = 0

                        for i, e in enumerate(e_samples):
                            if e > 0:
                                # sample a new tgt_sent $y$
                                old_word_pos = np.random.choice(range(
                                    1, tgt_sent_len + 1),
                                                                size=e,
                                                                replace=False)
                                new_words = [
                                    vocab.tgt.id2word[wid]
                                    for wid in np.random.randint(
                                        3, len(vocab.tgt), size=e)
                                ]
                                new_tgt_sent = list(tgt_sent)
                                for pos, word in zip(old_word_pos, new_words):
                                    new_tgt_sent[pos] = word
                            else:
                                new_tgt_sent = list(tgt_sent)

                            # if enable importance sampling, compute bleu score
                            if args.raml_sample_mode == 'hamming_distance_impt_sample':
                                if e > 0:
                                    # remove <s> and </s>
                                    bleu_score = sentence_bleu(
                                        [tgt_ref_tokens],
                                        new_tgt_sent[1:-1],
                                        smoothing_function=sm_func)
                                    bleu_scores.append(bleu_score)
                                else:
                                    bleu_scores.append(1.)

                            # print('y: %s' % ' '.join(new_tgt_sent))
                            tgt_samples.append(new_tgt_sent)

                        # if enable importance sampling, compute importance weight
                        if args.raml_sample_mode == 'hamming_distance_impt_sample':
                            tgt_sample_weights = [
                                math.exp(bleu_score / tau) / math.exp(-e / tau)
                                for e, bleu_score in zip(
                                    e_samples, bleu_scores)
                            ]
                            normalizer = sum(tgt_sample_weights)
                            tgt_sample_weights = [
                                w / normalizer for w in tgt_sample_weights
                            ]
                        else:
                            tgt_sample_weights = [1.] * args.sample_size

                        raml_src_sents.extend([src_sent] * len(tgt_samples))
                        raml_tgt_sents.extend(tgt_samples)
                        raml_tgt_weights.extend(tgt_sample_weights)

                        if args.debug:
                            print('*' * 30)
                            print('Target: %s' % ' '.join(tgt_sent))
                            for tgt_sample, e, bleu_score, weight in zip(
                                    tgt_samples, e_samples, bleu_scores,
                                    tgt_sample_weights):
                                print(
                                    'Sample: %s ||| e: %d ||| bleu: %f ||| weight: %f'
                                    % (' '.join(tgt_sample), e, bleu_score,
                                       weight))
                            print()
                            break

                src_sents_var = to_input_variable(raml_src_sents,
                                                  vocab.src,
                                                  cuda=args.cuda)
                tgt_sents_var = to_input_variable(raml_tgt_sents,
                                                  vocab.tgt,
                                                  cuda=args.cuda)
                weights_var = Variable(torch.FloatTensor(raml_tgt_weights),
                                       requires_grad=False)
                if args.cuda:
                    weights_var = weights_var.cuda()

                batch_size = len(
                    raml_src_sents
                )  # batch_size = args.batch_size * args.sample_size
                src_sents_len = [len(s) for s in raml_src_sents]
                pred_tgt_word_num = sum(len(
                    s[1:]) for s in raml_tgt_sents)  # omitting leading `<s>`
                optimizer.zero_grad()

                # (tgt_sent_len, batch_size, tgt_vocab_size)
                scores = model(src_sents_var, src_sents_len,
                               tgt_sents_var[:-1])
                # (tgt_sent_len * batch_size, tgt_vocab_size)
                log_scores = F.log_softmax(scores.view(-1, scores.size(2)))
                # remove leading <s> in tgt sent, which is not used as the target
                flattened_tgt_sents = tgt_sents_var[1:].view(-1)

                # batch_size * tgt_sent_len
                tgt_log_scores = torch.gather(
                    log_scores, 1, flattened_tgt_sents.unsqueeze(1)).squeeze(1)
                unweighted_loss = -tgt_log_scores * (
                    1. - torch.eq(flattened_tgt_sents, 0).float())
                weighted_loss = unweighted_loss * weights_var.repeat(
                    scores.size(0))
                weighted_loss = weighted_loss.sum()
                weighted_loss_val = weighted_loss.item()
                nll_loss_val = unweighted_loss.sum().item()
                # weighted_log_scores = log_scores * weights.view(-1, scores.size(2))
                # weighted_loss = nll_loss(weighted_log_scores, flattened_tgt_sents)

                loss = weighted_loss / batch_size
                # nll_loss_val = nll_loss(log_scores, flattened_tgt_sents).item()

                loss.backward()
                # clip gradient
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), args.clip_grad)
                optimizer.step()

                report_weighted_loss += weighted_loss_val
                cum_weighted_loss += weighted_loss_val
                report_loss += nll_loss_val
                cum_loss += nll_loss_val
                report_tgt_words += pred_tgt_word_num
                cum_tgt_words += pred_tgt_word_num
                report_examples += batch_size
                cum_examples += batch_size
                cum_batches += batch_size

                if train_iter % args.log_every == 0 or train_iter % args.notify_slack_every == 0:
                    _log = 'epoch %d, iter %d, avg. loss %.2f, avg. ppl %.2f cum. examples %d, speed %.2f words/sec, time elapsed %.2f sec' % (
                        epoch, train_iter,
                        report_weighted_loss / report_examples,
                        np.exp(report_loss / report_tgt_words), cum_examples,
                        report_tgt_words /
                        (time.time() - train_time), time.time() - begin_time)
                    print(_log)
                    print(_log, file=train_output)

                    _list_dict_update(
                        log_data, {
                            'epoch': epoch,
                            'train_iter': train_iter,
                            'loss': report_weighted_loss / report_examples,
                            'ppl': np.exp(report_loss / report_tgt_words),
                            'examples': cum_examples,
                            'speed': report_tgt_words /
                            (time.time() - train_time),
                            'elapsed': time.time() - begin_time
                        }, 'train')

                    train_time = time.time()
                    report_loss = report_weighted_loss = report_tgt_words = report_examples = 0.
                    if train_iter % args.notify_slack_every == 0 and args.notify_slack:
                        print('post slack')
                        slack.post(_log)

                # perform validation
                if train_iter % args.valid_niter == 0:
                    print('epoch %d, iter %d, cum. loss %.2f, '
                          'cum. ppl %.2f cum. examples %d' %
                          (epoch, train_iter, cum_weighted_loss / cum_batches,
                           np.exp(cum_loss / cum_tgt_words), cum_examples),
                          file=sys.stderr)

                    cum_loss = cum_weighted_loss = cum_batches = cum_tgt_words = 0.
                    valid_num += 1

                    print('begin validation ...')
                    model.eval()

                    # compute dev. ppl and bleu

                    dev_loss = evaluate_loss(model, dev_data,
                                             cross_entropy_loss)
                    dev_ppl = np.exp(dev_loss)

                    if args.valid_metric in ['bleu', 'word_acc', 'sent_acc']:
                        dev_hyps = decode(model,
                                          dev_data,
                                          f=validation_output,
                                          verbose=False)
                        dev_hyps = [hyps[0] for hyps in dev_hyps]
                        if args.valid_metric == 'bleu':
                            valid_metric = get_bleu(
                                [tgt for src, tgt in dev_data], dev_hyps)
                        else:
                            valid_metric = get_acc(
                                [tgt for src, tgt in dev_data],
                                dev_hyps,
                                acc_type=args.valid_metric)
                        _log = 'validation: iter %d, dev. ppl %f, dev. %s %f' % (
                            train_iter, dev_ppl, args.valid_metric,
                            valid_metric)
                        print(_log)
                        print(_log, file=validation_output)
                        if args.notify_slack:
                            slack.post(_log)

                    else:
                        valid_metric = -dev_ppl
                        print('validation: iter %d, dev. ppl %f' %
                              (train_iter, dev_ppl),
                              file=sys.stderr)

                    if 'dev_data' in log_data:
                        log_data['dev_data'] = dev_data

                    _list_dict_update(log_data, {
                        'epoch': epoch,
                        'train_iter': train_iter,
                        'loss': dev_loss,
                        'ppl': dev_ppl,
                        args.valid_metric: valid_metric,
                        'hyps': dev_hyps,
                    },
                                      'validation',
                                      is_save=True)

                    model.train()

                    is_better = len(
                        hist_valid_scores
                    ) == 0 or valid_metric > max(hist_valid_scores)
                    is_better_than_last = len(
                        hist_valid_scores
                    ) == 0 or valid_metric > hist_valid_scores[-1]
                    hist_valid_scores.append(valid_metric)

                    if valid_num > args.save_model_after:
                        model_file = args.save_to + '.iter%d.bin' % train_iter
                        print('save model to [%s]' % model_file)
                        model.save(model_file)

                    if (not is_better_than_last) and args.lr_decay:
                        lr = optimizer.param_groups[0]['lr'] * args.lr_decay
                        print('decay learning rate to %f' % lr)
                        optimizer.param_groups[0]['lr'] = lr

                    if is_better:
                        patience = 0
                        best_model_iter = train_iter

                        if valid_num > args.save_model_after:
                            print('save currently the best model ..')
                            model_file_abs_path = os.path.abspath(model_file)
                            symlin_file_abs_path = os.path.abspath(
                                args.save_to + '.bin')
                            os.system(
                                'ln -sf %s %s' %
                                (model_file_abs_path, symlin_file_abs_path))
                    else:
                        patience += 1
                        print('hit patience %d' % patience)
                        if patience == args.patience:
                            _log = f"""
                            {'hit patience %d' % patience}
                            early stop!
                            {'the best model is from iteration [%d]' % best_model_iter}
                            """
                            print(_log)
                            if args.notify_slack:
                                slack.post(_log)
                            exit(0)

                if args.debug:
                    print(f'debug epoch:{epoch} exit')
                    model_file = args.save_to + '.bin'
                    print('save model to [%s]' % model_file)
                    model.save(model_file)
                    exit(0)
Exemplo n.º 3
0
def train(args):
    train_data_src = read_corpus(args.train_src, source='src')
    train_data_tgt = read_corpus(args.train_tgt, source='tgt')

    dev_data_src = read_corpus(args.dev_src, source='src')
    dev_data_tgt = read_corpus(args.dev_tgt, source='tgt')

    train_data = list(zip(train_data_src, train_data_tgt))
    dev_data = list(zip(dev_data_src, dev_data_tgt))

    dev_data = dev_data[:args.dev_limit]

    vocab, model, optimizer, nll_loss, cross_entropy_loss = init_training(args)

    train_iter = patience = cum_loss = report_loss = cum_tgt_words = report_tgt_words = 0
    cum_examples = cum_batches = report_examples = epoch = valid_num = best_model_iter = 0
    hist_valid_scores = []
    train_time = begin_time = time.time()

    _info = f"""
        begin Maximum Likelihood training
        ・学習:{len(train_data)}ペア, {args.train_log_file}
        ・テスト:{len(dev_data)}ペア, {args.valid_niter}iter毎 {args.validation_log_file}
        ・バッチサイズ:{args.batch_size}
        ・1epoch = {len(train_data)}ペア = {int(len(train_data)/args.batch_size)}iter
        """
    print(_info)

    if args.notify_slack:
        slack.post(f"""
        {_info}
        {args}
        """)

    with open(args.train_log_file,
              "w") as train_output, open(args.validation_log_file,
                                         "w") as validation_output:

        while True:
            epoch += 1
            for src_sents, tgt_sents in data_iter(train_data,
                                                  batch_size=args.batch_size):
                train_iter += 1

                src_sents_var = to_input_variable(src_sents,
                                                  vocab.src,
                                                  cuda=args.cuda)
                tgt_sents_var = to_input_variable(tgt_sents,
                                                  vocab.tgt,
                                                  cuda=args.cuda)

                batch_size = len(src_sents)
                src_sents_len = [len(s) for s in src_sents]
                pred_tgt_word_num = sum(
                    len(s[1:]) for s in tgt_sents)  # omitting leading `<s>`

                optimizer.zero_grad()

                # (tgt_sent_len, batch_size, tgt_vocab_size)
                scores = model(src_sents_var, src_sents_len,
                               tgt_sents_var[:-1])

                word_loss = cross_entropy_loss(scores.view(-1, scores.size(2)),
                                               tgt_sents_var[1:].view(-1))
                loss = word_loss / batch_size
                word_loss_val = word_loss.item()
                loss_val = loss.item()

                loss.backward()
                # clip gradient
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), args.clip_grad)
                optimizer.step()

                report_loss += word_loss_val
                cum_loss += word_loss_val
                report_tgt_words += pred_tgt_word_num
                cum_tgt_words += pred_tgt_word_num
                report_examples += batch_size
                cum_examples += batch_size
                cum_batches += batch_size

                if train_iter % args.log_every == 0:
                    _log = 'epoch %d, iter %d, avg. loss %.2f, avg. ppl %.2f ' \
                           'cum. examples %d, speed %.2f words/sec, time elapsed %.2f sec' % (epoch, train_iter,
                                                                                              report_loss / report_examples,
                                                                                              np.exp(
                                                                                                  report_loss / report_tgt_words),
                                                                                              cum_examples,
                                                                                              report_tgt_words / (
                                                                                                      time.time() - train_time),
                                                                                              time.time() - begin_time)
                    print(_log)
                    print(_log, file=train_output)

                    train_time = time.time()
                    report_loss = report_tgt_words = report_examples = 0.

                # perform validation
                if train_iter % args.valid_niter == 0:
                    print(
                        'epoch %d, iter %d, cum. loss %.2f, cum. ppl %.2f cum. examples %d'
                        % (epoch, train_iter, cum_loss / cum_batches,
                           np.exp(cum_loss / cum_tgt_words), cum_examples),
                        file=sys.stderr)

                    cum_loss = cum_batches = cum_tgt_words = 0.
                    valid_num += 1

                    print('begin validation ...', file=sys.stderr)
                    model.eval()

                    # compute dev. ppl and bleu

                    dev_loss = evaluate_loss(model, dev_data,
                                             cross_entropy_loss)
                    dev_ppl = np.exp(dev_loss)

                    if args.valid_metric in ['bleu', 'word_acc', 'sent_acc']:
                        dev_hyps = decode(model, dev_data)
                        dev_hyps = [hyps[0] for hyps in dev_hyps]
                        if args.valid_metric == 'bleu':
                            valid_metric = get_bleu(
                                [tgt for src, tgt in dev_data], dev_hyps)
                        else:
                            valid_metric = get_acc(
                                [tgt for src, tgt in dev_data],
                                dev_hyps,
                                acc_type=args.valid_metric)
                        _log = 'validation: iter %d, dev. ppl %f, dev. %s %f' % (
                            train_iter, dev_ppl, args.valid_metric,
                            valid_metric)
                        print(_log, file=sys.stderr)
                        print(_log, file=validation_output)
                        if args.notify_slack:
                            slack.post(_log)

                    else:
                        valid_metric = -dev_ppl
                        print('validation: iter %d, dev. ppl %f' %
                              (train_iter, dev_ppl),
                              file=sys.stderr)

                    model.train()

                    is_better = len(
                        hist_valid_scores
                    ) == 0 or valid_metric > max(hist_valid_scores)
                    is_better_than_last = len(
                        hist_valid_scores
                    ) == 0 or valid_metric > hist_valid_scores[-1]
                    hist_valid_scores.append(valid_metric)

                    if valid_num > args.save_model_after:
                        model_file = args.save_to + '.iter%d.bin' % train_iter
                        print('save model to [%s]' % model_file,
                              file=sys.stderr)
                        model.save(model_file)

                    if (not is_better_than_last) and args.lr_decay:
                        lr = optimizer.param_groups[0]['lr'] * args.lr_decay
                        print('decay learning rate to %f' % lr,
                              file=sys.stderr)
                        optimizer.param_groups[0]['lr'] = lr

                    if is_better:
                        patience = 0
                        best_model_iter = train_iter

                        if valid_num > args.save_model_after:
                            print('save currently the best model ..',
                                  file=sys.stderr)
                            model_file_abs_path = os.path.abspath(model_file)
                            symlin_file_abs_path = os.path.abspath(
                                args.save_to + '.bin')
                            os.system(
                                'ln -sf %s %s' %
                                (model_file_abs_path, symlin_file_abs_path))
                    else:
                        patience += 1
                        print('hit patience %d' % patience, file=sys.stderr)
                        if patience == args.patience:
                            print('early stop!', file=sys.stderr)
                            print('the best model is from iteration [%d]' %
                                  best_model_iter,
                                  file=sys.stderr)
                            exit(0)
Exemplo n.º 4
0
def compute_lm_prob(args):
    """
    given source-target sentence pairs, compute ppl and log-likelihood
    """
    test_data_src = read_corpus(args.test_src, source='src')
    test_data_tgt = read_corpus(args.test_tgt, source='tgt')
    test_data = zip(test_data_src, test_data_tgt)

    if args.load_model:
        print('load model from [%s]' % args.load_model)
        params = torch.load(args.load_model,
                            map_location=lambda storage, loc: storage)
        vocab = params['vocab']
        saved_args = params['args']
        state_dict = params['state_dict']

        model = NMT(saved_args, vocab)
        model.load_state_dict(state_dict)
    else:
        vocab = torch.load(args.vocab)
        model = NMT(args, vocab)

    model.eval()

    if args.cuda:
        # model = nn.DataParallel(model).cuda()
        model = model.cuda()

    f = open(args.save_to_file, 'w')
    for src_sent, tgt_sent in test_data:
        src_sents = [src_sent]
        tgt_sents = [tgt_sent]

        batch_size = len(src_sents)
        src_sents_len = [len(s) for s in src_sents]
        pred_tgt_word_nums = [len(s[1:])
                              for s in tgt_sents]  # omitting leading `<s>`

        # (sent_len, batch_size)
        src_sents_var = to_input_variable(src_sents,
                                          model.vocab.src,
                                          cuda=args.cuda,
                                          is_test=True)
        tgt_sents_var = to_input_variable(tgt_sents,
                                          model.vocab.tgt,
                                          cuda=args.cuda,
                                          is_test=True)

        # (tgt_sent_len, batch_size, tgt_vocab_size)
        scores = model(src_sents_var, src_sents_len, tgt_sents_var[:-1])
        # (tgt_sent_len * batch_size, tgt_vocab_size)
        log_scores = F.log_softmax(scores.view(-1, scores.size(2)))
        # remove leading <s> in tgt sent, which is not used as the target
        # (batch_size * tgt_sent_len)
        flattened_tgt_sents = tgt_sents_var[1:].view(-1)
        # (batch_size * tgt_sent_len)
        tgt_log_scores = torch.gather(
            log_scores, 1, flattened_tgt_sents.unsqueeze(1)).squeeze(1)
        # 0-index is the <pad> symbol
        tgt_log_scores = tgt_log_scores * (
            1. - torch.eq(flattened_tgt_sents, 0).float())
        # (tgt_sent_len, batch_size)
        tgt_log_scores = tgt_log_scores.view(-1, batch_size)  # .permute(1, 0)
        # (batch_size)
        tgt_sent_scores = tgt_log_scores.sum(dim=0).squeeze()
        tgt_sent_word_scores = [
            tgt_sent_scores[i].item() / pred_tgt_word_nums[i]
            for i in range(batch_size)
        ]

        for src_sent, tgt_sent, score in zip(src_sents, tgt_sents,
                                             tgt_sent_word_scores):
            f.write('%s ||| %s ||| %f\n' %
                    (' '.join(src_sent), ' '.join(tgt_sent), score))

    f.close()
Exemplo n.º 5
0
    def sample(self, src_sents, sample_size=None, to_word=False):
        if not type(src_sents[0]) == list:
            src_sents = [src_sents]
        if not sample_size:
            sample_size = self.args.sample_size

        src_sents_num = len(src_sents)
        batch_size = src_sents_num * sample_size

        src_sents_var = to_input_variable(src_sents,
                                          self.vocab.src,
                                          cuda=self.args.cuda,
                                          is_test=True)
        src_encoding, (dec_init_state, dec_init_cell) = self.encode(
            src_sents_var, [len(s) for s in src_sents])

        dec_init_state = dec_init_state.repeat(sample_size, 1)
        dec_init_cell = dec_init_cell.repeat(sample_size, 1)
        hidden = (dec_init_state, dec_init_cell)

        # tile everything
        # if args.sample_method == 'expand':
        #     # src_enc: (src_sent_len, sample_size, enc_size)
        #     # cat result: (src_sent_len, batch_size * sample_size, enc_size)
        #     src_encoding = torch.cat([src_enc.expand(src_enc.size(0), sample_size, src_enc.size(2)) for src_enc in src_encoding.split(1, dim=1)], 1)
        #     dec_init_state = torch.cat([x.expand(sample_size, x.size(1)) for x in dec_init_state.split(1, dim=0)], 0)
        #     dec_init_cell = torch.cat([x.expand(sample_size, x.size(1)) for x in dec_init_cell.split(1, dim=0)], 0)
        # elif args.sample_method == 'repeat':

        src_encoding = src_encoding.repeat(1, sample_size, 1)
        src_encoding_att_linear = tensor_transform(self.att_src_linear,
                                                   src_encoding)
        src_encoding = src_encoding.permute(1, 0, 2)
        src_encoding_att_linear = src_encoding_att_linear.permute(1, 0, 2)

        new_tensor = dec_init_state.data.new
        att_tm1 = Variable(new_tensor(batch_size,
                                      self.args.hidden_size).zero_(),
                           volatile=True)
        y_0 = Variable(torch.LongTensor(
            [self.vocab.tgt['<s>'] for _ in range(batch_size)]),
                       volatile=True)

        eos = self.vocab.tgt['</s>']
        # eos_batch = torch.LongTensor([eos] * batch_size)
        sample_ends = torch.ByteTensor([0] * batch_size)
        all_ones = torch.ByteTensor([1] * batch_size)
        if self.args.cuda:
            y_0 = y_0.cuda()
            sample_ends = sample_ends.cuda()
            all_ones = all_ones.cuda()

        samples = [y_0]

        t = 0
        while t < self.args.decode_max_time_step:
            t += 1

            # (sample_size)
            y_tm1 = samples[-1]

            y_tm1_embed = self.tgt_embed(y_tm1)

            x = torch.cat([y_tm1_embed, att_tm1], 1)

            # h_t: (batch_size, hidden_size)
            h_t, cell_t = self.decoder_lstm(x, hidden)
            h_t = self.dropout(h_t)

            ctx_t, alpha_t = self.dot_prod_attention(h_t, src_encoding,
                                                     src_encoding_att_linear)

            att_t = torch.tanh(self.att_vec_linear(torch.cat([h_t, ctx_t],
                                                             1)))  # E.q. (5)
            att_t = self.dropout(att_t)

            score_t = self.readout(att_t)  # E.q. (6)
            p_t = F.softmax(score_t)

            if self.args.sample_method == 'random':
                y_t = torch.multinomial(p_t, num_samples=1).squeeze(1)
            elif self.args.sample_method == 'greedy':
                _, y_t = torch.topk(p_t, k=1, dim=1)
                y_t = y_t.squeeze(1)

            samples.append(y_t)

            sample_ends |= torch.eq(y_t, eos).byte().data
            if torch.equal(sample_ends, all_ones):
                break

            # if torch.equal(y_t.data, eos_batch):
            #     break

            att_tm1 = att_t
            hidden = h_t, cell_t

        # post-processing
        completed_samples = [
            list([list() for _ in range(sample_size)])
            for _ in range(src_sents_num)
        ]
        for y_t in samples:
            for i, sampled_word in enumerate(y_t.cpu().data):
                src_sent_id = i % src_sents_num
                sample_id = i / src_sents_num
                if len(
                        completed_samples[src_sent_id][sample_id]
                ) == 0 or completed_samples[src_sent_id][sample_id][-1] != eos:
                    completed_samples[src_sent_id][sample_id].append(
                        sampled_word)

        if to_word:
            for i, src_sent_samples in enumerate(completed_samples):
                completed_samples[i] = word2id(src_sent_samples,
                                               self.vocab.tgt.id2word)

        return completed_samples
Exemplo n.º 6
0
    def translate(self, src_sents, beam_size=None, to_word=True):
        """
        perform beam search
        TODO: batched beam search
        """
        if not type(src_sents[0]) == list:
            src_sents = [src_sents]
        if not beam_size:
            beam_size = self.args.beam_size

        src_sents_var = to_input_variable(src_sents,
                                          self.vocab.src,
                                          cuda=self.args.cuda,
                                          is_test=True)

        src_encoding, dec_init_vec = self.encode(src_sents_var,
                                                 [len(src_sents[0])])
        src_encoding_att_linear = tensor_transform(self.att_src_linear,
                                                   src_encoding)

        init_state = dec_init_vec[0]
        init_cell = dec_init_vec[1]
        hidden = (init_state, init_cell)

        att_tm1 = Variable(torch.zeros(1, self.args.hidden_size),
                           volatile=True)
        hyp_scores = Variable(torch.zeros(1), volatile=True)
        if self.args.cuda:
            att_tm1 = att_tm1.cuda()
            hyp_scores = hyp_scores.cuda()

        eos_id = self.vocab.tgt['</s>']
        bos_id = self.vocab.tgt['<s>']
        tgt_vocab_size = len(self.vocab.tgt)

        hypotheses = [[bos_id]]
        completed_hypotheses = []
        completed_hypothesis_scores = []

        t = 0
        while len(completed_hypotheses
                  ) < beam_size and t < self.args.decode_max_time_step:
            t += 1
            hyp_num = len(hypotheses)

            expanded_src_encoding = src_encoding.expand(
                src_encoding.size(0), hyp_num, src_encoding.size(2))
            expanded_src_encoding_att_linear = src_encoding_att_linear.expand(
                src_encoding_att_linear.size(0), hyp_num,
                src_encoding_att_linear.size(2))

            y_tm1 = Variable(torch.LongTensor([hyp[-1] for hyp in hypotheses]),
                             volatile=True)
            if self.args.cuda:
                y_tm1 = y_tm1.cuda()

            y_tm1_embed = self.tgt_embed(y_tm1)

            x = torch.cat([y_tm1_embed, att_tm1], 1)

            # h_t: (hyp_num, hidden_size)
            h_t, cell_t = self.decoder_lstm(x, hidden)
            h_t = self.dropout(h_t)

            ctx_t, alpha_t = self.dot_prod_attention(
                h_t, expanded_src_encoding.permute(1, 0, 2),
                expanded_src_encoding_att_linear.permute(1, 0, 2))

            att_t = torch.tanh(self.att_vec_linear(torch.cat([h_t, ctx_t], 1)))
            att_t = self.dropout(att_t)

            score_t = self.readout(att_t)
            p_t = F.log_softmax(score_t)

            live_hyp_num = beam_size - len(completed_hypotheses)
            new_hyp_scores = (hyp_scores.unsqueeze(1).expand_as(p_t) +
                              p_t).view(-1)
            top_new_hyp_scores, top_new_hyp_pos = torch.topk(new_hyp_scores,
                                                             k=live_hyp_num)
            prev_hyp_ids = top_new_hyp_pos / tgt_vocab_size
            word_ids = top_new_hyp_pos % tgt_vocab_size
            # new_hyp_scores = new_hyp_scores[top_new_hyp_pos.data]

            new_hypotheses = []

            live_hyp_ids = []
            new_hyp_scores = []
            for prev_hyp_id, word_id, new_hyp_score in zip(
                    prev_hyp_ids.cpu().data,
                    word_ids.cpu().data,
                    top_new_hyp_scores.cpu().data):
                hyp_tgt_words = hypotheses[prev_hyp_id] + [word_id]
                if word_id == eos_id:
                    completed_hypotheses.append(hyp_tgt_words)
                    completed_hypothesis_scores.append(new_hyp_score)
                else:
                    new_hypotheses.append(hyp_tgt_words)
                    live_hyp_ids.append(prev_hyp_id)
                    new_hyp_scores.append(new_hyp_score)

            if len(completed_hypotheses) == beam_size:
                break

            live_hyp_ids = torch.LongTensor(live_hyp_ids)
            if self.args.cuda:
                live_hyp_ids = live_hyp_ids.cuda()

            hidden = (h_t[live_hyp_ids], cell_t[live_hyp_ids])
            att_tm1 = att_t[live_hyp_ids]

            hyp_scores = Variable(
                torch.FloatTensor(new_hyp_scores),
                volatile=True)  # new_hyp_scores[live_hyp_ids]
            if self.args.cuda:
                hyp_scores = hyp_scores.cuda()
            hypotheses = new_hypotheses

        if len(completed_hypotheses) == 0:
            completed_hypotheses = [hypotheses[0]]
            completed_hypothesis_scores = [0.0]

        if to_word:
            for i, hyp in enumerate(completed_hypotheses):
                completed_hypotheses[i] = [
                    self.vocab.tgt.id2word[int(w)] for w in hyp
                ]

        ranked_hypotheses = sorted(zip(completed_hypotheses,
                                       completed_hypothesis_scores),
                                   key=lambda x: x[1],
                                   reverse=True)

        return [hyp for hyp, score in ranked_hypotheses]