Exemple #1
0
def loss_calculation(pred_r, pred_t, pred_c, target, model_points, idx, points, w, refine, num_point_mesh, sym_list):
    knn = KNearestNeighbor(1)
    bs, num_p, _ = pred_c.size()

    pred_r = pred_r / (torch.norm(pred_r, dim=2).view(bs, num_p, 1))
    
    base = torch.cat(((1.0 - 2.0*(pred_r[:, :, 2]**2 + pred_r[:, :, 3]**2)).view(bs, num_p, 1),\
                      (2.0*pred_r[:, :, 1]*pred_r[:, :, 2] - 2.0*pred_r[:, :, 0]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (2.0*pred_r[:, :, 0]*pred_r[:, :, 2] + 2.0*pred_r[:, :, 1]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (2.0*pred_r[:, :, 1]*pred_r[:, :, 2] + 2.0*pred_r[:, :, 3]*pred_r[:, :, 0]).view(bs, num_p, 1), \
                      (1.0 - 2.0*(pred_r[:, :, 1]**2 + pred_r[:, :, 3]**2)).view(bs, num_p, 1), \
                      (-2.0*pred_r[:, :, 0]*pred_r[:, :, 1] + 2.0*pred_r[:, :, 2]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (-2.0*pred_r[:, :, 0]*pred_r[:, :, 2] + 2.0*pred_r[:, :, 1]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (2.0*pred_r[:, :, 0]*pred_r[:, :, 1] + 2.0*pred_r[:, :, 2]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (1.0 - 2.0*(pred_r[:, :, 1]**2 + pred_r[:, :, 2]**2)).view(bs, num_p, 1)), dim=2).contiguous().view(bs * num_p, 3, 3)

    ori_base = base
    base = base.contiguous().transpose(2, 1).contiguous()
    model_points = model_points.view(bs, 1, num_point_mesh, 3).repeat(1, num_p, 1, 1).view(bs * num_p, num_point_mesh, 3)
    target = target.view(bs, 1, num_point_mesh, 3).repeat(1, num_p, 1, 1).view(bs * num_p, num_point_mesh, 3)
    ori_target = target
    pred_t = pred_t.contiguous().view(bs * num_p, 1, 3)
    ori_t = pred_t
    points = points.contiguous().view(bs * num_p, 1, 3)
    pred_c = pred_c.contiguous().view(bs * num_p)

    pred = torch.add(torch.bmm(model_points, base), points + pred_t)

    if not refine:
        if idx[0].item() in sym_list:
            target = target[0].transpose(1, 0).contiguous().view(3, -1)
            pred = pred.permute(2, 0, 1).contiguous().view(3, -1)
            inds = knn(target.unsqueeze(0), pred.unsqueeze(0))
            target = torch.index_select(target, 1, inds.view(-1) - 1)
            target = target.view(3, bs * num_p, num_point_mesh).permute(1, 2, 0).contiguous()
            pred = pred.view(3, bs * num_p, num_point_mesh).permute(1, 2, 0).contiguous()

    dis = torch.mean(torch.norm((pred - target), dim=2), dim=1)
    loss = torch.mean((dis * pred_c - w * torch.log(pred_c)), dim=0)
    

    pred_c = pred_c.view(bs, num_p)
    how_max, which_max = torch.max(pred_c, 1)
    dis = dis.view(bs, num_p)


    t = ori_t[which_max[0]] + points[which_max[0]]
    points = points.view(1, bs * num_p, 3)

    ori_base = ori_base[which_max[0]].view(1, 3, 3).contiguous()
    ori_t = t.repeat(bs * num_p, 1).contiguous().view(1, bs * num_p, 3)
    new_points = torch.bmm((points - ori_t), ori_base).contiguous()

    new_target = ori_target[0].view(1, num_point_mesh, 3).contiguous()
    ori_t = t.repeat(num_point_mesh, 1).contiguous().view(1, num_point_mesh, 3)
    new_target = torch.bmm((new_target - ori_t), ori_base).contiguous()

    # print('------------> ', dis[0][which_max[0]].item(), pred_c[0][which_max[0]].item(), idx[0].item())
    del knn
    return loss, dis[0][which_max[0]], new_points.detach(), new_target.detach()
Exemple #2
0
def loss_conf(pred_loss, pred_r, pred_t, pred_c, target, model_points, idx,
              points, w, refine, num_point_mesh, sym_list):
    knn = KNearestNeighbor(1)
    bs, num_p, _ = pred_c.size()

    pred_r = pred_r / (torch.norm(pred_r, dim=2).view(bs, num_p, 1))

    base = torch.cat(((1.0 - 2.0*(pred_r[:, :, 2]**2 + pred_r[:, :, 3]**2)).view(bs, num_p, 1),\
                      (2.0*pred_r[:, :, 1]*pred_r[:, :, 2] - 2.0*pred_r[:, :, 0]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (2.0*pred_r[:, :, 0]*pred_r[:, :, 2] + 2.0*pred_r[:, :, 1]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (2.0*pred_r[:, :, 1]*pred_r[:, :, 2] + 2.0*pred_r[:, :, 3]*pred_r[:, :, 0]).view(bs, num_p, 1), \
                      (1.0 - 2.0*(pred_r[:, :, 1]**2 + pred_r[:, :, 3]**2)).view(bs, num_p, 1), \
                      (-2.0*pred_r[:, :, 0]*pred_r[:, :, 1] + 2.0*pred_r[:, :, 2]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (-2.0*pred_r[:, :, 0]*pred_r[:, :, 2] + 2.0*pred_r[:, :, 1]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (2.0*pred_r[:, :, 0]*pred_r[:, :, 1] + 2.0*pred_r[:, :, 2]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (1.0 - 2.0*(pred_r[:, :, 1]**2 + pred_r[:, :, 2]**2)).view(bs, num_p, 1)), dim=2).contiguous().view(bs * num_p, 3, 3)

    ori_base = base
    base = base.contiguous().transpose(2, 1).contiguous()
    model_points = model_points.view(bs, 1, num_point_mesh,
                                     3).repeat(1, num_p, 1,
                                               1).view(bs * num_p,
                                                       num_point_mesh, 3)
    target = target.view(bs, 1, num_point_mesh,
                         3).repeat(1, num_p, 1,
                                   1).view(bs * num_p, num_point_mesh, 3)
    ori_target = target
    pred_t = pred_t.contiguous().view(bs * num_p, 1, 3)
    ori_t = pred_t
    points = points.contiguous().view(bs * num_p, 1, 3)
    pred_c = pred_c.contiguous().view(bs * num_p)

    pred = torch.add(torch.bmm(model_points, base), points + pred_t)

    if not refine:
        if idx[0].item() in sym_list:
            target = target[0].transpose(1, 0).contiguous().view(3, -1)
            pred = pred.permute(2, 0, 1).contiguous().view(3, -1)
            inds = knn(target.unsqueeze(0), pred.unsqueeze(0))
            target = torch.index_select(target, 1, inds.view(-1).detach() - 1)
            target = target.view(3, bs * num_p,
                                 num_point_mesh).permute(1, 2, 0).contiguous()
            pred = pred.view(3, bs * num_p,
                             num_point_mesh).permute(1, 2, 0).contiguous()

    dis = torch.mean(torch.norm((pred - target), dim=2), dim=1)
    dis = dis.view(bs, num_p)
    pred_loss = pred_loss.view(bs, num_p)

    loss = torch.mean((pred_loss - dis)**2)

    how_min, which_min = torch.min(pred_loss, 1)
    dis = dis.view(bs, num_p)

    del knn
    return loss, dis[0][which_min[0]]
Exemple #3
0
def loss_calculation(pred_r, pred_t, pred_c, target_r, target_t, model_points, idx, obj_diameter, rot_anchors, sym_list):
    """
    Args:
        pred_t: bs x num_p x 3
        target_t: bs x num_p x 3
        pred_r: bs x num_rot x 4
        pred_c: bs x num_rot
        target_r: bs x num_point_mesh x 3
        rot_anchors: num_rot x 4
        model_points: bs x num_point_mesh x 3
        idx: bs x 1, index of object in object class list
    Return:
        loss:

    """
    knn = KNearestNeighbor(1)
    bs, num_p, _ = pred_t.size()
    num_rot = pred_r.size()[1]
    num_point_mesh = model_points.size()[1]
    # regularization loss
    rot_anchors = torch.from_numpy(rot_anchors).float().cuda()
    rot_anchors = rot_anchors.unsqueeze(0).repeat(bs, 1, 1).permute(0, 2, 1)
    cos_dist = torch.bmm(pred_r, rot_anchors)   # bs x num_rot x num_rot
    loss_reg = F.threshold((torch.max(cos_dist, 2)[0] - torch.diagonal(cos_dist, dim1=1, dim2=2)), 0.001, 0)
    loss_reg = torch.mean(loss_reg)
    # rotation loss
    rotations = torch.cat(((1.0 - 2.0*(pred_r[:, :, 2]**2 + pred_r[:, :, 3]**2)).view(bs, num_rot, 1),\
                           (2.0*pred_r[:, :, 1]*pred_r[:, :, 2] - 2.0*pred_r[:, :, 0]*pred_r[:, :, 3]).view(bs, num_rot, 1), \
                           (2.0*pred_r[:, :, 0]*pred_r[:, :, 2] + 2.0*pred_r[:, :, 1]*pred_r[:, :, 3]).view(bs, num_rot, 1), \
                           (2.0*pred_r[:, :, 1]*pred_r[:, :, 2] + 2.0*pred_r[:, :, 3]*pred_r[:, :, 0]).view(bs, num_rot, 1), \
                           (1.0 - 2.0*(pred_r[:, :, 1]**2 + pred_r[:, :, 3]**2)).view(bs, num_rot, 1), \
                           (-2.0*pred_r[:, :, 0]*pred_r[:, :, 1] + 2.0*pred_r[:, :, 2]*pred_r[:, :, 3]).view(bs, num_rot, 1), \
                           (-2.0*pred_r[:, :, 0]*pred_r[:, :, 2] + 2.0*pred_r[:, :, 1]*pred_r[:, :, 3]).view(bs, num_rot, 1), \
                           (2.0*pred_r[:, :, 0]*pred_r[:, :, 1] + 2.0*pred_r[:, :, 2]*pred_r[:, :, 3]).view(bs, num_rot, 1), \
                           (1.0 - 2.0*(pred_r[:, :, 1]**2 + pred_r[:, :, 2]**2)).view(bs, num_rot, 1)), dim=2).contiguous().view(bs*num_rot, 3, 3)
    rotations = rotations.contiguous().transpose(2, 1).contiguous()
    model_points = model_points.view(bs, 1, num_point_mesh, 3).repeat(1, num_rot, 1, 1).view(bs*num_rot, num_point_mesh, 3)
    pred_r = torch.bmm(model_points, rotations)
    if idx[0].item() in sym_list:
        target_r = target_r[0].transpose(1, 0).contiguous().view(3, -1)
        pred_r = pred_r.permute(2, 0, 1).contiguous().view(3, -1)
        inds = knn(target_r.unsqueeze(0), pred_r.unsqueeze(0))
        target_r = torch.index_select(target_r, 1, inds.view(-1).detach() - 1)
        target_r = target_r.view(3, bs*num_rot, num_point_mesh).permute(1, 2, 0).contiguous()
        pred_r = pred_r.view(3, bs*num_rot, num_point_mesh).permute(1, 2, 0).contiguous()
    dis = torch.mean(torch.norm((pred_r - target_r), dim=2), dim=1)
    dis = dis / obj_diameter   # normalize by diameter
    pred_c = pred_c.contiguous().view(bs*num_rot)
    loss_r = torch.mean(dis / pred_c + torch.log(pred_c), dim=0)
    # translation loss
    loss_t = F.smooth_l1_loss(pred_t, target_t, reduction='mean')
    # total loss
    loss = loss_r + 2.0 * loss_reg + 5.0 * loss_t
    del knn
    return loss, loss_r, loss_t, loss_reg
Exemple #4
0
def loss_calculation(pred_r, pred_t, target, model_points, idx, points, num_point_mesh, sym_list):
    # this target is new target
    knn = KNearestNeighbor(1)
    pred_r = pred_r.view(1, 1, -1)
    pred_t = pred_t.view(1, 1, -1)
    bs, num_p, _ = pred_r.size()
    num_input_points = len(points[0])

    pred_r = pred_r / (torch.norm(pred_r, dim=2).view(bs, num_p, 1))
    
    base = torch.cat(((1.0 - 2.0*(pred_r[:, :, 2]**2 + pred_r[:, :, 3]**2)).view(bs, num_p, 1),\
                      (2.0*pred_r[:, :, 1]*pred_r[:, :, 2] - 2.0*pred_r[:, :, 0]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (2.0*pred_r[:, :, 0]*pred_r[:, :, 2] + 2.0*pred_r[:, :, 1]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (2.0*pred_r[:, :, 1]*pred_r[:, :, 2] + 2.0*pred_r[:, :, 3]*pred_r[:, :, 0]).view(bs, num_p, 1), \
                      (1.0 - 2.0*(pred_r[:, :, 1]**2 + pred_r[:, :, 3]**2)).view(bs, num_p, 1), \
                      (-2.0*pred_r[:, :, 0]*pred_r[:, :, 1] + 2.0*pred_r[:, :, 2]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (-2.0*pred_r[:, :, 0]*pred_r[:, :, 2] + 2.0*pred_r[:, :, 1]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (2.0*pred_r[:, :, 0]*pred_r[:, :, 1] + 2.0*pred_r[:, :, 2]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (1.0 - 2.0*(pred_r[:, :, 1]**2 + pred_r[:, :, 2]**2)).view(bs, num_p, 1)), dim=2).contiguous().view(bs * num_p, 3, 3)
    
    ori_base = base
    base = base.contiguous().transpose(2, 1).contiguous()
    model_points = model_points.view(bs, 1, num_point_mesh, 3).repeat(1, num_p, 1, 1).view(bs * num_p, num_point_mesh, 3)
    target = target.view(bs, 1, num_point_mesh, 3).repeat(1, num_p, 1, 1).view(bs * num_p, num_point_mesh, 3)
    ori_target = target
    pred_t = pred_t.contiguous().view(bs * num_p, 1, 3)
    ori_t = pred_t
    # residual!!!
    pred = torch.add(torch.bmm(model_points, base), pred_t)

    if idx[0].item() in sym_list:
        target = target[0].transpose(1, 0).contiguous().view(3, -1)
        pred = pred.permute(2, 0, 1).contiguous().view(3, -1)
        inds = knn(target.unsqueeze(0), pred.unsqueeze(0))
        target = torch.index_select(target, 1, inds.view(-1) - 1)
        target = target.view(3, bs * num_p, num_point_mesh).permute(1, 2, 0).contiguous()
        pred = pred.view(3, bs * num_p, num_point_mesh).permute(1, 2, 0).contiguous()

    dis = torch.mean(torch.norm((pred - target), dim=2), dim=1)

    t = ori_t[0]
    points = points.view(1, num_input_points, 3)

    ori_base = ori_base[0].view(1, 3, 3).contiguous()
    ori_t = t.repeat(bs * num_input_points, 1).contiguous().view(1, bs * num_input_points, 3)
    new_points = torch.bmm((points - ori_t), ori_base).contiguous()

    new_target = ori_target[0].view(1, num_point_mesh, 3).contiguous()
    ori_t = t.repeat(num_point_mesh, 1).contiguous().view(1, num_point_mesh, 3)
    new_target = torch.bmm((new_target - ori_t), ori_base).contiguous()

    # print('------------> ', dis.item(), idx[0].item())
    del knn
    return dis, new_points.detach(), new_target.detach()
Exemple #5
0
parser.add_argument('--refine_model',
                    type=str,
                    default='',
                    help='resume PoseRefineNet model')
opt = parser.parse_args()
import open3d as o3d
import cv2

num_objects = 13
objlist = [1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 15]
num_points = 500
iteration = 4
bs = 1
dataset_config_dir = 'datasets/linemod/dataset_config'
output_result_dir = 'experiments/eval_result/linemod'
knn = KNearestNeighbor(1)

estimator = PoseNet(num_points=num_points, num_obj=num_objects)
estimator.cuda()
refiner = PoseRefineNet(num_points=num_points, num_obj=num_objects)
refiner.cuda()
estimator.load_state_dict(torch.load(opt.model))
refiner.load_state_dict(torch.load(opt.refine_model))
estimator.eval()
refiner.eval()

testdataset = PoseDataset_linemod('test', num_points, False, opt.dataset_root,
                                  0.0, True)
testdataloader = torch.utils.data.DataLoader(testdataset,
                                             batch_size=1,
                                             shuffle=False,
def loss_calculation(pred_r, pred_t, pred_c, dis_vector_last, target,
                     model_points, idx, points, w, refine, num_point_mesh,
                     sym_list, stable_alpha):
    knn = KNearestNeighbor(1)
    bs, num_p, _ = pred_c.size()
    base = pred_r
    ori_base = pred_r.transpose(1, 2)
    base = base.contiguous()

    model_points = model_points.view(bs, 1, num_point_mesh,
                                     3).repeat(1, num_p, 1,
                                               1).view(bs * num_p,
                                                       num_point_mesh, 3)
    target = target.view(bs, 1, num_point_mesh,
                         3).repeat(1, num_p, 1,
                                   1).view(bs * num_p, num_point_mesh, 3)
    ori_target = target
    pred_t = pred_t.contiguous().view(bs * num_p, 1, 3)
    ori_t = pred_t
    points = points.contiguous().view(bs * num_p, 1, 3)
    pred_c = pred_c.contiguous().view(bs * num_p)

    pred = torch.add(torch.bmm(model_points, base), points + pred_t)

    if not refine:
        if idx[0].item() in sym_list:
            target = target[0].transpose(1, 0).contiguous().view(3, -1)
            pred = pred.permute(2, 0, 1).contiguous().view(3, -1)
            inds = knn(target.unsqueeze(0), pred.unsqueeze(0))
            target = torch.index_select(target, 1, inds.view(-1).detach() - 1)
            target = target.view(3, bs * num_p,
                                 num_point_mesh).permute(1, 2, 0).contiguous()
            pred = pred.view(3, bs * num_p,
                             num_point_mesh).permute(1, 2, 0).contiguous()

    dis = torch.mean(torch.norm((pred - target), dim=2), dim=1)
    loss = torch.mean((dis * pred_c - w * torch.log(pred_c)), dim=0)

    pred_c = pred_c.view(bs, num_p)
    how_max, which_max = torch.max(pred_c, 1)
    dis = dis.view(bs, num_p)
    dis_vector = torch.mean((pred - target)[which_max[0]], dim=0)
    if dis_vector_last is None: dis_vector_last = dis_vector
    loss_stable = loss + stable_alpha * torch.norm(
        dis_vector - dis_vector_last, dim=0)

    t = ori_t[which_max[0]] + points[which_max[0]]
    points = points.view(1, bs * num_p, 3)
    ori_base = ori_base[which_max[0]].view(1, 3, 3).contiguous()
    ori_t = t.repeat(bs * num_p, 1).contiguous().view(1, bs * num_p, 3)

    new_points = torch.bmm((points - ori_t), ori_base).contiguous()

    new_target = ori_target[0].view(1, num_point_mesh, 3).contiguous()
    ori_t = t.repeat(num_point_mesh, 1).contiguous().view(1, num_point_mesh, 3)

    new_target = torch.bmm((new_target - ori_t), ori_base).contiguous()

    del knn
    return loss_stable, dis[0][
        which_max[0]], new_points.detach(), new_target.detach(), dis_vector
Exemple #7
0
def main():
    # g13: parameter setting -------------------
    '''
    posemodel is trained_checkpoints/linemod/pose_model_9_0.01310166542980859.pth
    refine model is trained_checkpoints/linemod/pose_refine_model_493_0.006761023565178073.pth

    '''
    objlist = [1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 15]
    knn = KNearestNeighbor(1)
    opt.dataset ='linemod'
    opt.dataset_root = './datasets/linemod/Linemod_preprocessed'
    estimator_path = 'trained_checkpoints/linemod/pose_model_9_0.01310166542980859.pth'
    refiner_path = 'trained_checkpoints/linemod/pose_refine_model_493_0.006761023565178073.pth'
    opt.model = estimator_path
    opt.refine_model = refiner_path
    dataset_config_dir = 'datasets/linemod/dataset_config'
    output_result_dir = 'experiments/eval_result/linemod'
    opt.refine_start = True
    bs = 1 #fixed because of the default setting in torch.utils.data.DataLoader
    opt.iteration = 2 #default is 4 in eval_linemod.py
    t1_start = True
    t1_idx = 0
    t1_total_eval_num = 3
    t2_start = False
    t2_target_list = [22, 30, 172, 187, 267, 363, 410, 471, 472, 605, 644, 712, 1046, 1116, 1129, 1135, 1263]
    #t2_target_list = [0, 1]
    axis_range = 0.1   # the length of X, Y, and Z axis in 3D
    vimg_dir = 'verify_img'
    diameter = []
    meta_file = open('{0}/models_info.yml'.format(dataset_config_dir), 'r')
    meta_d = yaml.load(meta_file)
    for obj in objlist:
        diameter.append(meta_d[obj]['diameter'] / 1000.0 * 0.1)
    print(diameter)
    if not os.path.exists(vimg_dir):
        os.makedirs(vimg_dir)
    #-------------------------------------------
    
    if opt.dataset == 'ycb':
        opt.num_objects = 21 #number of object classes in the dataset
        opt.num_points = 1000 #number of points on the input pointcloud
        opt.outf = 'trained_models/ycb' #folder to save trained models
        opt.log_dir = 'experiments/logs/ycb' #folder to save logs
        opt.repeat_epoch = 1 #number of repeat times for one epoch training
    elif opt.dataset == 'linemod':
        opt.num_objects = 13
        opt.num_points = 500
        opt.outf = 'trained_models/linemod'
        opt.log_dir = 'experiments/logs/linemod'
        opt.repeat_epoch = 20
    else:
        print('Unknown dataset')
        return
    
    estimator = PoseNet(num_points = opt.num_points, num_obj = opt.num_objects)
    estimator.cuda()
    refiner = PoseRefineNet(num_points = opt.num_points, num_obj = opt.num_objects)
    refiner.cuda()
  
    estimator.load_state_dict(torch.load(estimator_path))    
    refiner.load_state_dict(torch.load(refiner_path))
    opt.refine_start = True
    
    test_dataset = PoseDataset_linemod('test', opt.num_points, False, opt.dataset_root, 0.0, opt.refine_start)
    testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=opt.workers)
    
    opt.sym_list = test_dataset.get_sym_list()
    opt.num_points_mesh = test_dataset.get_num_points_mesh()

    print('>>>>>>>>----------Dataset loaded!---------<<<<<<<<\n\
        length of the testing set: {0}\nnumber of sample points on mesh: {1}\n\
        symmetry object list: {2}'\
        .format( len(test_dataset), opt.num_points_mesh, opt.sym_list))
   
    
    #load pytorch model
    estimator.eval()    
    refiner.eval()
    criterion = Loss(opt.num_points_mesh, opt.sym_list)
    criterion_refine = Loss_refine(opt.num_points_mesh, opt.sym_list)
    fw = open('{0}/t1_eval_result_logs.txt'.format(output_result_dir), 'w')

    #Pose estimation
    for j, data in enumerate(testdataloader, 0):
        # g13: modify this part for evaluation target--------------------
        if t1_start and j == t1_total_eval_num:
            break
        if t2_start and not (j in t2_target_list):
            continue
        #----------------------------------------------------------------
        points, choose, img, target, model_points, idx = data
        if len(points.size()) == 2:
            print('No.{0} NOT Pass! Lost detection!'.format(j))
            fw.write('No.{0} NOT Pass! Lost detection!\n'.format(j))
            continue
        points, choose, img, target, model_points, idx = Variable(points).cuda(), \
                                                             Variable(choose).cuda(), \
                                                             Variable(img).cuda(), \
                                                             Variable(target).cuda(), \
                                                             Variable(model_points).cuda(), \
                                                             Variable(idx).cuda()
        pred_r, pred_t, pred_c, emb = estimator(img, points, choose, idx)
        _, dis, new_points, new_target = criterion(pred_r, pred_t, pred_c, target, model_points, idx, points, opt.w, opt.refine_start)

        #if opt.refine_start: #iterative poserefinement
        #    for ite in range(0, opt.iteration):
        #        pred_r, pred_t = refiner(new_points, emb, idx)
        #        dis, new_points, new_target = criterion_refine(pred_r, pred_t, new_target, model_points, idx, new_points)
        
        pred_r = pred_r / torch.norm(pred_r, dim=2).view(1, opt.num_points, 1)
        pred_c = pred_c.view(bs, opt.num_points)
        how_max, which_max = torch.max(pred_c, 1)
        pred_t = pred_t.view(bs * opt.num_points, 1, 3)
    
        my_r = pred_r[0][which_max[0]].view(-1).cpu().data.numpy()
        my_t = (points.view(bs * opt.num_points, 1, 3) + pred_t)[which_max[0]].view(-1).cpu().data.numpy()
        my_pred = np.append(my_r, my_t)
    
        for ite in range(0, opt.iteration):
            T = Variable(torch.from_numpy(my_t.astype(np.float32))).cuda().view(1, 3).repeat(opt.num_points, 1).contiguous().view(1, opt.num_points, 3)
            my_mat = quaternion_matrix(my_r)
            R = Variable(torch.from_numpy(my_mat[:3, :3].astype(np.float32))).cuda().view(1, 3, 3)
            my_mat[0:3, 3] = my_t
            
            new_points = torch.bmm((points - T), R).contiguous()
            pred_r, pred_t = refiner(new_points, emb, idx)
            pred_r = pred_r.view(1, 1, -1)
            pred_r = pred_r / (torch.norm(pred_r, dim=2).view(1, 1, 1))
            my_r_2 = pred_r.view(-1).cpu().data.numpy()
            my_t_2 = pred_t.view(-1).cpu().data.numpy()
            my_mat_2 = quaternion_matrix(my_r_2)
            my_mat_2[0:3, 3] = my_t_2
    
            my_mat_final = np.dot(my_mat, my_mat_2)
            my_r_final = copy.deepcopy(my_mat_final)
            my_r_final[0:3, 3] = 0
            my_r_final = quaternion_from_matrix(my_r_final, True)
            my_t_final = np.array([my_mat_final[0][3], my_mat_final[1][3], my_mat_final[2][3]])
    
            my_pred = np.append(my_r_final, my_t_final)
            my_r = my_r_final
            my_t = my_t_final
            # Here 'my_pred' is the final pose estimation result after refinement ('my_r': quaternion, 'my_t': translation)
        
        #g13: checking the dis value
        success_count = [0 for i in range(opt.num_objects)]
        num_count = [0 for i in range(opt.num_objects)]
        model_points = model_points[0].cpu().detach().numpy()
        my_r = quaternion_matrix(my_r)[:3, :3]
        pred = np.dot(model_points, my_r.T) + my_t
        target = target[0].cpu().detach().numpy()
    
        if idx[0].item() in opt.sym_list:
            pred = torch.from_numpy(pred.astype(np.float32)).cuda().transpose(1, 0).contiguous()
            target = torch.from_numpy(target.astype(np.float32)).cuda().transpose(1, 0).contiguous()
            inds = knn(target.unsqueeze(0), pred.unsqueeze(0))
            target = torch.index_select(target, 1, inds.view(-1) - 1)
            dis = torch.mean(torch.norm((pred.transpose(1, 0) - target.transpose(1, 0)), dim=1), dim=0).item()
        else:
            dis = np.mean(np.linalg.norm(pred - target, axis=1))
    
        if dis < diameter[idx[0].item()]:
            success_count[idx[0].item()] += 1
            print('No.{0} Pass! Distance: {1}'.format(j, dis))
            fw.write('No.{0} Pass! Distance: {1}\n'.format(j, dis))
        else:
            print('No.{0} NOT Pass! Distance: {1}'.format(j, dis))
            fw.write('No.{0} NOT Pass! Distance: {1}\n'.format(j, dis))
        num_count[idx[0].item()] += 1
        
        # g13: start drawing pose on image------------------------------------
        # pick up image
        print('{0}:\nmy_r is {1}\nmy_t is {2}\ndis:{3}'.format(j, my_r, my_t, dis.item()))    
        print("index {0}: {1}".format(j, test_dataset.list_rgb[j]))
        img = Image.open(test_dataset.list_rgb[j])
        
        # pick up center position by bbox
        meta_file = open('{0}/data/{1}/gt.yml'.format(opt.dataset_root, '%02d' % test_dataset.list_obj[j]), 'r')
        meta = {}
        meta = yaml.load(meta_file)
        which_item = test_dataset.list_rank[j]
        which_obj = test_dataset.list_obj[j]
        which_dict = 0
        dict_leng = len(meta[which_item])
        #print('get meta[{0}][{1}][obj_bb]'.format(which_item, which_obj))
        k_idx = 0
        while 1:
            if meta[which_item][k_idx]['obj_id'] == which_obj:
                which_dict = k_idx
                break
            k_idx = k_idx+1
        
        bbx = meta[which_item][which_dict]['obj_bb']
        draw = ImageDraw.Draw(img) 
        
        # draw box (ensure this is the right object)
        draw.line((bbx[0],bbx[1], bbx[0], bbx[1]+bbx[3]), fill=(255,0,0), width=5)
        draw.line((bbx[0],bbx[1], bbx[0]+bbx[2], bbx[1]), fill=(255,0,0), width=5)
        draw.line((bbx[0],bbx[1]+bbx[3], bbx[0]+bbx[2], bbx[1]+bbx[3]), fill=(255,0,0), width=5)
        draw.line((bbx[0]+bbx[2],bbx[1], bbx[0]+bbx[2], bbx[1]+bbx[3]), fill=(255,0,0), width=5)
        
        #get center
        c_x = bbx[0]+int(bbx[2]/2)
        c_y = bbx[1]+int(bbx[3]/2)
        draw.point((c_x,c_y), fill=(255,255,0))
        print('center:({0},{1})'.format(c_x, c_y))
        
        #get the 3D position of center
        cam_intrinsic = np.zeros((3,3))
        cam_intrinsic.itemset(0, test_dataset.cam_fx)
        cam_intrinsic.itemset(4, test_dataset.cam_fy)
        cam_intrinsic.itemset(2, test_dataset.cam_cx)
        cam_intrinsic.itemset(5, test_dataset.cam_cy)
        cam_intrinsic.itemset(8, 1)
        cam_extrinsic = my_mat_final[0:3, :]
        cam2d_3d = np.matmul(cam_intrinsic, cam_extrinsic)
        cen_3d = np.matmul(np.linalg.pinv(cam2d_3d), [[c_x],[c_y],[1]])
        # replace img.show() with plt.imshow(img)
        
        #transpose three 3D axis point into 2D
        x_3d = cen_3d + [[axis_range],[0],[0],[0]]
        y_3d = cen_3d + [[0],[axis_range],[0],[0]]
        z_3d = cen_3d + [[0],[0],[axis_range],[0]]
        x_2d = np.matmul(cam2d_3d, x_3d)
        y_2d = np.matmul(cam2d_3d, y_3d)
        z_2d = np.matmul(cam2d_3d, z_3d)
        
        #draw the axis on 2D
        draw.line((c_x, c_y, x_2d[0], x_2d[1]), fill=(255,255,0), width=5)
        draw.line((c_x, c_y, y_2d[0], y_2d[1]), fill=(0,255,0), width=5)
        draw.line((c_x, c_y, z_2d[0], z_2d[1]), fill=(0,0,255), width=5)
        
        #g13: draw the estimate pred obj
        for pti in pred:
            pti.transpose()
            pti_2d = np.matmul(cam_intrinsic, pti)
            #print('({0},{1})\n'.format(int(pti_2d[0]),int(pti_2d[1])))
            draw.point([int(pti_2d[0]),int(pti_2d[1])], fill=(255,255,0))
            
        
        #g13: show image
        #img.show()
        
        #save file under file 
        img_file_name = '{0}/batch{1}_pred_obj{2}_pic{3}.png'.format(vimg_dir, j, test_dataset.list_obj[j], which_item)
        img.save( img_file_name, "PNG" )
        img.close()
        
        # plot ground true ----------------------------
        img = Image.open(test_dataset.list_rgb[j])
        draw = ImageDraw.Draw(img) 
        draw.line((bbx[0],bbx[1], bbx[0], bbx[1]+bbx[3]), fill=(255,0,0), width=5)
        draw.line((bbx[0],bbx[1], bbx[0]+bbx[2], bbx[1]), fill=(255,0,0), width=5)
        draw.line((bbx[0],bbx[1]+bbx[3], bbx[0]+bbx[2], bbx[1]+bbx[3]), fill=(255,0,0), width=5)
        draw.line((bbx[0]+bbx[2],bbx[1], bbx[0]+bbx[2], bbx[1]+bbx[3]), fill=(255,0,0), width=5)        
        target_r = np.resize(np.array(meta[which_item][k_idx]['cam_R_m2c']), (3, 3))                
        target_t = np.array(meta[which_item][k_idx]['cam_t_m2c'])
        target_t = target_t[np.newaxis, :]               
        cam_extrinsic_GT = np.concatenate((target_r, target_t.T), axis=1)
        
        
        #get center 3D
        cam2d_3d_GT = np.matmul(cam_intrinsic, cam_extrinsic_GT)
        cen_3d_GT = np.matmul(np.linalg.pinv(cam2d_3d_GT), [[c_x],[c_y],[1]])
        
        #transpose three 3D axis point into 2D
        x_3d = cen_3d_GT + [[axis_range],[0],[0],[0]]
        y_3d = cen_3d_GT + [[0],[axis_range],[0],[0]]
        z_3d = cen_3d_GT + [[0],[0],[axis_range],[0]]
        
        x_2d = np.matmul(cam2d_3d_GT, x_3d)
        y_2d = np.matmul(cam2d_3d_GT, y_3d)
        z_2d = np.matmul(cam2d_3d_GT, z_3d)

        #draw the axis on 2D
        draw.line((c_x, c_y, x_2d[0], x_2d[1]), fill=(255,255,0), width=5)
        draw.line((c_x, c_y, y_2d[0], y_2d[1]), fill=(0,255,0), width=5)
        draw.line((c_x, c_y, z_2d[0], z_2d[1]), fill=(0,0,255), width=5)
      
       
        print('pred:\n{0}\nGT:\n{1}\n'.format(cam_extrinsic,cam_extrinsic_GT))
        print('pred 3D:{0}\nGT 3D:{1}\n'.format(cen_3d, cen_3d_GT))
        img_file_name = '{0}/batch{1}_pred_obj{2}_pic{3}_gt.png'.format(vimg_dir, j, test_dataset.list_obj[j], which_item)
        img.save( img_file_name, "PNG" )
        img.close()
        meta_file.close()
    print('\nplot_result_img.py completed the task\n')
Exemple #8
0
def loss_calculation(pred_r, pred_t, pred_c, target, model_points, idx, points,
                     w, refine, num_point_mesh, sym_list):
    knn = KNearestNeighbor(1)
    bs, num_p, _ = pred_c.size()
    pred_r = pred_r / (torch.norm(pred_r, dim=2).view(bs, num_p, 1))
    # 用四元数获得旋转矩阵
    # quaternion to rotate matrix
    base = torch.cat(((1.0 - 2.0*(pred_r[:, :, 2]**2 + pred_r[:, :, 3]**2)).view(bs, num_p, 1),\
                      (2.0*pred_r[:, :, 1]*pred_r[:, :, 2] - 2.0*pred_r[:, :, 0]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (2.0*pred_r[:, :, 0]*pred_r[:, :, 2] + 2.0*pred_r[:, :, 1]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (2.0*pred_r[:, :, 1]*pred_r[:, :, 2] + 2.0*pred_r[:, :, 3]*pred_r[:, :, 0]).view(bs, num_p, 1), \
                      (1.0 - 2.0*(pred_r[:, :, 1]**2 + pred_r[:, :, 3]**2)).view(bs, num_p, 1), \
                      (-2.0*pred_r[:, :, 0]*pred_r[:, :, 1] + 2.0*pred_r[:, :, 2]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (-2.0*pred_r[:, :, 0]*pred_r[:, :, 2] + 2.0*pred_r[:, :, 1]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (2.0*pred_r[:, :, 0]*pred_r[:, :, 1] + 2.0*pred_r[:, :, 2]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (1.0 - 2.0*(pred_r[:, :, 1]**2 + pred_r[:, :, 2]**2)).view(bs, num_p, 1)), dim=2).contiguous().view(bs * num_p, 3, 3)

    ori_base = base
    base = base.contiguous().transpose(2, 1).contiguous()
    # repeat for every predict
    # I think can ues broadcast
    # 重复,这样跟每个预测一一对应,我觉得可以用广播,还没试
    model_points = model_points.view(bs, 1, num_point_mesh,
                                     3).repeat(1, num_p, 1,
                                               1).view(bs * num_p,
                                                       num_point_mesh, 3)
    target = target.view(bs, 1, num_point_mesh,
                         3).repeat(1, num_p, 1,
                                   1).view(bs * num_p, num_point_mesh, 3)
    ori_target = target
    pred_t = pred_t.contiguous().view(bs * num_p, 1, 3)
    ori_t = pred_t
    points = points.contiguous().view(bs * num_p, 1, 3)
    pred_c = pred_c.contiguous().view(bs * num_p)
    # trans model point to predicted RT pose
    # 用预测的R T将模型的点云转换到相机坐标系下相应位置
    # this points come from RGB-D and is camera frame
    # why not just add pred_t

    pred = torch.add(torch.bmm(model_points, base), points + pred_t)

    if not refine:
        if idx[0].item() in sym_list:
            target = target[0].transpose(1, 0).contiguous().view(3, -1)
            pred = pred.permute(2, 0, 1).contiguous().view(3, -1)
            # when sym, use knn to find correspondence
            inds = knn(target.unsqueeze(0), pred.unsqueeze(0))
            target = torch.index_select(target, 1, inds.view(-1).detach() - 1)
            target = target.view(3, bs * num_p,
                                 num_point_mesh).permute(1, 2, 0).contiguous()
            pred = pred.view(3, bs * num_p,
                             num_point_mesh).permute(1, 2, 0).contiguous()

    dis = torch.mean(torch.norm((pred - target), dim=2), dim=1)
    # norm loss
    loss = torch.mean((dis * pred_c - w * torch.log(pred_c)), dim=0)

    pred_c = pred_c.view(bs, num_p)
    how_max, which_max = torch.max(pred_c, 1)
    dis = dis.view(bs, num_p)

    # change biggest conf
    # delat t + bast t
    t = ori_t[which_max[0]] + points[which_max[0]]
    points = points.view(1, bs * num_p, 3)

    ori_base = ori_base[which_max[0]].view(1, 3, 3).contiguous()
    ori_t = t.repeat(bs * num_p, 1).contiguous().view(1, bs * num_p, 3)
    # pred residual
    # use point from RGB-D - pred?
    # 得到
    new_points = torch.bmm((points - ori_t), ori_base).contiguous()
    # ori_target in all raw is same
    # ori_t is pred
    new_target = ori_target[0].view(1, num_point_mesh, 3).contiguous()
    ori_t = t.repeat(num_point_mesh, 1).contiguous().view(1, num_point_mesh, 3)
    # target residal
    # use target - pred ?
    new_target = torch.bmm((new_target - ori_t), ori_base).contiguous()

    # print('------------> ', dis[0][which_max[0]].item(), pred_c[0][which_max[0]].item(), idx[0].item())
    del knn
    return loss, dis[0][which_max[0]], new_points.detach(), new_target.detach()
Exemple #9
0
def main():
    opt.manualSeed = random.randint(1, 10000)
    random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)
    if opt.dataset == 'ycb':
        opt.dataset_root = 'datasets/ycb/YCB_Video_Dataset'
        opt.num_objects = 21
        opt.num_points = 1000
        opt.result_dir = 'results/ycb'
        opt.repeat_epoch = 1
    elif opt.dataset == 'linemod':
        opt.dataset_root = 'datasets/linemod/Linemod_preprocessed'
        opt.num_objects = 13
        opt.num_points = 500
        opt.result_dir = 'results/linemod'
        opt.repeat_epoch = 1
    else:
        print('unknown dataset')
        return
    if opt.dataset == 'ycb':
        dataset = PoseDataset_ycb('train', opt.num_points, True,
                                  opt.dataset_root, opt.noise_trans)
        test_dataset = PoseDataset_ycb('test', opt.num_points, False,
                                       opt.dataset_root, 0.0)
    elif opt.dataset == 'linemod':
        dataset = PoseDataset_linemod('train', opt.num_points, True,
                                      opt.dataset_root, opt.noise_trans)
        test_dataset = PoseDataset_linemod('test', opt.num_points, False,
                                           opt.dataset_root, 0.0)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=1,
                                             shuffle=True,
                                             num_workers=opt.workers)
    testdataloader = torch.utils.data.DataLoader(test_dataset,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 num_workers=opt.workers)
    opt.sym_list = dataset.get_sym_list()
    opt.num_points_mesh = dataset.get_num_points_mesh()
    opt.diameters = dataset.get_diameter()
    print('>>>>>>>>----------Dataset loaded!---------<<<<<<<<')
    print('length of the training set: {0}'.format(len(dataset)))
    print('length of the testing set: {0}'.format(len(test_dataset)))
    print('number of sample points on mesh: {0}'.format(opt.num_points_mesh))
    print('symmetrical object list: {0}'.format(opt.sym_list))

    if not os.path.exists(opt.result_dir):
        os.makedirs(opt.result_dir)
    tb_writer = tf.summary.FileWriter(opt.result_dir)
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id
    # network
    estimator = PoseNet(num_points=opt.num_points,
                        num_obj=opt.num_objects,
                        num_rot=opt.num_rot)
    estimator.cuda()
    # loss
    criterion = Loss(opt.sym_list, estimator.rot_anchors)
    knn = KNearestNeighbor(1)
    # learning rate decay
    best_test = np.Inf
    opt.first_decay_start = False
    opt.second_decay_start = False
    # if resume training
    if opt.resume_posenet != '':
        estimator.load_state_dict(torch.load(opt.resume_posenet))
        model_name_parsing = (opt.resume_posenet.split('.')[0]).split('_')
        best_test = float(model_name_parsing[-1])
        opt.start_epoch = int(model_name_parsing[-2]) + 1
        if best_test < 0.016 and not opt.first_decay_start:
            opt.first_decay_start = True
            opt.lr *= 0.6
        if best_test < 0.013 and not opt.second_decay_start:
            opt.second_decay_start = True
            opt.lr *= 0.5
    # optimizer
    optimizer = torch.optim.Adam(estimator.parameters(), lr=opt.lr)
    global_step = (len(dataset) //
                   opt.batch_size) * opt.repeat_epoch * (opt.start_epoch - 1)
    # train
    st_time = time.time()
    for epoch in range(opt.start_epoch, opt.nepoch):
        logger = setup_logger(
            'epoch%02d' % epoch,
            os.path.join(opt.result_dir, 'epoch_%02d_train_log.txt' % epoch))
        logger.info('Train time {0}'.format(
            time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) +
            ', ' + 'Training started'))
        train_count = 0
        train_loss_avg = 0.0
        train_loss_r_avg = 0.0
        train_loss_t_avg = 0.0
        train_loss_reg_avg = 0.0
        estimator.train()
        optimizer.zero_grad()
        for rep in range(opt.repeat_epoch):
            for i, data in enumerate(dataloader, 0):
                points, choose, img, target_t, target_r, model_points, idx, gt_t = data
                obj_diameter = opt.diameters[idx]
                points, choose, img, target_t, target_r, model_points, idx = Variable(points).cuda(), \
                                                                             Variable(choose).cuda(), \
                                                                             Variable(img).cuda(), \
                                                                             Variable(target_t).cuda(), \
                                                                             Variable(target_r).cuda(), \
                                                                             Variable(model_points).cuda(), \
                                                                             Variable(idx).cuda()
                pred_r, pred_t, pred_c = estimator(img, points, choose, idx)
                loss, loss_r, loss_t, loss_reg = criterion(
                    pred_r, pred_t, pred_c, target_r, target_t, model_points,
                    idx, obj_diameter)
                loss.backward()
                train_loss_avg += loss.item()
                train_loss_r_avg += loss_r.item()
                train_loss_t_avg += loss_t.item()
                train_loss_reg_avg += loss_reg.item()
                train_count += 1
                if train_count % opt.batch_size == 0:
                    global_step += 1
                    lr = opt.lr
                    optimizer.step()
                    optimizer.zero_grad()
                    # write results to tensorboard
                    summary = tf.Summary(value=[
                        tf.Summary.Value(tag='learning_rate', simple_value=lr),
                        tf.Summary.Value(tag='loss',
                                         simple_value=train_loss_avg /
                                         opt.batch_size),
                        tf.Summary.Value(tag='loss_r',
                                         simple_value=train_loss_r_avg /
                                         opt.batch_size),
                        tf.Summary.Value(tag='loss_t',
                                         simple_value=train_loss_t_avg /
                                         opt.batch_size),
                        tf.Summary.Value(tag='loss_reg',
                                         simple_value=train_loss_reg_avg /
                                         opt.batch_size)
                    ])
                    tb_writer.add_summary(summary, global_step)
                    logger.info(
                        'Train time {0} Epoch {1} Batch {2} Frame {3} Avg_loss:{4:f}'
                        .format(
                            time.strftime("%Hh %Mm %Ss",
                                          time.gmtime(time.time() - st_time)),
                            epoch, int(train_count / opt.batch_size),
                            train_count, train_loss_avg / opt.batch_size))
                    train_loss_avg = 0.0
                    train_loss_r_avg = 0.0
                    train_loss_t_avg = 0.0
                    train_loss_reg_avg = 0.0

        print(
            '>>>>>>>>----------epoch {0} train finish---------<<<<<<<<'.format(
                epoch))

        logger = setup_logger(
            'epoch%02d_test' % epoch,
            os.path.join(opt.result_dir, 'epoch_%02d_test_log.txt' % epoch))
        logger.info('Test time {0}'.format(
            time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) +
            ', ' + 'Testing started'))
        test_dis = 0.0
        test_count = 0
        save_model = False
        estimator.eval()
        success_count = [0 for i in range(opt.num_objects)]
        num_count = [0 for i in range(opt.num_objects)]

        for j, data in enumerate(testdataloader, 0):
            points, choose, img, target_t, target_r, model_points, idx, gt_t = data
            obj_diameter = opt.diameters[idx]
            points, choose, img, target_t, target_r, model_points, idx = Variable(points).cuda(), \
                                                                         Variable(choose).cuda(), \
                                                                         Variable(img).cuda(), \
                                                                         Variable(target_t).cuda(), \
                                                                         Variable(target_r).cuda(), \
                                                                         Variable(model_points).cuda(), \
                                                                         Variable(idx).cuda()
            pred_r, pred_t, pred_c = estimator(img, points, choose, idx)
            loss, _, _, _ = criterion(pred_r, pred_t, pred_c, target_r,
                                      target_t, model_points, idx,
                                      obj_diameter)
            test_count += 1
            # evalaution
            how_min, which_min = torch.min(pred_c, 1)
            pred_r = pred_r[0][which_min[0]].view(-1).cpu().data.numpy()
            pred_r = quaternion_matrix(pred_r)[:3, :3]
            pred_t, pred_mask = ransac_voting_layer(points, pred_t)
            pred_t = pred_t.cpu().data.numpy()
            model_points = model_points[0].cpu().detach().numpy()
            pred = np.dot(model_points, pred_r.T) + pred_t
            target = target_r[0].cpu().detach().numpy() + gt_t[0].cpu(
            ).data.numpy()
            if idx[0].item() in opt.sym_list:
                pred = torch.from_numpy(pred.astype(
                    np.float32)).cuda().transpose(1, 0).contiguous()
                target = torch.from_numpy(target.astype(
                    np.float32)).cuda().transpose(1, 0).contiguous()
                inds = knn(target.unsqueeze(0), pred.unsqueeze(0))
                target = torch.index_select(target, 1, inds.view(-1) - 1)
                dis = torch.mean(torch.norm(
                    (pred.transpose(1, 0) - target.transpose(1, 0)), dim=1),
                                 dim=0).item()
            else:
                dis = np.mean(np.linalg.norm(pred - target, axis=1))
            logger.info(
                'Test time {0} Test Frame No.{1} loss:{2:f} confidence:{3:f} distance:{4:f}'
                .format(
                    time.strftime("%Hh %Mm %Ss",
                                  time.gmtime(time.time() - st_time)),
                    test_count, loss, how_min[0].item(), dis))
            if dis < 0.1 * opt.diameters[idx[0].item()]:
                success_count[idx[0].item()] += 1
            num_count[idx[0].item()] += 1
            test_dis += dis
        # compute accuracy
        accuracy = 0.0
        for i in range(opt.num_objects):
            accuracy += float(success_count[i]) / num_count[i]
            logger.info('Object {0} success rate: {1}'.format(
                test_dataset.objlist[i],
                float(success_count[i]) / num_count[i]))
        accuracy = accuracy / opt.num_objects
        test_dis = test_dis / test_count
        # log results
        logger.info(
            'Test time {0} Epoch {1} TEST FINISH Avg dis: {2:f}, Accuracy: {3:f}'
            .format(
                time.strftime("%Hh %Mm %Ss",
                              time.gmtime(time.time() - st_time)), epoch,
                test_dis, accuracy))
        # tensorboard
        summary = tf.Summary(value=[
            tf.Summary.Value(tag='accuracy', simple_value=accuracy),
            tf.Summary.Value(tag='test_dis', simple_value=test_dis)
        ])
        tb_writer.add_summary(summary, global_step)
        # save model
        if test_dis < best_test:
            best_test = test_dis
        torch.save(
            estimator.state_dict(),
            '{0}/pose_model_{1:02d}_{2:06f}.pth'.format(
                opt.result_dir, epoch, best_test))
        # adjust learning rate if necessary
        if best_test < 0.016 and not opt.first_decay_start:
            opt.first_decay_start = True
            opt.lr *= 0.6
            optimizer = torch.optim.Adam(estimator.parameters(), lr=opt.lr)
        if best_test < 0.013 and not opt.second_decay_start:
            opt.second_decay_start = True
            opt.lr *= 0.5
            optimizer = torch.optim.Adam(estimator.parameters(), lr=opt.lr)

        print(
            '>>>>>>>>----------epoch {0} test finish---------<<<<<<<<'.format(
                epoch))