def train_one_epoch(sess, ops, train_writer): """ ops: dict mapping from string to tf ops """ is_training = True # Shuffle train files train_file_idxs = np.arange(0, len(TRAIN_FILES)) # 每一次的5个train文件的顺序都是不一样的 np.random.shuffle(train_file_idxs) for fn in range(len(TRAIN_FILES)): # 对每一个train文件 log_string('----train file' + str(fn) + '-----') current_data, current_label = provider.loadDataFile( TRAIN_FILES[train_file_idxs[fn]]) current_data = current_data[:, 0: NUM_POINT, :] # 采样1024个点,current代表这个文件中所有的点云 current_data, current_label, _ = provider.shuffle_data( current_data, np.squeeze(current_label)) # 每个点云之间的顺序被打乱 current_label = np.squeeze(current_label) # 移除数组中单一维度的数据 file_size = current_data.shape[0] # 点云的总数 num_batches = file_size // BATCH_SIZE # 需要几个batch total_correct = 0 total_seen = 0 loss_sum = 0 for batch_idx in range(num_batches): start_idx = batch_idx * BATCH_SIZE end_idx = (batch_idx + 1) * BATCH_SIZE # Augment batched point clouds by rotation and jittering rotated_data = provider.rotate_point_cloud( current_data[start_idx:end_idx, :, :]) jittered_data = provider.jitter_point_cloud(rotated_data) jittered_data = provider.random_scale_point_cloud(jittered_data) jittered_data = provider.rotate_perturbation_point_cloud( jittered_data) jittered_data = provider.shift_point_cloud(jittered_data) feed_dict = { ops['pointclouds_pl']: jittered_data, ops['labels_pl']: current_label[start_idx:end_idx], ops['is_training_pl']: is_training, } # feed_dict的key一定是place_holder summary, step, _, loss_val, pred_val = sess.run( [ ops['merged'], ops['step'], ops['train_op'], ops['loss'], ops['pred'] ], feed_dict=feed_dict) print("--train file:{}, batch_idx:{},step:{}".format( str(fn), str(batch_idx), str(step))) train_writer.add_summary(summary, step) # 只有train才保存训练过程的曲线 pred_val = np.argmax(pred_val, 1) correct = np.sum(pred_val == current_label[start_idx:end_idx]) total_correct += correct total_seen += BATCH_SIZE loss_sum += loss_val log_string('mean loss: %f' % (loss_sum / float(num_batches))) log_string('accuracy: %f' % (total_correct / float(total_seen)))
def train_one_epoch(sess, ops, train_writer, dataset, verbose=True): """ Train model for one epoch """ global EPOCH_CNT is_training = True # Shuffle train samples train_idxs = np.arange(0, len(dataset)) np.random.shuffle(train_idxs) num_batches = len(dataset) / FLAGS['BATCH_SIZE'] # discards samples if dataset not divisible by batch size log_string('[' + str(datetime.now()) + ' | EPOCH ' + str(EPOCH_CNT) + '] Starting training.', printout=False) loss_sum, batch_print_steps = 0, 10 for batch_idx in range(num_batches): start_idx, end_idx = batch_idx * FLAGS['BATCH_SIZE'], (batch_idx + 1) * FLAGS['BATCH_SIZE'] batch_data, batch_label = get_batch(dataset, train_idxs, start_idx, end_idx) # Perturb point clouds: batch_data[:,:,:3] = provider.jitter_point_cloud(batch_data[:,:,:3]) batch_data[:,:,:3] = provider.rotate_perturbation_point_cloud(batch_data[:,:,:3]) batch_data[:,:,:3] = provider.shift_point_cloud(batch_data[:,:,:3]) batch_data[:,:,:3] = provider.random_point_dropout(batch_data[:,:,:3], max_dropout_ratio=FLAGS['MAX_POINT_DROPOUT_RATIO']) feed_dict = {ops['pointclouds_pl']: batch_data, ops['labels_pl']: batch_label, ops['is_training_pl']: is_training} summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'], ops['train_op'], ops['loss'], ops['pred']], feed_dict=feed_dict) train_writer.add_summary(summary, step) loss_sum += loss_val if batch_idx % batch_print_steps == 0: log_string('[Batch %03d] Mean Loss: %f' % ((batch_idx + 1), (loss_sum / batch_print_steps)), printout=verbose) loss_sum = 0
def get_batch_wdp(dataset, batch_idx): bsize = BATCH_SIZE batch_data = np.zeros((bsize, NUM_POINT, 3)) batch_feats = np.zeros((bsize, NUM_POINT, 2)) batch_label = np.zeros((bsize, NUM_POINT), dtype=np.int32) batch_smpw = np.zeros((bsize, NUM_POINT), dtype=np.float32) batch_ctx = np.zeros((bsize, NUM_CLASSES), dtype=np.float32) for i in range(bsize): ps,seg,smpw,feat = get_batch(dataset,batch_idx) if np.random.random() >= 0.65: ps = provider.rotate_perturbation_point_cloud(ps.reshape(1, *ps.shape), angle_sigma=0.01, angle_clip=0.01) ps = ps.squeeze() batch_data[i,...] = ps batch_label[i,:] = seg batch_smpw[i,:] = smpw batch_feats[i,:] = feat # dropout_ratio = np.random.random()*0.875 # 0-0.875 # drop_idx = np.where(np.random.random((ps.shape[0]))<=dropout_ratio)[0] # batch_data[i,drop_idx,:] = batch_data[i,0,:] # batch_label[i,drop_idx] = batch_label[i,0] # batch_smpw[i,drop_idx] *= 0 # mask = np.ones(len(ps), dtype=np.int32) # mask[drop_idx] = 0 inds, _ = np.histogram(seg, range(NUM_CLASSES+1)) batch_ctx[i] = np.array(inds > 0, dtype=np.int32) return batch_data, batch_label, batch_smpw, batch_feats, batch_ctx
def augment_batch_data(batch_data): rotated_data = provider.rotate_point_cloud_with_normal(batch_data) jittered_data = provider.random_scale_point_cloud(rotated_data[:, :, 0:3]) jittered_data = provider.rotate_perturbation_point_cloud(jittered_data) jittered_data = provider.shift_point_cloud(jittered_data) jittered_data = provider.jitter_point_cloud(jittered_data) rotated_data[:, :, 0:3] = jittered_data return rotated_data
def _augment_batch_data(self, batch_data): rotated_data = provider.rotate_point_cloud(batch_data) rotated_data = provider.rotate_perturbation_point_cloud(rotated_data) jittered_data = provider.random_scale_point_cloud(rotated_data[:,:,0:3]) jittered_data = provider.shift_point_cloud(jittered_data) jittered_data = provider.jitter_point_cloud(jittered_data) rotated_data[:,:,0:3] = jittered_data return provider.shuffle_points(rotated_data)
def train_one_epoch(sess, ops, train_writer): """ ops: dict mapping from string to tf ops """ is_training = True # Shuffle train files train_file_idxs = np.arange(0, len(TRAIN_FILES)) np.random.shuffle(train_file_idxs) for fn in range(len(TRAIN_FILES)): log_string('----' + str(fn) + '-----') current_data, current_label = provider.loadDataFile( TRAIN_FILES[train_file_idxs[fn]]) current_data = current_data[:, 0:NUM_POINT, :] current_data, current_label, _ = provider.shuffle_data( current_data, np.squeeze(current_label)) current_label = np.squeeze(current_label) file_size = current_data.shape[0] num_batches = file_size // BATCH_SIZE total_correct = 0 total_seen = 0 loss_sum = 0 for batch_idx in range(num_batches): start_idx = batch_idx * BATCH_SIZE end_idx = (batch_idx + 1) * BATCH_SIZE # Augment batched point clouds by rotation and jittering rotated_data = provider.rotate_point_cloud( current_data[start_idx:end_idx, :, :]) jittered_data = provider.jitter_point_cloud(rotated_data) jittered_data = provider.random_scale_point_cloud(jittered_data) jittered_data = provider.rotate_perturbation_point_cloud( jittered_data) jittered_data = provider.shift_point_cloud(jittered_data) feed_dict = { ops['pointclouds_pl']: jittered_data, ops['labels_pl']: current_label[start_idx:end_idx], ops['is_training_pl']: is_training, } summary, step, _, loss_val, pred_val = sess.run( [ ops['merged'], ops['step'], ops['train_op'], ops['loss'], ops['pred'] ], feed_dict=feed_dict) train_writer.add_summary(summary, step) pred_val = np.argmax(pred_val, 1) correct = np.sum(pred_val == current_label[start_idx:end_idx]) total_correct += correct total_seen += BATCH_SIZE loss_sum += loss_val log_string('mean loss: %f' % (loss_sum / float(num_batches))) log_string('accuracy: %f' % (total_correct / float(total_seen)))
def _augment_batch_data(self, batch_data): if self.normal_channel: rotated_data = provider.rotate_point_cloud_with_normal(batch_data) rotated_data = provider.rotate_perturbation_point_cloud_with_normal(rotated_data) else: rotated_data = provider.rotate_point_cloud(batch_data) rotated_data = provider.rotate_perturbation_point_cloud(rotated_data) jittered_data = provider.random_scale_point_cloud(rotated_data[:,:,0:3]) jittered_data = provider.shift_point_cloud(jittered_data) jittered_data = provider.jitter_point_cloud(jittered_data) rotated_data[:,:,0:3] = jittered_data return provider.shuffle_points(rotated_data)
def augment_batch_data_MODELNET(batch_data, is_include_normal): ''' is_include_normal=False: xyz is_include_normal=True: xyznxnynz ''' if is_include_normal: rotated_data = provider.rotate_point_cloud_with_normal(batch_data) rotated_data = provider.rotate_perturbation_point_cloud_with_normal( rotated_data) else: rotated_data = provider.rotate_point_cloud(batch_data) rotated_data = provider.rotate_perturbation_point_cloud(rotated_data) jittered_data = provider.random_scale_point_cloud(rotated_data[:, :, 0:3]) jittered_data = provider.shift_point_cloud(jittered_data) jittered_data = provider.jitter_point_cloud(jittered_data) rotated_data[:, :, 0:3] = jittered_data return provider.shuffle_points(rotated_data)
def get_example(self, i): """Return i-th data""" if self.augment: rotated_data = provider.rotate_point_cloud(self.data[i:i + 1, :, :]) jittered_data = provider.jitter_point_cloud(rotated_data) jittered_data = provider.random_scale_point_cloud(jittered_data) jittered_data = provider.rotate_perturbation_point_cloud( jittered_data) jittered_data = provider.shift_point_cloud(jittered_data) point_data = jittered_data[0] else: point_data = self.data[i] # pint_data (2048, 3): (num_point, k) --> convert to (k, num_point, 1) point_data = np.transpose(point_data.astype(np.float32), (1, 0))[:, :, None] assert point_data.dtype == np.float32 assert self.label[i].dtype == np.int32 return point_data, self.label[i]
def __data_generation(self, batch_idx): x = np.zeros((self.batch_size, self.npoints, 3)) y = np.zeros((self.batch_size, )) for i, idx in enumerate(batch_idx, 0): x[i] = self.datas[ idx, 0:self. npoints, :] # take the first n points. TODO: random choice y[i] = self.labels[idx] if self.augment and np.random.rand() > 0.5: # implement data augmentation to the whole BATCH rotated_x = provider.rotate_point_cloud(x) # rotate around x-axis rotated_x = provider.rotate_perturbation_point_cloud( rotated_x) # slightly rotate around every aixs jittered_x = provider.random_scale_point_cloud( rotated_x) # random scale a little bit jittered_x = provider.shift_point_cloud( jittered_x) # shift a little jittered_x = provider.jitter_point_cloud( jittered_x) # add random noise (jitter) jittered_x = provider.shuffle_points( jittered_x) # shuffle the point. for FPS x = jittered_x return x, keras.utils.to_categorical(y, num_classes=len(self.cat))
def _augment_batch_data(self, batch_data): if self.normal_channel: rotated_data = provider.rotate_point_cloud_with_normal(batch_data) rotated_data = provider.rotate_perturbation_point_cloud_with_normal( rotated_data) else: rotated_data = provider.rotate_point_cloud(batch_data) rotated_data = provider.rotate_perturbation_point_cloud( rotated_data) jittered_data = provider.random_scale_point_cloud( rotated_data[:, :, 0:3], scale_low=self.scale_low, scale_high=self.scale_high) jittered_data = provider.shift_point_cloud( jittered_data, shift_range=self.shift_range) jittered_data = provider.jitter_point_cloud(jittered_data, sigma=self.jitter_sigma, clip=0.1) rotated_data[:, :, 0:3] = jittered_data if self.shuffle_points: return provider.shuffle_points(rotated_data) else: return rotated_data
def train_one_epoch(sess, ops, train_writer): """ ops: dict mapping from string to tf ops """ is_training = True # Shuffle train files #train_file_idxs = np.arange(0, len(TRAIN_FILES)) # np.random.shuffle(train_file_idxs) train_files = TRAIN_FILES # print(len(train_files)) random.shuffle(train_files) total_correct = 0 total_seen = 0 loss_sum = 0 total_labels = np.array([0, 0, 0], dtype=np.int32) gen = provider.data_generator(train_files, BATCH_SIZE) # for fn in range(len(TRAIN_FILES)): for fn in range(125000//BATCH_SIZE): # total 125000 shards # log_string('----' + str(fn) + '-----') data_q, label_q, seg_q = next(gen) if data_q == None: print("an epoch is done in steps: ", fn) #print("total_labels:", total_labels) break #print(data_q[0].shape, label_q[0], seg_q[0].shape) # current_data, current_label, current_seg = provider.loadsegDataFile(train_files[fn]) # current_data = current_data[:, 0:NUM_POINT, :] # current_data, current_label, current_seg_ = provider.shuffle_data(current_data, np.squeeze(current_label)) # current_label = np.squeeze(current_label) # current_seg = np.squeeze(current_seg) # file_size = int(TRAIN_FILES[train_file_idxs[fn]].split('_')[1]) file_size = BATCH_SIZE num_batches = 1 current_data = np.array(data_q) current_label = np.array(label_q) current_seg = np.array(seg_q) rs_label = np.eye(5)[current_label] #print(current_label.shape, rs_label.shape, current_label, rs_label) # sys.exit() rs_seg = current_seg #train_data=np.concatenate((rs_current_data, rs_pottery),axis=0) train_data = current_data # Augment batched point clouds by rotation and jittering jittered_data = provider.rotate_perturbation_point_cloud( train_data, angle_sigma=np.pi/2, angle_clip=np.pi) jittered_data = provider.jitter_point_cloud(jittered_data) jittered_data = provider.random_scale_point_cloud(jittered_data) jittered_data = provider.shift_point_cloud(jittered_data) # print(jittered_data.shape) # print(rs_seg.shape) feed_dict = {ops['pointclouds_pl']: jittered_data, ops['labels_pl']: rs_seg, ops['labels_pl_c']: rs_label, ops['is_training_pl']: is_training } summary, step, _, loss_val, pred_val, pred_c_val = sess.run([ ops['merged'], ops['step'], ops['train_op'], ops['loss'], # ops['cls_loss'], # ops['reg_loss'], ops['pred'], ops['pred_c'], ], feed_dict=feed_dict) train_writer.add_summary(summary, step) #correct = np.sum(np.argmax(pred_val) == rs_seg0) #print('pred, label, loss:', np.argmax(pred_val), rs_seg0, loss_val) #correct = np.prod(np.argmax(pred_val, axis=1) == np.argmax(rs_seg, axis=1)) print('loss:', loss_val) #total_correct += correct total_seen += 1 loss_sum += loss_val #total_labels[rs_seg0] += 1 if fn % 100 == 100 - 1: # COMBINE FEATURES log_string('mean loss: %f' % (loss_sum / float(total_seen))) print('label:', rs_seg) print('pred:', pred_val) print('label_c:', current_label) print('pred_c:', np.argmax(pred_c_val, axis=1)) print('loss_val:', loss_val) correct = np.sum(np.argmax(pred_c_val, axis=1) == np.argmax(rs_label, axis=1)) print('cls_acc_c:', correct)
def train_one_epoch(sess, train_writer): """ ops: dict mapping from string to tf ops """ is_training = True # Shuffle train files train_file_idxs = np.arange(0, len(TRAIN_FILES)) np.random.shuffle(train_file_idxs) mean_acc = 0 for fn in range(len(TRAIN_FILES)): print('----' + str(fn) + '-----') current_data, current_label = provider.loadDataFile( TRAIN_FILES[train_file_idxs[fn]]) current_data = current_data[:, 0:self.num_points, :] current_data, current_label, _ = provider.shuffle_data( current_data, np.squeeze(current_label)) current_label = np.squeeze(current_label) file_size = current_data.shape[0] num_batches = file_size // self.batch_size total_correct = 0 total_seen = 0 loss_sum = 0 for batch_idx in range(num_batches): start_idx = batch_idx * self.batch_size end_idx = (batch_idx + 1) * self.batch_size # Augment batched point clouds by rotation and jittering rotated_data = provider.rotate_point_cloud( current_data[start_idx:end_idx, :, :]) jittered_data = provider.jitter_point_cloud(rotated_data) jittered_data = provider.random_scale_point_cloud( jittered_data) jittered_data = provider.rotate_perturbation_point_cloud( jittered_data) jittered_data = provider.shift_point_cloud(jittered_data) soft_targets = current_label[start_idx:end_idx] n_values = np.zeros((soft_targets.size, self.num_classes)) n_values[np.arange(soft_targets.size), soft_targets] = 1 soft_targets = n_values if teacher_flag: soft_targets = teacher_model.predict( current_data[start_idx:end_idx, :, :]) # print(soft_targets, current_label[start_idx:end_idx]) if not self.quantize_delay: feed_dict = { # ops['pointclouds_pl']: jittered_data, self.pointclouds_pl: current_data[start_idx:end_idx, :, :], self.labels_pl: current_label[start_idx:end_idx], self.is_training: is_training, self.soft_Y: soft_targets, self.flag: teacher_flag, self.softmax_temperature: self.temperature } else: feed_dict = { # ops['pointclouds_pl']: jittered_data, self.pointclouds_pl: current_data[start_idx:end_idx, :, :], self.labels_pl: current_label[start_idx:end_idx], self.soft_Y: soft_targets, self.flag: teacher_flag, self.softmax_temperature: self.temperature # ops['is_training_pl']: is_training, } # print(feed_dict) summary, step, _, loss_val, pred_val = sess.run( [ self.merged_summary_op, self.batch, self.train_op, self.total_loss, self.prediction ], feed_dict=feed_dict) # sess.run(ops['mask_update_op']) train_writer.add_summary(summary, step) pred_val = np.argmax(pred_val, 1) correct = np.sum( pred_val == current_label[start_idx:end_idx]) total_correct += correct total_seen += self.batch_size loss_sum += loss_val print('mean loss: %f' % (loss_sum / float(num_batches))) print('accuracy: %f' % (total_correct / float(total_seen))) mean_acc += total_correct / float(total_seen) return mean_acc / len(TRAIN_FILES)
def eval_one_epoch(sess, ops, num_votes=1, topk=1): error_cnt = 0 is_training = False total_correct = 0 total_seen = 0 loss_sum = 0 total_seen_class = [0 for _ in range(NUM_CLASSES)] total_correct_class = [0 for _ in range(NUM_CLASSES)] fout = open(os.path.join(DUMP_DIR, 'pred_label.txt'), 'w') np.random.seed(101) for fn in range(len(TEST_FILES)): log_string('----' + str(fn) + '----') current_data, current_label = provider.loadDataFile(TEST_FILES[fn]) if FLAGS.random_pc_order: current_data = change_point_cloud_order(current_data) current_data = current_data[:, 0:NUM_POINT, :] current_label = np.squeeze(current_label) print(current_data.shape) file_size = current_data.shape[0] num_batches = file_size // BATCH_SIZE print(file_size) for batch_idx in range(num_batches): start_idx = batch_idx * BATCH_SIZE end_idx = (batch_idx + 1) * BATCH_SIZE cur_batch_size = end_idx - start_idx # Aggregating BEG batch_loss_sum = 0 # sum of losses for the batch batch_pred_sum = np.zeros( (cur_batch_size, NUM_CLASSES)) # score for classes batch_pred_max = np.ones( (cur_batch_size, NUM_CLASSES)) * (-999999) # score for classes batch_pred_classes = np.zeros( (cur_batch_size, NUM_CLASSES)) # 0/1 for classes for vote_idx in range(num_votes): # Shuffle point order to achieve different farthest samplings shuffled_indices = np.arange(NUM_POINT) np.random.shuffle(shuffled_indices) rotated_data = provider.rotate_point_cloud_by_angle( current_data[start_idx:end_idx, shuffled_indices, :], vote_idx / float(num_votes) * np.pi * 2) jittered_data = provider.random_scale_point_cloud(rotated_data) jittered_data = provider.rotate_perturbation_point_cloud( jittered_data) jittered_data = provider.jitter_point_cloud(jittered_data) feed_dict = { ops['pointclouds_pl']: rotated_data, ops['labels_pl']: current_label[start_idx:end_idx], ops['is_training_pl']: is_training } loss_val, pred_val = sess.run([ops['loss'], ops['pred']], feed_dict=feed_dict) batch_pred_sum += pred_val batch_pred_val = np.argmax(pred_val, 1) for el_idx in range(cur_batch_size): batch_pred_classes[el_idx, batch_pred_val[el_idx]] += 1 batch_loss_sum += (loss_val * cur_batch_size / float(num_votes)) # pred_val_topk = np.argsort(batch_pred_sum, axis=-1)[:,-1*np.array(range(topk))-1] # pred_val = np.argmax(batch_pred_classes, 1) pred_val = np.argmax(batch_pred_sum, 1) # Aggregating END correct = np.sum(pred_val == current_label[start_idx:end_idx]) # correct = np.sum(pred_val_topk[:,0:topk] == label_val) total_correct += correct total_seen += cur_batch_size loss_sum += batch_loss_sum for i in range(start_idx, end_idx): l = current_label[i] total_seen_class[l] += 1 total_correct_class[l] += (pred_val[i - start_idx] == l) fout.write('%d, %d\n' % (pred_val[i - start_idx], l)) if pred_val[ i - start_idx] != l and FLAGS.visu: # ERROR CASE, DUMP! img_filename = '%d_label_%s_pred_%s.jpg' % ( error_cnt, SHAPE_NAMES[l], SHAPE_NAMES[pred_val[i - start_idx]]) img_filename = os.path.join(DUMP_DIR, img_filename) output_img = pc_util.point_cloud_three_views( np.squeeze(current_data[i, :, :])) scipy.misc.imsave(img_filename, output_img) error_cnt += 1 log_string('eval mean loss: %f' % (loss_sum / float(total_seen))) log_string('eval accuracy: %f' % (total_correct / float(total_seen))) log_string('eval avg class acc: %f' % (np.mean( np.array(total_correct_class) / np.array(total_seen_class, dtype=np.float)))) class_accuracies = np.array(total_correct_class) / np.array( total_seen_class, dtype=np.float) for i, name in enumerate(SHAPE_NAMES): log_string('%10s:\t%0.3f' % (name, class_accuracies[i]))
def train_one_epoch(sess, ops, train_writer): """ ops: dict mapping from string to tf ops """ is_training = True # Shuffle train files train_file_idxs = np.arange(0, len(LABELS)) np.random.shuffle(train_file_idxs) current_data = np.empty([len(LABELS), NUM_POINT, 3], dtype=float) current_label = np.empty([len(LABELS), 1], dtype=int) for fn in range(len(LABELS)): cut1, cut2, _ = provider.loadDataFile_cut_2( TRAIN_FILES[train_file_idxs[fn]], False) idx = np.random.randint(cut1.shape[0], size=NUM_POINT / 2) cut1 = cut1[idx, :] idx = np.random.randint(cut2.shape[0], size=NUM_POINT / 2) cut2 = cut2[idx, :] current_data[fn] = np.concatenate((cut1, cut2), axis=0) current_label[fn] = LABELS[train_file_idxs[fn]] current_label = np.squeeze(current_label) file_size = current_data.shape[0] num_batches = file_size // BATCH_SIZE total_correct = 0 total_seen = 0 loss_sum = 0 for batch_idx in range(num_batches): start_idx = batch_idx * BATCH_SIZE end_idx = (batch_idx + 1) * BATCH_SIZE # Augment batched point clouds by rotation and jittering rotated_data = provider.rotate_point_cloud( current_data[start_idx:end_idx, :, :]) jittered_data = provider.jitter_point_cloud(rotated_data) jittered_data = provider.random_scale_point_cloud(jittered_data) jittered_data = provider.rotate_perturbation_point_cloud(jittered_data) jittered_data = provider.shift_point_cloud(jittered_data) feed_dict = { ops['pointclouds_pl']: jittered_data, ops['labels_pl']: current_label[start_idx:end_idx], ops['is_training_pl']: is_training, } summary, step, _, loss_val, pred_val, feature = sess.run( [ ops['merged'], ops['step'], ops['train_op'], ops['loss'], ops['pred'], ops['feat'] ], feed_dict=feed_dict) train_writer.add_summary(summary, step) pred_val = np.argmax(pred_val, 1) correct = np.sum(pred_val == current_label[start_idx:end_idx]) total_correct += correct total_seen += BATCH_SIZE loss_sum += loss_val if batch_idx % 100 == 0: log_string('mean loss: {0:f} accuracy: {1:f}'.format( loss_sum / float(batch_idx + 1), total_correct / float(total_seen))) log_string('mean loss: {0:f} accuracy: {1:f}'.format( loss_sum / float(batch_idx + 1), total_correct / float(total_seen)))
def train_one_epoch(sess, ops, writer, epoch): """ ops: dict mapping from string to tf ops """ is_training = True log_string(str(datetime.now())) # shuffle training files order random.shuffle(train_h5_fn_list) for item in train_h5_fn_list: cur_h5_fn = os.path.join(data_in_dir, item) print('Reading data from ', cur_h5_fn) pts, gt_label, gt_mask, gt_valid, gt_other_mask, _ = load_data(cur_h5_fn) # shuffle data order n_shape = pts.shape[0] idx = np.arange(n_shape) np.random.shuffle(idx) pts = pts[idx, ...] gt_label = gt_label[idx, ...] gt_mask = gt_mask[idx, ...] gt_valid = gt_valid[idx, ...] gt_other_mask = gt_other_mask[idx, ...] # data augmentation to pts pts = provider.jitter_point_cloud(pts) pts = provider.shift_point_cloud(pts) pts = provider.random_scale_point_cloud(pts) pts = provider.rotate_perturbation_point_cloud(pts) num_batch = n_shape // BATCH_SIZE for i in range(num_batch): start_idx = i * BATCH_SIZE end_idx = (i + 1) * BATCH_SIZE cur_pts = pts[start_idx: end_idx, ...] cur_gt_label = gt_label[start_idx: end_idx, ...] cur_gt_mask = gt_mask[start_idx: end_idx, ...] cur_gt_valid = gt_valid[start_idx: end_idx, ...] cur_gt_other_mask = gt_other_mask[start_idx: end_idx, ...] feed_dict = {ops['pc_pl']: cur_pts, ops['label_pl']: cur_gt_label, ops['gt_mask_pl']: cur_gt_mask, ops['gt_valid_pl']: cur_gt_valid, ops['gt_other_mask_pl']: cur_gt_other_mask, ops['is_training_pl']: is_training} summary, step, _, lr_val, bn_decay_val, seg_loss_val, ins_loss_val, other_ins_loss_val, l21_norm_loss_val, conf_loss_val, loss_val, seg_pred_val, \ mask_pred_val, other_mask_pred_val = sess.run([ops['merged'], ops['step'], ops['train_op'], ops['learning_rate'], ops['bn_decay'], \ ops['seg_loss'], ops['ins_loss'], ops['other_ins_loss'], ops['l21_norm_loss'], ops['conf_loss'], ops['loss'], \ ops['seg_pred'], ops['mask_pred'], ops['other_mask_pred']], feed_dict=feed_dict) writer.add_summary(summary, step) seg_pred_id = np.argmax(seg_pred_val, axis=-1) seg_acc = np.mean(seg_pred_id == cur_gt_label) log_string('[Train Epoch %03d, Batch %03d, LR: %f, BN_DECAY: %f] Loss: %f = %f x %f (seg_loss, Seg Acc: %f) + %f x %f (ins_loss) + %f x %f (other_ins_loss) + %f x %f (l21_norm_loss) + %f x %f (conf_loss)' \ % (epoch, i, lr_val, bn_decay_val, loss_val, FLAGS.seg_loss_weight, seg_loss_val, seg_acc, \ FLAGS.ins_loss_weight, ins_loss_val, FLAGS.other_ins_loss_weight, other_ins_loss_val, \ FLAGS.l21_norm_loss_weight, l21_norm_loss_val, FLAGS.conf_loss_weight, conf_loss_val))
def eval_one_epoch(sess, ops, test_writer): """ ops: dict mapping from string to tf ops """ global EPOCH_CNT global BEST_ACC global BEST_CLS_ACC is_training = False # Make sure batch data is of same size cur_batch_data = np.zeros((BATCH_SIZE,NUM_POINT,TEST_DATASET.num_channel())) cur_batch_label = np.zeros((BATCH_SIZE), dtype=np.int32) total_correct = 0 total_seen = 0 loss_sum = 0 batch_idx = 0 shape_ious = [] total_seen_class = [0 for _ in range(NUM_CLASSES)] total_correct_class = [0 for _ in range(NUM_CLASSES)] log_string(str(datetime.now())) log_string('---- EPOCH %03d EVALUATION ----'%(EPOCH_CNT)) while TEST_DATASET.has_next_batch(): batch_data, batch_label = TEST_DATASET.next_batch(augment=False) bsize = batch_data.shape[0] # print('Batch: %03d, batch size: %d'%(batch_idx, bsize)) # for the last batch in the epoch, the bsize:end are from last batch cur_batch_data[0:bsize,...] = batch_data cur_batch_label[0:bsize] = batch_label if ROTATE_FLAG: batch_pred_sum = np.zeros((BATCH_SIZE, NUM_CLASSES)) # score for classes for vote_idx in range(12): # Shuffle point order to achieve different farthest samplings shuffled_indices = np.arange(NUM_POINT) np.random.shuffle(shuffled_indices) if NORMAL_FLAG: rotated_data = provider.rotate_point_cloud_by_angle_with_normal(cur_batch_data[:, shuffled_indices, :], vote_idx/float(12) * np.pi * 2) rotated_data = provider.rotate_perturbation_point_cloud_with_normal(rotated_data) else: rotated_data = provider.rotate_point_cloud_by_angle(cur_batch_data[:, shuffled_indices, :], vote_idx/float(12) * np.pi * 2) rotated_data = provider.rotate_perturbation_point_cloud(rotated_data) jittered_data = provider.random_scale_point_cloud(rotated_data[:,:,0:3]) jittered_data = provider.jitter_point_cloud(jittered_data) rotated_data[:,:,0:3] = jittered_data # else: # rotated_data = provider.rotate_point_cloud_by_angle(cur_batch_data[:, shuffled_indices, :], # vote_idx/float(12) * np.pi * 2) feed_dict = {ops['pointclouds_pl']: rotated_data, ops['labels_pl']: cur_batch_label, ops['is_training_pl']: is_training} loss_val, pred_val = sess.run([ops['loss'], ops['pred']], feed_dict=feed_dict) batch_pred_sum += pred_val pred_val = np.argmax(batch_pred_sum, 1) else: feed_dict = {ops['pointclouds_pl']: cur_batch_data, ops['labels_pl']: cur_batch_label, ops['is_training_pl']: is_training} summary, step, loss_val, pred_val = sess.run([ops['merged'], ops['step'], ops['loss'], ops['pred']], feed_dict=feed_dict) test_writer.add_summary(summary, step) pred_val = np.argmax(pred_val, 1) correct = np.sum(pred_val[0:bsize] == batch_label[0:bsize]) total_correct += correct total_seen += bsize loss_sum += loss_val batch_idx += 1 for i in range(bsize): l = batch_label[i] total_seen_class[l] += 1 total_correct_class[l] += (pred_val[i] == l) current_acc = total_correct / float(total_seen) current_cls_acc = np.mean(np.array(total_correct_class)/np.array(total_seen_class,dtype=np.float)) log_string('eval mean loss: %f' % (loss_sum / float(batch_idx))) log_string('eval accuracy: %f'% (current_acc)) log_string('eval avg class acc: %f' % (current_cls_acc)) best_acc_flag, best_cls_acc_flag = False, False if current_acc > BEST_ACC: BEST_ACC = current_acc best_acc_flag = True if current_cls_acc > BEST_CLS_ACC: BEST_CLS_ACC = current_cls_acc best_cls_acc_flag = True log_string('eval best accuracy: %f'% (BEST_ACC)) log_string('eval best avg class acc: %f'% (BEST_CLS_ACC)) EPOCH_CNT += 1 TEST_DATASET.reset() return (best_acc_flag, best_cls_acc_flag)
def train_one_epoch(sess, ops, train_writer): """ ops: dict mapping from string to tf ops """ is_training = True # Shuffle train files train_file_idxs = np.arange(0, len(TRAIN_FILES)) np.random.shuffle(train_file_idxs) current_data_1 = np.empty([3 * len(TRAIN_FILES), NUM_POINT, 3], dtype=float) current_data_2 = np.empty([3 * len(TRAIN_FILES), NUM_POINT, 3], dtype=float) current_label = np.empty([3 * len(TRAIN_FILES), 1], dtype=int) fn = 0 count = 0 while fn < len(TRAIN_FILES) - 1: # log_string('----' + str(fn) + '-----') total_current = [] a1, a2, _ = provider.loadDataFile_cut_2( TRAIN_FILES[train_file_idxs[fn]]) idx = np.random.randint(a1.shape[0], size=NUM_POINT) a1 = a1[idx, :] idx = np.random.randint(a2.shape[0], size=NUM_POINT) a2 = a2[idx, :] total_current.append(a1) total_current.append(a2) fn = fn + 1 b1, b2, _ = provider.loadDataFile_cut_2( TRAIN_FILES[train_file_idxs[fn]]) idx = np.random.randint(b1.shape[0], size=NUM_POINT) b1 = b1[idx, :] idx = np.random.randint(b2.shape[0], size=NUM_POINT) b2 = b2[idx, :] total_current.append(b1) total_current.append(b2) fn = fn + 1 pair_num = 0 for index in range(len(total_current)): for index2 in range(index + 1, len(total_current)): current_data_1[6 * count + pair_num, :, :] = total_current[index] current_data_2[6 * count + pair_num, :, :] = total_current[index2] if (index < 2) and (index2 >= 2): current_label[6 * count + pair_num, :] = 0 else: current_label[6 * count + pair_num, :] = 1 pair_num = pair_num + 1 count = count + 1 current_label = np.squeeze(current_label) file_size = current_data_1.shape[0] num_batches = file_size // BATCH_SIZE total_correct = 0 total_seen = 0 loss_sum = 0 for batch_idx in range(num_batches): start_idx = batch_idx * BATCH_SIZE end_idx = (batch_idx + 1) * BATCH_SIZE # shuffle each batch data_1 = current_data_1[start_idx:end_idx, :, :] data_2 = current_data_2[start_idx:end_idx, :, :] label = current_label[start_idx:end_idx] combine_data = np.concatenate((data_1, data_2), axis=2) combine_data, label, _ = provider.shuffle_data(combine_data, np.squeeze(label)) data_1 = combine_data[:, :, 0:3] data_2 = combine_data[:, :, 3:6] label = np.squeeze(label) # Augment batched point clouds by rotation and jittering rotated_data_1 = provider.rotate_point_cloud(data_1) jittered_data_1 = provider.jitter_point_cloud(rotated_data_1) jittered_data_1 = provider.random_scale_point_cloud(jittered_data_1) jittered_data_1 = provider.rotate_perturbation_point_cloud( jittered_data_1) jittered_data_1 = provider.shift_point_cloud(jittered_data_1) rotated_data_2 = provider.rotate_point_cloud(data_2) jittered_data_2 = provider.jitter_point_cloud(rotated_data_2) jittered_data_2 = provider.random_scale_point_cloud(jittered_data_2) jittered_data_2 = provider.rotate_perturbation_point_cloud( jittered_data_2) jittered_data_2 = provider.shift_point_cloud(jittered_data_2) feed_dict = { ops['pointclouds_pl_1']: jittered_data_1, ops['pointclouds_pl_2']: jittered_data_2, ops['labels_pl']: label, ops['is_training_pl']: is_training } summary, step, _, loss_val, pred_val = sess.run([ ops['merged'], ops['step'], ops['train_op'], ops['loss'], ops['pred'] ], feed_dict=feed_dict) train_writer.add_summary(summary, step) pred_val = np.argmax(pred_val, 1) correct = np.sum(pred_val == label) total_correct += correct total_seen += BATCH_SIZE loss_sum += loss_val if batch_idx % 50 == 0: log_string('mean loss: {0:f} accuracy: {1:f}'.format( loss_sum / float(batch_idx + 1), total_correct / float(total_seen))) log_string('mean loss: {0:f} accuracy: {1:f}'.format( loss_sum / float(batch_idx + 1), total_correct / float(total_seen)))
def eval_one_epoch(sess, ops, test_writer): """ ops: dict mapping from string to tf ops """ is_training = False total_correct = 0 total_seen = 0 loss_sum = 0 loss_sum_c = 0 total_seen_class = [0 for _ in range(NUM_CLASSES)] total_correct_class = [0 for _ in range(NUM_CLASSES)] total_labels = np.array([0, 0, 0], dtype=np.int32) gen = provider.data_generator(TEST_FILES, BATCH_SIZE) # for fn in range(len(TEST_FILES)): for fn in range(125000//BATCH_SIZE): # total 125000 shards data_q, label_q, seg_q = next(gen) if data_q == None: print("an epoch is done in steps: ", fn) #print("total_labels:", total_labels) break #file_size = int(TEST_FILES[fn].split('_')[1]) file_size = BATCH_SIZE num_batches = 1 current_data = np.array(data_q) current_label = np.array(label_q) current_seg = np.array(seg_q) """ if current_label_i == 0 or current_label_i == 4: # pot #1 or #4 if current_seg_i[0] == 1: rs_seg0 = 0 elif current_seg_i[3] == 1: rs_seg0 = 2 else: rs_seg0 = 1 else: if current_seg_i[0] == 1: rs_seg0 = 0 elif current_seg_i[2] == 1: rs_seg0 = 2 else: rs_seg0 = 1 rs_seg = np.eye(3)[rs_seg0] #print(rs_seg) #rs_seg = np.eye(16)[rs_seg0] # rs_seg.shape=(4,2) #rs_seg = np.eye(8)[rs_seg0] # rs_seg.shape=(4,2) #print(rs_seg.shape, rs_seg) """ rs_label = np.eye(5)[current_label] rs_seg = current_seg #test_data=np.concatenate((rs_current_data, rs_pottery),axis=0) test_data = current_data jittered_data = provider.rotate_perturbation_point_cloud( test_data, angle_sigma=np.pi/2, angle_clip=np.pi) jittered_data = provider.jitter_point_cloud(jittered_data) jittered_data = provider.random_scale_point_cloud(jittered_data) jittered_data = provider.shift_point_cloud(jittered_data) feed_dict = {ops['pointclouds_pl']: jittered_data, ops['labels_pl']: rs_seg, ops['labels_pl_c']: rs_label, ops['is_training_pl']: is_training } # summary, step, cls_loss, cls_loss_c, pred_val, pred_c_val = sess.run([ # ops['merged'], ops['step'], ops['cls_loss'], ops['cls_loss_c'], ops['pred'], ops['pred_c'], ], feed_dict=feed_dict) #correct = np.sum(np.argmax(pred_val) == rs_seg0) #print(np.argmax(pred_c_val, axis=1), np.argmax(rs_label, axis=1)) correct = np.sum(np.argmax(pred_c_val, axis=1) == np.argmax(rs_label, axis=1)) print('cls_loss, cls_loss_c, cls_acc_c:', cls_loss, cls_loss_c, correct) #print('pred, label, loss:', np.argmax(pred_val, axis=1), np.argmax(rs_seg, axis=1), loss_val) total_correct += correct / float(BATCH_SIZE) total_seen += 1 # loss_sum += (loss_val*BATCH_SIZE) loss_sum += cls_loss loss_sum_c += cls_loss_c #total_labels[rs_seg0] += 1 loss_avg = loss_sum / float(total_seen) loss_avg_c = loss_sum_c / float(total_seen) acc_avg_c = total_correct / float(total_seen) log_string('eval mean cls_loss: %f' % loss_avg) log_string('eval mean cls_loss_c: %f' % loss_avg_c) log_string('eval mean acc_loss_c: %f' % acc_avg_c) val_loss_summary.value[0].simple_value = (loss_avg) val_loss_c_summary.value[0].simple_value = (loss_avg_c) val_acc_c_summary.value[0].simple_value = (acc_avg_c) test_writer.add_summary(val_loss_summary, step) test_writer.add_summary(val_loss_c_summary, step) test_writer.add_summary(val_acc_c_summary, step) test_writer.flush() #print("total_labels:", total_labels) print('label:', rs_seg) print('pred:', pred_val) print('cls_loss:', cls_loss) print('label_c:', current_label) print('pred_c:', np.argmax(pred_c_val, axis=1)) print('cls_loss_c:', cls_loss_c) print('cls_acc_c:', correct) return loss_avg, loss_avg_c
def train_one_epoch(epoch_num): ### NOTE: is_training = False: We do not update bn parameters during training due to the small batch size. This requires pre-training PointNet with large batchsize (say 32). is_training = False allow_full_loss = 0.0 if epoch_num > 5: allow_full_loss = 1.0 printout(flog, str(datetime.now())) random.shuffle(train_h5_fn_list) for item in train_h5_fn_list: cur_h5_fn = os.path.join(data_in_dir, item) printout(flog, 'Reading data from: %s' % cur_h5_fn) pts, semseg_one_hot, semseg_mask, insseg_one_hot, insseg_mask = load_data( cur_h5_fn) # shuffle data order n_shape = pts.shape[0] idx = np.arange(n_shape) np.random.shuffle(idx) pts = pts[idx, ...] semseg_one_hot = semseg_one_hot[idx, ...] semseg_mask = semseg_mask[idx, ...] insseg_one_hot = insseg_one_hot[idx, ...] insseg_mask = insseg_mask[idx, ...] # data augmentation to pts pts = provider.jitter_point_cloud(pts) pts = provider.shift_point_cloud(pts) pts = provider.random_scale_point_cloud(pts) pts = provider.rotate_perturbation_point_cloud(pts) total_loss = 0.0 total_grouperr = 0.0 total_same = 0.0 total_diff = 0.0 total_same = 0.0 total_pos = 0.0 total_acc = 0.0 num_batch = n_shape // BATCH_SIZE for i in range(num_batch): start_idx = i * BATCH_SIZE end_idx = (i + 1) * BATCH_SIZE feed_dict = { pointclouds_ph: pts[start_idx:end_idx, ...], ptsseglabel_ph: semseg_one_hot[start_idx:end_idx, ...], ptsgroup_label_ph: insseg_one_hot[start_idx:end_idx, ...], pts_seglabel_mask_ph: semseg_mask[start_idx:end_idx, ...], pts_group_mask_ph: insseg_mask[start_idx:end_idx, ...], is_training_ph: is_training, alpha_ph: min(10., (float(epoch_num) / 5.) * 2. + 2.), allow_full_loss_pl: allow_full_loss } _, loss_val, lr_val, simmat_val, semseg_logits_val, \ grouperr_val, same_val, same_cnt_val, diff_val, diff_cnt_val, pos_val, pos_cnt_val = sess.run([\ train_op, loss, learning_rate, \ net_output['simmat'], net_output['semseg_logits'], \ grouperr, same, same_cnt, diff, diff_cnt, pos, pos_cnt], feed_dict=feed_dict) seg_pred = np.argmax(semseg_logits_val, axis=-1) gt_mask_val = (np.sum( semseg_one_hot[start_idx:end_idx, ...], axis=-1) > 0) seg_gt = np.argmax(semseg_one_hot[start_idx:end_idx, ...], axis=-1) correct = ((seg_pred == seg_gt) + (~gt_mask_val) > 0) acc_val = np.mean(correct) total_acc += acc_val total_loss += loss_val total_grouperr += grouperr_val total_diff += diff_val / max(1e-6, diff_cnt_val) total_same += same_val / max(1e-6, same_cnt_val) total_pos += pos_val / max(1e-6, pos_cnt_val) if i % 10 == 9: printout(flog, 'Batch: %d, LR: %f, loss: %f, FullLoss: %f, SegAcc: %f, grouperr: %f, same: %f, diff: %f, pos: %f' % \ (i, lr_val, total_loss/10, allow_full_loss, total_acc/10, total_grouperr/10, total_same/10, total_diff/10, total_pos/10)) lr_sum, batch_sum, train_loss_sum, group_err_sum = sess.run( \ [lr_op, batch, total_train_loss_sum_op, group_err_op], \ feed_dict={total_training_loss_ph: total_loss / 10., group_err_loss_ph: total_grouperr / 10., }) train_writer.add_summary(train_loss_sum, batch_sum) train_writer.add_summary(lr_sum, batch_sum) train_writer.add_summary(group_err_sum, batch_sum) total_grouperr = 0.0 total_loss = 0.0 total_diff = 0.0 total_same = 0.0 total_pos = 0.0 same_cnt0 = 0 total_acc = 0.0