Exemplo n.º 1
0
def get_model(points, cls_label, num_cls, is_training, config=None):
    batch_size = tf.shape(points)[0]  # points.get_shape()[0].value
    num_point = tf.shape(points)[1]  # points.get_shape()[1].value
    end_points = {}
    xyz = points[:, :, 0:3]

    reuse = None
    net = points
    net = s3g_util.pointwise_conv3d(net,
                                    config.mlp,
                                    'mlp1',
                                    weight_decay=config.weight_decay,
                                    with_bn=config.with_bn,
                                    with_bias=config.with_bias,
                                    reuse=reuse,
                                    is_training=is_training)

    xyz_layers = []
    encoder = []
    encoder.append(net)
    xyz_layers.append(xyz)
    # ===============================================Encoder================================================
    for l in range(len(config.radius)):
        intra_idx, intra_cnt, \
        intra_dst, indices = s3g_util.build_graph(xyz, config.radius[l], config.nn_uplimit[l],
                                                  config.num_sample[l], sample_method=config.sample)
        filt_idx = s3g_util.spherical_kernel(xyz,
                                             xyz,
                                             intra_idx,
                                             intra_cnt,
                                             intra_dst,
                                             config.radius[l],
                                             kernel=config.kernel)
        net = _separable_conv3d_block(net,
                                      config.channels[l],
                                      config.binSize,
                                      intra_idx,
                                      intra_cnt,
                                      filt_idx,
                                      'conv' + str(l + 1),
                                      config.multiplier[l],
                                      reuse=reuse,
                                      weight_decay=config.weight_decay,
                                      with_bn=config.with_bn,
                                      with_bias=config.with_bias,
                                      is_training=is_training)

        encoder.append(net)
        if config.num_sample[l] > 1:
            # ==================================gather_nd====================================
            xyz = tf.gather_nd(xyz, indices)
            xyz_layers.append(xyz)
            inter_idx = tf.gather_nd(intra_idx, indices)
            inter_cnt = tf.gather_nd(intra_cnt, indices)
            inter_dst = tf.gather_nd(intra_dst, indices)
            # =====================================END=======================================

            net = s3g_util.pool3d(net,
                                  inter_idx,
                                  inter_cnt,
                                  method=config.pool_method,
                                  scope='pool' + str(l + 1))
    # ===============================================The End================================================

    config.radius.reverse()
    config.nn_uplimit.reverse()
    config.channels.reverse()
    config.multiplier.reverse()
    xyz_layers.reverse()
    encoder.reverse()
    # ===============================================Decoder================================================
    for l in range(len(config.radius)):
        xyz = xyz_layers[l]
        xyz_unpool = xyz_layers[l + 1]

        intra_idx, intra_cnt, intra_dst, \
        inter_idx, inter_cnt, inter_dst = s3g_util.build_graph_deconv(xyz, xyz_unpool,
                                                                      config.radius[l],
                                                                      config.nn_uplimit[l],
                                                                      nnsearch=config.nnsearch)
        filt_idx = s3g_util.spherical_kernel(xyz,
                                             xyz,
                                             intra_idx,
                                             intra_cnt,
                                             intra_dst,
                                             config.radius[l],
                                             kernel=config.kernel)
        net = _separable_conv3d_block(net,
                                      config.channels[l],
                                      config.binSize,
                                      intra_idx,
                                      intra_cnt,
                                      filt_idx,
                                      'deconv' + str(l + 1),
                                      config.multiplier[l],
                                      reuse=reuse,
                                      weight_decay=config.weight_decay,
                                      with_bn=config.with_bn,
                                      with_bias=config.with_bias,
                                      is_training=is_training)

        net = s3g_util.unpool3d(net,
                                inter_idx,
                                inter_cnt,
                                inter_dst,
                                method=config.unpool_method,
                                scope='unpool' + str(l + 1))

        net = tf.concat((net, encoder[l]), axis=2)
    # ===============================================The End================================================
    net = s3g_util.pointwise_conv3d(net,
                                    config.mlp,
                                    'mlp2',
                                    weight_decay=config.weight_decay,
                                    with_bn=config.with_bn,
                                    with_bias=config.with_bias,
                                    reuse=reuse,
                                    is_training=is_training)
    net = tf.concat((net, encoder[-1]), axis=2)

    cls_label_one_hot = tf.one_hot(cls_label,
                                   depth=NUM_CATEGORIES,
                                   on_value=1.0,
                                   off_value=0.0,
                                   dtype=tf.float32)
    cls_label_one_hot = tf.reshape(cls_label_one_hot,
                                   [batch_size, 1, NUM_CATEGORIES])
    cls_label_one_hot = tf.tile(cls_label_one_hot, [1, num_point, 1])
    net = tf.concat((net, cls_label_one_hot), axis=2)
    end_points['feats'] = net

    # point-wise classifier
    net = s3g_util.pointwise_conv3d(net,
                                    num_cls,
                                    scope='logits',
                                    with_bn=False,
                                    with_bias=config.with_bias,
                                    activation_fn=None,
                                    is_training=is_training)

    return net, end_points
def get_model(points, is_training, config=None):
    end_points = {}
    xyz = points[:, :, 0:3]
    if config.normalize:
        norm_xyz = normalize_xyz(xyz)

    reuse = None
    net = tf.concat((norm_xyz,points[:,:,3:]),axis=2)
    print('input',net)
    net = s3g_util.pointwise_conv3d(net, config.mlp, 'mlp1',
                                    weight_decay=config.weight_decay,
                                    with_bn=config.with_bn, with_bias=config.with_bias,
                                    reuse=reuse, is_training=is_training)

    xyz_layers = []
    encoder = []
    xyz_layers.append(xyz)
    # ===============================================Encoder================================================
    for l in range(len(config.radius)):
        intra_idx, intra_cnt, \
        intra_dst, indices = s3g_util.build_graph(xyz, config.radius[l], config.nn_uplimit[l],
                                                  config.num_sample[l], sample_method=config.sample)
        filt_idx = s3g_util.spherical_kernel(xyz, xyz, intra_idx, intra_cnt,
                                             intra_dst, config.radius[l],
                                             kernel=config.kernel)
        net = _separable_conv3d_block(net, config.channels[l], config.binSize, intra_idx, intra_cnt,
                                      filt_idx, 'conv'+str(l+1), config.multiplier[l], reuse=reuse,
                                      weight_decay=config.weight_decay, with_bn=config.with_bn,
                                      with_bias=config.with_bias, is_training=is_training)

        encoder.append(net)
        if config.num_sample[l]>1:
            # ==================================gather_nd====================================
            xyz = tf.gather_nd(xyz, indices)
            xyz_layers.append(xyz)
            inter_idx = tf.gather_nd(intra_idx, indices)
            inter_cnt = tf.gather_nd(intra_cnt, indices)
            inter_dst = tf.gather_nd(intra_dst, indices)
            # =====================================END=======================================

            net = s3g_util.pool3d(net, inter_idx, inter_cnt,
                                      method=config.pool_method, scope='pool'+str(l+1))
    # ===============================================The End================================================

    config.radius.reverse()
    config.nn_uplimit.reverse()
    config.channels.reverse()
    config.multiplier.reverse()
    xyz_layers.reverse()
    encoder.reverse()
    print(encoder)
    # ===============================================Decoder================================================
    for l in range(len(config.radius)):
        xyz = xyz_layers[l]
        xyz_unpool = xyz_layers[l+1]

        intra_idx, intra_cnt, intra_dst, \
        inter_idx, inter_cnt, inter_dst = s3g_util.build_graph_deconv(xyz, xyz_unpool,
                                                                      config.radius[l],
                                                                      config.nn_uplimit[l])
        filt_idx = s3g_util.spherical_kernel(xyz, xyz, intra_idx, intra_cnt,
                                             intra_dst, config.radius[l], kernel=config.kernel)
        net = _separable_conv3d_block(net, config.channels[l], config.binSize, intra_idx, intra_cnt,
                                      filt_idx, 'deconv'+str(l+1), config.multiplier[l], reuse=reuse,
                                      weight_decay=config.weight_decay, with_bn=config.with_bn,
                                      with_bias=config.with_bias, is_training=is_training)

        net = s3g_util.unpool3d(net, inter_idx, inter_cnt, inter_dst,
                                    method=config.unpool_method, scope='unpool'+str(l+1))

        net = tf.concat((net,encoder[l]),axis=2)
        print(net)
    # ===============================================The End================================================
    end_points['feats'] = net

    # point-wise classifier
    net = s3g_util.pointwise_conv3d(net, config.num_cls, scope='logits', with_bn=False,
                                    with_bias=config.with_bias, activation_fn=None,
                                    is_training=is_training)

    return net, end_points
Exemplo n.º 3
0
def get_model(points, is_training, config=None):
    """ Regression Network, input is BxNx3, output Bx1 """
    batch_size = points.get_shape()[0]
    # print(batch_size)
    # batch_size = points.get_shape()[0].value
    # num_point = points.get_shape()[1].value
    end_points = {}

    # assert(num_point==config.num_input)
    xyz = points[:, :, 0:3]
    if config.normalize:
        xyz = normalize_xyz(xyz)
        # net = tf.concat((xyz,points[:,:,8:]),axis=2)

    query = tf.reduce_mean(xyz, axis=1,
                           keepdims=True)  # the global viewing point
    reuse = None
    net = s3g_util.pointwise_conv3d(xyz,
                                    config.mlp,
                                    'mlp1',
                                    weight_decay=config.weight_decay,
                                    with_bn=config.with_bn,
                                    with_bias=config.with_bias,
                                    reuse=reuse,
                                    is_training=is_training)

    global_feat = []
    for l in range(len(config.radius)):
        if config.use_raw:
            net = tf.concat([net, xyz], axis=-1)

        # the neighbor information is the same within xyz_pose_1 and xyz_pose_2.
        # Therefore, we compute it with xyz_pose_1, and apply it to xyz_pose_2 as well
        intra_idx, intra_cnt, \
        intra_dst, indices = s3g_util.build_graph(xyz, config.radius[l], config.nn_uplimit[l],
                                                  config.num_sample[l], sample_method=config.sample)
        filt_idx = s3g_util.spherical_kernel(xyz,
                                             xyz,
                                             intra_idx,
                                             intra_cnt,
                                             intra_dst,
                                             config.radius[l],
                                             kernel=config.kernel)

        net = _separable_conv3d_block(net,
                                      config.channels[l],
                                      config.binSize,
                                      intra_idx,
                                      intra_cnt,
                                      filt_idx,
                                      'conv' + str(l + 1),
                                      config.multiplier[l],
                                      reuse=reuse,
                                      weight_decay=config.weight_decay,
                                      with_bn=config.with_bn,
                                      with_bias=config.with_bias,
                                      is_training=is_training)

        if config.num_sample[l] > 1:
            # ==================================gather_nd====================================
            xyz = tf.gather_nd(xyz, indices)
            inter_idx = tf.gather_nd(intra_idx, indices)
            inter_cnt = tf.gather_nd(intra_cnt, indices)
            inter_dst = tf.gather_nd(intra_dst, indices)
            # =====================================END=======================================

            net = s3g_util.pool3d(net,
                                  inter_idx,
                                  inter_cnt,
                                  method=config.pool_method,
                                  scope='pool' + str(l + 1))

        global_maxpool = tf.reduce_max(net, axis=1, keepdims=True)
        global_feat.append(global_maxpool)

    # =============================global feature extraction in the final layer=============================
    global_radius = 100.0  # global_radius(>=2.0) should connect all points to each point in the cloud
    nn_idx, nn_cnt, nn_dst = s3g_util.build_global_graph(
        xyz, query, global_radius)
    filt_idx = s3g_util.spherical_kernel(xyz,
                                         query,
                                         nn_idx,
                                         nn_cnt,
                                         nn_dst,
                                         global_radius,
                                         kernel=[8, 2, 1])
    net = s3g_util.separable_conv3d(net,
                                    config.global_channels,
                                    17,
                                    config.global_multiplier,
                                    'global_conv',
                                    nn_idx,
                                    nn_cnt,
                                    filt_idx,
                                    reuse=reuse,
                                    weight_decay=config.weight_decay,
                                    with_bn=config.with_bn,
                                    with_bias=config.with_bias,
                                    is_training=is_training)
    global_feat.append(net)
    net = tf.concat(global_feat, axis=2)
    # =====================================================================================================

    # MLP on global point cloud vector
    net = tf.reshape(net, [batch_size, -1])
    net = s3g_util.fully_connected(net,
                                   512,
                                   scope='fc1',
                                   weight_decay=config.weight_decay,
                                   with_bn=config.with_bn,
                                   with_bias=config.with_bias,
                                   is_training=is_training)
    net = tf.layers.dropout(net, 0.5, training=is_training, name='fc1_dp')
    net = s3g_util.fully_connected(net,
                                   256,
                                   scope='fc2',
                                   weight_decay=config.weight_decay,
                                   with_bn=config.with_bn,
                                   with_bias=config.with_bias,
                                   is_training=is_training)
    net = tf.layers.dropout(net, 0.5, training=is_training, name='fc2_dp')
    net = s3g_util.fully_connected(net,
                                   config.num_cls,
                                   scope='logits',
                                   with_bn=False,
                                   with_bias=config.with_bias,
                                   activation_fn=None,
                                   is_training=is_training)

    net = tf.reshape(net, [batch_size])
    return net, end_points