def build_crf(self): reg = tf.keras.regularizers.L2(1e-3) input = Input(shape=(self.seq_len, self.feat_dim), dtype='float32') mid = Dense(self.num_classes, input_shape=(self.seq_len, self.feat_dim), activation='linear', kernel_regularizer=reg)(input) crf = CRF(dtype='float32', sparse_target=True) crf.sequence_lengths = self.seq_len crf.output_dim = self.num_classes output = crf(mid) model = Model(input, output) opt = tf.keras.optimizers.Adam(learning_rate=0.01) model.compile(loss=crf.loss, optimizer=opt, metrics=[crf.accuracy]) model.load_weights(self.crf_weights) self.crf = model
# Setting global network params SEQ_LEN = 1024 NUM_CLASSES = 25 FEAT_DIM = 128 # Since tf2crf is a custom class, need to build the CRF model and load the weights (can not load the model directly) reg = tf.keras.regularizers.L2(1e-3) input = Input(shape=(SEQ_LEN, FEAT_DIM), dtype='float32') mid = Dense(NUM_CLASSES, input_shape=(SEQ_LEN, FEAT_DIM), activation='linear', kernel_regularizer=reg)(input) crf = CRF(dtype='float32', sparse_target=True) crf.sequence_lengths = SEQ_LEN crf.output_dim = NUM_CLASSES output = crf(mid) model = Model(input, output) opt = tf.keras.optimizers.Adam(learning_rate=0.01) model.compile(loss=crf.loss, optimizer=opt, metrics=[crf.accuracy]) load_from = './model_01' model.load_weights(load_from) crf = model # Loading FCNN feature extractor cnn = load_model('cnn_extractor.h5') CONTEXT_SIZE = 7