def __init__( self, opt, model, optimizer=None): self.opt = opt self.optimizer = optimizer self.loss_stats, self.loss = self._get_losses(opt) self.model_with_loss = ModelWithLoss(model, self.loss)
class BaseTrainer(object): def __init__(self, opt, model, optimizer=None): self.opt = opt self.optimizer = optimizer self.loss_stats, self.loss = self._get_losses(opt) self.eval_stats = self._get_evals(opt) self.model = model self.model_with_loss = ModleWithLoss(self.model, self.loss) def set_device(self, gpus, chunk_sizes, device): if len(gpus) > 1: self.model_with_loss = DataParallel( self.model_with_loss, device_ids=gpus, chunk_sizes=chunk_sizes).to(device) else: self.model_with_loss = self.model_with_loss.to(device) for state in self.optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device=device, non_blocking=True) def run_epoch(self, phase, epoch, data_loader): model_with_loss = self.model_with_loss if phase == 'train': model_with_loss.train() else: if len(self.opt.gpus) > 1: model_with_loss = self.model_with_loss.module model_with_loss.eval() opt = self.opt results = {} data_time, batch_time = AverageMeter(), AverageMeter() avg_loss_stats = {l: AverageMeter() for l in self.loss_stats} avg_eval_stats = {l: AverageMeter() for l in self.eval_stats} if phase == 'train': num_iters = len(data_loader.dataset) // opt.batch_size else: num_iters = len(data_loader.dataset) bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters) dataiterator = iter(data_loader) end = time.time() for iter_id in range(num_iters + 1): # if iter_id%100 == 0 and phase == 'train': # save_model(os.path.join(opt.save_dir, 'model_epoch_{}_iter_{}.pth'.format(epoch-1,iter_id)),epoch-1, self.model, self.optimizer) self.optimizer.zero_grad() batch_loss_stats = {l: 0 for l in self.loss_stats} batch_eval_stats = {l: 0 for l in self.eval_stats} if phase == 'train': subdivision = opt.subdivision else: subdivision = 1 for sub_iter in range(subdivision): try: batch = next(dataiterator) except StopIteration: dataiterator = iter(data_loader) batch = next(dataiterator) data_time.update(time.time() - end) for k in batch: batch[k] = batch[k].to(device=opt.device, non_blocking=True) output, loss, loss_stats = model_with_loss(batch) batch_time.update(time.time() - end) end = time.time() for l in batch_loss_stats: batch_loss_stats[l] += loss_stats[l].item() / subdivision if phase == 'train': loss.backward() else: test_stats = self._get_result(batch, output) for l in batch_eval_stats: batch_eval_stats[l] += test_stats[l] / subdivision if phase == 'train': self.optimizer.step() Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format( epoch, iter_id, num_iters, phase=phase, total=bar.elapsed_td, eta=bar.eta_td) if phase == 'test': batch_size = 1 else: batch_size = opt.batch_size for l in avg_loss_stats: avg_loss_stats[l].update(batch_loss_stats[l], batch_size) Bar.suffix = Bar.suffix + '|{} {:.4f} '.format( l, avg_loss_stats[l].avg) if not opt.hide_data_time: Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \ '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time) if phase == 'test': for l in avg_eval_stats: avg_eval_stats[l].update(batch_eval_stats[l], batch_size) Bar.suffix = Bar.suffix + '|{} {:.4f} '.format( l, avg_eval_stats[l].avg) if opt.print_iter > 0: if iter_id % opt.print_iter == 0: print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix)) else: bar.next() bar.finish() ret = {k: v.avg for k, v in avg_loss_stats.items()} ret['time'] = bar.elapsed_td.total_seconds() / 60. if phase == 'test': for l in avg_eval_stats: ret[l] = avg_eval_stats[l].avg return ret, results def _get_losses(self, opt): raise NotImplementedError def _get_result(self, batch, output): raise NotImplementedError def _get_evals(self, opt): raise NotImplementedError def train(self, epoch, data_loader): return self.run_epoch('train', epoch, data_loader) def test(self, epoch, data_loader): return self.run_epoch('test', epoch, data_loader)
class CtdetTrainer(object): def __init__(self, opt, model, optimizer=None): self.opt = opt self.optimizer = optimizer self.loss_stats, self.loss = self._get_losses(opt) self.model_with_loss = ModleWithLoss(model, self.loss) def set_device(self, gpus, chunk_sizes, device): if len(gpus) > 1: self.model_with_loss = DataParallel( self.model_with_loss, device_ids=gpus, chunk_sizes=chunk_sizes).to(device) else: self.model_with_loss = self.model_with_loss.to(device) for state in self.optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device=device, non_blocking=True) def run_epoch(self, phase, epoch, data_loader): model_with_loss = self.model_with_loss if phase == 'train': model_with_loss.train() else: if len(self.opt.gpus) > 1: model_with_loss = self.model_with_loss.module model_with_loss.eval() torch.cuda.empty_cache() opt = self.opt results = {} data_time, batch_time = AverageMeter(), AverageMeter() avg_loss_stats = {l: AverageMeter() for l in self.loss_stats} num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters) end = time.time() for iter_id, batch in enumerate(data_loader): if iter_id >= num_iters: break data_time.update(time.time() - end) for k in batch: if k != 'meta': batch[k] = batch[k].to(device=opt.device, non_blocking=True) output, loss, loss_stats = model_with_loss(batch) loss = loss.mean() if phase == 'train': self.optimizer.zero_grad() loss.backward() self.optimizer.step() batch_time.update(time.time() - end) end = time.time() Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format( epoch, iter_id, num_iters, phase=phase, total=bar.elapsed_td, eta=bar.eta_td) for l in avg_loss_stats: avg_loss_stats[l].update(loss_stats[l].mean().item(), batch['input'].size(0)) Bar.suffix = Bar.suffix + '|{} {:.4f} '.format( l, avg_loss_stats[l].avg) if not opt.hide_data_time: Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \ '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time) if opt.print_iter > 0: if iter_id % opt.print_iter == 0: print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix)) else: bar.next() if opt.debug > 0: self.debug(batch, output, iter_id) if opt.test: self.save_result(output, batch, results) del output, loss, loss_stats bar.finish() ret = {k: v.avg for k, v in avg_loss_stats.items()} ret['time'] = bar.elapsed_td.total_seconds() / 60. return ret, results def debug(self, batch, output, iter_id): opt = self.opt reg = output['reg'] if opt.reg_offset else None dets = ctdet_decode(output['hm'], output['wh'], reg=reg, cat_spec_wh=opt.cat_spec_wh, K=opt.K) dets = dets.detach().cpu().numpy().reshape(1, -1, dets.shape[2]) dets[:, :, :4] *= opt.down_ratio dets_gt = batch['meta']['gt_det'].numpy().reshape(1, -1, dets.shape[2]) dets_gt[:, :, :4] *= opt.down_ratio for i in range(1): debugger = Debugger(dataset=opt.dataset, ipynb=(opt.debug == 3), theme=opt.debugger_theme) img = batch['input'][i].detach().cpu().numpy().transpose(1, 2, 0) img = np.clip(((img * opt.std + opt.mean) * 255.), 0, 255).astype(np.uint8) pred = debugger.gen_colormap( output['hm'][i].detach().cpu().numpy()) gt = debugger.gen_colormap(batch['hm'][i].detach().cpu().numpy()) debugger.add_blend_img(img, pred, 'pred_hm') debugger.add_blend_img(img, gt, 'gt_hm') debugger.add_img(img, img_id='out_pred') for k in range(len(dets[i])): if dets[i, k, 4] > opt.center_thresh: debugger.add_coco_bbox(dets[i, k, :4], dets[i, k, -1], dets[i, k, 4], img_id='out_pred') debugger.add_img(img, img_id='out_gt') for k in range(len(dets_gt[i])): if dets_gt[i, k, 4] > opt.center_thresh: debugger.add_coco_bbox(dets_gt[i, k, :4], dets_gt[i, k, -1], dets_gt[i, k, 4], img_id='out_gt') if opt.debug == 4: debugger.save_all_imgs(opt.debug_dir, prefix='{}'.format(iter_id)) else: debugger.show_all_imgs(pause=True) def save_result(self, output, batch, results): reg = output['reg'] if self.opt.reg_offset else None dets = ctdet_decode(output['hm'], output['wh'], reg=reg, cat_spec_wh=self.opt.cat_spec_wh, K=self.opt.K) dets = dets.detach().cpu().numpy().reshape(1, -1, dets.shape[2]) dets_out = ctdet_post_process(dets.copy(), batch['meta']['c'].cpu().numpy(), batch['meta']['s'].cpu().numpy(), output['hm'].shape[2], output['hm'].shape[3], output['hm'].shape[1]) results[batch['meta']['img_id'].cpu().numpy()[0]] = dets_out[0] def _get_losses(self, opt): loss_states = ['loss', 'hm_loss', 'wh_loss', 'off_loss'] loss = CtdetLoss(opt) return loss_states, loss def val(self, epoch, data_loader): return self.run_epoch('val', epoch, data_loader) def train(self, epoch, data_loader): return self.run_epoch('train', epoch, data_loader)
class Trainer: def __init__(self, cfg, model, optimizer): self.cfg = cfg self.optimizer = optimizer self.model = model if cfg.USE_OFFSET: self.loss_stats = { 'total_loss': [], 'hm_loss': [], 'wh_loss': [], 'offset_loss': [] } else: self.loss_stats = {'total_loss': [], 'hm_loss': [], 'wh_loss': []} self.loss = CenterLoss(cfg) # for teacherloss # self.loss = TeacherLoss(cfg) def set_device(self, gpus, device): if len(gpus) > 1: self.model = DataParallel(self.model, device_ids=gpus).to(device) self.loss = DataParallel(self.loss, device_ids=gpus).to(device) else: self.model = self.model.to(device) self.loss = self.loss.to(device) for state in self.optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device=device, non_blocking=True) def run_epoch(self, epoch, data_loader, log_file): epoch_total_loss = 0.0 epoch_hm_loss = 0.0 epoch_wh_loss = 0.0 epoch_offset_loss = 0.0 self.model.train() self.loss.train() data_process = tqdm(data_loader) for batch_item in data_process: batch_img, batch_label = batch_item batch_img = batch_img.to(device=self.cfg.DEVICE) for k in batch_label: batch_label[k] = batch_label[k].to(device=self.cfg.DEVICE) batch_output, feature = self.model(batch_img) # for teacherloss # batch_output = self.model(batch_img) loss, loss_stats = self.loss(batch_output, batch_label) loss = loss.mean() self.optimizer.zero_grad() loss.backward() self.optimizer.step() batch_loss = [ loss_stats['total_loss'], loss_stats['hm_loss'], loss_stats['wh_loss'] ] epoch_total_loss += batch_loss[0] epoch_hm_loss += batch_loss[1] epoch_wh_loss += batch_loss[2] loss_str = "total_loss: {},hm_loss: {:.6f},wh_loss: {:.6f}".format( batch_loss[0], batch_loss[1], batch_loss[2]) if 'offset_loss' in loss_stats: batch_loss.append(loss_stats['offset_loss']) epoch_offset_loss += batch_loss[3] loss_str += ",offset_loss: {:.6f}".format(batch_loss[3]) data_process.set_description_str("epoch:{}".format(epoch)) data_process.set_postfix_str(loss_str) log_str = "{},{:.6f},{:.6f},{:.6f}".format( epoch, epoch_total_loss / len(data_loader), epoch_hm_loss / len(data_loader), epoch_wh_loss / len(data_loader)) if self.cfg.USE_OFFSET: log_str += ",{:.6f}\n".format(epoch_offset_loss / len(data_loader)) else: log_str += "\n" log_file.write(log_str) log_file.flush() def train(self, epoch, data_loader, train_log): return self.run_epoch(epoch, data_loader, train_log) def val(self, epoch, model_path, val_loader, val_log, cfg): detecter = Detector(model_path, cfg) mean_precision = 0 mean_recall = 0 sample_count = val_loader.num_samples for i in range(sample_count): image_path, gt_bboxes = val_loader.getitem(i) results = detecter.run(image_path) pre_bboxes = results[1] if len(gt_bboxes) > 0: precision, recall = compute_metrics(pre_bboxes, gt_bboxes) mean_precision += precision mean_recall += recall log_str = "{},{:.6f},{:.6f}\n".format(epoch, mean_precision / sample_count, mean_recall / sample_count) val_log.write(log_str) val_log.flush()
class BaseTrainer(object): def __init__(self, opt, model, optimizer=None): self.opt = opt self.optimizer = optimizer self.loss_stats, self.loss = self._get_losses(opt) self.model_with_loss = ModelWithLoss(model, self.loss) def set_device(self, gpus, chunk_sizes, device): if len(gpus) > 1: self.model_with_loss = DataParallel( self.model_with_loss, device_ids=gpus, chunk_sizes=chunk_sizes).to(device) else: self.model_with_loss = self.model_with_loss.to(device) for state in self.optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device=device, non_blocking=True) def run_epoch(self, phase, epoch, data_loader): opt = self.opt model_with_loss = self.model_with_loss if phase == 'train': model_with_loss.train() else: if len(opt.gpus) > 1: model_with_loss = self.model_with_loss.module model_with_loss.eval() torch.cuda.empty_cache() if opt.mixed_precision: from apex import amp results = {} data_time, batch_time = AverageMeter(), AverageMeter() avg_loss_stats = {l: AverageMeter() for l in self.loss_stats} num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters num_accum = opt.num_grad_accum or 1 bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters) end = time.time() for iter_id, batch in enumerate(data_loader): if iter_id >= num_iters: break data_time.update(time.time() - end) # [0] 1 [2] 3 [4] 5 [6] 7 [8] # [0] 1 2 [3] 4 5 [6] 7 8 [9] if not iter_id % num_accum: self.optimizer.zero_grad() for k in batch: if k != 'meta': batch[k] = batch[k].to(device=opt.device, non_blocking=True) output, loss, loss_stats = model_with_loss(batch) loss = loss.mean() / num_accum if phase == 'train': if opt.mixed_precision: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() # 0 [1] 2 [3] 4 [5] 6 [7] 8 # 0 1 [2] 3 4 [5] 6 7 [8] 9 if not (iter_id + 1) % num_accum: self.optimizer.step() if opt.use_swa and opt.swa_manual: # epochs count starts from 1 global_iter = num_iters * max(0, epoch - 1) + iter_id if opt.swa_lr is not None and global_iter == opt.swa_start: for param_group in self.optimizer.param_groups: param_group['lr'] = opt.swa_lr if global_iter > opt.swa_start: if not (iter_id + 1) % opt.swa_freq: self.optimizer.update_swa() batch_time.update(time.time() - end) end = time.time() Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format( epoch, iter_id, num_iters, phase=phase, total=bar.elapsed_td, eta=bar.eta_td) for l in avg_loss_stats: avg_loss_stats[l].update(loss_stats[l].mean().item(), batch['input'].size(0)) Bar.suffix = Bar.suffix + '|{} {:.4f} '.format( l, avg_loss_stats[l].avg) if not opt.hide_data_time: Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \ '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time) if opt.print_iter > 0: if iter_id % opt.print_iter == 0: print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix)) else: bar.next() if opt.debug > 0: self.debug(batch, output, iter_id) if opt.test: self.save_result(output, batch, results) del output, loss, loss_stats bar.finish() ret = {k: v.avg for k, v in avg_loss_stats.items()} ret['time'] = bar.elapsed_td.total_seconds() / 60. return ret, results def bn_update(self, data_loader): opt = self.opt model_with_loss = self.model_with_loss model_with_loss.train() for iter_id, batch in enumerate(data_loader): if iter_id >= opt.swa_bn_upd_iters: break for k in batch: if k != 'meta': batch[k] = batch[k].to(device=opt.device, non_blocking=True) with torch.no_grad(): output, loss, loss_stats = model_with_loss(batch) del output, loss, loss_stats model_with_loss.eval() def debug(self, batch, output, iter_id): raise NotImplementedError def save_result(self, output, batch, results): raise NotImplementedError def _get_losses(self, opt): raise NotImplementedError def val(self, epoch, data_loader): return self.run_epoch('val', epoch, data_loader) def train(self, epoch, data_loader): return self.run_epoch('train', epoch, data_loader)
class BaseTrainer(object): def __init__( self, opt, model, optimizer=None): self.opt = opt self.optimizer = optimizer self.loss_stats, self.loss = self._get_losses(opt) self.model_with_loss = ModelWithLoss(model, self.loss) self.rampup = exp_rampup(100) self.rampup_prob = exp_rampup(100) self.rampup_coor = exp_rampup(100) def set_device(self, gpus, chunk_sizes, device, distribute = False): # if len(gpus) > 1: # self.model_with_loss = DataParallel( # self.model_with_loss, device_ids=gpus, # chunk_sizes=chunk_sizes).to(device) # else: # self.model_with_loss = self.model_with_loss.to(device) # # for state in self.optimizer.state.values(): # for k, v in state.items(): # if isinstance(v, torch.Tensor): # state[k] = v.to(device=device, non_blocking=True) if len(gpus) > 1 and not distribute: self.model_with_loss = DataParallel( self.model_with_loss, device_ids=gpus, chunk_sizes=chunk_sizes).to(device) else: self.model_with_loss = self.model_with_loss.to(device) for state in self.optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device=device, non_blocking=True) def run_epoch(self, phase, epoch, data_loader,unlabel_loader1=None,unlabel_loader2=None,unlabel_set=None,iter_num=None,data_loder2=None): model_with_loss = self.model_with_loss if phase == 'train': model_with_loss.train() else: # if len(self.opt.gpus) > 1: # model_with_loss = self.model_with_loss.module model_with_loss.eval() torch.cuda.empty_cache() opt = self.opt results = {} data_time, batch_time = AverageMeter(), AverageMeter() avg_loss_stats = {l: AverageMeter() for l in self.loss_stats} num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters bar = Bar('{}/{}'.format("3D detection", opt.exp_id), max=num_iters) end = time.time() for iter_id, batch in enumerate(data_loader): if iter_id >= num_iters: break data_time.update(time.time() - end) for k in batch: if k != 'meta': batch[k] = batch[k].to(device=opt.device, non_blocking=True) coor_weight=self.rampup_coor(epoch) if coor_weight< self.opt.coor_thresh: coor_weight=0 output, loss, loss_stats = model_with_loss(batch,phase=phase) loss = loss['hm_loss'].mean() * opt.hm_weight + \ loss['wh_loss'].mean() * opt.wh_weight + \ loss['off_loss'].mean() * opt.off_weight + \ loss['hp_loss'].mean() * opt.hp_weight + \ loss['hp_offset_loss'].mean() * opt.off_weight + \ loss['hm_hp_loss'].mean() * opt.hm_hp_weight + \ loss['dim_loss'].mean() *opt.dim_weight + \ loss['rot_loss'].mean() * opt.rot_weight + \ loss['prob_loss'].mean()*self.rampup_prob(epoch)+\ loss['coor_loss'].mean()*coor_weight # loss = loss['hm_loss'].mean() * opt.hm_weight + \ # loss['hp_loss'].mean() * opt.hp_weight + \ # loss['dim_loss'].mean() * opt.dim_weight + \ # loss['rot_loss'].mean() * opt.rot_weight + \ # loss['prob_loss'].mean() * self.rampup_prob(epoch) + \ # loss['coor_loss'].mean() * coor_weight if phase == 'train': self.optimizer.zero_grad() loss.backward() self.optimizer.step() batch_time.update(time.time() - end) end = time.time() Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format( epoch, iter_id, num_iters, phase=phase, total=bar.elapsed_td, eta=bar.eta_td) for l in avg_loss_stats: avg_loss_stats[l].update( loss_stats[l].mean().item(), batch['input'].size(0)) Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(l, avg_loss_stats[l].avg) if not opt.hide_data_time: Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \ '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time) if opt.print_iter > 0: if iter_id % opt.print_iter == 0: print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix)) else: bar.next() if opt.debug > 0: self.debug(batch, output, iter_id) if opt.test: self.save_result(output, batch, results) del output, loss, loss_stats bar.finish() ret = {k: v.avg for k, v in avg_loss_stats.items()} ret['time'] = bar.elapsed_td.total_seconds() / 60. return ret, results def debug(self, batch, output, iter_id): raise NotImplementedError def save_result(self, output, batch, results): raise NotImplementedError def _get_losses(self, opt): raise NotImplementedError def val(self, epoch, data_loader): return self.run_epoch('val', epoch, data_loader) def train(self, epoch, data_loader,unlabel_loader1=None,unlabel_loader2=None,unlabel_set=None,iter_num=None,uncert=None): return self.run_epoch('train', epoch, data_loader,unlabel_loader1,unlabel_loader2,unlabel_set,iter_num,uncert)
def __init__(self, opt, model, optimizer=None): self.opt = opt self.optimizer = optimizer self.loss_stats, self.loss = self._get_losses(opt) self.model_with_loss = ModleWithLoss(model, self.loss) self.optimizer.add_param_group({'params': self.loss.parameters()})
def __init__(self, opt, model, optimizer=None): self.opt = opt self.optimizer = optimizer self.loss_stats, self.loss = self._get_losses(opt) self.model_with_loss = ModleWithLoss(model, self.loss) self.summary_writer = SummaryWriter("./tbX")
class BaseTrainer(object): def __init__( self, opt, model, optimizer=None): self.opt = opt self.optimizer = optimizer self.loss_stats, self.loss = self._get_losses(opt) self.model_with_loss = ModelWithLoss(model, self.loss) def set_device(self, gpus, chunk_sizes, device): if len(gpus) > 1: self.model_with_loss = DataParallel( self.model_with_loss, device_ids=gpus, chunk_sizes=chunk_sizes).to(device) else: self.model_with_loss = self.model_with_loss.to(device) for state in self.optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device=device, non_blocking=True) def run_epoch(self, phase, epoch, data_loader): """ Unified run epoch for train & evaluation Arguments: phase {str} -- 'train' or 'val' epoch {int} -- epoch index data_loader {object} -- data loader object Returns: ret [dict] -- 'time' results [dict] -- test results? """ model_with_loss = self.model_with_loss if phase == 'train': model_with_loss.train() # train model else: if len(self.opt.gpus) > 1: model_with_loss = self.model_with_loss.module model_with_loss.eval() torch.cuda.empty_cache() opt = self.opt results = {} data_time, batch_time = AverageMeter(), AverageMeter() avg_loss_stats = {l: AverageMeter() for l in self.loss_stats} num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters) end = time.time() for iter_id, batch in enumerate(data_loader): if iter_id >= num_iters: # ! stop early for debug purpose break data_time.update(time.time() - end) for k in batch: if k != 'meta': batch[k] = batch[k].to(device=opt.device, non_blocking=True) # feed-forward output, loss, loss_stats = model_with_loss(batch) # average batch loss loss = loss.mean() # back-propagation for training if phase == 'train': self.optimizer.zero_grad() # clear old gradients loss.backward() # back-propagation self.optimizer.step() # optimizer step based on graident # update timer batch_time.update(time.time() - end) end = time.time() # visualization Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format( epoch, iter_id, num_iters, phase=phase, total=bar.elapsed_td, eta=bar.eta_td) for l in avg_loss_stats: avg_loss_stats[l].update( loss_stats[l].mean().item(), batch['input'].size(0)) Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(l, avg_loss_stats[l].avg) if not opt.hide_data_time: Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \ '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time) if opt.print_iter > 0: if iter_id % opt.print_iter == 0: print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix)) else: bar.next() # show debug if opt.debug > 0 and epoch > 3: self.debug(batch, output, iter_id) if opt.test: self.save_result(output, batch, results) del output, loss, loss_stats #? delete variables? bar.finish() ret = {k: v.avg for k, v in avg_loss_stats.items()} ret['time'] = bar.elapsed_td.total_seconds() / 60. return ret, results def debug(self, batch, output, iter_id): raise NotImplementedError def save_result(self, output, batch, results): raise NotImplementedError def _get_losses(self, opt): raise NotImplementedError def val(self, epoch, data_loader): return self.run_epoch('val', epoch, data_loader) def train(self, epoch, data_loader): return self.run_epoch('train', epoch, data_loader)
class BaseTrainerIter(object): def __init__(self, opt, model, optimizer=None): self.opt = opt self.optimizer = optimizer self.accum_coco = AccumCOCO() self.accum_coco_det = AccumCOCODetResult() def set_device(self, gpus, chunk_sizes, device): if len(gpus) > 1: self.model = DataParallel(self.model, device_ids=gpus, chunk_sizes=chunk_sizes).to(device) else: self.model = self.model.to(device) for state in self.optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device=device, non_blocking=True) def run_epoch(self, phase, epoch, data_loader, no_aug_loader=None): model_with_loss = self.model if phase == 'train': model_with_loss.train() else: if len(self.opt.gpus) > 1: model_with_loss = self.model.module model_with_loss.eval() torch.cuda.empty_cache() opt = self.opt results = {} data_time, batch_time = AverageMeter(), AverageMeter() avg_loss_stats = {l: AverageMeter() for l in self.loss_stats} num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters) end = time.time() if opt.save_video: fourcc = cv2.VideoWriter_fourcc(*'mp4v') vid_pth = os.path.join(opt.save_dir, opt.exp_id + '_pred') gt_pth = os.path.join(opt.save_dir, opt.exp_id + '_gt') out_pred = cv2.VideoWriter('{}.mp4'.format(vid_pth), fourcc, opt.save_framerate, (opt.input_w, opt.input_h)) out_gt = cv2.VideoWriter('{}.mp4'.format(gt_pth), fourcc, opt.save_framerate, (opt.input_w, opt.input_h)) delta_max = opt.delta_max delta_min = opt.delta_min delta = delta_min umax = opt.umax a_thresh = opt.acc_thresh metric = get_metric(opt) iter_id = 0 data_iter = iter(data_loader) update_lst = [] acc_lst = [] coco_res_lst = [] while True: load_time, total_model_time, model_time, update_time, tot_time, display_time = 0, 0, 0, 0, 0, 0 start_time = time.time() # data loading try: batch = next(data_iter) except StopIteration: break if iter_id > opt.num_iters: break loaded_time = time.time() load_time += (loaded_time - start_time) if opt.adaptive: if iter_id % delta == 0: u = 0 update = True while (update): output, tmp_model_time = self.run_model(batch) total_model_time += tmp_model_time # save the stuff every iteration acc = metric.get_score(batch, output, u) print(acc) if u < umax and acc < a_thresh: update_time = self.update_model(batch) else: update = False u += 1 if acc > a_thresh: delta = min(delta_max, 2 * delta) else: delta = max(delta_min, delta / 2) output, _ = self.run_model( batch) # run model with new weights model_time = total_model_time / u update_lst += [(iter_id, u)] acc_lst += [(iter_id, acc)] self.accum_coco.store_metric_coco(iter_id, batch, output, opt) else: update_lst += [(iter_id, 0)] output, model_time = self.run_model(batch) if opt.acc_collect and (iter_id % opt.acc_interval == 0): acc = metric.get_score(batch, output, 0) print(acc) acc_lst += [(iter_id, acc)] self.accum_coco.store_metric_coco( iter_id, batch, output, opt) else: output, model_time = self.run_model(batch) if opt.acc_collect: acc = metric.get_score(batch, output, 0, is_baseline=True) print(acc) acc_lst += [(iter_id, acc)] self.accum_coco.store_metric_coco(iter_id, batch, output, opt, is_baseline=True) display_start = time.time() if opt.tracking: trackers, viz_pred = self.tracking( batch, output, iter_id) # TODO: factor this into the other class out_pred.write(viz_pred) elif opt.save_video: pred, gt = self.debug(batch, output, iter_id) out_pred.write(pred) out_gt.write(gt) if opt.debug > 1: self.debug(batch, output, iter_id) display_end = time.time() display_time = (display_end - display_start) end_time = time.time() tot_time = (end_time - start_time) # add a bunch of stuff to the bar to print Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format( epoch, iter_id, num_iters, phase=phase, total=bar.elapsed_td, eta=bar.eta_td) # add to the progress bar if opt.print_iter > 0: if iter_id % opt.print_iter == 0: print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix)) else: bar.next() if opt.display_timing: time_str = 'total {:.3f}s| load {:.3f}s | model_time {:.3f}s | update_time {:.3f}s | display {:.3f}s'.format( tot_time, load_time, model_time, update_time, display_time) print(time_str) self.save_result(output, batch, results) del output iter_id += 1 bar.finish() ret = {k: v.avg for k, v in avg_loss_stats.items()} ret['time'] = bar.elapsed_td.total_seconds() / 60. save_dict = {} if opt.adaptive: plt.scatter(*zip(*update_lst)) plt.xlabel('iteration') plt.ylabel('number of updates') plt.savefig(opt.save_dir + '/update_frequency.png') save_dict['updates'] = update_lst plt.clf() if opt.acc_collect: plt.scatter(*zip(*acc_lst)) plt.xlabel('iteration') plt.ylabel('mAP') plt.savefig(opt.save_dir + '/acc_figure.png') save_dict['acc'] = acc_lst if opt.adaptive and opt.acc_collect: x, y = zip(*filter(lambda x: x[1] > 0, update_lst)) plt.scatter(x, y, c='r', marker='o') plt.xlabel('iteration') # save dict # gt_dict = self.accum_coco.get_gt() # dt_dict = self.accum_coco.get_dt() dt_dict = self.accum_coco_det.get_dt() # save_dict['gt_dict'] = gt_dict # save_dict['dt_dict'] = dt_dict save_dict['full_res_pred'] = dt_dict return ret, results, save_dict def debug(self, batch, output, iter_id): raise NotImplementedError def save_result(self, output, batch, results): raise NotImplementedError def _get_losses(self, opt): raise NotImplementedError def train_model(self, batch): raise NotImplementedError def run_model(self, batch): raise NotImplementedError def update_model(self, loss): raise NotImplementedError def train(self, epoch, data_loader, aug_loader=None): return self.run_epoch('train', epoch, data_loader, aug_loader)
class BaseTrainer(object): def __init__(self, opt, model, optimizer=None, logger=None, lr_scheduler=None): self.opt = opt self.optimizer = optimizer self.loss_stats, self.loss = self._get_losses(opt) self.model_with_loss = ModleWithLoss(model, self.loss) # FIXME: vis interval self.original_debug = opt.debug self.vis_interval = opt.show_intervals if logger is not None: self.logger = logger self.lr_scheduler = lr_scheduler def set_device(self, gpus, chunk_sizes, device): if len(gpus) > 1: self.model_with_loss = DataParallel( self.model_with_loss, device_ids=gpus, chunk_sizes=chunk_sizes).to(device) else: self.model_with_loss = self.model_with_loss.to(device) for state in self.optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device=device, non_blocking=True) def run_epoch(self, phase, epoch, data_loader): model_with_loss = self.model_with_loss if phase == 'train': model_with_loss.train() else: # if len(self.opt.gpus) > 1: # model_with_loss = self.model_with_loss.module model_with_loss.eval() # torch.cuda.empty_cache() opt = self.opt results = {} data_time, batch_time = AverageMeter(), AverageMeter() avg_loss_stats = {l: AverageMeter() for l in self.loss_stats} num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters) end = time.time() for iter_id, batch in enumerate(data_loader): cur_step = (epoch - 1) * num_iters + iter_id self.lr_scheduler.step(cur_step) cur_lr = None for param_group in self.optimizer.param_groups: cur_lr = param_group['lr'] break if iter_id >= num_iters: break data_time.update(time.time() - end) for k in batch: # FIXME: # key changed from 'meta' to 'meta_c' ... # if k != 'meta': if 'meta' not in k: batch[k] = batch[k].to(device=opt.device, non_blocking=True) output, loss, loss_stats = model_with_loss(batch) loss = loss.mean() if phase == 'train': self.optimizer.zero_grad() loss.backward() self.optimizer.step() batch_time.update(time.time() - end) end = time.time() if self.logger is not None: for l in avg_loss_stats: if avg_loss_stats[l].val > 10: self.logger.scalar_summary( '{}_loss_large/{}'.format(phase, l), avg_loss_stats[l].val, cur_step) else: self.logger.scalar_summary( '{}_loss/{}'.format(phase, l), avg_loss_stats[l].val, cur_step) if phase == 'train': self.logger.scalar_summary('lr', cur_lr, cur_step) Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format( epoch, iter_id, num_iters, phase=phase, total=bar.elapsed_td, eta=bar.eta_td) for l in avg_loss_stats: avg_loss_stats[l].update(loss_stats[l].mean().item(), batch['input'].size(0)) Bar.suffix = Bar.suffix + '|{} {:.4f}({:.4f}) '.format( l, avg_loss_stats[l].val, avg_loss_stats[l].avg) if not opt.hide_data_time: Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \ '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time) if opt.print_iter > 0: if iter_id % opt.print_iter == 0: print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix)) else: bar.next() if opt.debug > 0: if self.vis_interval != 0: if iter_id % self.vis_interval == 0: self.debug(batch, output, cur_step) if opt.test: self.save_result(output, batch, results) del output, loss, loss_stats bar.finish() ret = {k: v.avg for k, v in avg_loss_stats.items()} ret['time'] = bar.elapsed_td.total_seconds() / 60. return ret, results def debug(self, batch, output, iter_id): raise NotImplementedError def save_result(self, output, batch, results): raise NotImplementedError def _get_losses(self, opt): raise NotImplementedError def val(self, epoch, data_loader): return self.run_epoch('val', epoch, data_loader) def train(self, epoch, data_loader): return self.run_epoch('train', epoch, data_loader)
class BaseTrainer(object): def __init__(self, opt, model, optimizer=None): self.opt = opt self.optimizer = optimizer self.loss_stats, self.loss = self._get_losses(opt) self.model_with_loss = ModleWithLoss(model, self.loss) def set_device(self, gpus, chunk_sizes, device): if len(gpus) > 1: self.model_with_loss = DataParallel( self.model_with_loss, device_ids=gpus, chunk_sizes=chunk_sizes).to(device) else: self.model_with_loss = self.model_with_loss.to(device) for state in self.optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device=device, non_blocking=True) def run_epoch(self, phase, epoch, data_loader, **kwargs): logger = kwargs.get('logger') model_with_loss = self.model_with_loss if phase == 'train': model_with_loss.train() else: if len(self.opt.gpus) > 1: model_with_loss = self.model_with_loss.module model_with_loss.eval() torch.cuda.empty_cache() opt = self.opt tb_writer = kwargs['tb_writer'] results = {} data_time, batch_time = AverageMeter(), AverageMeter() avg_loss_stats = { l: AverageMeter() for l in self.loss_stats if not l.startswith('img_id') and not l.startswith('batch_') } num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters) end = time.time() print('num_iters:', num_iters) for iter_id, batch in enumerate(data_loader): if iter_id >= num_iters: break global_step = num_iters * (epoch - 1) + iter_id data_time.update(time.time() - end) for k in batch: if k != 'meta': batch[k] = batch[k].to(device=opt.device, non_blocking=True) output, loss, loss_stats = model_with_loss(batch, global_step, tb_writer) tb_time = time.time() if opt.equal_loss: if phase == 'train': loss = (loss[0] * opt.master_batch_size + loss[1] * (opt.batch_size - opt.master_batch_size) ) / opt.batch_size else: loss = loss.mean() else: loss = loss.mean() if phase == 'train': self.optimizer.zero_grad() loss.backward() self.optimizer.step() batch_time.update(time.time() - end) end = time.time() Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format( epoch, iter_id, num_iters, phase=phase, total=bar.elapsed_td, eta=bar.eta_td) for l in avg_loss_stats: __stats = loss_stats if tb_writer is not None: tb_writer.add_scalar(l, __stats[l].mean().item(), global_step=global_step, walltime=tb_time) _n = batch['input'].size(0) avg_loss_stats[l].update(__stats[l].mean().item(), _n) Bar.suffix = Bar.suffix + '|{} {:.4f} '.format( l, avg_loss_stats[l].avg) if not opt.hide_data_time: Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \ '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time) if opt.print_iter > 0: if iter_id % opt.print_iter == 0: print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix)) else: bar.next() if opt.debug > 0 and opt.debug < 5: self.debug(batch, output, iter_id) if opt.test: self.save_result(output, batch, results) del output, loss, loss_stats if tb_writer is not None: for l in avg_loss_stats: tb_writer.add_scalar('epoch_avg/{}'.format(l), avg_loss_stats[l].avg, global_step=global_step, walltime=time.time()) bar.finish() ret = {k: v.avg for k, v in avg_loss_stats.items()} ret['time'] = bar.elapsed_td.total_seconds() / 60. return ret, results def debug(self, batch, output, iter_id): raise NotImplementedError def save_result(self, output, batch, results): raise NotImplementedError def _get_losses(self, opt): raise NotImplementedError def val(self, epoch, data_loader, **kwargs): return self.run_epoch('val', epoch, data_loader, **kwargs) def train(self, epoch, data_loader, **kwargs): return self.run_epoch('train', epoch, data_loader, **kwargs) def wta_stat(self, epoch, data_loader): stats = [] model_with_loss = self.model_with_loss if len(self.opt.gpus) > 1: model_with_loss = self.model_with_loss.module model_with_loss.eval() torch.cuda.empty_cache() opt = self.opt results = {} data_time, batch_time = AverageMeter(), AverageMeter() avg_loss_stats = {l: AverageMeter() for l in self.loss_stats} num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters end = time.time() for iter_id, batch in enumerate(data_loader): if iter_id >= num_iters: break global_step = num_iters * (epoch - 1) + iter_id data_time.update(time.time() - end) for k in batch: if k != 'meta': batch[k] = batch[k].to(device=opt.device, non_blocking=True) _, _, loss_stats = model_with_loss(batch, global_step, None) batch_time.update(time.time() - end) end = time.time() return stats
class BaseTrainer(object): def __init__( self, opt, model, optimizer=None): self.opt = opt self.optimizer = optimizer self.loss_stats, self.loss = self._get_losses(opt) self.model_with_loss = ModleWithLoss(model, self.loss) def set_device(self, gpus, chunk_sizes, device): if len(gpus) > 1: self.model_with_loss = DataParallel( self.model_with_loss, device_ids=gpus, chunk_sizes=chunk_sizes).to(device) else: self.model_with_loss = self.model_with_loss.to(device) for state in self.optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device=device, non_blocking=True) def run_epoch(self, phase, epoch, data_loader): model_with_loss = self.model_with_loss if phase == 'train': model_with_loss.train() else: if len(self.opt.gpus) > 1: model_with_loss = self.model_with_loss.module model_with_loss.eval() torch.cuda.empty_cache() opt = self.opt results = {} data_time, batch_time = AverageMeter(), AverageMeter() avg_loss_stats = {l: AverageMeter() for l in self.loss_stats} num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters) end = time.time() for iter_id, batch in enumerate(data_loader): if iter_id >= num_iters: break data_time.update(time.time() - end) for k in batch: if k != 'meta': # print("k=",k) # print("batch[k].size=",batch[k].size()) batch[k] = batch[k].to(device=opt.device, non_blocking=True) output, loss, loss_stats = model_with_loss(batch, fsm=self.opt.fsm) loss = loss.mean() if phase == 'train': self.optimizer.zero_grad() loss.backward() self.optimizer.step() batch_time.update(time.time() - end) end = time.time() Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format( epoch, iter_id, num_iters, phase=phase, total=bar.elapsed_td, eta=bar.eta_td) for l in avg_loss_stats: if l not in loss_stats: continue avg_loss_stats[l].update( loss_stats[l].mean().item(), batch['input'].size(0)) Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(l, avg_loss_stats[l].avg) if not opt.hide_data_time: Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \ '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time) if opt.print_iter > 0: if iter_id % opt.print_iter == 0: print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix)) else: bar.next() if opt.debug > 0: self.debug(batch, output, iter_id) if opt.test: self.save_result(output, batch, results) del output, loss, loss_stats bar.finish() ret = {k: v.avg for k, v in avg_loss_stats.items()} ret['time'] = bar.elapsed_td.total_seconds() / 60. return ret, results def find_lr(self, epoch, data_loader, lr_start=1e-5, lr_end=1e0, beta=0.98): model_with_loss = self.model_with_loss model_with_loss.train() opt = self.opt data_time, batch_time = AverageMeter(), AverageMeter() avg_loss_stats = {l: AverageMeter() for l in self.loss_stats} num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters) end = time.time() lr_mul = (lr_end / lr_start) ** (1. / (num_iters-1)) lrs = [] losses = [] avg_loss = 0 best_loss = 1e9 for param_group in self.optimizer.param_groups: param_group['lr'] = lr_start for iter_id, batch in enumerate(data_loader): if iter_id >= num_iters: break data_time.update(time.time() - end) for k in batch: if k != 'meta': # print("k=",k) # print("batch[k].size=",batch[k].size()) batch[k] = batch[k].to(device=opt.device, non_blocking=True) e2e = True if self.opt.e2e else False output, loss, loss_stats = model_with_loss(batch, e2e=e2e) loss = loss.mean() avg_loss = beta * avg_loss + (1 - beta) * loss.data smoothed_loss = avg_loss / (1 - beta ** (iter_id+1)) self.optimizer.zero_grad() loss.backward() self.optimizer.step() batch_time.update(time.time() - end) end = time.time() if smoothed_loss < best_loss: best_loss = smoothed_loss lrs.append(self.optimizer.param_groups[0]['lr']) losses.append(float(smoothed_loss)) if smoothed_loss > 4 * best_loss and iter_id > 0: break for param_group in self.optimizer.param_groups: param_group['lr'] = param_group['lr'] * lr_mul Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format( epoch, iter_id, num_iters, phase='train', total=bar.elapsed_td, eta=bar.eta_td) for l in avg_loss_stats: if l not in loss_stats: continue avg_loss_stats[l].update( loss_stats[l].mean().item(), batch['input'].size(0)) Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(l, avg_loss_stats[l].avg) Bar.suffix = Bar.suffix + '|lr {:.6f}'.format(lrs[-1]) if not opt.hide_data_time: Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \ '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time) if opt.print_iter > 0: if iter_id % opt.print_iter == 0: print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix)) else: bar.next() if opt.debug > 0: self.debug(batch, output, iter_id) del output, loss, loss_stats plt.figure() plt.xticks(np.log([1e-5, 1e-4, 1e-3, 1e-2, 5e-2, 1e-1, 5e-1, 1e0]),(1e-5, 1e-4, 1e-3, 1e-2, 5e-2, 1e-1, 5e-1, 1e0)) plt.xlabel('learning rate') plt.ylabel('loss') plt.plot(np.log(lrs), losses) # plt.show() plt.savefig(opt.exp_id+'_lr_loss.png') sys.exit() def debug(self, batch, output, iter_id): raise NotImplementedError def save_result(self, output, batch, results): raise NotImplementedError def _get_losses(self, opt): raise NotImplementedError def val(self, epoch, data_loader): return self.run_epoch('val', epoch, data_loader) def train(self, epoch, data_loader, find_lr=False): if find_lr: self.find_lr(epoch, data_loader) return self.run_epoch('train', epoch, data_loader)
class BaseTrainer(object): def __init__( self, opt, model, optimizer=None): self.opt = opt self.optimizer = optimizer self.loss_stats, self.loss = self._get_losses(opt) self.model_with_loss = ModleWithLoss(model, self.loss) self.reconstruct_img = False def set_device(self, gpus, chunk_sizes, device): if len(gpus) > 1: self.model_with_loss = DataParallel( self.model_with_loss, device_ids=gpus, chunk_sizes=chunk_sizes).to(device) else: self.model_with_loss = self.model_with_loss.to(device) for state in self.optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device=device, non_blocking=True) def save_tensor_to_img(self, tensors, filenames, path): batch_size, channel, w, h = tensors.size() reconstruct_imgs = tensors.detach().cpu().numpy().transpose(0,2,3,1) for ind in range(batch_size): save_path = os.path.join(path, filenames[ind]) cv2.imwrite(save_path, reconstruct_imgs[ind]) pass def run_epoch(self, phase, epoch, data_loader, logger=None): model_with_loss = self.model_with_loss if phase == 'train': model_with_loss.train() else: if len(self.opt.gpus) > 1: model_with_loss = self.model_with_loss.module model_with_loss.eval() torch.cuda.empty_cache() opt = self.opt results = {} data_time, batch_time = AverageMeter(), AverageMeter() avg_loss_stats = {l: AverageMeter() for l in self.loss_stats} num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters) end = time.time() for iter_id, batch in enumerate(data_loader): if iter_id >= num_iters: break data_time.update(time.time() - end) for k in batch: if k != 'meta': batch[k] = batch[k].to(device=opt.device, non_blocking=True) output, loss, loss_stats = model_with_loss(batch) if self.reconstruct_img: file_path = '/data/mry/code/CenterNet/debug_conflict_bt_class_recon' self.save_tensor_to_img(output['reconstruct_img'], batch['meta']['file_name'], file_path) loss = loss.mean() if phase == 'train': self.optimizer.zero_grad() loss.backward() self.optimizer.step() batch_time.update(time.time() - end) end = time.time() Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format( epoch, iter_id, num_iters, phase=phase, total=bar.elapsed_td, eta=bar.eta_td) for l in avg_loss_stats: if l == 'KL_loss': if loss_stats[l] is not None: avg_loss_stats[l].update(loss_stats[l].mean().item(), batch['input'].size(0)) Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(l, avg_loss_stats[l]) else: avg_loss_stats[l].update(0, batch['input'].size(0)) continue avg_loss_stats[l].update( loss_stats[l].mean().item(), batch['input'].size(0)) Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(l, avg_loss_stats[l].avg) if not opt.hide_data_time: Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \ '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time) if logger and iter_id % opt.logger_iteration == 0: logger.write_iteration( '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format( epoch, iter_id, num_iters, phase=phase, total=bar.elapsed_td, eta=bar.eta_td)) for l in avg_loss_stats: if loss_stats[l] is None: continue avg_loss_stats[l].update(loss_stats[l].mean().item(), batch['input'].size(0)) logger.write_iteration('|{} {:.4f} '.format(l, avg_loss_stats[l].avg)) logger.scalar_summary('train_iteration_{}'.format(l), avg_loss_stats[l].avg, (epoch-1)*num_iters+iter_id) logger.write_iteration('\n') if opt.print_iter > 0: if iter_id % opt.print_iter == 0: print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix)) else: bar.next() if opt.debug > 0: self.debug(batch, output, iter_id) if opt.test: self.save_result(output, batch, results) del output, loss, loss_stats bar.finish() ret = {k: v.avg for k, v in avg_loss_stats.items()} ret['time'] = bar.elapsed_td.total_seconds() / 60. return ret, results def debug(self, batch, output, iter_id): raise NotImplementedError def save_result(self, output, batch, results): raise NotImplementedError def _get_losses(self, opt): raise NotImplementedError def val(self, epoch, data_loader): return self.run_epoch('val', epoch, data_loader) def train(self, epoch, data_loader, logger): return self.run_epoch('train', epoch, data_loader, logger=logger)
class BaseTrainer(object): def __init__( self, opt, model, optimizer=None): self.opt = opt self.optimizer = optimizer self.loss_stats, self.loss = self._get_losses(opt) self.model_with_loss = ModleWithLoss(model, self.loss) self.optimizer.add_param_group({'params': self.loss.parameters()}) def set_device(self, gpus, chunk_sizes, device): if len(gpus) > 1: self.model_with_loss = DataParallel( self.model_with_loss, device_ids=gpus, chunk_sizes=chunk_sizes).to(device) else: self.model_with_loss = self.model_with_loss.to(device) for state in self.optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device=device, non_blocking=True) def run_epoch(self, phase, epoch, data_loader): model_with_loss = self.model_with_loss if phase == 'train': model_with_loss.train() else: if len(self.opt.gpus) > 1: model_with_loss = self.model_with_loss.module model_with_loss.eval() torch.cuda.empty_cache() opt = self.opt results = {} data_time, batch_time = AverageMeter(), AverageMeter() avg_loss_stats = {l: AverageMeter() for l in self.loss_stats} num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters) end = time.time() for iter_id, batch in enumerate(data_loader): if iter_id >= num_iters: break data_time.update(time.time() - end) for k in batch: if k != 'meta': batch[k] = batch[k].to(device=opt.device, non_blocking=True) output, loss, loss_stats = model_with_loss(batch) loss = loss.mean() if phase == 'train': self.optimizer.zero_grad() loss.backward() self.optimizer.step() batch_time.update(time.time() - end) end = time.time() Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format( epoch, iter_id, num_iters, phase=phase, total=bar.elapsed_td, eta=bar.eta_td) for l in avg_loss_stats: avg_loss_stats[l].update( loss_stats[l].mean().item(), batch['input'].size(0)) Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(l, avg_loss_stats[l].avg) if not opt.hide_data_time: Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \ '|Net time: {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time) if opt.print_iter > 0: if iter_id % opt.print_iter == 0: print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix)) else: bar.next() if opt.test: self.save_result(output, batch, results) del output, loss, loss_stats, batch bar.finish() ret = {k: v.avg for k, v in avg_loss_stats.items()} ret['time'] = bar.elapsed_td.total_seconds() / 60. return ret, results def debug(self, batch, output, iter_id): raise NotImplementedError def save_result(self, output, batch, results): raise NotImplementedError def _get_losses(self, opt): raise NotImplementedError def val(self, epoch, data_loader): return self.run_epoch('val', epoch, data_loader) def train(self, epoch, data_loader): return self.run_epoch('train', epoch, data_loader)
class CtTrainer(object): def __init__(self, opt, model, optimizer=None, dataloader=None): # super(CtTrainer, self).__init__( opt, model, optimizer=optimizer, dataloader=dataloader) self.opt = opt self.model = model # self.criterion = criterion self.optimizer = optimizer self.dataloader = dataloader self.iterations = 0 self.loss_stats, self.loss = self._get_losses(opt) self.model_with_loss = ModelWithLoss(model, self.loss) self.num_iters = len(self.dataloader) self.stats = {} self.phase = "train" self.plugin_queues = { 'iteration': [], 'epoch': [], 'batch': [], 'update': [], } def _get_losses(self, opt): loss_states = ['loss', 'hm_loss', 'wh_loss', 'off_loss'] loss = CtdetLoss(opt) return loss_states, loss def set_device(self, gpus, chunk_sizes, device): if len(gpus) > 1: self.model_with_loss = DataParallel( self.model_with_loss, device_ids=gpus, chunk_sizes=chunk_sizes).to(device) else: self.model_with_loss = self.model_with_loss.to(device) for state in self.optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device=device, non_blocking=True) def register_plugin(self, plugin): # 注册插件 plugin.register(self) # 插件的触发间隔,一般是这样的形式[(1, 'iteration'), (1, 'epoch')] intervals = plugin.trigger_interval if not isinstance(intervals, list): intervals = [intervals] for duration, unit in intervals: # unit 是事件的触发类别 queue = self.plugin_queues[unit] '''添加事件, 这里的duration就是触发间隔,,以后在调用插件的时候, 会进行更新 duration 决定了比如在第几个iteration or epoch 触发事件。len(queue)这里应当理解为优先级(越小越高) 【在相同duration的情况下决定调用的顺序】,根据加入队列的早晚决定。''' queue.append((duration, len(queue), plugin)) def call_plugins(self, queue_name, time, *args): # 调用插件 args = (time, ) + args # 这里的time 最基本的意思是次数,如(iteration or epoch) queue = self.plugin_queues[queue_name] if len(queue) == 0: return while queue[0][0] <= time: '''如果队列第一个事件的duration(也就是触发时间点)小于当前times''' plugin = queue[0][2] '''调用相关队列相应的方法,所以如果是继承Plugin类的插件, 必须实现 iteration、batch、epoch和update中的至少一个且名字必须一致。''' getattr(plugin, queue_name)(*args) for trigger in plugin.trigger_interval: if trigger[1] == queue_name: interval = trigger[0] '''根据插件的事件触发间隔,来更新事件队列里的事件 duration''' new_item = (time + interval, queue[0][1], plugin) heapq.heappushpop(queue, new_item) '''加入新的事件并弹出最小堆的堆头。最小堆重新排序。''' def run(self, epochs=1): for q in self.plugin_queues.values(): '''对四个事件调用序列进行最小堆排序。''' heapq.heapify(q) for i in range(1, epochs + 1): self.train() # 进行每次epoch 的更新 self.call_plugins('epoch', i) def train(self): model_with_loss = self.model_with_loss if self.phase == 'train': model_with_loss.train() for iter_id, batch in enumerate(self.dataloader): batch_input = batch['input'] batch_hm = batch['hm'] print("batch hm is: ", batch_hm) batch_reg_mask, batch_ind, batch_wh, batch_reg = batch[ 'reg_mask'], batch['ind'], batch['wh'], batch['reg'] # self.call_plugins('batch', iter_id, batch_input, batch_hm, batch_reg_mask, batch_ind, batch_wh, batch_reg) self.call_plugins('batch', iter_id, batch_input, batch_hm) if iter_id >= self.num_iters: break # data_time.update(time.time() - end) # plugin_data = [None, None, None, None] plugin_data = [None, None] for k in batch: if k != 'meta': batch[k] = batch[k].to(device=self.opt.device, non_blocking=True) # def closure(): batch_output, loss, loss_stats = self.model_with_loss(batch) # print("batch_output is", batch_output) # batch_output = self.model(input_var) loss = loss.mean() print('loss is,', loss) if plugin_data[0] is None: plugin_data[0] = batch_output['hm'].data # plugin_data[1] = batch_output['wh'].data # plugin_data[2] = batch_output['reg'].data # plugin_data[2] = batch_output['reg'].data plugin_data[1] = loss.data # return loss if self.phase == 'train': self.optimizer.zero_grad() loss.backward() self.optimizer.step() print("current loss is: ", loss) self.call_plugins('iteration', iter_id, batch_input, batch_hm * plugin_data) # self.call_plugins('iteration', iter_id, batch_input, batch_hm, batch_reg_mask, batch_ind, batch_wh, batch_reg, # *plugin_data) self.call_plugins('update', iter_id, self.model) # for i, data in enumerate(self.dataloader, self.iterations + 1): # batch_input, batch_target = data # # 在每次获取batch data 后进行更新 # self.call_plugins('batch', i, batch_input, batch_target) # input_var = batch_input # target_var = batch_target # # 这里是给后续插件做缓存部分数据,这里是网络输出与loss # plugin_data = [None, None] # # def closure(): # batch_output = self.model(input_var) # loss = self.criterion(batch_output, target_var) # loss.backward() # if plugin_data[0] is None: # plugin_data[0] = batch_output.data # plugin_data[1] = loss.data # return loss # # self.optimizer.zero_grad() # self.optimizer.step(closure) # self.call_plugins('iteration', i, batch_input, batch_target, # *plugin_data) # self.call_plugins('update', i, self.model) self.iterations += i
class BaseTrainer(object): def __init__(self, opt, model_image, model_point, optimizer_image, optimizer_point): self.opt = opt self.optimizer_image = optimizer_image self.optimizer_point = optimizer_point self.loss_stats, self.loss = self._get_losses(opt) self.model_image = model_image self.model_point = ModelWithLoss(model_image, model_point, self.loss) #model_point def set_device(self, gpus, chunk_sizes, device, model_select): if len(gpus) > 1: if model_select == 'image': self.model_image = DataParallel( self.model_image, device_ids=gpus, chunk_sizes=chunk_sizes).to(device) elif model_select == 'point': self.model_point = DataParallel( self.model_point, device_ids=gpus, chunk_sizes=chunk_sizes).to(device) else: print("no support model") else: if model_select == 'image': self.model_image = self.model_image.to(device) elif model_select == 'point': self.model_point = self.model_point.to(device) else: print("no support model") if model_select == 'image': for state in self.optimizer_image.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device=device, non_blocking=True) elif model_select == 'point': for state in self.optimizer_point.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device=device, non_blocking=True) else: print("no support model") def run_epoch(self, phase, epoch, data_loader): model_image = self.model_image model_point = self.model_point if phase == 'train': model_image.train() model_point.train() else: model_image.eval() model_point.eval() torch.cuda.empty_cache() opt = self.opt results = {} data_time, batch_time = AverageMeter(), AverageMeter() avg_loss_stats = {l: AverageMeter() for l in self.loss_stats} num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters bar = Bar('{}/{}'.format("3D detection", opt.exp_id), max=num_iters) end = time.time() for iter_id, batch in enumerate(data_loader): if iter_id >= num_iters: break data_time.update(time.time() - end) for k in batch: if k != 'meta': batch[k] = batch[k].to(device=opt.device, non_blocking=True) if iter_id == 8: a = 1 loss, loss_stats, next_est = model_point( batch, phase, epoch=epoch, opt=self.opt) #loss, loss_stats next_est = next_est.view(batch['reg_mask'].size(0), batch['reg_mask'].size(1), -1).detach() loss = loss.mean() if phase == 'train': self.optimizer_image.zero_grad() self.optimizer_point.zero_grad() loss.backward() self.optimizer_image.step() self.optimizer_point.step() batch['pos_est'] = next_est[:, :, :3] batch['dim_est'] = next_est[:, :, 3:6] ry = next_est[:, :, 6] R_yaw = batch['pos_est'].new_zeros(next_est.size(0), next_est.size(1), 3, 3) R_yaw[:, :, 0, 0] = torch.cos(ry) R_yaw[:, :, 0, 2] = torch.sin(ry) R_yaw[:, :, 1, 1] = 1 R_yaw[:, :, 2, 0] = -torch.sin(ry) R_yaw[:, :, 2, 2] = torch.cos(ry) batch['ori_est'] = R_yaw batch['ori_est_scalar'] = ry loss, loss_stats, next_est = model_point(batch, phase, epoch=epoch, opt=self.opt) loss = 2 * loss.mean() self.optimizer_image.zero_grad() self.optimizer_point.zero_grad() loss.backward() self.optimizer_image.step() self.optimizer_point.step() batch_time.update(time.time() - end) end = time.time() Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format( epoch, iter_id, num_iters, phase=phase, total=bar.elapsed_td, eta=bar.eta_td) for l in avg_loss_stats: avg_loss_stats[l].update(loss_stats[l].mean().item(), batch['input'].size(0)) Bar.suffix = Bar.suffix + '|{} {:.4f} '.format( l, avg_loss_stats[l].avg) if not opt.hide_data_time: Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \ '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time) if opt.print_iter > 0: if iter_id % opt.print_iter == 0: print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix)) else: bar.next() del loss, loss_stats bar.finish() ret = {k: v.avg for k, v in avg_loss_stats.items()} ret['time'] = bar.elapsed_td.total_seconds() / 60. return ret, results def debug(self, batch, output, iter_id): raise NotImplementedError def save_result(self, output, batch, results): raise NotImplementedError def _get_losses(self, opt): raise NotImplementedError def val(self, epoch, data_loader): return self.run_epoch('val', epoch, data_loader) def train(self, epoch, data_loader): return self.run_epoch('train', epoch, data_loader)
class Hoidet(object): def __init__(self, opt, model, optimizer=None): self.opt = opt self.optimizer = optimizer loss = HoidetLoss(opt) self.loss_states = [ 'loss', 'hm_loss', 'wh_loss', 'off_loss', 'hm_rel_loss', 'sub_offset_loss', 'obj_offset_loss' ] self.model_with_loss = ModelWithLoss(model, loss) def set_device(self, gpus, chunk_sizes, device): if len(gpus) > 1: self.model_with_loss = DataParallel( self.model_with_loss, device_ids=gpus, chunk_sizes=chunk_sizes).to(device) else: self.model_with_loss = self.model_with_loss.to(device) for state in self.optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device=device, non_blocking=True) def run_epoch(self, model_with_loss, epoch, data_loader, phase='train'): opt = self.opt results = {} data_time, batch_time = AverageMeter(), AverageMeter() avg_loss_states = {l: AverageMeter() for l in self.loss_states} num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters) end = time.time() for iter_id, batch in enumerate(data_loader): if iter_id >= num_iters: break data_time.update(time.time() - end) for k in batch: if k != 'meta': batch[k] = batch[k].to(device=opt.device, non_blocking=True) output, loss, loss_states = model_with_loss(batch) loss = loss.mean() if phase == 'train': self.optimizer.zero_grad() loss.backward() self.optimizer.step() batch_time.update(time.time() - end) end = time.time() Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format( epoch, iter_id, num_iters, phase=phase, total=bar.elapsed_td, eta=bar.eta_td) for l in avg_loss_states: avg_loss_states[l].update(loss_states[l].mean().item(), batch['input'].size(0)) Bar.suffix = Bar.suffix + '|{} {:.4f} '.format( l, avg_loss_states[l].avg) if not opt.hide_data_time: Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \ '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time) if opt.print_iter > 0: if iter_id % opt.print_iter == 0: print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix)) else: bar.next() del output, loss, loss_states bar.finish() ret = {k: v.avg for k, v in avg_loss_states.items()} ret['time'] = bar.elapsed_td.total_seconds() / 60. return ret, results def train(self, epoch, data_loader): model_with_loss = self.model_with_loss model_with_loss.train() ret, results = self.run_epoch(model_with_loss, epoch, data_loader) return ret, results def val(self, epoch, data_loader): model_with_loss = self.model_with_loss model_with_loss.eval() torch.cuda.empty_cache() with torch.no_grad: ret, results = self.run_epoch(model_with_loss, epoch, data_loader, phase='val') return ret, results