Exemplo n.º 1
0
    def sample_batch_in_all_files(self, batch_size, augment=True):
        batch_data = []
        batch_label = []
        batch_weights = []

        for _ in range(batch_size):
            points, labels, colors, geometry, weights = self.sample_in_all_files(
                is_training=True)
            data_list = [points]
            if self.use_color:
                data_list.append(colors)
            if self.use_geometry:
                data_list.append(geometry)
            batch_data.append(np.hstack(data_list))
            batch_label.append(labels)
            batch_weights.append(weights)

        batch_data = np.array(batch_data)  # (B, N, C), C = 3[+3][+7]
        batch_label = np.array(batch_label)  # (B, N)
        batch_weights = np.array(batch_weights)  # (B, N)

        if augment:
            batch_data[:, :, :3] = provider.jitter_point_cloud(
                batch_data[:, :, :3])
            batch_data[:, :, :3] = provider.shift_point_cloud(
                batch_data[:, :, :3])
            batch_data[:, :, :3] = provider.random_scale_point_cloud(
                batch_data[:, :, :3])

        return batch_data, batch_label, batch_weights
Exemplo n.º 2
0
 def _augment_batch_data(self, batch_data):
     rotated_data = provider.rotate_point_cloud(batch_data[:, :, 1:4])
     rotated_data = provider.rotate_perturbation_point_cloud(rotated_data)
     jittered_data = provider.random_scale_point_cloud(rotated_data)
     jittered_data = provider.shift_point_cloud(jittered_data)
     jittered_data = provider.jitter_point_cloud(jittered_data)
     batch_data[:, :, 1:4] = jittered_data
     return provider.shuffle_points(batch_data)
Exemplo n.º 3
0
def train_one_epoch(sess, ops, train_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = True
    log_string(str(datetime.now()))

    # Make sure batch data is of same size
    cur_batch_data = np.zeros((BATCH_SIZE, NUM_POINT, TRAIN_DATASET.num_channel()))
    cur_batch_label = np.zeros((BATCH_SIZE), dtype=np.int32)

    total_correct = 0
    total_seen = 0
    loss_sum = 0
    num_batch = int(len(TRAIN_DATASET) / BATCH_SIZE)

    with tqdm(total=num_batch) as pbar:
        while TRAIN_DATASET.has_next_batch():
            batch_data, batch_label = TRAIN_DATASET.next_batch()

            if FLAGS.rotation:
                if FLAGS.normal:
                    batch_data = provider.rotate_point_cloud_with_normal(batch_data)
                    batch_data = provider.rotate_perturbation_point_cloud_with_normal(batch_data)
                else:
                    batch_data = provider.rotate_point_cloud(batch_data)
                    batch_data = provider.rotate_perturbation_point_cloud(batch_data)

            batch_data[:, :, 0:3] = provider.random_scale_point_cloud(batch_data[:, :, 0:3])
            batch_data[:, :, 0:3] = provider.shift_point_cloud(batch_data[:, :, 0:3])
            batch_data = provider.shuffle_points(batch_data)
            batch_data = provider.random_point_dropout(batch_data)

            bsize = batch_data.shape[0]
            cur_batch_data[0:bsize, ...] = batch_data
            cur_batch_label[0:bsize] = batch_label

            feed_dict = {ops['pointclouds_pl']: cur_batch_data,
                         ops['labels_pl']: cur_batch_label,
                         ops['is_training_pl']: is_training, }
            summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'],
                                                             ops['train_op'], ops['loss'], ops['pred']],
                                                            feed_dict=feed_dict)
            train_writer.add_summary(summary, step)
            pred_val = np.argmax(pred_val, 1)
            correct = np.sum(pred_val[0:bsize] == batch_label[0:bsize])
            total_correct += correct
            total_seen += bsize
            loss_sum += loss_val

            if FLAGS.debug:
                break

            pbar.update(1)

    log_string('Current Learning Rate %.6f' % sess.run(get_learning_rate(step)))
    log_string('Training loss: %f' % (loss_sum / num_batch))
    log_string('Training accuracy: %f\n' % (total_correct / float(total_seen)))
    TRAIN_DATASET.reset()
Exemplo n.º 4
0
def train_one_epoch(sess, ops, train_writer):
    is_training = True

    train_file_idxs = np.arange(0, len(TRAIN_DATASET))
    np.random.shuffle(train_file_idxs)
    for fn in range(len(TRAIN_DATASET)):

        current_data, current_label = provider.loadDataFile(
            TRAIN_DATASET[train_file_idxs[fn]])
        current_data = current_data[:, 0:NUM_POINT, :]
        current_data, current_label, _ = provider.shuffle_data(
            current_data, np.squeeze(current_label))
        current_label = np.squeeze(current_label)
        file_size = current_data.shape[0]
        num_batches = file_size // BATCH_SIZE
        total_correct = 0
        total_seen = 0
        loss_sum = 0

        for batch_idx in range(num_batches):
            start_idx = batch_idx * BATCH_SIZE
            end_idx = start_idx + BATCH_SIZE
            pc_data = provider.random_scale_point_cloud(
                current_data[start_idx:end_idx, :, :])

            pc_data = provider.shift_point_cloud(pc_data)

            feed_dict = {
                ops['pointcloud_pl']: pc_data,
                ops['label_pl']: current_label[start_idx:end_idx],
                ops['is_training_pl']: is_training
            }

            summary, step, _, loss_val, pred_val = sess.run(
                [
                    ops["merged"], ops["step"], ops["train_op"], ops["loss"],
                    ops["pred"]
                ],
                feed_dict=feed_dict)

            train_writer.add_summary(summary, step)
            pred_val = np.argmax(pred_val, 1)
            correct = np.sum(pred_val == current_label[start_idx:end_idx])
            total_correct = total_correct + correct
            total_seen = total_seen + BATCH_SIZE
            loss_sum = loss_sum + loss_val

        log_string('mean loss: %f' % (loss_sum / float(num_batches)))
        log_string('accuracy: %f' % (total_correct / float(total_seen)))
Exemplo n.º 5
0
    def _augment_batch_data(self, batch_data):
        if self.normal_channel:
            rotated_data = provider.rotate_point_cloud_with_normal(batch_data)
            rotated_data = provider.rotate_perturbation_point_cloud_with_normal(
                rotated_data
            )
        else:
            rotated_data = provider.rotate_point_cloud(batch_data)
            rotated_data = provider.rotate_perturbation_point_cloud(rotated_data)

        jittered_data = provider.random_scale_point_cloud(rotated_data[:, :, 0:3])
        jittered_data = provider.shift_point_cloud(jittered_data)
        jittered_data = provider.jitter_point_cloud(jittered_data)
        rotated_data[:, :, 0:3] = jittered_data
        return provider.shuffle_points(rotated_data)