def _augment_batch_data(self, batch_data): if self.normal_channel: rotated_data = provider.rotate_point_cloud_with_normal(batch_data) rotated_data = provider.rotate_perturbation_point_cloud_with_normal(rotated_data) else: rotated_data = provider.rotate_point_cloud(batch_data) rotated_data = provider.rotate_perturbation_point_cloud(rotated_data) jittered_data = provider.random_scale_point_cloud(rotated_data[:,:,0:3]) jittered_data = provider.shift_point_cloud(jittered_data) jittered_data = provider.jitter_point_cloud(jittered_data) rotated_data[:,:,0:3] = jittered_data return provider.shuffle_points(rotated_data)
def augment_batch_data(batch_data): if FLAGS.normal: rotated_data = provider.rotate_point_cloud_with_normal(batch_data) rotated_data = provider.rotate_perturbation_point_cloud_with_normal( rotated_data) else: rotated_data = provider.rotate_point_cloud(batch_data) rotated_data = provider.rotate_perturbation_point_cloud(rotated_data) jittered_data = provider.random_scale_point_cloud(rotated_data[:, :, 0:3]) jittered_data = provider.shift_point_cloud(jittered_data) jittered_data = provider.jitter_point_cloud(jittered_data) rotated_data[:, :, 0:3] = jittered_data return rotated_data
def _augment_batch_data(self, batch_data, augment, rotate=0): if augment: #augment points jittered_data = provider.random_scale_point_cloud(batch_data[:, :, 0:3]) jittered_data = provider.shift_point_cloud(jittered_data) jittered_data = provider.jitter_point_cloud(jittered_data) batch_data[:, :, 0:3] = jittered_data if rotate == 2: #rotated points and normal batch_data = provider.rotate_point_cloud_with_normal(batch_data) elif rotate == 3: batch_data = provider.rotate_perturbation_point_cloud_with_normal( batch_data) return provider.shuffle_points(batch_data)
def get_batch(dataset, idxs, start_idx, end_idx, rotate=0): bsize = end_idx - start_idx batch_data = np.zeros((bsize, NUM_POINT, 6)) batch_label = np.zeros((bsize, NUM_POINT), dtype=np.int32) batch_cls_label = np.zeros((bsize, ), dtype=np.int32) for i in range(bsize): ps, normal, seg, cls = dataset[idxs[i + start_idx]] batch_data[i, :, 0:3] = ps batch_data[i, :, 3:6] = normal batch_label[i, :] = seg batch_cls_label[i] = cls if rotate == 2: # rotated points and normal batch_data = provider.rotate_point_cloud_with_normal(batch_data) elif rotate == 3: batch_data = provider.rotate_perturbation_point_cloud_with_normal( batch_data) return batch_data, batch_label, batch_cls_label
def augment_batch_data_MODELNET(batch_data, is_include_normal): ''' is_include_normal=False: xyz is_include_normal=True: xyznxnynz ''' if is_include_normal: rotated_data = provider.rotate_point_cloud_with_normal(batch_data) rotated_data = provider.rotate_perturbation_point_cloud_with_normal( rotated_data) else: rotated_data = provider.rotate_point_cloud(batch_data) rotated_data = provider.rotate_perturbation_point_cloud(rotated_data) jittered_data = provider.random_scale_point_cloud(rotated_data[:, :, 0:3]) jittered_data = provider.shift_point_cloud(jittered_data) jittered_data = provider.jitter_point_cloud(jittered_data) rotated_data[:, :, 0:3] = jittered_data return provider.shuffle_points(rotated_data)
def _augment_batch_data(self, batch_data): if self.normal_channel: rotated_data = provider.rotate_point_cloud_with_normal(batch_data) rotated_data = provider.rotate_perturbation_point_cloud_with_normal( rotated_data) else: rotated_data = provider.rotate_point_cloud(batch_data) rotated_data = provider.rotate_perturbation_point_cloud( rotated_data) jittered_data = provider.random_scale_point_cloud( rotated_data[:, :, 0:3], scale_low=self.scale_low, scale_high=self.scale_high) jittered_data = provider.shift_point_cloud( jittered_data, shift_range=self.shift_range) jittered_data = provider.jitter_point_cloud(jittered_data, sigma=self.jitter_sigma, clip=0.1) rotated_data[:, :, 0:3] = jittered_data if self.shuffle_points: return provider.shuffle_points(rotated_data) else: return rotated_data
def eval_one_epoch(sess, ops, test_writer): """ ops: dict mapping from string to tf ops """ global EPOCH_CNT global BEST_ACC global BEST_CLS_ACC is_training = False # Make sure batch data is of same size cur_batch_data = np.zeros((BATCH_SIZE,NUM_POINT,TEST_DATASET.num_channel())) cur_batch_label = np.zeros((BATCH_SIZE), dtype=np.int32) total_correct = 0 total_seen = 0 loss_sum = 0 batch_idx = 0 shape_ious = [] total_seen_class = [0 for _ in range(NUM_CLASSES)] total_correct_class = [0 for _ in range(NUM_CLASSES)] log_string(str(datetime.now())) log_string('---- EPOCH %03d EVALUATION ----'%(EPOCH_CNT)) while TEST_DATASET.has_next_batch(): batch_data, batch_label = TEST_DATASET.next_batch(augment=False) bsize = batch_data.shape[0] # print('Batch: %03d, batch size: %d'%(batch_idx, bsize)) # for the last batch in the epoch, the bsize:end are from last batch cur_batch_data[0:bsize,...] = batch_data cur_batch_label[0:bsize] = batch_label if ROTATE_FLAG: batch_pred_sum = np.zeros((BATCH_SIZE, NUM_CLASSES)) # score for classes for vote_idx in range(12): # Shuffle point order to achieve different farthest samplings shuffled_indices = np.arange(NUM_POINT) np.random.shuffle(shuffled_indices) if NORMAL_FLAG: rotated_data = provider.rotate_point_cloud_by_angle_with_normal(cur_batch_data[:, shuffled_indices, :], vote_idx/float(12) * np.pi * 2) rotated_data = provider.rotate_perturbation_point_cloud_with_normal(rotated_data) else: rotated_data = provider.rotate_point_cloud_by_angle(cur_batch_data[:, shuffled_indices, :], vote_idx/float(12) * np.pi * 2) rotated_data = provider.rotate_perturbation_point_cloud(rotated_data) jittered_data = provider.random_scale_point_cloud(rotated_data[:,:,0:3]) jittered_data = provider.jitter_point_cloud(jittered_data) rotated_data[:,:,0:3] = jittered_data # else: # rotated_data = provider.rotate_point_cloud_by_angle(cur_batch_data[:, shuffled_indices, :], # vote_idx/float(12) * np.pi * 2) feed_dict = {ops['pointclouds_pl']: rotated_data, ops['labels_pl']: cur_batch_label, ops['is_training_pl']: is_training} loss_val, pred_val = sess.run([ops['loss'], ops['pred']], feed_dict=feed_dict) batch_pred_sum += pred_val pred_val = np.argmax(batch_pred_sum, 1) else: feed_dict = {ops['pointclouds_pl']: cur_batch_data, ops['labels_pl']: cur_batch_label, ops['is_training_pl']: is_training} summary, step, loss_val, pred_val = sess.run([ops['merged'], ops['step'], ops['loss'], ops['pred']], feed_dict=feed_dict) test_writer.add_summary(summary, step) pred_val = np.argmax(pred_val, 1) correct = np.sum(pred_val[0:bsize] == batch_label[0:bsize]) total_correct += correct total_seen += bsize loss_sum += loss_val batch_idx += 1 for i in range(bsize): l = batch_label[i] total_seen_class[l] += 1 total_correct_class[l] += (pred_val[i] == l) current_acc = total_correct / float(total_seen) current_cls_acc = np.mean(np.array(total_correct_class)/np.array(total_seen_class,dtype=np.float)) log_string('eval mean loss: %f' % (loss_sum / float(batch_idx))) log_string('eval accuracy: %f'% (current_acc)) log_string('eval avg class acc: %f' % (current_cls_acc)) best_acc_flag, best_cls_acc_flag = False, False if current_acc > BEST_ACC: BEST_ACC = current_acc best_acc_flag = True if current_cls_acc > BEST_CLS_ACC: BEST_CLS_ACC = current_cls_acc best_cls_acc_flag = True log_string('eval best accuracy: %f'% (BEST_ACC)) log_string('eval best avg class acc: %f'% (BEST_CLS_ACC)) EPOCH_CNT += 1 TEST_DATASET.reset() return (best_acc_flag, best_cls_acc_flag)