Exemplo n.º 1
0
def gravnet_model(
    Inputs,
    td,
    empty_pca=False,
    total_iterations=2,
    variance_only=True,
    viscosity=0.1,
    print_viscosity=False,
    fluidity_decay=5e-4,  # reaches after about 7k batches
    max_viscosity=0.95,
    debug_outdir=None,
    plot_debug_every=1000,
):
    ####################################################################################
    ##################### Input processing, no need to change much here ################
    ####################################################################################

    orig_inputs = td.interpretAllModelInputs(Inputs, returndict=True)

    orig_t_spectator_weight = CreateTruthSpectatorWeights(
        threshold=3., minimum=1e-1,
        active=True)([orig_inputs['t_spectator'], orig_inputs['t_idx']])

    orig_inputs['t_spectator_weight'] = orig_t_spectator_weight
    #can be loaded - or use pre-selected dataset (to be made)
    pre_selection = pre_selection_model_full(orig_inputs, trainable=False)

    #just for info what's available
    print([k for k in pre_selection.keys()])

    t_spectator_weight = CreateTruthSpectatorWeights(
        threshold=3., minimum=1e-1,
        active=True)([pre_selection['t_spectator'], pre_selection['t_idx']])
    rs = pre_selection['rs']

    x_in = Concatenate()([
        pre_selection['coords'], pre_selection['features'],
        pre_selection['addfeat']
    ])

    x = x_in
    energy = pre_selection['energy']
    coords = pre_selection['phys_coords']  #physical coordinates
    c_coords = pre_selection['coords']  #pre-clustered coordinates
    t_idx = pre_selection['t_idx']

    ####################################################################################
    ##################### now the actual model goes below ##############################
    ####################################################################################

    allfeat = []

    n_cluster_space_coordinates = 3

    #extend coordinates already here if needed
    coords = extent_coords_if_needed(coords, x, n_cluster_space_coordinates)

    for i in range(total_iterations):

        # derive new coordinates for clustering
        x = RaggedGlobalExchange()([x, rs])

        x = Dense(64, activation='relu')(x)
        x = Dense(64, activation='relu')(x)
        x = Dense(64, activation='relu')(x)
        x = GooeyBatchNorm(viscosity=viscosity,
                           max_viscosity=max_viscosity,
                           variance_only=variance_only,
                           record_metrics=True,
                           fluidity_decay=fluidity_decay)(x)
        ### reduction done

        n_dims = 3
        #exchange information, create coordinates
        x = Concatenate()([coords, coords, c_coords, x])
        x, gncoords, gnnidx, gndist = RaggedGravNet(
            n_neighbours=64,
            n_dimensions=n_dims,
            n_filters=64,
            n_propagate=64,
            record_metrics=True,
            use_approximate_knn=False  #weird issue with that for now
        )([x, rs])

        gncoords = PlotCoordinates(plot_debug_every,
                                   outdir=debug_outdir,
                                   name='gn_coords_' +
                                   str(i))([gncoords, energy, t_idx, rs])
        #just keep them in a reasonable range
        #safeguard against diappearing gradients on coordinates
        gndist = AverageDistanceRegularizer(strength=0.01,
                                            record_metrics=True)(gndist)
        x = Concatenate()([energy, x])
        x_pca = Dense(4, activation='relu')(x)  #pca is expensive
        x_pca = ApproxPCA(empty=empty_pca)([gncoords, gndist, x_pca, gnnidx])
        x = Concatenate()([x, x_pca])

        x = DistanceWeightedMessagePassing([64, 64, 32, 32, 16,
                                            16])([x, gnnidx, gndist])

        x = GooeyBatchNorm(viscosity=viscosity,
                           max_viscosity=max_viscosity,
                           record_metrics=True,
                           variance_only=variance_only,
                           fluidity_decay=fluidity_decay)(x)

        x = Dense(64, activation='relu')(x)
        x = Dense(64, activation='relu')(x)
        x = Dense(64, activation='relu')(x)

        x = GooeyBatchNorm(viscosity=viscosity,
                           max_viscosity=max_viscosity,
                           variance_only=variance_only,
                           record_metrics=True,
                           fluidity_decay=fluidity_decay)(x)

        allfeat.append(x)

    x = Concatenate()([c_coords] + allfeat +
                      [pre_selection['not_noise_score']])
    #do one more exchange with all
    x = Dense(64, activation='elu')(x)
    x = Dense(64, activation='elu')(x)
    x = Dense(64, activation='elu')(x)

    #######################################################################
    ########### the part below should remain almost unchanged #############
    ########### of course with the exception of the OC loss   #############
    ########### weights                                       #############
    #######################################################################

    x = GooeyBatchNorm(viscosity=viscosity,
                       max_viscosity=max_viscosity,
                       fluidity_decay=fluidity_decay,
                       record_metrics=True,
                       variance_only=variance_only,
                       name='gooey_pre_out')(x)
    x = Concatenate()([c_coords] + [x])

    pred_beta, pred_ccoords, pred_dist, pred_energy_corr, \
    pred_pos, pred_time, pred_id = create_outputs(x, pre_selection['unproc_features'],
                                                  n_ccoords=n_cluster_space_coordinates)

    # loss
    pred_beta = LLFullObjectCondensation(
        scale=1.,
        energy_loss_weight=2.,
        #print_batch_time=True,
        position_loss_weight=1e-5,
        timing_loss_weight=1e-5,
        classification_loss_weight=1e-5,
        beta_loss_scale=1.,
        too_much_beta_scale=1e-4,
        use_energy_weights=True,
        record_metrics=True,
        q_min=0.2,
        #div_repulsion=True,
        # cont_beta_loss=True,
        # beta_gradient_damping=0.999,
        # phase_transition=1,
        #huber_energy_scale=0.1,
        use_average_cc_pos=0.2,  # smoothen it out a bit
        name="FullOCLoss")(  # oc output and payload
            [
                pred_beta, pred_ccoords, pred_dist, pred_energy_corr, pred_pos,
                pred_time, pred_id
            ] + [energy] +
            # truth information
            [
                pre_selection['t_idx'], pre_selection['t_energy'],
                pre_selection['t_pos'], pre_selection['t_time'],
                pre_selection['t_pid'], pre_selection['t_spectator_weight'],
                pre_selection['t_fully_contained'],
                pre_selection['t_rec_energy'], pre_selection['t_is_unique'],
                pre_selection['rs']
            ])

    #fast feedback
    pred_ccoords = PlotCoordinates(plot_debug_every,
                                   outdir=debug_outdir,
                                   name='condensation')([
                                       pred_ccoords, pred_beta,
                                       pre_selection['t_idx'], rs
                                   ])

    model_outputs = re_integrate_to_full_hits(pre_selection,
                                              pred_ccoords,
                                              pred_beta,
                                              pred_energy_corr,
                                              pred_pos,
                                              pred_time,
                                              pred_id,
                                              pred_dist,
                                              dict_output=True)

    return DictModel(inputs=Inputs, outputs=model_outputs)
Exemplo n.º 2
0
def gravnet_model(Inputs,
                  td,
                  debug_outdir=None,
                  plot_debug_every=1000,
                  ):
    ####################################################################################
    ##################### Input processing, no need to change much here ################
    ####################################################################################

    orig_inputs = td.interpretAllModelInputs(Inputs,returndict=True)
    
    
    orig_t_spectator_weight = CreateTruthSpectatorWeights(threshold=3.,
                                                     minimum=1e-1,
                                                     active=True
                                                     )([orig_inputs['t_spectator'], 
                                                        orig_inputs['t_idx']])
                                                     
    orig_inputs['t_spectator_weight'] = orig_t_spectator_weight                                                 
    #can be loaded - or use pre-selected dataset (to be made)
    pre_selection = pre_selection_model_full(orig_inputs,trainable=False)
    
    #just for info what's available
    print('available pre-selection outputs',[k for k in pre_selection.keys()])
                                          
    
    t_spectator_weight = CreateTruthSpectatorWeights(threshold=3.,
                                                     minimum=1e-1,
                                                     active=True
                                                     )([pre_selection['t_spectator'], 
                                                        pre_selection['t_idx']])
    rs = pre_selection['rs']
                               
    x_in = Concatenate()([pre_selection['coords'],
                          pre_selection['features'],
                          pre_selection['addfeat']])
                           
    x = x_in
    energy = pre_selection['energy']
    coords = pre_selection['phys_coords']#physical coordinates
    c_coords = pre_selection['coords']#pre-clustered coordinates
    t_idx = pre_selection['t_idx']
    
    ####################################################################################
    ##################### now the actual model goes below ##############################
    ####################################################################################
    
    allfeat = []
    
    n_cluster_space_coordinates = 3
    
    
    #extend coordinates already here if needed
    c_coords = extent_coords_if_needed(c_coords, x, n_cluster_space_coordinates)
        

    for i in range(total_iterations):

        # derive new coordinates for clustering
        x = RaggedGlobalExchange()([x, rs])
        
        x = Dense(64,activation=dense_activation)(x)
        x = Dense(64,activation=dense_activation)(x)
        x = Dense(64,activation=dense_activation)(x)
        x = GooeyBatchNorm(**batchnorm_options)(x)
        ### reduction done
        
        n_dims = 6
        #exchange information, create coordinates
        x = Concatenate()([c_coords,c_coords,c_coords,coords,x])
        xgn, gncoords, gnnidx, gndist = RaggedGravNet(n_neighbours=64,
                                                 n_dimensions=n_dims,
                                                 n_filters=64,
                                                 n_propagate=64,
                                                 record_metrics=True,
                                                 coord_initialiser_noise=1e-2,
                                                 use_approximate_knn=False #weird issue with that for now
                                                 )([x, rs])
        
        x = Concatenate()([x,xgn])                                                      
        #just keep them in a reasonable range  
        #safeguard against diappearing gradients on coordinates                                       
        gndist = AverageDistanceRegularizer(strength=1e-4,
                                            record_metrics=True
                                            )(gndist)
                                            
        gncoords = PlotCoordinates(plot_debug_every, outdir = debug_outdir,
                                   name='gn_coords_'+str(i))([gncoords, 
                                                                    energy,
                                                                    t_idx,
                                                                    rs]) 
        x = Concatenate()([gncoords,x])           
        
        pre_gndist=gndist
        if double_mp:
            for im,m in enumerate([64,64,32,32,16,16]):
                dscale=Dense(1)(x)
                gndist = LocalDistanceScaling(4.)([pre_gndist,dscale])                                  
                gndist = AverageDistanceRegularizer(strength=1e-6,
                                            record_metrics=True,
                                            name='average_distance_dmp_'+str(i)+'_'+str(im)
                                            )(gndist)
                                            
                x = DistanceWeightedMessagePassing([m],
                                           activation=dense_activation
                                           )([x,gnnidx,gndist])
        else:        
            x = DistanceWeightedMessagePassing([64,64,32,32,16,16],
                                           activation=dense_activation
                                           )([x,gnnidx,gndist])
            
        x = GooeyBatchNorm(**batchnorm_options)(x)
        
        x = Dense(64,name='dense_past_mp_'+str(i),activation=dense_activation)(x)
        x = Dense(64,activation=dense_activation)(x)
        x = Dense(64,activation=dense_activation)(x)
        
        x = GooeyBatchNorm(**batchnorm_options)(x)
        
        
        allfeat.append(x)
        
        
    
    x = Concatenate()([c_coords]+allfeat+[pre_selection['not_noise_score']])
    #do one more exchange with all
    x = Dense(64,activation=dense_activation)(x)
    x = Dense(64,activation=dense_activation)(x)
    x = Dense(64,activation=dense_activation)(x)
    
    
    #######################################################################
    ########### the part below should remain almost unchanged #############
    ########### of course with the exception of the OC loss   #############
    ########### weights                                       #############
    #######################################################################
    
    x = GooeyBatchNorm(**batchnorm_options,name='gooey_pre_out')(x)
    x = Concatenate()([c_coords]+[x])
    
    pred_beta, pred_ccoords, pred_dist, pred_energy_corr, \
    pred_pos, pred_time, pred_id = create_outputs(x, pre_selection['unproc_features'], 
                                                  n_ccoords=n_cluster_space_coordinates)
    
    # loss
    pred_beta = LLFullObjectCondensation(scale=4.,
                                         position_loss_weight=1e-5,
                                         timing_loss_weight=1e-5,
                                         beta_loss_scale=1.,
                                         use_energy_weights=True,
                                         record_metrics=True,
                                         name="FullOCLoss",
                                         **loss_options
                                         )(  # oc output and payload
        [pred_beta, pred_ccoords, pred_dist,
         pred_energy_corr, pred_pos, pred_time, pred_id] +
        [energy]+
        # truth information
        [pre_selection['t_idx'] ,
         pre_selection['t_energy'] ,
         pre_selection['t_pos'] ,
         pre_selection['t_time'] ,
         pre_selection['t_pid'] ,
         pre_selection['t_spectator_weight'],
         pre_selection['t_fully_contained'],
         pre_selection['t_rec_energy'],
         pre_selection['t_is_unique'],
         pre_selection['rs']])
                                         
    #fast feedback
    pred_ccoords = PlotCoordinates(plot_debug_every, outdir = debug_outdir,
                    name='condensation')([pred_ccoords, pred_beta,pre_selection['t_idx'],
                                          rs])                                    

    model_outputs = re_integrate_to_full_hits(
        pre_selection,
        pred_ccoords,
        pred_beta,
        pred_energy_corr,
        pred_pos,
        pred_time,
        pred_id,
        pred_dist,
        dict_output=True
        )
    
    return DictModel(inputs=Inputs, outputs=model_outputs)
Exemplo n.º 3
0
def pre_selection_staged(
    indict,
    debug_outdir,
    trainable,
    name='pre_selection_add_stage_0',
    debugplots_after=-1,
    reduction_threshold=0.75,
    use_edges=True,
    print_info=False,
    record_metrics=False,
    n_coords=3,
    edge_nodes_0=16,
    edge_nodes_1=8,
):
    '''
    Takes the output of the preselection model and selects again :)
    But the outputs are compatible, this one can be chained
    
    This one uses full blown GravNet
    
    Gets as inputs:
    
    indict['scatterids']
    indict['orig_t_idx'] 
    indict['orig_t_energy'] 
    indict['orig_dim_coords']
    indict['rs']
    indict['orig_row_splits']
    
    indict['features'] 
    indict['orig_features']
    indict['coords'] 
    indict['addfeat']
    indict['energy']
    
    
    indict['t_idx']
    indict['t_energy']
    ... all the truth info
    
    '''

    from GravNetLayersRagged import RaggedGravNet, DistanceWeightedMessagePassing, ElementScaling
    from GravNetLayersRagged import SelectFromIndices, GooeyBatchNorm, MaskTracksAsNoise
    from GravNetLayersRagged import AccumulateNeighbours, KNN, MultiAttentionGravNetAdd
    from LossLayers import LLClusterCoordinates, LLFillSpace
    from DebugLayers import PlotCoordinates
    from MetricsLayers import MLReductionMetrics
    from Regularizers import MeanMaxDistanceRegularizer, AverageDistanceRegularizer

    #assume the inputs are normalised
    rs = indict['rs']
    t_idx = indict['t_idx']

    track_charge = SelectFeatures(2, 3)(
        indict['unproc_features'])  #zero for calo hits
    x = Concatenate()([indict['features'], indict['addfeat']])
    x = Dense(64, activation='elu', trainable=trainable)(x)
    gn_pre_coords = indict['coords']
    gn_pre_coords = ElementScaling(name=name + 'es1',
                                   trainable=trainable)(gn_pre_coords)
    x = Concatenate()([gn_pre_coords, x])

    x, coords, nidx, dist = RaggedGravNet(n_neighbours=32,
                                          n_dimensions=n_coords,
                                          n_filters=64,
                                          n_propagate=64,
                                          coord_initialiser_noise=1e-5,
                                          feature_activation=None,
                                          record_metrics=record_metrics,
                                          use_approximate_knn=True,
                                          use_dynamic_knn=True,
                                          trainable=trainable,
                                          name=name + '_gn1')([x, rs])

    #the two below are mostly running to record metrics and kill very bad coordinate scalings
    dist = MeanMaxDistanceRegularizer(strength=1e-6 if trainable else 0.,
                                      record_metrics=record_metrics)(dist)

    dist = AverageDistanceRegularizer(strength=1e-6 if trainable else 0.,
                                      record_metrics=record_metrics)(dist)

    if debugplots_after > 0:
        coords = PlotCoordinates(debugplots_after,
                                 outdir=debug_outdir,
                                 name=name + '_gn1_coords')(
                                     [coords, indict['energy'], t_idx, rs])

    x = DistanceWeightedMessagePassing([32, 32, 8, 8],
                                       name=name + 'dmp1',
                                       trainable=trainable)([x, nidx, dist])

    x_matt = Dense(16, activation='elu', name=name + '_matt_dense')(x)

    x_matt = MultiAttentionGravNetAdd(5,
                                      name=name + '_att_gn1',
                                      record_metrics=record_metrics)(
                                          [x, x_matt, coords, nidx])
    x = Concatenate()([x, x_matt])
    x = Dense(64, activation='elu', name=name + '_bef_coord_dense')(x)

    coords = Add()([
        Dense(n_coords,
              name=name + '_coord_add_dense',
              kernel_initializer='zeros')(x), coords
    ])
    if debugplots_after > 0:
        coords = PlotCoordinates(debugplots_after,
                                 outdir=debug_outdir,
                                 name=name + '_red_coords')(
                                     [coords, indict['energy'], t_idx, rs])

    nidx, dist = KNN(
        K=16,
        radius='dynamic',  #use dynamic feature
        record_metrics=record_metrics,
        name=name + '_knn',
        min_bins=[7, 7]  #this can be fine grained
    )([coords, rs])

    coords = LLClusterCoordinates(print_loss=print_info,
                                  record_metrics=record_metrics,
                                  active=trainable,
                                  print_batch_time=False,
                                  scale=5.)([coords, t_idx, rs])

    coords = LLFillSpace(
        active=trainable,
        record_metrics=record_metrics,
        scale=0.025,  #just mild
        runevery=-1,  #give it a kick only every now and then - hat's enough
    )([coords, rs])

    unred_rs = rs

    cluster_tidx = MaskTracksAsNoise(active=trainable)([t_idx, track_charge])

    gnidx, gsel, group_backgather, rs = reduce_indices(
        x,
        dist,
        nidx,
        rs,
        cluster_tidx,
        threshold=reduction_threshold,
        print_reduction=print_info,
        trainable=trainable,
        name=name + '_reduce_indices',
        use_edges=use_edges,
        edge_nodes_0=edge_nodes_0,
        edge_nodes_1=edge_nodes_1,
        return_backscatter=False)

    gsel = MLReductionMetrics(name=name + '_reduction', record_metrics=True)(
        [gsel, t_idx, indict['t_energy'], unred_rs, rs])

    selfeat = SelectFromIndices()([gsel, indict['features']])
    unproc_features = SelectFromIndices()([gsel, indict['unproc_features']])

    energy = indict['energy']

    x = AccumulateNeighbours('minmeanmax')([x, gnidx, energy])
    x = SelectFromIndices()([gsel, x])
    #add more useful things
    coords = AccumulateNeighbours('mean')([coords, gnidx, energy])
    coords = SelectFromIndices()([gsel, coords])
    phys_coords = AccumulateNeighbours('mean')(
        [indict['phys_coords'], gnidx, energy])
    phys_coords = SelectFromIndices()([gsel, phys_coords])

    energy = AccumulateNeighbours('sum')([energy, gnidx])
    energy = SelectFromIndices()([gsel, energy])

    out = {}
    out['not_noise_score'] = AccumulateNeighbours('mean')(
        [indict['not_noise_score'], gnidx])
    out['not_noise_score'] = SelectFromIndices()(
        [gsel, out['not_noise_score']])

    out['scatterids'] = indict['scatterids'] + [group_backgather
                                                ]  #append new selection

    #re-build standard feature layout
    out['features'] = selfeat
    out['unproc_features'] = unproc_features
    out['coords'] = coords
    out['phys_coords'] = phys_coords
    out['addfeat'] = GooeyBatchNorm(name=name + '_gooey_norm',
                                    trainable=trainable)(x)  #norm them
    out['energy'] = energy
    out['rs'] = rs

    for k in indict.keys():
        if 't_' == k[0:2]:
            out[k] = SelectFromIndices()([gsel, indict[k]])

    #some pass throughs:
    out['orig_dim_coords'] = indict['orig_dim_coords']
    out['orig_t_idx'] = indict['orig_t_idx']
    out['orig_t_energy'] = indict['orig_t_energy']
    out['orig_row_splits'] = indict['orig_row_splits']

    #check
    anymissing = False
    for k in indict.keys():
        if not k in out.keys():
            anymissing = True
            print(k, 'missing')
    if anymissing:
        raise ValueError("key not found")

    return out