예제 #1
0
class DeepEventEventRelationMentionDecoder:
    def __init__(self, paramsFile):
        params = wbq_params.read(paramsFile)

        print("params: " + str(params))

        self.tf_model_prefix=params['TensorFlowEventEventRelationMentionExtractionModelPrefix']
        self.tf_model_relation2id=params['TensorFlowEventEventRelationMentionExtractionModelRelation2id']
        self.tf_model_word2vec=params['TensorFlowEventEventRelationMentionExtractionModelWord2vec']

        self.max_length=120
        self.batch_size=160
        self.word_embedding_dim=300
        self.encoder="pcnn"
        self.selector="att"

        self.test_data_loader = DataLoader(self.tf_model_word2vec,
                                      self.tf_model_relation2id,
                                      mode=DataLoader.MODE_INSTANCE,
                                      shuffle=False,
                                      max_length=self.max_length,
                                      batch_size=self.batch_size)

        self.test_data_loader.create_dataset([])

        self.model = MultiModel(None, self.test_data_loader,
                           max_length=self.max_length,
                           batch_size=self.batch_size,
                           word_embedding_dim=self.word_embedding_dim,
                           encoder=self.encoder,
                           selector=self.selector)

        self.model.load_best_model(self.tf_model_prefix)

        self.model.load_id2rel(self.tf_model_relation2id)


    def decode(self, serializedInstance):

        # we have one and only one instance each time this function is called
        # decoding_json = [serializedInstance]
        decoding_json = [serializedInstance for _ in range(0, 160)]

        self.test_data_loader.create_dataset(decoding_json)

        predicted_rels,predicted_prob = self.model.predict(self.test_data_loader)

        # we have one and only one instance in this list
        label = predicted_rels[0]
        if label=="NA":
            label="OTHER"
        confidence = predicted_prob[0]

        print("PREDICTION:\t" + label + "\t" + str(confidence) + "\tserializedInstance: " + str(serializedInstance))

        return { 'label': label, 'confidence' : str(confidence) }
예제 #2
0
def train(argv):

    loader = DataLoader(FLAGS.train_dir,
                        n_cls=FLAGS.ncls,
                        img_shape=FLAGS.img_shape)
    train_ds, val_ds = loader.create_dataset(batch_size=FLAGS.mb_size)
    logdir = str(
        PurePath(
            os.path.join(FLAGS.logdir,
                         f'{datetime.now().strftime("%Y%m%d-%H%M%S")}')))
    initial_epoch = 0

    checkpoint_path = str(
        PurePath(os.path.join(FLAGS.model_dir, "model-{epoch:04d}.ckpt")))
    cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                     verbose=1,
                                                     save_weights_only=True,
                                                     period=FLAGS.save_freq)

    model = YOLOv3(img_shape=FLAGS.img_shape,
                   ncls=FLAGS.ncls,
                   use_spp=True,
                   use_pretrained_weights=True)

    if FLAGS.restore:
        initial_epoch = FLAGS.initial_epoch
        path = str(
            PurePath(
                os.path.join(FLAGS.model_dir,
                             f"model-{initial_epoch:04d}.ckpt")))

        model.load_weights(path)

    optimizer = tf.keras.optimizers.Adam(learning_rate=FLAGS.learning_rate)
    tb_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir,
                                                 write_images=True)
    image_cb = callbacks.YOLOCallback(model.serving,
                                      val_data=val_ds,
                                      val_steps=5,
                                      logdir=logdir,
                                      encoder=loader.encoder,
                                      write_every_n_epochs=10)

    model.compile(optimizer=optimizer)
    model.fit(train_ds,
              steps_per_epoch=FLAGS.steps_per_epoch,
              validation_data=val_ds,
              validation_steps=100,
              epochs=FLAGS.max_epochs,
              callbacks=[image_cb, tb_callback, cp_callback],
              initial_epoch=initial_epoch)