def testIteratorBatchType(self): current_cfg = self.data_cfg.copy() current_cfg["level"] = "word" current_cfg["lowercase"] = False # load toy data data = load_data(current_cfg) train_data = data["train_data"] dev_data = data["dev_data"] test_data = data["test_data"] vocabs = data["vocabs"] src_vocab = vocabs["src"] trg_vocab = vocabs["trg"] # make batches by number of sentences train_iter = iter(make_data_iter( train_data, batch_size=10, batch_type="sentence")) batch = next(train_iter) self.assertEqual(batch.src[0].shape[0], 10) self.assertEqual(batch.trg[0].shape[0], 10) # make batches by number of tokens train_iter = iter(make_data_iter( train_data, batch_size=100, batch_type="token")) _ = next(train_iter) # skip a batch _ = next(train_iter) # skip another batch batch = next(train_iter) self.assertEqual(batch.src[0].shape[0], 8) self.assertEqual(np.prod(batch.src[0].shape), 88) self.assertLessEqual(np.prod(batch.src[0].shape), 100)
def testIteratorBatchType(self): current_cfg = self.data_cfg.copy() # load toy data train_data, dev_data, test_data, src_vocab, trg_vocab = \ load_data(current_cfg) # make batches by number of sentences train_iter = iter( make_data_iter(train_data, batch_size=10, batch_type="sentence")) batch = next(train_iter) self.assertEqual(batch.src[0].shape[0], 10) self.assertEqual(batch.trg[0].shape[0], 10) # make batches by number of tokens train_iter = iter( make_data_iter(train_data, batch_size=100, batch_type="token")) _ = next(train_iter) # skip a batch _ = next(train_iter) # skip another batch batch = next(train_iter) self.assertEqual(batch.src[0].shape[0], 8) self.assertEqual(np.prod(batch.src[0].shape), 88) self.assertLessEqual(np.prod(batch.src[0].shape), 100)
def fast_adapt(self, task, learner, valid=False): loss = 0.0 batch_size = self.batch_size steps = self.adaptation_steps if valid: batch_size = self.valid_batch_size steps = 1 # Take only one batch during validation task_iter = make_data_iter(task, batch_size=batch_size, batch_type=self.batch_type, train=True, shuffle=self.shuffle) for i in range(steps): batch = next(iter(task_iter)) train_batch = self.batch_class(batch, self.model.pad_index, use_cuda=self.use_cuda) batch_loss, _, _, _ = learner(return_type="loss", **vars(train_batch)) learner.adapt(batch_loss, allow_nograd=True, allow_unused=True) # Adapt learner after every batch loss += batch_loss loss /= (batch_size * steps) #adapt after every task #learner.adapt(loss,allow_nograd=True,allow_unused=True) print("Loss") print(loss) return loss
def testBatchTrainIterator(self): batch_size = 4 self.assertEqual(len(self.train_data), 27) # make data iterator # *note*: BucketIterator is replaced with Iterator train_iter = make_data_iter(self.train_data, train=True, shuffle=True, batch_size=batch_size) self.assertEqual(train_iter.batch_size, batch_size) self.assertTrue(train_iter.shuffle) self.assertTrue(train_iter.train) self.assertEqual(train_iter.epoch, 0) self.assertEqual(train_iter.iterations, 0) expected_src0 = torch.Tensor([ [ 18, 8, 6, 26, 5, 4, 10, 6, 28, 8, 17, 11, 22, 5, 19, 14, 4, 12, 25, 3 ], [ 19, 11, 30, 5, 18, 23, 13, 4, 12, 5, 21, 4, 12, 7, 23, 17, 11, 9, 3, 1 ], [19, 11, 22, 5, 8, 11, 5, 29, 8, 22, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1], [14, 8, 6, 15, 4, 9, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] ]).long() expected_src0_len = torch.Tensor([20, 19, 11, 7]).long() expected_trg0 = torch.Tensor( [[ 14, 8, 21, 12, 4, 11, 6, 12, 13, 22, 4, 14, 12, 10, 21, 8, 4, 14, 8, 23, 3 ], [ 5, 7, 30, 4, 20, 5, 5, 19, 4, 20, 5, 14, 10, 20, 9, 3, 1, 1, 1, 1, 1 ], [5, 7, 22, 4, 7, 6, 7, 9, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [ 8, 7, 6, 10, 17, 4, 13, 5, 15, 9, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ]]).long() expected_trg0_len = torch.Tensor([22, 17, 10, 12]).long() total_samples = 0 for b in iter(train_iter): b = Batch(torch_batch=b, pad_index=self.pad_index) if total_samples == 0: self.assertTensorEqual(b.src, expected_src0) self.assertTensorEqual(b.src_length, expected_src0_len) self.assertTensorEqual(b.trg, expected_trg0) self.assertTensorEqual(b.trg_length, expected_trg0_len) total_samples += b.nseqs self.assertLessEqual(b.nseqs, batch_size) self.assertEqual(total_samples, len(self.train_data))
def testBatchDevIterator(self): batch_size = 3 self.assertEqual(len(self.dev_data), 20) # make data iterator dev_iter = make_data_iter(self.dev_data, train=False, shuffle=False, batch_size=batch_size) self.assertEqual(dev_iter.batch_size, batch_size) self.assertFalse(dev_iter.shuffle) self.assertFalse(dev_iter.train) self.assertEqual(dev_iter.epoch, 0) self.assertEqual(dev_iter.iterations, 0) expected_src0 = torch.Tensor( [[29, 8, 5, 22, 5, 8, 16, 7, 19, 5, 22, 5, 24, 8, 7, 5, 7, 19, 16, 16, 5, 31, 10, 19, 11, 8, 17, 15, 10, 6, 18, 5, 7, 4, 10, 6, 5, 25, 3], [10, 17, 11, 5, 28, 12, 4, 23, 4, 5, 0, 10, 17, 11, 5, 22, 5, 14, 8, 7, 7, 5, 10, 17, 11, 5, 14, 8, 5, 31, 10, 6, 5, 9, 3, 1, 1, 1, 1], [29, 8, 5, 22, 5, 18, 23, 13, 4, 6, 5, 13, 8, 18, 5, 9, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]).long() expected_src0_len = torch.Tensor([39, 35, 17]).long() expected_trg0 = torch.Tensor( [[13, 11, 12, 4, 22, 4, 12, 5, 4, 22, 4, 25, 7, 6, 8, 4, 14, 12, 4, 24, 14, 5, 7, 6, 26, 17, 14, 10, 20, 4, 23, 3], [14, 0, 28, 4, 7, 6, 18, 18, 13, 4, 8, 5, 4, 24, 11, 4, 7, 11, 16, 11, 4, 9, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1], [13, 11, 12, 4, 22, 4, 7, 11, 27, 27, 5, 4, 9, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]).long() expected_trg0_len = torch.Tensor([33, 24, 15]).long() total_samples = 0 for b in iter(dev_iter): self.assertEqual(type(b), TorchTBatch) b = Batch(b, pad_index=self.pad_index) # test the sorting by src length self.assertEqual(type(b), Batch) before_sort = b.src_lengths b.sort_by_src_lengths() after_sort = b.src_lengths self.assertTensorEqual(torch.sort(before_sort, descending=True)[0], after_sort) self.assertEqual(type(b), Batch) if total_samples == 0: self.assertTensorEqual(b.src, expected_src0) self.assertTensorEqual(b.src_lengths, expected_src0_len) self.assertTensorEqual(b.trg, expected_trg0) self.assertTensorEqual(b.trg_lengths, expected_trg0_len) total_samples += b.nseqs self.assertLessEqual(b.nseqs, batch_size) self.assertEqual(total_samples, len(self.dev_data))
def testBatchTrainIterator(self): batch_size = 4 self.assertEqual(len(self.train_data), 27) # make data iterator train_iter = make_data_iter(self.train_data, train=True, shuffle=True, batch_size=batch_size) self.assertEqual(train_iter.batch_size, batch_size) self.assertTrue(train_iter.shuffle) self.assertTrue(train_iter.train) self.assertEqual(train_iter.epoch, 0) self.assertEqual(train_iter.iterations, 0) expected_src0 = torch.Tensor([[ 21, 10, 4, 16, 4, 5, 21, 4, 12, 33, 6, 14, 4, 12, 23, 6, 18, 4, 6, 9, 3 ], [ 20, 28, 4, 10, 28, 4, 6, 5, 14, 8, 6, 15, 4, 5, 7, 17, 11, 27, 6, 9, 3 ], [ 24, 8, 7, 5, 24, 10, 12, 14, 5, 18, 4, 7, 17, 11, 4, 11, 4, 6, 25, 3, 1 ]]).long() expected_src0_len = torch.Tensor([21, 21, 20]).long() expected_trg0 = torch.Tensor( [[6, 4, 27, 5, 8, 4, 5, 31, 4, 26, 7, 6, 10, 20, 11, 9, 3], [8, 7, 6, 10, 17, 4, 13, 5, 15, 9, 3, 1, 1, 1, 1, 1, 1], [12, 5, 4, 25, 7, 6, 8, 4, 7, 6, 18, 18, 11, 10, 12, 23, 3]]).long() expected_trg0_len = torch.Tensor([18, 12, 18]).long() total_samples = 0 for b in iter(train_iter): b = Batch(torch_batch=b, pad_index=self.pad_index) if total_samples == 0: src, src_lengths, _ = b["src"] trg, trg_lengths, _ = b["trg"] self.assertTensorEqual(src, expected_src0) self.assertTensorEqual(src_lengths, expected_src0_len) self.assertTensorEqual(trg[:, 1:], expected_trg0) self.assertTensorEqual(trg_lengths, expected_trg0_len) total_samples += b.nseqs self.assertLessEqual(b.nseqs, batch_size) self.assertEqual(total_samples, len(self.train_data))
def compute_task_loss(self, train_task, learner): loss = 0.0 task_train_iter = make_data_iter(train_task, batch_size=self.batch_size, batch_type=self.batch_type, train=True, shuffle=self.shuffle) for i, batch in enumerate(iter(task_train_iter)): train_batch = self.batch_class(batch, self.model.pad_index, use_cuda=self.use_cuda) batch_loss, _, _, _ = learner(return_type="loss", **vars(train_batch)) loss += batch_loss loss /= len(train_task) print("Task Loss", loss) return loss
def train_and_validate(self, train_data: Dataset, valid_data: Dataset) \ -> None: """ Train the model and validate it from time to time on the validation set. :param train_data: training data :param valid_data: validation data """ train_iter = make_data_iter(train_data, batch_size=self.batch_size, batch_type=self.batch_type, train=True, shuffle=self.shuffle) # For last batch in epoch batch_multiplier needs to be adjusted # to fit the number of leftover training examples leftover_batch_size = len(train_data) % (self.batch_multiplier * self.batch_size) for epoch_no in range(self.epochs): self.logger.info("EPOCH %d", epoch_no + 1) if self.scheduler is not None and self.scheduler_step_at == "epoch": self.scheduler.step(epoch=epoch_no) self.model.train() # Reset statistics for each epoch. start = time.time() total_valid_duration = 0 start_tokens = self.total_tokens self.current_batch_multiplier = self.batch_multiplier self.optimizer.zero_grad() count = self.current_batch_multiplier - 1 epoch_loss = 0 for i, batch in enumerate(iter(train_iter)): # reactivate training self.model.train() # create a Batch object from torchtext batch batch = Batch(batch, self.pad_index, use_cuda=self.use_cuda) # only update every batch_multiplier batches # see https://medium.com/@davidlmorton/ # increasing-mini-batch-size-without-increasing- # memory-6794e10db672 # Set current_batch_mutliplier to fit # number of leftover examples for last batch in epoch # Only works if batch_type == sentence if self.batch_type == "sentence": if self.batch_multiplier > 1 and i == len(train_iter) - \ math.ceil(leftover_batch_size / self.batch_size): self.current_batch_multiplier = math.ceil( leftover_batch_size / self.batch_size) count = self.current_batch_multiplier - 1 update = count == 0 # print(count, update, self.steps) batch_loss = self._train_batch(batch, update=update, count=count) # Only save finaly computed batch_loss of full batch if update: self.tb_writer.add_scalar("train/train_batch_loss", batch_loss, self.steps) count = self.batch_multiplier if update else count count -= 1 # Only add complete batch_loss of full mini-batch to epoch_loss if update: epoch_loss += batch_loss.detach().cpu().numpy() if self.scheduler is not None and \ self.scheduler_step_at == "step" and update: self.scheduler.step() # log learning progress if self.steps % self.logging_freq == 0 and update: elapsed = time.time() - start - total_valid_duration elapsed_tokens = self.total_tokens - start_tokens self.logger.info( "Epoch %3d Step: %8d Batch Loss: %12.6f " "Tokens per Sec: %8.0f, Lr: %.6f", epoch_no + 1, self.steps, batch_loss, elapsed_tokens / elapsed, self.optimizer.param_groups[0]["lr"]) start = time.time() total_valid_duration = 0 start_tokens = self.total_tokens # validate on the entire dev set if self.steps % self.validation_freq == 0 and update: valid_start_time = time.time() valid_score, valid_loss, valid_ppl, valid_sources, \ valid_sources_raw, valid_references, valid_hypotheses, \ valid_hypotheses_raw, valid_attention_scores = \ validate_on_data( logger=self.logger, batch_size=self.eval_batch_size, data=valid_data, eval_metric=self.eval_metric, level=self.level, model=self.model, use_cuda=self.use_cuda, max_output_length=self.max_output_length, loss_function=self.loss, beam_size=1, # greedy validations batch_type=self.eval_batch_type, postprocess=True # always remove BPE for validation ) self.tb_writer.add_scalar("valid/valid_loss", valid_loss, self.steps) self.tb_writer.add_scalar("valid/valid_score", valid_score, self.steps) self.tb_writer.add_scalar("valid/valid_ppl", valid_ppl, self.steps) if self.early_stopping_metric == "loss": ckpt_score = valid_loss elif self.early_stopping_metric in ["ppl", "perplexity"]: ckpt_score = valid_ppl else: ckpt_score = valid_score new_best = False if self.is_best(ckpt_score): self.best_ckpt_score = ckpt_score self.best_ckpt_iteration = self.steps self.logger.info( 'Hooray! New best validation result [%s]!', self.early_stopping_metric) if self.ckpt_queue.maxsize > 0: self.logger.info("Saving new checkpoint.") new_best = True self._save_checkpoint() if self.scheduler is not None \ and self.scheduler_step_at == "validation": self.scheduler.step(ckpt_score) # append to validation report self._add_report(valid_score=valid_score, valid_loss=valid_loss, valid_ppl=valid_ppl, eval_metric=self.eval_metric, new_best=new_best) self._log_examples( sources_raw=[v for v in valid_sources_raw], sources=valid_sources, hypotheses_raw=valid_hypotheses_raw, hypotheses=valid_hypotheses, references=valid_references) valid_duration = time.time() - valid_start_time total_valid_duration += valid_duration self.logger.info( 'Validation result (greedy) at epoch %3d, ' 'step %8d: %s: %6.2f, loss: %8.4f, ppl: %8.4f, ' 'duration: %.4fs', epoch_no + 1, self.steps, self.eval_metric, valid_score, valid_loss, valid_ppl, valid_duration) # store validation set outputs self._store_outputs(valid_hypotheses) # store attention plots for selected valid sentences if valid_attention_scores: store_attention_plots( attentions=valid_attention_scores, targets=valid_hypotheses_raw, sources=[s for s in valid_data.src], indices=self.log_valid_sents, output_prefix="{}/att.{}".format( self.model_dir, self.steps), tb_writer=self.tb_writer, steps=self.steps) if self.stop: break if self.stop: self.logger.info( 'Training ended since minimum lr %f was reached.', self.learning_rate_min) break self.logger.info('Epoch %3d: total training loss %.2f', epoch_no + 1, epoch_loss) else: self.logger.info('Training ended after %3d epochs.', epoch_no + 1) self.logger.info( 'Best validation result (greedy) at step ' '%8d: %6.2f %s.', self.best_ckpt_iteration, self.best_ckpt_score, self.early_stopping_metric) self.tb_writer.close() # close Tensorboard writer
def train_and_validate(self, train_data: Dataset, valid_data: Dataset) \ -> None: """ Train the model and validate it from time to time on the validation set. :param train_data: training data :param valid_data: validation data """ self.train_iter = make_data_iter(train_data, batch_size=self.batch_size, batch_type=self.batch_type, train=True, shuffle=self.shuffle) if self.train_iter_state is not None: self.train_iter.load_state_dict(self.train_iter_state) ################################################################# # simplify accumulation logic: ################################################################# # for epoch in range(epochs): # self.model.zero_grad() # epoch_loss = 0.0 # batch_loss = 0.0 # for i, batch in enumerate(iter(self.train_iter)): # # # gradient accumulation: # # loss.backward() inside _train_step() # batch_loss += self._train_step(inputs) # # if (i + 1) % self.batch_multiplier == 0: # self.optimizer.step() # update! # self.model.zero_grad() # reset gradients # self.steps += 1 # increment counter # # epoch_loss += batch_loss # accumulate batch loss # batch_loss = 0 # reset batch loss # # # leftovers are just ignored. ################################################################# logger.info( "Train stats:\n" "\tdevice: %s\n" "\tn_gpu: %d\n" "\t16-bits training: %r\n" "\tgradient accumulation: %d\n" "\tbatch size per device: %d\n" "\ttotal batch size (w. parallel & accumulation): %d", self.device, self.n_gpu, self.fp16, self.batch_multiplier, self.batch_size // self.n_gpu if self.n_gpu > 1 else self.batch_size, self.batch_size * self.batch_multiplier) for epoch_no in range(self.epochs): logger.info("EPOCH %d", epoch_no + 1) if self.scheduler is not None and self.scheduler_step_at == "epoch": self.scheduler.step(epoch=epoch_no) self.model.train() # Reset statistics for each epoch. start = time.time() total_valid_duration = 0 start_tokens = self.stats.total_tokens self.model.zero_grad() epoch_loss = 0 batch_loss = 0 for i, batch in enumerate(iter(self.train_iter)): # create a Batch object from torchtext batch batch = self.batch_class(batch, self.model.pad_index, use_cuda=self.use_cuda) # get batch loss batch_loss += self._train_step(batch) # update! if (i + 1) % self.batch_multiplier == 0: # clip gradients (in-place) if self.clip_grad_fun is not None: if self.fp16: self.clip_grad_fun( params=amp.master_params(self.optimizer)) else: self.clip_grad_fun(params=self.model.parameters()) # make gradient step self.optimizer.step() # decay lr if self.scheduler is not None \ and self.scheduler_step_at == "step": self.scheduler.step() # reset gradients self.model.zero_grad() # increment step counter self.stats.steps += 1 # log learning progress if self.stats.steps % self.logging_freq == 0: self.tb_writer.add_scalar("train/train_batch_loss", batch_loss, self.stats.steps) elapsed = time.time() - start - total_valid_duration elapsed_tokens = self.stats.total_tokens - start_tokens logger.info( "Epoch %3d, Step: %8d, Batch Loss: %12.6f, " "Tokens per Sec: %8.0f, Lr: %.6f", epoch_no + 1, self.stats.steps, batch_loss, elapsed_tokens / elapsed, self.optimizer.param_groups[0]["lr"]) start = time.time() total_valid_duration = 0 start_tokens = self.stats.total_tokens # Only add complete loss of full mini-batch to epoch_loss epoch_loss += batch_loss # accumulate epoch_loss batch_loss = 0 # rest batch_loss # validate on the entire dev set if self.stats.steps % self.validation_freq == 0: valid_duration = self._validate(valid_data, epoch_no) total_valid_duration += valid_duration if self.stats.stop: break if self.stats.stop: logger.info('Training ended since minimum lr %f was reached.', self.learning_rate_min) break logger.info('Epoch %3d: total training loss %.2f', epoch_no + 1, epoch_loss) else: logger.info('Training ended after %3d epochs.', epoch_no + 1) logger.info('Best validation result (greedy) at step %8d: %6.2f %s.', self.stats.best_ckpt_iter, self.stats.best_ckpt_score, self.early_stopping_metric) self.tb_writer.close() # close Tensorboard writer
def validate_on_data(model: Model, data: Dataset, batch_size: int, use_cuda: bool, max_output_length: int, level: str, eval_metric: Optional[str], n_gpu: int, batch_class: Batch = Batch, compute_loss: bool = False, beam_size: int = 1, beam_alpha: int = -1, batch_type: str = "sentence", postprocess: bool = True, bpe_type: str = "subword-nmt", sacrebleu: dict = None, n_best: int = 1) \ -> (float, float, float, List[str], List[List[str]], List[str], List[str], List[List[str]], List[np.array]): """ Generate translations for the given data. If `compute_loss` is True and references are given, also compute the loss. :param model: model module :param data: dataset for validation :param batch_size: validation batch size :param batch_class: class type of batch :param use_cuda: if True, use CUDA :param max_output_length: maximum length for generated hypotheses :param level: segmentation level, one of "char", "bpe", "word" :param eval_metric: evaluation metric, e.g. "bleu" :param n_gpu: number of GPUs :param compute_loss: whether to computes a scalar loss for given inputs and targets :param beam_size: beam size for validation. If <2 then greedy decoding (default). :param beam_alpha: beam search alpha for length penalty, disabled if set to -1 (default). :param batch_type: validation batch type (sentence or token) :param postprocess: if True, remove BPE segmentation from translations :param bpe_type: bpe type, one of {"subword-nmt", "sentencepiece"} :param sacrebleu: sacrebleu options :param n_best: Amount of candidates to return :return: - current_valid_score: current validation score [eval_metric], - valid_loss: validation loss, - valid_ppl:, validation perplexity, - valid_sources: validation sources, - valid_sources_raw: raw validation sources (before post-processing), - valid_references: validation references, - valid_hypotheses: validation_hypotheses, - decoded_valid: raw validation hypotheses (before post-processing), - valid_attention_scores: attention scores for validation hypotheses """ assert batch_size >= n_gpu, "batch_size must be bigger than n_gpu." if sacrebleu is None: # assign default value sacrebleu = {"remove_whitespace": True, "tokenize": "13a"} if batch_size > 1000 and batch_type == "sentence": logger.warning( "WARNING: Are you sure you meant to work on huge batches like " "this? 'batch_size' is > 1000 for sentence-batching. " "Consider decreasing it or switching to" " 'eval_batch_type: token'.") valid_iter = make_data_iter(dataset=data, batch_size=batch_size, batch_type=batch_type, shuffle=False, train=False) valid_sources_raw = data.src pad_index = model.src_vocab.stoi[PAD_TOKEN] # disable dropout model.eval() # don't track gradients during validation with torch.no_grad(): all_outputs = [] valid_attention_scores = [] total_loss = 0 total_ntokens = 0 total_nseqs = 0 for valid_batch in iter(valid_iter): # run as during training to get validation loss (e.g. xent) batch = batch_class(valid_batch, pad_index, use_cuda=use_cuda) # sort batch now by src length and keep track of order reverse_index = batch.sort_by_src_length() sort_reverse_index = expand_reverse_index(reverse_index, n_best) # run as during training with teacher forcing if compute_loss and batch.trg is not None: batch_loss, _, _, _ = model(return_type="loss", **vars(batch)) if n_gpu > 1: batch_loss = batch_loss.mean() # average on multi-gpu total_loss += batch_loss total_ntokens += batch.ntokens total_nseqs += batch.nseqs # run as during inference to produce translations output, attention_scores = run_batch( model=model, batch=batch, beam_size=beam_size, beam_alpha=beam_alpha, max_output_length=max_output_length, n_best=n_best) # sort outputs back to original order all_outputs.extend(output[sort_reverse_index]) valid_attention_scores.extend( attention_scores[sort_reverse_index] if attention_scores is not None else []) assert len(all_outputs) == len(data) * n_best if compute_loss and total_ntokens > 0: # total validation loss valid_loss = total_loss # exponent of token-level negative log prob valid_ppl = torch.exp(total_loss / total_ntokens) else: valid_loss = -1 valid_ppl = -1 # decode back to symbols decoded_valid = model.trg_vocab.arrays_to_sentences(arrays=all_outputs, cut_at_eos=True) # evaluate with metric on full dataset join_char = " " if level in ["word", "bpe"] else "" valid_sources = [join_char.join(s) for s in data.src] valid_references = [join_char.join(t) for t in data.trg] valid_hypotheses = [join_char.join(t) for t in decoded_valid] # post-process if level == "bpe" and postprocess: valid_sources = [ bpe_postprocess(s, bpe_type=bpe_type) for s in valid_sources ] valid_references = [ bpe_postprocess(v, bpe_type=bpe_type) for v in valid_references ] valid_hypotheses = [ bpe_postprocess(v, bpe_type=bpe_type) for v in valid_hypotheses ] # if references are given, evaluate against them if valid_references: assert len(valid_hypotheses) == len(valid_references) current_valid_score = 0 if eval_metric.lower() == 'bleu': # this version does not use any tokenization current_valid_score = bleu(valid_hypotheses, valid_references, tokenize=sacrebleu["tokenize"]) elif eval_metric.lower() == 'chrf': current_valid_score = chrf( valid_hypotheses, valid_references, remove_whitespace=sacrebleu["remove_whitespace"]) elif eval_metric.lower() == 'token_accuracy': current_valid_score = token_accuracy( # supply List[List[str]] list(decoded_valid), list(data.trg)) elif eval_metric.lower() == 'sequence_accuracy': current_valid_score = sequence_accuracy( valid_hypotheses, valid_references) else: current_valid_score = -1 return current_valid_score, valid_loss, valid_ppl, valid_sources, \ valid_sources_raw, valid_references, valid_hypotheses, \ decoded_valid, valid_attention_scores
def validate_on_data(model: Model, data: Dataset, batch_size: int, use_cuda: bool, max_output_length: int, trg_level: str, eval_metrics: Optional[Sequence[str]], loss_function: torch.nn.Module = None, beam_size: int = 0, force_prune_size: int = 5, beam_alpha: int = 0, batch_type: str = "sentence", save_attention: bool = False, validate_by_label: bool = False, forced_sparsity: bool = False, method=None, max_hyps=1, break_at_p: float = 1.0, break_at_argmax: bool = False, short_depth: int = 0): """ Generate translations for the given data. If `loss_function` is not None and references are given, also compute the loss. :param model: :param data: dataset for validation :param batch_size: validation batch size :param use_cuda: :param max_output_length: maximum length for generated hypotheses :param trg_level: target segmentation level :param eval_metrics: :param loss_function: loss function that computes a scalar loss for given inputs and targets :param beam_size: beam size for validation (default 0 is greedy) :param beam_alpha: beam search alpha for length penalty (default 0) :param batch_type: validation batch type (sentence or token) :return: - current_valid_scores: current validation score [eval_metric], - valid_references: validation references, - valid_hypotheses: validation_hypotheses, - decoded_valid: raw validation hypotheses (before post-processing), - valid_attention_scores: attention scores for validation hypotheses """ if beam_size > 0: force_prune_size = beam_size if validate_by_label: assert isinstance(data, TSVDataset) and data.label_columns valid_scores = defaultdict(float) # container for scores stats = defaultdict(float) valid_iter = make_data_iter(dataset=data, batch_size=batch_size, batch_type=batch_type, shuffle=False, train=False, use_cuda=use_cuda) pad_index = model.trg_vocab.stoi[PAD_TOKEN] model.eval() # disable dropout force_objectives = loss_function is not None or forced_sparsity # possible tasks are: force w/ gold, force w/ empty, search scorer = partial(len_penalty, alpha=beam_alpha) if beam_alpha > 0 else None confidences = [] corrects = [] with torch.no_grad(): all_outputs = [] valid_attention_scores = defaultdict(list) for valid_batch in iter(valid_iter): batch = Batch(valid_batch, pad_index) rev_index = batch.sort_by_src_lengths() encoder_output, _ = model.encode(batch) empty_probs = None if force_objectives and not isinstance(model, EnsembleModel): # compute all the logits. logits = model.force_decode(batch, encoder_output)[0] bsz, gold_len, vocab_size = logits.size() gold, gold_lengths, _ = batch["trg"] prediction_steps = gold_lengths.sum().item() - bsz assert gold.size(0) == bsz if loss_function is not None: gold_pred = gold[:, 1:].contiguous().view(-1) batch_loss = loss_function( logits.view(-1, logits.size(-1)), gold_pred) valid_scores["loss"] += batch_loss if forced_sparsity: # compute probabilities out = logits.view(-1, vocab_size) if isinstance(model, EnsembleModel): probs = out else: probs = model.decoder.gen_func(out, dim=-1) # Compute numbers derived from the distributions. # This includes support size, entropy, and calibration non_pad = (gold[:, 1:] != pad_index).view(-1) real_probs = probs[non_pad] n_supported = real_probs.gt(0).sum().item() pred_ps, pred_ix = real_probs.max(dim=-1) real_gold = gold[:, 1:].contiguous().view(-1)[non_pad] real_correct = pred_ix.eq(real_gold) corrects.append(real_correct) confidences.append(pred_ps) beam_probs, _ = real_probs.topk(force_prune_size, dim=-1) pruned_mass = 1 - beam_probs.sum(dim=-1) stats["force_pruned_mass"] += pruned_mass.sum().item() # compute stuff with the empty sequence empty_probs = probs.view(bsz, gold_len, vocab_size)[:, 0, model.eos_index] assert empty_probs.size() == gold_lengths.size() empty_possible = empty_probs.gt(0).sum().item() empty_mass = empty_probs.sum().item() stats["eos_supported"] += empty_possible stats["eos_mass"] += empty_mass stats["n_supp"] += n_supported stats["n_pred"] += prediction_steps short_scores = None if short_depth > 0: # we call run_batch again with the short depth. We don't # really care what the hypotheses are, we only want the # scores _, _, short_scores = model.run_batch( batch=batch, beam_size=beam_size, # can this be removed? scorer=scorer, # should be none max_output_length=short_depth, method="dfs", max_hyps=max_hyps, encoder_output=encoder_output, return_scores=True) # run as during inference to produce translations # todo: return_scores for greedy output, attention_scores, beam_scores = model.run_batch( batch=batch, beam_size=beam_size, scorer=scorer, max_output_length=max_output_length, method=method, max_hyps=max_hyps, encoder_output=encoder_output, return_scores=True, break_at_argmax=break_at_argmax, break_at_p=break_at_p) stats["hyp_length"] += output.ne(model.pad_index).sum().item() if beam_scores is not None and empty_probs is not None: # I need to expand this to handle stuff up to length m. # note that although you can compute the probability of the # empty sequence without any extra computation, you *do* need # to do extra decoding if you want to get the most likely # sequence with length <= m. empty_better = empty_probs.log().gt(beam_scores).sum().item() stats["empty_better"] += empty_better if short_scores is not None: short_better = short_scores.gt(beam_scores).sum().item() stats["short_better"] += short_better # sort outputs back to original order all_outputs.extend(output[rev_index]) if save_attention and attention_scores is not None: # beam search currently does not support attention logging for k, v in attention_scores.items(): valid_attention_scores[k].extend(v[rev_index]) assert len(all_outputs) == len(data) ref_length = sum(len(d.trg) for d in data) valid_scores["length_ratio"] = stats["hyp_length"] / ref_length assert len(corrects) == len(confidences) if corrects: valid_scores["ece"] = expected_calibration_error(corrects, confidences) if stats["n_pred"] > 0: valid_scores["ppl"] = math.exp(valid_scores["loss"] / stats["n_pred"]) if forced_sparsity and stats["n_pred"] > 0: valid_scores["support"] = stats["n_supp"] / stats["n_pred"] valid_scores["empty_possible"] = stats["eos_supported"] / len( all_outputs) valid_scores["empty_prob"] = stats["eos_mass"] / len(all_outputs) valid_scores[ "force_pruned_mass"] = stats["force_pruned_mass"] / stats["n_pred"] if beam_size > 0: valid_scores["empty_better"] = stats["empty_better"] / len( all_outputs) if short_depth > 0: score_name = "depth_{}_better".format(short_depth) valid_scores[score_name] = stats["short_better"] / len( all_outputs) # postprocess raw_hyps = model.trg_vocab.arrays_to_sentences(all_outputs) valid_hyps = postprocess(raw_hyps, trg_level) valid_refs = postprocess(data.trg, trg_level) # evaluate eval_funcs = { "bleu": bleu, "chrf": chrf, "token_accuracy": partial(token_accuracy, level=trg_level), "sequence_accuracy": sequence_accuracy, "wer": word_error_rate, "cer": partial(character_error_rate, level=trg_level), "levenshtein_distance": partial(levenshtein_distance, level=trg_level) } selected_eval_metrics = {name: eval_funcs[name] for name in eval_metrics} decoding_scores, scores_by_label = evaluate_decoding( data, valid_refs, valid_hyps, selected_eval_metrics, validate_by_label) valid_scores.update(decoding_scores) return valid_scores, valid_refs, valid_hyps, \ raw_hyps, valid_attention_scores, scores_by_label
def validate_on_data(model: Model, data: Dataset, logger: Logger, batch_size: int, use_cuda: bool, max_output_length: int, level: str, eval_metric: Optional[str], loss_function: torch.nn.Module = None, beam_size: int = 1, beam_alpha: int = -1, batch_type: str = "sentence", postprocess: bool = True ) \ -> (float, float, float, List[str], List[List[str]], List[str], List[str], List[List[str]], List[np.array]): """ Generate translations for the given data. If `loss_function` is not None and references are given, also compute the loss. :param model: model module :param logger: logger :param data: dataset for validation :param batch_size: validation batch size :param use_cuda: if True, use CUDA :param max_output_length: maximum length for generated hypotheses :param level: segmentation level, one of "char", "bpe", "word" :param eval_metric: evaluation metric, e.g. "bleu" :param loss_function: loss function that computes a scalar loss for given inputs and targets :param beam_size: beam size for validation. If <2 then greedy decoding (default). :param beam_alpha: beam search alpha for length penalty, disabled if set to -1 (default). :param batch_type: validation batch type (sentence or token) :param postprocess: if True, remove BPE segmentation from translations :return: - current_valid_score: current validation score [eval_metric], - valid_loss: validation loss, - valid_ppl:, validation perplexity, - valid_sources: validation sources, - valid_sources_raw: raw validation sources (before post-processing), - valid_references: validation references, - valid_hypotheses: validation_hypotheses, - decoded_valid: raw validation hypotheses (before post-processing), - valid_attention_scores: attention scores for validation hypotheses """ if batch_size > 1000 and batch_type == "sentence": logger.warning( "WARNING: Are you sure you meant to work on huge batches like " "this? 'batch_size' is > 1000 for sentence-batching. " "Consider decreasing it or switching to" " 'eval_batch_type: token'.") valid_iter = make_data_iter(dataset=data, batch_size=batch_size, batch_type=batch_type, shuffle=False, train=False) valid_sources_raw = data.src pad_index = model.src_vocab.stoi[PAD_TOKEN] # disable dropout model.eval() # don't track gradients during validation with torch.no_grad(): all_outputs = [] valid_attention_scores = [] total_loss = 0 total_ntokens = 0 total_nseqs = 0 for valid_batch in iter(valid_iter): # run as during training to get validation loss (e.g. xent) batch = Batch(valid_batch, pad_index, use_cuda=use_cuda) # sort batch now by src length and keep track of order sort_reverse_index = batch.sort_by_src_lengths() # run as during training with teacher forcing if loss_function is not None and batch.trg is not None: batch_loss = model.get_loss_for_batch( batch, loss_function=loss_function) total_loss += batch_loss total_ntokens += batch.ntokens total_nseqs += batch.nseqs # run as during inference to produce translations output, attention_scores = model.run_batch( batch=batch, beam_size=beam_size, beam_alpha=beam_alpha, max_output_length=max_output_length) # sort outputs back to original order all_outputs.extend(output[sort_reverse_index]) valid_attention_scores.extend( attention_scores[sort_reverse_index] if attention_scores is not None else []) assert len(all_outputs) == len(data) if loss_function is not None and total_ntokens > 0: # total validation loss valid_loss = total_loss # exponent of token-level negative log prob valid_ppl = torch.exp(total_loss / total_ntokens) else: valid_loss = -1 valid_ppl = -1 # decode back to symbols decoded_valid = model.trg_vocab.arrays_to_sentences(arrays=all_outputs, cut_at_eos=True) # evaluate with metric on full dataset join_char = " " if level in ["word", "bpe"] else "" valid_sources = [join_char.join(s) for s in data.src] valid_references = [join_char.join(t) for t in data.trg] valid_hypotheses = [join_char.join(t) for t in decoded_valid] # post-process if level == "bpe" and postprocess: valid_sources = [bpe_postprocess(s) for s in valid_sources] valid_references = [bpe_postprocess(v) for v in valid_references] valid_hypotheses = [bpe_postprocess(v) for v in valid_hypotheses] # if references are given, evaluate against them if valid_references: assert len(valid_hypotheses) == len(valid_references) current_valid_score = 0 if eval_metric.lower() == 'bleu': # this version does not use any tokenization current_valid_score = bleu(valid_hypotheses, valid_references) elif eval_metric.lower() == 'chrf': current_valid_score = chrf(valid_hypotheses, valid_references) elif eval_metric.lower() == 'token_accuracy': current_valid_score = token_accuracy(valid_hypotheses, valid_references, level=level) elif eval_metric.lower() == 'sequence_accuracy': current_valid_score = sequence_accuracy( valid_hypotheses, valid_references) else: current_valid_score = -1 return current_valid_score, valid_loss, valid_ppl, valid_sources, \ valid_sources_raw, valid_references, valid_hypotheses, \ decoded_valid, valid_attention_scores
def train_and_validate(self, train_data: Dataset, valid_data: Dataset) \ -> None: """ Train the model and validate it from time to time on the validation set. :param train_data: training data :param valid_data: validation data """ train_iter = make_data_iter(train_data, batch_size=self.batch_size, train=True, shuffle=self.shuffle) for epoch_no in range(self.epochs): self.logger.info("EPOCH %d", epoch_no + 1) if self.scheduler is not None and self.scheduler_step_at == "epoch": self.scheduler.step(epoch=epoch_no) self.model.train() start = time.time() total_valid_duration = 0 processed_tokens = self.total_tokens count = 0 epoch_loss = 0 for batch in iter(train_iter): # reactivate training self.model.train() # create a Batch object from torchtext batch batch = Batch(batch, self.pad_index, use_cuda=self.use_cuda) # only update every batch_multiplier batches # see https://medium.com/@davidlmorton/ # increasing-mini-batch-size-without-increasing- # memory-6794e10db672 update = count == 0 # print(count, update, self.steps) batch_loss = self._train_batch(batch, update=update) self.tb_writer.add_scalar("train/train_batch_loss", batch_loss, self.steps) count = self.batch_multiplier if update else count count -= 1 epoch_loss += batch_loss.detach().cpu().numpy() # log learning progress if self.steps % self.logging_freq == 0 and update: elapsed = time.time() - start - total_valid_duration elapsed_tokens = self.total_tokens - processed_tokens self.logger.info( "Epoch %d Step: %d Batch Loss: %f Tokens per Sec: %f", epoch_no + 1, self.steps, batch_loss, elapsed_tokens / elapsed) start = time.time() total_valid_duration = 0 # validate on the entire dev set if valid_data is not None and \ self.steps % self.validation_freq == 0 and update: valid_start_time = time.time() valid_score, valid_loss, valid_ppl, valid_sources, \ valid_sources_raw, valid_references, valid_hypotheses, \ valid_hypotheses_raw, valid_attention_scores, \ valid_logps = validate_on_data( batch_size=self.batch_size, data=valid_data, eval_metric=self.eval_metric, level=self.level, model=self.model, use_cuda=self.use_cuda, max_output_length=self.max_output_length, loss_function=self.loss, return_logp=self.return_logp) self.tb_writer.add_scalar("valid/valid_loss", valid_loss, self.steps) self.tb_writer.add_scalar("valid/valid_score", valid_score, self.steps) self.tb_writer.add_scalar("valid/valid_ppl", valid_ppl, self.steps) if self.early_stopping_metric == "loss": ckpt_score = valid_loss elif self.early_stopping_metric in ["ppl", "perplexity"]: ckpt_score = valid_ppl else: ckpt_score = valid_score new_best = False if self.is_best(ckpt_score): self.best_ckpt_score = ckpt_score self.best_ckpt_iteration = self.steps self.logger.info( 'Hooray! New best validation result [%s]!', self.early_stopping_metric) if self.ckpt_queue.maxsize > 0: self.logger.info("Saving new checkpoint.") new_best = True self._save_checkpoint() if self.scheduler is not None \ and self.scheduler_step_at == "validation": self.scheduler.step(ckpt_score) # append to validation report self._add_report(valid_score=valid_score, valid_loss=valid_loss, valid_ppl=valid_ppl, eval_metric=self.eval_metric, new_best=new_best) self._log_examples(sources_raw=valid_sources_raw, sources=valid_sources, hypotheses_raw=valid_hypotheses_raw, hypotheses=valid_hypotheses, references=valid_references) valid_duration = time.time() - valid_start_time total_valid_duration += valid_duration self.logger.info( 'Validation result at epoch %d, step %d: %s: %f, ' 'loss: %f, ppl: %f, duration: %.4fs', epoch_no + 1, self.steps, self.eval_metric, valid_score, valid_loss, valid_ppl, valid_duration) # store validation set outputs self._store_outputs( valid_hypotheses if self.post_process else [" ".join(v) for v in valid_hypotheses_raw], valid_logps if self.return_logp else None) # store attention plots for selected valid sentences store_attention_plots(attentions=valid_attention_scores, targets=valid_hypotheses_raw, sources=[s for s in valid_data.src], indices=self.log_valid_sents, output_prefix="{}/att.{}".format( self.model_dir, self.steps), tb_writer=self.tb_writer, steps=self.steps) if self.save_freq > 0 and self.steps % self.save_freq == 0: ## Drop checkpoint by number of batches ## Take care of batch multipler in to description self.logger.info("Saving new checkpoint!" "Batches passed:{}" "Number of updates:{}".format( self.batch_multiplier * self.steps, self.steps)) self._save_checkpoint() if self.stop: break if self.stop: self.logger.info( 'Training ended since minimum lr %f was reached.', self.learning_rate_min) break self.logger.info('Epoch %d: total training loss %.2f', epoch_no + 1, epoch_loss) else: self.logger.info('Training ended after %d epochs.', epoch_no + 1) if valid_data is not None: self.logger.info('Best validation result at step %d: %f %s.', self.best_ckpt_iteration, self.best_ckpt_score, self.early_stopping_metric)
def train_and_validate(self, train_data: Dataset, valid_data: Dataset): """ Train the model and validate it on the validation set. :param train_data: training data :param valid_data: validation data """ train_iter = make_data_iter( train_data, batch_size=self.batch_size, batch_type=self.batch_type, train=True, shuffle=self.shuffle) for epoch_no in range(1, self.epochs + 1): self.logger.info("EPOCH %d", epoch_no) if self.scheduler is not None and self.scheduler_step_at == "epoch": self.scheduler.step(epoch=epoch_no - 1) # 0-based indexing self.model.train() start = time.time() total_valid_duration = 0 processed_tokens = self.total_tokens epoch_loss = 0 for i, batch in enumerate(iter(train_iter), 1): # reactivate training self.model.train() # create a Batch object from torchtext batch batch = Batch(batch, self.pad_index, use_cuda=self.use_cuda) # only update every batch_multiplier batches # see https://medium.com/@davidlmorton/ # increasing-mini-batch-size-without-increasing- # memory-6794e10db672 update = i % self.batch_multiplier == 0 batch_loss = self._train_batch(batch, update=update) self.log_tensorboard("train", batch_loss=batch_loss) epoch_loss += batch_loss.detach().cpu().numpy() if self.scheduler is not None and \ self.scheduler_step_at == "step" and update: self.scheduler.step() # log learning progress if self.steps % self.logging_freq == 0 and update: elapsed = time.time() - start - total_valid_duration elapsed_tokens = self.total_tokens - processed_tokens self.logger.info( "Epoch %3d Step: %8d Batch Loss: %12.6f " "Tokens per Sec: %8.0f, Lr: %.6f", epoch_no, self.steps, batch_loss, elapsed_tokens / elapsed, self.optimizer.param_groups[0]["lr"]) start = time.time() total_valid_duration = 0 processed_tokens = self.total_tokens # validate on the entire dev set if self.steps % self.validation_freq == 0 and update: valid_start_time = time.time() # it would be nice to include loss and ppl in valid_scores valid_scores, valid_sources, valid_sources_raw, \ valid_references, valid_hypotheses, \ valid_hypotheses_raw, valid_attention_scores, \ scores_by_lang, by_lang = validate_on_data( batch_size=self.eval_batch_size, data=valid_data, eval_metrics=self.eval_metrics, attn_metrics=self.attn_metrics, src_level=self.src_level, trg_level=self.trg_level, model=self.model, use_cuda=self.use_cuda, max_output_length=self.max_output_length, loss_function=self.loss, beam_size=0, # greedy validations batch_type=self.eval_batch_type, save_attention=self.plot_attention, log_sparsity=self.log_sparsity, apply_mask=self.valid_apply_mask ) ckpt_score = valid_scores[self.early_stopping_metric] self.log_tensorboard("valid", **valid_scores) new_best = False if self.is_best(ckpt_score): self.best_ckpt_score = ckpt_score self.best_ckpt_iteration = self.steps self.logger.info( 'Hooray! New best validation result [%s]!', self.early_stopping_metric) if self.ckpt_queue.maxsize > 0: self.logger.info("Saving new checkpoint.") new_best = True self._save_checkpoint() if self.scheduler is not None \ and self.scheduler_step_at == "validation": self.scheduler.step(ckpt_score) # append to validation report self._add_report( valid_scores=valid_scores, eval_metrics=self.eval_metrics, new_best=new_best) self._log_examples( sources_raw=valid_sources_raw, sources=valid_sources, hypotheses_raw=valid_hypotheses_raw, hypotheses=valid_hypotheses, references=valid_references ) labeled_scores = sorted(valid_scores.items()) eval_report = ", ".join("{}: {:.5f}".format(n, v) for n, v in labeled_scores) valid_duration = time.time() - valid_start_time total_valid_duration += valid_duration self.logger.info( 'Validation result at epoch %3d, step %8d: %s, ' 'duration: %.4fs', epoch_no, self.steps, eval_report, valid_duration) if scores_by_lang is not None: for metric, scores in scores_by_lang.items(): # make a report lang_report = [metric] numbers = sorted(scores.items()) lang_report.extend(["{}: {:.5f}".format(k, v) for k, v in numbers]) self.logger.info("\n\t".join(lang_report)) # store validation set outputs self._store_outputs(valid_hypotheses) # store attention plots for selected valid sentences if valid_attention_scores and self.plot_attention: store_attention_plots( attentions=valid_attention_scores, sources=[s for s in valid_data.src], targets=valid_hypotheses_raw, indices=self.log_valid_sents, model_dir=self.model_dir, tb_writer=self.tb_writer, steps=self.steps) if self.stop: break if self.stop: self.logger.info( 'Training ended since minimum lr %f was reached.', self.learning_rate_min) break self.logger.info( 'Epoch %3d: total training loss %.2f', epoch_no, epoch_loss) else: self.logger.info('Training ended after %3d epochs.', epoch_no) self.logger.info('Best validation result at step %8d: %6.2f %s.', self.best_ckpt_iteration, self.best_ckpt_score, self.early_stopping_metric) self.tb_writer.close() # close Tensorboard writer
def train_and_validate(self, train_data: Dataset, valid_data: Dataset, kb_task=None, train_kb: TranslationDataset =None,\ train_kb_lkp: list = [], train_kb_lens: list = [], train_kb_truvals: TranslationDataset=None, valid_kb: Tuple=None, \ valid_kb_lkp: list=[], valid_kb_lens: list = [], valid_kb_truvals:list=[], valid_data_canon: list=[]) \ -> None: """ Train the model and validate it from time to time on the validation set. :param train_data: training data :param valid_data: validation data :param kb_task: is not None if kb_task should be executed :param train_kb: TranslationDataset holding the loaded train kb data :param train_kb_lkp: List with train example index to corresponding kb indices :param train_kb_len: List with num of triples per kb :param valid_kb: TranslationDataset holding the loaded valid kb data :param valid_kb_lkp: List with valid example index to corresponding kb indices :param valid_kb_len: List with num of triples per kb :param valid_kb_truvals: FIXME TODO :param valid_data_canon: required to report loss """ if kb_task: train_iter = make_data_iter_kb(train_data, train_kb, train_kb_lkp, train_kb_lens, train_kb_truvals, batch_size=self.batch_size, batch_type=self.batch_type, train=True, shuffle=self.shuffle, canonize=self.model.canonize) else: train_iter = make_data_iter(train_data, batch_size=self.batch_size, batch_type=self.batch_type, train=True, shuffle=self.shuffle) with torch.autograd.set_detect_anomaly(True): for epoch_no in range(self.epochs): self.logger.info("EPOCH %d", epoch_no + 1) if self.scheduler is not None and self.scheduler_step_at == "epoch": self.scheduler.step(epoch=epoch_no) self.model.train() start = time.time() total_valid_duration = 0 processed_tokens = self.total_tokens count = self.batch_multiplier - 1 epoch_loss = 0 for batch in iter(train_iter): # reactivate training self.model.train() # create a Batch object from torchtext batch batch = Batch(batch, self.pad_index, use_cuda=self.use_cuda) if not kb_task else \ Batch_with_KB(batch, self.pad_index, use_cuda=self.use_cuda) if kb_task: assert hasattr(batch, "kbsrc"), dir(batch) assert hasattr(batch, "kbtrg"), dir(batch) assert hasattr(batch, "kbtrv"), dir(batch) # only update every batch_multiplier batches # see https://medium.com/@davidlmorton/ # increasing-mini-batch-size-without-increasing- # memory-6794e10db672 update = count == 0 batch_loss = self._train_batch(batch, update=update) if update: self.tb_writer.add_scalar("train/train_batch_loss", batch_loss, self.steps) count = self.batch_multiplier if update else count count -= 1 epoch_loss += batch_loss.detach().cpu().numpy() if self.scheduler is not None and \ self.scheduler_step_at == "step" and update: self.scheduler.step() # log learning progress if self.steps % self.logging_freq == 0 and update: elapsed = time.time() - start - total_valid_duration elapsed_tokens = self.total_tokens - processed_tokens self.logger.info( "Epoch %3d Step: %8d Batch Loss: %12.6f " "Tokens per Sec: %8.0f, Lr: %.6f", epoch_no + 1, self.steps, batch_loss, elapsed_tokens / elapsed, self.optimizer.param_groups[0]["lr"]) start = time.time() total_valid_duration = 0 # validate on the entire dev set if self.steps % self.validation_freq == 0 and update: if self.manage_decoder_timer: self._log_decoder_timer_stats("train") self.decoder_timer.reset() valid_start_time = time.time() valid_score, valid_loss, valid_ppl, valid_sources, \ valid_sources_raw, valid_references, valid_hypotheses, \ valid_hypotheses_raw, valid_attention_scores, valid_kb_att_scores, \ valid_ent_f1, valid_ent_mcc = \ validate_on_data( batch_size=self.eval_batch_size, data=valid_data, eval_metric=self.eval_metric, level=self.level, model=self.model, use_cuda=self.use_cuda, max_output_length=self.max_output_length, loss_function=self.loss, beam_size=0, # greedy validations #FIXME XXX NOTE TODO BUG set to 0 again! batch_type=self.eval_batch_type, kb_task=kb_task, valid_kb=valid_kb, valid_kb_lkp=valid_kb_lkp, valid_kb_lens=valid_kb_lens, valid_kb_truvals=valid_kb_truvals, valid_data_canon=valid_data_canon, report_on_canonicals=self.report_entf1_on_canonicals ) if self.manage_decoder_timer: self._log_decoder_timer_stats("valid") self.decoder_timer.reset() self.tb_writer.add_scalar("valid/valid_loss", valid_loss, self.steps) self.tb_writer.add_scalar("valid/valid_score", valid_score, self.steps) self.tb_writer.add_scalar("valid/valid_ppl", valid_ppl, self.steps) if self.early_stopping_metric == "loss": ckpt_score = valid_loss elif self.early_stopping_metric in [ "ppl", "perplexity" ]: ckpt_score = valid_ppl else: ckpt_score = valid_score new_best = False if self.is_best(ckpt_score): self.best_ckpt_score = ckpt_score self.best_ckpt_iteration = self.steps self.logger.info( 'Hooray! New best validation result [%s]!', self.early_stopping_metric) if self.ckpt_queue.maxsize > 0: self.logger.info("Saving new checkpoint.") new_best = True self._save_checkpoint() if self.scheduler is not None \ and self.scheduler_step_at == "validation": self.scheduler.step(ckpt_score) # append to validation report self._add_report(valid_score=valid_score, valid_loss=valid_loss, valid_ppl=valid_ppl, eval_metric=self.eval_metric, valid_ent_f1=valid_ent_f1, valid_ent_mcc=valid_ent_mcc, new_best=new_best) # pylint: disable=unnecessary-comprehension self._log_examples( sources_raw=[v for v in valid_sources_raw], sources=valid_sources, hypotheses_raw=valid_hypotheses_raw, hypotheses=valid_hypotheses, references=valid_references) valid_duration = time.time() - valid_start_time total_valid_duration += valid_duration self.logger.info( 'Validation result at epoch %3d, step %8d: %s: %6.2f, ' 'loss: %8.4f, ppl: %8.4f, duration: %.4fs', epoch_no + 1, self.steps, self.eval_metric, valid_score, valid_loss, valid_ppl, valid_duration) # store validation set outputs self._store_outputs(valid_hypotheses) valid_src = list(valid_data.src) # store attention plots for selected valid sentences if valid_attention_scores: plot_success_ratio = store_attention_plots( attentions=valid_attention_scores, targets=valid_hypotheses_raw, sources=valid_src, indices=self.log_valid_sents, output_prefix="{}/att.{}".format( self.model_dir, self.steps), tb_writer=self.tb_writer, steps=self.steps) self.logger.info( f"stored {plot_success_ratio} valid att scores!" ) if valid_kb_att_scores: plot_success_ratio = store_attention_plots( attentions=valid_kb_att_scores, targets=valid_hypotheses_raw, sources=list(valid_kb.kbsrc), indices=self.log_valid_sents, output_prefix="{}/kbatt.{}".format( self.model_dir, self.steps), tb_writer=self.tb_writer, steps=self.steps, kb_info=(valid_kb_lkp, valid_kb_lens, valid_kb_truvals), on_the_fly_info=(valid_src, valid_kb, self.model.canonize, self.model.trg_vocab)) self.logger.info( f"stored {plot_success_ratio} valid kb att scores!" ) else: self.logger.info( "theres no valid kb att scores...") if self.stop: break if self.stop: self.logger.info( 'Training ended since minimum lr %f was reached.', self.learning_rate_min) break self.logger.info('Epoch %3d: total training loss %.2f', epoch_no + 1, epoch_loss) else: self.logger.info('Training ended after %3d epochs.', epoch_no + 1) self.logger.info('Best validation result at step %8d: %6.2f %s.', self.best_ckpt_iteration, self.best_ckpt_score, self.early_stopping_metric) self.tb_writer.close() # close Tensorboard writer
def validate_on_data(model: Model, data: Dataset, batch_size: int, use_cuda: bool, max_output_length: int, src_level: str, trg_level: str, eval_metrics: Optional[Sequence[str]], attn_metrics: Optional[Sequence[str]], loss_function: torch.nn.Module = None, beam_size: int = 0, beam_alpha: int = 0, batch_type: str = "sentence", save_attention: bool = False, log_sparsity: bool = False, apply_mask: bool = True # hmm ) \ -> (float, float, float, List[str], List[List[str]], List[str], List[str], List[List[str]], List[np.array]): """ Generate translations for the given data. If `loss_function` is not None and references are given, also compute the loss. :param model: model module :param data: dataset for validation :param batch_size: validation batch size :param use_cuda: if True, use CUDA :param max_output_length: maximum length for generated hypotheses :param src_level: source segmentation level, one of "char", "bpe", "word" :param trg_level: target segmentation level, one of "char", "bpe", "word" :param eval_metrics: evaluation metric, e.g. "bleu" :param loss_function: loss function that computes a scalar loss for given inputs and targets :param beam_size: beam size for validation. If 0 then greedy decoding (default). :param beam_alpha: beam search alpha for length penalty, disabled if set to 0 (default). :param batch_type: validation batch type (sentence or token) :return: - current_valid_score: current validation score [eval_metric], - valid_loss: validation loss, - valid_ppl:, validation perplexity, - valid_sources: validation sources, - valid_sources_raw: raw validation sources (before post-processing), - valid_references: validation references, - valid_hypotheses: validation_hypotheses, - decoded_valid: raw validation hypotheses (before post-processing), - valid_attention_scores: attention scores for validation hypotheses """ eval_funcs = { "bleu": bleu, "chrf": chrf, "token_accuracy": partial(token_accuracy, level=trg_level), "sequence_accuracy": sequence_accuracy, "wer": wer, "cer": partial(character_error_rate, level=trg_level) } selected_eval_metrics = {name: eval_funcs[name] for name in eval_metrics} valid_iter = make_data_iter( dataset=data, batch_size=batch_size, batch_type=batch_type, shuffle=False, train=False) valid_sources_raw = [s for s in data.src] pad_index = model.src_vocab.stoi[PAD_TOKEN] # disable dropout model.eval() # don't track gradients during validation scorer = partial(len_penalty, alpha=beam_alpha) if beam_alpha > 0 else None with torch.no_grad(): all_outputs = [] valid_attention_scores = defaultdict(list) total_loss = 0 total_ntokens = 0 total_nseqs = 0 total_attended = defaultdict(int) greedy_steps = 0 greedy_supported = 0 for valid_batch in iter(valid_iter): # run as during training to get validation loss (e.g. xent) batch = Batch(valid_batch, pad_index, use_cuda=use_cuda) # sort batch now by src length and keep track of order sort_reverse_index = batch.sort_by_src_lengths() # run as during training with teacher forcing if loss_function is not None and batch.trg is not None: batch_loss = model.get_loss_for_batch( batch, loss_function=loss_function) total_loss += batch_loss total_ntokens += batch.ntokens total_nseqs += batch.nseqs # run as during inference to produce translations output, attention_scores, probs = model.run_batch( batch=batch, beam_size=beam_size, scorer=scorer, max_output_length=max_output_length, log_sparsity=log_sparsity, apply_mask=apply_mask) if log_sparsity: lengths = torch.LongTensor((output == model.trg_vocab.stoi[EOS_TOKEN]).argmax(axis=1)).unsqueeze(1) batch_greedy_steps = lengths.sum().item() greedy_steps += lengths.sum().item() ix = torch.arange(output.shape[1]).unsqueeze(0).expand(output.shape[0], -1) mask = ix <= lengths supp = probs.exp().gt(0).sum(dim=-1).cpu() # batch x len supp = torch.where(mask, supp, torch.tensor(0)).sum() greedy_supported += supp.float().item() # sort outputs back to original order all_outputs.extend(output[sort_reverse_index]) if attention_scores is not None: # is attention_scores ever None? if save_attention: # beam search currently does not support attention logging for k, v in attention_scores.items(): valid_attention_scores[k].extend(v[sort_reverse_index]) if attn_metrics: # add to total_attended for k, v in attention_scores.items(): total_attended[k] += (v > 0).sum() assert len(all_outputs) == len(data) if log_sparsity: print(greedy_supported / greedy_steps) valid_scores = dict() if loss_function is not None and total_ntokens > 0: # total validation loss valid_loss = total_loss valid_scores["loss"] = total_loss valid_scores["ppl"] = torch.exp(total_loss / total_ntokens) # decode back to symbols decoded_valid = model.trg_vocab.arrays_to_sentences(arrays=all_outputs, cut_at_eos=True) # evaluate with metric on full dataset src_join_char = " " if src_level in ["word", "bpe"] else "" trg_join_char = " " if trg_level in ["word", "bpe"] else "" valid_sources = [src_join_char.join(s) for s in data.src] valid_references = [trg_join_char.join(t) for t in data.trg] valid_hypotheses = [trg_join_char.join(t) for t in decoded_valid] if attn_metrics: decoded_ntokens = sum(len(t) for t in decoded_valid) for attn_metric in attn_metrics: assert attn_metric == "support" for attn_name, tot_attended in total_attended.items(): score_name = attn_name + "_" + attn_metric # this is not the right denominator valid_scores[score_name] = tot_attended / decoded_ntokens # post-process if src_level == "bpe": valid_sources = [bpe_postprocess(s) for s in valid_sources] if trg_level == "bpe": valid_references = [bpe_postprocess(v) for v in valid_references] valid_hypotheses = [bpe_postprocess(v) for v in valid_hypotheses] languages = [language for language in data.language] by_language = defaultdict(list) seqs = zip(valid_references, valid_hypotheses) if valid_references else valid_hypotheses if languages: examples = zip(languages, seqs) for lang, seq in examples: by_language[lang].append(seq) else: by_language[None].extend(seqs) # if references are given, evaluate against them # incorrect if-condition? # scores_by_lang = {name: dict() for name in selected_eval_metrics} scores_by_lang = dict() if valid_references and eval_metrics is not None: assert len(valid_hypotheses) == len(valid_references) for eval_metric, eval_func in selected_eval_metrics.items(): score_by_lang = dict() for lang, pairs in by_language.items(): lang_hyps, lang_refs = zip(*pairs) lang_score = eval_func(lang_hyps, lang_refs) score_by_lang[lang] = lang_score score = sum(score_by_lang.values()) / len(score_by_lang) valid_scores[eval_metric] = score scores_by_lang[eval_metric] = score_by_lang if not languages: scores_by_lang = None return valid_scores, valid_sources, \ valid_sources_raw, valid_references, valid_hypotheses, \ decoded_valid, valid_attention_scores, scores_by_lang, by_language
def validate_on_data(model: Model, data: Dataset, batch_size: int, use_cuda: bool, max_output_length: int, level: str, eval_metric: Optional[str], loss_function: torch.nn.Module = None, beam_size: int = 0, beam_alpha: int = -1, batch_type: str = "sentence", kb_task = None, valid_kb: Dataset = None, valid_kb_lkp: list = [], valid_kb_lens:list=[], valid_kb_truvals: Dataset = None, valid_data_canon: Dataset = None, report_on_canonicals: bool = False, ) \ -> (float, float, float, List[str], List[List[str]], List[str], List[str], List[List[str]], List[np.array]): """ Generate translations for the given data. If `loss_function` is not None and references are given, also compute the loss. :param model: model module :param data: dataset for validation :param batch_size: validation batch size :param use_cuda: if True, use CUDA :param max_output_length: maximum length for generated hypotheses :param level: segmentation level, one of "char", "bpe", "word" :param eval_metric: evaluation metric, e.g. "bleu" :param loss_function: loss function that computes a scalar loss for given inputs and targets :param beam_size: beam size for validation. If 0 then greedy decoding (default). :param beam_alpha: beam search alpha for length penalty, disabled if set to -1 (default). :param batch_type: validation batch type (sentence or token) :param kb_task: is not None if kb_task should be executed :param valid_kb: MonoDataset holding the loaded valid kb data :param valid_kb_lkp: List with valid example index to corresponding kb indices :param valid_kb_len: List with amount of triples per kb :param valid_data_canon: TranslationDataset of valid data but with canonized target data (for loss reporting) :return: - current_valid_score: current validation score [eval_metric], - valid_loss: validation loss, - valid_ppl:, validation perplexity, - valid_sources: validation sources, - valid_sources_raw: raw validation sources (before post-processing), - valid_references: validation references, - valid_hypotheses: validation_hypotheses, - decoded_valid: raw validation hypotheses (before post-processing), - valid_attention_scores: attention scores for validation hypotheses - valid_ent_f1: TODO FIXME """ print(f"\n{'-'*10} ENTER VALIDATION {'-'*10}\n") print(f"\n{'-'*10} VALIDATION DEBUG {'-'*10}\n") print("---data---") print(dir(data[0])) print([[ getattr(example, attr) for attr in dir(example) if hasattr(getattr(example, attr), "__iter__") and "kb" in attr or "src" in attr or "trg" in attr ] for example in data[:3]]) print(batch_size) print(use_cuda) print(max_output_length) print(level) print(eval_metric) print(loss_function) print(beam_size) print(beam_alpha) print(batch_type) print(kb_task) print("---valid_kb---") print(dir(valid_kb[0])) print([[ getattr(example, attr) for attr in dir(example) if hasattr(getattr(example, attr), "__iter__") and "kb" in attr or "src" in attr or "trg" in attr ] for example in valid_kb[:3]]) print(len(valid_kb_lkp), valid_kb_lkp[-5:]) print(len(valid_kb_lens), valid_kb_lens[-5:]) print("---valid_kb_truvals---") print(len(valid_kb_truvals), valid_kb_lens[-5:]) print([[ getattr(example, attr) for attr in dir(example) if hasattr(getattr(example, attr), "__iter__") and "kb" in attr or "src" in attr or "trg" in attr or "trv" in attr ] for example in valid_kb_truvals[:3]]) print("---valid_data_canon---") print(len(valid_data_canon), valid_data_canon[-5:]) print([[ getattr(example, attr) for attr in dir(example) if hasattr(getattr(example, attr), "__iter__") and "kb" in attr or "src" in attr or "trg" in attr or "trv" or "can" in attr ] for example in valid_data_canon[:3]]) print(report_on_canonicals) print(f"\n{'-'*10} END VALIDATION DEBUG {'-'*10}\n") if not kb_task: valid_iter = make_data_iter(dataset=data, batch_size=batch_size, batch_type=batch_type, shuffle=False, train=False) else: # knowledgebase version of make data iter and also provide canonized target data # data: for bleu/ent f1 # canon_data: for loss valid_iter = make_data_iter_kb(data, valid_kb, valid_kb_lkp, valid_kb_lens, valid_kb_truvals, batch_size=batch_size, batch_type=batch_type, shuffle=False, train=False, canonize=model.canonize, canon_data=valid_data_canon) valid_sources_raw = data.src pad_index = model.src_vocab.stoi[PAD_TOKEN] # disable dropout model.eval() # don't track gradients during validation with torch.no_grad(): all_outputs = [] valid_attention_scores = [] valid_kb_att_scores = [] total_loss = 0 total_ntokens = 0 total_nseqs = 0 for valid_batch in iter(valid_iter): # run as during training to get validation loss (e.g. xent) batch = Batch(valid_batch, pad_index, use_cuda=use_cuda) \ if not kb_task else \ Batch_with_KB(valid_batch, pad_index, use_cuda=use_cuda) assert hasattr(batch, "kbsrc") == bool(kb_task) # sort batch now by src length and keep track of order if not kb_task: sort_reverse_index = batch.sort_by_src_lengths() else: sort_reverse_index = list(range(batch.src.shape[0])) # run as during training with teacher forcing if loss_function is not None and batch.trg is not None: ntokens = batch.ntokens if hasattr(batch, "trgcanon") and batch.trgcanon is not None: ntokens = batch.ntokenscanon # normalize loss with num canonical tokens for perplexity # do a loss calculation without grad updates just to report valid loss # we can only do this when batch.trg exists, so not during actual translation/deployment batch_loss = model.get_loss_for_batch( batch, loss_function=loss_function) # keep track of metrics for reporting total_loss += batch_loss total_ntokens += ntokens # gold target tokens total_nseqs += batch.nseqs # run as during inference to produce translations output, attention_scores, kb_att_scores = model.run_batch( batch=batch, beam_size=beam_size, beam_alpha=beam_alpha, max_output_length=max_output_length) # sort outputs back to original order all_outputs.extend(output[sort_reverse_index]) valid_attention_scores.extend( attention_scores[sort_reverse_index] if attention_scores is not None else []) valid_kb_att_scores.extend(kb_att_scores[sort_reverse_index] if kb_att_scores is not None else []) assert len(all_outputs) == len(data) if loss_function is not None and total_ntokens > 0: # total validation loss valid_loss = total_loss # exponent of token-level negative log likelihood # can be seen as 2^(cross_entropy of model on valid set); normalized by num tokens; # see https://en.wikipedia.org/wiki/Perplexity#Perplexity_per_word valid_ppl = torch.exp(valid_loss / total_ntokens) else: valid_loss = -1 valid_ppl = -1 # decode back to symbols decoding_vocab = model.trg_vocab if not kb_task else model.trv_vocab decoded_valid = decoding_vocab.arrays_to_sentences(arrays=all_outputs, cut_at_eos=True) print(f"decoding_vocab.itos: {decoding_vocab.itos}") print(decoded_valid) # evaluate with metric on full dataset join_char = " " if level in ["word", "bpe"] else "" valid_sources = [join_char.join(s) for s in data.src] # TODO replace valid_references with uncanonicalized dev.car data ... requires writing new Dataset in data.py valid_references = [join_char.join(t) for t in data.trg] valid_hypotheses = [join_char.join(t) for t in decoded_valid] # post-process if level == "bpe": valid_sources = [bpe_postprocess(s) for s in valid_sources] valid_references = [bpe_postprocess(v) for v in valid_references] valid_hypotheses = [bpe_postprocess(v) for v in valid_hypotheses] # if references are given, evaluate against them if valid_references: assert len(valid_hypotheses) == len(valid_references) print(list(zip(valid_sources, valid_references, valid_hypotheses))) current_valid_score = 0 if eval_metric.lower() == 'bleu': # this version does not use any tokenization current_valid_score = bleu(valid_hypotheses, valid_references) elif eval_metric.lower() == 'chrf': current_valid_score = chrf(valid_hypotheses, valid_references) elif eval_metric.lower() == 'token_accuracy': current_valid_score = token_accuracy(valid_hypotheses, valid_references, level=level) elif eval_metric.lower() == 'sequence_accuracy': current_valid_score = sequence_accuracy( valid_hypotheses, valid_references) if kb_task: valid_ent_f1, valid_ent_mcc = calc_ent_f1_and_ent_mcc( valid_hypotheses, valid_references, vocab=model.trv_vocab, c_fun=model.canonize, report_on_canonicals=report_on_canonicals) else: valid_ent_f1, valid_ent_mcc = -1, -1 else: current_valid_score = -1 print(f"\n{'-'*10} EXIT VALIDATION {'-'*10}\n") return current_valid_score, valid_loss, valid_ppl, valid_sources, \ valid_sources_raw, valid_references, valid_hypotheses, \ decoded_valid, valid_attention_scores, valid_kb_att_scores, \ valid_ent_f1, valid_ent_mcc
def dev_network(self): """ Show how is the current performace over the dev dataset, by mean of the total reward and the belu score. :return: current Bleu score """ freeze_model(self.eval_net) for data_set_name, data_set in self.data_to_dev.items(): #print(data_set_name) valid_iter = make_data_iter(dataset=data_set, batch_size=1, batch_type=self.batch_type, shuffle=False, train=False) valid_sources_raw = data_set.src # don't track gradients during validation r_total = 0 roptimal_total = 0 all_outputs = [] i_sample = 0 for valid_batch in iter(valid_iter): # run as during training to get validation loss (e.g. xent) batch = Batch(valid_batch, self.pad_index, use_cuda=self.use_cuda) encoder_output, encoder_hidden = self.model.encode( batch.src, batch.src_lengths, batch.src_mask) # if maximum output length is # not globally specified, adapt to src len if self.max_output_length is None: self.max_output_length = int( max(batch.src_lengths.cpu().numpy()) * 1.5) batch_size = batch.src_mask.size(0) prev_y = batch.src_mask.new_full(size=[batch_size, 1], fill_value=self.bos_index, dtype=torch.long) output = [] hidden = self.model.decoder._init_hidden(encoder_hidden) prev_att_vector = None finished = batch.src_mask.new_zeros((batch_size, 1)).byte() # pylint: disable=unused-variable for t in range(self.max_output_length): # if i_sample == 0 or i_sample == 3 or i_sample == 6: # print("state on t = ", t, " : " , state) # decode one single step logits, hidden, att_probs, prev_att_vector = self.model.decoder( encoder_output=encoder_output, encoder_hidden=encoder_hidden, src_mask=batch.src_mask, trg_embed=self.model.trg_embed(prev_y), hidden=hidden, prev_att_vector=prev_att_vector, unroll_steps=1) # greedy decoding: choose arg max over vocabulary in each step with egreedy porbability if self.state_type == 'hidden': state = torch.cat(hidden, dim=2).squeeze(1).detach().cpu()[0] else: state = torch.FloatTensor( prev_att_vector.squeeze(1).detach().cpu().numpy() [0]) logits = self.eval_net(state) logits = logits.reshape([1, 1, -1]) #print(type(logits), logits.shape, logits) next_word = torch.argmax(logits, dim=-1) a = next_word.squeeze(1).detach().cpu().numpy()[0] prev_y = next_word output.append(next_word.squeeze(1).detach().cpu().numpy()) prev_y = next_word # check if previous symbol was <eos> is_eos = torch.eq(next_word, self.eos_index) finished += is_eos # stop predicting if <eos> reached for all elements in batch if (finished >= 1).sum() == batch_size: break stacked_output = np.stack(output, axis=1) # batch, time #decode back to symbols decoded_valid_in = self.model.trg_vocab.arrays_to_sentences( arrays=batch.src, cut_at_eos=True) decoded_valid_out_trg = self.model.trg_vocab.arrays_to_sentences( arrays=batch.trg, cut_at_eos=True) decoded_valid_out = self.model.trg_vocab.arrays_to_sentences( arrays=stacked_output, cut_at_eos=True) hyp = stacked_output r = self.Reward(batch.trg, hyp, show=False) if i_sample == 0 or i_sample == 3 or i_sample == 6: print( "\n Sample ", i_sample, "-------------Target vs Eval_net prediction:--Raw---and---Decoded-----" ) print("Target: ", batch.trg, decoded_valid_out_trg) print("Eval : ", stacked_output, decoded_valid_out, "\n") print("Reward: ", r) #r = self.Reward1(batch.trg, hyp , show = False) r_total += sum(r[np.where(r > 0)]) if i_sample == 0: roptimal = self.Reward(batch.trg, batch.trg, show=False) roptimal_total += sum(roptimal[np.where(roptimal > 0)]) all_outputs.extend(stacked_output) i_sample += 1 assert len(all_outputs) == len(data_set) # decode back to symbols decoded_valid = self.model.trg_vocab.arrays_to_sentences( arrays=all_outputs, cut_at_eos=True) # evaluate with metric on full dataset join_char = " " if self.level in ["word", "bpe"] else "" valid_sources = [join_char.join(s) for s in data_set.src] valid_references = [join_char.join(t) for t in data_set.trg] valid_hypotheses = [join_char.join(t) for t in decoded_valid] # post-process if self.level == "bpe": valid_sources = [bpe_postprocess(s) for s in valid_sources] valid_references = [ bpe_postprocess(v) for v in valid_references ] valid_hypotheses = [ bpe_postprocess(v) for v in valid_hypotheses ] # if references are given, evaluate against them if valid_references: assert len(valid_hypotheses) == len(valid_references) current_valid_score = 0 if self.eval_metric.lower() == 'bleu': # this version does not use any tokenization current_valid_score = bleu(valid_hypotheses, valid_references) elif self.eval_metric.lower() == 'chrf': current_valid_score = chrf(valid_hypotheses, valid_references) elif self.eval_metric.lower() == 'token_accuracy': current_valid_score = token_accuracy(valid_hypotheses, valid_references, level=self.level) elif self.eval_metric.lower() == 'sequence_accuracy': current_valid_score = sequence_accuracy( valid_hypotheses, valid_references) else: current_valid_score = -1 self.dev_network_count += 1 self.tb_writer.add_scalar("dev/dev_reward", r_total, self.dev_network_count) self.tb_writer.add_scalar("dev/dev_bleu", current_valid_score, self.dev_network_count) print(self.dev_network_count, ' r_total and score: ', r_total, current_valid_score) unfreeze_model(self.eval_net) return current_valid_score
def Collecting_experiences(self) -> None: """ Main funtion. Compute all the process. :param exp_list: List of experineces. Tuples (memory_counter, state, a, state_, is_eos[0,0]) :param rew: rewards for every experince. Of lenght of the hypotesis """ for epoch_no in range(self.epochs): print("EPOCH %d", epoch_no + 1) #beam_dqn = self.beam_min + int(self.beam_max * epoch_no/self.epochs) #egreed = self.egreed_max*(1 - epoch_no/(1.1*self.epochs)) #self.gamma = self.gamma_max*(1 - epoch_no/(2*self.epochs)) beam_dqn = 1 egreed = 0.5 #self.gamma = self.gamma_max self.gamma = 0.6 self.tb_writer.add_scalar("parameters/beam_dqn", beam_dqn, epoch_no) self.tb_writer.add_scalar("parameters/egreed", egreed, epoch_no) self.tb_writer.add_scalar("parameters/gamma", self.gamma, epoch_no) if beam_dqn > self.actions_size: print("The beam_dqn cannot exceed the action size!") print("then the beam_dqn = action size") beam_dqn = self.actions_size print(' beam_dqn, egreed, gamma: ', beam_dqn, egreed, self.gamma) for _, data_set in self.data_to_train_dqn.items(): valid_iter = make_data_iter(dataset=data_set, batch_size=1, batch_type=self.batch_type, shuffle=False, train=False) #valid_sources_raw = data_set.src # disable dropout #self.model.eval() i_sample = 0 for valid_batch in iter(valid_iter): freeze_model(self.model) batch = Batch(valid_batch, self.pad_index, use_cuda=self.use_cuda) encoder_output, encoder_hidden = self.model.encode( batch.src, batch.src_lengths, batch.src_mask) # if maximum output length is not globally specified, adapt to src len if self.max_output_length is None: self.max_output_length = int( max(batch.src_lengths.cpu().numpy()) * 1.5) batch_size = batch.src_mask.size(0) prev_y = batch.src_mask.new_full(size=[batch_size, 1], fill_value=self.bos_index, dtype=torch.long) output = [] hidden = self.model.decoder._init_hidden(encoder_hidden) prev_att_vector = None finished = batch.src_mask.new_zeros((batch_size, 1)).byte() # print("Source_raw: ", batch.src) # print("Target_raw: ", batch.trg_input) # print("y0: ", prev_y) exp_list = [] # pylint: disable=unused-variable for t in range(self.max_output_length): if t != 0: if self.state_type == 'hidden': state = torch.cat( hidden, dim=2).squeeze(1).detach().cpu().numpy()[0] else: if t == 0: state = hidden[0].squeeze( 1).detach().cpu().numpy()[0] else: state = prev_att_vector.squeeze( 1).detach().cpu().numpy()[0] # decode one single step logits, hidden, att_probs, prev_att_vector = self.model.decoder( encoder_output=encoder_output, encoder_hidden=encoder_hidden, src_mask=batch.src_mask, trg_embed=self.model.trg_embed(prev_y), hidden=hidden, prev_att_vector=prev_att_vector, unroll_steps=1) # logits: batch x time=1 x vocab (logits) if t != 0: if self.state_type == 'hidden': state_ = torch.cat( hidden, dim=2).squeeze(1).detach().cpu().numpy()[0] else: state_ = prev_att_vector.squeeze( 1).detach().cpu().numpy()[0] # if t == 0: # print('states0: ', state, state_) # greedy decoding: choose arg max over vocabulary in each step with egreedy porbability if random.uniform(0, 1) < egreed: i_ran = random.randint(0, beam_dqn - 1) next_word = torch.argsort(logits, descending=True)[:, :, i_ran] else: next_word = torch.argmax(logits, dim=-1) # batch x time=1 # if t != 0: a = prev_y.squeeze(1).detach().cpu().numpy()[0] #a = next_word.squeeze(1).detach().cpu().numpy()[0] # print("state ",t," : ", state ) # print("state_ ",t," : ", state_ ) # print("action ",t," : ", a ) # print("__________________________________________") output.append( next_word.squeeze(1).detach().cpu().numpy()) #tup = (self.memory_counter, state, a, state_) prev_y = next_word # check if previous symbol was <eos> is_eos = torch.eq(next_word, self.eos_index) finished += is_eos if t != 0: self.memory_counter += 1 tup = (self.memory_counter, state, a, state_, 1) exp_list.append(tup) #print(t) # stop predicting if <eos> reached for all elements in batch if (finished >= 1).sum() == batch_size: a = next_word.squeeze(1).detach().cpu().numpy()[0] self.memory_counter += 1 #tup = (self.memory_counter, state_, a, np.zeros([self.state_size]) , is_eos[0,0]) tup = (self.memory_counter, state_, a, np.zeros([self.state_size]), 0) exp_list.append(tup) #print('break') break if t == self.max_output_length - 1: #print("reach the max output") a = 0 self.memory_counter += 1 #tup = (self.memory_counter, state_, a, np.zeros([self.state_size]) , is_eos[0,0]) tup = (self.memory_counter, state_, a, -1 * np.ones([self.state_size]), 1) exp_list.append(tup) #Collecting rewards hyp = np.stack(output, axis=1) # batch, time if epoch_no == 0: if i_sample == 0 or i_sample == 3 or i_sample == 6: #print(i_sample) r = self.Reward(batch.trg, hyp, show=True) # 1 , time-1 else: r = self.Reward(batch.trg, hyp, show=False) # 1 , time -1 else: #print("aaaa - ",i_sample) r = self.Reward(batch.trg, hyp, show=False) # 1 , time -1 # if i_sample == 0 or i_sample == 3 or i_sample == 6: # print("\n Sample Collected: ", i_sample, "-------------Target vs Eval_net prediction:--Raw---and---Decoded-----") # print("Target: ", batch.trg, decoded_valid_out_trg) # print("Eval : ", stacked_output, decoded_valid_out) # print("Reward: ", r, "\n") i_sample += 1 self.store_transition(exp_list, r) #Learning..... if self.memory_counter > self.mem_cap - self.max_output_length: self.learn() self.tb_writer.close()