Beispiel #1
0
def pretrain_model(Inputs,
                   td,
                   debug_outdir=None):

    orig_inputs = td.interpretAllModelInputs(Inputs,returndict=True)

    presel = pre_selection_model_full(orig_inputs,
                             debug_outdir,
                             trainable=True,
                             debugplots_after=1500,
                             record_metrics=True,
                             eweighted=True,
                             )
    
    
    # this will create issues with the output and is only needed if used in a full dim model.
    # so it's ok to pop it here for training
    presel.pop('scatterids')
    
    return DictModel(inputs=Inputs, outputs=presel)
Beispiel #2
0
def gravnet_model(Inputs,
                  td,
                  debug_outdir=None,
                  plot_debug_every=1000,
                  ):
    ####################################################################################
    ##################### Input processing, no need to change much here ################
    ####################################################################################

    orig_inputs = td.interpretAllModelInputs(Inputs,returndict=True)
    
    
    orig_t_spectator_weight = CreateTruthSpectatorWeights(threshold=3.,
                                                     minimum=1e-1,
                                                     active=True
                                                     )([orig_inputs['t_spectator'], 
                                                        orig_inputs['t_idx']])
                                                     
    orig_inputs['t_spectator_weight'] = orig_t_spectator_weight                                                 
    #can be loaded - or use pre-selected dataset (to be made)
    pre_selection = pre_selection_model_full(orig_inputs,trainable=False)
    
    #just for info what's available
    print('available pre-selection outputs',[k for k in pre_selection.keys()])
                                          
    
    t_spectator_weight = CreateTruthSpectatorWeights(threshold=3.,
                                                     minimum=1e-1,
                                                     active=True
                                                     )([pre_selection['t_spectator'], 
                                                        pre_selection['t_idx']])
    rs = pre_selection['rs']
                               
    x_in = Concatenate()([pre_selection['coords'],
                          pre_selection['features'],
                          pre_selection['addfeat']])
                           
    x = x_in
    energy = pre_selection['energy']
    coords = pre_selection['phys_coords']#physical coordinates
    c_coords = pre_selection['coords']#pre-clustered coordinates
    t_idx = pre_selection['t_idx']
    
    ####################################################################################
    ##################### now the actual model goes below ##############################
    ####################################################################################
    
    allfeat = []
    
    n_cluster_space_coordinates = 3
    
    
    #extend coordinates already here if needed
    c_coords = extent_coords_if_needed(c_coords, x, n_cluster_space_coordinates)
        

    for i in range(total_iterations):

        # derive new coordinates for clustering
        x = RaggedGlobalExchange()([x, rs])
        
        x = Dense(64,activation=dense_activation)(x)
        x = Dense(64,activation=dense_activation)(x)
        x = Dense(64,activation=dense_activation)(x)
        x = GooeyBatchNorm(**batchnorm_options)(x)
        ### reduction done
        
        n_dims = 6
        #exchange information, create coordinates
        x = Concatenate()([c_coords,c_coords,c_coords,coords,x])
        xgn, gncoords, gnnidx, gndist = RaggedGravNet(n_neighbours=64,
                                                 n_dimensions=n_dims,
                                                 n_filters=64,
                                                 n_propagate=64,
                                                 record_metrics=True,
                                                 coord_initialiser_noise=1e-2,
                                                 use_approximate_knn=False #weird issue with that for now
                                                 )([x, rs])
        
        x = Concatenate()([x,xgn])                                                      
        #just keep them in a reasonable range  
        #safeguard against diappearing gradients on coordinates                                       
        gndist = AverageDistanceRegularizer(strength=1e-4,
                                            record_metrics=True
                                            )(gndist)
                                            
        gncoords = PlotCoordinates(plot_debug_every, outdir = debug_outdir,
                                   name='gn_coords_'+str(i))([gncoords, 
                                                                    energy,
                                                                    t_idx,
                                                                    rs]) 
        x = Concatenate()([gncoords,x])           
        
        pre_gndist=gndist
        if double_mp:
            for im,m in enumerate([64,64,32,32,16,16]):
                dscale=Dense(1)(x)
                gndist = LocalDistanceScaling(4.)([pre_gndist,dscale])                                  
                gndist = AverageDistanceRegularizer(strength=1e-6,
                                            record_metrics=True,
                                            name='average_distance_dmp_'+str(i)+'_'+str(im)
                                            )(gndist)
                                            
                x = DistanceWeightedMessagePassing([m],
                                           activation=dense_activation
                                           )([x,gnnidx,gndist])
        else:        
            x = DistanceWeightedMessagePassing([64,64,32,32,16,16],
                                           activation=dense_activation
                                           )([x,gnnidx,gndist])
            
        x = GooeyBatchNorm(**batchnorm_options)(x)
        
        x = Dense(64,name='dense_past_mp_'+str(i),activation=dense_activation)(x)
        x = Dense(64,activation=dense_activation)(x)
        x = Dense(64,activation=dense_activation)(x)
        
        x = GooeyBatchNorm(**batchnorm_options)(x)
        
        
        allfeat.append(x)
        
        
    
    x = Concatenate()([c_coords]+allfeat+[pre_selection['not_noise_score']])
    #do one more exchange with all
    x = Dense(64,activation=dense_activation)(x)
    x = Dense(64,activation=dense_activation)(x)
    x = Dense(64,activation=dense_activation)(x)
    
    
    #######################################################################
    ########### the part below should remain almost unchanged #############
    ########### of course with the exception of the OC loss   #############
    ########### weights                                       #############
    #######################################################################
    
    x = GooeyBatchNorm(**batchnorm_options,name='gooey_pre_out')(x)
    x = Concatenate()([c_coords]+[x])
    
    pred_beta, pred_ccoords, pred_dist, pred_energy_corr, \
    pred_pos, pred_time, pred_id = create_outputs(x, pre_selection['unproc_features'], 
                                                  n_ccoords=n_cluster_space_coordinates)
    
    # loss
    pred_beta = LLFullObjectCondensation(scale=4.,
                                         position_loss_weight=1e-5,
                                         timing_loss_weight=1e-5,
                                         beta_loss_scale=1.,
                                         use_energy_weights=True,
                                         record_metrics=True,
                                         name="FullOCLoss",
                                         **loss_options
                                         )(  # oc output and payload
        [pred_beta, pred_ccoords, pred_dist,
         pred_energy_corr, pred_pos, pred_time, pred_id] +
        [energy]+
        # truth information
        [pre_selection['t_idx'] ,
         pre_selection['t_energy'] ,
         pre_selection['t_pos'] ,
         pre_selection['t_time'] ,
         pre_selection['t_pid'] ,
         pre_selection['t_spectator_weight'],
         pre_selection['t_fully_contained'],
         pre_selection['t_rec_energy'],
         pre_selection['t_is_unique'],
         pre_selection['rs']])
                                         
    #fast feedback
    pred_ccoords = PlotCoordinates(plot_debug_every, outdir = debug_outdir,
                    name='condensation')([pred_ccoords, pred_beta,pre_selection['t_idx'],
                                          rs])                                    

    model_outputs = re_integrate_to_full_hits(
        pre_selection,
        pred_ccoords,
        pred_beta,
        pred_energy_corr,
        pred_pos,
        pred_time,
        pred_id,
        pred_dist,
        dict_output=True
        )
    
    return DictModel(inputs=Inputs, outputs=model_outputs)
Beispiel #3
0
def gravnet_model(
    Inputs,
    td,
    empty_pca=False,
    total_iterations=2,
    variance_only=True,
    viscosity=0.1,
    print_viscosity=False,
    fluidity_decay=5e-4,  # reaches after about 7k batches
    max_viscosity=0.95,
    debug_outdir=None,
    plot_debug_every=1000,
):
    ####################################################################################
    ##################### Input processing, no need to change much here ################
    ####################################################################################

    orig_inputs = td.interpretAllModelInputs(Inputs, returndict=True)

    orig_t_spectator_weight = CreateTruthSpectatorWeights(
        threshold=3., minimum=1e-1,
        active=True)([orig_inputs['t_spectator'], orig_inputs['t_idx']])

    orig_inputs['t_spectator_weight'] = orig_t_spectator_weight
    #can be loaded - or use pre-selected dataset (to be made)
    pre_selection = pre_selection_model_full(orig_inputs, trainable=False)

    #just for info what's available
    print([k for k in pre_selection.keys()])

    t_spectator_weight = CreateTruthSpectatorWeights(
        threshold=3., minimum=1e-1,
        active=True)([pre_selection['t_spectator'], pre_selection['t_idx']])
    rs = pre_selection['rs']

    x_in = Concatenate()([
        pre_selection['coords'], pre_selection['features'],
        pre_selection['addfeat']
    ])

    x = x_in
    energy = pre_selection['energy']
    coords = pre_selection['phys_coords']  #physical coordinates
    c_coords = pre_selection['coords']  #pre-clustered coordinates
    t_idx = pre_selection['t_idx']

    ####################################################################################
    ##################### now the actual model goes below ##############################
    ####################################################################################

    allfeat = []

    n_cluster_space_coordinates = 3

    #extend coordinates already here if needed
    coords = extent_coords_if_needed(coords, x, n_cluster_space_coordinates)

    for i in range(total_iterations):

        # derive new coordinates for clustering
        x = RaggedGlobalExchange()([x, rs])

        x = Dense(64, activation='relu')(x)
        x = Dense(64, activation='relu')(x)
        x = Dense(64, activation='relu')(x)
        x = GooeyBatchNorm(viscosity=viscosity,
                           max_viscosity=max_viscosity,
                           variance_only=variance_only,
                           record_metrics=True,
                           fluidity_decay=fluidity_decay)(x)
        ### reduction done

        n_dims = 3
        #exchange information, create coordinates
        x = Concatenate()([coords, coords, c_coords, x])
        x, gncoords, gnnidx, gndist = RaggedGravNet(
            n_neighbours=64,
            n_dimensions=n_dims,
            n_filters=64,
            n_propagate=64,
            record_metrics=True,
            use_approximate_knn=False  #weird issue with that for now
        )([x, rs])

        gncoords = PlotCoordinates(plot_debug_every,
                                   outdir=debug_outdir,
                                   name='gn_coords_' +
                                   str(i))([gncoords, energy, t_idx, rs])
        #just keep them in a reasonable range
        #safeguard against diappearing gradients on coordinates
        gndist = AverageDistanceRegularizer(strength=0.01,
                                            record_metrics=True)(gndist)
        x = Concatenate()([energy, x])
        x_pca = Dense(4, activation='relu')(x)  #pca is expensive
        x_pca = ApproxPCA(empty=empty_pca)([gncoords, gndist, x_pca, gnnidx])
        x = Concatenate()([x, x_pca])

        x = DistanceWeightedMessagePassing([64, 64, 32, 32, 16,
                                            16])([x, gnnidx, gndist])

        x = GooeyBatchNorm(viscosity=viscosity,
                           max_viscosity=max_viscosity,
                           record_metrics=True,
                           variance_only=variance_only,
                           fluidity_decay=fluidity_decay)(x)

        x = Dense(64, activation='relu')(x)
        x = Dense(64, activation='relu')(x)
        x = Dense(64, activation='relu')(x)

        x = GooeyBatchNorm(viscosity=viscosity,
                           max_viscosity=max_viscosity,
                           variance_only=variance_only,
                           record_metrics=True,
                           fluidity_decay=fluidity_decay)(x)

        allfeat.append(x)

    x = Concatenate()([c_coords] + allfeat +
                      [pre_selection['not_noise_score']])
    #do one more exchange with all
    x = Dense(64, activation='elu')(x)
    x = Dense(64, activation='elu')(x)
    x = Dense(64, activation='elu')(x)

    #######################################################################
    ########### the part below should remain almost unchanged #############
    ########### of course with the exception of the OC loss   #############
    ########### weights                                       #############
    #######################################################################

    x = GooeyBatchNorm(viscosity=viscosity,
                       max_viscosity=max_viscosity,
                       fluidity_decay=fluidity_decay,
                       record_metrics=True,
                       variance_only=variance_only,
                       name='gooey_pre_out')(x)
    x = Concatenate()([c_coords] + [x])

    pred_beta, pred_ccoords, pred_dist, pred_energy_corr, \
    pred_pos, pred_time, pred_id = create_outputs(x, pre_selection['unproc_features'],
                                                  n_ccoords=n_cluster_space_coordinates)

    # loss
    pred_beta = LLFullObjectCondensation(
        scale=1.,
        energy_loss_weight=2.,
        #print_batch_time=True,
        position_loss_weight=1e-5,
        timing_loss_weight=1e-5,
        classification_loss_weight=1e-5,
        beta_loss_scale=1.,
        too_much_beta_scale=1e-4,
        use_energy_weights=True,
        record_metrics=True,
        q_min=0.2,
        #div_repulsion=True,
        # cont_beta_loss=True,
        # beta_gradient_damping=0.999,
        # phase_transition=1,
        #huber_energy_scale=0.1,
        use_average_cc_pos=0.2,  # smoothen it out a bit
        name="FullOCLoss")(  # oc output and payload
            [
                pred_beta, pred_ccoords, pred_dist, pred_energy_corr, pred_pos,
                pred_time, pred_id
            ] + [energy] +
            # truth information
            [
                pre_selection['t_idx'], pre_selection['t_energy'],
                pre_selection['t_pos'], pre_selection['t_time'],
                pre_selection['t_pid'], pre_selection['t_spectator_weight'],
                pre_selection['t_fully_contained'],
                pre_selection['t_rec_energy'], pre_selection['t_is_unique'],
                pre_selection['rs']
            ])

    #fast feedback
    pred_ccoords = PlotCoordinates(plot_debug_every,
                                   outdir=debug_outdir,
                                   name='condensation')([
                                       pred_ccoords, pred_beta,
                                       pre_selection['t_idx'], rs
                                   ])

    model_outputs = re_integrate_to_full_hits(pre_selection,
                                              pred_ccoords,
                                              pred_beta,
                                              pred_energy_corr,
                                              pred_pos,
                                              pred_time,
                                              pred_id,
                                              pred_dist,
                                              dict_output=True)

    return DictModel(inputs=Inputs, outputs=model_outputs)