def epoch_continuous_data(self, dat, epoch): self.s2s.train() epoch_start_time = time.time() log_start_time = time.time() sum_avgCEs = 0 if self.nbatches == 0: self.nbatches = self.count_batches(dat.train) self.s2s.dec.init_state(batch_size=dat.train.size(1), encoder_final=None) itr = data.build_itr(dat.train, self.bptt) # pdb.set_trace() for batch_num_, (subblock, golds) in enumerate(itr): loss = self.step_on_batch(subblock, golds, src=None, lengths=None, start=False) sum_avgCEs += loss if (batch_num_ + 1) % self.interval == 0: self.report(epoch, batch_num_ + 1, sum_avgCEs / self.interval, time.time() - log_start_time) sum_avgCEs = 0 log_start_time = time.time() val_loss, dummy = self.evaluate_continuous_data(dat.valid) epoch_time = time.time() - epoch_start_time return val_loss, dummy, epoch_time
def epoch_translation_data(self, dat, epoch, shuffle=False): self.s2s.train() epoch_start_time = time.time() log_start_time = time.time() sum_avgCEs = 0 if self.nbatches == 0: self.nbatches = self.count_batches(dat.train) batch_num = 0 if shuffle: random.shuffle(dat.train) for block, block_src, src_lens in dat.train: # pdb.set_trace() itr = data.build_itr(block, self.bptt) starting = True for subblock, golds in itr: loss = self.step_on_batch(subblock, golds, src=block_src, lengths=src_lens, start=starting) starting = False sum_avgCEs += loss batch_num += 1 if batch_num % self.interval == 0: self.report(epoch, batch_num, sum_avgCEs / self.interval, time.time() - log_start_time) sum_avgCEs = 0 log_start_time = time.time() val_loss, val_sqxent = self.evaluate_translation_data(dat.valid) epoch_time = time.time() - epoch_start_time return val_loss, val_sqxent, epoch_time
def count_batches(self, training_source): if isinstance(training_source, torch.Tensor): nbatches = (training_source.size(0) - 1) // self.bptt if (training_source.size(0) - 1) % self.bptt != 0: nbatches += 1 else: nbatches = 0 for block, _, _, in training_source: for _ in data.build_itr(block, self.bptt): nbatches += 1 return nbatches
def evaluate_continuous_data(self, block): self.s2s.eval() self.s2s.dec.init_state(batch_size=block.size(1), encoder_final=None) totalCE = 0. with torch.no_grad(): itr = data.build_itr(block, self.bptt) for subblock, golds in itr: output, _ = self.s2s(subblock, src=None, lengths=None, start=False) totalCE += len(subblock) * self.avgCE(output, golds).item() # There's no padding symbol, so this division works. return (totalCE / (len(block) - 1), -1.0) # dummy value for sqxent
def evaluate_translation_data(self, bundles): self.s2s.eval() totalCE = 0. ngolds = 0 nseqs = 0 with torch.no_grad(): for block, block_src, src_lens in bundles: nseqs += block.size(1) itr = data.build_itr(block, self.bptt) starting = True for subblock, golds in itr: output, _ = self.s2s(subblock, src=block_src, lengths=src_lens, start=starting) starting = False totalCE += self.sumCE(output, golds).item() ngolds += len(golds.nonzero()) # Ignore padders! return (totalCE / ngolds if ngolds > 0 else float('inf'), totalCE / nseqs if nseqs > 0 else float('inf')) # seq cross-ent
def test_itr(self): dat = data.Data(self.dpath, 2, 'continuous', self.device) for batch_num, (subblock, golds) in enumerate(data.build_itr(dat.train, 3)): self.assertLess(batch_num, 3) if batch_num == 0: self.assertEqual(verbalize_col(subblock, 0, dat.i2w), 'a b c') self.assertEqual(verbalize_col(subblock, 1, dat.i2w), 'h i j') self.assertEqual(verbalize_golds(golds, dat.i2w), 'b i c j %s %s' % (dat.EOS, dat.EOS)) if batch_num == 1: self.assertEqual(verbalize_col(subblock, 0, dat.i2w), '%s d e' % (dat.EOS)) self.assertEqual(verbalize_col(subblock, 1, dat.i2w), '%s k l' % (dat.EOS)) self.assertEqual(verbalize_golds(golds, dat.i2w), 'd k e l f m') if batch_num == 2: self.assertEqual(verbalize_col(subblock, 0, dat.i2w), 'f g') self.assertEqual(verbalize_col(subblock, 1, dat.i2w), 'm n') self.assertEqual(verbalize_golds(golds, dat.i2w), 'g n %s o' % (dat.EOS))