Пример #1
0
    def warm_up(self, data_loader):
        model = self.model
        model.train()
        step_losses = {}
        num_iters = self.cfg.schedule.warmup.steps
        cur_iter = 0
        while cur_iter < num_iters:
            for iter_id, batch in enumerate(data_loader):
                cur_iter += 1
                if cur_iter >= num_iters:
                    break
                lr = self.get_warmup_lr(cur_iter)
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] = lr
                batch['img'] = batch['img'].to(device=torch.device('cuda'), non_blocking=True)
                output, loss, loss_stats = self.run_step(model, batch)

                # TODO: simplify code
                for k in loss_stats:
                    if k not in step_losses:
                        step_losses[k] = MovingAverage(loss_stats[k].mean().item(), window_size=self.cfg.log.interval)
                    else:
                        step_losses[k].push(loss_stats[k].mean().item())
                if iter_id % self.cfg.log.interval == 0:
                    log_msg = '{}|Iter({}/{})| lr:{:.2e}| '.format('warmup', cur_iter, num_iters, self.optimizer.param_groups[0]['lr'])
                    for l in step_losses:
                        log_msg += '{}:{:.4f}| '.format(l, step_losses[l].avg())
                    self.logger.log(log_msg)
                del output, loss, loss_stats
Пример #2
0
    def run_epoch(self, epoch, data_loader, mode):
        """
        train or validate one epoch
        :param epoch: current epoch number
        :param data_loader: dataloader of train or test dataset
        :param mode: train or val or test
        :return: outputs and a dict of epoch average losses
        """
        model = self.model
        if mode == 'train':
            model.train()
            if self.rank > -1:  # Using distributed training, need to set epoch for sampler
                self.logger.log(
                    "distributed sampler set epoch at {}".format(epoch))
                data_loader.sampler.set_epoch(epoch)  # 每个epoch开始时重新设置采样器
        else:
            model.eval()
            torch.cuda.empty_cache()
        results = {}
        epoch_losses = {}
        step_losses = {}
        num_iters = len(data_loader)
        for iter_id, meta in enumerate(data_loader):
            if iter_id >= num_iters:
                break
            meta['img'] = meta['img'].to(device=torch.device('cuda'),
                                         non_blocking=True)
            output, loss, loss_stats = self.run_step(
                model, meta, mode)  # 调用上一步的step run进行一步train
            if mode == 'val' or mode == 'test':
                batch_dets = model.module.head.post_process(output, meta)
                results.update(batch_dets)

            for k in loss_stats:  # 统计平均loss:epoch_losses和滑动窗内的loss:step_losses
                if k not in epoch_losses:
                    epoch_losses[k] = AverageMeter(loss_stats[k].mean().item())
                    step_losses[k] = MovingAverage(
                        loss_stats[k].mean().item(),
                        window_size=self.cfg.log.interval)
                else:
                    epoch_losses[k].update(loss_stats[k].mean().item())
                    step_losses[k].push(loss_stats[k].mean().item())

            if iter_id % self.cfg.log.interval == 0:
                log_msg = '{}|Epoch{}/{}|Iter{}({}/{})| lr:{:.2e}| '.format(
                    mode, epoch, self.cfg.schedule.total_epochs, self._iter,
                    iter_id, num_iters, self.optimizer.param_groups[0]['lr'])
                for l in step_losses:
                    log_msg += '{}:{:.4f}| '.format(l, step_losses[l].avg())
                    if mode == 'train' and self.rank < 1:  # tensorboard遍历要记录的losses
                        self.logger.scalar_summary('Train_loss/' + l, mode,
                                                   step_losses[l].avg(),
                                                   self._iter)
                self.logger.log(log_msg)  # txt要保存的log
            if mode == 'train':
                self._iter += 1
            del output, loss, loss_stats  # 每个iter后删除了引用使loss被清理
        epoch_loss_dict = {k: v.avg
                           for k, v in epoch_losses.items()}  # 一个epoch的平均loss
        return results, epoch_loss_dict
Пример #3
0
    def run_epoch(self, epoch, data_loader, mode):
        """
        train or validate one epoch
        :param epoch: current epoch number
        :param data_loader: data-loader of train or test dataset
        :param mode: train or val or test
        :return: outputs and a dict of epoch average losses
        """
        model = self.model
        if mode == 'train':
            model.train()  # network in train mode
            if self.rank > -1:  # Using distributed training, need to set epoch for sampler
                self.logger.log(
                    "distributed sampler set epoch at {}".format(epoch))
                data_loader.sampler.set_epoch(epoch)
        else:
            model.eval()
            torch.cuda.empty_cache()

        ret_dict = {}
        epoch_losses = {}
        step_losses = {}
        num_iters = len(data_loader)

        for iter_id, meta in enumerate(data_loader):
            if iter_id >= num_iters:
                break

            meta['img'] = meta['img'].to(device=torch.device('cuda'),
                                         non_blocking=True)
            output, loss, loss_stats = self.run_step(model, meta, mode)

            if mode == 'val':  # TODO: eval
                dets_dict = model.module.head.post_process(output, meta)
                ret_dict[meta['img_info']['id'].cpu().numpy()[0]] = dets_dict

            for k in loss_stats:
                if k not in epoch_losses:
                    epoch_losses[k] = AverageMeter(loss_stats[k].mean().item())
                    step_losses[k] = MovingAverage(
                        loss_stats[k].mean().item(),
                        window_size=self.cfg.log.interval)
                else:
                    epoch_losses[k].update(loss_stats[k].mean().item())
                    step_losses[k].push(loss_stats[k].mean().item())

            if iter_id % self.cfg.log.interval == 0:
                log_msg = '{}|Epoch{}/{}|Iter{}({}/{})| lr:{:.3e}| '.format(
                    mode, epoch, self.cfg.schedule.total_epochs, self._iter,
                    iter_id, num_iters, self.optimizer.param_groups[0]['lr'])
                for l in step_losses:
                    log_msg += '{}:{:.4f}| '.format(l, step_losses[l].avg())
                    if mode == 'train' and self.rank < 1:
                        self.logger.scalar_summary('Train_loss/' + l, mode,
                                                   step_losses[l].avg(),
                                                   self._iter)
                self.logger.log(log_msg)

            # ----- save checkpoint in an epoch
            if iter_id % self.cfg.schedule.save_interval == 0:
                save_model(
                    self.rank, self.model,
                    os.path.join(
                        self.cfg.save_dir,
                        'epoch{:d}_iter{:d}.pth'.format(epoch, iter_id)),
                    epoch, self._iter, self.optimizer)

            if mode == 'train':
                self._iter += 1

            del output, loss, loss_stats

        epoch_loss_dict = {k: v.avg for k, v in epoch_losses.items()}

        return ret_dict, epoch_loss_dict