예제 #1
0
    def next_batch(self, batch_size, augment=True, dropout=True):
        batch_data = []
        batch_label = []
        batch_weights = []
        feature_size = 0
        for _ in range(batch_size):
            if not self.z_feature:
                data, label, colors, weights = self.next_input(dropout)
                if self.use_color:
                    feature_size = 3
                    data = np.hstack((data, colors))
            else:
                feature_size = 1
                data, z_norm, label, colors, weights = self.next_input(dropout)
                if self.use_color:
                    feature_size = 4
                    data = np.hstack((data, colors))
                data = np.hstack((data, z_norm))
            batch_data.append(data)
            batch_label.append(label)
            batch_weights.append(weights)

        batch_data = np.array(batch_data)
        batch_label = np.array(batch_label)
        batch_weights = np.array(batch_weights)

        # Optional batch augmentation
        if augment and feature_size:
            batch_data = provider.rotate_feature_point_cloud(
                batch_data, feature_size)
        if augment and not feature_size:
            batch_data = provider.rotate_point_cloud(batch_data)

        return batch_data, batch_label, batch_weights
예제 #2
0
 def _augment_batch_data(self, batch_data):
     rotated_data = provider.rotate_point_cloud(batch_data[:, :, 1:4])
     rotated_data = provider.rotate_perturbation_point_cloud(rotated_data)
     jittered_data = provider.random_scale_point_cloud(rotated_data)
     jittered_data = provider.shift_point_cloud(jittered_data)
     jittered_data = provider.jitter_point_cloud(jittered_data)
     batch_data[:, :, 1:4] = jittered_data
     return provider.shuffle_points(batch_data)
예제 #3
0
def train_one_epoch(sess, ops, train_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = True
    log_string(str(datetime.now()))

    # Make sure batch data is of same size
    cur_batch_data = np.zeros((BATCH_SIZE, NUM_POINT, TRAIN_DATASET.num_channel()))
    cur_batch_label = np.zeros((BATCH_SIZE), dtype=np.int32)

    total_correct = 0
    total_seen = 0
    loss_sum = 0
    num_batch = int(len(TRAIN_DATASET) / BATCH_SIZE)

    with tqdm(total=num_batch) as pbar:
        while TRAIN_DATASET.has_next_batch():
            batch_data, batch_label = TRAIN_DATASET.next_batch()

            if FLAGS.rotation:
                if FLAGS.normal:
                    batch_data = provider.rotate_point_cloud_with_normal(batch_data)
                    batch_data = provider.rotate_perturbation_point_cloud_with_normal(batch_data)
                else:
                    batch_data = provider.rotate_point_cloud(batch_data)
                    batch_data = provider.rotate_perturbation_point_cloud(batch_data)

            batch_data[:, :, 0:3] = provider.random_scale_point_cloud(batch_data[:, :, 0:3])
            batch_data[:, :, 0:3] = provider.shift_point_cloud(batch_data[:, :, 0:3])
            batch_data = provider.shuffle_points(batch_data)
            batch_data = provider.random_point_dropout(batch_data)

            bsize = batch_data.shape[0]
            cur_batch_data[0:bsize, ...] = batch_data
            cur_batch_label[0:bsize] = batch_label

            feed_dict = {ops['pointclouds_pl']: cur_batch_data,
                         ops['labels_pl']: cur_batch_label,
                         ops['is_training_pl']: is_training, }
            summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'],
                                                             ops['train_op'], ops['loss'], ops['pred']],
                                                            feed_dict=feed_dict)
            train_writer.add_summary(summary, step)
            pred_val = np.argmax(pred_val, 1)
            correct = np.sum(pred_val[0:bsize] == batch_label[0:bsize])
            total_correct += correct
            total_seen += bsize
            loss_sum += loss_val

            if FLAGS.debug:
                break

            pbar.update(1)

    log_string('Current Learning Rate %.6f' % sess.run(get_learning_rate(step)))
    log_string('Training loss: %f' % (loss_sum / num_batch))
    log_string('Training accuracy: %f\n' % (total_correct / float(total_seen)))
    TRAIN_DATASET.reset()
예제 #4
0
    def _augment_batch_data(self, batch_data):
        if self.normal_channel:
            rotated_data = provider.rotate_point_cloud_with_normal(batch_data)
            rotated_data = provider.rotate_perturbation_point_cloud_with_normal(
                rotated_data
            )
        else:
            rotated_data = provider.rotate_point_cloud(batch_data)
            rotated_data = provider.rotate_perturbation_point_cloud(rotated_data)

        jittered_data = provider.random_scale_point_cloud(rotated_data[:, :, 0:3])
        jittered_data = provider.shift_point_cloud(jittered_data)
        jittered_data = provider.jitter_point_cloud(jittered_data)
        rotated_data[:, :, 0:3] = jittered_data
        return provider.shuffle_points(rotated_data)
예제 #5
0
    def next_batch(self,batch_size,augment=True,dropout=True):
        batch_data = []
        batch_label = []
        batch_weights = []

        for batch in range(batch_size):
            data, label, weights = self.next_input()
            batch_data.append(data)
            batch_label.append(label)
            batch_weights.append(weights)

        batch_data = np.array(batch_data)
        batch_label = np.array(batch_label)
        batch_weights = np.array(batch_weights)

        # Optionnal batch augmentation
        if augment:
            batch_data = provider.rotate_point_cloud(batch_data)

        return batch_data, batch_label, batch_weights
예제 #6
0
def eval_one_epoch(sess, ops, test_writer):
    """Evaluate one epoch
    
    Args:
        sess (tf.Session): the session to evaluate tensors and operations
        ops (tf.Operation): the dict of operations
        test_writer (tf.summary.FileWriter): enable to log the evaluation on TensorBoard
    
    Returns:
        float: the overall accuracy computed on the test set
    """

    global EPOCH_CNT

    is_training = False
    test_idxs = np.arange(0, len(TEST_DATASET))
    num_batches = len(TEST_DATASET) / BATCH_SIZE

    # Reset metrics
    loss_sum = 0
    confusion_matrix = metric.ConfusionMatrix(NUM_CLASSES)

    log_string(str(datetime.now()))
    log_string('---- EPOCH %03d EVALUATION ----' % (EPOCH_CNT))

    for batch_idx in range(num_batches):
        start_idx = batch_idx * BATCH_SIZE
        end_idx = (batch_idx + 1) * BATCH_SIZE
        batch_data, batch_label, batch_smpw = get_batch(
            TEST_DATASET, test_idxs, start_idx, end_idx)

        aug_data = provider.rotate_point_cloud(batch_data)

        feed_dict = {
            ops['pointclouds_pl']: aug_data,
            ops['labels_pl']: batch_label,
            ops['smpws_pl']: batch_smpw,
            ops['is_training_pl']: is_training
        }
        summary, step, loss_val, pred_val = sess.run(
            [ops['merged'], ops['step'], ops['loss'], ops['pred']],
            feed_dict=feed_dict)

        test_writer.add_summary(summary, step)
        pred_val = np.argmax(pred_val, 2)  # BxN

        # Update metrics
        for i in range(len(pred_val)):
            for j in range(len(pred_val[i])):
                confusion_matrix.count_predicted(batch_label[i][j],
                                                 pred_val[i][j])
        loss_sum += loss_val

    iou_per_class = confusion_matrix.get_intersection_union_per_class()

    # Display metrics
    log_string('eval mean loss: %f' % (loss_sum / float(num_batches)))
    log_string("Overall accuracy : %f" %
               (confusion_matrix.get_overall_accuracy()))
    log_string("Average IoU : %f" %
               (confusion_matrix.get_average_intersection_union()))
    for i in range(1, NUM_CLASSES):
        log_string("IoU of %s : %f" % (data.LABELS_NAMES[i], iou_per_class[i]))

    EPOCH_CNT += 5
    return confusion_matrix.get_overall_accuracy()
예제 #7
0
def train_one_epoch(sess, ops, train_writer):
    """Train one epoch
    
    Args:
        sess (tf.Session): the session to evaluate Tensors and ops
        ops (dict of tf.Operation): contain multiple operation mapped with with strings
        train_writer (tf.FileSaver): enable to log the training with TensorBoard
    """

    is_training = True

    # Shuffle train samples
    train_idxs = np.arange(0, len(TRAIN_DATASET))
    np.random.shuffle(train_idxs)
    num_batches = len(TRAIN_DATASET) / BATCH_SIZE

    log_string(str(datetime.now()))

    # Reset metrics
    loss_sum = 0
    confusion_matrix = metric.ConfusionMatrix(NUM_CLASSES)

    # Train over num_batches batches
    for batch_idx in range(num_batches):
        start_idx = batch_idx * BATCH_SIZE
        end_idx = (batch_idx + 1) * BATCH_SIZE
        batch_data, batch_label, batch_smpw = get_batch(
            TRAIN_DATASET, train_idxs, start_idx, end_idx, True, INPUT_DROPOUT)

        # Augment batched point clouds by z-axis rotation
        aug_data = provider.rotate_point_cloud(batch_data)

        # Get predicted labels
        feed_dict = {
            ops['pointclouds_pl']: aug_data,
            ops['labels_pl']: batch_label,
            ops['smpws_pl']: batch_smpw,
            ops['is_training_pl']: is_training,
        }
        summary, step, _, loss_val, pred_val, _ = sess.run([
            ops['merged'], ops['step'], ops['train_op'], ops['loss'],
            ops['pred'], ops['update_iou']
        ],
                                                           feed_dict=feed_dict)
        train_writer.add_summary(summary, step)
        pred_val = np.argmax(pred_val, 2)

        # Update metrics
        for i in range(len(pred_val)):
            for j in range(len(pred_val[i])):
                confusion_matrix.count_predicted(batch_label[i][j],
                                                 pred_val[i][j])
        loss_sum += loss_val

        # Every 10 batches, print metrics and reset them
        if (batch_idx + 1) % 10 == 0:
            log_string(' -- %03d / %03d --' % (batch_idx + 1, num_batches))
            log_string('mean loss: %f' % (loss_sum / 10))
            log_string("Overall accuracy : %f" %
                       (confusion_matrix.get_overall_accuracy()))
            log_string("Average IoU : %f" %
                       (confusion_matrix.get_average_intersection_union()))
            iou_per_class = confusion_matrix.get_intersection_union_per_class()
            for i in range(1, NUM_CLASSES):
                log_string("IoU of %s : %f" %
                           (data.LABELS_NAMES[i], iou_per_class[i]))
            loss_sum = 0
            confusion_matrix = metric.ConfusionMatrix(NUM_CLASSES)