示例#1
0
def gravnet_model(Inputs,
                  td,
                  debug_outdir=None,
                  plot_debug_every=1000,
                  ):
    ####################################################################################
    ##################### Input processing, no need to change much here ################
    ####################################################################################

    orig_inputs = td.interpretAllModelInputs(Inputs,returndict=True)
    
    
    orig_t_spectator_weight = CreateTruthSpectatorWeights(threshold=3.,
                                                     minimum=1e-1,
                                                     active=True
                                                     )([orig_inputs['t_spectator'], 
                                                        orig_inputs['t_idx']])
                                                     
    orig_inputs['t_spectator_weight'] = orig_t_spectator_weight                                                 
    #can be loaded - or use pre-selected dataset (to be made)
    pre_selection = pre_selection_model_full(orig_inputs,trainable=False)
    
    #just for info what's available
    print('available pre-selection outputs',[k for k in pre_selection.keys()])
                                          
    
    t_spectator_weight = CreateTruthSpectatorWeights(threshold=3.,
                                                     minimum=1e-1,
                                                     active=True
                                                     )([pre_selection['t_spectator'], 
                                                        pre_selection['t_idx']])
    rs = pre_selection['rs']
                               
    x_in = Concatenate()([pre_selection['coords'],
                          pre_selection['features'],
                          pre_selection['addfeat']])
                           
    x = x_in
    energy = pre_selection['energy']
    coords = pre_selection['phys_coords']#physical coordinates
    c_coords = pre_selection['coords']#pre-clustered coordinates
    t_idx = pre_selection['t_idx']
    
    ####################################################################################
    ##################### now the actual model goes below ##############################
    ####################################################################################
    
    allfeat = []
    
    n_cluster_space_coordinates = 3
    
    
    #extend coordinates already here if needed
    c_coords = extent_coords_if_needed(c_coords, x, n_cluster_space_coordinates)
        

    for i in range(total_iterations):

        # derive new coordinates for clustering
        x = RaggedGlobalExchange()([x, rs])
        
        x = Dense(64,activation=dense_activation)(x)
        x = Dense(64,activation=dense_activation)(x)
        x = Dense(64,activation=dense_activation)(x)
        x = GooeyBatchNorm(**batchnorm_options)(x)
        ### reduction done
        
        n_dims = 6
        #exchange information, create coordinates
        x = Concatenate()([c_coords,c_coords,c_coords,coords,x])
        xgn, gncoords, gnnidx, gndist = RaggedGravNet(n_neighbours=64,
                                                 n_dimensions=n_dims,
                                                 n_filters=64,
                                                 n_propagate=64,
                                                 record_metrics=True,
                                                 coord_initialiser_noise=1e-2,
                                                 use_approximate_knn=False #weird issue with that for now
                                                 )([x, rs])
        
        x = Concatenate()([x,xgn])                                                      
        #just keep them in a reasonable range  
        #safeguard against diappearing gradients on coordinates                                       
        gndist = AverageDistanceRegularizer(strength=1e-4,
                                            record_metrics=True
                                            )(gndist)
                                            
        gncoords = PlotCoordinates(plot_debug_every, outdir = debug_outdir,
                                   name='gn_coords_'+str(i))([gncoords, 
                                                                    energy,
                                                                    t_idx,
                                                                    rs]) 
        x = Concatenate()([gncoords,x])           
        
        pre_gndist=gndist
        if double_mp:
            for im,m in enumerate([64,64,32,32,16,16]):
                dscale=Dense(1)(x)
                gndist = LocalDistanceScaling(4.)([pre_gndist,dscale])                                  
                gndist = AverageDistanceRegularizer(strength=1e-6,
                                            record_metrics=True,
                                            name='average_distance_dmp_'+str(i)+'_'+str(im)
                                            )(gndist)
                                            
                x = DistanceWeightedMessagePassing([m],
                                           activation=dense_activation
                                           )([x,gnnidx,gndist])
        else:        
            x = DistanceWeightedMessagePassing([64,64,32,32,16,16],
                                           activation=dense_activation
                                           )([x,gnnidx,gndist])
            
        x = GooeyBatchNorm(**batchnorm_options)(x)
        
        x = Dense(64,name='dense_past_mp_'+str(i),activation=dense_activation)(x)
        x = Dense(64,activation=dense_activation)(x)
        x = Dense(64,activation=dense_activation)(x)
        
        x = GooeyBatchNorm(**batchnorm_options)(x)
        
        
        allfeat.append(x)
        
        
    
    x = Concatenate()([c_coords]+allfeat+[pre_selection['not_noise_score']])
    #do one more exchange with all
    x = Dense(64,activation=dense_activation)(x)
    x = Dense(64,activation=dense_activation)(x)
    x = Dense(64,activation=dense_activation)(x)
    
    
    #######################################################################
    ########### the part below should remain almost unchanged #############
    ########### of course with the exception of the OC loss   #############
    ########### weights                                       #############
    #######################################################################
    
    x = GooeyBatchNorm(**batchnorm_options,name='gooey_pre_out')(x)
    x = Concatenate()([c_coords]+[x])
    
    pred_beta, pred_ccoords, pred_dist, pred_energy_corr, \
    pred_pos, pred_time, pred_id = create_outputs(x, pre_selection['unproc_features'], 
                                                  n_ccoords=n_cluster_space_coordinates)
    
    # loss
    pred_beta = LLFullObjectCondensation(scale=4.,
                                         position_loss_weight=1e-5,
                                         timing_loss_weight=1e-5,
                                         beta_loss_scale=1.,
                                         use_energy_weights=True,
                                         record_metrics=True,
                                         name="FullOCLoss",
                                         **loss_options
                                         )(  # oc output and payload
        [pred_beta, pred_ccoords, pred_dist,
         pred_energy_corr, pred_pos, pred_time, pred_id] +
        [energy]+
        # truth information
        [pre_selection['t_idx'] ,
         pre_selection['t_energy'] ,
         pre_selection['t_pos'] ,
         pre_selection['t_time'] ,
         pre_selection['t_pid'] ,
         pre_selection['t_spectator_weight'],
         pre_selection['t_fully_contained'],
         pre_selection['t_rec_energy'],
         pre_selection['t_is_unique'],
         pre_selection['rs']])
                                         
    #fast feedback
    pred_ccoords = PlotCoordinates(plot_debug_every, outdir = debug_outdir,
                    name='condensation')([pred_ccoords, pred_beta,pre_selection['t_idx'],
                                          rs])                                    

    model_outputs = re_integrate_to_full_hits(
        pre_selection,
        pred_ccoords,
        pred_beta,
        pred_energy_corr,
        pred_pos,
        pred_time,
        pred_id,
        pred_dist,
        dict_output=True
        )
    
    return DictModel(inputs=Inputs, outputs=model_outputs)
示例#2
0
def gravnet_model(Inputs, feature_dropout=-1., addBackGatherInfo=True):

    ######## pre-process all inputs and create global indices etc. No DNN actions here

    feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits = td.interpretAllModelInputs(
        Inputs)

    # feat = Lambda(lambda x: tf.squeeze(x,axis=1)) (feat)

    #tf.print([(t.shape, t.name) for t in [feat,  t_idx, t_energy, t_pos, t_time, t_pid, row_splits]])

    #feat,  t_idx, t_energy, t_pos, t_time, t_pid = NormalizeInputShapes()(
    #    [feat,  t_idx, t_energy, t_pos, t_time, t_pid]
    #    )

    orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits
    gidx_orig = CreateGlobalIndices()(feat)

    _, row_splits = RaggedConstructTensor()([feat, row_splits])
    rs = row_splits

    feat_norm = ProcessFeatures()(
        feat)  #get rid of unit scalings, almost normalise
    feat_norm = BatchNormalization(momentum=0.6)(feat_norm)
    x = feat_norm

    energy = SelectFeatures(0, 1)(feat)
    time = SelectFeatures(8, 9)(feat)
    orig_coords = SelectFeatures(5, 8)(feat_norm)

    ######## create output lists

    allfeat = []

    backgatheredids = []
    gatherids = []
    backgathered = []
    backgathered_coords = []

    ####### create simple first coordinate transformation explicitly (time critical)

    #trans_coords = ManualCoordTransform()(orig_coords)
    coords = orig_coords
    coords = Dense(3,
                   use_bias=False,
                   kernel_initializer=tf.keras.initializers.Identity())(coords)

    nidx, dist = KNN(K=48)([coords, rs])

    first_nidx = nidx
    first_dist = dist
    first_coords = coords

    dist = LocalDistanceScaling(max_scale=10.)([dist, Dense(1)(x)])

    x_mp = DistanceWeightedMessagePassing([32, 16, 8])([x, nidx, dist])
    x_mp = Dense(32, activation='elu')(x_mp)
    x_mp = BatchNormalization(momentum=0.6)(x_mp)
    x = Concatenate()([x, x_mp])

    ###### collect information about the surrounding energy and time distributions per vertex ###
    ncov = Dense(4, kernel_initializer=tf.keras.initializers.Identity())(
        Concatenate()([energy, time, x]))
    ncov = NeighbourCovariance()([coords, dist, ncov, nidx])
    ncov = Dense(24, activation='elu', name='pre_dense_ncov_c')(ncov)
    ncov = BatchNormalization(momentum=0.6)(ncov)
    #should be enough for PCA, total info: 2 * (9+3)

    ##### put together and process ####

    x = Concatenate()([x, x_mp, ncov])
    #x = Dense(64, activation='elu',name='pre_dense_a')(x)
    x = Dense(64, activation='elu', name='pre_dense_b')(x)

    ####### add first set of outputs to output lists

    allfeat.append(x)
    backgathered_coords.append(coords)

    total_iterations = 5

    sel_gidx = gidx_orig

    for i in range(total_iterations):

        ###### reshape the graph to fewer vertices ####

        hier = Dense(1)(x)
        #narrow the local scaling down at each iteration as the coords become more abstract
        dist = LocalDistanceScaling(max_scale=10.)(
            [dist, Dense(1)(Concatenate()([x, dist]))])

        x_cl, rs, bidxs, sel_gidx, energy, x, t_idx, coords = LocalClusterReshapeFromNeighbours(
            K=6,
            radius=
            0.2,  #doesn't really have an effect because of local distance scaling
            print_reduction=True,
            loss_enabled=True,
            loss_scale=1.,
            loss_repulsion=0.4,
            print_loss=True,
            name='clustering_' + str(i))([
                x, dist, hier, nidx, rs, sel_gidx, energy, x, t_idx, coords,
                t_idx
            ])  #last is truth index used by layer

        gatherids.append(bidxs)

        energy = ReduceSumEntirely()(energy)

        #use EdgeConv operation to determine cluster properties
        if True or i:  #only after second iteration because of OOM
            x_cl = Reshape([-1, x.shape[-1]])(x_cl)  #get to shape V x K x F
            x_cl = EdgeConvStatic([64, 64, 64],
                                  add_mean=True,
                                  name="ec_static_" + str(i))(x_cl)
            x_cl = Concatenate()([x, x_cl])

        x = Concatenate()([x_cl, StopGradient()(coords)])

        x = Dense(64, activation='elu', name='dense_clc0_' + str(i))(x)
        x = Dense(64, activation='elu', name='dense_clc1_' + str(i))(x)

        x = BatchNormalization(momentum=0.6)(x)
        #notice last relu for feature weighting later

        ### now these are the new cluster features, up for the next iteration of building new latent space
        #x = RaggedGlobalExchange()([x,rs])

        x_gn, coords, nidx, dist = RaggedGravNet(
            n_neighbours=64 + 32 * i,
            n_dimensions=3,
            n_filters=64,
            n_propagate=64,
            return_self=True)([Concatenate()([coords, x]), rs])

        ng_coords = StopGradient()(coords)
        ng_dist = StopGradient()(dist)

        x_gn = Dense(32, activation='elu')(x_gn)

        ### add neighbour summary statistics
        #dist = LocalDistanceScaling(max_scale=10.)([dist, Dense(1)(x_gn)])

        x_ncov = Dense(16)(x)
        x_ncov = NeighbourCovariance()([ng_coords, ng_dist, x_ncov, nidx])
        x_ncov = Dense(64, activation='elu',
                       name='dense_ncov_a_' + str(i))(x_ncov)
        x_ncov = Dense(64, activation='elu',
                       name='dense_ncov_b_' + str(i))(x_ncov)
        x_ncov = BatchNormalization(momentum=0.6)(x_ncov)

        ### with all this information perform a few message passing steps
        x_mp = x
        x_mp = DistanceWeightedMessagePassing([32, 16,
                                               8])([x_mp, nidx, ng_dist])
        x_mp = Dense(64, activation='elu')(x_mp)
        x_mp = Dense(32, activation='elu')(x_mp)
        x_mp = BatchNormalization(momentum=0.6)(x_mp)

        x = Concatenate()([x, x_mp, x_ncov, x_gn, ng_coords, ng_dist])

        ##### prepare output of this iteration

        x = Dense(64, activation='elu', name='dense_out_a_' + str(i))(x)
        x = Dense(64, activation='elu', name='dense_out_b_' + str(i))(x)
        x = BatchNormalization(momentum=0.6)(x)

        #### compress further for output, but forward fill 64 feature x to next iteration

        x_r = Dense(32, activation='elu', name='dense_out_c_' + str(i))(x)
        #x_r = Concatenate()([x_r, ng_coords, ng_dist])

        #x_r = Concatenate()([StopGradient()(coords),x_r]) ## add coordinates, might come handy for cluster space

        if i >= total_iterations - 1:
            energy = MultiBackGather()(
                [energy,
                 gatherids])  #assign energy sum to all cluster components

        allfeat.append(MultiBackGather()([x_r, gatherids]))
        backgatheredids.append(MultiBackGather()([sel_gidx, gatherids]))
        backgathered_coords.append(MultiBackGather()([coords, gatherids]))

    x = Concatenate(name='allconcat')(allfeat)
    x = RaggedGlobalExchange()([x, row_splits])
    x = Dense(256, activation='elu', name='globalexdense')(x)
    x = RaggedGlobalExchange()([x, row_splits])

    x = Dense(128, activation='elu', name='alldense')(x)
    x = Dense(64, activation='elu')(x)
    x = Dense(32, activation='elu')(x)
    x = Concatenate()([first_coords, x])
    x = BatchNormalization(momentum=0.6)(x)

    pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id = create_outputs(
        x, feat)

    #
    #
    # double scale phase transition with linear beta + qmin
    #  -> more high beta points, but: payload loss will still scale one
    #     (or two, but then doesn't matter)
    #

    pred_beta = LLFullObjectCondensation(
        print_loss=True,
        energy_loss_weight=1e-5,
        position_loss_weight=1e-5,  #seems broken
        timing_loss_weight=1e-5,  #1e-3,
        beta_loss_scale=3.,
        repulsion_scaling=1.,
        q_min=2.0,
        use_average_cc_pos=False,
        use_energy_weights=True,
        prob_repulsion=True,
        phase_transition=True,
        phase_transition_double_weight=False,
        alt_potential_norm=True,
        cut_payload_beta_gradient=False)([
            pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id,
            orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid,
            row_splits
        ])

    return Model(inputs=Inputs,
                 outputs=[
                     pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time,
                     pred_id, rs
                 ] + backgatheredids + backgathered_coords)
示例#3
0
def gravnet_model(
    Inputs,
    td,
    empty_pca=False,
    total_iterations=2,
    variance_only=True,
    viscosity=0.1,
    print_viscosity=False,
    fluidity_decay=5e-4,  # reaches after about 7k batches
    max_viscosity=0.95,
    debug_outdir=None,
    plot_debug_every=1000,
):
    ####################################################################################
    ##################### Input processing, no need to change much here ################
    ####################################################################################

    orig_inputs = td.interpretAllModelInputs(Inputs, returndict=True)

    orig_t_spectator_weight = CreateTruthSpectatorWeights(
        threshold=3., minimum=1e-1,
        active=True)([orig_inputs['t_spectator'], orig_inputs['t_idx']])

    orig_inputs['t_spectator_weight'] = orig_t_spectator_weight
    #can be loaded - or use pre-selected dataset (to be made)
    pre_selection = pre_selection_model_full(orig_inputs, trainable=False)

    #just for info what's available
    print([k for k in pre_selection.keys()])

    t_spectator_weight = CreateTruthSpectatorWeights(
        threshold=3., minimum=1e-1,
        active=True)([pre_selection['t_spectator'], pre_selection['t_idx']])
    rs = pre_selection['rs']

    x_in = Concatenate()([
        pre_selection['coords'], pre_selection['features'],
        pre_selection['addfeat']
    ])

    x = x_in
    energy = pre_selection['energy']
    coords = pre_selection['phys_coords']  #physical coordinates
    c_coords = pre_selection['coords']  #pre-clustered coordinates
    t_idx = pre_selection['t_idx']

    ####################################################################################
    ##################### now the actual model goes below ##############################
    ####################################################################################

    allfeat = []

    n_cluster_space_coordinates = 3

    #extend coordinates already here if needed
    coords = extent_coords_if_needed(coords, x, n_cluster_space_coordinates)

    for i in range(total_iterations):

        # derive new coordinates for clustering
        x = RaggedGlobalExchange()([x, rs])

        x = Dense(64, activation='relu')(x)
        x = Dense(64, activation='relu')(x)
        x = Dense(64, activation='relu')(x)
        x = GooeyBatchNorm(viscosity=viscosity,
                           max_viscosity=max_viscosity,
                           variance_only=variance_only,
                           record_metrics=True,
                           fluidity_decay=fluidity_decay)(x)
        ### reduction done

        n_dims = 3
        #exchange information, create coordinates
        x = Concatenate()([coords, coords, c_coords, x])
        x, gncoords, gnnidx, gndist = RaggedGravNet(
            n_neighbours=64,
            n_dimensions=n_dims,
            n_filters=64,
            n_propagate=64,
            record_metrics=True,
            use_approximate_knn=False  #weird issue with that for now
        )([x, rs])

        gncoords = PlotCoordinates(plot_debug_every,
                                   outdir=debug_outdir,
                                   name='gn_coords_' +
                                   str(i))([gncoords, energy, t_idx, rs])
        #just keep them in a reasonable range
        #safeguard against diappearing gradients on coordinates
        gndist = AverageDistanceRegularizer(strength=0.01,
                                            record_metrics=True)(gndist)
        x = Concatenate()([energy, x])
        x_pca = Dense(4, activation='relu')(x)  #pca is expensive
        x_pca = ApproxPCA(empty=empty_pca)([gncoords, gndist, x_pca, gnnidx])
        x = Concatenate()([x, x_pca])

        x = DistanceWeightedMessagePassing([64, 64, 32, 32, 16,
                                            16])([x, gnnidx, gndist])

        x = GooeyBatchNorm(viscosity=viscosity,
                           max_viscosity=max_viscosity,
                           record_metrics=True,
                           variance_only=variance_only,
                           fluidity_decay=fluidity_decay)(x)

        x = Dense(64, activation='relu')(x)
        x = Dense(64, activation='relu')(x)
        x = Dense(64, activation='relu')(x)

        x = GooeyBatchNorm(viscosity=viscosity,
                           max_viscosity=max_viscosity,
                           variance_only=variance_only,
                           record_metrics=True,
                           fluidity_decay=fluidity_decay)(x)

        allfeat.append(x)

    x = Concatenate()([c_coords] + allfeat +
                      [pre_selection['not_noise_score']])
    #do one more exchange with all
    x = Dense(64, activation='elu')(x)
    x = Dense(64, activation='elu')(x)
    x = Dense(64, activation='elu')(x)

    #######################################################################
    ########### the part below should remain almost unchanged #############
    ########### of course with the exception of the OC loss   #############
    ########### weights                                       #############
    #######################################################################

    x = GooeyBatchNorm(viscosity=viscosity,
                       max_viscosity=max_viscosity,
                       fluidity_decay=fluidity_decay,
                       record_metrics=True,
                       variance_only=variance_only,
                       name='gooey_pre_out')(x)
    x = Concatenate()([c_coords] + [x])

    pred_beta, pred_ccoords, pred_dist, pred_energy_corr, \
    pred_pos, pred_time, pred_id = create_outputs(x, pre_selection['unproc_features'],
                                                  n_ccoords=n_cluster_space_coordinates)

    # loss
    pred_beta = LLFullObjectCondensation(
        scale=1.,
        energy_loss_weight=2.,
        #print_batch_time=True,
        position_loss_weight=1e-5,
        timing_loss_weight=1e-5,
        classification_loss_weight=1e-5,
        beta_loss_scale=1.,
        too_much_beta_scale=1e-4,
        use_energy_weights=True,
        record_metrics=True,
        q_min=0.2,
        #div_repulsion=True,
        # cont_beta_loss=True,
        # beta_gradient_damping=0.999,
        # phase_transition=1,
        #huber_energy_scale=0.1,
        use_average_cc_pos=0.2,  # smoothen it out a bit
        name="FullOCLoss")(  # oc output and payload
            [
                pred_beta, pred_ccoords, pred_dist, pred_energy_corr, pred_pos,
                pred_time, pred_id
            ] + [energy] +
            # truth information
            [
                pre_selection['t_idx'], pre_selection['t_energy'],
                pre_selection['t_pos'], pre_selection['t_time'],
                pre_selection['t_pid'], pre_selection['t_spectator_weight'],
                pre_selection['t_fully_contained'],
                pre_selection['t_rec_energy'], pre_selection['t_is_unique'],
                pre_selection['rs']
            ])

    #fast feedback
    pred_ccoords = PlotCoordinates(plot_debug_every,
                                   outdir=debug_outdir,
                                   name='condensation')([
                                       pred_ccoords, pred_beta,
                                       pre_selection['t_idx'], rs
                                   ])

    model_outputs = re_integrate_to_full_hits(pre_selection,
                                              pred_ccoords,
                                              pred_beta,
                                              pred_energy_corr,
                                              pred_pos,
                                              pred_time,
                                              pred_id,
                                              pred_dist,
                                              dict_output=True)

    return DictModel(inputs=Inputs, outputs=model_outputs)
示例#4
0
def gravnet_model(Inputs, feature_dropout=-1., addBackGatherInfo=True):

    feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits = td.interpretAllModelInputs(
        Inputs)
    orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits
    gidx_orig = CreateGlobalIndices()(feat)

    _, row_splits = RaggedConstructTensor()([feat, row_splits])
    rs = row_splits

    feat_norm = ProcessFeatures()(feat)
    feat_norm = BatchNormalization(momentum=0.6)(feat_norm)
    allfeat = []
    x = feat_norm

    backgatheredids = []
    gatherids = []
    backgathered = []
    backgathered_coords = []
    energysums = []

    #really simple real coordinates
    energy = SelectFeatures(0, 1)(feat)
    orig_coords = SelectFeatures(4, 7)(feat_norm)
    coords = Dense(3,
                   use_bias=False,
                   kernel_initializer=tf.keras.initializers.Identity())(
                       orig_coords)  #just rotation and scaling

    #see whats there
    nidx, dist = KNN(K=64, radius=1.0)([coords, rs])

    x = RaggedGlobalExchange()([x, row_splits])
    x_c = Dense(32, activation='elu')(x)
    x_c = Dense(8)(x_c)  #just a few features are enough here
    #this can be full blown because of the small number of input features
    x_c = NeighbourApproxPCA(hidden_nodes=[32, 32, 9])(
        [coords, dist, x_c, nidx])
    x_mp = DistanceWeightedMessagePassing([32, 16, 8])([x, nidx, dist])
    x = Concatenate()([x, x_c, x_mp])
    #this is going to be among the most expensive operations:
    x = Dense(64, activation='elu', name='pre_dense_a')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(32, activation='elu', name='pre_dense_b')(x)
    x = BatchNormalization(momentum=0.6)(x)

    allfeat.append(x)
    backgathered_coords.append(coords)

    sel_gidx = gidx_orig

    cdist = dist
    ccoords = coords

    total_iterations = 5

    collectedccoords = []

    for i in range(total_iterations):

        cluster_neighbours = 5
        n_dimensions = 3  #make it plottable
        #derive new coordinates for clustering
        if i:
            ccoords = Add()([
                ccoords,
                (Dense(n_dimensions,
                       use_bias=False,
                       name='newcoords' + str(i),
                       kernel_initializer='zeros')(x))
            ])
            nidx, cdist = KNN(K=6 * cluster_neighbours,
                              radius=-1.0)([ccoords, rs])
            #here we use more neighbours to improve learning of the cluster space

        #cluster first
        #hier = Dense(1,activation='sigmoid')(x)
        #distcorr = Dense(dist.shape[-1],activation='relu',kernel_initializer='zeros')(Concatenate()([x,dist]))
        #dist = Add()([distcorr,dist])
        cdist = LocalDistanceScaling(max_scale=5.)([
            cdist,
            Dense(1, kernel_initializer='zeros')(Concatenate()([x, cdist]))
        ])
        #what if we let the individual distances scale here? so move hits in and out?

        x_cl, rs, bidxs, sel_gidx, energy, x, t_idx, coords, ccoords, cdist = LocalClusterReshapeFromNeighbours2(
            K=cluster_neighbours,
            radius=0.1,
            print_reduction=True,
            loss_enabled=True,
            loss_scale=1.,
            loss_repulsion=0.4,  #.5
            hier_transforms=[64, 32, 32],
            print_loss=True,
            name='clustering_' + str(i))([
                x, cdist, nidx, rs, sel_gidx, energy, x, t_idx, coords,
                ccoords, cdist, t_idx
            ])

        gatherids.append(bidxs)

        #explicit
        energy = ReduceSumEntirely()(
            energy)  #sums up all contained energy per cluster
        n_energy = BatchNormalization(momentum=0.6)(energy)

        x = Concatenate()([x_cl, n_energy])
        x = Dense(128, activation='elu', name='dense_clc_a' + str(i))(x)
        x = BatchNormalization(momentum=0.6)(x)
        #x = Dense(128, activation='elu',name='dense_clc_b'+str(i))(x)
        x = Dense(64, activation='elu')(x)
        x = BatchNormalization(momentum=0.6)(x)

        nneigh = 64
        nfilt = 64
        nprop = 64

        x = Concatenate()([coords, x])

        x_gn, coords, nidx, dist = RaggedGravNet(n_neighbours=nneigh,
                                                 n_dimensions=n_dimensions,
                                                 n_filters=nfilt,
                                                 n_propagate=nprop)([x, rs])

        x_sp = Dense(16)(x)
        x_sp = NeighbourApproxPCA(hidden_nodes=[32, 32, n_dimensions**2])(
            [coords, dist, x_sp, nidx])
        x_sp = BatchNormalization(momentum=0.6)(x_sp)

        x_mp = DistanceWeightedMessagePassing([32, 32, 16, 16, 8,
                                               8])([x, nidx, dist])
        x_mp = BatchNormalization(momentum=0.6)(x_mp)
        #x_sp=x_mp

        x = Concatenate()([x, x_mp, x_sp, x_gn])
        #check and compress it all
        x = Dense(128, activation='elu', name='dense_a_' + str(i))(x)
        x = BatchNormalization(momentum=0.6)(x)
        #x = Dense(128, activation='elu',name='dense_b_'+str(i))(x)
        x = Dense(64, activation='elu', name='dense_c_' + str(i))(x)
        x = Concatenate()([StopGradient()(ccoords), StopGradient()(cdist), x])
        x = BatchNormalization(momentum=0.6)(x)

        #record more and more the deeper we go
        x_r = x
        energysums.append(MultiBackGather()(
            [energy, gatherids]))  #assign energy sum to all cluster components

        allfeat.append(MultiBackGather()([x_r, gatherids]))

        backgatheredids.append(MultiBackGather()([sel_gidx, gatherids]))
        backgathered_coords.append(MultiBackGather()([ccoords, gatherids]))

    x = Concatenate(name='allconcat')(allfeat)
    x = RaggedGlobalExchange()([x, row_splits])
    #x - Dropout(0.3)(x)#force to use different information sources
    x = Dense(128, activation='elu', name='alldense')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(64, activation='elu')(x)
    #x = Concatenate()([x]+energysums)

    pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id = create_outputs(
        x, feat)

    #loss
    pred_beta = LLFullObjectCondensation(
        print_loss=True,
        energy_loss_weight=1e-4,
        position_loss_weight=1e-2,
        timing_loss_weight=1e-3,
        beta_loss_scale=1.,
        repulsion_scaling=1.,
        q_min=1.5,
        use_average_cc_pos=0.1,
        prob_repulsion=True,
        phase_transition=1,
        alt_potential_norm=True,
        kalpha_damping_strength=0.5,  #1.,
        name="FullOCLoss")([
            pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id,
            orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid,
            row_splits
        ])

    return Model(inputs=Inputs,
                 outputs=[
                     pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time,
                     pred_id, rs
                 ] + backgatheredids + backgathered_coords)