Exemplo n.º 1
0
def greedy_test(args):
    """ Test function """

    # load vocabulary
    vocab = torch.load(args.vocab)

    # build model
    translator = Transformer(args, vocab)
    translator.eval()

    # load parameters
    translator.load_state_dict(torch.load(args.decode_model_path))
    if args.cuda:
        translator = translator.cuda()

    test_data = read_corpus(args.decode_from_file, source="src")
    # ['<BOS>', '<PAD>', 'PAD', '<PAD>', '<PAD>']
    pred_data = len(test_data) * [[
        constants.PAD_WORD if i else constants.BOS_WORD
        for i in range(args.decode_max_steps)
    ]]

    output_file = codecs.open(args.decode_output_file, "w", encoding="utf-8")
    for test, pred in zip(test_data, pred_data):
        pred_output = [constants.PAD_WORD] * args.decode_max_steps
        test_var = to_input_variable([test], vocab.src, cuda=args.cuda)

        # only need one time
        enc_output = translator.encode(test_var[0], test_var[1])
        for i in range(args.decode_max_steps):
            pred_var = to_input_variable([pred[:i + 1]],
                                         vocab.tgt,
                                         cuda=args.cuda)

            scores = translator.translate(enc_output, test_var[0], pred_var)

            _, argmax_idxs = torch.max(scores, dim=-1)
            one_step_idx = argmax_idxs[-1].item()

            pred_output[i] = vocab.tgt.id2word[one_step_idx]
            if (one_step_idx
                    == constants.EOS) or (i == args.decode_max_steps - 1):
                print("[Source] %s" % " ".join(test))
                print("[Predict] %s" % " ".join(pred_output[:i]))
                print()

                output_file.write(" ".join(pred_output[:i]) + "\n")
                output_file.flush()
                break
            pred[i + 1] = vocab.tgt.id2word[one_step_idx]

    output_file.close()
Exemplo n.º 2
0
def test(args):
    """ Decode with beam search """

    # load vocabulary
    vocab = torch.load(args.vocab)

    # build model
    translator = Transformer(args, vocab)
    translator.eval()

    # load parameters
    translator.load_state_dict(torch.load(args.decode_model_path))
    if args.cuda:
        translator = translator.cuda()

    test_data = read_corpus(args.decode_from_file, source="src")
    output_file = codecs.open(args.decode_output_file, "w", encoding="utf-8")
    for test in test_data:
        test_seq, test_pos = to_input_variable([test],
                                               vocab.src,
                                               cuda=args.cuda)
        test_seq_beam = test_seq.expand(args.decode_beam_size,
                                        test_seq.size(1))

        enc_output = translator.encode(test_seq, test_pos)
        enc_output_beam = enc_output.expand(args.decode_beam_size,
                                            enc_output.size(1),
                                            enc_output.size(2))

        beam = Beam_Search_V2(beam_size=args.decode_beam_size,
                              tgt_vocab=vocab.tgt,
                              length_alpha=args.decode_alpha)
        for i in range(args.decode_max_steps):

            # the first time for beam search
            if i == 0:
                # <BOS>
                pred_var = to_input_variable(beam.candidates[:1],
                                             vocab.tgt,
                                             cuda=args.cuda)
                scores = translator.translate(enc_output, test_seq, pred_var)
            else:
                pred_var = to_input_variable(beam.candidates,
                                             vocab.tgt,
                                             cuda=args.cuda)
                scores = translator.translate(enc_output_beam, test_seq_beam,
                                              pred_var)

            log_softmax_scores = F.log_softmax(scores, dim=-1)
            log_softmax_scores = log_softmax_scores.view(
                pred_var[0].size(0), -1, log_softmax_scores.size(-1))
            log_softmax_scores = log_softmax_scores[:, -1, :]

            is_done = beam.advance(log_softmax_scores)
            beam.update_status()

            if is_done:
                break

        print("[Source] %s" % " ".join(test))
        print("[Predict] %s" % beam.get_best_candidate())
        print()

        output_file.write(beam.get_best_candidate() + "\n")
        output_file.flush()

    output_file.close()
Exemplo n.º 3
0
class Prediction:
    def __init__(self, args):
        """
        :param model_dir: model dir path
        :param vocab_file: vocab file path
        """
        self.tf = import_tf(0)

        self.args = args
        self.model_dir = args.logdir
        self.vocab_file = args.vocab
        self.token2idx, self.idx2token = _load_vocab(args.vocab)

        hparams = Hparams()
        parser = hparams.parser
        self.hp = parser.parse_args()

        self.model = Transformer(self.hp)

        self._add_placeholder()
        self._init_graph()

    def _init_graph(self):
        """
        init graph
        """
        self.ys = (self.input_y, None, None)
        self.xs = (self.input_x, None)
        self.memory = self.model.encode(self.xs, False)[0]
        self.logits = self.model.decode(self.xs, self.ys, self.memory, False)[0]

        ckpt = self.tf.train.get_checkpoint_state(self.model_dir).all_model_checkpoint_paths[-1]

        graph = self.logits.graph
        sess_config = self.tf.ConfigProto(allow_soft_placement=True)
        sess_config.gpu_options.allow_growth = True

        saver = self.tf.train.Saver()
        self.sess = self.tf.Session(config=sess_config, graph=graph)

        self.sess.run(self.tf.global_variables_initializer())
        self.tf.reset_default_graph()
        saver.restore(self.sess, ckpt)

        self.bs = BeamSearch(self.model,
                             self.hp.beam_size,
                             list(self.idx2token.keys())[2],
                             list(self.idx2token.keys())[3],
                             self.idx2token,
                             self.hp.maxlen2,
                             self.input_x,
                             self.input_y,
                             self.logits)

    def predict(self, content):
        """
        abstract prediction by beam search
        :param content: article content
        :return: prediction result
        """
        input_x = content.split()
        while len(input_x) < self.args.maxlen1: input_x.append('<pad>')
        input_x = input_x[:self.args.maxlen1]

        input_x = [self.token2idx.get(s, self.token2idx['<unk>']) for s in input_x]

        memory = self.sess.run(self.memory, feed_dict={self.input_x: [input_x]})

        return self.bs.search(self.sess, input_x, memory[0])

    def _add_placeholder(self):
        """
        add tensorflow placeholder
        """
        self.input_x = self.tf.placeholder(dtype=self.tf.int32, shape=[None, self.args.maxlen1], name='input_x')
        self.input_y = self.tf.placeholder(dtype=self.tf.int32, shape=[None, None], name='input_y')
Exemplo n.º 4
0
# 加载参数
hparams = Hparams()
parser = hparams.parser
hp = parser.parse_args()

load_hparams(hp, ckpt_dir)

with tf.Session() as sess:

    input_ids_p = tf.placeholder(tf.int32, [None, None], name="input_ids")
    input_len_p = tf.placeholder(tf.int32, [None], name="input_len")

    m = Transformer(hp)
    # tf.constant(1) is useless
    xs = (input_ids_p, input_len_p, tf.constant(1))
    memory, _, _ = m.encode(xs, False)

    vector = tf.reduce_mean(memory, axis=1, name='avg_vector')

    saver = tf.train.Saver()
    saver.restore(sess, tf.train.latest_checkpoint(ckpt_dir))

    graph_def = tf.get_default_graph().as_graph_def()
    # encoder/num_blocks_0/positionwise_feedforward/ln/add_1 is memory
    output_graph_def = tf.graph_util.convert_variables_to_constants(
        sess, graph_def, [
            'input_ids', 'input_len',
            'encoder/num_blocks_0/positionwise_feedforward/ln/add_1',
            'avg_vector'
        ])
    with tf.gfile.FastGFile('tsf.pb', mode='wb') as f:
Exemplo n.º 5
0
class Transformer_pl(pl.LightningModule):
    def __init__(self, hparams, **kwargs):
        super(Transformer_pl, self).__init__()
        self.hparams = hparams
        self.transformer = Transformer(self.hparams)

        self.sp_kor = korean_tokenizer_load()
        self.sp_eng = english_tokenizer_load()

    def forward(self, enc_inputs, dec_inputs):
        output_logits, *_ = self.transformer(enc_inputs, dec_inputs)
        return output_logits

    def cal_loss(self, tgt_hat, tgt_label):
        loss = F.cross_entropy(tgt_hat,
                               tgt_label.contiguous().view(-1),
                               ignore_index=self.hparams['padding_idx'])
        return loss

    def translate(self, input_sentence):
        self.eval()
        input_ids = self.sp_kor.EncodeAsIds(input_sentence)
        if len(input_ids) <= self.hparams['max_seq_length']:
            input_ids = input_ids + [self.hparams['padding_idx']] * (
                self.hparams['max_seq_length'] - len(input_ids))
        if len(input_ids) > self.hparams['max_seq_length']:
            input_ids = input_ids[:self.hparams['max_seq_length']]
        input_ids = torch.tensor([input_ids])

        enc_outputs, _ = self.transformer.encode(input_ids)
        target_ids = torch.zeros(1, self.hparams['max_seq_length']).type_as(
            input_ids.data)
        next_token = self.sp_eng.bos_id()

        for i in range(0, self.hparams['max_seq_length']):
            target_ids[0][i] = next_token
            decoder_output, *_ = self.transformer.decode(
                target_ids, input_ids, enc_outputs)
            prob = decoder_output.squeeze(0).max(dim=-1, keepdim=False)[1]
            next_token = prob.data[i].item()
            if next_token == self.sp_eng.eos_id():
                break

        output_sent = self.sp_eng.DecodeIds(target_ids[0].tolist())
        return output_sent

    # ---------------------
    # TRAINING AND EVALUATION
    # ---------------------
    def training_step(self, batch, batch_idx):
        src, tgt = batch.kor, batch.eng
        tgt_label = tgt[:, 1:]
        tgt_hat = self(src, tgt[:, :-1])
        loss = self.cal_loss(tgt_hat, tgt_label)
        train_ppl = math.exp(loss)
        tensorboard_logs = {'train_loss': loss, 'train_ppl': train_ppl}
        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        src, tgt = batch.kor, batch.eng
        tgt_label = tgt[:, 1:]
        tgt_hat = self(src, tgt[:, :-1])
        val_loss = self.cal_loss(tgt_hat, tgt_label)
        return {'val_loss': val_loss}

    def validation_epoch_end(self, outputs):
        val_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        val_ppl = math.exp(val_loss)
        tensorboard_logs = {'val_loss': val_loss, 'val_ppl': val_ppl}
        print("")
        print("=" * 30)
        print(f"val_loss:{val_loss}")
        print("=" * 30)
        return {
            'val_loss': val_loss,
            'val_ppl': val_ppl,
            'log': tensorboard_logs
        }

    # ---------------------
    # TRAINING SETUP
    # ---------------------
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.transformer.parameters(),
                                      lr=self.hparams['lr'])
        return [optimizer]

    def make_field(self):
        KOR = torchtext.data.Field(use_vocab=False,
                                   tokenize=self.sp_kor.EncodeAsIds,
                                   batch_first=True,
                                   fix_length=self.hparams['max_seq_length'],
                                   pad_token=self.sp_kor.pad_id())

        ENG = torchtext.data.Field(
            use_vocab=False,
            tokenize=self.sp_eng.EncodeAsIds,
            batch_first=True,
            fix_length=self.hparams['max_seq_length'] +
            1,  # should +1 because of bos token for tgt label
            init_token=self.sp_eng.bos_id(),
            eos_token=self.sp_eng.eos_id(),
            pad_token=self.sp_eng.pad_id())
        return KOR, ENG

    def train_dataloader(self):
        KOR, ENG = self.make_field()
        train_data = TabularDataset(path="./data/train.tsv",
                                    format='tsv',
                                    skip_header=True,
                                    fields=[('kor', KOR), ('eng', ENG)])

        train_iter = BucketIterator(train_data,
                                    batch_size=self.hparams['batch_size'],
                                    sort_key=lambda x: len(x.kor),
                                    sort_within_batch=False)
        return train_iter

    def val_dataloader(self):
        KOR, ENG = self.make_field()
        valid_data = TabularDataset(path="./data/valid.tsv",
                                    format='tsv',
                                    skip_header=True,
                                    fields=[('kor', KOR), ('eng', ENG)])

        val_iter = BucketIterator(valid_data,
                                  batch_size=self.hparams['batch_size'],
                                  sort_key=lambda x: len(x.kor),
                                  sort_within_batch=False)
        return val_iter