def indep_energy_block(x, ccoords, beta, x_row_splits): x = StopGradient()(x) ccoords = StopGradient()(ccoords) beta = StopGradient()(beta) feat = [x] sx, psrs, sids, asso_idx, belongs_to_prs = CondensateToPseudoRS( radius=0.8, soft=True, threshold=0.1)([x, ccoords, beta, x_row_splits]) sx = Concatenate()([RaggedSumAndScatter()([sx, psrs, belongs_to_prs]), sx]) feat.append(VertexScatterer()([sx, sids, sx])) sx, _ = FusedRaggedGravNetLinParse(n_neighbours=128, n_dimensions=4, n_filters=64, n_propagate=[32, 32, 32, 32], name='gravnet_enblock_prs')([sx, psrs]) x = VertexScatterer()([sx, sids, sx]) feat.append(x) x, _ = FusedRaggedGravNetLinParse( n_neighbours=128, n_dimensions=4, n_filters=64, n_propagate=[32, 32, 32, 32], name='gravnet_enblock_last')([x, x_row_splits]) feat.append(x) x = Concatenate()(feat) x = Dense(64, activation='elu', name="dense_last_enblock_1")(x) x = Dense(64, activation='elu', name="dense_last_enblock_2")(x) energy = Dense(1, activation=None, name="dense_enblock_final")(x) energy = energy #linear return energy
def indep_energy_block2(x, energy, ccoords, beta, x_row_splits, energy_proxy=None, stopxgrad=True): if stopxgrad: x = StopGradient()(x) energy = StopGradient()(energy) ccoords = StopGradient()(ccoords) beta = StopGradient()(beta) feat = [x] x = Dense(64, activation='elu', name="dense_last_start_enblock_1")(x) x = Dense(64, activation='elu', name="dense_last_start_enblock_2")(x) x = Concatenate()([energy, x]) sx, psrs, sids, asso_idx, belongs_to_prs = CondensateToPseudoRS( radius=0.8, soft=True, threshold=0.2)([x, ccoords, beta, x_row_splits]) sx = Dense(128, activation='elu', name="dense_set_sum_input")(sx) sx = Dense(128, activation='elu', name="dense_set_sum_input_b")(sx) sx = Dense(128, activation='elu', name="dense_set_sum_input_c")(sx) #deep set like approach sx = Concatenate()([RaggedSumAndScatter()([sx, psrs, belongs_to_prs]), sx]) feat.append(VertexScatterer()([sx, sids, sx])) sx, _ = FusedRaggedGravNetLinParse(n_neighbours=128, n_dimensions=4, n_filters=64, n_propagate=[32, 32, 32, 32], name='gravnet_enblock_prs')([sx, psrs]) x = VertexScatterer()([sx, sids, sx]) feat.append(x) x = Concatenate()([x, energy]) x, _ = FusedRaggedGravNetAggAtt( n_neighbours=256, n_dimensions=4, n_filters=64, n_propagate=[32, 32, 32, 32], name='gravnet_enblock_last')([x, x_row_splits, beta]) feat.append(x) x = Concatenate()(feat) x = Dense(64, activation='elu', name="dense_last_enblock_1")(x) x = Dense(64, activation='elu', name="dense_last_enblock_2")(x) #energy = None energy = Dense(1, activation=None, name="predicted_energy")(x) return energy
def gravnet_model(Inputs, feature_dropout=-1., addBackGatherInfo=True): ######## pre-process all inputs and create global indices etc. No DNN actions here feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits = td.interpretAllModelInputs( Inputs) # feat = Lambda(lambda x: tf.squeeze(x,axis=1)) (feat) #tf.print([(t.shape, t.name) for t in [feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits]]) #feat, t_idx, t_energy, t_pos, t_time, t_pid = NormalizeInputShapes()( # [feat, t_idx, t_energy, t_pos, t_time, t_pid] # ) orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits gidx_orig = CreateGlobalIndices()(feat) _, row_splits = RaggedConstructTensor()([feat, row_splits]) rs = row_splits feat_norm = ProcessFeatures()( feat) #get rid of unit scalings, almost normalise feat_norm = BatchNormalization(momentum=0.6)(feat_norm) x = feat_norm energy = SelectFeatures(0, 1)(feat) time = SelectFeatures(8, 9)(feat) orig_coords = SelectFeatures(5, 8)(feat_norm) ######## create output lists allfeat = [] backgatheredids = [] gatherids = [] backgathered = [] backgathered_coords = [] ####### create simple first coordinate transformation explicitly (time critical) #trans_coords = ManualCoordTransform()(orig_coords) coords = orig_coords coords = Dense(3, use_bias=False, kernel_initializer=tf.keras.initializers.Identity())(coords) nidx, dist = KNN(K=48)([coords, rs]) first_nidx = nidx first_dist = dist first_coords = coords dist = LocalDistanceScaling(max_scale=10.)([dist, Dense(1)(x)]) x_mp = DistanceWeightedMessagePassing([32, 16, 8])([x, nidx, dist]) x_mp = Dense(32, activation='elu')(x_mp) x_mp = BatchNormalization(momentum=0.6)(x_mp) x = Concatenate()([x, x_mp]) ###### collect information about the surrounding energy and time distributions per vertex ### ncov = Dense(4, kernel_initializer=tf.keras.initializers.Identity())( Concatenate()([energy, time, x])) ncov = NeighbourCovariance()([coords, dist, ncov, nidx]) ncov = Dense(24, activation='elu', name='pre_dense_ncov_c')(ncov) ncov = BatchNormalization(momentum=0.6)(ncov) #should be enough for PCA, total info: 2 * (9+3) ##### put together and process #### x = Concatenate()([x, x_mp, ncov]) #x = Dense(64, activation='elu',name='pre_dense_a')(x) x = Dense(64, activation='elu', name='pre_dense_b')(x) ####### add first set of outputs to output lists allfeat.append(x) backgathered_coords.append(coords) total_iterations = 5 sel_gidx = gidx_orig for i in range(total_iterations): ###### reshape the graph to fewer vertices #### hier = Dense(1)(x) #narrow the local scaling down at each iteration as the coords become more abstract dist = LocalDistanceScaling(max_scale=10.)( [dist, Dense(1)(Concatenate()([x, dist]))]) x_cl, rs, bidxs, sel_gidx, energy, x, t_idx, coords = LocalClusterReshapeFromNeighbours( K=6, radius= 0.2, #doesn't really have an effect because of local distance scaling print_reduction=True, loss_enabled=True, loss_scale=1., loss_repulsion=0.4, print_loss=True, name='clustering_' + str(i))([ x, dist, hier, nidx, rs, sel_gidx, energy, x, t_idx, coords, t_idx ]) #last is truth index used by layer gatherids.append(bidxs) energy = ReduceSumEntirely()(energy) #use EdgeConv operation to determine cluster properties if True or i: #only after second iteration because of OOM x_cl = Reshape([-1, x.shape[-1]])(x_cl) #get to shape V x K x F x_cl = EdgeConvStatic([64, 64, 64], add_mean=True, name="ec_static_" + str(i))(x_cl) x_cl = Concatenate()([x, x_cl]) x = Concatenate()([x_cl, StopGradient()(coords)]) x = Dense(64, activation='elu', name='dense_clc0_' + str(i))(x) x = Dense(64, activation='elu', name='dense_clc1_' + str(i))(x) x = BatchNormalization(momentum=0.6)(x) #notice last relu for feature weighting later ### now these are the new cluster features, up for the next iteration of building new latent space #x = RaggedGlobalExchange()([x,rs]) x_gn, coords, nidx, dist = RaggedGravNet( n_neighbours=64 + 32 * i, n_dimensions=3, n_filters=64, n_propagate=64, return_self=True)([Concatenate()([coords, x]), rs]) ng_coords = StopGradient()(coords) ng_dist = StopGradient()(dist) x_gn = Dense(32, activation='elu')(x_gn) ### add neighbour summary statistics #dist = LocalDistanceScaling(max_scale=10.)([dist, Dense(1)(x_gn)]) x_ncov = Dense(16)(x) x_ncov = NeighbourCovariance()([ng_coords, ng_dist, x_ncov, nidx]) x_ncov = Dense(64, activation='elu', name='dense_ncov_a_' + str(i))(x_ncov) x_ncov = Dense(64, activation='elu', name='dense_ncov_b_' + str(i))(x_ncov) x_ncov = BatchNormalization(momentum=0.6)(x_ncov) ### with all this information perform a few message passing steps x_mp = x x_mp = DistanceWeightedMessagePassing([32, 16, 8])([x_mp, nidx, ng_dist]) x_mp = Dense(64, activation='elu')(x_mp) x_mp = Dense(32, activation='elu')(x_mp) x_mp = BatchNormalization(momentum=0.6)(x_mp) x = Concatenate()([x, x_mp, x_ncov, x_gn, ng_coords, ng_dist]) ##### prepare output of this iteration x = Dense(64, activation='elu', name='dense_out_a_' + str(i))(x) x = Dense(64, activation='elu', name='dense_out_b_' + str(i))(x) x = BatchNormalization(momentum=0.6)(x) #### compress further for output, but forward fill 64 feature x to next iteration x_r = Dense(32, activation='elu', name='dense_out_c_' + str(i))(x) #x_r = Concatenate()([x_r, ng_coords, ng_dist]) #x_r = Concatenate()([StopGradient()(coords),x_r]) ## add coordinates, might come handy for cluster space if i >= total_iterations - 1: energy = MultiBackGather()( [energy, gatherids]) #assign energy sum to all cluster components allfeat.append(MultiBackGather()([x_r, gatherids])) backgatheredids.append(MultiBackGather()([sel_gidx, gatherids])) backgathered_coords.append(MultiBackGather()([coords, gatherids])) x = Concatenate(name='allconcat')(allfeat) x = RaggedGlobalExchange()([x, row_splits]) x = Dense(256, activation='elu', name='globalexdense')(x) x = RaggedGlobalExchange()([x, row_splits]) x = Dense(128, activation='elu', name='alldense')(x) x = Dense(64, activation='elu')(x) x = Dense(32, activation='elu')(x) x = Concatenate()([first_coords, x]) x = BatchNormalization(momentum=0.6)(x) pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id = create_outputs( x, feat) # # # double scale phase transition with linear beta + qmin # -> more high beta points, but: payload loss will still scale one # (or two, but then doesn't matter) # pred_beta = LLFullObjectCondensation( print_loss=True, energy_loss_weight=1e-5, position_loss_weight=1e-5, #seems broken timing_loss_weight=1e-5, #1e-3, beta_loss_scale=3., repulsion_scaling=1., q_min=2.0, use_average_cc_pos=False, use_energy_weights=True, prob_repulsion=True, phase_transition=True, phase_transition_double_weight=False, alt_potential_norm=True, cut_payload_beta_gradient=False)([ pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id, orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, row_splits ]) return Model(inputs=Inputs, outputs=[ pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id, rs ] + backgatheredids + backgathered_coords)
def gravnet_model(Inputs, feature_dropout=-1., addBackGatherInfo=True): feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits = td.interpretAllModelInputs( Inputs) orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits gidx_orig = CreateGlobalIndices()(feat) _, row_splits = RaggedConstructTensor()([feat, row_splits]) rs = row_splits feat_norm = ProcessFeatures()(feat) feat_norm = BatchNormalization(momentum=0.6)(feat_norm) allfeat = [] x = feat_norm backgatheredids = [] gatherids = [] backgathered = [] backgathered_coords = [] energysums = [] #really simple real coordinates energy = SelectFeatures(0, 1)(feat) orig_coords = SelectFeatures(4, 7)(feat_norm) coords = Dense(3, use_bias=False, kernel_initializer=tf.keras.initializers.Identity())( orig_coords) #just rotation and scaling #see whats there nidx, dist = KNN(K=64, radius=1.0)([coords, rs]) x = RaggedGlobalExchange()([x, row_splits]) x_c = Dense(32, activation='elu')(x) x_c = Dense(8)(x_c) #just a few features are enough here #this can be full blown because of the small number of input features x_c = NeighbourApproxPCA(hidden_nodes=[32, 32, 9])( [coords, dist, x_c, nidx]) x_mp = DistanceWeightedMessagePassing([32, 16, 8])([x, nidx, dist]) x = Concatenate()([x, x_c, x_mp]) #this is going to be among the most expensive operations: x = Dense(64, activation='elu', name='pre_dense_a')(x) x = BatchNormalization(momentum=0.6)(x) x = Dense(32, activation='elu', name='pre_dense_b')(x) x = BatchNormalization(momentum=0.6)(x) allfeat.append(x) backgathered_coords.append(coords) sel_gidx = gidx_orig cdist = dist ccoords = coords total_iterations = 5 collectedccoords = [] for i in range(total_iterations): cluster_neighbours = 5 n_dimensions = 3 #make it plottable #derive new coordinates for clustering if i: ccoords = Add()([ ccoords, (Dense(n_dimensions, use_bias=False, name='newcoords' + str(i), kernel_initializer='zeros')(x)) ]) nidx, cdist = KNN(K=6 * cluster_neighbours, radius=-1.0)([ccoords, rs]) #here we use more neighbours to improve learning of the cluster space #cluster first #hier = Dense(1,activation='sigmoid')(x) #distcorr = Dense(dist.shape[-1],activation='relu',kernel_initializer='zeros')(Concatenate()([x,dist])) #dist = Add()([distcorr,dist]) cdist = LocalDistanceScaling(max_scale=5.)([ cdist, Dense(1, kernel_initializer='zeros')(Concatenate()([x, cdist])) ]) #what if we let the individual distances scale here? so move hits in and out? x_cl, rs, bidxs, sel_gidx, energy, x, t_idx, coords, ccoords, cdist = LocalClusterReshapeFromNeighbours2( K=cluster_neighbours, radius=0.1, print_reduction=True, loss_enabled=True, loss_scale=1., loss_repulsion=0.4, #.5 hier_transforms=[64, 32, 32], print_loss=True, name='clustering_' + str(i))([ x, cdist, nidx, rs, sel_gidx, energy, x, t_idx, coords, ccoords, cdist, t_idx ]) gatherids.append(bidxs) #explicit energy = ReduceSumEntirely()( energy) #sums up all contained energy per cluster n_energy = BatchNormalization(momentum=0.6)(energy) x = Concatenate()([x_cl, n_energy]) x = Dense(128, activation='elu', name='dense_clc_a' + str(i))(x) x = BatchNormalization(momentum=0.6)(x) #x = Dense(128, activation='elu',name='dense_clc_b'+str(i))(x) x = Dense(64, activation='elu')(x) x = BatchNormalization(momentum=0.6)(x) nneigh = 64 nfilt = 64 nprop = 64 x = Concatenate()([coords, x]) x_gn, coords, nidx, dist = RaggedGravNet(n_neighbours=nneigh, n_dimensions=n_dimensions, n_filters=nfilt, n_propagate=nprop)([x, rs]) x_sp = Dense(16)(x) x_sp = NeighbourApproxPCA(hidden_nodes=[32, 32, n_dimensions**2])( [coords, dist, x_sp, nidx]) x_sp = BatchNormalization(momentum=0.6)(x_sp) x_mp = DistanceWeightedMessagePassing([32, 32, 16, 16, 8, 8])([x, nidx, dist]) x_mp = BatchNormalization(momentum=0.6)(x_mp) #x_sp=x_mp x = Concatenate()([x, x_mp, x_sp, x_gn]) #check and compress it all x = Dense(128, activation='elu', name='dense_a_' + str(i))(x) x = BatchNormalization(momentum=0.6)(x) #x = Dense(128, activation='elu',name='dense_b_'+str(i))(x) x = Dense(64, activation='elu', name='dense_c_' + str(i))(x) x = Concatenate()([StopGradient()(ccoords), StopGradient()(cdist), x]) x = BatchNormalization(momentum=0.6)(x) #record more and more the deeper we go x_r = x energysums.append(MultiBackGather()( [energy, gatherids])) #assign energy sum to all cluster components allfeat.append(MultiBackGather()([x_r, gatherids])) backgatheredids.append(MultiBackGather()([sel_gidx, gatherids])) backgathered_coords.append(MultiBackGather()([ccoords, gatherids])) x = Concatenate(name='allconcat')(allfeat) x = RaggedGlobalExchange()([x, row_splits]) #x - Dropout(0.3)(x)#force to use different information sources x = Dense(128, activation='elu', name='alldense')(x) x = BatchNormalization(momentum=0.6)(x) x = Dense(64, activation='elu')(x) #x = Concatenate()([x]+energysums) pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id = create_outputs( x, feat) #loss pred_beta = LLFullObjectCondensation( print_loss=True, energy_loss_weight=1e-4, position_loss_weight=1e-2, timing_loss_weight=1e-3, beta_loss_scale=1., repulsion_scaling=1., q_min=1.5, use_average_cc_pos=0.1, prob_repulsion=True, phase_transition=1, alt_potential_norm=True, kalpha_damping_strength=0.5, #1., name="FullOCLoss")([ pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id, orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, row_splits ]) return Model(inputs=Inputs, outputs=[ pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id, rs ] + backgatheredids + backgathered_coords)