def main_worker_eval(worker_id, args):

    device = torch.device("cuda:%d" % worker_id)
    cfg = setup(args)

    # build test set
    test_loader = build_data_loader(cfg,
                                    dataset,
                                    "test",
                                    multigpu=False,
                                    num_workers=8)
    logger.info("test - %d" % len(test_loader))

    # load checkpoing and build model
    if cfg.MODEL.CHECKPOINT == "":
        raise ValueError("Invalid checkpoing provided")
    logger.info("Loading model from checkpoint: %s" % (cfg.MODEL.CHECKPOINT))
    cp = torch.load(PathManager.get_local_path(cfg.MODEL.CHECKPOINT))
    state_dict = clean_state_dict(cp["best_states"]["model"])
    model = build_model(cfg)
    model.load_state_dict(state_dict)
    logger.info("Model loaded")
    model.to(device)

    wandb.init(project='MeshRCNN', config=cfg, name='meshrcnn-eval')
    if args.eval_p2m:
        evaluate_test_p2m(model, test_loader)
    else:
        evaluate_test(model, test_loader)
def training_loop(cfg, cp, model, optimizer, scheduler, loaders, device,
                  loss_fn):

    if comm.is_main_process():
        wandb.init(project='MeshRCNN', config=cfg, name='meshrcnn')

    Timer.timing = False
    iteration_timer = Timer("Iteration")

    # model.parameters() is surprisingly expensive at 150ms, so cache it
    if hasattr(model, "module"):
        params = list(model.module.parameters())
    else:
        params = list(model.parameters())
    loss_moving_average = cp.data.get("loss_moving_average", None)

    # Zhengyuan modification
    loss_predictor = LossPredictionModule().to(device)
    loss_pred_optim = torch.optim.SGD(loss_predictor.parameters(),
                                      lr=1e-3,
                                      momentum=0.9)

    while cp.epoch < cfg.SOLVER.NUM_EPOCHS:
        if comm.is_main_process():
            logger.info("Starting epoch %d / %d" %
                        (cp.epoch + 1, cfg.SOLVER.NUM_EPOCHS))

        # When using a DistributedSampler we need to manually set the epoch so that
        # the data is shuffled differently at each epoch
        for loader in loaders.values():
            if hasattr(loader.sampler, "set_epoch"):
                loader.sampler.set_epoch(cp.epoch)

        # Config settings for renderer
        blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
        raster_settings = RasterizationSettings(
            image_size=256,
            blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma,
            faces_per_pixel=50,
        )
        rot_y_90 = torch.tensor([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0],
                                 [0, 0, 0, 1]]).float().to(device)

        for i, batch in enumerate(loaders["train"]):
            if i == 0:
                iteration_timer.start()
            else:
                iteration_timer.tick()

            batch = loaders["train"].postprocess(batch, device)
            if dataset == 'MeshVoxMulti':
                imgs, meshes_gt, points_gt, normals_gt, voxels_gt, id_strs, _imgs, render_RTs, RTs = batch
            else:
                imgs, meshes_gt, points_gt, normals_gt, voxels_gt = batch

            # NOTE: _imgs contains all of the other images in belonging to this model
            # We have to select the next-best-view from that list of images

            num_infinite_params = 0
            for p in params:
                num_infinite_params += (torch.isfinite(
                    p.data) == 0).sum().item()
            if num_infinite_params > 0:
                msg = "ERROR: Model has %d non-finite params (before forward!)"
                logger.info(msg % num_infinite_params)
                return

            model_kwargs = {}
            if cfg.MODEL.VOXEL_ON and cp.t < cfg.MODEL.VOXEL_HEAD.VOXEL_ONLY_ITERS:
                model_kwargs["voxel_only"] = True
            with Timer("Forward"):
                voxel_scores, meshes_pred = model(imgs, **model_kwargs)

            num_infinite = 0
            for cur_meshes in meshes_pred:
                cur_verts = cur_meshes.verts_packed()
                num_infinite += (torch.isfinite(cur_verts) == 0).sum().item()
            if num_infinite > 0:
                logger.info("ERROR: Got %d non-finite verts" % num_infinite)
                return

            total_silh_loss = torch.tensor(
                0.)  # Total silhouette loss, to be added to "loss" below
            # Voxel only training for first few iterations
            if not meshes_gt is None and not model_kwargs.get(
                    "voxel_only", False):
                import pdb
                pdb.set_trace()
                _meshes_pred = meshes_pred[-1].clone()
                _meshes_gt = meshes_gt[-1].clone()

                # Render masks from predicted mesh for each view
                # GT probability map to supervise prediction module
                B = len(meshes_gt)
                probability_map = 0.01 * torch.ones((B, 24))  # batch size x 24
                for b, (cur_gt_mesh, cur_pred_mesh) in enumerate(
                        zip(meshes_gt, _meshes_pred)):
                    # Maybe computationally expensive, but need to transform back to world space based on rendered image viewpoint
                    RT = RTs[b]
                    # Rotate 90 degrees about y-axis and invert
                    invRT = torch.inverse(RT.mm(rot_y_90))
                    invRT_no_rot = torch.inverse(RT)  # Just invert

                    cur_pred_mesh._verts_list[0] = project_verts(
                        cur_pred_mesh._verts_list[0], invRT)
                    sid = id_strs[b].split('-')[0]

                    # For some strange reason all classes (expect vehicle class) require a 90 degree rotation about the y-axis
                    if sid == '02958343':
                        cur_gt_mesh._verts_list[0] = project_verts(
                            cur_gt_mesh._verts_list[0], invRT_no_rot)
                    else:
                        cur_gt_mesh._verts_list[0] = project_verts(
                            cur_gt_mesh._verts_list[0], invRT)

                    for iid in range(len(render_RTs[b])):

                        R = render_RTs[b][iid][:3, :3].unsqueeze(0)
                        T = render_RTs[b][iid][:3, 3].unsqueeze(0)
                        cameras = OpenGLPerspectiveCameras(device=device,
                                                           R=R,
                                                           T=T)
                        silhouette_renderer = MeshRenderer(
                            rasterizer=MeshRasterizer(
                                cameras=cameras,
                                raster_settings=raster_settings),
                            shader=SoftSilhouetteShader(
                                blend_params=blend_params))

                        ref_image = silhouette_renderer(
                            meshes_world=cur_gt_mesh, R=R, T=T)
                        image = silhouette_renderer(meshes_world=cur_pred_mesh,
                                                    R=R,
                                                    T=T)
                        '''
                        import matplotlib.pyplot as plt
                        plt.subplot(1,2,1)
                        plt.imshow(ref_image[0,:,:,3].detach().cpu().numpy())
                        plt.subplot(1,2,2)
                        plt.imshow(image[0,:,:,3].detach().cpu().numpy())
                        plt.show()
                        '''

                        # MSE Loss between both silhouettes
                        silh_loss = torch.sum(
                            (image[0, :, :, 3] - ref_image[0, :, :, 3])**2)
                        probability_map[b, iid] = silh_loss.detach()

                        total_silh_loss += silh_loss

                probability_map = torch.nn.functional.softmax(
                    probability_map, dim=1)  # Softmax across images
                nbv_idx = torch.argmax(probability_map,
                                       dim=1)  # Next-best view indices
                nbv_imgs = _imgs[torch.arange(B),
                                 nbv_idx]  # Next-best view images

                # NOTE: Do a second forward pass through the model? This time for multi-view reconstruction
                # The input should be the first image and the next-best view
                #voxel_scores, meshes_pred = model(nbv_imgs, **model_kwargs)

            loss, losses = None, {}
            if num_infinite == 0:
                loss, losses = loss_fn(voxel_scores, meshes_pred, voxels_gt,
                                       (points_gt, normals_gt))
            skip = loss is None
            if loss is None or (torch.isfinite(loss) == 0).sum().item() > 0:
                logger.info("WARNING: Got non-finite loss %f" % loss)
                skip = True

            # Add silhouette loss to total loss
            silh_weight = 1.0  # TODO: Add a weight for the silhouette loss?
            loss = loss + total_silh_loss * silh_weight
            losses['silhouette'] = total_silh_loss

            if model_kwargs.get("voxel_only", False):
                for k, v in losses.items():
                    if k != "voxel":
                        losses[k] = 0.0 * v

            if loss is not None and cp.t % cfg.SOLVER.LOGGING_PERIOD == 0:
                if comm.is_main_process():
                    cp.store_metric(loss=loss.item())
                    str_out = "Iteration: %d, epoch: %d, lr: %.5f," % (
                        cp.t,
                        cp.epoch,
                        optimizer.param_groups[0]["lr"],
                    )
                    for k, v in losses.items():
                        str_out += "  %s loss: %.4f," % (k, v.item())
                    str_out += "  total loss: %.4f," % loss.item()

                    # memory allocaged
                    if torch.cuda.is_available():
                        max_mem_mb = torch.cuda.max_memory_allocated(
                        ) / 1024.0 / 1024.0
                        str_out += " mem: %d" % max_mem_mb

                    if len(meshes_pred) > 0:
                        mean_V = meshes_pred[-1].num_verts_per_mesh().float(
                        ).mean().item()
                        mean_F = meshes_pred[-1].num_faces_per_mesh().float(
                        ).mean().item()
                        str_out += ", mesh size = (%d, %d)" % (mean_V, mean_F)
                    logger.info(str_out)

                # Log with Weights & Biases, comment out if not installed
                wandb.log(losses)

            if loss_moving_average is None and loss is not None:
                loss_moving_average = loss.item()

            # Skip backprop for this batch if the loss is above the skip factor times
            # the moving average for losses
            if loss is None:
                pass
            elif loss.item(
            ) > cfg.SOLVER.SKIP_LOSS_THRESH * loss_moving_average:
                logger.info("Warning: Skipping loss %f on GPU %d" %
                            (loss.item(), comm.get_rank()))
                cp.store_metric(losses_skipped=loss.item())
                skip = True
            else:
                # Update the moving average of our loss
                gamma = cfg.SOLVER.LOSS_SKIP_GAMMA
                loss_moving_average *= gamma
                loss_moving_average += (1.0 - gamma) * loss.item()
                cp.store_data("loss_moving_average", loss_moving_average)

            if skip:
                logger.info("Dummy backprop on GPU %d" % comm.get_rank())
                loss = 0.0 * sum(p.sum() for p in params)

            # Backprop and step
            scheduler.step()
            optimizer.zero_grad()
            with Timer("Backward"):
                loss.backward()

            # Zhengyuan step loss_prediction
            loss_predictor.train_batch(image, probability_map, loss_pred_optim)

            # When training with normal loss, sometimes I get NaNs in gradient that
            # cause the model to explode. Check for this before performing a gradient
            # update. This is safe in mult-GPU since gradients have already been
            # summed, so each GPU has the same gradients.
            num_infinite_grad = 0
            for p in params:
                num_infinite_grad += (torch.isfinite(p.grad) == 0).sum().item()
            if num_infinite_grad == 0:
                optimizer.step()
            else:
                msg = "WARNING: Got %d non-finite elements in gradient; skipping update"
                logger.info(msg % num_infinite_grad)
            cp.step()

            if cp.t % cfg.SOLVER.CHECKPOINT_PERIOD == 0:
                eval_and_save(model, loaders, optimizer, scheduler, cp)
        cp.step_epoch()
    eval_and_save(model, loaders, optimizer, scheduler, cp)

    if comm.is_main_process():
        logger.info("Evaluating on test set:")
        test_loader = build_data_loader(cfg, dataset, "test", multigpu=False)
        evaluate_test(model, test_loader)
Ejemplo n.º 3
0
def training_loop(cfg, cp, model, optimizer, scheduler, loaders, device,
                  loss_fn):
    Timer.timing = False
    iteration_timer = Timer("Iteration")

    # model.parameters() is surprisingly expensive at 150ms, so cache it
    if hasattr(model, "module"):
        params = list(model.module.parameters())
    else:
        params = list(model.parameters())
    loss_moving_average = cp.data.get("loss_moving_average", None)
    while cp.epoch < cfg.SOLVER.NUM_EPOCHS:
        if comm.is_main_process():
            logger.info("Starting epoch %d / %d" %
                        (cp.epoch + 1, cfg.SOLVER.NUM_EPOCHS))

        # When using a DistributedSampler we need to manually set the epoch so that
        # the data is shuffled differently at each epoch
        for loader in loaders.values():
            if hasattr(loader.sampler, "set_epoch"):
                loader.sampler.set_epoch(cp.epoch)

        for i, batch in enumerate(loaders["train"]):
            if i == 0:
                iteration_timer.start()
            else:
                iteration_timer.tick()
            batch = loaders["train"].postprocess(batch, device)
            imgs, meshes_gt, points_gt, normals_gt, voxels_gt = batch

            num_infinite_params = 0
            for p in params:
                num_infinite_params += (torch.isfinite(
                    p.data) == 0).sum().item()
            if num_infinite_params > 0:
                msg = "ERROR: Model has %d non-finite params (before forward!)"
                logger.info(msg % num_infinite_params)
                return

            model_kwargs = {}
            if cfg.MODEL.VOXEL_ON and cp.t < cfg.MODEL.VOXEL_HEAD.VOXEL_ONLY_ITERS:
                model_kwargs["voxel_only"] = True
            with Timer("Forward"):
                voxel_scores, meshes_pred = model(imgs, **model_kwargs)

            num_infinite = 0
            for cur_meshes in meshes_pred:
                cur_verts = cur_meshes.verts_packed()
                num_infinite += (torch.isfinite(cur_verts) == 0).sum().item()
            if num_infinite > 0:
                logger.info("ERROR: Got %d non-finite verts" % num_infinite)
                return

            loss, losses = None, {}
            if num_infinite == 0:
                loss, losses = loss_fn(voxel_scores, meshes_pred, voxels_gt,
                                       (points_gt, normals_gt))
            skip = loss is None
            if loss is None or (torch.isfinite(loss) == 0).sum().item() > 0:
                logger.info("WARNING: Got non-finite loss %f" % loss)
                skip = True

            if model_kwargs.get("voxel_only", False):
                for k, v in losses.items():
                    if k != "voxel":
                        losses[k] = 0.0 * v

            if loss is not None and cp.t % cfg.SOLVER.LOGGING_PERIOD == 0:
                if comm.is_main_process():
                    cp.store_metric(loss=loss.item())
                    str_out = "Iteration: %d, epoch: %d, lr: %.5f," % (
                        cp.t,
                        cp.epoch,
                        optimizer.param_groups[0]["lr"],
                    )
                    for k, v in losses.items():
                        str_out += "  %s loss: %.4f," % (k, v.item())
                    str_out += "  total loss: %.4f," % loss.item()

                    # memory allocaged
                    if torch.cuda.is_available():
                        max_mem_mb = torch.cuda.max_memory_allocated(
                        ) / 1024.0 / 1024.0
                        str_out += " mem: %d" % max_mem_mb

                    if len(meshes_pred) > 0:
                        mean_V = meshes_pred[-1].num_verts_per_mesh().float(
                        ).mean().item()
                        mean_F = meshes_pred[-1].num_faces_per_mesh().float(
                        ).mean().item()
                        str_out += ", mesh size = (%d, %d)" % (mean_V, mean_F)
                    logger.info(str_out)

            if loss_moving_average is None and loss is not None:
                loss_moving_average = loss.item()

            # Skip backprop for this batch if the loss is above the skip factor times
            # the moving average for losses
            if loss is None:
                pass
            elif loss.item(
            ) > cfg.SOLVER.SKIP_LOSS_THRESH * loss_moving_average:
                logger.info("Warning: Skipping loss %f on GPU %d" %
                            (loss.item(), comm.get_rank()))
                cp.store_metric(losses_skipped=loss.item())
                skip = True
            else:
                # Update the moving average of our loss
                gamma = cfg.SOLVER.LOSS_SKIP_GAMMA
                loss_moving_average *= gamma
                loss_moving_average += (1.0 - gamma) * loss.item()
                cp.store_data("loss_moving_average", loss_moving_average)

            if skip:
                logger.info("Dummy backprop on GPU %d" % comm.get_rank())
                loss = 0.0 * sum(p.sum() for p in params)

            # Backprop and step
            scheduler.step()
            optimizer.zero_grad()
            with Timer("Backward"):
                loss.backward()

            # When training with normal loss, sometimes I get NaNs in gradient that
            # cause the model to explode. Check for this before performing a gradient
            # update. This is safe in mult-GPU since gradients have already been
            # summed, so each GPU has the same gradients.
            num_infinite_grad = 0
            for p in params:
                num_infinite_grad += (torch.isfinite(p.grad) == 0).sum().item()
            if num_infinite_grad == 0:
                optimizer.step()
            else:
                msg = "WARNING: Got %d non-finite elements in gradient; skipping update"
                logger.info(msg % num_infinite_grad)
            cp.step()

            if cp.t % cfg.SOLVER.CHECKPOINT_PERIOD == 0:
                eval_and_save(model, loaders, optimizer, scheduler, cp)
        cp.step_epoch()
    eval_and_save(model, loaders, optimizer, scheduler, cp)

    if comm.is_main_process():
        logger.info("Evaluating on test set:")
        test_loader = build_data_loader(cfg, "MeshVox", "test", multigpu=False)
        evaluate_test(model, test_loader)
Ejemplo n.º 4
0
def training_loop(cfg, cp, model, optimizer, scheduler, loaders, device,
                  loss_fn):

    #if comm.is_main_process():
    #    wandb.init(project='MeshRCNN', config=cfg, name='prediction_module')

    Timer.timing = False
    iteration_timer = Timer("Iteration")

    # model.parameters() is surprisingly expensive at 150ms, so cache it
    if hasattr(model, "module"):
        params = list(model.module.parameters())
    else:
        params = list(model.parameters())
    loss_moving_average = cp.data.get("loss_moving_average", None)

    # Zhengyuan modification
    loss_predictor = LossPredictionModule().to(device)
    loss_pred_optim = torch.optim.Adam(loss_predictor.parameters(), lr=1e-5)

    while cp.epoch < cfg.SOLVER.NUM_EPOCHS:
        if comm.is_main_process():
            logger.info("Starting epoch %d / %d" %
                        (cp.epoch + 1, cfg.SOLVER.NUM_EPOCHS))

        # When using a DistributedSampler we need to manually set the epoch so that
        # the data is shuffled differently at each epoch
        for loader in loaders.values():
            if hasattr(loader.sampler, "set_epoch"):
                loader.sampler.set_epoch(cp.epoch)

        # Config settings for renderer
        render_image_size = 256
        blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
        raster_settings = RasterizationSettings(
            image_size=render_image_size,
            blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma,
            faces_per_pixel=50,
        )
        rot_y_90 = torch.tensor([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0],
                                 [0, 0, 0, 1]]).float().to(device)

        for i, batch in enumerate(loaders["train"]):
            if i == 0:
                iteration_timer.start()
            else:
                iteration_timer.tick()

            batch = loaders["train"].postprocess(batch, device)
            if dataset == 'MeshVoxMulti':
                imgs, meshes_gt, points_gt, normals_gt, voxels_gt, id_strs, _, render_RTs, RTs = batch
            else:
                imgs, meshes_gt, points_gt, normals_gt, voxels_gt = batch

            with inference_context(model):
                # NOTE: _imgs contains all of the other images in belonging to this model
                # We have to select the next-best-view from that list of images

                model_kwargs = {}
                if cfg.MODEL.VOXEL_ON and cp.t < cfg.MODEL.VOXEL_HEAD.VOXEL_ONLY_ITERS:
                    model_kwargs["voxel_only"] = True
                with Timer("Forward"):
                    voxel_scores, meshes_pred = model(imgs, **model_kwargs)

            total_silh_loss = torch.tensor(
                0.)  # Total silhouette loss, to be added to "loss" below
            # Voxel only training for first few iterations
            if not meshes_gt is None and not model_kwargs.get(
                    "voxel_only", False):
                _meshes_pred = meshes_pred[-1].clone()
                _meshes_gt = meshes_gt[-1].clone()

                # Render masks from predicted mesh for each view
                # GT probability map to supervise prediction module
                B = len(meshes_gt)
                probability_map = 0.01 * torch.ones(
                    (B, 24)).to(device)  # batch size x 24
                viewgrid = torch.zeros(
                    (B, 24, render_image_size,
                     render_image_size)).to(device)  # batch size x 24 x H x W
                for b, (cur_gt_mesh, cur_pred_mesh) in enumerate(
                        zip(meshes_gt, _meshes_pred)):
                    # Maybe computationally expensive, but need to transform back to world space based on rendered image viewpoint
                    RT = RTs[b]
                    # Rotate 90 degrees about y-axis and invert
                    invRT = torch.inverse(RT.mm(rot_y_90))
                    invRT_no_rot = torch.inverse(RT)  # Just invert

                    cur_pred_mesh._verts_list[0] = project_verts(
                        cur_pred_mesh._verts_list[0], invRT)
                    sid = id_strs[b].split('-')[0]

                    # For some strange reason all classes (expect vehicle class) require a 90 degree rotation about the y-axis
                    if sid == '02958343':
                        cur_gt_mesh._verts_list[0] = project_verts(
                            cur_gt_mesh._verts_list[0], invRT_no_rot)
                    else:
                        cur_gt_mesh._verts_list[0] = project_verts(
                            cur_gt_mesh._verts_list[0], invRT)

                    for iid in range(len(render_RTs[b])):

                        R = render_RTs[b][iid][:3, :3].unsqueeze(0)
                        T = render_RTs[b][iid][:3, 3].unsqueeze(0)
                        cameras = OpenGLPerspectiveCameras(device=device,
                                                           R=R,
                                                           T=T)
                        silhouette_renderer = MeshRenderer(
                            rasterizer=MeshRasterizer(
                                cameras=cameras,
                                raster_settings=raster_settings),
                            shader=SoftSilhouetteShader(
                                blend_params=blend_params))

                        ref_image = (silhouette_renderer(
                            meshes_world=cur_gt_mesh, R=R, T=T) > 0).float()
                        image = (silhouette_renderer(
                            meshes_world=cur_pred_mesh, R=R, T=T) > 0).float()

                        #Add image silhouette to viewgrid
                        viewgrid[b, iid] = image[..., -1]
                        '''
                        import matplotlib.pyplot as plt
                        plt.subplot(1,2,1)
                        plt.imshow(ref_image[0,:,:,3].detach().cpu().numpy())
                        plt.subplot(1,2,2)
                        plt.imshow(image[0,:,:,3].detach().cpu().numpy())
                        plt.show()
                        '''

                        # MSE Loss between both silhouettes
                        silh_loss = torch.sum(
                            (image[0, :, :, 3] - ref_image[0, :, :, 3])**2)
                        probability_map[b, iid] = silh_loss.detach()

                        total_silh_loss += silh_loss

                probability_map = probability_map / (torch.max(
                    probability_map, dim=1)[0].unsqueeze(1))  # Normalize

                probability_map = torch.nn.functional.softmax(
                    probability_map, dim=1).to(device)  # Softmax across images
                #nbv_idx = torch.argmax(probability_map, dim=1)  # Next-best view indices
                #nbv_imgs = _imgs[torch.arange(B), nbv_idx]  # Next-best view images

                # NOTE: Do a second forward pass through the model? This time for multi-view reconstruction
                # The input should be the first image and the next-best view
                #voxel_scores, meshes_pred = model(nbv_imgs, **model_kwargs)

                # Zhengyuan step loss_prediction
                predictor_loss = loss_predictor.train_batch(
                    viewgrid, probability_map, loss_pred_optim)
                if comm.is_main_process():
                    #wandb.log({'prediction module loss':predictor_loss})

                    if cp.t % 50 == 0:
                        print('{} predictor_loss: {}'.format(
                            cp.t, predictor_loss))

                    #Save checkpoint every t iteration
                    if cp.t % 500 == 0:
                        print(
                            'Saving loss prediction module at iter {}'.format(
                                cp.t))
                        os.makedirs('./output_prediction_module',
                                    exist_ok=True)
                        torch.save(
                            loss_predictor.state_dict(),
                            './output_prediction_module/prediction_module_' +
                            str(cp.t) + '.pth')

            cp.step()

            if cp.t % cfg.SOLVER.CHECKPOINT_PERIOD == 0:
                eval_and_save(model, loaders, optimizer, scheduler, cp)
        cp.step_epoch()
    eval_and_save(model, loaders, optimizer, scheduler, cp)

    if comm.is_main_process():
        logger.info("Evaluating on test set:")
        test_loader = build_data_loader(cfg, dataset, "test", multigpu=False)
        evaluate_test(model, test_loader)
def training_loop(cfg, cp, model, optimizer, scheduler, loaders, device, loss_fn):
    Timer.timing = False
    iteration_timer = Timer("Iteration")

    # model.parameters() is surprisingly expensive at 150ms, so cache it
    if hasattr(model, "module"):
        params = list(model.module.parameters())
    else:
        params = list(model.parameters())
    loss_moving_average = cp.data.get("loss_moving_average", None)
    while cp.epoch < cfg.SOLVER.NUM_EPOCHS:
        if comm.is_main_process():
            logger.info("Starting epoch %d / %d" % (cp.epoch + 1, cfg.SOLVER.NUM_EPOCHS))

        # When using a DistributedSampler we need to manually set the epoch so that
        # the data is shuffled differently at each epoch
        for loader in loaders.values():
            if hasattr(loader.sampler, "set_epoch"):
                loader.sampler.set_epoch(cp.epoch)

        for i, batch in enumerate(loaders["train"]):
            if i == 0:
                iteration_timer.start()
            else:
                iteration_timer.tick()
            batch = {
                k: (v.to(device) if isinstance(v, torch.Tensor) else v)
                for k, v in batch.items()
            }

            num_infinite_params = 0
            for p in params:
                num_infinite_params += (torch.isfinite(p.data) == 0).sum().item()
            if num_infinite_params > 0:
                msg = "ERROR: Model has %d non-finite params (before forward!)"
                logger.info(msg % num_infinite_params)
                return

            model_kwargs = {}
            model_kwargs["extrinsics"] = batch["extrinsics"]

            with Timer("Forward"):
                model_outputs = model(batch["imgs"], **model_kwargs)
                pred_depths = model_outputs["depths"]

            loss = loss_fn(
                batch["depths"], pred_depths, batch["masks"]
            )
            losses = {"pred_depth": loss}

            skip = loss is None
            if loss is None or (torch.isfinite(loss) == 0).sum().item() > 0:
                logger.info("WARNING: Got non-finite loss %f" % loss)
                skip = True

            if loss is not None and cp.t % cfg.SOLVER.LOGGING_PERIOD == 0:
                if comm.is_main_process():
                    cp.store_metric(loss=loss.item())
                    str_out = "Iteration: %d, epoch: %d, lr: %.5f," % (
                        cp.t,
                        cp.epoch,
                        optimizer.param_groups[0]["lr"],
                    )
                    for k, v in losses.items():
                        str_out += "  %s loss: %.7f," % (k, v.item())
                    str_out += "  total loss: %.7f," % loss.item()

                    # memory allocaged
                    if torch.cuda.is_available():
                        max_mem_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
                        str_out += " mem: %d" % max_mem_mb

                    logger.info(str_out)

            if loss_moving_average is None and loss is not None:
                loss_moving_average = loss.item()

            # Skip backprop for this batch if the loss is above the skip factor times
            # the moving average for losses
            if loss is None:
                pass
            elif loss.item() > cfg.SOLVER.SKIP_LOSS_THRESH * loss_moving_average:
                logger.info("Warning: Skipping loss %f on GPU %d" % (loss.item(), comm.get_rank()))
                cp.store_metric(losses_skipped=loss.item())
                skip = True
            else:
                # Update the moving average of our loss
                gamma = cfg.SOLVER.LOSS_SKIP_GAMMA
                loss_moving_average *= gamma
                loss_moving_average += (1.0 - gamma) * loss.item()
                cp.store_data("loss_moving_average", loss_moving_average)

            if skip:
                logger.info("Dummy backprop on GPU %d" % comm.get_rank())
                loss = 0.0 * sum(p.sum() for p in params)

            # Backprop and step
            scheduler.step()
            optimizer.zero_grad()
            with Timer("Backward"):
                loss.backward()

            # When training with normal loss, sometimes I get NaNs in gradient that
            # cause the model to explode. Check for this before performing a gradient
            # update. This is safe in mult-GPU since gradients have already been
            # summed, so each GPU has the same gradients.
            num_infinite_grad = 0
            for p in params:
                if p.grad is not None:
                    num_infinite_grad += (torch.isfinite(p.grad) == 0).sum() \
                                                                      .item()
            if num_infinite_grad == 0:
                optimizer.step()
            else:
                msg = "WARNING: Got %d non-finite elements in gradient; skipping update"
                logger.info(msg % num_infinite_grad)
            cp.step()

        eval_and_save(model, loaders, optimizer, scheduler, cp)
        cp.step_epoch()

    if comm.is_main_process():
        logger.info("Evaluating on test set:")
        test_loader = build_data_loader(
            cfg, get_dataset_name(cfg), "test", multigpu=False
        )
        test_loader.dataset.set_depth_only(True)
        test_loader.collate_fn = default_collate
        evaluate_test(model, test_loader)