Beispiel #1
0
def run_validation(epoch, dataset_name: str):
    dataset = load_data(dataset_name)
    print("Number of %s instances: %d" % (dataset_name, len(dataset)))

    model = Transformer(
        i2w=i2w, use_knowledge=args.use_knowledge, args=args, test=True
    ).cuda()
    model.load("{0}model_{1}.bin".format(args.save_path, epoch))
    model.transformer.eval()
    # Iterate over batches
    num_batches = math.ceil(len(dataset) / args.batch_size)
    cum_loss = 0
    cum_words = 0
    predicted_sentences = []
    indices = list(range(len(dataset)))
    for batch in tqdm(range(num_batches)):
        # Prepare batch
        batch_indices = indices[batch * args.batch_size : (batch + 1) * args.batch_size]
        batch_rows = [dataset[i] for i in batch_indices]

        # Encode batch. If facts are being used, they'll be prepended to the input
        input_seq, input_lens, target_seq, target_lens = model.prep_batch(batch_rows)

        # Decode batch
        predicted_sentences += model.decode(input_seq, input_lens)

        # Evaluate batch
        cum_loss += model.eval_ppl(input_seq, input_lens, target_seq, target_lens)
        cum_words += (target_seq != w2i["_pad"]).sum().item()

        # Log epoch
    ppl = math.exp(cum_loss / cum_words)
    print("{} Epoch: {} PPL: {}".format(dataset_name, epoch, ppl))
    # Save predictions
    open(
        "{0}{1}_epoch_{2}.pred".format(args.save_path, dataset_name, str(epoch)), "w+"
    ).writelines([l + "\n" for l in predicted_sentences])
Beispiel #2
0
def main(args):
    src, tgt = load_data(args.path)

    src_vocab = Vocab(init_token='<sos>',
                      eos_token='<eos>',
                      pad_token='<pad>',
                      unk_token='<unk>')
    src_vocab.load(os.path.join(args.path, 'vocab.en'))
    tgt_vocab = Vocab(init_token='<sos>',
                      eos_token='<eos>',
                      pad_token='<pad>',
                      unk_token='<unk>')
    tgt_vocab.load(os.path.join(args.path, 'vocab.de'))

    vsize_src = len(src_vocab)
    vsize_tar = len(tgt_vocab)
    net = Transformer(vsize_src, vsize_tar)

    if not args.test:

        train_loader = get_loader(src['train'],
                                  tgt['train'],
                                  src_vocab,
                                  tgt_vocab,
                                  batch_size=args.batch_size,
                                  shuffle=True)
        valid_loader = get_loader(src['valid'],
                                  tgt['valid'],
                                  src_vocab,
                                  tgt_vocab,
                                  batch_size=args.batch_size)

        net.to(device)
        optimizer = optim.Adam(net.parameters(), lr=args.lr)

        best_valid_loss = 10.0
        for epoch in range(args.epochs):
            print("Epoch {0}".format(epoch))
            net.train()
            train_loss = run_epoch(net, train_loader, optimizer)
            print("train loss: {0}".format(train_loss))
            net.eval()
            valid_loss = run_epoch(net, valid_loader, None)
            print("valid loss: {0}".format(valid_loss))
            torch.save(net, 'data/ckpt/last_model')
            if valid_loss < best_valid_loss:
                best_valid_loss = valid_loss
                torch.save(net, 'data/ckpt/best_model')
    else:
        # test
        net = torch.load('data/ckpt/best_model')
        net.to(device)
        net.eval()

        test_loader = get_loader(src['test'],
                                 tgt['test'],
                                 src_vocab,
                                 tgt_vocab,
                                 batch_size=args.batch_size)

        pred = []
        iter_cnt = 0
        for src_batch, tgt_batch in test_loader:
            source, src_mask = make_tensor(src_batch)
            source = source.to(device)
            src_mask = src_mask.to(device)
            res = net.decode(source, src_mask)
            pred_batch = res.tolist()
            # every sentences in pred_batch should start with <sos> token (index: 0) and end with <eos> token (index: 1).
            # every <pad> token (index: 2) should be located after <eos> token (index: 1).
            # example of pred_batch:
            # [[0, 5, 6, 7, 1],
            #  [0, 4, 9, 1, 2],
            #  [0, 6, 1, 2, 2]]
            pred += seq2sen(pred_batch, tgt_vocab)
            iter_cnt += 1
            #print(pred_batch)

        with open('data/results/pred.txt', 'w') as f:
            for line in pred:
                f.write('{}\n'.format(line))

        os.system(
            'bash scripts/bleu.sh data/results/pred.txt data/multi30k/test.de.atok'
        )
Beispiel #3
0
class Prediction:
    def __init__(self, args):
        """
        :param model_dir: model dir path
        :param vocab_file: vocab file path
        """
        self.tf = import_tf(0)

        self.args = args
        self.model_dir = args.logdir
        self.vocab_file = args.vocab
        self.token2idx, self.idx2token = _load_vocab(args.vocab)

        hparams = Hparams()
        parser = hparams.parser
        self.hp = parser.parse_args()

        self.model = Transformer(self.hp)

        self._add_placeholder()
        self._init_graph()

    def _init_graph(self):
        """
        init graph
        """
        self.ys = (self.input_y, None, None)
        self.xs = (self.input_x, None)
        self.memory = self.model.encode(self.xs, False)[0]
        self.logits = self.model.decode(self.xs, self.ys, self.memory, False)[0]

        ckpt = self.tf.train.get_checkpoint_state(self.model_dir).all_model_checkpoint_paths[-1]

        graph = self.logits.graph
        sess_config = self.tf.ConfigProto(allow_soft_placement=True)
        sess_config.gpu_options.allow_growth = True

        saver = self.tf.train.Saver()
        self.sess = self.tf.Session(config=sess_config, graph=graph)

        self.sess.run(self.tf.global_variables_initializer())
        self.tf.reset_default_graph()
        saver.restore(self.sess, ckpt)

        self.bs = BeamSearch(self.model,
                             self.hp.beam_size,
                             list(self.idx2token.keys())[2],
                             list(self.idx2token.keys())[3],
                             self.idx2token,
                             self.hp.maxlen2,
                             self.input_x,
                             self.input_y,
                             self.logits)

    def predict(self, content):
        """
        abstract prediction by beam search
        :param content: article content
        :return: prediction result
        """
        input_x = content.split()
        while len(input_x) < self.args.maxlen1: input_x.append('<pad>')
        input_x = input_x[:self.args.maxlen1]

        input_x = [self.token2idx.get(s, self.token2idx['<unk>']) for s in input_x]

        memory = self.sess.run(self.memory, feed_dict={self.input_x: [input_x]})

        return self.bs.search(self.sess, input_x, memory[0])

    def _add_placeholder(self):
        """
        add tensorflow placeholder
        """
        self.input_x = self.tf.placeholder(dtype=self.tf.int32, shape=[None, self.args.maxlen1], name='input_x')
        self.input_y = self.tf.placeholder(dtype=self.tf.int32, shape=[None, None], name='input_y')
Beispiel #4
0
class Transformer_pl(pl.LightningModule):
    def __init__(self, hparams, **kwargs):
        super(Transformer_pl, self).__init__()
        self.hparams = hparams
        self.transformer = Transformer(self.hparams)

        self.sp_kor = korean_tokenizer_load()
        self.sp_eng = english_tokenizer_load()

    def forward(self, enc_inputs, dec_inputs):
        output_logits, *_ = self.transformer(enc_inputs, dec_inputs)
        return output_logits

    def cal_loss(self, tgt_hat, tgt_label):
        loss = F.cross_entropy(tgt_hat,
                               tgt_label.contiguous().view(-1),
                               ignore_index=self.hparams['padding_idx'])
        return loss

    def translate(self, input_sentence):
        self.eval()
        input_ids = self.sp_kor.EncodeAsIds(input_sentence)
        if len(input_ids) <= self.hparams['max_seq_length']:
            input_ids = input_ids + [self.hparams['padding_idx']] * (
                self.hparams['max_seq_length'] - len(input_ids))
        if len(input_ids) > self.hparams['max_seq_length']:
            input_ids = input_ids[:self.hparams['max_seq_length']]
        input_ids = torch.tensor([input_ids])

        enc_outputs, _ = self.transformer.encode(input_ids)
        target_ids = torch.zeros(1, self.hparams['max_seq_length']).type_as(
            input_ids.data)
        next_token = self.sp_eng.bos_id()

        for i in range(0, self.hparams['max_seq_length']):
            target_ids[0][i] = next_token
            decoder_output, *_ = self.transformer.decode(
                target_ids, input_ids, enc_outputs)
            prob = decoder_output.squeeze(0).max(dim=-1, keepdim=False)[1]
            next_token = prob.data[i].item()
            if next_token == self.sp_eng.eos_id():
                break

        output_sent = self.sp_eng.DecodeIds(target_ids[0].tolist())
        return output_sent

    # ---------------------
    # TRAINING AND EVALUATION
    # ---------------------
    def training_step(self, batch, batch_idx):
        src, tgt = batch.kor, batch.eng
        tgt_label = tgt[:, 1:]
        tgt_hat = self(src, tgt[:, :-1])
        loss = self.cal_loss(tgt_hat, tgt_label)
        train_ppl = math.exp(loss)
        tensorboard_logs = {'train_loss': loss, 'train_ppl': train_ppl}
        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        src, tgt = batch.kor, batch.eng
        tgt_label = tgt[:, 1:]
        tgt_hat = self(src, tgt[:, :-1])
        val_loss = self.cal_loss(tgt_hat, tgt_label)
        return {'val_loss': val_loss}

    def validation_epoch_end(self, outputs):
        val_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        val_ppl = math.exp(val_loss)
        tensorboard_logs = {'val_loss': val_loss, 'val_ppl': val_ppl}
        print("")
        print("=" * 30)
        print(f"val_loss:{val_loss}")
        print("=" * 30)
        return {
            'val_loss': val_loss,
            'val_ppl': val_ppl,
            'log': tensorboard_logs
        }

    # ---------------------
    # TRAINING SETUP
    # ---------------------
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.transformer.parameters(),
                                      lr=self.hparams['lr'])
        return [optimizer]

    def make_field(self):
        KOR = torchtext.data.Field(use_vocab=False,
                                   tokenize=self.sp_kor.EncodeAsIds,
                                   batch_first=True,
                                   fix_length=self.hparams['max_seq_length'],
                                   pad_token=self.sp_kor.pad_id())

        ENG = torchtext.data.Field(
            use_vocab=False,
            tokenize=self.sp_eng.EncodeAsIds,
            batch_first=True,
            fix_length=self.hparams['max_seq_length'] +
            1,  # should +1 because of bos token for tgt label
            init_token=self.sp_eng.bos_id(),
            eos_token=self.sp_eng.eos_id(),
            pad_token=self.sp_eng.pad_id())
        return KOR, ENG

    def train_dataloader(self):
        KOR, ENG = self.make_field()
        train_data = TabularDataset(path="./data/train.tsv",
                                    format='tsv',
                                    skip_header=True,
                                    fields=[('kor', KOR), ('eng', ENG)])

        train_iter = BucketIterator(train_data,
                                    batch_size=self.hparams['batch_size'],
                                    sort_key=lambda x: len(x.kor),
                                    sort_within_batch=False)
        return train_iter

    def val_dataloader(self):
        KOR, ENG = self.make_field()
        valid_data = TabularDataset(path="./data/valid.tsv",
                                    format='tsv',
                                    skip_header=True,
                                    fields=[('kor', KOR), ('eng', ENG)])

        val_iter = BucketIterator(valid_data,
                                  batch_size=self.hparams['batch_size'],
                                  sort_key=lambda x: len(x.kor),
                                  sort_within_batch=False)
        return val_iter