def _augment_batch_data(self, batch_data):
     if self.normal_channel:
         rotated_data = provider.rotate_point_cloud_with_normal(batch_data)
         rotated_data = provider.rotate_perturbation_point_cloud_with_normal(rotated_data)
     else:
         rotated_data = provider.rotate_point_cloud(batch_data)
         rotated_data = provider.rotate_perturbation_point_cloud(rotated_data)
     jittered_data = provider.random_scale_point_cloud(rotated_data[:,:,0:3])
     jittered_data = provider.shift_point_cloud(jittered_data)
     jittered_data = provider.jitter_point_cloud(jittered_data)
     rotated_data[:,:,0:3] = jittered_data
     return provider.shuffle_points(rotated_data)
Beispiel #2
0
 def _augment_batch_data(self, batch_data):
     if self.normal_channel:
         rotated_data = provider.rotate_point_cloud_with_normal(batch_data)
         rotated_data = provider.rotate_perturbation_point_cloud_with_normal(rotated_data)
     else:
         rotated_data = provider.rotate_point_cloud(batch_data)
         rotated_data = provider.rotate_perturbation_point_cloud(rotated_data)
 
     jittered_data = provider.random_scale_point_cloud(rotated_data[:,:,0:3])
     jittered_data = provider.shift_point_cloud(jittered_data)
     jittered_data = provider.jitter_point_cloud(jittered_data)
     rotated_data[:,:,0:3] = jittered_data
     return provider.shuffle_points(rotated_data)
Beispiel #3
0
def augment_batch_data(batch_data):
    if FLAGS.normal:
        rotated_data = provider.rotate_point_cloud_with_normal(batch_data)
        rotated_data = provider.rotate_perturbation_point_cloud_with_normal(
            rotated_data)
    else:
        rotated_data = provider.rotate_point_cloud(batch_data)
        rotated_data = provider.rotate_perturbation_point_cloud(rotated_data)

    jittered_data = provider.random_scale_point_cloud(rotated_data[:, :, 0:3])
    jittered_data = provider.shift_point_cloud(jittered_data)
    jittered_data = provider.jitter_point_cloud(jittered_data)
    rotated_data[:, :, 0:3] = jittered_data
    return rotated_data
Beispiel #4
0
 def _augment_batch_data(self, batch_data, augment, rotate=0):
     if augment:
         #augment points
         jittered_data = provider.random_scale_point_cloud(batch_data[:, :,
                                                                      0:3])
         jittered_data = provider.shift_point_cloud(jittered_data)
         jittered_data = provider.jitter_point_cloud(jittered_data)
         batch_data[:, :, 0:3] = jittered_data
     if rotate == 2:
         #rotated points and normal
         batch_data = provider.rotate_point_cloud_with_normal(batch_data)
     elif rotate == 3:
         batch_data = provider.rotate_perturbation_point_cloud_with_normal(
             batch_data)
     return provider.shuffle_points(batch_data)
Beispiel #5
0
def get_batch(dataset, idxs, start_idx, end_idx, rotate=0):
    bsize = end_idx - start_idx
    batch_data = np.zeros((bsize, NUM_POINT, 6))
    batch_label = np.zeros((bsize, NUM_POINT), dtype=np.int32)
    batch_cls_label = np.zeros((bsize, ), dtype=np.int32)
    for i in range(bsize):
        ps, normal, seg, cls = dataset[idxs[i + start_idx]]
        batch_data[i, :, 0:3] = ps
        batch_data[i, :, 3:6] = normal
        batch_label[i, :] = seg
        batch_cls_label[i] = cls
    if rotate == 2:
        # rotated points and normal
        batch_data = provider.rotate_point_cloud_with_normal(batch_data)
    elif rotate == 3:
        batch_data = provider.rotate_perturbation_point_cloud_with_normal(
            batch_data)
    return batch_data, batch_label, batch_cls_label
Beispiel #6
0
def augment_batch_data_MODELNET(batch_data, is_include_normal):
    '''
    is_include_normal=False: xyz
    is_include_normal=True: xyznxnynz
    '''
    if is_include_normal:
        rotated_data = provider.rotate_point_cloud_with_normal(batch_data)
        rotated_data = provider.rotate_perturbation_point_cloud_with_normal(
            rotated_data)
    else:
        rotated_data = provider.rotate_point_cloud(batch_data)
        rotated_data = provider.rotate_perturbation_point_cloud(rotated_data)

    jittered_data = provider.random_scale_point_cloud(rotated_data[:, :, 0:3])
    jittered_data = provider.shift_point_cloud(jittered_data)
    jittered_data = provider.jitter_point_cloud(jittered_data)
    rotated_data[:, :, 0:3] = jittered_data
    return provider.shuffle_points(rotated_data)
Beispiel #7
0
    def _augment_batch_data(self, batch_data):
        if self.normal_channel:
            rotated_data = provider.rotate_point_cloud_with_normal(batch_data)
            rotated_data = provider.rotate_perturbation_point_cloud_with_normal(
                rotated_data)
        else:
            rotated_data = provider.rotate_point_cloud(batch_data)
            rotated_data = provider.rotate_perturbation_point_cloud(
                rotated_data)

        jittered_data = provider.random_scale_point_cloud(
            rotated_data[:, :, 0:3],
            scale_low=self.scale_low,
            scale_high=self.scale_high)
        jittered_data = provider.shift_point_cloud(
            jittered_data, shift_range=self.shift_range)
        jittered_data = provider.jitter_point_cloud(jittered_data,
                                                    sigma=self.jitter_sigma,
                                                    clip=0.1)
        rotated_data[:, :, 0:3] = jittered_data
        if self.shuffle_points:
            return provider.shuffle_points(rotated_data)
        else:
            return rotated_data
Beispiel #8
0
def eval_one_epoch(sess, ops, test_writer):
	""" ops: dict mapping from string to tf ops """
	global EPOCH_CNT
	global BEST_ACC
	global BEST_CLS_ACC

	is_training = False

	# Make sure batch data is of same size
	cur_batch_data = np.zeros((BATCH_SIZE,NUM_POINT,TEST_DATASET.num_channel()))
	cur_batch_label = np.zeros((BATCH_SIZE), dtype=np.int32)

	total_correct = 0
	total_seen = 0
	loss_sum = 0
	batch_idx = 0
	shape_ious = []
	total_seen_class = [0 for _ in range(NUM_CLASSES)]
	total_correct_class = [0 for _ in range(NUM_CLASSES)]
	
	log_string(str(datetime.now()))
	log_string('---- EPOCH %03d EVALUATION ----'%(EPOCH_CNT))

	while TEST_DATASET.has_next_batch():
		batch_data, batch_label = TEST_DATASET.next_batch(augment=False)
		bsize = batch_data.shape[0]
		# print('Batch: %03d, batch size: %d'%(batch_idx, bsize))
		# for the last batch in the epoch, the bsize:end are from last batch
		cur_batch_data[0:bsize,...] = batch_data
		cur_batch_label[0:bsize] = batch_label

		if ROTATE_FLAG:
			batch_pred_sum = np.zeros((BATCH_SIZE, NUM_CLASSES)) # score for classes
			for vote_idx in range(12):
				# Shuffle point order to achieve different farthest samplings
				shuffled_indices = np.arange(NUM_POINT)
				np.random.shuffle(shuffled_indices)
				if NORMAL_FLAG:
					rotated_data = provider.rotate_point_cloud_by_angle_with_normal(cur_batch_data[:, shuffled_indices, :],
						vote_idx/float(12) * np.pi * 2)
					rotated_data = provider.rotate_perturbation_point_cloud_with_normal(rotated_data)
				else:
					rotated_data = provider.rotate_point_cloud_by_angle(cur_batch_data[:, shuffled_indices, :],
	                                                  vote_idx/float(12) * np.pi * 2)
					rotated_data = provider.rotate_perturbation_point_cloud(rotated_data)

				jittered_data = provider.random_scale_point_cloud(rotated_data[:,:,0:3])
				
				jittered_data = provider.jitter_point_cloud(jittered_data)
				rotated_data[:,:,0:3] = jittered_data
				# else:
					# rotated_data = provider.rotate_point_cloud_by_angle(cur_batch_data[:, shuffled_indices, :],
						# vote_idx/float(12) * np.pi * 2)
				feed_dict = {ops['pointclouds_pl']: rotated_data,
							 ops['labels_pl']: cur_batch_label,
							 ops['is_training_pl']: is_training}
				loss_val, pred_val = sess.run([ops['loss'], ops['pred']], feed_dict=feed_dict)
				batch_pred_sum += pred_val
			pred_val = np.argmax(batch_pred_sum, 1)

		else:
			feed_dict = {ops['pointclouds_pl']: cur_batch_data,
					 ops['labels_pl']: cur_batch_label,
					 ops['is_training_pl']: is_training}
			summary, step, loss_val, pred_val = sess.run([ops['merged'], ops['step'],
				ops['loss'], ops['pred']], feed_dict=feed_dict)
			test_writer.add_summary(summary, step)
			pred_val = np.argmax(pred_val, 1)

		correct = np.sum(pred_val[0:bsize] == batch_label[0:bsize])
		total_correct += correct
		total_seen += bsize
		loss_sum += loss_val
		batch_idx += 1
		for i in range(bsize):
			l = batch_label[i]
			total_seen_class[l] += 1
			total_correct_class[l] += (pred_val[i] == l)
	
	current_acc = total_correct / float(total_seen)
	current_cls_acc = np.mean(np.array(total_correct_class)/np.array(total_seen_class,dtype=np.float))

	log_string('eval mean loss: %f' % (loss_sum / float(batch_idx)))
	log_string('eval accuracy: %f'% (current_acc))
	log_string('eval avg class acc: %f' % (current_cls_acc))

	best_acc_flag, best_cls_acc_flag = False, False
	if current_acc > BEST_ACC:
		BEST_ACC = current_acc
		best_acc_flag = True

	if current_cls_acc > BEST_CLS_ACC:
		BEST_CLS_ACC = current_cls_acc
		best_cls_acc_flag = True

	log_string('eval best accuracy: %f'% (BEST_ACC))
	log_string('eval best avg class acc: %f'% (BEST_CLS_ACC))

	EPOCH_CNT += 1

	TEST_DATASET.reset()
	return (best_acc_flag, best_cls_acc_flag)