def contextvec_func(cntxt, max_p_grid, cntxt_mlp=[64, 64], bn_decay=0.9, scope=""): print("cntxt_mlp:", cntxt_mlp) if len(cntxt_mlp) > 0: if BN: cntxt = mlp2d_c(cntxt, cntxt_mlp, use_bn=BN, bn_decay=bn_decay, scope=scope + "/cnt") else: cntxt = fully_connected_mlp(cntxt, cntxt_mlp, dim=4, bn_decay=bn_decay, dropout_ratio=0, flatten=False, use_bn=BN, use_relu=True, bn_axis=1, attr=None, scope=scope + '/cnt/cov2fc', flip=False) aggregated_cntxt = mx.sym.Pooling(name=scope + "/cnt/max_pooling", data=cntxt, kernel=(1, max_p_grid) if BN else (max_p_grid, 1), pool_type="max", layout="NCHW") aggregated_cntxt = mx.sym.tile(aggregated_cntxt, reps=(1, 1, 1, max_p_grid) if BN else (1, 1, max_p_grid, 1)) return aggregated_cntxt
def update_func(center_feats, outDim=[64, 256], scope="center_update", bn_decay=0.9): ''' :param centers: B * max_o_grid * 3 :param center_feats: B * max_o_grid * C :param outDim: [128, 256,...] :return: center_feats B * max_o_grid * outDim[-1] ''' if configs["relu"]: center_feats = mx.symbol.relu(center_feats, name=scope + '/pre_relu') # for i in range(len(outDim)): # center_feats = fully_connected(center_feats, outDim[i], bn_decay=bn_decay, dropout_ratio=0, # flatten=False, use_bn=True, use_relu=True, attr=None, scope=scope + '/fc'+str(i), flip=True) if len(outDim) != 0: if BN: center_feats = mlp1d_c(center_feats, outDim, bn_decay=bn_decay, use_bn=BN, attr=None, scope=scope + '/cov1d') else: center_feats = fully_connected_mlp(center_feats, outDim, dim=3, bn_decay=bn_decay, dropout_ratio=0, flatten=False, use_bn=BN, use_relu=True, bn_axis=1, attr=None, scope=scope + '/cov1dfc', flip=False) return center_feats
def sub_g_update(centers_xyz, center_den, neighbors, has_feats, center_masks, neighbor_masks, attfdim, center_ori_feats = None, pt_mlp_lst = None, outDim= [64, 256], shape = [64, 800000, 100, 6], scope='layer', aggtype="gcn", pool_type="max_pooling", att_full=False, center_dim=[], recalden=False, bn_decay=0.9): ''' RELU( MAX_POOLING ((geoRelation() * f_neib)) ) :param centers: B * max_o_grid * 3 :param neighbor_feats: B * max_o_grid * max_p_grid * 4+C :param center_masks: B * max_o_grid :param neighbor_masks: B * max_o_grid * max_p_grid :param scope: :return: ''' # B, max_o_grid, max_p_grid, N, C = shape if neighbor_masks is not None: neighbor_masks_expand = mx.sym.expand_dims(neighbor_masks, axis=C_dim, name=scope+"/concatmask") neighbor_feats = mx.sym.slice_axis(neighbors, axis=C_dim, begin=4, end=None) if has_feats else None # todo expand centers in Gridify centers_expand_xyz = mx.sym.tile(mx.sym.expand_dims(centers_xyz, axis=P_dim, name=scope + "/centerxyzexpand"), reps=(1, 1, 1, max_p_grid) if BN else (1, 1, max_p_grid, 1)) neighbor_locs_xyz = mx.sym.slice_axis(neighbors, axis=C_dim, begin=0, end=3) geo_vec = neighbor_locs_xyz - centers_expand_xyz # B * 3 * max_o_grid * max_p_grid geo_dist = mx.sym.sqrt(mx.sym.sum(mx.sym.square(geo_vec), axis=C_dim, keepdims=True)) # B * max_o_grid * max_p_grid if max(attfdim, configs["localfdim"]) in [5,11,12]: geo_num = mx.sym.slice_axis(neighbors, axis=C_dim, begin=3, end=4) # B * 1 * O * P or B * O * P * 1 num_points = mx.sym.ones_like(geo_num, name = scope+"/oneslike") * configs["num_points"] geo_global_density = mx.sym.elemwise_div(geo_num, num_points, name=scope+"/den_glo_ele_div") if max(attfdim, configs["localfdim"]) in [5,12]: if recalden: geo_num_masked = geo_num if neighbor_masks is not None: geo_num_masked = mx.sym.elemwise_mul(neighbor_masks_expand, geo_num) geo_den_sum = mx.sym.sum(geo_num_masked, axis=P_dim, keepdims=True) # B * 1 * O * 1 or B * O * 1 * 1 geo_den_sum = mx.sym.clip(geo_den_sum, 1, 1000000) # make sure it's bigger than zero else: geo_den_sum = mx.sym.tile(mx.sym.expand_dims(center_den, axis=P_dim, name=scope + "/centerdenexpand"), reps=(1, 1, 1, max_p_grid) if BN else (1, 1, max_p_grid, 1)) geo_density = mx.sym.broadcast_div(geo_num, geo_den_sum, name=scope+"/den_ele_div") if attfdim <= 3: att_vec = geo_vec elif attfdim == 4: att_vec = mx.sym.concat(geo_dist, geo_vec, dim=C_dim, name=scope+"/attconcat") elif attfdim == 5: att_vec = mx.sym.concat(geo_dist, geo_vec, geo_global_density, dim=C_dim, name=scope+"/attconcat") elif attfdim == 10: att_vec = mx.sym.concat(geo_dist, geo_vec, centers_expand_xyz, neighbor_locs_xyz, dim=C_dim, name=scope+"/attconcat") elif attfdim == 11: att_vec = mx.sym.concat(geo_dist, geo_vec, centers_expand_xyz, neighbor_locs_xyz, geo_global_density, dim=C_dim, name=scope+"/attconcat") elif attfdim == 12: att_vec = mx.sym.concat(geo_dist, geo_vec, centers_expand_xyz, neighbor_locs_xyz, geo_density, geo_global_density, dim=C_dim, name=scope+"/attconcat") else: raise NotImplementedError # if configs["localfdim"] <= 3: geo_feats = geo_vec elif configs["localfdim"] == 4: geo_feats = mx.sym.concat(geo_dist, geo_vec, dim=C_dim, name=scope+"/geo_concat") elif configs["localfdim"] == 5: geo_feats = mx.sym.concat(geo_dist, geo_vec, geo_global_density, dim=C_dim, name=scope+"/geo_concat") elif configs["localfdim"] == 10: geo_feats = mx.sym.concat(geo_dist, geo_vec, centers_expand_xyz, neighbor_locs_xyz, dim=C_dim, name=scope+"/geo_concat") elif configs["localfdim"] == 11: geo_feats = mx.sym.concat(geo_dist, geo_vec, centers_expand_xyz, neighbor_locs_xyz, geo_global_density, dim=C_dim, name=scope+"/geo_concat") elif configs["localfdim"] == 12: geo_feats = mx.sym.concat(geo_dist, geo_vec, centers_expand_xyz, neighbor_locs_xyz, geo_density, geo_global_density, dim=C_dim, name=scope+"/geo_concat") else: raise NotImplementedError if neighbor_feats is None: neighbor_feats = geo_feats elif configs["localfdim"] != 0: neighbor_feats = mx.symbol.concat(geo_feats, neighbor_feats, dim=C_dim, name=scope+"/neighbor_feats_concat") if aggtype == "gcn": if att_full: att_vec = mx.sym.concat(att_vec, neighbor_feats, dim=C_dim, name=scope+"/att_full_concat") pair_feats = verts_pair_func(neighbor_feats, att_vec, B, max_o_grid, max_p_grid, N, C, att_full=att_full, pt_mlp_lst=pt_mlp_lst, scope=scope, bn_decay=bn_decay) # B * max_o_grid * max_p_grid * C if neighbor_masks is not None and pool_type!="max" and pool_type!="max_pooling": pair_feats = mx.sym.broadcast_mul(pair_feats, neighbor_masks_expand, name="pairmask") # B * max_o_grid * max_p_grid * C agg_neigh_feats = aggregation_func(pair_feats, neighbor_masks, geo_dist, max_o_grid, max_p_grid, pool_type=pool_type, scope=scope+"/center_aggregation", bn_decay=bn_decay) # B * max_o_grid * 1 * outDim[-1] elif aggtype == "agg_gcn": agg_neigh_feats = aggregation_func(neighbor_feats, neighbor_masks, geo_dist, max_o_grid, max_p_grid, pool_type=pool_type, mlp_lst=pt_mlp_lst, scope=scope + "/neigh_aggregation", bn_decay=bn_decay) else: raise NotImplementedError("aggtype %s is not implemented" % aggtype) if center_ori_feats is not None: if len(center_dim) > 0: if BN: center_feats = mlp1d_c(center_ori_feats, center_dim, bn_decay=bn_decay, use_bn=BN, attr=None, scope=scope + '/centerfeats_conv1d') elif configs["use_bn"] == 'p': center_feats = fully_connected_mlp_withbn(center_ori_feats, center_dim, dim=3, bn_decay=bn_decay, dropout_ratio=0, flatten=False, use_bn=BN, use_relu=True, bn_axis=1, attr=None, scope=scope + '/centerfeats_conv1d1d', flip=False) else: center_feats = fully_connected_mlp(center_ori_feats, center_dim, dim = 3, bn_decay=bn_decay, dropout_ratio=0, flatten=False, use_bn=BN, use_relu=True, bn_axis=1, attr=None, scope=scope + '/centerfeats_conv1d1d', flip=False) else: center_feats = center_ori_feats if configs["up_center_inte"] == "add": agg_neigh_feats = center_feats + agg_neigh_feats elif configs["up_center_inte"] == "concat": agg_neigh_feats = mx.symbol.concat(center_feats, agg_neigh_feats, dim=1 if BN else 2, name=scope + "/agg_feats_concat") agg_feats = update_func(agg_neigh_feats, outDim=outDim, scope=scope + "/center_update", bn_decay=bn_decay) # B * max_o_grid * outDim[-1] if center_masks is not None: agg_feats = mx.sym.broadcast_mul(agg_feats, mx.sym.expand_dims(center_masks, axis=1 if BN else 2, name=scope+"/maskexpand"), name=scope+"/centermask") # B * outDim[-1] * max_o_grid return agg_feats
def aggregation_func(masked_pair_feats, neighbor_masks, geo_dist, max_o_grid, max_p_grid, pool_type="max", mlp_lst=None, scope='center_aggregation', bn_decay=0.9): ''' :param masked_pair_feats: B * max_o_grid * max_p_grid * C :param mlp: true of false :return: ''' if pool_type == "max_pooling": # aggregated_feats = mx.sym.Pooling(name=scope+"/max_pooling", data=masked_pair_feats, # kernel=(1, max_p_grid), pool_type="max", layout="NCHW" if BN else "NHWC") # aggregated_feats = mx.sym.squeeze(aggregated_feats, axis=P_dim) aggregated_feats = mx.sym.Pooling(name=scope + "/max_pooling", data=masked_pair_feats, kernel=(1, max_p_grid) if BN else (max_p_grid, 1), pool_type="max", layout="NCHW") aggregated_feats = mx.sym.squeeze(aggregated_feats, axis=P_dim) elif pool_type == "max": aggregated_feats = mx.symbol.max(masked_pair_feats, axis=P_dim, name=scope + '/max') elif pool_type == "sum": aggregated_feats = mx.symbol.sum(masked_pair_feats, axis=P_dim, name=scope + '/sum') elif pool_type == "inter_sum": aggregated_feats = interpolate_all(geo_dist, masked_pair_feats, neighbor_masks, scope=scope + '/interp') # (B, M, C1) else: raise NotImplementedError if mlp_lst is not None: if BN: aggregated_feats = mlp1d_c(aggregated_feats, mlp_lst, bn_decay=bn_decay, use_bn=BN, attr=None, scope=scope + '/agg_mlp/cov1d') elif configs["use_bn"] == 'p': aggregated_feats = fully_connected_mlp_withbn(aggregated_feats, mlp_lst, dim=3, bn_decay=bn_decay, dropout_ratio=0, flatten=False, use_bn=BN, use_relu=True, bn_axis=1, attr=None, scope=scope + '/agg_fc', flip=False) else: aggregated_feats = fully_connected_mlp(aggregated_feats, mlp_lst, dim=3, bn_decay=bn_decay, dropout_ratio=0, flatten=False, use_bn=BN, use_relu=True, bn_axis=1, attr=None, scope=scope + '/agg_fc', flip=False) return aggregated_feats
def verts_pair_func(ori_neighbor_feats, att_vec, B, max_o_grid, max_p_grid, N, C, att_full="", pt_mlp_lst=None, contextvec=None, scope="pair_relation", bn_decay=0.9): ''' :param centers: B * max_o_grid * 3 :param neighbor_feats: B * max_o_grid * max_p_grid * C :param neighbor_locs: B * max_o_grid * max_p_grid * 3 :param C: :param max_p_grid: :param scope: :return: geo_feat: B * max_o_grid * max_p_grid * C ''' if pt_mlp_lst is not None and len(pt_mlp_lst) > 0: C = pt_mlp_lst[-1] print("pt_mlp_lst:", pt_mlp_lst) if BN: neighbor_feats = mlp2d_c(ori_neighbor_feats, pt_mlp_lst, use_bn=BN, bn_decay=bn_decay, scope=scope) else: neighbor_feats = fully_connected_mlp(ori_neighbor_feats, pt_mlp_lst, dim=4, bn_decay=bn_decay, dropout_ratio=0, flatten=False, use_bn=BN, use_relu=True, bn_axis=1, attr=None, scope=scope + '/cov2fc', flip=False) # centers=centers.reshape((B, max_o_grid, 3), name=scope+"/centerreshape") if configs["attfdim"] > 0: if BN: att_vec = mlp2d_c(att_vec, [C // 4], use_bn=BN, bn_decay=bn_decay, attr=None, scope=scope + "/update_att_mlp2d_frst") if att_full == "last": att_vec = mx.sym.concat(att_vec, ori_neighbor_feats, dim=C_dim, name=scope + "/att_full_concat") print("has attfull last") elif att_full == "next": att_vec = mx.sym.concat(att_vec, neighbor_feats, dim=C_dim, name=scope + "/att_full_concat") print("has attfull next") if contextvec is not None: att_vec = mx.sym.concat(att_vec, contextvec, dim=C_dim, name=scope + "/att_contextvec_concat") print("att_vec has contextvec") att_vec = mlp2d_c(att_vec, [C], use_bn=BN, bn_decay=bn_decay, attr=None, scope=scope + "/update_att_mlp2d_scnd") else: att_vec = fully_connected_mlp(att_vec, [C // 4], dim=4, bn_decay=bn_decay, dropout_ratio=0, flatten=False, use_bn=BN, use_relu=True, bn_axis=1, attr=None, scope=scope + '/update_att_mlp2dfc_frst', flip=False) if att_full == "last": att_vec = mx.sym.concat(att_vec, ori_neighbor_feats, dim=C_dim, name=scope + "/att_full_concat") print("has attfull last") elif att_full == "next": att_vec = mx.sym.concat(att_vec, neighbor_feats, dim=C_dim, name=scope + "/att_full_concat") print("has attfull next") if contextvec is not None: att_vec = mx.sym.concat(att_vec, contextvec, dim=C_dim, name=scope + "/att_contextvec_concat") print("att_vec has contextvec") att_vec = fully_connected_mlp(att_vec, [C], dim=4, bn_decay=bn_decay, dropout_ratio=0, flatten=False, use_bn=BN, use_relu=True, bn_axis=1, attr=None, scope=scope + '/update_att_mlp2dfc_scnd', flip=False) pair_feats = att_vec * neighbor_feats else: pair_feats = neighbor_feats return pair_feats
def sub_g_update(centers_xyz, center_den, neighbors, has_feats, center_masks, neighbor_masks, attfdim, pt_mlp_lst=None, att_ele_lst=None, outDim=[64, 256], cntxt_mlp=[], shape=[64, 800000, 100, 6], scope='layer', aggtype="gcn", pool_type="max_pooling", att_full="off", recalden=False, bn_decay=0.9): ''' RELU( MAX_POOLING ((geoRelation() * f_neib)) ) :param centers: B * max_o_grid * 3 :param neighbor_feats: B * max_o_grid * max_p_grid * C :param neighbor_locs: B * max_o_grid * max_p_grid * 3 :param center_masks: B * max_o_grid :param neighbor_masks: B * max_o_grid * max_p_grid :param scope: :return: ''' # shape_array B, max_o_grid, max_p_grid, N, C = shape if neighbor_masks is not None: neighbor_masks_expand = mx.sym.expand_dims(neighbor_masks, axis=C_dim, name=scope + "/concatmask") neighbor_feats = mx.sym.slice_axis( neighbors, axis=C_dim, begin=4, end=None) if has_feats else None centers_expand_xyz = mx.sym.tile( mx.sym.expand_dims(centers_xyz, axis=P_dim, name=scope + "/centerxyzexpand"), reps=(1, 1, 1, max_p_grid) if BN else (1, 1, max_p_grid, 1)) neighbor_locs_xyz = mx.sym.slice_axis(neighbors, axis=C_dim, begin=0, end=3) geo_vec = neighbor_locs_xyz - centers_expand_xyz # B * max_o_grid * 3 geo_dist = mx.sym.sqrt( mx.sym.sum(mx.sym.square(geo_vec), axis=C_dim, keepdims=True)) # B * max_o_grid * max_p_grid * 1 if max(configs["attfdim"], configs["localfdim"]) in [4, 5, 11, 12]: geo_num = mx.sym.slice_axis(neighbors, axis=C_dim, begin=3, end=4) num_points = mx.sym.ones_like( geo_num, name=scope + "/oneslike") * configs["num_points"] geo_global_density = mx.sym.elemwise_div(geo_num, num_points, name=scope + "/den_glo_ele_div") if max(attfdim, configs["localfdim"]) in [5, 12]: if recalden: geo_num_masked = geo_num if neighbor_masks is not None: geo_num_masked = mx.sym.elemwise_mul( neighbor_masks_expand, geo_num) geo_den_sum = mx.sym.sum( geo_num_masked, axis=P_dim, keepdims=True) # B * 1 * O * 1 or B * O * 1 * 1 geo_den_sum = mx.sym.clip( geo_den_sum, 1, 1000000) # make sure it's bigger than zero else: geo_den_sum = mx.sym.tile(mx.sym.expand_dims( center_den, axis=P_dim, name=scope + "/centerdenexpand"), reps=(1, 1, 1, max_p_grid) if BN else (1, 1, max_p_grid, 1)) geo_density = mx.sym.broadcast_div(geo_num, geo_den_sum, name=scope + "/den_ele_div") if configs["attfdim"] <= 3: att_vec = geo_vec elif configs["attfdim"] == 4: att_vec = mx.sym.concat(geo_dist, geo_vec, dim=C_dim) elif configs["attfdim"] == 5: att_vec = mx.sym.concat(geo_dist, geo_vec, geo_global_density, dim=C_dim) elif configs["attfdim"] == 10: att_vec = mx.sym.concat(geo_dist, geo_vec, centers_expand_xyz, neighbor_locs_xyz, dim=C_dim) elif configs["attfdim"] == 11: att_vec = mx.sym.concat(geo_dist, geo_vec, centers_expand_xyz, neighbor_locs_xyz, geo_global_density, dim=C_dim) elif configs["attfdim"] == 12: att_vec = mx.sym.concat(geo_dist, geo_vec, centers_expand_xyz, neighbor_locs_xyz, geo_density, geo_global_density, dim=C_dim) else: raise NotImplementedError if configs["localfdim"] <= 3: geo_feats = geo_vec elif configs["localfdim"] == 4: geo_feats = mx.sym.concat(geo_dist, geo_vec, dim=C_dim) elif configs["localfdim"] == 5: geo_feats = mx.sym.concat(geo_dist, geo_vec, geo_global_density, dim=C_dim) elif configs["localfdim"] == 10: geo_feats = mx.sym.concat(geo_dist, geo_vec, centers_expand_xyz, neighbor_locs_xyz, dim=C_dim) elif configs["localfdim"] == 11: geo_feats = mx.sym.concat(geo_dist, geo_vec, centers_expand_xyz, neighbor_locs_xyz, geo_global_density, dim=C_dim) elif configs["localfdim"] == 12: geo_feats = mx.sym.concat(geo_dist, geo_vec, centers_expand_xyz, neighbor_locs_xyz, geo_density, geo_global_density, dim=C_dim) else: raise NotImplementedError if neighbor_feats is None: neighbor_feats = geo_feats if len(configs["elevation"]) > 0: if BN: neighbor_feats = mlp2d_c(neighbor_feats, configs["elevation"], use_bn=BN, bn_decay=bn_decay, scope=scope + "/elevation") else: neighbor_feats = fully_connected_mlp(neighbor_feats, configs["elevation"], dim=4, bn_decay=bn_decay, dropout_ratio=0, flatten=False, use_bn=BN, use_relu=True, bn_axis=1, attr=None, scope=scope + '/elevation/cov2fc', flip=False) elif configs["localfdim"] != 0: neighbor_feats = mx.symbol.concat(geo_feats, neighbor_feats, dim=C_dim, name=scope + "/neighbor_feats_concat") contextvec = None if cntxt_mlp is not None: contextvec = contextvec_func(neighbor_feats, max_p_grid, cntxt_mlp=cntxt_mlp, bn_decay=bn_decay, scope=scope) if aggtype == "gcn": pair_feats = verts_pair_func( neighbor_feats, att_vec, B, max_o_grid, max_p_grid, N, C, att_full=att_full, pt_mlp_lst=pt_mlp_lst, att_ele_lst=att_ele_lst, contextvec=contextvec, scope=scope, bn_decay=bn_decay) # B * max_o_grid * max_p_grid * C if neighbor_masks is not None and pool_type != "max" and pool_type != "max_pooling": pair_feats = mx.sym.broadcast_mul( pair_feats, neighbor_masks_expand, name="pairmask") # B * max_o_grid * max_p_grid * C agg_neigh_feats = aggregation_func( pair_feats, max_o_grid, max_p_grid, pool_type=pool_type, scope=scope + "/center_aggregation", bn_decay=bn_decay) # B * max_o_grid * 1 * outDim[-1] agg_feats = update_func( agg_neigh_feats, outDim=outDim, scope=scope + "/center_update", bn_decay=bn_decay) # B * max_o_grid * outDim[-1] if center_masks is not None: agg_feats = mx.sym.broadcast_mul( agg_feats, mx.sym.expand_dims(center_masks, axis=1 if BN else 2, name=scope + "/maskexpand"), name=scope + "/centermask") # B * max_o_grid * outDim[-1] else: raise NotImplementedError("aggtype %s is no implemented" % aggtype) return agg_feats