Beispiel #1
0
def get_point_info(points, mask_ratio=0, mask=-1):
    with tf.name_scope('points_info'):
        pts = points_property(points, property_name='xyz', channel=4)
        label = points_property(points, property_name='label', channel=1)
        label = tf.reshape(label, [-1])
        label_mask = label > mask  # mask out invalid points, -1
        if mask_ratio > 0:  # random drop some points to speed up training
            rnd_mask = tf.random.uniform(tf.shape(label_mask)) > mask_ratio
            label_mask = tf.logical_and(label_mask, rnd_mask)
        pts = tf.boolean_mask(pts, label_mask)
        label = tf.boolean_mask(label, label_mask)
    return pts, label
    def __call__(self, dataset='train', training=True, reuse=False, gpu_num=1):

        debug_checks = {}

        FLAGS = self.flags
        with tf.device('/cpu:0'):
            flags_data = FLAGS.DATA.train if dataset == 'train' else FLAGS.DATA.test
            data_iter = self.create_dataset(flags_data)

        tower_tensors = []
        for i in range(gpu_num):
            with tf.device('/gpu:%d' % i):
                with tf.name_scope('device_%d' % i):
                    octree, _labels, points = data_iter.get_next()
                    debug_checks["{}/input_octree".format(dataset)] = octree
                    debug_checks["{}/input_points".format(dataset)] = points
                    debug_checks["{}/input_labels".format(dataset)] = _labels
                    pts, label = get_point_info(points, flags_data.mask_ratio)
                    print("mask ratio for {} is {}".format(
                        dataset, flags_data.mask_ratio))
                    debug_checks["{}/input_point_info/points".format(
                        dataset)] = pts
                    debug_checks["{}/input_point_info/labels".format(
                        dataset)] = label
                    debug_checks["{}/input_point_info/normals".format(
                        dataset)] = points_property(points,
                                                    property_name='normal',
                                                    channel=3)
                    if not FLAGS.LOSS.point_wise:
                        pts, label = None, get_seg_label(
                            octree, FLAGS.MODEL.depth_out)
                        debug_checks["{}/input_seg_label/points"] = pts
                        debug_checks["{}/input_seg_label/label"] = label
                    logit = seg_network(octree,
                                        FLAGS.MODEL,
                                        training,
                                        reuse,
                                        pts=pts)
                    debug_checks["{}/logit".format(dataset)] = logit
                    losses, dc = loss_functions_seg_debug_checks(
                        logit,
                        label,
                        FLAGS.LOSS.num_class,
                        FLAGS.LOSS.weight_decay,
                        'ocnn',
                        mask=0)
                    debug_checks.update(dc)
                    tensors = losses + [losses[0] + losses[2]]  # total loss
                    names = ['loss', 'accu', 'regularizer', 'total_loss']

                    if flags_data.batch_size == 1:
                        num_class = FLAGS.LOSS.num_class
                        intsc, union = tf_IoU_per_shape(logit,
                                                        label,
                                                        num_class,
                                                        mask=0)
                        iou = tf.constant(
                            0.0)  # placeholder, calc its value later
                        tensors = [iou] + tensors + intsc + union
                        names = ['iou'] + names + \
                                ['intsc_%d' % i for i in range(num_class)] + \
                                ['union_%d' % i for i in range(num_class)]

                    tower_tensors.append(tensors)
                    reuse = True

        tensors = tower_tensors[0] if gpu_num == 1 else list(
            zip(*tower_tensors))
        return tensors, names, debug_checks