示例#1
0
class Translator(object):
    def __init__(self, args):
        super(Translator, self).__init__()
        self.config = getattr(configurations, args.proto)()
        self.logger = ut.get_logger(self.config['log_file'])

        self.input_file = args.input_file
        self.model_file = args.model_file

        if self.input_file is None or self.model_file is None or not os.path.exists(
                self.input_file) or not os.path.exists(self.model_file):
            raise ValueError('Input file or model file does not exist')

        self.data_manager = DataManager(self.config)
        self.translate()

    def ids_to_trans(self, trans_ids):
        words = []
        word_ids = []
        for idx, word_idx in enumerate(trans_ids):
            if word_idx == ac.EOS_ID:
                break
            words.append(self.data_manager.trg_ivocab[word_idx])
            word_ids.append(word_idx)

        return u' '.join(words), word_ids

    def get_trans(self, probs, scores, symbols):
        sorted_rows = numpy.argsort(scores)[::-1]
        best_trans = None
        best_tran_ids = None
        beam_trans = []
        for i, r in enumerate(sorted_rows):
            trans_ids = symbols[r]
            trans_out, trans_out_ids = self.ids_to_trans(trans_ids)
            beam_trans.append(u'{} {:.2f} {:.2f}'.format(
                trans_out, scores[r], probs[r]))
            if i == 0:  # highest prob trans
                best_trans = trans_out
                best_tran_ids = trans_out_ids

        return best_trans, best_tran_ids, u'\n'.join(beam_trans)

    def plot_head_map(self, mma, target_labels, target_ids, source_labels,
                      source_ids, filename):
        """https://github.com/EdinburghNLP/nematus/blob/master/utils/plot_heatmap.py
        Change the font in family param below. If the system font is not used, delete matplotlib
        font cache https://github.com/matplotlib/matplotlib/issues/3590
        """
        fig, ax = plt.subplots()
        heatmap = ax.pcolor(mma, cmap=plt.cm.Blues)

        # put the major ticks at the middle of each cell
        ax.set_xticks(numpy.arange(mma.shape[1]) + 0.5, minor=False)
        ax.set_yticks(numpy.arange(mma.shape[0]) + 0.5, minor=False)

        # without this I get some extra columns rows
        # http://stackoverflow.com/questions/31601351/why-does-this-matplotlib-heatmap-have-an-extra-blank-column
        ax.set_xlim(0, int(mma.shape[1]))
        ax.set_ylim(0, int(mma.shape[0]))

        # want a more natural, table-like display
        ax.invert_yaxis()
        ax.xaxis.tick_top()

        # source words -> column labels
        ax.set_xticklabels(source_labels,
                           minor=False,
                           family='Source Code Pro')
        for xtick, idx in zip(ax.get_xticklabels(), source_ids):
            if idx == ac.UNK_ID:
                xtick.set_color('b')
        # target words -> row labels
        ax.set_yticklabels(target_labels,
                           minor=False,
                           family='Source Code Pro')
        for ytick, idx in zip(ax.get_yticklabels(), target_ids):
            if idx == ac.UNK_ID:
                ytick.set_color('b')

        plt.xticks(rotation=45)

        plt.tight_layout()
        plt.savefig(filename)
        plt.close('all')

    def translate(self):
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        model = Model(self.config).to(device)
        self.logger.info('Restore model from {}'.format(self.model_file))
        model.load_state_dict(torch.load(self.model_file))
        model.eval()

        best_trans_file = self.input_file + '.best_trans'
        beam_trans_file = self.input_file + '.beam_trans'
        open(best_trans_file, 'w').close()
        open(beam_trans_file, 'w').close()

        num_sents = 0
        with open(self.input_file, 'r') as f:
            for line in f:
                if line.strip():
                    num_sents += 1
        all_best_trans = [''] * num_sents
        all_beam_trans = [''] * num_sents

        with torch.no_grad():
            self.logger.info('Start translating {}'.format(self.input_file))
            start = time.time()
            count = 0
            for (src_toks, original_idxs) in self.data_manager.get_trans_input(
                    self.input_file):
                src_toks_cuda = src_toks.to(device)
                rets = model.beam_decode(src_toks_cuda)

                for i, ret in enumerate(rets):
                    probs = ret['probs'].cpu().detach().numpy().reshape([-1])
                    scores = ret['scores'].cpu().detach().numpy().reshape([-1])
                    symbols = ret['symbols'].cpu().detach().numpy()

                    best_trans, best_trans_ids, beam_trans = self.get_trans(
                        probs, scores, symbols)
                    all_best_trans[original_idxs[i]] = best_trans + '\n'
                    all_beam_trans[original_idxs[i]] = beam_trans + '\n\n'

                    count += 1
                    if count % 100 == 0:
                        self.logger.info(
                            '  Translating line {}, average {} seconds/sent'.
                            format(count, (time.time() - start) / count))

        model.train()

        with open(best_trans_file, 'w') as ftrans, open(beam_trans_file,
                                                        'w') as btrans:
            ftrans.write(''.join(all_best_trans))
            btrans.write(''.join(all_beam_trans))

        self.logger.info('Done translating {}, it takes {} minutes'.format(
            self.input_file,
            float(time.time() - start) / 60.0))
class Translator(object):
    def __init__(self, args):
        super(Translator, self).__init__()
        self.config = getattr(configurations, args.proto)()
        self.reverse = self.config['reverse']
        self.logger = ut.get_logger(self.config['log_file'])

        self.input_file = args.input_file
        self.model_file = args.model_file
        self.plot_align = args.plot_align
        self.unk_repl   = args.unk_repl

        if self.input_file is None or self.model_file is None or not os.path.exists(self.input_file) or not os.path.exists(self.model_file + '.meta'):
            raise ValueError('Input file or model file does not exist')

        self.data_manager = DataManager(self.config)
        _, self.src_ivocab = self.data_manager.init_vocab(self.data_manager.src_lang)
        _, self.trg_ivocab = self.data_manager.init_vocab(self.data_manager.trg_lang)
        self.translate()

    def ids_to_trans(self, trans_ids, trans_alignments, no_unk_src_toks):
        words = []
        word_ids = []
        # Could have done better but this is clearer to me

        if not self.unk_repl:
            for idx, word_idx in enumerate(trans_ids):
                if word_idx == ac.EOS_ID:
                    break

                words.append(self.trg_ivocab[word_idx])
                word_ids.append(word_idx)
        else:
            for idx, word_idx in enumerate(trans_ids):
                if word_idx == ac.EOS_ID:
                    break
                    
                if word_idx == ac.UNK_ID:
                    # Replace UNK with higest attention source words
                    alignment = trans_alignments[idx]
                    highest_att_src_tok_pos = numpy.argmax(alignment)
                    words.append(no_unk_src_toks[highest_att_src_tok_pos])
                else:
                    words.append(self.trg_ivocab[word_idx])
                word_ids.append(word_idx)

        return u' '.join(words), word_ids

    def get_model(self, mode):
        reuse = mode != ac.TRAINING
        d = self.config['init_range']
        initializer = tf.random_uniform_initializer(-d, d)
        with tf.variable_scope(self.config['model_name'], reuse=reuse, initializer=initializer):
            return Model(self.config, mode)

    def get_trans(self, probs, scores, symbols, parents, alignments, no_unk_src_toks):
        sorted_rows = numpy.argsort(scores[:, -1])[::-1]
        best_trans_alignments = []
        best_trans = None
        best_tran_ids = None
        beam_trans = []
        for i, r in enumerate(sorted_rows):
            row_idx = r
            col_idx = scores.shape[1] - 1

            trans_ids = []
            trans_alignments = []
            while True:
                if col_idx < 0:
                    break

                trans_ids.append(symbols[row_idx, col_idx])
                align = alignments[row_idx, col_idx, :]
                trans_alignments.append(align)

                if i == 0:
                    best_trans_alignments.append(align if not self.reverse else align[::-1])

                row_idx = parents[row_idx, col_idx]
                col_idx -= 1

            trans_ids = trans_ids[::-1]
            trans_alignments = trans_alignments[::-1]
            trans_out, trans_out_ids = self.ids_to_trans(trans_ids, trans_alignments, no_unk_src_toks)
            beam_trans.append(u'{} {:.2f} {:.2f}'.format(trans_out, scores[r, -1], probs[r, -1]))
            if i == 0: # highest prob trans
                best_trans = trans_out
                best_tran_ids = trans_out_ids

        return best_trans, best_tran_ids, u'\n'.join(beam_trans), best_trans_alignments[::-1]

    def plot_head_map(self, mma, target_labels, target_ids, source_labels, source_ids, filename):
        """https://github.com/EdinburghNLP/nematus/blob/master/utils/plot_heatmap.py
        Change the font in family param below. If the system font is not used, delete matplotlib 
        font cache https://github.com/matplotlib/matplotlib/issues/3590
        """
        fig, ax = plt.subplots()
        heatmap = ax.pcolor(mma, cmap=plt.cm.Blues)

        # put the major ticks at the middle of each cell
        ax.set_xticks(numpy.arange(mma.shape[1]) + 0.5, minor=False)
        ax.set_yticks(numpy.arange(mma.shape[0]) + 0.5, minor=False)

        # without this I get some extra columns rows
        # http://stackoverflow.com/questions/31601351/why-does-this-matplotlib-heatmap-have-an-extra-blank-column
        ax.set_xlim(0, int(mma.shape[1]))
        ax.set_ylim(0, int(mma.shape[0]))

        # want a more natural, table-like display
        ax.invert_yaxis()
        ax.xaxis.tick_top()

        # source words -> column labels
        ax.set_xticklabels(source_labels, minor=False, family='Source Code Pro')
        for xtick, idx in zip(ax.get_xticklabels(), source_ids):
            if idx == ac.UNK_ID:
                xtick.set_color('b')
        # target words -> row labels
        ax.set_yticklabels(target_labels, minor=False, family='Source Code Pro')
        for ytick, idx in zip(ax.get_yticklabels(), target_ids):
            if idx == ac.UNK_ID:
                ytick.set_color('b')

        plt.xticks(rotation=45)

        plt.tight_layout()
        plt.savefig(filename)
        plt.close('all')

    def translate(self):
        with tf.Graph().as_default():
            train_model = self.get_model(ac.TRAINING)
            model = self.get_model(ac.VALIDATING)

            with tf.Session() as sess:
                self.logger.info('Restore model from {}'.format(self.model_file))
                saver = tf.train.Saver(var_list=tf.trainable_variables())
                saver.restore(sess, self.model_file)

                best_trans_file = self.input_file + '.best_trans'
                beam_trans_file = self.input_file + '.beam_trans'
                open(best_trans_file, 'w').close()
                open(beam_trans_file, 'w').close()
                ftrans = open(best_trans_file, 'w', 'utf-8')
                btrans = open(beam_trans_file, 'w', 'utf-8')

                self.logger.info('Start translating {}'.format(self.input_file))
                start = time.time()
                count = 0
                for (src_input, src_seq_len, no_unk_src_toks) in self.data_manager.get_trans_input(self.input_file):
                    feed = {
                        model.src_inputs: src_input,
                        model.src_seq_lengths: src_seq_len
                    }
                    probs, scores, symbols, parents, alignments = sess.run([model.probs, model.scores, model.symbols, model.parents, model.alignments], feed_dict=feed)
                    alignments = numpy.transpose(alignments, axes=(1, 0, 2))

                    probs = numpy.transpose(numpy.array(probs))
                    scores = numpy.transpose(numpy.array(scores))
                    symbols = numpy.transpose(numpy.array(symbols))
                    parents = numpy.transpose(numpy.array(parents))

                    best_trans, best_trans_ids, beam_trans, best_trans_alignments = self.get_trans(probs, scores, symbols, parents, alignments, no_unk_src_toks)
                    ftrans.write(best_trans + '\n')
                    btrans.write(beam_trans + '\n\n')

                    if self.plot_align:
                        src_input = numpy.reshape(src_input, [-1])
                        if self.reverse:
                            src_input = src_input[::-1]
                            no_unk_src_toks = no_unk_src_toks[::-1]
                        trans_toks = best_trans.split()
                        best_trans_alignments = numpy.array(best_trans_alignments)[:len(trans_toks)]
                        filename = '{}_{}.png'.format(self.input_file, count)

                        self.plot_head_map(best_trans_alignments, trans_toks, best_trans_ids, no_unk_src_toks, src_input, filename)

                    count += 1
                    if count % 100 == 0:
                        self.logger.info('  Translating line {}, average {} seconds/sent'.format(count, (time.time() - start) / count))

                ftrans.close()
                btrans.close()

                self.logger.info('Done translating {}, it takes {} minutes'.format(self.input_file, float(time.time() - start) / 60.0))
示例#3
0
class Translator(object):
    def __init__(self, args):
        super(Translator, self).__init__()
        self.config = getattr(configurations, args.proto)()
        self.reverse = self.config['reverse']
        self.unk_repl = self.config['unk_repl']
        self.logger = ut.get_logger(self.config['log_file'])

        self.input_file = args.input_file
        self.model_file = args.model_file
        self.plot_align = args.plot_align

        if self.input_file is None or self.model_file is None or not os.path.exists(
                self.input_file) or not os.path.exists(self.model_file +
                                                       '.meta'):
            raise ValueError('Input file or model file does not exist')

        self.data_manager = DataManager(self.config)
        _, self.src_ivocab = self.data_manager.init_vocab(
            self.data_manager.src_lang)
        _, self.trg_ivocab = self.data_manager.init_vocab(
            self.data_manager.trg_lang)
        self.translate()

    def get_model(self, mode):
        reuse = mode != ac.TRAINING
        d = self.config['init_range']
        initializer = tf.random_uniform_initializer(-d, d)
        with tf.variable_scope(self.config['model_name'],
                               reuse=reuse,
                               initializer=initializer):
            return Model(self.config, mode)

    def translate(self):
        with tf.Graph().as_default():
            train_model = self.get_model(ac.TRAINING)
            model = self.get_model(ac.VALIDATING)

            with tf.Session() as sess:
                self.logger.info('Restore model from {}'.format(
                    self.model_file))
                saver = tf.train.Saver(var_list=tf.trainable_variables())
                saver.restore(sess, self.model_file)

                best_trans_file = self.input_file + '.best_trans'
                beam_trans_file = self.input_file + '.beam_trans'
                open(best_trans_file, 'w').close()
                open(beam_trans_file, 'w').close()
                ftrans = open(best_trans_file, 'w', 'utf-8')
                btrans = open(beam_trans_file, 'w', 'utf-8')

                self.logger.info('Start translating {}'.format(
                    self.input_file))
                start = time.time()
                count = 0
                for (src_input, src_seq_len,
                     no_unk_src_toks) in self.data_manager.get_trans_input(
                         self.input_file):
                    feed = {
                        model.src_inputs: src_input,
                        model.src_seq_lengths: src_seq_len
                    }
                    probs, scores, symbols, parents, alignments = sess.run(
                        [
                            model.probs, model.scores, model.symbols,
                            model.parents, model.alignments
                        ],
                        feed_dict=feed)
                    alignments = numpy.transpose(alignments, axes=(1, 0, 2))

                    probs = numpy.transpose(numpy.array(probs))
                    scores = numpy.transpose(numpy.array(scores))
                    symbols = numpy.transpose(numpy.array(symbols))
                    parents = numpy.transpose(numpy.array(parents))

                    best_trans, best_trans_ids, beam_trans, best_trans_alignments = ut.get_trans(
                        probs,
                        scores,
                        symbols,
                        parents,
                        alignments,
                        no_unk_src_toks,
                        self.trg_ivocab,
                        reverse=self.reverse,
                        unk_repl=self.unk_repl)
                    best_trans_wo_eos = best_trans.split()[:-1]
                    best_trans_wo_eos = u' '.join(best_trans_wo_eos)
                    ftrans.write(best_trans_wo_eos + '\n')
                    btrans.write(beam_trans + '\n\n')

                    if self.plot_align:
                        src_input = numpy.reshape(src_input, [-1])
                        if self.reverse:
                            src_input = src_input[::-1]
                            no_unk_src_toks = no_unk_src_toks[::-1]
                        trans_toks = best_trans.split()
                        best_trans_alignments = numpy.array(
                            best_trans_alignments)[:len(trans_toks)]
                        filename = '{}_{}.png'.format(self.input_file, count)

                        ut.plot_head_map(best_trans_alignments, trans_toks,
                                         best_tran_ids, no_unk_src_toks,
                                         src_input, filename)

                    count += 1
                    if count % 100 == 0:
                        self.logger.info(
                            '  Translating line {}, average {} seconds/sent'.
                            format(count, (time.time() - start) / count))

                ftrans.close()
                btrans.close()

                self.logger.info(
                    'Done translating {}, it takes {} minutes'.format(
                        self.input_file,
                        float(time.time() - start) / 60.0))