示例#1
0
def aggregation_func(masked_pair_feats, neighbor_masks, geo_dist, max_o_grid, max_p_grid,
                     pool_type = "max", mlp_lst = None, scope='center_aggregation', bn_decay=0.9):
    '''
    :param masked_pair_feats: B * max_o_grid * max_p_grid * C
    :param mlp: true of false
    :return:
    '''
    if pool_type == "max_pooling":
        # aggregated_feats = mx.sym.Pooling(name=scope+"/max_pooling", data=masked_pair_feats,
        #     kernel=(1, max_p_grid), pool_type="max", layout="NCHW" if BN else "NHWC")
        # aggregated_feats = mx.sym.squeeze(aggregated_feats, axis=P_dim)

        aggregated_feats = mx.sym.Pooling(name=scope + "/max_pooling", data=masked_pair_feats,
                                          kernel=(1, max_p_grid) if BN else (max_p_grid, 1), pool_type="max", layout="NCHW")
        aggregated_feats = mx.sym.squeeze(aggregated_feats, axis=P_dim)

    elif pool_type == "max":
        aggregated_feats = mx.symbol.max(masked_pair_feats, axis=P_dim, name=scope + '/max')
    elif pool_type == "sum":
        aggregated_feats = mx.symbol.sum(masked_pair_feats, axis=P_dim, name=scope + '/sum')
    elif pool_type == "inter_sum":
        aggregated_feats = interpolate_all(geo_dist, masked_pair_feats, neighbor_masks, scope=scope+'/interp')  # (B, M, C1)
    else:
        raise NotImplementedError
    if mlp_lst is not None:
        if BN:
            aggregated_feats = mlp1d_c(aggregated_feats, mlp_lst, bn_decay=bn_decay, use_bn=BN, attr=None,
                                           scope=scope + '/agg_mlp/cov1d')
        elif configs["use_bn"] == 'p':
            aggregated_feats = fully_connected_mlp_withbn(aggregated_feats, mlp_lst, dim=3, bn_decay=bn_decay, dropout_ratio=0,
                flatten=False, use_bn=BN, use_relu=True, bn_axis=1, attr=None, scope=scope + '/agg_fc', flip=False)
        else:
            aggregated_feats = fully_connected_mlp(aggregated_feats, mlp_lst, dim = 3, bn_decay=bn_decay, dropout_ratio=0, flatten=False, use_bn=BN,
                            use_relu=True, bn_axis=1, attr=None, scope=scope+'/agg_fc', flip=False)
    return aggregated_feats
def update_func(center_feats,
                outDim=[64, 256],
                scope="center_update",
                bn_decay=0.9):
    '''
    :param centers: B * max_o_grid * 3
    :param center_feats: B * max_o_grid * C
    :param outDim: [128, 256,...]
    :return: center_feats  B * max_o_grid * outDim[-1]
    '''
    if configs["relu"]:
        center_feats = mx.symbol.relu(center_feats, name=scope + '/pre_relu')
    if len(outDim) != 0:
        if BN:
            center_feats = mlp1d_c(center_feats,
                                   outDim,
                                   bn_decay=bn_decay,
                                   use_bn=BN,
                                   attr=None,
                                   scope=scope + '/cov1d')
        elif configs["use_bn"] == 'p':
            center_feats = fully_connected_mlp_withbn(center_feats,
                                                      outDim,
                                                      dim=3,
                                                      bn_decay=bn_decay,
                                                      dropout_ratio=0,
                                                      flatten=False,
                                                      use_bn=BN,
                                                      use_relu=True,
                                                      bn_axis=1,
                                                      attr=None,
                                                      scope=scope + '/cov1dfc',
                                                      flip=False)
        else:
            center_feats = fully_connected_mlp(center_feats,
                                               outDim,
                                               dim=3,
                                               bn_decay=bn_decay,
                                               dropout_ratio=0,
                                               flatten=False,
                                               use_bn=BN,
                                               use_relu=True,
                                               bn_axis=1,
                                               attr=None,
                                               scope=scope + '/cov1dfc',
                                               flip=False)
    return center_feats
示例#3
0
def sub_g_update(centers_xyz, center_den, neighbors, has_feats, center_masks, neighbor_masks, attfdim,
        center_ori_feats = None, pt_mlp_lst = None, outDim= [64, 256], shape = [64, 800000, 100, 6], scope='layer',
        aggtype="gcn", pool_type="max_pooling", att_full=False, center_dim=[], recalden=False, bn_decay=0.9):
    '''
    RELU( MAX_POOLING ((geoRelation() * f_neib)) )
    :param centers:  B * max_o_grid * 3
    :param neighbor_feats: B * max_o_grid * max_p_grid * 4+C
    :param center_masks: B * max_o_grid
    :param neighbor_masks: B * max_o_grid * max_p_grid
    :param scope:
    :return:
    '''
    #
    B, max_o_grid, max_p_grid, N, C = shape
    if neighbor_masks is not None:
        neighbor_masks_expand = mx.sym.expand_dims(neighbor_masks, axis=C_dim, name=scope+"/concatmask")
    neighbor_feats = mx.sym.slice_axis(neighbors, axis=C_dim, begin=4, end=None) if has_feats else None
    # todo expand centers in Gridify
    centers_expand_xyz = mx.sym.tile(mx.sym.expand_dims(centers_xyz, axis=P_dim, name=scope + "/centerxyzexpand"),
                                     reps=(1, 1, 1, max_p_grid) if BN else (1, 1, max_p_grid, 1))
    neighbor_locs_xyz = mx.sym.slice_axis(neighbors, axis=C_dim, begin=0, end=3)
    geo_vec = neighbor_locs_xyz - centers_expand_xyz # B * 3 * max_o_grid * max_p_grid
    geo_dist = mx.sym.sqrt(mx.sym.sum(mx.sym.square(geo_vec), axis=C_dim, keepdims=True)) # B * max_o_grid * max_p_grid

    if max(attfdim, configs["localfdim"]) in [5,11,12]:
        geo_num = mx.sym.slice_axis(neighbors, axis=C_dim, begin=3, end=4) # B * 1 * O * P or B * O * P * 1
        num_points = mx.sym.ones_like(geo_num, name = scope+"/oneslike") * configs["num_points"]
        geo_global_density = mx.sym.elemwise_div(geo_num, num_points, name=scope+"/den_glo_ele_div")
        if max(attfdim, configs["localfdim"]) in [5,12]:
            if recalden:
                geo_num_masked = geo_num
                if neighbor_masks is not None:
                    geo_num_masked = mx.sym.elemwise_mul(neighbor_masks_expand, geo_num)
                geo_den_sum = mx.sym.sum(geo_num_masked, axis=P_dim, keepdims=True)  # B * 1 * O * 1 or B * O * 1 * 1
                geo_den_sum = mx.sym.clip(geo_den_sum, 1, 1000000) # make sure it's bigger than zero
            else:
                geo_den_sum = mx.sym.tile(mx.sym.expand_dims(center_den, axis=P_dim,
                    name=scope + "/centerdenexpand"), reps=(1, 1, 1, max_p_grid) if BN else (1, 1, max_p_grid, 1))
            geo_density = mx.sym.broadcast_div(geo_num, geo_den_sum, name=scope+"/den_ele_div")
    if attfdim <= 3:
        att_vec = geo_vec
    elif attfdim == 4:
        att_vec = mx.sym.concat(geo_dist, geo_vec, dim=C_dim, name=scope+"/attconcat")
    elif attfdim == 5:
        att_vec = mx.sym.concat(geo_dist, geo_vec, geo_global_density, dim=C_dim, name=scope+"/attconcat")
    elif attfdim == 10:
        att_vec = mx.sym.concat(geo_dist, geo_vec, centers_expand_xyz, neighbor_locs_xyz, dim=C_dim, name=scope+"/attconcat")
    elif attfdim == 11:
        att_vec = mx.sym.concat(geo_dist, geo_vec, centers_expand_xyz, neighbor_locs_xyz, geo_global_density, dim=C_dim, name=scope+"/attconcat")
    elif attfdim == 12:
        att_vec = mx.sym.concat(geo_dist, geo_vec, centers_expand_xyz, neighbor_locs_xyz, geo_density, geo_global_density, dim=C_dim, name=scope+"/attconcat")
    else:
        raise NotImplementedError
    #

    if configs["localfdim"] <= 3:
        geo_feats = geo_vec
    elif configs["localfdim"] == 4:
        geo_feats =  mx.sym.concat(geo_dist, geo_vec, dim=C_dim, name=scope+"/geo_concat")
    elif configs["localfdim"] == 5:
        geo_feats = mx.sym.concat(geo_dist, geo_vec, geo_global_density, dim=C_dim, name=scope+"/geo_concat")
    elif configs["localfdim"] == 10:
        geo_feats = mx.sym.concat(geo_dist, geo_vec, centers_expand_xyz, neighbor_locs_xyz, dim=C_dim, name=scope+"/geo_concat")
    elif configs["localfdim"] == 11:
        geo_feats = mx.sym.concat(geo_dist, geo_vec, centers_expand_xyz, neighbor_locs_xyz, geo_global_density, dim=C_dim, name=scope+"/geo_concat")
    elif configs["localfdim"] == 12:
        geo_feats = mx.sym.concat(geo_dist, geo_vec, centers_expand_xyz, neighbor_locs_xyz, geo_density, geo_global_density, dim=C_dim, name=scope+"/geo_concat")
    else:
        raise NotImplementedError

    if neighbor_feats is None:
        neighbor_feats = geo_feats
    elif configs["localfdim"] != 0:
        neighbor_feats = mx.symbol.concat(geo_feats, neighbor_feats, dim=C_dim, name=scope+"/neighbor_feats_concat")

    if aggtype == "gcn":
        if att_full:
            att_vec = mx.sym.concat(att_vec, neighbor_feats, dim=C_dim, name=scope+"/att_full_concat")
        pair_feats = verts_pair_func(neighbor_feats, att_vec, B, max_o_grid,
            max_p_grid, N, C, att_full=att_full, pt_mlp_lst=pt_mlp_lst, scope=scope, bn_decay=bn_decay)  # B * max_o_grid * max_p_grid * C
        if neighbor_masks is not None and pool_type!="max" and pool_type!="max_pooling":
            pair_feats = mx.sym.broadcast_mul(pair_feats, neighbor_masks_expand, name="pairmask") # B * max_o_grid * max_p_grid * C
        agg_neigh_feats = aggregation_func(pair_feats, neighbor_masks, geo_dist, max_o_grid, max_p_grid, pool_type=pool_type, scope=scope+"/center_aggregation", bn_decay=bn_decay) # B * max_o_grid * 1 * outDim[-1]
    elif aggtype == "agg_gcn":
        agg_neigh_feats = aggregation_func(neighbor_feats, neighbor_masks, geo_dist, max_o_grid, max_p_grid,
            pool_type=pool_type, mlp_lst=pt_mlp_lst, scope=scope + "/neigh_aggregation", bn_decay=bn_decay)
    else:
        raise NotImplementedError("aggtype %s is not implemented" % aggtype)
    if center_ori_feats is not None:
        if len(center_dim) > 0:
            if BN:
                center_feats = mlp1d_c(center_ori_feats, center_dim, bn_decay=bn_decay, use_bn=BN, attr=None, scope=scope + '/centerfeats_conv1d')
            elif configs["use_bn"] == 'p':
                center_feats = fully_connected_mlp_withbn(center_ori_feats, center_dim, dim=3, bn_decay=bn_decay, dropout_ratio=0,
                                                   flatten=False, use_bn=BN, use_relu=True, bn_axis=1, attr=None,
                                                   scope=scope + '/centerfeats_conv1d1d', flip=False)
            else:
                center_feats = fully_connected_mlp(center_ori_feats, center_dim, dim = 3, bn_decay=bn_decay, dropout_ratio=0, flatten=False, use_bn=BN, use_relu=True, bn_axis=1, attr=None, scope=scope + '/centerfeats_conv1d1d', flip=False)
        else:
            center_feats = center_ori_feats
        if configs["up_center_inte"] == "add":
            agg_neigh_feats = center_feats + agg_neigh_feats
        elif configs["up_center_inte"] == "concat":
            agg_neigh_feats = mx.symbol.concat(center_feats, agg_neigh_feats, dim=1 if BN else 2, name=scope + "/agg_feats_concat")
    agg_feats = update_func(agg_neigh_feats, outDim=outDim, scope=scope + "/center_update", bn_decay=bn_decay)  # B * max_o_grid * outDim[-1]
    if center_masks is not None:
        agg_feats = mx.sym.broadcast_mul(agg_feats, mx.sym.expand_dims(center_masks, axis=1 if BN else 2, name=scope+"/maskexpand"), name=scope+"/centermask")  # B * outDim[-1] * max_o_grid

    return agg_feats