コード例 #1
0
    def valid_on_epoch(self, epoch, log_interval=None):
        self.eval()
        self.set_loader(self.v_dset, self.cfg.valid_batch, shuffle=False)
        if log_interval is None:
            log_interval = min(len(self.loader) // 10, self.cfg.log_interval)
            log_interval = max(log_interval, 1)

        count = 1
        str_time = time.time()
        logging.info(f'Valid ==> Validation...')

        total_loss = AverageMeter()
        recon_loss = AverageMeter()
        kldiv_loss = AverageMeter()
        batch_time = AverageMeter()

        with torch.no_grad():
            for batch, data in enumerate(self.loader):
                batch_size = data['inputs'].size(0)
                preds, losses = self.step(data, True)

                batch_time.update(time.time() - str_time)
                total_loss.update(losses['loss'].item())
                recon_loss.update(losses['Reconstruction_Loss'].item())
                kldiv_loss.update(losses['KLD'].item())

                if self.cfg.valid_visualize and batch % log_interval == 0:
                    valid_visualize(epoch, count, data, preds,
                                    self.cfg.valid_visualization,
                                    self.cfg.valid_visualize_num)
                    count += batch_size

                if batch % log_interval == 0 or batch == len(self.loader) - 1:
                    template = f'Valid ==> Epoch [{str(epoch+1).zfill(len(str(self.cfg.epochs)))} | {str(self.cfg.epochs)}] \
    ->  Batch [{str(batch+1).zfill(len(str(len(self.loader))))} | {str(len(self.loader))}] \
    ->  Loss: {total_loss.avg:.6f} (Reconstruction -> {recon_loss.avg:12.6f}  |  KL-Div -> {kldiv_loss.avg:12.6f}) \
    ->  Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s) \
    ->  Speed: {batch_size/batch_time.val:.1f} samples/s \
    ->  Left: {batch_time.avg*(len(self.loader)-1-batch):.3f}s'

                    logging.info(template)
                str_time = time.time()
        logging.info("Valid ==> Done.\n")
        return total_loss, recon_loss, kldiv_loss
コード例 #2
0
    def test(self):
        top1 = AverageMeter()
        top5 = AverageMeter()
        loss = AverageMeter()
        self.ckpt.add_log(torch.zeros(1, 3))
        
        epoch = self.scheduler.last_epoch + 1
        self.ckpt.write_log('\n[INFO] Test:')
        self.model.eval()

        for batch, (inputs, targets) in enumerate(self.test_loader):
            inputs = inputs.to(self.device)
            
            targets = targets.to(self.device)
            rois = targets[:, :-1]
            labels = targets[:, -1]
            
            # compute outputs
            outputs = self.model(inputs, rois)
            loss_tmp = self.loss(outputs, labels)
            outputs = nn.functional.softmax(outputs, dim=1)
            # measure accuracy and record loss
            prec1, prec5 = accuracy(outputs, labels, topk=(1,5))

            top1.update(prec1.item(), inputs.size(0))
            top5.update(prec5.item(), inputs.size(0))
            loss.update(loss_tmp.item(), inputs.size(0))
            
        self.ckpt.log[-1, 0] = top1.avg
        self.ckpt.log[-1, 1] = top5.avg
        self.ckpt.log[-1, 2] = loss.avg
        bests = self.ckpt.log.max(0)
        self.ckpt.write_log(
            '[INFO] top1: {:.4f} top5: {:.4f} loss: {:.4f} (Best: {:.4f} @epoch {})'.format(
            top1.avg,
            top5.avg,
            loss.avg,
            bests[0][0],
            (bests[1][0] + 1)*self.args.test_every
            )
        )            
        if not self.args.test_only:
            self.ckpt.save(self, epoch, is_best=((bests[1][0] + 1)*self.args.test_every == epoch))
コード例 #3
0
    def train_on_epoch(self, epoch):
        self.train()
        self.set_loader(self.t_dset, self.cfg.train_batch, shuffle=True)
        self.C_max = torch.FloatTensor([self.cfg.C_max]).to(self.cfg.device)

        log_interval = min(len(self.loader) // 10, self.cfg.log_interval)
        log_interval = max(log_interval, 1)

        str_time = time.time()
        logging.info(f'Train ==> Training...')

        total_loss = AverageMeter()
        kldiv_loss = AverageMeter()
        recon_loss = AverageMeter()
        batch_time = AverageMeter()
        for batch, data in enumerate(self.loader):
            batch_size = data['inputs'].size(0)
            losses = self.step(data)

            batch_time.update(time.time() - str_time)
            total_loss.update(losses['loss'].item())
            recon_loss.update(losses['Reconstruction_Loss'].item())
            kldiv_loss.update(losses['KLD'].item())

            if batch % log_interval == 0 or batch == len(self.loader) - 1:
                template = f'Train ==> Epoch [{str(epoch+1).zfill(len(str(self.cfg.epochs)))} | {str(self.cfg.epochs)}] \
->  Batch [{str(batch+1).zfill(len(str(len(self.loader))))} | {str(len(self.loader))}] \
->  Loss: {total_loss.avg:.6f} (Reconstruction -> {recon_loss.avg:12.6f}  |  KL-Div -> {kldiv_loss.avg:12.6f}) \
->  Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s) \
->  Speed: {batch_size/batch_time.val:.1f} samples/s \
->  Left: {batch_time.avg*(len(self.loader)-1-batch):.3f}s'

                logging.info(template)

            str_time = time.time()
        logging.info("Train ==> Done.\n")