def gravnet_model(Inputs,
                  viscosity=0.2,
                  print_viscosity=False,
                  fluidity_decay=1e-3,
                  max_viscosity=0.9 # to start with
                  ):

    feature_dropout=-1.
    addBackGatherInfo=True,

    feat,  t_idx, t_energy, t_pos, t_time, t_pid, t_spectator, t_fully_contained, row_splits = td.interpretAllModelInputs(Inputs)
    orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits
    gidx_orig = CreateGlobalIndices()(feat)

    _, row_splits = RaggedConstructTensor()([feat, row_splits])
    rs = row_splits

    feat_norm = ProcessFeatures()(feat)
    energy = SelectFeatures(0,1)(feat)
    time = SelectFeatures(8,9)(feat_norm)
    orig_coords = SelectFeatures(5,8)(feat)

    x = feat_norm

    allfeat=[x]
    backgatheredids=[]
    gatherids=[]
    backgathered = []
    backgathered_coords = []
    energysums = []

    n_reshape_dimensions=3

    #really simple real coordinates
    coords = ManualCoordTransform()(orig_coords)
    coords = Concatenate()([coords, time])
    coords = Dense(n_reshape_dimensions, use_bias=False, kernel_initializer=tf.keras.initializers.Identity()  )(coords)#just rotation and scaling


    #see whats there
    nidx, dist = KNN(K=24,radius=1.0)([coords,rs])

    x_mp = DistanceWeightedMessagePassing([32,16,16,8])([x,nidx,dist])
    x = Concatenate()([x,x_mp])
    #this is going to be among the most expensive operations:
    x = Dense(64, activation='elu',name='pre_dense_a')(x)
    x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x)

    allfeat.append(x)
    backgathered_coords.append(coords)

    sel_gidx = gidx_orig

    cdist = dist
    ccoords = coords

    total_iterations=5

    fwdbgccoords=None

    for i in range(total_iterations):

        cluster_neighbours = 5
        #derive new coordinates for clustering
        if i:
            ccoords = Add()([
                ccoords,Dense(n_reshape_dimensions,name='newcoords'+str(i),
                                         kernel_initializer='zeros'
                                         )(x)
                ])
            nidx, cdist = KNN(K=5*cluster_neighbours,radius=-1.0)([ccoords,rs])
            #here we use more neighbours to improve learning of the cluster space
            #this can be adjusted in the final trained model to be equal to 'cluster_neighbours'
        cdist = LocalDistanceScaling(max_scale=10.)([cdist, Dense(1,kernel_initializer='zeros')(Concatenate()([x,cdist]))])

        x, rs, bidxs, sel_gidx, energy, x, t_idx, coords, ccoords, cdist = LocalClusterReshapeFromNeighbours2(
                 K=cluster_neighbours,
                 radius=0.1,
                 print_reduction=True,
                 loss_enabled=True,
                 loss_scale = 3.,
                 loss_repulsion=0.8, # it's important to get this right here
                 hier_transforms=[64,32,16],
                 print_loss=False,
                 name='clustering_'+str(i)
                 )([x, cdist, nidx, rs, sel_gidx, energy, x, t_idx, coords, ccoords, cdist, t_idx])

        gatherids.append(bidxs)

        #explicit
        energy = ReduceSumEntirely()(energy)#sums up all contained energy per cluster

        x = Dense(128, activation='elu',name='dense_clc_a'+str(i))(x)
        x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x)
        x = Dense(128, activation='elu',name='dense_clc_b'+str(i))(x)
        x = Dense(64, activation='elu')(x)
        x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x)


        n_dimensions = 3+i #make it plottable
        nneigh = 64+32*i #this will be almost fully connected for last clustering step
        nfilt = 64
        nprop = 64

        x = Concatenate()([coords,x])

        x_gn, coords, nidx, dist = RaggedGravNet(n_neighbours=nneigh,
                                              n_dimensions=n_dimensions,
                                              n_filters=nfilt,
                                              n_propagate=nprop)([x, rs])


        x_pca = Dense(2+4*i)(x)
        x_pca = NeighbourApproxPCA()([coords,dist,x_pca,nidx])
        x_pca = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x_pca)

        x_mp = DistanceWeightedMessagePassing([64,64,32,32])([x,nidx,dist])
        x_mp = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x_mp)

        x = Concatenate()([x,x_pca,x_mp,x_gn])
        #check and compress it all
        x = Dense(128, activation='elu',name='dense_a_'+str(i))(x)
        x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x)
        x = Dense(32+16*i, activation='elu',name='dense_c_'+str(i))(x)
        x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x)


        energysums.append( MultiBackGather()([energy, gatherids]) )#assign energy sum to all cluster components

        allfeat.append(MultiBackGather()([x, gatherids]))

        backgatheredids.append(MultiBackGather()([sel_gidx, gatherids]))
        bgccoords = MultiBackGather()([ccoords, gatherids])
        if fwdbgccoords is None:
            fwdbgccoords = bgccoords
        backgathered_coords.append(bgccoords)




    x = Concatenate(name='allconcat')(allfeat)
    x = Dense(128, activation='elu', name='alldense')(x)
    x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x)

    #this is going to be resource intense, give a good starting point with the last ccoords
    coords = Dense(3,use_bias=False,
                   kernel_initializer=tf.keras.initializers.Identity()
                   )(Concatenate()([ fwdbgccoords ,x]))

    nidx, dist = KNN(K=32,radius=1.0)([coords,row_splits])
    x_mp = DistanceWeightedMessagePassing(8*[8])([x,nidx,dist])#only some but a lot of hops
    x = Concatenate()([x,x_mp])
    #
    backgathered_coords.append(coords)

    x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x)
    x = Dense(128, activation='elu')(x)
    x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x)
    x = Concatenate()([x]+energysums)
    x = Dense(64, activation='elu')(x)
    x = Dense(64, activation='elu')(x)

    pred_beta, pred_ccoords, pred_dist, pred_energy, pred_pos, pred_time, pred_id = create_outputs(x,feat,add_distance_scale=True)

    #loss
    pred_beta = LLFullObjectCondensation(print_loss=True,
                                         energy_loss_weight=1e-4,
                                         position_loss_weight=1e-3,
                                         timing_loss_weight=1e-3,
                                         beta_loss_scale=1.,
                                         repulsion_scaling=5.,
                                         repulsion_q_min=1.,
                                         super_repulsion=False,
                                         q_min=0.5,
                                         use_average_cc_pos=0.5,
                                         prob_repulsion=True,
                                         phase_transition=1,
                                         huber_energy_scale = 3,
                                         alt_potential_norm=True,
                                         beta_gradient_damping=0.,
                                         payload_beta_gradient_damping_strength=0.,
                                         kalpha_damping_strength=0.,#1.,
                                         use_local_distances=True,
                                         name="FullOCLoss"
                                         )([pred_beta, pred_ccoords,
                                            pred_dist,
                                            pred_energy,
                                            pred_pos, pred_time, pred_id,
                                            orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid,
                                            row_splits])

    model_outputs = [('pred_beta', pred_beta), ('pred_ccoords',pred_ccoords),
       ('pred_energy',pred_energy),
       ('pred_pos',pred_pos),
       ('pred_time',pred_time),
       ('pred_id',pred_id),
       ('pred_dist',pred_dist),
       ('row_splits',rs)]

    for i, (x, y) in enumerate(zip(backgatheredids, backgathered_coords)):
        model_outputs.append(('backgatheredids_'+str(i), x))
        model_outputs.append(('backgathered_coords_'+str(i), y))
    return RobustModel(model_inputs=Inputs, model_outputs=model_outputs)
Beispiel #2
0
def gravnet_model(Inputs, feature_dropout=-1., addBackGatherInfo=True):

    ######## pre-process all inputs and create global indices etc. No DNN actions here

    feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits = td.interpretAllModelInputs(
        Inputs)

    # feat = Lambda(lambda x: tf.squeeze(x,axis=1)) (feat)

    #tf.print([(t.shape, t.name) for t in [feat,  t_idx, t_energy, t_pos, t_time, t_pid, row_splits]])

    #feat,  t_idx, t_energy, t_pos, t_time, t_pid = NormalizeInputShapes()(
    #    [feat,  t_idx, t_energy, t_pos, t_time, t_pid]
    #    )

    orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits
    gidx_orig = CreateGlobalIndices()(feat)

    _, row_splits = RaggedConstructTensor()([feat, row_splits])
    rs = row_splits

    feat_norm = ProcessFeatures()(
        feat)  #get rid of unit scalings, almost normalise
    feat_norm = BatchNormalization(momentum=0.6)(feat_norm)
    x = feat_norm

    energy = SelectFeatures(0, 1)(feat)
    time = SelectFeatures(8, 9)(feat)
    orig_coords = SelectFeatures(5, 8)(feat_norm)

    ######## create output lists

    allfeat = []

    backgatheredids = []
    gatherids = []
    backgathered = []
    backgathered_coords = []

    ####### create simple first coordinate transformation explicitly (time critical)

    #trans_coords = ManualCoordTransform()(orig_coords)
    coords = orig_coords
    coords = Dense(3,
                   use_bias=False,
                   kernel_initializer=tf.keras.initializers.Identity())(coords)

    nidx, dist = KNN(K=48)([coords, rs])

    first_nidx = nidx
    first_dist = dist
    first_coords = coords

    dist = LocalDistanceScaling(max_scale=10.)([dist, Dense(1)(x)])

    x_mp = DistanceWeightedMessagePassing([32, 16, 8])([x, nidx, dist])
    x_mp = Dense(32, activation='elu')(x_mp)
    x_mp = BatchNormalization(momentum=0.6)(x_mp)
    x = Concatenate()([x, x_mp])

    ###### collect information about the surrounding energy and time distributions per vertex ###
    ncov = Dense(4, kernel_initializer=tf.keras.initializers.Identity())(
        Concatenate()([energy, time, x]))
    ncov = NeighbourCovariance()([coords, dist, ncov, nidx])
    ncov = Dense(24, activation='elu', name='pre_dense_ncov_c')(ncov)
    ncov = BatchNormalization(momentum=0.6)(ncov)
    #should be enough for PCA, total info: 2 * (9+3)

    ##### put together and process ####

    x = Concatenate()([x, x_mp, ncov])
    #x = Dense(64, activation='elu',name='pre_dense_a')(x)
    x = Dense(64, activation='elu', name='pre_dense_b')(x)

    ####### add first set of outputs to output lists

    allfeat.append(x)
    backgathered_coords.append(coords)

    total_iterations = 5

    sel_gidx = gidx_orig

    for i in range(total_iterations):

        ###### reshape the graph to fewer vertices ####

        hier = Dense(1)(x)
        #narrow the local scaling down at each iteration as the coords become more abstract
        dist = LocalDistanceScaling(max_scale=10.)(
            [dist, Dense(1)(Concatenate()([x, dist]))])

        x_cl, rs, bidxs, sel_gidx, energy, x, t_idx, coords = LocalClusterReshapeFromNeighbours(
            K=6,
            radius=
            0.2,  #doesn't really have an effect because of local distance scaling
            print_reduction=True,
            loss_enabled=True,
            loss_scale=1.,
            loss_repulsion=0.4,
            print_loss=True,
            name='clustering_' + str(i))([
                x, dist, hier, nidx, rs, sel_gidx, energy, x, t_idx, coords,
                t_idx
            ])  #last is truth index used by layer

        gatherids.append(bidxs)

        energy = ReduceSumEntirely()(energy)

        #use EdgeConv operation to determine cluster properties
        if True or i:  #only after second iteration because of OOM
            x_cl = Reshape([-1, x.shape[-1]])(x_cl)  #get to shape V x K x F
            x_cl = EdgeConvStatic([64, 64, 64],
                                  add_mean=True,
                                  name="ec_static_" + str(i))(x_cl)
            x_cl = Concatenate()([x, x_cl])

        x = Concatenate()([x_cl, StopGradient()(coords)])

        x = Dense(64, activation='elu', name='dense_clc0_' + str(i))(x)
        x = Dense(64, activation='elu', name='dense_clc1_' + str(i))(x)

        x = BatchNormalization(momentum=0.6)(x)
        #notice last relu for feature weighting later

        ### now these are the new cluster features, up for the next iteration of building new latent space
        #x = RaggedGlobalExchange()([x,rs])

        x_gn, coords, nidx, dist = RaggedGravNet(
            n_neighbours=64 + 32 * i,
            n_dimensions=3,
            n_filters=64,
            n_propagate=64,
            return_self=True)([Concatenate()([coords, x]), rs])

        ng_coords = StopGradient()(coords)
        ng_dist = StopGradient()(dist)

        x_gn = Dense(32, activation='elu')(x_gn)

        ### add neighbour summary statistics
        #dist = LocalDistanceScaling(max_scale=10.)([dist, Dense(1)(x_gn)])

        x_ncov = Dense(16)(x)
        x_ncov = NeighbourCovariance()([ng_coords, ng_dist, x_ncov, nidx])
        x_ncov = Dense(64, activation='elu',
                       name='dense_ncov_a_' + str(i))(x_ncov)
        x_ncov = Dense(64, activation='elu',
                       name='dense_ncov_b_' + str(i))(x_ncov)
        x_ncov = BatchNormalization(momentum=0.6)(x_ncov)

        ### with all this information perform a few message passing steps
        x_mp = x
        x_mp = DistanceWeightedMessagePassing([32, 16,
                                               8])([x_mp, nidx, ng_dist])
        x_mp = Dense(64, activation='elu')(x_mp)
        x_mp = Dense(32, activation='elu')(x_mp)
        x_mp = BatchNormalization(momentum=0.6)(x_mp)

        x = Concatenate()([x, x_mp, x_ncov, x_gn, ng_coords, ng_dist])

        ##### prepare output of this iteration

        x = Dense(64, activation='elu', name='dense_out_a_' + str(i))(x)
        x = Dense(64, activation='elu', name='dense_out_b_' + str(i))(x)
        x = BatchNormalization(momentum=0.6)(x)

        #### compress further for output, but forward fill 64 feature x to next iteration

        x_r = Dense(32, activation='elu', name='dense_out_c_' + str(i))(x)
        #x_r = Concatenate()([x_r, ng_coords, ng_dist])

        #x_r = Concatenate()([StopGradient()(coords),x_r]) ## add coordinates, might come handy for cluster space

        if i >= total_iterations - 1:
            energy = MultiBackGather()(
                [energy,
                 gatherids])  #assign energy sum to all cluster components

        allfeat.append(MultiBackGather()([x_r, gatherids]))
        backgatheredids.append(MultiBackGather()([sel_gidx, gatherids]))
        backgathered_coords.append(MultiBackGather()([coords, gatherids]))

    x = Concatenate(name='allconcat')(allfeat)
    x = RaggedGlobalExchange()([x, row_splits])
    x = Dense(256, activation='elu', name='globalexdense')(x)
    x = RaggedGlobalExchange()([x, row_splits])

    x = Dense(128, activation='elu', name='alldense')(x)
    x = Dense(64, activation='elu')(x)
    x = Dense(32, activation='elu')(x)
    x = Concatenate()([first_coords, x])
    x = BatchNormalization(momentum=0.6)(x)

    pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id = create_outputs(
        x, feat)

    #
    #
    # double scale phase transition with linear beta + qmin
    #  -> more high beta points, but: payload loss will still scale one
    #     (or two, but then doesn't matter)
    #

    pred_beta = LLFullObjectCondensation(
        print_loss=True,
        energy_loss_weight=1e-5,
        position_loss_weight=1e-5,  #seems broken
        timing_loss_weight=1e-5,  #1e-3,
        beta_loss_scale=3.,
        repulsion_scaling=1.,
        q_min=2.0,
        use_average_cc_pos=False,
        use_energy_weights=True,
        prob_repulsion=True,
        phase_transition=True,
        phase_transition_double_weight=False,
        alt_potential_norm=True,
        cut_payload_beta_gradient=False)([
            pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id,
            orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid,
            row_splits
        ])

    return Model(inputs=Inputs,
                 outputs=[
                     pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time,
                     pred_id, rs
                 ] + backgatheredids + backgathered_coords)
Beispiel #3
0
def gravnet_model(Inputs, feature_dropout=-1., addBackGatherInfo=True):

    feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits = td.interpretAllModelInputs(
        Inputs)
    orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits
    gidx_orig = CreateGlobalIndices()(feat)

    _, row_splits = RaggedConstructTensor()([feat, row_splits])
    rs = row_splits

    feat_norm = ProcessFeatures()(feat)
    allfeat = [feat_norm]
    x = feat_norm

    backgatheredids = []
    gatherids = []
    backgathered = []
    backgathered_coords = []

    energy = SelectFeatures(0, 1)(feat)
    orig_coords = SelectFeatures(5, 8)(feat_norm)
    coords = orig_coords
    coords = Dense(3)(coords)  #just rotation and scaling
    #add gradient

    #see whats there
    nidx, dist = KNN(K=96, radius=1.0)([coords, rs])
    x = Dense(4, activation='elu',
              name='pre_pre_dense')(x)  #just a few features are enough here
    #this can be full blown because of the small number of input features
    x_c = SoftPixelCNN(mode='full', subdivisions=4,
                       name='prePCNN')([coords, x, dist, nidx])
    x = Concatenate()([x, x_c])
    #this is going to be among the most expensive operations:
    x = Dense(128, activation='elu', name='pre_dense_a')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(64, activation='elu', name='pre_dense_b')(x)
    x = BatchNormalization(momentum=0.6)(x)

    #allfeat.append(Dense(16, activation='elu',name='feat_compress_pre')(x))
    backgathered_coords.append(coords)

    sel_gidx = gidx_orig

    total_iterations = 5

    for i in range(total_iterations):

        #cluster first
        hier = Dense(1, activation='sigmoid')(x)
        x_cl, rs, bidxs, sel_gidx, energy, x, t_idx = LocalClusterReshapeFromNeighbours(
            K=8,
            radius=0.1,
            print_reduction=True,
            loss_enabled=True,
            loss_scale=2.,
            loss_repulsion=0.3,
            print_loss=True,
            name='clustering_' + str(i))(
                [x, dist, hier, nidx, rs, sel_gidx, energy, x, t_idx, t_idx])

        #explicit
        energy = ReduceSumEntirely()(
            energy)  #sums up all contained energy per cluster

        gatherids.append(bidxs)
        x_cl = Dense(128, activation='elu', name='dense_clc_' + str(i))(x_cl)
        n_energy = BatchNormalization(momentum=0.6)(energy)
        x = Concatenate()([x, x_cl, n_energy])

        pixelcompress = 4
        nneigh = 32 + 4 * i
        nfilt = 32 + 4 * i
        nprop = 32
        n_dimensions = 3  #make it plottable

        x_gn, coords, nidx, dist = RaggedGravNet(n_neighbours=nneigh,
                                                 n_dimensions=n_dimensions,
                                                 n_filters=nfilt,
                                                 n_propagate=nprop)([x, rs])

        subdivisions = 4

        #add more shape information
        x_sp = Dense(pixelcompress, activation='elu')(x)
        x_sp = BatchNormalization(momentum=0.6)(x_sp)
        x_sp = SoftPixelCNN(mode='full',
                            subdivisions=4)([coords, x_sp, dist, nidx])
        x_sp = Dense(128, activation='elu', name='dense_spc_' + str(i))(x_sp)

        x_mp = MessagePassing([32, 32, 16, 16, 8, 8])([x, nidx])
        x_mp = Dense(nfilt, activation='elu', name='dense_mpc_' + str(i))(x_sp)

        x = Concatenate()([x, x_sp, x_gn, x_mp])
        #check and compress it all
        x = Dense(128, activation='elu', name='dense_a_' + str(i))(x)
        x = Dense(64, activation='elu', name='dense_b_' + str(i))(x)
        x = BatchNormalization(momentum=0.6)(x)

        x_r = x
        #record more and more the deeper we go
        if i < total_iterations - 1:
            x_r = Dense(12 * (i + 1),
                        activation='elu',
                        name='dense_rec_' + str(i))(x)
        else:
            energy = MultiBackGather()(
                [energy,
                 gatherids])  #assign energy sum to all cluster components

        allfeat.append(MultiBackGather()([x_r, gatherids]))

        backgatheredids.append(MultiBackGather()([sel_gidx, gatherids]))
        backgathered_coords.append(MultiBackGather()([coords, gatherids]))

    x = Concatenate(name='allconcat')(allfeat)
    #x - Dropout(0.3)(x)#force to use different information sources
    x = Dense(128, activation='elu', name='alldense')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(64, activation='elu')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Concatenate()([feat, x])
    x = BatchNormalization(momentum=0.6)(x)
    x = Concatenate()([x, energy])

    pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id = create_outputs(
        x, feat)

    #loss
    pred_beta = LLFullObjectCondensation(
        print_loss=True,
        energy_loss_weight=1e-4,
        position_loss_weight=1e-2,
        timing_loss_weight=1e-3,
        beta_loss_scale=1.,
        repulsion_scaling=1.,
        q_min=1.5,
        prob_repulsion=True,
        phase_transition=1,
        alt_potential_norm=True)([
            pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id,
            orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid,
            row_splits
        ])

    return Model(inputs=Inputs,
                 outputs=[
                     pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time,
                     pred_id, rs
                 ] + backgatheredids + backgathered_coords)
Beispiel #4
0
def gravnet_model(Inputs, feature_dropout=-1., addBackGatherInfo=True):

    feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits = td.interpretAllModelInputs(
        Inputs)
    orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits
    gidx_orig = CreateGlobalIndices()(feat)

    _, row_splits = RaggedConstructTensor()([feat, row_splits])
    rs = row_splits

    feat_norm = ProcessFeatures()(feat)
    feat_norm = BatchNormalization(momentum=0.6)(feat_norm)
    allfeat = []
    x = feat_norm

    backgatheredids = []
    gatherids = []
    backgathered = []
    backgathered_coords = []
    energysums = []

    #really simple real coordinates
    energy = SelectFeatures(0, 1)(feat)
    orig_coords = SelectFeatures(4, 7)(feat_norm)
    coords = Dense(3,
                   use_bias=False,
                   kernel_initializer=tf.keras.initializers.Identity())(
                       orig_coords)  #just rotation and scaling

    #see whats there
    nidx, dist = KNN(K=64, radius=1.0)([coords, rs])

    x = RaggedGlobalExchange()([x, row_splits])
    x_c = Dense(32, activation='elu')(x)
    x_c = Dense(8)(x_c)  #just a few features are enough here
    #this can be full blown because of the small number of input features
    x_c = NeighbourApproxPCA(hidden_nodes=[32, 32, 9])(
        [coords, dist, x_c, nidx])
    x_mp = DistanceWeightedMessagePassing([32, 16, 8])([x, nidx, dist])
    x = Concatenate()([x, x_c, x_mp])
    #this is going to be among the most expensive operations:
    x = Dense(64, activation='elu', name='pre_dense_a')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(32, activation='elu', name='pre_dense_b')(x)
    x = BatchNormalization(momentum=0.6)(x)

    allfeat.append(x)
    backgathered_coords.append(coords)

    sel_gidx = gidx_orig

    cdist = dist
    ccoords = coords

    total_iterations = 5

    collectedccoords = []

    for i in range(total_iterations):

        cluster_neighbours = 5
        n_dimensions = 3  #make it plottable
        #derive new coordinates for clustering
        if i:
            ccoords = Add()([
                ccoords,
                (Dense(n_dimensions,
                       use_bias=False,
                       name='newcoords' + str(i),
                       kernel_initializer='zeros')(x))
            ])
            nidx, cdist = KNN(K=6 * cluster_neighbours,
                              radius=-1.0)([ccoords, rs])
            #here we use more neighbours to improve learning of the cluster space

        #cluster first
        #hier = Dense(1,activation='sigmoid')(x)
        #distcorr = Dense(dist.shape[-1],activation='relu',kernel_initializer='zeros')(Concatenate()([x,dist]))
        #dist = Add()([distcorr,dist])
        cdist = LocalDistanceScaling(max_scale=5.)([
            cdist,
            Dense(1, kernel_initializer='zeros')(Concatenate()([x, cdist]))
        ])
        #what if we let the individual distances scale here? so move hits in and out?

        x_cl, rs, bidxs, sel_gidx, energy, x, t_idx, coords, ccoords, cdist = LocalClusterReshapeFromNeighbours2(
            K=cluster_neighbours,
            radius=0.1,
            print_reduction=True,
            loss_enabled=True,
            loss_scale=1.,
            loss_repulsion=0.4,  #.5
            hier_transforms=[64, 32, 32],
            print_loss=True,
            name='clustering_' + str(i))([
                x, cdist, nidx, rs, sel_gidx, energy, x, t_idx, coords,
                ccoords, cdist, t_idx
            ])

        gatherids.append(bidxs)

        #explicit
        energy = ReduceSumEntirely()(
            energy)  #sums up all contained energy per cluster
        n_energy = BatchNormalization(momentum=0.6)(energy)

        x = Concatenate()([x_cl, n_energy])
        x = Dense(128, activation='elu', name='dense_clc_a' + str(i))(x)
        x = BatchNormalization(momentum=0.6)(x)
        #x = Dense(128, activation='elu',name='dense_clc_b'+str(i))(x)
        x = Dense(64, activation='elu')(x)
        x = BatchNormalization(momentum=0.6)(x)

        nneigh = 64
        nfilt = 64
        nprop = 64

        x = Concatenate()([coords, x])

        x_gn, coords, nidx, dist = RaggedGravNet(n_neighbours=nneigh,
                                                 n_dimensions=n_dimensions,
                                                 n_filters=nfilt,
                                                 n_propagate=nprop)([x, rs])

        x_sp = Dense(16)(x)
        x_sp = NeighbourApproxPCA(hidden_nodes=[32, 32, n_dimensions**2])(
            [coords, dist, x_sp, nidx])
        x_sp = BatchNormalization(momentum=0.6)(x_sp)

        x_mp = DistanceWeightedMessagePassing([32, 32, 16, 16, 8,
                                               8])([x, nidx, dist])
        x_mp = BatchNormalization(momentum=0.6)(x_mp)
        #x_sp=x_mp

        x = Concatenate()([x, x_mp, x_sp, x_gn])
        #check and compress it all
        x = Dense(128, activation='elu', name='dense_a_' + str(i))(x)
        x = BatchNormalization(momentum=0.6)(x)
        #x = Dense(128, activation='elu',name='dense_b_'+str(i))(x)
        x = Dense(64, activation='elu', name='dense_c_' + str(i))(x)
        x = Concatenate()([StopGradient()(ccoords), StopGradient()(cdist), x])
        x = BatchNormalization(momentum=0.6)(x)

        #record more and more the deeper we go
        x_r = x
        energysums.append(MultiBackGather()(
            [energy, gatherids]))  #assign energy sum to all cluster components

        allfeat.append(MultiBackGather()([x_r, gatherids]))

        backgatheredids.append(MultiBackGather()([sel_gidx, gatherids]))
        backgathered_coords.append(MultiBackGather()([ccoords, gatherids]))

    x = Concatenate(name='allconcat')(allfeat)
    x = RaggedGlobalExchange()([x, row_splits])
    #x - Dropout(0.3)(x)#force to use different information sources
    x = Dense(128, activation='elu', name='alldense')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(64, activation='elu')(x)
    #x = Concatenate()([x]+energysums)

    pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id = create_outputs(
        x, feat)

    #loss
    pred_beta = LLFullObjectCondensation(
        print_loss=True,
        energy_loss_weight=1e-4,
        position_loss_weight=1e-2,
        timing_loss_weight=1e-3,
        beta_loss_scale=1.,
        repulsion_scaling=1.,
        q_min=1.5,
        use_average_cc_pos=0.1,
        prob_repulsion=True,
        phase_transition=1,
        alt_potential_norm=True,
        kalpha_damping_strength=0.5,  #1.,
        name="FullOCLoss")([
            pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id,
            orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid,
            row_splits
        ])

    return Model(inputs=Inputs,
                 outputs=[
                     pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time,
                     pred_id, rs
                 ] + backgatheredids + backgathered_coords)
def gravnet_model(Inputs, feature_dropout=-1., addBackGatherInfo=True):

    ######## pre-process all inputs and create global indices etc. No DNN actions here

    feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits = td.interpretAllModelInputs(
        Inputs)
    feat, t_idx, t_energy, t_pos, t_time, t_pid = NormalizeInputShapes()(
        [feat, t_idx, t_energy, t_pos, t_time, t_pid])

    orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits
    gidx_orig = CreateGlobalIndices()(feat)

    _, row_splits = RaggedConstructTensor()([feat, row_splits])
    rs = row_splits

    feat_norm = ProcessFeatures()(
        feat)  #get rid of unit scalings, almost normalise
    feat_norm = BatchNormalization(momentum=0.6)(feat_norm)
    x = feat_norm

    energy = SelectFeatures(0, 1)(feat)
    time = SelectFeatures(8, 9)(feat)
    orig_coords = SelectFeatures(5, 8)(feat_norm)

    ######## create output lists

    allfeat = []

    backgatheredids = []
    gatherids = []
    backgathered = []
    backgathered_coords = []

    ####### create simple first coordinate transformation explicitly (time critical)

    coords = orig_coords
    coords = Dense(16, activation='elu')(coords)
    coords = Dense(32, activation='elu')(coords)
    coords = Dense(3, use_bias=False)(coords)
    coords = ScalarMultiply(0.1)(coords)
    coords = Add()([coords, orig_coords])
    coords = Dense(3,
                   use_bias=False,
                   kernel_initializer=tf.keras.initializers.identity())(coords)

    first_coords = coords

    ###### apply one gravnet-like transformation (explicit here because we have created coords by hand) ###

    nidx, dist = KNN(K=48)([coords, rs])
    x_mp = DistanceWeightedMessagePassing([32])([x, nidx, dist])

    first_nidx = nidx
    first_dist = dist

    ###### collect information about the surrounding energy and time distributions per vertex ###

    ncov = NeighbourCovariance()(
        [coords, ReluPlusEps()(Concatenate()([energy, time])), nidx])
    ncov = BatchNormalization(momentum=0.6)(ncov)
    ncov = Dense(64, activation='elu', name='pre_dense_ncov_a')(ncov)
    ncov = Dense(32, activation='elu', name='pre_dense_ncov_b')(ncov)

    ##### put together and process ####

    x = Concatenate()([x, x_mp, ncov, coords])
    x = Dense(64, activation='elu', name='pre_dense_a')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(32, activation='elu', name='pre_dense_b')(x)

    ####### add first set of outputs to output lists

    allfeat.append(x)
    backgathered_coords.append(coords)

    total_iterations = 5

    sel_gidx = gidx_orig

    for i in range(total_iterations):

        ###### reshape the graph to fewer vertices ####

        hier = Dense(1)(x)
        dist = LocalDistanceScaling()([dist, Dense(1)(x)])

        x_cl, rs, bidxs, sel_gidx, energy, x, t_idx, coords = LocalClusterReshapeFromNeighbours(
            K=6,
            radius=
            0.5,  #doesn't really have an effect because of local distance scaling
            print_reduction=True,
            loss_enabled=True,
            loss_scale=4.,
            loss_repulsion=0.5,
            print_loss=True,
            name='clustering_' + str(i))([
                x, dist, hier, nidx, rs, sel_gidx, energy, x, t_idx, coords,
                t_idx
            ])  #last is truth index used by layer

        gatherids.append(bidxs)

        if i or True:
            x_cl_rs = Reshape([-1, x.shape[-1]])(x_cl)  #get to shape V x K x F
            xec = EdgeConvStatic([32, 32, 32],
                                 name="ec_static_" + str(i))(x_cl_rs)
            x_cl = Concatenate()([x, xec])

        ### explicitly sum energy and re-add to features

        energy = ReduceSumEntirely()(energy)
        n_energy = BatchNormalization(momentum=0.6)(energy)
        x = Concatenate()([x_cl, n_energy])

        x = Dense(128, activation='elu', name='dense_clc0_' + str(i))(x)
        x = Dense(64, activation='relu', name='dense_clc1_' + str(i))(x)
        #notice last relu for feature weighting later

        x_gn, coords, nidx, dist = RaggedGravNet(
            n_neighbours=32 + 16 * i,
            n_dimensions=3,
            n_filters=64 + 16 * i,
            n_propagate=64,
            return_self=True)([Concatenate()([coords, x]), rs])

        ### add neighbour summary statistics

        x_ncov = NeighbourCovariance()([coords, ReluPlusEps()(x), nidx])
        x_ncov = Dense(128, activation='elu',
                       name='dense_ncov_a_' + str(i))(x_ncov)
        x_ncov = BatchNormalization(momentum=0.6)(x_ncov)
        x_ncov = Dense(64, activation='elu',
                       name='dense_ncov_b_' + str(i))(x_ncov)
        x = Concatenate()([x, x_ncov, x_gn])

        ### with all this information perform a few message passing steps

        x_mp = MessagePassing([32, 32, 16, 16, 8, 8])([x, nidx])
        x_mp = Dense(64, activation='elu', name='dense_mpc_' + str(i))(x_mp)
        x = Concatenate()([x, x_mp])

        ##### prepare output of this iteration

        x = Dense(128, activation='elu', name='dense_out_a_' + str(i))(x)
        x = BatchNormalization(momentum=0.6)(x)
        x = Dense(64, activation='elu', name='dense_out_b_' + str(i))(x)
        x = BatchNormalization(momentum=0.6)(x)

        #### compress further for output, but forward fill 64 feature x to next iteration

        x_r = Dense(8 + 16 * i, activation='elu',
                    name='dense_out_c_' + str(i))(x)
        #coords_nograd = StopGradient()(coords)
        #x_r = Concatenate()([coords_nograd,x_r]) ## add coordinates, might come handy for cluster space

        if i >= total_iterations - 1:
            energy = MultiBackGather()(
                [energy,
                 gatherids])  #assign energy sum to all cluster components

        allfeat.append(MultiBackGather()([x_r, gatherids]))
        backgatheredids.append(MultiBackGather()([sel_gidx, gatherids]))
        backgathered_coords.append(MultiBackGather()([coords, gatherids]))

    x = Concatenate(name='allconcat')(allfeat)
    #x = Dropout(0.2)(x)
    x_mp = DistanceWeightedMessagePassing([32, 32,
                                           32])([x, first_nidx, first_dist])
    x = Concatenate()([x, x_mp])

    x = Dense(128, activation='elu', name='alldense')(x)
    # TO BE ADDED WITH E LOSS x = Concatenate()([x,energy])
    #x = Dropout(0.2)(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(64, activation='elu')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(64, activation='elu')(x)

    pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id = create_outputs(
        x, feat)

    #
    #
    # double scale phase transition with linear beta + qmin
    #  -> more high beta points, but: payload loss will still scale one
    #     (or two, but then doesn't matter)
    #

    pred_beta = LLFullObjectCondensation(
        print_loss=True,
        energy_loss_weight=0.,
        position_loss_weight=0.,  #seems broken
        timing_loss_weight=0.,  #1e-3,
        beta_loss_scale=1.,
        repulsion_scaling=1.,
        q_min=1.5,
        prob_repulsion=True,
        phase_transition=True,
        phase_transition_double_weight=False,
        alt_potential_norm=True,
        cut_payload_beta_gradient=False)([
            pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id,
            orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid,
            row_splits
        ])

    return ExtendedMetricsModel(inputs=Inputs,
                                outputs=[
                                    pred_beta, pred_ccoords, pred_energy,
                                    pred_pos, pred_time, pred_id, rs
                                ] + backgatheredids + backgathered_coords)