def gravnet_model(Inputs, feature_dropout=-1.): nregressions = 5 # I_data = tf.keras.Input(shape=(num_features,), dtype="float32") # I_splits = tf.keras.Input(shape=(1,), dtype="int32") I_data = Inputs[0] I_splits = tf.cast(Inputs[1], tf.int32) x_data, x_row_splits = RaggedConstructTensor(name="construct_ragged")([I_data, I_splits]) input_features = x_data # these are going to be passed through for the loss x_basic = BatchNormalization(momentum=0.6)(x_data) # mask_and_norm is just batch norm now x = x_basic n_filters = 0 n_gravnet_layers = 3 feat = [x] for i in range(n_gravnet_layers): n_filters = 128 n_propagate = [128,64,32,16,16,8,8,4,4] n_neighbours = 128 n_dim=4 if i == n_gravnet_layers - 1: n_dim=2 x = Concatenate()([x_basic,x]) x,coords = FusedRaggedGravNetLinParse(n_neighbours=n_neighbours, n_dimensions=n_dim, n_filters=n_filters, n_propagate=n_propagate, name='gravnet_' + str(i))([x, x_row_splits]) #tf.print('minmax coords',tf.reduce_min(coords),tf.reduce_max(coords)) x = BatchNormalization(momentum=0.6)(x) x = RaggedGlobalExchange(name="global_exchange_bottom_"+str(i))([x, x_row_splits]) x = Dense(96, activation='elu',name="dense_bottom_"+str(i)+"_a")(x) x = Dense(96, activation='elu',name="dense_bottom_"+str(i)+"_b")(x) x = Dense(96, activation='elu',name="dense_bottom_"+str(i)+"_c")(x) x = BatchNormalization(momentum=0.6)(x) feat.append(x) x = Concatenate(name="concat_gravout")(feat) x = Dense(128, activation='elu',name="dense_last_a")(x) x = Dense(128, activation='elu',name="dense_last_a1")(x) x = BatchNormalization(momentum=0.6)(x) x = Dense(64, activation='elu',name="dense_last_b")(x) x = BatchNormalization(momentum=0.6)(x) x = Dense(64, activation='elu',name="dense_last_c")(x) predictions = create_default_outputs(input_features, x, x_row_splits, energy_block=True) # outputs = tf.tuple([predictions, x_row_splits]) return Model(inputs=Inputs, outputs=[predictions, predictions])
def gravnet_model(Inputs, feature_dropout=-1.): nregressions = 5 # I_data = tf.keras.Input(shape=(num_features,), dtype="float32") # I_splits = tf.keras.Input(shape=(1,), dtype="int32") I_data = Inputs[0] I_splits = tf.cast(Inputs[1], tf.int32) x_data, x_row_splits = RaggedConstructTensor(name="construct_ragged")( [I_data, I_splits]) input_features = x_data # these are going to be passed through for the loss x_basic = BatchNormalization(momentum=0.6)( x_data) # mask_and_norm is just batch norm now x = x_basic n_filters = 0 n_gravnet_layers = 7 feat = [] for i in range(n_gravnet_layers): n_filters = 128 n_propagate = 96 n_neighbours = 200 x = RaggedGlobalExchange(name="global_exchange_" + str(i))([x, x_row_splits]) x = Dense(64, activation='elu', name="dense_" + str(i) + "_a")(x) x = Dense(64, activation='elu', name="dense_" + str(i) + "_b")(x) x = Dense(64, activation='elu', name="dense_" + str(i) + "_c")(x) x = BatchNormalization(momentum=0.6)(x) x = FusedRaggedGravNet_simple(n_neighbours=n_neighbours, n_dimensions=4, n_filters=n_filters, n_propagate=n_propagate, name='gravnet_' + str(i))([x, x_row_splits]) x = BatchNormalization(momentum=0.6)(x) feat.append( Dense(48, activation='elu', name="dense_compress_" + str(i))(x)) x = Concatenate(name="concat_gravout")(feat) x = Dense(128, activation='elu', name="dense_last_a")(x) x = Dense(64, activation='elu', name="dense_last_b")(x) x = Dense(64, activation='elu', name="dense_last_c")(x) predictions = create_default_outputs(input_features, x, x_row_splits, energy_block=False) # outputs = tf.tuple([predictions, x_row_splits]) return Model(inputs=Inputs, outputs=[predictions, predictions])
def gravnet_model(Inputs, feature_dropout=-1.): nregressions = 5 # I_data = tf.keras.Input(shape=(num_features,), dtype="float32") # I_splits = tf.keras.Input(shape=(1,), dtype="int32") I_data = Inputs[0] I_splits = tf.cast(Inputs[1], tf.int32) x_data, x_row_splits = RaggedConstructTensor(name="construct_ragged")( [I_data, I_splits]) input_features = x_data # these are going to be passed through for the loss x_basic = BatchNormalization(momentum=0.6)( x_data) # mask_and_norm is just batch norm now x = x_basic x = RaggedGlobalExchange(name="global_exchange")([x, x_row_splits]) x = Dense(64, activation='elu', name="dense_start")(x) n_filters = 0 n_gravnet_layers = 4 feat = [x_basic, x] for i in range(n_gravnet_layers): n_filters = 196 n_propagate = 128 n_propagate_2 = [64, 32, 16, 8, 4, 2] n_neighbours = 32 n_dim = 8 #if n_dim < 2: # n_dim = 2 x, coords, neighbor_indices, neighbor_dist = RaggedGravNet( n_neighbours=n_neighbours, n_dimensions=n_dim, n_filters=n_filters, n_propagate=n_propagate, name='gravnet_' + str(i))([x, x_row_splits]) x = DistanceWeightedMessagePassing(n_propagate_2)( [x, neighbor_indices, neighbor_dist]) x = BatchNormalization(momentum=0.6)(x) x = Dense(128, activation='elu', name="dense_bottom_" + str(i) + "_a")(x) x = BatchNormalization(momentum=0.6, name="bn_a_" + str(i))(x) x = Dense(96, activation='elu', name="dense_bottom_" + str(i) + "_b")(x) x = RaggedGlobalExchange(name="global_exchange_bot_" + str(i))([x, x_row_splits]) x = Dense(96, activation='elu', name="dense_bottom_" + str(i) + "_c")(x) x = BatchNormalization(momentum=0.6, name="bn_b_" + str(i))(x) feat.append(x) x = Concatenate(name="concat_gravout")(feat) x = Dense(128, activation='elu', name="dense_last_a")(x) x = BatchNormalization(momentum=0.6, name="bn_last_a")(x) x = Dense(128, activation='elu', name="dense_last_a1")(x) x = BatchNormalization(momentum=0.6, name="bn_last_a1")(x) x = Dense(128, activation='elu', name="dense_last_a2")(x) x = BatchNormalization(momentum=0.6, name="bn_last_a2")(x) x = Dense(64, activation='elu', name="dense_last_b")(x) x = Dense(64, activation='elu', name="dense_last_c")(x) predictions = create_default_outputs(input_features, x, x_row_splits, energy_block=False, n_ccoords=n_ccoords, scale_exp_e=False) # outputs = tf.tuple([predictions, x_row_splits]) return Model(inputs=Inputs, outputs=[predictions, predictions])