def main(_):
    logging.set_verbosity(logging.INFO)

    model_spec = AverageWordVecModelSpec()

    data_path = tf.keras.utils.get_file(
        fname='aclImdb',
        origin='http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz',
        untar=True)
    train_data = TextClassifierDataLoader.from_folder(
        filename=os.path.join(os.path.join(data_path, 'train')),
        model_spec=model_spec,
        class_labels=['pos', 'neg'])
    train_data, validation_data = train_data.split(0.9)
    test_data = TextClassifierDataLoader.from_folder(filename=os.path.join(
        data_path, 'test'),
                                                     model_spec=model_spec,
                                                     is_training=False)

    model = text_classifier.create(train_data,
                                   model_spec=model_spec,
                                   validation_data=validation_data)

    _, acc = model.evaluate(test_data)
    print('\nTest accuracy: %f' % acc)

    model.export(FLAGS.tflite_filename, FLAGS.label_filename,
                 FLAGS.vocab_filename)
Пример #2
0
 def test_average_wordvec_model_create_v1_incompatible(self):
     with self.assertRaisesRegex(ValueError, 'Incompatible versions'):
         model_spec = ms.AverageWordVecModelSpec(seq_len=2)
         all_data = text_dataloader.TextClassifierDataLoader.from_folder(
             self.text_dir, model_spec=model_spec)
         _ = text_classifier.create(
             all_data,
             mef.ModelExportFormat.TFLITE,
             model_spec=model_spec,
         )
Пример #3
0
 def test_average_wordvec_model(self):
     model = text_classifier.create(self.data,
                                    mef.ModelExportFormat.TFLITE,
                                    model_name='average_wordvec',
                                    epochs=2,
                                    batch_size=4,
                                    sentence_len=2,
                                    shuffle=True)
     self._test_accuracy(model)
     self._test_export_to_tflite(model)
Пример #4
0
 def test_average_wordvec_model(self):
     model = text_classifier.create(
         self.train_data,
         mef.ModelExportFormat.TFLITE,
         model_spec=ms.AverageWordVecModelSpec(sentence_len=2),
         epochs=2,
         batch_size=4,
         shuffle=True)
     self._test_accuracy(model)
     self._test_export_to_tflite(model)
     self._test_predict_top_k(model)
Пример #5
0
    def test_average_wordvec_model(self):
        model_spec = ms.AverageWordVecModelSpec(seq_len=2)
        all_data = text_dataloader.TextClassifierDataLoader.from_folder(
            self.text_dir, model_spec=model_spec)
        # Splits data, 90% data for training, 10% for testing
        self.train_data, self.test_data = all_data.split(0.9)

        model = text_classifier.create(self.train_data,
                                       mef.ModelExportFormat.TFLITE,
                                       model_spec=model_spec,
                                       epochs=2,
                                       batch_size=4,
                                       shuffle=True)
        self._test_accuracy(model)
        self._test_export_to_tflite(model)
        self._test_predict_top_k(model)
Пример #6
0
def main(_):
    logging.set_verbosity(logging.INFO)

    data_path = tf.keras.utils.get_file(
        fname='aclImdb',
        origin='http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz',
        untar=True)
    train_data = TextClassifierDataLoader.from_folder(
        filename=os.path.join(os.path.join(data_path, 'train')),
        class_labels=['pos', 'neg'])
    test_data = TextClassifierDataLoader.from_folder(
        filename=os.path.join(data_path, 'test'))

    model = text_classifier.create(
        train_data, model_export_format=ModelExportFormat.TFLITE)

    _, acc = model.evaluate(test_data)
    print('Test accuracy: %f' % acc)

    model.export(FLAGS.tflite_filename, FLAGS.label_filename,
                 FLAGS.vocab_filename)