def main(_config, seed):
    global EPOCH, weights
    if _config['weight'] is not None:
        weights = _config['weight']

    dataset_class = DatasetVisibilityKittiSingle
    img_shape = (384, 1280)

    split = 'test'
    if _config['random_initial_pose']:
        split = 'test_random'
    maps_folder = 'local_maps'
    if _config['maps_folder'] is not None:
        maps_folder = _config['maps_folder']
    if _config['test_sequence'] is None:
        raise TypeError('test_sequences cannot be None')
    else:
        if isinstance(_config['test_sequence'], int):
            _config['test_sequence'] = f"{_config['test_sequence']:02d}"
        dataset_val = dataset_class(_config['data_folder'],
                                    max_r=_config['max_r'],
                                    max_t=_config['max_t'],
                                    split=split,
                                    use_reflectance=_config['use_reflectance'],
                                    maps_folder=maps_folder,
                                    test_sequence=_config['test_sequence'])

    np.random.seed(seed)
    torch.random.manual_seed(seed)

    def init_fn(x):
        return _init_fn(x, seed)

    num_worker = 6
    batch_size = 1

    TestImgLoader = torch.utils.data.DataLoader(dataset=dataset_val,
                                                shuffle=False,
                                                batch_size=batch_size,
                                                num_workers=num_worker,
                                                worker_init_fn=init_fn,
                                                collate_fn=merge_inputs,
                                                drop_last=False,
                                                pin_memory=False)

    print(len(TestImgLoader))

    models = []
    for i in range(len(weights)):
        if _config['network'].startswith('PWC'):
            feat = 1
            md = 4
            split = _config['network'].split('_')
            for item in split[1:]:
                if item.startswith('f'):
                    feat = int(item[-1])
                elif item.startswith('md'):
                    md = int(item[2:])
            assert 0 < feat < 7, "Feature Number from PWC have to be between 1 and 6"
            assert 0 < md, "md must be positive"
            model = CMRNet(img_shape,
                           use_feat_from=feat,
                           md=md,
                           use_reflectance=_config['use_reflectance'])
        else:
            raise TypeError("Network unknown")
        checkpoint = torch.load(weights[i], map_location='cpu')
        saved_state_dict = checkpoint['state_dict']
        model.load_state_dict(saved_state_dict)
        model = model.to(device)
        model.eval()
        models.append(model)
        if i == 0:
            _config['occlusion_threshold'] = checkpoint['config'][
                'occlusion_threshold']
            _config['occlusion_kernel'] = checkpoint['config'][
                'occlusion_kernel']
        else:
            assert _config['occlusion_threshold'] == checkpoint['config'][
                'occlusion_threshold']
            assert _config['occlusion_kernel'] == checkpoint['config'][
                'occlusion_kernel']

    if _config['save_log']:
        log_file = f'./results_for_paper/log_seq{_config["test_sequence"]}.csv'
        log_file = open(log_file, 'w')
        log_file = csv.writer(log_file)
        header = ['frame']
        for i in range(len(weights) + 1):
            header += [
                f'iter{i}_error_t', f'iter{i}_error_r', f'iter{i}_error_x',
                f'iter{i}_error_y', f'iter{i}_error_z', f'iter{i}_error_r',
                f'iter{i}_error_p', f'iter{i}_error_y'
            ]
        log_file.writerow(header)

    show = _config['show']
    show = True
    errors_r = []
    errors_t = []
    errors_t2 = []
    errors_rpy = []
    all_RTs = []

    prev_tr_error = None
    prev_rot_error = None

    for i in range(len(weights) + 1):
        errors_r.append([])
        errors_t.append([])
        errors_t2.append([])
        errors_rpy.append([])

    for batch_idx, sample in enumerate(TestImgLoader):

        log_string = [str(batch_idx)]

        lidar_input = []
        rgb_input = []
        shape_pad = [0, 0, 0, 0]

        if batch_idx == 0 or not _config['use_prev_output']:
            # Qui dare posizione di input del frame corrente rispetto alla GT
            sample['tr_error'] = sample['tr_error'].cuda()
            sample['rot_error'] = sample['rot_error'].cuda()
        else:
            sample['tr_error'] = prev_tr_error
            sample['rot_error'] = prev_rot_error

        for idx in range(len(sample['rgb'])):

            real_shape = [
                sample['rgb'][idx].shape[1], sample['rgb'][idx].shape[2],
                sample['rgb'][idx].shape[0]
            ]

            # ProjectPointCloud in RT-pose
            sample['point_cloud'][idx] = sample['point_cloud'][idx].cuda()
            pc_rotated = sample['point_cloud'][idx].clone()
            reflectance = None
            if _config['use_reflectance']:
                reflectance = sample['reflectance'][idx].cuda()

            R = mathutils.Quaternion(sample['rot_error'][idx])
            T = mathutils.Vector(sample['tr_error'][idx])

            pc_rotated = rotate_back(pc_rotated, R, T)
            cam_params = sample['calib'][idx].cuda()
            cam_model = CameraModel()
            cam_model.focal_length = cam_params[:2]
            cam_model.principal_point = cam_params[2:]
            uv, depth, points, refl = cam_model.project_pytorch(
                pc_rotated, real_shape, reflectance)
            uv = uv.t().int()
            depth_img = torch.zeros(real_shape[:2],
                                    device='cuda',
                                    dtype=torch.float)
            depth_img += 1000.
            depth_img = visibility.depth_image(uv, depth, depth_img,
                                               uv.shape[0], real_shape[1],
                                               real_shape[0])
            depth_img[depth_img == 1000.] = 0.

            projected_points = torch.zeros_like(depth_img, device='cuda')
            projected_points = visibility.visibility2(
                depth_img, cam_params, projected_points, depth_img.shape[1],
                depth_img.shape[0], _config['occlusion_threshold'],
                _config['occlusion_kernel'])

            if _config['use_reflectance']:
                uv = uv.long()
                indexes = projected_points[uv[:, 1], uv[:, 0]] == depth
                refl_img = torch.zeros(real_shape[:2],
                                       device='cuda',
                                       dtype=torch.float)
                refl_img[uv[indexes, 1], uv[indexes, 0]] = refl[0, indexes]

            projected_points /= 100.
            if not _config['use_reflectance']:
                projected_points = projected_points.unsqueeze(0)
            else:
                projected_points = torch.stack((projected_points, refl_img))

            rgb = sample['rgb'][idx].cuda()

            shape_pad[3] = (img_shape[0] - rgb.shape[1])
            shape_pad[1] = (img_shape[1] - rgb.shape[2])

            rgb = F.pad(rgb, shape_pad)
            projected_points = F.pad(projected_points, shape_pad)

            rgb_input.append(rgb)
            lidar_input.append(projected_points)

        lidar_input = torch.stack(lidar_input)
        rgb_input = torch.stack(rgb_input)
        if show:
            out0 = overlay_imgs(rgb, lidar_input)

            cv2.imshow("INPUT", out0[:, :, [2, 1, 0]])
            cv2.waitKey(1)

            pc_GT = sample['point_cloud'][idx].clone()

            uv, depth, _, refl = cam_model.project_pytorch(pc_GT, real_shape)
            uv = uv.t().int()
            depth_img = torch.zeros(real_shape[:2],
                                    device='cuda',
                                    dtype=torch.float)
            depth_img += 1000.
            depth_img = visibility.depth_image(uv, depth, depth_img,
                                               uv.shape[0], real_shape[1],
                                               real_shape[0])
            depth_img[depth_img == 1000.] = 0.

            projected_points = torch.zeros_like(depth_img, device='cuda')
            projected_points = visibility.visibility2(
                depth_img, cam_params, projected_points, depth_img.shape[1],
                depth_img.shape[0], _config['occlusion_threshold'],
                _config['occlusion_kernel'])
            projected_points /= 100.

            projected_points = F.pad(projected_points, shape_pad)

            lidar_GT = projected_points.unsqueeze(0).unsqueeze(0)
            out1 = overlay_imgs(rgb_input[0], lidar_GT)
            cv2.imshow("GT", out1[:, :, [2, 1, 0]])
            # plt.figure()
            # plt.imshow(out1)
            # if batch_idx == 0:
            #     # import ipdb; ipdb.set_trace()
            #     out2 = overlay_imgs(sample['rgb'][0], lidar_input[:,:,:,1241])
            #     plt.figure()
            #     plt.imshow(out2)
            #     io.imshow(lidar_input[0][0].cpu().numpy(), cmap='jet')
            #     io.show()
        rgb = rgb_input.to(device)
        lidar = lidar_input.to(device)
        target_transl = sample['tr_error'].to(device)
        target_rot = sample['rot_error'].to(device)

        point_cloud = sample['point_cloud'][0].to(device)
        reflectance = None
        if _config['use_reflectance']:
            reflectance = sample['reflectance'][0].to(device)
        camera_model = cam_model

        R = quat2mat(target_rot[0])
        T = tvector2mat(target_transl[0])
        RT1_inv = torch.mm(T, R)
        RT1 = RT1_inv.clone().inverse()

        rotated_point_cloud = rotate_forward(point_cloud, RT1)
        RTs = [RT1]

        T_composed = RT1[:3, 3]
        R_composed = quaternion_from_matrix(RT1)
        errors_t[0].append(T_composed.norm().item())
        errors_t2[0].append(T_composed)
        errors_r[0].append(
            quaternion_distance(
                R_composed.unsqueeze(0),
                torch.tensor([1., 0., 0., 0.],
                             device=R_composed.device).unsqueeze(0),
                R_composed.device))
        # rpy_error = quaternion_to_tait_bryan(R_composed)
        rpy_error = mat2xyzrpy(RT1)[3:]

        rpy_error *= (180.0 / 3.141592)
        errors_rpy[0].append(rpy_error)
        log_string += [
            str(errors_t[0][-1]),
            str(errors_r[0][-1]),
            str(errors_t2[0][-1][0].item()),
            str(errors_t2[0][-1][1].item()),
            str(errors_t2[0][-1][2].item()),
            str(errors_rpy[0][-1][0].item()),
            str(errors_rpy[0][-1][1].item()),
            str(errors_rpy[0][-1][2].item())
        ]

        if batch_idx == 0.:
            print(f'Initial T_erorr: {errors_t[0]}')
            print(f'Initial R_erorr: {errors_r[0]}')
        start = 0

        # Run model
        with torch.no_grad():
            for iteration in range(start, len(weights)):
                # Run the i-th network
                T_predicted, R_predicted = models[iteration](rgb, lidar)
                if _config['rot_transl_separated'] and iteration == 0:
                    T_predicted = torch.tensor([[0., 0., 0.]], device='cuda')
                if _config['rot_transl_separated'] and iteration == 1:
                    R_predicted = torch.tensor([[1., 0., 0., 0.]],
                                               device='cuda')

                # Project the points in the new pose predicted by the i-th network
                R_predicted = quat2mat(R_predicted[0])
                T_predicted = tvector2mat(T_predicted[0])
                RT_predicted = torch.mm(T_predicted, R_predicted)
                RTs.append(torch.mm(RTs[iteration], RT_predicted))

                rotated_point_cloud = rotate_forward(rotated_point_cloud,
                                                     RT_predicted)

                uv2, depth2, _, refl = camera_model.project_pytorch(
                    rotated_point_cloud, real_shape, reflectance)
                uv2 = uv2.t().int()
                depth_img2 = torch.zeros(real_shape[:2], device=device)
                depth_img2 += 1000.
                depth_img2 = visibility.depth_image(uv2, depth2, depth_img2,
                                                    uv2.shape[0],
                                                    real_shape[1],
                                                    real_shape[0])
                depth_img2[depth_img2 == 1000.] = 0.

                out_cuda2 = torch.zeros_like(depth_img2, device=device)
                out_cuda2 = visibility.visibility2(
                    depth_img2, cam_params, out_cuda2, depth_img2.shape[1],
                    depth_img2.shape[0], _config['occlusion_threshold'],
                    _config['occlusion_kernel'])

                if _config['use_reflectance']:
                    uv = uv.long()
                    indexes = projected_points[uv[:, 1], uv[:, 0]] == depth
                    refl_img = torch.zeros(real_shape[:2],
                                           device='cuda',
                                           dtype=torch.float)
                    refl_img[uv[indexes, 1], uv[indexes, 0]] = refl[0, indexes]
                    refl_img = F.pad(refl_img, shape_pad)

                out_cuda2 = F.pad(out_cuda2, shape_pad)

                lidar = out_cuda2.clone()
                lidar /= 100.
                if not _config['use_reflectance']:
                    lidar = lidar.unsqueeze(0)
                else:
                    lidar = torch.stack((lidar, refl_img))
                lidar = lidar.unsqueeze(0)
                if show:
                    out3 = overlay_imgs(rgb[0], lidar, idx=batch_idx)
                    cv2.imshow(f'Iter_{iteration}', out3[:, :, [2, 1, 0]])
                    cv2.waitKey(1)
                    # if iter == 1:
                    # plt.figure()
                    # plt.imshow(out3)
                    # io.imshow(lidar.cpu().numpy()[0,0], cmap='jet')
                    # io.show()

                T_composed = RTs[iteration + 1][:3, 3]
                R_composed = quaternion_from_matrix(RTs[iteration + 1])
                errors_t[iteration + 1].append(T_composed.norm().item())
                errors_t2[iteration + 1].append(T_composed)
                errors_r[iteration + 1].append(
                    quaternion_distance(
                        R_composed.unsqueeze(0),
                        torch.tensor([1., 0., 0., 0.],
                                     device=R_composed.device).unsqueeze(0),
                        R_composed.device))

                # rpy_error = quaternion_to_tait_bryan(R_composed)
                rpy_error = mat2xyzrpy(RTs[iteration + 1])[3:]
                rpy_error *= (180.0 / 3.141592)
                errors_rpy[iteration + 1].append(rpy_error)
                log_string += [
                    str(errors_t[iteration + 1][-1]),
                    str(errors_r[iteration + 1][-1]),
                    str(errors_t2[iteration + 1][-1][0].item()),
                    str(errors_t2[iteration + 1][-1][1].item()),
                    str(errors_t2[iteration + 1][-1][2].item()),
                    str(errors_rpy[iteration + 1][-1][0].item()),
                    str(errors_rpy[iteration + 1][-1][1].item()),
                    str(errors_rpy[iteration + 1][-1][2].item())
                ]

        all_RTs.append(RTs[-1])
        prev_RT = RTs[-1].inverse()
        prev_tr_error = prev_RT[:3, 3].unsqueeze(0)
        prev_rot_error = quaternion_from_matrix(prev_RT).unsqueeze(0)
        # Qui prev_rt è quanto si discosta l'output della rete rispetto alla GT

        if _config['save_log']:
            log_file.writerow(log_string)

    if _config['save_log']:
        log_file.close()
    print("Iterative refinement: ")
    for i in range(len(weights) + 1):
        errors_r[i] = torch.tensor(errors_r[i]) * (180.0 / 3.141592)
        errors_t[i] = torch.tensor(errors_t[i]) * 100
        print(
            f"Iteration {i}: \tMean Translation Error: {errors_t[i].mean():.4f} cm "
            f"     Mean Rotation Error: {errors_r[i].mean():.4f} °")
        print(
            f"Iteration {i}: \tMedian Translation Error: {errors_t[i].median():.4f} cm "
            f"     Median Rotation Error: {errors_r[i].median():.4f} °\n")
    print("-------------------------------------------------------")
    print("Timings:")
    for i in range(len(errors_t2)):
        errors_t2[i] = torch.stack(errors_t2[i])
        errors_rpy[i] = torch.stack(errors_rpy[i])
    plt.plot(errors_t2[-1][:, 0].cpu().numpy())
    plt.show()
    plt.plot(errors_t2[-1][:, 1].cpu().numpy())
    plt.show()
    plt.plot(errors_t2[-1][:, 2].cpu().numpy())
    plt.show()

    if _config["save_name"] is not None:
        torch.save(
            torch.stack(errors_t).cpu().numpy(),
            f'./results_for_paper/{_config["save_name"]}_errors_t')
        torch.save(
            torch.stack(errors_r).cpu().numpy(),
            f'./results_for_paper/{_config["save_name"]}_errors_r')
        torch.save(
            torch.stack(errors_t2).cpu().numpy(),
            f'./results_for_paper/{_config["save_name"]}_errors_t2')
        torch.save(
            torch.stack(errors_rpy).cpu().numpy(),
            f'./results_for_paper/{_config["save_name"]}_errors_rpy')

    print("End!")
コード例 #2
0
def main(_config, _run, seed):
    global EPOCH
    print(_config['loss'])

    if _config['test_sequence'] is None:
        raise TypeError('test_sequences cannot be None')
    else:
        _config['test_sequence'] = f"{_config['test_sequence']:02d}"
        print("Test Sequence: ", _config['test_sequence'])
        dataset_class = DatasetVisibilityKittiSingle
    occlusion_threshold = _config['occlusion_threshold']
    img_shape = (384, 1280)
    _config["savemodel"] = os.path.join(_config["savemodel"], _config['dataset'])

    maps_folder = 'local_maps'
    if _config['maps_folder'] is not None:
        maps_folder = _config['maps_folder']
    dataset = dataset_class(_config['data_folder'], max_r=_config['max_r'], max_t=_config['max_t'],
                            split='train', use_reflectance=_config['use_reflectance'], maps_folder=maps_folder,
                            test_sequence=_config['test_sequence'])
    dataset_val = dataset_class(_config['data_folder'], max_r=_config['max_r'], max_t=_config['max_t'],
                                split='test', use_reflectance=_config['use_reflectance'], maps_folder=maps_folder,
                                test_sequence=_config['test_sequence'])
    _config["savemodel"] = os.path.join(_config["savemodel"], _config['test_sequence'])
    if not os.path.exists(_config["savemodel"]):
        os.mkdir(_config["savemodel"])

    np.random.seed(seed)
    torch.random.manual_seed(seed)

    def init_fn(x): return _init_fn(x, seed)

    dataset_size = len(dataset)

    # Training and test set creation
    num_worker = _config['num_worker']
    batch_size = _config['batch_size']
    TrainImgLoader = torch.utils.data.DataLoader(dataset=dataset,
                                                 shuffle=True,
                                                 batch_size=batch_size,
                                                 num_workers=num_worker,
                                                 worker_init_fn=init_fn,
                                                 collate_fn=merge_inputs,
                                                 drop_last=False,
                                                 pin_memory=True)

    TestImgLoader = torch.utils.data.DataLoader(dataset=dataset_val,
                                                shuffle=False,
                                                batch_size=batch_size,
                                                num_workers=num_worker,
                                                worker_init_fn=init_fn,
                                                collate_fn=merge_inputs,
                                                drop_last=False,
                                                pin_memory=True)

    print(len(TrainImgLoader))
    print(len(TestImgLoader))

    if _config['loss'] == 'simple':
        loss_fn = ProposedLoss(_config['rescale_transl'], _config['rescale_rot'])
    elif _config['loss'] == 'geometric':
        loss_fn = GeometricLoss()
        loss_fn = loss_fn.to(device)
    elif _config['loss'] == 'points_distance':
        loss_fn = DistancePoints3D()
    elif _config['loss'] == 'L1':
        loss_fn = L1Loss(_config['rescale_transl'], _config['rescale_rot'])
    else:
        raise ValueError("Unknown Loss Function")

    #runs = datetime.now().strftime('%b%d_%H-%M-%S') + "/"
    #train_writer = SummaryWriter('./logs/' + runs)
    #ex.info["tensorflow"] = {}
    #ex.info["tensorflow"]["logdirs"] = ['./logs/' + runs]

    if _config['network'].startswith('PWC'):
        feat = 1
        md = 4
        split = _config['network'].split('_')
        for item in split[1:]:
            if item.startswith('f'):
                feat = int(item[-1])
            elif item.startswith('md'):
                md = int(item[2:])
        assert 0 < feat < 7, "Feature Number from PWC have to be between 1 and 6"
        assert 0 < md, "md must be positive"
        model = CMRNet(img_shape, use_feat_from=feat, md=md,
                       use_reflectance=_config['use_reflectance'], dropout=_config['dropout'])
    else:
        raise TypeError("Network unknown")
    if _config['weights'] is not None:
        print(f"Loading weights from {_config['weights']}")
        checkpoint = torch.load(_config['weights'], map_location='cpu')
        saved_state_dict = checkpoint['state_dict']
        model.load_state_dict(saved_state_dict)
    model = model.to(device)

    print(dataset_size)
    print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))

    parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
    if _config['loss'] == 'geometric':
        parameters += list(loss_fn.parameters())
    if _config['optimizer'] == 'adam':
        optimizer = optim.Adam(parameters, lr=_config['BASE_LEARNING_RATE'], weight_decay=5e-6)
        # Probably this scheduler is not used
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20, 50, 70], gamma=0.5)
    else:
        optimizer = optim.SGD(parameters, lr=_config['BASE_LEARNING_RATE'], momentum=0.9,
                              weight_decay=5e-6, nesterov=True)

    starting_epoch = 0
    if _config['weights'] is not None and _config['resume']:
        checkpoint = torch.load(_config['weights'], map_location='cpu')
        opt_state_dict = checkpoint['optimizer']
        optimizer.load_state_dict(opt_state_dict)
        starting_epoch = checkpoint['epoch']

    # Allow mixed-precision if needed
    # model, optimizer = apex.amp.initialize(model, optimizer, opt_level=_config["precision"])

    start_full_time = time.time()
    BEST_VAL_LOSS = 10000.
    old_save_filename = None

    total_iter = 0
    for epoch in range(starting_epoch, _config['epochs'] + 1):
        EPOCH = epoch
        print('This is %d-th epoch' % epoch)
        epoch_start_time = time.time()
        total_train_loss = 0
        local_loss = 0.
        if _config['optimizer'] != 'adam':
            _run.log_scalar("LR", _config['BASE_LEARNING_RATE'] *
                            math.exp((1 - epoch) * 4e-2), epoch)
            for param_group in optimizer.param_groups:
                param_group['lr'] = _config['BASE_LEARNING_RATE'] * \
                                    math.exp((1 - epoch) * 4e-2)
        else:
            #scheduler.step(epoch%100)
            _run.log_scalar("LR", scheduler.get_lr()[0])


        ## Training ##
        time_for_50ep = time.time()
        for batch_idx, sample in enumerate(TrainImgLoader):

            #print(f'batch {batch_idx+1}/{len(TrainImgLoader)}', end='\r')
            start_time = time.time()
            lidar_input = []
            rgb_input = []

            sample['tr_error'] = sample['tr_error'].cuda()
            sample['rot_error'] = sample['rot_error'].cuda()

            start_preprocess = time.time()
            for idx in range(len(sample['rgb'])):
                # ProjectPointCloud in RT-pose

                real_shape = [sample['rgb'][idx].shape[1], sample['rgb'][idx].shape[2], sample['rgb'][idx].shape[0]]

                sample['point_cloud'][idx] = sample['point_cloud'][idx].cuda()
                pc_rotated = sample['point_cloud'][idx].clone()
                reflectance = None
                if _config['use_reflectance']:
                    reflectance = sample['reflectance'][idx].cuda()

                R = mathutils.Quaternion(sample['rot_error'][idx]).to_matrix()
                R.resize_4x4()
                T = mathutils.Matrix.Translation(sample['tr_error'][idx])
                RT = T * R

                pc_rotated = rotate_back(pc_rotated, RT)

                if _config['max_depth'] < 100.:
                    pc_rotated = pc_rotated[:, pc_rotated[0, :] < _config['max_depth']].clone()

                cam_params = sample['calib'][idx].cuda()
                cam_model = CameraModel()
                cam_model.focal_length = cam_params[:2]
                cam_model.principal_point = cam_params[2:]
                uv, depth, _, refl = cam_model.project_pytorch(pc_rotated, real_shape, reflectance)
                uv = uv.t().int()
                depth_img = torch.zeros(real_shape[:2], device='cuda', dtype=torch.float)
                depth_img += 1000.
                depth_img = visibility.depth_image(uv, depth, depth_img, uv.shape[0], real_shape[1], real_shape[0])
                depth_img[depth_img == 1000.] = 0.

                depth_img_no_occlusion = torch.zeros_like(depth_img, device='cuda')
                depth_img_no_occlusion = visibility.visibility2(depth_img, cam_params, depth_img_no_occlusion,
                                                                depth_img.shape[1], depth_img.shape[0],
                                                                occlusion_threshold, _config['occlusion_kernel'])

                if _config['use_reflectance']:
                    # This need to be checked
                    uv = uv.long()
                    indexes = depth_img_no_occlusion[uv[:,1], uv[:,0]] == depth
                    refl_img = torch.zeros(real_shape[:2], device='cuda', dtype=torch.float)
                    refl_img[uv[indexes,1], uv[indexes,0]] = refl[0, indexes]

                depth_img_no_occlusion /= _config['max_depth']
                if not _config['use_reflectance']:
                    depth_img_no_occlusion = depth_img_no_occlusion.unsqueeze(0)
                else:
                    depth_img_no_occlusion = torch.stack((depth_img_no_occlusion, refl_img))

                # PAD ONLY ON RIGHT AND BOTTOM SIDE
                rgb = sample['rgb'][idx].cuda()
                shape_pad = [0, 0, 0, 0]

                shape_pad[3] = (img_shape[0] - rgb.shape[1])  # // 2
                shape_pad[1] = (img_shape[1] - rgb.shape[2])  # // 2 + 1

                rgb = F.pad(rgb, shape_pad)
                depth_img_no_occlusion = F.pad(depth_img_no_occlusion, shape_pad)

                rgb_input.append(rgb)
                lidar_input.append(depth_img_no_occlusion)

            lidar_input = torch.stack(lidar_input)
            rgb_input = torch.stack(rgb_input)
            end_preprocess = time.time()
            loss = train(model, optimizer, rgb_input, lidar_input, sample['tr_error'],
                         sample['rot_error'], loss_fn, sample['point_cloud'])

            if loss != loss:
                raise ValueError("Loss is NaN")

            #train_writer.add_scalar("Loss", loss, total_iter)
            local_loss += loss

            if batch_idx % 50 == 0 and batch_idx != 0:

                print(f'Iter {batch_idx}/{len(TrainImgLoader)} training loss = {local_loss/50:.3f}, '
                      f'time = {(time.time() - start_time)/lidar_input.shape[0]:.4f}, '
                      #f'time_preprocess = {(end_preprocess-start_preprocess)/lidar_input.shape[0]:.4f}, '
                      f'time for 50 iter: {time.time()-time_for_50ep:.4f}')
                time_for_50ep = time.time()
                _run.log_scalar("Loss", local_loss/50, total_iter)
                local_loss = 0.
            total_train_loss += loss * len(sample['rgb'])
            total_iter += len(sample['rgb'])

        print("------------------------------------")
        print('epoch %d total training loss = %.3f' % (epoch, total_train_loss / len(dataset)))
        print('Total epoch time = %.2f' % (time.time() - epoch_start_time))
        print("------------------------------------")
        _run.log_scalar("Total training loss", total_train_loss / len(dataset), epoch)

        ## Test ##
        total_test_loss = 0.
        total_test_t = 0.
        total_test_r = 0.

        local_loss = 0.0
        for batch_idx, sample in enumerate(TestImgLoader):
            # print(f'batch {batch_idx + 1}/{len(TestImgLoader)}', end='\r')
            start_time = time.time()
            lidar_input = []
            rgb_input = []

            #sample['rgb'] = sample['rgb'].cuda()
            sample['tr_error'] = sample['tr_error'].cuda()
            sample['rot_error'] = sample['rot_error'].cuda()

            for idx in range(len(sample['rgb'])):
                # ProjectPointCloud in RT-pose
                real_shape = [sample['rgb'][idx].shape[1], sample['rgb'][idx].shape[2], sample['rgb'][idx].shape[0]]

                sample['point_cloud'][idx] = sample['point_cloud'][idx].cuda()
                pc_rotated = sample['point_cloud'][idx].clone()
                reflectance = None
                if _config['use_reflectance']:
                    reflectance = sample['reflectance'][idx].cuda()

                R = mathutils.Quaternion(sample['rot_error'][idx]).to_matrix()
                R.resize_4x4()
                T = mathutils.Matrix.Translation(sample['tr_error'][idx])
                RT = T * R

                pc_rotated = rotate_back(pc_rotated, RT)

                if _config['max_depth'] < 100.:
                    pc_rotated = pc_rotated[:, pc_rotated[0, :] < _config['max_depth']].clone()

                cam_params = sample['calib'][idx].cuda()
                cam_model = CameraModel()
                cam_model.focal_length = cam_params[:2]
                cam_model.principal_point = cam_params[2:]
                uv, depth, _, refl = cam_model.project_pytorch(pc_rotated, real_shape, reflectance)
                uv = uv.t().int()
                depth_img = torch.zeros(real_shape[:2], device='cuda', dtype=torch.float)
                depth_img += 1000.
                depth_img = visibility.depth_image(uv, depth, depth_img, uv.shape[0], real_shape[1], real_shape[0])
                depth_img[depth_img == 1000.] = 0.

                depth_img_no_occlusion = torch.zeros_like(depth_img, device='cuda')
                depth_img_no_occlusion = visibility.visibility2(depth_img,
                                                                cam_params, depth_img_no_occlusion,
                                                                depth_img.shape[1], depth_img.shape[0],
                                                                occlusion_threshold, _config['occlusion_kernel'])

                if _config['use_reflectance']:
                    uv = uv.long()
                    indexes = depth_img_no_occlusion[uv[:,1], uv[:,0]] == depth
                    refl_img = torch.zeros(real_shape[:2], device='cuda', dtype=torch.float)
                    refl_img[uv[indexes,1], uv[indexes,0]] = refl[0, indexes]

                depth_img_no_occlusion /= _config['max_depth']
                if not _config['use_reflectance']:
                    depth_img_no_occlusion = depth_img_no_occlusion.unsqueeze(0)
                else:
                    depth_img_no_occlusion = torch.stack((depth_img_no_occlusion, refl_img))

                rgb = sample['rgb'][idx].cuda()
                shape_pad = [0, 0, 0, 0]

                shape_pad[3] = (img_shape[0] - rgb.shape[1])
                shape_pad[1] = (img_shape[1] - rgb.shape[2])

                rgb = F.pad(rgb, shape_pad)
                depth_img_no_occlusion = F.pad(depth_img_no_occlusion, shape_pad)

                rgb_input.append(rgb)
                lidar_input.append(depth_img_no_occlusion)

            lidar_input = torch.stack(lidar_input)
            rgb_input = torch.stack(rgb_input)

            loss, trasl_e, rot_e = test(model, rgb_input, lidar_input, sample['tr_error'],
                                        sample['rot_error'], loss_fn, dataset_val.model, sample['point_cloud'])

            if loss != loss:
                raise ValueError("Loss is NaN")

            total_test_t += trasl_e
            total_test_r += rot_e
            local_loss += loss

            if batch_idx % 50 == 0 and batch_idx != 0:
                print('Iter %d test loss = %.3f , time = %.2f' % (batch_idx, local_loss/50.,
                                                                  (time.time() - start_time)/lidar_input.shape[0]))
                local_loss = 0.0
            total_test_loss += loss * len(sample['rgb'])

        print("------------------------------------")
        print('total test loss = %.3f' % (total_test_loss / len(dataset_val)))
        print(f'total traslation error: {total_test_t / len(dataset_val)} cm')
        print(f'total rotation error: {total_test_r / len(dataset_val)} °')
        print("------------------------------------")

        #train_writer.add_scalar("Val_Loss", total_test_loss / len(dataset_val), epoch)
        #train_writer.add_scalar("Val_t_error", total_test_t / len(dataset_val), epoch)
        #train_writer.add_scalar("Val_r_error", total_test_r / len(dataset_val), epoch)
        _run.log_scalar("Val_Loss", total_test_loss / len(dataset_val), epoch)
        _run.log_scalar("Val_t_error", total_test_t / len(dataset_val), epoch)
        _run.log_scalar("Val_r_error", total_test_r / len(dataset_val), epoch)

        # SAVE
        val_loss = total_test_loss / len(dataset_val)
        if val_loss < BEST_VAL_LOSS:
            BEST_VAL_LOSS = val_loss
            #_run.result = BEST_VAL_LOSS
            if _config['rescale_transl'] > 0:
                _run.result = total_test_t / len(dataset_val)
            else:
                _run.result = total_test_r / len(dataset_val)
            savefilename = f'{_config["savemodel"]}/checkpoint_r{_config["max_r"]:.2f}_t{_config["max_t"]:.2f}_e{epoch}_{val_loss:.3f}.tar'
            torch.save({
                'config': _config,
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'train_loss': total_train_loss / len(dataset),
                'test_loss': total_test_loss / len(dataset_val),
            }, savefilename)
            print(f'Model saved as {savefilename}')
            if old_save_filename is not None:
                if os.path.exists(old_save_filename):
                    os.remove(old_save_filename)
            old_save_filename = savefilename

    print('full training time = %.2f HR' % ((time.time() - start_full_time) / 3600))
    return _run.result