Example #1
0
File: mt.py Project: oceanfly/lit
 def spec(self) -> Spec:
     return {
         'source': lit_types.TextSegment(),
         'source_language': lit_types.CategoryLabel(),
         'target': lit_types.TextSegment(),
         'target_language': lit_types.CategoryLabel(),
     }
Example #2
0
 def spec(self):
     return {
         "premise": lit_types.TextSegment(),
         "hypothesis": lit_types.TextSegment(),
         "label": lit_types.CategoryLabel(vocab=self.LABELS),
         "language": lit_types.CategoryLabel(),
     }
Example #3
0
 def spec(self):
   return {
       'sentence': lit_types.TextSegment(),
       'review_title': lit_types.TextSegment(),
       'product_name': lit_types.TextSegment(),
       'label': lit_types.CategoryLabel(vocab=self.LABELS)
   }
Example #4
0
 def spec(self) -> lit_types.Spec:
     return {
         "premise": lit_types.TextSegment(),
         "hypothesis": lit_types.TextSegment(),
         # 'label2' for 2-way NLI labels
         "label2": lit_types.CategoryLabel(vocab=self.LABELS),
         "heuristic": lit_types.CategoryLabel(),
         "template": lit_types.CategoryLabel(),
     }
Example #5
0
 def test_compatibility_fullmatch(self):
     """Test with an exact match."""
     mspec = model.ModelSpec(input={
         "text_a": types.TextSegment(),
         "text_b": types.TextSegment(),
     },
                             output={})
     dspec = mspec.input
     self.assertTrue(mspec.is_compatible_with_dataset(dspec))
Example #6
0
 def spec(self) -> lit_types.Spec:
     """Should match MnliModel's input_spec()."""
     return {
         "premise": lit_types.TextSegment(),
         "hypothesis": lit_types.TextSegment(),
         # 'label' for 3-way NLI labels, 'label2' for binarized.
         "label": lit_types.CategoryLabel(vocab=self.LABELS3),
         "label2": lit_types.CategoryLabel(vocab=self.LABELS2),
         "genre": lit_types.CategoryLabel(),
     }
Example #7
0
 def config_spec(self) -> types.Spec:
     return {
         CLASS_KEY: types.TextSegment(default=str(CLASS_DEFAULT)),
         KERNEL_WIDTH_KEY:
         types.TextSegment(default=str(KERNEL_WIDTH_DEFAULT)),
         MASK_KEY: types.TextSegment(default=MASK_DEFAULT),
         NUM_SAMPLES_KEY:
         types.TextSegment(default=str(NUM_SAMPLES_DEFAULT)),
         SEED_KEY: types.TextSegment(default=SEED_DEFAULT),
     }
Example #8
0
 def input_spec(self) -> Spec:
   ret = {}
   ret[self.config.text_a_name] = lit_types.TextSegment()
   if self.config.text_b_name:
     ret[self.config.text_b_name] = lit_types.TextSegment()
   if self.is_regression:
     ret[self.config.label_name] = lit_types.RegressionScore(required=False)
   else:
     ret[self.config.label_name] = lit_types.CategoryLabel(
         required=False, vocab=self.config.labels)
   return ret
Example #9
0
 def spec(self) -> lit_types.Spec:
     """Should match MLM's input_spec()."""
     return {
         'input_text':
         lit_types.TextSegment(),
         'target_text':
         lit_types.TextSegment(),
         'input_tokens':
         lit_types.Tokens(required=False),
         'gece_tags':
         lit_types.SequenceTags(align='input_tokens', required=False)
     }
Example #10
0
 def test_compatibility_mismatch(self):
     """Test with specs that don't match."""
     mspec = model.ModelSpec(input={
         "text_a": types.TextSegment(),
         "text_b": types.TextSegment(),
     },
                             output={})
     dspec = {
         "premise": types.TextSegment(),
         "hypothesis": types.TextSegment()
     }
     self.assertFalse(mspec.is_compatible_with_dataset(dspec))
Example #11
0
 def test_compatibility_extrafield(self):
     """Test with an extra field in the dataset."""
     mspec = model.ModelSpec(input={
         "text_a": types.TextSegment(),
         "text_b": types.TextSegment(),
     },
                             output={})
     dspec = {
         "text_a": types.TextSegment(),
         "text_b": types.TextSegment(),
         "label": types.CategoryLabel(vocab=["0", "1"]),
     }
     self.assertTrue(mspec.is_compatible_with_dataset(dspec))
Example #12
0
 def config_spec(self) -> types.Spec:
     matcher_types = [
         'MulticlassPreds', 'SparseMultilabelPreds', 'RegressionScore'
     ]
     return {
         TARGET_HEAD_KEY: types.FieldMatcher(spec='output',
                                             types=matcher_types),
         CLASS_KEY: types.TextSegment(default='-1'),
         MASK_KEY: types.TextSegment(default='[MASK]'),
         KERNEL_WIDTH_KEY: types.TextSegment(default='256'),
         NUM_SAMPLES_KEY: types.TextSegment(default='256'),
         SEED_KEY: types.TextSegment(default=''),
     }
Example #13
0
 def test_find_spec_keys(self):
   spec = {
       "score": types.RegressionScore(),
       "scalar_foo": types.Scalar(),
       "text": types.TextSegment(),
       "emb_0": types.Embeddings(),
       "emb_1": types.Embeddings(),
       "tokens": types.Tokens(),
       "generated_text": types.GeneratedText(),
   }
   self.assertEqual(["score"], utils.find_spec_keys(spec,
                                                    types.RegressionScore))
   self.assertEqual(["text", "tokens", "generated_text"],
                    utils.find_spec_keys(spec,
                                         (types.TextSegment, types.Tokens)))
   self.assertEqual(["emb_0", "emb_1"],
                    utils.find_spec_keys(spec, types.Embeddings))
   self.assertEqual([], utils.find_spec_keys(spec, types.AttentionHeads))
   # Check subclasses
   self.assertEqual(
       list(spec.keys()), utils.find_spec_keys(spec, types.LitType))
   self.assertEqual(["text", "generated_text"],
                    utils.find_spec_keys(spec, types.TextSegment))
   self.assertEqual(["score", "scalar_foo"],
                    utils.find_spec_keys(spec, types.Scalar))
Example #14
0
 def test_compatibility_optionals(self):
     """Test with optionals in the model spec."""
     mspec = model.ModelSpec(input={
         "text":
         types.TextSegment(),
         "tokens":
         types.Tokens(parent="text", required=False),
         "label":
         types.CategoryLabel(vocab=["0", "1"], required=False),
     },
                             output={})
     dspec = {
         "text": types.TextSegment(),
         "label": types.CategoryLabel(vocab=["0", "1"]),
     }
     self.assertTrue(mspec.is_compatible_with_dataset(dspec))
Example #15
0
 def config_spec(self) -> types.Spec:
     return {
         NUM_EXAMPLES_KEY:
         types.TextSegment(default=str(NUM_EXAMPLES_DEFAULT)),
         MAX_ABLATIONS_KEY:
         types.TextSegment(default=str(MAX_ABLATIONS_DEFAULT)),
         PREDICTION_KEY:
         types.FieldMatcher(spec="output",
                            types=["MulticlassPreds", "RegressionScore"]),
         REGRESSION_THRESH_KEY:
         types.TextSegment(default=str(REGRESSION_THRESH_DEFAULT)),
         FIELDS_TO_ABLATE_KEY:
         types.MultiFieldMatcher(spec="input",
                                 types=["TextSegment", "SparseMultilabel"],
                                 select_all=True),
     }
Example #16
0
    def __init__(self, model, tasks):
        """Initialize with Stanza model and a dictionary of tasks.

    Args:
      model: A Stanza model
      tasks: A dictionary of tasks, grouped by task type.
        Keys are the grouping, which should be one of:
          ('sequence', 'span', 'edge').
        Values are a list of stanza task names as strings.
    """
        self.model = model
        # Store lists of task name strings by grouping
        self.sequence_tasks = tasks["sequence"]
        self.span_tasks = tasks["span"]
        self.edge_tasks = tasks["edge"]

        self._input_spec = {
            "sentence": lit_types.TextSegment(),
        }

        self._output_spec = {
            "tokens": lit_types.Tokens(),
        }

        # Output spec based on specified tasks
        for task in self.sequence_tasks:
            self._output_spec[task] = lit_types.SequenceTags(align="tokens")
        for task in self.span_tasks:
            self._output_spec[task] = lit_types.SpanLabels(align="tokens")
        for task in self.edge_tasks:
            self._output_spec[task] = lit_types.EdgeLabels(align="tokens")
Example #17
0
 def spec(self):
     return {
         "text":
         lit_types.TextSegment(),
         "tokens":
         lit_types.Tokens(parent="text"),
         "coref":
         lit_types.EdgeLabels(align="tokens"),
         # Metadata fields for filtering and analysis.
         "occupation":
         lit_types.CategoryLabel(),
         "participant":
         lit_types.CategoryLabel(),
         "answer":
         lit_types.CategoryLabel(vocab=ANSWER_VOCAB),
         "someone":
         lit_types.CategoryLabel(vocab=["True", "False"]),
         "pronouns":
         lit_types.CategoryLabel(vocab=list(PRONOUNS_BY_GENDER.values())),
         "pronoun_type":
         lit_types.CategoryLabel(vocab=["NOM", "POSS", "ACC"]),
         "gender":
         lit_types.CategoryLabel(vocab=[g.name for g in Gender]),
         "pf_bls":
         lit_types.Scalar(),
     }
Example #18
0
 def test_compatibility_optionals_mismatch(self):
     """Test with optionals that don't match metadata."""
     mspec = model.ModelSpec(input={
         "text":
         types.TextSegment(),
         "tokens":
         types.Tokens(parent="text", required=False),
         "label":
         types.CategoryLabel(vocab=["0", "1"], required=False),
     },
                             output={})
     dspec = {
         "text": types.TextSegment(),
         # This label field doesn't match the one the model expects.
         "label": types.CategoryLabel(vocab=["foo", "bar"]),
     }
     self.assertFalse(mspec.is_compatible_with_dataset(dspec))
Example #19
0
 def config_spec(self) -> types.Spec:
     return {
         CLASS_KEY:
         types.TextSegment(default=''),
         NORMALIZATION_KEY:
         types.Boolean(default=True),
         INTERPOLATION_KEY:
         types.Scalar(min_val=5, max_val=100, default=30, step=1)
     }
Example #20
0
 def config_spec(self) -> types.Spec:
     return {
         NUM_EXAMPLES_KEY:
         types.TextSegment(default=str(NUM_EXAMPLES_DEFAULT)),
         MAX_FLIPS_KEY:
         types.TextSegment(default=str(MAX_FLIPS_DEFAULT)),
         TOKENS_TO_IGNORE_KEY:
         types.Tokens(default=TOKENS_TO_IGNORE_DEFAULT),
         PREDICTION_KEY:
         types.FieldMatcher(spec="output",
                            types=["MulticlassPreds", "RegressionScore"]),
         REGRESSION_THRESH_KEY:
         types.TextSegment(default=str(REGRESSION_THRESH_DEFAULT)),
         FIELDS_TO_HOTFLIP_KEY:
         types.MultiFieldMatcher(spec="input",
                                 types=["Tokens"],
                                 select_all=True),
     }
Example #21
0
 def config_spec(self) -> types.Spec:
     return {
         # Requires a substitution string. Include a default.
         SUBSTITUTIONS_KEY:
         types.TextSegment(default='great -> terrible'),
         FIELDS_TO_REPLACE_KEY:
         types.MultiFieldMatcher(spec='input',
                                 types=['TextSegment'],
                                 select_all=True),
     }
Example #22
0
 def spec(self) -> lit_types.Spec:
     """Dataset spec, which should match the model"s input_spec()."""
     return {
         "sentence": lit_types.TextSegment(),
         "label": lit_types.CategoryLabel(vocab=self.LABELS),
         "identity_attack": lit_types.Boolean(),
         "insult": lit_types.Boolean(),
         "obscene": lit_types.Boolean(),
         "severe_toxicity": lit_types.Boolean(),
         "threat": lit_types.Boolean()
     }
Example #23
0
 def input_spec(self) -> Spec:
     ret = {}
     ret[self.config.text_a_name] = lit_types.TextSegment()
     if self.config.text_b_name:
         ret[self.config.text_b_name] = lit_types.TextSegment()
     if self.is_regression:
         ret[self.config.label_name] = lit_types.RegressionScore(
             required=False)
     else:
         ret[self.config.label_name] = lit_types.CategoryLabel(
             required=False, vocab=self.config.labels)
     # The input_embs_ and grad_class fields are used for Integrated Gradients.
     ret["input_embs_" +
         self.config.text_a_name] = lit_types.TokenEmbeddings(
             align="tokens", required=False)
     if self.config.text_b_name:
         ret["input_embs_" +
             self.config.text_b_name] = lit_types.TokenEmbeddings(
                 align="tokens", required=False)
     ret["grad_class"] = lit_types.CategoryLabel(required=False,
                                                 vocab=self.config.labels)
     return ret
Example #24
0
def create_train_dataset(config: Config) -> lit_dataset.Dataset:
    src_path = config.exp_dir / "train.src.txt"
    trg_path = config.exp_dir / "train.trg.txt"
    default_src_iso = config.default_src_iso
    default_trg_iso = config.default_trg_iso
    examples: List[lit_types.JsonDict] = []
    with src_path.open("r", encoding="utf-8") as src_file, open(
            trg_path, "r", encoding="utf-8") as trg_file:
        for src_line, trg_line in zip(src_file, trg_file):
            src_line = src_line.strip()
            trg_line = trg_line.strip()
            src_iso = default_src_iso
            if len(config.src_isos) > 1:
                src_iso = "?"
            trg_iso = default_trg_iso
            if src_line.startswith("<2"):
                index = src_line.index(">")
                val = src_line[2:index]
                if val != "qaa":
                    trg_iso = val
            example: lit_types.JsonDict = {
                "vref": "?",
                "src_text": decode_sp(src_line),
                "ref_text": decode_sp(trg_line),
                "src_iso": src_iso,
                "trg_iso": trg_iso,
            }
            examples.append(example)
            if len(examples) == 2000:
                break
    spec: lit_types.JsonDict = {
        "vref": lit_types.CategoryLabel(),
        "src_text": lit_types.TextSegment(),
        "ref_text": lit_types.TextSegment(),
        "src_iso": lit_types.CategoryLabel(),
        "trg_iso": lit_types.CategoryLabel(),
    }
    return lit_dataset.Dataset(spec, examples, description="train dataset")
Example #25
0
    def setUp(self):
        super(ThresholderTest, self).setUp()
        self.thresholder = thresholder.Thresholder()
        self.model = caching.CachingModelWrapper(
            glue_models.SST2Model(BERT_TINY_PATH), 'test')
        examples = [{
            'sentence': 'a',
            'label': '1'
        }, {
            'sentence': 'b',
            'label': '1'
        }, {
            'sentence': 'c',
            'label': '1'
        }, {
            'sentence': 'd',
            'label': '1'
        }, {
            'sentence': 'e',
            'label': '1'
        }, {
            'sentence': 'f',
            'label': '0'
        }, {
            'sentence': 'g',
            'label': '0'
        }, {
            'sentence': 'h',
            'label': '0'
        }, {
            'sentence': 'i',
            'label': '0'
        }]

        self.indexed_inputs = [{
            'id': caching.input_hash(ex),
            'data': ex
        } for ex in examples]
        self.dataset = lit_dataset.IndexedDataset(
            id_fn=caching.input_hash,
            spec={
                'sentence': lit_types.TextSegment(),
                'label': lit_types.CategoryLabel(vocab=['0', '1'])
            },
            indexed_examples=self.indexed_inputs)
        self.model_outputs = list(
            self.model.predict_with_metadata(self.indexed_inputs,
                                             dataset_name='test'))
Example #26
0
 def input_spec(self):
     return {
         'text':
         lit_types.TextSegment(),
         'tokens':
         lit_types.Tokens(parent='text'),
         'coref':
         lit_types.EdgeLabels(align='tokens'),
         # Index of predicted (single) edge for Winogender
         'answer':
         lit_types.CategoryLabel(vocab=winogender.ANSWER_VOCAB,
                                 required=False),
         # TODO(b/172975096): allow plotting of scalars from input data,
         # so we don't need to add this to the predictions.
         'pf_bls':
         lit_types.Scalar(required=False),
     }
 def config_spec(self) -> lit_types.Spec:
     return {
         NUM_EXAMPLES_KEY:
         lit_types.Scalar(min_val=1,
                          max_val=20,
                          default=NUM_EXAMPLES_DEFAULT,
                          step=1),
         MAX_FLIPS_KEY:
         lit_types.Scalar(min_val=1,
                          max_val=10,
                          default=MAX_FLIPS_DEFAULT,
                          step=1),
         PREDICTION_KEY:
         lit_types.FieldMatcher(
             spec='output', types=['MulticlassPreds', 'RegressionScore']),
         REGRESSION_THRESH_KEY:
         lit_types.TextSegment(default=str(REGRESSION_THRESH_DEFAULT)),
     }
Example #28
0
 def test_remap(self):
     """Test remap method."""
     spec = {
         "score": types.Scalar(),
         "text": types.TextSegment(),
     }
     datapoints = [
         {
             "score": 0,
             "text": "a"
         },
         {
             "score": 0,
             "text": "b"
         },
     ]
     dset = lit_dataset.Dataset(spec, datapoints)
     remap_dict = {"score": "val", "nothing": "nada"}
     remapped_dset = dset.remap(remap_dict)
     self.assertIn("val", remapped_dset.spec())
     self.assertNotIn("score", remapped_dset.spec())
     self.assertEqual({"val": 0, "text": "a"}, remapped_dset.examples[0])
Example #29
0
 def input_spec(self) -> lit_types.Spec:
     return {
         "sentence": lit_types.TextSegment(),
         "label": lit_types.RegressionScore(required=False)
     }
Example #30
0
 def input_spec(self) -> lit_types.Spec:
     return {
         "sentence": lit_types.TextSegment(),
         "label": lit_types.CategoryLabel(vocab=self._labels, required=False)
     }