コード例 #1
0
ファイル: base_trainer.py プロジェクト: laclouis5/CenterNet
 def __init__(
   self, opt, model, optimizer=None):
   self.opt = opt
   self.optimizer = optimizer
   self.loss_stats, self.loss = self._get_losses(opt)
   self.model_with_loss = ModelWithLoss(model, self.loss)
コード例 #2
0
class BaseTrainer(object):
    def __init__(self, opt, model, optimizer=None):
        self.opt = opt
        self.optimizer = optimizer
        self.loss_stats, self.loss = self._get_losses(opt)
        self.eval_stats = self._get_evals(opt)
        self.model = model
        self.model_with_loss = ModleWithLoss(self.model, self.loss)

    def set_device(self, gpus, chunk_sizes, device):
        if len(gpus) > 1:
            self.model_with_loss = DataParallel(
                self.model_with_loss, device_ids=gpus,
                chunk_sizes=chunk_sizes).to(device)
        else:
            self.model_with_loss = self.model_with_loss.to(device)

        for state in self.optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device=device, non_blocking=True)

    def run_epoch(self, phase, epoch, data_loader):

        model_with_loss = self.model_with_loss

        if phase == 'train':
            model_with_loss.train()
        else:
            if len(self.opt.gpus) > 1:
                model_with_loss = self.model_with_loss.module
            model_with_loss.eval()

        opt = self.opt
        results = {}
        data_time, batch_time = AverageMeter(), AverageMeter()
        avg_loss_stats = {l: AverageMeter() for l in self.loss_stats}
        avg_eval_stats = {l: AverageMeter() for l in self.eval_stats}

        if phase == 'train':
            num_iters = len(data_loader.dataset) // opt.batch_size
        else:
            num_iters = len(data_loader.dataset)

        bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters)

        dataiterator = iter(data_loader)
        end = time.time()

        for iter_id in range(num_iters + 1):
            # if iter_id%100 == 0 and  phase == 'train':
            # save_model(os.path.join(opt.save_dir, 'model_epoch_{}_iter_{}.pth'.format(epoch-1,iter_id)),epoch-1, self.model, self.optimizer)

            self.optimizer.zero_grad()
            batch_loss_stats = {l: 0 for l in self.loss_stats}
            batch_eval_stats = {l: 0 for l in self.eval_stats}
            if phase == 'train':
                subdivision = opt.subdivision
            else:
                subdivision = 1

            for sub_iter in range(subdivision):
                try:
                    batch = next(dataiterator)
                except StopIteration:
                    dataiterator = iter(data_loader)
                    batch = next(dataiterator)
                data_time.update(time.time() - end)

                for k in batch:
                    batch[k] = batch[k].to(device=opt.device,
                                           non_blocking=True)
                output, loss, loss_stats = model_with_loss(batch)
                batch_time.update(time.time() - end)
                end = time.time()

                for l in batch_loss_stats:
                    batch_loss_stats[l] += loss_stats[l].item() / subdivision

                if phase == 'train':
                    loss.backward()
                else:
                    test_stats = self._get_result(batch, output)
                    for l in batch_eval_stats:
                        batch_eval_stats[l] += test_stats[l] / subdivision

            if phase == 'train':
                self.optimizer.step()

            Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format(
                epoch,
                iter_id,
                num_iters,
                phase=phase,
                total=bar.elapsed_td,
                eta=bar.eta_td)
            if phase == 'test':
                batch_size = 1
            else:
                batch_size = opt.batch_size

            for l in avg_loss_stats:
                avg_loss_stats[l].update(batch_loss_stats[l], batch_size)
                Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(
                    l, avg_loss_stats[l].avg)

            if not opt.hide_data_time:
                Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \
                '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time)

            if phase == 'test':
                for l in avg_eval_stats:
                    avg_eval_stats[l].update(batch_eval_stats[l], batch_size)
                    Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(
                        l, avg_eval_stats[l].avg)

            if opt.print_iter > 0:
                if iter_id % opt.print_iter == 0:
                    print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix))
            else:
                bar.next()

        bar.finish()
        ret = {k: v.avg for k, v in avg_loss_stats.items()}
        ret['time'] = bar.elapsed_td.total_seconds() / 60.
        if phase == 'test':
            for l in avg_eval_stats:
                ret[l] = avg_eval_stats[l].avg
        return ret, results

    def _get_losses(self, opt):
        raise NotImplementedError

    def _get_result(self, batch, output):
        raise NotImplementedError

    def _get_evals(self, opt):
        raise NotImplementedError

    def train(self, epoch, data_loader):
        return self.run_epoch('train', epoch, data_loader)

    def test(self, epoch, data_loader):
        return self.run_epoch('test', epoch, data_loader)
コード例 #3
0
class CtdetTrainer(object):
    def __init__(self, opt, model, optimizer=None):
        self.opt = opt
        self.optimizer = optimizer
        self.loss_stats, self.loss = self._get_losses(opt)
        self.model_with_loss = ModleWithLoss(model, self.loss)

    def set_device(self, gpus, chunk_sizes, device):
        if len(gpus) > 1:
            self.model_with_loss = DataParallel(
                self.model_with_loss, device_ids=gpus,
                chunk_sizes=chunk_sizes).to(device)
        else:
            self.model_with_loss = self.model_with_loss.to(device)

        for state in self.optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device=device, non_blocking=True)

    def run_epoch(self, phase, epoch, data_loader):
        model_with_loss = self.model_with_loss
        if phase == 'train':
            model_with_loss.train()
        else:
            if len(self.opt.gpus) > 1:
                model_with_loss = self.model_with_loss.module
            model_with_loss.eval()
            torch.cuda.empty_cache()

        opt = self.opt
        results = {}
        data_time, batch_time = AverageMeter(), AverageMeter()
        avg_loss_stats = {l: AverageMeter() for l in self.loss_stats}
        num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters
        bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters)
        end = time.time()
        for iter_id, batch in enumerate(data_loader):
            if iter_id >= num_iters:
                break
            data_time.update(time.time() - end)

            for k in batch:
                if k != 'meta':
                    batch[k] = batch[k].to(device=opt.device,
                                           non_blocking=True)
            output, loss, loss_stats = model_with_loss(batch)
            loss = loss.mean()
            if phase == 'train':
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
            batch_time.update(time.time() - end)
            end = time.time()

            Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format(
                epoch,
                iter_id,
                num_iters,
                phase=phase,
                total=bar.elapsed_td,
                eta=bar.eta_td)
            for l in avg_loss_stats:
                avg_loss_stats[l].update(loss_stats[l].mean().item(),
                                         batch['input'].size(0))
                Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(
                    l, avg_loss_stats[l].avg)
            if not opt.hide_data_time:
                Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \
                                          '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time)
            if opt.print_iter > 0:
                if iter_id % opt.print_iter == 0:
                    print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix))
            else:
                bar.next()

            if opt.debug > 0:
                self.debug(batch, output, iter_id)

            if opt.test:
                self.save_result(output, batch, results)
            del output, loss, loss_stats

        bar.finish()
        ret = {k: v.avg for k, v in avg_loss_stats.items()}
        ret['time'] = bar.elapsed_td.total_seconds() / 60.
        return ret, results

    def debug(self, batch, output, iter_id):
        opt = self.opt
        reg = output['reg'] if opt.reg_offset else None
        dets = ctdet_decode(output['hm'],
                            output['wh'],
                            reg=reg,
                            cat_spec_wh=opt.cat_spec_wh,
                            K=opt.K)
        dets = dets.detach().cpu().numpy().reshape(1, -1, dets.shape[2])
        dets[:, :, :4] *= opt.down_ratio
        dets_gt = batch['meta']['gt_det'].numpy().reshape(1, -1, dets.shape[2])
        dets_gt[:, :, :4] *= opt.down_ratio
        for i in range(1):
            debugger = Debugger(dataset=opt.dataset,
                                ipynb=(opt.debug == 3),
                                theme=opt.debugger_theme)
            img = batch['input'][i].detach().cpu().numpy().transpose(1, 2, 0)
            img = np.clip(((img * opt.std + opt.mean) * 255.), 0,
                          255).astype(np.uint8)
            pred = debugger.gen_colormap(
                output['hm'][i].detach().cpu().numpy())
            gt = debugger.gen_colormap(batch['hm'][i].detach().cpu().numpy())
            debugger.add_blend_img(img, pred, 'pred_hm')
            debugger.add_blend_img(img, gt, 'gt_hm')
            debugger.add_img(img, img_id='out_pred')
            for k in range(len(dets[i])):
                if dets[i, k, 4] > opt.center_thresh:
                    debugger.add_coco_bbox(dets[i, k, :4],
                                           dets[i, k, -1],
                                           dets[i, k, 4],
                                           img_id='out_pred')

            debugger.add_img(img, img_id='out_gt')
            for k in range(len(dets_gt[i])):
                if dets_gt[i, k, 4] > opt.center_thresh:
                    debugger.add_coco_bbox(dets_gt[i, k, :4],
                                           dets_gt[i, k, -1],
                                           dets_gt[i, k, 4],
                                           img_id='out_gt')

            if opt.debug == 4:
                debugger.save_all_imgs(opt.debug_dir,
                                       prefix='{}'.format(iter_id))
            else:
                debugger.show_all_imgs(pause=True)

    def save_result(self, output, batch, results):
        reg = output['reg'] if self.opt.reg_offset else None
        dets = ctdet_decode(output['hm'],
                            output['wh'],
                            reg=reg,
                            cat_spec_wh=self.opt.cat_spec_wh,
                            K=self.opt.K)
        dets = dets.detach().cpu().numpy().reshape(1, -1, dets.shape[2])
        dets_out = ctdet_post_process(dets.copy(),
                                      batch['meta']['c'].cpu().numpy(),
                                      batch['meta']['s'].cpu().numpy(),
                                      output['hm'].shape[2],
                                      output['hm'].shape[3],
                                      output['hm'].shape[1])
        results[batch['meta']['img_id'].cpu().numpy()[0]] = dets_out[0]

    def _get_losses(self, opt):
        loss_states = ['loss', 'hm_loss', 'wh_loss', 'off_loss']
        loss = CtdetLoss(opt)
        return loss_states, loss

    def val(self, epoch, data_loader):
        return self.run_epoch('val', epoch, data_loader)

    def train(self, epoch, data_loader):
        return self.run_epoch('train', epoch, data_loader)
コード例 #4
0
class Trainer:
    def __init__(self, cfg, model, optimizer):
        self.cfg = cfg
        self.optimizer = optimizer
        self.model = model
        if cfg.USE_OFFSET:
            self.loss_stats = {
                'total_loss': [],
                'hm_loss': [],
                'wh_loss': [],
                'offset_loss': []
            }
        else:
            self.loss_stats = {'total_loss': [], 'hm_loss': [], 'wh_loss': []}
        self.loss = CenterLoss(cfg)
        # for teacherloss
        # self.loss = TeacherLoss(cfg)

    def set_device(self, gpus, device):
        if len(gpus) > 1:
            self.model = DataParallel(self.model, device_ids=gpus).to(device)
            self.loss = DataParallel(self.loss, device_ids=gpus).to(device)
        else:
            self.model = self.model.to(device)
            self.loss = self.loss.to(device)
        for state in self.optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device=device, non_blocking=True)

    def run_epoch(self, epoch, data_loader, log_file):
        epoch_total_loss = 0.0
        epoch_hm_loss = 0.0
        epoch_wh_loss = 0.0
        epoch_offset_loss = 0.0

        self.model.train()
        self.loss.train()

        data_process = tqdm(data_loader)
        for batch_item in data_process:
            batch_img, batch_label = batch_item
            batch_img = batch_img.to(device=self.cfg.DEVICE)
            for k in batch_label:
                batch_label[k] = batch_label[k].to(device=self.cfg.DEVICE)
            batch_output, feature = self.model(batch_img)
            # for teacherloss
            # batch_output = self.model(batch_img)

            loss, loss_stats = self.loss(batch_output, batch_label)
            loss = loss.mean()
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            batch_loss = [
                loss_stats['total_loss'], loss_stats['hm_loss'],
                loss_stats['wh_loss']
            ]

            epoch_total_loss += batch_loss[0]
            epoch_hm_loss += batch_loss[1]
            epoch_wh_loss += batch_loss[2]

            loss_str = "total_loss: {},hm_loss: {:.6f},wh_loss: {:.6f}".format(
                batch_loss[0], batch_loss[1], batch_loss[2])
            if 'offset_loss' in loss_stats:
                batch_loss.append(loss_stats['offset_loss'])
                epoch_offset_loss += batch_loss[3]
                loss_str += ",offset_loss: {:.6f}".format(batch_loss[3])

            data_process.set_description_str("epoch:{}".format(epoch))
            data_process.set_postfix_str(loss_str)

        log_str = "{},{:.6f},{:.6f},{:.6f}".format(
            epoch, epoch_total_loss / len(data_loader),
            epoch_hm_loss / len(data_loader), epoch_wh_loss / len(data_loader))
        if self.cfg.USE_OFFSET:
            log_str += ",{:.6f}\n".format(epoch_offset_loss / len(data_loader))
        else:
            log_str += "\n"
        log_file.write(log_str)
        log_file.flush()

    def train(self, epoch, data_loader, train_log):
        return self.run_epoch(epoch, data_loader, train_log)

    def val(self, epoch, model_path, val_loader, val_log, cfg):
        detecter = Detector(model_path, cfg)
        mean_precision = 0
        mean_recall = 0
        sample_count = val_loader.num_samples
        for i in range(sample_count):
            image_path, gt_bboxes = val_loader.getitem(i)
            results = detecter.run(image_path)
            pre_bboxes = results[1]
            if len(gt_bboxes) > 0:
                precision, recall = compute_metrics(pre_bboxes, gt_bboxes)
                mean_precision += precision
                mean_recall += recall

        log_str = "{},{:.6f},{:.6f}\n".format(epoch,
                                              mean_precision / sample_count,
                                              mean_recall / sample_count)
        val_log.write(log_str)
        val_log.flush()
class BaseTrainer(object):
    def __init__(self, opt, model, optimizer=None):
        self.opt = opt
        self.optimizer = optimizer
        self.loss_stats, self.loss = self._get_losses(opt)
        self.model_with_loss = ModelWithLoss(model, self.loss)

    def set_device(self, gpus, chunk_sizes, device):
        if len(gpus) > 1:
            self.model_with_loss = DataParallel(
                self.model_with_loss, device_ids=gpus,
                chunk_sizes=chunk_sizes).to(device)
        else:
            self.model_with_loss = self.model_with_loss.to(device)

        for state in self.optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device=device, non_blocking=True)

    def run_epoch(self, phase, epoch, data_loader):
        opt = self.opt
        model_with_loss = self.model_with_loss
        if phase == 'train':
            model_with_loss.train()
        else:
            if len(opt.gpus) > 1:
                model_with_loss = self.model_with_loss.module
            model_with_loss.eval()
            torch.cuda.empty_cache()
        if opt.mixed_precision:
            from apex import amp

        results = {}
        data_time, batch_time = AverageMeter(), AverageMeter()
        avg_loss_stats = {l: AverageMeter() for l in self.loss_stats}
        num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters
        num_accum = opt.num_grad_accum or 1
        bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters)
        end = time.time()
        for iter_id, batch in enumerate(data_loader):
            if iter_id >= num_iters:
                break
            data_time.update(time.time() - end)
            # [0] 1 [2] 3 [4] 5 [6] 7 [8]
            # [0] 1 2 [3] 4 5 [6] 7 8 [9]
            if not iter_id % num_accum:
                self.optimizer.zero_grad()

            for k in batch:
                if k != 'meta':
                    batch[k] = batch[k].to(device=opt.device,
                                           non_blocking=True)
            output, loss, loss_stats = model_with_loss(batch)
            loss = loss.mean() / num_accum
            if phase == 'train':
                if opt.mixed_precision:
                    with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()
                # 0 [1] 2 [3] 4 [5] 6 [7] 8
                # 0 1 [2] 3 4 [5] 6 7 [8] 9
                if not (iter_id + 1) % num_accum:
                    self.optimizer.step()
                if opt.use_swa and opt.swa_manual:
                    # epochs count starts from 1
                    global_iter = num_iters * max(0, epoch - 1) + iter_id
                    if opt.swa_lr is not None and global_iter == opt.swa_start:
                        for param_group in self.optimizer.param_groups:
                            param_group['lr'] = opt.swa_lr
                    if global_iter > opt.swa_start:
                        if not (iter_id + 1) % opt.swa_freq:
                            self.optimizer.update_swa()
            batch_time.update(time.time() - end)
            end = time.time()

            Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format(
                epoch,
                iter_id,
                num_iters,
                phase=phase,
                total=bar.elapsed_td,
                eta=bar.eta_td)
            for l in avg_loss_stats:
                avg_loss_stats[l].update(loss_stats[l].mean().item(),
                                         batch['input'].size(0))
                Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(
                    l, avg_loss_stats[l].avg)
            if not opt.hide_data_time:
                Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \
                  '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time)
            if opt.print_iter > 0:
                if iter_id % opt.print_iter == 0:
                    print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix))
            else:
                bar.next()

            if opt.debug > 0:
                self.debug(batch, output, iter_id)

            if opt.test:
                self.save_result(output, batch, results)
            del output, loss, loss_stats

        bar.finish()
        ret = {k: v.avg for k, v in avg_loss_stats.items()}
        ret['time'] = bar.elapsed_td.total_seconds() / 60.
        return ret, results

    def bn_update(self, data_loader):
        opt = self.opt
        model_with_loss = self.model_with_loss
        model_with_loss.train()

        for iter_id, batch in enumerate(data_loader):
            if iter_id >= opt.swa_bn_upd_iters:
                break
            for k in batch:
                if k != 'meta':
                    batch[k] = batch[k].to(device=opt.device,
                                           non_blocking=True)
            with torch.no_grad():
                output, loss, loss_stats = model_with_loss(batch)
            del output, loss, loss_stats
        model_with_loss.eval()

    def debug(self, batch, output, iter_id):
        raise NotImplementedError

    def save_result(self, output, batch, results):
        raise NotImplementedError

    def _get_losses(self, opt):
        raise NotImplementedError

    def val(self, epoch, data_loader):
        return self.run_epoch('val', epoch, data_loader)

    def train(self, epoch, data_loader):
        return self.run_epoch('train', epoch, data_loader)
コード例 #6
0
class BaseTrainer(object):
  def __init__(
    self, opt, model, optimizer=None):
    self.opt = opt
    self.optimizer = optimizer
    self.loss_stats, self.loss = self._get_losses(opt)
    self.model_with_loss = ModelWithLoss(model, self.loss)
    self.rampup = exp_rampup(100)
    self.rampup_prob = exp_rampup(100)
    self.rampup_coor = exp_rampup(100)
  def set_device(self, gpus, chunk_sizes, device, distribute = False):
    # if len(gpus) > 1:
    #   self.model_with_loss = DataParallel(
    #     self.model_with_loss, device_ids=gpus,
    #     chunk_sizes=chunk_sizes).to(device)
    # else:
    #   self.model_with_loss = self.model_with_loss.to(device)
    #
    # for state in self.optimizer.state.values():
    #   for k, v in state.items():
    #     if isinstance(v, torch.Tensor):
    #       state[k] = v.to(device=device, non_blocking=True)
    if len(gpus) > 1 and not distribute:
      self.model_with_loss = DataParallel(
          self.model_with_loss, device_ids=gpus,
          chunk_sizes=chunk_sizes).to(device)
    else:
      self.model_with_loss = self.model_with_loss.to(device)
    for state in self.optimizer.state.values():
      for k, v in state.items():
        if isinstance(v, torch.Tensor):
          state[k] = v.to(device=device, non_blocking=True)

  def run_epoch(self, phase, epoch, data_loader,unlabel_loader1=None,unlabel_loader2=None,unlabel_set=None,iter_num=None,data_loder2=None):
    model_with_loss = self.model_with_loss
    if phase == 'train':
      model_with_loss.train()
    else:
      # if len(self.opt.gpus) > 1:
      #   model_with_loss = self.model_with_loss.module
      model_with_loss.eval()
      torch.cuda.empty_cache()

    opt = self.opt
    results = {}
    data_time, batch_time = AverageMeter(), AverageMeter()
    avg_loss_stats = {l: AverageMeter() for l in self.loss_stats}
    num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters
    bar = Bar('{}/{}'.format("3D detection", opt.exp_id), max=num_iters)
    end = time.time()

    for iter_id, batch in enumerate(data_loader):
      if iter_id >= num_iters:
        break
      data_time.update(time.time() - end)

      for k in batch:
        if k != 'meta':
          batch[k] = batch[k].to(device=opt.device, non_blocking=True)
      coor_weight=self.rampup_coor(epoch)
      if coor_weight< self.opt.coor_thresh:
          coor_weight=0
      output, loss, loss_stats = model_with_loss(batch,phase=phase)
      loss = loss['hm_loss'].mean() * opt.hm_weight + \
             loss['wh_loss'].mean() * opt.wh_weight + \
             loss['off_loss'].mean() * opt.off_weight + \
             loss['hp_loss'].mean() * opt.hp_weight + \
             loss['hp_offset_loss'].mean() * opt.off_weight + \
             loss['hm_hp_loss'].mean() * opt.hm_hp_weight + \
             loss['dim_loss'].mean() *opt.dim_weight + \
             loss['rot_loss'].mean() * opt.rot_weight + \
             loss['prob_loss'].mean()*self.rampup_prob(epoch)+\
             loss['coor_loss'].mean()*coor_weight
      # loss = loss['hm_loss'].mean() * opt.hm_weight + \
      #        loss['hp_loss'].mean() * opt.hp_weight + \
      #        loss['dim_loss'].mean() * opt.dim_weight + \
      #        loss['rot_loss'].mean() * opt.rot_weight + \
      #        loss['prob_loss'].mean() * self.rampup_prob(epoch) + \
      #        loss['coor_loss'].mean() * coor_weight
      if phase == 'train':
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
      batch_time.update(time.time() - end)
      end = time.time()

      Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format(
        epoch, iter_id, num_iters, phase=phase,
        total=bar.elapsed_td, eta=bar.eta_td)
      for l in avg_loss_stats:
        avg_loss_stats[l].update(
          loss_stats[l].mean().item(), batch['input'].size(0))
        Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(l, avg_loss_stats[l].avg)
      if not opt.hide_data_time:
        Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \
          '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time)
      if opt.print_iter > 0:
        if iter_id % opt.print_iter == 0:
          print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix)) 
      else:
        bar.next()
      
      if opt.debug > 0:
        self.debug(batch, output, iter_id)
      
      if opt.test:
        self.save_result(output, batch, results)
      del output, loss, loss_stats
    
    bar.finish()
    ret = {k: v.avg for k, v in avg_loss_stats.items()}
    ret['time'] = bar.elapsed_td.total_seconds() / 60.
    return ret, results
  
  def debug(self, batch, output, iter_id):
    raise NotImplementedError

  def save_result(self, output, batch, results):
    raise NotImplementedError

  def _get_losses(self, opt):
    raise NotImplementedError
  
  def val(self, epoch, data_loader):
    return self.run_epoch('val', epoch, data_loader)

  def train(self, epoch, data_loader,unlabel_loader1=None,unlabel_loader2=None,unlabel_set=None,iter_num=None,uncert=None):
    return self.run_epoch('train', epoch, data_loader,unlabel_loader1,unlabel_loader2,unlabel_set,iter_num,uncert)
コード例 #7
0
 def __init__(self, opt, model, optimizer=None):
     self.opt = opt
     self.optimizer = optimizer
     self.loss_stats, self.loss = self._get_losses(opt)
     self.model_with_loss = ModleWithLoss(model, self.loss)
     self.optimizer.add_param_group({'params': self.loss.parameters()})
コード例 #8
0
ファイル: base_trainer.py プロジェクト: naomori/CenterNet.org
 def __init__(self, opt, model, optimizer=None):
     self.opt = opt
     self.optimizer = optimizer
     self.loss_stats, self.loss = self._get_losses(opt)
     self.model_with_loss = ModleWithLoss(model, self.loss)
     self.summary_writer = SummaryWriter("./tbX")
コード例 #9
0
ファイル: base_trainer.py プロジェクト: pc2005/CenterNet
class BaseTrainer(object):
  def __init__(
    self, opt, model, optimizer=None):
    self.opt = opt
    self.optimizer = optimizer
    self.loss_stats, self.loss = self._get_losses(opt)
    self.model_with_loss = ModelWithLoss(model, self.loss)

  def set_device(self, gpus, chunk_sizes, device):
    if len(gpus) > 1:
      self.model_with_loss = DataParallel(
        self.model_with_loss, device_ids=gpus, 
        chunk_sizes=chunk_sizes).to(device)
    else:
      self.model_with_loss = self.model_with_loss.to(device)
    
    for state in self.optimizer.state.values():
      for k, v in state.items():
        if isinstance(v, torch.Tensor):
          state[k] = v.to(device=device, non_blocking=True)

  def run_epoch(self, phase, epoch, data_loader):
    """ Unified run epoch for train & evaluation

    Arguments:
        phase {str} -- 'train' or 'val'
        epoch {int} -- epoch index
        data_loader {object} -- data loader object

    Returns:
        ret [dict] -- 'time'
        results [dict] -- test results?
    """
    model_with_loss = self.model_with_loss
    
    if phase == 'train':
      model_with_loss.train() # train model
    else:
      if len(self.opt.gpus) > 1:
        model_with_loss = self.model_with_loss.module
      model_with_loss.eval()
      torch.cuda.empty_cache()

    opt = self.opt
    results = {}
    data_time, batch_time = AverageMeter(), AverageMeter()
    avg_loss_stats = {l: AverageMeter() for l in self.loss_stats}
    num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters
    bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters)
    end = time.time()
    
    for iter_id, batch in enumerate(data_loader):
      
      if iter_id >= num_iters:      # ! stop early for debug purpose
        break
      
      data_time.update(time.time() - end)

      for k in batch:
        if k != 'meta':
          batch[k] = batch[k].to(device=opt.device, non_blocking=True)    
      
      # feed-forward
      output, loss, loss_stats = model_with_loss(batch)

      # average batch loss
      loss = loss.mean()
      
      # back-propagation for training
      if phase == 'train':
        self.optimizer.zero_grad()  # clear old gradients
        loss.backward()             # back-propagation
        self.optimizer.step()       # optimizer step based on graident

      # update timer
      batch_time.update(time.time() - end)
      end = time.time()

      # visualization
      Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format(
        epoch, iter_id, num_iters, phase=phase,
        total=bar.elapsed_td, eta=bar.eta_td)
      
      for l in avg_loss_stats:
        avg_loss_stats[l].update(
          loss_stats[l].mean().item(), batch['input'].size(0))
        Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(l, avg_loss_stats[l].avg)
      
      if not opt.hide_data_time:
        Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \
          '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time)
      
      if opt.print_iter > 0:
        if iter_id % opt.print_iter == 0:
          print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix)) 
      else:
        bar.next()
      
      # show debug
      if opt.debug > 0 and epoch > 3:
        self.debug(batch, output, iter_id)
      
      if opt.test:
        self.save_result(output, batch, results)

      del output, loss, loss_stats  #? delete variables?
    
    bar.finish()
    ret = {k: v.avg for k, v in avg_loss_stats.items()}
    ret['time'] = bar.elapsed_td.total_seconds() / 60.
    
    return ret, results
  
  def debug(self, batch, output, iter_id):
    raise NotImplementedError

  def save_result(self, output, batch, results):
    raise NotImplementedError

  def _get_losses(self, opt):
    raise NotImplementedError
  
  def val(self, epoch, data_loader):
    return self.run_epoch('val', epoch, data_loader)

  def train(self, epoch, data_loader):
    return self.run_epoch('train', epoch, data_loader)
コード例 #10
0
class BaseTrainerIter(object):
    def __init__(self, opt, model, optimizer=None):
        self.opt = opt
        self.optimizer = optimizer
        self.accum_coco = AccumCOCO()
        self.accum_coco_det = AccumCOCODetResult()

    def set_device(self, gpus, chunk_sizes, device):
        if len(gpus) > 1:
            self.model = DataParallel(self.model,
                                      device_ids=gpus,
                                      chunk_sizes=chunk_sizes).to(device)
        else:
            self.model = self.model.to(device)

        for state in self.optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device=device, non_blocking=True)

    def run_epoch(self, phase, epoch, data_loader, no_aug_loader=None):
        model_with_loss = self.model
        if phase == 'train':
            model_with_loss.train()
        else:
            if len(self.opt.gpus) > 1:
                model_with_loss = self.model.module
            model_with_loss.eval()
            torch.cuda.empty_cache()
        opt = self.opt
        results = {}
        data_time, batch_time = AverageMeter(), AverageMeter()
        avg_loss_stats = {l: AverageMeter() for l in self.loss_stats}
        num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters
        bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters)
        end = time.time()

        if opt.save_video:
            fourcc = cv2.VideoWriter_fourcc(*'mp4v')
            vid_pth = os.path.join(opt.save_dir, opt.exp_id + '_pred')
            gt_pth = os.path.join(opt.save_dir, opt.exp_id + '_gt')
            out_pred = cv2.VideoWriter('{}.mp4'.format(vid_pth), fourcc,
                                       opt.save_framerate,
                                       (opt.input_w, opt.input_h))
            out_gt = cv2.VideoWriter('{}.mp4'.format(gt_pth), fourcc,
                                     opt.save_framerate,
                                     (opt.input_w, opt.input_h))

        delta_max = opt.delta_max
        delta_min = opt.delta_min
        delta = delta_min
        umax = opt.umax
        a_thresh = opt.acc_thresh
        metric = get_metric(opt)
        iter_id = 0

        data_iter = iter(data_loader)
        update_lst = []
        acc_lst = []
        coco_res_lst = []
        while True:
            load_time, total_model_time, model_time, update_time, tot_time, display_time = 0, 0, 0, 0, 0, 0
            start_time = time.time()
            # data loading
            try:
                batch = next(data_iter)
            except StopIteration:
                break

            if iter_id > opt.num_iters:
                break

            loaded_time = time.time()
            load_time += (loaded_time - start_time)

            if opt.adaptive:
                if iter_id % delta == 0:
                    u = 0
                    update = True
                    while (update):
                        output, tmp_model_time = self.run_model(batch)
                        total_model_time += tmp_model_time
                        # save the stuff every iteration
                        acc = metric.get_score(batch, output, u)
                        print(acc)
                        if u < umax and acc < a_thresh:
                            update_time = self.update_model(batch)
                        else:
                            update = False
                        u += 1
                    if acc > a_thresh:
                        delta = min(delta_max, 2 * delta)
                    else:
                        delta = max(delta_min, delta / 2)
                    output, _ = self.run_model(
                        batch)  # run model with new weights
                    model_time = total_model_time / u
                    update_lst += [(iter_id, u)]
                    acc_lst += [(iter_id, acc)]
                    self.accum_coco.store_metric_coco(iter_id, batch, output,
                                                      opt)
                else:
                    update_lst += [(iter_id, 0)]
                    output, model_time = self.run_model(batch)
                    if opt.acc_collect and (iter_id % opt.acc_interval == 0):
                        acc = metric.get_score(batch, output, 0)
                        print(acc)
                        acc_lst += [(iter_id, acc)]
                        self.accum_coco.store_metric_coco(
                            iter_id, batch, output, opt)
            else:
                output, model_time = self.run_model(batch)
                if opt.acc_collect:
                    acc = metric.get_score(batch, output, 0, is_baseline=True)
                    print(acc)
                    acc_lst += [(iter_id, acc)]
                    self.accum_coco.store_metric_coco(iter_id,
                                                      batch,
                                                      output,
                                                      opt,
                                                      is_baseline=True)

            display_start = time.time()

            if opt.tracking:
                trackers, viz_pred = self.tracking(
                    batch, output,
                    iter_id)  # TODO: factor this into the other class
                out_pred.write(viz_pred)
            elif opt.save_video:
                pred, gt = self.debug(batch, output, iter_id)
                out_pred.write(pred)
                out_gt.write(gt)
            if opt.debug > 1:
                self.debug(batch, output, iter_id)

            display_end = time.time()
            display_time = (display_end - display_start)
            end_time = time.time()
            tot_time = (end_time - start_time)

            # add a bunch of stuff to the bar to print
            Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format(
                epoch,
                iter_id,
                num_iters,
                phase=phase,
                total=bar.elapsed_td,
                eta=bar.eta_td)  # add to the progress bar
            if opt.print_iter > 0:
                if iter_id % opt.print_iter == 0:
                    print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix))
            else:
                bar.next()
            if opt.display_timing:
                time_str = 'total {:.3f}s| load {:.3f}s | model_time {:.3f}s | update_time {:.3f}s | display {:.3f}s'.format(
                    tot_time, load_time, model_time, update_time, display_time)
                print(time_str)
            self.save_result(output, batch, results)
            del output
            iter_id += 1

        bar.finish()
        ret = {k: v.avg for k, v in avg_loss_stats.items()}
        ret['time'] = bar.elapsed_td.total_seconds() / 60.
        save_dict = {}
        if opt.adaptive:
            plt.scatter(*zip(*update_lst))
            plt.xlabel('iteration')
            plt.ylabel('number of updates')
            plt.savefig(opt.save_dir + '/update_frequency.png')
            save_dict['updates'] = update_lst
            plt.clf()
        if opt.acc_collect:
            plt.scatter(*zip(*acc_lst))
            plt.xlabel('iteration')
            plt.ylabel('mAP')
            plt.savefig(opt.save_dir + '/acc_figure.png')
            save_dict['acc'] = acc_lst
        if opt.adaptive and opt.acc_collect:
            x, y = zip(*filter(lambda x: x[1] > 0, update_lst))
            plt.scatter(x, y, c='r', marker='o')
            plt.xlabel('iteration')

        # save dict
        # gt_dict = self.accum_coco.get_gt()
        # dt_dict = self.accum_coco.get_dt()
        dt_dict = self.accum_coco_det.get_dt()

        # save_dict['gt_dict'] = gt_dict
        # save_dict['dt_dict'] = dt_dict
        save_dict['full_res_pred'] = dt_dict
        return ret, results, save_dict

    def debug(self, batch, output, iter_id):
        raise NotImplementedError

    def save_result(self, output, batch, results):
        raise NotImplementedError

    def _get_losses(self, opt):
        raise NotImplementedError

    def train_model(self, batch):
        raise NotImplementedError

    def run_model(self, batch):
        raise NotImplementedError

    def update_model(self, loss):
        raise NotImplementedError

    def train(self, epoch, data_loader, aug_loader=None):
        return self.run_epoch('train', epoch, data_loader, aug_loader)
コード例 #11
0
class BaseTrainer(object):
    def __init__(self,
                 opt,
                 model,
                 optimizer=None,
                 logger=None,
                 lr_scheduler=None):
        self.opt = opt
        self.optimizer = optimizer
        self.loss_stats, self.loss = self._get_losses(opt)
        self.model_with_loss = ModleWithLoss(model, self.loss)

        # FIXME: vis interval
        self.original_debug = opt.debug
        self.vis_interval = opt.show_intervals
        if logger is not None:
            self.logger = logger
        self.lr_scheduler = lr_scheduler

    def set_device(self, gpus, chunk_sizes, device):
        if len(gpus) > 1:
            self.model_with_loss = DataParallel(
                self.model_with_loss, device_ids=gpus,
                chunk_sizes=chunk_sizes).to(device)
        else:
            self.model_with_loss = self.model_with_loss.to(device)

        for state in self.optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device=device, non_blocking=True)

    def run_epoch(self, phase, epoch, data_loader):
        model_with_loss = self.model_with_loss
        if phase == 'train':
            model_with_loss.train()
        else:
            # if len(self.opt.gpus) > 1:
            #   model_with_loss = self.model_with_loss.module
            model_with_loss.eval()
            # torch.cuda.empty_cache()

        opt = self.opt
        results = {}
        data_time, batch_time = AverageMeter(), AverageMeter()
        avg_loss_stats = {l: AverageMeter() for l in self.loss_stats}
        num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters
        bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters)
        end = time.time()
        for iter_id, batch in enumerate(data_loader):

            cur_step = (epoch - 1) * num_iters + iter_id
            self.lr_scheduler.step(cur_step)

            cur_lr = None
            for param_group in self.optimizer.param_groups:
                cur_lr = param_group['lr']
                break

            if iter_id >= num_iters:
                break
            data_time.update(time.time() - end)

            for k in batch:
                # FIXME:
                # key changed from 'meta' to 'meta_c' ...
                # if k != 'meta':
                if 'meta' not in k:
                    batch[k] = batch[k].to(device=opt.device,
                                           non_blocking=True)

            output, loss, loss_stats = model_with_loss(batch)
            loss = loss.mean()
            if phase == 'train':
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
            batch_time.update(time.time() - end)
            end = time.time()

            if self.logger is not None:
                for l in avg_loss_stats:
                    if avg_loss_stats[l].val > 10:
                        self.logger.scalar_summary(
                            '{}_loss_large/{}'.format(phase, l),
                            avg_loss_stats[l].val, cur_step)
                    else:
                        self.logger.scalar_summary(
                            '{}_loss/{}'.format(phase, l),
                            avg_loss_stats[l].val, cur_step)
                if phase == 'train':
                    self.logger.scalar_summary('lr', cur_lr, cur_step)

            Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format(
                epoch,
                iter_id,
                num_iters,
                phase=phase,
                total=bar.elapsed_td,
                eta=bar.eta_td)
            for l in avg_loss_stats:
                avg_loss_stats[l].update(loss_stats[l].mean().item(),
                                         batch['input'].size(0))
                Bar.suffix = Bar.suffix + '|{} {:.4f}({:.4f}) '.format(
                    l, avg_loss_stats[l].val, avg_loss_stats[l].avg)
            if not opt.hide_data_time:
                Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \
                  '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time)
            if opt.print_iter > 0:
                if iter_id % opt.print_iter == 0:
                    print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix))
            else:
                bar.next()

            if opt.debug > 0:
                if self.vis_interval != 0:
                    if iter_id % self.vis_interval == 0:
                        self.debug(batch, output, cur_step)

            if opt.test:
                self.save_result(output, batch, results)
            del output, loss, loss_stats

        bar.finish()
        ret = {k: v.avg for k, v in avg_loss_stats.items()}
        ret['time'] = bar.elapsed_td.total_seconds() / 60.
        return ret, results

    def debug(self, batch, output, iter_id):
        raise NotImplementedError

    def save_result(self, output, batch, results):
        raise NotImplementedError

    def _get_losses(self, opt):
        raise NotImplementedError

    def val(self, epoch, data_loader):
        return self.run_epoch('val', epoch, data_loader)

    def train(self, epoch, data_loader):
        return self.run_epoch('train', epoch, data_loader)
コード例 #12
0
class BaseTrainer(object):
    def __init__(self, opt, model, optimizer=None):
        self.opt = opt
        self.optimizer = optimizer
        self.loss_stats, self.loss = self._get_losses(opt)
        self.model_with_loss = ModleWithLoss(model, self.loss)

    def set_device(self, gpus, chunk_sizes, device):
        if len(gpus) > 1:
            self.model_with_loss = DataParallel(
                self.model_with_loss, device_ids=gpus,
                chunk_sizes=chunk_sizes).to(device)
        else:
            self.model_with_loss = self.model_with_loss.to(device)

        for state in self.optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device=device, non_blocking=True)

    def run_epoch(self, phase, epoch, data_loader, **kwargs):

        logger = kwargs.get('logger')
        model_with_loss = self.model_with_loss
        if phase == 'train':
            model_with_loss.train()
        else:
            if len(self.opt.gpus) > 1:
                model_with_loss = self.model_with_loss.module
            model_with_loss.eval()
            torch.cuda.empty_cache()

        opt = self.opt

        tb_writer = kwargs['tb_writer']
        results = {}
        data_time, batch_time = AverageMeter(), AverageMeter()
        avg_loss_stats = {
            l: AverageMeter()
            for l in self.loss_stats
            if not l.startswith('img_id') and not l.startswith('batch_')
        }
        num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters
        bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters)
        end = time.time()
        print('num_iters:', num_iters)

        for iter_id, batch in enumerate(data_loader):
            if iter_id >= num_iters:
                break
            global_step = num_iters * (epoch - 1) + iter_id
            data_time.update(time.time() - end)

            for k in batch:
                if k != 'meta':
                    batch[k] = batch[k].to(device=opt.device,
                                           non_blocking=True)
            output, loss, loss_stats = model_with_loss(batch, global_step,
                                                       tb_writer)

            tb_time = time.time()
            if opt.equal_loss:
                if phase == 'train':
                    loss = (loss[0] * opt.master_batch_size + loss[1] *
                            (opt.batch_size - opt.master_batch_size)
                            ) / opt.batch_size
                else:
                    loss = loss.mean()
            else:
                loss = loss.mean()

            if phase == 'train':
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            batch_time.update(time.time() - end)
            end = time.time()

            Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format(
                epoch,
                iter_id,
                num_iters,
                phase=phase,
                total=bar.elapsed_td,
                eta=bar.eta_td)

            for l in avg_loss_stats:
                __stats = loss_stats
                if tb_writer is not None:
                    tb_writer.add_scalar(l,
                                         __stats[l].mean().item(),
                                         global_step=global_step,
                                         walltime=tb_time)
                _n = batch['input'].size(0)
                avg_loss_stats[l].update(__stats[l].mean().item(), _n)
                Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(
                    l, avg_loss_stats[l].avg)

            if not opt.hide_data_time:
                Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \
                  '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time)
            if opt.print_iter > 0:
                if iter_id % opt.print_iter == 0:
                    print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix))
            else:
                bar.next()

            if opt.debug > 0 and opt.debug < 5:
                self.debug(batch, output, iter_id)

            if opt.test:
                self.save_result(output, batch, results)
            del output, loss, loss_stats

        if tb_writer is not None:
            for l in avg_loss_stats:
                tb_writer.add_scalar('epoch_avg/{}'.format(l),
                                     avg_loss_stats[l].avg,
                                     global_step=global_step,
                                     walltime=time.time())
        bar.finish()
        ret = {k: v.avg for k, v in avg_loss_stats.items()}
        ret['time'] = bar.elapsed_td.total_seconds() / 60.
        return ret, results

    def debug(self, batch, output, iter_id):
        raise NotImplementedError

    def save_result(self, output, batch, results):
        raise NotImplementedError

    def _get_losses(self, opt):
        raise NotImplementedError

    def val(self, epoch, data_loader, **kwargs):
        return self.run_epoch('val', epoch, data_loader, **kwargs)

    def train(self, epoch, data_loader, **kwargs):
        return self.run_epoch('train', epoch, data_loader, **kwargs)

    def wta_stat(self, epoch, data_loader):

        stats = []
        model_with_loss = self.model_with_loss
        if len(self.opt.gpus) > 1:
            model_with_loss = self.model_with_loss.module
        model_with_loss.eval()
        torch.cuda.empty_cache()

        opt = self.opt
        results = {}
        data_time, batch_time = AverageMeter(), AverageMeter()
        avg_loss_stats = {l: AverageMeter() for l in self.loss_stats}
        num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters
        end = time.time()
        for iter_id, batch in enumerate(data_loader):
            if iter_id >= num_iters:
                break
            global_step = num_iters * (epoch - 1) + iter_id
            data_time.update(time.time() - end)

            for k in batch:
                if k != 'meta':
                    batch[k] = batch[k].to(device=opt.device,
                                           non_blocking=True)
            _, _, loss_stats = model_with_loss(batch, global_step, None)
            batch_time.update(time.time() - end)
            end = time.time()

        return stats
コード例 #13
0
class BaseTrainer(object):
  def __init__(
    self, opt, model, optimizer=None):
    self.opt = opt
    self.optimizer = optimizer
    self.loss_stats, self.loss = self._get_losses(opt)
    self.model_with_loss = ModleWithLoss(model, self.loss)

  def set_device(self, gpus, chunk_sizes, device):
    if len(gpus) > 1:
      self.model_with_loss = DataParallel(
        self.model_with_loss, device_ids=gpus, 
        chunk_sizes=chunk_sizes).to(device)
    else:
      self.model_with_loss = self.model_with_loss.to(device)

    for state in self.optimizer.state.values():
      for k, v in state.items():
        if isinstance(v, torch.Tensor):
          state[k] = v.to(device=device, non_blocking=True)


  def run_epoch(self, phase, epoch, data_loader):
    model_with_loss = self.model_with_loss
    if phase == 'train':
      model_with_loss.train()
    else:
      if len(self.opt.gpus) > 1:
        model_with_loss = self.model_with_loss.module
      model_with_loss.eval()
      torch.cuda.empty_cache()

    opt = self.opt
    results = {}
    data_time, batch_time = AverageMeter(), AverageMeter()
    avg_loss_stats = {l: AverageMeter() for l in self.loss_stats}
    num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters
    bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters)
    end = time.time()
    for iter_id, batch in enumerate(data_loader):
      if iter_id >= num_iters:
        break
      data_time.update(time.time() - end)

      for k in batch:
        if k != 'meta':
          # print("k=",k)
          # print("batch[k].size=",batch[k].size())
          batch[k] = batch[k].to(device=opt.device, non_blocking=True)
      output, loss, loss_stats = model_with_loss(batch, fsm=self.opt.fsm)
      loss = loss.mean()
      if phase == 'train':
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
      batch_time.update(time.time() - end)
      end = time.time()

      Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format(
        epoch, iter_id, num_iters, phase=phase,
        total=bar.elapsed_td, eta=bar.eta_td)
      for l in avg_loss_stats:
        if l not in loss_stats:
          continue
        avg_loss_stats[l].update(
          loss_stats[l].mean().item(), batch['input'].size(0))
        Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(l, avg_loss_stats[l].avg)

      if not opt.hide_data_time:
        Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \
          '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time)
      if opt.print_iter > 0:
        if iter_id % opt.print_iter == 0:
          print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix)) 
      else:
        bar.next()
      
      if opt.debug > 0:
        self.debug(batch, output, iter_id)
      
      if opt.test:
        self.save_result(output, batch, results)
      del output, loss, loss_stats
    
    bar.finish()
    ret = {k: v.avg for k, v in avg_loss_stats.items()}
    ret['time'] = bar.elapsed_td.total_seconds() / 60.
    return ret, results


  def find_lr(self, epoch, data_loader, lr_start=1e-5, lr_end=1e0, beta=0.98):
    model_with_loss = self.model_with_loss
    model_with_loss.train()

    opt = self.opt
    data_time, batch_time = AverageMeter(), AverageMeter()
    avg_loss_stats = {l: AverageMeter() for l in self.loss_stats}
    num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters
    bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters)
    end = time.time()

    lr_mul = (lr_end / lr_start) ** (1. / (num_iters-1))
    lrs = []
    losses = []
    avg_loss = 0
    best_loss = 1e9
    for param_group in self.optimizer.param_groups:
      param_group['lr'] = lr_start
    for iter_id, batch in enumerate(data_loader):
      if iter_id >= num_iters:
        break
      data_time.update(time.time() - end)

      for k in batch:
        if k != 'meta':
          # print("k=",k)
          # print("batch[k].size=",batch[k].size())
          batch[k] = batch[k].to(device=opt.device, non_blocking=True)
      e2e = True if self.opt.e2e else False
      output, loss, loss_stats = model_with_loss(batch, e2e=e2e)
      loss = loss.mean()
      avg_loss = beta * avg_loss + (1 - beta) * loss.data
      smoothed_loss = avg_loss / (1 - beta ** (iter_id+1))
      self.optimizer.zero_grad()
      loss.backward()
      self.optimizer.step()
      batch_time.update(time.time() - end)
      end = time.time()
      if smoothed_loss < best_loss:
        best_loss = smoothed_loss
      lrs.append(self.optimizer.param_groups[0]['lr'])
      losses.append(float(smoothed_loss))

      if smoothed_loss > 4 * best_loss and iter_id > 0:
        break

      for param_group in self.optimizer.param_groups:
        param_group['lr'] = param_group['lr'] * lr_mul

      Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format(
        epoch, iter_id, num_iters, phase='train',
        total=bar.elapsed_td, eta=bar.eta_td)
      for l in avg_loss_stats:
        if l not in loss_stats:
          continue
        avg_loss_stats[l].update(
          loss_stats[l].mean().item(), batch['input'].size(0))
        Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(l, avg_loss_stats[l].avg)
      Bar.suffix = Bar.suffix + '|lr {:.6f}'.format(lrs[-1])
      if not opt.hide_data_time:
        Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \
                                  '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time)
      if opt.print_iter > 0:
        if iter_id % opt.print_iter == 0:
          print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix))
      else:
        bar.next()

      if opt.debug > 0:
        self.debug(batch, output, iter_id)
      del output, loss, loss_stats
    plt.figure()
    plt.xticks(np.log([1e-5, 1e-4, 1e-3, 1e-2, 5e-2, 1e-1, 5e-1, 1e0]),(1e-5, 1e-4, 1e-3, 1e-2, 5e-2, 1e-1, 5e-1, 1e0))
    plt.xlabel('learning rate')
    plt.ylabel('loss')
    plt.plot(np.log(lrs), losses)
    # plt.show()
    plt.savefig(opt.exp_id+'_lr_loss.png')
    sys.exit()

  def debug(self, batch, output, iter_id):
    raise NotImplementedError

  def save_result(self, output, batch, results):
    raise NotImplementedError

  def _get_losses(self, opt):
    raise NotImplementedError
  
  def val(self, epoch, data_loader):
    return self.run_epoch('val', epoch, data_loader)

  def train(self, epoch, data_loader, find_lr=False):
    if find_lr:
      self.find_lr(epoch, data_loader)
    return self.run_epoch('train', epoch, data_loader)
コード例 #14
0
class BaseTrainer(object):
  def __init__(
    self, opt, model, optimizer=None):
    self.opt = opt
    self.optimizer = optimizer
    self.loss_stats, self.loss = self._get_losses(opt)
    self.model_with_loss = ModleWithLoss(model, self.loss)
    self.reconstruct_img = False

  def set_device(self, gpus, chunk_sizes, device):
    if len(gpus) > 1:
      self.model_with_loss = DataParallel(
        self.model_with_loss, device_ids=gpus, 
        chunk_sizes=chunk_sizes).to(device)
    else:
      self.model_with_loss = self.model_with_loss.to(device)
    
    for state in self.optimizer.state.values():
      for k, v in state.items():
        if isinstance(v, torch.Tensor):
          state[k] = v.to(device=device, non_blocking=True)

  def save_tensor_to_img(self, tensors, filenames, path):
    batch_size, channel, w, h = tensors.size()
    reconstruct_imgs = tensors.detach().cpu().numpy().transpose(0,2,3,1)
    for ind in range(batch_size):
      save_path = os.path.join(path, filenames[ind])
      cv2.imwrite(save_path, reconstruct_imgs[ind])
    pass

  def run_epoch(self, phase, epoch, data_loader, logger=None):
    model_with_loss = self.model_with_loss
    if phase == 'train':
      model_with_loss.train()
    else:
      if len(self.opt.gpus) > 1:
        model_with_loss = self.model_with_loss.module
      model_with_loss.eval()
      torch.cuda.empty_cache()

    opt = self.opt
    results = {}
    data_time, batch_time = AverageMeter(), AverageMeter()
    avg_loss_stats = {l: AverageMeter() for l in self.loss_stats}
    num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters
    bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters)
    end = time.time()
    for iter_id, batch in enumerate(data_loader):
      if iter_id >= num_iters:
        break
      data_time.update(time.time() - end)

      for k in batch:
        if k != 'meta':
          batch[k] = batch[k].to(device=opt.device, non_blocking=True)    
      output, loss, loss_stats = model_with_loss(batch)
      if self.reconstruct_img:
        file_path = '/data/mry/code/CenterNet/debug_conflict_bt_class_recon'
        self.save_tensor_to_img(output['reconstruct_img'], batch['meta']['file_name'], file_path)
      loss = loss.mean()
      if phase == 'train':
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
      batch_time.update(time.time() - end)
      end = time.time()

      Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format(
        epoch, iter_id, num_iters, phase=phase,
        total=bar.elapsed_td, eta=bar.eta_td)
      for l in avg_loss_stats:
        if l == 'KL_loss':
          if loss_stats[l] is not None:
            avg_loss_stats[l].update(loss_stats[l].mean().item(), batch['input'].size(0))
            Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(l, avg_loss_stats[l])
          else:
            avg_loss_stats[l].update(0, batch['input'].size(0))
          continue
        avg_loss_stats[l].update(
          loss_stats[l].mean().item(), batch['input'].size(0))
        Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(l, avg_loss_stats[l].avg)
      if not opt.hide_data_time:
        Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \
          '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time)


      if logger and iter_id % opt.logger_iteration == 0:
        logger.write_iteration(
          '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format(
        epoch, iter_id, num_iters, phase=phase,
        total=bar.elapsed_td, eta=bar.eta_td))
        for l in avg_loss_stats:
          if loss_stats[l] is None:
            continue
          avg_loss_stats[l].update(loss_stats[l].mean().item(), batch['input'].size(0))
          logger.write_iteration('|{} {:.4f} '.format(l, avg_loss_stats[l].avg))
          logger.scalar_summary('train_iteration_{}'.format(l), avg_loss_stats[l].avg, (epoch-1)*num_iters+iter_id)
        logger.write_iteration('\n')

      if opt.print_iter > 0:
        if iter_id % opt.print_iter == 0:
          print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix)) 
      else:
        bar.next()
      
      if opt.debug > 0:
        self.debug(batch, output, iter_id)
      
      if opt.test:
        self.save_result(output, batch, results)
      del output, loss, loss_stats
    
    bar.finish()
    ret = {k: v.avg for k, v in avg_loss_stats.items()}
    ret['time'] = bar.elapsed_td.total_seconds() / 60.
    return ret, results
  
  def debug(self, batch, output, iter_id):
    raise NotImplementedError

  def save_result(self, output, batch, results):
    raise NotImplementedError

  def _get_losses(self, opt):
    raise NotImplementedError
  
  def val(self, epoch, data_loader):
    return self.run_epoch('val', epoch, data_loader)

  def train(self, epoch, data_loader, logger):
    return self.run_epoch('train', epoch, data_loader, logger=logger)
コード例 #15
0
ファイル: base_trainer.py プロジェクト: nemonameless/FairMOT
class BaseTrainer(object):
  def __init__(
    self, opt, model, optimizer=None):
    self.opt = opt
    self.optimizer = optimizer
    self.loss_stats, self.loss = self._get_losses(opt)
    self.model_with_loss = ModleWithLoss(model, self.loss)
    self.optimizer.add_param_group({'params': self.loss.parameters()})

  def set_device(self, gpus, chunk_sizes, device):
    if len(gpus) > 1:
      self.model_with_loss = DataParallel(
        self.model_with_loss, device_ids=gpus, 
        chunk_sizes=chunk_sizes).to(device)
    else:
      self.model_with_loss = self.model_with_loss.to(device)
    
    for state in self.optimizer.state.values():
      for k, v in state.items():
        if isinstance(v, torch.Tensor):
          state[k] = v.to(device=device, non_blocking=True)

  def run_epoch(self, phase, epoch, data_loader):
    model_with_loss = self.model_with_loss
    if phase == 'train':
      model_with_loss.train()
    else:
      if len(self.opt.gpus) > 1:
        model_with_loss = self.model_with_loss.module
      model_with_loss.eval()
      torch.cuda.empty_cache()

    opt = self.opt
    results = {}
    data_time, batch_time = AverageMeter(), AverageMeter()
    avg_loss_stats = {l: AverageMeter() for l in self.loss_stats}
    num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters
    bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters)
    end = time.time()
    for iter_id, batch in enumerate(data_loader):
      if iter_id >= num_iters:
        break
      data_time.update(time.time() - end)

      for k in batch:
        if k != 'meta':
          batch[k] = batch[k].to(device=opt.device, non_blocking=True)

      output, loss, loss_stats = model_with_loss(batch)
      loss = loss.mean()
      if phase == 'train':
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
      batch_time.update(time.time() - end)
      end = time.time()

      Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format(
        epoch, iter_id, num_iters, phase=phase,
        total=bar.elapsed_td, eta=bar.eta_td)
      for l in avg_loss_stats:
        avg_loss_stats[l].update(
          loss_stats[l].mean().item(), batch['input'].size(0))
        Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(l, avg_loss_stats[l].avg)
      if not opt.hide_data_time:
        Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \
          '|Net time: {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time)
      if opt.print_iter > 0:
        if iter_id % opt.print_iter == 0:
          print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix)) 
      else:
        bar.next()
      
      if opt.test:
        self.save_result(output, batch, results)
      del output, loss, loss_stats, batch
    
    bar.finish()
    ret = {k: v.avg for k, v in avg_loss_stats.items()}
    ret['time'] = bar.elapsed_td.total_seconds() / 60.
    return ret, results

  
  def debug(self, batch, output, iter_id):
    raise NotImplementedError

  def save_result(self, output, batch, results):
    raise NotImplementedError

  def _get_losses(self, opt):
    raise NotImplementedError
  
  def val(self, epoch, data_loader):
    return self.run_epoch('val', epoch, data_loader)

  def train(self, epoch, data_loader):
    return self.run_epoch('train', epoch, data_loader)
コード例 #16
0
class CtTrainer(object):
    def __init__(self, opt, model, optimizer=None, dataloader=None):
        # super(CtTrainer, self).__init__( opt, model, optimizer=optimizer, dataloader=dataloader)
        self.opt = opt
        self.model = model
        # self.criterion = criterion
        self.optimizer = optimizer
        self.dataloader = dataloader
        self.iterations = 0
        self.loss_stats, self.loss = self._get_losses(opt)
        self.model_with_loss = ModelWithLoss(model, self.loss)
        self.num_iters = len(self.dataloader)
        self.stats = {}
        self.phase = "train"

        self.plugin_queues = {
            'iteration': [],
            'epoch': [],
            'batch': [],
            'update': [],
        }

    def _get_losses(self, opt):
        loss_states = ['loss', 'hm_loss', 'wh_loss', 'off_loss']
        loss = CtdetLoss(opt)
        return loss_states, loss

    def set_device(self, gpus, chunk_sizes, device):
        if len(gpus) > 1:
            self.model_with_loss = DataParallel(
                self.model_with_loss, device_ids=gpus,
                chunk_sizes=chunk_sizes).to(device)
        else:
            self.model_with_loss = self.model_with_loss.to(device)

        for state in self.optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device=device, non_blocking=True)

    def register_plugin(self, plugin):
        # 注册插件
        plugin.register(self)

        # 插件的触发间隔,一般是这样的形式[(1, 'iteration'), (1, 'epoch')]
        intervals = plugin.trigger_interval

        if not isinstance(intervals, list):
            intervals = [intervals]
        for duration, unit in intervals:
            # unit 是事件的触发类别
            queue = self.plugin_queues[unit]
            '''添加事件, 这里的duration就是触发间隔,,以后在调用插件的时候,
            会进行更新  duration 决定了比如在第几个iteration or epoch 触发事件。len(queue)这里应当理解为优先级(越小越高)
            【在相同duration的情况下决定调用的顺序】,根据加入队列的早晚决定。'''
            queue.append((duration, len(queue), plugin))

    def call_plugins(self, queue_name, time, *args):
        # 调用插件
        args = (time, ) + args
        # 这里的time 最基本的意思是次数,如(iteration or epoch)
        queue = self.plugin_queues[queue_name]
        if len(queue) == 0:
            return
        while queue[0][0] <= time:
            '''如果队列第一个事件的duration(也就是触发时间点)小于当前times'''
            plugin = queue[0][2]
            '''调用相关队列相应的方法,所以如果是继承Plugin类的插件,
                       必须实现 iteration、batch、epoch和update中的至少一个且名字必须一致。'''
            getattr(plugin, queue_name)(*args)
            for trigger in plugin.trigger_interval:
                if trigger[1] == queue_name:
                    interval = trigger[0]
            '''根据插件的事件触发间隔,来更新事件队列里的事件 duration'''
            new_item = (time + interval, queue[0][1], plugin)
            heapq.heappushpop(queue, new_item)
            '''加入新的事件并弹出最小堆的堆头。最小堆重新排序。'''

    def run(self, epochs=1):
        for q in self.plugin_queues.values():
            '''对四个事件调用序列进行最小堆排序。'''
            heapq.heapify(q)

        for i in range(1, epochs + 1):
            self.train()
            # 进行每次epoch 的更新
            self.call_plugins('epoch', i)

    def train(self):
        model_with_loss = self.model_with_loss
        if self.phase == 'train':
            model_with_loss.train()
        for iter_id, batch in enumerate(self.dataloader):
            batch_input = batch['input']
            batch_hm = batch['hm']
            print("batch hm is: ", batch_hm)
            batch_reg_mask, batch_ind, batch_wh, batch_reg = batch[
                'reg_mask'], batch['ind'], batch['wh'], batch['reg']
            # self.call_plugins('batch', iter_id, batch_input, batch_hm,  batch_reg_mask, batch_ind, batch_wh, batch_reg)
            self.call_plugins('batch', iter_id, batch_input, batch_hm)
            if iter_id >= self.num_iters:
                break
            # data_time.update(time.time() - end)
            # plugin_data = [None, None, None, None]
            plugin_data = [None, None]
            for k in batch:
                if k != 'meta':
                    batch[k] = batch[k].to(device=self.opt.device,
                                           non_blocking=True)
            # def closure():
            batch_output, loss, loss_stats = self.model_with_loss(batch)
            # print("batch_output is", batch_output)
            # batch_output = self.model(input_var)
            loss = loss.mean()
            print('loss is,', loss)

            if plugin_data[0] is None:
                plugin_data[0] = batch_output['hm'].data
                # plugin_data[1] = batch_output['wh'].data
                # plugin_data[2] = batch_output['reg'].data
                # plugin_data[2] = batch_output['reg'].data
                plugin_data[1] = loss.data
                # return loss
            if self.phase == 'train':
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
            print("current loss is: ", loss)
            self.call_plugins('iteration', iter_id, batch_input,
                              batch_hm * plugin_data)
            # self.call_plugins('iteration', iter_id, batch_input, batch_hm,  batch_reg_mask, batch_ind, batch_wh, batch_reg,
            #                   *plugin_data)
            self.call_plugins('update', iter_id, self.model)
        # for i, data in enumerate(self.dataloader, self.iterations + 1):
        #     batch_input, batch_target = data
        #     # 在每次获取batch data 后进行更新
        #     self.call_plugins('batch', i, batch_input, batch_target)
        #     input_var = batch_input
        #     target_var = batch_target
        #     # 这里是给后续插件做缓存部分数据,这里是网络输出与loss
        #     plugin_data = [None, None]
        #
        #     def closure():
        #         batch_output = self.model(input_var)
        #         loss = self.criterion(batch_output, target_var)
        #         loss.backward()
        #         if plugin_data[0] is None:
        #             plugin_data[0] = batch_output.data
        #             plugin_data[1] = loss.data
        #         return loss
        #
        #     self.optimizer.zero_grad()
        #     self.optimizer.step(closure)
        #     self.call_plugins('iteration', i, batch_input, batch_target,
        #                       *plugin_data)
        #     self.call_plugins('update', i, self.model)

        self.iterations += i
コード例 #17
0
ファイル: base_trainer.py プロジェクト: XrosLiang/RTS3D
class BaseTrainer(object):
    def __init__(self, opt, model_image, model_point, optimizer_image,
                 optimizer_point):
        self.opt = opt
        self.optimizer_image = optimizer_image
        self.optimizer_point = optimizer_point

        self.loss_stats, self.loss = self._get_losses(opt)
        self.model_image = model_image
        self.model_point = ModelWithLoss(model_image, model_point,
                                         self.loss)  #model_point

    def set_device(self, gpus, chunk_sizes, device, model_select):
        if len(gpus) > 1:
            if model_select == 'image':
                self.model_image = DataParallel(
                    self.model_image, device_ids=gpus,
                    chunk_sizes=chunk_sizes).to(device)
            elif model_select == 'point':
                self.model_point = DataParallel(
                    self.model_point, device_ids=gpus,
                    chunk_sizes=chunk_sizes).to(device)
            else:
                print("no support model")
        else:
            if model_select == 'image':
                self.model_image = self.model_image.to(device)
            elif model_select == 'point':
                self.model_point = self.model_point.to(device)
            else:
                print("no support model")

        if model_select == 'image':
            for state in self.optimizer_image.state.values():
                for k, v in state.items():
                    if isinstance(v, torch.Tensor):
                        state[k] = v.to(device=device, non_blocking=True)
        elif model_select == 'point':
            for state in self.optimizer_point.state.values():
                for k, v in state.items():
                    if isinstance(v, torch.Tensor):
                        state[k] = v.to(device=device, non_blocking=True)
        else:
            print("no support model")

    def run_epoch(self, phase, epoch, data_loader):
        model_image = self.model_image
        model_point = self.model_point
        if phase == 'train':
            model_image.train()
            model_point.train()
        else:
            model_image.eval()
            model_point.eval()
            torch.cuda.empty_cache()

        opt = self.opt
        results = {}
        data_time, batch_time = AverageMeter(), AverageMeter()
        avg_loss_stats = {l: AverageMeter() for l in self.loss_stats}
        num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters
        bar = Bar('{}/{}'.format("3D detection", opt.exp_id), max=num_iters)
        end = time.time()

        for iter_id, batch in enumerate(data_loader):

            if iter_id >= num_iters:
                break
            data_time.update(time.time() - end)

            for k in batch:
                if k != 'meta':
                    batch[k] = batch[k].to(device=opt.device,
                                           non_blocking=True)

            if iter_id == 8:
                a = 1
            loss, loss_stats, next_est = model_point(
                batch, phase, epoch=epoch, opt=self.opt)  #loss, loss_stats
            next_est = next_est.view(batch['reg_mask'].size(0),
                                     batch['reg_mask'].size(1), -1).detach()
            loss = loss.mean()
            if phase == 'train':
                self.optimizer_image.zero_grad()
                self.optimizer_point.zero_grad()
                loss.backward()
                self.optimizer_image.step()
                self.optimizer_point.step()

                batch['pos_est'] = next_est[:, :, :3]
                batch['dim_est'] = next_est[:, :, 3:6]
                ry = next_est[:, :, 6]
                R_yaw = batch['pos_est'].new_zeros(next_est.size(0),
                                                   next_est.size(1), 3, 3)
                R_yaw[:, :, 0, 0] = torch.cos(ry)
                R_yaw[:, :, 0, 2] = torch.sin(ry)
                R_yaw[:, :, 1, 1] = 1
                R_yaw[:, :, 2, 0] = -torch.sin(ry)
                R_yaw[:, :, 2, 2] = torch.cos(ry)
                batch['ori_est'] = R_yaw
                batch['ori_est_scalar'] = ry
                loss, loss_stats, next_est = model_point(batch,
                                                         phase,
                                                         epoch=epoch,
                                                         opt=self.opt)
                loss = 2 * loss.mean()
                self.optimizer_image.zero_grad()
                self.optimizer_point.zero_grad()
                loss.backward()
                self.optimizer_image.step()
                self.optimizer_point.step()
            batch_time.update(time.time() - end)
            end = time.time()

            Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format(
                epoch,
                iter_id,
                num_iters,
                phase=phase,
                total=bar.elapsed_td,
                eta=bar.eta_td)
            for l in avg_loss_stats:
                avg_loss_stats[l].update(loss_stats[l].mean().item(),
                                         batch['input'].size(0))
                Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(
                    l, avg_loss_stats[l].avg)
            if not opt.hide_data_time:
                Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \
                  '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time)
            if opt.print_iter > 0:
                if iter_id % opt.print_iter == 0:
                    print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix))
            else:
                bar.next()

            del loss, loss_stats

        bar.finish()
        ret = {k: v.avg for k, v in avg_loss_stats.items()}
        ret['time'] = bar.elapsed_td.total_seconds() / 60.
        return ret, results

    def debug(self, batch, output, iter_id):
        raise NotImplementedError

    def save_result(self, output, batch, results):
        raise NotImplementedError

    def _get_losses(self, opt):
        raise NotImplementedError

    def val(self, epoch, data_loader):
        return self.run_epoch('val', epoch, data_loader)

    def train(self, epoch, data_loader):
        return self.run_epoch('train', epoch, data_loader)
コード例 #18
0
ファイル: trainers.py プロジェクト: zzq-oss/PPDM
class Hoidet(object):
    def __init__(self, opt, model, optimizer=None):
        self.opt = opt
        self.optimizer = optimizer
        loss = HoidetLoss(opt)
        self.loss_states = [
            'loss', 'hm_loss', 'wh_loss', 'off_loss', 'hm_rel_loss',
            'sub_offset_loss', 'obj_offset_loss'
        ]
        self.model_with_loss = ModelWithLoss(model, loss)

    def set_device(self, gpus, chunk_sizes, device):
        if len(gpus) > 1:
            self.model_with_loss = DataParallel(
                self.model_with_loss, device_ids=gpus,
                chunk_sizes=chunk_sizes).to(device)
        else:
            self.model_with_loss = self.model_with_loss.to(device)

        for state in self.optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device=device, non_blocking=True)

    def run_epoch(self, model_with_loss, epoch, data_loader, phase='train'):
        opt = self.opt
        results = {}
        data_time, batch_time = AverageMeter(), AverageMeter()
        avg_loss_states = {l: AverageMeter() for l in self.loss_states}
        num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters
        bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters)
        end = time.time()
        for iter_id, batch in enumerate(data_loader):
            if iter_id >= num_iters:
                break
            data_time.update(time.time() - end)

            for k in batch:
                if k != 'meta':
                    batch[k] = batch[k].to(device=opt.device,
                                           non_blocking=True)
            output, loss, loss_states = model_with_loss(batch)
            loss = loss.mean()
            if phase == 'train':
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
            batch_time.update(time.time() - end)
            end = time.time()

            Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format(
                epoch,
                iter_id,
                num_iters,
                phase=phase,
                total=bar.elapsed_td,
                eta=bar.eta_td)
            for l in avg_loss_states:
                avg_loss_states[l].update(loss_states[l].mean().item(),
                                          batch['input'].size(0))
                Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(
                    l, avg_loss_states[l].avg)
            if not opt.hide_data_time:
                Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \
                                          '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time)
            if opt.print_iter > 0:
                if iter_id % opt.print_iter == 0:
                    print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix))
            else:
                bar.next()

            del output, loss, loss_states

        bar.finish()
        ret = {k: v.avg for k, v in avg_loss_states.items()}
        ret['time'] = bar.elapsed_td.total_seconds() / 60.
        return ret, results

    def train(self, epoch, data_loader):
        model_with_loss = self.model_with_loss
        model_with_loss.train()
        ret, results = self.run_epoch(model_with_loss, epoch, data_loader)
        return ret, results

    def val(self, epoch, data_loader):
        model_with_loss = self.model_with_loss
        model_with_loss.eval()
        torch.cuda.empty_cache()
        with torch.no_grad:
            ret, results = self.run_epoch(model_with_loss,
                                          epoch,
                                          data_loader,
                                          phase='val')
        return ret, results