def pretrain_model(Inputs, td, debug_outdir=None): orig_inputs = td.interpretAllModelInputs(Inputs,returndict=True) presel = pre_selection_model_full(orig_inputs, debug_outdir, trainable=True, debugplots_after=1500, record_metrics=True, eweighted=True, ) # this will create issues with the output and is only needed if used in a full dim model. # so it's ok to pop it here for training presel.pop('scatterids') return DictModel(inputs=Inputs, outputs=presel)
def gravnet_model(Inputs, td, debug_outdir=None, plot_debug_every=1000, ): #################################################################################### ##################### Input processing, no need to change much here ################ #################################################################################### orig_inputs = td.interpretAllModelInputs(Inputs,returndict=True) orig_t_spectator_weight = CreateTruthSpectatorWeights(threshold=3., minimum=1e-1, active=True )([orig_inputs['t_spectator'], orig_inputs['t_idx']]) orig_inputs['t_spectator_weight'] = orig_t_spectator_weight #can be loaded - or use pre-selected dataset (to be made) pre_selection = pre_selection_model_full(orig_inputs,trainable=False) #just for info what's available print('available pre-selection outputs',[k for k in pre_selection.keys()]) t_spectator_weight = CreateTruthSpectatorWeights(threshold=3., minimum=1e-1, active=True )([pre_selection['t_spectator'], pre_selection['t_idx']]) rs = pre_selection['rs'] x_in = Concatenate()([pre_selection['coords'], pre_selection['features'], pre_selection['addfeat']]) x = x_in energy = pre_selection['energy'] coords = pre_selection['phys_coords']#physical coordinates c_coords = pre_selection['coords']#pre-clustered coordinates t_idx = pre_selection['t_idx'] #################################################################################### ##################### now the actual model goes below ############################## #################################################################################### allfeat = [] n_cluster_space_coordinates = 3 #extend coordinates already here if needed c_coords = extent_coords_if_needed(c_coords, x, n_cluster_space_coordinates) for i in range(total_iterations): # derive new coordinates for clustering x = RaggedGlobalExchange()([x, rs]) x = Dense(64,activation=dense_activation)(x) x = Dense(64,activation=dense_activation)(x) x = Dense(64,activation=dense_activation)(x) x = GooeyBatchNorm(**batchnorm_options)(x) ### reduction done n_dims = 6 #exchange information, create coordinates x = Concatenate()([c_coords,c_coords,c_coords,coords,x]) xgn, gncoords, gnnidx, gndist = RaggedGravNet(n_neighbours=64, n_dimensions=n_dims, n_filters=64, n_propagate=64, record_metrics=True, coord_initialiser_noise=1e-2, use_approximate_knn=False #weird issue with that for now )([x, rs]) x = Concatenate()([x,xgn]) #just keep them in a reasonable range #safeguard against diappearing gradients on coordinates gndist = AverageDistanceRegularizer(strength=1e-4, record_metrics=True )(gndist) gncoords = PlotCoordinates(plot_debug_every, outdir = debug_outdir, name='gn_coords_'+str(i))([gncoords, energy, t_idx, rs]) x = Concatenate()([gncoords,x]) pre_gndist=gndist if double_mp: for im,m in enumerate([64,64,32,32,16,16]): dscale=Dense(1)(x) gndist = LocalDistanceScaling(4.)([pre_gndist,dscale]) gndist = AverageDistanceRegularizer(strength=1e-6, record_metrics=True, name='average_distance_dmp_'+str(i)+'_'+str(im) )(gndist) x = DistanceWeightedMessagePassing([m], activation=dense_activation )([x,gnnidx,gndist]) else: x = DistanceWeightedMessagePassing([64,64,32,32,16,16], activation=dense_activation )([x,gnnidx,gndist]) x = GooeyBatchNorm(**batchnorm_options)(x) x = Dense(64,name='dense_past_mp_'+str(i),activation=dense_activation)(x) x = Dense(64,activation=dense_activation)(x) x = Dense(64,activation=dense_activation)(x) x = GooeyBatchNorm(**batchnorm_options)(x) allfeat.append(x) x = Concatenate()([c_coords]+allfeat+[pre_selection['not_noise_score']]) #do one more exchange with all x = Dense(64,activation=dense_activation)(x) x = Dense(64,activation=dense_activation)(x) x = Dense(64,activation=dense_activation)(x) ####################################################################### ########### the part below should remain almost unchanged ############# ########### of course with the exception of the OC loss ############# ########### weights ############# ####################################################################### x = GooeyBatchNorm(**batchnorm_options,name='gooey_pre_out')(x) x = Concatenate()([c_coords]+[x]) pred_beta, pred_ccoords, pred_dist, pred_energy_corr, \ pred_pos, pred_time, pred_id = create_outputs(x, pre_selection['unproc_features'], n_ccoords=n_cluster_space_coordinates) # loss pred_beta = LLFullObjectCondensation(scale=4., position_loss_weight=1e-5, timing_loss_weight=1e-5, beta_loss_scale=1., use_energy_weights=True, record_metrics=True, name="FullOCLoss", **loss_options )( # oc output and payload [pred_beta, pred_ccoords, pred_dist, pred_energy_corr, pred_pos, pred_time, pred_id] + [energy]+ # truth information [pre_selection['t_idx'] , pre_selection['t_energy'] , pre_selection['t_pos'] , pre_selection['t_time'] , pre_selection['t_pid'] , pre_selection['t_spectator_weight'], pre_selection['t_fully_contained'], pre_selection['t_rec_energy'], pre_selection['t_is_unique'], pre_selection['rs']]) #fast feedback pred_ccoords = PlotCoordinates(plot_debug_every, outdir = debug_outdir, name='condensation')([pred_ccoords, pred_beta,pre_selection['t_idx'], rs]) model_outputs = re_integrate_to_full_hits( pre_selection, pred_ccoords, pred_beta, pred_energy_corr, pred_pos, pred_time, pred_id, pred_dist, dict_output=True ) return DictModel(inputs=Inputs, outputs=model_outputs)
def gravnet_model( Inputs, td, empty_pca=False, total_iterations=2, variance_only=True, viscosity=0.1, print_viscosity=False, fluidity_decay=5e-4, # reaches after about 7k batches max_viscosity=0.95, debug_outdir=None, plot_debug_every=1000, ): #################################################################################### ##################### Input processing, no need to change much here ################ #################################################################################### orig_inputs = td.interpretAllModelInputs(Inputs, returndict=True) orig_t_spectator_weight = CreateTruthSpectatorWeights( threshold=3., minimum=1e-1, active=True)([orig_inputs['t_spectator'], orig_inputs['t_idx']]) orig_inputs['t_spectator_weight'] = orig_t_spectator_weight #can be loaded - or use pre-selected dataset (to be made) pre_selection = pre_selection_model_full(orig_inputs, trainable=False) #just for info what's available print([k for k in pre_selection.keys()]) t_spectator_weight = CreateTruthSpectatorWeights( threshold=3., minimum=1e-1, active=True)([pre_selection['t_spectator'], pre_selection['t_idx']]) rs = pre_selection['rs'] x_in = Concatenate()([ pre_selection['coords'], pre_selection['features'], pre_selection['addfeat'] ]) x = x_in energy = pre_selection['energy'] coords = pre_selection['phys_coords'] #physical coordinates c_coords = pre_selection['coords'] #pre-clustered coordinates t_idx = pre_selection['t_idx'] #################################################################################### ##################### now the actual model goes below ############################## #################################################################################### allfeat = [] n_cluster_space_coordinates = 3 #extend coordinates already here if needed coords = extent_coords_if_needed(coords, x, n_cluster_space_coordinates) for i in range(total_iterations): # derive new coordinates for clustering x = RaggedGlobalExchange()([x, rs]) x = Dense(64, activation='relu')(x) x = Dense(64, activation='relu')(x) x = Dense(64, activation='relu')(x) x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity, variance_only=variance_only, record_metrics=True, fluidity_decay=fluidity_decay)(x) ### reduction done n_dims = 3 #exchange information, create coordinates x = Concatenate()([coords, coords, c_coords, x]) x, gncoords, gnnidx, gndist = RaggedGravNet( n_neighbours=64, n_dimensions=n_dims, n_filters=64, n_propagate=64, record_metrics=True, use_approximate_knn=False #weird issue with that for now )([x, rs]) gncoords = PlotCoordinates(plot_debug_every, outdir=debug_outdir, name='gn_coords_' + str(i))([gncoords, energy, t_idx, rs]) #just keep them in a reasonable range #safeguard against diappearing gradients on coordinates gndist = AverageDistanceRegularizer(strength=0.01, record_metrics=True)(gndist) x = Concatenate()([energy, x]) x_pca = Dense(4, activation='relu')(x) #pca is expensive x_pca = ApproxPCA(empty=empty_pca)([gncoords, gndist, x_pca, gnnidx]) x = Concatenate()([x, x_pca]) x = DistanceWeightedMessagePassing([64, 64, 32, 32, 16, 16])([x, gnnidx, gndist]) x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity, record_metrics=True, variance_only=variance_only, fluidity_decay=fluidity_decay)(x) x = Dense(64, activation='relu')(x) x = Dense(64, activation='relu')(x) x = Dense(64, activation='relu')(x) x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity, variance_only=variance_only, record_metrics=True, fluidity_decay=fluidity_decay)(x) allfeat.append(x) x = Concatenate()([c_coords] + allfeat + [pre_selection['not_noise_score']]) #do one more exchange with all x = Dense(64, activation='elu')(x) x = Dense(64, activation='elu')(x) x = Dense(64, activation='elu')(x) ####################################################################### ########### the part below should remain almost unchanged ############# ########### of course with the exception of the OC loss ############# ########### weights ############# ####################################################################### x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity, fluidity_decay=fluidity_decay, record_metrics=True, variance_only=variance_only, name='gooey_pre_out')(x) x = Concatenate()([c_coords] + [x]) pred_beta, pred_ccoords, pred_dist, pred_energy_corr, \ pred_pos, pred_time, pred_id = create_outputs(x, pre_selection['unproc_features'], n_ccoords=n_cluster_space_coordinates) # loss pred_beta = LLFullObjectCondensation( scale=1., energy_loss_weight=2., #print_batch_time=True, position_loss_weight=1e-5, timing_loss_weight=1e-5, classification_loss_weight=1e-5, beta_loss_scale=1., too_much_beta_scale=1e-4, use_energy_weights=True, record_metrics=True, q_min=0.2, #div_repulsion=True, # cont_beta_loss=True, # beta_gradient_damping=0.999, # phase_transition=1, #huber_energy_scale=0.1, use_average_cc_pos=0.2, # smoothen it out a bit name="FullOCLoss")( # oc output and payload [ pred_beta, pred_ccoords, pred_dist, pred_energy_corr, pred_pos, pred_time, pred_id ] + [energy] + # truth information [ pre_selection['t_idx'], pre_selection['t_energy'], pre_selection['t_pos'], pre_selection['t_time'], pre_selection['t_pid'], pre_selection['t_spectator_weight'], pre_selection['t_fully_contained'], pre_selection['t_rec_energy'], pre_selection['t_is_unique'], pre_selection['rs'] ]) #fast feedback pred_ccoords = PlotCoordinates(plot_debug_every, outdir=debug_outdir, name='condensation')([ pred_ccoords, pred_beta, pre_selection['t_idx'], rs ]) model_outputs = re_integrate_to_full_hits(pre_selection, pred_ccoords, pred_beta, pred_energy_corr, pred_pos, pred_time, pred_id, pred_dist, dict_output=True) return DictModel(inputs=Inputs, outputs=model_outputs)