Beispiel #1
0
def main(_):

  models = {}
  datasets = {}

  if "sst2" in FLAGS.tasks:
    models["sst2"] = glue_models.SST2Model(
        os.path.join(FLAGS.models_path, "sst2"))
    datasets["sst_dev"] = glue.SST2Data("validation")
    logging.info("Loaded models and data for SST-2 task.")

  if "stsb" in FLAGS.tasks:
    models["stsb"] = glue_models.STSBModel(
        os.path.join(FLAGS.models_path, "stsb"))
    datasets["stsb_dev"] = glue.STSBData("validation")
    logging.info("Loaded models and data for STS-B task.")

  if "mnli" in FLAGS.tasks:
    models["mnli"] = glue_models.MNLIModel(
        os.path.join(FLAGS.models_path, "mnli"))
    datasets["mnli_dev"] = glue.MNLIData("validation_matched")
    datasets["mnli_dev_mm"] = glue.MNLIData("validation_mismatched")
    logging.info("Loaded models and data for MultiNLI task.")

  # Truncate datasets if --max_examples is set.
  for name in datasets:
    logging.info("Dataset: '%s' with %d examples", name, len(datasets[name]))
    datasets[name] = datasets[name].slice[:FLAGS.max_examples]
    logging.info("  truncated to %d examples", len(datasets[name]))

  # Start the LIT server. See server_flags.py for server options.
  lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags())
  lit_demo.serve()
Beispiel #2
0
    def setUp(self):
        super(ModelBasedHotflipTest, self).setUp()
        self.hotflip = hotflip.HotFlip()

        # Classification model that clasifies a given input sentence.
        self.classification_model = glue_models.SST2Model(BERT_TINY_PATH)
        self.classification_config = {hotflip.PREDICTION_KEY: 'probas'}

        # A wrapped version of the classification model that does not expose
        # embeddings.
        self.classification_model_without_embeddings = SSTModelWithoutEmbeddings(
            BERT_TINY_PATH)

        # A wrapped version of the classification model that does not take tokens
        # as input.
        self.classification_model_without_tokens = SSTModelWithoutTokens(
            BERT_TINY_PATH)

        # A wrapped version of the classification model that does not expose
        # gradients.
        self.classification_model_without_gradients = SSTModelWithoutGradients(
            BERT_TINY_PATH)

        # Regression model determining similarity between two input sentences.
        self.regression_model = glue_models.STSBModel(STSB_PATH)
        self.regression_config = {hotflip.PREDICTION_KEY: 'score'}
Beispiel #3
0
def main(_):

    ##
    # Pick the model and datasets
    # TODO(lit-dev): add remaining GLUE tasks? These three cover all the major
    # features (single segment, two segment, classification, regression).
    if FLAGS.task == "sst2":
        train_data = glue.SST2Data("train")
        val_data = glue.SST2Data("validation")
        model = glue_models.SST2Model(FLAGS.encoder_name)
    elif FLAGS.task == "mnli":
        train_data = glue.MNLIData("train")
        val_data = glue.MNLIData("validation_matched")
        model = glue_models.MNLIModel(FLAGS.encoder_name)
    elif FLAGS.task == "stsb":
        train_data = glue.STSBData("train")
        val_data = glue.STSBData("validation")
        model = glue_models.STSBModel(FLAGS.encoder_name)
    elif FLAGS.task == "toxicity":
        train_data = classification.ToxicityData("train")
        val_data = classification.ToxicityData("test")
        model = glue_models.ToxicityModel(FLAGS.encoder_name)
    else:
        raise ValueError(f"Unrecognized task name: '{FLAGS.task:s}'")

    ##
    # Run training and save model.
    train_and_save(model,
                   train_data,
                   val_data,
                   FLAGS.train_path,
                   save_intermediates=FLAGS.save_intermediates,
                   num_epochs=FLAGS.num_epochs)
Beispiel #4
0
    def setUp(self):
        super(ModelBasedAblationFlipTest, self).setUp()
        self.ablation_flip = ablation_flip.AblationFlip()

        # Classification model that clasifies a given input sentence.
        self.classification_model = glue_models.SST2Model(BERT_TINY_PATH)
        self.classification_config = {ablation_flip.PREDICTION_KEY: 'probas'}

        # Clasification model with the 'sentence' field marked as
        # non-required.
        self.classification_model_non_required_field = SST2ModelNonRequiredField(
            BERT_TINY_PATH)

        # Clasification model with a counter to count number of predict calls.
        # TODO(ataly): Consider setting up a Mock object to count number of
        # predict calls.
        self.classification_model_with_predict_counter = (
            SST2ModelWithPredictCounter(BERT_TINY_PATH))

        # Regression model determining similarity between two input sentences.
        self.regression_model = glue_models.STSBModel(STSB_PATH)
        self.regression_config = {ablation_flip.PREDICTION_KEY: 'score'}