Esempio n. 1
0
def verts_pair_func(neighbor_feats, att_vec, B, max_o_grid, max_p_grid, N, C, att_full = False, pt_mlp_lst= None, scope="pair_relation", bn_decay=0.9):
    '''
    :param centers:  B * max_o_grid * 3
    :param neighbor_feats:  B * max_o_grid * max_p_grid * C
    :param neighbor_locs:  B * max_o_grid * max_p_grid * 3
    :param C:
    :param max_p_grid:
    :param scope:
    :return: geo_feat: B * max_o_grid * max_p_grid * C
    '''

    if pt_mlp_lst is not None and len(pt_mlp_lst) > 0:
        C = pt_mlp_lst[-1]
        print("pt_mlp_lst:", pt_mlp_lst)
        if BN:
            neighbor_feats = mlp2d_c(neighbor_feats, pt_mlp_lst, use_bn=BN, bn_decay=bn_decay, scope=scope)
        else:
            neighbor_feats = fully_connected_mlp(neighbor_feats, pt_mlp_lst, dim = 4, bn_decay=bn_decay, dropout_ratio=0, flatten=False, use_bn=BN, use_relu=True, bn_axis=1, attr=None, scope=scope + '/cov2fc', flip=False)
    # centers=centers.reshape((B, max_o_grid, 3), name=scope+"/centerreshape")
    if configs["attfdim"] > 0:
        if BN:
            att_vec = mlp2d_c(att_vec, [C if att_full else C//4, C], use_bn=BN, bn_decay=bn_decay, attr=None, scope=scope+"/update_att_mlp2d")
        else:
            att_vec = fully_connected_mlp(att_vec, [C if att_full else C//4, C], dim = 4, bn_decay=bn_decay, dropout_ratio=0, flatten=False, use_bn=BN, use_relu=True, bn_axis=1, attr=None, scope=scope + '/update_att_mlp2dfc', flip=False)
        pair_feats = att_vec * neighbor_feats
    else:
        pair_feats = neighbor_feats
    return pair_feats
Esempio n. 2
0
def verts_pair_func(ori_neighbor_feats, att_vec, B, max_o_grid, max_p_grid, N, C, att_full = "off", pt_mlp_lst= None, att_ele_lst = None, contextvec = None, scope="pair_relation", bn_decay=0.9):
    '''
    :param centers:  B * max_o_grid * 3
    :param neighbor_feats:  B * max_o_grid * max_p_grid * C
    :param neighbor_locs:  B * max_o_grid * max_p_grid * 3
    :param C:
    :param max_p_grid:
    :param scope:
    :return: geo_feat: B * max_o_grid * max_p_grid * C
    '''
    if len(att_ele_lst)>0:
        C = att_ele_lst[-1] + configs["localfdim"]
    if pt_mlp_lst is not None and len(pt_mlp_lst) > 0:
        C = pt_mlp_lst[-1]
        print("pt_mlp_lst:", pt_mlp_lst)
        if BN:
            neighbor_feats = mlp2d_c(ori_neighbor_feats, pt_mlp_lst, use_bn=BN, bn_decay=bn_decay, scope=scope+ '/cov2fc')
        else:
            neighbor_feats = fully_connected_mlp(ori_neighbor_feats, pt_mlp_lst, dim = 4, bn_decay=bn_decay, dropout_ratio=0, flatten=False, use_bn=BN, use_relu=True, bn_axis=1, attr=None, scope=scope + '/cov2fc', flip=False)    # centers=centers.reshape((B, max_o_grid, 3), name=scope+"/centerreshape")
    if configs["attfdim"] > 0 and len(att_ele_lst)>0:
        att_ele_lst[-1] = C
        if BN:
            att_vec = mlp2d_c(att_vec, [att_ele_lst[0]], use_bn=BN, bn_decay=bn_decay, attr=None,
                              scope=scope + "/update_att_mlp2d_frst")
            if att_full == "last":
                att_vec = mx.sym.concat(att_vec, ori_neighbor_feats, dim=C_dim, name=scope + "/att_full_concat")
                print("has attfull last")
            elif att_full == "next":
                att_vec = mx.sym.concat(att_vec, neighbor_feats, dim=C_dim, name=scope + "/att_full_concat")
                print("has attfull next")
            if contextvec is not None:
                att_vec = mx.sym.concat(att_vec, contextvec, dim=C_dim, name=scope + "/att_contextvec_concat")
                print("att_vec has contextvec")
            att_vec = mlp2d_c(att_vec, [att_ele_lst[1]] if len(att_ele_lst)==2 else att_ele_lst[1:], use_bn=BN, bn_decay=bn_decay, attr=None, scope=scope + "/update_att_mlp2d_scnd")
        else:
            att_vec = fully_connected_mlp(att_vec, [att_ele_lst[0]], dim=4, bn_decay=bn_decay, dropout_ratio=0, flatten=False, use_bn=BN, use_relu=True, bn_axis=1, attr=None, scope=scope + '/update_att_mlp2dfc_frst', flip=False)
            if att_full == "last":
                att_vec = mx.sym.concat(att_vec, ori_neighbor_feats, dim=C_dim, name=scope + "/att_full_concat")
                print("has attfull last")
            elif att_full == "next":
                att_vec = mx.sym.concat(att_vec, neighbor_feats, dim=C_dim, name=scope + "/att_full_concat")
                print("has attfull next")
            if contextvec is not None:
                att_vec = mx.sym.concat(att_vec, contextvec, dim=C_dim, name=scope + "/att_contextvec_concat")
                print("att_vec has contextvec")
            att_vec = fully_connected_mlp(att_vec, [att_ele_lst[1]] if len(att_ele_lst)==2 else att_ele_lst[1:], dim=4, bn_decay=bn_decay, dropout_ratio=0, flatten=False, use_bn=BN, use_relu=True, bn_axis=1, attr=None, scope=scope + '/update_att_mlp2dfc_scnd', flip=False)

        pair_feats = att_vec * neighbor_feats
    else:
        pair_feats = neighbor_feats
    return pair_feats
Esempio n. 3
0
def mlp_mask_and_pool(data, mlp_list, max_p_grid, pool_type="max", use_bn=BN, bn_decay=0.9, mask=None, scope=''):
    """
    mlp and max-pooling
    Input:
        data: (B, fps_npoints, rs_npoints, C), Symbol
        mlp: list of int32, output size of MLP
        bn_decay: decay parameter in batch normalization
    Output:
        pooled_data: (B, fps_npoints, mlp[-1]), Symbol
    """
    # mlp
    if BN:
        data = mlp2d_c(data, mlp_list, use_bn=use_bn, bn_decay=bn_decay, scope=scope)
    else:
        data = fully_connected_mlp(data, mlp_list, dim = 4, bn_decay=configs["bn_decay"], dropout_ratio=0,
            flatten=False, use_bn=BN, use_relu=True, bn_axis=1, attr=None, scope=scope, flip=False)
    # pooling
    if mask is not None and pool_type!="max" and pool_type!="max_pooling":
        data = mx.sym.broadcast_mul(data, mask, axis=1, name="pairmask")  # B * C * max_o_grid * max_p_grid
    # data = mx.symbol.max(data, axis=2, name=scope+'/pool')
    if pool_type == "max_pooling":
        data = mx.sym.Pooling(name=scope + "/max_pooling", data=data,
                       kernel=(1, max_p_grid), pool_type="max")
        return mx.sym.squeeze(data,  axis=3)
    elif pool_type == "max":
        data = mx.symbol.max(data, axis=3, name=scope + '/max')
        return data
    else:
        raise NotImplementedError("agg %s is no implemented" % pool_type)
Esempio n. 4
0
def contextvec_func(cntxt, max_p_grid, cntxt_mlp = [64,64], bn_decay=0.9, scope=""):
    print("cntxt_mlp:", cntxt_mlp)
    if len(cntxt_mlp)>0:
        if BN:
            cntxt = mlp2d_c(cntxt, cntxt_mlp, use_bn=BN, bn_decay=bn_decay, scope=scope+"/cnt")
        else:
            cntxt = fully_connected_mlp(cntxt, cntxt_mlp, dim = 4, bn_decay=bn_decay, dropout_ratio=0, flatten=False, use_bn=BN, use_relu=True, bn_axis=1, attr=None, scope=scope + '/cnt/cov2fc', flip=False)

    aggregated_cntxt = mx.sym.Pooling(name=scope + "/cnt/max_pooling", data=cntxt,
                                      kernel=(1, max_p_grid) if BN else (max_p_grid, 1), pool_type="max", layout="NCHW")
    aggregated_cntxt = mx.sym.tile(aggregated_cntxt, reps=(1, 1, 1, max_p_grid) if BN else (1, 1, max_p_grid, 1))
    return aggregated_cntxt
def sub_g_update(centers_xyz,
                 center_den,
                 neighbors,
                 has_feats,
                 center_masks,
                 neighbor_masks,
                 attfdim,
                 center_ori_feats=None,
                 pt_mlp_lst=None,
                 outDim=[64, 256],
                 cntxt_mlp=None,
                 shape=[64, 800000, 100, 6],
                 scope='layer',
                 aggtype="gcn",
                 pool_type="max_pooling",
                 att_full="",
                 center_dim=[],
                 recalden=False,
                 bn_decay=0.9):
    '''
    RELU( MAX_POOLING ((geoRelation() * f_neib)) )
    :param centers:  B * max_o_grid * 3
    :param neighbor_feats: B * max_o_grid * max_p_grid * 4+C
    :param center_masks: B * max_o_grid
    :param neighbor_masks: B * max_o_grid * max_p_grid
    :param scope:
    :return:
    '''
    #
    B, max_o_grid, max_p_grid, N, C = shape
    if neighbor_masks is not None:
        neighbor_masks_expand = mx.sym.expand_dims(neighbor_masks,
                                                   axis=C_dim,
                                                   name=scope + "/concatmask")
    neighbor_feats = mx.sym.slice_axis(
        neighbors, axis=C_dim, begin=4, end=None) if has_feats else None
    # todo expand centers in Gridify
    centers_expand_xyz = mx.sym.tile(
        mx.sym.expand_dims(centers_xyz,
                           axis=P_dim,
                           name=scope + "/centerxyzexpand"),
        reps=(1, 1, 1, max_p_grid) if BN else (1, 1, max_p_grid, 1))
    neighbor_locs_xyz = mx.sym.slice_axis(neighbors,
                                          axis=C_dim,
                                          begin=0,
                                          end=3)
    geo_vec = neighbor_locs_xyz - centers_expand_xyz  # B * 3 * max_o_grid * max_p_grid
    geo_dist = mx.sym.sqrt(
        mx.sym.sum(mx.sym.square(geo_vec), axis=C_dim,
                   keepdims=True))  # B * max_o_grid * max_p_grid

    if max(attfdim, configs["localfdim"]) in [5, 11, 12]:
        geo_num = mx.sym.slice_axis(neighbors, axis=C_dim, begin=3,
                                    end=4)  # B * 1 * O * P or B * O * P * 1
        num_points = mx.sym.ones_like(
            geo_num, name=scope + "/oneslike") * configs["num_points"]
        geo_global_density = mx.sym.elemwise_div(geo_num,
                                                 num_points,
                                                 name=scope +
                                                 "/den_glo_ele_div")
        if max(attfdim, configs["localfdim"]) in [5, 12]:
            if recalden:
                geo_num_masked = geo_num
                if neighbor_masks is not None:
                    geo_num_masked = mx.sym.elemwise_mul(
                        neighbor_masks_expand, geo_num)
                geo_den_sum = mx.sym.sum(
                    geo_num_masked, axis=P_dim,
                    keepdims=True)  # B * 1 * O * 1 or B * O * 1 * 1
                geo_den_sum = mx.sym.clip(
                    geo_den_sum, 1, 1000000)  # make sure it's bigger than zero
            else:
                geo_den_sum = mx.sym.tile(mx.sym.expand_dims(
                    center_den, axis=P_dim, name=scope + "/centerdenexpand"),
                                          reps=(1, 1, 1, max_p_grid) if BN else
                                          (1, 1, max_p_grid, 1))
            geo_density = mx.sym.broadcast_div(geo_num,
                                               geo_den_sum,
                                               name=scope + "/den_ele_div")
    if attfdim <= 3:
        att_vec = geo_vec
    elif attfdim == 4:
        att_vec = mx.sym.concat(geo_dist,
                                geo_vec,
                                dim=C_dim,
                                name=scope + "/attconcat")
    elif attfdim == 5:
        att_vec = mx.sym.concat(geo_dist,
                                geo_vec,
                                geo_global_density,
                                dim=C_dim,
                                name=scope + "/attconcat")
    elif attfdim == 10:
        att_vec = mx.sym.concat(geo_dist,
                                geo_vec,
                                centers_expand_xyz,
                                neighbor_locs_xyz,
                                dim=C_dim,
                                name=scope + "/attconcat")
    elif attfdim == 11:
        att_vec = mx.sym.concat(geo_dist,
                                geo_vec,
                                centers_expand_xyz,
                                neighbor_locs_xyz,
                                geo_global_density,
                                dim=C_dim,
                                name=scope + "/attconcat")
    elif attfdim == 12:
        att_vec = mx.sym.concat(geo_dist,
                                geo_vec,
                                centers_expand_xyz,
                                neighbor_locs_xyz,
                                geo_density,
                                geo_global_density,
                                dim=C_dim,
                                name=scope + "/attconcat")
    else:
        raise NotImplementedError
    #

    if configs["localfdim"] <= 3:
        geo_feats = geo_vec
    elif configs["localfdim"] == 4:
        geo_feats = mx.sym.concat(geo_dist,
                                  geo_vec,
                                  dim=C_dim,
                                  name=scope + "/geo_concat")
    elif configs["localfdim"] == 5:
        geo_feats = mx.sym.concat(geo_dist,
                                  geo_vec,
                                  geo_global_density,
                                  dim=C_dim,
                                  name=scope + "/geo_concat")
    elif configs["localfdim"] == 10:
        geo_feats = mx.sym.concat(geo_dist,
                                  geo_vec,
                                  centers_expand_xyz,
                                  neighbor_locs_xyz,
                                  dim=C_dim,
                                  name=scope + "/geo_concat")
    elif configs["localfdim"] == 11:
        geo_feats = mx.sym.concat(geo_dist,
                                  geo_vec,
                                  centers_expand_xyz,
                                  neighbor_locs_xyz,
                                  geo_global_density,
                                  dim=C_dim,
                                  name=scope + "/geo_concat")
    elif configs["localfdim"] == 12:
        geo_feats = mx.sym.concat(geo_dist,
                                  geo_vec,
                                  centers_expand_xyz,
                                  neighbor_locs_xyz,
                                  geo_density,
                                  geo_global_density,
                                  dim=C_dim,
                                  name=scope + "/geo_concat")
    else:
        raise NotImplementedError

    if neighbor_feats is None:
        neighbor_feats = geo_feats
        if len(configs["elevation"]) > 0:
            if BN:
                neighbor_feats = mlp2d_c(neighbor_feats,
                                         configs["elevation"],
                                         use_bn=BN,
                                         bn_decay=bn_decay,
                                         scope=scope + "/elevation")
            else:
                neighbor_feats = fully_connected_mlp(neighbor_feats,
                                                     configs["elevation"],
                                                     dim=4,
                                                     bn_decay=bn_decay,
                                                     dropout_ratio=0,
                                                     flatten=False,
                                                     use_bn=BN,
                                                     use_relu=True,
                                                     bn_axis=1,
                                                     attr=None,
                                                     scope=scope +
                                                     '/elevation/cov2fc',
                                                     flip=False)
    elif configs["localfdim"] != 0:
        neighbor_feats = mx.symbol.concat(geo_feats,
                                          neighbor_feats,
                                          dim=C_dim,
                                          name=scope +
                                          "/neighbor_feats_concat")

    contextvec = None
    if cntxt_mlp is not None:
        contextvec = contextvec_func(neighbor_feats,
                                     max_p_grid,
                                     cntxt_mlp=cntxt_mlp,
                                     bn_decay=bn_decay,
                                     scope=scope)

    if aggtype == "gcn":
        pair_feats = verts_pair_func(
            neighbor_feats,
            att_vec,
            B,
            max_o_grid,
            max_p_grid,
            N,
            C,
            contextvec=contextvec,
            att_full=att_full,
            pt_mlp_lst=pt_mlp_lst,
            scope=scope,
            bn_decay=bn_decay)  # B * max_o_grid * max_p_grid * C
        if neighbor_masks is not None and pool_type != "max" and pool_type != "max_pooling":
            pair_feats = mx.sym.broadcast_mul(
                pair_feats, neighbor_masks_expand,
                name="pairmask")  # B * max_o_grid * max_p_grid * C
        agg_neigh_feats = aggregation_func(
            pair_feats,
            neighbor_masks,
            geo_dist,
            max_o_grid,
            max_p_grid,
            pool_type=pool_type,
            scope=scope + "/center_aggregation",
            bn_decay=bn_decay)  # B * max_o_grid * 1 * outDim[-1]
    elif aggtype == "agg_gcn":
        agg_neigh_feats = aggregation_func(neighbor_feats,
                                           neighbor_masks,
                                           geo_dist,
                                           max_o_grid,
                                           max_p_grid,
                                           pool_type=pool_type,
                                           mlp_lst=pt_mlp_lst,
                                           scope=scope + "/neigh_aggregation",
                                           bn_decay=bn_decay)
    else:
        raise NotImplementedError("aggtype %s is not implemented" % aggtype)
    if center_ori_feats is not None:
        if len(center_dim) > 0:
            if BN:
                center_feats = mlp1d_c(center_ori_feats,
                                       center_dim,
                                       bn_decay=bn_decay,
                                       use_bn=BN,
                                       attr=None,
                                       scope=scope + '/centerfeats_conv1d')
            elif configs["use_bn"] == 'p':
                center_feats = fully_connected_mlp_withbn(
                    center_ori_feats,
                    center_dim,
                    dim=3,
                    bn_decay=bn_decay,
                    dropout_ratio=0,
                    flatten=False,
                    use_bn=BN,
                    use_relu=True,
                    bn_axis=1,
                    attr=None,
                    scope=scope + '/centerfeats_conv1d1d',
                    flip=False)
            else:
                center_feats = fully_connected_mlp(center_ori_feats,
                                                   center_dim,
                                                   dim=3,
                                                   bn_decay=bn_decay,
                                                   dropout_ratio=0,
                                                   flatten=False,
                                                   use_bn=BN,
                                                   use_relu=True,
                                                   bn_axis=1,
                                                   attr=None,
                                                   scope=scope +
                                                   '/centerfeats_conv1d1d',
                                                   flip=False)
        else:
            center_feats = center_ori_feats
        if configs["up_center_inte"] == "add":
            agg_neigh_feats = center_feats + agg_neigh_feats
        elif configs["up_center_inte"] == "concat":
            agg_neigh_feats = mx.symbol.concat(center_feats,
                                               agg_neigh_feats,
                                               dim=1 if BN else 2,
                                               name=scope +
                                               "/agg_feats_concat")
    agg_feats = update_func(agg_neigh_feats,
                            outDim=outDim,
                            scope=scope + "/center_update",
                            bn_decay=bn_decay)  # B * max_o_grid * outDim[-1]
    if center_masks is not None:
        agg_feats = mx.sym.broadcast_mul(
            agg_feats,
            mx.sym.expand_dims(center_masks,
                               axis=1 if BN else 2,
                               name=scope + "/maskexpand"),
            name=scope + "/centermask")  # B * outDim[-1] * max_o_grid

    return agg_feats