Ejemplo n.º 1
0
def main():
    opt.manualSeed = random.randint(1, 10000)
    random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)

    if opt.dataset == 'ycb':
        opt.num_objects = 21  #number of object classes in the dataset
        opt.num_points = 1000  #number of points on the input pointcloud
        opt.outf = 'trained_models/ycb'  #folder to save trained models
        opt.log_dir = 'experiments/logs/ycb'  #folder to save logs
        opt.repeat_epoch = 1  #number of repeat times for one epoch training
    elif opt.dataset == 'linemod':
        opt.num_objects = 13
        opt.num_points = 500
        opt.outf = 'trained_models/linemod'
        opt.log_dir = 'experiments/logs/linemod'
        opt.repeat_epoch = 20
    else:
        print('Unknown dataset')
        return

    estimator = PoseNet(num_points=opt.num_points, num_obj=opt.num_objects)
    estimator.cuda()
    refiner = PoseRefineNet(num_points=opt.num_points, num_obj=opt.num_objects)
    refiner.cuda()

    if opt.resume_posenet != '':
        estimator.load_state_dict(
            torch.load('{0}/{1}'.format(opt.outf, opt.resume_posenet)))

    if opt.resume_refinenet != '':
        refiner.load_state_dict(
            torch.load('{0}/{1}'.format(opt.outf, opt.resume_refinenet)))
        opt.refine_start = True
        opt.decay_start = True
        opt.lr *= opt.lr_rate
        # when refine, change w_rate
        # refine的时候会改变
        opt.w *= opt.w_rate
        opt.batch_size = int(opt.batch_size / opt.iteration)
        optimizer = optim.Adam(refiner.parameters(), lr=opt.lr)
    else:
        opt.refine_start = False
        opt.decay_start = False
        optimizer = optim.Adam(estimator.parameters(), lr=opt.lr)

    if opt.dataset == 'ycb':
        dataset = PoseDataset_ycb('train', opt.num_points, True,
                                  opt.dataset_root, opt.noise_trans,
                                  opt.refine_start)
    elif opt.dataset == 'linemod':
        dataset = PoseDataset_linemod('train', opt.num_points, True,
                                      opt.dataset_root, opt.noise_trans,
                                      opt.refine_start)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=1,
                                             shuffle=True,
                                             num_workers=opt.workers)
    if opt.dataset == 'ycb':
        test_dataset = PoseDataset_ycb('test', opt.num_points, False,
                                       opt.dataset_root, 0.0, opt.refine_start)
    elif opt.dataset == 'linemod':
        test_dataset = PoseDataset_linemod('test', opt.num_points, False,
                                           opt.dataset_root, 0.0,
                                           opt.refine_start)
    testdataloader = torch.utils.data.DataLoader(test_dataset,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 num_workers=opt.workers)

    opt.sym_list = dataset.get_sym_list()
    opt.num_points_mesh = dataset.get_num_points_mesh()

    print(
        '>>>>>>>>----------Dataset loaded!---------<<<<<<<<\nlength of the training set: {0}\nlength of the testing set: {1}\nnumber of sample points on mesh: {2}\nsymmetry object list: {3}'
        .format(len(dataset), len(test_dataset), opt.num_points_mesh,
                opt.sym_list))

    criterion = Loss(opt.num_points_mesh, opt.sym_list)
    criterion_refine = Loss_refine(opt.num_points_mesh, opt.sym_list)

    best_test = np.Inf

    if opt.start_epoch == 1:
        for log in os.listdir(opt.log_dir):
            os.remove(os.path.join(opt.log_dir, log))
    st_time = time.time()

    for epoch in range(opt.start_epoch, opt.nepoch):
        logger = setup_logger(
            'epoch%d' % epoch,
            os.path.join(opt.log_dir, 'epoch_%d_log.txt' % epoch))
        logger.info('Train time {0}'.format(
            time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) +
            ', ' + 'Training started'))
        train_count = 0
        train_dis_avg = 0.0
        if opt.refine_start:
            estimator.eval()
            refiner.train()
        else:
            estimator.train()
        optimizer.zero_grad()

        for rep in range(opt.repeat_epoch):
            # for i, data in enumerate(dataloader):
            for i, data in enumerate(dataloader, 0):
                points, choose, img, target, model_points, idx = data
                points, choose, img, target, model_points, idx = Variable(points).cuda(), \
                                                                 Variable(choose).cuda(), \
                                                                 Variable(img).cuda(), \
                                                                 Variable(target).cuda(), \
                                                                 Variable(model_points).cuda(), \
                                                                 Variable(idx).cuda()
                pred_r, pred_t, pred_c, emb = estimator(
                    img, points, choose, idx)
                # new_point: point cloud - pred
                # new_target: target - pred
                loss, dis, new_points, new_target = criterion(
                    pred_r, pred_t, pred_c, target, model_points, idx, points,
                    opt.w, opt.refine_start)

                # if opt.refine_start:
                if True:
                    for ite in range(0, opt.iteration):
                        # input cnn result? not global feature?
                        # not same as paper
                        pred_r, pred_t = refiner(new_points, emb, idx)
                        dis, new_points, new_target = criterion_refine(
                            pred_r, pred_t, new_target, model_points, idx,
                            new_points)
                        dis.backward()
                else:
                    loss.backward()

                train_dis_avg += dis.item()
                train_count += 1

                if train_count % opt.batch_size == 0:
                    logger.info(
                        'Train time {0} Epoch {1} Batch {2} Frame {3} Avg_dis:{4}'
                        .format(
                            time.strftime("%Hh %Mm %Ss",
                                          time.gmtime(time.time() - st_time)),
                            epoch, int(train_count / opt.batch_size),
                            train_count, train_dis_avg / opt.batch_size))
                    optimizer.step()
                    optimizer.zero_grad()
                    train_dis_avg = 0

                if train_count != 0 and train_count % 1000 == 0:
                    if opt.refine_start:
                        torch.save(
                            refiner.state_dict(),
                            '{0}/pose_refine_model_current.pth'.format(
                                opt.outf))
                    else:
                        torch.save(
                            estimator.state_dict(),
                            '{0}/pose_model_current.pth'.format(opt.outf))

        print(
            '>>>>>>>>----------epoch {0} train finish---------<<<<<<<<'.format(
                epoch))

        logger = setup_logger(
            'epoch%d_test' % epoch,
            os.path.join(opt.log_dir, 'epoch_%d_test_log.txt' % epoch))
        logger.info('Test time {0}'.format(
            time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) +
            ', ' + 'Testing started'))
        test_dis = 0.0
        test_count = 0
        estimator.eval()
        refiner.eval()

        for j, data in enumerate(testdataloader, 0):
            # reutrn point_cloud, mask, image_norm, target, model_point, obj_index
            points, choose, img, target, model_points, idx = data
            points, choose, img, target, model_points, idx = Variable(points).cuda(), \
                                                             Variable(choose).cuda(), \
                                                             Variable(img).cuda(), \
                                                             Variable(target).cuda(), \
                                                             Variable(model_points).cuda(), \
                                                             Variable(idx).cuda()
            pred_r, pred_t, pred_c, emb = estimator(img, points, choose, idx)
            _, dis, new_points, new_target = criterion(pred_r, pred_t, pred_c,
                                                       target, model_points,
                                                       idx, points, opt.w,
                                                       opt.refine_start)

            if opt.refine_start:
                for ite in range(0, opt.iteration):
                    pred_r, pred_t = refiner(new_points, emb, idx)
                    dis, new_points, new_target = criterion_refine(
                        pred_r, pred_t, new_target, model_points, idx,
                        new_points)

            test_dis += dis.item()
            logger.info('Test time {0} Test Frame No.{1} dis:{2}'.format(
                time.strftime("%Hh %Mm %Ss",
                              time.gmtime(time.time() - st_time)), test_count,
                dis))

            test_count += 1

        test_dis = test_dis / test_count
        logger.info('Test time {0} Epoch {1} TEST FINISH Avg dis: {2}'.format(
            time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)),
            epoch, test_dis))
        if test_dis <= best_test:
            best_test = test_dis
            if opt.refine_start:
                torch.save(
                    refiner.state_dict(),
                    '{0}/pose_refine_model_{1}_{2}.pth'.format(
                        opt.outf, epoch, test_dis))
            else:
                torch.save(
                    estimator.state_dict(),
                    '{0}/pose_model_{1}_{2}.pth'.format(
                        opt.outf, epoch, test_dis))
            print(epoch,
                  '>>>>>>>>----------BEST TEST MODEL SAVED---------<<<<<<<<')

        if best_test < opt.decay_margin and not opt.decay_start:
            opt.decay_start = True
            opt.lr *= opt.lr_rate
            opt.w *= opt.w_rate
            optimizer = optim.Adam(estimator.parameters(), lr=opt.lr)

        if best_test < opt.refine_margin and not opt.refine_start:
            opt.refine_start = True
            opt.batch_size = int(opt.batch_size / opt.iteration)
            optimizer = optim.Adam(refiner.parameters(), lr=opt.lr)

            if opt.dataset == 'ycb':
                dataset = PoseDataset_ycb('train', opt.num_points, True,
                                          opt.dataset_root, opt.noise_trans,
                                          opt.refine_start)
            elif opt.dataset == 'linemod':
                dataset = PoseDataset_linemod('train', opt.num_points, True,
                                              opt.dataset_root,
                                              opt.noise_trans,
                                              opt.refine_start)
            dataloader = torch.utils.data.DataLoader(dataset,
                                                     batch_size=1,
                                                     shuffle=True,
                                                     num_workers=opt.workers)
            if opt.dataset == 'ycb':
                test_dataset = PoseDataset_ycb('test', opt.num_points, False,
                                               opt.dataset_root, 0.0,
                                               opt.refine_start)
            elif opt.dataset == 'linemod':
                test_dataset = PoseDataset_linemod('test', opt.num_points,
                                                   False, opt.dataset_root,
                                                   0.0, opt.refine_start)
            testdataloader = torch.utils.data.DataLoader(
                test_dataset,
                batch_size=1,
                shuffle=False,
                num_workers=opt.workers)

            opt.sym_list = dataset.get_sym_list()
            opt.num_points_mesh = dataset.get_num_points_mesh()

            print(
                '>>>>>>>>----------Dataset loaded!---------<<<<<<<<\nlength of the training set: {0}\nlength of the testing set: {1}\nnumber of sample points on mesh: {2}\nsymmetry object list: {3}'
                .format(len(dataset), len(test_dataset), opt.num_points_mesh,
                        opt.sym_list))

            criterion = Loss(opt.num_points_mesh, opt.sym_list)
            criterion_refine = Loss_refine(opt.num_points_mesh, opt.sym_list)
def main():
    # opt.manualSeed = random.randint(1, 10000)
    # # opt.manualSeed = 1
    # random.seed(opt.manualSeed)
    # torch.manual_seed(opt.manualSeed)

    torch.set_printoptions(threshold=5000)
    # device_ids = [0,1]
    cudnn.benchmark = True
    if opt.dataset == 'ycb':
        opt.num_objects = 21  #number of object classes in the dataset
        opt.num_points = 1000  #number of points on the input pointcloud
        opt.outf = 'trained_models/ycb'  #folder to save trained models
        opt.log_dir = 'experiments/logs/ycb'  #folder to save logs
        opt.repeat_epoch = 3  #number of repeat times for one epoch training
    elif opt.dataset == 'linemod':
        opt.num_objects = 13
        opt.num_points = 500
        opt.outf = 'trained_models/linemod'
        opt.log_dir = 'experiments/logs/linemod'
        opt.repeat_epoch = 20
    else:
        print('Unknown dataset')
        return

    estimator = PoseNet(num_points=opt.num_points, num_obj=opt.num_objects)

    estimator.cuda()
    refiner = PoseRefineNet(num_points=opt.num_points, num_obj=opt.num_objects)
    # refiner.cuda()
    # estimator = nn.DataParallel(estimator, device_ids=device_ids)

    if opt.resume_posenet != '':
        estimator.load_state_dict(
            torch.load('{0}/{1}'.format(opt.outf, opt.resume_posenet)))
        print('LOADED!!')

    if opt.resume_refinenet != '':
        refiner.load_state_dict(
            torch.load('{0}/{1}'.format(opt.outf, opt.resume_refinenet)))
        opt.refine_start = True
        opt.decay_start = True
        opt.lr *= opt.lr_rate
        opt.w *= opt.w_rate
        opt.batch_size = int(opt.batch_size / opt.iteration)
        optimizer = optim.Adam(refiner.parameters(), lr=opt.lr)
    else:
        print('no refinement')
        opt.refine_start = False
        opt.decay_start = False
        optimizer = optim.Adam(estimator.parameters(), lr=opt.lr)
        # optimizer = nn.DataParallel(optimizer, device_ids=device_ids)

    if opt.dataset == 'ycb':
        dataset = PoseDataset_ycb('train', opt.num_points, False,
                                  opt.dataset_root, opt.noise_trans,
                                  opt.refine_start)
        # print(dataset.list)
    elif opt.dataset == 'linemod':
        dataset = PoseDataset_linemod('train', opt.num_points, True,
                                      opt.dataset_root, opt.noise_trans,
                                      opt.refine_start)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=1,
                                             shuffle=True,
                                             num_workers=opt.workers)
    if opt.dataset == 'ycb':
        test_dataset = PoseDataset_ycb('test', opt.num_points, False,
                                       opt.dataset_root, 0.0, opt.refine_start)
    elif opt.dataset == 'linemod':
        test_dataset = PoseDataset_linemod('test', opt.num_points, False,
                                           opt.dataset_root, 0.0,
                                           opt.refine_start)
    testdataloader = torch.utils.data.DataLoader(test_dataset,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 num_workers=opt.workers)

    opt.sym_list = dataset.get_sym_list()
    opt.num_points_mesh = dataset.get_num_points_mesh()

    # print('>>>>>>>>----------Dataset loaded!---------<<<<<<<<\nlength of the training set: {0}\nlength of the testing set: {1}\nnumber of sample points on mesh: {2}\nsymmetry object list: {3}'.format(len(dataset), len(test_dataset), opt.num_points_mesh, opt.sym_list))

    criterion = Loss(opt.num_points_mesh, opt.sym_list)
    # criterion_refine = Loss_refine(opt.num_points_mesh, opt.sym_list)

    best_test = np.Inf
    best_epoch = 0

    if opt.start_epoch == 1:
        for log in os.listdir(opt.log_dir):
            os.remove(os.path.join(opt.log_dir, log))
    st_time = time.time()

    count_gen = 0

    mode = 1

    if mode == 1:

        for epoch in range(opt.start_epoch, opt.nepoch):
            logger = setup_logger(
                'epoch%d' % epoch,
                os.path.join(opt.log_dir, 'epoch_%d_log.txt' % epoch))
            logger.info('Train time {0}'.format(
                time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() -
                                                         st_time)) + ', ' +
                'Training started'))
            train_count = 0
            train_dis_avg = 0.0
            if opt.refine_start:
                estimator.eval()
                refiner.train()
            else:
                estimator.train()
            optimizer.zero_grad()

            for rep in range(opt.repeat_epoch):
                for i, data in enumerate(dataloader, 0):
                    points, choose, img, target_sym, target_cen, idx, file_list_idx = data

                    if idx is 9 or idx is 16:
                        continue

                    points, choose, img, target_sym, target_cen, idx = Variable(points).cuda(), \
                                                                     Variable(choose).cuda(), \
                                                                     Variable(img).cuda(), \
                                                                     Variable(target_sym).cuda(), \
                                                                     Variable(target_cen).cuda(), \
                                                                     Variable(idx).cuda()

                    pred_norm, pred_on_plane, emb = estimator(
                        img, points, choose, idx)

                    loss = criterion(pred_norm, pred_on_plane, target_sym,
                                     target_cen, idx, points, opt.w,
                                     opt.refine_start)

                    # scene_idx = dataset.list[file_list_idx]

                    loss.backward()

                    # train_dis_avg += dis.item()
                    train_count += 1

                    if train_count % opt.batch_size == 0:
                        logger.info(
                            'Train time {0} Epoch {1} Batch {2} Frame {3}'.
                            format(
                                time.strftime(
                                    "%Hh %Mm %Ss",
                                    time.gmtime(time.time() - st_time)), epoch,
                                int(train_count / opt.batch_size),
                                train_count))
                        optimizer.step()
                        # for param_lr in optimizer.module.param_groups:
                        #         param_lr['lr'] /= 2
                        optimizer.zero_grad()
                        train_dis_avg = 0

                    if train_count % 8 == 0:
                        print(pred_on_plane.max())
                        print(pred_on_plane.mean())
                        print(idx)

                    if train_count != 0 and train_count % 1000 == 0:
                        if opt.refine_start:
                            torch.save(
                                refiner.state_dict(),
                                '{0}/pose_refine_model_current.pth'.format(
                                    opt.outf))
                        else:
                            torch.save(
                                estimator.state_dict(),
                                '{0}/pose_model_current.pth'.format(opt.outf))

            print('>>>>>>>>----------epoch {0} train finish---------<<<<<<<<'.
                  format(epoch))

            logger = setup_logger(
                'epoch%d_test' % epoch,
                os.path.join(opt.log_dir, 'epoch_%d_test_log.txt' % epoch))
            logger.info('Test time {0}'.format(
                time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() -
                                                         st_time)) + ', ' +
                'Testing started'))
            test_loss = 0.0
            test_count = 0
            estimator.eval()

            logger.info(
                'Test time {0} Epoch {1} TEST FINISH Avg dis: {2}'.format(
                    time.strftime("%Hh %Mm %Ss",
                                  time.gmtime(time.time() - st_time)), epoch,
                    test_loss))
            print(pred_on_plane.max())
            print(pred_on_plane.mean())
            bs, num_p, _ = pred_on_plane.size()
            # if epoch % 40 == 0:
            #     import pdb;pdb.set_trace()
            best_test = test_loss
            best_epoch = epoch
            if opt.refine_start:
                torch.save(
                    refiner.state_dict(),
                    '{0}/pose_refine_model_{1}_{2}.pth'.format(
                        opt.outf, epoch, test_loss))
            else:
                torch.save(
                    estimator.state_dict(),
                    '{0}/pose_model_{1}_{2}.pth'.format(
                        opt.outf, epoch, test_loss))
            print(epoch,
                  '>>>>>>>>----------BEST TEST MODEL SAVED---------<<<<<<<<')

            if best_test < opt.decay_margin and not opt.decay_start:
                opt.decay_start = True
                opt.lr *= opt.lr_rate
                # opt.w *= opt.w_rate
                optimizer = optim.Adam(estimator.parameters(), lr=opt.lr)

        estimator.load_state_dict(
            torch.load('{0}/pose_model_{1}_{2}.pth'.format(
                opt.outf, best_epoch, best_test)))
    else:
        estimator.load_state_dict(
            torch.load('{0}/pose_model_45_0.0.pth'.format(opt.outf),
                       map_location='cpu'))
Ejemplo n.º 3
0
def main():
    if opt.dataset == 'linemod':
        opt.num_obj = 1
        opt.list_obj = [1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 15]
        opt.occ_list_obj = [1, 5, 6, 8, 9, 10, 11, 12]
        opt.list_name = ['ape', 'benchvise', 'cam', 'can', 'cat', 'driller', 'duck', 'eggbox', 'glue', 'holepuncher', 'iron', 'lamp', 'phone']
        obj_name = opt.list_name[opt.list_obj.index(opt.obj_id)]
        opt.sym_list = [10, 11]
        opt.num_points = 500
        meta_file = open('{0}/models/models_info.yml'.format(opt.dataset_root), 'r')
        meta = yaml.load(meta_file)
        diameter = meta[opt.obj_id]['diameter'] / 1000.0 * 0.1
        if opt.render:
            opt.repeat_num = 1
        elif opt.fuse:
            opt.repeat_num = 1
        else:
            opt.repeat_num = 5
        writer = SummaryWriter('experiments/runs/linemod/{}{}'.format(obj_name, opt.experiment_name))
        opt.outf = 'trained_models/linemod/{}{}'.format(obj_name, opt.experiment_name)
        opt.log_dir = 'experiments/logs/linemod/{}{}'.format(obj_name, opt.experiment_name)
        if not os.path.exists(opt.outf):
            os.mkdir(opt.outf)
        if not os.path.exists(opt.log_dir):
            os.mkdir(opt.log_dir)
    else:
        print('Unknown dataset')
        return

    estimator = PoseNet(num_points = opt.num_points, num_vote = 9, num_obj = opt.num_obj)
    estimator.cuda()
    refiner = PoseRefineNet(num_points = opt.num_points, num_obj = opt.num_obj)
    refiner.cuda()

    if opt.resume_posenet != '':
        estimator.load_state_dict(torch.load('{0}/{1}'.format(opt.outf, opt.resume_posenet)))
    if opt.resume_refinenet != '':
        refiner.load_state_dict(torch.load('{0}/{1}'.format(opt.outf, opt.resume_refinenet)))
        opt.refine_start = True
        opt.lr = opt.lr_refine
        opt.batch_size = int(opt.batch_size / opt.iteration)
        optimizer = optim.Adam(refiner.parameters(), lr=opt.lr)
    else:
        opt.refine_start = False
        optimizer = optim.Adam(estimator.parameters(), lr=opt.lr)

    dataset = PoseDataset_linemod('train', opt.num_points, opt.dataset_root, opt.real, opt.render, opt.fuse, opt.obj_id)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=opt.workers)
    test_dataset = PoseDataset_linemod('test', opt.num_points, opt.dataset_root, True, False, False, opt.obj_id)
    testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=opt.workers)

    print('>>>>>>>>----------Dataset loaded!---------<<<<<<<<\nlength of the training set: {0}\nlength of the testing set: {1}\nnumber of sample points on mesh: {2}'.format(len(dataset), len(test_dataset), opt.num_points))
    if opt.obj_id in opt.occ_list_obj:
        occ_test_dataset = PoseDataset_occ('test', opt.num_points, opt.occ_dataset_root, opt.obj_id)
        occtestdataloader = torch.utils.data.DataLoader(occ_test_dataset, batch_size=1, shuffle=False, num_workers=opt.workers)
        print('length of the occ testing set: {}'.format(len(occ_test_dataset)))

    criterion = Loss(opt.num_points, opt.sym_list)
    criterion_refine = Loss_refine(opt.num_points, opt.sym_list)
    best_test = np.Inf

    if opt.start_epoch == 1:
        for log in os.listdir(opt.log_dir):
            os.remove(os.path.join(opt.log_dir, log))
    st_time = time.time()
    train_scalar = 0

    for epoch in range(opt.start_epoch, opt.nepoch):
        logger = setup_logger('epoch%d' % epoch, os.path.join(opt.log_dir, 'epoch_%d_log.txt' % epoch))
        logger.info('Train time {0}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) + ', ' + 'Training started'))
        train_count = 0
        train_loss_avg = 0.0
        train_loss = 0.0
        train_dis_avg = 0.0
        train_dis = 0.0
        if opt.refine_start:
            estimator.eval()
            refiner.train()
        else:
            estimator.train()
        optimizer.zero_grad()
        for rep in range(opt.repeat_num):
            for i, data in enumerate(dataloader, 0):
                points, choose, img, target, model_points, model_kp, vertex_gt, idx, target_r, target_t = data
                if len(points.size()) == 2:
                    print('pass')
                    continue
                points, choose, img, target, model_points, model_kp, vertex_gt, idx, target_r, target_t = points.cuda(), choose.cuda(), img.cuda(), target.cuda(), model_points.cuda(), model_kp.cuda(), vertex_gt.cuda(), idx.cuda(), target_r.cuda(), target_t.cuda()
                vertex_pred, c_pred, emb = estimator(img, points, choose, idx)
                vertex_loss, pose_loss, dis, new_points, new_target = criterion(vertex_pred, vertex_gt, c_pred, points, target, model_points, model_kp, opt.obj_id, target_r, target_t)
                loss = 10 * vertex_loss + pose_loss
                if opt.refine_start:
                    for ite in range(0, opt.iteration):
                        pred_r, pred_t = refiner(new_points, emb, idx)
                        dis, new_points, new_target = criterion_refine(pred_r, pred_t, new_points, new_target, model_points, opt.obj_id)
                        dis.backward()
                else:
                    loss.backward()

                train_loss_avg += loss.item()
                train_loss += loss.item()
                train_dis_avg += dis.item()
                train_dis += dis.item()
                train_count += 1
                train_scalar += 1

                if train_count % opt.batch_size == 0:
                    logger.info('Train time {0} Epoch {1} Batch {2} Frame {3} Avg_loss:{4} Avg_diss:{5}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), epoch, int(train_count / opt.batch_size), train_count, train_loss_avg / opt.batch_size, train_dis_avg / opt.batch_size))
                    writer.add_scalar('linemod training loss', train_loss_avg / opt.batch_size, train_scalar)
                    writer.add_scalar('linemod training dis', train_dis_avg / opt.batch_size, train_scalar)
                    optimizer.step()
                    optimizer.zero_grad()
                    train_loss_avg = 0
                    train_dis_avg = 0

                if train_count != 0 and train_count % 1000 == 0:
                    if opt.refine_start:
                        torch.save(refiner.state_dict(), '{0}/pose_refine_model_current.pth'.format(opt.outf))
                    else:
                        torch.save(estimator.state_dict(), '{0}/pose_model_current.pth'.format(opt.outf))

        print('>>>>>>>>----------epoch {0} train finish---------<<<<<<<<'.format(epoch))
        train_loss = train_loss / train_count
        train_dis = train_dis / train_count
        logger.info('Train time {0} Epoch {1} TRAIN FINISH Avg loss: {2} Avg dis: {3}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), epoch, train_loss, train_dis))

        logger = setup_logger('epoch%d_test' % epoch, os.path.join(opt.log_dir, 'epoch_%d_test_log.txt' % epoch))
        logger.info('Test time {0}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) + ', ' + 'Testing started'))
        test_loss = 0.0
        test_vertex_loss = 0.0
        test_pose_loss = 0.0
        test_dis = 0.0
        test_count = 0
        success_count = 0
        estimator.eval()
        refiner.eval()

        for j, data in enumerate(testdataloader, 0):
            points, choose, img, target, model_points, model_kp, vertex_gt, idx, target_r, target_t = data
            if len(points.size()) == 2:
                logger.info('Test time {0} Lost detection!'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time))))
                continue
            points, choose, img, target, model_points, model_kp, vertex_gt, idx, target_r, target_t = points.cuda(), choose.cuda(), img.cuda(), target.cuda(), model_points.cuda(), model_kp.cuda(), vertex_gt.cuda(), idx.cuda(), target_r.cuda(), target_t.cuda()
            vertex_pred, c_pred, emb = estimator(img, points, choose, idx)
            vertex_loss, pose_loss, dis, new_points, new_target = criterion(vertex_pred, vertex_gt, c_pred, points, target, model_points, model_kp, opt.obj_id, target_r, target_t)
            loss = 10 * vertex_loss + pose_loss
            if opt.refine_start:
                for ite in range(0, opt.iteration):
                    pred_r, pred_t = refiner(new_points, emb, idx)
                    dis, new_points, new_target = criterion_refine(pred_r, pred_t, new_points, new_target, model_points, opt.obj_id)

            test_loss += loss.item()
            test_vertex_loss += vertex_loss.item()
            test_pose_loss += pose_loss.item()
            test_dis += dis.item()
            logger.info('Test time {0} Test Frame No.{1} loss:{2} dis:{3}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), test_count, loss, dis))
            if dis.item() < diameter:
                success_count += 1
            test_count += 1

        test_loss = test_loss / test_count
        test_vertex_loss = test_vertex_loss / test_count
        test_pose_loss = test_pose_loss / test_count
        test_dis = test_dis / test_count
        success_rate = float(success_count) / test_count
        logger.info('Test time {0} Epoch {1} TEST FINISH Avg loss: {2} Avg dis: {3} Success rate: {4}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), epoch, test_loss, test_dis, success_rate))
        writer.add_scalar('linemod test loss', test_loss, epoch)
        writer.add_scalar('linemod test vertex loss', test_vertex_loss, epoch)
        writer.add_scalar('linemod test pose loss', test_pose_loss, epoch)
        writer.add_scalar('linemod test dis', test_dis, epoch)
        writer.add_scalar('linemod success rate', success_rate, epoch)
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)
        if test_dis <= best_test:
            best_test = test_dis
        if opt.refine_start:
            torch.save(refiner.state_dict(), '{0}/pose_refine_model_{1}_{2}.pth'.format(opt.outf, epoch, test_dis))
        else:
            torch.save(estimator.state_dict(), '{0}/pose_model_{1}_{2}.pth'.format(opt.outf, epoch, test_dis))
        print(epoch, '>>>>>>>>----------MODEL SAVED---------<<<<<<<<')

        if opt.obj_id in opt.occ_list_obj:
            logger = setup_logger('epoch%d_occ_test' % epoch, os.path.join(opt.log_dir, 'epoch_%d_occ_test_log.txt' % epoch))
            logger.info('Occ test time {0}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) + ', ' + 'Testing started'))
            occ_test_dis = 0.0
            occ_test_count = 0
            occ_success_count = 0
            estimator.eval()
            refiner.eval()

            for j, data in enumerate(occtestdataloader, 0):
                points, choose, img, target, model_points, model_kp, vertex_gt, idx, target_r, target_t = data
                if len(points.size()) == 2:
                    logger.info('Occ test time {0} Lost detection!'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time))))
                    continue
                points, choose, img, target, model_points, model_kp, vertex_gt, idx, target_r, target_t = points.cuda(), choose.cuda(), img.cuda(), target.cuda(), model_points.cuda(), model_kp.cuda(), vertex_gt.cuda(), idx.cuda(), target_r.cuda(), target_t.cuda()
                vertex_pred, c_pred, emb = estimator(img, points, choose, idx)
                vertex_loss, pose_loss, dis, new_points, new_target = criterion(vertex_pred, vertex_gt, c_pred, points, target, model_points, model_kp, opt.obj_id, target_r, target_t)
                if opt.refine_start:
                    for ite in range(0, opt.iteration):
                        pred_r, pred_t = refiner(new_points, emb, idx)
                        dis, new_points, new_target = criterion_refine(pred_r, pred_t, new_points, new_target, model_points, opt.obj_id)

                occ_test_dis += dis.item()
                logger.info('Occ test time {0} Test Frame No.{1} dis:{2}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), occ_test_count, dis))
                if dis.item() < diameter:
                    occ_success_count += 1
                occ_test_count += 1

            occ_test_dis = occ_test_dis / occ_test_count
            occ_success_rate = float(occ_success_count) / occ_test_count
            logger.info('Occ test time {0} Epoch {1} TEST FINISH Avg dis: {2} Success rate: {3}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), epoch, occ_test_dis, occ_success_rate))
            writer.add_scalar('occ test dis', occ_test_dis, epoch)
            writer.add_scalar('occ success rate', occ_success_rate, epoch)

        if best_test < opt.refine_margin and not opt.refine_start:
            opt.refine_start = True
            opt.lr = opt.lr_refine
            opt.batch_size = int(opt.batch_size / opt.iteration)
            optimizer = optim.Adam(refiner.parameters(), lr=opt.lr)
            print('>>>>>>>>----------Refine started---------<<<<<<<<')

    writer.close()
def main():
    # g13: parameter setting -------------------
    batch_id = 1
    
    opt.dataset ='linemod'
    opt.dataset_root = './datasets/linemod/Linemod_preprocessed'
    estimator_path = 'trained_checkpoints/linemod/pose_model_9_0.01310166542980859.pth'
    refiner_path = 'trained_checkpoints/linemod/pose_refine_model_493_0.006761023565178073.pth'
    opt.resume_posenet = estimator_path
    opt.resume_posenet = refiner_path
    dataset_config_dir = 'datasets/linemod/dataset_config'
    output_result_dir = 'experiments/eval_result/linemod'
    bs = 1 #fixed because of the default setting in torch.utils.data.DataLoader
    opt.iteration = 2 #default is 4 in eval_linemod.py
    t1_idx = 0
    t1_total_eval_num = 3
    
    axis_range = 0.1   # the length of X, Y, and Z axis in 3D
    vimg_dir = 'verify_img'
    if not os.path.exists(vimg_dir):
        os.makedirs(vimg_dir)
    #-------------------------------------------
    
    if opt.dataset == 'ycb':
        opt.num_objects = 21 #number of object classes in the dataset
        opt.num_points = 1000 #number of points on the input pointcloud
        opt.outf = 'trained_models/ycb' #folder to save trained models
        opt.log_dir = 'experiments/logs/ycb' #folder to save logs
        opt.repeat_epoch = 1 #number of repeat times for one epoch training
    elif opt.dataset == 'linemod':
        opt.num_objects = 13
        opt.num_points = 500
        opt.outf = 'trained_models/linemod'
        opt.log_dir = 'experiments/logs/linemod'
        opt.repeat_epoch = 20
    else:
        print('Unknown dataset')
        return
    
    estimator = PoseNet(num_points = opt.num_points, num_obj = opt.num_objects)
    estimator.cuda()
    refiner = PoseRefineNet(num_points = opt.num_points, num_obj = opt.num_objects)
    refiner.cuda()

    if opt.resume_posenet != '':
        estimator.load_state_dict(torch.load(estimator_path))

    if opt.resume_refinenet != '':
        refiner.load_state_dict(torch.load(refiner_path))
        opt.refine_start = True
        opt.decay_start = True
        opt.lr *= opt.lr_rate
        opt.w *= opt.w_rate
        opt.batch_size = int(opt.batch_size / opt.iteration)
        optimizer = optim.Adam(refiner.parameters(), lr=opt.lr)
    else:
        opt.refine_start = False
        opt.decay_start = False
        optimizer = optim.Adam(estimator.parameters(), lr=opt.lr)


    if opt.dataset == 'ycb':
        test_dataset = PoseDataset_ycb('test', opt.num_points, False, opt.dataset_root, 0.0, opt.refine_start)
    elif opt.dataset == 'linemod':
        test_dataset = PoseDataset_linemod('test', opt.num_points, False, opt.dataset_root, 0.0, opt.refine_start)
    testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=opt.workers)
    print('complete loading testing loader\n')
    opt.sym_list = test_dataset.get_sym_list()
    opt.num_points_mesh = test_dataset.get_num_points_mesh()

    print('>>>>>>>>----------Dataset loaded!---------<<<<<<<<\n\
        length of the testing set: {0}\nnumber of sample points on mesh: {1}\n\
        symmetry object list: {2}'\
        .format( len(test_dataset), opt.num_points_mesh, opt.sym_list))
    
    
    
    #load pytorch model
    estimator.eval()    
    refiner.eval()
    criterion = Loss(opt.num_points_mesh, opt.sym_list)
    criterion_refine = Loss_refine(opt.num_points_mesh, opt.sym_list)
    fw = open('{0}/t1_eval_result_logs.txt'.format(output_result_dir), 'w')

    #Pose estimation
    for j, data in enumerate(testdataloader, 0):
        # g13: modify this part for evaluation target--------------------
        if j == t1_total_eval_num:
            break
        #----------------------------------------------------------------
        points, choose, img, target, model_points, idx = data
        if len(points.size()) == 2:
            print('No.{0} NOT Pass! Lost detection!'.format(j))
            fw.write('No.{0} NOT Pass! Lost detection!\n'.format(j))
            continue
        points, choose, img, target, model_points, idx = Variable(points).cuda(), \
                                                             Variable(choose).cuda(), \
                                                             Variable(img).cuda(), \
                                                             Variable(target).cuda(), \
                                                             Variable(model_points).cuda(), \
                                                             Variable(idx).cuda()
        pred_r, pred_t, pred_c, emb = estimator(img, points, choose, idx)
        _, dis, new_points, new_target = criterion(pred_r, pred_t, pred_c, target, model_points, idx, points, opt.w, opt.refine_start)

        #if opt.refine_start: #iterative poserefinement
        #    for ite in range(0, opt.iteration):
        #        pred_r, pred_t = refiner(new_points, emb, idx)
        #        dis, new_points, new_target = criterion_refine(pred_r, pred_t, new_target, model_points, idx, new_points)
        
        pred_r = pred_r / torch.norm(pred_r, dim=2).view(1, opt.num_points, 1)
        pred_c = pred_c.view(bs, opt.num_points)
        how_max, which_max = torch.max(pred_c, 1)
        pred_t = pred_t.view(bs * opt.num_points, 1, 3)
    
        my_r = pred_r[0][which_max[0]].view(-1).cpu().data.numpy()
        my_t = (points.view(bs * opt.num_points, 1, 3) + pred_t)[which_max[0]].view(-1).cpu().data.numpy()
        my_pred = np.append(my_r, my_t)
    
        for ite in range(0, opt.iteration):
            T = Variable(torch.from_numpy(my_t.astype(np.float32))).cuda().view(1, 3).repeat(opt.num_points, 1).contiguous().view(1, opt.num_points, 3)
            my_mat = quaternion_matrix(my_r)
            R = Variable(torch.from_numpy(my_mat[:3, :3].astype(np.float32))).cuda().view(1, 3, 3)
            my_mat[0:3, 3] = my_t
            
            new_points = torch.bmm((points - T), R).contiguous()
            pred_r, pred_t = refiner(new_points, emb, idx)
            pred_r = pred_r.view(1, 1, -1)
            pred_r = pred_r / (torch.norm(pred_r, dim=2).view(1, 1, 1))
            my_r_2 = pred_r.view(-1).cpu().data.numpy()
            my_t_2 = pred_t.view(-1).cpu().data.numpy()
            my_mat_2 = quaternion_matrix(my_r_2)
            my_mat_2[0:3, 3] = my_t_2
    
            my_mat_final = np.dot(my_mat, my_mat_2)
            my_r_final = copy.deepcopy(my_mat_final)
            my_r_final[0:3, 3] = 0
            my_r_final = quaternion_from_matrix(my_r_final, True)
            my_t_final = np.array([my_mat_final[0][3], my_mat_final[1][3], my_mat_final[2][3]])
    
            my_pred = np.append(my_r_final, my_t_final)
            my_r = my_r_final
            my_t = my_t_final

        # g13: start drawing pose on image------------------------------------
        # pick up image
        print("index {0}: {1}".format(j, test_dataset.list_rgb[j]))
        img = Image.open(test_dataset.list_rgb[j])
        
        # pick up center position by bbox
        meta_file = open('{0}/data/{1}/gt.yml'.format(opt.dataset_root, '%02d' % test_dataset.list_obj[j]), 'r')
        meta = {}
        meta = yaml.load(meta_file)
        which_item = test_dataset.list_rank[j]
        bbx = meta[which_item][0]['obj_bb']
        draw = ImageDraw.Draw(img) 
        
        # draw box (ensure this is the right object)
        draw.line((bbx[0],bbx[1], bbx[0], bbx[1]+bbx[3]), fill=(255,0,0), width=5)
        draw.line((bbx[0],bbx[1], bbx[0]+bbx[2], bbx[1]), fill=(255,0,0), width=5)
        draw.line((bbx[0],bbx[1]+bbx[3], bbx[0]+bbx[2], bbx[1]+bbx[3]), fill=(255,0,0), width=5)
        draw.line((bbx[0]+bbx[2],bbx[1], bbx[0]+bbx[2], bbx[1]+bbx[3]), fill=(255,0,0), width=5)
        
        #get center
        c_x = bbx[0]+int(bbx[2]/2)
        c_y = bbx[1]+int(bbx[3]/2)
        draw.point((c_x,c_y), fill=(255,255,0))
        
        #get the 3D position of center
        cam_intrinsic = np.zeros((3,3))
        cam_intrinsic.itemset(0, test_dataset.cam_fx)
        cam_intrinsic.itemset(4, test_dataset.cam_fy)
        cam_intrinsic.itemset(2, test_dataset.cam_cx)
        cam_intrinsic.itemset(5, test_dataset.cam_cy)
        cam_intrinsic.itemset(8, 1)
        cam_extrinsic = my_mat_final[0:3, :]
        cam2d_3d = np.matmul(cam_intrinsic, cam_extrinsic)
        cen_3d = np.matmul(np.linalg.pinv(cam2d_3d), [[c_x],[c_y],[1]])
        # replace img.show() with plt.imshow(img)
        
        #transpose three 3D axis point into 2D
        x_3d = cen_3d + [[axis_range],[0],[0],[0]]
        y_3d = cen_3d + [[0],[axis_range],[0],[0]]
        z_3d = cen_3d + [[0],[0],[axis_range],[0]]
        x_2d = np.matmul(cam2d_3d, x_3d)
        y_2d = np.matmul(cam2d_3d, y_3d)
        z_2d = np.matmul(cam2d_3d, z_3d)
        
        #draw the axis on 2D
        draw.line((c_x, c_y, x_2d[0], x_2d[1]), fill=(255,255,0), width=5)
        draw.line((c_x, c_y, y_2d[0], y_2d[1]), fill=(0,255,0), width=5)
        draw.line((c_x, c_y, z_2d[0], z_2d[1]), fill=(0,0,255), width=5)

        #g13: show image
        #img.show()
        
        #save file under file 
        img_file_name = '{0}/pred_obj{1}_pic{2}.png'.format(vimg_dir, test_dataset.list_obj[j], which_item)
        img.save( img_file_name, "PNG" )
        img.close()
Ejemplo n.º 5
0
class DenseFusionLightning(LightningModule):
    def __init__(self, exp, env):
        super().__init__()
        self._mode = 'init'

        # check experiment cfg for errors
        check_exp(exp)

        # logging h-params
        exp_config_flatten = flatten_dict(copy.deepcopy(exp))
        for k in exp_config_flatten.keys():
            if exp_config_flatten[k] is None:
                exp_config_flatten[k] = 'is None'

        self.hparams = exp_config_flatten
        self.hparams['lr'] = exp['training']['lr']
        self.test_size = exp['training']['test_size']
        self.env, self.exp = env, exp

        # number of input points to the network
        num_points_small = exp['d_train']['num_pt_mesh_small']
        num_points_large = exp['d_train']['num_pt_mesh_large']
        num_obj = exp['d_train']['objects']

        self.df_pose_estimator = PoseNet(
            num_points=exp['d_test']['num_points'], num_obj=num_obj)

        self.df_refiner = PoseRefineNet(
            num_points=exp['d_test']['num_points'], num_obj=num_obj)

        if exp.get('model', {}).get('df_load', False):
            self.df_pose_estimator.load_state_dict(
                torch.load(exp['model']['df_pose_estimator']))
            if exp.get('model', {}).get('df_refine', False):
                self.df_refiner.load_state_dict(
                    torch.load(exp['model']['df_refiner']))

        sl = exp['d_train']['obj_list_sym']
        self.df_criterion = Loss( num_points_large, sl)
        self.df_criterion_refine = Loss_refine( num_points_large, sl)
        
        self.criterion_adds = LossAddS(sym_list=sl)
        
        self.visualizer = Visualizer(self.exp['model_path'] + '/visu/', None)
        
        self._dict_track = {}
        self.number_images_log_test = self.exp.get(
            'visu', {}).get('number_images_log_test', 1)
        self.counter_images_logged = 0
        self.init_train_vali_split = False

        mp = exp['model_path']
        fh = logging.FileHandler(f'{mp}/Live_Logger_Lightning.log')
        fh.setLevel(logging.DEBUG)
        logging.getLogger("lightning").addHandler(fh)
        
        self.start = time.time()

        self.best_val_loss = 999
        
        # optional, set the logging level
        if self.exp.get('visu', {}).get('log_to_file', False):
            console = logging.StreamHandler()
            console.setLevel(logging.DEBUG)
            logging.getLogger("lightning").addHandler(console)
            log = open(f'{mp}/Live_Logger_Lightning.log', "a")
            sys.stdout = log
            logging.info('Logging to File')

    def forward(self, batch):
        st = time.time()
        
        # unpack batch
        points, choose, img, target, model_points, idx = batch[0:6]
        log_scalars = {}
        bs = points.shape[0]

        tight_padded_img_batch = tight_image_batch(
                img, device=self.device)

        pred_r = torch.zeros((bs, 1000, 4), device=self.device)
        pred_t = torch.zeros((bs, 1000, 3), device=self.device)
        pred_c = torch.zeros((bs, 1000, 1), device=self.device)
        emb = torch.zeros((bs, 32, 1000), device=self.device)
        
        for i in range(bs):
            pred_r[i], pred_t[i], pred_c[i], emb[i] = self.df_pose_estimator(
                ret_cropped_image(img[i])[None],
                points[i][None],
                choose[i][None],
                idx[i][None])

        refine = True if exp['model']['df_refine_iterations'] > 0 else False

        loss, dis, new_points, new_target, pred_r_current, pred_t_current = self.df_criterion(
            pred_r, pred_t, pred_c,
            target, model_points, idx,
            points, exp['model']['df_w'], refine)

        for i in range( self.exp['model']['df_refine_iterations'] ):
            pred_r, pred_t = self.df_refiner(new_points, emb, idx)
            dis, new_points, new_target, pred_r_current, pred_t_current = self.df_refine_criterion(
                pred_r, pred_t, new_target, model_points, idx,
                new_points, pred_r_current, pred_t_current)

        return loss, dis, pred_r_current, pred_t_current, new_points, log_scalars

    def training_step(self, batch, batch_idx):
        self._mode = 'train'
        st = time.time()
        total_loss = 0
        total_dis = 0
        
        # forward
        loss, dis, pred_r_current, pred_t_current, new_points, log_scalars = self(batch[0])

        if self.counter_images_logged < self.exp.get('visu', {}).get('images_train', 1):
            # self.visu_batch(batch, pred_trans, pred_rot_wxyz, pred_points) TODO
            pass 
        # tensorboard logging
        loss = torch.mean(loss, dim= 0) 
        tensorboard_logs = {'train_loss': float(loss)}
        tensorboard_logs = {**tensorboard_logs, **log_scalars}
        self._dict_track = {**self._dict_track}
        return {'train_loss': loss, 'log': tensorboard_logs,'progress_bar': {'Loss': loss, 'ADD-S': torch.mean(dis, dim= 0) }  }

    def validation_step(self, batch, batch_idx):
        self._mode = 'train'
        st = time.time()
        total_loss = 0
        total_dis = 0
        
        # forward
        loss, dis, pred_r_current, pred_t_current, new_points, log_scalars = self(batch[0])

        if self.counter_images_logged < self.exp.get('visu', {}).get('images_train', 1):
            self.visu_batch(batch[0], pred_r_current, pred_t_current, new_points) 
        # tensorboard logging
        loss = torch.mean(loss, dim= 0)
        dis = torch.mean(dis, dim= 0)
        tensorboard_logs = {'val_loss': float( loss ), 'val_dis': loss, 'val_dis_float': float(loss) }
        
        tensorboard_logs = {**tensorboard_logs, **log_scalars}

        self._dict_track = {**self._dict_track,'val_dis_float': float(loss), 'val_dis': float(loss), 'val_loss': float(loss)}
        
        return{'val_loss': loss, 'val_dis': loss, 'log': tensorboard_logs} # 'progress_bar': {'L_Seg': log_scalars['loss_segmentation'], 'L_Add': log_scalars['loss_pose_add'], 'L_Tra': log_scalars[f'loss_translation']}}
    
    def test_step(self, batch, batch_idx):
        self._mode = 'train'
        st = time.time()
        total_loss = 0
        total_dis = 0
        
        # forward
        loss, dis, pred_r_current, pred_t_current, new_points, log_scalars = self(batch[0])

        if self.counter_images_logged < self.exp.get('visu', {}).get('images_train', 1):
            # self.visu_batch(batch, pred_trans, pred_rot_wxyz, pred_points) TODO
            pass 
        # tensorboard logging
        
        tensorboard_logs = {'train_loss': float(dis)}
        tensorboard_logs = {**tensorboard_logs, **log_scalars}
        self._dict_track = {**self._dict_track}
        return {'loss': dis, 'log': tensorboard_logs} # 'progress_bar': {'L_Seg': log_scalars['loss_segmentation'], 'L_Add': log_scalars['loss_pose_add'], 'L_Tra': log_scalars[f'loss_translation']}}

    def validation_epoch_end(self, outputs):
        avg_dict = {}
        self.counter_images_logged = 0  # reset image log counter

        # only keys that are logged in tensorboard are removed from log_scalars !
        for old_key in list(self._dict_track.keys()):
            if old_key.find('val') == -1:
                continue

            newk = 'avg_' + old_key
            avg_dict['avg_' +
                     old_key] = float(np.mean(np.array(self._dict_track[old_key])))

            p = old_key.find('adds_dis')
            if p != -1:
                auc = compute_auc(self._dict_track[old_key])
                avg_dict[old_key[:p] + 'auc [0 - 100]'] = auc

            self._dict_track.pop(old_key, None)

        df1 = dict_to_df(avg_dict)
        df2 = dict_to_df(get_df_dict(pre='val'))
        img = compare_df(df1, df2, key='auc [0 - 100]')
        tag = 'val_table_res_vs_df'
        img.save(self.exp['model_path'] +
                 f'/visu/{self.current_epoch}_{tag}.png')
        self.logger.experiment.add_image(tag, np.array(img).astype(
            np.uint8), global_step=self.current_epoch, dataformats='HWC')

        avg_val_dis_float = float(0)
        if  avg_dict.get( 'avg_val_loss',999) < self.best_val_loss:
            self.best_val_loss = avg_dict.get( 'avg_val_loss',999)

        return {'avg_val_dis_float': float(avg_dict.get( 'avg_val_loss',999)),
                'log': avg_dict}

    def train_epoch_end(self, outputs):
        self.counter_images_logged = 0  # reset image log counter
        avg_dict = {}
        for old_key in list(self._dict_track.keys()):
            if old_key.find('train') == -1:
                continue
            avg_dict['avg_' +
                     old_key] = float(np.mean(np.array(self._dict_track[old_key])))
            self._dict_track.pop(old_key, None)
        string = 'Time for one epoch: ' + str(time.time() - self.start)
        print(string)
        self.start = time.time()
        return {**avg_dict, 'log': avg_dict}

    def test_epoch_end(self, outputs):
        self.counter_images_logged = 0  # reset image log counter
        avg_dict = {}
        # only keys that are logged in tensorboard are removed from log_scalars !
        for old_key in list(self._dict_track.keys()):
            if old_key.find('test') == -1:
                continue

            newk = 'avg_' + old_key
            avg_dict['avg_' +
                     old_key] = float(np.mean(np.array(self._dict_track[old_key])))

            p = old_key.find('adds_dis')
            if p != -1:
                auc = compute_auc(self._dict_track[old_key])
                avg_dict[old_key[:p] + 'auc [0 - 100]'] = auc

            self._dict_track.pop(old_key, None)

        avg_test_dis_float = float(avg_dict['avg_test_loss  [+inf - 0]'])

        df1 = dict_to_df(avg_dict)
        df2 = dict_to_df(get_df_dict(pre='test'))
        img = compare_df(df1, df2, key='auc [0 - 100]')
        tag = 'test_table_res_vs_df'
        img.save(self.exp['model_path'] +
                 f'/visu/{self.current_epoch}_{tag}.png')
        self.logger.experiment.add_image(tag, np.array(img).astype(
            np.uint8), global_step=self.current_epoch, dataformats='HWC')

        return {'avg_test_dis_float': avg_test_dis_float,
                'avg_test_dis': avg_dict['avg_test_loss  [+inf - 0]'],
                'log': avg_dict}

    def visu_batch(self, batch, pred_r_current, pred_t_current, new_points):
        target = copy.deepcopy(batch[3][0].detach().cpu().numpy())
        mp = copy.deepcopy(batch[4][0].detach().cpu().numpy())
        gt_rot_wxyz, gt_trans, unique_desig = batch[10:13]
        img = batch[8].detach().cpu().numpy()[0]
        cam = batch[9][0]
        pre = f'%s_obj%d' % (str(unique_desig[0][0]).replace('/', "_"), int(unique_desig[1][0])) 
        store = self.exp['visu'].get('store', False)
        self.visualizer.plot_estimated_pose(tag=f'target_{pre}',
                                            epoch=self.current_epoch,
                                            img=img,
                                            points=target,
                                            cam_cx=float(cam[0]),
                                            cam_cy=float(cam[1]),
                                            cam_fx=float(cam[2]),
                                            cam_fy=float(cam[3]),
                                            store=store)

        self.visualizer.plot_estimated_pose(tag=f'new_points_{pre}',
                                            epoch=self.current_epoch,
                                            img=img,
                                            points=new_points[0].clone().detach().cpu().numpy(),
                                            cam_cx=float(cam[0]),
                                            cam_cy=float(cam[1]),
                                            cam_fx=float(cam[2]),
                                            cam_fy=float(cam[3]),
                                            store=store)
        t = pred_t_current.detach().cpu().numpy()[0,:][None,:]
        mat  = quat_to_rot(pred_r_current).detach().cpu().numpy()[0]
        self.visualizer.plot_estimated_pose(tag=f'pred_{pre}',
                                            epoch=self.current_epoch,
                                            img=img,
                                            points=mp,
                                            trans=t,
                                            rot_mat=mat,
                                            cam_cx=float(cam[0]),
                                            cam_cy=float(cam[1]),
                                            cam_fx=float(cam[2]),
                                            cam_fy=float(cam[3]),
                                            store=store)

        #     self.visualizer.plot_contour(tag='gt_contour_%s_obj%d' % (str(unique_desig[0][0]).replace('/', "_"), int(unique_desig[1][0])),
        #                                  epoch=self.current_epoch,
        #                                  img=img,
        #                                  points=points,
        #                                  cam_cx=float(cam[0]),
        #                                  cam_cy=float(cam[1]),
        #                                  cam_fx=float(cam[2]),
        #                                  cam_fy=float(cam[3]),
        #                                  store=store)

        # t = pred_t.detach().cpu().numpy()
        # r = pred_r.detach().cpu().numpy()

        # rot = R.from_quat(re_quat(r, 'wxyz'))

        # self.visualizer.plot_estimated_pose(tag='pred_%s_obj%d' % (str(unique_desig[0][0]).replace('/', "_"), int(unique_desig[1][0])),
        #                                     epoch=self.current_epoch,
        #                                     img=img,
        #                                     points=copy.deepcopy(
        #     model_points[:, :].detach(
        #     ).cpu().numpy()),
        #     trans=t.reshape((1, 3)),
        #     rot_mat=rot.as_matrix(),
        #     cam_cx=float(cam[0]),
        #     cam_cy=float(cam[1]),
        #     cam_fx=float(cam[2]),
        #     cam_fy=float(cam[3]),
        #     store=store)

        # self.visualizer.plot_contour(tag='pred_contour_%s_obj%d' % (str(unique_desig[0][0]).replace('/', "_"), int(unique_desig[1][0])),
        #                              epoch=self.current_epoch,
        #                              img=img,
        #                              points=copy.deepcopy(
        #     model_points[:, :].detach(
        #     ).cpu().numpy()),
        #     trans=t.reshape((1, 3)),
        #     rot_mat=rot.as_matrix(),
        #     cam_cx=float(cam[0]),
        #     cam_cy=float(cam[1]),
        #     cam_fx=float(cam[2]),
        #     cam_fy=float(cam[3]),
        #     store=store)

        # render_img, depth, h_render = self.vm.get_closest_image_batch(
        #     i=idx.unsqueeze(0), rot=pred_r.unsqueeze(0), conv='wxyz')
        # # get the bounding box !
        # w = 640
        # h = 480

        # real_img = torch.zeros((1, 3, h, w), device=self.device)
        # # update the target to get new bb

        # base_inital = quat_to_rot(
        #     pred_r.unsqueeze(0), 'wxyz', device=self.device).squeeze(0)
        # base_new = base_inital.view(-1, 3, 3).permute(0, 2, 1)
        # pred_points = torch.add(
        #     torch.bmm(model_points.unsqueeze(0), base_inital.unsqueeze(0)), pred_t)
        # # torch.Size([16, 2000, 3]), torch.Size([16, 4]) , torch.Size([16, 3])
        # bb_ls = get_bb_real_target(
        #     pred_points, cam.unsqueeze(0))

        # for j, b in enumerate(bb_ls):
        #     if not b.check_min_size():
        #         pass
        #     c = cam.unsqueeze(0)
        #     center_real = backproject_points(
        #         pred_t.view(1, 3), fx=c[j, 2], fy=c[j, 3], cx=c[j, 0], cy=c[j, 1])
        #     center_real = center_real.squeeze()
        #     b.move(-center_real[0], -center_real[1])
        #     b.expand(1.1)
        #     b.expand_to_correct_ratio(w, h)
        #     b.move(center_real[0], center_real[1])
        #     crop_real = b.crop(img_orig).unsqueeze(0)
        #     up = torch.nn.UpsamplingBilinear2d(size=(h, w))
        #     crop_real = torch.transpose(crop_real, 1, 3)
        #     crop_real = torch.transpose(crop_real, 2, 3)
        #     real_img[j] = up(crop_real)
        # inp = real_img[0].unsqueeze(0)
        # inp = torch.transpose(inp, 1, 3)
        # inp = torch.transpose(inp, 1, 2)
        # data = torch.cat([inp, render_img], dim=3)
        # data = torch.transpose(data, 1, 3)
        # data = torch.transpose(data, 2, 3)
        # self.visualizer.visu_network_input(tag='render_real_comp_%s_obj%d' % (str(unique_desig[0][0]).replace('/', "_"), int(unique_desig[1][0])),
        #                                    epoch=self.current_epoch,
        #                                    data=data,
        #                                    max_images=1, store=store)

    def configure_optimizers(self):
        
        optimizer = torch.optim.Adam(
            [{'params': self.df_pose_estimator.parameters()}], lr=self.hparams['lr'])
        scheduler = {
            'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **self.exp['lr_cfg']['on_plateau_cfg']),
            **self.exp['lr_cfg']['scheduler']
        }
        return [optimizer], [scheduler]

    def train_dataloader(self):
        self.visualizer.writer = self.logger.experiment
        dataset_train = GenericDataset(
            cfg_d=self.exp['d_train'],
            cfg_env=self.env)

        # initalize train and validation indices
        if not self.init_train_vali_split:
            self.init_train_vali_split = True
            self.indices_valid, self.indices_train = sklearn.model_selection.train_test_split(
                range(0, len(dataset_train)), test_size=self.test_size)

        dataset_subset = torch.utils.data.Subset(
            dataset_train, self.indices_train)

        dataloader_train = torch.utils.data.DataLoader(dataset_train,
                                                       **self.exp['loader'])
        return dataloader_train

    def test_dataloader(self):
        self.visualizer.writer = self.logger.experiment
        dataset_test = GenericDataset(
            cfg_d=self.exp['d_test'],
            cfg_env=self.env)
        dataloader_test = torch.utils.data.DataLoader(dataset_test,
                                                      **self.exp['loader'])
        return dataloader_test

    def val_dataloader(self):
        self.visualizer.writer = self.logger.experiment
        dataset_val = GenericDataset(
            cfg_d=self.exp['d_train'],
            cfg_env=self.env)
        # initalize train and validation indices
        if not self.init_train_vali_split:
            self.init_train_vali_split = True
            self.indices_valid, self.indices_train = sklearn.model_selection.train_test_split(
                range(0, len(dataset_val)), test_size=self.test_size)

        dataset_subset = torch.utils.data.Subset(
            dataset_val, self.indices_valid)
        dataloader_val = torch.utils.data.DataLoader(dataset_val,
                                                     **self.exp['loader'])
        return dataloader_val
Ejemplo n.º 6
0
def main():
    opt.manualSeed = random.randint(1, 10000)
    random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)
    if opt.dataset == 'ycb':
        opt.dataset_root = 'datasets/ycb/YCB_Video_Dataset'
        opt.num_objects = 21
        opt.num_points = 1000
        opt.result_dir = 'results/ycb'
        opt.repeat_epoch = 1
    elif opt.dataset == 'linemod':
        opt.dataset_root = 'datasets/linemod/Linemod_preprocessed'
        opt.num_objects = 13
        opt.num_points = 500
        opt.result_dir = 'results/linemod'
        opt.repeat_epoch = 1
    else:
        print('unknown dataset')
        return
    if opt.dataset == 'ycb':
        dataset = PoseDataset_ycb('train', opt.num_points, True,
                                  opt.dataset_root, opt.noise_trans)
        test_dataset = PoseDataset_ycb('test', opt.num_points, False,
                                       opt.dataset_root, 0.0)
    elif opt.dataset == 'linemod':
        dataset = PoseDataset_linemod('train', opt.num_points, True,
                                      opt.dataset_root, opt.noise_trans)
        test_dataset = PoseDataset_linemod('test', opt.num_points, False,
                                           opt.dataset_root, 0.0)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=1,
                                             shuffle=True,
                                             num_workers=opt.workers)
    testdataloader = torch.utils.data.DataLoader(test_dataset,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 num_workers=opt.workers)
    opt.sym_list = dataset.get_sym_list()
    opt.num_points_mesh = dataset.get_num_points_mesh()
    opt.diameters = dataset.get_diameter()
    print('>>>>>>>>----------Dataset loaded!---------<<<<<<<<')
    print('length of the training set: {0}'.format(len(dataset)))
    print('length of the testing set: {0}'.format(len(test_dataset)))
    print('number of sample points on mesh: {0}'.format(opt.num_points_mesh))
    print('symmetrical object list: {0}'.format(opt.sym_list))

    if not os.path.exists(opt.result_dir):
        os.makedirs(opt.result_dir)
    tb_writer = tf.summary.FileWriter(opt.result_dir)
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id
    # network
    estimator = PoseNet(num_points=opt.num_points,
                        num_obj=opt.num_objects,
                        num_rot=opt.num_rot)
    estimator.cuda()
    # loss
    criterion = Loss(opt.sym_list, estimator.rot_anchors)
    knn = KNearestNeighbor(1)
    # learning rate decay
    best_test = np.Inf
    opt.first_decay_start = False
    opt.second_decay_start = False
    # if resume training
    if opt.resume_posenet != '':
        estimator.load_state_dict(torch.load(opt.resume_posenet))
        model_name_parsing = (opt.resume_posenet.split('.')[0]).split('_')
        best_test = float(model_name_parsing[-1])
        opt.start_epoch = int(model_name_parsing[-2]) + 1
        if best_test < 0.016 and not opt.first_decay_start:
            opt.first_decay_start = True
            opt.lr *= 0.6
        if best_test < 0.013 and not opt.second_decay_start:
            opt.second_decay_start = True
            opt.lr *= 0.5
    # optimizer
    optimizer = torch.optim.Adam(estimator.parameters(), lr=opt.lr)
    global_step = (len(dataset) //
                   opt.batch_size) * opt.repeat_epoch * (opt.start_epoch - 1)
    # train
    st_time = time.time()
    for epoch in range(opt.start_epoch, opt.nepoch):
        logger = setup_logger(
            'epoch%02d' % epoch,
            os.path.join(opt.result_dir, 'epoch_%02d_train_log.txt' % epoch))
        logger.info('Train time {0}'.format(
            time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) +
            ', ' + 'Training started'))
        train_count = 0
        train_loss_avg = 0.0
        train_loss_r_avg = 0.0
        train_loss_t_avg = 0.0
        train_loss_reg_avg = 0.0
        estimator.train()
        optimizer.zero_grad()
        for rep in range(opt.repeat_epoch):
            for i, data in enumerate(dataloader, 0):
                points, choose, img, target_t, target_r, model_points, idx, gt_t = data
                obj_diameter = opt.diameters[idx]
                points, choose, img, target_t, target_r, model_points, idx = Variable(points).cuda(), \
                                                                             Variable(choose).cuda(), \
                                                                             Variable(img).cuda(), \
                                                                             Variable(target_t).cuda(), \
                                                                             Variable(target_r).cuda(), \
                                                                             Variable(model_points).cuda(), \
                                                                             Variable(idx).cuda()
                pred_r, pred_t, pred_c = estimator(img, points, choose, idx)
                loss, loss_r, loss_t, loss_reg = criterion(
                    pred_r, pred_t, pred_c, target_r, target_t, model_points,
                    idx, obj_diameter)
                loss.backward()
                train_loss_avg += loss.item()
                train_loss_r_avg += loss_r.item()
                train_loss_t_avg += loss_t.item()
                train_loss_reg_avg += loss_reg.item()
                train_count += 1
                if train_count % opt.batch_size == 0:
                    global_step += 1
                    lr = opt.lr
                    optimizer.step()
                    optimizer.zero_grad()
                    # write results to tensorboard
                    summary = tf.Summary(value=[
                        tf.Summary.Value(tag='learning_rate', simple_value=lr),
                        tf.Summary.Value(tag='loss',
                                         simple_value=train_loss_avg /
                                         opt.batch_size),
                        tf.Summary.Value(tag='loss_r',
                                         simple_value=train_loss_r_avg /
                                         opt.batch_size),
                        tf.Summary.Value(tag='loss_t',
                                         simple_value=train_loss_t_avg /
                                         opt.batch_size),
                        tf.Summary.Value(tag='loss_reg',
                                         simple_value=train_loss_reg_avg /
                                         opt.batch_size)
                    ])
                    tb_writer.add_summary(summary, global_step)
                    logger.info(
                        'Train time {0} Epoch {1} Batch {2} Frame {3} Avg_loss:{4:f}'
                        .format(
                            time.strftime("%Hh %Mm %Ss",
                                          time.gmtime(time.time() - st_time)),
                            epoch, int(train_count / opt.batch_size),
                            train_count, train_loss_avg / opt.batch_size))
                    train_loss_avg = 0.0
                    train_loss_r_avg = 0.0
                    train_loss_t_avg = 0.0
                    train_loss_reg_avg = 0.0

        print(
            '>>>>>>>>----------epoch {0} train finish---------<<<<<<<<'.format(
                epoch))

        logger = setup_logger(
            'epoch%02d_test' % epoch,
            os.path.join(opt.result_dir, 'epoch_%02d_test_log.txt' % epoch))
        logger.info('Test time {0}'.format(
            time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) +
            ', ' + 'Testing started'))
        test_dis = 0.0
        test_count = 0
        save_model = False
        estimator.eval()
        success_count = [0 for i in range(opt.num_objects)]
        num_count = [0 for i in range(opt.num_objects)]

        for j, data in enumerate(testdataloader, 0):
            points, choose, img, target_t, target_r, model_points, idx, gt_t = data
            obj_diameter = opt.diameters[idx]
            points, choose, img, target_t, target_r, model_points, idx = Variable(points).cuda(), \
                                                                         Variable(choose).cuda(), \
                                                                         Variable(img).cuda(), \
                                                                         Variable(target_t).cuda(), \
                                                                         Variable(target_r).cuda(), \
                                                                         Variable(model_points).cuda(), \
                                                                         Variable(idx).cuda()
            pred_r, pred_t, pred_c = estimator(img, points, choose, idx)
            loss, _, _, _ = criterion(pred_r, pred_t, pred_c, target_r,
                                      target_t, model_points, idx,
                                      obj_diameter)
            test_count += 1
            # evalaution
            how_min, which_min = torch.min(pred_c, 1)
            pred_r = pred_r[0][which_min[0]].view(-1).cpu().data.numpy()
            pred_r = quaternion_matrix(pred_r)[:3, :3]
            pred_t, pred_mask = ransac_voting_layer(points, pred_t)
            pred_t = pred_t.cpu().data.numpy()
            model_points = model_points[0].cpu().detach().numpy()
            pred = np.dot(model_points, pred_r.T) + pred_t
            target = target_r[0].cpu().detach().numpy() + gt_t[0].cpu(
            ).data.numpy()
            if idx[0].item() in opt.sym_list:
                pred = torch.from_numpy(pred.astype(
                    np.float32)).cuda().transpose(1, 0).contiguous()
                target = torch.from_numpy(target.astype(
                    np.float32)).cuda().transpose(1, 0).contiguous()
                inds = knn(target.unsqueeze(0), pred.unsqueeze(0))
                target = torch.index_select(target, 1, inds.view(-1) - 1)
                dis = torch.mean(torch.norm(
                    (pred.transpose(1, 0) - target.transpose(1, 0)), dim=1),
                                 dim=0).item()
            else:
                dis = np.mean(np.linalg.norm(pred - target, axis=1))
            logger.info(
                'Test time {0} Test Frame No.{1} loss:{2:f} confidence:{3:f} distance:{4:f}'
                .format(
                    time.strftime("%Hh %Mm %Ss",
                                  time.gmtime(time.time() - st_time)),
                    test_count, loss, how_min[0].item(), dis))
            if dis < 0.1 * opt.diameters[idx[0].item()]:
                success_count[idx[0].item()] += 1
            num_count[idx[0].item()] += 1
            test_dis += dis
        # compute accuracy
        accuracy = 0.0
        for i in range(opt.num_objects):
            accuracy += float(success_count[i]) / num_count[i]
            logger.info('Object {0} success rate: {1}'.format(
                test_dataset.objlist[i],
                float(success_count[i]) / num_count[i]))
        accuracy = accuracy / opt.num_objects
        test_dis = test_dis / test_count
        # log results
        logger.info(
            'Test time {0} Epoch {1} TEST FINISH Avg dis: {2:f}, Accuracy: {3:f}'
            .format(
                time.strftime("%Hh %Mm %Ss",
                              time.gmtime(time.time() - st_time)), epoch,
                test_dis, accuracy))
        # tensorboard
        summary = tf.Summary(value=[
            tf.Summary.Value(tag='accuracy', simple_value=accuracy),
            tf.Summary.Value(tag='test_dis', simple_value=test_dis)
        ])
        tb_writer.add_summary(summary, global_step)
        # save model
        if test_dis < best_test:
            best_test = test_dis
        torch.save(
            estimator.state_dict(),
            '{0}/pose_model_{1:02d}_{2:06f}.pth'.format(
                opt.result_dir, epoch, best_test))
        # adjust learning rate if necessary
        if best_test < 0.016 and not opt.first_decay_start:
            opt.first_decay_start = True
            opt.lr *= 0.6
            optimizer = torch.optim.Adam(estimator.parameters(), lr=opt.lr)
        if best_test < 0.013 and not opt.second_decay_start:
            opt.second_decay_start = True
            opt.lr *= 0.5
            optimizer = torch.optim.Adam(estimator.parameters(), lr=opt.lr)

        print(
            '>>>>>>>>----------epoch {0} test finish---------<<<<<<<<'.format(
                epoch))
Ejemplo n.º 7
0
def main():
    if opt.dataset == 'ycb':
        opt.num_obj = 21
        opt.sym_list = [12, 15, 18, 19, 20]
        opt.num_points = 1000
        writer = SummaryWriter('experiments/runs/ycb/{0}'.format(opt.experiment_name))
        opt.outf = 'trained_models/ycb/{0}'.format(opt.experiment_name)
        opt.log_dir = 'experiments/logs/ycb/{0}'.format(opt.experiment_name)
        opt.repeat_num = 1
        if not os.path.exists(opt.outf):
            os.mkdir(opt.outf)
        if not os.path.exists(opt.log_dir):
            os.mkdir(opt.log_dir)
    else:
        print('Unknown dataset')
        return

    estimator = PoseNet(num_points = opt.num_points, num_vote = 9, num_obj = opt.num_obj)
    estimator.cuda()
    refiner = PoseRefineNet(num_points = opt.num_points, num_obj = opt.num_obj)
    refiner.cuda()

    if opt.resume_posenet != '':
        estimator.load_state_dict(torch.load('{0}/{1}'.format(opt.outf, opt.resume_posenet)))
    if opt.resume_refinenet != '':
        refiner.load_state_dict(torch.load('{0}/{1}'.format(opt.outf, opt.resume_refinenet)))
        opt.refine_start = True
        opt.lr = opt.lr_refine
        opt.batch_size = int(opt.batch_size / opt.iteration)
        optimizer = optim.Adam(refiner.parameters(), lr=opt.lr)
    else:
        opt.refine_start = False
        optimizer = optim.Adam(estimator.parameters(), lr=opt.lr)

    dataset = PoseDataset_ycb('train', opt.num_points, True, opt.dataset_root)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=opt.workers)
    test_dataset = PoseDataset_ycb('test', opt.num_points, False, opt.dataset_root)
    testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=opt.workers)

    print('>>>>>>>>----------Dataset loaded!---------<<<<<<<<\nlength of the training set: {0}\nlength of the testing set: {1}\nnumber of sample points on mesh: {2}'.format(len(dataset), len(test_dataset), opt.num_points))

    criterion = Loss(opt.num_points, opt.sym_list)
    criterion_refine = Loss_refine(opt.num_points, opt.sym_list)
    best_test = np.Inf

    if opt.start_epoch == 1:
        for log in os.listdir(opt.log_dir):
            os.remove(os.path.join(opt.log_dir, log))
    st_time = time.time()
    train_scalar = 0

    for epoch in range(opt.start_epoch, opt.nepoch):
        logger = setup_logger('epoch%d' % epoch, os.path.join(opt.log_dir, 'epoch_%d_log.txt' % epoch))
        logger.info('Train time {0}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) + ', ' + 'Training started'))
        train_count = 0
        train_loss_avg = 0.0
        train_loss = 0.0
        train_dis_avg = 0.0
        train_dis = 0.0
        if opt.refine_start:
            estimator.eval()
            refiner.train()
        else:
            estimator.train()
        optimizer.zero_grad()
        for rep in range(opt.repeat_num):
            for i, data in enumerate(dataloader, 0):
                points, choose, img, target, model_points, model_kp, vertex_gt, idx, target_r, target_t = data
                points, choose, img, target, model_points, model_kp, vertex_gt, idx, target_r, target_t = points.cuda(), choose.cuda(), img.cuda(), target.cuda(), model_points.cuda(), model_kp.cuda(), vertex_gt.cuda(), idx.cuda(), target_r.cuda(), target_t.cuda()
                vertex_pred, c_pred, emb = estimator(img, points, choose, idx)
                vertex_loss, pose_loss, dis, new_points, new_target = criterion(vertex_pred, vertex_gt, c_pred, points, target, model_points, model_kp, idx, target_r, target_t)
                loss = 10 * vertex_loss + pose_loss
                if opt.refine_start:
                    for ite in range(0, opt.iteration):
                        pred_r, pred_t = refiner(new_points, emb, idx)
                        dis, new_points, new_target = criterion_refine(pred_r, pred_t, new_points, new_target, model_points, idx)
                        dis.backward()
                else:
                    loss.backward()
                train_loss_avg += loss.item()
                train_loss += loss.item()
                train_dis_avg += dis.item()
                train_dis += dis.item()
                train_count += 1
                train_scalar += 1

                if train_count % opt.batch_size == 0:
                    logger.info('Train time {0} Epoch {1} Batch {2} Frame {3} Avg_loss:{4} Avg_diss:{5}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), epoch, int(train_count / opt.batch_size), train_count, train_loss_avg / opt.batch_size, train_dis_avg / opt.batch_size))
                    writer.add_scalar('ycb training loss', train_loss_avg / opt.batch_size, train_scalar)
                    writer.add_scalar('ycb training dis', train_dis_avg / opt.batch_size, train_scalar)
                    optimizer.step()
                    optimizer.zero_grad()
                    train_loss_avg = 0
                    train_dis_avg = 0

                if train_count != 0 and train_count % 1000 == 0:
                    if opt.refine_start:
                        torch.save(refiner.state_dict(), '{0}/pose_refine_model_current.pth'.format(opt.outf))
                    else:
                        torch.save(estimator.state_dict(), '{0}/pose_model_current.pth'.format(opt.outf))

        print('>>>>>>>>----------epoch {0} train finish---------<<<<<<<<'.format(epoch))
        train_loss = train_loss / train_count
        train_dis = train_dis / train_count
        logger.info('Train time {0} Epoch {1} TRAIN FINISH Avg loss: {2} Avg dis: {3}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), epoch, train_loss, train_dis))

        logger = setup_logger('epoch%d_test' % epoch, os.path.join(opt.log_dir, 'epoch_%d_test_log.txt' % epoch))
        logger.info('Test time {0}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) + ', ' + 'Testing started'))
        test_loss = 0.0
        test_vertex_loss = 0.0
        test_pose_loss = 0.0
        test_dis = 0.0
        test_count = 0
        success_count = 0
        estimator.eval()
        refiner.eval()
        for j, data in enumerate(testdataloader, 0):
            points, choose, img, target, model_points, model_kp, vertex_gt, idx, target_r, target_t = data
            points, choose, img, target, model_points, model_kp, vertex_gt, idx, target_r, target_t = points.cuda(), choose.cuda(), img.cuda(), target.cuda(), model_points.cuda(), model_kp.cuda(), vertex_gt.cuda(), idx.cuda(), target_r.cuda(), target_t.cuda()
            vertex_pred, c_pred, emb = estimator(img, points, choose, idx)
            vertex_loss, pose_loss, dis, new_points, new_target = criterion(vertex_pred, vertex_gt, c_pred, points, target, model_points, model_kp, idx, target_r, target_t)
            loss = 10 * vertex_loss + pose_loss
            if opt.refine_start:
                for ite in range(0, opt.iteration):
                    pred_r, pred_t = refiner(new_points, emb, idx)
                    dis, new_points, new_target = criterion_refine(pred_r, pred_t, new_points, new_target, model_points, idx)
            test_loss += loss.item()
            test_vertex_loss += vertex_loss.item()
            test_pose_loss += pose_loss.item()
            test_dis += dis.item()
            logger.info('Test time {0} Test Frame No.{1} loss:{2} dis:{3}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), test_count, loss, dis))
            test_count += 1
            if dis.item() < 0.02:
                success_count += 1

        test_loss = test_loss / test_count
        test_vertex_loss = test_vertex_loss / test_count
        test_pose_loss = test_pose_loss / test_count
        test_dis = test_dis / test_count
        logger.info('Test time {0} Epoch {1} TEST FINISH Avg loss: {2} Avg dis: {3}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), epoch, test_loss, test_dis))
        logger.info('Success rate: {}'.format(float(success_count) / test_count))
        writer.add_scalar('ycb test loss', test_loss, epoch)
        writer.add_scalar('ycb test vertex loss', test_vertex_loss, epoch)
        writer.add_scalar('ycb test pose loss', test_pose_loss, epoch)
        writer.add_scalar('ycb test dis', test_dis, epoch)
        writer.add_scalar('ycb success rate', float(success_count) / test_count, epoch)
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)
        if test_dis <= best_test:
            best_test = test_dis
        if opt.refine_start:
            torch.save(refiner.state_dict(), '{0}/pose_refine_model_{1}_{2}.pth'.format(opt.outf, epoch, test_dis))
        else:
            torch.save(estimator.state_dict(), '{0}/pose_model_{1}_{2}.pth'.format(opt.outf, epoch, test_dis))
        print(epoch, '>>>>>>>>----------MODEL SAVED---------<<<<<<<<')

        if best_test < opt.refine_margin and not opt.refine_start:
            opt.refine_start = True
            opt.lr = opt.lr_refine
            opt.batch_size = int(opt.batch_size / opt.iteration)
            optimizer = optim.Adam(refiner.parameters(), lr=opt.lr)
            print('>>>>>>>>----------Refine started---------<<<<<<<<')

    writer.close()
Ejemplo n.º 8
0
def main():
    opt.manualSeed = random.randint(1, 10000)
    random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)

    opt.num_objects = 21  #number of object classes in the dataset
    opt.num_points = 1000  #number of points on the input pointcloud
    opt.outf = 'trained_models/ycb_rot'  #folder to save trained models
    opt.log_dir = 'experiments/logs/ycb_rot'  #folder to save logs
    opt.repeat_epoch = 1  #number of repeat times for one epoch training

    estimator = PoseNet(num_points=opt.num_points, num_obj=opt.num_objects)
    estimator.cuda()
    refiner = PoseRefineNet(num_points=opt.num_points, num_obj=opt.num_objects)
    refiner.cuda()

    if opt.resume_posenet != '':
        estimator.load_state_dict(
            torch.load('{0}/{1}'.format(opt.outf, opt.resume_posenet)))

    if opt.resume_refinenet != '':
        refiner.load_state_dict(
            torch.load('{0}/{1}'.format(opt.outf, opt.resume_refinenet)))
        opt.refine_start = True
        opt.decay_start = True
        opt.lr *= opt.lr_rate
        opt.w *= opt.w_rate
        opt.batch_size = int(opt.batch_size / opt.iteration)
        optimizer = optim.Adam(refiner.parameters(), lr=opt.lr)
    else:
        opt.refine_start = False
        opt.decay_start = False
        optimizer = optim.Adam(estimator.parameters(), lr=opt.lr)

    object_list = list(range(1, 22))
    output_format = [
        otypes.DEPTH_POINTS_MASKED_AND_INDEXES,
        otypes.IMAGE_CROPPED,
        otypes.MODEL_POINTS_TRANSFORMED,
        otypes.MODEL_POINTS,
        otypes.OBJECT_LABEL,
    ]

    dataset = YCBDataset(opt.dataset_root,
                         mode='train_syn_grid_valid',
                         object_list=object_list,
                         output_data=output_format,
                         resample_on_error=True,
                         preprocessors=[
                             YCBOcclusionAugmentor(opt.dataset_root),
                             ColorJitter(),
                             InplaneRotator()
                         ],
                         postprocessors=[ImageNormalizer(),
                                         PointShifter()],
                         refine=opt.refine_start,
                         image_size=[640, 480],
                         num_points=1000)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=1,
                                             shuffle=True,
                                             num_workers=opt.workers - 1)

    test_dataset = YCBDataset(opt.dataset_root,
                              mode='valid',
                              object_list=object_list,
                              output_data=output_format,
                              resample_on_error=True,
                              preprocessors=[],
                              postprocessors=[ImageNormalizer()],
                              refine=opt.refine_start,
                              image_size=[640, 480],
                              num_points=1000)
    testdataloader = torch.utils.data.DataLoader(test_dataset,
                                                 shuffle=True,
                                                 batch_size=1,
                                                 num_workers=1)
    opt.sym_list = [12, 15, 18, 19, 20]
    opt.num_points_mesh = dataset.num_pt_mesh_small

    print(
        '>>>>>>>>----------Dataset loaded!---------<<<<<<<<\nlength of the training set: {0}\nlength of the testing set: {1}\nnumber of sample points on mesh: {2}\nsymmetry object list: {3}'
        .format(len(dataset), len(test_dataset), opt.num_points_mesh,
                opt.sym_list))

    criterion = Loss(opt.num_points_mesh, opt.sym_list)
    criterion_refine = Loss_refine(opt.num_points_mesh, opt.sym_list)

    best_test = np.Inf

    if opt.start_epoch == 1:
        for log in os.listdir(opt.log_dir):
            os.remove(os.path.join(opt.log_dir, log))
    st_time = time.time()

    for epoch in range(opt.start_epoch, opt.nepoch):
        logger = setup_logger(
            'epoch%d' % epoch,
            os.path.join(opt.log_dir, 'epoch_%d_log.txt' % epoch))
        logger.info('Train time {0}'.format(
            time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) +
            ', ' + 'Training started'))
        train_count = 0
        train_dis_avg = 0.0
        if opt.refine_start:
            estimator.eval()
            refiner.train()
        else:
            estimator.train()
        optimizer.zero_grad()

        for rep in range(opt.repeat_epoch):
            for i, data in enumerate(dataloader, 0):
                points, choose, img, target, model_points, idx = data
                idx = idx - 1
                points, choose, img, target, model_points, idx = Variable(points).cuda(), \
                                                                 Variable(choose).cuda(), \
                                                                 Variable(img).cuda(), \
                                                                 Variable(target).cuda(), \
                                                                 Variable(model_points).cuda(), \
                                                                 Variable(idx).cuda()
                pred_r, pred_t, pred_c, emb = estimator(
                    img, points, choose, idx)
                loss, dis, new_points, new_target = criterion(
                    pred_r, pred_t, pred_c, target, model_points, idx, points,
                    opt.w, opt.refine_start)

                if opt.refine_start:
                    for ite in range(0, opt.iteration):
                        pred_r, pred_t = refiner(new_points, emb, idx)
                        dis, new_points, new_target = criterion_refine(
                            pred_r, pred_t, new_target, model_points, idx,
                            new_points)
                        dis.backward()
                else:
                    loss.backward()

                train_dis_avg += dis.item()
                train_count += 1

                if train_count % opt.batch_size == 0:
                    logger.info(
                        'Train time {0} Epoch {1} Batch {2} Frame {3} Avg_dis:{4}'
                        .format(
                            time.strftime("%Hh %Mm %Ss",
                                          time.gmtime(time.time() - st_time)),
                            epoch, int(train_count / opt.batch_size),
                            train_count, train_dis_avg / opt.batch_size))
                    optimizer.step()
                    optimizer.zero_grad()
                    train_dis_avg = 0

                if train_count != 0 and train_count % 1000 == 0:
                    if opt.refine_start:
                        torch.save(
                            refiner.state_dict(),
                            '{0}/pose_refine_model_current.pth'.format(
                                opt.outf))
                    else:
                        torch.save(
                            estimator.state_dict(),
                            '{0}/pose_model_current.pth'.format(opt.outf))
                if (train_count >= 100000):
                    break
        print(
            '>>>>>>>>----------epoch {0} train finish---------<<<<<<<<'.format(
                epoch))

        logger = setup_logger(
            'epoch%d_test' % epoch,
            os.path.join(opt.log_dir, 'epoch_%d_test_log.txt' % epoch))
        logger.info('Test time {0}'.format(
            time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) +
            ', ' + 'Testing started'))
        test_dis = 0.0
        test_count = 0
        estimator.eval()
        refiner.eval()

        for j, data in enumerate(testdataloader, 0):
            points, choose, img, target, model_points, idx = data
            idx = idx - 1
            points, choose, img, target, model_points, idx = Variable(points).cuda(), \
                                                             Variable(choose).cuda(), \
                                                             Variable(img).cuda(), \
                                                             Variable(target).cuda(), \
                                                             Variable(model_points).cuda(), \
                                                             Variable(idx).cuda()
            pred_r, pred_t, pred_c, emb = estimator(img, points, choose, idx)
            _, dis, new_points, new_target = criterion(pred_r, pred_t, pred_c,
                                                       target, model_points,
                                                       idx, points, opt.w,
                                                       opt.refine_start)

            if opt.refine_start:
                for ite in range(0, opt.iteration):
                    pred_r, pred_t = refiner(new_points, emb, idx)
                    dis, new_points, new_target = criterion_refine(
                        pred_r, pred_t, new_target, model_points, idx,
                        new_points)

            test_dis += dis.item()
            logger.info('Test time {0} Test Frame No.{1} dis:{2}'.format(
                time.strftime("%Hh %Mm %Ss",
                              time.gmtime(time.time() - st_time)), test_count,
                dis))

            test_count += 1
            if (test_count >= 3000):
                break
        test_dis = test_dis / test_count
        logger.info('Test time {0} Epoch {1} TEST FINISH Avg dis: {2}'.format(
            time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)),
            epoch, test_dis))
        if test_dis <= best_test:
            best_test = test_dis
            if opt.refine_start:
                torch.save(
                    refiner.state_dict(),
                    '{0}/pose_refine_model_{1}_{2}.pth'.format(
                        opt.outf, epoch, test_dis))
            else:
                torch.save(
                    estimator.state_dict(),
                    '{0}/pose_model_{1}_{2}.pth'.format(
                        opt.outf, epoch, test_dis))
            print(epoch,
                  '>>>>>>>>----------BEST TEST MODEL SAVED---------<<<<<<<<')

        if best_test < opt.decay_margin and not opt.decay_start:
            opt.decay_start = True
            opt.lr *= opt.lr_rate
            opt.w *= opt.w_rate
            optimizer = optim.Adam(estimator.parameters(), lr=opt.lr)

        if best_test < opt.refine_margin and not opt.refine_start:
            opt.refine_start = True
            opt.batch_size = int(opt.batch_size / opt.iteration)
            optimizer = optim.Adam(refiner.parameters(), lr=opt.lr)

            dataset = YCBDataset(
                opt.dataset_root,
                mode='train_syn_grid',
                object_list=object_list,
                output_data=output_format,
                resample_on_error=True,
                preprocessors=[
                    YCBOcclusionAugmentor(opt.dataset_root),
                    ColorJitter(),
                    InplaneRotator()
                ],
                postprocessors=[ImageNormalizer(),
                                PointShifter()],
                refine=opt.refine_start,
                image_size=[640, 480],
                num_points=1000)
            dataloader = torch.utils.data.DataLoader(dataset,
                                                     batch_size=1,
                                                     shuffle=True,
                                                     num_workers=opt.workers)

            test_dataset = YCBDataset(opt.dataset_root,
                                      mode='valid',
                                      object_list=object_list,
                                      output_data=output_format,
                                      resample_on_error=True,
                                      preprocessors=[],
                                      postprocessors=[ImageNormalizer()],
                                      refine=opt.refine_start,
                                      image_size=[640, 480],
                                      num_points=1000)
            testdataloader = torch.utils.data.DataLoader(
                test_dataset,
                batch_size=1,
                shuffle=False,
                num_workers=opt.workers)
            opt.num_points_mesh = dataset.num_pt_mesh_large

            print(
                '>>>>>>>>----------Dataset loaded!---------<<<<<<<<\nlength of the training set: {0}\nlength of the testing set: {1}\nnumber of sample points on mesh: {2}\nsymmetry object list: {3}'
                .format(len(dataset), len(test_dataset), opt.num_points_mesh,
                        opt.sym_list))

            criterion = Loss(opt.num_points_mesh, opt.sym_list)
            criterion_refine = Loss_refine(opt.num_points_mesh, opt.sym_list)
Ejemplo n.º 9
0
def main():
    # opt.manualSeed = random.randint(1, 10000)
    # # opt.manualSeed = 1
    # random.seed(opt.manualSeed)
    # torch.manual_seed(opt.manualSeed)

    torch.set_printoptions(threshold=5000)
    # device_ids = [0,1]
    cudnn.benchmark = True
    if opt.dataset == 'ycb':
        opt.num_objects = 21  #number of object classes in the dataset
        opt.num_points = 1000  #number of points on the input pointcloud
        opt.outf = 'trained_models/ycb'  #folder to save trained models
        opt.log_dir = 'experiments/logs/ycb'  #folder to save logs
        opt.repeat_epoch = 3  #number of repeat times for one epoch training
    elif opt.dataset == 'linemod':
        opt.num_objects = 13
        opt.num_points = 500
        opt.outf = 'trained_models/linemod'
        opt.log_dir = 'experiments/logs/linemod'
        opt.repeat_epoch = 20
    else:
        print('Unknown dataset')
        return

    estimator = PoseNet(num_points=opt.num_points, num_obj=opt.num_objects)

    estimator.cuda()
    refiner = PoseRefineNet(num_points=opt.num_points, num_obj=opt.num_objects)
    refiner.cuda()
    # estimator = nn.DataParallel(estimator, device_ids=device_ids)

    if opt.resume_posenet != '':
        estimator.load_state_dict(
            torch.load('{0}/{1}'.format(opt.outf, opt.resume_posenet)))

    if opt.resume_refinenet != '':
        refiner.load_state_dict(
            torch.load('{0}/{1}'.format(opt.outf, opt.resume_refinenet)))
        opt.refine_start = True
        opt.decay_start = True
        opt.lr *= opt.lr_rate
        opt.w *= opt.w_rate
        opt.batch_size = int(opt.batch_size / opt.iteration)
        optimizer = optim.Adam(refiner.parameters(), lr=opt.lr)
    else:
        print('no refinement')
        opt.refine_start = False
        opt.decay_start = False
        optimizer = optim.Adam(estimator.parameters(), lr=opt.lr)
        # optimizer = nn.DataParallel(optimizer, device_ids=device_ids)

    if opt.dataset == 'ycb':
        dataset = PoseDataset_ycb('train', opt.num_points, False,
                                  opt.dataset_root, opt.noise_trans,
                                  opt.refine_start)
        # print(dataset.list)
    elif opt.dataset == 'linemod':
        dataset = PoseDataset_linemod('train', opt.num_points, True,
                                      opt.dataset_root, opt.noise_trans,
                                      opt.refine_start)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=1,
                                             shuffle=True,
                                             num_workers=opt.workers)
    if opt.dataset == 'ycb':
        test_dataset = PoseDataset_ycb('test', opt.num_points, False,
                                       opt.dataset_root, 0.0, opt.refine_start)
    elif opt.dataset == 'linemod':
        test_dataset = PoseDataset_linemod('test', opt.num_points, False,
                                           opt.dataset_root, 0.0,
                                           opt.refine_start)
    testdataloader = torch.utils.data.DataLoader(test_dataset,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 num_workers=opt.workers)

    opt.sym_list = dataset.get_sym_list()
    opt.num_points_mesh = dataset.get_num_points_mesh()

    # print('>>>>>>>>----------Dataset loaded!---------<<<<<<<<\nlength of the training set: {0}\nlength of the testing set: {1}\nnumber of sample points on mesh: {2}\nsymmetry object list: {3}'.format(len(dataset), len(test_dataset), opt.num_points_mesh, opt.sym_list))

    criterion = Loss(opt.num_points_mesh, opt.sym_list)
    # criterion_refine = Loss_refine(opt.num_points_mesh, opt.sym_list)

    best_test = np.Inf
    best_epoch = 0

    if opt.start_epoch == 1:
        for log in os.listdir(opt.log_dir):
            os.remove(os.path.join(opt.log_dir, log))
    st_time = time.time()

    count_gen = 0

    mode = 1

    if mode == 1:

        for epoch in range(opt.start_epoch, opt.nepoch):
            logger = setup_logger(
                'epoch%d' % epoch,
                os.path.join(opt.log_dir, 'epoch_%d_log.txt' % epoch))
            logger.info('Train time {0}'.format(
                time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() -
                                                         st_time)) + ', ' +
                'Training started'))
            train_count = 0
            train_dis_avg = 0.0
            if opt.refine_start:
                estimator.eval()
                refiner.train()
            else:
                estimator.train()
            optimizer.zero_grad()

            for rep in range(opt.repeat_epoch):
                for i, data in enumerate(dataloader, 0):
                    points, choose, img, target_sym, target_cen, idx, file_list_idx = data

                    if idx is 9 or idx is 16:
                        continue
                    # points, choose, img, target_sym, target_cen, target, idx, file_list_idx = data
                    # generate_obj_file(target_sym, target_cen, target, idx.squeeze())
                    # import pdb;pdb.set_trace()
                    points, choose, img, target_sym, target_cen, idx = Variable(points).cuda(), \
                    Variable(choose).cuda(), \
                    Variable(img).cuda(), \
                    Variable(target_sym).cuda(), \
                    Variable(target_cen).cuda(), \
                    Variable(idx).cuda()
                    # points, choose, img, target_sym, target_cen, idx = Variable(points), \
                    #                                                 Variable(choose), \
                    #                                                 Variable(img), \
                    #                                                 Variable(target_sym), \
                    #                                                 Variable(target_cen), \
                    #                                                 Variable(idx)
                    pred_norm, pred_on_plane, emb = estimator(
                        img, points, choose, idx)

                    # pred_norm_new = torch.cat((pred_norm, torch.zeros(1,pred_norm.size(1),1)),2)

                    # for i in range(pred_norm.size(1)):
                    #     pred_norm_new[0,i,2] = torch.sqrt(1 - pred_norm[0,i,0] * pred_norm[0,i,0] - pred_norm[0,i,1] * pred_norm[0,i,1])
                    # if epoch % 10 == 0:
                    #     generate_obj_file_pred(pred_norm, pred_on_plane, points, count_gen, idx)
                    #     count_gen += 1
                    # print(pred_norm[0,0,:])

                    loss = criterion(pred_norm, pred_on_plane, target_sym,
                                     target_cen, idx, points, opt.w,
                                     opt.refine_start)

                    # scene_idx = dataset.list[file_list_idx]

                    loss.backward()

                    # train_dis_avg += dis.item()
                    train_count += 1

                    if train_count % opt.batch_size == 0:
                        logger.info(
                            'Train time {0} Epoch {1} Batch {2} Frame {3}'.
                            format(
                                time.strftime(
                                    "%Hh %Mm %Ss",
                                    time.gmtime(time.time() - st_time)), epoch,
                                int(train_count / opt.batch_size),
                                train_count))
                        optimizer.step()
                        # for param_lr in optimizer.module.param_groups:
                        #         param_lr['lr'] /= 2
                        optimizer.zero_grad()
                        train_dis_avg = 0

                    if train_count % 5000 == 0:
                        print(pred_on_plane.max())
                        print(pred_on_plane.mean())

                    if train_count != 0 and train_count % 1000 == 0:
                        if opt.refine_start:
                            torch.save(
                                refiner.state_dict(),
                                '{0}/pose_refine_model_current.pth'.format(
                                    opt.outf))
                        else:
                            torch.save(
                                estimator.state_dict(),
                                '{0}/pose_model_current.pth'.format(opt.outf))

            print('>>>>>>>>----------epoch {0} train finish---------<<<<<<<<'.
                  format(epoch))

            logger = setup_logger(
                'epoch%d_test' % epoch,
                os.path.join(opt.log_dir, 'epoch_%d_test_log.txt' % epoch))
            logger.info('Test time {0}'.format(
                time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() -
                                                         st_time)) + ', ' +
                'Testing started'))
            test_loss = 0.0
            test_count = 0
            estimator.eval()
            # refiner.eval()

            # for rep in range(opt.repeat_epoch):
            #     for j, data in enumerate(testdataloader, 0):
            #         points, choose, img, target_sym, target_cen, idx, img_idx = data
            #         # points, choose, img, target, model_points, idx = Variable(points).cuda(), \
            #         #                                                  Variable(choose).cuda(), \
            #         #                                                  Variable(img).cuda(), \
            #         #                                                  Variable(target).cuda(), \
            #         #                                                  Variable(model_points).cuda(), \
            #         #                                                  Variable(idx).cuda()
            #         points, choose, img, target_sym, target_cen, idx = Variable(points), \
            #                                                             Variable(choose), \
            #                                                             Variable(img), \
            #                                                             Variable(target_sym), \
            #                                                             Variable(target_cen), \
            #                                                             Variable(idx)

            #         pred_norm, pred_on_plane, emb = estimator(img, points, choose, idx)
            #         loss = criterion(pred_norm, pred_on_plane, target_sym, target_cen, idx, points, opt.w, opt.refine_start)
            #         test_loss += loss

            #         logger.info('Test time {0} Test Frame No.{1}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), test_count))

            #         test_count += 1

            # test_loss = test_loss / test_count
            logger.info(
                'Test time {0} Epoch {1} TEST FINISH Avg dis: {2}'.format(
                    time.strftime("%Hh %Mm %Ss",
                                  time.gmtime(time.time() - st_time)), epoch,
                    test_loss))
            print(pred_on_plane.max())
            print(pred_on_plane.mean())
            bs, num_p, _ = pred_on_plane.size()
            # if epoch % 40 == 0:
            #     import pdb;pdb.set_trace()
            best_test = test_loss
            best_epoch = epoch
            if opt.refine_start:
                torch.save(
                    refiner.state_dict(),
                    '{0}/pose_refine_model_{1}_{2}.pth'.format(
                        opt.outf, epoch, test_loss))
            else:
                torch.save(
                    estimator.state_dict(),
                    '{0}/pose_model_{1}_{2}.pth'.format(
                        opt.outf, epoch, test_loss))
            print(epoch,
                  '>>>>>>>>----------BEST TEST MODEL SAVED---------<<<<<<<<')

            if best_test < opt.decay_margin and not opt.decay_start:
                opt.decay_start = True
                opt.lr *= opt.lr_rate
                # opt.w *= opt.w_rate
                optimizer = optim.Adam(estimator.parameters(), lr=opt.lr)

        estimator.load_state_dict(
            torch.load('{0}/pose_model_{1}_{2}.pth'.format(
                opt.outf, best_epoch, best_test)))
    else:
        estimator.load_state_dict(
            torch.load('{0}/pose_model_11_0.0.pth'.format(opt.outf)))

    product_list = []
    dist_list = []

    true_positives = 0
    false_positives = 0
    false_negatives = 0

    for index in range(len(test_dataset.list)):
        img = Image.open('{0}/data_v1/{1}-color.png'.format(
            test_dataset.root, test_dataset.list[index]))
        depth = np.array(
            Image.open('{0}/data_v1/{1}-depth.png'.format(
                test_dataset.root, test_dataset.list[index])))
        label = np.array(
            Image.open('{0}/data_v1/{1}-label.png'.format(
                test_dataset.root, test_dataset.list[index])))
        meta = scio.loadmat('{0}/data_v1/{1}-meta.mat'.format(
            test_dataset.root, test_dataset.list[index]))

        cam_cx = test_dataset.cam_cx_1
        cam_cy = test_dataset.cam_cy_1
        cam_fx = test_dataset.cam_fx_1
        cam_fy = test_dataset.cam_fy_1
        mask_back = ma.getmaskarray(ma.masked_equal(label, 0))

        obj = meta['cls_indexes'].flatten().astype(np.int32)
        for idx in range(0, len(obj)):
            print('object index: ', obj[idx])
            mask_depth = ma.getmaskarray(ma.masked_not_equal(depth, 0))
            mask_label = ma.getmaskarray(ma.masked_equal(label, obj[idx]))
            mask = mask_label * mask_depth
            if not (len(mask.nonzero()[0]) > test_dataset.minimum_num_pt
                    and len(test_dataset.symmetry[obj[idx]]['mirror']) > 0):
                continue

            rmin, rmax, cmin, cmax = get_bbox(mask_label)
            img_temp = np.transpose(np.array(img)[:, :, :3],
                                    (2, 0, 1))[:, rmin:rmax, cmin:cmax]

            img_masked = img_temp
            target_r = meta['poses'][:, :, idx][:, 0:3]
            target_t = np.array(meta['poses'][:, :, idx][:, 3:4].flatten())
            add_t = np.array([
                random.uniform(-test_dataset.noise_trans,
                               test_dataset.noise_trans) for i in range(3)
            ])

            choose = mask[rmin:rmax, cmin:cmax].flatten().nonzero()[0]
            if len(choose) > test_dataset.num_pt:
                c_mask = np.zeros(len(choose), dtype=int)
                c_mask[:test_dataset.num_pt] = 1
                np.random.shuffle(c_mask)
                choose = choose[c_mask.nonzero()]
            else:
                choose = np.pad(choose, (0, test_dataset.num_pt - len(choose)),
                                'wrap')

            depth_masked = depth[
                rmin:rmax,
                cmin:cmax].flatten()[choose][:, np.newaxis].astype(np.float32)
            xmap_masked = test_dataset.xmap[
                rmin:rmax,
                cmin:cmax].flatten()[choose][:, np.newaxis].astype(np.float32)
            ymap_masked = test_dataset.ymap[
                rmin:rmax,
                cmin:cmax].flatten()[choose][:, np.newaxis].astype(np.float32)
            choose = np.array([choose])

            cam_scale = meta['factor_depth'][0][0]
            pt2 = depth_masked / cam_scale
            pt0 = (ymap_masked - cam_cx) * pt2 / cam_fx
            pt1 = (xmap_masked - cam_cy) * pt2 / cam_fy
            cloud = np.concatenate((pt0, pt1, pt2), axis=1)

            dellist = [j for j in range(0, len(test_dataset.cld[obj[idx]]))]

            # dellist = random.sample(dellist, len(test_dataset.cld[obj[idx]]) - test_dataset.num_pt_mesh_small)
            # model_points = np.delete(test_dataset.cld[obj[idx]], dellist, axis=0)
            model_points = test_dataset.cld[obj[idx]]

            target_sym = []
            for sym in test_dataset.symmetry[obj[idx]]['mirror']:
                target_sym.append(np.dot(sym, target_r.T))
            target_sym = np.array(target_sym)

            target_cen = np.add(test_dataset.symmetry[obj[idx]]['center'],
                                target_t)

            target = np.dot(model_points, target_r.T)
            target = np.add(target, target_t)

            print('ground truth norm: ', target_sym)
            print('ground truth center: ', target_cen)
            points_ten, choose_ten, img_ten, target_sym_ten, target_cen_ten, target_ten, idx_ten = \
               torch.from_numpy(cloud.astype(np.float32)).unsqueeze(0), \
               torch.LongTensor(choose.astype(np.int32)).unsqueeze(0), \
               test_dataset.norm(torch.from_numpy(img_masked.astype(np.float32))).unsqueeze(0), \
               torch.from_numpy(target_sym.astype(np.float32)).unsqueeze(0), \
               torch.from_numpy(target_cen.astype(np.float32)).unsqueeze(0), \
               torch.from_numpy(target.astype(np.float32)).unsqueeze(0), \
               torch.LongTensor([obj[idx]-1]).unsqueeze(0)

            # print(img_ten.size())
            # print(points_ten.size())
            # print(choose_ten.size())
            # print(idx_ten.size())

            points_ten, choose_ten, img_ten, target_sym_ten, target_cen_ten, idx_ten = Variable(points_ten).cuda(), \
                                                                Variable(choose_ten).cuda(), \
                                                                Variable(img_ten).cuda(), \
                                                                Variable(target_sym_ten).cuda(), \
                                                                Variable(target_cen_ten).cuda(), \
                                                                Variable(idx_ten).cuda()

            pred_norm, pred_on_plane, emb = estimator(img_ten, points_ten,
                                                      choose_ten, idx_ten)

            # import pdb;pdb.set_trace()

            bs, num_p, _ = pred_on_plane.size()

            # pred_norm = torch.cat((pred_norm, torch.zeros(1,pred_norm.size(1),1)),2)

            # for i in range(pred_norm.size(1)):
            #     pred_norm[0,i,2] = torch.sqrt(1 - pred_norm[0,i,0] * pred_norm[0,i,0] - pred_norm[0,i,1] * pred_norm[0,i,1])
            # pred_norm = pred_norm / (torch.norm(pred_norm, dim=2).view(bs, num_p, 1))

            generate_obj_file_norm_pred(
                pred_norm / (torch.norm(pred_norm, dim=2).view(bs, num_p, 1)),
                pred_on_plane, points_ten,
                test_dataset.list[index].split('/')[0],
                test_dataset.list[index].split('/')[1], obj[idx])

            loss = criterion(pred_norm, pred_on_plane, target_sym_ten,
                             target_cen_ten, idx, points_ten, opt.w,
                             opt.refine_start)
            # print('test loss: ', loss)

            # bs, num_p, _ = pred_on_plane.size()
            pred_norm = pred_norm / (torch.norm(pred_norm, dim=2).view(
                bs, num_p, 1))
            pred_norm = pred_norm.cpu().detach().numpy()
            pred_on_plane = pred_on_plane.cpu().detach().numpy()
            points = points_ten.cpu().detach().numpy()

            clustering_points_idx = np.where(
                pred_on_plane > pred_on_plane.max() * PRED_ON_PLANE_FACTOR +
                pred_on_plane.mean() * (1 - PRED_ON_PLANE_FACTOR))[1]
            clustering_norm = pred_norm[0, clustering_points_idx, :]
            clustering_points = points[0, clustering_points_idx, :]
            num_points = len(clustering_points_idx)

            # import pdb;pdb.set_trace()

            close_thresh = 5e-3
            broad_thresh = 7e-3

            sym_flag = [0 for i in range(target_sym.shape[0])]
            sym_max_product = [0.0 for i in range(target_sym.shape[0])]
            sym_dist = [0.0 for i in range(target_sym.shape[0])]

            count_pred = 0
            while True:
                if num_points == 0:
                    break
                count_pred += 1
                if count_pred > target_sym.shape[0]:
                    break
                best_fit_num = 0

                count_try = 0
                while True:
                    if count_try > 3 or num_points <= 1:
                        break

                    pick_idx = np.random.randint(0, num_points - 1)
                    pick_point = clustering_points[pick_idx]
                    # proposal_norm = np.array(Plane(Point3D(pick_points[0]),Point3D(pick_points[1]),Point3D(pick_points[2])).normal_vector).astype(np.float32)
                    proposal_norm = clustering_norm[pick_idx]
                    proposal_norm = proposal_norm[:, np.newaxis]

                    # import pdb;pdb.set_trace()
                    proposal_point = pick_point
                    # highest_pred_idx = np.argmax(pred_on_plane[0,clustering_points_idx,:])
                    # highest_pred_loc = clustering_points[highest_pred_idx]
                    # proposal_norm = clustering_norm[highest_pred_idx][:,np.newaxis]
                    clustering_diff = clustering_points - proposal_point
                    clustering_dist = np.abs(
                        np.matmul(clustering_diff, proposal_norm))

                    broad_inliers = np.where(clustering_dist < broad_thresh)[0]
                    broad_inlier_num = len(broad_inliers)

                    close_inliers = np.where(clustering_dist < close_thresh)[0]
                    close_inlier_num = len(close_inliers)

                    if broad_inlier_num > num_points / (5 - count_pred):
                        best_fit_num = close_inlier_num
                        best_fit_norm = proposal_norm
                        best_fit_cen = clustering_points[close_inliers].mean(0)
                        best_fit_idx = clustering_points_idx[close_inliers]
                        scrub_idx = clustering_points_idx[broad_inliers]
                        break
                    else:
                        count_try += 1
                    # else:
                    #     np.delete(clustering_points_idx, highest_pred_idx)
                    #     num_points -= 1

                if count_try > 3 or num_points <= 1:
                    break

                for i in range(2):

                    def f(x):
                        dist = 0
                        x = x / LA.norm(x)
                        for point in clustering_points[broad_inliers]:
                            dist += np.abs(point[0] * x[0] + point[1] * x[1] +
                                           point[2] * np.sqrt(1 - x[0] * x[0] -
                                                              x[1] * x[1]) +
                                           x[2])
                        return dist

                    start_point = np.copy(proposal_norm)
                    start_point[2] = (-proposal_point *
                                      proposal_norm[:, 0]).sum()

                    min_point = fmin(f, start_point)
                    new_pred_loc = np.array([
                        0, 0, -min_point[2] /
                        np.sqrt(1 - min_point[0] * min_point[0] -
                                min_point[1] * min_point[1])
                    ])

                    min_point[2] = np.sqrt(1 - min_point[0] * min_point[0] -
                                           min_point[1] * min_point[1])
                    new_proposal_norm = min_point
                    clustering_diff = clustering_points - new_pred_loc
                    clustering_dist = np.abs(
                        np.matmul(clustering_diff, new_proposal_norm))

                    close_inliers = np.where(clustering_dist < close_thresh)[0]
                    new_close_inlier_num = len(close_inliers)

                    broad_inliers = np.where(clustering_dist < broad_thresh)[0]
                    new_broad_inlier_num = len(broad_inliers)
                    # import pdb;pdb.set_trace()
                    if new_close_inlier_num > close_inlier_num:
                        best_fit_num = new_close_inlier_num
                        # proposal_point = clustering_points_idx[clustering_dist.argmin()]
                        proposal_point = new_pred_loc
                        best_fit_norm = new_proposal_norm[:, np.newaxis]
                        best_fit_idx = clustering_points_idx[close_inliers]
                        scrub_idx = clustering_points_idx[broad_inliers]
                        best_fit_cen = new_pred_loc
                        inlier_num = new_inlier_num
                        proposal_norm = best_fit_norm

                # other_idx_pick = other_idx[other_idx_pick]

                # if len(other_idx_pick) > num_points//6:
                #     pick_idx = np.concatenate((pick_idx, other_idx_pick), 0)
                #     norm_proposal_new = clustering_norm[pick_idx,:].mean(0)
                #     norm_proposal_new = norm_proposal_new / LA.norm(norm_proposal_new)
                #     inlier_num_new = len(np.where(np.abs(clustering_norm-norm_proposal_new).sum(1) < thresh)[0])
                #     if inlier_num_new > inlier_num:
                #         best_fit_num = inlier_num_new
                #         best_fit_idx = np.where(np.abs(clustering_norm-norm_proposal_new).sum(1) < thresh_scrap)
                #         best_fit_norm = norm_proposal_new
                #         best_fit_cen = clustering_points[best_fit_idx].mean(0)

                if best_fit_num == 0:
                    break
                else:
                    print('predicted norm:{}, predicted point:{}'.format(
                        best_fit_norm, best_fit_cen))

                    max_idx = np.argmax(np.matmul(target_sym, best_fit_norm))
                    sym_flag[max_idx] += 1
                    sym_product = np.abs((target_sym[max_idx] *
                                          (best_fit_cen - target_cen)).sum())
                    if sym_max_product[max_idx] < sym_product:
                        sym_max_product[max_idx] = sym_product
                        sym_dist[max_idx] = np.matmul(target_sym,
                                                      best_fit_norm)[max_idx]

                    # generate_obj_file_sym_pred(best_fit_norm, best_fit_cen, target_ten, test_dataset.list[index].split('/')[0], test_dataset.list[index].split('/')[1], obj[idx], count_pred)
                    # import pdb;pdb.set_trace()
                    clustering_points_idx = np.setdiff1d(
                        clustering_points_idx, scrub_idx)

                    clustering_norm = pred_norm[0, clustering_points_idx, :]
                    clustering_points = points[0, clustering_points_idx, :]
                    num_points = len(clustering_points_idx)

            for i in range(target_sym.shape[0]):
                if sym_flag[i] >= 1:
                    dist_list.append(sym_dist[i])
                    product_list.append(sym_max_product[i])
                    false_positives += sym_flag[i] - 1
                else:
                    false_negatives += 1

    product_list = np.array(product_list)
    dist_list = np.array(dist_list)
    # import pdb;pdb.set_trace()
    total_num = len(product_list)

    prec = []
    recall = []
    for t in range(1000):
        good_ones = len(
            np.logical_and(dist_list < 0.5 * t / 1000,
                           product_list > math.cos(math.pi * 0.25 * t / 1000)))

        prec.append(good_ones * 1.0 / (false_positives + total_num))
        recall.append(good_ones * 1.0 / (good_ones + false_negatives))

    print(prec)
    print(recall)
    plt.plot(recall, prec, 'r')
    plt.axis([0, 1, 0, 1])
    plt.savefig('prec-recall.png')
Ejemplo n.º 10
0
def main():
    opt.manualSeed = random.randint(1, 10000)
    random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)

    if opt.dataset == 'linemod':
        opt.num_objects = 13
        opt.num_points = 500
        opt.outf = 'trained_models/linemod'
        opt.log_dir = 'experiments/logs/linemod'
        output_results = 'check_linemod.txt'
        opt.repeat_epoch = 20

    elif opt.dataset == 'ycb':
        opt.num_objects = 21  #number of object classes in the dataset
        opt.num_points = 1000  #number of points on the input pointcloud
        opt.outf = 'trained_models/ycb'  #folder to save trained models
        opt.log_dir = 'experiments/logs/ycb'  #folder to save logs
        opt.repeat_epoch = 1  #number of repeat times for one epoch training

    elif opt.dataset == 'ycb-syn':
        opt.num_objects = 31  # number of object classes in the dataset
        opt.num_points = 1000  # number of points on the input pointcloud
        opt.dataset_root = '/data/Akeaveny/Datasets/ycb_syn'
        opt.outf = 'trained_models/ycb_syn/ycb_syn2'  # folder to save trained models
        opt.log_dir = 'experiments/logs/ycb_syn/ycb_syn2'  # folder to save logs
        output_results = 'check_ycb_syn.txt'

        opt.w = 0.05
        opt.refine_margin = 0.01

    elif opt.dataset == 'arl':
        opt.num_objects = 10  # number of object classes in the dataset
        opt.num_points = 1000  # number of points on the input pointcloud
        opt.dataset_root = '/data/Akeaveny/Datasets/arl_dataset'
        opt.outf = 'trained_models/arl/clutter/arl_finetune_syn_2'  # folder to save trained models
        opt.log_dir = '/home/akeaveny/catkin_ws/src/object-rpe-ak/DenseFusion/experiments/logs/arl/clutter/arl_finetune_syn_2'  # folder to save logs
        output_results = 'check_arl_syn.txt'

        opt.nepoch = 750

        opt.w = 0.05
        opt.refine_margin = 0.0045

        # TODO
        opt.repeat_epoch = 20
        opt.start_epoch = 0
        opt.resume_posenet = 'pose_model_1_0.012397416144377301.pth'
        opt.resume_refinenet = 'pose_refine_model_153_0.004032851301599294.pth'

    elif opt.dataset == 'arl1':
        opt.num_objects = 5  # number of object classes in the dataset
        opt.num_points = 1000  # number of points on the input pointcloud
        opt.dataset_root = '/data/Akeaveny/Datasets/arl_dataset'
        opt.outf = 'trained_models/arl1/clutter/arl_real_2'  # folder to save trained models
        opt.log_dir = '/home/akeaveny/catkin_ws/src/object-rpe-ak/DenseFusion/experiments/logs/arl1/clutter/arl_real_2'  # folder to save logs
        output_results = 'check_arl_syn.txt'

        opt.nepoch = 750

        opt.w = 0.05
        opt.refine_margin = 0.015

        # opt.start_epoch = 120
        # opt.resume_posenet = 'pose_model_current.pth'
        # opt.resume_refinenet = 'pose_refine_model_115_0.008727498716640046.pth'

    elif opt.dataset == 'elevator':
        opt.num_objects = 1  # number of object classes in the dataset
        opt.num_points = 1000  # number of points on the input pointcloud
        opt.dataset_root = '/data/Akeaveny/Datasets/elevator_dataset'
        opt.outf = 'trained_models/elevator/elevator_2'  # folder to save trained models
        opt.log_dir = '/home/akeaveny/catkin_ws/src/object-rpe-ak/DenseFusion/experiments/logs/elevator/elevator_2'  # folder to save logs
        output_results = 'check_arl_syn.txt'

        opt.nepoch = 750

        opt.w = 0.05
        opt.refine_margin = 0.015

        opt.nepoch = 750

        opt.w = 0.05
        opt.refine_margin = 0.015

        # TODO
        opt.repeat_epoch = 40
        # opt.start_epoch = 47
        # opt.resume_posenet = 'pose_model_current.pth'
        # opt.resume_refinenet = 'pose_refine_model_46_0.007581770288279472.pth'

    else:
        print('Unknown dataset')
        return

    estimator = PoseNet(num_points=opt.num_points, num_obj=opt.num_objects)
    estimator.cuda()
    refiner = PoseRefineNet(num_points=opt.num_points, num_obj=opt.num_objects)
    refiner.cuda()

    if opt.resume_posenet != '':
        estimator.load_state_dict(
            torch.load('{0}/{1}'.format(opt.outf, opt.resume_posenet)))

    if opt.resume_refinenet != '':
        refiner.load_state_dict(
            torch.load('{0}/{1}'.format(opt.outf, opt.resume_refinenet)))
        opt.refine_start = False
        opt.decay_start = False
        opt.lr *= opt.lr_rate
        opt.w *= opt.w_rate
        opt.batch_size = int(opt.batch_size / opt.iteration)
        optimizer = optim.Adam(refiner.parameters(), lr=opt.lr)
    else:
        opt.refine_start = False
        opt.decay_start = False
        optimizer = optim.Adam(estimator.parameters(), lr=opt.lr)

    if opt.dataset == 'ycb':
        dataset = PoseDataset_ycb('train', opt.num_points, True,
                                  opt.dataset_root, opt.noise_trans,
                                  opt.refine_start)
    elif opt.dataset == 'linemod':
        dataset = PoseDataset_linemod('train', opt.num_points, True,
                                      opt.dataset_root, opt.noise_trans,
                                      opt.refine_start)
    elif opt.dataset == 'ycb-syn':
        dataset = PoseDataset_ycb_syn('train', opt.num_points, True,
                                      opt.dataset_root, opt.noise_trans,
                                      opt.refine_start)
    elif opt.dataset == 'arl':
        dataset = PoseDataset_arl('train', opt.num_points, True,
                                  opt.dataset_root, opt.noise_trans,
                                  opt.refine_start)
    elif opt.dataset == 'arl1':
        dataset = PoseDataset_arl1('train', opt.num_points, True,
                                   opt.dataset_root, opt.noise_trans,
                                   opt.refine_start)
    elif opt.dataset == 'elevator':
        dataset = PoseDataset_elevator('train', opt.num_points, True,
                                       opt.dataset_root, opt.noise_trans,
                                       opt.refine_start)

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=1,
                                             shuffle=True,
                                             num_workers=opt.workers)

    if opt.dataset == 'ycb':
        test_dataset = PoseDataset_ycb('test', opt.num_points, False,
                                       opt.dataset_root, 0.0, opt.refine_start)
    elif opt.dataset == 'linemod':
        test_dataset = PoseDataset_linemod('test', opt.num_points, False,
                                           opt.dataset_root, 0.0,
                                           opt.refine_start)
    elif opt.dataset == 'ycb-syn':
        test_dataset = PoseDataset_ycb_syn('test', opt.num_points, True,
                                           opt.dataset_root, 0.0,
                                           opt.refine_start)
    elif opt.dataset == 'arl':
        test_dataset = PoseDataset_arl('test', opt.num_points, True,
                                       opt.dataset_root, 0.0, opt.refine_start)
    elif opt.dataset == 'arl1':
        test_dataset = PoseDataset_arl1('test', opt.num_points, True,
                                        opt.dataset_root, 0.0,
                                        opt.refine_start)
    elif opt.dataset == 'elevator':
        test_dataset = PoseDataset_elevator('test', opt.num_points, True,
                                            opt.dataset_root, 0.0,
                                            opt.refine_start)

    testdataloader = torch.utils.data.DataLoader(test_dataset,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 num_workers=opt.workers)

    opt.sym_list = dataset.get_sym_list()
    opt.num_points_mesh = dataset.get_num_points_mesh()

    print(
        '>>>>>>>>----------Dataset loaded!---------<<<<<<<<\nlength of the training set: {0}\nlength of the testing set: {1}\nnumber of sample points on mesh: {2}\nsymmetry object list: {3}'
        .format(len(dataset), len(test_dataset), opt.num_points_mesh,
                opt.sym_list))

    criterion = Loss(opt.num_points_mesh, opt.sym_list)
    criterion_refine = Loss_refine(opt.num_points_mesh, opt.sym_list)

    best_test = np.Inf

    if opt.start_epoch == 1:
        for log in os.listdir(opt.log_dir):
            os.remove(os.path.join(opt.log_dir, log))
    st_time = time.time()

    ######################
    ######################

    # TODO (ak): set up tensor board
    # if not os.path.exists(opt.log_dir):
    #     os.makedirs(opt.log_dir)
    #
    # writer = SummaryWriter(opt.log_dir)

    ######################
    ######################

    for epoch in range(opt.start_epoch, opt.nepoch):
        logger = setup_logger(
            'epoch%d' % epoch,
            os.path.join(opt.log_dir, 'epoch_%d_log.txt' % epoch))
        logger.info('Train time {0}'.format(
            time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) +
            ', ' + 'Training started'))
        train_count = 0
        train_dis_avg = 0.0
        if opt.refine_start:
            estimator.eval()
            refiner.train()
        else:
            estimator.train()
        optimizer.zero_grad()

        for rep in range(opt.repeat_epoch):

            ##################
            # train
            ##################

            for i, data in enumerate(dataloader, 0):
                points, choose, img, target, model_points, idx = data

                # TODO: txt file
                # fw = open(test_folder + output_results, 'w')
                # fw.write('Points\n{0}\n\nchoose\n{1}\n\nimg\n{2}\n\ntarget\n{3}\n\nmodel_points\n{4}'.format(points, choose, img, target, model_points))
                # fw.close()

                points, choose, img, target, model_points, idx = Variable(points).cuda(), \
                                                                 Variable(choose).cuda(), \
                                                                 Variable(img).cuda(), \
                                                                 Variable(target).cuda(), \
                                                                 Variable(model_points).cuda(), \
                                                                 Variable(idx).cuda()
                pred_r, pred_t, pred_c, emb = estimator(
                    img, points, choose, idx)
                loss, dis, new_points, new_target = criterion(
                    pred_r, pred_t, pred_c, target, model_points, idx, points,
                    opt.w, opt.refine_start)

                if opt.refine_start:
                    for ite in range(0, opt.iteration):
                        pred_r, pred_t = refiner(new_points, emb, idx)
                        dis, new_points, new_target = criterion_refine(
                            pred_r, pred_t, new_target, model_points, idx,
                            new_points)
                        dis.backward()
                else:
                    loss.backward()

                train_dis_avg += dis.item()
                train_count += 1

                if train_count % opt.batch_size == 0:
                    logger.info(
                        'Train time {} Epoch {} Batch {} Frame {}/{} Avg_dis: {:.2f} [cm]'
                        .format(
                            time.strftime("%Hh %Mm %Ss",
                                          time.gmtime(time.time() - st_time)),
                            epoch, int(train_count / opt.batch_size),
                            train_count, len(dataset.list),
                            train_dis_avg / opt.batch_size * 100))
                    optimizer.step()
                    optimizer.zero_grad()

                    # TODO: tensorboard
                    # if train_count != 0 and train_count % 250 == 0:
                    #     scalar_info = {'loss': loss.item(),
                    #                    'dis': train_dis_avg / opt.batch_size}
                    #     for key, val in scalar_info.items():
                    #         writer.add_scalar(key, val, train_count)

                    train_dis_avg = 0

                if train_count != 0 and train_count % 1000 == 0:
                    if opt.refine_start:
                        torch.save(
                            refiner.state_dict(),
                            '{0}/pose_refine_model_current.pth'.format(
                                opt.outf))
                    else:
                        torch.save(
                            estimator.state_dict(),
                            '{0}/pose_model_current.pth'.format(opt.outf))

                    # TODO: tensorboard
                    # scalar_info = {'loss': loss.item(),
                    #                'dis': dis.item()}
                    # for key, val in scalar_info.items():
                    #     writer.add_scalar(key, val, train_count)

        print(
            '>>>>>>>>----------epoch {0} train finish---------<<<<<<<<'.format(
                epoch))

        logger = setup_logger(
            'epoch%d_test' % epoch,
            os.path.join(opt.log_dir, 'epoch_%d_test_log.txt' % epoch))
        logger.info('Test time {0}'.format(
            time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) +
            ', ' + 'Testing started'))
        test_dis = 0.0
        test_count = 0
        estimator.eval()
        refiner.eval()

        for j, data in enumerate(testdataloader, 0):
            points, choose, img, target, model_points, idx = data
            points, choose, img, target, model_points, idx = Variable(points).cuda(), \
                                                             Variable(choose).cuda(), \
                                                             Variable(img).cuda(), \
                                                             Variable(target).cuda(), \
                                                             Variable(model_points).cuda(), \
                                                             Variable(idx).cuda()
            pred_r, pred_t, pred_c, emb = estimator(img, points, choose, idx)
            _, dis, new_points, new_target = criterion(pred_r, pred_t, pred_c,
                                                       target, model_points,
                                                       idx, points, opt.w,
                                                       opt.refine_start)

            if opt.refine_start:
                for ite in range(0, opt.iteration):
                    pred_r, pred_t = refiner(new_points, emb, idx)
                    dis, new_points, new_target = criterion_refine(
                        pred_r, pred_t, new_target, model_points, idx,
                        new_points)

            test_dis += dis.item()
            logger.info('Test time {} Test Frame No.{} dis: {} [cm]'.format(
                time.strftime("%Hh %Mm %Ss",
                              time.gmtime(time.time() - st_time)), test_count,
                dis * 100))

            test_count += 1

        test_dis = test_dis / test_count
        logger.info(
            'Test time {} Epoch {} TEST FINISH Avg dis: {} [cm]'.format(
                time.strftime("%Hh %Mm %Ss",
                              time.gmtime(time.time() - st_time)), epoch,
                test_dis * 100))

        # TODO: tensorboard
        # scalar_info = {'test dis': test_dis}
        # for key, val in scalar_info.items():
        #     writer.add_scalar(key, val, train_count)

        if test_dis <= best_test:
            best_test = test_dis
            if opt.refine_start:
                torch.save(
                    refiner.state_dict(),
                    '{0}/pose_refine_model_{1}_{2}.pth'.format(
                        opt.outf, epoch, test_dis))
            else:
                torch.save(
                    estimator.state_dict(),
                    '{0}/pose_model_{1}_{2}.pth'.format(
                        opt.outf, epoch, test_dis))
            print(epoch,
                  '>>>>>>>>----------BEST TEST MODEL SAVED---------<<<<<<<<')

        if best_test < opt.decay_margin and not opt.decay_start:
            opt.decay_start = True
            opt.lr *= opt.lr_rate
            opt.w *= opt.w_rate
            optimizer = optim.Adam(estimator.parameters(), lr=opt.lr)

        if best_test < opt.refine_margin and not opt.refine_start:
            opt.refine_start = True
            opt.batch_size = int(opt.batch_size / opt.iteration)
            optimizer = optim.Adam(refiner.parameters(), lr=opt.lr)

            if opt.dataset == 'ycb':
                dataset = PoseDataset_ycb('train', opt.num_points, True,
                                          opt.dataset_root, opt.noise_trans,
                                          opt.refine_start)
            elif opt.dataset == 'linemod':
                dataset = PoseDataset_linemod('train', opt.num_points, True,
                                              opt.dataset_root,
                                              opt.noise_trans,
                                              opt.refine_start)
            elif opt.dataset == 'ycb-syn':
                dataset = PoseDataset_ycb_syn('train', opt.num_points, True,
                                              opt.dataset_root,
                                              opt.noise_trans,
                                              opt.refine_start)
            elif opt.dataset == 'arl':
                dataset = PoseDataset_arl('train', opt.num_points, True,
                                          opt.dataset_root, opt.noise_trans,
                                          opt.refine_start)
            elif opt.dataset == 'arl1':
                dataset = PoseDataset_arl1('train', opt.num_points, True,
                                           opt.dataset_root, opt.noise_trans,
                                           opt.refine_start)
            elif opt.dataset == 'elevator':
                dataset = PoseDataset_elevator('train', opt.num_points, True,
                                               opt.dataset_root,
                                               opt.noise_trans,
                                               opt.refine_start)

            dataloader = torch.utils.data.DataLoader(dataset,
                                                     batch_size=1,
                                                     shuffle=True,
                                                     num_workers=opt.workers)

            if opt.dataset == 'ycb':
                test_dataset = PoseDataset_ycb('test', opt.num_points, False,
                                               opt.dataset_root, 0.0,
                                               opt.refine_start)
            elif opt.dataset == 'linemod':
                test_dataset = PoseDataset_linemod('test', opt.num_points,
                                                   False, opt.dataset_root,
                                                   0.0, opt.refine_start)
            elif opt.dataset == 'ycb-syn':
                test_dataset = PoseDataset_ycb_syn('test', opt.num_points,
                                                   True, opt.dataset_root, 0.0,
                                                   opt.refine_start)
            elif opt.dataset == 'arl':
                test_dataset = PoseDataset_arl('test', opt.num_points, True,
                                               opt.dataset_root, 0.0,
                                               opt.refine_start)
            elif opt.dataset == 'arl1':
                test_dataset = PoseDataset_arl1('test', opt.num_points, True,
                                                opt.dataset_root, 0.0,
                                                opt.refine_start)
            elif opt.dataset == 'elevator':
                test_dataset = PoseDataset_elevator('test', opt.num_points,
                                                    True, opt.dataset_root,
                                                    0.0, opt.refine_start)

            testdataloader = torch.utils.data.DataLoader(
                test_dataset,
                batch_size=1,
                shuffle=False,
                num_workers=opt.workers)

            opt.sym_list = dataset.get_sym_list()
            opt.num_points_mesh = dataset.get_num_points_mesh()

            print(
                '>>>>>>>>----------Dataset loaded!---------<<<<<<<<\nlength of the training set: {0}\nlength of the testing set: {1}\nnumber of sample points on mesh: {2}\nsymmetry object list: {3}'
                .format(len(dataset), len(test_dataset), opt.num_points_mesh,
                        opt.sym_list))

            criterion = Loss(opt.num_points_mesh, opt.sym_list)
            criterion_refine = Loss_refine(opt.num_points_mesh, opt.sym_list)
Ejemplo n.º 11
0
def main():
    opt.manualSeed = random.randint(1, 10000)
    random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)

    opt.num_objects = 3
    opt.num_points = 500
    opt.outf = 'trained_models'
    opt.log_dir = 'experiments/logs'
    opt.repeat_epoch = 20

    estimator = PoseNet(num_points=opt.num_points, num_obj=opt.num_objects)
    estimator.cuda()
    refiner = PoseRefineNet(num_points=opt.num_points, num_obj=opt.num_objects)
    refiner.cuda()

    if opt.resume_posenet != '':
        estimator.load_state_dict(
            torch.load('{0}/{1}'.format(opt.outf, opt.resume_posenet)))

    if opt.resume_refinenet != '':
        refiner.load_state_dict(
            torch.load('{0}/{1}'.format(opt.outf, opt.resume_refinenet)))
        opt.refine_start = True
        opt.decay_start = True
        opt.lr *= opt.lr_rate
        opt.w *= opt.w_rate
        opt.batch_size = int(opt.batch_size / opt.iteration)
        optimizer = optim.Adam(refiner.parameters(), lr=opt.lr)
    else:
        opt.refine_start = False
        opt.decay_start = False
        optimizer = optim.Adam(estimator.parameters(), lr=opt.lr)

    dataset = PoseDataset('train', opt.num_points, True, opt.dataset_root,
                          opt.noise_trans, opt.refine_start)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=1,
                                             shuffle=True,
                                             num_workers=opt.workers)

    test_dataset = PoseDataset('test', opt.num_points, False, opt.dataset_root,
                               0.0, opt.refine_start)
    testdataloader = torch.utils.data.DataLoader(test_dataset,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 num_workers=opt.workers)

    opt.sym_list = dataset.get_sym_list()
    opt.num_points_mesh = dataset.get_num_points_mesh()

    print(
        '>>>>>>>>----------Dataset loaded!---------<<<<<<<<\nlength of the training set: {0}\nlength of the testing set: {1}\nnumber of sample points on mesh: {2}'
        .format(len(dataset), len(test_dataset), opt.num_points_mesh))

    criterion = Loss(opt.num_points_mesh, opt.sym_list)
    criterion_refine = Loss_refine(opt.num_points_mesh, opt.sym_list)

    best_test = np.Inf

    if opt.start_epoch == 1:
        for log in os.listdir(opt.log_dir):
            os.remove(os.path.join(opt.log_dir, log))
    st_time = time.time()

    for epoch in range(opt.start_epoch, opt.nepoch):
        logger = setup_logger(
            'epoch%d' % epoch,
            os.path.join(opt.log_dir, 'epoch_%d_log.txt' % epoch))
        logger.info('Train time {0}'.format(
            time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) +
            ', ' + 'Training started'))
        train_count = 0
        train_dis_avg = 0.0
        if opt.refine_start:
            estimator.eval()  # affects dropout and batch normalization
            refiner.train()
        else:
            estimator.train()
        optimizer.zero_grad()

        for rep in range(opt.repeat_epoch):
            for i, data in enumerate(dataloader, 0):
                points, choose, img, target, model_points, idx = data
                #points        ->torch.Size([500, 3])  ->在crop出来的像素区域随机选取500个点,利用相机内参结合深度值算出来的点云cloud
                #choose        ->torch.Size([1, 500])
                #img           ->torch.Size([3, 80, 80])
                #target        ->torch.Size([500, 3])  ->真实模型上随机选取的mesh点进行ground truth pose变换后得到的点
                #model_points  ->torch.Size([500, 3])  ->真实模型上随机选取的mesh点在进行pose变换前的点
                #idx           ->torch.Size([1])
                #tensor([4], device='cuda:0')
                #img和points对应rgb和点云信息,需要在网络内部fusion
                points, choose, img, target, model_points, idx = Variable(points).cuda(), \
                                                                 Variable(choose).cuda(), \
                                                                 Variable(img).cuda(), \
                                                                 Variable(target).cuda(), \
                                                                 Variable(model_points).cuda(), \
                                                                 Variable(idx).cuda()
                pred_r, pred_t, pred_c, emb = estimator(
                    img, points, choose, idx)
                loss, dis, new_points, new_target = criterion(
                    pred_r, pred_t, pred_c, target, model_points, idx, points,
                    opt.w, opt.refine_start)

                if opt.refine_start:
                    for ite in range(0, opt.iteration):
                        pred_r, pred_t = refiner(new_points, emb, idx)
                        dis, new_points, new_target = criterion_refine(
                            pred_r, pred_t, new_target, model_points, idx,
                            new_points)
                        dis.backward()
                else:
                    loss.backward()

                train_dis_avg += dis.item()
                train_count += 1

                if train_count % opt.batch_size == 0:
                    logger.info(
                        'Train time {0} Epoch {1} Batch {2} Frame {3} Avg_dis:{4}'
                        .format(
                            time.strftime("%Hh %Mm %Ss",
                                          time.gmtime(time.time() - st_time)),
                            epoch, int(train_count / opt.batch_size),
                            train_count, train_dis_avg / opt.batch_size))
                    optimizer.step()
                    optimizer.zero_grad()
                    train_dis_avg = 0

                if train_count != 0 and train_count % 1000 == 0:
                    if opt.refine_start:
                        torch.save(
                            refiner.state_dict(),
                            '{0}/pose_refine_model_current.pth'.format(
                                opt.outf))
                    else:
                        torch.save(
                            estimator.state_dict(),
                            '{0}/pose_model_current.pth'.format(opt.outf))

        print(
            '>>>>>>>>----------epoch {0} train finish---------<<<<<<<<'.format(
                epoch))

        logger = setup_logger(
            'epoch%d_test' % epoch,
            os.path.join(opt.log_dir, 'epoch_%d_test_log.txt' % epoch))
        logger.info('Test time {0}'.format(
            time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) +
            ', ' + 'Testing started'))
        test_dis = 0.0
        test_count = 0
        estimator.eval()
        refiner.eval()

        for j, data in enumerate(testdataloader, 0):
            points, choose, img, target, model_points, idx = data
            points, choose, img, target, model_points, idx = Variable(points).cuda(), \
                                                             Variable(choose).cuda(), \
                                                             Variable(img).cuda(), \
                                                             Variable(target).cuda(), \
                                                             Variable(model_points).cuda(), \
                                                             Variable(idx).cuda()
            pred_r, pred_t, pred_c, emb = estimator(img, points, choose, idx)
            _, dis, new_points, new_target = criterion(pred_r, pred_t, pred_c,
                                                       target, model_points,
                                                       idx, points, opt.w,
                                                       opt.refine_start)

            if opt.refine_start:
                for ite in range(0, opt.iteration):
                    pred_r, pred_t = refiner(new_points, emb, idx)
                    dis, new_points, new_target = criterion_refine(
                        pred_r, pred_t, new_target, model_points, idx,
                        new_points)

            test_dis += dis.item()
            logger.info('Test time {0} Test Frame No.{1} dis:{2}'.format(
                time.strftime("%Hh %Mm %Ss",
                              time.gmtime(time.time() - st_time)), test_count,
                dis))

            test_count += 1

        test_dis = test_dis / test_count
        logger.info('Test time {0} Epoch {1} TEST FINISH Avg dis: {2}'.format(
            time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)),
            epoch, test_dis))
        if test_dis <= best_test:
            best_test = test_dis
            if opt.refine_start:
                torch.save(
                    refiner.state_dict(),
                    '{0}/pose_refine_model_{1}_{2}.pth'.format(
                        opt.outf, epoch, test_dis))
            else:
                torch.save(
                    estimator.state_dict(),
                    '{0}/pose_model_{1}_{2}.pth'.format(
                        opt.outf, epoch, test_dis))
            print(epoch,
                  '>>>>>>>>----------BEST TEST MODEL SAVED---------<<<<<<<<')

        if best_test < opt.decay_margin and not opt.decay_start:
            opt.decay_start = True
            opt.lr *= opt.lr_rate
            opt.w *= opt.w_rate
            optimizer = optim.Adam(estimator.parameters(), lr=opt.lr)

        if best_test < opt.refine_margin and not opt.refine_start:
            opt.refine_start = True
            opt.batch_size = int(opt.batch_size / opt.iteration)
            optimizer = optim.Adam(refiner.parameters(), lr=opt.lr)

            dataset = PoseDataset('train', opt.num_points, True,
                                  opt.dataset_root, opt.noise_trans,
                                  opt.refine_start)
            dataloader = torch.utils.data.DataLoader(dataset,
                                                     batch_size=1,
                                                     shuffle=True,
                                                     num_workers=opt.workers)

            test_dataset = PoseDataset('test', opt.num_points, False,
                                       opt.dataset_root, 0.0, opt.refine_start)
            testdataloader = torch.utils.data.DataLoader(
                test_dataset,
                batch_size=1,
                shuffle=False,
                num_workers=opt.workers)

            opt.sym_list = dataset.get_sym_list()
            opt.num_points_mesh = dataset.get_num_points_mesh()

            print(
                '>>>>>>>>----------Dataset loaded!---------<<<<<<<<\nlength of the training set: {0}\nlength of the testing set: {1}\nnumber of sample points on mesh: {2}'
                .format(len(dataset), len(test_dataset), opt.num_points_mesh))

            criterion = Loss(opt.num_points_mesh, opt.sym_list)
            criterion_refine = Loss_refine(opt.num_points_mesh, opt.sym_list)