Beispiel #1
0
def gravnet_model(Inputs,
                  td,
                  debug_outdir=None,
                  plot_debug_every=1000,
                  ):
    ####################################################################################
    ##################### Input processing, no need to change much here ################
    ####################################################################################

    orig_inputs = td.interpretAllModelInputs(Inputs,returndict=True)
    
    
    orig_t_spectator_weight = CreateTruthSpectatorWeights(threshold=3.,
                                                     minimum=1e-1,
                                                     active=True
                                                     )([orig_inputs['t_spectator'], 
                                                        orig_inputs['t_idx']])
                                                     
    orig_inputs['t_spectator_weight'] = orig_t_spectator_weight                                                 
    #can be loaded - or use pre-selected dataset (to be made)
    pre_selection = pre_selection_model_full(orig_inputs,trainable=False)
    
    #just for info what's available
    print('available pre-selection outputs',[k for k in pre_selection.keys()])
                                          
    
    t_spectator_weight = CreateTruthSpectatorWeights(threshold=3.,
                                                     minimum=1e-1,
                                                     active=True
                                                     )([pre_selection['t_spectator'], 
                                                        pre_selection['t_idx']])
    rs = pre_selection['rs']
                               
    x_in = Concatenate()([pre_selection['coords'],
                          pre_selection['features'],
                          pre_selection['addfeat']])
                           
    x = x_in
    energy = pre_selection['energy']
    coords = pre_selection['phys_coords']#physical coordinates
    c_coords = pre_selection['coords']#pre-clustered coordinates
    t_idx = pre_selection['t_idx']
    
    ####################################################################################
    ##################### now the actual model goes below ##############################
    ####################################################################################
    
    allfeat = []
    
    n_cluster_space_coordinates = 3
    
    
    #extend coordinates already here if needed
    c_coords = extent_coords_if_needed(c_coords, x, n_cluster_space_coordinates)
        

    for i in range(total_iterations):

        # derive new coordinates for clustering
        x = RaggedGlobalExchange()([x, rs])
        
        x = Dense(64,activation=dense_activation)(x)
        x = Dense(64,activation=dense_activation)(x)
        x = Dense(64,activation=dense_activation)(x)
        x = GooeyBatchNorm(**batchnorm_options)(x)
        ### reduction done
        
        n_dims = 6
        #exchange information, create coordinates
        x = Concatenate()([c_coords,c_coords,c_coords,coords,x])
        xgn, gncoords, gnnidx, gndist = RaggedGravNet(n_neighbours=64,
                                                 n_dimensions=n_dims,
                                                 n_filters=64,
                                                 n_propagate=64,
                                                 record_metrics=True,
                                                 coord_initialiser_noise=1e-2,
                                                 use_approximate_knn=False #weird issue with that for now
                                                 )([x, rs])
        
        x = Concatenate()([x,xgn])                                                      
        #just keep them in a reasonable range  
        #safeguard against diappearing gradients on coordinates                                       
        gndist = AverageDistanceRegularizer(strength=1e-4,
                                            record_metrics=True
                                            )(gndist)
                                            
        gncoords = PlotCoordinates(plot_debug_every, outdir = debug_outdir,
                                   name='gn_coords_'+str(i))([gncoords, 
                                                                    energy,
                                                                    t_idx,
                                                                    rs]) 
        x = Concatenate()([gncoords,x])           
        
        pre_gndist=gndist
        if double_mp:
            for im,m in enumerate([64,64,32,32,16,16]):
                dscale=Dense(1)(x)
                gndist = LocalDistanceScaling(4.)([pre_gndist,dscale])                                  
                gndist = AverageDistanceRegularizer(strength=1e-6,
                                            record_metrics=True,
                                            name='average_distance_dmp_'+str(i)+'_'+str(im)
                                            )(gndist)
                                            
                x = DistanceWeightedMessagePassing([m],
                                           activation=dense_activation
                                           )([x,gnnidx,gndist])
        else:        
            x = DistanceWeightedMessagePassing([64,64,32,32,16,16],
                                           activation=dense_activation
                                           )([x,gnnidx,gndist])
            
        x = GooeyBatchNorm(**batchnorm_options)(x)
        
        x = Dense(64,name='dense_past_mp_'+str(i),activation=dense_activation)(x)
        x = Dense(64,activation=dense_activation)(x)
        x = Dense(64,activation=dense_activation)(x)
        
        x = GooeyBatchNorm(**batchnorm_options)(x)
        
        
        allfeat.append(x)
        
        
    
    x = Concatenate()([c_coords]+allfeat+[pre_selection['not_noise_score']])
    #do one more exchange with all
    x = Dense(64,activation=dense_activation)(x)
    x = Dense(64,activation=dense_activation)(x)
    x = Dense(64,activation=dense_activation)(x)
    
    
    #######################################################################
    ########### the part below should remain almost unchanged #############
    ########### of course with the exception of the OC loss   #############
    ########### weights                                       #############
    #######################################################################
    
    x = GooeyBatchNorm(**batchnorm_options,name='gooey_pre_out')(x)
    x = Concatenate()([c_coords]+[x])
    
    pred_beta, pred_ccoords, pred_dist, pred_energy_corr, \
    pred_pos, pred_time, pred_id = create_outputs(x, pre_selection['unproc_features'], 
                                                  n_ccoords=n_cluster_space_coordinates)
    
    # loss
    pred_beta = LLFullObjectCondensation(scale=4.,
                                         position_loss_weight=1e-5,
                                         timing_loss_weight=1e-5,
                                         beta_loss_scale=1.,
                                         use_energy_weights=True,
                                         record_metrics=True,
                                         name="FullOCLoss",
                                         **loss_options
                                         )(  # oc output and payload
        [pred_beta, pred_ccoords, pred_dist,
         pred_energy_corr, pred_pos, pred_time, pred_id] +
        [energy]+
        # truth information
        [pre_selection['t_idx'] ,
         pre_selection['t_energy'] ,
         pre_selection['t_pos'] ,
         pre_selection['t_time'] ,
         pre_selection['t_pid'] ,
         pre_selection['t_spectator_weight'],
         pre_selection['t_fully_contained'],
         pre_selection['t_rec_energy'],
         pre_selection['t_is_unique'],
         pre_selection['rs']])
                                         
    #fast feedback
    pred_ccoords = PlotCoordinates(plot_debug_every, outdir = debug_outdir,
                    name='condensation')([pred_ccoords, pred_beta,pre_selection['t_idx'],
                                          rs])                                    

    model_outputs = re_integrate_to_full_hits(
        pre_selection,
        pred_ccoords,
        pred_beta,
        pred_energy_corr,
        pred_pos,
        pred_time,
        pred_id,
        pred_dist,
        dict_output=True
        )
    
    return DictModel(inputs=Inputs, outputs=model_outputs)
def gravnet_model(Inputs,
                  viscosity=0.2,
                  print_viscosity=False,
                  fluidity_decay=1e-3,
                  max_viscosity=0.9 # to start with
                  ):

    feature_dropout=-1.
    addBackGatherInfo=True,

    feat,  t_idx, t_energy, t_pos, t_time, t_pid, t_spectator, t_fully_contained, row_splits = td.interpretAllModelInputs(Inputs)
    orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits
    gidx_orig = CreateGlobalIndices()(feat)

    _, row_splits = RaggedConstructTensor()([feat, row_splits])
    rs = row_splits

    feat_norm = ProcessFeatures()(feat)
    energy = SelectFeatures(0,1)(feat)
    time = SelectFeatures(8,9)(feat_norm)
    orig_coords = SelectFeatures(5,8)(feat)

    x = feat_norm

    allfeat=[x]
    backgatheredids=[]
    gatherids=[]
    backgathered = []
    backgathered_coords = []
    energysums = []

    n_reshape_dimensions=3

    #really simple real coordinates
    coords = ManualCoordTransform()(orig_coords)
    coords = Concatenate()([coords, time])
    coords = Dense(n_reshape_dimensions, use_bias=False, kernel_initializer=tf.keras.initializers.Identity()  )(coords)#just rotation and scaling


    #see whats there
    nidx, dist = KNN(K=24,radius=1.0)([coords,rs])

    x_mp = DistanceWeightedMessagePassing([32,16,16,8])([x,nidx,dist])
    x = Concatenate()([x,x_mp])
    #this is going to be among the most expensive operations:
    x = Dense(64, activation='elu',name='pre_dense_a')(x)
    x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x)

    allfeat.append(x)
    backgathered_coords.append(coords)

    sel_gidx = gidx_orig

    cdist = dist
    ccoords = coords

    total_iterations=5

    fwdbgccoords=None

    for i in range(total_iterations):

        cluster_neighbours = 5
        #derive new coordinates for clustering
        if i:
            ccoords = Add()([
                ccoords,Dense(n_reshape_dimensions,name='newcoords'+str(i),
                                         kernel_initializer='zeros'
                                         )(x)
                ])
            nidx, cdist = KNN(K=5*cluster_neighbours,radius=-1.0)([ccoords,rs])
            #here we use more neighbours to improve learning of the cluster space
            #this can be adjusted in the final trained model to be equal to 'cluster_neighbours'
        cdist = LocalDistanceScaling(max_scale=10.)([cdist, Dense(1,kernel_initializer='zeros')(Concatenate()([x,cdist]))])

        x, rs, bidxs, sel_gidx, energy, x, t_idx, coords, ccoords, cdist = LocalClusterReshapeFromNeighbours2(
                 K=cluster_neighbours,
                 radius=0.1,
                 print_reduction=True,
                 loss_enabled=True,
                 loss_scale = 3.,
                 loss_repulsion=0.8, # it's important to get this right here
                 hier_transforms=[64,32,16],
                 print_loss=False,
                 name='clustering_'+str(i)
                 )([x, cdist, nidx, rs, sel_gidx, energy, x, t_idx, coords, ccoords, cdist, t_idx])

        gatherids.append(bidxs)

        #explicit
        energy = ReduceSumEntirely()(energy)#sums up all contained energy per cluster

        x = Dense(128, activation='elu',name='dense_clc_a'+str(i))(x)
        x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x)
        x = Dense(128, activation='elu',name='dense_clc_b'+str(i))(x)
        x = Dense(64, activation='elu')(x)
        x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x)


        n_dimensions = 3+i #make it plottable
        nneigh = 64+32*i #this will be almost fully connected for last clustering step
        nfilt = 64
        nprop = 64

        x = Concatenate()([coords,x])

        x_gn, coords, nidx, dist = RaggedGravNet(n_neighbours=nneigh,
                                              n_dimensions=n_dimensions,
                                              n_filters=nfilt,
                                              n_propagate=nprop)([x, rs])


        x_pca = Dense(2+4*i)(x)
        x_pca = NeighbourApproxPCA()([coords,dist,x_pca,nidx])
        x_pca = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x_pca)

        x_mp = DistanceWeightedMessagePassing([64,64,32,32])([x,nidx,dist])
        x_mp = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x_mp)

        x = Concatenate()([x,x_pca,x_mp,x_gn])
        #check and compress it all
        x = Dense(128, activation='elu',name='dense_a_'+str(i))(x)
        x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x)
        x = Dense(32+16*i, activation='elu',name='dense_c_'+str(i))(x)
        x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x)


        energysums.append( MultiBackGather()([energy, gatherids]) )#assign energy sum to all cluster components

        allfeat.append(MultiBackGather()([x, gatherids]))

        backgatheredids.append(MultiBackGather()([sel_gidx, gatherids]))
        bgccoords = MultiBackGather()([ccoords, gatherids])
        if fwdbgccoords is None:
            fwdbgccoords = bgccoords
        backgathered_coords.append(bgccoords)




    x = Concatenate(name='allconcat')(allfeat)
    x = Dense(128, activation='elu', name='alldense')(x)
    x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x)

    #this is going to be resource intense, give a good starting point with the last ccoords
    coords = Dense(3,use_bias=False,
                   kernel_initializer=tf.keras.initializers.Identity()
                   )(Concatenate()([ fwdbgccoords ,x]))

    nidx, dist = KNN(K=32,radius=1.0)([coords,row_splits])
    x_mp = DistanceWeightedMessagePassing(8*[8])([x,nidx,dist])#only some but a lot of hops
    x = Concatenate()([x,x_mp])
    #
    backgathered_coords.append(coords)

    x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x)
    x = Dense(128, activation='elu')(x)
    x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x)
    x = Concatenate()([x]+energysums)
    x = Dense(64, activation='elu')(x)
    x = Dense(64, activation='elu')(x)

    pred_beta, pred_ccoords, pred_dist, pred_energy, pred_pos, pred_time, pred_id = create_outputs(x,feat,add_distance_scale=True)

    #loss
    pred_beta = LLFullObjectCondensation(print_loss=True,
                                         energy_loss_weight=1e-4,
                                         position_loss_weight=1e-3,
                                         timing_loss_weight=1e-3,
                                         beta_loss_scale=1.,
                                         repulsion_scaling=5.,
                                         repulsion_q_min=1.,
                                         super_repulsion=False,
                                         q_min=0.5,
                                         use_average_cc_pos=0.5,
                                         prob_repulsion=True,
                                         phase_transition=1,
                                         huber_energy_scale = 3,
                                         alt_potential_norm=True,
                                         beta_gradient_damping=0.,
                                         payload_beta_gradient_damping_strength=0.,
                                         kalpha_damping_strength=0.,#1.,
                                         use_local_distances=True,
                                         name="FullOCLoss"
                                         )([pred_beta, pred_ccoords,
                                            pred_dist,
                                            pred_energy,
                                            pred_pos, pred_time, pred_id,
                                            orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid,
                                            row_splits])

    model_outputs = [('pred_beta', pred_beta), ('pred_ccoords',pred_ccoords),
       ('pred_energy',pred_energy),
       ('pred_pos',pred_pos),
       ('pred_time',pred_time),
       ('pred_id',pred_id),
       ('pred_dist',pred_dist),
       ('row_splits',rs)]

    for i, (x, y) in enumerate(zip(backgatheredids, backgathered_coords)):
        model_outputs.append(('backgatheredids_'+str(i), x))
        model_outputs.append(('backgathered_coords_'+str(i), y))
    return RobustModel(model_inputs=Inputs, model_outputs=model_outputs)
Beispiel #3
0
def gravnet_model(
    Inputs,
    td,
    empty_pca=False,
    total_iterations=2,
    variance_only=True,
    viscosity=0.1,
    print_viscosity=False,
    fluidity_decay=5e-4,  # reaches after about 7k batches
    max_viscosity=0.95,
    debug_outdir=None,
    plot_debug_every=1000,
):
    ####################################################################################
    ##################### Input processing, no need to change much here ################
    ####################################################################################

    orig_inputs = td.interpretAllModelInputs(Inputs, returndict=True)

    orig_t_spectator_weight = CreateTruthSpectatorWeights(
        threshold=3., minimum=1e-1,
        active=True)([orig_inputs['t_spectator'], orig_inputs['t_idx']])

    orig_inputs['t_spectator_weight'] = orig_t_spectator_weight
    #can be loaded - or use pre-selected dataset (to be made)
    pre_selection = pre_selection_model_full(orig_inputs, trainable=False)

    #just for info what's available
    print([k for k in pre_selection.keys()])

    t_spectator_weight = CreateTruthSpectatorWeights(
        threshold=3., minimum=1e-1,
        active=True)([pre_selection['t_spectator'], pre_selection['t_idx']])
    rs = pre_selection['rs']

    x_in = Concatenate()([
        pre_selection['coords'], pre_selection['features'],
        pre_selection['addfeat']
    ])

    x = x_in
    energy = pre_selection['energy']
    coords = pre_selection['phys_coords']  #physical coordinates
    c_coords = pre_selection['coords']  #pre-clustered coordinates
    t_idx = pre_selection['t_idx']

    ####################################################################################
    ##################### now the actual model goes below ##############################
    ####################################################################################

    allfeat = []

    n_cluster_space_coordinates = 3

    #extend coordinates already here if needed
    coords = extent_coords_if_needed(coords, x, n_cluster_space_coordinates)

    for i in range(total_iterations):

        # derive new coordinates for clustering
        x = RaggedGlobalExchange()([x, rs])

        x = Dense(64, activation='relu')(x)
        x = Dense(64, activation='relu')(x)
        x = Dense(64, activation='relu')(x)
        x = GooeyBatchNorm(viscosity=viscosity,
                           max_viscosity=max_viscosity,
                           variance_only=variance_only,
                           record_metrics=True,
                           fluidity_decay=fluidity_decay)(x)
        ### reduction done

        n_dims = 3
        #exchange information, create coordinates
        x = Concatenate()([coords, coords, c_coords, x])
        x, gncoords, gnnidx, gndist = RaggedGravNet(
            n_neighbours=64,
            n_dimensions=n_dims,
            n_filters=64,
            n_propagate=64,
            record_metrics=True,
            use_approximate_knn=False  #weird issue with that for now
        )([x, rs])

        gncoords = PlotCoordinates(plot_debug_every,
                                   outdir=debug_outdir,
                                   name='gn_coords_' +
                                   str(i))([gncoords, energy, t_idx, rs])
        #just keep them in a reasonable range
        #safeguard against diappearing gradients on coordinates
        gndist = AverageDistanceRegularizer(strength=0.01,
                                            record_metrics=True)(gndist)
        x = Concatenate()([energy, x])
        x_pca = Dense(4, activation='relu')(x)  #pca is expensive
        x_pca = ApproxPCA(empty=empty_pca)([gncoords, gndist, x_pca, gnnidx])
        x = Concatenate()([x, x_pca])

        x = DistanceWeightedMessagePassing([64, 64, 32, 32, 16,
                                            16])([x, gnnidx, gndist])

        x = GooeyBatchNorm(viscosity=viscosity,
                           max_viscosity=max_viscosity,
                           record_metrics=True,
                           variance_only=variance_only,
                           fluidity_decay=fluidity_decay)(x)

        x = Dense(64, activation='relu')(x)
        x = Dense(64, activation='relu')(x)
        x = Dense(64, activation='relu')(x)

        x = GooeyBatchNorm(viscosity=viscosity,
                           max_viscosity=max_viscosity,
                           variance_only=variance_only,
                           record_metrics=True,
                           fluidity_decay=fluidity_decay)(x)

        allfeat.append(x)

    x = Concatenate()([c_coords] + allfeat +
                      [pre_selection['not_noise_score']])
    #do one more exchange with all
    x = Dense(64, activation='elu')(x)
    x = Dense(64, activation='elu')(x)
    x = Dense(64, activation='elu')(x)

    #######################################################################
    ########### the part below should remain almost unchanged #############
    ########### of course with the exception of the OC loss   #############
    ########### weights                                       #############
    #######################################################################

    x = GooeyBatchNorm(viscosity=viscosity,
                       max_viscosity=max_viscosity,
                       fluidity_decay=fluidity_decay,
                       record_metrics=True,
                       variance_only=variance_only,
                       name='gooey_pre_out')(x)
    x = Concatenate()([c_coords] + [x])

    pred_beta, pred_ccoords, pred_dist, pred_energy_corr, \
    pred_pos, pred_time, pred_id = create_outputs(x, pre_selection['unproc_features'],
                                                  n_ccoords=n_cluster_space_coordinates)

    # loss
    pred_beta = LLFullObjectCondensation(
        scale=1.,
        energy_loss_weight=2.,
        #print_batch_time=True,
        position_loss_weight=1e-5,
        timing_loss_weight=1e-5,
        classification_loss_weight=1e-5,
        beta_loss_scale=1.,
        too_much_beta_scale=1e-4,
        use_energy_weights=True,
        record_metrics=True,
        q_min=0.2,
        #div_repulsion=True,
        # cont_beta_loss=True,
        # beta_gradient_damping=0.999,
        # phase_transition=1,
        #huber_energy_scale=0.1,
        use_average_cc_pos=0.2,  # smoothen it out a bit
        name="FullOCLoss")(  # oc output and payload
            [
                pred_beta, pred_ccoords, pred_dist, pred_energy_corr, pred_pos,
                pred_time, pred_id
            ] + [energy] +
            # truth information
            [
                pre_selection['t_idx'], pre_selection['t_energy'],
                pre_selection['t_pos'], pre_selection['t_time'],
                pre_selection['t_pid'], pre_selection['t_spectator_weight'],
                pre_selection['t_fully_contained'],
                pre_selection['t_rec_energy'], pre_selection['t_is_unique'],
                pre_selection['rs']
            ])

    #fast feedback
    pred_ccoords = PlotCoordinates(plot_debug_every,
                                   outdir=debug_outdir,
                                   name='condensation')([
                                       pred_ccoords, pred_beta,
                                       pre_selection['t_idx'], rs
                                   ])

    model_outputs = re_integrate_to_full_hits(pre_selection,
                                              pred_ccoords,
                                              pred_beta,
                                              pred_energy_corr,
                                              pred_pos,
                                              pred_time,
                                              pred_id,
                                              pred_dist,
                                              dict_output=True)

    return DictModel(inputs=Inputs, outputs=model_outputs)
Beispiel #4
0
def gravnet_model(Inputs, feature_dropout=-1., addBackGatherInfo=True):

    ######## pre-process all inputs and create global indices etc. No DNN actions here

    feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits = td.interpretAllModelInputs(
        Inputs)

    # feat = Lambda(lambda x: tf.squeeze(x,axis=1)) (feat)

    #tf.print([(t.shape, t.name) for t in [feat,  t_idx, t_energy, t_pos, t_time, t_pid, row_splits]])

    #feat,  t_idx, t_energy, t_pos, t_time, t_pid = NormalizeInputShapes()(
    #    [feat,  t_idx, t_energy, t_pos, t_time, t_pid]
    #    )

    orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits
    gidx_orig = CreateGlobalIndices()(feat)

    _, row_splits = RaggedConstructTensor()([feat, row_splits])
    rs = row_splits

    feat_norm = ProcessFeatures()(
        feat)  #get rid of unit scalings, almost normalise
    feat_norm = BatchNormalization(momentum=0.6)(feat_norm)
    x = feat_norm

    energy = SelectFeatures(0, 1)(feat)
    time = SelectFeatures(8, 9)(feat)
    orig_coords = SelectFeatures(5, 8)(feat_norm)

    ######## create output lists

    allfeat = []

    backgatheredids = []
    gatherids = []
    backgathered = []
    backgathered_coords = []

    ####### create simple first coordinate transformation explicitly (time critical)

    #trans_coords = ManualCoordTransform()(orig_coords)
    coords = orig_coords
    coords = Dense(3,
                   use_bias=False,
                   kernel_initializer=tf.keras.initializers.Identity())(coords)

    nidx, dist = KNN(K=48)([coords, rs])

    first_nidx = nidx
    first_dist = dist
    first_coords = coords

    dist = LocalDistanceScaling(max_scale=10.)([dist, Dense(1)(x)])

    x_mp = DistanceWeightedMessagePassing([32, 16, 8])([x, nidx, dist])
    x_mp = Dense(32, activation='elu')(x_mp)
    x_mp = BatchNormalization(momentum=0.6)(x_mp)
    x = Concatenate()([x, x_mp])

    ###### collect information about the surrounding energy and time distributions per vertex ###
    ncov = Dense(4, kernel_initializer=tf.keras.initializers.Identity())(
        Concatenate()([energy, time, x]))
    ncov = NeighbourCovariance()([coords, dist, ncov, nidx])
    ncov = Dense(24, activation='elu', name='pre_dense_ncov_c')(ncov)
    ncov = BatchNormalization(momentum=0.6)(ncov)
    #should be enough for PCA, total info: 2 * (9+3)

    ##### put together and process ####

    x = Concatenate()([x, x_mp, ncov])
    #x = Dense(64, activation='elu',name='pre_dense_a')(x)
    x = Dense(64, activation='elu', name='pre_dense_b')(x)

    ####### add first set of outputs to output lists

    allfeat.append(x)
    backgathered_coords.append(coords)

    total_iterations = 5

    sel_gidx = gidx_orig

    for i in range(total_iterations):

        ###### reshape the graph to fewer vertices ####

        hier = Dense(1)(x)
        #narrow the local scaling down at each iteration as the coords become more abstract
        dist = LocalDistanceScaling(max_scale=10.)(
            [dist, Dense(1)(Concatenate()([x, dist]))])

        x_cl, rs, bidxs, sel_gidx, energy, x, t_idx, coords = LocalClusterReshapeFromNeighbours(
            K=6,
            radius=
            0.2,  #doesn't really have an effect because of local distance scaling
            print_reduction=True,
            loss_enabled=True,
            loss_scale=1.,
            loss_repulsion=0.4,
            print_loss=True,
            name='clustering_' + str(i))([
                x, dist, hier, nidx, rs, sel_gidx, energy, x, t_idx, coords,
                t_idx
            ])  #last is truth index used by layer

        gatherids.append(bidxs)

        energy = ReduceSumEntirely()(energy)

        #use EdgeConv operation to determine cluster properties
        if True or i:  #only after second iteration because of OOM
            x_cl = Reshape([-1, x.shape[-1]])(x_cl)  #get to shape V x K x F
            x_cl = EdgeConvStatic([64, 64, 64],
                                  add_mean=True,
                                  name="ec_static_" + str(i))(x_cl)
            x_cl = Concatenate()([x, x_cl])

        x = Concatenate()([x_cl, StopGradient()(coords)])

        x = Dense(64, activation='elu', name='dense_clc0_' + str(i))(x)
        x = Dense(64, activation='elu', name='dense_clc1_' + str(i))(x)

        x = BatchNormalization(momentum=0.6)(x)
        #notice last relu for feature weighting later

        ### now these are the new cluster features, up for the next iteration of building new latent space
        #x = RaggedGlobalExchange()([x,rs])

        x_gn, coords, nidx, dist = RaggedGravNet(
            n_neighbours=64 + 32 * i,
            n_dimensions=3,
            n_filters=64,
            n_propagate=64,
            return_self=True)([Concatenate()([coords, x]), rs])

        ng_coords = StopGradient()(coords)
        ng_dist = StopGradient()(dist)

        x_gn = Dense(32, activation='elu')(x_gn)

        ### add neighbour summary statistics
        #dist = LocalDistanceScaling(max_scale=10.)([dist, Dense(1)(x_gn)])

        x_ncov = Dense(16)(x)
        x_ncov = NeighbourCovariance()([ng_coords, ng_dist, x_ncov, nidx])
        x_ncov = Dense(64, activation='elu',
                       name='dense_ncov_a_' + str(i))(x_ncov)
        x_ncov = Dense(64, activation='elu',
                       name='dense_ncov_b_' + str(i))(x_ncov)
        x_ncov = BatchNormalization(momentum=0.6)(x_ncov)

        ### with all this information perform a few message passing steps
        x_mp = x
        x_mp = DistanceWeightedMessagePassing([32, 16,
                                               8])([x_mp, nidx, ng_dist])
        x_mp = Dense(64, activation='elu')(x_mp)
        x_mp = Dense(32, activation='elu')(x_mp)
        x_mp = BatchNormalization(momentum=0.6)(x_mp)

        x = Concatenate()([x, x_mp, x_ncov, x_gn, ng_coords, ng_dist])

        ##### prepare output of this iteration

        x = Dense(64, activation='elu', name='dense_out_a_' + str(i))(x)
        x = Dense(64, activation='elu', name='dense_out_b_' + str(i))(x)
        x = BatchNormalization(momentum=0.6)(x)

        #### compress further for output, but forward fill 64 feature x to next iteration

        x_r = Dense(32, activation='elu', name='dense_out_c_' + str(i))(x)
        #x_r = Concatenate()([x_r, ng_coords, ng_dist])

        #x_r = Concatenate()([StopGradient()(coords),x_r]) ## add coordinates, might come handy for cluster space

        if i >= total_iterations - 1:
            energy = MultiBackGather()(
                [energy,
                 gatherids])  #assign energy sum to all cluster components

        allfeat.append(MultiBackGather()([x_r, gatherids]))
        backgatheredids.append(MultiBackGather()([sel_gidx, gatherids]))
        backgathered_coords.append(MultiBackGather()([coords, gatherids]))

    x = Concatenate(name='allconcat')(allfeat)
    x = RaggedGlobalExchange()([x, row_splits])
    x = Dense(256, activation='elu', name='globalexdense')(x)
    x = RaggedGlobalExchange()([x, row_splits])

    x = Dense(128, activation='elu', name='alldense')(x)
    x = Dense(64, activation='elu')(x)
    x = Dense(32, activation='elu')(x)
    x = Concatenate()([first_coords, x])
    x = BatchNormalization(momentum=0.6)(x)

    pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id = create_outputs(
        x, feat)

    #
    #
    # double scale phase transition with linear beta + qmin
    #  -> more high beta points, but: payload loss will still scale one
    #     (or two, but then doesn't matter)
    #

    pred_beta = LLFullObjectCondensation(
        print_loss=True,
        energy_loss_weight=1e-5,
        position_loss_weight=1e-5,  #seems broken
        timing_loss_weight=1e-5,  #1e-3,
        beta_loss_scale=3.,
        repulsion_scaling=1.,
        q_min=2.0,
        use_average_cc_pos=False,
        use_energy_weights=True,
        prob_repulsion=True,
        phase_transition=True,
        phase_transition_double_weight=False,
        alt_potential_norm=True,
        cut_payload_beta_gradient=False)([
            pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id,
            orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid,
            row_splits
        ])

    return Model(inputs=Inputs,
                 outputs=[
                     pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time,
                     pred_id, rs
                 ] + backgatheredids + backgathered_coords)
Beispiel #5
0
def pre_selection_staged(
    indict,
    debug_outdir,
    trainable,
    name='pre_selection_add_stage_0',
    debugplots_after=-1,
    reduction_threshold=0.75,
    use_edges=True,
    print_info=False,
    record_metrics=False,
    n_coords=3,
    edge_nodes_0=16,
    edge_nodes_1=8,
):
    '''
    Takes the output of the preselection model and selects again :)
    But the outputs are compatible, this one can be chained
    
    This one uses full blown GravNet
    
    Gets as inputs:
    
    indict['scatterids']
    indict['orig_t_idx'] 
    indict['orig_t_energy'] 
    indict['orig_dim_coords']
    indict['rs']
    indict['orig_row_splits']
    
    indict['features'] 
    indict['orig_features']
    indict['coords'] 
    indict['addfeat']
    indict['energy']
    
    
    indict['t_idx']
    indict['t_energy']
    ... all the truth info
    
    '''

    from GravNetLayersRagged import RaggedGravNet, DistanceWeightedMessagePassing, ElementScaling
    from GravNetLayersRagged import SelectFromIndices, GooeyBatchNorm, MaskTracksAsNoise
    from GravNetLayersRagged import AccumulateNeighbours, KNN, MultiAttentionGravNetAdd
    from LossLayers import LLClusterCoordinates, LLFillSpace
    from DebugLayers import PlotCoordinates
    from MetricsLayers import MLReductionMetrics
    from Regularizers import MeanMaxDistanceRegularizer, AverageDistanceRegularizer

    #assume the inputs are normalised
    rs = indict['rs']
    t_idx = indict['t_idx']

    track_charge = SelectFeatures(2, 3)(
        indict['unproc_features'])  #zero for calo hits
    x = Concatenate()([indict['features'], indict['addfeat']])
    x = Dense(64, activation='elu', trainable=trainable)(x)
    gn_pre_coords = indict['coords']
    gn_pre_coords = ElementScaling(name=name + 'es1',
                                   trainable=trainable)(gn_pre_coords)
    x = Concatenate()([gn_pre_coords, x])

    x, coords, nidx, dist = RaggedGravNet(n_neighbours=32,
                                          n_dimensions=n_coords,
                                          n_filters=64,
                                          n_propagate=64,
                                          coord_initialiser_noise=1e-5,
                                          feature_activation=None,
                                          record_metrics=record_metrics,
                                          use_approximate_knn=True,
                                          use_dynamic_knn=True,
                                          trainable=trainable,
                                          name=name + '_gn1')([x, rs])

    #the two below are mostly running to record metrics and kill very bad coordinate scalings
    dist = MeanMaxDistanceRegularizer(strength=1e-6 if trainable else 0.,
                                      record_metrics=record_metrics)(dist)

    dist = AverageDistanceRegularizer(strength=1e-6 if trainable else 0.,
                                      record_metrics=record_metrics)(dist)

    if debugplots_after > 0:
        coords = PlotCoordinates(debugplots_after,
                                 outdir=debug_outdir,
                                 name=name + '_gn1_coords')(
                                     [coords, indict['energy'], t_idx, rs])

    x = DistanceWeightedMessagePassing([32, 32, 8, 8],
                                       name=name + 'dmp1',
                                       trainable=trainable)([x, nidx, dist])

    x_matt = Dense(16, activation='elu', name=name + '_matt_dense')(x)

    x_matt = MultiAttentionGravNetAdd(5,
                                      name=name + '_att_gn1',
                                      record_metrics=record_metrics)(
                                          [x, x_matt, coords, nidx])
    x = Concatenate()([x, x_matt])
    x = Dense(64, activation='elu', name=name + '_bef_coord_dense')(x)

    coords = Add()([
        Dense(n_coords,
              name=name + '_coord_add_dense',
              kernel_initializer='zeros')(x), coords
    ])
    if debugplots_after > 0:
        coords = PlotCoordinates(debugplots_after,
                                 outdir=debug_outdir,
                                 name=name + '_red_coords')(
                                     [coords, indict['energy'], t_idx, rs])

    nidx, dist = KNN(
        K=16,
        radius='dynamic',  #use dynamic feature
        record_metrics=record_metrics,
        name=name + '_knn',
        min_bins=[7, 7]  #this can be fine grained
    )([coords, rs])

    coords = LLClusterCoordinates(print_loss=print_info,
                                  record_metrics=record_metrics,
                                  active=trainable,
                                  print_batch_time=False,
                                  scale=5.)([coords, t_idx, rs])

    coords = LLFillSpace(
        active=trainable,
        record_metrics=record_metrics,
        scale=0.025,  #just mild
        runevery=-1,  #give it a kick only every now and then - hat's enough
    )([coords, rs])

    unred_rs = rs

    cluster_tidx = MaskTracksAsNoise(active=trainable)([t_idx, track_charge])

    gnidx, gsel, group_backgather, rs = reduce_indices(
        x,
        dist,
        nidx,
        rs,
        cluster_tidx,
        threshold=reduction_threshold,
        print_reduction=print_info,
        trainable=trainable,
        name=name + '_reduce_indices',
        use_edges=use_edges,
        edge_nodes_0=edge_nodes_0,
        edge_nodes_1=edge_nodes_1,
        return_backscatter=False)

    gsel = MLReductionMetrics(name=name + '_reduction', record_metrics=True)(
        [gsel, t_idx, indict['t_energy'], unred_rs, rs])

    selfeat = SelectFromIndices()([gsel, indict['features']])
    unproc_features = SelectFromIndices()([gsel, indict['unproc_features']])

    energy = indict['energy']

    x = AccumulateNeighbours('minmeanmax')([x, gnidx, energy])
    x = SelectFromIndices()([gsel, x])
    #add more useful things
    coords = AccumulateNeighbours('mean')([coords, gnidx, energy])
    coords = SelectFromIndices()([gsel, coords])
    phys_coords = AccumulateNeighbours('mean')(
        [indict['phys_coords'], gnidx, energy])
    phys_coords = SelectFromIndices()([gsel, phys_coords])

    energy = AccumulateNeighbours('sum')([energy, gnidx])
    energy = SelectFromIndices()([gsel, energy])

    out = {}
    out['not_noise_score'] = AccumulateNeighbours('mean')(
        [indict['not_noise_score'], gnidx])
    out['not_noise_score'] = SelectFromIndices()(
        [gsel, out['not_noise_score']])

    out['scatterids'] = indict['scatterids'] + [group_backgather
                                                ]  #append new selection

    #re-build standard feature layout
    out['features'] = selfeat
    out['unproc_features'] = unproc_features
    out['coords'] = coords
    out['phys_coords'] = phys_coords
    out['addfeat'] = GooeyBatchNorm(name=name + '_gooey_norm',
                                    trainable=trainable)(x)  #norm them
    out['energy'] = energy
    out['rs'] = rs

    for k in indict.keys():
        if 't_' == k[0:2]:
            out[k] = SelectFromIndices()([gsel, indict[k]])

    #some pass throughs:
    out['orig_dim_coords'] = indict['orig_dim_coords']
    out['orig_t_idx'] = indict['orig_t_idx']
    out['orig_t_energy'] = indict['orig_t_energy']
    out['orig_row_splits'] = indict['orig_row_splits']

    #check
    anymissing = False
    for k in indict.keys():
        if not k in out.keys():
            anymissing = True
            print(k, 'missing')
    if anymissing:
        raise ValueError("key not found")

    return out
Beispiel #6
0
def gravnet_model(Inputs, feature_dropout=-1., addBackGatherInfo=True):

    feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits = td.interpretAllModelInputs(
        Inputs)
    orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits
    gidx_orig = CreateGlobalIndices()(feat)

    _, row_splits = RaggedConstructTensor()([feat, row_splits])
    rs = row_splits

    feat_norm = ProcessFeatures()(feat)
    feat_norm = BatchNormalization(momentum=0.6)(feat_norm)
    allfeat = []
    x = feat_norm

    backgatheredids = []
    gatherids = []
    backgathered = []
    backgathered_coords = []
    energysums = []

    #really simple real coordinates
    energy = SelectFeatures(0, 1)(feat)
    orig_coords = SelectFeatures(4, 7)(feat_norm)
    coords = Dense(3,
                   use_bias=False,
                   kernel_initializer=tf.keras.initializers.Identity())(
                       orig_coords)  #just rotation and scaling

    #see whats there
    nidx, dist = KNN(K=64, radius=1.0)([coords, rs])

    x = RaggedGlobalExchange()([x, row_splits])
    x_c = Dense(32, activation='elu')(x)
    x_c = Dense(8)(x_c)  #just a few features are enough here
    #this can be full blown because of the small number of input features
    x_c = NeighbourApproxPCA(hidden_nodes=[32, 32, 9])(
        [coords, dist, x_c, nidx])
    x_mp = DistanceWeightedMessagePassing([32, 16, 8])([x, nidx, dist])
    x = Concatenate()([x, x_c, x_mp])
    #this is going to be among the most expensive operations:
    x = Dense(64, activation='elu', name='pre_dense_a')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(32, activation='elu', name='pre_dense_b')(x)
    x = BatchNormalization(momentum=0.6)(x)

    allfeat.append(x)
    backgathered_coords.append(coords)

    sel_gidx = gidx_orig

    cdist = dist
    ccoords = coords

    total_iterations = 5

    collectedccoords = []

    for i in range(total_iterations):

        cluster_neighbours = 5
        n_dimensions = 3  #make it plottable
        #derive new coordinates for clustering
        if i:
            ccoords = Add()([
                ccoords,
                (Dense(n_dimensions,
                       use_bias=False,
                       name='newcoords' + str(i),
                       kernel_initializer='zeros')(x))
            ])
            nidx, cdist = KNN(K=6 * cluster_neighbours,
                              radius=-1.0)([ccoords, rs])
            #here we use more neighbours to improve learning of the cluster space

        #cluster first
        #hier = Dense(1,activation='sigmoid')(x)
        #distcorr = Dense(dist.shape[-1],activation='relu',kernel_initializer='zeros')(Concatenate()([x,dist]))
        #dist = Add()([distcorr,dist])
        cdist = LocalDistanceScaling(max_scale=5.)([
            cdist,
            Dense(1, kernel_initializer='zeros')(Concatenate()([x, cdist]))
        ])
        #what if we let the individual distances scale here? so move hits in and out?

        x_cl, rs, bidxs, sel_gidx, energy, x, t_idx, coords, ccoords, cdist = LocalClusterReshapeFromNeighbours2(
            K=cluster_neighbours,
            radius=0.1,
            print_reduction=True,
            loss_enabled=True,
            loss_scale=1.,
            loss_repulsion=0.4,  #.5
            hier_transforms=[64, 32, 32],
            print_loss=True,
            name='clustering_' + str(i))([
                x, cdist, nidx, rs, sel_gidx, energy, x, t_idx, coords,
                ccoords, cdist, t_idx
            ])

        gatherids.append(bidxs)

        #explicit
        energy = ReduceSumEntirely()(
            energy)  #sums up all contained energy per cluster
        n_energy = BatchNormalization(momentum=0.6)(energy)

        x = Concatenate()([x_cl, n_energy])
        x = Dense(128, activation='elu', name='dense_clc_a' + str(i))(x)
        x = BatchNormalization(momentum=0.6)(x)
        #x = Dense(128, activation='elu',name='dense_clc_b'+str(i))(x)
        x = Dense(64, activation='elu')(x)
        x = BatchNormalization(momentum=0.6)(x)

        nneigh = 64
        nfilt = 64
        nprop = 64

        x = Concatenate()([coords, x])

        x_gn, coords, nidx, dist = RaggedGravNet(n_neighbours=nneigh,
                                                 n_dimensions=n_dimensions,
                                                 n_filters=nfilt,
                                                 n_propagate=nprop)([x, rs])

        x_sp = Dense(16)(x)
        x_sp = NeighbourApproxPCA(hidden_nodes=[32, 32, n_dimensions**2])(
            [coords, dist, x_sp, nidx])
        x_sp = BatchNormalization(momentum=0.6)(x_sp)

        x_mp = DistanceWeightedMessagePassing([32, 32, 16, 16, 8,
                                               8])([x, nidx, dist])
        x_mp = BatchNormalization(momentum=0.6)(x_mp)
        #x_sp=x_mp

        x = Concatenate()([x, x_mp, x_sp, x_gn])
        #check and compress it all
        x = Dense(128, activation='elu', name='dense_a_' + str(i))(x)
        x = BatchNormalization(momentum=0.6)(x)
        #x = Dense(128, activation='elu',name='dense_b_'+str(i))(x)
        x = Dense(64, activation='elu', name='dense_c_' + str(i))(x)
        x = Concatenate()([StopGradient()(ccoords), StopGradient()(cdist), x])
        x = BatchNormalization(momentum=0.6)(x)

        #record more and more the deeper we go
        x_r = x
        energysums.append(MultiBackGather()(
            [energy, gatherids]))  #assign energy sum to all cluster components

        allfeat.append(MultiBackGather()([x_r, gatherids]))

        backgatheredids.append(MultiBackGather()([sel_gidx, gatherids]))
        backgathered_coords.append(MultiBackGather()([ccoords, gatherids]))

    x = Concatenate(name='allconcat')(allfeat)
    x = RaggedGlobalExchange()([x, row_splits])
    #x - Dropout(0.3)(x)#force to use different information sources
    x = Dense(128, activation='elu', name='alldense')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(64, activation='elu')(x)
    #x = Concatenate()([x]+energysums)

    pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id = create_outputs(
        x, feat)

    #loss
    pred_beta = LLFullObjectCondensation(
        print_loss=True,
        energy_loss_weight=1e-4,
        position_loss_weight=1e-2,
        timing_loss_weight=1e-3,
        beta_loss_scale=1.,
        repulsion_scaling=1.,
        q_min=1.5,
        use_average_cc_pos=0.1,
        prob_repulsion=True,
        phase_transition=1,
        alt_potential_norm=True,
        kalpha_damping_strength=0.5,  #1.,
        name="FullOCLoss")([
            pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id,
            orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid,
            row_splits
        ])

    return Model(inputs=Inputs,
                 outputs=[
                     pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time,
                     pred_id, rs
                 ] + backgatheredids + backgathered_coords)
Beispiel #7
0
def gravnet_model(Inputs, feature_dropout=-1.):
    nregressions = 5

    # I_data = tf.keras.Input(shape=(num_features,), dtype="float32")
    # I_splits = tf.keras.Input(shape=(1,), dtype="int32")

    I_feat = Inputs[0]
    I_truth = Inputs[2]
    I_splits = tf.cast(Inputs[1], tf.int32)

    x_data, x_row_splits = RaggedConstructTensor(name="construct_ragged")(
        [I_feat, I_splits])

    input_features = x_data  # these are going to be passed through for the loss

    x_basic = BatchNormalization(momentum=0.6)(
        x_data)  # mask_and_norm is just batch norm now
    x = x_basic
    x = RaggedGlobalExchange(name="global_exchange")([x, x_row_splits])
    x = Dense(64, activation='elu', name="dense_start")(x)

    n_filters = 0
    n_gravnet_layers = 4
    feat = [x_basic, x]
    for i in range(n_gravnet_layers):
        n_filters = 196
        n_propagate = 128
        n_propagate_2 = [64, 32, 16, 8, 4, 2]
        n_neighbours = 32
        n_dim = 8
        #if n_dim < 2:
        #    n_dim = 2

        x, coords, neighbor_indices, neighbor_dist = RaggedGravNet(
            n_neighbours=n_neighbours,
            n_dimensions=n_dim,
            n_filters=n_filters,
            n_propagate=n_propagate,
            name='gravnet_' + str(i))([x, x_row_splits])
        x = DistanceWeightedMessagePassing(n_propagate_2)(
            [x, neighbor_indices, neighbor_dist])

        x = BatchNormalization(momentum=0.6)(x)
        x = Dense(128, activation='elu',
                  name="dense_bottom_" + str(i) + "_a")(x)
        x = BatchNormalization(momentum=0.6, name="bn_a_" + str(i))(x)
        x = Dense(96, activation='elu',
                  name="dense_bottom_" + str(i) + "_b")(x)
        x = RaggedGlobalExchange(name="global_exchange_bot_" +
                                 str(i))([x, x_row_splits])
        x = Dense(96, activation='elu',
                  name="dense_bottom_" + str(i) + "_c")(x)
        x = BatchNormalization(momentum=0.6, name="bn_b_" + str(i))(x)

        feat.append(x)

    x = Concatenate(name="concat_gravout")(feat)
    x = Dense(128, activation='elu', name="dense_last_a")(x)
    x = BatchNormalization(momentum=0.6, name="bn_last_a")(x)
    x = Dense(128, activation='elu', name="dense_last_a1")(x)
    x = BatchNormalization(momentum=0.6, name="bn_last_a1")(x)
    x = Dense(128, activation='elu', name="dense_last_a2")(x)
    x = BatchNormalization(momentum=0.6, name="bn_last_a2")(x)
    x = Dense(64, activation='elu', name="dense_last_b")(x)
    x = Dense(64, activation='elu', name="dense_last_c")(x)

    x = create_output_layers(x,
                             x_row_splits,
                             n_ccoords=n_ccoords,
                             scale_exp_e=False)

    truth_dict = create_ragged_cal_truth_dict(I_truth)
    pred_dict = create_ragged_cal_pred_dict(x, n_ccoords=n_ccoords)
    feat_dict = create_ragged_cal_feature_dict(I_feat)

    x = LLObjectCondensation()(
        [x, truth_dict, pred_dict, feat_dict, x_row_splits])

    # outputs = tf.tuple([predictions, x_row_splits])

    return Model(inputs=Inputs, outputs=x)
Beispiel #8
0
def gravnet_model(Inputs, feature_dropout=-1., addBackGatherInfo=True):

    feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits = td.interpretAllModelInputs(
        Inputs)
    orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits
    gidx_orig = CreateGlobalIndices()(feat)

    _, row_splits = RaggedConstructTensor()([feat, row_splits])
    rs = row_splits

    feat_norm = ProcessFeatures()(feat)
    allfeat = [feat_norm]
    x = feat_norm

    backgatheredids = []
    gatherids = []
    backgathered = []
    backgathered_coords = []

    energy = SelectFeatures(0, 1)(feat)
    orig_coords = SelectFeatures(5, 8)(feat_norm)
    coords = orig_coords
    coords = Dense(3)(coords)  #just rotation and scaling
    #add gradient

    #see whats there
    nidx, dist = KNN(K=96, radius=1.0)([coords, rs])
    x = Dense(4, activation='elu',
              name='pre_pre_dense')(x)  #just a few features are enough here
    #this can be full blown because of the small number of input features
    x_c = SoftPixelCNN(mode='full', subdivisions=4,
                       name='prePCNN')([coords, x, dist, nidx])
    x = Concatenate()([x, x_c])
    #this is going to be among the most expensive operations:
    x = Dense(128, activation='elu', name='pre_dense_a')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(64, activation='elu', name='pre_dense_b')(x)
    x = BatchNormalization(momentum=0.6)(x)

    #allfeat.append(Dense(16, activation='elu',name='feat_compress_pre')(x))
    backgathered_coords.append(coords)

    sel_gidx = gidx_orig

    total_iterations = 5

    for i in range(total_iterations):

        #cluster first
        hier = Dense(1, activation='sigmoid')(x)
        x_cl, rs, bidxs, sel_gidx, energy, x, t_idx = LocalClusterReshapeFromNeighbours(
            K=8,
            radius=0.1,
            print_reduction=True,
            loss_enabled=True,
            loss_scale=2.,
            loss_repulsion=0.3,
            print_loss=True,
            name='clustering_' + str(i))(
                [x, dist, hier, nidx, rs, sel_gidx, energy, x, t_idx, t_idx])

        #explicit
        energy = ReduceSumEntirely()(
            energy)  #sums up all contained energy per cluster

        gatherids.append(bidxs)
        x_cl = Dense(128, activation='elu', name='dense_clc_' + str(i))(x_cl)
        n_energy = BatchNormalization(momentum=0.6)(energy)
        x = Concatenate()([x, x_cl, n_energy])

        pixelcompress = 4
        nneigh = 32 + 4 * i
        nfilt = 32 + 4 * i
        nprop = 32
        n_dimensions = 3  #make it plottable

        x_gn, coords, nidx, dist = RaggedGravNet(n_neighbours=nneigh,
                                                 n_dimensions=n_dimensions,
                                                 n_filters=nfilt,
                                                 n_propagate=nprop)([x, rs])

        subdivisions = 4

        #add more shape information
        x_sp = Dense(pixelcompress, activation='elu')(x)
        x_sp = BatchNormalization(momentum=0.6)(x_sp)
        x_sp = SoftPixelCNN(mode='full',
                            subdivisions=4)([coords, x_sp, dist, nidx])
        x_sp = Dense(128, activation='elu', name='dense_spc_' + str(i))(x_sp)

        x_mp = MessagePassing([32, 32, 16, 16, 8, 8])([x, nidx])
        x_mp = Dense(nfilt, activation='elu', name='dense_mpc_' + str(i))(x_sp)

        x = Concatenate()([x, x_sp, x_gn, x_mp])
        #check and compress it all
        x = Dense(128, activation='elu', name='dense_a_' + str(i))(x)
        x = Dense(64, activation='elu', name='dense_b_' + str(i))(x)
        x = BatchNormalization(momentum=0.6)(x)

        x_r = x
        #record more and more the deeper we go
        if i < total_iterations - 1:
            x_r = Dense(12 * (i + 1),
                        activation='elu',
                        name='dense_rec_' + str(i))(x)
        else:
            energy = MultiBackGather()(
                [energy,
                 gatherids])  #assign energy sum to all cluster components

        allfeat.append(MultiBackGather()([x_r, gatherids]))

        backgatheredids.append(MultiBackGather()([sel_gidx, gatherids]))
        backgathered_coords.append(MultiBackGather()([coords, gatherids]))

    x = Concatenate(name='allconcat')(allfeat)
    #x - Dropout(0.3)(x)#force to use different information sources
    x = Dense(128, activation='elu', name='alldense')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(64, activation='elu')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Concatenate()([feat, x])
    x = BatchNormalization(momentum=0.6)(x)
    x = Concatenate()([x, energy])

    pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id = create_outputs(
        x, feat)

    #loss
    pred_beta = LLFullObjectCondensation(
        print_loss=True,
        energy_loss_weight=1e-4,
        position_loss_weight=1e-2,
        timing_loss_weight=1e-3,
        beta_loss_scale=1.,
        repulsion_scaling=1.,
        q_min=1.5,
        prob_repulsion=True,
        phase_transition=1,
        alt_potential_norm=True)([
            pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id,
            orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid,
            row_splits
        ])

    return Model(inputs=Inputs,
                 outputs=[
                     pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time,
                     pred_id, rs
                 ] + backgatheredids + backgathered_coords)
Beispiel #9
0
def gravnet_model(Inputs, feature_dropout=-1., addBackGatherInfo=True):

    feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits = td.interpretAllModelInputs(
        Inputs)
    orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits
    gidx_orig = CreateGlobalIndices()(feat)
    energy = SelectFeatures(0, 1)(feat)

    _, row_splits = RaggedConstructTensor()([feat, row_splits])
    rs = row_splits

    #n_cluster_coords=6

    x = feat  #Dense(64,activation='elu')(feat)

    backgatheredids = []
    gatherids = []
    backgathered = []
    backgathered_en = []

    allfeat = [x]

    sel_gidx = gidx_orig
    for i in range(6):

        pixelcompress = 16 + 8 * i
        nneigh = 8 + 32 * i
        n_dimensions = int(4 + i)

        x = BatchNormalization(momentum=0.6)(x)
        #do some processing
        x = Dense(64, activation='elu', name='dense_a_' + str(i))(x)
        x = BatchNormalization(momentum=0.6)(x)
        x, coords, nidx, dist = RaggedGravNet(n_neighbours=nneigh,
                                              n_dimensions=n_dimensions,
                                              n_filters=64,
                                              n_propagate=64)([x, rs])
        x = Concatenate()([x, MessagePassing([64, 32, 16, 8])([x, nidx])])

        #allfeat.append(x)

        x = Dense(128, activation='elu')(x)
        x = Dense(pixelcompress, activation='elu')(x)

        x_sp = SoftPixelCNN(length_scale=1.)([coords, x, nidx])
        x = Concatenate()([x, x_sp])
        x = BatchNormalization(momentum=0.6)(x)
        x = Dense(128, activation='elu')(x)
        x = Dense(pixelcompress, activation='elu')(x)

        hier = Dense(1, activation='sigmoid',
                     name='hier_' + str(i))(x)  #clustering hierarchy

        dist = ScalarMultiply(1 / (3. * float(i + 0.5)))(dist)

        x, rs, bidxs, t_idx = LocalClusterReshapeFromNeighbours(
            K=5,
            radius=0.1,
            print_reduction=True,
            loss_enabled=True,
            loss_scale=2.,
            loss_repulsion=0.5,
            print_loss=True)([x, dist, hier, nidx, rs, t_idx, t_idx])

        print('>>>>x0', x.shape)
        gatherids.append(bidxs)
        if addBackGatherInfo:
            backgatheredids.append(MultiBackGather()([sel_gidx, gatherids]))

        print('>>>>x1', x.shape)
        x = BatchNormalization(momentum=0.6)(x)
        x_record = Dense(2 * pixelcompress, activation='elu')(x)
        x_record = BatchNormalization(momentum=0.6)(x_record)
        allfeat.append(MultiBackGather()([x_record, gatherids]))

    x = Concatenate(name='allconcat')(allfeat)
    x = Dense(128, activation='elu', name='alldense')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(64, activation='elu')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Concatenate()([feat, x])
    x = BatchNormalization(momentum=0.6)(x)

    pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id = create_outputs(
        x, feat)

    #loss
    pred_beta = LLFullObjectCondensation(
        print_loss=True,
        energy_loss_weight=1e-3,
        position_loss_weight=1e-3,
        timing_loss_weight=1e-3,
    )([
        pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id,
        orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid,
        row_splits
    ])

    return Model(inputs=Inputs,
                 outputs=[
                     pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time,
                     pred_id, rs
                 ] + backgatheredids)
def gravnet_model(Inputs, feature_dropout=-1., addBackGatherInfo=True):

    ######## pre-process all inputs and create global indices etc. No DNN actions here

    feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits = td.interpretAllModelInputs(
        Inputs)
    feat, t_idx, t_energy, t_pos, t_time, t_pid = NormalizeInputShapes()(
        [feat, t_idx, t_energy, t_pos, t_time, t_pid])

    orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits
    gidx_orig = CreateGlobalIndices()(feat)

    _, row_splits = RaggedConstructTensor()([feat, row_splits])
    rs = row_splits

    feat_norm = ProcessFeatures()(
        feat)  #get rid of unit scalings, almost normalise
    feat_norm = BatchNormalization(momentum=0.6)(feat_norm)
    x = feat_norm

    energy = SelectFeatures(0, 1)(feat)
    time = SelectFeatures(8, 9)(feat)
    orig_coords = SelectFeatures(5, 8)(feat_norm)

    ######## create output lists

    allfeat = []

    backgatheredids = []
    gatherids = []
    backgathered = []
    backgathered_coords = []

    ####### create simple first coordinate transformation explicitly (time critical)

    coords = orig_coords
    coords = Dense(16, activation='elu')(coords)
    coords = Dense(32, activation='elu')(coords)
    coords = Dense(3, use_bias=False)(coords)
    coords = ScalarMultiply(0.1)(coords)
    coords = Add()([coords, orig_coords])
    coords = Dense(3,
                   use_bias=False,
                   kernel_initializer=tf.keras.initializers.identity())(coords)

    first_coords = coords

    ###### apply one gravnet-like transformation (explicit here because we have created coords by hand) ###

    nidx, dist = KNN(K=48)([coords, rs])
    x_mp = DistanceWeightedMessagePassing([32])([x, nidx, dist])

    first_nidx = nidx
    first_dist = dist

    ###### collect information about the surrounding energy and time distributions per vertex ###

    ncov = NeighbourCovariance()(
        [coords, ReluPlusEps()(Concatenate()([energy, time])), nidx])
    ncov = BatchNormalization(momentum=0.6)(ncov)
    ncov = Dense(64, activation='elu', name='pre_dense_ncov_a')(ncov)
    ncov = Dense(32, activation='elu', name='pre_dense_ncov_b')(ncov)

    ##### put together and process ####

    x = Concatenate()([x, x_mp, ncov, coords])
    x = Dense(64, activation='elu', name='pre_dense_a')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(32, activation='elu', name='pre_dense_b')(x)

    ####### add first set of outputs to output lists

    allfeat.append(x)
    backgathered_coords.append(coords)

    total_iterations = 5

    sel_gidx = gidx_orig

    for i in range(total_iterations):

        ###### reshape the graph to fewer vertices ####

        hier = Dense(1)(x)
        dist = LocalDistanceScaling()([dist, Dense(1)(x)])

        x_cl, rs, bidxs, sel_gidx, energy, x, t_idx, coords = LocalClusterReshapeFromNeighbours(
            K=6,
            radius=
            0.5,  #doesn't really have an effect because of local distance scaling
            print_reduction=True,
            loss_enabled=True,
            loss_scale=4.,
            loss_repulsion=0.5,
            print_loss=True,
            name='clustering_' + str(i))([
                x, dist, hier, nidx, rs, sel_gidx, energy, x, t_idx, coords,
                t_idx
            ])  #last is truth index used by layer

        gatherids.append(bidxs)

        if i or True:
            x_cl_rs = Reshape([-1, x.shape[-1]])(x_cl)  #get to shape V x K x F
            xec = EdgeConvStatic([32, 32, 32],
                                 name="ec_static_" + str(i))(x_cl_rs)
            x_cl = Concatenate()([x, xec])

        ### explicitly sum energy and re-add to features

        energy = ReduceSumEntirely()(energy)
        n_energy = BatchNormalization(momentum=0.6)(energy)
        x = Concatenate()([x_cl, n_energy])

        x = Dense(128, activation='elu', name='dense_clc0_' + str(i))(x)
        x = Dense(64, activation='relu', name='dense_clc1_' + str(i))(x)
        #notice last relu for feature weighting later

        x_gn, coords, nidx, dist = RaggedGravNet(
            n_neighbours=32 + 16 * i,
            n_dimensions=3,
            n_filters=64 + 16 * i,
            n_propagate=64,
            return_self=True)([Concatenate()([coords, x]), rs])

        ### add neighbour summary statistics

        x_ncov = NeighbourCovariance()([coords, ReluPlusEps()(x), nidx])
        x_ncov = Dense(128, activation='elu',
                       name='dense_ncov_a_' + str(i))(x_ncov)
        x_ncov = BatchNormalization(momentum=0.6)(x_ncov)
        x_ncov = Dense(64, activation='elu',
                       name='dense_ncov_b_' + str(i))(x_ncov)
        x = Concatenate()([x, x_ncov, x_gn])

        ### with all this information perform a few message passing steps

        x_mp = MessagePassing([32, 32, 16, 16, 8, 8])([x, nidx])
        x_mp = Dense(64, activation='elu', name='dense_mpc_' + str(i))(x_mp)
        x = Concatenate()([x, x_mp])

        ##### prepare output of this iteration

        x = Dense(128, activation='elu', name='dense_out_a_' + str(i))(x)
        x = BatchNormalization(momentum=0.6)(x)
        x = Dense(64, activation='elu', name='dense_out_b_' + str(i))(x)
        x = BatchNormalization(momentum=0.6)(x)

        #### compress further for output, but forward fill 64 feature x to next iteration

        x_r = Dense(8 + 16 * i, activation='elu',
                    name='dense_out_c_' + str(i))(x)
        #coords_nograd = StopGradient()(coords)
        #x_r = Concatenate()([coords_nograd,x_r]) ## add coordinates, might come handy for cluster space

        if i >= total_iterations - 1:
            energy = MultiBackGather()(
                [energy,
                 gatherids])  #assign energy sum to all cluster components

        allfeat.append(MultiBackGather()([x_r, gatherids]))
        backgatheredids.append(MultiBackGather()([sel_gidx, gatherids]))
        backgathered_coords.append(MultiBackGather()([coords, gatherids]))

    x = Concatenate(name='allconcat')(allfeat)
    #x = Dropout(0.2)(x)
    x_mp = DistanceWeightedMessagePassing([32, 32,
                                           32])([x, first_nidx, first_dist])
    x = Concatenate()([x, x_mp])

    x = Dense(128, activation='elu', name='alldense')(x)
    # TO BE ADDED WITH E LOSS x = Concatenate()([x,energy])
    #x = Dropout(0.2)(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(64, activation='elu')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(64, activation='elu')(x)

    pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id = create_outputs(
        x, feat)

    #
    #
    # double scale phase transition with linear beta + qmin
    #  -> more high beta points, but: payload loss will still scale one
    #     (or two, but then doesn't matter)
    #

    pred_beta = LLFullObjectCondensation(
        print_loss=True,
        energy_loss_weight=0.,
        position_loss_weight=0.,  #seems broken
        timing_loss_weight=0.,  #1e-3,
        beta_loss_scale=1.,
        repulsion_scaling=1.,
        q_min=1.5,
        prob_repulsion=True,
        phase_transition=True,
        phase_transition_double_weight=False,
        alt_potential_norm=True,
        cut_payload_beta_gradient=False)([
            pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id,
            orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid,
            row_splits
        ])

    return ExtendedMetricsModel(inputs=Inputs,
                                outputs=[
                                    pred_beta, pred_ccoords, pred_energy,
                                    pred_pos, pred_time, pred_id, rs
                                ] + backgatheredids + backgathered_coords)