def gravnet_model(Inputs, td, debug_outdir=None, plot_debug_every=1000, ): #################################################################################### ##################### Input processing, no need to change much here ################ #################################################################################### orig_inputs = td.interpretAllModelInputs(Inputs,returndict=True) orig_t_spectator_weight = CreateTruthSpectatorWeights(threshold=3., minimum=1e-1, active=True )([orig_inputs['t_spectator'], orig_inputs['t_idx']]) orig_inputs['t_spectator_weight'] = orig_t_spectator_weight #can be loaded - or use pre-selected dataset (to be made) pre_selection = pre_selection_model_full(orig_inputs,trainable=False) #just for info what's available print('available pre-selection outputs',[k for k in pre_selection.keys()]) t_spectator_weight = CreateTruthSpectatorWeights(threshold=3., minimum=1e-1, active=True )([pre_selection['t_spectator'], pre_selection['t_idx']]) rs = pre_selection['rs'] x_in = Concatenate()([pre_selection['coords'], pre_selection['features'], pre_selection['addfeat']]) x = x_in energy = pre_selection['energy'] coords = pre_selection['phys_coords']#physical coordinates c_coords = pre_selection['coords']#pre-clustered coordinates t_idx = pre_selection['t_idx'] #################################################################################### ##################### now the actual model goes below ############################## #################################################################################### allfeat = [] n_cluster_space_coordinates = 3 #extend coordinates already here if needed c_coords = extent_coords_if_needed(c_coords, x, n_cluster_space_coordinates) for i in range(total_iterations): # derive new coordinates for clustering x = RaggedGlobalExchange()([x, rs]) x = Dense(64,activation=dense_activation)(x) x = Dense(64,activation=dense_activation)(x) x = Dense(64,activation=dense_activation)(x) x = GooeyBatchNorm(**batchnorm_options)(x) ### reduction done n_dims = 6 #exchange information, create coordinates x = Concatenate()([c_coords,c_coords,c_coords,coords,x]) xgn, gncoords, gnnidx, gndist = RaggedGravNet(n_neighbours=64, n_dimensions=n_dims, n_filters=64, n_propagate=64, record_metrics=True, coord_initialiser_noise=1e-2, use_approximate_knn=False #weird issue with that for now )([x, rs]) x = Concatenate()([x,xgn]) #just keep them in a reasonable range #safeguard against diappearing gradients on coordinates gndist = AverageDistanceRegularizer(strength=1e-4, record_metrics=True )(gndist) gncoords = PlotCoordinates(plot_debug_every, outdir = debug_outdir, name='gn_coords_'+str(i))([gncoords, energy, t_idx, rs]) x = Concatenate()([gncoords,x]) pre_gndist=gndist if double_mp: for im,m in enumerate([64,64,32,32,16,16]): dscale=Dense(1)(x) gndist = LocalDistanceScaling(4.)([pre_gndist,dscale]) gndist = AverageDistanceRegularizer(strength=1e-6, record_metrics=True, name='average_distance_dmp_'+str(i)+'_'+str(im) )(gndist) x = DistanceWeightedMessagePassing([m], activation=dense_activation )([x,gnnidx,gndist]) else: x = DistanceWeightedMessagePassing([64,64,32,32,16,16], activation=dense_activation )([x,gnnidx,gndist]) x = GooeyBatchNorm(**batchnorm_options)(x) x = Dense(64,name='dense_past_mp_'+str(i),activation=dense_activation)(x) x = Dense(64,activation=dense_activation)(x) x = Dense(64,activation=dense_activation)(x) x = GooeyBatchNorm(**batchnorm_options)(x) allfeat.append(x) x = Concatenate()([c_coords]+allfeat+[pre_selection['not_noise_score']]) #do one more exchange with all x = Dense(64,activation=dense_activation)(x) x = Dense(64,activation=dense_activation)(x) x = Dense(64,activation=dense_activation)(x) ####################################################################### ########### the part below should remain almost unchanged ############# ########### of course with the exception of the OC loss ############# ########### weights ############# ####################################################################### x = GooeyBatchNorm(**batchnorm_options,name='gooey_pre_out')(x) x = Concatenate()([c_coords]+[x]) pred_beta, pred_ccoords, pred_dist, pred_energy_corr, \ pred_pos, pred_time, pred_id = create_outputs(x, pre_selection['unproc_features'], n_ccoords=n_cluster_space_coordinates) # loss pred_beta = LLFullObjectCondensation(scale=4., position_loss_weight=1e-5, timing_loss_weight=1e-5, beta_loss_scale=1., use_energy_weights=True, record_metrics=True, name="FullOCLoss", **loss_options )( # oc output and payload [pred_beta, pred_ccoords, pred_dist, pred_energy_corr, pred_pos, pred_time, pred_id] + [energy]+ # truth information [pre_selection['t_idx'] , pre_selection['t_energy'] , pre_selection['t_pos'] , pre_selection['t_time'] , pre_selection['t_pid'] , pre_selection['t_spectator_weight'], pre_selection['t_fully_contained'], pre_selection['t_rec_energy'], pre_selection['t_is_unique'], pre_selection['rs']]) #fast feedback pred_ccoords = PlotCoordinates(plot_debug_every, outdir = debug_outdir, name='condensation')([pred_ccoords, pred_beta,pre_selection['t_idx'], rs]) model_outputs = re_integrate_to_full_hits( pre_selection, pred_ccoords, pred_beta, pred_energy_corr, pred_pos, pred_time, pred_id, pred_dist, dict_output=True ) return DictModel(inputs=Inputs, outputs=model_outputs)
def gravnet_model(Inputs, viscosity=0.2, print_viscosity=False, fluidity_decay=1e-3, max_viscosity=0.9 # to start with ): feature_dropout=-1. addBackGatherInfo=True, feat, t_idx, t_energy, t_pos, t_time, t_pid, t_spectator, t_fully_contained, row_splits = td.interpretAllModelInputs(Inputs) orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits gidx_orig = CreateGlobalIndices()(feat) _, row_splits = RaggedConstructTensor()([feat, row_splits]) rs = row_splits feat_norm = ProcessFeatures()(feat) energy = SelectFeatures(0,1)(feat) time = SelectFeatures(8,9)(feat_norm) orig_coords = SelectFeatures(5,8)(feat) x = feat_norm allfeat=[x] backgatheredids=[] gatherids=[] backgathered = [] backgathered_coords = [] energysums = [] n_reshape_dimensions=3 #really simple real coordinates coords = ManualCoordTransform()(orig_coords) coords = Concatenate()([coords, time]) coords = Dense(n_reshape_dimensions, use_bias=False, kernel_initializer=tf.keras.initializers.Identity() )(coords)#just rotation and scaling #see whats there nidx, dist = KNN(K=24,radius=1.0)([coords,rs]) x_mp = DistanceWeightedMessagePassing([32,16,16,8])([x,nidx,dist]) x = Concatenate()([x,x_mp]) #this is going to be among the most expensive operations: x = Dense(64, activation='elu',name='pre_dense_a')(x) x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x) allfeat.append(x) backgathered_coords.append(coords) sel_gidx = gidx_orig cdist = dist ccoords = coords total_iterations=5 fwdbgccoords=None for i in range(total_iterations): cluster_neighbours = 5 #derive new coordinates for clustering if i: ccoords = Add()([ ccoords,Dense(n_reshape_dimensions,name='newcoords'+str(i), kernel_initializer='zeros' )(x) ]) nidx, cdist = KNN(K=5*cluster_neighbours,radius=-1.0)([ccoords,rs]) #here we use more neighbours to improve learning of the cluster space #this can be adjusted in the final trained model to be equal to 'cluster_neighbours' cdist = LocalDistanceScaling(max_scale=10.)([cdist, Dense(1,kernel_initializer='zeros')(Concatenate()([x,cdist]))]) x, rs, bidxs, sel_gidx, energy, x, t_idx, coords, ccoords, cdist = LocalClusterReshapeFromNeighbours2( K=cluster_neighbours, radius=0.1, print_reduction=True, loss_enabled=True, loss_scale = 3., loss_repulsion=0.8, # it's important to get this right here hier_transforms=[64,32,16], print_loss=False, name='clustering_'+str(i) )([x, cdist, nidx, rs, sel_gidx, energy, x, t_idx, coords, ccoords, cdist, t_idx]) gatherids.append(bidxs) #explicit energy = ReduceSumEntirely()(energy)#sums up all contained energy per cluster x = Dense(128, activation='elu',name='dense_clc_a'+str(i))(x) x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x) x = Dense(128, activation='elu',name='dense_clc_b'+str(i))(x) x = Dense(64, activation='elu')(x) x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x) n_dimensions = 3+i #make it plottable nneigh = 64+32*i #this will be almost fully connected for last clustering step nfilt = 64 nprop = 64 x = Concatenate()([coords,x]) x_gn, coords, nidx, dist = RaggedGravNet(n_neighbours=nneigh, n_dimensions=n_dimensions, n_filters=nfilt, n_propagate=nprop)([x, rs]) x_pca = Dense(2+4*i)(x) x_pca = NeighbourApproxPCA()([coords,dist,x_pca,nidx]) x_pca = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x_pca) x_mp = DistanceWeightedMessagePassing([64,64,32,32])([x,nidx,dist]) x_mp = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x_mp) x = Concatenate()([x,x_pca,x_mp,x_gn]) #check and compress it all x = Dense(128, activation='elu',name='dense_a_'+str(i))(x) x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x) x = Dense(32+16*i, activation='elu',name='dense_c_'+str(i))(x) x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x) energysums.append( MultiBackGather()([energy, gatherids]) )#assign energy sum to all cluster components allfeat.append(MultiBackGather()([x, gatherids])) backgatheredids.append(MultiBackGather()([sel_gidx, gatherids])) bgccoords = MultiBackGather()([ccoords, gatherids]) if fwdbgccoords is None: fwdbgccoords = bgccoords backgathered_coords.append(bgccoords) x = Concatenate(name='allconcat')(allfeat) x = Dense(128, activation='elu', name='alldense')(x) x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x) #this is going to be resource intense, give a good starting point with the last ccoords coords = Dense(3,use_bias=False, kernel_initializer=tf.keras.initializers.Identity() )(Concatenate()([ fwdbgccoords ,x])) nidx, dist = KNN(K=32,radius=1.0)([coords,row_splits]) x_mp = DistanceWeightedMessagePassing(8*[8])([x,nidx,dist])#only some but a lot of hops x = Concatenate()([x,x_mp]) # backgathered_coords.append(coords) x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x) x = Dense(128, activation='elu')(x) x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x) x = Concatenate()([x]+energysums) x = Dense(64, activation='elu')(x) x = Dense(64, activation='elu')(x) pred_beta, pred_ccoords, pred_dist, pred_energy, pred_pos, pred_time, pred_id = create_outputs(x,feat,add_distance_scale=True) #loss pred_beta = LLFullObjectCondensation(print_loss=True, energy_loss_weight=1e-4, position_loss_weight=1e-3, timing_loss_weight=1e-3, beta_loss_scale=1., repulsion_scaling=5., repulsion_q_min=1., super_repulsion=False, q_min=0.5, use_average_cc_pos=0.5, prob_repulsion=True, phase_transition=1, huber_energy_scale = 3, alt_potential_norm=True, beta_gradient_damping=0., payload_beta_gradient_damping_strength=0., kalpha_damping_strength=0.,#1., use_local_distances=True, name="FullOCLoss" )([pred_beta, pred_ccoords, pred_dist, pred_energy, pred_pos, pred_time, pred_id, orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, row_splits]) model_outputs = [('pred_beta', pred_beta), ('pred_ccoords',pred_ccoords), ('pred_energy',pred_energy), ('pred_pos',pred_pos), ('pred_time',pred_time), ('pred_id',pred_id), ('pred_dist',pred_dist), ('row_splits',rs)] for i, (x, y) in enumerate(zip(backgatheredids, backgathered_coords)): model_outputs.append(('backgatheredids_'+str(i), x)) model_outputs.append(('backgathered_coords_'+str(i), y)) return RobustModel(model_inputs=Inputs, model_outputs=model_outputs)
def gravnet_model( Inputs, td, empty_pca=False, total_iterations=2, variance_only=True, viscosity=0.1, print_viscosity=False, fluidity_decay=5e-4, # reaches after about 7k batches max_viscosity=0.95, debug_outdir=None, plot_debug_every=1000, ): #################################################################################### ##################### Input processing, no need to change much here ################ #################################################################################### orig_inputs = td.interpretAllModelInputs(Inputs, returndict=True) orig_t_spectator_weight = CreateTruthSpectatorWeights( threshold=3., minimum=1e-1, active=True)([orig_inputs['t_spectator'], orig_inputs['t_idx']]) orig_inputs['t_spectator_weight'] = orig_t_spectator_weight #can be loaded - or use pre-selected dataset (to be made) pre_selection = pre_selection_model_full(orig_inputs, trainable=False) #just for info what's available print([k for k in pre_selection.keys()]) t_spectator_weight = CreateTruthSpectatorWeights( threshold=3., minimum=1e-1, active=True)([pre_selection['t_spectator'], pre_selection['t_idx']]) rs = pre_selection['rs'] x_in = Concatenate()([ pre_selection['coords'], pre_selection['features'], pre_selection['addfeat'] ]) x = x_in energy = pre_selection['energy'] coords = pre_selection['phys_coords'] #physical coordinates c_coords = pre_selection['coords'] #pre-clustered coordinates t_idx = pre_selection['t_idx'] #################################################################################### ##################### now the actual model goes below ############################## #################################################################################### allfeat = [] n_cluster_space_coordinates = 3 #extend coordinates already here if needed coords = extent_coords_if_needed(coords, x, n_cluster_space_coordinates) for i in range(total_iterations): # derive new coordinates for clustering x = RaggedGlobalExchange()([x, rs]) x = Dense(64, activation='relu')(x) x = Dense(64, activation='relu')(x) x = Dense(64, activation='relu')(x) x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity, variance_only=variance_only, record_metrics=True, fluidity_decay=fluidity_decay)(x) ### reduction done n_dims = 3 #exchange information, create coordinates x = Concatenate()([coords, coords, c_coords, x]) x, gncoords, gnnidx, gndist = RaggedGravNet( n_neighbours=64, n_dimensions=n_dims, n_filters=64, n_propagate=64, record_metrics=True, use_approximate_knn=False #weird issue with that for now )([x, rs]) gncoords = PlotCoordinates(plot_debug_every, outdir=debug_outdir, name='gn_coords_' + str(i))([gncoords, energy, t_idx, rs]) #just keep them in a reasonable range #safeguard against diappearing gradients on coordinates gndist = AverageDistanceRegularizer(strength=0.01, record_metrics=True)(gndist) x = Concatenate()([energy, x]) x_pca = Dense(4, activation='relu')(x) #pca is expensive x_pca = ApproxPCA(empty=empty_pca)([gncoords, gndist, x_pca, gnnidx]) x = Concatenate()([x, x_pca]) x = DistanceWeightedMessagePassing([64, 64, 32, 32, 16, 16])([x, gnnidx, gndist]) x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity, record_metrics=True, variance_only=variance_only, fluidity_decay=fluidity_decay)(x) x = Dense(64, activation='relu')(x) x = Dense(64, activation='relu')(x) x = Dense(64, activation='relu')(x) x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity, variance_only=variance_only, record_metrics=True, fluidity_decay=fluidity_decay)(x) allfeat.append(x) x = Concatenate()([c_coords] + allfeat + [pre_selection['not_noise_score']]) #do one more exchange with all x = Dense(64, activation='elu')(x) x = Dense(64, activation='elu')(x) x = Dense(64, activation='elu')(x) ####################################################################### ########### the part below should remain almost unchanged ############# ########### of course with the exception of the OC loss ############# ########### weights ############# ####################################################################### x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity, fluidity_decay=fluidity_decay, record_metrics=True, variance_only=variance_only, name='gooey_pre_out')(x) x = Concatenate()([c_coords] + [x]) pred_beta, pred_ccoords, pred_dist, pred_energy_corr, \ pred_pos, pred_time, pred_id = create_outputs(x, pre_selection['unproc_features'], n_ccoords=n_cluster_space_coordinates) # loss pred_beta = LLFullObjectCondensation( scale=1., energy_loss_weight=2., #print_batch_time=True, position_loss_weight=1e-5, timing_loss_weight=1e-5, classification_loss_weight=1e-5, beta_loss_scale=1., too_much_beta_scale=1e-4, use_energy_weights=True, record_metrics=True, q_min=0.2, #div_repulsion=True, # cont_beta_loss=True, # beta_gradient_damping=0.999, # phase_transition=1, #huber_energy_scale=0.1, use_average_cc_pos=0.2, # smoothen it out a bit name="FullOCLoss")( # oc output and payload [ pred_beta, pred_ccoords, pred_dist, pred_energy_corr, pred_pos, pred_time, pred_id ] + [energy] + # truth information [ pre_selection['t_idx'], pre_selection['t_energy'], pre_selection['t_pos'], pre_selection['t_time'], pre_selection['t_pid'], pre_selection['t_spectator_weight'], pre_selection['t_fully_contained'], pre_selection['t_rec_energy'], pre_selection['t_is_unique'], pre_selection['rs'] ]) #fast feedback pred_ccoords = PlotCoordinates(plot_debug_every, outdir=debug_outdir, name='condensation')([ pred_ccoords, pred_beta, pre_selection['t_idx'], rs ]) model_outputs = re_integrate_to_full_hits(pre_selection, pred_ccoords, pred_beta, pred_energy_corr, pred_pos, pred_time, pred_id, pred_dist, dict_output=True) return DictModel(inputs=Inputs, outputs=model_outputs)
def gravnet_model(Inputs, feature_dropout=-1., addBackGatherInfo=True): ######## pre-process all inputs and create global indices etc. No DNN actions here feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits = td.interpretAllModelInputs( Inputs) # feat = Lambda(lambda x: tf.squeeze(x,axis=1)) (feat) #tf.print([(t.shape, t.name) for t in [feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits]]) #feat, t_idx, t_energy, t_pos, t_time, t_pid = NormalizeInputShapes()( # [feat, t_idx, t_energy, t_pos, t_time, t_pid] # ) orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits gidx_orig = CreateGlobalIndices()(feat) _, row_splits = RaggedConstructTensor()([feat, row_splits]) rs = row_splits feat_norm = ProcessFeatures()( feat) #get rid of unit scalings, almost normalise feat_norm = BatchNormalization(momentum=0.6)(feat_norm) x = feat_norm energy = SelectFeatures(0, 1)(feat) time = SelectFeatures(8, 9)(feat) orig_coords = SelectFeatures(5, 8)(feat_norm) ######## create output lists allfeat = [] backgatheredids = [] gatherids = [] backgathered = [] backgathered_coords = [] ####### create simple first coordinate transformation explicitly (time critical) #trans_coords = ManualCoordTransform()(orig_coords) coords = orig_coords coords = Dense(3, use_bias=False, kernel_initializer=tf.keras.initializers.Identity())(coords) nidx, dist = KNN(K=48)([coords, rs]) first_nidx = nidx first_dist = dist first_coords = coords dist = LocalDistanceScaling(max_scale=10.)([dist, Dense(1)(x)]) x_mp = DistanceWeightedMessagePassing([32, 16, 8])([x, nidx, dist]) x_mp = Dense(32, activation='elu')(x_mp) x_mp = BatchNormalization(momentum=0.6)(x_mp) x = Concatenate()([x, x_mp]) ###### collect information about the surrounding energy and time distributions per vertex ### ncov = Dense(4, kernel_initializer=tf.keras.initializers.Identity())( Concatenate()([energy, time, x])) ncov = NeighbourCovariance()([coords, dist, ncov, nidx]) ncov = Dense(24, activation='elu', name='pre_dense_ncov_c')(ncov) ncov = BatchNormalization(momentum=0.6)(ncov) #should be enough for PCA, total info: 2 * (9+3) ##### put together and process #### x = Concatenate()([x, x_mp, ncov]) #x = Dense(64, activation='elu',name='pre_dense_a')(x) x = Dense(64, activation='elu', name='pre_dense_b')(x) ####### add first set of outputs to output lists allfeat.append(x) backgathered_coords.append(coords) total_iterations = 5 sel_gidx = gidx_orig for i in range(total_iterations): ###### reshape the graph to fewer vertices #### hier = Dense(1)(x) #narrow the local scaling down at each iteration as the coords become more abstract dist = LocalDistanceScaling(max_scale=10.)( [dist, Dense(1)(Concatenate()([x, dist]))]) x_cl, rs, bidxs, sel_gidx, energy, x, t_idx, coords = LocalClusterReshapeFromNeighbours( K=6, radius= 0.2, #doesn't really have an effect because of local distance scaling print_reduction=True, loss_enabled=True, loss_scale=1., loss_repulsion=0.4, print_loss=True, name='clustering_' + str(i))([ x, dist, hier, nidx, rs, sel_gidx, energy, x, t_idx, coords, t_idx ]) #last is truth index used by layer gatherids.append(bidxs) energy = ReduceSumEntirely()(energy) #use EdgeConv operation to determine cluster properties if True or i: #only after second iteration because of OOM x_cl = Reshape([-1, x.shape[-1]])(x_cl) #get to shape V x K x F x_cl = EdgeConvStatic([64, 64, 64], add_mean=True, name="ec_static_" + str(i))(x_cl) x_cl = Concatenate()([x, x_cl]) x = Concatenate()([x_cl, StopGradient()(coords)]) x = Dense(64, activation='elu', name='dense_clc0_' + str(i))(x) x = Dense(64, activation='elu', name='dense_clc1_' + str(i))(x) x = BatchNormalization(momentum=0.6)(x) #notice last relu for feature weighting later ### now these are the new cluster features, up for the next iteration of building new latent space #x = RaggedGlobalExchange()([x,rs]) x_gn, coords, nidx, dist = RaggedGravNet( n_neighbours=64 + 32 * i, n_dimensions=3, n_filters=64, n_propagate=64, return_self=True)([Concatenate()([coords, x]), rs]) ng_coords = StopGradient()(coords) ng_dist = StopGradient()(dist) x_gn = Dense(32, activation='elu')(x_gn) ### add neighbour summary statistics #dist = LocalDistanceScaling(max_scale=10.)([dist, Dense(1)(x_gn)]) x_ncov = Dense(16)(x) x_ncov = NeighbourCovariance()([ng_coords, ng_dist, x_ncov, nidx]) x_ncov = Dense(64, activation='elu', name='dense_ncov_a_' + str(i))(x_ncov) x_ncov = Dense(64, activation='elu', name='dense_ncov_b_' + str(i))(x_ncov) x_ncov = BatchNormalization(momentum=0.6)(x_ncov) ### with all this information perform a few message passing steps x_mp = x x_mp = DistanceWeightedMessagePassing([32, 16, 8])([x_mp, nidx, ng_dist]) x_mp = Dense(64, activation='elu')(x_mp) x_mp = Dense(32, activation='elu')(x_mp) x_mp = BatchNormalization(momentum=0.6)(x_mp) x = Concatenate()([x, x_mp, x_ncov, x_gn, ng_coords, ng_dist]) ##### prepare output of this iteration x = Dense(64, activation='elu', name='dense_out_a_' + str(i))(x) x = Dense(64, activation='elu', name='dense_out_b_' + str(i))(x) x = BatchNormalization(momentum=0.6)(x) #### compress further for output, but forward fill 64 feature x to next iteration x_r = Dense(32, activation='elu', name='dense_out_c_' + str(i))(x) #x_r = Concatenate()([x_r, ng_coords, ng_dist]) #x_r = Concatenate()([StopGradient()(coords),x_r]) ## add coordinates, might come handy for cluster space if i >= total_iterations - 1: energy = MultiBackGather()( [energy, gatherids]) #assign energy sum to all cluster components allfeat.append(MultiBackGather()([x_r, gatherids])) backgatheredids.append(MultiBackGather()([sel_gidx, gatherids])) backgathered_coords.append(MultiBackGather()([coords, gatherids])) x = Concatenate(name='allconcat')(allfeat) x = RaggedGlobalExchange()([x, row_splits]) x = Dense(256, activation='elu', name='globalexdense')(x) x = RaggedGlobalExchange()([x, row_splits]) x = Dense(128, activation='elu', name='alldense')(x) x = Dense(64, activation='elu')(x) x = Dense(32, activation='elu')(x) x = Concatenate()([first_coords, x]) x = BatchNormalization(momentum=0.6)(x) pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id = create_outputs( x, feat) # # # double scale phase transition with linear beta + qmin # -> more high beta points, but: payload loss will still scale one # (or two, but then doesn't matter) # pred_beta = LLFullObjectCondensation( print_loss=True, energy_loss_weight=1e-5, position_loss_weight=1e-5, #seems broken timing_loss_weight=1e-5, #1e-3, beta_loss_scale=3., repulsion_scaling=1., q_min=2.0, use_average_cc_pos=False, use_energy_weights=True, prob_repulsion=True, phase_transition=True, phase_transition_double_weight=False, alt_potential_norm=True, cut_payload_beta_gradient=False)([ pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id, orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, row_splits ]) return Model(inputs=Inputs, outputs=[ pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id, rs ] + backgatheredids + backgathered_coords)
def pre_selection_staged( indict, debug_outdir, trainable, name='pre_selection_add_stage_0', debugplots_after=-1, reduction_threshold=0.75, use_edges=True, print_info=False, record_metrics=False, n_coords=3, edge_nodes_0=16, edge_nodes_1=8, ): ''' Takes the output of the preselection model and selects again :) But the outputs are compatible, this one can be chained This one uses full blown GravNet Gets as inputs: indict['scatterids'] indict['orig_t_idx'] indict['orig_t_energy'] indict['orig_dim_coords'] indict['rs'] indict['orig_row_splits'] indict['features'] indict['orig_features'] indict['coords'] indict['addfeat'] indict['energy'] indict['t_idx'] indict['t_energy'] ... all the truth info ''' from GravNetLayersRagged import RaggedGravNet, DistanceWeightedMessagePassing, ElementScaling from GravNetLayersRagged import SelectFromIndices, GooeyBatchNorm, MaskTracksAsNoise from GravNetLayersRagged import AccumulateNeighbours, KNN, MultiAttentionGravNetAdd from LossLayers import LLClusterCoordinates, LLFillSpace from DebugLayers import PlotCoordinates from MetricsLayers import MLReductionMetrics from Regularizers import MeanMaxDistanceRegularizer, AverageDistanceRegularizer #assume the inputs are normalised rs = indict['rs'] t_idx = indict['t_idx'] track_charge = SelectFeatures(2, 3)( indict['unproc_features']) #zero for calo hits x = Concatenate()([indict['features'], indict['addfeat']]) x = Dense(64, activation='elu', trainable=trainable)(x) gn_pre_coords = indict['coords'] gn_pre_coords = ElementScaling(name=name + 'es1', trainable=trainable)(gn_pre_coords) x = Concatenate()([gn_pre_coords, x]) x, coords, nidx, dist = RaggedGravNet(n_neighbours=32, n_dimensions=n_coords, n_filters=64, n_propagate=64, coord_initialiser_noise=1e-5, feature_activation=None, record_metrics=record_metrics, use_approximate_knn=True, use_dynamic_knn=True, trainable=trainable, name=name + '_gn1')([x, rs]) #the two below are mostly running to record metrics and kill very bad coordinate scalings dist = MeanMaxDistanceRegularizer(strength=1e-6 if trainable else 0., record_metrics=record_metrics)(dist) dist = AverageDistanceRegularizer(strength=1e-6 if trainable else 0., record_metrics=record_metrics)(dist) if debugplots_after > 0: coords = PlotCoordinates(debugplots_after, outdir=debug_outdir, name=name + '_gn1_coords')( [coords, indict['energy'], t_idx, rs]) x = DistanceWeightedMessagePassing([32, 32, 8, 8], name=name + 'dmp1', trainable=trainable)([x, nidx, dist]) x_matt = Dense(16, activation='elu', name=name + '_matt_dense')(x) x_matt = MultiAttentionGravNetAdd(5, name=name + '_att_gn1', record_metrics=record_metrics)( [x, x_matt, coords, nidx]) x = Concatenate()([x, x_matt]) x = Dense(64, activation='elu', name=name + '_bef_coord_dense')(x) coords = Add()([ Dense(n_coords, name=name + '_coord_add_dense', kernel_initializer='zeros')(x), coords ]) if debugplots_after > 0: coords = PlotCoordinates(debugplots_after, outdir=debug_outdir, name=name + '_red_coords')( [coords, indict['energy'], t_idx, rs]) nidx, dist = KNN( K=16, radius='dynamic', #use dynamic feature record_metrics=record_metrics, name=name + '_knn', min_bins=[7, 7] #this can be fine grained )([coords, rs]) coords = LLClusterCoordinates(print_loss=print_info, record_metrics=record_metrics, active=trainable, print_batch_time=False, scale=5.)([coords, t_idx, rs]) coords = LLFillSpace( active=trainable, record_metrics=record_metrics, scale=0.025, #just mild runevery=-1, #give it a kick only every now and then - hat's enough )([coords, rs]) unred_rs = rs cluster_tidx = MaskTracksAsNoise(active=trainable)([t_idx, track_charge]) gnidx, gsel, group_backgather, rs = reduce_indices( x, dist, nidx, rs, cluster_tidx, threshold=reduction_threshold, print_reduction=print_info, trainable=trainable, name=name + '_reduce_indices', use_edges=use_edges, edge_nodes_0=edge_nodes_0, edge_nodes_1=edge_nodes_1, return_backscatter=False) gsel = MLReductionMetrics(name=name + '_reduction', record_metrics=True)( [gsel, t_idx, indict['t_energy'], unred_rs, rs]) selfeat = SelectFromIndices()([gsel, indict['features']]) unproc_features = SelectFromIndices()([gsel, indict['unproc_features']]) energy = indict['energy'] x = AccumulateNeighbours('minmeanmax')([x, gnidx, energy]) x = SelectFromIndices()([gsel, x]) #add more useful things coords = AccumulateNeighbours('mean')([coords, gnidx, energy]) coords = SelectFromIndices()([gsel, coords]) phys_coords = AccumulateNeighbours('mean')( [indict['phys_coords'], gnidx, energy]) phys_coords = SelectFromIndices()([gsel, phys_coords]) energy = AccumulateNeighbours('sum')([energy, gnidx]) energy = SelectFromIndices()([gsel, energy]) out = {} out['not_noise_score'] = AccumulateNeighbours('mean')( [indict['not_noise_score'], gnidx]) out['not_noise_score'] = SelectFromIndices()( [gsel, out['not_noise_score']]) out['scatterids'] = indict['scatterids'] + [group_backgather ] #append new selection #re-build standard feature layout out['features'] = selfeat out['unproc_features'] = unproc_features out['coords'] = coords out['phys_coords'] = phys_coords out['addfeat'] = GooeyBatchNorm(name=name + '_gooey_norm', trainable=trainable)(x) #norm them out['energy'] = energy out['rs'] = rs for k in indict.keys(): if 't_' == k[0:2]: out[k] = SelectFromIndices()([gsel, indict[k]]) #some pass throughs: out['orig_dim_coords'] = indict['orig_dim_coords'] out['orig_t_idx'] = indict['orig_t_idx'] out['orig_t_energy'] = indict['orig_t_energy'] out['orig_row_splits'] = indict['orig_row_splits'] #check anymissing = False for k in indict.keys(): if not k in out.keys(): anymissing = True print(k, 'missing') if anymissing: raise ValueError("key not found") return out
def gravnet_model(Inputs, feature_dropout=-1., addBackGatherInfo=True): feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits = td.interpretAllModelInputs( Inputs) orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits gidx_orig = CreateGlobalIndices()(feat) _, row_splits = RaggedConstructTensor()([feat, row_splits]) rs = row_splits feat_norm = ProcessFeatures()(feat) feat_norm = BatchNormalization(momentum=0.6)(feat_norm) allfeat = [] x = feat_norm backgatheredids = [] gatherids = [] backgathered = [] backgathered_coords = [] energysums = [] #really simple real coordinates energy = SelectFeatures(0, 1)(feat) orig_coords = SelectFeatures(4, 7)(feat_norm) coords = Dense(3, use_bias=False, kernel_initializer=tf.keras.initializers.Identity())( orig_coords) #just rotation and scaling #see whats there nidx, dist = KNN(K=64, radius=1.0)([coords, rs]) x = RaggedGlobalExchange()([x, row_splits]) x_c = Dense(32, activation='elu')(x) x_c = Dense(8)(x_c) #just a few features are enough here #this can be full blown because of the small number of input features x_c = NeighbourApproxPCA(hidden_nodes=[32, 32, 9])( [coords, dist, x_c, nidx]) x_mp = DistanceWeightedMessagePassing([32, 16, 8])([x, nidx, dist]) x = Concatenate()([x, x_c, x_mp]) #this is going to be among the most expensive operations: x = Dense(64, activation='elu', name='pre_dense_a')(x) x = BatchNormalization(momentum=0.6)(x) x = Dense(32, activation='elu', name='pre_dense_b')(x) x = BatchNormalization(momentum=0.6)(x) allfeat.append(x) backgathered_coords.append(coords) sel_gidx = gidx_orig cdist = dist ccoords = coords total_iterations = 5 collectedccoords = [] for i in range(total_iterations): cluster_neighbours = 5 n_dimensions = 3 #make it plottable #derive new coordinates for clustering if i: ccoords = Add()([ ccoords, (Dense(n_dimensions, use_bias=False, name='newcoords' + str(i), kernel_initializer='zeros')(x)) ]) nidx, cdist = KNN(K=6 * cluster_neighbours, radius=-1.0)([ccoords, rs]) #here we use more neighbours to improve learning of the cluster space #cluster first #hier = Dense(1,activation='sigmoid')(x) #distcorr = Dense(dist.shape[-1],activation='relu',kernel_initializer='zeros')(Concatenate()([x,dist])) #dist = Add()([distcorr,dist]) cdist = LocalDistanceScaling(max_scale=5.)([ cdist, Dense(1, kernel_initializer='zeros')(Concatenate()([x, cdist])) ]) #what if we let the individual distances scale here? so move hits in and out? x_cl, rs, bidxs, sel_gidx, energy, x, t_idx, coords, ccoords, cdist = LocalClusterReshapeFromNeighbours2( K=cluster_neighbours, radius=0.1, print_reduction=True, loss_enabled=True, loss_scale=1., loss_repulsion=0.4, #.5 hier_transforms=[64, 32, 32], print_loss=True, name='clustering_' + str(i))([ x, cdist, nidx, rs, sel_gidx, energy, x, t_idx, coords, ccoords, cdist, t_idx ]) gatherids.append(bidxs) #explicit energy = ReduceSumEntirely()( energy) #sums up all contained energy per cluster n_energy = BatchNormalization(momentum=0.6)(energy) x = Concatenate()([x_cl, n_energy]) x = Dense(128, activation='elu', name='dense_clc_a' + str(i))(x) x = BatchNormalization(momentum=0.6)(x) #x = Dense(128, activation='elu',name='dense_clc_b'+str(i))(x) x = Dense(64, activation='elu')(x) x = BatchNormalization(momentum=0.6)(x) nneigh = 64 nfilt = 64 nprop = 64 x = Concatenate()([coords, x]) x_gn, coords, nidx, dist = RaggedGravNet(n_neighbours=nneigh, n_dimensions=n_dimensions, n_filters=nfilt, n_propagate=nprop)([x, rs]) x_sp = Dense(16)(x) x_sp = NeighbourApproxPCA(hidden_nodes=[32, 32, n_dimensions**2])( [coords, dist, x_sp, nidx]) x_sp = BatchNormalization(momentum=0.6)(x_sp) x_mp = DistanceWeightedMessagePassing([32, 32, 16, 16, 8, 8])([x, nidx, dist]) x_mp = BatchNormalization(momentum=0.6)(x_mp) #x_sp=x_mp x = Concatenate()([x, x_mp, x_sp, x_gn]) #check and compress it all x = Dense(128, activation='elu', name='dense_a_' + str(i))(x) x = BatchNormalization(momentum=0.6)(x) #x = Dense(128, activation='elu',name='dense_b_'+str(i))(x) x = Dense(64, activation='elu', name='dense_c_' + str(i))(x) x = Concatenate()([StopGradient()(ccoords), StopGradient()(cdist), x]) x = BatchNormalization(momentum=0.6)(x) #record more and more the deeper we go x_r = x energysums.append(MultiBackGather()( [energy, gatherids])) #assign energy sum to all cluster components allfeat.append(MultiBackGather()([x_r, gatherids])) backgatheredids.append(MultiBackGather()([sel_gidx, gatherids])) backgathered_coords.append(MultiBackGather()([ccoords, gatherids])) x = Concatenate(name='allconcat')(allfeat) x = RaggedGlobalExchange()([x, row_splits]) #x - Dropout(0.3)(x)#force to use different information sources x = Dense(128, activation='elu', name='alldense')(x) x = BatchNormalization(momentum=0.6)(x) x = Dense(64, activation='elu')(x) #x = Concatenate()([x]+energysums) pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id = create_outputs( x, feat) #loss pred_beta = LLFullObjectCondensation( print_loss=True, energy_loss_weight=1e-4, position_loss_weight=1e-2, timing_loss_weight=1e-3, beta_loss_scale=1., repulsion_scaling=1., q_min=1.5, use_average_cc_pos=0.1, prob_repulsion=True, phase_transition=1, alt_potential_norm=True, kalpha_damping_strength=0.5, #1., name="FullOCLoss")([ pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id, orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, row_splits ]) return Model(inputs=Inputs, outputs=[ pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id, rs ] + backgatheredids + backgathered_coords)
def gravnet_model(Inputs, feature_dropout=-1.): nregressions = 5 # I_data = tf.keras.Input(shape=(num_features,), dtype="float32") # I_splits = tf.keras.Input(shape=(1,), dtype="int32") I_feat = Inputs[0] I_truth = Inputs[2] I_splits = tf.cast(Inputs[1], tf.int32) x_data, x_row_splits = RaggedConstructTensor(name="construct_ragged")( [I_feat, I_splits]) input_features = x_data # these are going to be passed through for the loss x_basic = BatchNormalization(momentum=0.6)( x_data) # mask_and_norm is just batch norm now x = x_basic x = RaggedGlobalExchange(name="global_exchange")([x, x_row_splits]) x = Dense(64, activation='elu', name="dense_start")(x) n_filters = 0 n_gravnet_layers = 4 feat = [x_basic, x] for i in range(n_gravnet_layers): n_filters = 196 n_propagate = 128 n_propagate_2 = [64, 32, 16, 8, 4, 2] n_neighbours = 32 n_dim = 8 #if n_dim < 2: # n_dim = 2 x, coords, neighbor_indices, neighbor_dist = RaggedGravNet( n_neighbours=n_neighbours, n_dimensions=n_dim, n_filters=n_filters, n_propagate=n_propagate, name='gravnet_' + str(i))([x, x_row_splits]) x = DistanceWeightedMessagePassing(n_propagate_2)( [x, neighbor_indices, neighbor_dist]) x = BatchNormalization(momentum=0.6)(x) x = Dense(128, activation='elu', name="dense_bottom_" + str(i) + "_a")(x) x = BatchNormalization(momentum=0.6, name="bn_a_" + str(i))(x) x = Dense(96, activation='elu', name="dense_bottom_" + str(i) + "_b")(x) x = RaggedGlobalExchange(name="global_exchange_bot_" + str(i))([x, x_row_splits]) x = Dense(96, activation='elu', name="dense_bottom_" + str(i) + "_c")(x) x = BatchNormalization(momentum=0.6, name="bn_b_" + str(i))(x) feat.append(x) x = Concatenate(name="concat_gravout")(feat) x = Dense(128, activation='elu', name="dense_last_a")(x) x = BatchNormalization(momentum=0.6, name="bn_last_a")(x) x = Dense(128, activation='elu', name="dense_last_a1")(x) x = BatchNormalization(momentum=0.6, name="bn_last_a1")(x) x = Dense(128, activation='elu', name="dense_last_a2")(x) x = BatchNormalization(momentum=0.6, name="bn_last_a2")(x) x = Dense(64, activation='elu', name="dense_last_b")(x) x = Dense(64, activation='elu', name="dense_last_c")(x) x = create_output_layers(x, x_row_splits, n_ccoords=n_ccoords, scale_exp_e=False) truth_dict = create_ragged_cal_truth_dict(I_truth) pred_dict = create_ragged_cal_pred_dict(x, n_ccoords=n_ccoords) feat_dict = create_ragged_cal_feature_dict(I_feat) x = LLObjectCondensation()( [x, truth_dict, pred_dict, feat_dict, x_row_splits]) # outputs = tf.tuple([predictions, x_row_splits]) return Model(inputs=Inputs, outputs=x)
def gravnet_model(Inputs, feature_dropout=-1., addBackGatherInfo=True): feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits = td.interpretAllModelInputs( Inputs) orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits gidx_orig = CreateGlobalIndices()(feat) _, row_splits = RaggedConstructTensor()([feat, row_splits]) rs = row_splits feat_norm = ProcessFeatures()(feat) allfeat = [feat_norm] x = feat_norm backgatheredids = [] gatherids = [] backgathered = [] backgathered_coords = [] energy = SelectFeatures(0, 1)(feat) orig_coords = SelectFeatures(5, 8)(feat_norm) coords = orig_coords coords = Dense(3)(coords) #just rotation and scaling #add gradient #see whats there nidx, dist = KNN(K=96, radius=1.0)([coords, rs]) x = Dense(4, activation='elu', name='pre_pre_dense')(x) #just a few features are enough here #this can be full blown because of the small number of input features x_c = SoftPixelCNN(mode='full', subdivisions=4, name='prePCNN')([coords, x, dist, nidx]) x = Concatenate()([x, x_c]) #this is going to be among the most expensive operations: x = Dense(128, activation='elu', name='pre_dense_a')(x) x = BatchNormalization(momentum=0.6)(x) x = Dense(64, activation='elu', name='pre_dense_b')(x) x = BatchNormalization(momentum=0.6)(x) #allfeat.append(Dense(16, activation='elu',name='feat_compress_pre')(x)) backgathered_coords.append(coords) sel_gidx = gidx_orig total_iterations = 5 for i in range(total_iterations): #cluster first hier = Dense(1, activation='sigmoid')(x) x_cl, rs, bidxs, sel_gidx, energy, x, t_idx = LocalClusterReshapeFromNeighbours( K=8, radius=0.1, print_reduction=True, loss_enabled=True, loss_scale=2., loss_repulsion=0.3, print_loss=True, name='clustering_' + str(i))( [x, dist, hier, nidx, rs, sel_gidx, energy, x, t_idx, t_idx]) #explicit energy = ReduceSumEntirely()( energy) #sums up all contained energy per cluster gatherids.append(bidxs) x_cl = Dense(128, activation='elu', name='dense_clc_' + str(i))(x_cl) n_energy = BatchNormalization(momentum=0.6)(energy) x = Concatenate()([x, x_cl, n_energy]) pixelcompress = 4 nneigh = 32 + 4 * i nfilt = 32 + 4 * i nprop = 32 n_dimensions = 3 #make it plottable x_gn, coords, nidx, dist = RaggedGravNet(n_neighbours=nneigh, n_dimensions=n_dimensions, n_filters=nfilt, n_propagate=nprop)([x, rs]) subdivisions = 4 #add more shape information x_sp = Dense(pixelcompress, activation='elu')(x) x_sp = BatchNormalization(momentum=0.6)(x_sp) x_sp = SoftPixelCNN(mode='full', subdivisions=4)([coords, x_sp, dist, nidx]) x_sp = Dense(128, activation='elu', name='dense_spc_' + str(i))(x_sp) x_mp = MessagePassing([32, 32, 16, 16, 8, 8])([x, nidx]) x_mp = Dense(nfilt, activation='elu', name='dense_mpc_' + str(i))(x_sp) x = Concatenate()([x, x_sp, x_gn, x_mp]) #check and compress it all x = Dense(128, activation='elu', name='dense_a_' + str(i))(x) x = Dense(64, activation='elu', name='dense_b_' + str(i))(x) x = BatchNormalization(momentum=0.6)(x) x_r = x #record more and more the deeper we go if i < total_iterations - 1: x_r = Dense(12 * (i + 1), activation='elu', name='dense_rec_' + str(i))(x) else: energy = MultiBackGather()( [energy, gatherids]) #assign energy sum to all cluster components allfeat.append(MultiBackGather()([x_r, gatherids])) backgatheredids.append(MultiBackGather()([sel_gidx, gatherids])) backgathered_coords.append(MultiBackGather()([coords, gatherids])) x = Concatenate(name='allconcat')(allfeat) #x - Dropout(0.3)(x)#force to use different information sources x = Dense(128, activation='elu', name='alldense')(x) x = BatchNormalization(momentum=0.6)(x) x = Dense(64, activation='elu')(x) x = BatchNormalization(momentum=0.6)(x) x = Concatenate()([feat, x]) x = BatchNormalization(momentum=0.6)(x) x = Concatenate()([x, energy]) pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id = create_outputs( x, feat) #loss pred_beta = LLFullObjectCondensation( print_loss=True, energy_loss_weight=1e-4, position_loss_weight=1e-2, timing_loss_weight=1e-3, beta_loss_scale=1., repulsion_scaling=1., q_min=1.5, prob_repulsion=True, phase_transition=1, alt_potential_norm=True)([ pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id, orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, row_splits ]) return Model(inputs=Inputs, outputs=[ pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id, rs ] + backgatheredids + backgathered_coords)
def gravnet_model(Inputs, feature_dropout=-1., addBackGatherInfo=True): feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits = td.interpretAllModelInputs( Inputs) orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits gidx_orig = CreateGlobalIndices()(feat) energy = SelectFeatures(0, 1)(feat) _, row_splits = RaggedConstructTensor()([feat, row_splits]) rs = row_splits #n_cluster_coords=6 x = feat #Dense(64,activation='elu')(feat) backgatheredids = [] gatherids = [] backgathered = [] backgathered_en = [] allfeat = [x] sel_gidx = gidx_orig for i in range(6): pixelcompress = 16 + 8 * i nneigh = 8 + 32 * i n_dimensions = int(4 + i) x = BatchNormalization(momentum=0.6)(x) #do some processing x = Dense(64, activation='elu', name='dense_a_' + str(i))(x) x = BatchNormalization(momentum=0.6)(x) x, coords, nidx, dist = RaggedGravNet(n_neighbours=nneigh, n_dimensions=n_dimensions, n_filters=64, n_propagate=64)([x, rs]) x = Concatenate()([x, MessagePassing([64, 32, 16, 8])([x, nidx])]) #allfeat.append(x) x = Dense(128, activation='elu')(x) x = Dense(pixelcompress, activation='elu')(x) x_sp = SoftPixelCNN(length_scale=1.)([coords, x, nidx]) x = Concatenate()([x, x_sp]) x = BatchNormalization(momentum=0.6)(x) x = Dense(128, activation='elu')(x) x = Dense(pixelcompress, activation='elu')(x) hier = Dense(1, activation='sigmoid', name='hier_' + str(i))(x) #clustering hierarchy dist = ScalarMultiply(1 / (3. * float(i + 0.5)))(dist) x, rs, bidxs, t_idx = LocalClusterReshapeFromNeighbours( K=5, radius=0.1, print_reduction=True, loss_enabled=True, loss_scale=2., loss_repulsion=0.5, print_loss=True)([x, dist, hier, nidx, rs, t_idx, t_idx]) print('>>>>x0', x.shape) gatherids.append(bidxs) if addBackGatherInfo: backgatheredids.append(MultiBackGather()([sel_gidx, gatherids])) print('>>>>x1', x.shape) x = BatchNormalization(momentum=0.6)(x) x_record = Dense(2 * pixelcompress, activation='elu')(x) x_record = BatchNormalization(momentum=0.6)(x_record) allfeat.append(MultiBackGather()([x_record, gatherids])) x = Concatenate(name='allconcat')(allfeat) x = Dense(128, activation='elu', name='alldense')(x) x = BatchNormalization(momentum=0.6)(x) x = Dense(64, activation='elu')(x) x = BatchNormalization(momentum=0.6)(x) x = Concatenate()([feat, x]) x = BatchNormalization(momentum=0.6)(x) pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id = create_outputs( x, feat) #loss pred_beta = LLFullObjectCondensation( print_loss=True, energy_loss_weight=1e-3, position_loss_weight=1e-3, timing_loss_weight=1e-3, )([ pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id, orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, row_splits ]) return Model(inputs=Inputs, outputs=[ pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id, rs ] + backgatheredids)
def gravnet_model(Inputs, feature_dropout=-1., addBackGatherInfo=True): ######## pre-process all inputs and create global indices etc. No DNN actions here feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits = td.interpretAllModelInputs( Inputs) feat, t_idx, t_energy, t_pos, t_time, t_pid = NormalizeInputShapes()( [feat, t_idx, t_energy, t_pos, t_time, t_pid]) orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits gidx_orig = CreateGlobalIndices()(feat) _, row_splits = RaggedConstructTensor()([feat, row_splits]) rs = row_splits feat_norm = ProcessFeatures()( feat) #get rid of unit scalings, almost normalise feat_norm = BatchNormalization(momentum=0.6)(feat_norm) x = feat_norm energy = SelectFeatures(0, 1)(feat) time = SelectFeatures(8, 9)(feat) orig_coords = SelectFeatures(5, 8)(feat_norm) ######## create output lists allfeat = [] backgatheredids = [] gatherids = [] backgathered = [] backgathered_coords = [] ####### create simple first coordinate transformation explicitly (time critical) coords = orig_coords coords = Dense(16, activation='elu')(coords) coords = Dense(32, activation='elu')(coords) coords = Dense(3, use_bias=False)(coords) coords = ScalarMultiply(0.1)(coords) coords = Add()([coords, orig_coords]) coords = Dense(3, use_bias=False, kernel_initializer=tf.keras.initializers.identity())(coords) first_coords = coords ###### apply one gravnet-like transformation (explicit here because we have created coords by hand) ### nidx, dist = KNN(K=48)([coords, rs]) x_mp = DistanceWeightedMessagePassing([32])([x, nidx, dist]) first_nidx = nidx first_dist = dist ###### collect information about the surrounding energy and time distributions per vertex ### ncov = NeighbourCovariance()( [coords, ReluPlusEps()(Concatenate()([energy, time])), nidx]) ncov = BatchNormalization(momentum=0.6)(ncov) ncov = Dense(64, activation='elu', name='pre_dense_ncov_a')(ncov) ncov = Dense(32, activation='elu', name='pre_dense_ncov_b')(ncov) ##### put together and process #### x = Concatenate()([x, x_mp, ncov, coords]) x = Dense(64, activation='elu', name='pre_dense_a')(x) x = BatchNormalization(momentum=0.6)(x) x = Dense(32, activation='elu', name='pre_dense_b')(x) ####### add first set of outputs to output lists allfeat.append(x) backgathered_coords.append(coords) total_iterations = 5 sel_gidx = gidx_orig for i in range(total_iterations): ###### reshape the graph to fewer vertices #### hier = Dense(1)(x) dist = LocalDistanceScaling()([dist, Dense(1)(x)]) x_cl, rs, bidxs, sel_gidx, energy, x, t_idx, coords = LocalClusterReshapeFromNeighbours( K=6, radius= 0.5, #doesn't really have an effect because of local distance scaling print_reduction=True, loss_enabled=True, loss_scale=4., loss_repulsion=0.5, print_loss=True, name='clustering_' + str(i))([ x, dist, hier, nidx, rs, sel_gidx, energy, x, t_idx, coords, t_idx ]) #last is truth index used by layer gatherids.append(bidxs) if i or True: x_cl_rs = Reshape([-1, x.shape[-1]])(x_cl) #get to shape V x K x F xec = EdgeConvStatic([32, 32, 32], name="ec_static_" + str(i))(x_cl_rs) x_cl = Concatenate()([x, xec]) ### explicitly sum energy and re-add to features energy = ReduceSumEntirely()(energy) n_energy = BatchNormalization(momentum=0.6)(energy) x = Concatenate()([x_cl, n_energy]) x = Dense(128, activation='elu', name='dense_clc0_' + str(i))(x) x = Dense(64, activation='relu', name='dense_clc1_' + str(i))(x) #notice last relu for feature weighting later x_gn, coords, nidx, dist = RaggedGravNet( n_neighbours=32 + 16 * i, n_dimensions=3, n_filters=64 + 16 * i, n_propagate=64, return_self=True)([Concatenate()([coords, x]), rs]) ### add neighbour summary statistics x_ncov = NeighbourCovariance()([coords, ReluPlusEps()(x), nidx]) x_ncov = Dense(128, activation='elu', name='dense_ncov_a_' + str(i))(x_ncov) x_ncov = BatchNormalization(momentum=0.6)(x_ncov) x_ncov = Dense(64, activation='elu', name='dense_ncov_b_' + str(i))(x_ncov) x = Concatenate()([x, x_ncov, x_gn]) ### with all this information perform a few message passing steps x_mp = MessagePassing([32, 32, 16, 16, 8, 8])([x, nidx]) x_mp = Dense(64, activation='elu', name='dense_mpc_' + str(i))(x_mp) x = Concatenate()([x, x_mp]) ##### prepare output of this iteration x = Dense(128, activation='elu', name='dense_out_a_' + str(i))(x) x = BatchNormalization(momentum=0.6)(x) x = Dense(64, activation='elu', name='dense_out_b_' + str(i))(x) x = BatchNormalization(momentum=0.6)(x) #### compress further for output, but forward fill 64 feature x to next iteration x_r = Dense(8 + 16 * i, activation='elu', name='dense_out_c_' + str(i))(x) #coords_nograd = StopGradient()(coords) #x_r = Concatenate()([coords_nograd,x_r]) ## add coordinates, might come handy for cluster space if i >= total_iterations - 1: energy = MultiBackGather()( [energy, gatherids]) #assign energy sum to all cluster components allfeat.append(MultiBackGather()([x_r, gatherids])) backgatheredids.append(MultiBackGather()([sel_gidx, gatherids])) backgathered_coords.append(MultiBackGather()([coords, gatherids])) x = Concatenate(name='allconcat')(allfeat) #x = Dropout(0.2)(x) x_mp = DistanceWeightedMessagePassing([32, 32, 32])([x, first_nidx, first_dist]) x = Concatenate()([x, x_mp]) x = Dense(128, activation='elu', name='alldense')(x) # TO BE ADDED WITH E LOSS x = Concatenate()([x,energy]) #x = Dropout(0.2)(x) x = BatchNormalization(momentum=0.6)(x) x = Dense(64, activation='elu')(x) x = BatchNormalization(momentum=0.6)(x) x = Dense(64, activation='elu')(x) pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id = create_outputs( x, feat) # # # double scale phase transition with linear beta + qmin # -> more high beta points, but: payload loss will still scale one # (or two, but then doesn't matter) # pred_beta = LLFullObjectCondensation( print_loss=True, energy_loss_weight=0., position_loss_weight=0., #seems broken timing_loss_weight=0., #1e-3, beta_loss_scale=1., repulsion_scaling=1., q_min=1.5, prob_repulsion=True, phase_transition=True, phase_transition_double_weight=False, alt_potential_norm=True, cut_payload_beta_gradient=False)([ pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id, orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, row_splits ]) return ExtendedMetricsModel(inputs=Inputs, outputs=[ pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id, rs ] + backgatheredids + backgathered_coords)