Ejemplo n.º 1
0
def gravnet_model(Inputs,
                  td,
                  debug_outdir=None,
                  plot_debug_every=1000,
                  ):
    ####################################################################################
    ##################### Input processing, no need to change much here ################
    ####################################################################################

    orig_inputs = td.interpretAllModelInputs(Inputs,returndict=True)
    
    
    orig_t_spectator_weight = CreateTruthSpectatorWeights(threshold=3.,
                                                     minimum=1e-1,
                                                     active=True
                                                     )([orig_inputs['t_spectator'], 
                                                        orig_inputs['t_idx']])
                                                     
    orig_inputs['t_spectator_weight'] = orig_t_spectator_weight                                                 
    #can be loaded - or use pre-selected dataset (to be made)
    pre_selection = pre_selection_model_full(orig_inputs,trainable=False)
    
    #just for info what's available
    print('available pre-selection outputs',[k for k in pre_selection.keys()])
                                          
    
    t_spectator_weight = CreateTruthSpectatorWeights(threshold=3.,
                                                     minimum=1e-1,
                                                     active=True
                                                     )([pre_selection['t_spectator'], 
                                                        pre_selection['t_idx']])
    rs = pre_selection['rs']
                               
    x_in = Concatenate()([pre_selection['coords'],
                          pre_selection['features'],
                          pre_selection['addfeat']])
                           
    x = x_in
    energy = pre_selection['energy']
    coords = pre_selection['phys_coords']#physical coordinates
    c_coords = pre_selection['coords']#pre-clustered coordinates
    t_idx = pre_selection['t_idx']
    
    ####################################################################################
    ##################### now the actual model goes below ##############################
    ####################################################################################
    
    allfeat = []
    
    n_cluster_space_coordinates = 3
    
    
    #extend coordinates already here if needed
    c_coords = extent_coords_if_needed(c_coords, x, n_cluster_space_coordinates)
        

    for i in range(total_iterations):

        # derive new coordinates for clustering
        x = RaggedGlobalExchange()([x, rs])
        
        x = Dense(64,activation=dense_activation)(x)
        x = Dense(64,activation=dense_activation)(x)
        x = Dense(64,activation=dense_activation)(x)
        x = GooeyBatchNorm(**batchnorm_options)(x)
        ### reduction done
        
        n_dims = 6
        #exchange information, create coordinates
        x = Concatenate()([c_coords,c_coords,c_coords,coords,x])
        xgn, gncoords, gnnidx, gndist = RaggedGravNet(n_neighbours=64,
                                                 n_dimensions=n_dims,
                                                 n_filters=64,
                                                 n_propagate=64,
                                                 record_metrics=True,
                                                 coord_initialiser_noise=1e-2,
                                                 use_approximate_knn=False #weird issue with that for now
                                                 )([x, rs])
        
        x = Concatenate()([x,xgn])                                                      
        #just keep them in a reasonable range  
        #safeguard against diappearing gradients on coordinates                                       
        gndist = AverageDistanceRegularizer(strength=1e-4,
                                            record_metrics=True
                                            )(gndist)
                                            
        gncoords = PlotCoordinates(plot_debug_every, outdir = debug_outdir,
                                   name='gn_coords_'+str(i))([gncoords, 
                                                                    energy,
                                                                    t_idx,
                                                                    rs]) 
        x = Concatenate()([gncoords,x])           
        
        pre_gndist=gndist
        if double_mp:
            for im,m in enumerate([64,64,32,32,16,16]):
                dscale=Dense(1)(x)
                gndist = LocalDistanceScaling(4.)([pre_gndist,dscale])                                  
                gndist = AverageDistanceRegularizer(strength=1e-6,
                                            record_metrics=True,
                                            name='average_distance_dmp_'+str(i)+'_'+str(im)
                                            )(gndist)
                                            
                x = DistanceWeightedMessagePassing([m],
                                           activation=dense_activation
                                           )([x,gnnidx,gndist])
        else:        
            x = DistanceWeightedMessagePassing([64,64,32,32,16,16],
                                           activation=dense_activation
                                           )([x,gnnidx,gndist])
            
        x = GooeyBatchNorm(**batchnorm_options)(x)
        
        x = Dense(64,name='dense_past_mp_'+str(i),activation=dense_activation)(x)
        x = Dense(64,activation=dense_activation)(x)
        x = Dense(64,activation=dense_activation)(x)
        
        x = GooeyBatchNorm(**batchnorm_options)(x)
        
        
        allfeat.append(x)
        
        
    
    x = Concatenate()([c_coords]+allfeat+[pre_selection['not_noise_score']])
    #do one more exchange with all
    x = Dense(64,activation=dense_activation)(x)
    x = Dense(64,activation=dense_activation)(x)
    x = Dense(64,activation=dense_activation)(x)
    
    
    #######################################################################
    ########### the part below should remain almost unchanged #############
    ########### of course with the exception of the OC loss   #############
    ########### weights                                       #############
    #######################################################################
    
    x = GooeyBatchNorm(**batchnorm_options,name='gooey_pre_out')(x)
    x = Concatenate()([c_coords]+[x])
    
    pred_beta, pred_ccoords, pred_dist, pred_energy_corr, \
    pred_pos, pred_time, pred_id = create_outputs(x, pre_selection['unproc_features'], 
                                                  n_ccoords=n_cluster_space_coordinates)
    
    # loss
    pred_beta = LLFullObjectCondensation(scale=4.,
                                         position_loss_weight=1e-5,
                                         timing_loss_weight=1e-5,
                                         beta_loss_scale=1.,
                                         use_energy_weights=True,
                                         record_metrics=True,
                                         name="FullOCLoss",
                                         **loss_options
                                         )(  # oc output and payload
        [pred_beta, pred_ccoords, pred_dist,
         pred_energy_corr, pred_pos, pred_time, pred_id] +
        [energy]+
        # truth information
        [pre_selection['t_idx'] ,
         pre_selection['t_energy'] ,
         pre_selection['t_pos'] ,
         pre_selection['t_time'] ,
         pre_selection['t_pid'] ,
         pre_selection['t_spectator_weight'],
         pre_selection['t_fully_contained'],
         pre_selection['t_rec_energy'],
         pre_selection['t_is_unique'],
         pre_selection['rs']])
                                         
    #fast feedback
    pred_ccoords = PlotCoordinates(plot_debug_every, outdir = debug_outdir,
                    name='condensation')([pred_ccoords, pred_beta,pre_selection['t_idx'],
                                          rs])                                    

    model_outputs = re_integrate_to_full_hits(
        pre_selection,
        pred_ccoords,
        pred_beta,
        pred_energy_corr,
        pred_pos,
        pred_time,
        pred_id,
        pred_dist,
        dict_output=True
        )
    
    return DictModel(inputs=Inputs, outputs=model_outputs)
def gravnet_model(Inputs,
                  viscosity=0.2,
                  print_viscosity=False,
                  fluidity_decay=1e-3,
                  max_viscosity=0.9 # to start with
                  ):

    feature_dropout=-1.
    addBackGatherInfo=True,

    feat,  t_idx, t_energy, t_pos, t_time, t_pid, t_spectator, t_fully_contained, row_splits = td.interpretAllModelInputs(Inputs)
    orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits
    gidx_orig = CreateGlobalIndices()(feat)

    _, row_splits = RaggedConstructTensor()([feat, row_splits])
    rs = row_splits

    feat_norm = ProcessFeatures()(feat)
    energy = SelectFeatures(0,1)(feat)
    time = SelectFeatures(8,9)(feat_norm)
    orig_coords = SelectFeatures(5,8)(feat)

    x = feat_norm

    allfeat=[x]
    backgatheredids=[]
    gatherids=[]
    backgathered = []
    backgathered_coords = []
    energysums = []

    n_reshape_dimensions=3

    #really simple real coordinates
    coords = ManualCoordTransform()(orig_coords)
    coords = Concatenate()([coords, time])
    coords = Dense(n_reshape_dimensions, use_bias=False, kernel_initializer=tf.keras.initializers.Identity()  )(coords)#just rotation and scaling


    #see whats there
    nidx, dist = KNN(K=24,radius=1.0)([coords,rs])

    x_mp = DistanceWeightedMessagePassing([32,16,16,8])([x,nidx,dist])
    x = Concatenate()([x,x_mp])
    #this is going to be among the most expensive operations:
    x = Dense(64, activation='elu',name='pre_dense_a')(x)
    x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x)

    allfeat.append(x)
    backgathered_coords.append(coords)

    sel_gidx = gidx_orig

    cdist = dist
    ccoords = coords

    total_iterations=5

    fwdbgccoords=None

    for i in range(total_iterations):

        cluster_neighbours = 5
        #derive new coordinates for clustering
        if i:
            ccoords = Add()([
                ccoords,Dense(n_reshape_dimensions,name='newcoords'+str(i),
                                         kernel_initializer='zeros'
                                         )(x)
                ])
            nidx, cdist = KNN(K=5*cluster_neighbours,radius=-1.0)([ccoords,rs])
            #here we use more neighbours to improve learning of the cluster space
            #this can be adjusted in the final trained model to be equal to 'cluster_neighbours'
        cdist = LocalDistanceScaling(max_scale=10.)([cdist, Dense(1,kernel_initializer='zeros')(Concatenate()([x,cdist]))])

        x, rs, bidxs, sel_gidx, energy, x, t_idx, coords, ccoords, cdist = LocalClusterReshapeFromNeighbours2(
                 K=cluster_neighbours,
                 radius=0.1,
                 print_reduction=True,
                 loss_enabled=True,
                 loss_scale = 3.,
                 loss_repulsion=0.8, # it's important to get this right here
                 hier_transforms=[64,32,16],
                 print_loss=False,
                 name='clustering_'+str(i)
                 )([x, cdist, nidx, rs, sel_gidx, energy, x, t_idx, coords, ccoords, cdist, t_idx])

        gatherids.append(bidxs)

        #explicit
        energy = ReduceSumEntirely()(energy)#sums up all contained energy per cluster

        x = Dense(128, activation='elu',name='dense_clc_a'+str(i))(x)
        x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x)
        x = Dense(128, activation='elu',name='dense_clc_b'+str(i))(x)
        x = Dense(64, activation='elu')(x)
        x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x)


        n_dimensions = 3+i #make it plottable
        nneigh = 64+32*i #this will be almost fully connected for last clustering step
        nfilt = 64
        nprop = 64

        x = Concatenate()([coords,x])

        x_gn, coords, nidx, dist = RaggedGravNet(n_neighbours=nneigh,
                                              n_dimensions=n_dimensions,
                                              n_filters=nfilt,
                                              n_propagate=nprop)([x, rs])


        x_pca = Dense(2+4*i)(x)
        x_pca = NeighbourApproxPCA()([coords,dist,x_pca,nidx])
        x_pca = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x_pca)

        x_mp = DistanceWeightedMessagePassing([64,64,32,32])([x,nidx,dist])
        x_mp = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x_mp)

        x = Concatenate()([x,x_pca,x_mp,x_gn])
        #check and compress it all
        x = Dense(128, activation='elu',name='dense_a_'+str(i))(x)
        x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x)
        x = Dense(32+16*i, activation='elu',name='dense_c_'+str(i))(x)
        x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x)


        energysums.append( MultiBackGather()([energy, gatherids]) )#assign energy sum to all cluster components

        allfeat.append(MultiBackGather()([x, gatherids]))

        backgatheredids.append(MultiBackGather()([sel_gidx, gatherids]))
        bgccoords = MultiBackGather()([ccoords, gatherids])
        if fwdbgccoords is None:
            fwdbgccoords = bgccoords
        backgathered_coords.append(bgccoords)




    x = Concatenate(name='allconcat')(allfeat)
    x = Dense(128, activation='elu', name='alldense')(x)
    x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x)

    #this is going to be resource intense, give a good starting point with the last ccoords
    coords = Dense(3,use_bias=False,
                   kernel_initializer=tf.keras.initializers.Identity()
                   )(Concatenate()([ fwdbgccoords ,x]))

    nidx, dist = KNN(K=32,radius=1.0)([coords,row_splits])
    x_mp = DistanceWeightedMessagePassing(8*[8])([x,nidx,dist])#only some but a lot of hops
    x = Concatenate()([x,x_mp])
    #
    backgathered_coords.append(coords)

    x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x)
    x = Dense(128, activation='elu')(x)
    x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x)
    x = Concatenate()([x]+energysums)
    x = Dense(64, activation='elu')(x)
    x = Dense(64, activation='elu')(x)

    pred_beta, pred_ccoords, pred_dist, pred_energy, pred_pos, pred_time, pred_id = create_outputs(x,feat,add_distance_scale=True)

    #loss
    pred_beta = LLFullObjectCondensation(print_loss=True,
                                         energy_loss_weight=1e-4,
                                         position_loss_weight=1e-3,
                                         timing_loss_weight=1e-3,
                                         beta_loss_scale=1.,
                                         repulsion_scaling=5.,
                                         repulsion_q_min=1.,
                                         super_repulsion=False,
                                         q_min=0.5,
                                         use_average_cc_pos=0.5,
                                         prob_repulsion=True,
                                         phase_transition=1,
                                         huber_energy_scale = 3,
                                         alt_potential_norm=True,
                                         beta_gradient_damping=0.,
                                         payload_beta_gradient_damping_strength=0.,
                                         kalpha_damping_strength=0.,#1.,
                                         use_local_distances=True,
                                         name="FullOCLoss"
                                         )([pred_beta, pred_ccoords,
                                            pred_dist,
                                            pred_energy,
                                            pred_pos, pred_time, pred_id,
                                            orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid,
                                            row_splits])

    model_outputs = [('pred_beta', pred_beta), ('pred_ccoords',pred_ccoords),
       ('pred_energy',pred_energy),
       ('pred_pos',pred_pos),
       ('pred_time',pred_time),
       ('pred_id',pred_id),
       ('pred_dist',pred_dist),
       ('row_splits',rs)]

    for i, (x, y) in enumerate(zip(backgatheredids, backgathered_coords)):
        model_outputs.append(('backgatheredids_'+str(i), x))
        model_outputs.append(('backgathered_coords_'+str(i), y))
    return RobustModel(model_inputs=Inputs, model_outputs=model_outputs)
Ejemplo n.º 3
0
def gravnet_model(
    Inputs,
    td,
    empty_pca=False,
    total_iterations=2,
    variance_only=True,
    viscosity=0.1,
    print_viscosity=False,
    fluidity_decay=5e-4,  # reaches after about 7k batches
    max_viscosity=0.95,
    debug_outdir=None,
    plot_debug_every=1000,
):
    ####################################################################################
    ##################### Input processing, no need to change much here ################
    ####################################################################################

    orig_inputs = td.interpretAllModelInputs(Inputs, returndict=True)

    orig_t_spectator_weight = CreateTruthSpectatorWeights(
        threshold=3., minimum=1e-1,
        active=True)([orig_inputs['t_spectator'], orig_inputs['t_idx']])

    orig_inputs['t_spectator_weight'] = orig_t_spectator_weight
    #can be loaded - or use pre-selected dataset (to be made)
    pre_selection = pre_selection_model_full(orig_inputs, trainable=False)

    #just for info what's available
    print([k for k in pre_selection.keys()])

    t_spectator_weight = CreateTruthSpectatorWeights(
        threshold=3., minimum=1e-1,
        active=True)([pre_selection['t_spectator'], pre_selection['t_idx']])
    rs = pre_selection['rs']

    x_in = Concatenate()([
        pre_selection['coords'], pre_selection['features'],
        pre_selection['addfeat']
    ])

    x = x_in
    energy = pre_selection['energy']
    coords = pre_selection['phys_coords']  #physical coordinates
    c_coords = pre_selection['coords']  #pre-clustered coordinates
    t_idx = pre_selection['t_idx']

    ####################################################################################
    ##################### now the actual model goes below ##############################
    ####################################################################################

    allfeat = []

    n_cluster_space_coordinates = 3

    #extend coordinates already here if needed
    coords = extent_coords_if_needed(coords, x, n_cluster_space_coordinates)

    for i in range(total_iterations):

        # derive new coordinates for clustering
        x = RaggedGlobalExchange()([x, rs])

        x = Dense(64, activation='relu')(x)
        x = Dense(64, activation='relu')(x)
        x = Dense(64, activation='relu')(x)
        x = GooeyBatchNorm(viscosity=viscosity,
                           max_viscosity=max_viscosity,
                           variance_only=variance_only,
                           record_metrics=True,
                           fluidity_decay=fluidity_decay)(x)
        ### reduction done

        n_dims = 3
        #exchange information, create coordinates
        x = Concatenate()([coords, coords, c_coords, x])
        x, gncoords, gnnidx, gndist = RaggedGravNet(
            n_neighbours=64,
            n_dimensions=n_dims,
            n_filters=64,
            n_propagate=64,
            record_metrics=True,
            use_approximate_knn=False  #weird issue with that for now
        )([x, rs])

        gncoords = PlotCoordinates(plot_debug_every,
                                   outdir=debug_outdir,
                                   name='gn_coords_' +
                                   str(i))([gncoords, energy, t_idx, rs])
        #just keep them in a reasonable range
        #safeguard against diappearing gradients on coordinates
        gndist = AverageDistanceRegularizer(strength=0.01,
                                            record_metrics=True)(gndist)
        x = Concatenate()([energy, x])
        x_pca = Dense(4, activation='relu')(x)  #pca is expensive
        x_pca = ApproxPCA(empty=empty_pca)([gncoords, gndist, x_pca, gnnidx])
        x = Concatenate()([x, x_pca])

        x = DistanceWeightedMessagePassing([64, 64, 32, 32, 16,
                                            16])([x, gnnidx, gndist])

        x = GooeyBatchNorm(viscosity=viscosity,
                           max_viscosity=max_viscosity,
                           record_metrics=True,
                           variance_only=variance_only,
                           fluidity_decay=fluidity_decay)(x)

        x = Dense(64, activation='relu')(x)
        x = Dense(64, activation='relu')(x)
        x = Dense(64, activation='relu')(x)

        x = GooeyBatchNorm(viscosity=viscosity,
                           max_viscosity=max_viscosity,
                           variance_only=variance_only,
                           record_metrics=True,
                           fluidity_decay=fluidity_decay)(x)

        allfeat.append(x)

    x = Concatenate()([c_coords] + allfeat +
                      [pre_selection['not_noise_score']])
    #do one more exchange with all
    x = Dense(64, activation='elu')(x)
    x = Dense(64, activation='elu')(x)
    x = Dense(64, activation='elu')(x)

    #######################################################################
    ########### the part below should remain almost unchanged #############
    ########### of course with the exception of the OC loss   #############
    ########### weights                                       #############
    #######################################################################

    x = GooeyBatchNorm(viscosity=viscosity,
                       max_viscosity=max_viscosity,
                       fluidity_decay=fluidity_decay,
                       record_metrics=True,
                       variance_only=variance_only,
                       name='gooey_pre_out')(x)
    x = Concatenate()([c_coords] + [x])

    pred_beta, pred_ccoords, pred_dist, pred_energy_corr, \
    pred_pos, pred_time, pred_id = create_outputs(x, pre_selection['unproc_features'],
                                                  n_ccoords=n_cluster_space_coordinates)

    # loss
    pred_beta = LLFullObjectCondensation(
        scale=1.,
        energy_loss_weight=2.,
        #print_batch_time=True,
        position_loss_weight=1e-5,
        timing_loss_weight=1e-5,
        classification_loss_weight=1e-5,
        beta_loss_scale=1.,
        too_much_beta_scale=1e-4,
        use_energy_weights=True,
        record_metrics=True,
        q_min=0.2,
        #div_repulsion=True,
        # cont_beta_loss=True,
        # beta_gradient_damping=0.999,
        # phase_transition=1,
        #huber_energy_scale=0.1,
        use_average_cc_pos=0.2,  # smoothen it out a bit
        name="FullOCLoss")(  # oc output and payload
            [
                pred_beta, pred_ccoords, pred_dist, pred_energy_corr, pred_pos,
                pred_time, pred_id
            ] + [energy] +
            # truth information
            [
                pre_selection['t_idx'], pre_selection['t_energy'],
                pre_selection['t_pos'], pre_selection['t_time'],
                pre_selection['t_pid'], pre_selection['t_spectator_weight'],
                pre_selection['t_fully_contained'],
                pre_selection['t_rec_energy'], pre_selection['t_is_unique'],
                pre_selection['rs']
            ])

    #fast feedback
    pred_ccoords = PlotCoordinates(plot_debug_every,
                                   outdir=debug_outdir,
                                   name='condensation')([
                                       pred_ccoords, pred_beta,
                                       pre_selection['t_idx'], rs
                                   ])

    model_outputs = re_integrate_to_full_hits(pre_selection,
                                              pred_ccoords,
                                              pred_beta,
                                              pred_energy_corr,
                                              pred_pos,
                                              pred_time,
                                              pred_id,
                                              pred_dist,
                                              dict_output=True)

    return DictModel(inputs=Inputs, outputs=model_outputs)
Ejemplo n.º 4
0
def gravnet_model(Inputs, feature_dropout=-1., addBackGatherInfo=True):

    ######## pre-process all inputs and create global indices etc. No DNN actions here

    feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits = td.interpretAllModelInputs(
        Inputs)

    # feat = Lambda(lambda x: tf.squeeze(x,axis=1)) (feat)

    #tf.print([(t.shape, t.name) for t in [feat,  t_idx, t_energy, t_pos, t_time, t_pid, row_splits]])

    #feat,  t_idx, t_energy, t_pos, t_time, t_pid = NormalizeInputShapes()(
    #    [feat,  t_idx, t_energy, t_pos, t_time, t_pid]
    #    )

    orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits
    gidx_orig = CreateGlobalIndices()(feat)

    _, row_splits = RaggedConstructTensor()([feat, row_splits])
    rs = row_splits

    feat_norm = ProcessFeatures()(
        feat)  #get rid of unit scalings, almost normalise
    feat_norm = BatchNormalization(momentum=0.6)(feat_norm)
    x = feat_norm

    energy = SelectFeatures(0, 1)(feat)
    time = SelectFeatures(8, 9)(feat)
    orig_coords = SelectFeatures(5, 8)(feat_norm)

    ######## create output lists

    allfeat = []

    backgatheredids = []
    gatherids = []
    backgathered = []
    backgathered_coords = []

    ####### create simple first coordinate transformation explicitly (time critical)

    #trans_coords = ManualCoordTransform()(orig_coords)
    coords = orig_coords
    coords = Dense(3,
                   use_bias=False,
                   kernel_initializer=tf.keras.initializers.Identity())(coords)

    nidx, dist = KNN(K=48)([coords, rs])

    first_nidx = nidx
    first_dist = dist
    first_coords = coords

    dist = LocalDistanceScaling(max_scale=10.)([dist, Dense(1)(x)])

    x_mp = DistanceWeightedMessagePassing([32, 16, 8])([x, nidx, dist])
    x_mp = Dense(32, activation='elu')(x_mp)
    x_mp = BatchNormalization(momentum=0.6)(x_mp)
    x = Concatenate()([x, x_mp])

    ###### collect information about the surrounding energy and time distributions per vertex ###
    ncov = Dense(4, kernel_initializer=tf.keras.initializers.Identity())(
        Concatenate()([energy, time, x]))
    ncov = NeighbourCovariance()([coords, dist, ncov, nidx])
    ncov = Dense(24, activation='elu', name='pre_dense_ncov_c')(ncov)
    ncov = BatchNormalization(momentum=0.6)(ncov)
    #should be enough for PCA, total info: 2 * (9+3)

    ##### put together and process ####

    x = Concatenate()([x, x_mp, ncov])
    #x = Dense(64, activation='elu',name='pre_dense_a')(x)
    x = Dense(64, activation='elu', name='pre_dense_b')(x)

    ####### add first set of outputs to output lists

    allfeat.append(x)
    backgathered_coords.append(coords)

    total_iterations = 5

    sel_gidx = gidx_orig

    for i in range(total_iterations):

        ###### reshape the graph to fewer vertices ####

        hier = Dense(1)(x)
        #narrow the local scaling down at each iteration as the coords become more abstract
        dist = LocalDistanceScaling(max_scale=10.)(
            [dist, Dense(1)(Concatenate()([x, dist]))])

        x_cl, rs, bidxs, sel_gidx, energy, x, t_idx, coords = LocalClusterReshapeFromNeighbours(
            K=6,
            radius=
            0.2,  #doesn't really have an effect because of local distance scaling
            print_reduction=True,
            loss_enabled=True,
            loss_scale=1.,
            loss_repulsion=0.4,
            print_loss=True,
            name='clustering_' + str(i))([
                x, dist, hier, nidx, rs, sel_gidx, energy, x, t_idx, coords,
                t_idx
            ])  #last is truth index used by layer

        gatherids.append(bidxs)

        energy = ReduceSumEntirely()(energy)

        #use EdgeConv operation to determine cluster properties
        if True or i:  #only after second iteration because of OOM
            x_cl = Reshape([-1, x.shape[-1]])(x_cl)  #get to shape V x K x F
            x_cl = EdgeConvStatic([64, 64, 64],
                                  add_mean=True,
                                  name="ec_static_" + str(i))(x_cl)
            x_cl = Concatenate()([x, x_cl])

        x = Concatenate()([x_cl, StopGradient()(coords)])

        x = Dense(64, activation='elu', name='dense_clc0_' + str(i))(x)
        x = Dense(64, activation='elu', name='dense_clc1_' + str(i))(x)

        x = BatchNormalization(momentum=0.6)(x)
        #notice last relu for feature weighting later

        ### now these are the new cluster features, up for the next iteration of building new latent space
        #x = RaggedGlobalExchange()([x,rs])

        x_gn, coords, nidx, dist = RaggedGravNet(
            n_neighbours=64 + 32 * i,
            n_dimensions=3,
            n_filters=64,
            n_propagate=64,
            return_self=True)([Concatenate()([coords, x]), rs])

        ng_coords = StopGradient()(coords)
        ng_dist = StopGradient()(dist)

        x_gn = Dense(32, activation='elu')(x_gn)

        ### add neighbour summary statistics
        #dist = LocalDistanceScaling(max_scale=10.)([dist, Dense(1)(x_gn)])

        x_ncov = Dense(16)(x)
        x_ncov = NeighbourCovariance()([ng_coords, ng_dist, x_ncov, nidx])
        x_ncov = Dense(64, activation='elu',
                       name='dense_ncov_a_' + str(i))(x_ncov)
        x_ncov = Dense(64, activation='elu',
                       name='dense_ncov_b_' + str(i))(x_ncov)
        x_ncov = BatchNormalization(momentum=0.6)(x_ncov)

        ### with all this information perform a few message passing steps
        x_mp = x
        x_mp = DistanceWeightedMessagePassing([32, 16,
                                               8])([x_mp, nidx, ng_dist])
        x_mp = Dense(64, activation='elu')(x_mp)
        x_mp = Dense(32, activation='elu')(x_mp)
        x_mp = BatchNormalization(momentum=0.6)(x_mp)

        x = Concatenate()([x, x_mp, x_ncov, x_gn, ng_coords, ng_dist])

        ##### prepare output of this iteration

        x = Dense(64, activation='elu', name='dense_out_a_' + str(i))(x)
        x = Dense(64, activation='elu', name='dense_out_b_' + str(i))(x)
        x = BatchNormalization(momentum=0.6)(x)

        #### compress further for output, but forward fill 64 feature x to next iteration

        x_r = Dense(32, activation='elu', name='dense_out_c_' + str(i))(x)
        #x_r = Concatenate()([x_r, ng_coords, ng_dist])

        #x_r = Concatenate()([StopGradient()(coords),x_r]) ## add coordinates, might come handy for cluster space

        if i >= total_iterations - 1:
            energy = MultiBackGather()(
                [energy,
                 gatherids])  #assign energy sum to all cluster components

        allfeat.append(MultiBackGather()([x_r, gatherids]))
        backgatheredids.append(MultiBackGather()([sel_gidx, gatherids]))
        backgathered_coords.append(MultiBackGather()([coords, gatherids]))

    x = Concatenate(name='allconcat')(allfeat)
    x = RaggedGlobalExchange()([x, row_splits])
    x = Dense(256, activation='elu', name='globalexdense')(x)
    x = RaggedGlobalExchange()([x, row_splits])

    x = Dense(128, activation='elu', name='alldense')(x)
    x = Dense(64, activation='elu')(x)
    x = Dense(32, activation='elu')(x)
    x = Concatenate()([first_coords, x])
    x = BatchNormalization(momentum=0.6)(x)

    pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id = create_outputs(
        x, feat)

    #
    #
    # double scale phase transition with linear beta + qmin
    #  -> more high beta points, but: payload loss will still scale one
    #     (or two, but then doesn't matter)
    #

    pred_beta = LLFullObjectCondensation(
        print_loss=True,
        energy_loss_weight=1e-5,
        position_loss_weight=1e-5,  #seems broken
        timing_loss_weight=1e-5,  #1e-3,
        beta_loss_scale=3.,
        repulsion_scaling=1.,
        q_min=2.0,
        use_average_cc_pos=False,
        use_energy_weights=True,
        prob_repulsion=True,
        phase_transition=True,
        phase_transition_double_weight=False,
        alt_potential_norm=True,
        cut_payload_beta_gradient=False)([
            pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id,
            orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid,
            row_splits
        ])

    return Model(inputs=Inputs,
                 outputs=[
                     pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time,
                     pred_id, rs
                 ] + backgatheredids + backgathered_coords)
Ejemplo n.º 5
0
def gravnet_model(Inputs, feature_dropout=-1., addBackGatherInfo=True):

    feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits = td.interpretAllModelInputs(
        Inputs)
    orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits
    gidx_orig = CreateGlobalIndices()(feat)

    _, row_splits = RaggedConstructTensor()([feat, row_splits])
    rs = row_splits

    feat_norm = ProcessFeatures()(feat)
    feat_norm = BatchNormalization(momentum=0.6)(feat_norm)
    allfeat = []
    x = feat_norm

    backgatheredids = []
    gatherids = []
    backgathered = []
    backgathered_coords = []
    energysums = []

    #really simple real coordinates
    energy = SelectFeatures(0, 1)(feat)
    orig_coords = SelectFeatures(4, 7)(feat_norm)
    coords = Dense(3,
                   use_bias=False,
                   kernel_initializer=tf.keras.initializers.Identity())(
                       orig_coords)  #just rotation and scaling

    #see whats there
    nidx, dist = KNN(K=64, radius=1.0)([coords, rs])

    x = RaggedGlobalExchange()([x, row_splits])
    x_c = Dense(32, activation='elu')(x)
    x_c = Dense(8)(x_c)  #just a few features are enough here
    #this can be full blown because of the small number of input features
    x_c = NeighbourApproxPCA(hidden_nodes=[32, 32, 9])(
        [coords, dist, x_c, nidx])
    x_mp = DistanceWeightedMessagePassing([32, 16, 8])([x, nidx, dist])
    x = Concatenate()([x, x_c, x_mp])
    #this is going to be among the most expensive operations:
    x = Dense(64, activation='elu', name='pre_dense_a')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(32, activation='elu', name='pre_dense_b')(x)
    x = BatchNormalization(momentum=0.6)(x)

    allfeat.append(x)
    backgathered_coords.append(coords)

    sel_gidx = gidx_orig

    cdist = dist
    ccoords = coords

    total_iterations = 5

    collectedccoords = []

    for i in range(total_iterations):

        cluster_neighbours = 5
        n_dimensions = 3  #make it plottable
        #derive new coordinates for clustering
        if i:
            ccoords = Add()([
                ccoords,
                (Dense(n_dimensions,
                       use_bias=False,
                       name='newcoords' + str(i),
                       kernel_initializer='zeros')(x))
            ])
            nidx, cdist = KNN(K=6 * cluster_neighbours,
                              radius=-1.0)([ccoords, rs])
            #here we use more neighbours to improve learning of the cluster space

        #cluster first
        #hier = Dense(1,activation='sigmoid')(x)
        #distcorr = Dense(dist.shape[-1],activation='relu',kernel_initializer='zeros')(Concatenate()([x,dist]))
        #dist = Add()([distcorr,dist])
        cdist = LocalDistanceScaling(max_scale=5.)([
            cdist,
            Dense(1, kernel_initializer='zeros')(Concatenate()([x, cdist]))
        ])
        #what if we let the individual distances scale here? so move hits in and out?

        x_cl, rs, bidxs, sel_gidx, energy, x, t_idx, coords, ccoords, cdist = LocalClusterReshapeFromNeighbours2(
            K=cluster_neighbours,
            radius=0.1,
            print_reduction=True,
            loss_enabled=True,
            loss_scale=1.,
            loss_repulsion=0.4,  #.5
            hier_transforms=[64, 32, 32],
            print_loss=True,
            name='clustering_' + str(i))([
                x, cdist, nidx, rs, sel_gidx, energy, x, t_idx, coords,
                ccoords, cdist, t_idx
            ])

        gatherids.append(bidxs)

        #explicit
        energy = ReduceSumEntirely()(
            energy)  #sums up all contained energy per cluster
        n_energy = BatchNormalization(momentum=0.6)(energy)

        x = Concatenate()([x_cl, n_energy])
        x = Dense(128, activation='elu', name='dense_clc_a' + str(i))(x)
        x = BatchNormalization(momentum=0.6)(x)
        #x = Dense(128, activation='elu',name='dense_clc_b'+str(i))(x)
        x = Dense(64, activation='elu')(x)
        x = BatchNormalization(momentum=0.6)(x)

        nneigh = 64
        nfilt = 64
        nprop = 64

        x = Concatenate()([coords, x])

        x_gn, coords, nidx, dist = RaggedGravNet(n_neighbours=nneigh,
                                                 n_dimensions=n_dimensions,
                                                 n_filters=nfilt,
                                                 n_propagate=nprop)([x, rs])

        x_sp = Dense(16)(x)
        x_sp = NeighbourApproxPCA(hidden_nodes=[32, 32, n_dimensions**2])(
            [coords, dist, x_sp, nidx])
        x_sp = BatchNormalization(momentum=0.6)(x_sp)

        x_mp = DistanceWeightedMessagePassing([32, 32, 16, 16, 8,
                                               8])([x, nidx, dist])
        x_mp = BatchNormalization(momentum=0.6)(x_mp)
        #x_sp=x_mp

        x = Concatenate()([x, x_mp, x_sp, x_gn])
        #check and compress it all
        x = Dense(128, activation='elu', name='dense_a_' + str(i))(x)
        x = BatchNormalization(momentum=0.6)(x)
        #x = Dense(128, activation='elu',name='dense_b_'+str(i))(x)
        x = Dense(64, activation='elu', name='dense_c_' + str(i))(x)
        x = Concatenate()([StopGradient()(ccoords), StopGradient()(cdist), x])
        x = BatchNormalization(momentum=0.6)(x)

        #record more and more the deeper we go
        x_r = x
        energysums.append(MultiBackGather()(
            [energy, gatherids]))  #assign energy sum to all cluster components

        allfeat.append(MultiBackGather()([x_r, gatherids]))

        backgatheredids.append(MultiBackGather()([sel_gidx, gatherids]))
        backgathered_coords.append(MultiBackGather()([ccoords, gatherids]))

    x = Concatenate(name='allconcat')(allfeat)
    x = RaggedGlobalExchange()([x, row_splits])
    #x - Dropout(0.3)(x)#force to use different information sources
    x = Dense(128, activation='elu', name='alldense')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(64, activation='elu')(x)
    #x = Concatenate()([x]+energysums)

    pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id = create_outputs(
        x, feat)

    #loss
    pred_beta = LLFullObjectCondensation(
        print_loss=True,
        energy_loss_weight=1e-4,
        position_loss_weight=1e-2,
        timing_loss_weight=1e-3,
        beta_loss_scale=1.,
        repulsion_scaling=1.,
        q_min=1.5,
        use_average_cc_pos=0.1,
        prob_repulsion=True,
        phase_transition=1,
        alt_potential_norm=True,
        kalpha_damping_strength=0.5,  #1.,
        name="FullOCLoss")([
            pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id,
            orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid,
            row_splits
        ])

    return Model(inputs=Inputs,
                 outputs=[
                     pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time,
                     pred_id, rs
                 ] + backgatheredids + backgathered_coords)
Ejemplo n.º 6
0
def gravnet_model(Inputs, feature_dropout=-1., addBackGatherInfo=True):

    feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits = td.interpretAllModelInputs(
        Inputs)
    orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits
    gidx_orig = CreateGlobalIndices()(feat)

    _, row_splits = RaggedConstructTensor()([feat, row_splits])
    rs = row_splits

    feat_norm = ProcessFeatures()(feat)
    allfeat = [feat_norm]
    x = feat_norm

    backgatheredids = []
    gatherids = []
    backgathered = []
    backgathered_coords = []

    energy = SelectFeatures(0, 1)(feat)
    orig_coords = SelectFeatures(5, 8)(feat_norm)
    coords = orig_coords
    coords = Dense(3)(coords)  #just rotation and scaling
    #add gradient

    #see whats there
    nidx, dist = KNN(K=96, radius=1.0)([coords, rs])
    x = Dense(4, activation='elu',
              name='pre_pre_dense')(x)  #just a few features are enough here
    #this can be full blown because of the small number of input features
    x_c = SoftPixelCNN(mode='full', subdivisions=4,
                       name='prePCNN')([coords, x, dist, nidx])
    x = Concatenate()([x, x_c])
    #this is going to be among the most expensive operations:
    x = Dense(128, activation='elu', name='pre_dense_a')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(64, activation='elu', name='pre_dense_b')(x)
    x = BatchNormalization(momentum=0.6)(x)

    #allfeat.append(Dense(16, activation='elu',name='feat_compress_pre')(x))
    backgathered_coords.append(coords)

    sel_gidx = gidx_orig

    total_iterations = 5

    for i in range(total_iterations):

        #cluster first
        hier = Dense(1, activation='sigmoid')(x)
        x_cl, rs, bidxs, sel_gidx, energy, x, t_idx = LocalClusterReshapeFromNeighbours(
            K=8,
            radius=0.1,
            print_reduction=True,
            loss_enabled=True,
            loss_scale=2.,
            loss_repulsion=0.3,
            print_loss=True,
            name='clustering_' + str(i))(
                [x, dist, hier, nidx, rs, sel_gidx, energy, x, t_idx, t_idx])

        #explicit
        energy = ReduceSumEntirely()(
            energy)  #sums up all contained energy per cluster

        gatherids.append(bidxs)
        x_cl = Dense(128, activation='elu', name='dense_clc_' + str(i))(x_cl)
        n_energy = BatchNormalization(momentum=0.6)(energy)
        x = Concatenate()([x, x_cl, n_energy])

        pixelcompress = 4
        nneigh = 32 + 4 * i
        nfilt = 32 + 4 * i
        nprop = 32
        n_dimensions = 3  #make it plottable

        x_gn, coords, nidx, dist = RaggedGravNet(n_neighbours=nneigh,
                                                 n_dimensions=n_dimensions,
                                                 n_filters=nfilt,
                                                 n_propagate=nprop)([x, rs])

        subdivisions = 4

        #add more shape information
        x_sp = Dense(pixelcompress, activation='elu')(x)
        x_sp = BatchNormalization(momentum=0.6)(x_sp)
        x_sp = SoftPixelCNN(mode='full',
                            subdivisions=4)([coords, x_sp, dist, nidx])
        x_sp = Dense(128, activation='elu', name='dense_spc_' + str(i))(x_sp)

        x_mp = MessagePassing([32, 32, 16, 16, 8, 8])([x, nidx])
        x_mp = Dense(nfilt, activation='elu', name='dense_mpc_' + str(i))(x_sp)

        x = Concatenate()([x, x_sp, x_gn, x_mp])
        #check and compress it all
        x = Dense(128, activation='elu', name='dense_a_' + str(i))(x)
        x = Dense(64, activation='elu', name='dense_b_' + str(i))(x)
        x = BatchNormalization(momentum=0.6)(x)

        x_r = x
        #record more and more the deeper we go
        if i < total_iterations - 1:
            x_r = Dense(12 * (i + 1),
                        activation='elu',
                        name='dense_rec_' + str(i))(x)
        else:
            energy = MultiBackGather()(
                [energy,
                 gatherids])  #assign energy sum to all cluster components

        allfeat.append(MultiBackGather()([x_r, gatherids]))

        backgatheredids.append(MultiBackGather()([sel_gidx, gatherids]))
        backgathered_coords.append(MultiBackGather()([coords, gatherids]))

    x = Concatenate(name='allconcat')(allfeat)
    #x - Dropout(0.3)(x)#force to use different information sources
    x = Dense(128, activation='elu', name='alldense')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(64, activation='elu')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Concatenate()([feat, x])
    x = BatchNormalization(momentum=0.6)(x)
    x = Concatenate()([x, energy])

    pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id = create_outputs(
        x, feat)

    #loss
    pred_beta = LLFullObjectCondensation(
        print_loss=True,
        energy_loss_weight=1e-4,
        position_loss_weight=1e-2,
        timing_loss_weight=1e-3,
        beta_loss_scale=1.,
        repulsion_scaling=1.,
        q_min=1.5,
        prob_repulsion=True,
        phase_transition=1,
        alt_potential_norm=True)([
            pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id,
            orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid,
            row_splits
        ])

    return Model(inputs=Inputs,
                 outputs=[
                     pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time,
                     pred_id, rs
                 ] + backgatheredids + backgathered_coords)
Ejemplo n.º 7
0
def gravnet_model(Inputs, feature_dropout=-1., addBackGatherInfo=True):

    feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits = td.interpretAllModelInputs(
        Inputs)
    orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits
    gidx_orig = CreateGlobalIndices()(feat)
    energy = SelectFeatures(0, 1)(feat)

    _, row_splits = RaggedConstructTensor()([feat, row_splits])
    rs = row_splits

    #n_cluster_coords=6

    x = feat  #Dense(64,activation='elu')(feat)

    backgatheredids = []
    gatherids = []
    backgathered = []
    backgathered_en = []

    allfeat = [x]

    sel_gidx = gidx_orig
    for i in range(6):

        pixelcompress = 16 + 8 * i
        nneigh = 8 + 32 * i
        n_dimensions = int(4 + i)

        x = BatchNormalization(momentum=0.6)(x)
        #do some processing
        x = Dense(64, activation='elu', name='dense_a_' + str(i))(x)
        x = BatchNormalization(momentum=0.6)(x)
        x, coords, nidx, dist = RaggedGravNet(n_neighbours=nneigh,
                                              n_dimensions=n_dimensions,
                                              n_filters=64,
                                              n_propagate=64)([x, rs])
        x = Concatenate()([x, MessagePassing([64, 32, 16, 8])([x, nidx])])

        #allfeat.append(x)

        x = Dense(128, activation='elu')(x)
        x = Dense(pixelcompress, activation='elu')(x)

        x_sp = SoftPixelCNN(length_scale=1.)([coords, x, nidx])
        x = Concatenate()([x, x_sp])
        x = BatchNormalization(momentum=0.6)(x)
        x = Dense(128, activation='elu')(x)
        x = Dense(pixelcompress, activation='elu')(x)

        hier = Dense(1, activation='sigmoid',
                     name='hier_' + str(i))(x)  #clustering hierarchy

        dist = ScalarMultiply(1 / (3. * float(i + 0.5)))(dist)

        x, rs, bidxs, t_idx = LocalClusterReshapeFromNeighbours(
            K=5,
            radius=0.1,
            print_reduction=True,
            loss_enabled=True,
            loss_scale=2.,
            loss_repulsion=0.5,
            print_loss=True)([x, dist, hier, nidx, rs, t_idx, t_idx])

        print('>>>>x0', x.shape)
        gatherids.append(bidxs)
        if addBackGatherInfo:
            backgatheredids.append(MultiBackGather()([sel_gidx, gatherids]))

        print('>>>>x1', x.shape)
        x = BatchNormalization(momentum=0.6)(x)
        x_record = Dense(2 * pixelcompress, activation='elu')(x)
        x_record = BatchNormalization(momentum=0.6)(x_record)
        allfeat.append(MultiBackGather()([x_record, gatherids]))

    x = Concatenate(name='allconcat')(allfeat)
    x = Dense(128, activation='elu', name='alldense')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(64, activation='elu')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Concatenate()([feat, x])
    x = BatchNormalization(momentum=0.6)(x)

    pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id = create_outputs(
        x, feat)

    #loss
    pred_beta = LLFullObjectCondensation(
        print_loss=True,
        energy_loss_weight=1e-3,
        position_loss_weight=1e-3,
        timing_loss_weight=1e-3,
    )([
        pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id,
        orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid,
        row_splits
    ])

    return Model(inputs=Inputs,
                 outputs=[
                     pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time,
                     pred_id, rs
                 ] + backgatheredids)
def gravnet_model(Inputs, feature_dropout=-1., addBackGatherInfo=True):

    ######## pre-process all inputs and create global indices etc. No DNN actions here

    feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits = td.interpretAllModelInputs(
        Inputs)
    feat, t_idx, t_energy, t_pos, t_time, t_pid = NormalizeInputShapes()(
        [feat, t_idx, t_energy, t_pos, t_time, t_pid])

    orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits
    gidx_orig = CreateGlobalIndices()(feat)

    _, row_splits = RaggedConstructTensor()([feat, row_splits])
    rs = row_splits

    feat_norm = ProcessFeatures()(
        feat)  #get rid of unit scalings, almost normalise
    feat_norm = BatchNormalization(momentum=0.6)(feat_norm)
    x = feat_norm

    energy = SelectFeatures(0, 1)(feat)
    time = SelectFeatures(8, 9)(feat)
    orig_coords = SelectFeatures(5, 8)(feat_norm)

    ######## create output lists

    allfeat = []

    backgatheredids = []
    gatherids = []
    backgathered = []
    backgathered_coords = []

    ####### create simple first coordinate transformation explicitly (time critical)

    coords = orig_coords
    coords = Dense(16, activation='elu')(coords)
    coords = Dense(32, activation='elu')(coords)
    coords = Dense(3, use_bias=False)(coords)
    coords = ScalarMultiply(0.1)(coords)
    coords = Add()([coords, orig_coords])
    coords = Dense(3,
                   use_bias=False,
                   kernel_initializer=tf.keras.initializers.identity())(coords)

    first_coords = coords

    ###### apply one gravnet-like transformation (explicit here because we have created coords by hand) ###

    nidx, dist = KNN(K=48)([coords, rs])
    x_mp = DistanceWeightedMessagePassing([32])([x, nidx, dist])

    first_nidx = nidx
    first_dist = dist

    ###### collect information about the surrounding energy and time distributions per vertex ###

    ncov = NeighbourCovariance()(
        [coords, ReluPlusEps()(Concatenate()([energy, time])), nidx])
    ncov = BatchNormalization(momentum=0.6)(ncov)
    ncov = Dense(64, activation='elu', name='pre_dense_ncov_a')(ncov)
    ncov = Dense(32, activation='elu', name='pre_dense_ncov_b')(ncov)

    ##### put together and process ####

    x = Concatenate()([x, x_mp, ncov, coords])
    x = Dense(64, activation='elu', name='pre_dense_a')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(32, activation='elu', name='pre_dense_b')(x)

    ####### add first set of outputs to output lists

    allfeat.append(x)
    backgathered_coords.append(coords)

    total_iterations = 5

    sel_gidx = gidx_orig

    for i in range(total_iterations):

        ###### reshape the graph to fewer vertices ####

        hier = Dense(1)(x)
        dist = LocalDistanceScaling()([dist, Dense(1)(x)])

        x_cl, rs, bidxs, sel_gidx, energy, x, t_idx, coords = LocalClusterReshapeFromNeighbours(
            K=6,
            radius=
            0.5,  #doesn't really have an effect because of local distance scaling
            print_reduction=True,
            loss_enabled=True,
            loss_scale=4.,
            loss_repulsion=0.5,
            print_loss=True,
            name='clustering_' + str(i))([
                x, dist, hier, nidx, rs, sel_gidx, energy, x, t_idx, coords,
                t_idx
            ])  #last is truth index used by layer

        gatherids.append(bidxs)

        if i or True:
            x_cl_rs = Reshape([-1, x.shape[-1]])(x_cl)  #get to shape V x K x F
            xec = EdgeConvStatic([32, 32, 32],
                                 name="ec_static_" + str(i))(x_cl_rs)
            x_cl = Concatenate()([x, xec])

        ### explicitly sum energy and re-add to features

        energy = ReduceSumEntirely()(energy)
        n_energy = BatchNormalization(momentum=0.6)(energy)
        x = Concatenate()([x_cl, n_energy])

        x = Dense(128, activation='elu', name='dense_clc0_' + str(i))(x)
        x = Dense(64, activation='relu', name='dense_clc1_' + str(i))(x)
        #notice last relu for feature weighting later

        x_gn, coords, nidx, dist = RaggedGravNet(
            n_neighbours=32 + 16 * i,
            n_dimensions=3,
            n_filters=64 + 16 * i,
            n_propagate=64,
            return_self=True)([Concatenate()([coords, x]), rs])

        ### add neighbour summary statistics

        x_ncov = NeighbourCovariance()([coords, ReluPlusEps()(x), nidx])
        x_ncov = Dense(128, activation='elu',
                       name='dense_ncov_a_' + str(i))(x_ncov)
        x_ncov = BatchNormalization(momentum=0.6)(x_ncov)
        x_ncov = Dense(64, activation='elu',
                       name='dense_ncov_b_' + str(i))(x_ncov)
        x = Concatenate()([x, x_ncov, x_gn])

        ### with all this information perform a few message passing steps

        x_mp = MessagePassing([32, 32, 16, 16, 8, 8])([x, nidx])
        x_mp = Dense(64, activation='elu', name='dense_mpc_' + str(i))(x_mp)
        x = Concatenate()([x, x_mp])

        ##### prepare output of this iteration

        x = Dense(128, activation='elu', name='dense_out_a_' + str(i))(x)
        x = BatchNormalization(momentum=0.6)(x)
        x = Dense(64, activation='elu', name='dense_out_b_' + str(i))(x)
        x = BatchNormalization(momentum=0.6)(x)

        #### compress further for output, but forward fill 64 feature x to next iteration

        x_r = Dense(8 + 16 * i, activation='elu',
                    name='dense_out_c_' + str(i))(x)
        #coords_nograd = StopGradient()(coords)
        #x_r = Concatenate()([coords_nograd,x_r]) ## add coordinates, might come handy for cluster space

        if i >= total_iterations - 1:
            energy = MultiBackGather()(
                [energy,
                 gatherids])  #assign energy sum to all cluster components

        allfeat.append(MultiBackGather()([x_r, gatherids]))
        backgatheredids.append(MultiBackGather()([sel_gidx, gatherids]))
        backgathered_coords.append(MultiBackGather()([coords, gatherids]))

    x = Concatenate(name='allconcat')(allfeat)
    #x = Dropout(0.2)(x)
    x_mp = DistanceWeightedMessagePassing([32, 32,
                                           32])([x, first_nidx, first_dist])
    x = Concatenate()([x, x_mp])

    x = Dense(128, activation='elu', name='alldense')(x)
    # TO BE ADDED WITH E LOSS x = Concatenate()([x,energy])
    #x = Dropout(0.2)(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(64, activation='elu')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(64, activation='elu')(x)

    pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id = create_outputs(
        x, feat)

    #
    #
    # double scale phase transition with linear beta + qmin
    #  -> more high beta points, but: payload loss will still scale one
    #     (or two, but then doesn't matter)
    #

    pred_beta = LLFullObjectCondensation(
        print_loss=True,
        energy_loss_weight=0.,
        position_loss_weight=0.,  #seems broken
        timing_loss_weight=0.,  #1e-3,
        beta_loss_scale=1.,
        repulsion_scaling=1.,
        q_min=1.5,
        prob_repulsion=True,
        phase_transition=True,
        phase_transition_double_weight=False,
        alt_potential_norm=True,
        cut_payload_beta_gradient=False)([
            pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id,
            orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid,
            row_splits
        ])

    return ExtendedMetricsModel(inputs=Inputs,
                                outputs=[
                                    pred_beta, pred_ccoords, pred_energy,
                                    pred_pos, pred_time, pred_id, rs
                                ] + backgatheredids + backgathered_coords)