示例#1
0
def test_vis():
    dset_name = sys.argv[1]
    assert dset_name in DatasetCatalog.list()

    meta = MetadataCatalog.get(dset_name)
    dprint("MetadataCatalog: ", meta)
    objs = meta.objs

    t_start = time.perf_counter()
    dicts = DatasetCatalog.get(dset_name)
    logger.info("Done loading {} samples with {:.3f}s.".format(len(dicts), time.perf_counter() - t_start))

    dirname = "output/{}-data-vis".format(dset_name)
    os.makedirs(dirname, exist_ok=True)
    for d in dicts:
        img = read_image_cv2(d["file_name"], format="BGR")
        depth = mmcv.imread(d["depth_file"], "unchanged") / 1000.0

        anno = d["annotations"][0]  # only one instance per image
        imH, imW = img.shape[:2]
        mask = cocosegm2mask(anno["segmentation"], imH, imW)
        bbox = anno["bbox"]
        bbox_mode = anno["bbox_mode"]
        bbox_xyxy = np.array(BoxMode.convert(bbox, bbox_mode, BoxMode.XYXY_ABS))
        kpt3d = anno["bbox3d_and_center"]
        quat = anno["quat"]
        trans = anno["trans"]
        R = quat2mat(quat)
        # 0-based label
        cat_id = anno["category_id"]
        K = d["cam"]
        kpt_2d = misc.project_pts(kpt3d, K, R, trans)
        # # TODO: visualize pose and keypoints
        label = objs[cat_id]
        # img_vis = vis_image_bboxes_cv2(img, bboxes=bboxes_xyxy, labels=labels)
        img_vis = vis_image_mask_bbox_cv2(img, [mask], bboxes=[bbox_xyxy], labels=[label])
        img_vis_kpt2d = img.copy()
        img_vis_kpt2d = misc.draw_projected_box3d(
            img_vis_kpt2d, kpt_2d, middle_color=None, bottom_color=(128, 128, 128)
        )

        xyz_info = mmcv.load(anno["xyz_path"])
        xyz = np.zeros((imH, imW, 3), dtype=np.float32)
        xyz_crop = xyz_info["xyz_crop"].astype(np.float32)
        x1, y1, x2, y2 = xyz_info["xyxy"]
        xyz[y1 : y2 + 1, x1 : x2 + 1, :] = xyz_crop
        xyz_show = get_emb_show(xyz)

        grid_show(
            [img[:, :, [2, 1, 0]], img_vis[:, :, [2, 1, 0]], img_vis_kpt2d[:, :, [2, 1, 0]], depth, xyz_show],
            ["img", "vis_img", "img_vis_kpts2d", "depth", "emb_show"],
            row=2,
            col=3,
        )
示例#2
0
def test_vis():
    dset_name = sys.argv[1]
    assert dset_name in DatasetCatalog.list()

    meta = MetadataCatalog.get(dset_name)
    dprint("MetadataCatalog: ", meta)
    objs = meta.objs

    t_start = time.perf_counter()
    dicts = DatasetCatalog.get(dset_name)
    logger.info("Done loading {} samples with {:.3f}s.".format(
        len(dicts),
        time.perf_counter() - t_start))

    dirname = "output/{}-data-vis".format(dset_name)
    os.makedirs(dirname, exist_ok=True)
    for d in dicts:
        img = read_image_cv2(d["file_name"], format="BGR")
        depth = mmcv.imread(d["depth_file"], "unchanged") / 1000.0

        imH, imW = img.shape[:2]
        annos = d["annotations"]
        masks = [
            cocosegm2mask(anno["segmentation"], imH, imW) for anno in annos
        ]
        bboxes = [anno["bbox"] for anno in annos]
        bbox_modes = [anno["bbox_mode"] for anno in annos]
        bboxes_xyxy = np.array([
            BoxMode.convert(box, box_mode, BoxMode.XYXY_ABS)
            for box, box_mode in zip(bboxes, bbox_modes)
        ])
        kpts_3d_list = [anno["bbox3d_and_center"] for anno in annos]
        quats = [anno["quat"] for anno in annos]
        transes = [anno["trans"] for anno in annos]
        Rs = [quat2mat(quat) for quat in quats]
        # 0-based label
        cat_ids = [anno["category_id"] for anno in annos]
        K = d["cam"]
        kpts_2d = [
            misc.project_pts(kpt3d, K, R, t)
            for kpt3d, R, t in zip(kpts_3d_list, Rs, transes)
        ]
        # # TODO: visualize pose and keypoints
        labels = [objs[cat_id] for cat_id in cat_ids]
        for _i in range(len(annos)):
            img_vis = vis_image_mask_bbox_cv2(img,
                                              masks[_i:_i + 1],
                                              bboxes=bboxes_xyxy[_i:_i + 1],
                                              labels=labels[_i:_i + 1])
            img_vis_kpts2d = misc.draw_projected_box3d(img_vis.copy(),
                                                       kpts_2d[_i])
            if "test" not in dset_name:
                xyz_path = annos[_i]["xyz_path"]
                xyz_info = mmcv.load(xyz_path)
                x1, y1, x2, y2 = xyz_info["xyxy"]
                xyz_crop = xyz_info["xyz_crop"].astype(np.float32)
                xyz = np.zeros((imH, imW, 3), dtype=np.float32)
                xyz[y1:y2 + 1, x1:x2 + 1, :] = xyz_crop
                xyz_show = get_emb_show(xyz)
                xyz_crop_show = get_emb_show(xyz_crop)
                img_xyz = img.copy() / 255.0
                mask_xyz = ((xyz[:, :, 0] != 0) | (xyz[:, :, 1] != 0) |
                            (xyz[:, :, 2] != 0)).astype("uint8")
                fg_idx = np.where(mask_xyz != 0)
                img_xyz[fg_idx[0], fg_idx[1], :] = xyz_show[fg_idx[0],
                                                            fg_idx[1], :3]
                img_xyz_crop = img_xyz[y1:y2 + 1, x1:x2 + 1, :]
                img_vis_crop = img_vis[y1:y2 + 1, x1:x2 + 1, :]
                # diff mask
                diff_mask_xyz = np.abs(masks[_i] - mask_xyz)[y1:y2 + 1,
                                                             x1:x2 + 1]

                grid_show(
                    [
                        img[:, :, [2, 1, 0]],
                        img_vis[:, :, [2, 1, 0]],
                        img_vis_kpts2d[:, :, [2, 1, 0]],
                        depth,
                        # xyz_show,
                        diff_mask_xyz,
                        xyz_crop_show,
                        img_xyz[:, :, [2, 1, 0]],
                        img_xyz_crop[:, :, [2, 1, 0]],
                        img_vis_crop,
                    ],
                    [
                        "img",
                        "vis_img",
                        "img_vis_kpts2d",
                        "depth",
                        "diff_mask_xyz",
                        "xyz_crop_show",
                        "img_xyz",
                        "img_xyz_crop",
                        "img_vis_crop",
                    ],
                    row=3,
                    col=3,
                )
            else:
                grid_show(
                    [
                        img[:, :, [2, 1, 0]], img_vis[:, :, [2, 1, 0]],
                        img_vis_kpts2d[:, :, [2, 1, 0]], depth
                    ],
                    ["img", "vis_img", "img_vis_kpts2d", "depth"],
                    row=2,
                    col=2,
                )
示例#3
0
文件: engine.py 项目: hz-ants/GDR-Net
def do_train(cfg, args, model, optimizer, resume=False):
    model.train()

    # some basic settings =========================
    dataset_meta = MetadataCatalog.get(cfg.DATASETS.TRAIN[0])
    data_ref = ref.__dict__[dataset_meta.ref_key]
    obj_names = dataset_meta.objs

    # load data ===================================
    train_dset_names = cfg.DATASETS.TRAIN
    data_loader = build_gdrn_train_loader(cfg, train_dset_names)
    data_loader_iter = iter(data_loader)

    # load 2nd train dataloader if needed
    train_2_dset_names = cfg.DATASETS.get("TRAIN2", ())
    train_2_ratio = cfg.DATASETS.get("TRAIN2_RATIO", 0.0)
    if train_2_ratio > 0.0 and len(train_2_dset_names) > 0:
        data_loader_2 = build_gdrn_train_loader(cfg, train_2_dset_names)
        data_loader_2_iter = iter(data_loader_2)
    else:
        data_loader_2 = None
        data_loader_2_iter = None

    images_per_batch = cfg.SOLVER.IMS_PER_BATCH
    if isinstance(data_loader, AspectRatioGroupedDataset):
        dataset_len = len(data_loader.dataset.dataset)
        if data_loader_2 is not None:
            dataset_len += len(data_loader_2.dataset.dataset)
        iters_per_epoch = dataset_len // images_per_batch
    else:
        dataset_len = len(data_loader.dataset)
        if data_loader_2 is not None:
            dataset_len += len(data_loader_2.dataset)
        iters_per_epoch = dataset_len // images_per_batch
    max_iter = cfg.SOLVER.TOTAL_EPOCHS * iters_per_epoch
    dprint("images_per_batch: ", images_per_batch)
    dprint("dataset length: ", dataset_len)
    dprint("iters per epoch: ", iters_per_epoch)
    dprint("total iters: ", max_iter)
    scheduler = solver_utils.build_lr_scheduler(cfg,
                                                optimizer,
                                                total_iters=max_iter)

    AMP_ON = cfg.SOLVER.AMP.ENABLED
    logger.info(f"AMP enabled: {AMP_ON}")
    grad_scaler = GradScaler()

    # resume or load model ===================================
    checkpointer = MyCheckpointer(
        model,
        cfg.OUTPUT_DIR,
        optimizer=optimizer,
        scheduler=scheduler,
        gradscaler=grad_scaler,
        save_to_disk=comm.is_main_process(),
    )
    start_iter = checkpointer.resume_or_load(
        cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1

    if comm._USE_HVD:  # hvd may be not available, so do not use the one in args
        # not needed
        # start_iter = hvd.broadcast(torch.tensor(start_iter), root_rank=0, name="start_iter").item()

        # Horovod: broadcast parameters & optimizer state.
        hvd.broadcast_parameters(model.state_dict(), root_rank=0)
        hvd.broadcast_optimizer_state(optimizer, root_rank=0)
        # Horovod: (optional) compression algorithm.
        compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none
        optimizer = hvd.DistributedOptimizer(
            optimizer,
            named_parameters=model.named_parameters(),
            op=hvd.Adasum if args.use_adasum else hvd.Average,
            compression=compression,
        )  # device_dense='/cpu:0'

    if cfg.SOLVER.CHECKPOINT_BY_EPOCH:
        ckpt_period = cfg.SOLVER.CHECKPOINT_PERIOD * iters_per_epoch
    else:
        ckpt_period = cfg.SOLVER.CHECKPOINT_PERIOD
    periodic_checkpointer = PeriodicCheckpointer(
        checkpointer,
        ckpt_period,
        max_iter=max_iter,
        max_to_keep=cfg.SOLVER.MAX_TO_KEEP)

    # build writers ==============================================
    tbx_event_writer = get_tbx_event_writer(
        cfg.OUTPUT_DIR, backup=not cfg.get("RESUME", False))
    tbx_writer = tbx_event_writer._writer  # NOTE: we want to write some non-scalar data
    writers = ([
        MyCommonMetricPrinter(max_iter),
        MyJSONWriter(osp.join(cfg.OUTPUT_DIR, "metrics.json")),
        tbx_event_writer
    ] if comm.is_main_process() else [])

    # compared to "train_net.py", we do not support accurate timing and
    # precise BN here, because they are not trivial to implement
    logger.info("Starting training from iteration {}".format(start_iter))
    iter_time = None
    with EventStorage(start_iter) as storage:
        # for data, iteration in zip(data_loader, range(start_iter, max_iter)):
        for iteration in range(start_iter, max_iter):
            storage.iter = iteration
            epoch = iteration // dataset_len + 1

            if np.random.rand() < train_2_ratio:
                data = next(data_loader_2_iter)
            else:
                data = next(data_loader_iter)

            if iter_time is not None:
                storage.put_scalar("time", time.perf_counter() - iter_time)
            iter_time = time.perf_counter()

            # forward ============================================================
            batch = batch_data(cfg, data)
            with autocast(enabled=AMP_ON):
                out_dict, loss_dict = model(
                    batch["roi_img"],
                    gt_xyz=batch.get("roi_xyz", None),
                    gt_xyz_bin=batch.get("roi_xyz_bin", None),
                    gt_mask_trunc=batch["roi_mask_trunc"],
                    gt_mask_visib=batch["roi_mask_visib"],
                    gt_mask_obj=batch["roi_mask_obj"],
                    gt_region=batch.get("roi_region", None),
                    gt_allo_quat=batch.get("allo_quat", None),
                    gt_ego_quat=batch.get("ego_quat", None),
                    gt_allo_rot6d=batch.get("allo_rot6d", None),
                    gt_ego_rot6d=batch.get("ego_rot6d", None),
                    gt_ego_rot=batch.get("ego_rot", None),
                    gt_trans=batch.get("trans", None),
                    gt_trans_ratio=batch["roi_trans_ratio"],
                    gt_points=batch.get("roi_points", None),
                    sym_infos=batch.get("sym_info", None),
                    roi_classes=batch["roi_cls"],
                    roi_cams=batch["roi_cam"],
                    roi_whs=batch["roi_wh"],
                    roi_centers=batch["roi_center"],
                    resize_ratios=batch["resize_ratio"],
                    roi_coord_2d=batch.get("roi_coord_2d", None),
                    roi_extents=batch.get("roi_extent", None),
                    do_loss=True,
                )
                losses = sum(loss_dict.values())
                assert torch.isfinite(losses).all(), loss_dict

            loss_dict_reduced = {
                k: v.item()
                for k, v in comm.reduce_dict(loss_dict).items()
            }
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            if comm.is_main_process():
                storage.put_scalars(total_loss=losses_reduced,
                                    **loss_dict_reduced)

            optimizer.zero_grad()
            if AMP_ON:
                grad_scaler.scale(losses).backward()

                # # Unscales the gradients of optimizer's assigned params in-place
                # grad_scaler.unscale_(optimizer)
                # # Since the gradients of optimizer's assigned params are unscaled, clips as usual:
                # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
                if comm._USE_HVD:
                    optimizer.synchronize()
                    with optimizer.skip_synchronize():
                        grad_scaler.step(optimizer)
                        grad_scaler.update()
                else:
                    grad_scaler.step(optimizer)
                    grad_scaler.update()
            else:
                losses.backward()
                optimizer.step()

            storage.put_scalar("lr",
                               optimizer.param_groups[0]["lr"],
                               smoothing_hint=False)
            scheduler.step()

            if cfg.TEST.EVAL_PERIOD > 0 and (
                    iteration + 1
            ) % cfg.TEST.EVAL_PERIOD == 0 and iteration != max_iter - 1:
                do_test(cfg, model, epoch=epoch, iteration=iteration)
                # Compared to "train_net.py", the test results are not dumped to EventStorage
                comm.synchronize()

            if iteration - start_iter > 5 and (
                (iteration + 1) % cfg.TRAIN.PRINT_FREQ == 0
                    or iteration == max_iter - 1 or iteration < 100):
                for writer in writers:
                    writer.write()
                # visualize some images ========================================
                if cfg.TRAIN.VIS_IMG:
                    with torch.no_grad():
                        vis_i = 0
                        roi_img_vis = batch["roi_img"][vis_i].cpu().numpy()
                        roi_img_vis = denormalize_image(roi_img_vis,
                                                        cfg).transpose(
                                                            1, 2,
                                                            0).astype("uint8")
                        tbx_writer.add_image("input_image", roi_img_vis,
                                             iteration)

                        out_coor_x = out_dict["coor_x"].detach()
                        out_coor_y = out_dict["coor_y"].detach()
                        out_coor_z = out_dict["coor_z"].detach()
                        out_xyz = get_out_coor(cfg, out_coor_x, out_coor_y,
                                               out_coor_z)

                        out_xyz_vis = out_xyz[vis_i].cpu().numpy().transpose(
                            1, 2, 0)
                        out_xyz_vis = get_emb_show(out_xyz_vis)
                        tbx_writer.add_image("out_xyz", out_xyz_vis, iteration)

                        gt_xyz_vis = batch["roi_xyz"][vis_i].cpu().numpy(
                        ).transpose(1, 2, 0)
                        gt_xyz_vis = get_emb_show(gt_xyz_vis)
                        tbx_writer.add_image("gt_xyz", gt_xyz_vis, iteration)

                        out_mask = out_dict["mask"].detach()
                        out_mask = get_out_mask(cfg, out_mask)
                        out_mask_vis = out_mask[vis_i, 0].cpu().numpy()
                        tbx_writer.add_image("out_mask", out_mask_vis,
                                             iteration)

                        gt_mask_vis = batch["roi_mask"][vis_i].detach().cpu(
                        ).numpy()
                        tbx_writer.add_image("gt_mask", gt_mask_vis, iteration)
            periodic_checkpointer.step(iteration, epoch=epoch)
示例#4
0
    def do_train(self, cfg, args, model, optimizer, resume=False):
        model.train()

        # some basic settings =========================
        dataset_meta = MetadataCatalog.get(cfg.DATASETS.TRAIN[0])
        data_ref = ref.__dict__[dataset_meta.ref_key]
        obj_names = dataset_meta.objs

        # load data ===================================
        train_dset_names = cfg.DATASETS.TRAIN
        data_loader = build_gdrn_train_loader(cfg, train_dset_names)
        data_loader_iter = iter(data_loader)

        # load 2nd train dataloader if needed
        train_2_dset_names = cfg.DATASETS.get("TRAIN2", ())
        train_2_ratio = cfg.DATASETS.get("TRAIN2_RATIO", 0.0)
        if train_2_ratio > 0.0 and len(train_2_dset_names) > 0:
            data_loader_2 = build_gdrn_train_loader(cfg, train_2_dset_names)
            data_loader_2_iter = iter(data_loader_2)
        else:
            data_loader_2 = None
            data_loader_2_iter = None

        images_per_batch = cfg.SOLVER.IMS_PER_BATCH
        if isinstance(data_loader, AspectRatioGroupedDataset):
            dataset_len = len(data_loader.dataset.dataset)
            if data_loader_2 is not None:
                dataset_len += len(data_loader_2.dataset.dataset)
            iters_per_epoch = dataset_len // images_per_batch
        else:
            dataset_len = len(data_loader.dataset)
            if data_loader_2 is not None:
                dataset_len += len(data_loader_2.dataset)
            iters_per_epoch = dataset_len // images_per_batch
        max_iter = cfg.SOLVER.TOTAL_EPOCHS * iters_per_epoch
        dprint("images_per_batch: ", images_per_batch)
        dprint("dataset length: ", dataset_len)
        dprint("iters per epoch: ", iters_per_epoch)
        dprint("total iters: ", max_iter)

        data_loader = self.setup_dataloaders(data_loader,
                                             replace_sampler=False,
                                             move_to_device=False)
        if data_loader_2 is not None:
            data_loader_2 = self.setup_dataloaders(data_loader_2,
                                                   replace_sampler=False,
                                                   move_to_device=False)

        scheduler = solver_utils.build_lr_scheduler(cfg,
                                                    optimizer,
                                                    total_iters=max_iter)

        # resume or load model ===================================
        extra_ckpt_dict = dict(
            optimizer=optimizer,
            scheduler=scheduler,
        )
        if hasattr(self._precision_plugin, "scaler"):
            extra_ckpt_dict["gradscaler"] = self._precision_plugin.scaler

        checkpointer = MyCheckpointer(
            model,
            cfg.OUTPUT_DIR,
            save_to_disk=self.is_global_zero,
            **extra_ckpt_dict,
        )
        start_iter = checkpointer.resume_or_load(
            cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1

        if cfg.SOLVER.CHECKPOINT_BY_EPOCH:
            ckpt_period = cfg.SOLVER.CHECKPOINT_PERIOD * iters_per_epoch
        else:
            ckpt_period = cfg.SOLVER.CHECKPOINT_PERIOD
        periodic_checkpointer = PeriodicCheckpointer(
            checkpointer,
            ckpt_period,
            max_iter=max_iter,
            max_to_keep=cfg.SOLVER.MAX_TO_KEEP)

        # build writers ==============================================
        tbx_event_writer = self.get_tbx_event_writer(
            cfg.OUTPUT_DIR, backup=not cfg.get("RESUME", False))
        tbx_writer = tbx_event_writer._writer  # NOTE: we want to write some non-scalar data
        writers = ([
            MyCommonMetricPrinter(max_iter),
            MyJSONWriter(osp.join(cfg.OUTPUT_DIR, "metrics.json")),
            tbx_event_writer
        ] if self.is_global_zero else [])

        # compared to "train_net.py", we do not support accurate timing and
        # precise BN here, because they are not trivial to implement
        logger.info("Starting training from iteration {}".format(start_iter))
        iter_time = None
        with EventStorage(start_iter) as storage:
            for iteration in range(start_iter, max_iter):
                storage.iter = iteration
                epoch = iteration // dataset_len + 1

                if np.random.rand() < train_2_ratio:
                    data = next(data_loader_2_iter)
                else:
                    data = next(data_loader_iter)

                if iter_time is not None:
                    storage.put_scalar("time", time.perf_counter() - iter_time)
                iter_time = time.perf_counter()

                # forward ============================================================
                batch = batch_data(cfg, data)

                out_dict, loss_dict = model(
                    batch["roi_img"],
                    gt_xyz=batch.get("roi_xyz", None),
                    gt_xyz_bin=batch.get("roi_xyz_bin", None),
                    gt_mask_trunc=batch["roi_mask_trunc"],
                    gt_mask_visib=batch["roi_mask_visib"],
                    gt_mask_obj=batch["roi_mask_obj"],
                    gt_region=batch.get("roi_region", None),
                    gt_allo_quat=batch.get("allo_quat", None),
                    gt_ego_quat=batch.get("ego_quat", None),
                    gt_allo_rot6d=batch.get("allo_rot6d", None),
                    gt_ego_rot6d=batch.get("ego_rot6d", None),
                    gt_ego_rot=batch.get("ego_rot", None),
                    gt_trans=batch.get("trans", None),
                    gt_trans_ratio=batch["roi_trans_ratio"],
                    gt_points=batch.get("roi_points", None),
                    sym_infos=batch.get("sym_info", None),
                    roi_classes=batch["roi_cls"],
                    roi_cams=batch["roi_cam"],
                    roi_whs=batch["roi_wh"],
                    roi_centers=batch["roi_center"],
                    resize_ratios=batch["resize_ratio"],
                    roi_coord_2d=batch.get("roi_coord_2d", None),
                    roi_extents=batch.get("roi_extent", None),
                    do_loss=True,
                )
                losses = sum(loss_dict.values())
                assert torch.isfinite(losses).all(), loss_dict

                loss_dict_reduced = {
                    k: v.item()
                    for k, v in comm.reduce_dict(loss_dict).items()
                }
                losses_reduced = sum(loss
                                     for loss in loss_dict_reduced.values())
                if self.is_global_zero:
                    storage.put_scalars(total_loss=losses_reduced,
                                        **loss_dict_reduced)

                optimizer.zero_grad(set_to_none=True)
                self.backward(losses)
                optimizer.step()

                storage.put_scalar("lr",
                                   optimizer.param_groups[0]["lr"],
                                   smoothing_hint=False)
                scheduler.step()

                if (cfg.TEST.EVAL_PERIOD > 0
                        and (iteration + 1) % cfg.TEST.EVAL_PERIOD == 0
                        and iteration != max_iter - 1):
                    self.do_test(cfg, model, epoch=epoch, iteration=iteration)
                    # Compared to "train_net.py", the test results are not dumped to EventStorage
                    self.barrier()

                if iteration - start_iter > 5 and (
                    (iteration + 1) % cfg.TRAIN.PRINT_FREQ == 0
                        or iteration == max_iter - 1 or iteration < 100):
                    for writer in writers:
                        writer.write()
                    # visualize some images ========================================
                    if cfg.TRAIN.VIS_IMG:
                        with torch.no_grad():
                            vis_i = 0
                            roi_img_vis = batch["roi_img"][vis_i].cpu().numpy()
                            roi_img_vis = denormalize_image(
                                roi_img_vis, cfg).transpose(1, 2,
                                                            0).astype("uint8")
                            tbx_writer.add_image("input_image", roi_img_vis,
                                                 iteration)

                            out_coor_x = out_dict["coor_x"].detach()
                            out_coor_y = out_dict["coor_y"].detach()
                            out_coor_z = out_dict["coor_z"].detach()
                            out_xyz = get_out_coor(cfg, out_coor_x, out_coor_y,
                                                   out_coor_z)

                            out_xyz_vis = out_xyz[vis_i].cpu().numpy(
                            ).transpose(1, 2, 0)
                            out_xyz_vis = get_emb_show(out_xyz_vis)
                            tbx_writer.add_image("out_xyz", out_xyz_vis,
                                                 iteration)

                            gt_xyz_vis = batch["roi_xyz"][vis_i].cpu().numpy(
                            ).transpose(1, 2, 0)
                            gt_xyz_vis = get_emb_show(gt_xyz_vis)
                            tbx_writer.add_image("gt_xyz", gt_xyz_vis,
                                                 iteration)

                            out_mask = out_dict["mask"].detach()
                            out_mask = get_out_mask(cfg, out_mask)
                            out_mask_vis = out_mask[vis_i, 0].cpu().numpy()
                            tbx_writer.add_image("out_mask", out_mask_vis,
                                                 iteration)

                            gt_mask_vis = batch["roi_mask"][vis_i].detach(
                            ).cpu().numpy()
                            tbx_writer.add_image("gt_mask", gt_mask_vis,
                                                 iteration)

                if (iteration + 1) % periodic_checkpointer.period == 0 or (
                        periodic_checkpointer.max_iter is not None and
                    (iteration + 1) >= periodic_checkpointer.max_iter):
                    if hasattr(optimizer,
                               "consolidate_state_dict"):  # for ddp_sharded
                        optimizer.consolidate_state_dict()
                periodic_checkpointer.step(iteration, epoch=epoch)