コード例 #1
0
ファイル: train.py プロジェクト: xiaobai-1-1/aiimooc_lesson
def train_one_epoch(sess, ops, train_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = True

    # Shuffle train files
    train_file_idxs = np.arange(0, len(TRAIN_FILES))  # 每一次的5个train文件的顺序都是不一样的
    np.random.shuffle(train_file_idxs)

    for fn in range(len(TRAIN_FILES)):  # 对每一个train文件
        log_string('----train file' + str(fn) + '-----')
        current_data, current_label = provider.loadDataFile(
            TRAIN_FILES[train_file_idxs[fn]])
        current_data = current_data[:, 0:
                                    NUM_POINT, :]  # 采样1024个点,current代表这个文件中所有的点云
        current_data, current_label, _ = provider.shuffle_data(
            current_data, np.squeeze(current_label))  # 每个点云之间的顺序被打乱
        current_label = np.squeeze(current_label)  # 移除数组中单一维度的数据

        file_size = current_data.shape[0]  # 点云的总数
        num_batches = file_size // BATCH_SIZE  # 需要几个batch

        total_correct = 0
        total_seen = 0
        loss_sum = 0

        for batch_idx in range(num_batches):
            start_idx = batch_idx * BATCH_SIZE
            end_idx = (batch_idx + 1) * BATCH_SIZE

            # Augment batched point clouds by rotation and jittering
            rotated_data = provider.rotate_point_cloud(
                current_data[start_idx:end_idx, :, :])
            jittered_data = provider.jitter_point_cloud(rotated_data)
            jittered_data = provider.random_scale_point_cloud(jittered_data)
            jittered_data = provider.rotate_perturbation_point_cloud(
                jittered_data)
            jittered_data = provider.shift_point_cloud(jittered_data)

            feed_dict = {
                ops['pointclouds_pl']: jittered_data,
                ops['labels_pl']: current_label[start_idx:end_idx],
                ops['is_training_pl']: is_training,
            }  # feed_dict的key一定是place_holder
            summary, step, _, loss_val, pred_val = sess.run(
                [
                    ops['merged'], ops['step'], ops['train_op'], ops['loss'],
                    ops['pred']
                ],
                feed_dict=feed_dict)
            print("--train file:{}, batch_idx:{},step:{}".format(
                str(fn), str(batch_idx), str(step)))
            train_writer.add_summary(summary, step)  # 只有train才保存训练过程的曲线
            pred_val = np.argmax(pred_val, 1)
            correct = np.sum(pred_val == current_label[start_idx:end_idx])
            total_correct += correct
            total_seen += BATCH_SIZE
            loss_sum += loss_val

        log_string('mean loss: %f' % (loss_sum / float(num_batches)))
        log_string('accuracy: %f' % (total_correct / float(total_seen)))
コード例 #2
0
def train_one_epoch(sess, ops, train_writer, dataset, verbose=True):
  """
  Train model for one epoch
  """
  global EPOCH_CNT
  is_training = True

  # Shuffle train samples
  train_idxs = np.arange(0, len(dataset))
  np.random.shuffle(train_idxs)

  num_batches = len(dataset) / FLAGS['BATCH_SIZE'] # discards samples if dataset not divisible by batch size

  log_string('[' + str(datetime.now()) + ' | EPOCH ' + str(EPOCH_CNT) + '] Starting training.', printout=False)

  loss_sum, batch_print_steps = 0, 10
  for batch_idx in range(num_batches):
    start_idx, end_idx = batch_idx * FLAGS['BATCH_SIZE'], (batch_idx + 1) * FLAGS['BATCH_SIZE']
    batch_data, batch_label = get_batch(dataset, train_idxs, start_idx, end_idx)
    # Perturb point clouds:
    batch_data[:,:,:3] = provider.jitter_point_cloud(batch_data[:,:,:3])
    batch_data[:,:,:3] = provider.rotate_perturbation_point_cloud(batch_data[:,:,:3])
    batch_data[:,:,:3] = provider.shift_point_cloud(batch_data[:,:,:3])
    batch_data[:,:,:3] = provider.random_point_dropout(batch_data[:,:,:3],
                                                       max_dropout_ratio=FLAGS['MAX_POINT_DROPOUT_RATIO'])
    feed_dict = {ops['pointclouds_pl']: batch_data,
                 ops['labels_pl']: batch_label,
                 ops['is_training_pl']: is_training}
    summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'], ops['train_op'],
                                                     ops['loss'], ops['pred']], feed_dict=feed_dict)
    train_writer.add_summary(summary, step)
    loss_sum += loss_val
    if batch_idx % batch_print_steps == 0:
      log_string('[Batch %03d] Mean Loss: %f' % ((batch_idx + 1), (loss_sum / batch_print_steps)), printout=verbose)
      loss_sum = 0
コード例 #3
0
def get_batch_wdp(dataset, batch_idx):
    bsize = BATCH_SIZE
    batch_data = np.zeros((bsize, NUM_POINT, 3))
    batch_feats = np.zeros((bsize, NUM_POINT, 2))
    batch_label = np.zeros((bsize, NUM_POINT), dtype=np.int32)
    batch_smpw = np.zeros((bsize, NUM_POINT), dtype=np.float32)
    
    batch_ctx = np.zeros((bsize, NUM_CLASSES), dtype=np.float32)
        
    for i in range(bsize):
        ps,seg,smpw,feat = get_batch(dataset,batch_idx)
        if np.random.random() >= 0.65:
            ps = provider.rotate_perturbation_point_cloud(ps.reshape(1, *ps.shape), angle_sigma=0.01, angle_clip=0.01)
            ps = ps.squeeze()

        batch_data[i,...] = ps
        batch_label[i,:] = seg
        batch_smpw[i,:] = smpw
        batch_feats[i,:] = feat
        
#         dropout_ratio = np.random.random()*0.875 # 0-0.875
#         drop_idx = np.where(np.random.random((ps.shape[0]))<=dropout_ratio)[0]
#         batch_data[i,drop_idx,:] = batch_data[i,0,:]
#         batch_label[i,drop_idx] = batch_label[i,0]
#         batch_smpw[i,drop_idx] *= 0

#         mask = np.ones(len(ps), dtype=np.int32)
#         mask[drop_idx] = 0
        inds, _ = np.histogram(seg, range(NUM_CLASSES+1))
        batch_ctx[i] = np.array(inds > 0, dtype=np.int32)   
        
    return batch_data, batch_label, batch_smpw, batch_feats, batch_ctx 
コード例 #4
0
def augment_batch_data(batch_data):
    rotated_data = provider.rotate_point_cloud_with_normal(batch_data)
    jittered_data = provider.random_scale_point_cloud(rotated_data[:, :, 0:3])
    jittered_data = provider.rotate_perturbation_point_cloud(jittered_data)
    jittered_data = provider.shift_point_cloud(jittered_data)
    jittered_data = provider.jitter_point_cloud(jittered_data)
    rotated_data[:, :, 0:3] = jittered_data
    return rotated_data
コード例 #5
0
 def _augment_batch_data(self, batch_data):
     rotated_data = provider.rotate_point_cloud(batch_data)
     rotated_data = provider.rotate_perturbation_point_cloud(rotated_data)
     jittered_data = provider.random_scale_point_cloud(rotated_data[:,:,0:3])
     jittered_data = provider.shift_point_cloud(jittered_data)
     jittered_data = provider.jitter_point_cloud(jittered_data)
     rotated_data[:,:,0:3] = jittered_data
     return provider.shuffle_points(rotated_data)
コード例 #6
0
def train_one_epoch(sess, ops, train_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = True

    # Shuffle train files
    train_file_idxs = np.arange(0, len(TRAIN_FILES))
    np.random.shuffle(train_file_idxs)

    for fn in range(len(TRAIN_FILES)):
        log_string('----' + str(fn) + '-----')
        current_data, current_label = provider.loadDataFile(
            TRAIN_FILES[train_file_idxs[fn]])
        current_data = current_data[:, 0:NUM_POINT, :]
        current_data, current_label, _ = provider.shuffle_data(
            current_data, np.squeeze(current_label))
        current_label = np.squeeze(current_label)

        file_size = current_data.shape[0]
        num_batches = file_size // BATCH_SIZE

        total_correct = 0
        total_seen = 0
        loss_sum = 0

        for batch_idx in range(num_batches):
            start_idx = batch_idx * BATCH_SIZE
            end_idx = (batch_idx + 1) * BATCH_SIZE

            # Augment batched point clouds by rotation and jittering
            rotated_data = provider.rotate_point_cloud(
                current_data[start_idx:end_idx, :, :])
            jittered_data = provider.jitter_point_cloud(rotated_data)
            jittered_data = provider.random_scale_point_cloud(jittered_data)
            jittered_data = provider.rotate_perturbation_point_cloud(
                jittered_data)
            jittered_data = provider.shift_point_cloud(jittered_data)

            feed_dict = {
                ops['pointclouds_pl']: jittered_data,
                ops['labels_pl']: current_label[start_idx:end_idx],
                ops['is_training_pl']: is_training,
            }
            summary, step, _, loss_val, pred_val = sess.run(
                [
                    ops['merged'], ops['step'], ops['train_op'], ops['loss'],
                    ops['pred']
                ],
                feed_dict=feed_dict)
            train_writer.add_summary(summary, step)
            pred_val = np.argmax(pred_val, 1)
            correct = np.sum(pred_val == current_label[start_idx:end_idx])
            total_correct += correct
            total_seen += BATCH_SIZE
            loss_sum += loss_val

        log_string('mean loss: %f' % (loss_sum / float(num_batches)))
        log_string('accuracy: %f' % (total_correct / float(total_seen)))
コード例 #7
0
 def _augment_batch_data(self, batch_data):
     if self.normal_channel:
         rotated_data = provider.rotate_point_cloud_with_normal(batch_data)
         rotated_data = provider.rotate_perturbation_point_cloud_with_normal(rotated_data)
     else:
         rotated_data = provider.rotate_point_cloud(batch_data)
         rotated_data = provider.rotate_perturbation_point_cloud(rotated_data)
     jittered_data = provider.random_scale_point_cloud(rotated_data[:,:,0:3])
     jittered_data = provider.shift_point_cloud(jittered_data)
     jittered_data = provider.jitter_point_cloud(jittered_data)
     rotated_data[:,:,0:3] = jittered_data
     return provider.shuffle_points(rotated_data)
コード例 #8
0
ファイル: modelnet_dataset.py プロジェクト: joosm/pointnet2
 def _augment_batch_data(self, batch_data):
     if self.normal_channel:
         rotated_data = provider.rotate_point_cloud_with_normal(batch_data)
         rotated_data = provider.rotate_perturbation_point_cloud_with_normal(rotated_data)
     else:
         rotated_data = provider.rotate_point_cloud(batch_data)
         rotated_data = provider.rotate_perturbation_point_cloud(rotated_data)
 
     jittered_data = provider.random_scale_point_cloud(rotated_data[:,:,0:3])
     jittered_data = provider.shift_point_cloud(jittered_data)
     jittered_data = provider.jitter_point_cloud(jittered_data)
     rotated_data[:,:,0:3] = jittered_data
     return provider.shuffle_points(rotated_data)
コード例 #9
0
def augment_batch_data_MODELNET(batch_data, is_include_normal):
    '''
    is_include_normal=False: xyz
    is_include_normal=True: xyznxnynz
    '''
    if is_include_normal:
        rotated_data = provider.rotate_point_cloud_with_normal(batch_data)
        rotated_data = provider.rotate_perturbation_point_cloud_with_normal(
            rotated_data)
    else:
        rotated_data = provider.rotate_point_cloud(batch_data)
        rotated_data = provider.rotate_perturbation_point_cloud(rotated_data)

    jittered_data = provider.random_scale_point_cloud(rotated_data[:, :, 0:3])
    jittered_data = provider.shift_point_cloud(jittered_data)
    jittered_data = provider.jitter_point_cloud(jittered_data)
    rotated_data[:, :, 0:3] = jittered_data
    return provider.shuffle_points(rotated_data)
コード例 #10
0
 def get_example(self, i):
     """Return i-th data"""
     if self.augment:
         rotated_data = provider.rotate_point_cloud(self.data[i:i +
                                                              1, :, :])
         jittered_data = provider.jitter_point_cloud(rotated_data)
         jittered_data = provider.random_scale_point_cloud(jittered_data)
         jittered_data = provider.rotate_perturbation_point_cloud(
             jittered_data)
         jittered_data = provider.shift_point_cloud(jittered_data)
         point_data = jittered_data[0]
     else:
         point_data = self.data[i]
     # pint_data (2048, 3): (num_point, k) --> convert to (k, num_point, 1)
     point_data = np.transpose(point_data.astype(np.float32), (1, 0))[:, :,
                                                                      None]
     assert point_data.dtype == np.float32
     assert self.label[i].dtype == np.int32
     return point_data, self.label[i]
コード例 #11
0
 def __data_generation(self, batch_idx):
     x = np.zeros((self.batch_size, self.npoints, 3))
     y = np.zeros((self.batch_size, ))
     for i, idx in enumerate(batch_idx, 0):
         x[i] = self.datas[
             idx, 0:self.
             npoints, :]  # take the first n points. TODO: random choice
         y[i] = self.labels[idx]
     if self.augment and np.random.rand() > 0.5:
         # implement data augmentation to the whole BATCH
         rotated_x = provider.rotate_point_cloud(x)  # rotate around x-axis
         rotated_x = provider.rotate_perturbation_point_cloud(
             rotated_x)  # slightly rotate around every aixs
         jittered_x = provider.random_scale_point_cloud(
             rotated_x)  # random scale a little bit
         jittered_x = provider.shift_point_cloud(
             jittered_x)  # shift a little
         jittered_x = provider.jitter_point_cloud(
             jittered_x)  # add random noise (jitter)
         jittered_x = provider.shuffle_points(
             jittered_x)  # shuffle the point. for FPS
         x = jittered_x
     return x, keras.utils.to_categorical(y, num_classes=len(self.cat))
コード例 #12
0
ファイル: fold_dataset.py プロジェクト: tryptofanik/bios2net
    def _augment_batch_data(self, batch_data):
        if self.normal_channel:
            rotated_data = provider.rotate_point_cloud_with_normal(batch_data)
            rotated_data = provider.rotate_perturbation_point_cloud_with_normal(
                rotated_data)
        else:
            rotated_data = provider.rotate_point_cloud(batch_data)
            rotated_data = provider.rotate_perturbation_point_cloud(
                rotated_data)

        jittered_data = provider.random_scale_point_cloud(
            rotated_data[:, :, 0:3],
            scale_low=self.scale_low,
            scale_high=self.scale_high)
        jittered_data = provider.shift_point_cloud(
            jittered_data, shift_range=self.shift_range)
        jittered_data = provider.jitter_point_cloud(jittered_data,
                                                    sigma=self.jitter_sigma,
                                                    clip=0.1)
        rotated_data[:, :, 0:3] = jittered_data
        if self.shuffle_points:
            return provider.shuffle_points(rotated_data)
        else:
            return rotated_data
コード例 #13
0
def train_one_epoch(sess, ops, train_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = True

    # Shuffle train files
    #train_file_idxs = np.arange(0, len(TRAIN_FILES))
    # np.random.shuffle(train_file_idxs)
    train_files = TRAIN_FILES
    # print(len(train_files))
    random.shuffle(train_files)

    total_correct = 0
    total_seen = 0
    loss_sum = 0
    total_labels = np.array([0, 0, 0], dtype=np.int32)

    gen = provider.data_generator(train_files, BATCH_SIZE)

    # for fn in range(len(TRAIN_FILES)):
    for fn in range(125000//BATCH_SIZE):  # total 125000 shards
        # log_string('----' + str(fn) + '-----')

        data_q, label_q, seg_q = next(gen)
        if data_q == None:
            print("an epoch is done in steps: ", fn)
            #print("total_labels:", total_labels)
            break
        #print(data_q[0].shape, label_q[0], seg_q[0].shape)
#        current_data, current_label, current_seg = provider.loadsegDataFile(train_files[fn])
#        current_data = current_data[:, 0:NUM_POINT, :]
#         current_data, current_label, current_seg_ = provider.shuffle_data(current_data, np.squeeze(current_label))
#        current_label = np.squeeze(current_label)
#        current_seg = np.squeeze(current_seg)

#        file_size = int(TRAIN_FILES[train_file_idxs[fn]].split('_')[1])
        file_size = BATCH_SIZE
        num_batches = 1

        current_data = np.array(data_q)
        current_label = np.array(label_q)
        current_seg = np.array(seg_q)

        rs_label = np.eye(5)[current_label]
        #print(current_label.shape, rs_label.shape, current_label, rs_label)
        # sys.exit()
        rs_seg = current_seg

        #train_data=np.concatenate((rs_current_data, rs_pottery),axis=0)
        train_data = current_data

        # Augment batched point clouds by rotation and jittering
        jittered_data = provider.rotate_perturbation_point_cloud(
            train_data, angle_sigma=np.pi/2, angle_clip=np.pi)
        jittered_data = provider.jitter_point_cloud(jittered_data)
        jittered_data = provider.random_scale_point_cloud(jittered_data)
        jittered_data = provider.shift_point_cloud(jittered_data)
        # print(jittered_data.shape)
        # print(rs_seg.shape)

        feed_dict = {ops['pointclouds_pl']: jittered_data,
                     ops['labels_pl']: rs_seg,
                     ops['labels_pl_c']: rs_label,
                     ops['is_training_pl']: is_training
                     }

        summary, step, _, loss_val, pred_val, pred_c_val = sess.run([
            ops['merged'],
            ops['step'],
            ops['train_op'],
            ops['loss'],
            # ops['cls_loss'],
            # ops['reg_loss'],
            ops['pred'],
            ops['pred_c'],
        ], feed_dict=feed_dict)

        train_writer.add_summary(summary, step)

        #correct = np.sum(np.argmax(pred_val) == rs_seg0)
        #print('pred, label, loss:', np.argmax(pred_val), rs_seg0, loss_val)
        #correct = np.prod(np.argmax(pred_val, axis=1) == np.argmax(rs_seg, axis=1))
        print('loss:', loss_val)
        #total_correct += correct
        total_seen += 1
        loss_sum += loss_val
        #total_labels[rs_seg0] += 1

        if fn % 100 == 100 - 1:
            # COMBINE FEATURES
            log_string('mean loss: %f' % (loss_sum / float(total_seen)))
            print('label:', rs_seg)
            print('pred:', pred_val)
            print('label_c:', current_label)
            print('pred_c:', np.argmax(pred_c_val, axis=1))
            print('loss_val:', loss_val)
            correct = np.sum(np.argmax(pred_c_val, axis=1)
                             == np.argmax(rs_label, axis=1))
            print('cls_acc_c:', correct)
コード例 #14
0
        def train_one_epoch(sess, train_writer):
            """ ops: dict mapping from string to tf ops """
            is_training = True

            # Shuffle train files
            train_file_idxs = np.arange(0, len(TRAIN_FILES))
            np.random.shuffle(train_file_idxs)

            mean_acc = 0
            for fn in range(len(TRAIN_FILES)):
                print('----' + str(fn) + '-----')
                current_data, current_label = provider.loadDataFile(
                    TRAIN_FILES[train_file_idxs[fn]])
                current_data = current_data[:, 0:self.num_points, :]
                current_data, current_label, _ = provider.shuffle_data(
                    current_data, np.squeeze(current_label))
                current_label = np.squeeze(current_label)

                file_size = current_data.shape[0]
                num_batches = file_size // self.batch_size

                total_correct = 0
                total_seen = 0
                loss_sum = 0

                for batch_idx in range(num_batches):
                    start_idx = batch_idx * self.batch_size
                    end_idx = (batch_idx + 1) * self.batch_size

                    # Augment batched point clouds by rotation and jittering
                    rotated_data = provider.rotate_point_cloud(
                        current_data[start_idx:end_idx, :, :])
                    jittered_data = provider.jitter_point_cloud(rotated_data)
                    jittered_data = provider.random_scale_point_cloud(
                        jittered_data)
                    jittered_data = provider.rotate_perturbation_point_cloud(
                        jittered_data)
                    jittered_data = provider.shift_point_cloud(jittered_data)

                    soft_targets = current_label[start_idx:end_idx]
                    n_values = np.zeros((soft_targets.size, self.num_classes))
                    n_values[np.arange(soft_targets.size), soft_targets] = 1
                    soft_targets = n_values
                    if teacher_flag:
                        soft_targets = teacher_model.predict(
                            current_data[start_idx:end_idx, :, :])

                    # print(soft_targets, current_label[start_idx:end_idx])

                    if not self.quantize_delay:
                        feed_dict = {
                            # ops['pointclouds_pl']: jittered_data,
                            self.pointclouds_pl:
                            current_data[start_idx:end_idx, :, :],
                            self.labels_pl:
                            current_label[start_idx:end_idx],
                            self.is_training:
                            is_training,
                            self.soft_Y:
                            soft_targets,
                            self.flag:
                            teacher_flag,
                            self.softmax_temperature:
                            self.temperature
                        }
                    else:
                        feed_dict = {
                            # ops['pointclouds_pl']: jittered_data,
                            self.pointclouds_pl:
                            current_data[start_idx:end_idx, :, :],
                            self.labels_pl:
                            current_label[start_idx:end_idx],
                            self.soft_Y:
                            soft_targets,
                            self.flag:
                            teacher_flag,
                            self.softmax_temperature:
                            self.temperature
                            # ops['is_training_pl']: is_training,
                        }

                    # print(feed_dict)
                    summary, step, _, loss_val, pred_val = sess.run(
                        [
                            self.merged_summary_op, self.batch, self.train_op,
                            self.total_loss, self.prediction
                        ],
                        feed_dict=feed_dict)
                    # sess.run(ops['mask_update_op'])
                    train_writer.add_summary(summary, step)
                    pred_val = np.argmax(pred_val, 1)
                    correct = np.sum(
                        pred_val == current_label[start_idx:end_idx])
                    total_correct += correct
                    total_seen += self.batch_size
                    loss_sum += loss_val

                print('mean loss: %f' % (loss_sum / float(num_batches)))
                print('accuracy: %f' % (total_correct / float(total_seen)))

                mean_acc += total_correct / float(total_seen)
            return mean_acc / len(TRAIN_FILES)
コード例 #15
0
ファイル: evaluate_hdf5.py プロジェクト: zuochen33/pointnet2
def eval_one_epoch(sess, ops, num_votes=1, topk=1):
    error_cnt = 0
    is_training = False
    total_correct = 0
    total_seen = 0
    loss_sum = 0
    total_seen_class = [0 for _ in range(NUM_CLASSES)]
    total_correct_class = [0 for _ in range(NUM_CLASSES)]
    fout = open(os.path.join(DUMP_DIR, 'pred_label.txt'), 'w')
    np.random.seed(101)
    for fn in range(len(TEST_FILES)):
        log_string('----' + str(fn) + '----')
        current_data, current_label = provider.loadDataFile(TEST_FILES[fn])
        if FLAGS.random_pc_order:
            current_data = change_point_cloud_order(current_data)
        current_data = current_data[:, 0:NUM_POINT, :]
        current_label = np.squeeze(current_label)
        print(current_data.shape)

        file_size = current_data.shape[0]
        num_batches = file_size // BATCH_SIZE
        print(file_size)

        for batch_idx in range(num_batches):
            start_idx = batch_idx * BATCH_SIZE
            end_idx = (batch_idx + 1) * BATCH_SIZE
            cur_batch_size = end_idx - start_idx

            # Aggregating BEG
            batch_loss_sum = 0  # sum of losses for the batch
            batch_pred_sum = np.zeros(
                (cur_batch_size, NUM_CLASSES))  # score for classes
            batch_pred_max = np.ones(
                (cur_batch_size, NUM_CLASSES)) * (-999999)  # score for classes
            batch_pred_classes = np.zeros(
                (cur_batch_size, NUM_CLASSES))  # 0/1 for classes
            for vote_idx in range(num_votes):
                # Shuffle point order to achieve different farthest samplings
                shuffled_indices = np.arange(NUM_POINT)
                np.random.shuffle(shuffled_indices)
                rotated_data = provider.rotate_point_cloud_by_angle(
                    current_data[start_idx:end_idx, shuffled_indices, :],
                    vote_idx / float(num_votes) * np.pi * 2)
                jittered_data = provider.random_scale_point_cloud(rotated_data)
                jittered_data = provider.rotate_perturbation_point_cloud(
                    jittered_data)
                jittered_data = provider.jitter_point_cloud(jittered_data)
                feed_dict = {
                    ops['pointclouds_pl']: rotated_data,
                    ops['labels_pl']: current_label[start_idx:end_idx],
                    ops['is_training_pl']: is_training
                }
                loss_val, pred_val = sess.run([ops['loss'], ops['pred']],
                                              feed_dict=feed_dict)
                batch_pred_sum += pred_val
                batch_pred_val = np.argmax(pred_val, 1)
                for el_idx in range(cur_batch_size):
                    batch_pred_classes[el_idx, batch_pred_val[el_idx]] += 1
                batch_loss_sum += (loss_val * cur_batch_size /
                                   float(num_votes))
            # pred_val_topk = np.argsort(batch_pred_sum, axis=-1)[:,-1*np.array(range(topk))-1]
            # pred_val = np.argmax(batch_pred_classes, 1)
            pred_val = np.argmax(batch_pred_sum, 1)
            # Aggregating END

            correct = np.sum(pred_val == current_label[start_idx:end_idx])
            # correct = np.sum(pred_val_topk[:,0:topk] == label_val)
            total_correct += correct
            total_seen += cur_batch_size
            loss_sum += batch_loss_sum

            for i in range(start_idx, end_idx):
                l = current_label[i]
                total_seen_class[l] += 1
                total_correct_class[l] += (pred_val[i - start_idx] == l)
                fout.write('%d, %d\n' % (pred_val[i - start_idx], l))

                if pred_val[
                        i -
                        start_idx] != l and FLAGS.visu:  # ERROR CASE, DUMP!
                    img_filename = '%d_label_%s_pred_%s.jpg' % (
                        error_cnt, SHAPE_NAMES[l],
                        SHAPE_NAMES[pred_val[i - start_idx]])
                    img_filename = os.path.join(DUMP_DIR, img_filename)
                    output_img = pc_util.point_cloud_three_views(
                        np.squeeze(current_data[i, :, :]))
                    scipy.misc.imsave(img_filename, output_img)
                    error_cnt += 1

    log_string('eval mean loss: %f' % (loss_sum / float(total_seen)))
    log_string('eval accuracy: %f' % (total_correct / float(total_seen)))
    log_string('eval avg class acc: %f' % (np.mean(
        np.array(total_correct_class) /
        np.array(total_seen_class, dtype=np.float))))

    class_accuracies = np.array(total_correct_class) / np.array(
        total_seen_class, dtype=np.float)
    for i, name in enumerate(SHAPE_NAMES):
        log_string('%10s:\t%0.3f' % (name, class_accuracies[i]))
コード例 #16
0
def train_one_epoch(sess, ops, train_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = True

    # Shuffle train files
    train_file_idxs = np.arange(0, len(LABELS))
    np.random.shuffle(train_file_idxs)

    current_data = np.empty([len(LABELS), NUM_POINT, 3], dtype=float)
    current_label = np.empty([len(LABELS), 1], dtype=int)

    for fn in range(len(LABELS)):
        cut1, cut2, _ = provider.loadDataFile_cut_2(
            TRAIN_FILES[train_file_idxs[fn]], False)
        idx = np.random.randint(cut1.shape[0], size=NUM_POINT / 2)
        cut1 = cut1[idx, :]
        idx = np.random.randint(cut2.shape[0], size=NUM_POINT / 2)
        cut2 = cut2[idx, :]
        current_data[fn] = np.concatenate((cut1, cut2), axis=0)
        current_label[fn] = LABELS[train_file_idxs[fn]]

    current_label = np.squeeze(current_label)

    file_size = current_data.shape[0]
    num_batches = file_size // BATCH_SIZE

    total_correct = 0
    total_seen = 0
    loss_sum = 0

    for batch_idx in range(num_batches):
        start_idx = batch_idx * BATCH_SIZE
        end_idx = (batch_idx + 1) * BATCH_SIZE

        # Augment batched point clouds by rotation and jittering
        rotated_data = provider.rotate_point_cloud(
            current_data[start_idx:end_idx, :, :])
        jittered_data = provider.jitter_point_cloud(rotated_data)
        jittered_data = provider.random_scale_point_cloud(jittered_data)
        jittered_data = provider.rotate_perturbation_point_cloud(jittered_data)
        jittered_data = provider.shift_point_cloud(jittered_data)

        feed_dict = {
            ops['pointclouds_pl']: jittered_data,
            ops['labels_pl']: current_label[start_idx:end_idx],
            ops['is_training_pl']: is_training,
        }
        summary, step, _, loss_val, pred_val, feature = sess.run(
            [
                ops['merged'], ops['step'], ops['train_op'], ops['loss'],
                ops['pred'], ops['feat']
            ],
            feed_dict=feed_dict)
        train_writer.add_summary(summary, step)
        pred_val = np.argmax(pred_val, 1)
        correct = np.sum(pred_val == current_label[start_idx:end_idx])
        total_correct += correct
        total_seen += BATCH_SIZE
        loss_sum += loss_val

        if batch_idx % 100 == 0:
            log_string('mean loss: {0:f}     accuracy: {1:f}'.format(
                loss_sum / float(batch_idx + 1),
                total_correct / float(total_seen)))

    log_string('mean loss: {0:f}     accuracy: {1:f}'.format(
        loss_sum / float(batch_idx + 1), total_correct / float(total_seen)))
コード例 #17
0
ファイル: train.py プロジェクト: xvjiarui/partnet_seg_exps
def train_one_epoch(sess, ops, writer, epoch):
    """ ops: dict mapping from string to tf ops """
    is_training = True
    
    log_string(str(datetime.now()))

    # shuffle training files order
    random.shuffle(train_h5_fn_list)

    for item in train_h5_fn_list:
        cur_h5_fn = os.path.join(data_in_dir, item)
        print('Reading data from ', cur_h5_fn)
        pts, gt_label, gt_mask, gt_valid, gt_other_mask, _ = load_data(cur_h5_fn)
        
        # shuffle data order
        n_shape = pts.shape[0]
        idx = np.arange(n_shape)
        np.random.shuffle(idx)
        
        pts = pts[idx, ...]
        gt_label = gt_label[idx, ...]
        gt_mask = gt_mask[idx, ...]
        gt_valid = gt_valid[idx, ...]
        gt_other_mask = gt_other_mask[idx, ...]

        # data augmentation to pts
        pts = provider.jitter_point_cloud(pts)
        pts = provider.shift_point_cloud(pts)
        pts = provider.random_scale_point_cloud(pts)
        pts = provider.rotate_perturbation_point_cloud(pts)

        num_batch = n_shape // BATCH_SIZE
        for i in range(num_batch):
            start_idx = i * BATCH_SIZE
            end_idx = (i + 1) * BATCH_SIZE

            cur_pts = pts[start_idx: end_idx, ...]
            cur_gt_label = gt_label[start_idx: end_idx, ...]
            cur_gt_mask = gt_mask[start_idx: end_idx, ...]
            cur_gt_valid = gt_valid[start_idx: end_idx, ...]
            cur_gt_other_mask = gt_other_mask[start_idx: end_idx, ...]

            feed_dict = {ops['pc_pl']: cur_pts,
                         ops['label_pl']: cur_gt_label,
                         ops['gt_mask_pl']: cur_gt_mask,
                         ops['gt_valid_pl']: cur_gt_valid,
                         ops['gt_other_mask_pl']: cur_gt_other_mask,
                         ops['is_training_pl']: is_training}

            summary, step, _, lr_val, bn_decay_val, seg_loss_val, ins_loss_val, other_ins_loss_val, l21_norm_loss_val, conf_loss_val, loss_val, seg_pred_val, \
                    mask_pred_val, other_mask_pred_val = sess.run([ops['merged'], ops['step'], ops['train_op'], ops['learning_rate'], ops['bn_decay'], \
                        ops['seg_loss'], ops['ins_loss'], ops['other_ins_loss'], ops['l21_norm_loss'], ops['conf_loss'], ops['loss'], \
                        ops['seg_pred'], ops['mask_pred'], ops['other_mask_pred']], feed_dict=feed_dict)

            writer.add_summary(summary, step)

            seg_pred_id = np.argmax(seg_pred_val, axis=-1)
            seg_acc = np.mean(seg_pred_id == cur_gt_label)

            log_string('[Train Epoch %03d, Batch %03d, LR: %f, BN_DECAY: %f] Loss: %f = %f x %f (seg_loss, Seg Acc: %f) + %f x %f (ins_loss) + %f x %f (other_ins_loss) + %f x %f (l21_norm_loss) + %f x %f (conf_loss)' \
                    % (epoch, i, lr_val, bn_decay_val, loss_val, FLAGS.seg_loss_weight, seg_loss_val, seg_acc, \
                    FLAGS.ins_loss_weight, ins_loss_val, FLAGS.other_ins_loss_weight, other_ins_loss_val, \
                    FLAGS.l21_norm_loss_weight, l21_norm_loss_val, FLAGS.conf_loss_weight, conf_loss_val))
コード例 #18
0
def eval_one_epoch(sess, ops, test_writer):
	""" ops: dict mapping from string to tf ops """
	global EPOCH_CNT
	global BEST_ACC
	global BEST_CLS_ACC

	is_training = False

	# Make sure batch data is of same size
	cur_batch_data = np.zeros((BATCH_SIZE,NUM_POINT,TEST_DATASET.num_channel()))
	cur_batch_label = np.zeros((BATCH_SIZE), dtype=np.int32)

	total_correct = 0
	total_seen = 0
	loss_sum = 0
	batch_idx = 0
	shape_ious = []
	total_seen_class = [0 for _ in range(NUM_CLASSES)]
	total_correct_class = [0 for _ in range(NUM_CLASSES)]
	
	log_string(str(datetime.now()))
	log_string('---- EPOCH %03d EVALUATION ----'%(EPOCH_CNT))

	while TEST_DATASET.has_next_batch():
		batch_data, batch_label = TEST_DATASET.next_batch(augment=False)
		bsize = batch_data.shape[0]
		# print('Batch: %03d, batch size: %d'%(batch_idx, bsize))
		# for the last batch in the epoch, the bsize:end are from last batch
		cur_batch_data[0:bsize,...] = batch_data
		cur_batch_label[0:bsize] = batch_label

		if ROTATE_FLAG:
			batch_pred_sum = np.zeros((BATCH_SIZE, NUM_CLASSES)) # score for classes
			for vote_idx in range(12):
				# Shuffle point order to achieve different farthest samplings
				shuffled_indices = np.arange(NUM_POINT)
				np.random.shuffle(shuffled_indices)
				if NORMAL_FLAG:
					rotated_data = provider.rotate_point_cloud_by_angle_with_normal(cur_batch_data[:, shuffled_indices, :],
						vote_idx/float(12) * np.pi * 2)
					rotated_data = provider.rotate_perturbation_point_cloud_with_normal(rotated_data)
				else:
					rotated_data = provider.rotate_point_cloud_by_angle(cur_batch_data[:, shuffled_indices, :],
	                                                  vote_idx/float(12) * np.pi * 2)
					rotated_data = provider.rotate_perturbation_point_cloud(rotated_data)

				jittered_data = provider.random_scale_point_cloud(rotated_data[:,:,0:3])
				
				jittered_data = provider.jitter_point_cloud(jittered_data)
				rotated_data[:,:,0:3] = jittered_data
				# else:
					# rotated_data = provider.rotate_point_cloud_by_angle(cur_batch_data[:, shuffled_indices, :],
						# vote_idx/float(12) * np.pi * 2)
				feed_dict = {ops['pointclouds_pl']: rotated_data,
							 ops['labels_pl']: cur_batch_label,
							 ops['is_training_pl']: is_training}
				loss_val, pred_val = sess.run([ops['loss'], ops['pred']], feed_dict=feed_dict)
				batch_pred_sum += pred_val
			pred_val = np.argmax(batch_pred_sum, 1)

		else:
			feed_dict = {ops['pointclouds_pl']: cur_batch_data,
					 ops['labels_pl']: cur_batch_label,
					 ops['is_training_pl']: is_training}
			summary, step, loss_val, pred_val = sess.run([ops['merged'], ops['step'],
				ops['loss'], ops['pred']], feed_dict=feed_dict)
			test_writer.add_summary(summary, step)
			pred_val = np.argmax(pred_val, 1)

		correct = np.sum(pred_val[0:bsize] == batch_label[0:bsize])
		total_correct += correct
		total_seen += bsize
		loss_sum += loss_val
		batch_idx += 1
		for i in range(bsize):
			l = batch_label[i]
			total_seen_class[l] += 1
			total_correct_class[l] += (pred_val[i] == l)
	
	current_acc = total_correct / float(total_seen)
	current_cls_acc = np.mean(np.array(total_correct_class)/np.array(total_seen_class,dtype=np.float))

	log_string('eval mean loss: %f' % (loss_sum / float(batch_idx)))
	log_string('eval accuracy: %f'% (current_acc))
	log_string('eval avg class acc: %f' % (current_cls_acc))

	best_acc_flag, best_cls_acc_flag = False, False
	if current_acc > BEST_ACC:
		BEST_ACC = current_acc
		best_acc_flag = True

	if current_cls_acc > BEST_CLS_ACC:
		BEST_CLS_ACC = current_cls_acc
		best_cls_acc_flag = True

	log_string('eval best accuracy: %f'% (BEST_ACC))
	log_string('eval best avg class acc: %f'% (BEST_CLS_ACC))

	EPOCH_CNT += 1

	TEST_DATASET.reset()
	return (best_acc_flag, best_cls_acc_flag)
コード例 #19
0
def train_one_epoch(sess, ops, train_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = True

    # Shuffle train files
    train_file_idxs = np.arange(0, len(TRAIN_FILES))
    np.random.shuffle(train_file_idxs)

    current_data_1 = np.empty([3 * len(TRAIN_FILES), NUM_POINT, 3],
                              dtype=float)
    current_data_2 = np.empty([3 * len(TRAIN_FILES), NUM_POINT, 3],
                              dtype=float)
    current_label = np.empty([3 * len(TRAIN_FILES), 1], dtype=int)

    fn = 0
    count = 0
    while fn < len(TRAIN_FILES) - 1:
        # log_string('----' + str(fn) + '-----')

        total_current = []
        a1, a2, _ = provider.loadDataFile_cut_2(
            TRAIN_FILES[train_file_idxs[fn]])

        idx = np.random.randint(a1.shape[0], size=NUM_POINT)
        a1 = a1[idx, :]
        idx = np.random.randint(a2.shape[0], size=NUM_POINT)
        a2 = a2[idx, :]
        total_current.append(a1)
        total_current.append(a2)

        fn = fn + 1

        b1, b2, _ = provider.loadDataFile_cut_2(
            TRAIN_FILES[train_file_idxs[fn]])

        idx = np.random.randint(b1.shape[0], size=NUM_POINT)
        b1 = b1[idx, :]
        idx = np.random.randint(b2.shape[0], size=NUM_POINT)
        b2 = b2[idx, :]
        total_current.append(b1)
        total_current.append(b2)

        fn = fn + 1

        pair_num = 0
        for index in range(len(total_current)):
            for index2 in range(index + 1, len(total_current)):
                current_data_1[6 * count +
                               pair_num, :, :] = total_current[index]
                current_data_2[6 * count +
                               pair_num, :, :] = total_current[index2]
                if (index < 2) and (index2 >= 2):
                    current_label[6 * count + pair_num, :] = 0
                else:
                    current_label[6 * count + pair_num, :] = 1

                pair_num = pair_num + 1
        count = count + 1

    current_label = np.squeeze(current_label)

    file_size = current_data_1.shape[0]
    num_batches = file_size // BATCH_SIZE

    total_correct = 0
    total_seen = 0
    loss_sum = 0

    for batch_idx in range(num_batches):
        start_idx = batch_idx * BATCH_SIZE
        end_idx = (batch_idx + 1) * BATCH_SIZE

        # shuffle each batch
        data_1 = current_data_1[start_idx:end_idx, :, :]
        data_2 = current_data_2[start_idx:end_idx, :, :]
        label = current_label[start_idx:end_idx]
        combine_data = np.concatenate((data_1, data_2), axis=2)
        combine_data, label, _ = provider.shuffle_data(combine_data,
                                                       np.squeeze(label))
        data_1 = combine_data[:, :, 0:3]
        data_2 = combine_data[:, :, 3:6]
        label = np.squeeze(label)

        # Augment batched point clouds by rotation and jittering
        rotated_data_1 = provider.rotate_point_cloud(data_1)
        jittered_data_1 = provider.jitter_point_cloud(rotated_data_1)
        jittered_data_1 = provider.random_scale_point_cloud(jittered_data_1)
        jittered_data_1 = provider.rotate_perturbation_point_cloud(
            jittered_data_1)
        jittered_data_1 = provider.shift_point_cloud(jittered_data_1)

        rotated_data_2 = provider.rotate_point_cloud(data_2)
        jittered_data_2 = provider.jitter_point_cloud(rotated_data_2)
        jittered_data_2 = provider.random_scale_point_cloud(jittered_data_2)
        jittered_data_2 = provider.rotate_perturbation_point_cloud(
            jittered_data_2)
        jittered_data_2 = provider.shift_point_cloud(jittered_data_2)

        feed_dict = {
            ops['pointclouds_pl_1']: jittered_data_1,
            ops['pointclouds_pl_2']: jittered_data_2,
            ops['labels_pl']: label,
            ops['is_training_pl']: is_training
        }
        summary, step, _, loss_val, pred_val = sess.run([
            ops['merged'], ops['step'], ops['train_op'], ops['loss'],
            ops['pred']
        ],
                                                        feed_dict=feed_dict)
        train_writer.add_summary(summary, step)
        pred_val = np.argmax(pred_val, 1)
        correct = np.sum(pred_val == label)
        total_correct += correct
        total_seen += BATCH_SIZE
        loss_sum += loss_val

        if batch_idx % 50 == 0:
            log_string('mean loss: {0:f}     accuracy: {1:f}'.format(
                loss_sum / float(batch_idx + 1),
                total_correct / float(total_seen)))

    log_string('mean loss: {0:f}     accuracy: {1:f}'.format(
        loss_sum / float(batch_idx + 1), total_correct / float(total_seen)))
コード例 #20
0
def eval_one_epoch(sess, ops, test_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = False
    total_correct = 0
    total_seen = 0
    loss_sum = 0
    loss_sum_c = 0
    total_seen_class = [0 for _ in range(NUM_CLASSES)]
    total_correct_class = [0 for _ in range(NUM_CLASSES)]
    total_labels = np.array([0, 0, 0], dtype=np.int32)

    gen = provider.data_generator(TEST_FILES, BATCH_SIZE)

    # for fn in range(len(TEST_FILES)):
    for fn in range(125000//BATCH_SIZE):  # total 125000 shards

        data_q, label_q, seg_q = next(gen)
        if data_q == None:
            print("an epoch is done in steps: ", fn)
            #print("total_labels:", total_labels)
            break

        #file_size = int(TEST_FILES[fn].split('_')[1])
        file_size = BATCH_SIZE
        num_batches = 1

        current_data = np.array(data_q)
        current_label = np.array(label_q)
        current_seg = np.array(seg_q)


        """
        if current_label_i == 0 or current_label_i == 4:  # pot #1 or #4
            if current_seg_i[0] == 1:
                rs_seg0 = 0
            elif current_seg_i[3] == 1:
                rs_seg0 = 2
            else:
                rs_seg0 = 1
        else:
            if current_seg_i[0] == 1:
                rs_seg0 = 0
            elif current_seg_i[2] == 1:
                rs_seg0 = 2
            else:
                rs_seg0 = 1

        rs_seg = np.eye(3)[rs_seg0]

        #print(rs_seg)
        #rs_seg = np.eye(16)[rs_seg0] # rs_seg.shape=(4,2)
        #rs_seg = np.eye(8)[rs_seg0] # rs_seg.shape=(4,2)
        #print(rs_seg.shape, rs_seg)
        """

        rs_label = np.eye(5)[current_label]
        rs_seg = current_seg

        #test_data=np.concatenate((rs_current_data, rs_pottery),axis=0)
        test_data = current_data

        jittered_data = provider.rotate_perturbation_point_cloud(
            test_data, angle_sigma=np.pi/2, angle_clip=np.pi)
        jittered_data = provider.jitter_point_cloud(jittered_data)
        jittered_data = provider.random_scale_point_cloud(jittered_data)
        jittered_data = provider.shift_point_cloud(jittered_data)

        feed_dict = {ops['pointclouds_pl']: jittered_data,
                     ops['labels_pl']: rs_seg,
                     ops['labels_pl_c']: rs_label,
                     ops['is_training_pl']: is_training
                     }
        # summary,
        step, cls_loss, cls_loss_c, pred_val, pred_c_val = sess.run([
            # ops['merged'],
            ops['step'],
            ops['cls_loss'],
            ops['cls_loss_c'],
            ops['pred'],
            ops['pred_c'],
        ], feed_dict=feed_dict)

        #correct = np.sum(np.argmax(pred_val) == rs_seg0)
        #print(np.argmax(pred_c_val, axis=1), np.argmax(rs_label, axis=1))
        correct = np.sum(np.argmax(pred_c_val, axis=1)
                         == np.argmax(rs_label, axis=1))
        print('cls_loss, cls_loss_c, cls_acc_c:',
              cls_loss, cls_loss_c, correct)
        #print('pred, label, loss:', np.argmax(pred_val, axis=1), np.argmax(rs_seg, axis=1), loss_val)

        total_correct += correct / float(BATCH_SIZE)
        total_seen += 1
#            loss_sum += (loss_val*BATCH_SIZE)
        loss_sum += cls_loss
        loss_sum_c += cls_loss_c
        #total_labels[rs_seg0] += 1

    loss_avg = loss_sum / float(total_seen)
    loss_avg_c = loss_sum_c / float(total_seen)
    acc_avg_c = total_correct / float(total_seen)
    log_string('eval mean cls_loss: %f' % loss_avg)
    log_string('eval mean cls_loss_c: %f' % loss_avg_c)
    log_string('eval mean acc_loss_c: %f' % acc_avg_c)
    val_loss_summary.value[0].simple_value = (loss_avg)
    val_loss_c_summary.value[0].simple_value = (loss_avg_c)
    val_acc_c_summary.value[0].simple_value = (acc_avg_c)
    test_writer.add_summary(val_loss_summary, step)
    test_writer.add_summary(val_loss_c_summary, step)
    test_writer.add_summary(val_acc_c_summary, step)
    test_writer.flush()
    #print("total_labels:", total_labels)
    print('label:', rs_seg)
    print('pred:', pred_val)
    print('cls_loss:', cls_loss)
    print('label_c:', current_label)
    print('pred_c:', np.argmax(pred_c_val, axis=1))
    print('cls_loss_c:', cls_loss_c)
    print('cls_acc_c:', correct)

    return loss_avg, loss_avg_c
コード例 #21
0
ファイル: train.py プロジェクト: xvjiarui/partnet_seg_exps
        def train_one_epoch(epoch_num):

            ### NOTE: is_training = False: We do not update bn parameters during training due to the small batch size. This requires pre-training PointNet with large batchsize (say 32).
            is_training = False

            allow_full_loss = 0.0
            if epoch_num > 5:
                allow_full_loss = 1.0

            printout(flog, str(datetime.now()))

            random.shuffle(train_h5_fn_list)
            for item in train_h5_fn_list:
                cur_h5_fn = os.path.join(data_in_dir, item)
                printout(flog, 'Reading data from: %s' % cur_h5_fn)

                pts, semseg_one_hot, semseg_mask, insseg_one_hot, insseg_mask = load_data(
                    cur_h5_fn)

                # shuffle data order
                n_shape = pts.shape[0]
                idx = np.arange(n_shape)
                np.random.shuffle(idx)

                pts = pts[idx, ...]
                semseg_one_hot = semseg_one_hot[idx, ...]
                semseg_mask = semseg_mask[idx, ...]
                insseg_one_hot = insseg_one_hot[idx, ...]
                insseg_mask = insseg_mask[idx, ...]

                # data augmentation to pts
                pts = provider.jitter_point_cloud(pts)
                pts = provider.shift_point_cloud(pts)
                pts = provider.random_scale_point_cloud(pts)
                pts = provider.rotate_perturbation_point_cloud(pts)

                total_loss = 0.0
                total_grouperr = 0.0
                total_same = 0.0
                total_diff = 0.0
                total_same = 0.0
                total_pos = 0.0
                total_acc = 0.0

                num_batch = n_shape // BATCH_SIZE
                for i in range(num_batch):
                    start_idx = i * BATCH_SIZE
                    end_idx = (i + 1) * BATCH_SIZE

                    feed_dict = {
                        pointclouds_ph: pts[start_idx:end_idx, ...],
                        ptsseglabel_ph: semseg_one_hot[start_idx:end_idx, ...],
                        ptsgroup_label_ph: insseg_one_hot[start_idx:end_idx,
                                                          ...],
                        pts_seglabel_mask_ph: semseg_mask[start_idx:end_idx,
                                                          ...],
                        pts_group_mask_ph: insseg_mask[start_idx:end_idx, ...],
                        is_training_ph: is_training,
                        alpha_ph: min(10., (float(epoch_num) / 5.) * 2. + 2.),
                        allow_full_loss_pl: allow_full_loss
                    }

                    _, loss_val, lr_val, simmat_val, semseg_logits_val, \
                            grouperr_val, same_val, same_cnt_val, diff_val, diff_cnt_val, pos_val, pos_cnt_val = sess.run([\
                                train_op, loss, learning_rate, \
                                net_output['simmat'], net_output['semseg_logits'], \
                                grouperr, same, same_cnt, diff, diff_cnt, pos, pos_cnt], feed_dict=feed_dict)

                    seg_pred = np.argmax(semseg_logits_val, axis=-1)
                    gt_mask_val = (np.sum(
                        semseg_one_hot[start_idx:end_idx, ...], axis=-1) > 0)
                    seg_gt = np.argmax(semseg_one_hot[start_idx:end_idx, ...],
                                       axis=-1)
                    correct = ((seg_pred == seg_gt) + (~gt_mask_val) > 0)
                    acc_val = np.mean(correct)

                    total_acc += acc_val
                    total_loss += loss_val
                    total_grouperr += grouperr_val
                    total_diff += diff_val / max(1e-6, diff_cnt_val)
                    total_same += same_val / max(1e-6, same_cnt_val)
                    total_pos += pos_val / max(1e-6, pos_cnt_val)

                    if i % 10 == 9:
                        printout(flog, 'Batch: %d, LR: %f, loss: %f, FullLoss: %f, SegAcc: %f, grouperr: %f, same: %f, diff: %f, pos: %f' % \
                                (i, lr_val, total_loss/10, allow_full_loss, total_acc/10, total_grouperr/10, total_same/10, total_diff/10, total_pos/10))

                        lr_sum, batch_sum, train_loss_sum, group_err_sum = sess.run( \
                            [lr_op, batch, total_train_loss_sum_op, group_err_op], \
                            feed_dict={total_training_loss_ph: total_loss / 10.,
                                       group_err_loss_ph: total_grouperr / 10., })

                        train_writer.add_summary(train_loss_sum, batch_sum)
                        train_writer.add_summary(lr_sum, batch_sum)
                        train_writer.add_summary(group_err_sum, batch_sum)

                        total_grouperr = 0.0
                        total_loss = 0.0
                        total_diff = 0.0
                        total_same = 0.0
                        total_pos = 0.0
                        same_cnt0 = 0
                        total_acc = 0.0