class Trainer(object): def __init__(self, opt, model, optimizer=None): self.opt = opt self.optimizer = optimizer self.loss_stats, self.loss = self._get_losses(opt) self.model_with_loss = ModleWithLoss(model, self.loss) def set_device(self, gpus, chunk_sizes, device): if len(gpus) > 1: self.model_with_loss = DataParallel( self.model_with_loss, device_ids=gpus, chunk_sizes=chunk_sizes).to(device) else: self.model_with_loss = self.model_with_loss.to(device) for state in self.optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device=device, non_blocking=True) def run_epoch(self, phase, epoch, data_loader): print("yesyes ") model_with_loss = self.model_with_loss if phase == "train": model_with_loss.train() else: if len(self.opt.gpus) > 1: model_with_loss = self.model_with_loss.module model_with_loss.eval() torch.cuda.empty_cache() opt = self.opt results = {} data_time, batch_time = AverageMeter(), AverageMeter() avg_loss_stats = { l: AverageMeter() for l in self.loss_stats if l == "tot" or opt.weights[l] > 0 } num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters bar = Bar("{}/{}".format(opt.task, opt.exp_id), max=num_iters) end = time.time() for iter_id, batch in enumerate(data_loader): if iter_id >= num_iters: break data_time.update(time.time() - end) for k in batch: if k != "meta": batch[k] = batch[k].to(device=opt.device, non_blocking=True) output, loss, loss_stats = model_with_loss(batch) loss = loss.mean() if phase == "train": self.optimizer.zero_grad() loss.backward() self.optimizer.step() batch_time.update(time.time() - end) end = time.time() Bar.suffix = "{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} ".format( epoch, iter_id, num_iters, phase=phase, total=bar.elapsed_td, eta=bar.eta_td, ) for l in avg_loss_stats: avg_loss_stats[l].update(loss_stats[l].mean().item(), batch["image"].size(0)) Bar.suffix = Bar.suffix + "|{} {:.4f} ".format( l, avg_loss_stats[l].avg) Bar.suffix = ( Bar.suffix + "|Data {dt.val:.3f}s({dt.avg:.3f}s) " "|Net {bt.avg:.3f}s".format(dt=data_time, bt=batch_time)) if opt.print_iter > 0: # If not using progress bar if iter_id % opt.print_iter == 0: print("{}/{}| {}".format(opt.task, opt.exp_id, Bar.suffix)) else: bar.next() if opt.debug > 0: self.debug(batch, output, iter_id, dataset=data_loader.dataset) del output, loss, loss_stats bar.finish() ret = {k: v.avg for k, v in avg_loss_stats.items()} ret["time"] = bar.elapsed_td.total_seconds() / 60.0 return ret, results def _get_losses(self, opt): loss_order = [ "hm", "wh", "reg", "ltrb", "hps", "hm_hp", "hp_offset", "dep", "dim", "rot", "amodel_offset", "ltrb_amodal", "tracking", "nuscenes_att", "velocity", ] loss_states = ["tot"] + [k for k in loss_order if k in opt.heads] loss = GenericLoss(opt) return loss_states, loss def debug(self, batch, output, iter_id, dataset): opt = self.opt if "pre_hm" in batch: output.update({"pre_hm": batch["pre_hm"]}) dets = generic_decode(output, K=opt.K, opt=opt) for k in dets: dets[k] = dets[k].detach().cpu().numpy() dets_gt = batch["meta"]["gt_det"] for i in range(1): debugger = Debugger(opt=opt, dataset=dataset) img = batch["image"][i].detach().cpu().numpy().transpose(1, 2, 0) img = np.clip(((img * dataset.std + dataset.mean) * 255.0), 0, 255).astype(np.uint8) pred = debugger.gen_colormap( output["hm"][i].detach().cpu().numpy()) gt = debugger.gen_colormap(batch["hm"][i].detach().cpu().numpy()) debugger.add_blend_img(img, pred, "pred_hm") debugger.add_blend_img(img, gt, "gt_hm") if "pre_img" in batch: pre_img = batch["pre_img"][i].detach().cpu().numpy().transpose( 1, 2, 0) pre_img = np.clip( ((pre_img * dataset.std + dataset.mean) * 255), 0, 255).astype(np.uint8) debugger.add_img(pre_img, "pre_img_pred") debugger.add_img(pre_img, "pre_img_gt") if "pre_hm" in batch: pre_hm = debugger.gen_colormap( batch["pre_hm"][i].detach().cpu().numpy()) debugger.add_blend_img(pre_img, pre_hm, "pre_hm") debugger.add_img(img, img_id="out_pred") if "ltrb_amodal" in opt.heads: debugger.add_img(img, img_id="out_pred_amodal") debugger.add_img(img, img_id="out_gt_amodal") # Predictions for k in range(len(dets["scores"][i])): if dets["scores"][i, k] > opt.vis_thresh: debugger.add_coco_bbox( dets["bboxes"][i, k] * opt.down_ratio, dets["clses"][i, k], dets["scores"][i, k], img_id="out_pred", ) if "ltrb_amodal" in opt.heads: debugger.add_coco_bbox( dets["bboxes_amodal"][i, k] * opt.down_ratio, dets["clses"][i, k], dets["scores"][i, k], img_id="out_pred_amodal", ) if "hps" in opt.heads and int(dets["clses"][i, k]) == 0: debugger.add_coco_hp(dets["hps"][i, k] * opt.down_ratio, img_id="out_pred") if "tracking" in opt.heads: debugger.add_arrow( dets["cts"][i][k] * opt.down_ratio, dets["tracking"][i][k] * opt.down_ratio, img_id="out_pred", ) debugger.add_arrow( dets["cts"][i][k] * opt.down_ratio, dets["tracking"][i][k] * opt.down_ratio, img_id="pre_img_pred", ) # Ground truth debugger.add_img(img, img_id="out_gt") for k in range(len(dets_gt["scores"][i])): if dets_gt["scores"][i][k] > opt.vis_thresh: debugger.add_coco_bbox( dets_gt["bboxes"][i][k] * opt.down_ratio, dets_gt["clses"][i][k], dets_gt["scores"][i][k], img_id="out_gt", ) if "ltrb_amodal" in opt.heads: debugger.add_coco_bbox( dets_gt["bboxes_amodal"][i, k] * opt.down_ratio, dets_gt["clses"][i, k], dets_gt["scores"][i, k], img_id="out_gt_amodal", ) if "hps" in opt.heads and (int(dets["clses"][i, k]) == 0): debugger.add_coco_hp(dets_gt["hps"][i][k] * opt.down_ratio, img_id="out_gt") if "tracking" in opt.heads: debugger.add_arrow( dets_gt["cts"][i][k] * opt.down_ratio, dets_gt["tracking"][i][k] * opt.down_ratio, img_id="out_gt", ) debugger.add_arrow( dets_gt["cts"][i][k] * opt.down_ratio, dets_gt["tracking"][i][k] * opt.down_ratio, img_id="pre_img_gt", ) if "hm_hp" in opt.heads: pred = debugger.gen_colormap_hp( output["hm_hp"][i].detach().cpu().numpy()) gt = debugger.gen_colormap_hp( batch["hm_hp"][i].detach().cpu().numpy()) debugger.add_blend_img(img, pred, "pred_hmhp") debugger.add_blend_img(img, gt, "gt_hmhp") if "rot" in opt.heads and "dim" in opt.heads and "dep" in opt.heads: dets_gt = {k: dets_gt[k].cpu().numpy() for k in dets_gt} calib = (batch["meta"]["calib"].detach().numpy() if "calib" in batch["meta"] else None) det_pred = generic_post_process( opt, dets, batch["meta"]["c"].cpu().numpy(), batch["meta"]["s"].cpu().numpy(), output["hm"].shape[2], output["hm"].shape[3], self.opt.num_classes, calib, ) det_gt = generic_post_process( opt, dets_gt, batch["meta"]["c"].cpu().numpy(), batch["meta"]["s"].cpu().numpy(), output["hm"].shape[2], output["hm"].shape[3], self.opt.num_classes, calib, ) debugger.add_3d_detection( batch["meta"]["img_path"][i], batch["meta"]["flipped"][i], det_pred[i], calib[i], vis_thresh=opt.vis_thresh, img_id="add_pred", ) debugger.add_3d_detection( batch["meta"]["img_path"][i], batch["meta"]["flipped"][i], det_gt[i], calib[i], vis_thresh=opt.vis_thresh, img_id="add_gt", ) debugger.add_bird_views( det_pred[i], det_gt[i], vis_thresh=opt.vis_thresh, img_id="bird_pred_gt", ) if opt.debug == 4: debugger.save_all_imgs(opt.debug_dir, prefix="{}".format(iter_id)) else: debugger.show_all_imgs(pause=True) def val(self, epoch, data_loader): return self.run_epoch("val", epoch, data_loader) def train(self, epoch, data_loader): return self.run_epoch("train", epoch, data_loader)
class Trainer_Custom(object): def __init__(self, opt, model, optimizer=None): self.opt = opt self.optimizer = optimizer self.loss_stats, self.loss = self._get_losses(opt) self.model_with_loss = ModleWithLoss(model, self.loss) def set_device(self, gpus, chunk_sizes, device): if len(gpus) > 1: self.model_with_loss = DataParallel( self.model_with_loss, device_ids=gpus, chunk_sizes=chunk_sizes).to(device) else: self.model_with_loss = self.model_with_loss.to(device) for state in self.optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device=device, non_blocking=True) def _get_pre_fts(self, pre_imgs, pre_cts_fix, pre_len): # batch['image']: N*3*h*w # batch['pre_imgs']: N*T*3*H*W # batch['pre_cts_fix']:N*T*Om*2 # batch['pre_len']: N*T # hidden_ft: N*T*C*H*W # output: N*T*Om*C assert len(pre_imgs.shape) == 5 \ and len(pre_cts_fix.shape) == 4 \ and len(pre_len.shape) == 2 assert pre_imgs.shape[0] == pre_cts_fix.shape[0] \ and pre_len.shape[0] == pre_cts_fix.shape[0] N = pre_imgs.shape[0] assert pre_imgs.shape[1] == pre_cts_fix.shape[1] and pre_len.shape[ 1] == pre_cts_fix.shape[1] T = pre_imgs.shape[1] assert pre_imgs.shape[2] == 3 \ and pre_cts_fix.shape[3] == 2 and pre_cts_fix.shape[2] == self.opt.K Om = self.opt.K origin_H = pre_imgs.shape[3] origin_W = pre_imgs.shape[4] ft_H = int(origin_H / 4) ft_W = int(origin_W / 4) if self.opt.Freeze_ft: with torch.no_grad(): assert pre_imgs.is_contiguous() pre_imgs_viewed = pre_imgs.view(N * T, 3, origin_H, origin_W) hidden_ft = self.model_with_loss.model.img2feats( pre_imgs_viewed)[-1] assert hidden_ft.shape[0] == N*T \ and hidden_ft.shape[2] == int(origin_H/4) \ and hidden_ft.shape[3] == int(origin_W/4) Channel_O = hidden_ft.shape[1] assert pre_cts_fix.is_contiguous() pre_cts_fix_viewed = pre_cts_fix.view(N * T, Om, 2) assert pre_len.is_contiguous() pre_len_viewed = pre_len.view(N * T) ret_ft = torch.zeros((N * T, Om, Channel_O)) ret_ft = ret_ft.to(device=self.opt.device, non_blocking=True) for iterator_i in range(0, N * T): this_len = pre_len_viewed[iterator_i].detach().cpu().numpy( ) # print(pre_len_viewed[iterator_i].detach().cpu().numpy()) for iterator_o in range(0, this_len): H_o = pre_cts_fix_viewed[iterator_i, iterator_o, 1].detach().cpu().numpy() W_o = pre_cts_fix_viewed[iterator_i, iterator_o, 0].detach().cpu().numpy() ft_o = hidden_ft[ iterator_i, :, int(np.rint(H_o) ) if int(np.rint(H_o)) < ft_H else ft_H - 1, int(np.rint(W_o) ) if int(np.rint(W_o)) < ft_W else ft_W - 1] ret_ft[iterator_i, iterator_o, :] = ft_o assert ret_ft.is_contiguous() ret = ret_ft.view(N, T, Om, Channel_O).detach().requires_grad_(False) del ret_ft del pre_imgs del hidden_ft del pre_cts_fix_viewed del pre_imgs_viewed torch.cuda.empty_cache() else: assert pre_imgs.is_contiguous() pre_imgs_viewed = pre_imgs.view(N * T, 3, origin_H, origin_W) hidden_ft = self.model_with_loss.model.img2feats( pre_imgs_viewed)[-1] assert hidden_ft.shape[0] == N*T \ and hidden_ft.shape[2] == int(origin_H/4) \ and hidden_ft.shape[3] == int(origin_W/4) Channel_O = hidden_ft.shape[1] assert pre_cts_fix.is_contiguous() pre_cts_fix_viewed = pre_cts_fix.view(N * T, Om, 2) assert pre_len.is_contiguous() pre_len_viewed = pre_len.view(N * T) ret_ft = torch.zeros((N * T, Om, Channel_O)) ret_ft = ret_ft.to(device=self.opt.device, non_blocking=True) for iterator_i in range(0, N * T): this_len = pre_len_viewed[iterator_i].detach().cpu().numpy() # print(pre_len_viewed[iterator_i].detach().cpu().numpy()) for iterator_o in range(0, this_len): H_o = pre_cts_fix_viewed[iterator_i, iterator_o, 1].detach().cpu().numpy() W_o = pre_cts_fix_viewed[iterator_i, iterator_o, 0].detach().cpu().numpy() ft_o = hidden_ft[ iterator_i, :, int(np.rint(H_o) ) if int(np.rint(H_o)) < ft_H else ft_H - 1, int(np.rint(W_o) ) if int(np.rint(W_o)) < ft_W else ft_W - 1] ret_ft[iterator_i, iterator_o, :] = ft_o assert ret_ft.is_contiguous() ret = ret_ft.view(N, T, Om, Channel_O) return ret def run_epoch(self, phase, epoch, data_loader): model_with_loss = self.model_with_loss if phase == 'train': model_with_loss.train() else: if len(self.opt.gpus) > 1: model_with_loss = self.model_with_loss.module model_with_loss.eval() torch.cuda.empty_cache() opt = self.opt results = {} data_time, batch_time = AverageMeter(), AverageMeter() avg_loss_stats = {l: AverageMeter() for l in self.loss_stats \ if l == 'tot' or opt.weights[l] > 0} num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters) end = time.time() for iter_id, batch in enumerate(data_loader): # print(batch.keys()) # print(batch['hm_old'].shape) # print(batch['hm_new'].shape) # print(batch['hm'].shape) # print(batch['hm_old'].mean()) # print(batch['hm_new'].mean()) # print(batch['hm'].mean()) # exit() if iter_id >= num_iters: break data_time.update(time.time() - end) for k in batch: if k != 'meta': batch[k] = batch[k].to(device=opt.device, non_blocking=True) # print(batch['image'].shape) # print(batch['pre_imgs'].shape) # print(batch['pre_cts_fix'].shape) # print(batch['pre_len'].shape) assert batch['image'].shape[0] == batch['pre_imgs'].shape[0] # batch['image']: N*3*h*w # batch['pre_imgs']: N*T*3*H*W # batch['pre_cts_fix']:N*T*Om*2 # batch['pre_len']: N*T # pr_object_ft: N*T*Om*C pr_object_ft = self._get_pre_fts(batch['pre_imgs'], batch['pre_cts_fix'], batch['pre_len']) # for n in range(0, batch['pre_len'].shape[0]): # for t in range(0, batch['pre_len'].shape[1]): # print(batch['pre_len'][n, t]) # print(pr_object_ft[n, t, :, :]) # exit() output, loss, loss_stats = model_with_loss(batch, pr_object_ft) loss = loss.mean() if phase == 'train': self.optimizer.zero_grad() loss.backward() self.optimizer.step() batch_time.update(time.time() - end) end = time.time() Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format( epoch, iter_id, num_iters, phase=phase, total=bar.elapsed_td, eta=bar.eta_td) for l in avg_loss_stats: avg_loss_stats[l].update(loss_stats[l].mean().item(), batch['image'].size(0)) Bar.suffix = Bar.suffix + '|{} {:.4f} '.format( l, avg_loss_stats[l].avg) Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \ '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time) if opt.print_iter > 0: # If not using progress bar if iter_id % opt.print_iter == 0: print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix)) else: bar.next() if opt.debug > 0: self.debug(batch, output, iter_id, dataset=data_loader.dataset) del output, loss, loss_stats bar.finish() ret = {k: v.avg for k, v in avg_loss_stats.items()} ret['time'] = bar.elapsed_td.total_seconds() / 60. return ret, results def _get_losses(self, opt): loss_order = ['hm', 'wh', 'reg', 'ltrb', 'hps', 'hm_hp', \ 'hp_offset', 'dep', 'dim', 'rot', 'amodel_offset', \ 'ltrb_amodal', 'tracking', 'nuscenes_att', 'velocity'] loss_states = ['tot'] + [k for k in loss_order if k in opt.heads] loss = GenericLoss(opt) return loss_states, loss def debug(self, batch, output, iter_id, dataset): opt = self.opt if 'pre_hm' in batch: output.update({'pre_hm': batch['pre_hm']}) dets = generic_decode(output, K=opt.K, opt=opt) for k in dets: dets[k] = dets[k].detach().cpu().numpy() dets_gt = batch['meta']['gt_det'] for i in range(1): debugger = Debugger(opt=opt, dataset=dataset) img = batch['image'][i].detach().cpu().numpy().transpose(1, 2, 0) img = np.clip(((img * dataset.std + dataset.mean) * 255.), 0, 255).astype(np.uint8) pred = debugger.gen_colormap( output['hm'][i].detach().cpu().numpy()) gt = debugger.gen_colormap(batch['hm'][i].detach().cpu().numpy()) debugger.add_blend_img(img, pred, 'pred_hm') debugger.add_blend_img(img, gt, 'gt_hm') if 'pre_img' in batch: pre_img = batch['pre_img'][i].detach().cpu().numpy().transpose( 1, 2, 0) pre_img = np.clip( ((pre_img * dataset.std + dataset.mean) * 255), 0, 255).astype(np.uint8) debugger.add_img(pre_img, 'pre_img_pred') debugger.add_img(pre_img, 'pre_img_gt') if 'pre_hm' in batch: pre_hm = debugger.gen_colormap( batch['pre_hm'][i].detach().cpu().numpy()) debugger.add_blend_img(pre_img, pre_hm, 'pre_hm') debugger.add_img(img, img_id='out_pred') if 'ltrb_amodal' in opt.heads: debugger.add_img(img, img_id='out_pred_amodal') debugger.add_img(img, img_id='out_gt_amodal') # Predictions for k in range(len(dets['scores'][i])): if dets['scores'][i, k] > opt.vis_thresh: debugger.add_coco_bbox(dets['bboxes'][i, k] * opt.down_ratio, dets['clses'][i, k], dets['scores'][i, k], img_id='out_pred') if 'ltrb_amodal' in opt.heads: debugger.add_coco_bbox(dets['bboxes_amodal'][i, k] * opt.down_ratio, dets['clses'][i, k], dets['scores'][i, k], img_id='out_pred_amodal') if 'hps' in opt.heads and int(dets['clses'][i, k]) == 0: debugger.add_coco_hp(dets['hps'][i, k] * opt.down_ratio, img_id='out_pred') if 'tracking' in opt.heads: debugger.add_arrow(dets['cts'][i][k] * opt.down_ratio, dets['tracking'][i][k] * opt.down_ratio, img_id='out_pred') debugger.add_arrow(dets['cts'][i][k] * opt.down_ratio, dets['tracking'][i][k] * opt.down_ratio, img_id='pre_img_pred') # Ground truth debugger.add_img(img, img_id='out_gt') for k in range(len(dets_gt['scores'][i])): if dets_gt['scores'][i][k] > opt.vis_thresh: debugger.add_coco_bbox(dets_gt['bboxes'][i][k] * opt.down_ratio, dets_gt['clses'][i][k], dets_gt['scores'][i][k], img_id='out_gt') if 'ltrb_amodal' in opt.heads: debugger.add_coco_bbox(dets_gt['bboxes_amodal'][i, k] * opt.down_ratio, dets_gt['clses'][i, k], dets_gt['scores'][i, k], img_id='out_gt_amodal') if 'hps' in opt.heads and \ (int(dets['clses'][i, k]) == 0): debugger.add_coco_hp(dets_gt['hps'][i][k] * opt.down_ratio, img_id='out_gt') if 'tracking' in opt.heads: debugger.add_arrow( dets_gt['cts'][i][k] * opt.down_ratio, dets_gt['tracking'][i][k] * opt.down_ratio, img_id='out_gt') debugger.add_arrow( dets_gt['cts'][i][k] * opt.down_ratio, dets_gt['tracking'][i][k] * opt.down_ratio, img_id='pre_img_gt') if 'hm_hp' in opt.heads: pred = debugger.gen_colormap_hp( output['hm_hp'][i].detach().cpu().numpy()) gt = debugger.gen_colormap_hp( batch['hm_hp'][i].detach().cpu().numpy()) debugger.add_blend_img(img, pred, 'pred_hmhp') debugger.add_blend_img(img, gt, 'gt_hmhp') if 'rot' in opt.heads and 'dim' in opt.heads and 'dep' in opt.heads: dets_gt = {k: dets_gt[k].cpu().numpy() for k in dets_gt} calib = batch['meta']['calib'].detach().numpy() \ if 'calib' in batch['meta'] else None det_pred = generic_post_process( opt, dets, batch['meta']['c'].cpu().numpy(), batch['meta']['s'].cpu().numpy(), output['hm'].shape[2], output['hm'].shape[3], self.opt.num_classes, calib) det_gt = generic_post_process(opt, dets_gt, batch['meta']['c'].cpu().numpy(), batch['meta']['s'].cpu().numpy(), output['hm'].shape[2], output['hm'].shape[3], self.opt.num_classes, calib) debugger.add_3d_detection(batch['meta']['img_path'][i], batch['meta']['flipped'][i], det_pred[i], calib[i], vis_thresh=opt.vis_thresh, img_id='add_pred') debugger.add_3d_detection(batch['meta']['img_path'][i], batch['meta']['flipped'][i], det_gt[i], calib[i], vis_thresh=opt.vis_thresh, img_id='add_gt') debugger.add_bird_views(det_pred[i], det_gt[i], vis_thresh=opt.vis_thresh, img_id='bird_pred_gt') if opt.debug == 4: debugger.save_all_imgs(opt.debug_dir, prefix='{}'.format(iter_id)) else: debugger.show_all_imgs(pause=True) def val(self, epoch, data_loader): return self.run_epoch('val', epoch, data_loader) def train(self, epoch, data_loader): return self.run_epoch('train', epoch, data_loader)
class Trainer(object): def __init__(self, opt, model, optimizer=None): self.opt = opt self.optimizer = optimizer self.loss_stats, self.loss = self._get_losses(opt) self.model_with_loss = ModelWithLoss(model, self.loss, opt) def set_device(self, gpus, chunk_sizes, device): if len(gpus) > 1: self.model_with_loss = DataParallel( self.model_with_loss, device_ids=gpus, chunk_sizes=chunk_sizes).to(device) else: self.model_with_loss = self.model_with_loss.to(device) for state in self.optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device=device, non_blocking=True) def run_epoch(self, phase, epoch, data_loader): model_with_loss = self.model_with_loss if phase == 'train': model_with_loss.train() else: if len(self.opt.gpus) > 1: model_with_loss = self.model_with_loss.module model_with_loss.eval() torch.cuda.empty_cache() opt = self.opt results = {} data_time, batch_time = AverageMeter(), AverageMeter() avg_loss_stats = {l: AverageMeter() for l in self.loss_stats \ if l == 'tot' or opt.weights[l] > 0} num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters) end = time.time() for iter_id, batch in enumerate(data_loader): if iter_id >= num_iters: break data_time.update(time.time() - end) for k in batch: if k != 'meta': batch[k] = batch[k].to(device=opt.device, non_blocking=True) # run one iteration output, loss, loss_stats = model_with_loss(batch, phase) # backpropagate and step optimizer loss = loss.mean() if phase == 'train': self.optimizer.zero_grad() loss.backward() self.optimizer.step() batch_time.update(time.time() - end) end = time.time() Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format( epoch, iter_id, num_iters, phase=phase, total=bar.elapsed_td, eta=bar.eta_td) for l in avg_loss_stats: avg_loss_stats[l].update(loss_stats[l].mean().item(), batch['image'].size(0)) Bar.suffix = Bar.suffix + '|{} {:.4f} '.format( l, avg_loss_stats[l].avg) Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \ '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time) if opt.print_iter > 0: # If not using progress bar if iter_id % opt.print_iter == 0: print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix)) else: bar.next() if opt.debug > 0: self.debug(batch, output, iter_id, dataset=data_loader.dataset) # generate detections for evaluation if (phase == 'val' and (opt.run_dataset_eval or opt.eval)): meta = batch['meta'] dets = fusion_decode(output, K=opt.K, opt=opt) for k in dets: dets[k] = dets[k].detach().cpu().numpy() calib = meta['calib'].detach().numpy( ) if 'calib' in meta else None dets = generic_post_process(opt, dets, meta['c'].cpu().numpy(), meta['s'].cpu().numpy(), output['hm'].shape[2], output['hm'].shape[3], self.opt.num_classes, calib) # merge results result = [] for i in range(len(dets[0])): if dets[0][i]['score'] > self.opt.out_thresh and all( dets[0][i]['dim'] > 0): result.append(dets[0][i]) img_id = batch['meta']['img_id'].numpy().astype(np.int32)[0] results[img_id] = result del output, loss, loss_stats bar.finish() ret = {k: v.avg for k, v in avg_loss_stats.items()} ret['time'] = bar.elapsed_td.total_seconds() / 60. return ret, results def _get_losses(self, opt): loss_order = ['hm', 'wh', 'reg', 'ltrb', 'hps', 'hm_hp', \ 'hp_offset', 'dep', 'dep_sec', 'dim', 'rot', 'rot_sec', 'amodel_offset', 'ltrb_amodal', 'tracking', 'nuscenes_att', 'velocity'] loss_states = ['tot'] + [k for k in loss_order if k in opt.heads] loss = GenericLoss(opt) return loss_states, loss def debug(self, batch, output, iter_id, dataset): opt = self.opt if 'pre_hm' in batch: output.update({'pre_hm': batch['pre_hm']}) dets = fusion_decode(output, K=opt.K, opt=opt) for k in dets: dets[k] = dets[k].detach().cpu().numpy() dets_gt = batch['meta']['gt_det'] for i in range(1): debugger = Debugger(opt=opt, dataset=dataset) img = batch['image'][i].detach().cpu().numpy().transpose(1, 2, 0) img = np.clip(((img * dataset.std + dataset.mean) * 255.), 0, 255).astype(np.uint8) pred = debugger.gen_colormap( output['hm'][i].detach().cpu().numpy()) gt = debugger.gen_colormap(batch['hm'][i].detach().cpu().numpy()) debugger.add_blend_img(img, pred, 'pred_hm', trans=self.opt.hm_transparency) debugger.add_blend_img(img, gt, 'gt_hm', trans=self.opt.hm_transparency) debugger.add_img(img, img_id='img') # show point clouds if opt.pointcloud: pc_2d = batch['pc_2d'][i].detach().cpu().numpy() pc_3d = None pc_N = batch['pc_N'][i].detach().cpu().numpy() debugger.add_img(img, img_id='pc') debugger.add_pointcloud(pc_2d, pc_N, img_id='pc') if 'pc_hm' in opt.pc_feat_lvl: channel = opt.pc_feat_channels['pc_hm'] pc_hm = debugger.gen_colormap( batch['pc_hm'][i][channel].unsqueeze( 0).detach().cpu().numpy()) debugger.add_blend_img(img, pc_hm, 'pc_hm', trans=self.opt.hm_transparency) if 'pc_dep' in opt.pc_feat_lvl: channel = opt.pc_feat_channels['pc_dep'] pc_hm = batch['pc_hm'][i][channel].unsqueeze( 0).detach().cpu().numpy() pc_dep = debugger.add_overlay_img(img, pc_hm, 'pc_dep') if 'pre_img' in batch: pre_img = batch['pre_img'][i].detach().cpu().numpy().transpose( 1, 2, 0) pre_img = np.clip( ((pre_img * dataset.std + dataset.mean) * 255), 0, 255).astype(np.uint8) debugger.add_img(pre_img, 'pre_img_pred') debugger.add_img(pre_img, 'pre_img_gt') if 'pre_hm' in batch: pre_hm = debugger.gen_colormap( batch['pre_hm'][i].detach().cpu().numpy()) debugger.add_blend_img(pre_img, pre_hm, 'pre_hm', trans=self.opt.hm_transparency) debugger.add_img(img, img_id='out_pred') if 'ltrb_amodal' in opt.heads: debugger.add_img(img, img_id='out_pred_amodal') debugger.add_img(img, img_id='out_gt_amodal') # Predictions for k in range(len(dets['scores'][i])): if dets['scores'][i, k] > opt.vis_thresh: debugger.add_coco_bbox(dets['bboxes'][i, k] * opt.down_ratio, dets['clses'][i, k], dets['scores'][i, k], img_id='out_pred') if 'ltrb_amodal' in opt.heads: debugger.add_coco_bbox(dets['bboxes_amodal'][i, k] * opt.down_ratio, dets['clses'][i, k], dets['scores'][i, k], img_id='out_pred_amodal') if 'hps' in opt.heads and int(dets['clses'][i, k]) == 0: debugger.add_coco_hp(dets['hps'][i, k] * opt.down_ratio, img_id='out_pred') if 'tracking' in opt.heads: debugger.add_arrow(dets['cts'][i][k] * opt.down_ratio, dets['tracking'][i][k] * opt.down_ratio, img_id='out_pred') debugger.add_arrow(dets['cts'][i][k] * opt.down_ratio, dets['tracking'][i][k] * opt.down_ratio, img_id='pre_img_pred') # Ground truth debugger.add_img(img, img_id='out_gt') for k in range(len(dets_gt['scores'][i])): if dets_gt['scores'][i][k] > opt.vis_thresh: if 'dep' in dets_gt.keys(): dist = dets_gt['dep'][i][k] if len(dist) > 1: dist = dist[0] else: dist = -1 debugger.add_coco_bbox(dets_gt['bboxes'][i][k] * opt.down_ratio, dets_gt['clses'][i][k], dets_gt['scores'][i][k], img_id='out_gt', dist=dist) if 'ltrb_amodal' in opt.heads: debugger.add_coco_bbox(dets_gt['bboxes_amodal'][i, k] * opt.down_ratio, dets_gt['clses'][i, k], dets_gt['scores'][i, k], img_id='out_gt_amodal') if 'hps' in opt.heads and \ (int(dets['clses'][i, k]) == 0): debugger.add_coco_hp(dets_gt['hps'][i][k] * opt.down_ratio, img_id='out_gt') if 'tracking' in opt.heads: debugger.add_arrow( dets_gt['cts'][i][k] * opt.down_ratio, dets_gt['tracking'][i][k] * opt.down_ratio, img_id='out_gt') debugger.add_arrow( dets_gt['cts'][i][k] * opt.down_ratio, dets_gt['tracking'][i][k] * opt.down_ratio, img_id='pre_img_gt') if 'hm_hp' in opt.heads: pred = debugger.gen_colormap_hp( output['hm_hp'][i].detach().cpu().numpy()) gt = debugger.gen_colormap_hp( batch['hm_hp'][i].detach().cpu().numpy()) debugger.add_blend_img(img, pred, 'pred_hmhp', trans=self.opt.hm_transparency) debugger.add_blend_img(img, gt, 'gt_hmhp', trans=self.opt.hm_transparency) if 'rot' in opt.heads and 'dim' in opt.heads and 'dep' in opt.heads: dets_gt = {k: dets_gt[k].cpu().numpy() for k in dets_gt} calib = batch['meta']['calib'].detach().numpy() \ if 'calib' in batch['meta'] else None det_pred = generic_post_process( opt, dets, batch['meta']['c'].cpu().numpy(), batch['meta']['s'].cpu().numpy(), output['hm'].shape[2], output['hm'].shape[3], self.opt.num_classes, calib) det_gt = generic_post_process(opt, dets_gt, batch['meta']['c'].cpu().numpy(), batch['meta']['s'].cpu().numpy(), output['hm'].shape[2], output['hm'].shape[3], self.opt.num_classes, calib, is_gt=True) debugger.add_3d_detection(batch['meta']['img_path'][i], batch['meta']['flipped'][i], det_pred[i], calib[i], vis_thresh=opt.vis_thresh, img_id='add_pred') debugger.add_3d_detection(batch['meta']['img_path'][i], batch['meta']['flipped'][i], det_gt[i], calib[i], vis_thresh=opt.vis_thresh, img_id='add_gt') pc_3d = None if opt.pointcloud: pc_3d = batch['pc_3d'].cpu().numpy() debugger.add_bird_views(det_pred[i], det_gt[i], vis_thresh=opt.vis_thresh, img_id='bird_pred_gt', pc_3d=pc_3d, show_velocity=opt.show_velocity) debugger.add_bird_views([], det_gt[i], vis_thresh=opt.vis_thresh, img_id='bird_gt', pc_3d=pc_3d, show_velocity=opt.show_velocity) if opt.debug == 4: debugger.save_all_imgs(opt.debug_dir, prefix='{}'.format(iter_id)) else: debugger.show_all_imgs(pause=True) def val(self, epoch, data_loader): return self.run_epoch('val', epoch, data_loader) def train(self, epoch, data_loader): return self.run_epoch('train', epoch, data_loader)
class Trainer(object): def __init__( self, opt, model, optimizer=None, logger=None): self.opt = opt self.optimizer = optimizer param = list(model.neck.parameters())[-2] # print(param.shape) self.loss_stats, self.loss = self._get_losses(opt, logger, param) if opt.pad_net: len_stat = len(self.loss_stats) for state_n in range(len_stat): self.loss_stats.append(self.loss_stats[state_n] + "_inter") self.loss_stats.append('tot_2nd') self.model_with_loss = ModleWithLoss(model, self.loss, opt) self.old_norm = 0 def set_device(self, gpus, chunk_sizes, device): if len(gpus) > 1: self.model_with_loss = DataParallel( self.model_with_loss, device_ids=gpus, chunk_sizes=chunk_sizes).to(device) else: self.model_with_loss = self.model_with_loss.to(device) for state in self.optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device=device, non_blocking=True) def run_epoch(self, phase, epoch, data_loader): model_with_loss = self.model_with_loss if len(self.opt.gpus) > 1: if self.opt.weight_strategy == 'GRADNORM': model_with_loss.module.loss.optimizer = get_loss_optimizer(model=model_with_loss.module.loss.loss_model, opt=self.opt) model_with_loss.module.loss.update_weight(epoch) else: if self.opt.weight_strategy == 'GRADNORM': model_with_loss.loss.optimizer = get_loss_optimizer(model=model_with_loss.loss.loss_model, opt=self.opt) model_with_loss.loss.update_weight(epoch) if phase == 'train': model_with_loss.train() else: if len(self.opt.gpus) > 1: model_with_loss = self.model_with_loss.module model_with_loss.eval() torch.cuda.empty_cache() opt = self.opt results = {} data_time, batch_time = AverageMeter(), AverageMeter() avg_loss_stats = {l: AverageMeter() for l in self.loss_stats} #\ # if l == 'tot' or opt.weights[l] > 0} num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters) end = time.time() for iter_id, batch in enumerate(data_loader): if iter_id >= num_iters: break data_time.update(time.time() - end) for k in batch: if k != 'meta': batch[k] = batch[k].to(device=opt.device, non_blocking=True) output, loss, loss_stats = model_with_loss(batch, epoch) loss = loss.mean() if phase == 'train': self.optimizer.zero_grad() if opt.weight_strategy == 'GRADNORM': loss.backward(retain_graph=True) if len(self.opt.gpus) > 1: # model_with_loss.module.loss.loss_model.update_weight(model_with_loss.module.model, model_with_loss.module.loss.optimizer, loss_stats) # torch.sum(loss_stats['update']).backward() model_with_loss.module.loss.optimizer.zero_grad() temp_grad = torch.autograd.grad(torch.sum(loss_stats['update']), model_with_loss.module.loss.loss_model.weight)[0] grad_norm = torch.norm(temp_grad.data, 1) print(grad_norm) if grad_norm > opt.gradnorm_thred: temp_grad = torch.zeros_like(temp_grad) model_with_loss.module.loss.loss_model.weight.grad = temp_grad model_with_loss.module.loss.optimizer.step() else: # model_with_loss.loss.loss_model.update_weight(model_with_loss.model, # model_with_loss.loss.optimizer, loss_stats) model_with_loss.loss.optimizer.zero_grad() temp_grad = torch.autograd.grad(loss_stats['update'], model_with_loss.loss.loss_model. weight)[0] grad_norm = torch.norm(temp_grad.data, 1) if grad_norm > opt.gradnorm_thred: temp_grad = torch.zeros_like(temp_grad) model_with_loss.loss.loss_model.weight.grad = temp_grad model_with_loss.loss.optimizer.step() else: # torch.autograd.grad(loss, model_with_loss.model.padnet.parameters()) loss.backward() self.optimizer.step() if opt.weight_strategy == 'UNCER': if len(self.opt.gpus) > 1: model_with_loss.module.loss.optimizer.step() model_with_loss.module.loss.optimizer.zero_grad() print(model_with_loss.module.loss.group_weight) else: model_with_loss.loss.optimizer.step() model_with_loss.loss.optimizer.zero_grad() print(model_with_loss.loss.group_weight) batch_time.update(time.time() - end) end = time.time() Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format( epoch, iter_id, num_iters, phase=phase, total=bar.elapsed_td, eta=bar.eta_td) for l in avg_loss_stats: avg_loss_stats[l].update( loss_stats[l].mean().item(), batch['image'].size(0)) Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(l, avg_loss_stats[l].avg) Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \ '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time) if opt.print_iter > 0: # If not using progress bar if iter_id % opt.print_iter == 0: print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix)) else: bar.next() if opt.debug > 0: self.debug(batch, output, iter_id, dataset=data_loader.dataset) del output, loss, loss_stats bar.finish() ret = {k: v.avg for k, v in avg_loss_stats.items()} ret['time'] = bar.elapsed_td.total_seconds() / 60. if len(self.opt.gpus) > 1: model_with_loss.module.loss.update_loss(epoch, ret) else: model_with_loss.loss.update_loss(epoch, ret) return ret, results def _get_losses(self, opt, logger=None, param=None): loss_order = ['hm', 'wh', 'reg', 'ltrb', 'hps', 'hm_hp', \ 'hp_offset', 'dep', 'dim', 'rot', 'amodel_offset', \ 'ltrb_amodal', 'tracking', 'nuscenes_att', 'velocity'] loss_states = ['tot'] + [k for k in loss_order if k in opt.heads] # loss = GenericLoss(opt) loss = LossWithStrategy(opt, logger, param) return loss_states, loss def debug(self, batch, output, iter_id, dataset): opt = self.opt if 'pre_hm' in batch: output.update({'pre_hm': batch['pre_hm']}) dets = generic_decode(output, K=opt.K, opt=opt) for k in dets: dets[k] = dets[k].detach().cpu().numpy() dets_gt = batch['meta']['gt_det'] for i in range(1): debugger = Debugger(opt=opt, dataset=dataset) img = batch['image'][i].detach().cpu().numpy().transpose(1, 2, 0) img = np.clip((( img * dataset.std + dataset.mean) * 255.), 0, 255).astype(np.uint8) pred = debugger.gen_colormap(output['hm'][i].detach().cpu().numpy()) gt = debugger.gen_colormap(batch['hm'][i].detach().cpu().numpy()) debugger.add_blend_img(img, pred, 'pred_hm') debugger.add_blend_img(img, gt, 'gt_hm') if 'pre_img' in batch: pre_img = batch['pre_img'][i].detach().cpu().numpy().transpose(1, 2, 0) pre_img = np.clip((( pre_img * dataset.std + dataset.mean) * 255), 0, 255).astype(np.uint8) debugger.add_img(pre_img, 'pre_img_pred') debugger.add_img(pre_img, 'pre_img_gt') if 'pre_hm' in batch: pre_hm = debugger.gen_colormap( batch['pre_hm'][i].detach().cpu().numpy()) debugger.add_blend_img(pre_img, pre_hm, 'pre_hm') debugger.add_img(img, img_id='out_pred') if 'ltrb_amodal' in opt.heads: debugger.add_img(img, img_id='out_pred_amodal') debugger.add_img(img, img_id='out_gt_amodal') # Predictions for k in range(len(dets['scores'][i])): if dets['scores'][i, k] > opt.vis_thresh: debugger.add_coco_bbox( dets['bboxes'][i, k] * opt.down_ratio, dets['clses'][i, k], dets['scores'][i, k], img_id='out_pred') if 'ltrb_amodal' in opt.heads: debugger.add_coco_bbox( dets['bboxes_amodal'][i, k] * opt.down_ratio, dets['clses'][i, k], dets['scores'][i, k], img_id='out_pred_amodal') if 'hps' in opt.heads and int(dets['clses'][i, k]) == 0: debugger.add_coco_hp( dets['hps'][i, k] * opt.down_ratio, img_id='out_pred') if 'tracking' in opt.heads: debugger.add_arrow( dets['cts'][i][k] * opt.down_ratio, dets['tracking'][i][k] * opt.down_ratio, img_id='out_pred') debugger.add_arrow( dets['cts'][i][k] * opt.down_ratio, dets['tracking'][i][k] * opt.down_ratio, img_id='pre_img_pred') # Ground truth debugger.add_img(img, img_id='out_gt') for k in range(len(dets_gt['scores'][i])): if dets_gt['scores'][i][k] > opt.vis_thresh: debugger.add_coco_bbox( dets_gt['bboxes'][i][k] * opt.down_ratio, dets_gt['clses'][i][k], dets_gt['scores'][i][k], img_id='out_gt') if 'ltrb_amodal' in opt.heads: debugger.add_coco_bbox( dets_gt['bboxes_amodal'][i, k] * opt.down_ratio, dets_gt['clses'][i, k], dets_gt['scores'][i, k], img_id='out_gt_amodal') if 'hps' in opt.heads and \ (int(dets['clses'][i, k]) == 0): debugger.add_coco_hp( dets_gt['hps'][i][k] * opt.down_ratio, img_id='out_gt') if 'tracking' in opt.heads: debugger.add_arrow( dets_gt['cts'][i][k] * opt.down_ratio, dets_gt['tracking'][i][k] * opt.down_ratio, img_id='out_gt') debugger.add_arrow( dets_gt['cts'][i][k] * opt.down_ratio, dets_gt['tracking'][i][k] * opt.down_ratio, img_id='pre_img_gt') if 'hm_hp' in opt.heads: pred = debugger.gen_colormap_hp( output['hm_hp'][i].detach().cpu().numpy()) gt = debugger.gen_colormap_hp(batch['hm_hp'][i].detach().cpu().numpy()) debugger.add_blend_img(img, pred, 'pred_hmhp') debugger.add_blend_img(img, gt, 'gt_hmhp') if 'rot' in opt.heads and 'dim' in opt.heads and 'dep' in opt.heads: dets_gt = {k: dets_gt[k].cpu().numpy() for k in dets_gt} calib = batch['meta']['calib'].detach().numpy() \ if 'calib' in batch['meta'] else None det_pred = generic_post_process(opt, dets, batch['meta']['c'].cpu().numpy(), batch['meta']['s'].cpu().numpy(), output['hm'].shape[2], output['hm'].shape[3], self.opt.num_classes, calib) det_gt = generic_post_process(opt, dets_gt, batch['meta']['c'].cpu().numpy(), batch['meta']['s'].cpu().numpy(), output['hm'].shape[2], output['hm'].shape[3], self.opt.num_classes, calib) debugger.add_3d_detection( batch['meta']['img_path'][i], batch['meta']['flipped'][i], det_pred[i], calib[i], vis_thresh=opt.vis_thresh, img_id='add_pred') debugger.add_3d_detection( batch['meta']['img_path'][i], batch['meta']['flipped'][i], det_gt[i], calib[i], vis_thresh=opt.vis_thresh, img_id='add_gt') debugger.add_bird_views(det_pred[i], det_gt[i], vis_thresh=opt.vis_thresh, img_id='bird_pred_gt') if opt.debug == 4: debugger.save_all_imgs(opt.debug_dir, prefix='{}'.format(iter_id)) else: debugger.show_all_imgs(pause=True) def val(self, epoch, data_loader): return self.run_epoch('val', epoch, data_loader) def train(self, epoch, data_loader): return self.run_epoch('train', epoch, data_loader)