def epoch_continuous_data(self, dat, epoch):
        self.s2s.train()
        epoch_start_time = time.time()
        log_start_time = time.time()
        sum_avgCEs = 0
        if self.nbatches == 0:
            self.nbatches = self.count_batches(dat.train)
        self.s2s.dec.init_state(batch_size=dat.train.size(1),
                                 encoder_final=None)

        itr = data.build_itr(dat.train, self.bptt)
        # pdb.set_trace()
        for batch_num_, (subblock, golds) in enumerate(itr):
            loss = self.step_on_batch(subblock, golds, src=None,
                                      lengths=None, start=False)
            sum_avgCEs += loss

            if (batch_num_ + 1) % self.interval == 0:
                self.report(epoch, batch_num_ + 1, sum_avgCEs / self.interval,
                            time.time() - log_start_time)
                sum_avgCEs = 0
                log_start_time = time.time()

        val_loss, dummy = self.evaluate_continuous_data(dat.valid)
        epoch_time = time.time() - epoch_start_time

        return val_loss, dummy, epoch_time
    def epoch_translation_data(self, dat, epoch, shuffle=False):
        self.s2s.train()
        epoch_start_time = time.time()
        log_start_time = time.time()
        sum_avgCEs = 0
        if self.nbatches == 0:
            self.nbatches = self.count_batches(dat.train)
        batch_num = 0

        if shuffle:
            random.shuffle(dat.train)

        for block, block_src, src_lens in dat.train:
            # pdb.set_trace()
            itr = data.build_itr(block, self.bptt)
            starting = True
            for subblock, golds in itr:
                loss = self.step_on_batch(subblock, golds, src=block_src,
                                          lengths=src_lens, start=starting)
                starting = False
                sum_avgCEs += loss
                batch_num += 1

                if batch_num % self.interval == 0:
                    self.report(epoch, batch_num, sum_avgCEs / self.interval,
                                time.time() - log_start_time)
                    sum_avgCEs = 0
                    log_start_time = time.time()

        val_loss, val_sqxent = self.evaluate_translation_data(dat.valid)
        epoch_time = time.time() - epoch_start_time

        return val_loss, val_sqxent, epoch_time
    def count_batches(self, training_source):
        if isinstance(training_source, torch.Tensor):
            nbatches = (training_source.size(0) - 1) // self.bptt
            if (training_source.size(0) - 1) % self.bptt != 0:
                nbatches += 1
        else:
            nbatches = 0
            for block, _, _, in training_source:
                for _ in data.build_itr(block, self.bptt):
                    nbatches += 1

        return nbatches
    def evaluate_continuous_data(self, block):
        self.s2s.eval()
        self.s2s.dec.init_state(batch_size=block.size(1), encoder_final=None)
        totalCE = 0.

        with torch.no_grad():
            itr = data.build_itr(block, self.bptt)
            for subblock, golds in itr:
                output, _ = self.s2s(subblock, src=None, lengths=None,
                                      start=False)
                totalCE += len(subblock) * self.avgCE(output, golds).item()

        # There's no padding symbol, so this division works.
        return (totalCE / (len(block) - 1),
                -1.0) # dummy value for sqxent
    def evaluate_translation_data(self, bundles):
        self.s2s.eval()
        totalCE = 0.
        ngolds = 0
        nseqs = 0

        with torch.no_grad():
            for block, block_src, src_lens in bundles:
                nseqs += block.size(1)
                itr = data.build_itr(block, self.bptt)
                starting = True
                for subblock, golds in itr:
                    output, _ = self.s2s(subblock, src=block_src,
                                          lengths=src_lens, start=starting)
                    starting = False
                    totalCE += self.sumCE(output, golds).item()
                    ngolds += len(golds.nonzero())  # Ignore padders!

        return (totalCE / ngolds if ngolds > 0 else float('inf'),
                totalCE / nseqs if nseqs > 0 else float('inf'))  # seq cross-ent
    def test_itr(self):
        dat = data.Data(self.dpath, 2, 'continuous', self.device)
        for batch_num, (subblock,
                        golds) in enumerate(data.build_itr(dat.train, 3)):
            self.assertLess(batch_num, 3)

            if batch_num == 0:
                self.assertEqual(verbalize_col(subblock, 0, dat.i2w), 'a b c')
                self.assertEqual(verbalize_col(subblock, 1, dat.i2w), 'h i j')
                self.assertEqual(verbalize_golds(golds, dat.i2w),
                                 'b i c j %s %s' % (dat.EOS, dat.EOS))
            if batch_num == 1:
                self.assertEqual(verbalize_col(subblock, 0, dat.i2w),
                                 '%s d e' % (dat.EOS))
                self.assertEqual(verbalize_col(subblock, 1, dat.i2w),
                                 '%s k l' % (dat.EOS))
                self.assertEqual(verbalize_golds(golds, dat.i2w),
                                 'd k e l f m')
            if batch_num == 2:
                self.assertEqual(verbalize_col(subblock, 0, dat.i2w), 'f g')
                self.assertEqual(verbalize_col(subblock, 1, dat.i2w), 'm n')
                self.assertEqual(verbalize_golds(golds, dat.i2w),
                                 'g n %s o' % (dat.EOS))