Exemple #1
0
class Trainer(object):
    def __init__(self, opt, model, optimizer=None):
        self.opt = opt
        self.optimizer = optimizer
        self.loss_stats, self.loss = self._get_losses(opt)
        self.model_with_loss = ModleWithLoss(model, self.loss)

    def set_device(self, gpus, chunk_sizes, device):
        if len(gpus) > 1:
            self.model_with_loss = DataParallel(
                self.model_with_loss, device_ids=gpus,
                chunk_sizes=chunk_sizes).to(device)
        else:
            self.model_with_loss = self.model_with_loss.to(device)

        for state in self.optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device=device, non_blocking=True)

    def run_epoch(self, phase, epoch, data_loader):
        print("yesyes ")
        model_with_loss = self.model_with_loss
        if phase == "train":
            model_with_loss.train()
        else:
            if len(self.opt.gpus) > 1:
                model_with_loss = self.model_with_loss.module
            model_with_loss.eval()
            torch.cuda.empty_cache()
        opt = self.opt
        results = {}
        data_time, batch_time = AverageMeter(), AverageMeter()
        avg_loss_stats = {
            l: AverageMeter()
            for l in self.loss_stats if l == "tot" or opt.weights[l] > 0
        }
        num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters
        bar = Bar("{}/{}".format(opt.task, opt.exp_id), max=num_iters)
        end = time.time()
        for iter_id, batch in enumerate(data_loader):
            if iter_id >= num_iters:
                break
            data_time.update(time.time() - end)

            for k in batch:
                if k != "meta":
                    batch[k] = batch[k].to(device=opt.device,
                                           non_blocking=True)
            output, loss, loss_stats = model_with_loss(batch)
            loss = loss.mean()
            if phase == "train":
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
            batch_time.update(time.time() - end)
            end = time.time()

            Bar.suffix = "{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} ".format(
                epoch,
                iter_id,
                num_iters,
                phase=phase,
                total=bar.elapsed_td,
                eta=bar.eta_td,
            )
            for l in avg_loss_stats:
                avg_loss_stats[l].update(loss_stats[l].mean().item(),
                                         batch["image"].size(0))
                Bar.suffix = Bar.suffix + "|{} {:.4f} ".format(
                    l, avg_loss_stats[l].avg)
            Bar.suffix = (
                Bar.suffix + "|Data {dt.val:.3f}s({dt.avg:.3f}s) "
                "|Net {bt.avg:.3f}s".format(dt=data_time, bt=batch_time))
            if opt.print_iter > 0:  # If not using progress bar
                if iter_id % opt.print_iter == 0:
                    print("{}/{}| {}".format(opt.task, opt.exp_id, Bar.suffix))
            else:
                bar.next()

            if opt.debug > 0:
                self.debug(batch, output, iter_id, dataset=data_loader.dataset)

            del output, loss, loss_stats

        bar.finish()
        ret = {k: v.avg for k, v in avg_loss_stats.items()}
        ret["time"] = bar.elapsed_td.total_seconds() / 60.0
        return ret, results

    def _get_losses(self, opt):
        loss_order = [
            "hm",
            "wh",
            "reg",
            "ltrb",
            "hps",
            "hm_hp",
            "hp_offset",
            "dep",
            "dim",
            "rot",
            "amodel_offset",
            "ltrb_amodal",
            "tracking",
            "nuscenes_att",
            "velocity",
        ]
        loss_states = ["tot"] + [k for k in loss_order if k in opt.heads]
        loss = GenericLoss(opt)
        return loss_states, loss

    def debug(self, batch, output, iter_id, dataset):
        opt = self.opt
        if "pre_hm" in batch:
            output.update({"pre_hm": batch["pre_hm"]})
        dets = generic_decode(output, K=opt.K, opt=opt)
        for k in dets:
            dets[k] = dets[k].detach().cpu().numpy()
        dets_gt = batch["meta"]["gt_det"]
        for i in range(1):
            debugger = Debugger(opt=opt, dataset=dataset)
            img = batch["image"][i].detach().cpu().numpy().transpose(1, 2, 0)
            img = np.clip(((img * dataset.std + dataset.mean) * 255.0), 0,
                          255).astype(np.uint8)
            pred = debugger.gen_colormap(
                output["hm"][i].detach().cpu().numpy())
            gt = debugger.gen_colormap(batch["hm"][i].detach().cpu().numpy())
            debugger.add_blend_img(img, pred, "pred_hm")
            debugger.add_blend_img(img, gt, "gt_hm")

            if "pre_img" in batch:
                pre_img = batch["pre_img"][i].detach().cpu().numpy().transpose(
                    1, 2, 0)
                pre_img = np.clip(
                    ((pre_img * dataset.std + dataset.mean) * 255), 0,
                    255).astype(np.uint8)
                debugger.add_img(pre_img, "pre_img_pred")
                debugger.add_img(pre_img, "pre_img_gt")
                if "pre_hm" in batch:
                    pre_hm = debugger.gen_colormap(
                        batch["pre_hm"][i].detach().cpu().numpy())
                    debugger.add_blend_img(pre_img, pre_hm, "pre_hm")

            debugger.add_img(img, img_id="out_pred")
            if "ltrb_amodal" in opt.heads:
                debugger.add_img(img, img_id="out_pred_amodal")
                debugger.add_img(img, img_id="out_gt_amodal")

            # Predictions
            for k in range(len(dets["scores"][i])):
                if dets["scores"][i, k] > opt.vis_thresh:
                    debugger.add_coco_bbox(
                        dets["bboxes"][i, k] * opt.down_ratio,
                        dets["clses"][i, k],
                        dets["scores"][i, k],
                        img_id="out_pred",
                    )

                    if "ltrb_amodal" in opt.heads:
                        debugger.add_coco_bbox(
                            dets["bboxes_amodal"][i, k] * opt.down_ratio,
                            dets["clses"][i, k],
                            dets["scores"][i, k],
                            img_id="out_pred_amodal",
                        )

                    if "hps" in opt.heads and int(dets["clses"][i, k]) == 0:
                        debugger.add_coco_hp(dets["hps"][i, k] *
                                             opt.down_ratio,
                                             img_id="out_pred")

                    if "tracking" in opt.heads:
                        debugger.add_arrow(
                            dets["cts"][i][k] * opt.down_ratio,
                            dets["tracking"][i][k] * opt.down_ratio,
                            img_id="out_pred",
                        )
                        debugger.add_arrow(
                            dets["cts"][i][k] * opt.down_ratio,
                            dets["tracking"][i][k] * opt.down_ratio,
                            img_id="pre_img_pred",
                        )

            # Ground truth
            debugger.add_img(img, img_id="out_gt")
            for k in range(len(dets_gt["scores"][i])):
                if dets_gt["scores"][i][k] > opt.vis_thresh:
                    debugger.add_coco_bbox(
                        dets_gt["bboxes"][i][k] * opt.down_ratio,
                        dets_gt["clses"][i][k],
                        dets_gt["scores"][i][k],
                        img_id="out_gt",
                    )

                    if "ltrb_amodal" in opt.heads:
                        debugger.add_coco_bbox(
                            dets_gt["bboxes_amodal"][i, k] * opt.down_ratio,
                            dets_gt["clses"][i, k],
                            dets_gt["scores"][i, k],
                            img_id="out_gt_amodal",
                        )

                    if "hps" in opt.heads and (int(dets["clses"][i, k]) == 0):
                        debugger.add_coco_hp(dets_gt["hps"][i][k] *
                                             opt.down_ratio,
                                             img_id="out_gt")

                    if "tracking" in opt.heads:
                        debugger.add_arrow(
                            dets_gt["cts"][i][k] * opt.down_ratio,
                            dets_gt["tracking"][i][k] * opt.down_ratio,
                            img_id="out_gt",
                        )
                        debugger.add_arrow(
                            dets_gt["cts"][i][k] * opt.down_ratio,
                            dets_gt["tracking"][i][k] * opt.down_ratio,
                            img_id="pre_img_gt",
                        )

            if "hm_hp" in opt.heads:
                pred = debugger.gen_colormap_hp(
                    output["hm_hp"][i].detach().cpu().numpy())
                gt = debugger.gen_colormap_hp(
                    batch["hm_hp"][i].detach().cpu().numpy())
                debugger.add_blend_img(img, pred, "pred_hmhp")
                debugger.add_blend_img(img, gt, "gt_hmhp")

            if "rot" in opt.heads and "dim" in opt.heads and "dep" in opt.heads:
                dets_gt = {k: dets_gt[k].cpu().numpy() for k in dets_gt}
                calib = (batch["meta"]["calib"].detach().numpy()
                         if "calib" in batch["meta"] else None)
                det_pred = generic_post_process(
                    opt,
                    dets,
                    batch["meta"]["c"].cpu().numpy(),
                    batch["meta"]["s"].cpu().numpy(),
                    output["hm"].shape[2],
                    output["hm"].shape[3],
                    self.opt.num_classes,
                    calib,
                )
                det_gt = generic_post_process(
                    opt,
                    dets_gt,
                    batch["meta"]["c"].cpu().numpy(),
                    batch["meta"]["s"].cpu().numpy(),
                    output["hm"].shape[2],
                    output["hm"].shape[3],
                    self.opt.num_classes,
                    calib,
                )

                debugger.add_3d_detection(
                    batch["meta"]["img_path"][i],
                    batch["meta"]["flipped"][i],
                    det_pred[i],
                    calib[i],
                    vis_thresh=opt.vis_thresh,
                    img_id="add_pred",
                )
                debugger.add_3d_detection(
                    batch["meta"]["img_path"][i],
                    batch["meta"]["flipped"][i],
                    det_gt[i],
                    calib[i],
                    vis_thresh=opt.vis_thresh,
                    img_id="add_gt",
                )
                debugger.add_bird_views(
                    det_pred[i],
                    det_gt[i],
                    vis_thresh=opt.vis_thresh,
                    img_id="bird_pred_gt",
                )

            if opt.debug == 4:
                debugger.save_all_imgs(opt.debug_dir,
                                       prefix="{}".format(iter_id))
            else:
                debugger.show_all_imgs(pause=True)

    def val(self, epoch, data_loader):
        return self.run_epoch("val", epoch, data_loader)

    def train(self, epoch, data_loader):
        return self.run_epoch("train", epoch, data_loader)
class Trainer_Custom(object):
    def __init__(self, opt, model, optimizer=None):
        self.opt = opt
        self.optimizer = optimizer
        self.loss_stats, self.loss = self._get_losses(opt)
        self.model_with_loss = ModleWithLoss(model, self.loss)

    def set_device(self, gpus, chunk_sizes, device):
        if len(gpus) > 1:
            self.model_with_loss = DataParallel(
                self.model_with_loss, device_ids=gpus,
                chunk_sizes=chunk_sizes).to(device)
        else:
            self.model_with_loss = self.model_with_loss.to(device)

        for state in self.optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device=device, non_blocking=True)

    def _get_pre_fts(self, pre_imgs, pre_cts_fix, pre_len):
        # batch['image']: N*3*h*w
        # batch['pre_imgs']: N*T*3*H*W
        # batch['pre_cts_fix']:N*T*Om*2
        # batch['pre_len']: N*T
        # hidden_ft: N*T*C*H*W

        # output: N*T*Om*C

        assert len(pre_imgs.shape) == 5 \
               and len(pre_cts_fix.shape) == 4 \
               and len(pre_len.shape) == 2

        assert pre_imgs.shape[0] == pre_cts_fix.shape[0] \
               and pre_len.shape[0] == pre_cts_fix.shape[0]
        N = pre_imgs.shape[0]
        assert pre_imgs.shape[1] == pre_cts_fix.shape[1] and pre_len.shape[
            1] == pre_cts_fix.shape[1]
        T = pre_imgs.shape[1]
        assert pre_imgs.shape[2] == 3 \
               and pre_cts_fix.shape[3] == 2 and pre_cts_fix.shape[2] == self.opt.K
        Om = self.opt.K
        origin_H = pre_imgs.shape[3]
        origin_W = pre_imgs.shape[4]
        ft_H = int(origin_H / 4)
        ft_W = int(origin_W / 4)

        if self.opt.Freeze_ft:
            with torch.no_grad():
                assert pre_imgs.is_contiguous()
                pre_imgs_viewed = pre_imgs.view(N * T, 3, origin_H, origin_W)
                hidden_ft = self.model_with_loss.model.img2feats(
                    pre_imgs_viewed)[-1]

                assert hidden_ft.shape[0] == N*T \
                       and hidden_ft.shape[2] == int(origin_H/4) \
                       and hidden_ft.shape[3] == int(origin_W/4)
                Channel_O = hidden_ft.shape[1]
                assert pre_cts_fix.is_contiguous()
                pre_cts_fix_viewed = pre_cts_fix.view(N * T, Om, 2)
                assert pre_len.is_contiguous()
                pre_len_viewed = pre_len.view(N * T)
                ret_ft = torch.zeros((N * T, Om, Channel_O))
                ret_ft = ret_ft.to(device=self.opt.device, non_blocking=True)
                for iterator_i in range(0, N * T):
                    this_len = pre_len_viewed[iterator_i].detach().cpu().numpy(
                    )
                    # print(pre_len_viewed[iterator_i].detach().cpu().numpy())
                    for iterator_o in range(0, this_len):
                        H_o = pre_cts_fix_viewed[iterator_i, iterator_o,
                                                 1].detach().cpu().numpy()
                        W_o = pre_cts_fix_viewed[iterator_i, iterator_o,
                                                 0].detach().cpu().numpy()

                        ft_o = hidden_ft[
                            iterator_i, :,
                            int(np.rint(H_o)
                                ) if int(np.rint(H_o)) < ft_H else ft_H - 1,
                            int(np.rint(W_o)
                                ) if int(np.rint(W_o)) < ft_W else ft_W - 1]
                        ret_ft[iterator_i, iterator_o, :] = ft_o

            assert ret_ft.is_contiguous()
            ret = ret_ft.view(N, T, Om,
                              Channel_O).detach().requires_grad_(False)
            del ret_ft
            del pre_imgs
            del hidden_ft
            del pre_cts_fix_viewed
            del pre_imgs_viewed
            torch.cuda.empty_cache()
        else:
            assert pre_imgs.is_contiguous()
            pre_imgs_viewed = pre_imgs.view(N * T, 3, origin_H, origin_W)
            hidden_ft = self.model_with_loss.model.img2feats(
                pre_imgs_viewed)[-1]
            assert hidden_ft.shape[0] == N*T \
                   and hidden_ft.shape[2] == int(origin_H/4) \
                   and hidden_ft.shape[3] == int(origin_W/4)
            Channel_O = hidden_ft.shape[1]
            assert pre_cts_fix.is_contiguous()
            pre_cts_fix_viewed = pre_cts_fix.view(N * T, Om, 2)
            assert pre_len.is_contiguous()
            pre_len_viewed = pre_len.view(N * T)
            ret_ft = torch.zeros((N * T, Om, Channel_O))
            ret_ft = ret_ft.to(device=self.opt.device, non_blocking=True)
            for iterator_i in range(0, N * T):
                this_len = pre_len_viewed[iterator_i].detach().cpu().numpy()
                # print(pre_len_viewed[iterator_i].detach().cpu().numpy())
                for iterator_o in range(0, this_len):
                    H_o = pre_cts_fix_viewed[iterator_i, iterator_o,
                                             1].detach().cpu().numpy()
                    W_o = pre_cts_fix_viewed[iterator_i, iterator_o,
                                             0].detach().cpu().numpy()

                    ft_o = hidden_ft[
                        iterator_i, :,
                        int(np.rint(H_o)
                            ) if int(np.rint(H_o)) < ft_H else ft_H - 1,
                        int(np.rint(W_o)
                            ) if int(np.rint(W_o)) < ft_W else ft_W - 1]
                    ret_ft[iterator_i, iterator_o, :] = ft_o

            assert ret_ft.is_contiguous()
            ret = ret_ft.view(N, T, Om, Channel_O)

        return ret

    def run_epoch(self, phase, epoch, data_loader):
        model_with_loss = self.model_with_loss
        if phase == 'train':
            model_with_loss.train()
        else:
            if len(self.opt.gpus) > 1:
                model_with_loss = self.model_with_loss.module
            model_with_loss.eval()
            torch.cuda.empty_cache()

        opt = self.opt
        results = {}
        data_time, batch_time = AverageMeter(), AverageMeter()
        avg_loss_stats = {l: AverageMeter() for l in self.loss_stats \
                          if l == 'tot' or opt.weights[l] > 0}
        num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters
        bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters)
        end = time.time()
        for iter_id, batch in enumerate(data_loader):

            # print(batch.keys())
            # print(batch['hm_old'].shape)
            # print(batch['hm_new'].shape)
            # print(batch['hm'].shape)
            # print(batch['hm_old'].mean())
            # print(batch['hm_new'].mean())
            # print(batch['hm'].mean())
            # exit()
            if iter_id >= num_iters:
                break
            data_time.update(time.time() - end)

            for k in batch:
                if k != 'meta':
                    batch[k] = batch[k].to(device=opt.device,
                                           non_blocking=True)

            # print(batch['image'].shape)
            # print(batch['pre_imgs'].shape)
            # print(batch['pre_cts_fix'].shape)
            # print(batch['pre_len'].shape)
            assert batch['image'].shape[0] == batch['pre_imgs'].shape[0]

            # batch['image']: N*3*h*w
            # batch['pre_imgs']: N*T*3*H*W
            # batch['pre_cts_fix']:N*T*Om*2
            # batch['pre_len']: N*T
            # pr_object_ft: N*T*Om*C

            pr_object_ft = self._get_pre_fts(batch['pre_imgs'],
                                             batch['pre_cts_fix'],
                                             batch['pre_len'])

            # for n in range(0, batch['pre_len'].shape[0]):
            #   for t in range(0, batch['pre_len'].shape[1]):
            #     print(batch['pre_len'][n, t])
            #     print(pr_object_ft[n, t, :, :])
            # exit()

            output, loss, loss_stats = model_with_loss(batch, pr_object_ft)
            loss = loss.mean()
            if phase == 'train':
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
            batch_time.update(time.time() - end)
            end = time.time()

            Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format(
                epoch,
                iter_id,
                num_iters,
                phase=phase,
                total=bar.elapsed_td,
                eta=bar.eta_td)
            for l in avg_loss_stats:
                avg_loss_stats[l].update(loss_stats[l].mean().item(),
                                         batch['image'].size(0))
                Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(
                    l, avg_loss_stats[l].avg)
            Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \
              '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time)
            if opt.print_iter > 0:  # If not using progress bar
                if iter_id % opt.print_iter == 0:
                    print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix))
            else:
                bar.next()

            if opt.debug > 0:
                self.debug(batch, output, iter_id, dataset=data_loader.dataset)

            del output, loss, loss_stats

        bar.finish()
        ret = {k: v.avg for k, v in avg_loss_stats.items()}
        ret['time'] = bar.elapsed_td.total_seconds() / 60.
        return ret, results

    def _get_losses(self, opt):
        loss_order = ['hm', 'wh', 'reg', 'ltrb', 'hps', 'hm_hp', \
          'hp_offset', 'dep', 'dim', 'rot', 'amodel_offset', \
          'ltrb_amodal', 'tracking', 'nuscenes_att', 'velocity']
        loss_states = ['tot'] + [k for k in loss_order if k in opt.heads]
        loss = GenericLoss(opt)
        return loss_states, loss

    def debug(self, batch, output, iter_id, dataset):
        opt = self.opt
        if 'pre_hm' in batch:
            output.update({'pre_hm': batch['pre_hm']})
        dets = generic_decode(output, K=opt.K, opt=opt)
        for k in dets:
            dets[k] = dets[k].detach().cpu().numpy()
        dets_gt = batch['meta']['gt_det']
        for i in range(1):
            debugger = Debugger(opt=opt, dataset=dataset)
            img = batch['image'][i].detach().cpu().numpy().transpose(1, 2, 0)
            img = np.clip(((img * dataset.std + dataset.mean) * 255.), 0,
                          255).astype(np.uint8)
            pred = debugger.gen_colormap(
                output['hm'][i].detach().cpu().numpy())
            gt = debugger.gen_colormap(batch['hm'][i].detach().cpu().numpy())
            debugger.add_blend_img(img, pred, 'pred_hm')
            debugger.add_blend_img(img, gt, 'gt_hm')

            if 'pre_img' in batch:
                pre_img = batch['pre_img'][i].detach().cpu().numpy().transpose(
                    1, 2, 0)
                pre_img = np.clip(
                    ((pre_img * dataset.std + dataset.mean) * 255), 0,
                    255).astype(np.uint8)
                debugger.add_img(pre_img, 'pre_img_pred')
                debugger.add_img(pre_img, 'pre_img_gt')
                if 'pre_hm' in batch:
                    pre_hm = debugger.gen_colormap(
                        batch['pre_hm'][i].detach().cpu().numpy())
                    debugger.add_blend_img(pre_img, pre_hm, 'pre_hm')

            debugger.add_img(img, img_id='out_pred')
            if 'ltrb_amodal' in opt.heads:
                debugger.add_img(img, img_id='out_pred_amodal')
                debugger.add_img(img, img_id='out_gt_amodal')

            # Predictions
            for k in range(len(dets['scores'][i])):
                if dets['scores'][i, k] > opt.vis_thresh:
                    debugger.add_coco_bbox(dets['bboxes'][i, k] *
                                           opt.down_ratio,
                                           dets['clses'][i, k],
                                           dets['scores'][i, k],
                                           img_id='out_pred')

                    if 'ltrb_amodal' in opt.heads:
                        debugger.add_coco_bbox(dets['bboxes_amodal'][i, k] *
                                               opt.down_ratio,
                                               dets['clses'][i, k],
                                               dets['scores'][i, k],
                                               img_id='out_pred_amodal')

                    if 'hps' in opt.heads and int(dets['clses'][i, k]) == 0:
                        debugger.add_coco_hp(dets['hps'][i, k] *
                                             opt.down_ratio,
                                             img_id='out_pred')

                    if 'tracking' in opt.heads:
                        debugger.add_arrow(dets['cts'][i][k] * opt.down_ratio,
                                           dets['tracking'][i][k] *
                                           opt.down_ratio,
                                           img_id='out_pred')
                        debugger.add_arrow(dets['cts'][i][k] * opt.down_ratio,
                                           dets['tracking'][i][k] *
                                           opt.down_ratio,
                                           img_id='pre_img_pred')

            # Ground truth
            debugger.add_img(img, img_id='out_gt')
            for k in range(len(dets_gt['scores'][i])):
                if dets_gt['scores'][i][k] > opt.vis_thresh:
                    debugger.add_coco_bbox(dets_gt['bboxes'][i][k] *
                                           opt.down_ratio,
                                           dets_gt['clses'][i][k],
                                           dets_gt['scores'][i][k],
                                           img_id='out_gt')

                    if 'ltrb_amodal' in opt.heads:
                        debugger.add_coco_bbox(dets_gt['bboxes_amodal'][i, k] *
                                               opt.down_ratio,
                                               dets_gt['clses'][i, k],
                                               dets_gt['scores'][i, k],
                                               img_id='out_gt_amodal')

                    if 'hps' in opt.heads and \
                      (int(dets['clses'][i, k]) == 0):
                        debugger.add_coco_hp(dets_gt['hps'][i][k] *
                                             opt.down_ratio,
                                             img_id='out_gt')

                    if 'tracking' in opt.heads:
                        debugger.add_arrow(
                            dets_gt['cts'][i][k] * opt.down_ratio,
                            dets_gt['tracking'][i][k] * opt.down_ratio,
                            img_id='out_gt')
                        debugger.add_arrow(
                            dets_gt['cts'][i][k] * opt.down_ratio,
                            dets_gt['tracking'][i][k] * opt.down_ratio,
                            img_id='pre_img_gt')

            if 'hm_hp' in opt.heads:
                pred = debugger.gen_colormap_hp(
                    output['hm_hp'][i].detach().cpu().numpy())
                gt = debugger.gen_colormap_hp(
                    batch['hm_hp'][i].detach().cpu().numpy())
                debugger.add_blend_img(img, pred, 'pred_hmhp')
                debugger.add_blend_img(img, gt, 'gt_hmhp')

            if 'rot' in opt.heads and 'dim' in opt.heads and 'dep' in opt.heads:
                dets_gt = {k: dets_gt[k].cpu().numpy() for k in dets_gt}
                calib = batch['meta']['calib'].detach().numpy() \
                        if 'calib' in batch['meta'] else None
                det_pred = generic_post_process(
                    opt, dets, batch['meta']['c'].cpu().numpy(),
                    batch['meta']['s'].cpu().numpy(), output['hm'].shape[2],
                    output['hm'].shape[3], self.opt.num_classes, calib)
                det_gt = generic_post_process(opt, dets_gt,
                                              batch['meta']['c'].cpu().numpy(),
                                              batch['meta']['s'].cpu().numpy(),
                                              output['hm'].shape[2],
                                              output['hm'].shape[3],
                                              self.opt.num_classes, calib)

                debugger.add_3d_detection(batch['meta']['img_path'][i],
                                          batch['meta']['flipped'][i],
                                          det_pred[i],
                                          calib[i],
                                          vis_thresh=opt.vis_thresh,
                                          img_id='add_pred')
                debugger.add_3d_detection(batch['meta']['img_path'][i],
                                          batch['meta']['flipped'][i],
                                          det_gt[i],
                                          calib[i],
                                          vis_thresh=opt.vis_thresh,
                                          img_id='add_gt')
                debugger.add_bird_views(det_pred[i],
                                        det_gt[i],
                                        vis_thresh=opt.vis_thresh,
                                        img_id='bird_pred_gt')

            if opt.debug == 4:
                debugger.save_all_imgs(opt.debug_dir,
                                       prefix='{}'.format(iter_id))
            else:
                debugger.show_all_imgs(pause=True)

    def val(self, epoch, data_loader):
        return self.run_epoch('val', epoch, data_loader)

    def train(self, epoch, data_loader):
        return self.run_epoch('train', epoch, data_loader)
Exemple #3
0
class Trainer(object):
    def __init__(self, opt, model, optimizer=None):
        self.opt = opt
        self.optimizer = optimizer
        self.loss_stats, self.loss = self._get_losses(opt)
        self.model_with_loss = ModelWithLoss(model, self.loss, opt)

    def set_device(self, gpus, chunk_sizes, device):
        if len(gpus) > 1:
            self.model_with_loss = DataParallel(
                self.model_with_loss, device_ids=gpus,
                chunk_sizes=chunk_sizes).to(device)
        else:
            self.model_with_loss = self.model_with_loss.to(device)

        for state in self.optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device=device, non_blocking=True)

    def run_epoch(self, phase, epoch, data_loader):
        model_with_loss = self.model_with_loss
        if phase == 'train':
            model_with_loss.train()
        else:
            if len(self.opt.gpus) > 1:
                model_with_loss = self.model_with_loss.module
            model_with_loss.eval()
            torch.cuda.empty_cache()

        opt = self.opt
        results = {}
        data_time, batch_time = AverageMeter(), AverageMeter()
        avg_loss_stats = {l: AverageMeter() for l in self.loss_stats \
                          if l == 'tot' or opt.weights[l] > 0}
        num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters
        bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters)
        end = time.time()

        for iter_id, batch in enumerate(data_loader):
            if iter_id >= num_iters:
                break
            data_time.update(time.time() - end)
            for k in batch:
                if k != 'meta':
                    batch[k] = batch[k].to(device=opt.device,
                                           non_blocking=True)

                    # run one iteration
            output, loss, loss_stats = model_with_loss(batch, phase)

            # backpropagate and step optimizer
            loss = loss.mean()
            if phase == 'train':
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
            batch_time.update(time.time() - end)
            end = time.time()

            Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format(
                epoch,
                iter_id,
                num_iters,
                phase=phase,
                total=bar.elapsed_td,
                eta=bar.eta_td)
            for l in avg_loss_stats:
                avg_loss_stats[l].update(loss_stats[l].mean().item(),
                                         batch['image'].size(0))
                Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(
                    l, avg_loss_stats[l].avg)
            Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \
                                      '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time)
            if opt.print_iter > 0:  # If not using progress bar
                if iter_id % opt.print_iter == 0:
                    print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix))
            else:
                bar.next()

            if opt.debug > 0:
                self.debug(batch, output, iter_id, dataset=data_loader.dataset)

            # generate detections for evaluation
            if (phase == 'val' and (opt.run_dataset_eval or opt.eval)):
                meta = batch['meta']
                dets = fusion_decode(output, K=opt.K, opt=opt)

                for k in dets:
                    dets[k] = dets[k].detach().cpu().numpy()

                calib = meta['calib'].detach().numpy(
                ) if 'calib' in meta else None
                dets = generic_post_process(opt, dets, meta['c'].cpu().numpy(),
                                            meta['s'].cpu().numpy(),
                                            output['hm'].shape[2],
                                            output['hm'].shape[3],
                                            self.opt.num_classes, calib)

                # merge results
                result = []
                for i in range(len(dets[0])):
                    if dets[0][i]['score'] > self.opt.out_thresh and all(
                            dets[0][i]['dim'] > 0):
                        result.append(dets[0][i])

                img_id = batch['meta']['img_id'].numpy().astype(np.int32)[0]
                results[img_id] = result

            del output, loss, loss_stats

        bar.finish()
        ret = {k: v.avg for k, v in avg_loss_stats.items()}
        ret['time'] = bar.elapsed_td.total_seconds() / 60.
        return ret, results

    def _get_losses(self, opt):
        loss_order = ['hm', 'wh', 'reg', 'ltrb', 'hps', 'hm_hp', \
                      'hp_offset', 'dep', 'dep_sec', 'dim', 'rot', 'rot_sec',
                      'amodel_offset', 'ltrb_amodal', 'tracking', 'nuscenes_att', 'velocity']
        loss_states = ['tot'] + [k for k in loss_order if k in opt.heads]
        loss = GenericLoss(opt)
        return loss_states, loss

    def debug(self, batch, output, iter_id, dataset):
        opt = self.opt
        if 'pre_hm' in batch:
            output.update({'pre_hm': batch['pre_hm']})
        dets = fusion_decode(output, K=opt.K, opt=opt)
        for k in dets:
            dets[k] = dets[k].detach().cpu().numpy()
        dets_gt = batch['meta']['gt_det']
        for i in range(1):
            debugger = Debugger(opt=opt, dataset=dataset)
            img = batch['image'][i].detach().cpu().numpy().transpose(1, 2, 0)
            img = np.clip(((img * dataset.std + dataset.mean) * 255.), 0,
                          255).astype(np.uint8)
            pred = debugger.gen_colormap(
                output['hm'][i].detach().cpu().numpy())
            gt = debugger.gen_colormap(batch['hm'][i].detach().cpu().numpy())
            debugger.add_blend_img(img,
                                   pred,
                                   'pred_hm',
                                   trans=self.opt.hm_transparency)
            debugger.add_blend_img(img,
                                   gt,
                                   'gt_hm',
                                   trans=self.opt.hm_transparency)

            debugger.add_img(img, img_id='img')

            # show point clouds
            if opt.pointcloud:
                pc_2d = batch['pc_2d'][i].detach().cpu().numpy()
                pc_3d = None
                pc_N = batch['pc_N'][i].detach().cpu().numpy()
                debugger.add_img(img, img_id='pc')
                debugger.add_pointcloud(pc_2d, pc_N, img_id='pc')

                if 'pc_hm' in opt.pc_feat_lvl:
                    channel = opt.pc_feat_channels['pc_hm']
                    pc_hm = debugger.gen_colormap(
                        batch['pc_hm'][i][channel].unsqueeze(
                            0).detach().cpu().numpy())
                    debugger.add_blend_img(img,
                                           pc_hm,
                                           'pc_hm',
                                           trans=self.opt.hm_transparency)
                if 'pc_dep' in opt.pc_feat_lvl:
                    channel = opt.pc_feat_channels['pc_dep']
                    pc_hm = batch['pc_hm'][i][channel].unsqueeze(
                        0).detach().cpu().numpy()
                    pc_dep = debugger.add_overlay_img(img, pc_hm, 'pc_dep')

            if 'pre_img' in batch:
                pre_img = batch['pre_img'][i].detach().cpu().numpy().transpose(
                    1, 2, 0)
                pre_img = np.clip(
                    ((pre_img * dataset.std + dataset.mean) * 255), 0,
                    255).astype(np.uint8)
                debugger.add_img(pre_img, 'pre_img_pred')
                debugger.add_img(pre_img, 'pre_img_gt')
                if 'pre_hm' in batch:
                    pre_hm = debugger.gen_colormap(
                        batch['pre_hm'][i].detach().cpu().numpy())
                    debugger.add_blend_img(pre_img,
                                           pre_hm,
                                           'pre_hm',
                                           trans=self.opt.hm_transparency)

            debugger.add_img(img, img_id='out_pred')
            if 'ltrb_amodal' in opt.heads:
                debugger.add_img(img, img_id='out_pred_amodal')
                debugger.add_img(img, img_id='out_gt_amodal')

            # Predictions
            for k in range(len(dets['scores'][i])):
                if dets['scores'][i, k] > opt.vis_thresh:
                    debugger.add_coco_bbox(dets['bboxes'][i, k] *
                                           opt.down_ratio,
                                           dets['clses'][i, k],
                                           dets['scores'][i, k],
                                           img_id='out_pred')

                    if 'ltrb_amodal' in opt.heads:
                        debugger.add_coco_bbox(dets['bboxes_amodal'][i, k] *
                                               opt.down_ratio,
                                               dets['clses'][i, k],
                                               dets['scores'][i, k],
                                               img_id='out_pred_amodal')

                    if 'hps' in opt.heads and int(dets['clses'][i, k]) == 0:
                        debugger.add_coco_hp(dets['hps'][i, k] *
                                             opt.down_ratio,
                                             img_id='out_pred')

                    if 'tracking' in opt.heads:
                        debugger.add_arrow(dets['cts'][i][k] * opt.down_ratio,
                                           dets['tracking'][i][k] *
                                           opt.down_ratio,
                                           img_id='out_pred')
                        debugger.add_arrow(dets['cts'][i][k] * opt.down_ratio,
                                           dets['tracking'][i][k] *
                                           opt.down_ratio,
                                           img_id='pre_img_pred')

            # Ground truth
            debugger.add_img(img, img_id='out_gt')
            for k in range(len(dets_gt['scores'][i])):
                if dets_gt['scores'][i][k] > opt.vis_thresh:
                    if 'dep' in dets_gt.keys():
                        dist = dets_gt['dep'][i][k]
                        if len(dist) > 1:
                            dist = dist[0]
                    else:
                        dist = -1
                    debugger.add_coco_bbox(dets_gt['bboxes'][i][k] *
                                           opt.down_ratio,
                                           dets_gt['clses'][i][k],
                                           dets_gt['scores'][i][k],
                                           img_id='out_gt',
                                           dist=dist)

                    if 'ltrb_amodal' in opt.heads:
                        debugger.add_coco_bbox(dets_gt['bboxes_amodal'][i, k] *
                                               opt.down_ratio,
                                               dets_gt['clses'][i, k],
                                               dets_gt['scores'][i, k],
                                               img_id='out_gt_amodal')

                    if 'hps' in opt.heads and \
                            (int(dets['clses'][i, k]) == 0):
                        debugger.add_coco_hp(dets_gt['hps'][i][k] *
                                             opt.down_ratio,
                                             img_id='out_gt')

                    if 'tracking' in opt.heads:
                        debugger.add_arrow(
                            dets_gt['cts'][i][k] * opt.down_ratio,
                            dets_gt['tracking'][i][k] * opt.down_ratio,
                            img_id='out_gt')
                        debugger.add_arrow(
                            dets_gt['cts'][i][k] * opt.down_ratio,
                            dets_gt['tracking'][i][k] * opt.down_ratio,
                            img_id='pre_img_gt')

            if 'hm_hp' in opt.heads:
                pred = debugger.gen_colormap_hp(
                    output['hm_hp'][i].detach().cpu().numpy())
                gt = debugger.gen_colormap_hp(
                    batch['hm_hp'][i].detach().cpu().numpy())
                debugger.add_blend_img(img,
                                       pred,
                                       'pred_hmhp',
                                       trans=self.opt.hm_transparency)
                debugger.add_blend_img(img,
                                       gt,
                                       'gt_hmhp',
                                       trans=self.opt.hm_transparency)

            if 'rot' in opt.heads and 'dim' in opt.heads and 'dep' in opt.heads:
                dets_gt = {k: dets_gt[k].cpu().numpy() for k in dets_gt}
                calib = batch['meta']['calib'].detach().numpy() \
                    if 'calib' in batch['meta'] else None
                det_pred = generic_post_process(
                    opt, dets, batch['meta']['c'].cpu().numpy(),
                    batch['meta']['s'].cpu().numpy(), output['hm'].shape[2],
                    output['hm'].shape[3], self.opt.num_classes, calib)
                det_gt = generic_post_process(opt,
                                              dets_gt,
                                              batch['meta']['c'].cpu().numpy(),
                                              batch['meta']['s'].cpu().numpy(),
                                              output['hm'].shape[2],
                                              output['hm'].shape[3],
                                              self.opt.num_classes,
                                              calib,
                                              is_gt=True)

                debugger.add_3d_detection(batch['meta']['img_path'][i],
                                          batch['meta']['flipped'][i],
                                          det_pred[i],
                                          calib[i],
                                          vis_thresh=opt.vis_thresh,
                                          img_id='add_pred')
                debugger.add_3d_detection(batch['meta']['img_path'][i],
                                          batch['meta']['flipped'][i],
                                          det_gt[i],
                                          calib[i],
                                          vis_thresh=opt.vis_thresh,
                                          img_id='add_gt')

                pc_3d = None
                if opt.pointcloud:
                    pc_3d = batch['pc_3d'].cpu().numpy()

                debugger.add_bird_views(det_pred[i],
                                        det_gt[i],
                                        vis_thresh=opt.vis_thresh,
                                        img_id='bird_pred_gt',
                                        pc_3d=pc_3d,
                                        show_velocity=opt.show_velocity)
                debugger.add_bird_views([],
                                        det_gt[i],
                                        vis_thresh=opt.vis_thresh,
                                        img_id='bird_gt',
                                        pc_3d=pc_3d,
                                        show_velocity=opt.show_velocity)

            if opt.debug == 4:
                debugger.save_all_imgs(opt.debug_dir,
                                       prefix='{}'.format(iter_id))
            else:
                debugger.show_all_imgs(pause=True)

    def val(self, epoch, data_loader):
        return self.run_epoch('val', epoch, data_loader)

    def train(self, epoch, data_loader):
        return self.run_epoch('train', epoch, data_loader)
Exemple #4
0
class Trainer(object):
    def __init__(
            self, opt, model, optimizer=None, logger=None):
        self.opt = opt
        self.optimizer = optimizer
        param = list(model.neck.parameters())[-2]
        # print(param.shape)
        self.loss_stats, self.loss = self._get_losses(opt, logger, param)
        if opt.pad_net:
            len_stat = len(self.loss_stats)
            for state_n in range(len_stat):
                self.loss_stats.append(self.loss_stats[state_n] + "_inter")
            self.loss_stats.append('tot_2nd')
        self.model_with_loss = ModleWithLoss(model, self.loss, opt)
        self.old_norm = 0

    def set_device(self, gpus, chunk_sizes, device):
        if len(gpus) > 1:
            self.model_with_loss = DataParallel(
                self.model_with_loss, device_ids=gpus,
                chunk_sizes=chunk_sizes).to(device)
        else:
            self.model_with_loss = self.model_with_loss.to(device)

        for state in self.optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device=device, non_blocking=True)

    def run_epoch(self, phase, epoch, data_loader):
        model_with_loss = self.model_with_loss
        if len(self.opt.gpus) > 1:
            if self.opt.weight_strategy == 'GRADNORM':
                model_with_loss.module.loss.optimizer = get_loss_optimizer(model=model_with_loss.module.loss.loss_model,
                                                                           opt=self.opt)
            model_with_loss.module.loss.update_weight(epoch)
        else:
            if self.opt.weight_strategy == 'GRADNORM':
                model_with_loss.loss.optimizer = get_loss_optimizer(model=model_with_loss.loss.loss_model, opt=self.opt)
            model_with_loss.loss.update_weight(epoch)
        if phase == 'train':
            model_with_loss.train()
        else:
            if len(self.opt.gpus) > 1:
                model_with_loss = self.model_with_loss.module
            model_with_loss.eval()
            torch.cuda.empty_cache()

        opt = self.opt
        results = {}
        data_time, batch_time = AverageMeter(), AverageMeter()
        avg_loss_stats = {l: AverageMeter() for l in self.loss_stats} #\
                          # if l == 'tot' or opt.weights[l] > 0}
        num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters
        bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters)
        end = time.time()
        for iter_id, batch in enumerate(data_loader):
            if iter_id >= num_iters:
                break
            data_time.update(time.time() - end)

            for k in batch:
                if k != 'meta':
                    batch[k] = batch[k].to(device=opt.device, non_blocking=True)
            output, loss, loss_stats = model_with_loss(batch, epoch)
            loss = loss.mean()
            if phase == 'train':
                self.optimizer.zero_grad()
                if opt.weight_strategy == 'GRADNORM':
                    loss.backward(retain_graph=True)
                    if len(self.opt.gpus) > 1:
                        # model_with_loss.module.loss.loss_model.update_weight(model_with_loss.module.model, model_with_loss.module.loss.optimizer, loss_stats)
                        # torch.sum(loss_stats['update']).backward()
                        model_with_loss.module.loss.optimizer.zero_grad()
                        temp_grad = torch.autograd.grad(torch.sum(loss_stats['update']),
                                                        model_with_loss.module.loss.loss_model.weight)[0]
                        grad_norm = torch.norm(temp_grad.data, 1)
                        print(grad_norm)
                        if grad_norm > opt.gradnorm_thred:
                            temp_grad = torch.zeros_like(temp_grad)
                        model_with_loss.module.loss.loss_model.weight.grad = temp_grad
                        model_with_loss.module.loss.optimizer.step()
                    else:
                        # model_with_loss.loss.loss_model.update_weight(model_with_loss.model,
                        #                                               model_with_loss.loss.optimizer, loss_stats)
                        model_with_loss.loss.optimizer.zero_grad()
                        temp_grad = torch.autograd.grad(loss_stats['update'],
                                                        model_with_loss.loss.loss_model.
                                                        weight)[0]
                        grad_norm = torch.norm(temp_grad.data, 1)
                        if grad_norm > opt.gradnorm_thred:
                            temp_grad = torch.zeros_like(temp_grad)
                        model_with_loss.loss.loss_model.weight.grad = temp_grad
                        model_with_loss.loss.optimizer.step()
                else:
                    # torch.autograd.grad(loss, model_with_loss.model.padnet.parameters())
                    loss.backward()
                self.optimizer.step()
                if opt.weight_strategy == 'UNCER':
                    if len(self.opt.gpus) > 1:
                        model_with_loss.module.loss.optimizer.step()
                        model_with_loss.module.loss.optimizer.zero_grad()
                        print(model_with_loss.module.loss.group_weight)
                    else:
                        model_with_loss.loss.optimizer.step()
                        model_with_loss.loss.optimizer.zero_grad()
                        print(model_with_loss.loss.group_weight)
            batch_time.update(time.time() - end)
            end = time.time()

            Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format(
                epoch, iter_id, num_iters, phase=phase,
                total=bar.elapsed_td, eta=bar.eta_td)
            for l in avg_loss_stats:
                avg_loss_stats[l].update(
                    loss_stats[l].mean().item(), batch['image'].size(0))
                Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(l, avg_loss_stats[l].avg)
            Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \
                                      '|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time)
            if opt.print_iter > 0:  # If not using progress bar
                if iter_id % opt.print_iter == 0:
                    print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix))
            else:
                bar.next()

            if opt.debug > 0:
                self.debug(batch, output, iter_id, dataset=data_loader.dataset)

            del output, loss, loss_stats

        bar.finish()
        ret = {k: v.avg for k, v in avg_loss_stats.items()}
        ret['time'] = bar.elapsed_td.total_seconds() / 60.
        if len(self.opt.gpus) > 1:
            model_with_loss.module.loss.update_loss(epoch, ret)
        else:
            model_with_loss.loss.update_loss(epoch, ret)
        return ret, results

    def _get_losses(self, opt, logger=None, param=None):
        loss_order = ['hm', 'wh', 'reg', 'ltrb', 'hps', 'hm_hp', \
                      'hp_offset', 'dep', 'dim', 'rot', 'amodel_offset', \
                      'ltrb_amodal', 'tracking', 'nuscenes_att', 'velocity']
        loss_states = ['tot'] + [k for k in loss_order if k in opt.heads]
        # loss = GenericLoss(opt)
        loss = LossWithStrategy(opt, logger, param)
        return loss_states, loss

    def debug(self, batch, output, iter_id, dataset):
        opt = self.opt
        if 'pre_hm' in batch:
            output.update({'pre_hm': batch['pre_hm']})
        dets = generic_decode(output, K=opt.K, opt=opt)
        for k in dets:
            dets[k] = dets[k].detach().cpu().numpy()
        dets_gt = batch['meta']['gt_det']
        for i in range(1):
            debugger = Debugger(opt=opt, dataset=dataset)
            img = batch['image'][i].detach().cpu().numpy().transpose(1, 2, 0)
            img = np.clip(((
                                   img * dataset.std + dataset.mean) * 255.), 0, 255).astype(np.uint8)
            pred = debugger.gen_colormap(output['hm'][i].detach().cpu().numpy())
            gt = debugger.gen_colormap(batch['hm'][i].detach().cpu().numpy())
            debugger.add_blend_img(img, pred, 'pred_hm')
            debugger.add_blend_img(img, gt, 'gt_hm')

            if 'pre_img' in batch:
                pre_img = batch['pre_img'][i].detach().cpu().numpy().transpose(1, 2, 0)
                pre_img = np.clip(((
                                           pre_img * dataset.std + dataset.mean) * 255), 0, 255).astype(np.uint8)
                debugger.add_img(pre_img, 'pre_img_pred')
                debugger.add_img(pre_img, 'pre_img_gt')
                if 'pre_hm' in batch:
                    pre_hm = debugger.gen_colormap(
                        batch['pre_hm'][i].detach().cpu().numpy())
                    debugger.add_blend_img(pre_img, pre_hm, 'pre_hm')

            debugger.add_img(img, img_id='out_pred')
            if 'ltrb_amodal' in opt.heads:
                debugger.add_img(img, img_id='out_pred_amodal')
                debugger.add_img(img, img_id='out_gt_amodal')

            # Predictions
            for k in range(len(dets['scores'][i])):
                if dets['scores'][i, k] > opt.vis_thresh:
                    debugger.add_coco_bbox(
                        dets['bboxes'][i, k] * opt.down_ratio, dets['clses'][i, k],
                        dets['scores'][i, k], img_id='out_pred')

                    if 'ltrb_amodal' in opt.heads:
                        debugger.add_coco_bbox(
                            dets['bboxes_amodal'][i, k] * opt.down_ratio, dets['clses'][i, k],
                            dets['scores'][i, k], img_id='out_pred_amodal')

                    if 'hps' in opt.heads and int(dets['clses'][i, k]) == 0:
                        debugger.add_coco_hp(
                            dets['hps'][i, k] * opt.down_ratio, img_id='out_pred')

                    if 'tracking' in opt.heads:
                        debugger.add_arrow(
                            dets['cts'][i][k] * opt.down_ratio,
                            dets['tracking'][i][k] * opt.down_ratio, img_id='out_pred')
                        debugger.add_arrow(
                            dets['cts'][i][k] * opt.down_ratio,
                            dets['tracking'][i][k] * opt.down_ratio, img_id='pre_img_pred')

            # Ground truth
            debugger.add_img(img, img_id='out_gt')
            for k in range(len(dets_gt['scores'][i])):
                if dets_gt['scores'][i][k] > opt.vis_thresh:
                    debugger.add_coco_bbox(
                        dets_gt['bboxes'][i][k] * opt.down_ratio, dets_gt['clses'][i][k],
                        dets_gt['scores'][i][k], img_id='out_gt')

                    if 'ltrb_amodal' in opt.heads:
                        debugger.add_coco_bbox(
                            dets_gt['bboxes_amodal'][i, k] * opt.down_ratio,
                            dets_gt['clses'][i, k],
                            dets_gt['scores'][i, k], img_id='out_gt_amodal')

                    if 'hps' in opt.heads and \
                            (int(dets['clses'][i, k]) == 0):
                        debugger.add_coco_hp(
                            dets_gt['hps'][i][k] * opt.down_ratio, img_id='out_gt')

                    if 'tracking' in opt.heads:
                        debugger.add_arrow(
                            dets_gt['cts'][i][k] * opt.down_ratio,
                            dets_gt['tracking'][i][k] * opt.down_ratio, img_id='out_gt')
                        debugger.add_arrow(
                            dets_gt['cts'][i][k] * opt.down_ratio,
                            dets_gt['tracking'][i][k] * opt.down_ratio, img_id='pre_img_gt')

            if 'hm_hp' in opt.heads:
                pred = debugger.gen_colormap_hp(
                    output['hm_hp'][i].detach().cpu().numpy())
                gt = debugger.gen_colormap_hp(batch['hm_hp'][i].detach().cpu().numpy())
                debugger.add_blend_img(img, pred, 'pred_hmhp')
                debugger.add_blend_img(img, gt, 'gt_hmhp')

            if 'rot' in opt.heads and 'dim' in opt.heads and 'dep' in opt.heads:
                dets_gt = {k: dets_gt[k].cpu().numpy() for k in dets_gt}
                calib = batch['meta']['calib'].detach().numpy() \
                    if 'calib' in batch['meta'] else None
                det_pred = generic_post_process(opt, dets,
                                                batch['meta']['c'].cpu().numpy(), batch['meta']['s'].cpu().numpy(),
                                                output['hm'].shape[2], output['hm'].shape[3], self.opt.num_classes,
                                                calib)
                det_gt = generic_post_process(opt, dets_gt,
                                              batch['meta']['c'].cpu().numpy(), batch['meta']['s'].cpu().numpy(),
                                              output['hm'].shape[2], output['hm'].shape[3], self.opt.num_classes,
                                              calib)

                debugger.add_3d_detection(
                    batch['meta']['img_path'][i], batch['meta']['flipped'][i],
                    det_pred[i], calib[i],
                    vis_thresh=opt.vis_thresh, img_id='add_pred')
                debugger.add_3d_detection(
                    batch['meta']['img_path'][i], batch['meta']['flipped'][i],
                    det_gt[i], calib[i],
                    vis_thresh=opt.vis_thresh, img_id='add_gt')
                debugger.add_bird_views(det_pred[i], det_gt[i],
                                        vis_thresh=opt.vis_thresh, img_id='bird_pred_gt')

            if opt.debug == 4:
                debugger.save_all_imgs(opt.debug_dir, prefix='{}'.format(iter_id))
            else:
                debugger.show_all_imgs(pause=True)

    def val(self, epoch, data_loader):
        return self.run_epoch('val', epoch, data_loader)

    def train(self, epoch, data_loader):
        return self.run_epoch('train', epoch, data_loader)