def test(model, loader, out_channel, criterion_rmse, criterion_cos):
    rot_error = []
    xyz_error = []
    network = model.eval()

    for j, data in tqdm(enumerate(loader), total=len(loader)):
        points = data[parameter.pcd_key].numpy()
        points = provider.normalize_data(points)
        points = torch.Tensor(points)
        delta_rot = data[parameter.delta_rot_key]
        delta_xyz = data[parameter.delta_xyz_key]

        if not args.use_cpu:
            points = points.cuda()
            delta_rot = delta_rot.cuda()
            delta_xyz = delta_xyz.cuda()

        points = points.transpose(2, 1)
        pred, _ = network(points)
        delta_rot_pred_6d = pred[:, 0:6]
        delta_rot_pred = compute_rotation_matrix_from_ortho6d(
            delta_rot_pred_6d, args.use_cpu)  # batch*3*3
        delta_xyz_pred = pred[:, 6:9].view(-1, 3)  # batch*3

        loss_r = criterion_rmse(delta_rot_pred, delta_rot)
        loss_t = (1 - criterion_cos(delta_xyz_pred, delta_xyz)
                  ).mean() + criterion_rmse(delta_xyz_pred, delta_xyz)

        rot_error.append(loss_r.item())
        xyz_error.append(loss_t.item())

    rot_error = sum(rot_error) / len(rot_error)
    xyz_error = sum(xyz_error) / len(xyz_error)

    return rot_error, xyz_error
Esempio n. 2
0
def test(model, loader, out_channel, criterion_rmse, criterion_cos,
         criterion_bce):
    rot_error = []
    xyz_error = []
    heatmap_error = []
    step_size_error = []
    network = model.eval()

    for j, data in tqdm(enumerate(loader), total=len(loader)):
        points = data[parameter.pcd_key].numpy()
        points = provider.normalize_data(points)
        points = torch.Tensor(points)
        delta_rot = data[parameter.delta_rot_key]
        delta_xyz = data[parameter.delta_xyz_key]
        heatmap_target = data[parameter.heatmap_key]
        segmentation_target = data[parameter.segmentation_key]
        unit_delta_xyz = data[parameter.unit_delta_xyz_key]
        step_size = data[parameter.step_size_key]

        if not args.use_cpu:
            points = points.cuda()
            delta_rot = delta_rot.cuda()
            delta_xyz = delta_xyz.cuda()
            unit_delta_xyz = unit_delta_xyz.cuda()
            step_size = step_size.cuda()
            heatmap_target = heatmap_target.cuda()

        points = points.transpose(2, 1)
        heatmap_pred, action_pred, step_size_pred = network(points)

        # action control
        delta_rot_pred_6d = action_pred[:, 0:6]
        delta_rot_pred = compute_rotation_matrix_from_ortho6d(
            delta_rot_pred_6d, args.use_cpu)  # batch*3*3
        delta_xyz_pred = action_pred[:, 6:9].view(-1, 3)  # batch*3

        # loss computation
        loss_heatmap = criterion_rmse(heatmap_pred, heatmap_target)
        loss_r = criterion_rmse(delta_rot_pred, delta_rot)
        #loss_t = (1-criterion_cos(delta_xyz_pred, delta_xyz)).mean() + criterion_rmse(delta_xyz_pred, delta_xyz)
        loss_t = (1 - criterion_cos(delta_xyz_pred, unit_delta_xyz)).mean()
        loss_step_size = criterion_bce(step_size_pred, step_size)

        rot_error.append(loss_r.item())
        xyz_error.append(loss_t.item())
        heatmap_error.append(loss_heatmap.item())
        step_size_error.append(loss_step_size.item())

    rot_error = sum(rot_error) / len(rot_error)
    xyz_error = sum(xyz_error) / len(xyz_error)
    heatmap_error = sum(heatmap_error) / len(heatmap_error)
    step_size_error = sum(step_size_error) / len(step_size_error)

    return rot_error, xyz_error, heatmap_error, step_size_error
Esempio n. 3
0
def eval_one_epoch(sess, ops, test_writer):
    """ ops: dict mapping from string to tf ops """
    global EPOCH_CNT
    is_training = False
    test_idxs = np.arange(0, len(TEST_DATASET))
    num_batches = int(len(TEST_DATASET) / BATCH_SIZE)

    total_correct = 0
    total_seen = 0
    loss_sum = 0
    total_seen_class = [0 for _ in range(NUM_CLASSES)]
    total_correct_class = [0 for _ in range(NUM_CLASSES)]
    total_iou_deno_class = [0 for _ in range(NUM_CLASSES)]

    log_string(str(datetime.now()))
    log_string('---- EPOCH %03d EVALUATION ----' % (EPOCH_CNT))

    labelweights = np.zeros(21)
    for batch_idx in range(num_batches):
        start_idx = batch_idx * BATCH_SIZE
        end_idx = (batch_idx + 1) * BATCH_SIZE
        batch_data, batch_label, batch_smpw = get_batch(TEST_DATASET, test_idxs, start_idx, end_idx)
        batch_data[:, :, :3] = provider.normalize_data(batch_data[:, :, :3])
        batch_data[:, :, :3] = provider.rotate_point_cloud_z(batch_data[:, :, :3])

        feed_dict = {ops['pointclouds_pl']: batch_data,
                     ops['labels_pl']: batch_label,
                     ops['smpws_pl']: batch_smpw,
                     ops['is_training_pl']: is_training}
        summary, step, loss_val, pred_val = sess.run([ops['merged'], ops['step'],
                                                      ops['loss'], ops['pred']], feed_dict=feed_dict)
        test_writer.add_summary(summary, step)
        pred_val = np.argmax(pred_val, 2)  # BxN
        correct = np.sum((pred_val == batch_label) & (batch_label > 0) & (batch_smpw > 0))  # evaluate only on 20 categories but not unknown
        total_correct += correct
        total_seen += np.sum((batch_label > 0) & (batch_smpw > 0))
        loss_sum += loss_val
        tmp, _ = np.histogram(batch_label, range(22))
        labelweights += tmp
        for l in range(NUM_CLASSES):
            total_seen_class[l] += np.sum((batch_label == l) & (batch_smpw > 0))
            total_correct_class[l] += np.sum((pred_val == l) & (batch_label == l) & (batch_smpw > 0))
            total_iou_deno_class[l] += np.sum(((pred_val == l) | (batch_label == l)) & (batch_smpw > 0))
    mIoU = np.mean(np.array(total_correct_class[1:]) / (np.array(total_iou_deno_class[1:], dtype=np.float) + 1e-6))
    log_string('Eval mean loss: %f' % (loss_sum / float(num_batches)))
    log_string('Eval point avg class IoU: %f' % (mIoU))
    log_string('Eval point accuracy: %f' % (total_correct / float(total_seen)))
    log_string('Eval point avg class acc: %f' % (np.mean(np.array(total_correct_class[1:]) / (np.array(total_seen_class[1:], dtype=np.float) + 1e-6))))

    EPOCH_CNT += 1
    return mIoU
Esempio n. 4
0
def train_one_epoch(sess, ops, train_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = True

    # Shuffle train samples
    train_idxs = np.arange(0, len(TRAIN_DATASET))
    np.random.shuffle(train_idxs)
    num_batches = int(len(TRAIN_DATASET) / BATCH_SIZE)

    log_string(str(datetime.now()))

    total_correct = 0
    total_seen = 0
    loss_sum = 0
    total_iou_deno = 0
    for batch_idx in tqdm(range(num_batches), total=num_batches):
        start_idx = batch_idx * BATCH_SIZE
        end_idx = (batch_idx + 1) * BATCH_SIZE
        batch_data, batch_label, batch_smpw = get_batch(TRAIN_DATASET, train_idxs, start_idx, end_idx)
        # Augment batched point clouds by rotation
        batch_data[:, :, :3] = provider.rotate_point_cloud_z(batch_data[:, :, :3])
        batch_data[:, :, :3] = provider.normalize_data(batch_data[:, :, :3])

        feed_dict = {ops['pointclouds_pl']: batch_data,
                     ops['labels_pl']: batch_label,
                     ops['smpws_pl']: batch_smpw,
                     ops['is_training_pl']: is_training, }
        summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'],
                                                         ops['train_op'], ops['loss'], ops['pred']],
                                                        feed_dict=feed_dict)
        train_writer.add_summary(summary, step)
        pred_val = np.argmax(pred_val, 2)
        correct = np.sum(pred_val == batch_label)
        total_correct += correct
        total_seen += (BATCH_SIZE * NUM_POINT)
        iou_deno = 0
        for l in range(NUM_CLASSES):
            iou_deno += np.sum((pred_val == l) | (batch_label == l))
        total_iou_deno += iou_deno
        loss_sum += loss_val

    log_string('Training loss: %f' % (loss_sum / num_batches))
    log_string('Training accuracy: %f' % (total_correct / float(total_seen)))
    log_string('Training IoU: %f' % (total_correct / float(total_iou_deno)))
Esempio n. 5
0
def main(args):
    def log_string(str):
        logger.info(str)
        print(str)

    '''CREATE DIR'''
    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
    exp_dir = Path('./log/')
    exp_dir.mkdir(exist_ok=True)
    exp_dir = exp_dir.joinpath('reg_seg_heatmap_v3')
    exp_dir.mkdir(exist_ok=True)
    if args.log_dir is None:
        exp_dir = exp_dir.joinpath(timestr)
    else:
        exp_dir = exp_dir.joinpath(args.log_dir)
    exp_dir.mkdir(exist_ok=True)
    checkpoints_dir = exp_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = exp_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)
    '''LOG'''
    args = parse_args()
    logger = logging.getLogger("Model")
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    log_string('PARAMETER ...')
    log_string(args)
    '''DATA LOADING'''
    log_string('Load dataset ...')
    # Construct the dataset
    train_dataset, train_config = construct_dataset(is_train=True)
    # Random split
    train_set_size = int(len(train_dataset) * 0.8)
    valid_set_size = len(train_dataset) - train_set_size
    train_dataset, valid_dataset = torch.utils.data.random_split(
        train_dataset, [train_set_size, valid_set_size])
    # And the dataloader
    trainDataLoader = DataLoader(dataset=train_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=4)
    validDataLoader = DataLoader(dataset=valid_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=4)
    '''MODEL LOADING'''
    out_channel = args.out_channel
    model = importlib.import_module(args.model)
    shutil.copy('./models/%s.py' % args.model, str(exp_dir))
    shutil.copy('models/pointnet2_utils.py', str(exp_dir))
    shutil.copy('./train_pointnet2_reg_seg_heatmap_stepsize.py', str(exp_dir))

    #network = model.get_model(out_channel, normal_channel=args.use_normals)
    network = model.get_model(out_channel)
    criterion_rmse = RMSELoss()
    criterion_cos = torch.nn.CosineSimilarity(dim=1)
    criterion_bce = torch.nn.BCELoss()

    network.apply(inplace_relu)

    if not args.use_cpu:
        network = network.cuda()
        criterion_rmse = criterion_rmse.cuda()
        criterion_cos = criterion_cos.cuda()
    try:
        checkpoint = torch.load(str(exp_dir) + '/checkpoints/best_model.pth')
        start_epoch = checkpoint['epoch']
        network.load_state_dict(checkpoint['model_state_dict'])
        log_string('Use pretrain model')
    except:
        log_string('No existing model, starting training from scratch...')
        start_epoch = 0

    if args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(network.parameters(),
                                     lr=args.learning_rate,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=args.decay_rate)
    else:
        optimizer = torch.optim.SGD(network.parameters(),
                                    lr=0.01,
                                    momentum=0.9)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=20,
                                                gamma=0.7)
    global_epoch = 0
    global_step = 0
    best_rot_error = 99.9
    best_xyz_error = 99.9
    best_heatmap_error = 99.9
    best_step_size_error = 99.9
    '''TRANING'''
    logger.info('Start training...')
    for epoch in range(start_epoch, args.epoch):
        log_string('Epoch %d (%d/%s):' %
                   (global_epoch + 1, epoch + 1, args.epoch))
        train_rot_error = []
        train_xyz_error = []
        train_heatmap_error = []
        train_step_size_error = []
        network = network.train()

        scheduler.step()
        for batch_id, data in tqdm(enumerate(trainDataLoader, 0),
                                   total=len(trainDataLoader),
                                   smoothing=0.9):
            optimizer.zero_grad()

            points = data[parameter.pcd_key].numpy()
            points = provider.normalize_data(points)
            points = provider.random_point_dropout(points)
            points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :,
                                                                         0:3])
            points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])
            points = torch.Tensor(points)
            points = points.transpose(2, 1)
            heatmap_target = data[parameter.heatmap_key]
            segmentation_target = data[parameter.segmentation_key]
            #print('heatmap size', heatmap_target.size())
            #print('segmentation', segmentation_target.size())
            delta_rot = data[parameter.delta_rot_key]
            delta_xyz = data[parameter.delta_xyz_key]
            unit_delta_xyz = data[parameter.unit_delta_xyz_key]
            step_size = data[parameter.step_size_key]

            if not args.use_cpu:
                points = points.cuda()
                delta_rot = delta_rot.cuda()
                delta_xyz = delta_xyz.cuda()
                heatmap_target = heatmap_target.cuda()
                unit_delta_xyz = unit_delta_xyz.cuda()
                step_size = step_size.cuda()

            heatmap_pred, action_pred, step_size_pred = network(points)
            # action control
            delta_rot_pred_6d = action_pred[:, 0:6]
            delta_rot_pred = compute_rotation_matrix_from_ortho6d(
                delta_rot_pred_6d, args.use_cpu)  # batch*3*3
            delta_xyz_pred = action_pred[:, 6:9].view(-1, 3)  # batch*3

            # loss computation
            loss_heatmap = criterion_rmse(heatmap_pred, heatmap_target)
            loss_r = criterion_rmse(delta_rot_pred, delta_rot)
            #loss_t = (1-criterion_cos(delta_xyz_pred, delta_xyz)).mean() + criterion_rmse(delta_xyz_pred, delta_xyz)
            loss_t = (1 - criterion_cos(delta_xyz_pred, unit_delta_xyz)).mean()
            loss_step_size = criterion_bce(step_size_pred, step_size)
            loss = loss_r + loss_t + loss_heatmap + loss_step_size
            loss.backward()
            optimizer.step()
            global_step += 1

            train_rot_error.append(loss_r.item())
            train_xyz_error.append(loss_t.item())
            train_heatmap_error.append(loss_heatmap.item())
            train_step_size_error.append(loss_step_size.item())

        train_rot_error = sum(train_rot_error) / len(train_rot_error)
        train_xyz_error = sum(train_xyz_error) / len(train_xyz_error)
        train_heatmap_error = sum(train_heatmap_error) / len(
            train_heatmap_error)
        train_step_size_error = sum(train_step_size_error) / len(
            train_step_size_error)
        log_string('Train Rotation Error: %f' % train_rot_error)
        log_string('Train Translation Error: %f' % train_xyz_error)
        log_string('Train Heatmap Error: %f' % train_xyz_error)
        log_string('Train Step size Error: %f' % train_step_size_error)

        with torch.no_grad():
            rot_error, xyz_error, heatmap_error, step_size_error = test(
                network.eval(), validDataLoader, out_channel, criterion_rmse,
                criterion_cos, criterion_bce)

            log_string(
                'Test Rotation Error: %f, Translation Error: %f, Heatmap Error: %f, Step size Error: %f'
                % (rot_error, xyz_error, heatmap_error, step_size_error))
            log_string(
                'Best Rotation Error: %f, Translation Error: %f, Heatmap Error: %f, Step size Error: %f'
                % (best_rot_error, best_xyz_error, best_heatmap_error,
                   best_step_size_error))

            if (rot_error + xyz_error + heatmap_error + step_size_error) < (
                    best_rot_error + best_xyz_error + best_heatmap_error +
                    best_step_size_error):
                best_rot_error = rot_error
                best_xyz_error = xyz_error
                best_heatmap_error = heatmap_error
                best_step_size_error = step_size_error
                best_epoch = epoch + 1
                logger.info('Save model...')
                savepath = str(checkpoints_dir) + '/best_model.pth'
                log_string('Saving at %s' % savepath)
                state = {
                    'epoch': best_epoch,
                    'rot_error': rot_error,
                    'xyz_error': xyz_error,
                    'heatmap_error': heatmap_error,
                    'step_size_error': step_size_error,
                    'model_state_dict': network.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }
                torch.save(state, savepath)
            global_epoch += 1

    logger.info('End of training...')
def main(args):
    def log_string(str):
        logger.info(str)
        print(str)

    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    '''CREATE DIR'''
    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
    experiment_dir = Path('./log/')
    experiment_dir.mkdir(exist_ok=True)
    experiment_dir = experiment_dir.joinpath('sem_seg')
    experiment_dir.mkdir(exist_ok=True)
    if args.log_dir is None:
        experiment_dir = experiment_dir.joinpath(timestr)
    else:
        experiment_dir = experiment_dir.joinpath(args.log_dir)
    experiment_dir.mkdir(exist_ok=True)
    checkpoints_dir = experiment_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = experiment_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)

    '''LOG'''
    args = parse_args()
    logger = logging.getLogger("Model")
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    log_string('PARAMETER ...')
    log_string(args)

    root = 'data/stanford_indoor3d/'

    NUM_CLASSES = 13
    NUM_POINT = args.npoint
    BATCH_SIZE = args.batch_size
    FEATURE_CHANNEL = 3 if args.with_rgb else 0

    print("start loading training data ...")
    TRAIN_DATASET = S3DISDataset(root, split='train', with_rgb=args.with_rgb, test_area=args.test_area, block_points=NUM_POINT)
    print("start loading test data ...")
    TEST_DATASET = S3DISDataset(root, split='test', with_rgb=args.with_rgb, test_area=args.test_area, block_points=NUM_POINT)
    print("start loading whole scene validation data ...")
    TEST_DATASET_WHOLE_SCENE = S3DISDatasetWholeScene(root, split='test', with_rgb=args.with_rgb, test_area=args.test_area, block_points=NUM_POINT)
    trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
    testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
    weights = TRAIN_DATASET.labelweights
    weights = torch.Tensor(weights).cuda()

    log_string("The number of training data is: %d" % len(TRAIN_DATASET))
    log_string("The number of test data is: %d" %  len(TEST_DATASET_WHOLE_SCENE))

    '''MODEL LOADING'''
    MODEL = importlib.import_module(args.model)
    shutil.copy('models/%s.py' % args.model, str(experiment_dir))
    shutil.copy('models/pointnet_util.py', str(experiment_dir))

    classifier = MODEL.get_model(NUM_CLASSES, with_rgb=args.with_rgb).cuda()
    criterion = MODEL.get_loss().cuda()

    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv2d') != -1:
            torch.nn.init.xavier_normal_(m.weight.data)
            torch.nn.init.constant_(m.bias.data, 0.0)
        elif classname.find('Linear') != -1:
            torch.nn.init.xavier_normal_(m.weight.data)
            torch.nn.init.constant_(m.bias.data, 0.0)

    try:
        checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth')
        start_epoch = checkpoint['epoch']
        classifier.load_state_dict(checkpoint['model_state_dict'])
        log_string('Use pretrain model')
    except:
        log_string('No existing model, starting training from scratch...')
        start_epoch = 0
        classifier = classifier.apply(weights_init)

    if args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(
            classifier.parameters(),
            lr=args.learning_rate,
            betas=(0.9, 0.999),
            eps=1e-08,
            weight_decay=args.decay_rate
        )
    else:
        optimizer = torch.optim.SGD(classifier.parameters(), lr=args.learning_rate, momentum=0.9)

    def bn_momentum_adjust(m, momentum):
        if isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d):
            m.momentum = momentum

    LEARNING_RATE_CLIP = 1e-5
    MOMENTUM_ORIGINAL = 0.1
    MOMENTUM_DECCAY = 0.5
    MOMENTUM_DECCAY_STEP = args.step_size

    global_epoch = 0
    best_iou = 0

    for epoch in range(start_epoch,args.epoch):
        '''Train on chopped scenes'''
        log_string('**** Epoch %d (%d/%s) ****' % (global_epoch + 1, epoch + 1, args.epoch))
        lr = max(args.learning_rate * (args.lr_decay ** (epoch // args.step_size)), LEARNING_RATE_CLIP)
        log_string('Learning rate:%f' % lr)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        momentum = MOMENTUM_ORIGINAL * (MOMENTUM_DECCAY ** (epoch // MOMENTUM_DECCAY_STEP))
        if momentum < 0.01:
            momentum = 0.01
        print('BN momentum updated to: %f' % momentum)
        classifier = classifier.apply(lambda x: bn_momentum_adjust(x,momentum))
        num_batches = len(trainDataLoader)
        total_correct = 0
        total_seen = 0
        loss_sum = 0
        for i, data in tqdm(enumerate(trainDataLoader), total=len(trainDataLoader), smoothing=0.9):
            points, target, _ = data
            points = points.data.numpy()
            points[:, :, :3] = provider.normalize_data(points[:, :, :3])
            points[:,:, :3] = provider.random_scale_point_cloud(points[:,:, :3])
            points[:,:, :3] = provider.rotate_point_cloud_z(points[:,:, :3])
            points = torch.Tensor(points)
            points, target = points.float().cuda(),target.long().cuda()
            points = points.transpose(2, 1)
            optimizer.zero_grad()
            classifier = classifier.train()
            seg_pred, trans_feat = classifier(points)
            seg_pred = seg_pred.contiguous().view(-1, NUM_CLASSES)
            batch_label = target.view(-1, 1)[:, 0].cpu().data.numpy()
            target = target.view(-1, 1)[:, 0]
            loss = criterion(seg_pred, target, trans_feat, weights)
            loss.backward()
            optimizer.step()
            pred_choice = seg_pred.cpu().data.max(1)[1].numpy()
            correct = np.sum(pred_choice == batch_label)
            total_correct += correct
            total_seen += (BATCH_SIZE * NUM_POINT)
            loss_sum += loss
        log_string('Training mean loss: %f' % (loss_sum / num_batches))
        log_string('Training accuracy: %f' % (total_correct / float(total_seen)))

        if epoch % 10 == 0 and epoch < 800:
            logger.info('Save model...')
            savepath = str(checkpoints_dir) + '/best_model.pth'
            log_string('Saving at %s' % savepath)
            state = {
                'epoch': epoch,
                'model_state_dict': classifier.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }
            torch.save(state, savepath)
            log_string('Saving model....')

        '''Evaluate on chopped scenes'''
        with torch.no_grad():
            num_batches = len(testDataLoader)
            total_correct = 0
            total_seen = 0
            loss_sum = 0
            log_string('---- EPOCH %03d EVALUATION ----' % (global_epoch + 1))
            for i, data in tqdm(enumerate(testDataLoader), total=len(testDataLoader), smoothing=0.9):
                points, target, _ = data
                points = points.data.numpy()
                points[:, :, :3] = provider.normalize_data(points[:, :, :3])
                points = torch.Tensor(points)
                points, target = points.float().cuda(), target.long().cuda()
                points = points.transpose(2, 1)
                classifier = classifier.eval()
                seg_pred, trans_feat = classifier(points)
                seg_pred = seg_pred.contiguous().view(-1, NUM_CLASSES)
                target = target.view(-1, 1)[:, 0]
                loss = criterion(seg_pred, target, trans_feat, weights)
                loss_sum += loss
                batch_label = target.cpu().data.numpy()
                pred_choice = seg_pred.cpu().data.max(1)[1].numpy()
                correct = np.sum(pred_choice == batch_label)
                total_correct += correct
                total_seen += (BATCH_SIZE * NUM_POINT)
            log_string('Eval mean loss: %f' % (loss_sum / num_batches))
            log_string('Eval accuracy: %f' % (total_correct / float(total_seen)))

        '''Evaluate on whole scenes'''
        if epoch % 5 ==0 and epoch > 800:
            with torch.no_grad():
                num_batches = len(TEST_DATASET_WHOLE_SCENE)
                log_string('---- EPOCH %03d EVALUATION WHOLE SCENE----' % (global_epoch + 1))
                total_correct = 0
                total_seen = 0
                loss_sum = 0
                total_seen_class = [0 for _ in range(NUM_CLASSES)]
                total_correct_class = [0 for _ in range(NUM_CLASSES)]
                total_iou_deno_class = [0 for _ in range(NUM_CLASSES)]

                labelweights = np.zeros(NUM_CLASSES)
                is_continue_batch = False

                extra_batch_data = np.zeros((0, NUM_POINT, 3 + FEATURE_CHANNEL))
                extra_batch_label = np.zeros((0, NUM_POINT))
                extra_batch_smpw = np.zeros((0, NUM_POINT))
                for batch_idx in tqdm(range(num_batches),total=num_batches):
                    if not is_continue_batch:
                        batch_data, batch_label, batch_smpw = TEST_DATASET_WHOLE_SCENE[batch_idx]
                        batch_data = np.concatenate((batch_data, extra_batch_data), axis=0)
                        batch_label = np.concatenate((batch_label, extra_batch_label), axis=0)
                        batch_smpw = np.concatenate((batch_smpw, extra_batch_smpw), axis=0)
                    else:
                        batch_data_tmp, batch_label_tmp, batch_smpw_tmp = TEST_DATASET_WHOLE_SCENE[batch_idx]
                        batch_data = np.concatenate((batch_data, batch_data_tmp), axis=0)
                        batch_label = np.concatenate((batch_label, batch_label_tmp), axis=0)
                        batch_smpw = np.concatenate((batch_smpw, batch_smpw_tmp), axis=0)
                    if batch_data.shape[0] < BATCH_SIZE:
                        is_continue_batch = True
                        continue
                    elif batch_data.shape[0] == BATCH_SIZE:
                        is_continue_batch = False
                        extra_batch_data = np.zeros((0, NUM_POINT, 3 + FEATURE_CHANNEL))
                        extra_batch_label = np.zeros((0, NUM_POINT))
                        extra_batch_smpw = np.zeros((0, NUM_POINT))
                    else:
                        is_continue_batch = False
                        extra_batch_data = batch_data[BATCH_SIZE:, :, :]
                        extra_batch_label = batch_label[BATCH_SIZE:, :]
                        extra_batch_smpw = batch_smpw[BATCH_SIZE:, :]
                        batch_data = batch_data[:BATCH_SIZE, :, :]
                        batch_label = batch_label[:BATCH_SIZE, :]
                        batch_smpw = batch_smpw[:BATCH_SIZE, :]

                    batch_data[:, :, :3] = provider.normalize_data(batch_data[:, :, :3])
                    batch_label = torch.Tensor(batch_label)
                    batch_data = torch.Tensor(batch_data)
                    batch_data, batch_label = batch_data.float().cuda(), batch_label.long().cuda()
                    batch_data = batch_data.transpose(2, 1)
                    classifier = classifier.eval()
                    seg_pred, _ = classifier(batch_data)
                    seg_pred = seg_pred.contiguous()
                    batch_label = batch_label.cpu().data.numpy()
                    pred_val = seg_pred.cpu().data.max(2)[1].numpy()
                    correct = np.sum((pred_val == batch_label) & (batch_smpw > 0))
                    total_correct += correct
                    total_seen += np.sum(batch_smpw > 0)
                    tmp, _ = np.histogram(batch_label, range(NUM_CLASSES + 1))
                    labelweights += tmp
                    for l in range(NUM_CLASSES):
                        total_seen_class[l] += np.sum((batch_label == l) & (batch_smpw > 0))
                        total_correct_class[l] += np.sum((pred_val == l) & (batch_label == l) & (batch_smpw > 0))
                        total_iou_deno_class[l] += np.sum(((pred_val == l) | (batch_label == l)) & (batch_smpw > 0))

                mIoU = np.mean(np.array(total_correct_class) / (np.array(total_iou_deno_class, dtype=np.float) + 1e-6))
                log_string('eval whole scene mean loss: %f' % (loss_sum / float(num_batches)))
                log_string('eval point avg class IoU: %f' % mIoU)
                log_string('eval whole scene point accuracy: %f' % (total_correct / float(total_seen)))
                log_string('eval whole scene point avg class acc: %f' % (
                    np.mean(np.array(total_correct_class) / (np.array(total_seen_class, dtype=np.float) + 1e-6))))
                labelweights = labelweights.astype(np.float32) / np.sum(labelweights.astype(np.float32))

                iou_per_class_str = '------- IoU --------\n'
                for l in range(NUM_CLASSES):
                    iou_per_class_str += 'class %s weight: %.3f, IoU: %.3f \n' % (
                        seg_label_to_cat[l] + ' ' * (14 - len(seg_label_to_cat[l])), labelweights[l],
                        total_correct_class[l] / float(total_iou_deno_class[l]))
                log_string(iou_per_class_str)

                if (mIoU >= best_iou):
                    logger.info('Save model...')
                    savepath = str(checkpoints_dir) + '/best_model.pth'
                    log_string('Saving at %s' % savepath)
                    state = {
                        'epoch': epoch,
                        'class_avg_iou': mIoU,
                        'model_state_dict': classifier.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                    }
                    torch.save(state, savepath)
                    log_string('Saving model....')

        global_epoch += 1
Esempio n. 7
0
def eval_whole_scene_one_epoch(sess, ops, test_writer):
    """ ops: dict mapping from string to tf ops """
    global EPOCH_CNT_WHOLE
    is_training = False
    num_batches = len(TEST_DATASET_WHOLE_SCENE)

    total_correct = 0
    total_seen = 0
    loss_sum = 0
    total_seen_class = [0 for _ in range(NUM_CLASSES)]
    total_correct_class = [0 for _ in range(NUM_CLASSES)]
    total_iou_deno_class = [0 for _ in range(NUM_CLASSES)]

    log_string(str(datetime.now()))
    log_string('---- EPOCH %03d EVALUATION WHOLE SCENE----' % (EPOCH_CNT_WHOLE))

    labelweights = np.zeros(21)
    is_continue_batch = False

    extra_batch_data = np.zeros((0, NUM_POINT, 3 + feature_channel))
    extra_batch_label = np.zeros((0, NUM_POINT))
    extra_batch_smpw = np.zeros((0, NUM_POINT))

    for batch_idx in tqdm(range(num_batches),total=num_batches):
        if not is_continue_batch:
            batch_data, batch_label, batch_smpw = TEST_DATASET_WHOLE_SCENE[batch_idx]
            batch_data = np.concatenate((batch_data, extra_batch_data), axis=0)
            batch_label = np.concatenate((batch_label, extra_batch_label), axis=0)
            batch_smpw = np.concatenate((batch_smpw, extra_batch_smpw), axis=0)
        else:
            batch_data_tmp, batch_label_tmp, batch_smpw_tmp = TEST_DATASET_WHOLE_SCENE[batch_idx]
            batch_data = np.concatenate((batch_data, batch_data_tmp), axis=0)
            batch_label = np.concatenate((batch_label, batch_label_tmp), axis=0)
            batch_smpw = np.concatenate((batch_smpw, batch_smpw_tmp), axis=0)
        if batch_data.shape[0] < BATCH_SIZE:
            is_continue_batch = True
            continue
        elif batch_data.shape[0] == BATCH_SIZE:
            is_continue_batch = False
            extra_batch_data = np.zeros((0, NUM_POINT, 3))
            extra_batch_label = np.zeros((0, NUM_POINT))
            extra_batch_smpw = np.zeros((0, NUM_POINT))
        else:
            is_continue_batch = False
            extra_batch_data = batch_data[BATCH_SIZE:, :, :]
            extra_batch_label = batch_label[BATCH_SIZE:, :]
            extra_batch_smpw = batch_smpw[BATCH_SIZE:, :]
            batch_data = batch_data[:BATCH_SIZE, :, :]
            batch_label = batch_label[:BATCH_SIZE, :]
            batch_smpw = batch_smpw[:BATCH_SIZE, :]

        batch_data[:, :, :3] = provider.normalize_data(batch_data[:, :, :3])
        feed_dict = {ops['pointclouds_pl']: batch_data,
                     ops['labels_pl']: batch_label,
                     ops['smpws_pl']: batch_smpw,
                     ops['is_training_pl']: is_training}
        summary, step, loss_val, pred_val = sess.run([ops['merged'], ops['step'], ops['loss'], ops['pred']], feed_dict=feed_dict)
        test_writer.add_summary(summary, step)
        pred_val = np.argmax(pred_val, 2)  # BxN
        correct = np.sum((pred_val == batch_label) & (batch_label > 0) & (
                    batch_smpw > 0))  # evaluate only on 20 categories but not unknown
        total_correct += correct
        total_seen += np.sum((batch_label > 0) & (batch_smpw > 0))
        loss_sum += loss_val
        tmp, _ = np.histogram(batch_label, range(22))
        labelweights += tmp
        for l in range(NUM_CLASSES):
            total_seen_class[l] += np.sum((batch_label == l) & (batch_smpw > 0))
            total_correct_class[l] += np.sum((pred_val == l) & (batch_label == l) & (batch_smpw > 0))
            total_iou_deno_class[l] += np.sum(((pred_val == l) | (batch_label == l)) & (batch_smpw > 0))

    mIoU = np.mean(np.array(total_correct_class[1:]) / (np.array(total_iou_deno_class[1:], dtype=np.float) + 1e-6))
    log_string('Eval whole scene mean loss: %f' % (loss_sum / float(num_batches)))
    log_string('Eval point avg class IoU: %f' % mIoU)
    log_string('Eval whole scene point accuracy: %f' % (total_correct / float(total_seen)))
    log_string('Eval whole scene point avg class acc: %f' % (
        np.mean(np.array(total_correct_class[1:]) / (np.array(total_seen_class[1:], dtype=np.float) + 1e-6))))
    labelweights = labelweights[1:].astype(np.float32) / np.sum(labelweights[1:].astype(np.float32))

    iou_per_class_str = '------- IoU --------\n'
    for l in range(1, NUM_CLASSES):
        iou_per_class_str += 'class %s weight: %.3f, IoU: %.3f \n' % (
        seg_label_to_cat[l] + ' ' * (14 - len(seg_label_to_cat[l])), labelweights[l - 1],
        total_correct_class[l] / float(total_iou_deno_class[l]))
    log_string(iou_per_class_str)

    EPOCH_CNT_WHOLE += 1
    return mIoU
Esempio n. 8
0
def main(args):
    def log_string(str):
        logger.info(str)
        print(str)

    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    experiment_dir = 'log/sem_seg/' + args.log_dir
    visual_dir = experiment_dir + '/visual/'
    visual_dir = Path(visual_dir)
    visual_dir.mkdir(exist_ok=True)
    '''LOG'''
    args = parse_args()
    logger = logging.getLogger("Model")
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/eval.txt' % experiment_dir)
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    log_string('PARAMETER ...')
    log_string(args)

    NUM_CLASSES = 13
    WITH_RGB = args.with_rgb
    BATCH_SIZE = args.batch_size
    NUM_POINT = args.num_point

    root = 'data/stanford_indoor3d/'

    TEST_DATASET_WHOLE_SCENE = ScannetDatasetWholeScene_evaluation(
        root,
        split='test',
        with_rgb=WITH_RGB,
        test_area=args.test_area,
        block_points=NUM_POINT)
    log_string("The number of test data is: %d" %
               len(TEST_DATASET_WHOLE_SCENE))
    '''MODEL LOADING'''
    model_name = os.listdir(experiment_dir + '/logs')[0].split('.')[0]
    MODEL = importlib.import_module(model_name)
    classifier = MODEL.get_model(NUM_CLASSES, with_rgb=WITH_RGB).cuda()
    checkpoint = torch.load(
        str(experiment_dir) + '/checkpoints/best_model.pth')
    classifier.load_state_dict(checkpoint['model_state_dict'])

    with torch.no_grad():
        scene_id = TEST_DATASET_WHOLE_SCENE.file_list
        scene_id = [x[:-4] for x in scene_id]
        num_batches = len(TEST_DATASET_WHOLE_SCENE)

        total_seen_class = [0 for _ in range(NUM_CLASSES)]
        total_correct_class = [0 for _ in range(NUM_CLASSES)]
        total_iou_deno_class = [0 for _ in range(NUM_CLASSES)]

        log_string('---- EVALUATION WHOLE SCENE----')

        for batch_idx in range(num_batches):
            print("visualize %d %s ..." % (batch_idx, scene_id[batch_idx]))
            total_seen_class_tmp = [0 for _ in range(NUM_CLASSES)]
            total_correct_class_tmp = [0 for _ in range(NUM_CLASSES)]
            total_iou_deno_class_tmp = [0 for _ in range(NUM_CLASSES)]
            if args.visual:
                fout = open(
                    os.path.join(visual_dir,
                                 scene_id[batch_idx] + '_pred.obj'), 'w')
                fout_gt = open(
                    os.path.join(visual_dir, scene_id[batch_idx] + '_gt.obj'),
                    'w')

            whole_scene_data = TEST_DATASET_WHOLE_SCENE.scene_points_list[
                batch_idx]
            whole_scene_label = TEST_DATASET_WHOLE_SCENE.semantic_labels_list[
                batch_idx]
            vote_label_pool = np.zeros(
                (whole_scene_label.shape[0], NUM_CLASSES))
            for _ in tqdm(range(args.num_votes), total=args.num_votes):
                scene_data, scene_label, scene_smpw, scene_point_index = TEST_DATASET_WHOLE_SCENE[
                    batch_idx]
                num_blocks = scene_data.shape[0]
                s_batch_num = (num_blocks + BATCH_SIZE - 1) // BATCH_SIZE
                if WITH_RGB:
                    batch_data = np.zeros((BATCH_SIZE, NUM_POINT, 6))
                else:
                    batch_data = np.zeros((BATCH_SIZE, NUM_POINT, 3))
                batch_label = np.zeros((BATCH_SIZE, NUM_POINT))
                batch_point_index = np.zeros((BATCH_SIZE, NUM_POINT))
                batch_smpw = np.zeros((BATCH_SIZE, NUM_POINT))
                for sbatch in range(s_batch_num):
                    start_idx = sbatch * BATCH_SIZE
                    end_idx = min((sbatch + 1) * BATCH_SIZE, num_blocks)
                    real_batch_size = end_idx - start_idx
                    batch_data[0:real_batch_size,
                               ...] = scene_data[start_idx:end_idx, ...]
                    batch_label[0:real_batch_size,
                                ...] = scene_label[start_idx:end_idx, ...]
                    batch_point_index[0:real_batch_size,
                                      ...] = scene_point_index[
                                          start_idx:end_idx, ...]
                    batch_smpw[0:real_batch_size,
                               ...] = scene_smpw[start_idx:end_idx, ...]

                    if WITH_RGB:
                        batch_data[:, :, 3:6] /= 1.0

                    batch_data[:, :, :3] = provider.normalize_data(
                        batch_data[:, :, :3])

                    batch_data = torch.Tensor(batch_data)
                    batch_data = batch_data.float().cuda()
                    batch_data = batch_data.transpose(2, 1)
                    seg_pred, _ = classifier(batch_data)
                    batch_pred_label = seg_pred.contiguous().cpu().data.max(
                        2)[1].numpy()

                    vote_label_pool = add_vote(
                        vote_label_pool, batch_point_index[0:real_batch_size,
                                                           ...],
                        batch_pred_label[0:real_batch_size, ...],
                        batch_smpw[0:real_batch_size, ...])

            pred_label = np.argmax(vote_label_pool, 1)

            for l in range(NUM_CLASSES):
                total_seen_class_tmp[l] += np.sum((whole_scene_label == l))
                total_correct_class_tmp[l] += np.sum((pred_label == l) &
                                                     (whole_scene_label == l))
                total_iou_deno_class_tmp[l] += np.sum(
                    ((pred_label == l) | (whole_scene_label == l)))
                total_seen_class[l] += total_seen_class_tmp[l]
                total_correct_class[l] += total_correct_class_tmp[l]
                total_iou_deno_class[l] += total_iou_deno_class_tmp[l]

            iou_map = np.array(total_correct_class_tmp) / (
                np.array(total_iou_deno_class_tmp, dtype=np.float) + 1e-6)
            print(iou_map)
            arr = np.array(total_seen_class_tmp)
            tmp_iou = np.mean(iou_map[arr != 0])
            log_string('Mean IoU of %s: %.4f' % (scene_id[batch_idx], tmp_iou))
            print('----------------------------')

            filename = os.path.join(visual_dir, scene_id[batch_idx] + '.txt')
            with open(filename, 'w') as pl_save:
                for i in pred_label:
                    pl_save.write(str(int(i)) + '\n')
                pl_save.close()
            for i in range(whole_scene_label.shape[0]):
                color = g_label2color[pred_label[i]]
                color_gt = g_label2color[whole_scene_label[i]]
                if args.visual:
                    fout.write(
                        'v %f %f %f %d %d %d\n' %
                        (whole_scene_data[i, 0], whole_scene_data[i, 1],
                         whole_scene_data[i, 2], color[0], color[1], color[2]))
                    fout_gt.write(
                        'v %f %f %f %d %d %d\n' %
                        (whole_scene_data[i, 0], whole_scene_data[i, 1],
                         whole_scene_data[i, 2], color_gt[0], color_gt[1],
                         color_gt[2]))
            if args.visual:
                fout.close()
                fout_gt.close()

        IoU = np.array(total_correct_class) / (
            np.array(total_iou_deno_class, dtype=np.float) + 1e-6)
        iou_per_class_str = '------- IoU --------\n'
        for l in range(NUM_CLASSES):
            iou_per_class_str += 'class %s, IoU: %.3f \n' % (
                seg_label_to_cat[l] + ' ' * (14 - len(seg_label_to_cat[l])),
                total_correct_class[l] / float(total_iou_deno_class[l]))
        log_string(iou_per_class_str)
        log_string('eval point avg class IoU: %f' % np.mean(IoU))
        log_string('eval whole scene point avg class acc: %f' % (np.mean(
            np.array(total_correct_class) /
            (np.array(total_seen_class, dtype=np.float) + 1e-6))))
        log_string('eval whole scene point accuracy: %f' %
                   (np.sum(total_correct_class) /
                    float(np.sum(total_seen_class) + 1e-6)))

        print("Done!")
Esempio n. 9
0
def eval_one_epoch(sess, ops, num_votes=1, NUM_NOISY_POINT=0):
    is_training = False

    # Make sure batch data is of same size
    cur_batch_data = np.zeros((BATCH_SIZE,NUM_POINT,TEST_DATASET.num_channel()))
    cur_batch_label = np.zeros((BATCH_SIZE), dtype=np.int32)
    num_batch = int(len(TEST_DATASET) / BATCH_SIZE)

    total_correct = 0
    total_object = 0
    total_seen = 0
    loss_sum = 0
    batch_idx = 0

    total_seen_class = [0 for _ in range(NUM_CLASSES)]
    total_correct_class = [0 for _ in range(NUM_CLASSES)]

    with tqdm(total=num_batch) as pbar:
        while TEST_DATASET.has_next_batch():
            batch_data, batch_label = TEST_DATASET.next_batch()
            # for the last batch in the epoch, the bsize:end are from last batch
            bsize = batch_data.shape[0]

            # noisy robustness
            if NUM_NOISY_POINT > 0:
                noisy_point = np.random.random((bsize, NUM_NOISY_POINT, 3))
                noisy_point = provider.normalize_data(noisy_point)
                batch_data[:bsize, :NUM_NOISY_POINT, :3] = noisy_point

            loss_vote = 0
            cur_batch_data[0:bsize,...] = batch_data
            cur_batch_label[0:bsize] = batch_label

            batch_pred_sum = np.zeros((BATCH_SIZE, NUM_CLASSES)) # score for classes
            for vote_idx in range(num_votes):
                # Shuffle point order to achieve different farthest samplings
                shuffled_indices = np.arange(NUM_POINT)
                np.random.shuffle(shuffled_indices)

                feed_dict = {ops['pointclouds_pl']: cur_batch_data,
                             ops['labels_pl']: cur_batch_label,
                             ops['is_training_pl']: is_training}
                loss_val, pred_val = sess.run([ops['loss'], ops['pred']], feed_dict=feed_dict)
                batch_pred_sum += pred_val
                loss_vote += loss_val
            loss_vote /= num_votes

            pred_val = np.argmax(batch_pred_sum, 1)
            correct = np.sum(pred_val[0:bsize] == batch_label[0:bsize])
            total_correct += correct
            total_seen += bsize
            loss_sum += loss_vote
            batch_idx += 1
            total_object += BATCH_SIZE
            for i in range(bsize):
                l = batch_label[i]
                total_seen_class[l] += 1
                total_correct_class[l] += (pred_val[i] == l)

            pbar.update(1)

    log_string('Eval mean loss: %f' % (loss_sum / float(total_object)))
    log_string('Eval accuracy: %f'% (total_correct / float(total_seen)))
    log_string('Eval avg class acc: %f' % (np.mean(np.array(total_correct_class)/np.array(total_seen_class,dtype=np.float))))

    class_accuracies = np.array(total_correct_class)/np.array(total_seen_class,dtype=np.float)
    for i, name in enumerate(SHAPE_NAMES):
        log_string('%10s:\t%0.3f' % (name, class_accuracies[i]))
    TEST_DATASET.reset()
    return total_correct / float(total_seen)