Esempio n. 1
0
def compute_summary(end_points,labels_pl,centers_pl,heading_class_label_pl,heading_residual_label_pl,size_class_label_pl,size_residual_label_pl):
    '''
    计算 iou_2d, iou_3d 用 原作者提供的 numpy 版本 的操作实现可能速度会偏慢
    @author chonepeiceyb
    :param end_points:   预测结果
    :param labels_pl:      (B,2)
    :param centers_pl:      (B,3)
    :param heading_class_label_pl:   (B,)
    :param heading_residual_label_pl:(B,)
    :param size_class_label_pl:(B,)
    :param size_residual_label_pl:(B,3)
    :return:
    iou2ds: (B,) birdeye view oriented 2d box ious
    iou3ds: (B,) 3d box ious
    accuracy: python float 平均预测准确度
    '''
    end_points_np = {}
    # convert tensor to numpy array
    for key,value in end_points.items():
        end_points_np[key] = value.cpu().data.numpy()
    iou2ds, iou3ds = provider.compute_box3d_iou(  end_points_np['center'],\
                                                  end_points_np['heading_scores'], end_points_np['heading_residuals'], \
                                                  end_points_np['size_scores'], end_points_np['size_residuals'],\
                                                  centers_pl,\
                                                  heading_class_label_pl,heading_residual_label_pl,\
                                                  size_class_label_pl,size_residual_label_pl)
    correct = torch.eq( torch.argmax(end_points['mask_logits'],dim=1),labels_pl.type(torch.int64))                  #end_points['mask_logits'] ,(B,2,N) , 需要调 bug
    accuracy = torch.mean(correct.type(torch.float32))
    return iou2ds,iou3ds,accuracy
Esempio n. 2
0
    def eval_epoch(self):
        self.model.eval()
        test_idxs = np.arange(0, len(self.valid_dataset))
        num_batches = len(self.valid_dataset) // self.val_batch_size

        # To collect statistics
        loss_sum = 0
        iou2ds_sum = 0
        iou3ds_sum = 0
        iou3d_correct_cnt = 0

        for batch_idx in range(num_batches):
            start_idx = batch_idx * self.val_batch_size
            end_idx = (batch_idx + 1) * self.val_batch_size

            batch_data, batch_label, batch_center, \
            batch_hclass, batch_hres, \
            batch_sclass, batch_sres, \
            batch_rot_angle, batch_one_hot_vec = \
                tuple(get_batch(self.valid_dataset, test_idxs, start_idx, end_idx,
                                self.config.NUM_POINT, self.config.NUM_CHANNELS))

            with torch.no_grad():
                self.endpoints = self.model(batch_data, batch_one_hot_vec,
                                            batch_label)
                val_loss = self.loss(batch_center, batch_hclass, batch_hres,
                                     batch_sclass, batch_sres)

            loss_sum += val_loss

            iou2ds, iou3ds = compute_box3d_iou(
                self.endpoints['center'].detach().cpu().numpy(),
                self.endpoints['heading_scores'].detach().cpu().numpy(),
                self.endpoints['heading_residuals'].detach().cpu().numpy(),
                self.endpoints['size_scores'].detach().cpu().numpy(),
                self.endpoints['size_residuals'].detach().cpu().numpy(),
                batch_center.detach().cpu().numpy(),
                batch_hclass.detach().cpu().numpy(),
                batch_hres.detach().cpu().numpy(),
                batch_sclass.detach().cpu().numpy(),
                batch_sres.detach().cpu().numpy())
            self.endpoints['iou2ds'] = iou2ds
            self.endpoints['iou3ds'] = iou3ds

            iou2ds_sum += np.sum(self.endpoints['iou2ds'])
            iou3ds_sum += np.sum(self.endpoints['iou3ds'])
            iou3d_correct_cnt += np.sum(self.endpoints['iou3ds'] >= 0.7)

        box_acc = float(iou3d_correct_cnt) / float(
            self.val_batch_size * num_batches)
        self.log_box_values(batch_idx, loss_sum / float(num_batches), box_acc,
                            'Val')

        if self.best_val_loss > (loss_sum / float(num_batches)):
            self.best_val_loss = (loss_sum / float(num_batches))
            self.best_model = self.model
def train():
    ''' Main function for training and simple evaluation. '''
    start = time.perf_counter()
    SEED = 1
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    random.seed(SEED)
    torch.backends.cudnn.deterministic = True
    blue = lambda x: '\033[94m' + x + '\033[0m'

    # set model
    if FLAGS.model == 'frustum_pointnets_v1':
        from frustum_pointnets_v1 import FrustumPointNetv1
        FrustumPointNet = FrustumPointNetv1(n_classes=n_classes).cuda()

    # load pre-trained model
    if FLAGS.ckpt:
        ckpt = torch.load(FLAGS.ckpt)
        FrustumPointNet.load_state_dict(ckpt['model_state_dict'])

    # set optimizer and scheduler
    if OPTIMIZER == 'adam':
        optimizer = torch.optim.Adam(FrustumPointNet.parameters(),
                                     lr=BASE_LEARNING_RATE,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=FLAGS.weight_decay)

    def lr_func(epoch,
                init=BASE_LEARNING_RATE,
                step_size=DECAY_STEP,
                gamma=DECAY_RATE,
                eta_min=0.00001):
        f = gamma**(epoch // DECAY_STEP)
        if init * f > eta_min:
            return f
        else:
            return 0.01  #0.001*0.01 = eta_min

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer,
                                                  lr_lambda=lr_func)

    # train
    if os.path.exists('runs/' + NAME):
        print('name has been existed')
        shutil.rmtree('runs/' + NAME)

    writer = SummaryWriter('runs/' + NAME)
    num_batch = len(TRAIN_DATASET) / BATCH_SIZE
    best_iou3d_acc = 0.0
    best_epoch = 1
    best_file = ''

    for epoch in range(MAX_EPOCH):
        log_string('**** EPOCH %03d ****' % (epoch + 1))
        sys.stdout.flush()
        print('Epoch %d/%s:' % (epoch + 1, MAX_EPOCH))

        # record for one epoch
        train_total_loss = 0.0
        train_iou2d = 0.0
        train_iou3d = 0.0
        train_acc = 0.0
        train_iou3d_acc = 0.0

        if FLAGS.return_all_loss:
            train_mask_loss = 0.0
            train_center_loss = 0.0
            train_heading_class_loss = 0.0
            train_size_class_loss = 0.0
            train_heading_residuals_normalized_loss = 0.0
            train_size_residuals_normalized_loss = 0.0
            train_stage1_center_loss = 0.0
            train_corners_loss = 0.0

        n_samples = 0
        for i, data in tqdm(enumerate(train_dataloader),\
                total=len(train_dataloader), smoothing=0.9):
            n_samples += data[0].shape[0]
            #for debug
            if FLAGS.debug == True:
                if i == 1:
                    break
            '''
            data after frustum rotation
            1. For Seg
            batch_data:[32, 2048, 4], pts in frustum, 
            batch_label:[32, 2048], pts ins seg label in frustum,
            2. For T-Net
            batch_center:[32, 3],
            3. For Box Est.
            batch_hclass:[32],
            batch_hres:[32],
            batch_sclass:[32],
            batch_sres:[32,3],
            4. Others
            batch_rot_angle:[32],alpha, not rotation_y,
            batch_one_hot_vec:[32,3],
            '''
            batch_data, batch_label, batch_center, \
            batch_hclass, batch_hres, \
            batch_sclass, batch_sres, \
            batch_rot_angle, batch_one_hot_vec = data

            batch_data = batch_data.transpose(2, 1).float().cuda()
            batch_label = batch_label.float().cuda()
            batch_center = batch_center.float().cuda()
            batch_hclass = batch_hclass.long().cuda()
            batch_hres = batch_hres.float().cuda()
            batch_sclass = batch_sclass.long().cuda()
            batch_sres = batch_sres.float().cuda()
            batch_rot_angle = batch_rot_angle.float().cuda()  ###Not Use?
            batch_one_hot_vec = batch_one_hot_vec.float().cuda()

            optimizer.zero_grad()
            FrustumPointNet = FrustumPointNet.train()
            '''
            #bn_decay(defaut 0.1)
            bn_momentum = BN_INIT_DECAY * BN_DECAY_DECAY_RATE**(epoch//BN_DECAY_DECAY_STEP)
            if bn_momentum < 1 - BN_DECAY_CLIP:
                bn_momentum = 1 - BN_DECAY_CLIP
            '''
            logits, mask, stage1_center, center_boxnet, \
            heading_scores, heading_residuals_normalized, heading_residuals, \
            size_scores, size_residuals_normalized, size_residuals, center = \
                FrustumPointNet(batch_data, batch_one_hot_vec)

            if FLAGS.return_all_loss:
                total_loss, mask_loss, center_loss, heading_class_loss, \
                    size_class_loss, heading_residuals_normalized_loss, \
                    size_residuals_normalized_loss, stage1_center_loss, \
                    corners_loss = \
                    Loss(logits, batch_label, \
                        center, batch_center, stage1_center, \
                        heading_scores, heading_residuals_normalized, \
                        heading_residuals, \
                        batch_hclass, batch_hres, \
                        size_scores,size_residuals_normalized,\
                        size_residuals,\
                        batch_sclass,batch_sres)
            else:
                total_loss = \
                    Loss(logits, batch_label, \
                        center, batch_center, stage1_center, \
                        heading_scores, heading_residuals_normalized, \
                        heading_residuals, \
                        batch_hclass, batch_hres, \
                        size_scores,size_residuals_normalized,\
                        size_residuals,\
                        batch_sclass,batch_sres)

            total_loss.backward()
            optimizer.step()
            train_total_loss += total_loss.item()

            iou2ds, iou3ds = provider.compute_box3d_iou(\
                center.cpu().detach().numpy(),\
                heading_scores.cpu().detach().numpy(),\
                heading_residuals.cpu().detach().numpy(), \
                size_scores.cpu().detach().numpy(), \
                size_residuals.cpu().detach().numpy(), \
                batch_center.cpu().detach().numpy(), \
                batch_hclass.cpu().detach().numpy(), \
                batch_hres.cpu().detach().numpy(), \
                batch_sclass.cpu().detach().numpy(), \
                batch_sres.cpu().detach().numpy())
            train_iou2d += np.sum(iou2ds)
            train_iou3d += np.sum(iou3ds)
            train_iou3d_acc += np.sum(iou3ds >= 0.7)

            correct = torch.argmax(logits, 2).eq(
                batch_label.long()).detach().cpu().numpy()
            accuracy = np.sum(correct)
            train_acc += accuracy
            if FLAGS.return_all_loss:
                train_mask_loss += mask_loss.item()
                train_center_loss += center_loss.item()
                train_heading_class_loss += heading_class_loss.item()
                train_size_class_loss += size_class_loss.item()
                train_heading_residuals_normalized_loss += heading_residuals_normalized_loss.item(
                )
                train_size_residuals_normalized_loss += size_residuals_normalized_loss.item(
                )
                train_stage1_center_loss += stage1_center_loss.item()
                train_corners_loss += corners_loss.item()
            '''
            print('[%d: %d/%d] train loss: %.6f' % \
                  (epoch + 1, i, len(train_dataloader),(train_total_loss/n_samples)))
            print('box IoU(ground/3D): %.6f/%.6f' % (train_iou2d/n_samples, train_iou3d/n_samples))
            print('box estimation accuracy (IoU=0.7): %.6f' % (train_iou3d_acc/n_samples))
            if FLAGS.return_all_loss:
                print('train_mask_loss:%.6f'%(train_mask_loss/n_samples))
                print('train_stage1_center_loss:%.6f' % (train_stage1_center_loss/n_samples))
                print('train_heading_class_loss:%.6f' % (train_heading_class_loss/n_samples))
                print('train_size_class_loss:%.6f' % (train_size_class_loss/n_samples))
                print('train_heading_residuals_normalized_loss:%.6f' % (train_heading_residuals_normalized_loss/n_samples))
                print('train_size_residuals_normalized_loss:%.6f' % (train_size_residuals_normalized_loss/n_samples))
                print('train_stage1_center_loss:%.6f' % (train_stage1_center_loss/n_samples))
                print('train_corners_loss:%.6f'%(train_corners_loss/n_samples))
            '''
        train_total_loss /= n_samples
        train_acc /= n_samples * float(NUM_POINT)
        train_iou2d /= n_samples
        train_iou3d /= n_samples
        train_iou3d_acc /= n_samples

        if FLAGS.return_all_loss:
            train_mask_loss /= n_samples
            train_center_loss /= n_samples
            train_heading_class_loss /= n_samples
            train_size_class_loss /= n_samples
            train_heading_residuals_normalized_loss /= n_samples
            train_size_residuals_normalized_loss /= n_samples
            train_stage1_center_loss /= n_samples
            train_corners_loss /= n_samples

        print('[%d: %d/%d] train loss: %.6f' % \
              (epoch + 1, i, len(train_dataloader),train_total_loss))
        print('segmentation accuracy: %.6f' % train_acc)
        print('box IoU(ground/3D): %.6f/%.6f' % (train_iou2d, train_iou3d))
        print('box estimation accuracy (IoU=0.7): %.6f' % (train_iou3d_acc))

        # test one epoch
        if FLAGS.return_all_loss:
            test_total_loss, test_iou2d, test_iou3d, test_acc, test_iou3d_acc, \
                test_mask_loss, \
                test_center_loss, \
                test_heading_class_loss, \
                test_size_class_loss, \
                test_heading_residuals_normalized_loss, \
                test_size_residuals_normalized_loss, \
                test_stage1_center_loss, \
                test_corners_loss \
                    = \
                    test_one_epoch(FrustumPointNet,test_dataloader)
        else:
            test_total_loss, test_iou2d, test_iou3d, test_acc, test_iou3d_acc,\
                = \
                test_one_epoch(FrustumPointNet,test_dataloader)

        print('[%d] %s loss: %.6f' % \
              (epoch + 1, blue('test'), test_total_loss))
        print('%s segmentation accuracy: %.6f' % (blue('test'), test_acc))
        print('%s box IoU(ground/3D): %.6f/%.6f' %
              (blue('test'), test_iou2d, test_iou3d))
        print('%s box estimation accuracy (IoU=0.7): %.6f' %
              (blue('test'), test_iou3d_acc))
        print("learning rate: {:.6f}".format(optimizer.param_groups[0]['lr']))
        scheduler.step()

        if not FLAGS.debug:
            writer.add_scalar('train_total_loss', train_total_loss, epoch)
            writer.add_scalar('train_iou2d', train_iou2d, epoch)
            writer.add_scalar('train_iou3d', train_iou3d, epoch)
            writer.add_scalar('train_acc', train_acc, epoch)
            writer.add_scalar('train_iou3d_acc', train_iou3d_acc, epoch)

        if FLAGS.return_all_loss and not FLAGS.debug:
            writer.add_scalar('train_mask_loss', train_mask_loss)
            writer.add_scalar('train_center_loss', train_center_loss, epoch)
            writer.add_scalar('train_heading_class_loss',
                              train_heading_class_loss, epoch)
            writer.add_scalar('train_size_class_loss', train_size_class_loss,
                              epoch)
            writer.add_scalar('train_heading_residuals_normalized_loss',
                              train_heading_residuals_normalized_loss, epoch)
            writer.add_scalar('train_size_residuals_normalized_loss',
                              train_size_residuals_normalized_loss, epoch)
            writer.add_scalar('train_stage1_center_loss',
                              train_stage1_center_loss, epoch)
            writer.add_scalar('train_corners_loss', train_corners_loss, epoch)

        if not FLAGS.debug:
            writer.add_scalar('test_total_loss', test_total_loss, epoch)
            writer.add_scalar('test_iou2d_loss', test_iou2d, epoch)
            writer.add_scalar('test_iou3d_loss', test_iou3d, epoch)
            writer.add_scalar('test_acc', test_acc, epoch)
            writer.add_scalar('test_iou3d_acc', test_iou3d_acc, epoch)

        if FLAGS.return_all_loss:
            writer.add_scalar('test_mask_loss', test_mask_loss, epoch)
            writer.add_scalar('test_center_loss', test_center_loss, epoch)
            writer.add_scalar('test_heading_class_loss',
                              test_heading_class_loss, epoch)
            writer.add_scalar('test_size_class_loss', test_size_class_loss,
                              epoch)
            writer.add_scalar('test_heading_residuals_normalized_loss',
                              test_heading_residuals_normalized_loss, epoch)
            writer.add_scalar('test_size_residuals_normalized_loss',
                              test_size_residuals_normalized_loss, epoch)
            writer.add_scalar('test_stage1_center_loss',
                              test_stage1_center_loss, epoch)
            writer.add_scalar('test_corners_loss', test_corners_loss, epoch)

        if test_iou3d_acc >= best_iou3d_acc:
            best_iou3d_acc = test_iou3d_acc
            best_epoch = epoch + 1
            if epoch > MAX_EPOCH / 5:
                savepath = LOG_DIR + '/' + NAME + '/%s-acc%04f-epoch%03d.pth' % \
                           (NAME, test_iou3d_acc, epoch)
                print('save to:', savepath)
                if os.path.exists(best_file):
                    os.remove(best_file)  # update to newest best epoch
                best_file = savepath
                state = {
                    'epoch': epoch + 1,
                    'train_iou3d_acc': train_iou3d_acc,
                    'test_iou3d_acc': test_iou3d_acc,
                    'model_state_dict': FrustumPointNet.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }
                torch.save(state, savepath)
                print('Saving model to %s' % savepath)
        print('Best Test acc: %f(Epoch %d)' % (best_iou3d_acc, best_epoch))

        # Save the variables to disk.
        #if epoch % 10 == 0:
        #    save_path = saver.save(sess, os.path.join(LOG_DIR, "model.ckpt"))
        #    log_string("Model saved in file: %s" % save_path)
    print("Time {} hours".format(float(time.perf_counter() - start) / 3600))
    writer.close()
def test_one_epoch(model, loader):
    test_n_samples = 0
    test_total_loss = 0.0
    test_iou2d = 0.0
    test_iou3d = 0.0
    test_acc = 0.0
    test_iou3d_acc = 0.0

    if FLAGS.return_all_loss:
        test_mask_loss = 0.0
        test_center_loss = 0.0
        test_heading_class_loss = 0.0
        test_size_class_loss = 0.0
        test_heading_residuals_normalized_loss = 0.0
        test_size_residuals_normalized_loss = 0.0
        test_stage1_center_loss = 0.0
        test_corners_loss = 0.0

    for i, data in tqdm(enumerate(loader), \
                        total=len(loader), smoothing=0.9):
        # for debug
        if FLAGS.debug == True:
            if i == 1:
                break
        test_n_samples += data[0].shape[0]
        '''
        batch_data:[32, 2048, 4], pts in frustum
        batch_label:[32, 2048], pts ins seg label in frustum
        batch_center:[32, 3],
        batch_hclass:[32],
        batch_hres:[32],
        batch_sclass:[32],
        batch_sres:[32,3],
        batch_rot_angle:[32],
        batch_one_hot_vec:[32,3],
        '''
        batch_data, batch_label, batch_center, \
        batch_hclass, batch_hres, \
        batch_sclass, batch_sres, \
        batch_rot_angle, batch_one_hot_vec = data

        batch_data = batch_data.transpose(2, 1).float().cuda()
        batch_label = batch_label.float().cuda()
        batch_center = batch_center.float().cuda()
        batch_hclass = batch_hclass.float().cuda()
        batch_hres = batch_hres.float().cuda()
        batch_sclass = batch_sclass.float().cuda()
        batch_sres = batch_sres.float().cuda()
        batch_rot_angle = batch_rot_angle.float().cuda()
        batch_one_hot_vec = batch_one_hot_vec.float().cuda()

        model = model.eval()

        logits, mask, stage1_center, center_boxnet, \
        heading_scores, heading_residuals_normalized, heading_residuals, \
        size_scores, size_residuals_normalized, size_residuals, center = \
            model(batch_data, batch_one_hot_vec)

        logits = logits.detach()
        stage1_center = stage1_center.detach()
        center_boxnet = center_boxnet.detach()
        heading_scores = heading_scores.detach()
        heading_residuals_normalized = heading_residuals_normalized.detach()
        heading_residuals = heading_residuals.detach()
        size_scores = size_scores.detach()
        size_residuals_normalized = size_residuals_normalized.detach()
        size_residuals = size_residuals.detach()
        center = center.detach()

        if FLAGS.return_all_loss:
            total_loss, mask_loss, center_loss, heading_class_loss, \
                size_class_loss, heading_residuals_normalized_loss, \
                size_residuals_normalized_loss, stage1_center_loss, \
                corners_loss = \
                Loss(logits, batch_label, \
                     center, batch_center, stage1_center, \
                     heading_scores, heading_residuals_normalized, \
                     heading_residuals, \
                     batch_hclass, batch_hres, \
                     size_scores, size_residuals_normalized, \
                     size_residuals, \
                     batch_sclass, batch_sres)
        else:
            total_loss = \
                Loss(logits, batch_label, \
                     center, batch_center, stage1_center, \
                     heading_scores, heading_residuals_normalized, \
                     heading_residuals, \
                     batch_hclass, batch_hres, \
                     size_scores, size_residuals_normalized, \
                     size_residuals, \
                     batch_sclass, batch_sres)

        test_total_loss += total_loss.item()

        iou2ds, iou3ds = provider.compute_box3d_iou( \
            center.cpu().detach().numpy(), \
            heading_scores.cpu().detach().numpy(), \
            heading_residuals.cpu().detach().numpy(), \
            size_scores.cpu().detach().numpy(), \
            size_residuals.cpu().detach().numpy(), \
            batch_center.cpu().detach().numpy(), \
            batch_hclass.cpu().detach().numpy(), \
            batch_hres.cpu().detach().numpy(), \
            batch_sclass.cpu().detach().numpy(), \
            batch_sres.cpu().detach().numpy())
        test_iou2d += np.sum(iou2ds)
        test_iou3d += np.sum(iou3ds)

        correct = torch.argmax(logits, 2).eq(
            batch_label.detach().long()).cpu().numpy()
        accuracy = np.sum(correct) / float(NUM_POINT)
        test_acc += accuracy

        test_iou3d_acc += np.sum(iou3ds >= 0.7)

        if FLAGS.return_all_loss:
            test_mask_loss += mask_loss.item()
            test_center_loss += center_loss.item()
            test_heading_class_loss += heading_class_loss.item()
            test_size_class_loss += size_class_loss.item()
            test_heading_residuals_normalized_loss += heading_residuals_normalized_loss.item(
            )
            test_size_residuals_normalized_loss += size_residuals_normalized_loss.item(
            )
            test_stage1_center_loss += stage1_center_loss.item()
            test_corners_loss += corners_loss.item()

    if FLAGS.return_all_loss:
        return test_total_loss / test_n_samples, \
               test_iou2d / test_n_samples, \
               test_iou3d / test_n_samples, \
               test_acc / test_n_samples, \
               test_iou3d_acc / test_n_samples,\
               test_mask_loss / test_n_samples, \
               test_center_loss / test_n_samples, \
               test_heading_class_loss / test_n_samples, \
               test_size_class_loss / test_n_samples, \
               test_heading_residuals_normalized_loss / test_n_samples, \
               test_size_residuals_normalized_loss / test_n_samples, \
               test_stage1_center_loss / test_n_samples, \
               test_corners_loss / test_n_samples
    else:
        return test_total_loss/test_n_samples,  \
               test_iou2d/test_n_samples, \
               test_iou3d/test_n_samples, \
               test_acc/test_n_samples, \
               test_iou3d_acc/test_n_samples
def test_one_epoch(model, loader):
    ''' Test frustum pointnets with GT 2D boxes.
    Write test results to KITTI format label files.
    todo (rqi): support variable number of points.
    '''
    ps_list = []
    seg_list = []
    segp_list = []
    center_list = []
    heading_cls_list = []
    heading_res_list = []
    size_cls_list = []
    size_res_list = []
    rot_angle_list = []
    score_list = []

    test_idxs = np.arange(0, len(TEST_DATASET))
    batch_size = BATCH_SIZE
    num_batches = len(TEST_DATASET) // batch_size

    test_n_samples = 0
    test_total_loss = 0.0
    test_iou2d = 0.0
    test_iou3d = 0.0
    test_acc = 0.0
    test_iou3d_acc = 0.0
    eval_time = 0.0
    if FLAGS.return_all_loss:
        test_mask_loss = 0.0
        test_center_loss = 0.0
        test_heading_class_loss = 0.0
        test_size_class_loss = 0.0
        test_heading_residuals_normalized_loss = 0.0
        test_size_residuals_normalized_loss = 0.0
        test_stage1_center_loss = 0.0
        test_corners_loss = 0.0

    pos_cnt = 0.0
    pos_pred_cnt = 0.0
    all_cnt = 0.0
    max_info = np.zeros(3)
    min_info = np.zeros(3)
    mean_info = np.zeros(3)
    for i, data in tqdm(enumerate(loader), \
                        total=len(loader), smoothing=0.9):
        # for debug
        if FLAGS.debug == True:
            if i == 1:
                break
        test_n_samples += data[0].shape[0]
        '''
        batch_data:[32, 2048, 4], pts in frustum
        batch_label:[32, 2048], pts ins seg label in frustum
        batch_center:[32, 3],
        batch_hclass:[32],
        batch_hres:[32],
        batch_sclass:[32],
        batch_sres:[32,3],
        batch_rot_angle:[32],
        batch_one_hot_vec:[32,3],
        '''
        # 1. Load data
        batch_data, batch_label, batch_center, \
        batch_hclass, batch_hres, \
        batch_sclass, batch_sres, \
        batch_rot_angle, batch_one_hot_vec = data

        batch_data = batch_data.transpose(2, 1).float().cuda()
        batch_label = batch_label.float().cuda()
        batch_center = batch_center.float().cuda()
        batch_hclass = batch_hclass.float().cuda()
        batch_hres = batch_hres.float().cuda()
        batch_sclass = batch_sclass.float().cuda()
        batch_sres = batch_sres.float().cuda()
        batch_rot_angle = batch_rot_angle.float().cuda()
        batch_one_hot_vec = batch_one_hot_vec.float().cuda()

        # 2. Eval one batch
        model = model.eval()
        eval_t1 = time.perf_counter()
        logits, mask, stage1_center, center_boxnet, \
        heading_scores, heading_residuals_normalized, heading_residuals, \
        size_scores, size_residuals_normalized, size_residuals, center = \
            model(batch_data, batch_one_hot_vec)
        #logits:[32, 1024, 2] , mask:[32, 1024]
        eval_t2 = time.perf_counter()
        eval_time += (eval_t2 - eval_t1)

        # 3. Compute Loss
        if FLAGS.return_all_loss:
            total_loss, mask_loss, center_loss, heading_class_loss, \
                size_class_loss, heading_residuals_normalized_loss, \
                size_residuals_normalized_loss, stage1_center_loss, \
                corners_loss = \
                Loss(logits, batch_label, \
                     center, batch_center, stage1_center, \
                     heading_scores, heading_residuals_normalized, \
                     heading_residuals, \
                     batch_hclass, batch_hres, \
                     size_scores, size_residuals_normalized, \
                     size_residuals, \
                     batch_sclass, batch_sres)
        else:
            total_loss = \
                Loss(logits, batch_label, \
                     center, batch_center, stage1_center, \
                     heading_scores, heading_residuals_normalized, \
                     heading_residuals, \
                     batch_hclass, batch_hres, \
                     size_scores, size_residuals_normalized, \
                     size_residuals, \
                     batch_sclass, batch_sres)

        test_total_loss += total_loss.item()
        if FLAGS.return_all_loss:
            test_mask_loss += mask_loss.item()
            test_center_loss += center_loss.item()
            test_heading_class_loss += heading_class_loss.item()
            test_size_class_loss += size_class_loss.item()
            test_heading_residuals_normalized_loss += heading_residuals_normalized_loss.item(
            )
            test_size_residuals_normalized_loss += size_residuals_normalized_loss.item(
            )
            test_stage1_center_loss += stage1_center_loss.item()
            test_corners_loss += corners_loss.item()

        # 4. compute seg acc, IoU and acc(IoU)
        correct = torch.argmax(logits, 2).eq(
            batch_label.detach().long()).cpu().numpy()
        accuracy = np.sum(correct) / float(NUM_POINT)
        test_acc += accuracy

        logits = logits.cpu().detach().numpy()
        mask = mask.cpu().detach().numpy()
        center_boxnet = center_boxnet.cpu().detach().numpy()
        #stage1_center = stage1_center.cpu().detach().numpy()#
        center = center.cpu().detach().numpy()
        heading_scores = heading_scores.cpu().detach().numpy()
        #heading_residuals_normalized = heading_residuals_normalized.cpu().detach().numpy()
        heading_residuals = heading_residuals.cpu().detach().numpy()
        size_scores = size_scores.cpu().detach().numpy()
        size_residuals = size_residuals.cpu().detach().numpy()
        #size_residuals_normalized = size_residuals_normalized.cpu().detach().numpy()#
        batch_rot_angle = batch_rot_angle.cpu().detach().numpy()
        batch_center = batch_center.cpu().detach().numpy()
        batch_hclass = batch_hclass.cpu().detach().numpy()
        batch_hres = batch_hres.cpu().detach().numpy()
        batch_sclass = batch_sclass.cpu().detach().numpy()
        batch_sres = batch_sres.cpu().detach().numpy()

        iou2ds, iou3ds = provider.compute_box3d_iou(
            center, heading_scores, heading_residuals, size_scores,
            size_residuals, batch_center, batch_hclass, batch_hres,
            batch_sclass, batch_sres)
        test_iou2d += np.sum(iou2ds)
        test_iou3d += np.sum(iou3ds)
        test_iou3d_acc += np.sum(iou3ds >= 0.7)

        # 5. Compute and write all Results
        batch_output = np.argmax(logits, 2)  #mask#torch.Size([32, 1024])
        batch_center_pred = center  #_boxnet#torch.Size([32, 3])
        batch_hclass_pred = np.argmax(heading_scores, 1)  # (32,)
        batch_hres_pred = np.array([heading_residuals[j, batch_hclass_pred[j]] \
                                    for j in range(batch_data.shape[0])]) # (32,)
        # batch_size_cls,batch_size_res
        batch_sclass_pred = np.argmax(size_scores, 1)  # (32,)
        batch_sres_pred = np.vstack([size_residuals[j, batch_sclass_pred[j], :] \
                                     for j in range(batch_data.shape[0])]) # (32,3)

        # batch_scores
        batch_seg_prob = softmax(logits)[:, :, 1]  # (32, 1024, 2) ->(32, 1024)
        batch_seg_mask = np.argmax(logits, 2)  # BxN
        mask_mean_prob = np.sum(batch_seg_prob * batch_seg_mask, 1)  # B,
        mask_mean_prob = mask_mean_prob / np.sum(batch_seg_mask, 1)  # B,
        heading_prob = np.max(softmax(heading_scores), 1)  # B
        size_prob = np.max(softmax(size_scores), 1)  # B,
        #batch_scores = np.log(mask_mean_prob) + np.log(heading_prob) + np.log(size_prob)
        # batch_scores = mask_mean_prob/3 + heading_prob/3 + size_prob/3
        batch_scores = np.ones_like(mask_mean_prob) + 0.1
        # batch_scores = heading_prob/2 + size_prob/2

        #ipdb.set_trace()
        # batch_scores = heading_prob
        for j in range(batch_output.shape[0]):
            ps_list.append(batch_data[j, ...])
            seg_list.append(batch_label[j, ...])
            segp_list.append(batch_output[j, ...])
            center_list.append(batch_center_pred[j, :])
            heading_cls_list.append(batch_hclass_pred[j])
            heading_res_list.append(batch_hres_pred[j])
            size_cls_list.append(batch_sclass_pred[j])
            size_res_list.append(batch_sres_pred[j, :])
            rot_angle_list.append(batch_rot_angle[j])
            score_list.append(batch_scores[j])
            pos_cnt += np.sum(batch_label[j, :].cpu().detach().numpy())
            pos_pred_cnt += np.sum(batch_output[j, :])
            pts_np = batch_data[j, :3, :].cpu().detach().numpy()  #(3,1024)
            max_xyz = np.max(pts_np, axis=1)
            max_info = np.maximum(max_info, max_xyz)
            min_xyz = np.min(pts_np, axis=1)
            min_info = np.minimum(min_info, min_xyz)
            mean_info += np.sum(pts_np, axis=1)
    '''
    return np.argmax(logits, 2), centers, heading_cls, heading_res, \
        size_cls, size_res, scores
        
	batch_output, batch_center_pred, \
        batch_hclass_pred, batch_hres_pred, \
        batch_sclass_pred, batch_sres_pred, batch_scores = \
            inference(sess, ops, batch_data,
                batch_one_hot_vec, batch_size=batch_size)
    '''
    if FLAGS.dump_result:
        print('dumping...')
        with open(output_filename, 'wp') as fp:
            pickle.dump(ps_list, fp)
            pickle.dump(seg_list, fp)
            pickle.dump(segp_list, fp)
            pickle.dump(center_list, fp)
            pickle.dump(heading_cls_list, fp)
            pickle.dump(heading_res_list, fp)
            pickle.dump(size_cls_list, fp)
            pickle.dump(size_res_list, fp)
            pickle.dump(rot_angle_list, fp)
            pickle.dump(score_list, fp)

    # Write detection results for KITTI evaluation
    print('Number of point clouds: %d' % (len(ps_list)))
    write_detection_results(result_dir, TEST_DATASET.id_list,
                            TEST_DATASET.type_list, TEST_DATASET.box2d_list,
                            center_list, heading_cls_list, heading_res_list,
                            size_cls_list, size_res_list, rot_angle_list,
                            score_list)
    # Make sure for each frame (no matter if we have measurment for that frame),
    # there is a TXT file
    output_dir = os.path.join(result_dir, 'data')
    if FLAGS.idx_path is not None:
        to_fill_filename_list = [line.rstrip()+'.txt' \
            for line in open(FLAGS.idx_path)]
        fill_files(output_dir, to_fill_filename_list)

    all_cnt = FLAGS.num_point * len(ps_list)
    print('Average pos ratio: %f' % (pos_cnt / float(all_cnt)))
    print('Average pos prediction ratio: %f' % (pos_pred_cnt / float(all_cnt)))
    print('Average npoints: %f' % (float(all_cnt) / len(ps_list)))
    mean_info = mean_info / len(ps_list) / FLAGS.num_point
    print('Mean points: x%f y%f z%f' %
          (mean_info[0], mean_info[1], mean_info[2]))
    print('Max points: x%f y%f z%f' % (max_info[0], max_info[1], max_info[2]))
    print('Min points: x%f y%f z%f' % (min_info[0], min_info[1], min_info[2]))
    '''
    2020.2.9
    
    nuscenes->nuscenes:
    Number of point clouds: 6408
    Average pos ratio: 151.052473
    Average pos prediction ratio: 0.982697
    Average npoints: 393.820069
    Mean points: x-0.064442 y0.845251 z35.271175
    Max points: x69.445435 y8.203144 z104.238876
    Min points: x-75.259071 y-14.208739 z0.000000
    test from 2d gt: Done
    test loss: 0.948371
    test segmentation accuracy: 0.842269
    test box IoU(ground/3D): 0.563429/0.483347
    test box estimation accuracy (IoU=0.7): 0.243914

    kitti->kitti:
    Number of point clouds: 12538
    Average pos ratio: 59.842735
    Average pos prediction ratio: 1.089077
    Average npoints: 447.111102
    Mean points: x0.026136 y0.987294 z24.958540
    Max points: x16.992741 y9.347202 z79.747406
    Min points: x-20.476559 y-3.882158 z0.000000
    test from 2d gt: Done
    test loss: 0.104301
    test segmentation accuracy: 0.901477
    test box IoU(ground/3D): 0.796052/0.743673
    test box estimation accuracy (IoU=0.7): 0.761605

    kitti->nuscenes:
    Number of point clouds: 6408
    Average pos ratio: 151.174013
    Average pos prediction ratio: 0.000000
    Average npoints: 393.582865
    Mean points: x-0.063360 y0.845064 z35.272771
    Max points: x69.348808 y9.001218 z104.238876
    Min points: x-67.831200 y-14.235371 z0.000000
    test from 2d gt: Done
    test loss: 25.182920
    test segmentation accuracy: 0.615642
    test box IoU(ground/3D): 0.023954/0.019012
    test box estimation accuracy (IoU=0.7): 0.000000

    


    '''
    if FLAGS.return_all_loss:
        return test_total_loss / test_n_samples, \
               test_iou2d / test_n_samples, \
               test_iou3d / test_n_samples, \
               test_acc / test_n_samples, \
               test_iou3d_acc / test_n_samples,\
               test_mask_loss / test_n_samples, \
               test_center_loss / test_n_samples, \
               test_heading_class_loss / test_n_samples, \
               test_size_class_loss / test_n_samples, \
               test_heading_residuals_normalized_loss / test_n_samples, \
               test_size_residuals_normalized_loss / test_n_samples, \
               test_stage1_center_loss / test_n_samples, \
               test_corners_loss / test_n_samples
    else:
        return test_total_loss/test_n_samples,  \
               test_iou2d/test_n_samples, \
               test_iou3d/test_n_samples, \
               test_acc/test_n_samples, \
               test_iou3d_acc/test_n_samples
    def forward(self, data_dicts):
        #dict_keys(['point_cloud', 'rot_angle', 'box3d_center', 'size_class', 'size_residual', 'angle_class', 'angle_residual', 'one_hot', 'seg'])

        point_cloud = data_dicts.get('point_cloud')  #torch.Size([32, 4, 1024])
        point_cloud = point_cloud[:, :self.n_channel, :]
        one_hot = data_dicts.get('one_hot')  #torch.Size([32, 3])
        bs = point_cloud.shape[0]
        # If not None, use to Compute Loss
        seg_label = data_dicts.get('seg')  #torch.Size([32, 1024])
        box3d_center_label = data_dicts.get(
            'box3d_center')  #torch.Size([32, 3])
        size_class_label = data_dicts.get('size_class')  #torch.Size([32, 1])
        size_residual_label = data_dicts.get(
            'size_residual')  #torch.Size([32, 3])
        heading_class_label = data_dicts.get(
            'angle_class')  #torch.Size([32, 1])
        heading_residual_label = data_dicts.get(
            'angle_residual')  #torch.Size([32, 1])

        # 3D Instance Segmentation PointNet
        logits = self.InsSeg(point_cloud, one_hot)  #bs,n,2

        # Mask Point Centroid
        object_pts_xyz, mask_xyz_mean, mask = \
                 point_cloud_masking(point_cloud, logits)

        # T-Net
        object_pts_xyz = object_pts_xyz.cuda()
        center_delta = self.STN(object_pts_xyz, one_hot)  #(32,3)
        stage1_center = center_delta + mask_xyz_mean  #(32,3)

        if (np.isnan(stage1_center.cpu().detach().numpy()).any()):
            ipdb.set_trace()
        object_pts_xyz_new = object_pts_xyz - \
                    center_delta.view(center_delta.shape[0],-1,1).repeat(1,1,object_pts_xyz.shape[-1])

        # 3D Box Estimation
        box_pred = self.est(object_pts_xyz_new, one_hot)  #(32, 59)

        center_boxnet, \
        heading_scores, heading_residual_normalized, heading_residual, \
        size_scores, size_residual_normalized, size_residual = \
                parse_output_to_tensors(box_pred, logits, mask, stage1_center)

        box3d_center = center_boxnet + stage1_center  #bs,3

        losses = self.Loss(logits, seg_label, \
                 box3d_center, box3d_center_label, stage1_center, \
                 heading_scores, heading_residual_normalized, \
                 heading_residual, \
                 heading_class_label, heading_residual_label, \
                 size_scores, size_residual_normalized, \
                 size_residual, \
                 size_class_label, size_residual_label)

        for key in losses.keys():
            losses[key] = losses[key] / bs

        with torch.no_grad():
            seg_correct = torch.argmax(logits.detach().cpu(),
                                       2).eq(seg_label.detach().cpu()).numpy()
            seg_accuracy = np.sum(seg_correct) / float(point_cloud.shape[-1])

            iou2ds, iou3ds = compute_box3d_iou( \
                box3d_center.detach().cpu().numpy(),
                heading_scores.detach().cpu().numpy(),
                heading_residual.detach().cpu().numpy(),
                size_scores.detach().cpu().numpy(),
                size_residual.detach().cpu().numpy(),
                box3d_center_label.detach().cpu().numpy(),
                heading_class_label.detach().cpu().numpy(),
                heading_residual_label.detach().cpu().numpy(),
                size_class_label.detach().cpu().numpy(),
                size_residual_label.detach().cpu().numpy())
        metrics = {
            'seg_acc': seg_accuracy,
            'iou2d': iou2ds.mean(),
            'iou3d': iou3ds.mean(),
            'iou3d_0.7': np.sum(iou3ds >= 0.7) / bs
        }
        return losses, metrics
Esempio n. 7
0
    def eval_epoch(self):
        self.model.eval()
        test_idxs = np.arange(0, len(self.valid_dataset))
        num_batches = len(self.valid_dataset) // self.val_batch_size

        # To collect statistics
        total_correct = 0
        total_seen = 0
        loss_sum = 0
        total_seen_class = [0 for _ in range(self.config.NUM_CLASSES)]
        total_correct_class = [0 for _ in range(self.config.NUM_CLASSES)]
        iou2ds_sum = 0
        iou3ds_sum = 0
        iou3d_correct_cnt = 0

        # Simple evaluation with batches
        for batch_idx in range(num_batches):
            start_idx = batch_idx * self.val_batch_size
            end_idx = (batch_idx + 1) * self.val_batch_size

            batch_data, batch_label, batch_center, \
            batch_hclass, batch_hres, \
            batch_sclass, batch_sres, \
            batch_rot_angle, batch_one_hot_vec = \
                tuple(get_batch(self.valid_dataset, test_idxs, start_idx, end_idx,
                          self.config.NUM_POINT, self.config.NUM_CHANNELS))

            with torch.no_grad():
                self.endpoints = self.model(batch_data, batch_one_hot_vec)
                val_loss = self.loss(batch_label, batch_center, batch_hclass,
                                     batch_hres, batch_sclass, batch_sres)

            preds_val = np.argmax(
                self.endpoints['mask_logits'].detach().cpu().numpy(), 2)
            correct = np.sum(preds_val == batch_label.detach().cpu().numpy())
            total_correct += correct
            total_seen += (self.val_batch_size * self.config.NUM_POINT)
            loss_sum += val_loss

            iou2ds, iou3ds = compute_box3d_iou(
                self.endpoints['center'].detach().cpu().numpy(),
                self.endpoints['heading_scores'].detach().cpu().numpy(),
                self.endpoints['heading_residuals'].detach().cpu().numpy(),
                self.endpoints['size_scores'].detach().cpu().numpy(),
                self.endpoints['size_residuals'].detach().cpu().numpy(),
                batch_center.detach().cpu().numpy(),
                batch_hclass.detach().cpu().numpy(),
                batch_hres.detach().cpu().numpy(),
                batch_sclass.detach().cpu().numpy(),
                batch_sres.detach().cpu().numpy())
            self.endpoints['iou2ds'] = iou2ds
            self.endpoints['iou3ds'] = iou3ds

            iou2ds_sum += np.sum(self.endpoints['iou2ds'])
            iou3ds_sum += np.sum(self.endpoints['iou3ds'])
            iou3d_correct_cnt += np.sum(self.endpoints['iou3ds'] >= 0.7)

            for l in range(self.config.NUM_CLASSES):
                total_seen_class[l] += np.sum(
                    batch_label.detach().cpu().numpy() == l)
                total_correct_class[l] += (
                    np.sum((preds_val == l)
                           & (batch_label.detach().cpu().numpy() == l)))
        seg_acc = (total_correct / float(total_seen))
        iou_ground = iou2ds_sum / float(self.val_batch_size * num_batches)
        iou_3d = iou3ds_sum / float(self.val_batch_size * num_batches)

        box_acc = float(iou3d_correct_cnt) / float(
            self.val_batch_size * num_batches)

        self.log_values(batch_idx, loss_sum / float(num_batches), seg_acc,
                        iou_ground, iou_3d, box_acc, 'Val')

        if self.best_val_loss > (loss_sum / float(num_batches)):
            self.best_val_loss = (loss_sum / float(num_batches))
            self.best_model = self.model
            save_checkpoint('./models/best_model.pth', self.model, self.epoch,
                            self.optimizer, self.best_val_loss)
Esempio n. 8
0
    def train_epoch(self):
        # this is the current iteration inside the epoch

        train_idxs = np.arange(0, self.train_dataset_length)
        np.random.shuffle(train_idxs)

        # To collect statistics
        total_correct = 0
        total_seen = 0
        loss_sum = 0
        iou2ds_sum = 0
        iou3ds_sum = 0

        iou3d_correct_cnt = 0

        for batch_idx in range(self.num_batches):
            self.global_step += 1
            start_idx = batch_idx * self.train_batch_size
            end_idx = (batch_idx + 1) * self.train_batch_size

            batch_data, batch_label, batch_center, \
            batch_hclass, batch_hres, \
            batch_sclass, batch_sres, \
            batch_rot_angle, batch_one_hot_vec = \
                tuple(get_batch(self.train_dataset, train_idxs, start_idx, end_idx,
                                self.config.NUM_POINT, self.config.NUM_CHANNELS))
            self.model.zero_grad()
            self.endpoints = self.model(batch_data, batch_one_hot_vec)
            total_loss = self.loss(batch_label, batch_center, batch_hclass,
                                   batch_hres, batch_sclass, batch_sres)

            total_loss.backward()
            # self.loss.losses['seg_loss'].backward()
            # self.loss.losses['size_class_loss'].backward()
            # self.loss.losses['heading_residual_normalized_loss'].backward()
            # self.loss.losses['size_residuals_normalized_loss'].backward()
            # self.loss.losses['stage1_center_loss'].backward()
            # self.loss.losses['corner_loss'].backward()
            # self.loss.losses['center_loss'].backward()

            self.optimizer.step()

            # print("after backward: ", type(self.endpoints))
            # print("after backward: ", self.endpoints.keys())

            preds_val = np.argmax(
                self.endpoints['mask_logits'].detach().cpu().numpy(), 2)
            correct = np.sum(preds_val == batch_label.detach().cpu().numpy())
            total_correct += correct
            total_seen += (self.train_batch_size * self.config.NUM_POINT)
            loss_sum += total_loss

            iou2ds, iou3ds = compute_box3d_iou(
                self.endpoints['center'].detach().cpu().numpy(),
                self.endpoints['heading_scores'].detach().cpu().numpy(),
                self.endpoints['heading_residuals'].detach().cpu().numpy(),
                self.endpoints['size_scores'].detach().cpu().numpy(),
                self.endpoints['size_residuals'].detach().cpu().numpy(),
                batch_center.detach().cpu().numpy(),
                batch_hclass.detach().cpu().numpy(),
                batch_hres.detach().cpu().numpy(),
                batch_sclass.detach().cpu().numpy(),
                batch_sres.detach().cpu().numpy())
            self.endpoints['iou2ds'] = iou2ds
            self.endpoints['iou3ds'] = iou3ds

            iou2ds_sum += np.sum(self.endpoints['iou2ds'])
            iou3ds_sum += np.sum(self.endpoints['iou3ds'])
            iou3d_correct_cnt += np.sum(self.endpoints['iou3ds'] >= 0.7)

            if (batch_idx + 1) % self.log_interval == 0:
                seg_acc = (total_correct / float(total_seen))
                iou_ground = iou2ds_sum / float(
                    self.train_batch_size * self.log_interval)
                iou_3d = iou3ds_sum / float(
                    self.train_batch_size * self.log_interval)

                box_acc = float(iou3d_correct_cnt) / float(
                    self.train_batch_size * self.log_interval)

                self.log_values(batch_idx, loss_sum / self.log_interval,
                                seg_acc, iou_ground, iou_3d, box_acc, 'Train')

                total_correct = 0
                total_seen = 0
                loss_sum = 0
                iou2ds_sum = 0
                iou3ds_sum = 0
                iou3d_correct_cnt = 0
Esempio n. 9
0
    def train_epoch(self):

        train_idxs = np.arange(0, self.train_dataset_length)
        np.random.shuffle(train_idxs)

        loss_sum = 0
        iou2ds_sum = 0
        iou3ds_sum = 0

        iou3d_correct_cnt = 0

        for batch_idx in range(self.num_batches):
            self.global_step += 1
            start_idx = batch_idx * self.train_batch_size
            end_idx = (batch_idx + 1) * self.train_batch_size

            batch_data, batch_label, batch_center, \
            batch_hclass, batch_hres, \
            batch_sclass, batch_sres, \
            batch_rot_angle, batch_one_hot_vec = \
                tuple(get_batch(self.train_dataset, train_idxs, start_idx, end_idx,
                                self.config.NUM_POINT, self.config.NUM_CHANNELS))
            self.model.zero_grad()
            self.endpoints = self.model(batch_data, batch_one_hot_vec,
                                        batch_label)
            total_loss = self.loss(batch_center, batch_hclass, batch_hres,
                                   batch_sclass, batch_sres)

            total_loss.backward()
            self.optimizer.step()

            loss_sum += total_loss

            iou2ds, iou3ds = compute_box3d_iou(
                self.endpoints['center'].detach().cpu().numpy(),
                self.endpoints['heading_scores'].detach().cpu().numpy(),
                self.endpoints['heading_residuals'].detach().cpu().numpy(),
                self.endpoints['size_scores'].detach().cpu().numpy(),
                self.endpoints['size_residuals'].detach().cpu().numpy(),
                batch_center.detach().cpu().numpy(),
                batch_hclass.detach().cpu().numpy(),
                batch_hres.detach().cpu().numpy(),
                batch_sclass.detach().cpu().numpy(),
                batch_sres.detach().cpu().numpy())
            self.endpoints['iou2ds'] = iou2ds
            self.endpoints['iou3ds'] = iou3ds

            iou2ds_sum += np.sum(self.endpoints['iou2ds'])
            iou3ds_sum += np.sum(self.endpoints['iou3ds'])
            iou3d_correct_cnt += np.sum(self.endpoints['iou3ds'] >= 0.7)

            if (batch_idx + 1) % self.log_interval == 0:

                box_acc = float(iou3d_correct_cnt) / float(
                    self.train_batch_size * self.log_interval)
                self.log_box_values(batch_idx, loss_sum / self.log_interval,
                                    box_acc, 'Train')

                loss_sum = 0
                iou2ds_sum = 0
                iou3ds_sum = 0
                iou3d_correct_cnt = 0