示例#1
0
def augment(i, cur_data, cur_label, cur_meta, points, labels, meta):
    # rotated_data = provider.rotate_point_cloud(np.expand_dims(cur_data[i], axis=0))
    jittered_data = provider.jitter_point_cloud(
        np.expand_dims(cur_data[i], axis=0))
    translated_data = provider.translate_point_cloud(jittered_data)
    points.append(np.squeeze(translated_data))
    labels.append(cur_label[i])
    meta.append(cur_meta[i])
    return points, labels, meta
示例#2
0
def train_one_epoch(sess, ops, gmm, train_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = True

    # Shuffle train files
    train_file_idxs = np.arange(0, len(TRAIN_FILES))
    np.random.shuffle(train_file_idxs)

    for fn in range(len(TRAIN_FILES)):
        log_string('----' + str(fn) + '-----')
        current_data, current_label = provider.loadDataFile(TRAIN_FILES[train_file_idxs[fn]], compensate = False)
        # points_idx = range(0,NUM_POINT)
        points_idx = np.random.choice(range(0,2048),NUM_POINT)
        current_data = current_data[:, points_idx, :]
        current_data, current_label, _ = provider.shuffle_data(current_data, np.squeeze(current_label))
        current_label = np.squeeze(current_label)

        file_size = current_data.shape[0]
        num_batches = file_size / BATCH_SIZE

        loss_sum = 0

        for batch_idx in range(num_batches):
            start_idx = batch_idx * BATCH_SIZE
            end_idx = (batch_idx + 1) * BATCH_SIZE

            # Augment batched point clouds by rotation and jittering

            augmented_data = current_data[start_idx:end_idx, :, :]
            if augment_scale:
                augmented_data = provider.scale_point_cloud(augmented_data, smin=0.66, smax=1.5)
            if augment_rotation:
                augmented_data = provider.rotate_point_cloud(augmented_data)
            if augment_translation:
                augmented_data = provider.translate_point_cloud(augmented_data, tval = 0.2)
            if augment_jitter:
                augmented_data = provider.jitter_point_cloud(augmented_data, sigma=0.01,
                                                        clip=0.05)  # default sigma=0.01, clip=0.05
            if augment_outlier:
                augmented_data = provider.insert_outliers_to_point_cloud(augmented_data, outlier_ratio=0.02)



            feed_dict = {ops['points_pl']: augmented_data,
                         ops['labels_pl']: current_label[start_idx:end_idx],
                         ops['w_pl']: gmm.weights_,
                         ops['mu_pl']: gmm.means_,
                         ops['sigma_pl']: np.sqrt(gmm.covariances_),
                         ops['is_training_pl']: is_training, }
            summary, step, _, loss_val, reconstructed_points_val = sess.run([ops['merged'], ops['step'],
                                                             ops['train_op'], ops['loss'], ops['reconstructed_points']],
                                                            feed_dict=feed_dict)
            train_writer.add_summary(summary, step)

            loss_sum += loss_val

        log_string('mean loss: %f' % (loss_sum / float(num_batches)))
示例#3
0
def train_one_epoch(sess, ops, gmm, train_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = True

    if (".h5" in TRAIN_FILE):
        current_data, current_label = data_utils.get_current_data_h5(TRAIN_DATA, TRAIN_LABELS, NUM_POINT)
    else:
        current_data, current_label = data_utils.get_current_data(TRAIN_DATA, TRAIN_LABELS, NUM_POINT)


    current_label = np.squeeze(current_label)

    num_batches = current_data.shape[0]//BATCH_SIZE

    total_correct = 0
    total_seen = 0
    loss_sum = 0

    for batch_idx in range(num_batches):
        start_idx = batch_idx * BATCH_SIZE
        end_idx = (batch_idx + 1) * BATCH_SIZE

        # Augment batched point clouds by rotation and jittering

        augmented_data = current_data[start_idx:end_idx, :, :]
        if augment_scale:
            augmented_data = provider.scale_point_cloud(augmented_data, smin=0.66, smax=1.5)
        if augment_rotation:
            augmented_data = provider.rotate_point_cloud(augmented_data)
        if augment_translation:
            augmented_data = provider.translate_point_cloud(augmented_data, tval = 0.2)
        if augment_jitter:
            augmented_data = provider.jitter_point_cloud(augmented_data, sigma=0.01,
                                                    clip=0.05)  # default sigma=0.01, clip=0.05
        if augment_outlier:
            augmented_data = provider.insert_outliers_to_point_cloud(augmented_data, outlier_ratio=0.02)

        feed_dict = {ops['points_pl']: augmented_data,
                     ops['labels_pl']: current_label[start_idx:end_idx],
                     ops['w_pl']: gmm.weights_,
                     ops['mu_pl']: gmm.means_,
                     ops['sigma_pl']: np.sqrt(gmm.covariances_),
                     ops['is_training_pl']: is_training, }
        summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'],
                                                         ops['train_op'], ops['loss'], ops['pred']],
                                                        feed_dict=feed_dict)

        train_writer.add_summary(summary, step)
        pred_val = np.argmax(pred_val, 1)
        correct = np.sum(pred_val == current_label[start_idx:end_idx])
        total_correct += correct
        total_seen += BATCH_SIZE
        loss_sum += loss_val

    log_string('mean loss: %f' % (loss_sum / float(num_batches)))
    log_string('accuracy: %f' % (total_correct / float(total_seen)))
示例#4
0
文件: train.py 项目: Usertlcc/cc
def train_one_epoch(sess, ops, train_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = True

    # Shuffle train files
    train_file_idxs = np.arange(0, len(TRAIN_FILES))
    np.random.shuffle(train_file_idxs)
    log_string(str(datetime.now()))

    for fn in range(len(TRAIN_FILES)):
        log_string('----' + str(fn) + '-----')
        current_data, current_label = provider.loadDataFile(
            TRAIN_FILES[train_file_idxs[fn]])
        current_data = current_data[:, 0:NUM_POINT, :]
        current_data, current_label, _ = provider.shuffle_data(
            current_data, np.squeeze(current_label))
        current_label = np.squeeze(current_label)

        file_size = current_data.shape[0]
        num_batches = file_size // BATCH_SIZE

        total_correct = 0
        total_seen = 0
        loss_sum = 0

        for batch_idx in range(num_batches):
            start_idx = batch_idx * BATCH_SIZE
            end_idx = (batch_idx + 1) * BATCH_SIZE

            # Augment batched point clouds by rotation and jittering
            batch_data = provider.scale_point_cloud(
                current_data[start_idx:end_idx, :, :])
            batch_data = provider.translate_point_cloud(batch_data)
            batch_data = provider.jitter_point_cloud(batch_data)

            feed_dict = {
                ops['pointclouds_pl']: batch_data,
                ops['labels_pl']: current_label[start_idx:end_idx],
                ops['is_training_pl']: is_training,
            }
            summary, step, _, loss_val, pred_val = sess.run(
                [
                    ops['merged'], ops['step'], ops['train_op'], ops['loss'],
                    ops['pred']
                ],
                feed_dict=feed_dict)
            train_writer.add_summary(summary, step)
            pred_val = np.argmax(pred_val, 1)
            correct = np.sum(pred_val == current_label[start_idx:end_idx])
            total_correct += correct
            total_seen += BATCH_SIZE
            loss_sum += loss_val

        log_string('mean loss: %f' % (loss_sum / float(num_batches)))
        log_string('accuracy: %f' % (total_correct / float(total_seen)))
示例#5
0
    def train(self):
        print("**** Start training ")
        is_training = True

        print('**** Training with max epoch %03d' %
              (self.train_config.max_epoch))
        for epoch in range(self.train_config.max_epoch):
            pointcloud_data, labels, idx = provider.shuffle_data(
                self.train_pointcloud_data, self.train_labels)

            file_size = pointcloud_data.shape[0]
            num_batches = file_size / self.train_config.batch_size

            curr_step = 0
            total_correct = 0
            total_seen = 0
            loss_sum = 0

            for batch_idx in range(int(num_batches)):
                start_idx = batch_idx * self.train_config.batch_size
                end_idx = (batch_idx + 1) * self.train_config.batch_size

                points_batch = pointcloud_data[start_idx:end_idx, ...]

                if self.data_config.augment_scale:
                    points_batch = provider.scale_point_cloud(points_batch,
                                                              smin=0.66,
                                                              smax=1.5)
                if self.data_config.augment_rotation:
                    points_batch = provider.rotate_point_cloud(points_batch)
                if self.data_config.augment_translation:
                    points_batch = provider.translate_point_cloud(points_batch,
                                                                  tval=0.2)
                if self.data_config.augment_jitter:
                    points_batch = provider.jitter_point_cloud(points_batch,
                                                               sigma=0.01,
                                                               clip=0.05)
                if self.data_config.augment_outlier:
                    points_batch = provider.insert_outliers_to_point_cloud(
                        points_batch, outlier_ratio=0.02)

                points_batch = utils.scale_to_unit_sphere(points_batch)
                label_batch = labels[start_idx:end_idx]

                xforms_np, rotations_np = pf.get_xforms(
                    self.train_config.batch_size,
                    rotation_range=setting.rotation_range,
                    scaling_range=setting.scaling_range,
                    order=setting.rotation_order)

                feed_dict = {
                    self.ops['points_pl']:
                    points_batch,
                    self.ops['labels_pl']:
                    label_batch,
                    self.ops['w_pl']:
                    self.gmm.weights_ if self.model == "3DmFV" else [1],
                    self.ops['mu_pl']:
                    self.gmm.means_ if self.model == "3DmFV" else [1],
                    self.ops['sigma_pl']:
                    np.sqrt(self.gmm.covariances_ if self.model ==
                            "3DmFV" else [1]),
                    self.ops['is_training_pl']:
                    is_training,
                    self.ops['xforms']:
                    xforms_np if self.model == "PointCNN" else [1],
                    self.ops['rotations']:
                    rotations_np if self.model == "PointCNN" else [1],
                    self.ops['jitter_range']:
                    np.array([setting.jitter] if self.model ==
                             "PointCNN" else [1])
                }

                #Log the summaries every 100 step.
                if curr_step % 10 == 0 and curr_step > 0:
                    summary, step, gstep, _top, _mop, loss_val, pred_val = self.sess.run(
                        [
                            self.ops['summary_op'], self.ops['step'],
                            self.ops['global_step'], self.ops['train_op'],
                            self.ops['metrics_op'], self.ops['loss'],
                            self.ops['predictions']
                        ],
                        feed_dict=feed_dict)

                    self.sv.summary_computed(self.sess, summary)
                else:

                    step, gstep, _top, _mop, loss_val, pred_val = self.sess.run(
                        [
                            self.ops['step'], self.ops['global_step'],
                            self.ops['train_op'], self.ops['metrics_op'],
                            self.ops['loss'], self.ops['predictions']
                        ],
                        feed_dict=feed_dict)

                    if curr_step % 100 == 0 or curr_step % 75 == 0:
                        print('global step {}: loss: {} '.format(
                            gstep, loss_val))

                correct = np.sum(pred_val == label_batch)

                total_correct += correct
                total_seen += self.train_config.batch_size
                loss_sum += loss_val

                curr_step += 1

            #evaluate
            if epoch % 2 == 0 or epoch == self.train_config.max_epoch - 1:
                acc, acc_avg_cls = self.evaluate()