def run_finetuning(train_path): """Fine-tune a transformer model.""" train_data = glue.SST2Data("train") val_data = glue.SST2Data("validation") model = glue_models.SST2Model(FLAGS.encoder_name, for_training=True) model.train(train_data.examples, validation_inputs=val_data.examples) model.save(train_path)
def main(_): ## # Pick the model and datasets # TODO(lit-dev): add remaining GLUE tasks? These three cover all the major # features (single segment, two segment, classification, regression). if FLAGS.task == "sst2": train_data = glue.SST2Data("train") val_data = glue.SST2Data("validation") model = glue_models.SST2Model(FLAGS.encoder_name) elif FLAGS.task == "mnli": train_data = glue.MNLIData("train") val_data = glue.MNLIData("validation_matched") model = glue_models.MNLIModel(FLAGS.encoder_name) elif FLAGS.task == "stsb": train_data = glue.STSBData("train") val_data = glue.STSBData("validation") model = glue_models.STSBModel(FLAGS.encoder_name) elif FLAGS.task == "toxicity": train_data = classification.ToxicityData("train") val_data = classification.ToxicityData("test") model = glue_models.ToxicityModel(FLAGS.encoder_name) else: raise ValueError(f"Unrecognized task name: '{FLAGS.task:s}'") ## # Run training and save model. train_and_save(model, train_data, val_data, FLAGS.train_path, save_intermediates=FLAGS.save_intermediates, num_epochs=FLAGS.num_epochs)
def main(_): models = {} datasets = {} if "sst2" in FLAGS.tasks: models["sst2"] = glue_models.SST2Model( os.path.join(FLAGS.models_path, "sst2")) datasets["sst_dev"] = glue.SST2Data("validation") logging.info("Loaded models and data for SST-2 task.") if "stsb" in FLAGS.tasks: models["stsb"] = glue_models.STSBModel( os.path.join(FLAGS.models_path, "stsb")) datasets["stsb_dev"] = glue.STSBData("validation") logging.info("Loaded models and data for STS-B task.") if "mnli" in FLAGS.tasks: models["mnli"] = glue_models.MNLIModel( os.path.join(FLAGS.models_path, "mnli")) datasets["mnli_dev"] = glue.MNLIData("validation_matched") datasets["mnli_dev_mm"] = glue.MNLIData("validation_mismatched") logging.info("Loaded models and data for MultiNLI task.") # Truncate datasets if --max_examples is set. for name in datasets: logging.info("Dataset: '%s' with %d examples", name, len(datasets[name])) datasets[name] = datasets[name].slice[:FLAGS.max_examples] logging.info(" truncated to %d examples", len(datasets[name])) # Start the LIT server. See server_flags.py for server options. lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags()) lit_demo.serve()
def setUp(self): super(ModelBasedHotflipTest, self).setUp() self.hotflip = hotflip.HotFlip() # Classification model that clasifies a given input sentence. self.classification_model = glue_models.SST2Model(BERT_TINY_PATH) self.classification_config = {hotflip.PREDICTION_KEY: 'probas'} # A wrapped version of the classification model that does not expose # embeddings. self.classification_model_without_embeddings = SSTModelWithoutEmbeddings( BERT_TINY_PATH) # A wrapped version of the classification model that does not take tokens # as input. self.classification_model_without_tokens = SSTModelWithoutTokens( BERT_TINY_PATH) # A wrapped version of the classification model that does not expose # gradients. self.classification_model_without_gradients = SSTModelWithoutGradients( BERT_TINY_PATH) # Regression model determining similarity between two input sentences. self.regression_model = glue_models.STSBModel(STSB_PATH) self.regression_config = {hotflip.PREDICTION_KEY: 'score'}
def main(_): model_path = FLAGS.model_path or tempfile.mkdtemp() logging.info("Working directory: %s", model_path) run_finetuning(model_path) # Load our trained model. models = {"sst": glue_models.SST2Model(model_path)} datasets = {"sst_dev": glue.SST2Data("validation")} # Start the LIT server. See server_flags.py for server options. lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags()) lit_demo.serve()
def setUp(self): super(ThresholderTest, self).setUp() self.thresholder = thresholder.Thresholder() self.model = caching.CachingModelWrapper( glue_models.SST2Model(BERT_TINY_PATH), 'test') examples = [{ 'sentence': 'a', 'label': '1' }, { 'sentence': 'b', 'label': '1' }, { 'sentence': 'c', 'label': '1' }, { 'sentence': 'd', 'label': '1' }, { 'sentence': 'e', 'label': '1' }, { 'sentence': 'f', 'label': '0' }, { 'sentence': 'g', 'label': '0' }, { 'sentence': 'h', 'label': '0' }, { 'sentence': 'i', 'label': '0' }] self.indexed_inputs = [{ 'id': caching.input_hash(ex), 'data': ex } for ex in examples] self.dataset = lit_dataset.IndexedDataset( id_fn=caching.input_hash, spec={ 'sentence': lit_types.TextSegment(), 'label': lit_types.CategoryLabel(vocab=['0', '1']) }, indexed_examples=self.indexed_inputs) self.model_outputs = list( self.model.predict_with_metadata(self.indexed_inputs, dataset_name='test'))
def test_sst2_model_predict(self): # Create model. model_path = "https://storage.googleapis.com/what-if-tool-resources/lit-models/sst2_tiny.tar.gz" # pylint: disable=line-too-long if model_path.endswith(".tar.gz"): model_path = transformers.file_utils.cached_path( model_path, extract_compressed_file=True) model = glue_models.SST2Model(model_path) # Run prediction to ensure no failure. model_in = [{"sentence": "test sentence"}] model_out = list(model.predict(model_in)) # Sanity-check output vs output spec. self.assertLen(model_out, 1) for key in model.output_spec().keys(): self.assertIn(key, model_out[0].keys())
def setUp(self): super(ModelBasedAblationFlipTest, self).setUp() self.ablation_flip = ablation_flip.AblationFlip() # Classification model that clasifies a given input sentence. self.classification_model = glue_models.SST2Model(BERT_TINY_PATH) self.classification_config = {ablation_flip.PREDICTION_KEY: 'probas'} # Clasification model with the 'sentence' field marked as # non-required. self.classification_model_non_required_field = SST2ModelNonRequiredField( BERT_TINY_PATH) # Clasification model with a counter to count number of predict calls. # TODO(ataly): Consider setting up a Mock object to count number of # predict calls. self.classification_model_with_predict_counter = ( SST2ModelWithPredictCounter(BERT_TINY_PATH)) # Regression model determining similarity between two input sentences. self.regression_model = glue_models.STSBModel(STSB_PATH) self.regression_config = {ablation_flip.PREDICTION_KEY: 'score'}
def setUp(self): super(ModelBasedTCAVTest, self).setUp() self.tcav = tcav.TCAV() self.model = caching.CachingModelWrapper( glue_models.SST2Model(BERT_TINY_PATH), 'test')
def setUp(self): super(TCAVTest, self).setUp() self.tcav = tcav.TCAV() self.model = glue_models.SST2Model(BERT_TINY_PATH)