예제 #1
0
파일: train.py 프로젝트: abao1999/DRNOC
def minimodel(Inputs,feature_dropout=-1.):
    x = Inputs[0] #this is the self.x list from the TrainData data structure
    energy_raw = SelectFeatures(0,3)(x)

    x = BatchNormalization(momentum=0.6)(x)
    feat=[x]

    for i in range(6):
        #add global exchange and another dense here
        x = GlobalExchange()(x)
        x = Dense(64, activation='elu')(x)
        x = Dense(64, activation='elu')(x)
        x = BatchNormalization(momentum=0.6)(x)
        x = Dense(64, activation='elu')(x)
        x = GravNet_simple(n_neighbours=10, 
                 n_dimensions=4, 
                 n_filters=128, 
                 n_propagate=64)(x)
        x = BatchNormalization(momentum=0.6)(x)
        feat.append(Dense(32, activation='elu')(x))

    x = Concatenate()(feat)
    x = Dense(64, activation='elu')(x)

    return Model(inputs=Inputs, outputs=output_block(x,checkids(Inputs),energy_raw))
예제 #2
0
파일: gravnet.py 프로젝트: gvonsem/HGCalML
def gravnet_model(Inputs, nclasses, nregressions, feature_dropout=0.1):

    x = Inputs[0]  #this is the self.x list from the TrainData data structure

    print('x', x.shape)
    coords = []

    etas = SelectFeatures(1, 2)(x)  #just to propagate to the prediction

    mask = CreateZeroMask(0)(x)
    x = BatchNormalization(momentum=0.9)(x)
    x = Multiply()([x, mask])
    x, coord = GravNet(n_neighbours=40,
                       n_dimensions=4,
                       n_filters=80,
                       n_propagate=16,
                       name='gravnet_pre',
                       also_coordinates=True)(x)
    coords.append(coord)
    x = BatchNormalization(momentum=0.9)(x)
    x = Multiply()([x, mask])

    feats = []
    for i in range(n_gravnet_layers):
        x = GlobalExchange()(x)
        x = Multiply()([x, mask])
        x = Dense(64, activation='tanh')(x)
        x = Dense(64, activation='tanh')(x)
        x = BatchNormalization(momentum=0.9)(x)
        x = Multiply()([x, mask])
        x = Dense(64, activation='sigmoid')(x)
        x = Multiply()([x, mask])
        x, coord = GravNet(n_neighbours=40,
                           n_dimensions=4,
                           n_filters=80,
                           n_propagate=16,
                           name='gravnet_' + str(i),
                           also_coordinates=True,
                           feature_dropout=feature_dropout)(x)
        coords.append(coord)
        x = BatchNormalization(momentum=0.9)(x)
        x = Multiply()([x, mask])
        feats.append(x)

    x = Concatenate()(feats)
    x = Dense(64, activation='elu', name='pre_last_correction')(x)
    x = BatchNormalization(momentum=0.9)(x)
    x = Multiply()([x, mask])
    x = Dense(nregressions, activation=None, kernel_initializer='zeros')(x)
    #x = Clip(-0.5, 1.5) (x)
    x = Multiply()([x, mask])

    #x = SortPredictionByEta(input_energy_index=0, input_eta_index=1)([x,Inputs[0]])

    x = Concatenate()([x] + coords + [etas])
    predictions = [x]
    return Model(inputs=Inputs, outputs=predictions)
def gravnet_model(Inputs, feature_dropout=-1.):
    nregressions = 5

    # I_data = tf.keras.Input(shape=(num_features,), dtype="float32")
    # I_splits = tf.keras.Input(shape=(1,), dtype="int32")

    I_data = Inputs[0]
    I_splits = tf.cast(Inputs[1], tf.int32)

    x_data, x_row_splits = RaggedConstructTensor(name="construct_ragged")(
        [I_data, I_splits])

    input_features = x_data  # these are going to be passed through for the loss

    x_basic = BatchNormalization(momentum=0.6)(
        x_data)  # mask_and_norm is just batch norm now
    x = x_basic

    n_filters = 0
    n_gravnet_layers = 4
    feat = []
    for i in range(n_gravnet_layers):
        n_filters = 128
        n_propagate = 96
        n_neighbours = 200

        x = RaggedGlobalExchange(name="global_exchange_" +
                                 str(i))([x, x_row_splits])
        x = Dense(64, activation='elu', name="dense_" + str(i) + "_a")(x)
        x = Dense(64, activation='elu', name="dense_" + str(i) + "_b")(x)
        x = Dense(64, activation='elu', name="dense_" + str(i) + "_c")(x)
        x = BatchNormalization(momentum=0.6)(x)
        x = FusedRaggedGravNet_simple(n_neighbours=n_neighbours,
                                      n_dimensions=4,
                                      n_filters=n_filters,
                                      n_propagate=n_propagate,
                                      name='gravnet_' +
                                      str(i))([x, x_row_splits])
        x = BatchNormalization(momentum=0.6)(x)
        feat.append(
            Dense(128, activation='elu', name="dense_compress_" + str(i))(x))

    x = Concatenate(name="concat_gravout")(feat)
    x = Dense(128, activation='elu', name="dense_last_a")(x)
    x = Dense(64, activation='elu', name="dense_last_b")(x)
    x = Dense(64, activation='elu', name="dense_last_c")(x)

    beta = Dense(1, activation='sigmoid', name="dense_beta")(x)

    xy = Dense(2, activation=None, name="dense_xy")(x)
    t = ScalarMultiply(1e-9)(Dense(1, activation=None, name="dense_t")(x))
    ccoords = Dense(2, activation=None, name="dense_ccoords")(x)

    x_en = Dense(64, activation='elu', name="dense_en_a")(
        x)  #herer so the other names remain the same
    input_energy = SelectFeatures(0, 1, name="select_en")(input_features)
    energy_condensates, idxs = CondensateAndSum(
        radius=0.5, min_beta=0.1,
        name="condensate_en")([ccoords, beta, input_energy, x_row_splits])
    x_en = Concatenate(name="concat_en_cond")(
        [ScalarMultiply(10, name="multi_en")(x_en), energy_condensates])
    x_en = Dense(64, activation='elu', name="dense_en_b")(x_en)
    energy = Dense(1, activation=None, name="dense_en_final")(x_en)

    print('input_features', input_features.shape)

    x = Concatenate(name="concat_final")(
        [input_features, beta, energy, xy, t, ccoords])

    # x = Concatenate(name="concatlast", axis=-1)([x,coords])#+[n_showers]+[etas_phis])
    predictions = x

    # outputs = tf.tuple([predictions, x_row_splits])

    return Model(inputs=Inputs, outputs=[predictions, predictions])
def gravnet_model(Inputs,
                  viscosity=0.2,
                  print_viscosity=False,
                  fluidity_decay=1e-3,
                  max_viscosity=0.9 # to start with
                  ):

    feature_dropout=-1.
    addBackGatherInfo=True,

    feat,  t_idx, t_energy, t_pos, t_time, t_pid, t_spectator, t_fully_contained, row_splits = td.interpretAllModelInputs(Inputs)
    orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits
    gidx_orig = CreateGlobalIndices()(feat)

    _, row_splits = RaggedConstructTensor()([feat, row_splits])
    rs = row_splits

    feat_norm = ProcessFeatures()(feat)
    energy = SelectFeatures(0,1)(feat)
    time = SelectFeatures(8,9)(feat_norm)
    orig_coords = SelectFeatures(5,8)(feat)

    x = feat_norm

    allfeat=[x]
    backgatheredids=[]
    gatherids=[]
    backgathered = []
    backgathered_coords = []
    energysums = []

    n_reshape_dimensions=3

    #really simple real coordinates
    coords = ManualCoordTransform()(orig_coords)
    coords = Concatenate()([coords, time])
    coords = Dense(n_reshape_dimensions, use_bias=False, kernel_initializer=tf.keras.initializers.Identity()  )(coords)#just rotation and scaling


    #see whats there
    nidx, dist = KNN(K=24,radius=1.0)([coords,rs])

    x_mp = DistanceWeightedMessagePassing([32,16,16,8])([x,nidx,dist])
    x = Concatenate()([x,x_mp])
    #this is going to be among the most expensive operations:
    x = Dense(64, activation='elu',name='pre_dense_a')(x)
    x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x)

    allfeat.append(x)
    backgathered_coords.append(coords)

    sel_gidx = gidx_orig

    cdist = dist
    ccoords = coords

    total_iterations=5

    fwdbgccoords=None

    for i in range(total_iterations):

        cluster_neighbours = 5
        #derive new coordinates for clustering
        if i:
            ccoords = Add()([
                ccoords,Dense(n_reshape_dimensions,name='newcoords'+str(i),
                                         kernel_initializer='zeros'
                                         )(x)
                ])
            nidx, cdist = KNN(K=5*cluster_neighbours,radius=-1.0)([ccoords,rs])
            #here we use more neighbours to improve learning of the cluster space
            #this can be adjusted in the final trained model to be equal to 'cluster_neighbours'
        cdist = LocalDistanceScaling(max_scale=10.)([cdist, Dense(1,kernel_initializer='zeros')(Concatenate()([x,cdist]))])

        x, rs, bidxs, sel_gidx, energy, x, t_idx, coords, ccoords, cdist = LocalClusterReshapeFromNeighbours2(
                 K=cluster_neighbours,
                 radius=0.1,
                 print_reduction=True,
                 loss_enabled=True,
                 loss_scale = 3.,
                 loss_repulsion=0.8, # it's important to get this right here
                 hier_transforms=[64,32,16],
                 print_loss=False,
                 name='clustering_'+str(i)
                 )([x, cdist, nidx, rs, sel_gidx, energy, x, t_idx, coords, ccoords, cdist, t_idx])

        gatherids.append(bidxs)

        #explicit
        energy = ReduceSumEntirely()(energy)#sums up all contained energy per cluster

        x = Dense(128, activation='elu',name='dense_clc_a'+str(i))(x)
        x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x)
        x = Dense(128, activation='elu',name='dense_clc_b'+str(i))(x)
        x = Dense(64, activation='elu')(x)
        x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x)


        n_dimensions = 3+i #make it plottable
        nneigh = 64+32*i #this will be almost fully connected for last clustering step
        nfilt = 64
        nprop = 64

        x = Concatenate()([coords,x])

        x_gn, coords, nidx, dist = RaggedGravNet(n_neighbours=nneigh,
                                              n_dimensions=n_dimensions,
                                              n_filters=nfilt,
                                              n_propagate=nprop)([x, rs])


        x_pca = Dense(2+4*i)(x)
        x_pca = NeighbourApproxPCA()([coords,dist,x_pca,nidx])
        x_pca = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x_pca)

        x_mp = DistanceWeightedMessagePassing([64,64,32,32])([x,nidx,dist])
        x_mp = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x_mp)

        x = Concatenate()([x,x_pca,x_mp,x_gn])
        #check and compress it all
        x = Dense(128, activation='elu',name='dense_a_'+str(i))(x)
        x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x)
        x = Dense(32+16*i, activation='elu',name='dense_c_'+str(i))(x)
        x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x)


        energysums.append( MultiBackGather()([energy, gatherids]) )#assign energy sum to all cluster components

        allfeat.append(MultiBackGather()([x, gatherids]))

        backgatheredids.append(MultiBackGather()([sel_gidx, gatherids]))
        bgccoords = MultiBackGather()([ccoords, gatherids])
        if fwdbgccoords is None:
            fwdbgccoords = bgccoords
        backgathered_coords.append(bgccoords)




    x = Concatenate(name='allconcat')(allfeat)
    x = Dense(128, activation='elu', name='alldense')(x)
    x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x)

    #this is going to be resource intense, give a good starting point with the last ccoords
    coords = Dense(3,use_bias=False,
                   kernel_initializer=tf.keras.initializers.Identity()
                   )(Concatenate()([ fwdbgccoords ,x]))

    nidx, dist = KNN(K=32,radius=1.0)([coords,row_splits])
    x_mp = DistanceWeightedMessagePassing(8*[8])([x,nidx,dist])#only some but a lot of hops
    x = Concatenate()([x,x_mp])
    #
    backgathered_coords.append(coords)

    x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x)
    x = Dense(128, activation='elu')(x)
    x = GooeyBatchNorm(viscosity=viscosity, max_viscosity=max_viscosity,fluidity_decay=fluidity_decay)(x)
    x = Concatenate()([x]+energysums)
    x = Dense(64, activation='elu')(x)
    x = Dense(64, activation='elu')(x)

    pred_beta, pred_ccoords, pred_dist, pred_energy, pred_pos, pred_time, pred_id = create_outputs(x,feat,add_distance_scale=True)

    #loss
    pred_beta = LLFullObjectCondensation(print_loss=True,
                                         energy_loss_weight=1e-4,
                                         position_loss_weight=1e-3,
                                         timing_loss_weight=1e-3,
                                         beta_loss_scale=1.,
                                         repulsion_scaling=5.,
                                         repulsion_q_min=1.,
                                         super_repulsion=False,
                                         q_min=0.5,
                                         use_average_cc_pos=0.5,
                                         prob_repulsion=True,
                                         phase_transition=1,
                                         huber_energy_scale = 3,
                                         alt_potential_norm=True,
                                         beta_gradient_damping=0.,
                                         payload_beta_gradient_damping_strength=0.,
                                         kalpha_damping_strength=0.,#1.,
                                         use_local_distances=True,
                                         name="FullOCLoss"
                                         )([pred_beta, pred_ccoords,
                                            pred_dist,
                                            pred_energy,
                                            pred_pos, pred_time, pred_id,
                                            orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid,
                                            row_splits])

    model_outputs = [('pred_beta', pred_beta), ('pred_ccoords',pred_ccoords),
       ('pred_energy',pred_energy),
       ('pred_pos',pred_pos),
       ('pred_time',pred_time),
       ('pred_id',pred_id),
       ('pred_dist',pred_dist),
       ('row_splits',rs)]

    for i, (x, y) in enumerate(zip(backgatheredids, backgathered_coords)):
        model_outputs.append(('backgatheredids_'+str(i), x))
        model_outputs.append(('backgathered_coords_'+str(i), y))
    return RobustModel(model_inputs=Inputs, model_outputs=model_outputs)
예제 #5
0
def pre_selection_staged(
    indict,
    debug_outdir,
    trainable,
    name='pre_selection_add_stage_0',
    debugplots_after=-1,
    reduction_threshold=0.75,
    use_edges=True,
    print_info=False,
    record_metrics=False,
    n_coords=3,
    edge_nodes_0=16,
    edge_nodes_1=8,
):
    '''
    Takes the output of the preselection model and selects again :)
    But the outputs are compatible, this one can be chained
    
    This one uses full blown GravNet
    
    Gets as inputs:
    
    indict['scatterids']
    indict['orig_t_idx'] 
    indict['orig_t_energy'] 
    indict['orig_dim_coords']
    indict['rs']
    indict['orig_row_splits']
    
    indict['features'] 
    indict['orig_features']
    indict['coords'] 
    indict['addfeat']
    indict['energy']
    
    
    indict['t_idx']
    indict['t_energy']
    ... all the truth info
    
    '''

    from GravNetLayersRagged import RaggedGravNet, DistanceWeightedMessagePassing, ElementScaling
    from GravNetLayersRagged import SelectFromIndices, GooeyBatchNorm, MaskTracksAsNoise
    from GravNetLayersRagged import AccumulateNeighbours, KNN, MultiAttentionGravNetAdd
    from LossLayers import LLClusterCoordinates, LLFillSpace
    from DebugLayers import PlotCoordinates
    from MetricsLayers import MLReductionMetrics
    from Regularizers import MeanMaxDistanceRegularizer, AverageDistanceRegularizer

    #assume the inputs are normalised
    rs = indict['rs']
    t_idx = indict['t_idx']

    track_charge = SelectFeatures(2, 3)(
        indict['unproc_features'])  #zero for calo hits
    x = Concatenate()([indict['features'], indict['addfeat']])
    x = Dense(64, activation='elu', trainable=trainable)(x)
    gn_pre_coords = indict['coords']
    gn_pre_coords = ElementScaling(name=name + 'es1',
                                   trainable=trainable)(gn_pre_coords)
    x = Concatenate()([gn_pre_coords, x])

    x, coords, nidx, dist = RaggedGravNet(n_neighbours=32,
                                          n_dimensions=n_coords,
                                          n_filters=64,
                                          n_propagate=64,
                                          coord_initialiser_noise=1e-5,
                                          feature_activation=None,
                                          record_metrics=record_metrics,
                                          use_approximate_knn=True,
                                          use_dynamic_knn=True,
                                          trainable=trainable,
                                          name=name + '_gn1')([x, rs])

    #the two below are mostly running to record metrics and kill very bad coordinate scalings
    dist = MeanMaxDistanceRegularizer(strength=1e-6 if trainable else 0.,
                                      record_metrics=record_metrics)(dist)

    dist = AverageDistanceRegularizer(strength=1e-6 if trainable else 0.,
                                      record_metrics=record_metrics)(dist)

    if debugplots_after > 0:
        coords = PlotCoordinates(debugplots_after,
                                 outdir=debug_outdir,
                                 name=name + '_gn1_coords')(
                                     [coords, indict['energy'], t_idx, rs])

    x = DistanceWeightedMessagePassing([32, 32, 8, 8],
                                       name=name + 'dmp1',
                                       trainable=trainable)([x, nidx, dist])

    x_matt = Dense(16, activation='elu', name=name + '_matt_dense')(x)

    x_matt = MultiAttentionGravNetAdd(5,
                                      name=name + '_att_gn1',
                                      record_metrics=record_metrics)(
                                          [x, x_matt, coords, nidx])
    x = Concatenate()([x, x_matt])
    x = Dense(64, activation='elu', name=name + '_bef_coord_dense')(x)

    coords = Add()([
        Dense(n_coords,
              name=name + '_coord_add_dense',
              kernel_initializer='zeros')(x), coords
    ])
    if debugplots_after > 0:
        coords = PlotCoordinates(debugplots_after,
                                 outdir=debug_outdir,
                                 name=name + '_red_coords')(
                                     [coords, indict['energy'], t_idx, rs])

    nidx, dist = KNN(
        K=16,
        radius='dynamic',  #use dynamic feature
        record_metrics=record_metrics,
        name=name + '_knn',
        min_bins=[7, 7]  #this can be fine grained
    )([coords, rs])

    coords = LLClusterCoordinates(print_loss=print_info,
                                  record_metrics=record_metrics,
                                  active=trainable,
                                  print_batch_time=False,
                                  scale=5.)([coords, t_idx, rs])

    coords = LLFillSpace(
        active=trainable,
        record_metrics=record_metrics,
        scale=0.025,  #just mild
        runevery=-1,  #give it a kick only every now and then - hat's enough
    )([coords, rs])

    unred_rs = rs

    cluster_tidx = MaskTracksAsNoise(active=trainable)([t_idx, track_charge])

    gnidx, gsel, group_backgather, rs = reduce_indices(
        x,
        dist,
        nidx,
        rs,
        cluster_tidx,
        threshold=reduction_threshold,
        print_reduction=print_info,
        trainable=trainable,
        name=name + '_reduce_indices',
        use_edges=use_edges,
        edge_nodes_0=edge_nodes_0,
        edge_nodes_1=edge_nodes_1,
        return_backscatter=False)

    gsel = MLReductionMetrics(name=name + '_reduction', record_metrics=True)(
        [gsel, t_idx, indict['t_energy'], unred_rs, rs])

    selfeat = SelectFromIndices()([gsel, indict['features']])
    unproc_features = SelectFromIndices()([gsel, indict['unproc_features']])

    energy = indict['energy']

    x = AccumulateNeighbours('minmeanmax')([x, gnidx, energy])
    x = SelectFromIndices()([gsel, x])
    #add more useful things
    coords = AccumulateNeighbours('mean')([coords, gnidx, energy])
    coords = SelectFromIndices()([gsel, coords])
    phys_coords = AccumulateNeighbours('mean')(
        [indict['phys_coords'], gnidx, energy])
    phys_coords = SelectFromIndices()([gsel, phys_coords])

    energy = AccumulateNeighbours('sum')([energy, gnidx])
    energy = SelectFromIndices()([gsel, energy])

    out = {}
    out['not_noise_score'] = AccumulateNeighbours('mean')(
        [indict['not_noise_score'], gnidx])
    out['not_noise_score'] = SelectFromIndices()(
        [gsel, out['not_noise_score']])

    out['scatterids'] = indict['scatterids'] + [group_backgather
                                                ]  #append new selection

    #re-build standard feature layout
    out['features'] = selfeat
    out['unproc_features'] = unproc_features
    out['coords'] = coords
    out['phys_coords'] = phys_coords
    out['addfeat'] = GooeyBatchNorm(name=name + '_gooey_norm',
                                    trainable=trainable)(x)  #norm them
    out['energy'] = energy
    out['rs'] = rs

    for k in indict.keys():
        if 't_' == k[0:2]:
            out[k] = SelectFromIndices()([gsel, indict[k]])

    #some pass throughs:
    out['orig_dim_coords'] = indict['orig_dim_coords']
    out['orig_t_idx'] = indict['orig_t_idx']
    out['orig_t_energy'] = indict['orig_t_energy']
    out['orig_row_splits'] = indict['orig_row_splits']

    #check
    anymissing = False
    for k in indict.keys():
        if not k in out.keys():
            anymissing = True
            print(k, 'missing')
    if anymissing:
        raise ValueError("key not found")

    return out
예제 #6
0
def pre_selection_model_full(
    orig_inputs,
    debug_outdir='',
    trainable=False,
    name='pre_selection',
    debugplots_after=-1,
    reduction_threshold=0.75,
    noise_threshold=0.025,
    use_edges=True,
    n_coords=3,
    pass_through=False,
    print_info=False,
    record_metrics=False,
    omit_reduction=False,  #only trains coordinate transform. useful for pretrain phase
    use_multigrav=True,
    eweighted=True,
):

    from GravNetLayersRagged import AccumulateNeighbours, SelectFromIndices
    from GravNetLayersRagged import SortAndSelectNeighbours, NoiseFilter
    from GravNetLayersRagged import CastRowSplits, ProcessFeatures
    from GravNetLayersRagged import GooeyBatchNorm, MaskTracksAsNoise
    from DebugLayers import PlotCoordinates
    from LossLayers import LLClusterCoordinates, LLNotNoiseClassifier, LLFillSpace
    from MetricsLayers import MLReductionMetrics

    rs = CastRowSplits()(orig_inputs['row_splits'])
    t_idx = orig_inputs['t_idx']

    orig_processed_features = ProcessFeatures()(orig_inputs['features'])
    x = orig_processed_features
    energy = SelectFeatures(0, 1)(orig_inputs['features'])
    coords = SelectFeatures(5, 8)(x)
    track_charge = SelectFeatures(2, 3)(
        orig_inputs['features'])  #zero for calo hits
    phys_coords = coords

    # here the actual network starts
    if debugplots_after > 0:
        coords = PlotCoordinates(debugplots_after,
                                 outdir=debug_outdir,
                                 name=name +
                                 '_initial')([coords, energy, t_idx, rs])
    ############## Keep this part to reload the noise filter with pre-trained weights for other trainings

    out = {}
    if pass_through:  #do nothing but make output compatible
        for k in orig_inputs.keys():
            out[k] = orig_inputs[k]
        out['features'] = x
        out['coords'] = coords
        out['addfeat'] = x  #add more
        out['energy'] = energy
        out['not_noise_score'] = Dense(1, name=name + '_passthrough_noise')(x)
        out['orig_t_idx'] = orig_inputs['t_idx']
        out['orig_t_energy'] = orig_inputs['t_energy']  #for validation
        out['orig_dim_coords'] = coords
        out['rs'] = rs
        out['orig_row_splits'] = rs
        return out

    #this takes O(200ms) for 100k hits
    coords, nidx, dist, x = first_coordinate_adjustment(
        coords,
        x,
        energy,
        rs,
        t_idx,
        debug_outdir,
        trainable=trainable,
        name=name + '_first_coords',
        debugplots_after=debugplots_after,
        n_coords=n_coords,
        record_metrics=record_metrics,
        use_multigrav=use_multigrav)
    #create the gradients
    coords = LLClusterCoordinates(print_loss=trainable and print_info,
                                  active=trainable,
                                  print_batch_time=False,
                                  record_metrics=record_metrics,
                                  scale=5.)([coords, t_idx, rs])

    if debugplots_after > 0:
        coords = PlotCoordinates(debugplots_after,
                                 outdir=debug_outdir,
                                 name=name +
                                 '_bef_red')([coords, energy, t_idx, rs])

    if omit_reduction:
        return {'coords': coords, 'dist': dist, 'x': x}

    dist, nidx = SortAndSelectNeighbours(K=16)(
        [dist, nidx])  #only run reduction on 12 closest
    '''
    run a full reduction block
    return the noise score in addition - don't select yet
    
    do not cluster tracks with anything here
    '''

    cluster_tidx = MaskTracksAsNoise(active=trainable)([t_idx, track_charge])

    unred_rs = rs
    gnidx, gsel, group_backgather, rs = reduce_indices(
        x,
        dist,
        nidx,
        rs,
        cluster_tidx,
        threshold=reduction_threshold,
        print_reduction=print_info,
        trainable=trainable,
        name=name + '_reduce_indices',
        use_edges=use_edges,
        record_metrics=record_metrics,
        return_backscatter=False)

    gsel = MLReductionMetrics(name=name + '_reduction_0',
                              record_metrics=record_metrics)([
                                  gsel, t_idx, orig_inputs['t_energy'],
                                  unred_rs, rs
                              ])

    #do it explicitly

    #selfeat = orig_inputs['features']
    selfeat = SelectFromIndices()([gsel, orig_processed_features])
    unproc_features = SelectFromIndices()([gsel, orig_inputs['features']])

    #save for later
    orig_dim_coords = coords

    energy_weight = energy
    if not eweighted:
        energy_weight = OnesLike()(energy)

    x = AccumulateNeighbours('minmeanmax')([x, gnidx, energy_weight])
    x = SelectFromIndices()([gsel, x])
    #add more useful things
    coords = AccumulateNeighbours('mean')([coords, gnidx, energy_weight])
    coords = SelectFromIndices()([gsel, coords])

    phys_coords = AccumulateNeighbours('mean')(
        [phys_coords, gnidx, energy_weight])
    phys_coords = SelectFromIndices()([gsel, phys_coords])

    energy = AccumulateNeighbours('sum')([energy, gnidx])
    energy = SelectFromIndices()([gsel, energy])

    #re-build standard feature layout
    out['features'] = selfeat
    out['unproc_features'] = unproc_features
    out['coords'] = coords
    out['phys_coords'] = phys_coords
    out['addfeat'] = GooeyBatchNorm(trainable=trainable)(x)  #norm them
    out['energy'] = energy

    ## all the truth
    for k in orig_inputs.keys():
        if 't_' == k[0:2]:
            out[k] = SelectFromIndices()([gsel, orig_inputs[k]])

    #debug
    if debugplots_after > 0:
        out['coords'] = PlotCoordinates(debugplots_after,
                                        outdir=debug_outdir,
                                        name=name + '_after_red')([
                                            out['coords'], out['energy'],
                                            out['t_idx'], rs
                                        ])

    ######## below is noise classifier

    #this does not work, but also might not be an issue for the studies
    #out['backscatter']=bg

    isnotnoise = Dense(
        1,
        activation='sigmoid',
        trainable=trainable,
        name=name + '_noisescore_d1',
    )(Concatenate()([out['addfeat'], out['coords']]))
    isnotnoise = LLNotNoiseClassifier(
        print_loss=trainable and print_info,
        scale=1.,
        active=trainable,
        record_metrics=record_metrics,
    )([isnotnoise, out['t_idx']])

    unred_rs = rs
    sel, rs, noise_backscatter = NoiseFilter(
        threshold=noise_threshold,  #high signal efficiency filter
        print_reduction=print_info,
        record_metrics=record_metrics)([isnotnoise, rs])

    out['not_noise_score'] = isnotnoise

    for k in out.keys():
        out[k] = SelectFromIndices()([sel, out[k]])

    out['coords'] = LLFillSpace(
        print_loss=trainable and print_info,
        active=trainable,
        record_metrics=record_metrics,
        scale=0.025,  #just mild
        runevery=-1,  #give it a kick only every now and then - hat's enough
    )([out['coords'], rs])

    out['scatterids'] = [group_backgather,
                         noise_backscatter]  #add them here directly
    out['orig_t_idx'] = orig_inputs['t_idx']
    out['orig_t_energy'] = orig_inputs['t_energy']  #for validation
    out['orig_dim_coords'] = orig_dim_coords
    out['rs'] = rs
    out['orig_row_splits'] = orig_inputs['row_splits']
    '''
    So we have the following outputs at this stage:
    
    out['group_backgather']
    out['noise_backscatter_N']
    out['noise_backscatter_idx']
    out['orig_t_idx'] 
    out['orig_t_energy'] 
    out['orig_dim_coords']
    out['rs']
    out['orig_row_splits']
    
    out['features'] 
    out['unproc_features']
    out['coords'] 
    out['addfeat']
    out['energy']
    
    '''

    return out
예제 #7
0
def gravnet_model(Inputs, feature_dropout=-1., addBackGatherInfo=True):

    ######## pre-process all inputs and create global indices etc. No DNN actions here

    feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits = td.interpretAllModelInputs(
        Inputs)

    # feat = Lambda(lambda x: tf.squeeze(x,axis=1)) (feat)

    #tf.print([(t.shape, t.name) for t in [feat,  t_idx, t_energy, t_pos, t_time, t_pid, row_splits]])

    #feat,  t_idx, t_energy, t_pos, t_time, t_pid = NormalizeInputShapes()(
    #    [feat,  t_idx, t_energy, t_pos, t_time, t_pid]
    #    )

    orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits
    gidx_orig = CreateGlobalIndices()(feat)

    _, row_splits = RaggedConstructTensor()([feat, row_splits])
    rs = row_splits

    feat_norm = ProcessFeatures()(
        feat)  #get rid of unit scalings, almost normalise
    feat_norm = BatchNormalization(momentum=0.6)(feat_norm)
    x = feat_norm

    energy = SelectFeatures(0, 1)(feat)
    time = SelectFeatures(8, 9)(feat)
    orig_coords = SelectFeatures(5, 8)(feat_norm)

    ######## create output lists

    allfeat = []

    backgatheredids = []
    gatherids = []
    backgathered = []
    backgathered_coords = []

    ####### create simple first coordinate transformation explicitly (time critical)

    #trans_coords = ManualCoordTransform()(orig_coords)
    coords = orig_coords
    coords = Dense(3,
                   use_bias=False,
                   kernel_initializer=tf.keras.initializers.Identity())(coords)

    nidx, dist = KNN(K=48)([coords, rs])

    first_nidx = nidx
    first_dist = dist
    first_coords = coords

    dist = LocalDistanceScaling(max_scale=10.)([dist, Dense(1)(x)])

    x_mp = DistanceWeightedMessagePassing([32, 16, 8])([x, nidx, dist])
    x_mp = Dense(32, activation='elu')(x_mp)
    x_mp = BatchNormalization(momentum=0.6)(x_mp)
    x = Concatenate()([x, x_mp])

    ###### collect information about the surrounding energy and time distributions per vertex ###
    ncov = Dense(4, kernel_initializer=tf.keras.initializers.Identity())(
        Concatenate()([energy, time, x]))
    ncov = NeighbourCovariance()([coords, dist, ncov, nidx])
    ncov = Dense(24, activation='elu', name='pre_dense_ncov_c')(ncov)
    ncov = BatchNormalization(momentum=0.6)(ncov)
    #should be enough for PCA, total info: 2 * (9+3)

    ##### put together and process ####

    x = Concatenate()([x, x_mp, ncov])
    #x = Dense(64, activation='elu',name='pre_dense_a')(x)
    x = Dense(64, activation='elu', name='pre_dense_b')(x)

    ####### add first set of outputs to output lists

    allfeat.append(x)
    backgathered_coords.append(coords)

    total_iterations = 5

    sel_gidx = gidx_orig

    for i in range(total_iterations):

        ###### reshape the graph to fewer vertices ####

        hier = Dense(1)(x)
        #narrow the local scaling down at each iteration as the coords become more abstract
        dist = LocalDistanceScaling(max_scale=10.)(
            [dist, Dense(1)(Concatenate()([x, dist]))])

        x_cl, rs, bidxs, sel_gidx, energy, x, t_idx, coords = LocalClusterReshapeFromNeighbours(
            K=6,
            radius=
            0.2,  #doesn't really have an effect because of local distance scaling
            print_reduction=True,
            loss_enabled=True,
            loss_scale=1.,
            loss_repulsion=0.4,
            print_loss=True,
            name='clustering_' + str(i))([
                x, dist, hier, nidx, rs, sel_gidx, energy, x, t_idx, coords,
                t_idx
            ])  #last is truth index used by layer

        gatherids.append(bidxs)

        energy = ReduceSumEntirely()(energy)

        #use EdgeConv operation to determine cluster properties
        if True or i:  #only after second iteration because of OOM
            x_cl = Reshape([-1, x.shape[-1]])(x_cl)  #get to shape V x K x F
            x_cl = EdgeConvStatic([64, 64, 64],
                                  add_mean=True,
                                  name="ec_static_" + str(i))(x_cl)
            x_cl = Concatenate()([x, x_cl])

        x = Concatenate()([x_cl, StopGradient()(coords)])

        x = Dense(64, activation='elu', name='dense_clc0_' + str(i))(x)
        x = Dense(64, activation='elu', name='dense_clc1_' + str(i))(x)

        x = BatchNormalization(momentum=0.6)(x)
        #notice last relu for feature weighting later

        ### now these are the new cluster features, up for the next iteration of building new latent space
        #x = RaggedGlobalExchange()([x,rs])

        x_gn, coords, nidx, dist = RaggedGravNet(
            n_neighbours=64 + 32 * i,
            n_dimensions=3,
            n_filters=64,
            n_propagate=64,
            return_self=True)([Concatenate()([coords, x]), rs])

        ng_coords = StopGradient()(coords)
        ng_dist = StopGradient()(dist)

        x_gn = Dense(32, activation='elu')(x_gn)

        ### add neighbour summary statistics
        #dist = LocalDistanceScaling(max_scale=10.)([dist, Dense(1)(x_gn)])

        x_ncov = Dense(16)(x)
        x_ncov = NeighbourCovariance()([ng_coords, ng_dist, x_ncov, nidx])
        x_ncov = Dense(64, activation='elu',
                       name='dense_ncov_a_' + str(i))(x_ncov)
        x_ncov = Dense(64, activation='elu',
                       name='dense_ncov_b_' + str(i))(x_ncov)
        x_ncov = BatchNormalization(momentum=0.6)(x_ncov)

        ### with all this information perform a few message passing steps
        x_mp = x
        x_mp = DistanceWeightedMessagePassing([32, 16,
                                               8])([x_mp, nidx, ng_dist])
        x_mp = Dense(64, activation='elu')(x_mp)
        x_mp = Dense(32, activation='elu')(x_mp)
        x_mp = BatchNormalization(momentum=0.6)(x_mp)

        x = Concatenate()([x, x_mp, x_ncov, x_gn, ng_coords, ng_dist])

        ##### prepare output of this iteration

        x = Dense(64, activation='elu', name='dense_out_a_' + str(i))(x)
        x = Dense(64, activation='elu', name='dense_out_b_' + str(i))(x)
        x = BatchNormalization(momentum=0.6)(x)

        #### compress further for output, but forward fill 64 feature x to next iteration

        x_r = Dense(32, activation='elu', name='dense_out_c_' + str(i))(x)
        #x_r = Concatenate()([x_r, ng_coords, ng_dist])

        #x_r = Concatenate()([StopGradient()(coords),x_r]) ## add coordinates, might come handy for cluster space

        if i >= total_iterations - 1:
            energy = MultiBackGather()(
                [energy,
                 gatherids])  #assign energy sum to all cluster components

        allfeat.append(MultiBackGather()([x_r, gatherids]))
        backgatheredids.append(MultiBackGather()([sel_gidx, gatherids]))
        backgathered_coords.append(MultiBackGather()([coords, gatherids]))

    x = Concatenate(name='allconcat')(allfeat)
    x = RaggedGlobalExchange()([x, row_splits])
    x = Dense(256, activation='elu', name='globalexdense')(x)
    x = RaggedGlobalExchange()([x, row_splits])

    x = Dense(128, activation='elu', name='alldense')(x)
    x = Dense(64, activation='elu')(x)
    x = Dense(32, activation='elu')(x)
    x = Concatenate()([first_coords, x])
    x = BatchNormalization(momentum=0.6)(x)

    pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id = create_outputs(
        x, feat)

    #
    #
    # double scale phase transition with linear beta + qmin
    #  -> more high beta points, but: payload loss will still scale one
    #     (or two, but then doesn't matter)
    #

    pred_beta = LLFullObjectCondensation(
        print_loss=True,
        energy_loss_weight=1e-5,
        position_loss_weight=1e-5,  #seems broken
        timing_loss_weight=1e-5,  #1e-3,
        beta_loss_scale=3.,
        repulsion_scaling=1.,
        q_min=2.0,
        use_average_cc_pos=False,
        use_energy_weights=True,
        prob_repulsion=True,
        phase_transition=True,
        phase_transition_double_weight=False,
        alt_potential_norm=True,
        cut_payload_beta_gradient=False)([
            pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id,
            orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid,
            row_splits
        ])

    return Model(inputs=Inputs,
                 outputs=[
                     pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time,
                     pred_id, rs
                 ] + backgatheredids + backgathered_coords)
예제 #8
0
파일: train.py 프로젝트: abao1999/DRNOC
def checkids(Inputs):
    return SelectFeatures(5,6)(Inputs[0])
예제 #9
0
def create_default_outputs(raw_inputs,
                           x,
                           x_row_splits,
                           energy_block=True,
                           n_ccoords=2,
                           add_beta=None,
                           add_beta_weight=0.2,
                           use_e_proxy=False,
                           scale_exp_e=True):

    beta = None
    if add_beta is not None:

        #the exact weighting can be learnt, but there has to be a positive correlation
        from tensorflow.keras.constraints import non_neg

        assert add_beta_weight < 1
        n_adds = float(len(add_beta))
        if isinstance(add_beta, list):
            add_beta = Concatenate()(add_beta)
            add_beta = ScalarMultiply(1. / n_adds)(add_beta)
        add_beta = Dense(
            1,
            activation='sigmoid',
            name="predicted_add_beta",
            kernel_constraint=non_neg(),  #maybe it figures it out...?
            kernel_initializer='ones')(add_beta)

        #tf.print(add_beta)

        add_beta = ScalarMultiply(add_beta_weight)(add_beta)

        beta = Dense(1, activation='sigmoid', name="pre_predicted_beta")(x)
        beta = ScalarMultiply(1 - add_beta_weight)(beta)
        beta = Add(name="predicted_beta")([beta, add_beta])

    else:
        beta = Dense(1, activation='sigmoid', name="predicted_beta")(x)

    #x_raw = BatchNormalization(momentum=0.6,name="pre_ccoords_bn")(raw_inputs)
    #pre_ccoords = Dense(64, activation='elu',name="pre_ccords")(Concatenate()([x,x_raw]))
    ccoords = Dense(2, activation=None, name="predicted_ccoords")(x)
    if n_ccoords > 2:
        ccoords = Concatenate()([
            ccoords,
            Dense(n_ccoords - 2, activation=None,
                  name="predicted_ccoords_add")(x)
        ])

    xy = Dense(2,
               activation=None,
               name="predicted_positions",
               kernel_initializer='zeros')(x)
    t = Dense(1,
              activation=None,
              name="predicted_time",
              kernel_initializer='zeros')(x)
    t = ScalarMultiply(1e-9)(t)
    xyt = Concatenate()([xy, t])

    energy = None
    if energy_block:
        e_proxy = None

        energy = indep_energy_block2(x,
                                     SelectFeatures(0, 1)(raw_inputs),
                                     ccoords,
                                     beta,
                                     x_row_splits,
                                     energy_proxy=e_proxy)
    else:
        energy = Dense(1, activation=None)(x)
        if scale_exp_e:
            energy = ExpMinusOne(name='predicted_energy')(energy)
        else:
            energy = ScalarMultiply(100.)(energy)

    #(None, 9) (None, 1) (None, 1) (None, 3) (None, 2)
    print(raw_inputs.shape, beta.shape, energy.shape, xyt.shape, ccoords.shape)
    return Concatenate(name="predicted_final")(
        [raw_inputs, beta, energy, xyt, ccoords])
예제 #10
0
def gravnet_model(Inputs, feature_dropout=-1., addBackGatherInfo=True):

    feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits = td.interpretAllModelInputs(
        Inputs)
    orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits
    gidx_orig = CreateGlobalIndices()(feat)

    _, row_splits = RaggedConstructTensor()([feat, row_splits])
    rs = row_splits

    feat_norm = ProcessFeatures()(feat)
    feat_norm = BatchNormalization(momentum=0.6)(feat_norm)
    allfeat = []
    x = feat_norm

    backgatheredids = []
    gatherids = []
    backgathered = []
    backgathered_coords = []
    energysums = []

    #really simple real coordinates
    energy = SelectFeatures(0, 1)(feat)
    orig_coords = SelectFeatures(4, 7)(feat_norm)
    coords = Dense(3,
                   use_bias=False,
                   kernel_initializer=tf.keras.initializers.Identity())(
                       orig_coords)  #just rotation and scaling

    #see whats there
    nidx, dist = KNN(K=64, radius=1.0)([coords, rs])

    x = RaggedGlobalExchange()([x, row_splits])
    x_c = Dense(32, activation='elu')(x)
    x_c = Dense(8)(x_c)  #just a few features are enough here
    #this can be full blown because of the small number of input features
    x_c = NeighbourApproxPCA(hidden_nodes=[32, 32, 9])(
        [coords, dist, x_c, nidx])
    x_mp = DistanceWeightedMessagePassing([32, 16, 8])([x, nidx, dist])
    x = Concatenate()([x, x_c, x_mp])
    #this is going to be among the most expensive operations:
    x = Dense(64, activation='elu', name='pre_dense_a')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(32, activation='elu', name='pre_dense_b')(x)
    x = BatchNormalization(momentum=0.6)(x)

    allfeat.append(x)
    backgathered_coords.append(coords)

    sel_gidx = gidx_orig

    cdist = dist
    ccoords = coords

    total_iterations = 5

    collectedccoords = []

    for i in range(total_iterations):

        cluster_neighbours = 5
        n_dimensions = 3  #make it plottable
        #derive new coordinates for clustering
        if i:
            ccoords = Add()([
                ccoords,
                (Dense(n_dimensions,
                       use_bias=False,
                       name='newcoords' + str(i),
                       kernel_initializer='zeros')(x))
            ])
            nidx, cdist = KNN(K=6 * cluster_neighbours,
                              radius=-1.0)([ccoords, rs])
            #here we use more neighbours to improve learning of the cluster space

        #cluster first
        #hier = Dense(1,activation='sigmoid')(x)
        #distcorr = Dense(dist.shape[-1],activation='relu',kernel_initializer='zeros')(Concatenate()([x,dist]))
        #dist = Add()([distcorr,dist])
        cdist = LocalDistanceScaling(max_scale=5.)([
            cdist,
            Dense(1, kernel_initializer='zeros')(Concatenate()([x, cdist]))
        ])
        #what if we let the individual distances scale here? so move hits in and out?

        x_cl, rs, bidxs, sel_gidx, energy, x, t_idx, coords, ccoords, cdist = LocalClusterReshapeFromNeighbours2(
            K=cluster_neighbours,
            radius=0.1,
            print_reduction=True,
            loss_enabled=True,
            loss_scale=1.,
            loss_repulsion=0.4,  #.5
            hier_transforms=[64, 32, 32],
            print_loss=True,
            name='clustering_' + str(i))([
                x, cdist, nidx, rs, sel_gidx, energy, x, t_idx, coords,
                ccoords, cdist, t_idx
            ])

        gatherids.append(bidxs)

        #explicit
        energy = ReduceSumEntirely()(
            energy)  #sums up all contained energy per cluster
        n_energy = BatchNormalization(momentum=0.6)(energy)

        x = Concatenate()([x_cl, n_energy])
        x = Dense(128, activation='elu', name='dense_clc_a' + str(i))(x)
        x = BatchNormalization(momentum=0.6)(x)
        #x = Dense(128, activation='elu',name='dense_clc_b'+str(i))(x)
        x = Dense(64, activation='elu')(x)
        x = BatchNormalization(momentum=0.6)(x)

        nneigh = 64
        nfilt = 64
        nprop = 64

        x = Concatenate()([coords, x])

        x_gn, coords, nidx, dist = RaggedGravNet(n_neighbours=nneigh,
                                                 n_dimensions=n_dimensions,
                                                 n_filters=nfilt,
                                                 n_propagate=nprop)([x, rs])

        x_sp = Dense(16)(x)
        x_sp = NeighbourApproxPCA(hidden_nodes=[32, 32, n_dimensions**2])(
            [coords, dist, x_sp, nidx])
        x_sp = BatchNormalization(momentum=0.6)(x_sp)

        x_mp = DistanceWeightedMessagePassing([32, 32, 16, 16, 8,
                                               8])([x, nidx, dist])
        x_mp = BatchNormalization(momentum=0.6)(x_mp)
        #x_sp=x_mp

        x = Concatenate()([x, x_mp, x_sp, x_gn])
        #check and compress it all
        x = Dense(128, activation='elu', name='dense_a_' + str(i))(x)
        x = BatchNormalization(momentum=0.6)(x)
        #x = Dense(128, activation='elu',name='dense_b_'+str(i))(x)
        x = Dense(64, activation='elu', name='dense_c_' + str(i))(x)
        x = Concatenate()([StopGradient()(ccoords), StopGradient()(cdist), x])
        x = BatchNormalization(momentum=0.6)(x)

        #record more and more the deeper we go
        x_r = x
        energysums.append(MultiBackGather()(
            [energy, gatherids]))  #assign energy sum to all cluster components

        allfeat.append(MultiBackGather()([x_r, gatherids]))

        backgatheredids.append(MultiBackGather()([sel_gidx, gatherids]))
        backgathered_coords.append(MultiBackGather()([ccoords, gatherids]))

    x = Concatenate(name='allconcat')(allfeat)
    x = RaggedGlobalExchange()([x, row_splits])
    #x - Dropout(0.3)(x)#force to use different information sources
    x = Dense(128, activation='elu', name='alldense')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(64, activation='elu')(x)
    #x = Concatenate()([x]+energysums)

    pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id = create_outputs(
        x, feat)

    #loss
    pred_beta = LLFullObjectCondensation(
        print_loss=True,
        energy_loss_weight=1e-4,
        position_loss_weight=1e-2,
        timing_loss_weight=1e-3,
        beta_loss_scale=1.,
        repulsion_scaling=1.,
        q_min=1.5,
        use_average_cc_pos=0.1,
        prob_repulsion=True,
        phase_transition=1,
        alt_potential_norm=True,
        kalpha_damping_strength=0.5,  #1.,
        name="FullOCLoss")([
            pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id,
            orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid,
            row_splits
        ])

    return Model(inputs=Inputs,
                 outputs=[
                     pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time,
                     pred_id, rs
                 ] + backgatheredids + backgathered_coords)
예제 #11
0
def gravnet_model(Inputs, nclasses, nregressions, feature_dropout=-1.):
    coords = []
    feats = []

    x = Inputs[0]  #this is the self.x list from the TrainData data structure
    x = CenterPhi(2)(x)
    mask = CreateZeroMask(0)(x)
    x_in = norm_and_mask(x, mask)

    etas_phis = SelectFeatures(1, 3)(
        x)  #eta, phi, just to propagate to the prediction
    r_coordinate = SelectFeatures(4, 5)(x)
    energy = SelectFeatures(0, 1)(x)
    x = Concatenate()([etas_phis, r_coordinate,
                       x])  #just for the kernel initializer

    x = norm_and_mask(x, mask)
    x, coord = GravNet(n_neighbours=40,
                       n_dimensions=3,
                       n_filters=80,
                       n_propagate=16,
                       name='gravnet_pre',
                       fix_coordinate_space=True,
                       also_coordinates=True,
                       masked_coordinate_offset=-10)([x, mask])
    x = norm_and_mask(x, mask)
    coords.append(coord)
    feats.append(x)

    for i in range(n_gravnet_layers):
        x = GlobalExchange()(x)

        x = Dense(64, activation='elu', name='dense_a_' + str(i))(x)
        x = norm_and_mask(x, mask)
        x = Dense(64, activation='elu', name='dense_b_' + str(i))(x)
        #x = Concatenate()([TransformCoordinates()(x),x])
        x = Dense(64, activation='elu', name='dense_c_' + str(i))(x)
        x = norm_and_mask(x, mask)
        x, coord = GravNet(
            n_neighbours=40,
            n_dimensions=4,
            n_filters=80,
            n_propagate=16,
            name='gravnet_' + str(i),
            also_coordinates=True,
            feature_dropout=feature_dropout,
            masked_coordinate_offset=-10)([
                x, mask
            ])  #shift+activation makes it impossible to mix real with zero-pad
        x = norm_and_mask(x, mask)
        coords.append(coord)
        feats.append(x)

    x = Concatenate()(feats)
    x = Dense(64, activation='elu', name='dense_a_last')(x)
    x = Dense(64, activation='elu', name='dense_b_last')(x)
    x = norm_and_mask(x, mask)
    x = Dense(64, activation='elu', name='dense_c_last')(x)
    x = norm_and_mask(x, mask)

    n_showers = AveragePoolVertices(keepdims=True)(x)
    n_showers = Dense(64, activation='elu',
                      name='dense_n_showers_a')(n_showers)
    n_showers = Dense(1, activation=None, name='dense_n_showers')(n_showers)

    x = Dense(nregressions, activation=None, name='dense_pre_fracs')(x)
    x = Concatenate()([x, x_in])
    x = Dense(64, activation='elu', name='dense_last_correction')(x)
    x = Dense(nregressions,
              activation=None,
              name='dense_fracs',
              kernel_initializer=keras.initializers.RandomNormal(
                  mean=0.0, stddev=0.01))(x)

    x = Concatenate(name="concatlast",
                    axis=-1)([x] + coords + [n_showers] + [etas_phis])
    x = Multiply()([x, mask])
    predictions = [x]
    return Model(inputs=Inputs, outputs=predictions)
예제 #12
0
def gravnet_model(Inputs, feature_dropout=-1.):
    nregressions = 5

    # I_data = tf.keras.Input(shape=(num_features,), dtype="float32")
    # I_splits = tf.keras.Input(shape=(1,), dtype="int32")

    I_data = Inputs[0]
    I_splits = tf.cast(Inputs[1], tf.int32)

    x_data, x_row_splits = RaggedConstructTensor(name="construct_ragged")(
        [I_data, I_splits])

    input_features = x_data  # these are going to be passed through for the loss

    x_basic = BatchNormalization(momentum=0.6)(
        x_data)  # mask_and_norm is just batch norm now
    x = x_basic

    n_filters = 0
    n_gravnet_layers = 4
    feat = [x]
    betas = []
    for i in range(n_gravnet_layers):
        n_filters = 128
        n_propagate = [32, 32, 64, 64, 128, 128]
        n_neighbours = 128
        n_dim = 4 - i
        if n_dim < 2:
            n_dim = 2
        x = Concatenate()([x_basic, x])
        layer = None
        inputs = None
        if i < 2:
            layer = FusedRaggedGravNetLinParse(n_neighbours=n_neighbours,
                                               n_dimensions=n_dim,
                                               n_filters=n_filters,
                                               n_propagate=n_propagate,
                                               name='gravnet_' + str(i))
            inputs = [x, x_row_splits]
        else:
            layer = FusedRaggedGravNetGarNetLike(n_neighbours=n_neighbours,
                                                 n_dimensions=n_dim,
                                                 n_filters=n_filters,
                                                 n_propagate=n_propagate,
                                                 name='gravnet_bounce_' +
                                                 str(i))
            thresh = Dense(1, activation='sigmoid')(x)
            betas.append(thresh)
            inputs = [x, x_row_splits, thresh]

        x, coords = layer(inputs)

        #tf.print('minmax coords',tf.reduce_min(coords),tf.reduce_max(coords))

        x = BatchNormalization(momentum=0.6)(x)
        x = Dense(96, activation='elu',
                  name="dense_bottom_" + str(i) + "_a")(x)
        x = Dense(96, activation='elu',
                  name="dense_bottom_" + str(i) + "_b")(x)
        x = Dense(96, activation='elu',
                  name="dense_bottom_" + str(i) + "_c")(x)
        x = BatchNormalization(momentum=0.6)(x)

        feat.append(x)

    x = Concatenate(name="concat_gravout")(feat)
    x = RaggedGlobalExchange(name="global_exchange_bottom_" +
                             str(i))([x, x_row_splits])
    x = Dense(128, activation='elu', name="dense_last_a")(x)
    x = Dense(128, activation='elu', name="dense_last_a1")(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(128, activation='elu', name="dense_last_a2")(x)
    x = Dense(64, activation='elu', name="dense_last_b")(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(64, activation='elu', name="dense_last_c")(x)

    beta = ScalarMultiply(3 / 4)(Dense(1,
                                       activation='sigmoid',
                                       name="dense_beta")(x))
    add_beta = ScalarMultiply(1 / (4. * float(len(betas))))(Add()(betas))
    beta = Add()([beta, add_beta])

    xy = Dense(2, activation=None, name="dense_xy",
               kernel_initializer='zeros')(x)
    t = ScalarMultiply(1e-9)(Dense(1,
                                   activation=None,
                                   name="dense_t",
                                   kernel_initializer='zeros')(x))
    ccoords = Dense(2, activation=None, name="dense_ccoords")(x)

    energy = indep_energy_block2(x,
                                 SelectFeatures(0, 1)(x_basic), ccoords, beta,
                                 x_row_splits)

    x = Concatenate(name="concat_final")(
        [input_features, beta, energy, xy, t, ccoords])

    # x = Concatenate(name="concatlast", axis=-1)([x,coords])#+[n_showers]+[etas_phis])
    predictions = x

    # outputs = tf.tuple([predictions, x_row_splits])

    return Model(inputs=Inputs, outputs=[predictions, predictions])
예제 #13
0
def gravnet_model(Inputs, feature_dropout=-1., addBackGatherInfo=True):

    feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits = td.interpretAllModelInputs(
        Inputs)
    orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits
    gidx_orig = CreateGlobalIndices()(feat)

    _, row_splits = RaggedConstructTensor()([feat, row_splits])
    rs = row_splits

    feat_norm = ProcessFeatures()(feat)
    allfeat = [feat_norm]
    x = feat_norm

    backgatheredids = []
    gatherids = []
    backgathered = []
    backgathered_coords = []

    energy = SelectFeatures(0, 1)(feat)
    orig_coords = SelectFeatures(5, 8)(feat_norm)
    coords = orig_coords
    coords = Dense(3)(coords)  #just rotation and scaling
    #add gradient

    #see whats there
    nidx, dist = KNN(K=96, radius=1.0)([coords, rs])
    x = Dense(4, activation='elu',
              name='pre_pre_dense')(x)  #just a few features are enough here
    #this can be full blown because of the small number of input features
    x_c = SoftPixelCNN(mode='full', subdivisions=4,
                       name='prePCNN')([coords, x, dist, nidx])
    x = Concatenate()([x, x_c])
    #this is going to be among the most expensive operations:
    x = Dense(128, activation='elu', name='pre_dense_a')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(64, activation='elu', name='pre_dense_b')(x)
    x = BatchNormalization(momentum=0.6)(x)

    #allfeat.append(Dense(16, activation='elu',name='feat_compress_pre')(x))
    backgathered_coords.append(coords)

    sel_gidx = gidx_orig

    total_iterations = 5

    for i in range(total_iterations):

        #cluster first
        hier = Dense(1, activation='sigmoid')(x)
        x_cl, rs, bidxs, sel_gidx, energy, x, t_idx = LocalClusterReshapeFromNeighbours(
            K=8,
            radius=0.1,
            print_reduction=True,
            loss_enabled=True,
            loss_scale=2.,
            loss_repulsion=0.3,
            print_loss=True,
            name='clustering_' + str(i))(
                [x, dist, hier, nidx, rs, sel_gidx, energy, x, t_idx, t_idx])

        #explicit
        energy = ReduceSumEntirely()(
            energy)  #sums up all contained energy per cluster

        gatherids.append(bidxs)
        x_cl = Dense(128, activation='elu', name='dense_clc_' + str(i))(x_cl)
        n_energy = BatchNormalization(momentum=0.6)(energy)
        x = Concatenate()([x, x_cl, n_energy])

        pixelcompress = 4
        nneigh = 32 + 4 * i
        nfilt = 32 + 4 * i
        nprop = 32
        n_dimensions = 3  #make it plottable

        x_gn, coords, nidx, dist = RaggedGravNet(n_neighbours=nneigh,
                                                 n_dimensions=n_dimensions,
                                                 n_filters=nfilt,
                                                 n_propagate=nprop)([x, rs])

        subdivisions = 4

        #add more shape information
        x_sp = Dense(pixelcompress, activation='elu')(x)
        x_sp = BatchNormalization(momentum=0.6)(x_sp)
        x_sp = SoftPixelCNN(mode='full',
                            subdivisions=4)([coords, x_sp, dist, nidx])
        x_sp = Dense(128, activation='elu', name='dense_spc_' + str(i))(x_sp)

        x_mp = MessagePassing([32, 32, 16, 16, 8, 8])([x, nidx])
        x_mp = Dense(nfilt, activation='elu', name='dense_mpc_' + str(i))(x_sp)

        x = Concatenate()([x, x_sp, x_gn, x_mp])
        #check and compress it all
        x = Dense(128, activation='elu', name='dense_a_' + str(i))(x)
        x = Dense(64, activation='elu', name='dense_b_' + str(i))(x)
        x = BatchNormalization(momentum=0.6)(x)

        x_r = x
        #record more and more the deeper we go
        if i < total_iterations - 1:
            x_r = Dense(12 * (i + 1),
                        activation='elu',
                        name='dense_rec_' + str(i))(x)
        else:
            energy = MultiBackGather()(
                [energy,
                 gatherids])  #assign energy sum to all cluster components

        allfeat.append(MultiBackGather()([x_r, gatherids]))

        backgatheredids.append(MultiBackGather()([sel_gidx, gatherids]))
        backgathered_coords.append(MultiBackGather()([coords, gatherids]))

    x = Concatenate(name='allconcat')(allfeat)
    #x - Dropout(0.3)(x)#force to use different information sources
    x = Dense(128, activation='elu', name='alldense')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(64, activation='elu')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Concatenate()([feat, x])
    x = BatchNormalization(momentum=0.6)(x)
    x = Concatenate()([x, energy])

    pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id = create_outputs(
        x, feat)

    #loss
    pred_beta = LLFullObjectCondensation(
        print_loss=True,
        energy_loss_weight=1e-4,
        position_loss_weight=1e-2,
        timing_loss_weight=1e-3,
        beta_loss_scale=1.,
        repulsion_scaling=1.,
        q_min=1.5,
        prob_repulsion=True,
        phase_transition=1,
        alt_potential_norm=True)([
            pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id,
            orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid,
            row_splits
        ])

    return Model(inputs=Inputs,
                 outputs=[
                     pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time,
                     pred_id, rs
                 ] + backgatheredids + backgathered_coords)
예제 #14
0
def gravnet_model(Inputs, feature_dropout=-1., addBackGatherInfo=True):

    feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits = td.interpretAllModelInputs(
        Inputs)
    orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits
    gidx_orig = CreateGlobalIndices()(feat)
    energy = SelectFeatures(0, 1)(feat)

    _, row_splits = RaggedConstructTensor()([feat, row_splits])
    rs = row_splits

    #n_cluster_coords=6

    x = feat  #Dense(64,activation='elu')(feat)

    backgatheredids = []
    gatherids = []
    backgathered = []
    backgathered_en = []

    allfeat = [x]

    sel_gidx = gidx_orig
    for i in range(6):

        pixelcompress = 16 + 8 * i
        nneigh = 8 + 32 * i
        n_dimensions = int(4 + i)

        x = BatchNormalization(momentum=0.6)(x)
        #do some processing
        x = Dense(64, activation='elu', name='dense_a_' + str(i))(x)
        x = BatchNormalization(momentum=0.6)(x)
        x, coords, nidx, dist = RaggedGravNet(n_neighbours=nneigh,
                                              n_dimensions=n_dimensions,
                                              n_filters=64,
                                              n_propagate=64)([x, rs])
        x = Concatenate()([x, MessagePassing([64, 32, 16, 8])([x, nidx])])

        #allfeat.append(x)

        x = Dense(128, activation='elu')(x)
        x = Dense(pixelcompress, activation='elu')(x)

        x_sp = SoftPixelCNN(length_scale=1.)([coords, x, nidx])
        x = Concatenate()([x, x_sp])
        x = BatchNormalization(momentum=0.6)(x)
        x = Dense(128, activation='elu')(x)
        x = Dense(pixelcompress, activation='elu')(x)

        hier = Dense(1, activation='sigmoid',
                     name='hier_' + str(i))(x)  #clustering hierarchy

        dist = ScalarMultiply(1 / (3. * float(i + 0.5)))(dist)

        x, rs, bidxs, t_idx = LocalClusterReshapeFromNeighbours(
            K=5,
            radius=0.1,
            print_reduction=True,
            loss_enabled=True,
            loss_scale=2.,
            loss_repulsion=0.5,
            print_loss=True)([x, dist, hier, nidx, rs, t_idx, t_idx])

        print('>>>>x0', x.shape)
        gatherids.append(bidxs)
        if addBackGatherInfo:
            backgatheredids.append(MultiBackGather()([sel_gidx, gatherids]))

        print('>>>>x1', x.shape)
        x = BatchNormalization(momentum=0.6)(x)
        x_record = Dense(2 * pixelcompress, activation='elu')(x)
        x_record = BatchNormalization(momentum=0.6)(x_record)
        allfeat.append(MultiBackGather()([x_record, gatherids]))

    x = Concatenate(name='allconcat')(allfeat)
    x = Dense(128, activation='elu', name='alldense')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(64, activation='elu')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Concatenate()([feat, x])
    x = BatchNormalization(momentum=0.6)(x)

    pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id = create_outputs(
        x, feat)

    #loss
    pred_beta = LLFullObjectCondensation(
        print_loss=True,
        energy_loss_weight=1e-3,
        position_loss_weight=1e-3,
        timing_loss_weight=1e-3,
    )([
        pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id,
        orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid,
        row_splits
    ])

    return Model(inputs=Inputs,
                 outputs=[
                     pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time,
                     pred_id, rs
                 ] + backgatheredids)
예제 #15
0
def normalise_but_energy(x, name, momentum=0.6):
    e = SelectFeatures(0, 1)(x)
    r = SelectFeatures(1, x.shape[-1])(x)
    r = BatchNormalization(momentum=momentum, name=name)(r)
    return Concatenate()([e, r])
def gravnet_model(Inputs, feature_dropout=-1., addBackGatherInfo=True):

    ######## pre-process all inputs and create global indices etc. No DNN actions here

    feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits = td.interpretAllModelInputs(
        Inputs)
    feat, t_idx, t_energy, t_pos, t_time, t_pid = NormalizeInputShapes()(
        [feat, t_idx, t_energy, t_pos, t_time, t_pid])

    orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits
    gidx_orig = CreateGlobalIndices()(feat)

    _, row_splits = RaggedConstructTensor()([feat, row_splits])
    rs = row_splits

    feat_norm = ProcessFeatures()(
        feat)  #get rid of unit scalings, almost normalise
    feat_norm = BatchNormalization(momentum=0.6)(feat_norm)
    x = feat_norm

    energy = SelectFeatures(0, 1)(feat)
    time = SelectFeatures(8, 9)(feat)
    orig_coords = SelectFeatures(5, 8)(feat_norm)

    ######## create output lists

    allfeat = []

    backgatheredids = []
    gatherids = []
    backgathered = []
    backgathered_coords = []

    ####### create simple first coordinate transformation explicitly (time critical)

    coords = orig_coords
    coords = Dense(16, activation='elu')(coords)
    coords = Dense(32, activation='elu')(coords)
    coords = Dense(3, use_bias=False)(coords)
    coords = ScalarMultiply(0.1)(coords)
    coords = Add()([coords, orig_coords])
    coords = Dense(3,
                   use_bias=False,
                   kernel_initializer=tf.keras.initializers.identity())(coords)

    first_coords = coords

    ###### apply one gravnet-like transformation (explicit here because we have created coords by hand) ###

    nidx, dist = KNN(K=48)([coords, rs])
    x_mp = DistanceWeightedMessagePassing([32])([x, nidx, dist])

    first_nidx = nidx
    first_dist = dist

    ###### collect information about the surrounding energy and time distributions per vertex ###

    ncov = NeighbourCovariance()(
        [coords, ReluPlusEps()(Concatenate()([energy, time])), nidx])
    ncov = BatchNormalization(momentum=0.6)(ncov)
    ncov = Dense(64, activation='elu', name='pre_dense_ncov_a')(ncov)
    ncov = Dense(32, activation='elu', name='pre_dense_ncov_b')(ncov)

    ##### put together and process ####

    x = Concatenate()([x, x_mp, ncov, coords])
    x = Dense(64, activation='elu', name='pre_dense_a')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(32, activation='elu', name='pre_dense_b')(x)

    ####### add first set of outputs to output lists

    allfeat.append(x)
    backgathered_coords.append(coords)

    total_iterations = 5

    sel_gidx = gidx_orig

    for i in range(total_iterations):

        ###### reshape the graph to fewer vertices ####

        hier = Dense(1)(x)
        dist = LocalDistanceScaling()([dist, Dense(1)(x)])

        x_cl, rs, bidxs, sel_gidx, energy, x, t_idx, coords = LocalClusterReshapeFromNeighbours(
            K=6,
            radius=
            0.5,  #doesn't really have an effect because of local distance scaling
            print_reduction=True,
            loss_enabled=True,
            loss_scale=4.,
            loss_repulsion=0.5,
            print_loss=True,
            name='clustering_' + str(i))([
                x, dist, hier, nidx, rs, sel_gidx, energy, x, t_idx, coords,
                t_idx
            ])  #last is truth index used by layer

        gatherids.append(bidxs)

        if i or True:
            x_cl_rs = Reshape([-1, x.shape[-1]])(x_cl)  #get to shape V x K x F
            xec = EdgeConvStatic([32, 32, 32],
                                 name="ec_static_" + str(i))(x_cl_rs)
            x_cl = Concatenate()([x, xec])

        ### explicitly sum energy and re-add to features

        energy = ReduceSumEntirely()(energy)
        n_energy = BatchNormalization(momentum=0.6)(energy)
        x = Concatenate()([x_cl, n_energy])

        x = Dense(128, activation='elu', name='dense_clc0_' + str(i))(x)
        x = Dense(64, activation='relu', name='dense_clc1_' + str(i))(x)
        #notice last relu for feature weighting later

        x_gn, coords, nidx, dist = RaggedGravNet(
            n_neighbours=32 + 16 * i,
            n_dimensions=3,
            n_filters=64 + 16 * i,
            n_propagate=64,
            return_self=True)([Concatenate()([coords, x]), rs])

        ### add neighbour summary statistics

        x_ncov = NeighbourCovariance()([coords, ReluPlusEps()(x), nidx])
        x_ncov = Dense(128, activation='elu',
                       name='dense_ncov_a_' + str(i))(x_ncov)
        x_ncov = BatchNormalization(momentum=0.6)(x_ncov)
        x_ncov = Dense(64, activation='elu',
                       name='dense_ncov_b_' + str(i))(x_ncov)
        x = Concatenate()([x, x_ncov, x_gn])

        ### with all this information perform a few message passing steps

        x_mp = MessagePassing([32, 32, 16, 16, 8, 8])([x, nidx])
        x_mp = Dense(64, activation='elu', name='dense_mpc_' + str(i))(x_mp)
        x = Concatenate()([x, x_mp])

        ##### prepare output of this iteration

        x = Dense(128, activation='elu', name='dense_out_a_' + str(i))(x)
        x = BatchNormalization(momentum=0.6)(x)
        x = Dense(64, activation='elu', name='dense_out_b_' + str(i))(x)
        x = BatchNormalization(momentum=0.6)(x)

        #### compress further for output, but forward fill 64 feature x to next iteration

        x_r = Dense(8 + 16 * i, activation='elu',
                    name='dense_out_c_' + str(i))(x)
        #coords_nograd = StopGradient()(coords)
        #x_r = Concatenate()([coords_nograd,x_r]) ## add coordinates, might come handy for cluster space

        if i >= total_iterations - 1:
            energy = MultiBackGather()(
                [energy,
                 gatherids])  #assign energy sum to all cluster components

        allfeat.append(MultiBackGather()([x_r, gatherids]))
        backgatheredids.append(MultiBackGather()([sel_gidx, gatherids]))
        backgathered_coords.append(MultiBackGather()([coords, gatherids]))

    x = Concatenate(name='allconcat')(allfeat)
    #x = Dropout(0.2)(x)
    x_mp = DistanceWeightedMessagePassing([32, 32,
                                           32])([x, first_nidx, first_dist])
    x = Concatenate()([x, x_mp])

    x = Dense(128, activation='elu', name='alldense')(x)
    # TO BE ADDED WITH E LOSS x = Concatenate()([x,energy])
    #x = Dropout(0.2)(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(64, activation='elu')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(64, activation='elu')(x)

    pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id = create_outputs(
        x, feat)

    #
    #
    # double scale phase transition with linear beta + qmin
    #  -> more high beta points, but: payload loss will still scale one
    #     (or two, but then doesn't matter)
    #

    pred_beta = LLFullObjectCondensation(
        print_loss=True,
        energy_loss_weight=0.,
        position_loss_weight=0.,  #seems broken
        timing_loss_weight=0.,  #1e-3,
        beta_loss_scale=1.,
        repulsion_scaling=1.,
        q_min=1.5,
        prob_repulsion=True,
        phase_transition=True,
        phase_transition_double_weight=False,
        alt_potential_norm=True,
        cut_payload_beta_gradient=False)([
            pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id,
            orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid,
            row_splits
        ])

    return ExtendedMetricsModel(inputs=Inputs,
                                outputs=[
                                    pred_beta, pred_ccoords, pred_energy,
                                    pred_pos, pred_time, pred_id, rs
                                ] + backgatheredids + backgathered_coords)