def gravnet_model(Inputs, viscosity=0.2, print_viscosity=False, fluidity_decay=1e-3, max_viscosity=0.9 # to start with ): feature_dropout=-1. addBackGatherInfo=True, feat, t_idx, t_energy, t_pos, t_time, t_pid, t_spectator, t_fully_contained, row_splits = td.interpretAllModelInputs(Inputs) orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits gidx_orig = CreateGlobalIndices()(feat) _, row_splits = RaggedConstructTensor()([feat, row_splits]) rs = row_splits feat_norm = ProcessFeatures()(feat) energy = SelectFeatures(0,1)(feat) time = SelectFeatures(8,9)(feat_norm) orig_coords = SelectFeatures(5,8)(feat) x = feat_norm allfeat=[x] backgatheredids=[] gatherids=[] backgathered = [] backgathered_coords = [] energysums = [] n_reshape_dimensions=3 #really simple real coordinates coords = ManualCoordTransform()(orig_coords) coords = Concatenate()([coords, time]) coords = Dense(n_reshape_dimensions, use_bias=False, kernel_initializer=tf.keras.initializers.Identity() )(coords)#just rotation and scaling #see whats there nidx, dist = KNN(K=24,radius=1.0)([coords,rs]) x_mp = DistanceWeightedMessagePassing([32,16,16,8])([x,nidx,dist]) x = Concatenate()([x,x_mp]) #this is going to be among the most expensive operations: x = Dense(64, activation='elu',name='pre_dense_a')(x) x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x) allfeat.append(x) backgathered_coords.append(coords) sel_gidx = gidx_orig cdist = dist ccoords = coords total_iterations=5 fwdbgccoords=None for i in range(total_iterations): cluster_neighbours = 5 #derive new coordinates for clustering if i: ccoords = Add()([ ccoords,Dense(n_reshape_dimensions,name='newcoords'+str(i), kernel_initializer='zeros' )(x) ]) nidx, cdist = KNN(K=5*cluster_neighbours,radius=-1.0)([ccoords,rs]) #here we use more neighbours to improve learning of the cluster space #this can be adjusted in the final trained model to be equal to 'cluster_neighbours' cdist = LocalDistanceScaling(max_scale=10.)([cdist, Dense(1,kernel_initializer='zeros')(Concatenate()([x,cdist]))]) x, rs, bidxs, sel_gidx, energy, x, t_idx, coords, ccoords, cdist = LocalClusterReshapeFromNeighbours2( K=cluster_neighbours, radius=0.1, print_reduction=True, loss_enabled=True, loss_scale = 3., loss_repulsion=0.8, # it's important to get this right here hier_transforms=[64,32,16], print_loss=False, name='clustering_'+str(i) )([x, cdist, nidx, rs, sel_gidx, energy, x, t_idx, coords, ccoords, cdist, t_idx]) gatherids.append(bidxs) #explicit energy = ReduceSumEntirely()(energy)#sums up all contained energy per cluster x = Dense(128, activation='elu',name='dense_clc_a'+str(i))(x) x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x) x = Dense(128, activation='elu',name='dense_clc_b'+str(i))(x) x = Dense(64, activation='elu')(x) x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x) n_dimensions = 3+i #make it plottable nneigh = 64+32*i #this will be almost fully connected for last clustering step nfilt = 64 nprop = 64 x = Concatenate()([coords,x]) x_gn, coords, nidx, dist = RaggedGravNet(n_neighbours=nneigh, n_dimensions=n_dimensions, n_filters=nfilt, n_propagate=nprop)([x, rs]) x_pca = Dense(2+4*i)(x) x_pca = NeighbourApproxPCA()([coords,dist,x_pca,nidx]) x_pca = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x_pca) x_mp = DistanceWeightedMessagePassing([64,64,32,32])([x,nidx,dist]) x_mp = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x_mp) x = Concatenate()([x,x_pca,x_mp,x_gn]) #check and compress it all x = Dense(128, activation='elu',name='dense_a_'+str(i))(x) x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x) x = Dense(32+16*i, activation='elu',name='dense_c_'+str(i))(x) x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x) energysums.append( MultiBackGather()([energy, gatherids]) )#assign energy sum to all cluster components allfeat.append(MultiBackGather()([x, gatherids])) backgatheredids.append(MultiBackGather()([sel_gidx, gatherids])) bgccoords = MultiBackGather()([ccoords, gatherids]) if fwdbgccoords is None: fwdbgccoords = bgccoords backgathered_coords.append(bgccoords) x = Concatenate(name='allconcat')(allfeat) x = Dense(128, activation='elu', name='alldense')(x) x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x) #this is going to be resource intense, give a good starting point with the last ccoords coords = Dense(3,use_bias=False, kernel_initializer=tf.keras.initializers.Identity() )(Concatenate()([ fwdbgccoords ,x])) nidx, dist = KNN(K=32,radius=1.0)([coords,row_splits]) x_mp = DistanceWeightedMessagePassing(8*[8])([x,nidx,dist])#only some but a lot of hops x = Concatenate()([x,x_mp]) # backgathered_coords.append(coords) x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x) x = Dense(128, activation='elu')(x) x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x) x = Concatenate()([x]+energysums) x = Dense(64, activation='elu')(x) x = Dense(64, activation='elu')(x) pred_beta, pred_ccoords, pred_dist, pred_energy, pred_pos, pred_time, pred_id = create_outputs(x,feat,add_distance_scale=True) #loss pred_beta = LLFullObjectCondensation(print_loss=True, energy_loss_weight=1e-4, position_loss_weight=1e-3, timing_loss_weight=1e-3, beta_loss_scale=1., repulsion_scaling=5., repulsion_q_min=1., super_repulsion=False, q_min=0.5, use_average_cc_pos=0.5, prob_repulsion=True, phase_transition=1, huber_energy_scale = 3, alt_potential_norm=True, beta_gradient_damping=0., payload_beta_gradient_damping_strength=0., kalpha_damping_strength=0.,#1., use_local_distances=True, name="FullOCLoss" )([pred_beta, pred_ccoords, pred_dist, pred_energy, pred_pos, pred_time, pred_id, orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, row_splits]) model_outputs = [('pred_beta', pred_beta), ('pred_ccoords',pred_ccoords), ('pred_energy',pred_energy), ('pred_pos',pred_pos), ('pred_time',pred_time), ('pred_id',pred_id), ('pred_dist',pred_dist), ('row_splits',rs)] for i, (x, y) in enumerate(zip(backgatheredids, backgathered_coords)): model_outputs.append(('backgatheredids_'+str(i), x)) model_outputs.append(('backgathered_coords_'+str(i), y)) return RobustModel(model_inputs=Inputs, model_outputs=model_outputs)
def gravnet_model(Inputs, feature_dropout=-1., addBackGatherInfo=True): ######## pre-process all inputs and create global indices etc. No DNN actions here feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits = td.interpretAllModelInputs( Inputs) # feat = Lambda(lambda x: tf.squeeze(x,axis=1)) (feat) #tf.print([(t.shape, t.name) for t in [feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits]]) #feat, t_idx, t_energy, t_pos, t_time, t_pid = NormalizeInputShapes()( # [feat, t_idx, t_energy, t_pos, t_time, t_pid] # ) orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits gidx_orig = CreateGlobalIndices()(feat) _, row_splits = RaggedConstructTensor()([feat, row_splits]) rs = row_splits feat_norm = ProcessFeatures()( feat) #get rid of unit scalings, almost normalise feat_norm = BatchNormalization(momentum=0.6)(feat_norm) x = feat_norm energy = SelectFeatures(0, 1)(feat) time = SelectFeatures(8, 9)(feat) orig_coords = SelectFeatures(5, 8)(feat_norm) ######## create output lists allfeat = [] backgatheredids = [] gatherids = [] backgathered = [] backgathered_coords = [] ####### create simple first coordinate transformation explicitly (time critical) #trans_coords = ManualCoordTransform()(orig_coords) coords = orig_coords coords = Dense(3, use_bias=False, kernel_initializer=tf.keras.initializers.Identity())(coords) nidx, dist = KNN(K=48)([coords, rs]) first_nidx = nidx first_dist = dist first_coords = coords dist = LocalDistanceScaling(max_scale=10.)([dist, Dense(1)(x)]) x_mp = DistanceWeightedMessagePassing([32, 16, 8])([x, nidx, dist]) x_mp = Dense(32, activation='elu')(x_mp) x_mp = BatchNormalization(momentum=0.6)(x_mp) x = Concatenate()([x, x_mp]) ###### collect information about the surrounding energy and time distributions per vertex ### ncov = Dense(4, kernel_initializer=tf.keras.initializers.Identity())( Concatenate()([energy, time, x])) ncov = NeighbourCovariance()([coords, dist, ncov, nidx]) ncov = Dense(24, activation='elu', name='pre_dense_ncov_c')(ncov) ncov = BatchNormalization(momentum=0.6)(ncov) #should be enough for PCA, total info: 2 * (9+3) ##### put together and process #### x = Concatenate()([x, x_mp, ncov]) #x = Dense(64, activation='elu',name='pre_dense_a')(x) x = Dense(64, activation='elu', name='pre_dense_b')(x) ####### add first set of outputs to output lists allfeat.append(x) backgathered_coords.append(coords) total_iterations = 5 sel_gidx = gidx_orig for i in range(total_iterations): ###### reshape the graph to fewer vertices #### hier = Dense(1)(x) #narrow the local scaling down at each iteration as the coords become more abstract dist = LocalDistanceScaling(max_scale=10.)( [dist, Dense(1)(Concatenate()([x, dist]))]) x_cl, rs, bidxs, sel_gidx, energy, x, t_idx, coords = LocalClusterReshapeFromNeighbours( K=6, radius= 0.2, #doesn't really have an effect because of local distance scaling print_reduction=True, loss_enabled=True, loss_scale=1., loss_repulsion=0.4, print_loss=True, name='clustering_' + str(i))([ x, dist, hier, nidx, rs, sel_gidx, energy, x, t_idx, coords, t_idx ]) #last is truth index used by layer gatherids.append(bidxs) energy = ReduceSumEntirely()(energy) #use EdgeConv operation to determine cluster properties if True or i: #only after second iteration because of OOM x_cl = Reshape([-1, x.shape[-1]])(x_cl) #get to shape V x K x F x_cl = EdgeConvStatic([64, 64, 64], add_mean=True, name="ec_static_" + str(i))(x_cl) x_cl = Concatenate()([x, x_cl]) x = Concatenate()([x_cl, StopGradient()(coords)]) x = Dense(64, activation='elu', name='dense_clc0_' + str(i))(x) x = Dense(64, activation='elu', name='dense_clc1_' + str(i))(x) x = BatchNormalization(momentum=0.6)(x) #notice last relu for feature weighting later ### now these are the new cluster features, up for the next iteration of building new latent space #x = RaggedGlobalExchange()([x,rs]) x_gn, coords, nidx, dist = RaggedGravNet( n_neighbours=64 + 32 * i, n_dimensions=3, n_filters=64, n_propagate=64, return_self=True)([Concatenate()([coords, x]), rs]) ng_coords = StopGradient()(coords) ng_dist = StopGradient()(dist) x_gn = Dense(32, activation='elu')(x_gn) ### add neighbour summary statistics #dist = LocalDistanceScaling(max_scale=10.)([dist, Dense(1)(x_gn)]) x_ncov = Dense(16)(x) x_ncov = NeighbourCovariance()([ng_coords, ng_dist, x_ncov, nidx]) x_ncov = Dense(64, activation='elu', name='dense_ncov_a_' + str(i))(x_ncov) x_ncov = Dense(64, activation='elu', name='dense_ncov_b_' + str(i))(x_ncov) x_ncov = BatchNormalization(momentum=0.6)(x_ncov) ### with all this information perform a few message passing steps x_mp = x x_mp = DistanceWeightedMessagePassing([32, 16, 8])([x_mp, nidx, ng_dist]) x_mp = Dense(64, activation='elu')(x_mp) x_mp = Dense(32, activation='elu')(x_mp) x_mp = BatchNormalization(momentum=0.6)(x_mp) x = Concatenate()([x, x_mp, x_ncov, x_gn, ng_coords, ng_dist]) ##### prepare output of this iteration x = Dense(64, activation='elu', name='dense_out_a_' + str(i))(x) x = Dense(64, activation='elu', name='dense_out_b_' + str(i))(x) x = BatchNormalization(momentum=0.6)(x) #### compress further for output, but forward fill 64 feature x to next iteration x_r = Dense(32, activation='elu', name='dense_out_c_' + str(i))(x) #x_r = Concatenate()([x_r, ng_coords, ng_dist]) #x_r = Concatenate()([StopGradient()(coords),x_r]) ## add coordinates, might come handy for cluster space if i >= total_iterations - 1: energy = MultiBackGather()( [energy, gatherids]) #assign energy sum to all cluster components allfeat.append(MultiBackGather()([x_r, gatherids])) backgatheredids.append(MultiBackGather()([sel_gidx, gatherids])) backgathered_coords.append(MultiBackGather()([coords, gatherids])) x = Concatenate(name='allconcat')(allfeat) x = RaggedGlobalExchange()([x, row_splits]) x = Dense(256, activation='elu', name='globalexdense')(x) x = RaggedGlobalExchange()([x, row_splits]) x = Dense(128, activation='elu', name='alldense')(x) x = Dense(64, activation='elu')(x) x = Dense(32, activation='elu')(x) x = Concatenate()([first_coords, x]) x = BatchNormalization(momentum=0.6)(x) pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id = create_outputs( x, feat) # # # double scale phase transition with linear beta + qmin # -> more high beta points, but: payload loss will still scale one # (or two, but then doesn't matter) # pred_beta = LLFullObjectCondensation( print_loss=True, energy_loss_weight=1e-5, position_loss_weight=1e-5, #seems broken timing_loss_weight=1e-5, #1e-3, beta_loss_scale=3., repulsion_scaling=1., q_min=2.0, use_average_cc_pos=False, use_energy_weights=True, prob_repulsion=True, phase_transition=True, phase_transition_double_weight=False, alt_potential_norm=True, cut_payload_beta_gradient=False)([ pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id, orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, row_splits ]) return Model(inputs=Inputs, outputs=[ pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id, rs ] + backgatheredids + backgathered_coords)
def gravnet_model(Inputs, feature_dropout=-1., addBackGatherInfo=True): feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits = td.interpretAllModelInputs( Inputs) orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits gidx_orig = CreateGlobalIndices()(feat) _, row_splits = RaggedConstructTensor()([feat, row_splits]) rs = row_splits feat_norm = ProcessFeatures()(feat) allfeat = [feat_norm] x = feat_norm backgatheredids = [] gatherids = [] backgathered = [] backgathered_coords = [] energy = SelectFeatures(0, 1)(feat) orig_coords = SelectFeatures(5, 8)(feat_norm) coords = orig_coords coords = Dense(3)(coords) #just rotation and scaling #add gradient #see whats there nidx, dist = KNN(K=96, radius=1.0)([coords, rs]) x = Dense(4, activation='elu', name='pre_pre_dense')(x) #just a few features are enough here #this can be full blown because of the small number of input features x_c = SoftPixelCNN(mode='full', subdivisions=4, name='prePCNN')([coords, x, dist, nidx]) x = Concatenate()([x, x_c]) #this is going to be among the most expensive operations: x = Dense(128, activation='elu', name='pre_dense_a')(x) x = BatchNormalization(momentum=0.6)(x) x = Dense(64, activation='elu', name='pre_dense_b')(x) x = BatchNormalization(momentum=0.6)(x) #allfeat.append(Dense(16, activation='elu',name='feat_compress_pre')(x)) backgathered_coords.append(coords) sel_gidx = gidx_orig total_iterations = 5 for i in range(total_iterations): #cluster first hier = Dense(1, activation='sigmoid')(x) x_cl, rs, bidxs, sel_gidx, energy, x, t_idx = LocalClusterReshapeFromNeighbours( K=8, radius=0.1, print_reduction=True, loss_enabled=True, loss_scale=2., loss_repulsion=0.3, print_loss=True, name='clustering_' + str(i))( [x, dist, hier, nidx, rs, sel_gidx, energy, x, t_idx, t_idx]) #explicit energy = ReduceSumEntirely()( energy) #sums up all contained energy per cluster gatherids.append(bidxs) x_cl = Dense(128, activation='elu', name='dense_clc_' + str(i))(x_cl) n_energy = BatchNormalization(momentum=0.6)(energy) x = Concatenate()([x, x_cl, n_energy]) pixelcompress = 4 nneigh = 32 + 4 * i nfilt = 32 + 4 * i nprop = 32 n_dimensions = 3 #make it plottable x_gn, coords, nidx, dist = RaggedGravNet(n_neighbours=nneigh, n_dimensions=n_dimensions, n_filters=nfilt, n_propagate=nprop)([x, rs]) subdivisions = 4 #add more shape information x_sp = Dense(pixelcompress, activation='elu')(x) x_sp = BatchNormalization(momentum=0.6)(x_sp) x_sp = SoftPixelCNN(mode='full', subdivisions=4)([coords, x_sp, dist, nidx]) x_sp = Dense(128, activation='elu', name='dense_spc_' + str(i))(x_sp) x_mp = MessagePassing([32, 32, 16, 16, 8, 8])([x, nidx]) x_mp = Dense(nfilt, activation='elu', name='dense_mpc_' + str(i))(x_sp) x = Concatenate()([x, x_sp, x_gn, x_mp]) #check and compress it all x = Dense(128, activation='elu', name='dense_a_' + str(i))(x) x = Dense(64, activation='elu', name='dense_b_' + str(i))(x) x = BatchNormalization(momentum=0.6)(x) x_r = x #record more and more the deeper we go if i < total_iterations - 1: x_r = Dense(12 * (i + 1), activation='elu', name='dense_rec_' + str(i))(x) else: energy = MultiBackGather()( [energy, gatherids]) #assign energy sum to all cluster components allfeat.append(MultiBackGather()([x_r, gatherids])) backgatheredids.append(MultiBackGather()([sel_gidx, gatherids])) backgathered_coords.append(MultiBackGather()([coords, gatherids])) x = Concatenate(name='allconcat')(allfeat) #x - Dropout(0.3)(x)#force to use different information sources x = Dense(128, activation='elu', name='alldense')(x) x = BatchNormalization(momentum=0.6)(x) x = Dense(64, activation='elu')(x) x = BatchNormalization(momentum=0.6)(x) x = Concatenate()([feat, x]) x = BatchNormalization(momentum=0.6)(x) x = Concatenate()([x, energy]) pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id = create_outputs( x, feat) #loss pred_beta = LLFullObjectCondensation( print_loss=True, energy_loss_weight=1e-4, position_loss_weight=1e-2, timing_loss_weight=1e-3, beta_loss_scale=1., repulsion_scaling=1., q_min=1.5, prob_repulsion=True, phase_transition=1, alt_potential_norm=True)([ pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id, orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, row_splits ]) return Model(inputs=Inputs, outputs=[ pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id, rs ] + backgatheredids + backgathered_coords)
def gravnet_model(Inputs, feature_dropout=-1., addBackGatherInfo=True): feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits = td.interpretAllModelInputs( Inputs) orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits gidx_orig = CreateGlobalIndices()(feat) _, row_splits = RaggedConstructTensor()([feat, row_splits]) rs = row_splits feat_norm = ProcessFeatures()(feat) feat_norm = BatchNormalization(momentum=0.6)(feat_norm) allfeat = [] x = feat_norm backgatheredids = [] gatherids = [] backgathered = [] backgathered_coords = [] energysums = [] #really simple real coordinates energy = SelectFeatures(0, 1)(feat) orig_coords = SelectFeatures(4, 7)(feat_norm) coords = Dense(3, use_bias=False, kernel_initializer=tf.keras.initializers.Identity())( orig_coords) #just rotation and scaling #see whats there nidx, dist = KNN(K=64, radius=1.0)([coords, rs]) x = RaggedGlobalExchange()([x, row_splits]) x_c = Dense(32, activation='elu')(x) x_c = Dense(8)(x_c) #just a few features are enough here #this can be full blown because of the small number of input features x_c = NeighbourApproxPCA(hidden_nodes=[32, 32, 9])( [coords, dist, x_c, nidx]) x_mp = DistanceWeightedMessagePassing([32, 16, 8])([x, nidx, dist]) x = Concatenate()([x, x_c, x_mp]) #this is going to be among the most expensive operations: x = Dense(64, activation='elu', name='pre_dense_a')(x) x = BatchNormalization(momentum=0.6)(x) x = Dense(32, activation='elu', name='pre_dense_b')(x) x = BatchNormalization(momentum=0.6)(x) allfeat.append(x) backgathered_coords.append(coords) sel_gidx = gidx_orig cdist = dist ccoords = coords total_iterations = 5 collectedccoords = [] for i in range(total_iterations): cluster_neighbours = 5 n_dimensions = 3 #make it plottable #derive new coordinates for clustering if i: ccoords = Add()([ ccoords, (Dense(n_dimensions, use_bias=False, name='newcoords' + str(i), kernel_initializer='zeros')(x)) ]) nidx, cdist = KNN(K=6 * cluster_neighbours, radius=-1.0)([ccoords, rs]) #here we use more neighbours to improve learning of the cluster space #cluster first #hier = Dense(1,activation='sigmoid')(x) #distcorr = Dense(dist.shape[-1],activation='relu',kernel_initializer='zeros')(Concatenate()([x,dist])) #dist = Add()([distcorr,dist]) cdist = LocalDistanceScaling(max_scale=5.)([ cdist, Dense(1, kernel_initializer='zeros')(Concatenate()([x, cdist])) ]) #what if we let the individual distances scale here? so move hits in and out? x_cl, rs, bidxs, sel_gidx, energy, x, t_idx, coords, ccoords, cdist = LocalClusterReshapeFromNeighbours2( K=cluster_neighbours, radius=0.1, print_reduction=True, loss_enabled=True, loss_scale=1., loss_repulsion=0.4, #.5 hier_transforms=[64, 32, 32], print_loss=True, name='clustering_' + str(i))([ x, cdist, nidx, rs, sel_gidx, energy, x, t_idx, coords, ccoords, cdist, t_idx ]) gatherids.append(bidxs) #explicit energy = ReduceSumEntirely()( energy) #sums up all contained energy per cluster n_energy = BatchNormalization(momentum=0.6)(energy) x = Concatenate()([x_cl, n_energy]) x = Dense(128, activation='elu', name='dense_clc_a' + str(i))(x) x = BatchNormalization(momentum=0.6)(x) #x = Dense(128, activation='elu',name='dense_clc_b'+str(i))(x) x = Dense(64, activation='elu')(x) x = BatchNormalization(momentum=0.6)(x) nneigh = 64 nfilt = 64 nprop = 64 x = Concatenate()([coords, x]) x_gn, coords, nidx, dist = RaggedGravNet(n_neighbours=nneigh, n_dimensions=n_dimensions, n_filters=nfilt, n_propagate=nprop)([x, rs]) x_sp = Dense(16)(x) x_sp = NeighbourApproxPCA(hidden_nodes=[32, 32, n_dimensions**2])( [coords, dist, x_sp, nidx]) x_sp = BatchNormalization(momentum=0.6)(x_sp) x_mp = DistanceWeightedMessagePassing([32, 32, 16, 16, 8, 8])([x, nidx, dist]) x_mp = BatchNormalization(momentum=0.6)(x_mp) #x_sp=x_mp x = Concatenate()([x, x_mp, x_sp, x_gn]) #check and compress it all x = Dense(128, activation='elu', name='dense_a_' + str(i))(x) x = BatchNormalization(momentum=0.6)(x) #x = Dense(128, activation='elu',name='dense_b_'+str(i))(x) x = Dense(64, activation='elu', name='dense_c_' + str(i))(x) x = Concatenate()([StopGradient()(ccoords), StopGradient()(cdist), x]) x = BatchNormalization(momentum=0.6)(x) #record more and more the deeper we go x_r = x energysums.append(MultiBackGather()( [energy, gatherids])) #assign energy sum to all cluster components allfeat.append(MultiBackGather()([x_r, gatherids])) backgatheredids.append(MultiBackGather()([sel_gidx, gatherids])) backgathered_coords.append(MultiBackGather()([ccoords, gatherids])) x = Concatenate(name='allconcat')(allfeat) x = RaggedGlobalExchange()([x, row_splits]) #x - Dropout(0.3)(x)#force to use different information sources x = Dense(128, activation='elu', name='alldense')(x) x = BatchNormalization(momentum=0.6)(x) x = Dense(64, activation='elu')(x) #x = Concatenate()([x]+energysums) pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id = create_outputs( x, feat) #loss pred_beta = LLFullObjectCondensation( print_loss=True, energy_loss_weight=1e-4, position_loss_weight=1e-2, timing_loss_weight=1e-3, beta_loss_scale=1., repulsion_scaling=1., q_min=1.5, use_average_cc_pos=0.1, prob_repulsion=True, phase_transition=1, alt_potential_norm=True, kalpha_damping_strength=0.5, #1., name="FullOCLoss")([ pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id, orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, row_splits ]) return Model(inputs=Inputs, outputs=[ pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id, rs ] + backgatheredids + backgathered_coords)
def gravnet_model(Inputs, feature_dropout=-1., addBackGatherInfo=True): ######## pre-process all inputs and create global indices etc. No DNN actions here feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits = td.interpretAllModelInputs( Inputs) feat, t_idx, t_energy, t_pos, t_time, t_pid = NormalizeInputShapes()( [feat, t_idx, t_energy, t_pos, t_time, t_pid]) orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits gidx_orig = CreateGlobalIndices()(feat) _, row_splits = RaggedConstructTensor()([feat, row_splits]) rs = row_splits feat_norm = ProcessFeatures()( feat) #get rid of unit scalings, almost normalise feat_norm = BatchNormalization(momentum=0.6)(feat_norm) x = feat_norm energy = SelectFeatures(0, 1)(feat) time = SelectFeatures(8, 9)(feat) orig_coords = SelectFeatures(5, 8)(feat_norm) ######## create output lists allfeat = [] backgatheredids = [] gatherids = [] backgathered = [] backgathered_coords = [] ####### create simple first coordinate transformation explicitly (time critical) coords = orig_coords coords = Dense(16, activation='elu')(coords) coords = Dense(32, activation='elu')(coords) coords = Dense(3, use_bias=False)(coords) coords = ScalarMultiply(0.1)(coords) coords = Add()([coords, orig_coords]) coords = Dense(3, use_bias=False, kernel_initializer=tf.keras.initializers.identity())(coords) first_coords = coords ###### apply one gravnet-like transformation (explicit here because we have created coords by hand) ### nidx, dist = KNN(K=48)([coords, rs]) x_mp = DistanceWeightedMessagePassing([32])([x, nidx, dist]) first_nidx = nidx first_dist = dist ###### collect information about the surrounding energy and time distributions per vertex ### ncov = NeighbourCovariance()( [coords, ReluPlusEps()(Concatenate()([energy, time])), nidx]) ncov = BatchNormalization(momentum=0.6)(ncov) ncov = Dense(64, activation='elu', name='pre_dense_ncov_a')(ncov) ncov = Dense(32, activation='elu', name='pre_dense_ncov_b')(ncov) ##### put together and process #### x = Concatenate()([x, x_mp, ncov, coords]) x = Dense(64, activation='elu', name='pre_dense_a')(x) x = BatchNormalization(momentum=0.6)(x) x = Dense(32, activation='elu', name='pre_dense_b')(x) ####### add first set of outputs to output lists allfeat.append(x) backgathered_coords.append(coords) total_iterations = 5 sel_gidx = gidx_orig for i in range(total_iterations): ###### reshape the graph to fewer vertices #### hier = Dense(1)(x) dist = LocalDistanceScaling()([dist, Dense(1)(x)]) x_cl, rs, bidxs, sel_gidx, energy, x, t_idx, coords = LocalClusterReshapeFromNeighbours( K=6, radius= 0.5, #doesn't really have an effect because of local distance scaling print_reduction=True, loss_enabled=True, loss_scale=4., loss_repulsion=0.5, print_loss=True, name='clustering_' + str(i))([ x, dist, hier, nidx, rs, sel_gidx, energy, x, t_idx, coords, t_idx ]) #last is truth index used by layer gatherids.append(bidxs) if i or True: x_cl_rs = Reshape([-1, x.shape[-1]])(x_cl) #get to shape V x K x F xec = EdgeConvStatic([32, 32, 32], name="ec_static_" + str(i))(x_cl_rs) x_cl = Concatenate()([x, xec]) ### explicitly sum energy and re-add to features energy = ReduceSumEntirely()(energy) n_energy = BatchNormalization(momentum=0.6)(energy) x = Concatenate()([x_cl, n_energy]) x = Dense(128, activation='elu', name='dense_clc0_' + str(i))(x) x = Dense(64, activation='relu', name='dense_clc1_' + str(i))(x) #notice last relu for feature weighting later x_gn, coords, nidx, dist = RaggedGravNet( n_neighbours=32 + 16 * i, n_dimensions=3, n_filters=64 + 16 * i, n_propagate=64, return_self=True)([Concatenate()([coords, x]), rs]) ### add neighbour summary statistics x_ncov = NeighbourCovariance()([coords, ReluPlusEps()(x), nidx]) x_ncov = Dense(128, activation='elu', name='dense_ncov_a_' + str(i))(x_ncov) x_ncov = BatchNormalization(momentum=0.6)(x_ncov) x_ncov = Dense(64, activation='elu', name='dense_ncov_b_' + str(i))(x_ncov) x = Concatenate()([x, x_ncov, x_gn]) ### with all this information perform a few message passing steps x_mp = MessagePassing([32, 32, 16, 16, 8, 8])([x, nidx]) x_mp = Dense(64, activation='elu', name='dense_mpc_' + str(i))(x_mp) x = Concatenate()([x, x_mp]) ##### prepare output of this iteration x = Dense(128, activation='elu', name='dense_out_a_' + str(i))(x) x = BatchNormalization(momentum=0.6)(x) x = Dense(64, activation='elu', name='dense_out_b_' + str(i))(x) x = BatchNormalization(momentum=0.6)(x) #### compress further for output, but forward fill 64 feature x to next iteration x_r = Dense(8 + 16 * i, activation='elu', name='dense_out_c_' + str(i))(x) #coords_nograd = StopGradient()(coords) #x_r = Concatenate()([coords_nograd,x_r]) ## add coordinates, might come handy for cluster space if i >= total_iterations - 1: energy = MultiBackGather()( [energy, gatherids]) #assign energy sum to all cluster components allfeat.append(MultiBackGather()([x_r, gatherids])) backgatheredids.append(MultiBackGather()([sel_gidx, gatherids])) backgathered_coords.append(MultiBackGather()([coords, gatherids])) x = Concatenate(name='allconcat')(allfeat) #x = Dropout(0.2)(x) x_mp = DistanceWeightedMessagePassing([32, 32, 32])([x, first_nidx, first_dist]) x = Concatenate()([x, x_mp]) x = Dense(128, activation='elu', name='alldense')(x) # TO BE ADDED WITH E LOSS x = Concatenate()([x,energy]) #x = Dropout(0.2)(x) x = BatchNormalization(momentum=0.6)(x) x = Dense(64, activation='elu')(x) x = BatchNormalization(momentum=0.6)(x) x = Dense(64, activation='elu')(x) pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id = create_outputs( x, feat) # # # double scale phase transition with linear beta + qmin # -> more high beta points, but: payload loss will still scale one # (or two, but then doesn't matter) # pred_beta = LLFullObjectCondensation( print_loss=True, energy_loss_weight=0., position_loss_weight=0., #seems broken timing_loss_weight=0., #1e-3, beta_loss_scale=1., repulsion_scaling=1., q_min=1.5, prob_repulsion=True, phase_transition=True, phase_transition_double_weight=False, alt_potential_norm=True, cut_payload_beta_gradient=False)([ pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id, orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, row_splits ]) return ExtendedMetricsModel(inputs=Inputs, outputs=[ pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id, rs ] + backgatheredids + backgathered_coords)