def run_validation(epoch, dataset_name: str): dataset = load_data(dataset_name) print("Number of %s instances: %d" % (dataset_name, len(dataset))) model = Transformer( i2w=i2w, use_knowledge=args.use_knowledge, args=args, test=True ).cuda() model.load("{0}model_{1}.bin".format(args.save_path, epoch)) model.transformer.eval() # Iterate over batches num_batches = math.ceil(len(dataset) / args.batch_size) cum_loss = 0 cum_words = 0 predicted_sentences = [] indices = list(range(len(dataset))) for batch in tqdm(range(num_batches)): # Prepare batch batch_indices = indices[batch * args.batch_size : (batch + 1) * args.batch_size] batch_rows = [dataset[i] for i in batch_indices] # Encode batch. If facts are being used, they'll be prepended to the input input_seq, input_lens, target_seq, target_lens = model.prep_batch(batch_rows) # Decode batch predicted_sentences += model.decode(input_seq, input_lens) # Evaluate batch cum_loss += model.eval_ppl(input_seq, input_lens, target_seq, target_lens) cum_words += (target_seq != w2i["_pad"]).sum().item() # Log epoch ppl = math.exp(cum_loss / cum_words) print("{} Epoch: {} PPL: {}".format(dataset_name, epoch, ppl)) # Save predictions open( "{0}{1}_epoch_{2}.pred".format(args.save_path, dataset_name, str(epoch)), "w+" ).writelines([l + "\n" for l in predicted_sentences])
def main(args): src, tgt = load_data(args.path) src_vocab = Vocab(init_token='<sos>', eos_token='<eos>', pad_token='<pad>', unk_token='<unk>') src_vocab.load(os.path.join(args.path, 'vocab.en')) tgt_vocab = Vocab(init_token='<sos>', eos_token='<eos>', pad_token='<pad>', unk_token='<unk>') tgt_vocab.load(os.path.join(args.path, 'vocab.de')) vsize_src = len(src_vocab) vsize_tar = len(tgt_vocab) net = Transformer(vsize_src, vsize_tar) if not args.test: train_loader = get_loader(src['train'], tgt['train'], src_vocab, tgt_vocab, batch_size=args.batch_size, shuffle=True) valid_loader = get_loader(src['valid'], tgt['valid'], src_vocab, tgt_vocab, batch_size=args.batch_size) net.to(device) optimizer = optim.Adam(net.parameters(), lr=args.lr) best_valid_loss = 10.0 for epoch in range(args.epochs): print("Epoch {0}".format(epoch)) net.train() train_loss = run_epoch(net, train_loader, optimizer) print("train loss: {0}".format(train_loss)) net.eval() valid_loss = run_epoch(net, valid_loader, None) print("valid loss: {0}".format(valid_loss)) torch.save(net, 'data/ckpt/last_model') if valid_loss < best_valid_loss: best_valid_loss = valid_loss torch.save(net, 'data/ckpt/best_model') else: # test net = torch.load('data/ckpt/best_model') net.to(device) net.eval() test_loader = get_loader(src['test'], tgt['test'], src_vocab, tgt_vocab, batch_size=args.batch_size) pred = [] iter_cnt = 0 for src_batch, tgt_batch in test_loader: source, src_mask = make_tensor(src_batch) source = source.to(device) src_mask = src_mask.to(device) res = net.decode(source, src_mask) pred_batch = res.tolist() # every sentences in pred_batch should start with <sos> token (index: 0) and end with <eos> token (index: 1). # every <pad> token (index: 2) should be located after <eos> token (index: 1). # example of pred_batch: # [[0, 5, 6, 7, 1], # [0, 4, 9, 1, 2], # [0, 6, 1, 2, 2]] pred += seq2sen(pred_batch, tgt_vocab) iter_cnt += 1 #print(pred_batch) with open('data/results/pred.txt', 'w') as f: for line in pred: f.write('{}\n'.format(line)) os.system( 'bash scripts/bleu.sh data/results/pred.txt data/multi30k/test.de.atok' )
class Prediction: def __init__(self, args): """ :param model_dir: model dir path :param vocab_file: vocab file path """ self.tf = import_tf(0) self.args = args self.model_dir = args.logdir self.vocab_file = args.vocab self.token2idx, self.idx2token = _load_vocab(args.vocab) hparams = Hparams() parser = hparams.parser self.hp = parser.parse_args() self.model = Transformer(self.hp) self._add_placeholder() self._init_graph() def _init_graph(self): """ init graph """ self.ys = (self.input_y, None, None) self.xs = (self.input_x, None) self.memory = self.model.encode(self.xs, False)[0] self.logits = self.model.decode(self.xs, self.ys, self.memory, False)[0] ckpt = self.tf.train.get_checkpoint_state(self.model_dir).all_model_checkpoint_paths[-1] graph = self.logits.graph sess_config = self.tf.ConfigProto(allow_soft_placement=True) sess_config.gpu_options.allow_growth = True saver = self.tf.train.Saver() self.sess = self.tf.Session(config=sess_config, graph=graph) self.sess.run(self.tf.global_variables_initializer()) self.tf.reset_default_graph() saver.restore(self.sess, ckpt) self.bs = BeamSearch(self.model, self.hp.beam_size, list(self.idx2token.keys())[2], list(self.idx2token.keys())[3], self.idx2token, self.hp.maxlen2, self.input_x, self.input_y, self.logits) def predict(self, content): """ abstract prediction by beam search :param content: article content :return: prediction result """ input_x = content.split() while len(input_x) < self.args.maxlen1: input_x.append('<pad>') input_x = input_x[:self.args.maxlen1] input_x = [self.token2idx.get(s, self.token2idx['<unk>']) for s in input_x] memory = self.sess.run(self.memory, feed_dict={self.input_x: [input_x]}) return self.bs.search(self.sess, input_x, memory[0]) def _add_placeholder(self): """ add tensorflow placeholder """ self.input_x = self.tf.placeholder(dtype=self.tf.int32, shape=[None, self.args.maxlen1], name='input_x') self.input_y = self.tf.placeholder(dtype=self.tf.int32, shape=[None, None], name='input_y')
class Transformer_pl(pl.LightningModule): def __init__(self, hparams, **kwargs): super(Transformer_pl, self).__init__() self.hparams = hparams self.transformer = Transformer(self.hparams) self.sp_kor = korean_tokenizer_load() self.sp_eng = english_tokenizer_load() def forward(self, enc_inputs, dec_inputs): output_logits, *_ = self.transformer(enc_inputs, dec_inputs) return output_logits def cal_loss(self, tgt_hat, tgt_label): loss = F.cross_entropy(tgt_hat, tgt_label.contiguous().view(-1), ignore_index=self.hparams['padding_idx']) return loss def translate(self, input_sentence): self.eval() input_ids = self.sp_kor.EncodeAsIds(input_sentence) if len(input_ids) <= self.hparams['max_seq_length']: input_ids = input_ids + [self.hparams['padding_idx']] * ( self.hparams['max_seq_length'] - len(input_ids)) if len(input_ids) > self.hparams['max_seq_length']: input_ids = input_ids[:self.hparams['max_seq_length']] input_ids = torch.tensor([input_ids]) enc_outputs, _ = self.transformer.encode(input_ids) target_ids = torch.zeros(1, self.hparams['max_seq_length']).type_as( input_ids.data) next_token = self.sp_eng.bos_id() for i in range(0, self.hparams['max_seq_length']): target_ids[0][i] = next_token decoder_output, *_ = self.transformer.decode( target_ids, input_ids, enc_outputs) prob = decoder_output.squeeze(0).max(dim=-1, keepdim=False)[1] next_token = prob.data[i].item() if next_token == self.sp_eng.eos_id(): break output_sent = self.sp_eng.DecodeIds(target_ids[0].tolist()) return output_sent # --------------------- # TRAINING AND EVALUATION # --------------------- def training_step(self, batch, batch_idx): src, tgt = batch.kor, batch.eng tgt_label = tgt[:, 1:] tgt_hat = self(src, tgt[:, :-1]) loss = self.cal_loss(tgt_hat, tgt_label) train_ppl = math.exp(loss) tensorboard_logs = {'train_loss': loss, 'train_ppl': train_ppl} return {'loss': loss, 'log': tensorboard_logs} def validation_step(self, batch, batch_idx): src, tgt = batch.kor, batch.eng tgt_label = tgt[:, 1:] tgt_hat = self(src, tgt[:, :-1]) val_loss = self.cal_loss(tgt_hat, tgt_label) return {'val_loss': val_loss} def validation_epoch_end(self, outputs): val_loss = torch.stack([x['val_loss'] for x in outputs]).mean() val_ppl = math.exp(val_loss) tensorboard_logs = {'val_loss': val_loss, 'val_ppl': val_ppl} print("") print("=" * 30) print(f"val_loss:{val_loss}") print("=" * 30) return { 'val_loss': val_loss, 'val_ppl': val_ppl, 'log': tensorboard_logs } # --------------------- # TRAINING SETUP # --------------------- def configure_optimizers(self): optimizer = torch.optim.AdamW(self.transformer.parameters(), lr=self.hparams['lr']) return [optimizer] def make_field(self): KOR = torchtext.data.Field(use_vocab=False, tokenize=self.sp_kor.EncodeAsIds, batch_first=True, fix_length=self.hparams['max_seq_length'], pad_token=self.sp_kor.pad_id()) ENG = torchtext.data.Field( use_vocab=False, tokenize=self.sp_eng.EncodeAsIds, batch_first=True, fix_length=self.hparams['max_seq_length'] + 1, # should +1 because of bos token for tgt label init_token=self.sp_eng.bos_id(), eos_token=self.sp_eng.eos_id(), pad_token=self.sp_eng.pad_id()) return KOR, ENG def train_dataloader(self): KOR, ENG = self.make_field() train_data = TabularDataset(path="./data/train.tsv", format='tsv', skip_header=True, fields=[('kor', KOR), ('eng', ENG)]) train_iter = BucketIterator(train_data, batch_size=self.hparams['batch_size'], sort_key=lambda x: len(x.kor), sort_within_batch=False) return train_iter def val_dataloader(self): KOR, ENG = self.make_field() valid_data = TabularDataset(path="./data/valid.tsv", format='tsv', skip_header=True, fields=[('kor', KOR), ('eng', ENG)]) val_iter = BucketIterator(valid_data, batch_size=self.hparams['batch_size'], sort_key=lambda x: len(x.kor), sort_within_batch=False) return val_iter