Beispiel #1
0
    def eval(self, name):
        # evaluate on evaluation [name], either takes string or list of strings
        if (isinstance(name, list)):
            self.results = {x: self.eval(x) for x in name}
            return self.results

        tpath = self.params.task_path
        assert name in self.list_tasks, str(name) + ' not in ' + str(
            self.list_tasks)
        if name == 'CR':
            self.evaluation = CREval(tpath + '/CR', seed=self.params.seed)
        elif name == 'MR':
            self.evaluation = MREval(tpath + '/MR', seed=self.params.seed)
        elif name == 'MPQA':
            self.evaluation = MPQAEval(tpath + '/MPQA', seed=self.params.seed)
        elif name == 'SUBJ':
            self.evaluation = SUBJEval(tpath + '/SUBJ', seed=self.params.seed)
        elif name == 'SST2':
            self.evaluation = SSTEval(tpath + '/SST/binary',
                                      nclasses=2,
                                      seed=self.params.seed)
        elif name == 'SST5':
            self.evaluation = SSTEval(tpath + '/SST/fine',
                                      nclasses=5,
                                      seed=self.params.seed)
        elif name == 'TREC':
            self.evaluation = TRECEval(tpath + '/TREC', seed=self.params.seed)
        elif name == 'MRPC':
            self.evaluation = MRPCEval(tpath + '/MRPC', seed=self.params.seed)
        elif name == 'SICKRelatedness':
            self.evaluation = SICKRelatednessEval(tpath + '/SICK',
                                                  seed=self.params.seed)
        elif name == 'STSBenchmark':
            self.evaluation = STSBenchmarkEval(tpath + '/STS/STSBenchmark',
                                               seed=self.params.seed)
        elif name == 'SICKEntailment':
            self.evaluation = SICKEntailmentEval(tpath + '/SICK',
                                                 seed=self.params.seed)
        elif name == 'SNLI':
            self.evaluation = SNLIEval(tpath + '/SNLI', seed=self.params.seed)
        elif name in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16']:
            fpath = name + '-en-test'
            # Hacky way of constructing a class from a string
            # STSClass will be one of STS[NN]Eval
            STSClass = eval(name + 'Eval')
            self.evaluation = STSClass(tpath + '/STS/' + fpath,
                                       seed=self.params.seed)
        elif name == 'ImageCaptionRetrieval':
            self.evaluation = ImageCaptionRetrievalEval(tpath + '/COCO',
                                                        seed=self.params.seed)

        self.params.current_task = name
        self.evaluation.do_prepare(self.params, self.prepare)

        self.results = self.evaluation.run(self.params, self.batcher)

        return self.results
Beispiel #2
0
    def eval(self, name):
        # evaluate on evaluation [name], either takes string or list of strings
        if (isinstance(name, list)):
            self.results = {x: self.eval(x) for x in name}
            return self.results

        tpath = self.params.task_path
        assert name in self.list_tasks, str(name) + ' not in ' + str(self.list_tasks)

        if name == 'SST2':
            self.evaluation = SSTEval(tpath + '/SST/binary', 'SST binary', nclasses=2, seed=self.params.seed)
        elif name == 'SST3':
            self.evaluation = SSTEval(tpath + '/SST/dialog-2016', 'SST3', nclasses=3, seed=self.params.seed)
        elif name == 'ReadabilityCl':
            self.evaluation = SSTEval(tpath + '/Readability classifier', 'readability classifier', nclasses=10, seed=self.params.seed)
        elif name == 'TagCl':
            self.evaluation = SSTEval(tpath + '/Tags classifier', 'tag classifier', nclasses=6961, seed=self.params.seed)
        elif name == 'PoemsCl':
            self.evaluation = SSTEval(tpath + '/Poems classifier', 'poems classifier', nclasses=33, seed=self.params.seed)
        elif name == 'ProzaCl':
            self.evaluation = SSTEval(tpath + '/Proza classifier', 'proza classifier', nclasses=35, seed=self.params.seed)
        elif name == 'TREC':
            self.evaluation = TRECEval(tpath + '/TREC', seed=self.params.seed)
        elif name == 'STS':
            self.evaluation = STSBenchmarkEval(tpath + '/STS', seed=self.params.seed)
        elif name == 'SICK':
            self.evaluation = SICKRelatednessEval(tpath + '/SICK', seed=self.params.seed)
        elif name == 'MRPC':
            self.evaluation = MRPCEval(tpath + '/MRPC', seed=self.params.seed)

        self.params.current_task = name
        self.evaluation.do_prepare(self.params, self.prepare)

        start = time.time()
        self.results = self.evaluation.run(self.params, self.batcher)
        end = time.time()
        self.results["time"] = end - start
        logging.debug('\nTime for task : {0} sec\n'.format(self.results["time"]))

        return self.results
Beispiel #3
0
    def eval(self, name):
        # evaluate on evaluation [name], either takes string or list of strings
        if (isinstance(name, list)):
            self.results = {x: self.eval(x) for x in name}
            return self.results

        tpath = self.params.task_path
        assert name in self.list_tasks, str(name) + ' not in ' + str(
            self.list_tasks)
        if name == 'CR':
            self.evaluation = CREval(tpath + '/CR', seed=self.params.seed)
        elif name == 'MR':
            self.evaluation = MREval(tpath + '/MR', seed=self.params.seed)
        elif name == 'MPQA':
            self.evaluation = MPQAEval(tpath + '/MPQA', seed=self.params.seed)
        elif name == 'SUBJ':
            self.evaluation = SUBJEval(tpath + '/SUBJ', seed=self.params.seed)
        elif name == 'SST2':
            self.evaluation = SSTEval(tpath + '/SST/binary',
                                      nclasses=2,
                                      seed=self.params.seed)
        elif name == 'SST5':
            self.evaluation = SSTEval(tpath + '/SST/fine',
                                      nclasses=5,
                                      seed=self.params.seed)
        elif name == 'TREC':
            self.evaluation = TRECEval(tpath + '/TREC', seed=self.params.seed)
        elif name == 'MRPC':
            self.evaluation = MRPCEval(tpath + '/MRPC', seed=self.params.seed)
        elif name == 'SICKRelatedness':
            self.evaluation = SICKRelatednessEval(tpath + '/SICK',
                                                  seed=self.params.seed)
        elif name == 'STSBenchmark':
            self.evaluation = STSBenchmarkEval(tpath + '/STS/STSBenchmark',
                                               seed=self.params.seed)
        elif name == 'SICKEntailment':
            self.evaluation = SICKEntailmentEval(tpath + '/SICK',
                                                 seed=self.params.seed)
        elif name == 'SNLI':
            self.evaluation = SNLIEval(tpath + '/SNLI', seed=self.params.seed)
        elif name == 'DIS':
            self.evaluation = DISEval(tpath + '/DIS', seed=self.params.seed)
        elif name == 'PDTB':
            self.evaluation = PDTB_Eval(tpath + '/PDTB', seed=self.params.seed)
        elif name == "PDTB_EX":
            self.evaluation = PDTB_EX_Eval(tpath + '/PDTB_EX',
                                           seed=self.params.seed)
        elif name == "PDTB_IMEX":
            self.evaluation = PDTB_IMEX_Eval(tpath + '/PDTB_IMEX',
                                             seed=self.params.seed)
        elif name == 'DAT':
            self.evaluation = DAT_EVAL(tpath + '/DAT', seed=self.params.seed)
        elif name in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16']:
            fpath = name + '-en-test'
            self.evaluation = eval(name + 'Eval')(tpath + '/STS/' + fpath,
                                                  seed=self.params.seed)
        elif name == 'ImageCaptionRetrieval':
            self.evaluation = ImageCaptionRetrievalEval(tpath + '/COCO',
                                                        seed=self.params.seed)
        elif name == 'ABSA_CH':
            self.evaluation = ABSA_CHEval(tpath + '/ABSA_CH',
                                          seed=self.params.seed)
        elif name == 'ABSA_SP':
            self.evaluation = ABSA_SPEval(tpath + '/ABSA_SP',
                                          seed=self.params.seed)
        elif name == 'STS_SP':
            self.evaluation = STS_SPBenchmarkEval(tpath +
                                                  '/STS_SP/STSBenchmark',
                                                  seed=self.params.seed)

        self.params.current_task = name
        self.evaluation.do_prepare(self.params, self.prepare)

        self.results = self.evaluation.run(self.params, self.batcher)

        return self.results
Beispiel #4
0
    def eval(self, name):
        # evaluate on evaluation [name], either takes string or list of strings
        if (isinstance(name, list)):
            self.results = {x: self.eval(x) for x in name}
            return self.results

        tpath = self.params.task_path
        assert name in self.list_tasks, str(name) + ' not in ' + str(
            self.list_tasks)

        # Original SentEval tasks
        if name == 'CR':
            self.evaluation = CREval(tpath + '/downstream/CR',
                                     seed=self.params.seed)
        elif name == 'MR':
            self.evaluation = MREval(tpath + '/downstream/MR',
                                     seed=self.params.seed)
        elif name == 'MPQA':
            self.evaluation = MPQAEval(tpath + '/downstream/MPQA',
                                       seed=self.params.seed)
        elif name == 'SUBJ':
            self.evaluation = SUBJEval(tpath + '/downstream/SUBJ',
                                       seed=self.params.seed)
        elif name == 'SST2':
            self.evaluation = SSTEval(tpath + '/downstream/SST/binary',
                                      nclasses=2,
                                      seed=self.params.seed)
        elif name == 'SST5':
            self.evaluation = SSTEval(tpath + '/downstream/SST/fine',
                                      nclasses=5,
                                      seed=self.params.seed)
        elif name == 'TREC':
            self.evaluation = TRECEval(tpath + '/downstream/TREC',
                                       seed=self.params.seed)
        elif name == 'MRPC':
            self.evaluation = MRPCEval(tpath + '/downstream/MRPC',
                                       seed=self.params.seed)
        elif name == 'SICKRelatedness':
            self.evaluation = SICKRelatednessEval(tpath + '/downstream/SICK',
                                                  seed=self.params.seed)
        elif name == 'STSBenchmark':
            self.evaluation = STSBenchmarkEval(tpath +
                                               '/downstream/STS/STSBenchmark',
                                               seed=self.params.seed)
        elif name == 'SICKEntailment':
            self.evaluation = SICKEntailmentEval(tpath + '/downstream/SICK',
                                                 seed=self.params.seed)
        elif name == 'SNLI':
            self.evaluation = SNLIEval(tpath + '/downstream/SNLI',
                                       seed=self.params.seed)
        elif name in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16']:
            fpath = name + '-en-test'
            self.evaluation = eval(name + 'Eval')(tpath + '/downstream/STS/' +
                                                  fpath,
                                                  seed=self.params.seed)
        elif name == 'ImageCaptionRetrieval':
            self.evaluation = ImageCaptionRetrievalEval(tpath +
                                                        '/downstream/COCO',
                                                        seed=self.params.seed)

        # Probing Tasks
        elif name == 'Length':
            self.evaluation = LengthEval(tpath + '/probing',
                                         seed=self.params.seed)
        elif name == 'WordContent':
            self.evaluation = WordContentEval(tpath + '/probing',
                                              seed=self.params.seed)
        elif name == 'Depth':
            self.evaluation = DepthEval(tpath + '/probing',
                                        seed=self.params.seed)
        elif name == 'TopConstituents':
            self.evaluation = TopConstituentsEval(tpath + '/probing',
                                                  seed=self.params.seed)
        elif name == 'BigramShift':
            self.evaluation = BigramShiftEval(tpath + '/probing',
                                              seed=self.params.seed)
        elif name == 'Tense':
            self.evaluation = TenseEval(tpath + '/probing',
                                        seed=self.params.seed)
        elif name == 'SubjNumber':
            self.evaluation = SubjNumberEval(tpath + '/probing',
                                             seed=self.params.seed)
        elif name == 'ObjNumber':
            self.evaluation = ObjNumberEval(tpath + '/probing',
                                            seed=self.params.seed)
        elif name == 'OddManOut':
            self.evaluation = OddManOutEval(tpath + '/probing',
                                            seed=self.params.seed)
        elif name == 'CoordinationInversion':
            self.evaluation = CoordinationInversionEval(tpath + '/probing',
                                                        seed=self.params.seed)

        self.params.current_task = name
        self.evaluation.do_prepare(self.params, self.prepare)

        self.results = self.evaluation.run(self.params, self.batcher)

        return self.results
Beispiel #5
0
def get_datasets(path, dataset, MAX_NUM_WORDS, MAX_SEQUENCE_LENGTH,
                 isPairData):
    if dataset == "TREC":
        class_num = 6
    elif dataset in ["SICK_E", "SNLI"]:
        class_num = 3
    elif dataset == "SST5":
        class_num = 5
    elif dataset in ["SICK_R", "STS"]:
        class_num = 1
    else:
        class_num = 2

    if dataset == "MRPC":
        # Pair datasets, it does not provide val set

        mrpc = MRPCEval(path + "data/MRPC/")

        sen1_train = mrpc.mrpc_data['train']['X_A']
        sen2_train = mrpc.mrpc_data['train']['X_B']
        sen1_test = mrpc.mrpc_data['test']['X_A']
        sen2_test = mrpc.mrpc_data['test']['X_B']

        corpus = sen1_train + sen2_train + sen1_test + sen2_test

        y_train = np.array(mrpc.mrpc_data['train']["y"])
        y_test = np.array(mrpc.mrpc_data['test']["y"])

        class_num = 2

    if dataset in ["SUBJ", "MR", "CR", "MPQA"]:

        if dataset == "CR":
            eval = CREval(path + "data/CR/")
        elif dataset == "MR":
            eval = MREval(path + "data/MR/")
        elif dataset == "SUBJ":
            eval = SUBJEval(path + "data/SUBJ/")
        elif dataset == "MPQA":
            eval = MPQAEval(path + "data/MPQA/")

        corpus = eval.samples
        labels = eval.labels

        x_train, x_test, y_train, y_test = train_test_split(
            corpus, labels, test_size=0.2, random_state=eval.seed)
        x_train, x_val, y_train, y_val = train_test_split(
            x_train, y_train, test_size=0.2, random_state=eval.seed)

        y_train = np.array(y_train)
        y_val = np.array(y_val)
        y_test = np.array(y_test)

        class_num = 2

    elif dataset == "TREC":
        trec = TRECEval(path + "data/TREC/")

        x_train = trec.train["X"]
        y_train = trec.train["y"]
        x_test = trec.test["X"]
        y_test = trec.test["y"]
        y_train = to_categorical(y_train)
        y_test = to_categorical(y_test)

        x_train, x_val, y_train, y_val = train_test_split(
            x_train, y_train, test_size=0.2, random_state=trec.seed)

        corpus = x_train + x_test
        class_num = 6

    elif dataset == "SNLI":
        snli = SNLIEval(path + "data/SNLI/")
        corpus = snli.samples

        sen1_train = snli.data['train'][0]
        sen2_train = snli.data['train'][1]
        cat2idx = {"contradiction": 0, "entailment": 1, "neutral": 2}
        y_train = to_categorical([cat2idx[i] for i in snli.data['train'][2]])

        sen1_val = snli.data['valid'][0]
        sen2_val = snli.data['valid'][1]
        y_val = snli.data['valid'][2]

        sen1_test = snli.data['test'][0]
        sen2_test = snli.data['test'][1]
        y_test = to_categorical([cat2idx[i] for i in snli.data['test'][2]])

        class_num = 3

    elif dataset in ["SICK_R", "SICK_E", "STS"]:
        if dataset == "SICK_R":
            sick = SICKRelatednessEval(path + "data/SICK/")
            class_num = 1
        elif dataset == "SICK_E":
            sick = SICKEntailmentEval(path + "data/SICK/")
            class_num = 3
        elif dataset == "STS":
            sick = STSBenchmarkEval(path + "data/STS/STSBenchmark/")
            class_num = 1

        sen1_train = sick.sick_data['train']['X_A']
        sen2_train = sick.sick_data['train']['X_B']

        sen1_val = sick.sick_data['dev']['X_A']
        sen2_val = sick.sick_data['dev']['X_B']

        sen1_test = sick.sick_data['test']['X_A']
        sen2_test = sick.sick_data['test']['X_B']

        corpus = sen1_train + sen2_train + sen1_val + sen2_val + sen1_test + sen2_test

        y_train = np.array(sick.sick_data['train']
                           ["y"]) if dataset != "SICK_E" else to_categorical(
                               sick.sick_data['train']["y"])
        y_val = np.array(sick.sick_data['dev']
                         ["y"]) if dataset != "SICK_E" else to_categorical(
                             sick.sick_data['dev']["y"])
        y_test = np.array(sick.sick_data['test']
                          ["y"]) if dataset != "SICK_E" else to_categorical(
                              sick.sick_data['test']["y"])

    elif dataset in ["SST2", "SST5"]:
        if dataset == "SST2":
            sst = SSTEval(path + "data/SST/binary", nclasses=2)
        else:
            sst = SSTEval(path + "data/SST/fine/", nclasses=5)

        class_num = 2 if dataset == "SST2" else 1  # SST5 labels are 0 - 5 so regression task

        x_train = sst.sst_data["train"]["X"]
        y_train = np.array(sst.sst_data["train"]["y"])
        x_val = sst.sst_data["dev"]["X"]
        y_val = np.array(sst.sst_data["dev"]["y"])
        x_test = sst.sst_data["test"]["X"]
        y_test = np.array(sst.sst_data["test"]["y"])

        corpus = x_train + x_val + x_test

    elif dataset == "QQP":

        df = pd.read_csv(path + "data/glue_data/QQP/train.tsv",
                         sep="\t",
                         names=["id", "qid1", "qid2", "s1", "s2", "label"],
                         skiprows=1,
                         error_bad_lines=False)
        df_test = pd.read_csv(
            path + "data/glue_data/QQP/dev.tsv",
            sep="\t",
            names=["id", "qid1", "qid2", "s1", "s2", "label"],
            skiprows=1,
            error_bad_lines=False)

        df = df[~df.label.isna()]
        df_test = df_test[~df_test.label.isna()]

        df.s1 = df.s1.astype(str)
        df.s2 = df.s2.astype(str)

        df_test.s1 = df_test.s1.astype(str)
        df_test.s2 = df_test.s2.astype(str)

        y_train = df.label.values
        y_test = df_test.label.values

        sen1_train = df.s1.tolist()
        sen2_train = df.s2.tolist()
        sen1_test = df_test.s1.tolist()
        sen2_test = df_test.s2.tolist()

        corpus = sen1_train + sen2_train + sen1_test + sen2_test
        class_num = 2

    # # create the tokenizer
    # t = Tokenizer()
    # # fit the tokenizer on the documents
    # t.fit_on_texts(corpus)

    # finally, vectorize the text samples into a 2D integer tensor
    tokenizer = Tokenizer(num_words=MAX_NUM_WORDS)
    tokenizer.fit_on_texts(corpus)

    word_index = tokenizer.word_index
    print('Found %s unique tokens.' % len(word_index))

    # Update max length
    MAX_SEQUENCE_LENGTH = min(np.max([len(i) for i in corpus]),
                              MAX_SEQUENCE_LENGTH)
    print("Updated maxlen: %d" % MAX_SEQUENCE_LENGTH)

    if isPairData:

        x1_train = tokenizer.texts_to_sequences(sen1_train)
        x2_train = tokenizer.texts_to_sequences(sen2_train)

        x1_val = tokenizer.texts_to_sequences(sen1_val)
        x2_val = tokenizer.texts_to_sequences(sen2_val)

        x1_test = tokenizer.texts_to_sequences(sen1_test)
        x2_test = tokenizer.texts_to_sequences(sen2_test)

        x1_train = pad_sequences(x1_train, maxlen=MAX_SEQUENCE_LENGTH)
        x2_train = pad_sequences(x2_train, maxlen=MAX_SEQUENCE_LENGTH)
        x_train = [x1_train, x2_train]

        x1_val = pad_sequences(x1_val, maxlen=MAX_SEQUENCE_LENGTH)
        x2_val = pad_sequences(x2_val, maxlen=MAX_SEQUENCE_LENGTH)
        x_val = [x1_val, x2_val]

        x1_test = pad_sequences(x1_test, maxlen=MAX_SEQUENCE_LENGTH)
        x2_test = pad_sequences(x2_test, maxlen=MAX_SEQUENCE_LENGTH)
        x_test = [x1_test, x2_test]

    else:
        x_train = pad_sequences(tokenizer.texts_to_sequences(x_train),
                                maxlen=MAX_SEQUENCE_LENGTH)
        x_val = pad_sequences(tokenizer.texts_to_sequences(x_val),
                              maxlen=MAX_SEQUENCE_LENGTH)
        x_test = pad_sequences(tokenizer.texts_to_sequences(x_test),
                               maxlen=MAX_SEQUENCE_LENGTH)

    # print(x_train.shape)

    return x_train, y_train, x_val, y_val, x_test, y_test, word_index, class_num, MAX_SEQUENCE_LENGTH
Beispiel #6
0
    def eval(self, name):
        # evaluate on evaluation [name], either takes string or list of strings
        if isinstance(name, list):
            self.results = {x: self.eval(x) for x in name}
            return self.results

        tpath = self.params.task_path
        assert name in self.list_tasks, str(name) + ' not in ' + str(
            self.list_tasks)
        max_seq_len, load_data, seed = self.params.max_seq_len, self.params.load_data, self.params.seed

        # Original SentEval tasks
        if name == 'CR':
            self.evaluation = CREval(tpath + '/downstream/CR',
                                     seed=self.params.seed)
        elif name == 'MR':
            self.evaluation = MREval(tpath + '/downstream/MR',
                                     seed=self.params.seed)
        elif name == 'MPQA':
            self.evaluation = MPQAEval(tpath + '/downstream/MPQA',
                                       seed=self.params.seed)
        elif name == 'SUBJ':
            self.evaluation = SUBJEval(tpath + '/downstream/SUBJ',
                                       seed=self.params.seed)
        elif name == 'SST2':
            self.evaluation = SSTEval(tpath + '/downstream/SST/binary',
                                      nclasses=2,
                                      seed=self.params.seed)
        elif name == 'SST5':
            self.evaluation = SSTEval(tpath + '/downstream/SST/fine',
                                      nclasses=5,
                                      seed=self.params.seed)
        elif name == 'TREC':
            self.evaluation = TRECEval(tpath + '/downstream/TREC',
                                       seed=self.params.seed)
        elif name == 'MRPC':
            self.evaluation = MRPCEval(tpath + '/downstream/MRPC',
                                       load_data=load_data,
                                       seed=self.params.seed)
        elif name == 'SICKRelatedness':
            self.evaluation = SICKRelatednessEval(tpath + '/downstream/SICK',
                                                  seed=self.params.seed)
        elif name == 'STSBenchmark':
            self.evaluation = STSBenchmarkEval(tpath +
                                               '/downstream/STS/STSBenchmark',
                                               load_data=load_data,
                                               seed=self.params.seed)
        elif name == 'SICKEntailment':
            self.evaluation = SICKEntailmentEval(tpath + '/downstream/SICK',
                                                 seed=self.params.seed)
        elif name == 'SNLI':
            self.evaluation = SNLIEval(tpath + '/downstream/SNLI',
                                       seed=self.params.seed)
        elif name in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16']:
            fpath = name + '-en-test'
            self.evaluation = eval(name + 'Eval')(tpath + '/downstream/STS/' +
                                                  fpath,
                                                  seed=self.params.seed)
        elif name == 'ImageCaptionRetrieval':
            self.evaluation = ImageCaptionRetrievalEval(tpath +
                                                        '/downstream/COCO',
                                                        seed=self.params.seed)

        # additional GLUE tasks; STS-B, SST2, MRPC are above
        # might want to have the same interface for these tasks as above

        elif name == 'MNLI':
            self.evaluation = MNLIEval(tpath + '/glue_data/MNLI',
                                       max_seq_len=max_seq_len,
                                       load_data=load_data,
                                       seed=seed)
        elif name == 'QQP':
            self.evaluation = QQPEval(tpath + '/glue_data/QQP',
                                      max_seq_len=max_seq_len,
                                      load_data=load_data,
                                      seed=seed)
        elif name == 'RTE':
            self.evaluation = RTEEval(tpath + '/glue_data/RTE',
                                      max_seq_len=max_seq_len,
                                      load_data=load_data,
                                      seed=seed)
        elif name == 'QNLI':
            self.evaluation = QNLIEval(tpath + '/glue_data/QNLI',
                                       max_seq_len=max_seq_len,
                                       load_data=load_data,
                                       seed=seed)
        elif name == 'QNLIv2':
            self.evaluation = QNLIv2Eval(tpath + '/glue_data/QNLIv2',
                                         max_seq_len=max_seq_len,
                                         load_data=load_data,
                                         seed=seed)
        elif name == 'WNLI':
            self.evaluation = WNLIEval(tpath + '/glue_data/WNLI',
                                       max_seq_len=max_seq_len,
                                       load_data=load_data,
                                       seed=seed)
        elif name == 'CoLA':
            self.evaluation = CoLAEval(tpath + '/glue_data/CoLA',
                                       max_seq_len=max_seq_len,
                                       load_data=load_data,
                                       seed=seed)
        elif name == 'ANLI':  # diagnostic dataset
            self.evaluation = ANLIEval(tpath + '/glue_data/ANLI',
                                       max_seq_len=max_seq_len,
                                       load_data=load_data,
                                       seed=seed)

        # Probing Tasks
        elif name == 'Length':
            self.evaluation = LengthEval(tpath + '/probing',
                                         seed=self.params.seed)
        elif name == 'WordContent':
            self.evaluation = WordContentEval(tpath + '/probing',
                                              seed=self.params.seed)
        elif name == 'Depth':
            self.evaluation = DepthEval(tpath + '/probing',
                                        seed=self.params.seed)
        elif name == 'TopConstituents':
            self.evaluation = TopConstituentsEval(tpath + '/probing',
                                                  seed=self.params.seed)
        elif name == 'BigramShift':
            self.evaluation = BigramShiftEval(tpath + '/probing',
                                              seed=self.params.seed)
        elif name == 'Tense':
            self.evaluation = TenseEval(tpath + '/probing',
                                        seed=self.params.seed)
        elif name == 'SubjNumber':
            self.evaluation = SubjNumberEval(tpath + '/probing',
                                             seed=self.params.seed)
        elif name == 'ObjNumber':
            self.evaluation = ObjNumberEval(tpath + '/probing',
                                            seed=self.params.seed)
        elif name == 'OddManOut':
            self.evaluation = OddManOutEval(tpath + '/probing',
                                            seed=self.params.seed)
        elif name == 'CoordinationInversion':
            self.evaluation = CoordinationInversionEval(tpath + '/probing',
                                                        seed=self.params.seed)

        self.params.current_task = name
        self.evaluation.do_prepare(self.params, self.prepare)

        self.results = self.evaluation.run(self.params, self.batcher)

        return self.results
Beispiel #7
0
    def eval(self, name):
        # evaluate on evaluation [name], either takes string or list of strings
        if (isinstance(name, list)):
            self.results = {x: self.eval(x) for x in name}
            return self.results

        tpath = self.params.task_path
        assert name in self.list_tasks, str(name) + ' not in ' + str(
            self.list_tasks)

        start = time.time()

        # Original SentEval tasks
        if name == 'CR':
            self.evaluation = CREval(tpath + '/downstream/CR',
                                     seed=self.params.seed)
        elif name == 'MR':
            self.evaluation = MREval(tpath + '/downstream/MR',
                                     seed=self.params.seed)
        elif name == 'MPQA':
            self.evaluation = MPQAEval(tpath + '/downstream/MPQA',
                                       seed=self.params.seed)
        elif name == 'SUBJ':
            self.evaluation = SUBJEval(tpath + '/downstream/SUBJ',
                                       seed=self.params.seed)
        elif name == 'SST2':
            self.evaluation = SSTEval(tpath + '/downstream/SST/binary',
                                      nclasses=2,
                                      seed=self.params.seed)
        elif name == 'SST5':
            self.evaluation = SSTEval(tpath + '/downstream/SST/fine',
                                      nclasses=5,
                                      seed=self.params.seed)
        elif name == 'TREC':
            self.evaluation = TRECEval(tpath + '/downstream/TREC',
                                       seed=self.params.seed)
        elif name == 'MRPC':
            self.evaluation = MRPCEval(tpath + '/downstream/MRPC',
                                       seed=self.params.seed)
        elif name == 'SICKRelatedness':
            self.evaluation = SICKRelatednessEval(tpath + '/downstream/SICK',
                                                  seed=self.params.seed)
        elif name == 'STSBenchmark':
            self.evaluation = STSBenchmarkEval(tpath +
                                               '/downstream/STS/STSBenchmark',
                                               seed=self.params.seed)
        elif name == 'SICKEntailment':
            self.evaluation = SICKEntailmentEval(tpath + '/downstream/SICK',
                                                 seed=self.params.seed)
        elif name == 'SNLI':
            self.evaluation = SNLIEval(tpath + '/downstream/SNLI',
                                       seed=self.params.seed)
        elif name in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16']:
            fpath = name + '-en-test'
            self.evaluation = eval(name + 'Eval')(tpath + '/downstream/STS/' +
                                                  fpath,
                                                  seed=self.params.seed)
        elif name == 'ImageCaptionRetrieval':
            self.evaluation = ImageCaptionRetrievalEval(tpath +
                                                        '/downstream/COCO',
                                                        seed=self.params.seed)

        # added tasks
        elif name == 'BEAN' or name == 'MASC':
            self.evaluation = BeanMascEval(osp.join(tpath, 'downstream', name),
                                           name,
                                           seed=self.params.seed)
        elif name == 'AmBrit':
            self.evaluation = AmBritEval(tpath + '/downstream/AmBrit',
                                         seed=self.params.seed)
        elif name == 'AmazonJa':
            self.evaluation = AmazonJaEval(osp.join(tpath, 'downstream', name),
                                           seed=self.params.seed)
        elif name == 'Rite2JaBC-Entailment':
            self.evaluation = Rite2JaBCEntailmentEval(osp.join(
                tpath, 'downstream', 'Rite2'),
                                                      seed=self.params.seed)
        elif name == 'FormalityJa':
            self.evaluation = FormalityJaEval(osp.join(tpath, 'downstream',
                                                       name),
                                              seed=self.params.seed)
        elif name == 'StyleSimJa':
            self.evaluation = StyleSimJaEval(
                osp.join(tpath, 'downstream', name))
        elif name == 'WordContentJapanese':
            self.evaluation = WordContentJapaneseEval(tpath + '/probing',
                                                      seed=self.params.seed)

        # Probing Tasks
        elif name == 'Length':
            self.evaluation = LengthEval(tpath + '/probing',
                                         seed=self.params.seed)
        elif name == 'WordContent':
            self.evaluation = WordContentEval(tpath + '/probing',
                                              seed=self.params.seed)
        elif name == 'Depth':
            self.evaluation = DepthEval(tpath + '/probing',
                                        seed=self.params.seed)
        elif name == 'TopConstituents':
            self.evaluation = TopConstituentsEval(tpath + '/probing',
                                                  seed=self.params.seed)
        elif name == 'BigramShift':
            self.evaluation = BigramShiftEval(tpath + '/probing',
                                              seed=self.params.seed)
        elif name == 'Tense':
            self.evaluation = TenseEval(tpath + '/probing',
                                        seed=self.params.seed)
        elif name == 'SubjNumber':
            self.evaluation = SubjNumberEval(tpath + '/probing',
                                             seed=self.params.seed)
        elif name == 'ObjNumber':
            self.evaluation = ObjNumberEval(tpath + '/probing',
                                            seed=self.params.seed)
        elif name == 'OddManOut':
            self.evaluation = OddManOutEval(tpath + '/probing',
                                            seed=self.params.seed)
        elif name == 'CoordinationInversion':
            self.evaluation = CoordinationInversionEval(tpath + '/probing',
                                                        seed=self.params.seed)

        self.params.current_task = name
        self.evaluation.do_prepare(self.params, self.prepare)

        self.results = self.evaluation.run(self.params, self.batcher)

        end = time.time()
        print(f'Eval {name} took {end - start} s')

        return self.results