Ejemplo n.º 1
0
def create_outputs(x,
                   feat,
                   energy=None,
                   n_ccoords=3,
                   n_classes=4,
                   td=TrainData_NanoML(),
                   add_features=True,
                   fix_distance_scale=False,
                   scale_energy=False,
                   energy_factor=True,
                   energy_proxy=None,
                   name_prefix="output_module"):
    '''
    returns pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id
    '''
    assert scale_energy != energy_factor

    feat = td.createFeatureDict(feat)

    pred_beta = Dense(1, activation='sigmoid', name=name_prefix + '_beta')(x)
    pred_ccoords = Dense(
        n_ccoords,
        #this initialisation is much better than standard glorot
        kernel_initializer=EyeInitializer(stddev=0.001),
        use_bias=False,
        name=name_prefix + '_clustercoords')(x)  #bias has no effect

    if energy_proxy is None:
        energy_proxy = x
    else:
        energy_proxy = Concatenate()([energy_proxy, x])
    energy_act = None
    if energy_factor:
        energy_act = 'relu'
    pred_energy = Dense(
        1,
        name=name_prefix + '_energy',
        bias_initializer='ones',  #no effect if full scale, useful if corr factor
        activation=energy_act)(energy_proxy)
    if scale_energy:
        pred_energy = ScalarMultiply(10.)(pred_energy)
    if energy is not None:
        pred_energy = Multiply()([pred_energy, energy])

    pred_pos = Dense(2, use_bias=False, name=name_prefix + '_pos')(x)
    pred_time = ScalarMultiply(10.)(Dense(1)(x))

    if add_features:
        pred_pos = Add()([feat['recHitXY'], pred_pos])
    pred_id = Dense(n_classes,
                    activation="softmax",
                    name=name_prefix + '_class')(x)

    pred_dist = OnesLike()(pred_time)
    if not fix_distance_scale:
        pred_dist = ScalarMultiply(2.)(Dense(1,
                                             activation='sigmoid',
                                             name=name_prefix + '_dist')(x))
        #this needs to be bound otherwise fully anti-correlated with coordates scale
    return pred_beta, pred_ccoords, pred_dist, pred_energy, pred_pos, pred_time, pred_id
Ejemplo n.º 2
0
def create_outputs(x, feat, energy=None, n_ccoords=3, n_classes=6, td=TrainData_OC(), add_features=True):
    '''
    returns pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id
    '''
    
    feat = td.createFeatureDict(feat)
    
    pred_beta = Dense(1, activation='sigmoid')(x)
    pred_ccoords = Dense(n_ccoords,
                         #this initialisation is much better than standard glorot
                         kernel_initializer=tf.keras.initializers.RandomNormal(mean=0., stddev=1./float(x.shape[-1])),
                         use_bias=False)(x) #bias has no effect
    
    pred_energy = ScalarMultiply(10.)(Dense(1)(x))
    if energy is not None:
        pred_energy = Multiply()([pred_energy,energy])
        
    pred_pos =  Dense(2)(x)
    pred_time = ScalarMultiply(10.)(Dense(1)(x))
    if add_features:
        pred_pos =  Add()([feat['recHitXY'],pred_pos])
        pred_time = Add()([feat['recHitTime'],pred_time])
    pred_id = Dense(n_classes, activation="softmax")(x)
    
    return pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id
Ejemplo n.º 3
0
def create_output_layers(x, x_row_splits, n_ccoords=2,
                           add_beta=None, add_beta_weight=0.2, use_e_proxy=False,
                           scale_exp_e=True, n_classes=0):
    beta = None
    if add_beta is not None:

        # the exact weighting can be learnt, but there has to be a positive correlation
        from tensorflow.keras.constraints import non_neg

        assert add_beta_weight < 1
        n_adds = float(len(add_beta))
        if isinstance(add_beta, list):
            add_beta = Concatenate()(add_beta)
            add_beta = ScalarMultiply(1. / n_adds)(add_beta)
        add_beta = Dense(1, activation='sigmoid', name="predicted_add_beta",
                         kernel_constraint=non_neg(),  # maybe it figures it out...?
                         kernel_initializer='ones'
                         )(add_beta)

        # tf.print(add_beta)

        add_beta = ScalarMultiply(add_beta_weight)(add_beta)

        beta = Dense(1, activation='sigmoid', name="pre_predicted_beta")(x)
        beta = ScalarMultiply(1 - add_beta_weight)(beta)
        beta = Add(name="predicted_beta")([beta, add_beta])

    else:
        beta = Dense(1, activation='sigmoid', name="predicted_beta")(x)

    # x_raw = BatchNormalization(momentum=0.6,name="pre_ccoords_bn")(raw_inputs)
    # pre_ccoords = Dense(64, activation='elu',name="pre_ccords")(Concatenate()([x,x_raw]))
    ccoords = Dense(2, activation=None, name="predicted_ccoords")(x)
    if n_ccoords > 2:
        ccoords = Concatenate()([ccoords, Dense(n_ccoords - 2, activation=None, name="predicted_ccoords_add")(x)])

    xy = Dense(2, activation=None, name="predicted_positions", kernel_initializer='zeros')(x)
    t = Dense(1, activation=None, name="predicted_time", kernel_initializer='zeros')(x)
    t = ScalarMultiply(1e-9)(t)
    xyt = Concatenate()([xy, t])

    energy = Dense(1, activation=None)(x)
    if scale_exp_e:
        energy = ExpMinusOne(name='predicted_energy')(energy)
    else:
        energy = ScalarMultiply(100.)(energy)

    if n_classes > 0:
        classes_scores = Dense(n_classes, activation=None, name="predicted_classification_scores")(x)
       
        return Concatenate(name="predicted_final")([beta, energy, xyt, ccoords, classes_scores])
    else:
        return Concatenate(name="predicted_final")([beta, energy, xyt, ccoords])
Ejemplo n.º 4
0
def output_block(x, ids, energy_raw):
    p_beta = Dense(1, activation='sigmoid')(x)
    p_tpos = ScalarMultiply(10.)(Dense(2)(x))
    p_ID = Dense(2, activation='softmax')(x)

    p_E = (Dense(1)(x))
    p_ccoords = ScalarMultiply(10.)(Dense(2)(x))

    predictions = Concatenate()(
        [p_beta, p_E, p_tpos, p_ID, p_ccoords, ids, energy_raw])

    print('predictions', predictions.shape)
    return predictions
Ejemplo n.º 5
0
def my_model(Inputs, momentum=0.6):

    feat, mask, hitmatched = Inputs[0], Inputs[1], Inputs[2]

    x = Concatenate()([feat, mask, hitmatched])
    x = BatchNormalization(momentum=momentum)(x)

    x = Dense(32, activation='elu')(x)
    x = Dense(64, activation='elu')(x)
    x = Dense(64, activation='elu')(x)
    x = Dense(32, activation='elu')(x)
    x = BatchNormalization(momentum=momentum)(x)
    #x = TestLayer()(x)
    allx = []
    for i in range(8):
        x = GravNet_simple(n_neighbours=12,
                           n_dimensions=3,
                           n_filters=64,
                           n_propagate=24)(x)
        x = BatchNormalization(momentum=momentum)(x)
        x = Dropout(0.05)(x)
        allx.append(x)

    x = Concatenate()(allx)
    x = Dense(32, activation='tanh')(x)

    #predict corr factor for px, py, pz, so 3
    # correction factor is 1 + c, so ranges between 0 and 2
    correction = ScalarMultiply(.1)(x)
    correction = Dense(3, activation='sigmoid', use_bias=False)(correction)
    correction = ScalarMultiply(2.)(correction)
    confidence = Dense(3, activation='sigmoid')(x)

    correction = ReduceMeanVertices()(correction)
    confidence = ReduceMeanVertices()(confidence)

    predictions = Concatenate()([correction, confidence])
    return tf.keras.models.Model(inputs=Inputs, outputs=[predictions])
def gravnet_model(Inputs, feature_dropout=-1.):
    nregressions = 5

    # I_data = tf.keras.Input(shape=(num_features,), dtype="float32")
    # I_splits = tf.keras.Input(shape=(1,), dtype="int32")

    I_data = Inputs[0]
    I_splits = tf.cast(Inputs[1], tf.int32)

    x_data, x_row_splits = RaggedConstructTensor(name="construct_ragged")(
        [I_data, I_splits])

    input_features = x_data  # these are going to be passed through for the loss

    x_basic = BatchNormalization(momentum=0.6)(
        x_data)  # mask_and_norm is just batch norm now
    x = x_basic

    n_filters = 0
    n_gravnet_layers = 4
    feat = []
    for i in range(n_gravnet_layers):
        n_filters = 128
        n_propagate = 96
        n_neighbours = 200

        x = RaggedGlobalExchange(name="global_exchange_" +
                                 str(i))([x, x_row_splits])
        x = Dense(64, activation='elu', name="dense_" + str(i) + "_a")(x)
        x = Dense(64, activation='elu', name="dense_" + str(i) + "_b")(x)
        x = Dense(64, activation='elu', name="dense_" + str(i) + "_c")(x)
        x = BatchNormalization(momentum=0.6)(x)
        x = FusedRaggedGravNet_simple(n_neighbours=n_neighbours,
                                      n_dimensions=4,
                                      n_filters=n_filters,
                                      n_propagate=n_propagate,
                                      name='gravnet_' +
                                      str(i))([x, x_row_splits])
        x = BatchNormalization(momentum=0.6)(x)
        feat.append(
            Dense(128, activation='elu', name="dense_compress_" + str(i))(x))

    x = Concatenate(name="concat_gravout")(feat)
    x = Dense(128, activation='elu', name="dense_last_a")(x)
    x = Dense(64, activation='elu', name="dense_last_b")(x)
    x = Dense(64, activation='elu', name="dense_last_c")(x)

    beta = Dense(1, activation='sigmoid', name="dense_beta")(x)

    xy = Dense(2, activation=None, name="dense_xy")(x)
    t = ScalarMultiply(1e-9)(Dense(1, activation=None, name="dense_t")(x))
    ccoords = Dense(2, activation=None, name="dense_ccoords")(x)

    x_en = Dense(64, activation='elu', name="dense_en_a")(
        x)  #herer so the other names remain the same
    input_energy = SelectFeatures(0, 1, name="select_en")(input_features)
    energy_condensates, idxs = CondensateAndSum(
        radius=0.5, min_beta=0.1,
        name="condensate_en")([ccoords, beta, input_energy, x_row_splits])
    x_en = Concatenate(name="concat_en_cond")(
        [ScalarMultiply(10, name="multi_en")(x_en), energy_condensates])
    x_en = Dense(64, activation='elu', name="dense_en_b")(x_en)
    energy = Dense(1, activation=None, name="dense_en_final")(x_en)

    print('input_features', input_features.shape)

    x = Concatenate(name="concat_final")(
        [input_features, beta, energy, xy, t, ccoords])

    # x = Concatenate(name="concatlast", axis=-1)([x,coords])#+[n_showers]+[etas_phis])
    predictions = x

    # outputs = tf.tuple([predictions, x_row_splits])

    return Model(inputs=Inputs, outputs=[predictions, predictions])
Ejemplo n.º 7
0
def model(Inputs, feature_dropout=-1.):

    x = Inputs[0]  #this is the self.x list from the TrainData data structure
    x_in = BatchNormalization(momentum=momentum)(x)
    x = x_in
    feat = []

    x = reduce_block(x,
                     depth=4,
                     pooling=2,
                     kernel=(4, 4),
                     filters=[32, 48, 64, 96])  # -> 4x4 grid
    x = Concatenate()([x_in, x])
    x = Conv2D(64, (3, 3), padding='same', activation='elu')(x)
    x = Conv2DGlobalExchange()(x)
    x = BatchNormalization(momentum=momentum)(x)

    x = reduce_block(x,
                     depth=4,
                     pooling=2,
                     kernel=(4, 4),
                     filters=[32, 48, 64, 96])
    x = Concatenate()([x_in, x])
    x = Conv2D(64, (4, 4), padding='same', activation='elu')(x)
    x = Conv2DGlobalExchange()(x)
    x = BatchNormalization(momentum=momentum)(x)

    x = reduce_block(x,
                     depth=4,
                     pooling=2,
                     kernel=(4, 4),
                     filters=[32, 48, 64, 96])
    x = Concatenate()([x_in, x])
    x = Conv2D(64, (5, 5), padding='same', activation='elu')(x)
    x = Conv2DGlobalExchange()(x)
    x = BatchNormalization(momentum=momentum)(x)

    x = Dense(128, activation='elu')(x)
    x = Dropout(0.02)(x)
    x = Dense(64, activation='elu')(x)
    x = Dense(64, activation='elu')(x)
    '''
    p_beta   =  tf.reshape(pred[:,:,:,0:1], [pred.shape[0],pred.shape[1]*pred.shape[2],-1])
    p_tpos   =  tf.reshape(pred[:,:,:,1:3], [pred.shape[0],pred.shape[1]*pred.shape[2],-1])
    p_ID     =  tf.reshape(pred[:,:,:,3:6], [pred.shape[0],pred.shape[1]*pred.shape[2],-1])
    p_dim    =  tf.reshape(pred[:,:,:,6:8], [pred.shape[0],pred.shape[1]*pred.shape[2],-1])
    p_object  = pred[:,0,0,8]
    p_ccoords = tf.reshape(pred[:,:,:,8:10], [pred.shape[0],pred.shape[1]*pred.shape[2],-1])
                 
    '''

    p_beta = Conv2D(
        1,
        (1, 1),
        padding='same',
        activation='sigmoid',
        #kernel_initializer='zeros',
        trainable=True)(x)
    p_tpos = ScalarMultiply(64.)(Conv2D(2, (1, 1), padding='same')(x))
    p_ID = Conv2D(3, (1, 1), padding='same', activation='softmax')(x)
    p_dim = ScalarMultiply(32.)(Conv2D(2, (1, 1), padding='same')(x))
    #p_object  = Conv2D(1, (1,1), padding='same')(x)
    p_ccoords = Conv2D(2, (1, 1), padding='same')(x)

    predictions = Concatenate()([
        p_beta,
        p_tpos,
        p_ID,
        p_dim,
        #p_object ,
        p_ccoords
    ])

    print('predictions', predictions.shape)

    return Model(inputs=Inputs, outputs=predictions)
Ejemplo n.º 8
0
def gravnet_model(Inputs, feature_dropout=-1., momentum=0.6):
    nregressions = 5

    # I_data = tf.keras.Input(shape=(num_features,), dtype="float32")
    # I_splits = tf.keras.Input(shape=(1,), dtype="int32")

    I_data = Inputs[0]
    I_splits = tf.cast(Inputs[1], tf.int32)

    x_data, x_row_splits = RaggedConstructTensor(name="construct_ragged")(
        [I_data, I_splits])

    input_features = x_data  # these are going to be passed through for the loss

    x_basic = BatchNormalization(momentum=momentum)(
        x_data)  # mask_and_norm is just batch norm now
    x = x_basic

    n_filters = 0
    n_gravnet_layers = 7
    x = RaggedGlobalExchange(name="global_exchange_inputs")([x, x_row_splits])
    feat = [x]

    rs = x_row_splits
    allscatter = []
    proto = []
    all_beta_p = []
    beta_proto = []
    for i in range(n_gravnet_layers):

        n_filters = 128
        n_propagate = 96
        n_neighbours = 256

        x = Dense(64 + 8 * i, activation='elu',
                  name="dense_" + str(i) + "_a")(x)
        x = Dense(64 + 8 * i, activation='elu',
                  name="dense_" + str(i) + "_b")(x)
        x = Dense(64 + 8 * i, activation='elu',
                  name="dense_" + str(i) + "_c")(x)
        x = BatchNormalization(momentum=momentum)(x)
        x = FusedRaggedGravNet_simple(n_neighbours=n_neighbours,
                                      n_dimensions=4,
                                      n_filters=n_filters,
                                      n_propagate=n_propagate,
                                      name='gravnet_' + str(i))([x, rs])
        x = BatchNormalization(momentum=momentum)(x)

        if i < 2:
            feat.append(x)
            continue

        beta_p = Dense(1, activation='sigmoid', name='sel_x_' + str(i))(x)

        proto.append(x)
        beta_proto.append(beta_p)
        #tf.print('pre Sel')
        thresh = 0.5  #+0.1*i
        pre_rs = rs
        x, rs, scatter_idxs = RaggedSelectThreshold(thresh,
                                                    name='sel_thresh_' +
                                                    str(i))([beta_p, x, rs])
        beta_p, _, _ = RaggedSelectThreshold(thresh,
                                             name='sel_thresh_b_p_' +
                                             str(i))([beta_p, beta_p, pre_rs])
        #tf.print('post Sel')
        allscatter.append(scatter_idxs)

        print([p.shape for p in proto])
        x_scat = x
        b_scat = beta_p
        for k in range(len(allscatter)):
            l = len(proto) - k - 1
            x_scat = VertexScatterer(name='scat_' + str(i) + "_" +
                                     str(k))([x_scat, allscatter[l], proto[l]])
            b_scat = VertexScatterer(name='scat_beta_' + str(i) + "_" +
                                     str(k))([
                                         b_scat, allscatter[l], beta_proto[l]
                                     ])

        all_beta_p.append(b_scat)
        feat.append(x_scat)

    x = Concatenate(name="concat_gravout")(feat)

    x = Dense(256, activation='elu', name="dense_last_a")(x)
    x = Dense(128, activation='elu', name="dense_last_a2")(x)

    x = FusedRaggedGravNet_simple(n_neighbours=n_neighbours,
                                  n_dimensions=2,
                                  n_filters=128,
                                  n_propagate=64,
                                  name='gravnet_last')([x, x_row_splits])

    x = Dense(128, activation='elu', name="dense_last_a3")(x)
    x = Dense(64, activation='elu', name="dense_last_b")(x)
    x = Dense(64, activation='elu', name="dense_last_c")(x)

    all_beta_pt = Add()(all_beta_p)
    all_beta_p = ScalarMultiply(1 / (2. * len(all_beta_p)))(all_beta_pt)

    beta = ScalarMultiply(1 / 2.)(Dense(1,
                                        activation='sigmoid',
                                        name="dense_beta")(x))
    beta = Add()([all_beta_p, beta])

    xy = Dense(2, activation=None, name="dense_xy",
               kernel_initializer='zeros')(x)
    t = Dense(1, activation=None, name="dense_t",
              kernel_initializer='zeros')(x)
    ccoords = Dense(2, activation=None, name="dense_ccoords")(x)
    energy = Dense(1, activation=None, name="dense_en_final")(x)
    energy = ExpMinusOne(name="en_scaling")(energy)

    print('input_features', input_features.shape)

    x = Concatenate(name="concat_final")(
        [input_features, beta, energy, xy, t, ccoords])

    # x = Concatenate(name="concatlast", axis=-1)([x,coords])#+[n_showers]+[etas_phis])
    predictions = x

    # outputs = tf.tuple([predictions, x_row_splits])

    return Model(inputs=Inputs, outputs=[predictions, predictions])
Ejemplo n.º 9
0
def gravnet_model(Inputs, feature_dropout=-1.):
    nregressions = 5

    # I_data = tf.keras.Input(shape=(num_features,), dtype="float32")
    # I_splits = tf.keras.Input(shape=(1,), dtype="int32")

    I_data = Inputs[0]
    I_splits = tf.cast(Inputs[1], tf.int32)

    x_data, x_row_splits = RaggedConstructTensor(name="construct_ragged")(
        [I_data, I_splits])

    input_features = x_data  # these are going to be passed through for the loss

    x_basic = BatchNormalization(momentum=0.6)(
        x_data)  # mask_and_norm is just batch norm now
    x = x_basic

    n_filters = 0
    n_gravnet_blocks = 3
    feat = [x]
    all_beta_p = []
    for i in range(n_gravnet_blocks):
        n_filters = 128
        n_propagate = [128, 64, 32, 32, 16, 16, 8, 8, 4, 4]
        n_neighbours = 128

        x, coords = FusedRaggedGravNet(n_neighbours=n_neighbours,
                                       n_dimensions=4,
                                       n_filters=n_filters,
                                       n_propagate=n_propagate,
                                       name='gravnet_' +
                                       str(i))([x, x_row_splits])

        x = Dense(96, activation='elu',
                  name="dense_bottom_" + str(i) + "_a")(x)
        x = BatchNormalization(momentum=0.6)(x)
        feat.append(x)

        beta_p = Dense(1, activation='sigmoid', name='sel_x_' + str(i))(x)
        all_beta_p.append(beta_p)

        threshold = 0.3 * (i + 1)

        x = BatchNormalization(momentum=0.6)(x)
        x, coords = FusedMaskedRaggedGravNet(
            n_neighbours=n_neighbours,
            n_dimensions=4,
            n_filters=n_filters,
            n_propagate=128,
            direction='acc',  ##accumulate in potential condensation points
            threshold=threshold,
            ex_mode='xor',
            name='gravnet_masked_up_' + str(i))([x, x_row_splits, beta_p])

        x = Dense(96, activation='elu',
                  name="dense_bottom_" + str(i) + "_b")(x)
        x = BatchNormalization(momentum=0.6)(x)
        feat.append(x)

        x = BatchNormalization(momentum=0.6)(x)
        x, coords = FusedMaskedRaggedGravNet(
            n_neighbours=n_neighbours,
            n_dimensions=4,
            n_filters=n_filters,
            n_propagate=n_propagate,
            direction='acc',  ##accumulate in potential condensation points
            threshold=threshold,
            ex_mode='and',
            name='gravnet_masked_up_ex_' + str(i))([x, x_row_splits, beta_p])

        x = Dense(96, activation='elu',
                  name="dense_bottom_" + str(i) + "_c")(x)
        x = BatchNormalization(momentum=0.6)(x)
        feat.append(x)

        x = BatchNormalization(momentum=0.6)(x)
        x, coords = FusedMaskedRaggedGravNet(
            n_neighbours=n_neighbours,
            n_dimensions=4,
            n_filters=n_filters,
            n_propagate=128,
            direction='scat',  ##accumulate in potential condensation points
            threshold=threshold,
            ex_mode='xor',
            name='gravnet_masked_down_' + str(i))([x, x_row_splits, beta_p])

        x = RaggedGlobalExchange(name="global_exchange_bottom_" +
                                 str(i))([x, x_row_splits])
        x = Dense(96, activation='elu',
                  name="dense_bottom_" + str(i) + "_d")(x)
        x = BatchNormalization(momentum=0.6)(x)

        feat.append(x)

    x = Concatenate(name="concat_gravout")(feat)
    x = Dense(128, activation='elu', name="dense_last_a")(x)
    x = Dense(128, activation='elu', name="dense_last_a1")(x)
    x = Dense(64, activation='elu', name="dense_last_b")(x)
    x = Dense(64, activation='elu', name="dense_last_c")(x)

    all_beta_pt = Add()(all_beta_p)
    all_beta_p = ScalarMultiply(1 / (2. * len(all_beta_p)))(all_beta_pt)

    beta = ScalarMultiply(1 / 2.)(Dense(1,
                                        activation='sigmoid',
                                        name="dense_beta")(x))
    beta = Add()([all_beta_p, beta])

    xy = Dense(1, activation=None, name="dense_xy",
               kernel_initializer='zeros')(x)
    t = ScalarMultiply(1e-9)(Dense(1,
                                   activation=None,
                                   name="dense_t",
                                   kernel_initializer='zeros')(x))
    ccoords = Dense(2, activation=None, name="dense_ccoords")(x)
    energy = Dense(1, activation=None, name="dense_en_final")(x)
    energy = ExpMinusOne(name="en_scaling")(energy)

    print('input_features', input_features.shape)

    x = Concatenate(name="concat_final")(
        [input_features, beta, energy, xy, t, ccoords])

    # x = Concatenate(name="concatlast", axis=-1)([x,coords])#+[n_showers]+[etas_phis])
    predictions = x

    # outputs = tf.tuple([predictions, x_row_splits])

    return Model(inputs=Inputs, outputs=[predictions, predictions])
def gravnet_model(Inputs, feature_dropout=-1.):
    nregressions = 5

    # I_data = tf.keras.Input(shape=(num_features,), dtype="float32")
    # I_splits = tf.keras.Input(shape=(1,), dtype="int32")

    I_data = Inputs[0]
    I_splits = tf.cast(Inputs[1], tf.int32)

    x_data, x_row_splits = RaggedConstructTensor(name="construct_ragged")(
        [I_data, I_splits])

    input_features = x_data  # these are going to be passed through for the loss

    x_basic = BatchNormalization(momentum=0.6)(
        x_data)  # mask_and_norm is just batch norm now
    x = x_basic

    n_filters = 0
    n_gravnet_layers = 6

    feat = []
    coords = []
    for i in range(n_gravnet_layers):

        n_filters = 128
        n_propagate = 96
        n_neighbours = 256
        n_dim = 4 - int(i / 2)
        if n_dim < 2:
            n_dim = 2
            n_propagate = 4 * [32]

        x = RaggedGlobalExchange(name="global_exchange_" +
                                 str(i))([x, x_row_splits])
        x = Dense(64, activation='elu', name="dense_" + str(i) + "_a")(x)
        x = Dense(64, activation='elu', name="dense_" + str(i) + "_b")(x)
        x = Dense(64, activation='elu', name="dense_" + str(i) + "_c")(x)
        x = BatchNormalization(momentum=0.6)(x)

        x, c = FusedRaggedGravNet(n_neighbours=n_neighbours,
                                  n_dimensions=n_dim,
                                  n_filters=n_filters,
                                  n_propagate=n_propagate,
                                  name='gravnet_' + str(i))([x, x_row_splits])
        #
        #
        # needs a larger RF branch somewhere here
        #
        #

        x = BatchNormalization(momentum=0.6)(x)
        feat.append(x)
        if n_dim == 2:
            print('add coords for GN', i)
            coords.append(c)

    x = Concatenate(name="concat_gravout")(feat)
    x = Dense(128, activation='elu', name="dense_last_a")(x)
    x = Dense(64, activation='elu', name="dense_last_b")(x)
    x = Dense(64, activation='elu', name="dense_last_c")(x)

    beta = Dense(1, activation='sigmoid', name="dense_beta")(x)
    xy = Dense(2, activation=None, name="dense_xy",
               kernel_initializer='zeros')(x)
    t = ScalarMultiply(1e-9)(Dense(1,
                                   activation=None,
                                   name="dense_t",
                                   kernel_initializer='zeros')(x))
    ccoords = Dense(2, activation=None, name="dense_ccoords")(x)

    n_cc = len(coords)
    coords = Add()(coords)
    coords = ScalarMultiply(1 / (4 * n_cc))(coords)
    ccoords = Add()([ScalarMultiply(3. / 4.)(ccoords), coords])

    energy = Dense(1, activation=None, name="dense_en_final")(x)
    energy = ExpMinusOne(name="en_scaling")(energy)

    print('input_features', input_features.shape)

    x = Concatenate(name="concat_final")(
        [input_features, beta, energy, eta, phi, ccoords])

    # x = Concatenate(name="concatlast", axis=-1)([x,coords])#+[n_showers]+[etas_phis])
    predictions = x

    # outputs = tf.tuple([predictions, x_row_splits])

    return Model(inputs=Inputs, outputs=[predictions, predictions])
Ejemplo n.º 11
0
def gravnet_model(Inputs, feature_dropout=-1.):
    nregressions = 5

    # I_data = tf.keras.Input(shape=(num_features,), dtype="float32")
    # I_splits = tf.keras.Input(shape=(1,), dtype="int32")

    I_data = Inputs[0]
    I_splits = tf.cast(Inputs[1], tf.int32)

    x_data, x_row_splits = RaggedConstructTensor(name="construct_ragged")(
        [I_data, I_splits])

    input_features = x_data  # these are going to be passed through for the loss

    x_basic = BatchNormalization(momentum=0.6)(
        x_data)  # mask_and_norm is just batch norm now
    x = x_basic

    n_filters = 0
    n_gravnet_layers = 4
    feat = [x]
    betas = []
    for i in range(n_gravnet_layers):
        n_filters = 128
        n_propagate = [32, 32, 64, 64, 128, 128]
        n_neighbours = 128
        n_dim = 4 - i
        if n_dim < 2:
            n_dim = 2
        x = Concatenate()([x_basic, x])
        layer = None
        inputs = None
        if i < 2:
            layer = FusedRaggedGravNetLinParse(n_neighbours=n_neighbours,
                                               n_dimensions=n_dim,
                                               n_filters=n_filters,
                                               n_propagate=n_propagate,
                                               name='gravnet_' + str(i))
            inputs = [x, x_row_splits]
        else:
            layer = FusedRaggedGravNetGarNetLike(n_neighbours=n_neighbours,
                                                 n_dimensions=n_dim,
                                                 n_filters=n_filters,
                                                 n_propagate=n_propagate,
                                                 name='gravnet_bounce_' +
                                                 str(i))
            thresh = Dense(1, activation='sigmoid')(x)
            betas.append(thresh)
            inputs = [x, x_row_splits, thresh]

        x, coords = layer(inputs)

        #tf.print('minmax coords',tf.reduce_min(coords),tf.reduce_max(coords))

        x = BatchNormalization(momentum=0.6)(x)
        x = Dense(96, activation='elu',
                  name="dense_bottom_" + str(i) + "_a")(x)
        x = Dense(96, activation='elu',
                  name="dense_bottom_" + str(i) + "_b")(x)
        x = Dense(96, activation='elu',
                  name="dense_bottom_" + str(i) + "_c")(x)
        x = BatchNormalization(momentum=0.6)(x)

        feat.append(x)

    x = Concatenate(name="concat_gravout")(feat)
    x = RaggedGlobalExchange(name="global_exchange_bottom_" +
                             str(i))([x, x_row_splits])
    x = Dense(128, activation='elu', name="dense_last_a")(x)
    x = Dense(128, activation='elu', name="dense_last_a1")(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(128, activation='elu', name="dense_last_a2")(x)
    x = Dense(64, activation='elu', name="dense_last_b")(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(64, activation='elu', name="dense_last_c")(x)

    beta = ScalarMultiply(3 / 4)(Dense(1,
                                       activation='sigmoid',
                                       name="dense_beta")(x))
    add_beta = ScalarMultiply(1 / (4. * float(len(betas))))(Add()(betas))
    beta = Add()([beta, add_beta])

    xy = Dense(2, activation=None, name="dense_xy",
               kernel_initializer='zeros')(x)
    t = ScalarMultiply(1e-9)(Dense(1,
                                   activation=None,
                                   name="dense_t",
                                   kernel_initializer='zeros')(x))
    ccoords = Dense(2, activation=None, name="dense_ccoords")(x)

    energy = indep_energy_block2(x,
                                 SelectFeatures(0, 1)(x_basic), ccoords, beta,
                                 x_row_splits)

    x = Concatenate(name="concat_final")(
        [input_features, beta, energy, xy, t, ccoords])

    # x = Concatenate(name="concatlast", axis=-1)([x,coords])#+[n_showers]+[etas_phis])
    predictions = x

    # outputs = tf.tuple([predictions, x_row_splits])

    return Model(inputs=Inputs, outputs=[predictions, predictions])
Ejemplo n.º 12
0
def gravnet_model(Inputs, feature_dropout=-1., addBackGatherInfo=True):

    feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits = td.interpretAllModelInputs(
        Inputs)
    orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits
    gidx_orig = CreateGlobalIndices()(feat)
    energy = SelectFeatures(0, 1)(feat)

    _, row_splits = RaggedConstructTensor()([feat, row_splits])
    rs = row_splits

    #n_cluster_coords=6

    x = feat  #Dense(64,activation='elu')(feat)

    backgatheredids = []
    gatherids = []
    backgathered = []
    backgathered_en = []

    allfeat = [x]

    sel_gidx = gidx_orig
    for i in range(6):

        pixelcompress = 16 + 8 * i
        nneigh = 8 + 32 * i
        n_dimensions = int(4 + i)

        x = BatchNormalization(momentum=0.6)(x)
        #do some processing
        x = Dense(64, activation='elu', name='dense_a_' + str(i))(x)
        x = BatchNormalization(momentum=0.6)(x)
        x, coords, nidx, dist = RaggedGravNet(n_neighbours=nneigh,
                                              n_dimensions=n_dimensions,
                                              n_filters=64,
                                              n_propagate=64)([x, rs])
        x = Concatenate()([x, MessagePassing([64, 32, 16, 8])([x, nidx])])

        #allfeat.append(x)

        x = Dense(128, activation='elu')(x)
        x = Dense(pixelcompress, activation='elu')(x)

        x_sp = SoftPixelCNN(length_scale=1.)([coords, x, nidx])
        x = Concatenate()([x, x_sp])
        x = BatchNormalization(momentum=0.6)(x)
        x = Dense(128, activation='elu')(x)
        x = Dense(pixelcompress, activation='elu')(x)

        hier = Dense(1, activation='sigmoid',
                     name='hier_' + str(i))(x)  #clustering hierarchy

        dist = ScalarMultiply(1 / (3. * float(i + 0.5)))(dist)

        x, rs, bidxs, t_idx = LocalClusterReshapeFromNeighbours(
            K=5,
            radius=0.1,
            print_reduction=True,
            loss_enabled=True,
            loss_scale=2.,
            loss_repulsion=0.5,
            print_loss=True)([x, dist, hier, nidx, rs, t_idx, t_idx])

        print('>>>>x0', x.shape)
        gatherids.append(bidxs)
        if addBackGatherInfo:
            backgatheredids.append(MultiBackGather()([sel_gidx, gatherids]))

        print('>>>>x1', x.shape)
        x = BatchNormalization(momentum=0.6)(x)
        x_record = Dense(2 * pixelcompress, activation='elu')(x)
        x_record = BatchNormalization(momentum=0.6)(x_record)
        allfeat.append(MultiBackGather()([x_record, gatherids]))

    x = Concatenate(name='allconcat')(allfeat)
    x = Dense(128, activation='elu', name='alldense')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(64, activation='elu')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Concatenate()([feat, x])
    x = BatchNormalization(momentum=0.6)(x)

    pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id = create_outputs(
        x, feat)

    #loss
    pred_beta = LLFullObjectCondensation(
        print_loss=True,
        energy_loss_weight=1e-3,
        position_loss_weight=1e-3,
        timing_loss_weight=1e-3,
    )([
        pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id,
        orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid,
        row_splits
    ])

    return Model(inputs=Inputs,
                 outputs=[
                     pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time,
                     pred_id, rs
                 ] + backgatheredids)
Ejemplo n.º 13
0
def gravnet_model(
    Inputs,
    beta_loss_scale,
    q_min,
    use_average_cc_pos,
    kalpha_damping_strength,
    batchnorm_momentum=0.9999  #that's actually the damping factor. High -> slow
):

    feature_dropout = -1.
    addBackGatherInfo = True,

    feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits = td.interpretAllModelInputs(
        Inputs)
    orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits
    gidx_orig = CreateGlobalIndices()(feat)

    _, row_splits = RaggedConstructTensor()([feat, row_splits])
    rs = row_splits

    feat_norm = ProcessFeatures()(feat)
    #feat_norm = BatchNormalization(momentum=batchnorm_momentum)(feat_norm)
    allfeat = []
    x = feat_norm

    backgatheredids = []
    gatherids = []
    backgathered = []
    backgathered_coords = []
    energysums = []

    #really simple real coordinates
    energy = SelectFeatures(0, 1)(feat)
    orig_coords = SelectFeatures(5, 8)(feat)
    coords = ManualCoordTransform()(orig_coords)
    coords = Dense(3,
                   use_bias=False,
                   kernel_initializer=tf.keras.initializers.Identity())(
                       coords)  #just rotation and scaling

    #see whats there
    nidx, dist = KNN(K=64, radius=1.0)([coords, rs])

    x_c = Dense(32, activation='elu')(x)
    x_c = Dense(8)(x_c)  #just a few features are enough here
    #this can be full blown because of the small number of input features
    x_c = NeighbourApproxPCA()([coords, dist, x_c, nidx])
    x_mp = DistanceWeightedMessagePassing([32, 16, 8])([x, nidx, dist])
    x = Concatenate()([x, x_c, x_mp])
    #this is going to be among the most expensive operations:
    x = Dense(64, activation='elu', name='pre_dense_a')(x)
    x = Dense(32, activation='selu', name='pre_dense_b')(x)
    x = BatchNormalization(momentum=batchnorm_momentum)(x)

    allfeat.append(x)
    backgathered_coords.append(coords)

    sel_gidx = gidx_orig

    cdist = dist
    ccoords = coords

    total_iterations = 5

    collectedccoords = []

    for i in range(total_iterations):

        cluster_neighbours = 5
        n_dimensions = 3  #make it plottable
        #derive new coordinates for clustering
        if i:
            ccoords = Add()([
                ccoords,
                ScalarMultiply(0.3)(Dense(n_dimensions,
                                          name='newcoords' + str(i),
                                          kernel_initializer='zeros')(x))
            ])
            nidx, cdist = KNN(K=6 * cluster_neighbours,
                              radius=-1.0)([ccoords, rs])
            #here we use more neighbours to improve learning of the cluster space

        #cluster first
        #hier = Dense(1,activation='sigmoid')(x)
        #distcorr = Dense(dist.shape[-1],activation='relu',kernel_initializer='zeros')(Concatenate()([x,dist]))
        #dist = Add()([distcorr,dist])
        cdist = LocalDistanceScaling(max_scale=5.)([
            cdist,
            Dense(1, kernel_initializer='zeros')(Concatenate()([x, cdist]))
        ])
        #what if we let the individual distances scale here? so move hits in and out?

        x_cl, rs, bidxs, sel_gidx, energy, x, t_idx, coords, ccoords, cdist = LocalClusterReshapeFromNeighbours2(
            K=cluster_neighbours,
            radius=0.1,
            print_reduction=False,
            loss_enabled=True,
            loss_scale=1.,
            loss_repulsion=0.4,  #.5
            hier_transforms=[64, 32, 32, 32],
            print_loss=False,
            name='clustering_' + str(i))([
                x, cdist, nidx, rs, sel_gidx, energy, x, t_idx, coords,
                ccoords, cdist, t_idx
            ])

        gatherids.append(bidxs)

        #explicit
        energy = ReduceSumEntirely()(
            energy)  #sums up all contained energy per cluster
        #n_energy = BatchNormalization(momentum=batchnorm_momentum)(energy)

        x = x_cl  #Concatenate()([x_cl,n_energy])
        x = Dense(128, activation='elu', name='dense_clc_a' + str(i))(x)
        #x = BatchNormalization(momentum=batchnorm_momentum)(x)
        x = Dense(128, activation='elu', name='dense_clc_b' + str(i))(x)
        x = Dense(64, activation='selu')(x)
        x = BatchNormalization(momentum=batchnorm_momentum)(x)

        nneigh = 128
        nfilt = 64
        nprop = 64

        x = Concatenate()([coords, x])

        x_gn, coords, nidx, dist = RaggedGravNet(n_neighbours=nneigh,
                                                 n_dimensions=n_dimensions,
                                                 n_filters=nfilt,
                                                 n_propagate=nprop)([x, rs])

        x_sp = Dense(16)(x)
        x_sp = NeighbourApproxPCA(hidden_nodes=[32, 32, n_dimensions**2])(
            [coords, dist, x_sp, nidx])
        x_sp = BatchNormalization(momentum=batchnorm_momentum)(x_sp)

        x_mp = DistanceWeightedMessagePassing([32, 32, 16, 16, 8,
                                               8])([x, nidx, dist])
        #x_mp = BatchNormalization(momentum=batchnorm_momentum)(x_mp)
        #x_sp=x_mp

        x = Concatenate()([x, x_mp, x_sp, x_gn])
        #check and compress it all
        x = Dense(128, activation='elu', name='dense_a_' + str(i))(x)
        #x = BatchNormalization(momentum=batchnorm_momentum)(x)
        #x = Dense(128, activation='elu',name='dense_b_'+str(i))(x)
        x = Dense(64, activation='selu', name='dense_c_' + str(i))(x)
        x = Concatenate()([StopGradient()(ccoords), StopGradient()(cdist), x])
        x = BatchNormalization(momentum=batchnorm_momentum)(x)

        #record more and more the deeper we go
        x_r = x
        energysums.append(MultiBackGather()(
            [energy, gatherids]))  #assign energy sum to all cluster components

        allfeat.append(MultiBackGather()([x_r, gatherids]))

        backgatheredids.append(MultiBackGather()([sel_gidx, gatherids]))
        backgathered_coords.append(MultiBackGather()([ccoords, gatherids]))

    x = Concatenate(name='allconcat')(allfeat)
    x = Concatenate()([x] + energysums)
    x = Dense(128, activation='elu', name='alldense')(x)
    x = RaggedGlobalExchange()([x, row_splits])
    x = Dense(64, activation='selu')(x)
    x = BatchNormalization(momentum=batchnorm_momentum)(x)

    pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id = create_outputs(
        x, feat)

    #loss
    pred_beta = LLFullObjectCondensation(
        print_loss=True,
        energy_loss_weight=1e-1,
        position_loss_weight=1e-1,
        timing_loss_weight=1e-1,
        beta_loss_scale=beta_loss_scale,
        repulsion_scaling=1.,
        q_min=q_min,
        use_average_cc_pos=use_average_cc_pos,
        prob_repulsion=True,
        phase_transition=1,
        huber_energy_scale=3,
        alt_potential_norm=True,
        payload_beta_gradient_damping_strength=0.,
        kalpha_damping_strength=kalpha_damping_strength,  #1.,
        name="FullOCLoss")([
            pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id,
            orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid,
            row_splits
        ])

    return RobustModel(inputs=Inputs,
                       outputs=[
                           pred_beta, pred_ccoords, pred_energy, pred_pos,
                           pred_time, pred_id, rs
                       ] + backgatheredids + backgathered_coords)
Ejemplo n.º 14
0
def gravnet_model(Inputs, feature_dropout=-1., addBackGatherInfo=True):

    ######## pre-process all inputs and create global indices etc. No DNN actions here

    feat, t_idx, t_energy, t_pos, t_time, t_pid, row_splits = td.interpretAllModelInputs(
        Inputs)
    feat, t_idx, t_energy, t_pos, t_time, t_pid = NormalizeInputShapes()(
        [feat, t_idx, t_energy, t_pos, t_time, t_pid])

    orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid, orig_row_splits = t_idx, t_energy, t_pos, t_time, t_pid, row_splits
    gidx_orig = CreateGlobalIndices()(feat)

    _, row_splits = RaggedConstructTensor()([feat, row_splits])
    rs = row_splits

    feat_norm = ProcessFeatures()(
        feat)  #get rid of unit scalings, almost normalise
    feat_norm = BatchNormalization(momentum=0.6)(feat_norm)
    x = feat_norm

    energy = SelectFeatures(0, 1)(feat)
    time = SelectFeatures(8, 9)(feat)
    orig_coords = SelectFeatures(5, 8)(feat_norm)

    ######## create output lists

    allfeat = []

    backgatheredids = []
    gatherids = []
    backgathered = []
    backgathered_coords = []

    ####### create simple first coordinate transformation explicitly (time critical)

    coords = orig_coords
    coords = Dense(16, activation='elu')(coords)
    coords = Dense(32, activation='elu')(coords)
    coords = Dense(3, use_bias=False)(coords)
    coords = ScalarMultiply(0.1)(coords)
    coords = Add()([coords, orig_coords])
    coords = Dense(3,
                   use_bias=False,
                   kernel_initializer=tf.keras.initializers.identity())(coords)

    first_coords = coords

    ###### apply one gravnet-like transformation (explicit here because we have created coords by hand) ###

    nidx, dist = KNN(K=48)([coords, rs])
    x_mp = DistanceWeightedMessagePassing([32])([x, nidx, dist])

    first_nidx = nidx
    first_dist = dist

    ###### collect information about the surrounding energy and time distributions per vertex ###

    ncov = NeighbourCovariance()(
        [coords, ReluPlusEps()(Concatenate()([energy, time])), nidx])
    ncov = BatchNormalization(momentum=0.6)(ncov)
    ncov = Dense(64, activation='elu', name='pre_dense_ncov_a')(ncov)
    ncov = Dense(32, activation='elu', name='pre_dense_ncov_b')(ncov)

    ##### put together and process ####

    x = Concatenate()([x, x_mp, ncov, coords])
    x = Dense(64, activation='elu', name='pre_dense_a')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(32, activation='elu', name='pre_dense_b')(x)

    ####### add first set of outputs to output lists

    allfeat.append(x)
    backgathered_coords.append(coords)

    total_iterations = 5

    sel_gidx = gidx_orig

    for i in range(total_iterations):

        ###### reshape the graph to fewer vertices ####

        hier = Dense(1)(x)
        dist = LocalDistanceScaling()([dist, Dense(1)(x)])

        x_cl, rs, bidxs, sel_gidx, energy, x, t_idx, coords = LocalClusterReshapeFromNeighbours(
            K=6,
            radius=
            0.5,  #doesn't really have an effect because of local distance scaling
            print_reduction=True,
            loss_enabled=True,
            loss_scale=4.,
            loss_repulsion=0.5,
            print_loss=True,
            name='clustering_' + str(i))([
                x, dist, hier, nidx, rs, sel_gidx, energy, x, t_idx, coords,
                t_idx
            ])  #last is truth index used by layer

        gatherids.append(bidxs)

        if i or True:
            x_cl_rs = Reshape([-1, x.shape[-1]])(x_cl)  #get to shape V x K x F
            xec = EdgeConvStatic([32, 32, 32],
                                 name="ec_static_" + str(i))(x_cl_rs)
            x_cl = Concatenate()([x, xec])

        ### explicitly sum energy and re-add to features

        energy = ReduceSumEntirely()(energy)
        n_energy = BatchNormalization(momentum=0.6)(energy)
        x = Concatenate()([x_cl, n_energy])

        x = Dense(128, activation='elu', name='dense_clc0_' + str(i))(x)
        x = Dense(64, activation='relu', name='dense_clc1_' + str(i))(x)
        #notice last relu for feature weighting later

        x_gn, coords, nidx, dist = RaggedGravNet(
            n_neighbours=32 + 16 * i,
            n_dimensions=3,
            n_filters=64 + 16 * i,
            n_propagate=64,
            return_self=True)([Concatenate()([coords, x]), rs])

        ### add neighbour summary statistics

        x_ncov = NeighbourCovariance()([coords, ReluPlusEps()(x), nidx])
        x_ncov = Dense(128, activation='elu',
                       name='dense_ncov_a_' + str(i))(x_ncov)
        x_ncov = BatchNormalization(momentum=0.6)(x_ncov)
        x_ncov = Dense(64, activation='elu',
                       name='dense_ncov_b_' + str(i))(x_ncov)
        x = Concatenate()([x, x_ncov, x_gn])

        ### with all this information perform a few message passing steps

        x_mp = MessagePassing([32, 32, 16, 16, 8, 8])([x, nidx])
        x_mp = Dense(64, activation='elu', name='dense_mpc_' + str(i))(x_mp)
        x = Concatenate()([x, x_mp])

        ##### prepare output of this iteration

        x = Dense(128, activation='elu', name='dense_out_a_' + str(i))(x)
        x = BatchNormalization(momentum=0.6)(x)
        x = Dense(64, activation='elu', name='dense_out_b_' + str(i))(x)
        x = BatchNormalization(momentum=0.6)(x)

        #### compress further for output, but forward fill 64 feature x to next iteration

        x_r = Dense(8 + 16 * i, activation='elu',
                    name='dense_out_c_' + str(i))(x)
        #coords_nograd = StopGradient()(coords)
        #x_r = Concatenate()([coords_nograd,x_r]) ## add coordinates, might come handy for cluster space

        if i >= total_iterations - 1:
            energy = MultiBackGather()(
                [energy,
                 gatherids])  #assign energy sum to all cluster components

        allfeat.append(MultiBackGather()([x_r, gatherids]))
        backgatheredids.append(MultiBackGather()([sel_gidx, gatherids]))
        backgathered_coords.append(MultiBackGather()([coords, gatherids]))

    x = Concatenate(name='allconcat')(allfeat)
    #x = Dropout(0.2)(x)
    x_mp = DistanceWeightedMessagePassing([32, 32,
                                           32])([x, first_nidx, first_dist])
    x = Concatenate()([x, x_mp])

    x = Dense(128, activation='elu', name='alldense')(x)
    # TO BE ADDED WITH E LOSS x = Concatenate()([x,energy])
    #x = Dropout(0.2)(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(64, activation='elu')(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(64, activation='elu')(x)

    pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id = create_outputs(
        x, feat)

    #
    #
    # double scale phase transition with linear beta + qmin
    #  -> more high beta points, but: payload loss will still scale one
    #     (or two, but then doesn't matter)
    #

    pred_beta = LLFullObjectCondensation(
        print_loss=True,
        energy_loss_weight=0.,
        position_loss_weight=0.,  #seems broken
        timing_loss_weight=0.,  #1e-3,
        beta_loss_scale=1.,
        repulsion_scaling=1.,
        q_min=1.5,
        prob_repulsion=True,
        phase_transition=True,
        phase_transition_double_weight=False,
        alt_potential_norm=True,
        cut_payload_beta_gradient=False)([
            pred_beta, pred_ccoords, pred_energy, pred_pos, pred_time, pred_id,
            orig_t_idx, orig_t_energy, orig_t_pos, orig_t_time, orig_t_pid,
            row_splits
        ])

    return ExtendedMetricsModel(inputs=Inputs,
                                outputs=[
                                    pred_beta, pred_ccoords, pred_energy,
                                    pred_pos, pred_time, pred_id, rs
                                ] + backgatheredids + backgathered_coords)