Esempio n. 1
0
def run_finetuning(train_path):
    """Fine-tune a transformer model."""
    train_data = glue.SST2Data("train")
    val_data = glue.SST2Data("validation")
    model = glue_models.SST2Model(FLAGS.encoder_name, for_training=True)
    model.train(train_data.examples, validation_inputs=val_data.examples)
    model.save(train_path)
Esempio n. 2
0
def main(_):

    ##
    # Pick the model and datasets
    # TODO(lit-dev): add remaining GLUE tasks? These three cover all the major
    # features (single segment, two segment, classification, regression).
    if FLAGS.task == "sst2":
        train_data = glue.SST2Data("train")
        val_data = glue.SST2Data("validation")
        model = glue_models.SST2Model(FLAGS.encoder_name)
    elif FLAGS.task == "mnli":
        train_data = glue.MNLIData("train")
        val_data = glue.MNLIData("validation_matched")
        model = glue_models.MNLIModel(FLAGS.encoder_name)
    elif FLAGS.task == "stsb":
        train_data = glue.STSBData("train")
        val_data = glue.STSBData("validation")
        model = glue_models.STSBModel(FLAGS.encoder_name)
    elif FLAGS.task == "toxicity":
        train_data = classification.ToxicityData("train")
        val_data = classification.ToxicityData("test")
        model = glue_models.ToxicityModel(FLAGS.encoder_name)
    else:
        raise ValueError(f"Unrecognized task name: '{FLAGS.task:s}'")

    ##
    # Run training and save model.
    train_and_save(model,
                   train_data,
                   val_data,
                   FLAGS.train_path,
                   save_intermediates=FLAGS.save_intermediates,
                   num_epochs=FLAGS.num_epochs)
Esempio n. 3
0
def main(_):

  models = {}
  datasets = {}

  if "sst2" in FLAGS.tasks:
    models["sst2"] = glue_models.SST2Model(
        os.path.join(FLAGS.models_path, "sst2"))
    datasets["sst_dev"] = glue.SST2Data("validation")
    logging.info("Loaded models and data for SST-2 task.")

  if "stsb" in FLAGS.tasks:
    models["stsb"] = glue_models.STSBModel(
        os.path.join(FLAGS.models_path, "stsb"))
    datasets["stsb_dev"] = glue.STSBData("validation")
    logging.info("Loaded models and data for STS-B task.")

  if "mnli" in FLAGS.tasks:
    models["mnli"] = glue_models.MNLIModel(
        os.path.join(FLAGS.models_path, "mnli"))
    datasets["mnli_dev"] = glue.MNLIData("validation_matched")
    datasets["mnli_dev_mm"] = glue.MNLIData("validation_mismatched")
    logging.info("Loaded models and data for MultiNLI task.")

  # Truncate datasets if --max_examples is set.
  for name in datasets:
    logging.info("Dataset: '%s' with %d examples", name, len(datasets[name]))
    datasets[name] = datasets[name].slice[:FLAGS.max_examples]
    logging.info("  truncated to %d examples", len(datasets[name]))

  # Start the LIT server. See server_flags.py for server options.
  lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags())
  lit_demo.serve()
Esempio n. 4
0
    def setUp(self):
        super(ModelBasedHotflipTest, self).setUp()
        self.hotflip = hotflip.HotFlip()

        # Classification model that clasifies a given input sentence.
        self.classification_model = glue_models.SST2Model(BERT_TINY_PATH)
        self.classification_config = {hotflip.PREDICTION_KEY: 'probas'}

        # A wrapped version of the classification model that does not expose
        # embeddings.
        self.classification_model_without_embeddings = SSTModelWithoutEmbeddings(
            BERT_TINY_PATH)

        # A wrapped version of the classification model that does not take tokens
        # as input.
        self.classification_model_without_tokens = SSTModelWithoutTokens(
            BERT_TINY_PATH)

        # A wrapped version of the classification model that does not expose
        # gradients.
        self.classification_model_without_gradients = SSTModelWithoutGradients(
            BERT_TINY_PATH)

        # Regression model determining similarity between two input sentences.
        self.regression_model = glue_models.STSBModel(STSB_PATH)
        self.regression_config = {hotflip.PREDICTION_KEY: 'score'}
Esempio n. 5
0
def main(_):
    model_path = FLAGS.model_path or tempfile.mkdtemp()
    logging.info("Working directory: %s", model_path)
    run_finetuning(model_path)

    # Load our trained model.
    models = {"sst": glue_models.SST2Model(model_path)}
    datasets = {"sst_dev": glue.SST2Data("validation")}

    # Start the LIT server. See server_flags.py for server options.
    lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags())
    lit_demo.serve()
Esempio n. 6
0
    def setUp(self):
        super(ThresholderTest, self).setUp()
        self.thresholder = thresholder.Thresholder()
        self.model = caching.CachingModelWrapper(
            glue_models.SST2Model(BERT_TINY_PATH), 'test')
        examples = [{
            'sentence': 'a',
            'label': '1'
        }, {
            'sentence': 'b',
            'label': '1'
        }, {
            'sentence': 'c',
            'label': '1'
        }, {
            'sentence': 'd',
            'label': '1'
        }, {
            'sentence': 'e',
            'label': '1'
        }, {
            'sentence': 'f',
            'label': '0'
        }, {
            'sentence': 'g',
            'label': '0'
        }, {
            'sentence': 'h',
            'label': '0'
        }, {
            'sentence': 'i',
            'label': '0'
        }]

        self.indexed_inputs = [{
            'id': caching.input_hash(ex),
            'data': ex
        } for ex in examples]
        self.dataset = lit_dataset.IndexedDataset(
            id_fn=caching.input_hash,
            spec={
                'sentence': lit_types.TextSegment(),
                'label': lit_types.CategoryLabel(vocab=['0', '1'])
            },
            indexed_examples=self.indexed_inputs)
        self.model_outputs = list(
            self.model.predict_with_metadata(self.indexed_inputs,
                                             dataset_name='test'))
Esempio n. 7
0
    def test_sst2_model_predict(self):
        # Create model.
        model_path = "https://storage.googleapis.com/what-if-tool-resources/lit-models/sst2_tiny.tar.gz"  # pylint: disable=line-too-long
        if model_path.endswith(".tar.gz"):
            model_path = transformers.file_utils.cached_path(
                model_path, extract_compressed_file=True)
        model = glue_models.SST2Model(model_path)

        # Run prediction to ensure no failure.
        model_in = [{"sentence": "test sentence"}]
        model_out = list(model.predict(model_in))

        # Sanity-check output vs output spec.
        self.assertLen(model_out, 1)
        for key in model.output_spec().keys():
            self.assertIn(key, model_out[0].keys())
Esempio n. 8
0
    def setUp(self):
        super(ModelBasedAblationFlipTest, self).setUp()
        self.ablation_flip = ablation_flip.AblationFlip()

        # Classification model that clasifies a given input sentence.
        self.classification_model = glue_models.SST2Model(BERT_TINY_PATH)
        self.classification_config = {ablation_flip.PREDICTION_KEY: 'probas'}

        # Clasification model with the 'sentence' field marked as
        # non-required.
        self.classification_model_non_required_field = SST2ModelNonRequiredField(
            BERT_TINY_PATH)

        # Clasification model with a counter to count number of predict calls.
        # TODO(ataly): Consider setting up a Mock object to count number of
        # predict calls.
        self.classification_model_with_predict_counter = (
            SST2ModelWithPredictCounter(BERT_TINY_PATH))

        # Regression model determining similarity between two input sentences.
        self.regression_model = glue_models.STSBModel(STSB_PATH)
        self.regression_config = {ablation_flip.PREDICTION_KEY: 'score'}
Esempio n. 9
0
 def setUp(self):
   super(ModelBasedTCAVTest, self).setUp()
   self.tcav = tcav.TCAV()
   self.model = caching.CachingModelWrapper(
       glue_models.SST2Model(BERT_TINY_PATH), 'test')
Esempio n. 10
0
 def setUp(self):
   super(TCAVTest, self).setUp()
   self.tcav = tcav.TCAV()
   self.model = glue_models.SST2Model(BERT_TINY_PATH)