def gravnet_model(Inputs, feature_dropout=-1.):
    nregressions = 5

    # I_data = tf.keras.Input(shape=(num_features,), dtype="float32")
    # I_splits = tf.keras.Input(shape=(1,), dtype="int32")

    I_data = Inputs[0]
    I_splits = tf.cast(Inputs[1], tf.int32)

    x_data, x_row_splits = RaggedConstructTensor(name="construct_ragged")([I_data, I_splits])

    input_features = x_data  # these are going to be passed through for the loss

    x_basic = BatchNormalization(momentum=0.6)(x_data)  # mask_and_norm is just batch norm now
    x = x_basic

    n_filters = 0
    n_gravnet_layers = 3
    feat = [x]
    for i in range(n_gravnet_layers):
        n_filters = 128
        n_propagate = [128,64,32,16,16,8,8,4,4]
        n_neighbours = 128
        n_dim=4
        if i == n_gravnet_layers - 1:
            n_dim=2
        x = Concatenate()([x_basic,x])
        x,coords = FusedRaggedGravNetLinParse(n_neighbours=n_neighbours,
                                 n_dimensions=n_dim,
                                 n_filters=n_filters,
                                 n_propagate=n_propagate,
                                 name='gravnet_' + str(i))([x, x_row_splits])
                                 
        #tf.print('minmax coords',tf.reduce_min(coords),tf.reduce_max(coords))
                                 
        x = BatchNormalization(momentum=0.6)(x)
        x = RaggedGlobalExchange(name="global_exchange_bottom_"+str(i))([x, x_row_splits])
        x = Dense(96, activation='elu',name="dense_bottom_"+str(i)+"_a")(x)
        x = Dense(96, activation='elu',name="dense_bottom_"+str(i)+"_b")(x)
        x = Dense(96, activation='elu',name="dense_bottom_"+str(i)+"_c")(x)
        x = BatchNormalization(momentum=0.6)(x)

        feat.append(x)
    
        
    x = Concatenate(name="concat_gravout")(feat)
    x = Dense(128, activation='elu',name="dense_last_a")(x)
    x = Dense(128, activation='elu',name="dense_last_a1")(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(64, activation='elu',name="dense_last_b")(x)
    x = BatchNormalization(momentum=0.6)(x)
    x = Dense(64, activation='elu',name="dense_last_c")(x)

    predictions = create_default_outputs(input_features, x, x_row_splits, energy_block=True)
    

    # outputs = tf.tuple([predictions, x_row_splits])

    return Model(inputs=Inputs, outputs=[predictions, predictions])
Exemple #2
0
def gravnet_model(Inputs, feature_dropout=-1.):
    nregressions = 5

    # I_data = tf.keras.Input(shape=(num_features,), dtype="float32")
    # I_splits = tf.keras.Input(shape=(1,), dtype="int32")

    I_data = Inputs[0]
    I_splits = tf.cast(Inputs[1], tf.int32)

    x_data, x_row_splits = RaggedConstructTensor(name="construct_ragged")(
        [I_data, I_splits])

    input_features = x_data  # these are going to be passed through for the loss

    x_basic = BatchNormalization(momentum=0.6)(
        x_data)  # mask_and_norm is just batch norm now
    x = x_basic

    n_filters = 0
    n_gravnet_layers = 7
    feat = []
    for i in range(n_gravnet_layers):
        n_filters = 128
        n_propagate = 96
        n_neighbours = 200

        x = RaggedGlobalExchange(name="global_exchange_" +
                                 str(i))([x, x_row_splits])
        x = Dense(64, activation='elu', name="dense_" + str(i) + "_a")(x)
        x = Dense(64, activation='elu', name="dense_" + str(i) + "_b")(x)
        x = Dense(64, activation='elu', name="dense_" + str(i) + "_c")(x)
        x = BatchNormalization(momentum=0.6)(x)
        x = FusedRaggedGravNet_simple(n_neighbours=n_neighbours,
                                      n_dimensions=4,
                                      n_filters=n_filters,
                                      n_propagate=n_propagate,
                                      name='gravnet_' +
                                      str(i))([x, x_row_splits])
        x = BatchNormalization(momentum=0.6)(x)
        feat.append(
            Dense(48, activation='elu', name="dense_compress_" + str(i))(x))

    x = Concatenate(name="concat_gravout")(feat)
    x = Dense(128, activation='elu', name="dense_last_a")(x)
    x = Dense(64, activation='elu', name="dense_last_b")(x)
    x = Dense(64, activation='elu', name="dense_last_c")(x)

    predictions = create_default_outputs(input_features,
                                         x,
                                         x_row_splits,
                                         energy_block=False)

    # outputs = tf.tuple([predictions, x_row_splits])

    return Model(inputs=Inputs, outputs=[predictions, predictions])
Exemple #3
0
def gravnet_model(Inputs, feature_dropout=-1.):
    nregressions = 5

    # I_data = tf.keras.Input(shape=(num_features,), dtype="float32")
    # I_splits = tf.keras.Input(shape=(1,), dtype="int32")

    I_data = Inputs[0]
    I_splits = tf.cast(Inputs[1], tf.int32)

    x_data, x_row_splits = RaggedConstructTensor(name="construct_ragged")(
        [I_data, I_splits])

    input_features = x_data  # these are going to be passed through for the loss

    x_basic = BatchNormalization(momentum=0.6)(
        x_data)  # mask_and_norm is just batch norm now
    x = x_basic
    x = RaggedGlobalExchange(name="global_exchange")([x, x_row_splits])
    x = Dense(64, activation='elu', name="dense_start")(x)

    n_filters = 0
    n_gravnet_layers = 4
    feat = [x_basic, x]
    for i in range(n_gravnet_layers):
        n_filters = 196
        n_propagate = 128
        n_propagate_2 = [64, 32, 16, 8, 4, 2]
        n_neighbours = 32
        n_dim = 8
        #if n_dim < 2:
        #    n_dim = 2

        x, coords, neighbor_indices, neighbor_dist = RaggedGravNet(
            n_neighbours=n_neighbours,
            n_dimensions=n_dim,
            n_filters=n_filters,
            n_propagate=n_propagate,
            name='gravnet_' + str(i))([x, x_row_splits])
        x = DistanceWeightedMessagePassing(n_propagate_2)(
            [x, neighbor_indices, neighbor_dist])

        x = BatchNormalization(momentum=0.6)(x)
        x = Dense(128, activation='elu',
                  name="dense_bottom_" + str(i) + "_a")(x)
        x = BatchNormalization(momentum=0.6, name="bn_a_" + str(i))(x)
        x = Dense(96, activation='elu',
                  name="dense_bottom_" + str(i) + "_b")(x)
        x = RaggedGlobalExchange(name="global_exchange_bot_" +
                                 str(i))([x, x_row_splits])
        x = Dense(96, activation='elu',
                  name="dense_bottom_" + str(i) + "_c")(x)
        x = BatchNormalization(momentum=0.6, name="bn_b_" + str(i))(x)

        feat.append(x)

    x = Concatenate(name="concat_gravout")(feat)
    x = Dense(128, activation='elu', name="dense_last_a")(x)
    x = BatchNormalization(momentum=0.6, name="bn_last_a")(x)
    x = Dense(128, activation='elu', name="dense_last_a1")(x)
    x = BatchNormalization(momentum=0.6, name="bn_last_a1")(x)
    x = Dense(128, activation='elu', name="dense_last_a2")(x)
    x = BatchNormalization(momentum=0.6, name="bn_last_a2")(x)
    x = Dense(64, activation='elu', name="dense_last_b")(x)
    x = Dense(64, activation='elu', name="dense_last_c")(x)

    predictions = create_default_outputs(input_features,
                                         x,
                                         x_row_splits,
                                         energy_block=False,
                                         n_ccoords=n_ccoords,
                                         scale_exp_e=False)

    # outputs = tf.tuple([predictions, x_row_splits])

    return Model(inputs=Inputs, outputs=[predictions, predictions])