Beispiel #1
0
 def setUp(self):
     self.temp_dir = tempfile.mkdtemp()
     self.path = os.path.join(self.temp_dir, "temp_dataset.tsv")
     self.wic = tasks.WiCTask(self.temp_dir, 100, "wic", tokenizer_name="MosesTokenizer")
     indexers = {"bert_cased": SingleIdTokenIndexer("bert-xe-cased")}
     self.wic.val_data = [
         Instance(
             {
                 "sent1_str": MetadataField("Room and board yo."),
                 "sent2_str": MetadataField("He nailed boards"),
                 "idx": LabelField(1, skip_indexing=True),
                 "idx2": NumericField(2),
                 "idx1": NumericField(3),
                 "inputs": self.sentence_to_text_field(
                     ["[CLS]", "Room", "and", "board", "yo", "[SEP]", "He", "nailed", "boards"],
                     indexers,
                 ),
                 "labels": LabelField(0, skip_indexing=1),
             }
         )
     ]
     self.vocab = vocabulary.Vocabulary.from_instances(self.wic.val_data)
     self.vocab.add_token_to_namespace("True", "wic_tags")
     for data in self.wic.val_data:
         data.index_fields(self.vocab)
     self.args = mock.Mock()
     self.args.batch_size = 4
     self.args.cuda = -1
     self.args.run_dir = self.temp_dir
     self.args.exp_dir = ""
 def setUp(self):
     """
     Since we're testing write_preds, we need to mock model predictions and the parts
     of the model, arguments, and trainer needed to write to predictions.
     Unlike in update_metrics tests, the actual contents of the examples in val_data
     is not the most important as long as it adheres to the API necessary for examples
     of that task.
     """
     self.temp_dir = tempfile.mkdtemp()
     self.path = os.path.join(self.temp_dir, "temp_dataset.tsv")
     self.stsb = tasks.STSBTask(self.temp_dir,
                                100,
                                "sts-b",
                                tokenizer_name="MosesTokenizer")
     self.wic = tasks.WiCTask(self.temp_dir,
                              100,
                              "wic",
                              tokenizer_name="MosesTokenizer")
     stsb_val_preds = pd.DataFrame(data=[
         {
             "idx": 0,
             "labels": 1.00,
             "preds": 1.00,
             "sent1_str": "A man with a hard hat is dancing.",
             "sent2_str": "A man wearing a hard hat is dancing",
         },
         {
             "idx": 1,
             "labels": 0.950,
             "preds": 0.34,
             "sent1_str": "A young child is riding a horse.",
             "sent2_str": "A child is riding a horse.",
         },
     ])
     wic_val_preds = pd.DataFrame(data=[
         {
             "idx": 0,
             "sent1": "Room and board. ",
             "sent2": "He nailed boards across the windows.",
             "labels": 0,
             "preds": 0,
         },
         {
             "idx": 1,
             "sent1": "Hook a fish",
             "sent2": "He hooked a snake accidentally.",
             "labels": 1,
             "preds": 1,
         },
     ])
     indexers = {"bert_cased": SingleIdTokenIndexer("bert-xe-cased")}
     self.wic.val_data = [
         Instance({
             "sent1_str":
             MetadataField("Room and board."),
             "sent2_str":
             MetadataField("He nailed boards"),
             "idx":
             LabelField(0, skip_indexing=True),
             "idx2":
             NumericField(2),
             "idx1":
             NumericField(3),
             "inputs":
             self.sentence_to_text_field(
                 [
                     "[CLS]",
                     "Room",
                     "and",
                     "Board",
                     ".",
                     "[SEP]",
                     "He",
                     "nailed",
                     "boards",
                     "[SEP]",
                 ],
                 indexers,
             ),
             "labels":
             LabelField(0, skip_indexing=1),
         }),
         Instance({
             "sent1_str":
             MetadataField("C ##ir ##culate a rumor ."),
             "sent2_str":
             MetadataField("This letter is being circulated"),
             "idx":
             LabelField(1, skip_indexing=True),
             "idx2":
             NumericField(2),
             "idx1":
             NumericField(3),
             "inputs":
             self.sentence_to_text_field(
                 [
                     "[CLS]",
                     "C",
                     "##ir",
                     "##culate",
                     "a",
                     "rumor",
                     "[SEP]",
                     "This",
                     "##let",
                     "##ter",
                     "is",
                     "being",
                     "c",
                     "##ir",
                     "##culated",
                     "[SEP]",
                 ],
                 indexers,
             ),
             "labels":
             LabelField(0, skip_indexing=1),
         }),
         Instance({
             "sent1_str":
             MetadataField("Hook a fish'"),
             "sent2_str":
             MetadataField("He hooked a snake accidentally"),
             "idx":
             LabelField(2, skip_indexing=True),
             "idx2":
             NumericField(2),
             "idx1":
             NumericField(3),
             "inputs":
             self.sentence_to_text_field(
                 [
                     "[CLS]",
                     "Hook",
                     "a",
                     "fish",
                     "[SEP]",
                     "He",
                     "hooked",
                     "a",
                     "snake",
                     "accidentally",
                     "[SEP]",
                 ],
                 indexers,
             ),
             "labels":
             LabelField(1, skip_indexing=1),
         }),
         Instance({
             "sent1_str":
             MetadataField("For recreation he wrote poetry."),
             "sent2_str":
             MetadataField("Drug abuse is often regarded as recreation ."),
             "idx":
             LabelField(3, skip_indexing=True),
             "idx2":
             NumericField(2),
             "idx1":
             NumericField(3),
             "inputs":
             self.sentence_to_text_field(
                 [
                     "[CLS]",
                     "For",
                     "re",
                     "##creation",
                     "he",
                     "wrote",
                     "poetry",
                     "[SEP]",
                     "Drug",
                     "abuse",
                     "is",
                     "often",
                     "re",
                     "##garded",
                     "as",
                     "re",
                     "##creation",
                     "[SEP]",
                 ],
                 indexers,
             ),
             "labels":
             LabelField(1, skip_indexing=1),
         }),
     ]
     self.val_preds = {"sts-b": stsb_val_preds, "wic": wic_val_preds}
     self.vocab = vocabulary.Vocabulary.from_instances(self.wic.val_data)
     self.vocab.add_token_to_namespace("True", "wic_tags")
     for data in self.wic.val_data:
         data.index_fields(self.vocab)
     self.glue_tasks = [self.stsb, self.wic]
     self.args = mock.Mock()
     self.args.batch_size = 4
     self.args.cuda = -1
     self.args.run_dir = self.temp_dir
     self.args.exp_dir = ""
 def setUp(self):
     self.temp_dir = tempfile.mkdtemp()
     self.wic = tasks.WiCTask(self.temp_dir,
                              100,
                              "wic",
                              tokenizer_name="MosesTokenizer")