Beispiel #1
0
def run(data_dir, export_dir, spec='bert_classifier', **kwargs):
    """Runs demo."""
    # Chooses model specification that represents model.
    spec = model_spec.get(spec)

    # Gets training data and validation data.
    train_data = TextClassifierDataLoader.from_csv(filename=os.path.join(
        os.path.join(data_dir, 'train.tsv')),
                                                   text_column='sentence',
                                                   label_column='label',
                                                   model_spec=spec,
                                                   delimiter='\t',
                                                   is_training=True)
    validation_data = TextClassifierDataLoader.from_csv(filename=os.path.join(
        os.path.join(data_dir, 'dev.tsv')),
                                                        text_column='sentence',
                                                        label_column='label',
                                                        model_spec=spec,
                                                        delimiter='\t',
                                                        is_training=False)

    # Fine-tunes the model.
    model = text_classifier.create(train_data,
                                   model_spec=spec,
                                   validation_data=validation_data,
                                   **kwargs)

    # Gets evaluation results.
    _, acc = model.evaluate(validation_data)
    print('Eval accuracy: %f' % acc)

    # Exports to TFLite format.
    model.export(export_dir)
Beispiel #2
0
 def _test_model_without_training(self, model_spec):
     # Test without retraining.
     model = text_classifier.create(self.train_data,
                                    model_spec=model_spec,
                                    do_train=False)
     self._test_accuracy(model, threshold=0.0)
     self._test_export_to_tflite(model, threshold=0.0)
 def test_average_wordvec_model_create_v1_incompatible(self):
     with self.assertRaisesRegex(ValueError, 'Incompatible versions'):
         model_spec = ms.AverageWordVecModelSpec(seq_len=2)
         all_data = text_dataloader.TextClassifierDataLoader.from_folder(
             self.text_dir, model_spec=model_spec)
         _ = text_classifier.create(
             all_data,
             model_spec=model_spec,
         )
    def test_bert_model(self):
        model_spec = ms.BertClassifierModelSpec(seq_len=2, trainable=False)
        all_data = text_dataloader.TextClassifierDataLoader.from_folder(
            self.text_dir, model_spec=model_spec)
        # Splits data, 90% data for training, 10% for testing
        self.train_data, self.test_data = all_data.split(0.9)

        model = text_classifier.create(self.train_data,
                                       model_spec=model_spec,
                                       epochs=1,
                                       batch_size=1,
                                       shuffle=True)
        self._test_accuracy(model, 0.5)
Beispiel #5
0
    def test_mobilebert_model(self):
        model_spec = ms.mobilebert_classifier_spec(seq_len=2, trainable=False)
        all_data = text_dataloader.TextClassifierDataLoader.from_folder(
            self.tiny_text_dir, model_spec=model_spec)
        # Splits data, 50% data for training, 50% for testing
        self.train_data, self.test_data = all_data.split(0.5)

        model = text_classifier.create(self.train_data,
                                       model_spec=model_spec,
                                       epochs=1,
                                       batch_size=1,
                                       shuffle=True)
        self._test_accuracy(model, 0.0)
        self._test_export_to_tflite(model, threshold=0.0, atol=1e-2)
        self._test_export_to_tflite_quant(model)
Beispiel #6
0
    def test_average_wordvec_model(self):
        model_spec = ms.AverageWordVecModelSpec(seq_len=2)
        all_data = text_dataloader.TextClassifierDataLoader.from_folder(
            self.text_dir, model_spec=model_spec)
        # Splits data, 90% data for training, 10% for testing
        self.train_data, self.test_data = all_data.split(0.9)

        model = text_classifier.create(self.train_data,
                                       mef.ModelExportFormat.TFLITE,
                                       model_spec=model_spec,
                                       epochs=2,
                                       batch_size=4,
                                       shuffle=True)
        self._test_accuracy(model)
        self._test_export_to_tflite(model)
        self._test_predict_top_k(model)
    def test_mobilebert_model(self):
        model_spec = ms.mobilebert_classifier_spec
        model_spec.seq_len = 2
        model_spec.trainable = False
        all_data = text_dataloader.TextClassifierDataLoader.from_folder(
            self.text_dir, model_spec=model_spec)
        # Splits data, 90% data for training, 10% for testing
        self.train_data, self.test_data = all_data.split(0.9)

        model = text_classifier.create(self.train_data,
                                       model_spec=model_spec,
                                       epochs=1,
                                       batch_size=1,
                                       shuffle=True)
        self._test_accuracy(model, 0.5)
        self._test_export_to_tflite(model, test_predict_accuracy=False)
        self._test_export_to_tflite_quant(model)
Beispiel #8
0
    def test_bert_model(self):
        model_spec = ms.BertClassifierModelSpec(seq_len=2, trainable=False)
        all_data = text_dataloader.TextClassifierDataLoader.from_folder(
            self.tiny_text_dir, model_spec=model_spec)
        # Splits data, 50% data for training, 50% for testing
        self.train_data, self.test_data = all_data.split(0.5)

        model = text_classifier.create(self.train_data,
                                       model_spec=model_spec,
                                       epochs=1,
                                       batch_size=1,
                                       shuffle=True)
        self._test_accuracy(model, 0.0)
        self._test_export_to_tflite(
            model,
            threshold=0.0,
            expected_json_file='bert_classifier_metadata.json')
        self._test_model_without_training(model_spec)
    def test_average_wordvec_model(self):
        model_spec = ms.AverageWordVecModelSpec(seq_len=2)
        all_data = text_dataloader.TextClassifierDataLoader.from_folder(
            self.text_dir, model_spec=model_spec)
        # Splits data, 90% data for training, 10% for testing
        self.train_data, self.test_data = all_data.split(0.5)

        model = text_classifier.create(self.train_data,
                                       model_spec=model_spec,
                                       epochs=1,
                                       batch_size=1,
                                       shuffle=True)
        self._test_accuracy(model, threshold=0.0)
        self._test_predict_top_k(model)
        self._test_export_to_tflite(model, threshold=0.0)
        self._test_export_to_saved_model(model)
        self._test_export_labels(model)
        self._test_export_vocab(model)
        self._test_model_without_training(model_spec)
def main(_):
    logging.set_verbosity(logging.INFO)

    data_path = tf.keras.utils.get_file(
        fname='sst.tar.gz',
        origin=
        'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',
        extract=True)
    data_path = os.path.join(os.path.dirname(data_path), 'SST-2')

    # Chooses model specification that represents model.
    model_spec = BertModelSpec()

    # Gets training data and validation data.
    train_data = TextClassifierDataLoader.from_csv(filename=os.path.join(
        os.path.join(data_path, 'train.tsv')),
                                                   text_column='sentence',
                                                   label_column='label',
                                                   model_spec=model_spec,
                                                   delimiter='\t')
    validation_data = TextClassifierDataLoader.from_csv(filename=os.path.join(
        os.path.join(data_path, 'dev.tsv')),
                                                        text_column='sentence',
                                                        label_column='label',
                                                        model_spec=model_spec,
                                                        delimiter='\t')

    # Fine-tunes the model.
    model = text_classifier.create(train_data,
                                   model_spec=model_spec,
                                   validation_data=validation_data)

    # Gets evaluation results.
    _, acc = model.evaluate(validation_data)
    print('Eval accuracy: %f' % acc)

    # Exports to TFLite format.
    model.export(FLAGS.tflite_filename, FLAGS.label_filename,
                 FLAGS.vocab_filename)
Beispiel #11
0
    def test_mobilebert_model(self):
        model_spec = ms.mobilebert_classifier_spec
        model_spec.seq_len = 2
        model_spec.trainable = False
        all_data = text_dataloader.TextClassifierDataLoader.from_folder(
            self.text_dir, model_spec=model_spec)
        # Splits data, 90% data for training, 10% for testing
        self.train_data, self.test_data = all_data.split(0.9)

        model = text_classifier.create(self.train_data,
                                       model_spec=model_spec,
                                       epochs=1,
                                       batch_size=1,
                                       shuffle=True)
        self._test_accuracy(model, 0.5)
        error_message = 'Couldn\'t convert MobileBert to TFLite for now.'
        with self.assertRaises(ValueError) as error:
            self._test_export_to_tflite(model, test_predict_accuracy=False)
        self.assertEqual(error_message, str(error.exception))

        with self.assertRaises(ValueError) as error:
            self._test_export_to_tflite_quant(model)
        self.assertEqual(error_message, str(error.exception))