def test_process_shape(self):
     dummy_input_bs_1 = [[["this", "is", "for", "the", "unittest"],
                          ["NOUN", "VERB", "PREP", "ART", "NOUN"],
                          ["", "", "", "", "MODULE"]]]
     dummy_input_bs_5 = [[["this", "is", "for", "the", "unittest"],
                          ["NOUN", "VERB", "PREP", "ART", "NOUN"],
                          ["", "", "", "", "MODULE"]],
                         [["batch", "2"], ["NOUN", "NUM"], ["", ""]],
                         [["batch", "3", "is", "the", "longest", "batch"],
                          ["NOUN", "NUM", "VERB", "ART", "ADJ", "NOUN"],
                          ["", "", "", "", "", ""]],
                         [["fourth", "batch"], ["ORD", "NOUN"], ["", ""]],
                         [["and", "another", "one"], ["CONJ", "?", "NUM"],
                          ["", "", ""]]]
     for bs, max_len, dummy_input in [(1, 5, dummy_input_bs_1),
                                      (5, 6, dummy_input_bs_5)]:
         for init_case, params in itertools.product(self.INIT_CASES,
                                                    self.PARAMS):
             init_case = self.initialize_case(init_case, params)
             mf = TextMultiField(**init_case)
             fields = [init_case["base_field"]] \
                 + [f for _, f in init_case["feats_fields"]]
             nfields = len(fields)
             for i, f in enumerate(fields):
                 all_sents = [b[i] for b in dummy_input]
                 f.build_vocab(all_sents)
             inp_only_desired_fields = [b[:nfields] for b in dummy_input]
             data = mf.process(inp_only_desired_fields)
             if params["include_lengths"]:
                 data, lengths = data
                 self.assertEqual(lengths.shape, (bs, ))
             expected_shape = (max_len, bs, nfields)
             self.assertEqual(data.shape, expected_shape)
예제 #2
0
 def test_preprocess_shape(self):
     for init_case, params in itertools.product(
             self.INIT_CASES, self.PARAMS):
         init_case = self.initialize_case(init_case, params)
         mf = TextMultiField(**init_case)
         sample_str = "dummy input here ."
         proc = mf.preprocess(sample_str)
         self.assertEqual(len(proc), len(init_case["feats_fields"]) + 1)
 def test_preprocess_shape(self):
     for init_case, params in itertools.product(self.INIT_CASES,
                                                self.PARAMS):
         init_case = self.initialize_case(init_case, params)
         mf = TextMultiField(**init_case)
         sample_str = "dummy input here ."
         proc = mf.preprocess(sample_str)
         self.assertEqual(len(proc), len(init_case["feats_fields"]) + 1)
예제 #4
0
    def test_preprocess_shape(self):
        for init_case, params in itertools.product(
                self.INIT_CASES, self.PARAMS):
            init_case = self.initialize_case(init_case, params)
            mf = TextMultiField(**init_case)

            sample_str = {
                "base_field": "dummy input here .",
                "a": "A A B D",
                "r": "C C C C",
                "b": "D F E D",
                "zbase_field": "another dummy input ."
            }
            proc = mf.preprocess(sample_str)
            self.assertEqual(len(proc), len(init_case["feats_fields"]) + 1)
예제 #5
0
 def test_correct_n_fields(self):
     for init_case, params in itertools.product(
             self.INIT_CASES, self.PARAMS):
         init_case = self.initialize_case(init_case, params)
         mf = TextMultiField(**init_case)
         self.assertEqual(len(mf.fields),
                          len(init_case["feats_fields"]) + 1)
예제 #6
0
 def test_getitem_0_returns_correct_field(self):
     for init_case, params in itertools.product(
             self.INIT_CASES, self.PARAMS):
         init_case = self.initialize_case(init_case, params)
         mf = TextMultiField(**init_case)
         self.assertEqual(mf[0][0], init_case["base_name"])
         self.assertIs(mf[0][1], init_case["base_field"])
예제 #7
0
def load_old_vocab(vocab, data_type="text", dynamic_dict=False):
    """
    vocab: a list of (field name, torchtext.vocab.Vocab) pairs. This is the
           format formerly saved in *.vocab.pt files.
    data_type: text, img, or audio
    returns: a dictionary whose keys are the field names and whose values
             are lists of (name, Field) pairs
    """
    if _old_style_field_list(vocab):  # upgrade to multifield
        fields = vocab
        for base_name, vals in fields.items():
            if ((base_name == 'src' and data_type == 'text')
                    or base_name == 'tgt'):
                assert not isinstance(vals[0][1], TextMultiField)
                fields[base_name] = [(base_name,
                                      TextMultiField(vals[0][0], vals[0][1],
                                                     vals[1:]))]
        return fields
    vocab = dict(vocab)
    n_src_features = sum('src_feat_' in k for k in vocab)
    n_tgt_features = sum('tgt_feat_' in k for k in vocab)
    fields = get_fields(data_type,
                        n_src_features,
                        n_tgt_features,
                        dynamic_dict=dynamic_dict)
    for k, vals in fields.items():
        for n, f in vals:
            try:
                f_iter = iter(f)
            except TypeError:
                f_iter = [(n, f)]
            for sub_n, sub_f in f_iter:
                if sub_n in vocab:
                    sub_f.vocab = vocab[sub_n]
    return fields
예제 #8
0
 def test_getitem_has_correct_number_of_indexes(self):
     for init_case, params in itertools.product(
             self.INIT_CASES, self.PARAMS):
         init_case = self.initialize_case(init_case, params)
         mf = TextMultiField(**init_case)
         nfields = len(init_case["feats_fields"]) + 1
         with self.assertRaises(IndexError):
             mf[nfields]
예제 #9
0
 def test_fields_order_correct(self):
     for init_case, params in itertools.product(
             self.INIT_CASES, self.PARAMS):
         init_case = self.initialize_case(init_case, params)
         mf = TextMultiField(**init_case)
         fnames = [name for name, _ in init_case["feats_fields"]]
         correct_order = [init_case["base_name"]] + list(sorted(fnames))
         self.assertEqual([name for name, _ in mf.fields], correct_order)
예제 #10
0
def load_old_vocab(vocab, data_type="text", dynamic_dict=False):
    """Update a legacy vocab/field format.

    Args:
        vocab: a list of (field name, torchtext.vocab.Vocab) pairs. This is the
            format formerly saved in *.vocab.pt files. Or, text data
            not using a :class:`TextMultiField`.
        data_type (str): text, img, or audio
        dynamic_dict (bool): Used for copy attention.

    Returns:
        a dictionary whose keys are the field names and whose values Fields.
    """

    if _old_style_vocab(vocab):
        # List[Tuple[str, Vocab]] -> List[Tuple[str, Field]]
        # -> dict[str, Field]
        vocab = dict(vocab)
        n_src_features = sum('src_feat_' in k for k in vocab)
        n_tgt_features = sum('tgt_feat_' in k for k in vocab)
        fields = get_fields(
            data_type, n_src_features, n_tgt_features,
            dynamic_dict=dynamic_dict)
        for n, f in fields.items():
            try:
                f_iter = iter(f)
            except TypeError:
                f_iter = [(n, f)]
            for sub_n, sub_f in f_iter:
                if sub_n in vocab:
                    sub_f.vocab = vocab[sub_n]
        return fields

    if _old_style_field_list(vocab):  # upgrade to multifield
        # Dict[str, List[Tuple[str, Field]]]
        # doesn't change structure - don't return early.
        fields = vocab
        for base_name, vals in fields.items():
            if ((base_name == 'src' and (data_type == 'text' or data_type == 'keyphrase')) or
                    base_name == 'tgt'):
                # assert not isinstance(vals[0][1], TextMultiField)
                # changed by @memray, to solve the problem of cannot find vocab while loading dataset
                if isinstance(vals[0][1], TextMultiField):
                    fields[base_name] = [(base_name, TextMultiField(
                        vals[0][0], vals[0][1].base_field, vals[1:]))]
                elif isinstance(vals[0][1], KeyphraseField):
                    fields[base_name] = [(base_name, KeyphraseField(
                        vals[0][0], vals[0][1].base_field))]

    if _old_style_nesting(vocab):
        # Dict[str, List[Tuple[str, Field]]] -> List[Tuple[str, Field]]
        # -> dict[str, Field]
        fields = dict(list(chain.from_iterable(vocab.values())))

    return fields
예제 #11
0
 def test_getitem_nonzero_returns_correct_field(self):
     for init_case, params in itertools.product(
             self.INIT_CASES, self.PARAMS):
         init_case = self.initialize_case(init_case, params)
         mf = TextMultiField(**init_case)
         fnames = [name for name, _ in init_case["feats_fields"]]
         if len(fnames) > 0:
             ordered_names = list(sorted(fnames))
             name2field = dict(init_case["feats_fields"])
             for i, name in enumerate(ordered_names, 1):
                 expected_field = name2field[name]
                 self.assertIs(mf[i][1], expected_field)
예제 #12
0
 def test_process_shape(self):
     dummy_input_bs_1 = [[
             ["this", "is", "for", "the", "unittest"],
             ["NOUN", "VERB", "PREP", "ART", "NOUN"],
             ["", "", "", "", "MODULE"]]]
     dummy_input_bs_5 = [
             [["this", "is", "for", "the", "unittest"],
              ["NOUN", "VERB", "PREP", "ART", "NOUN"],
              ["", "", "", "", "MODULE"]],
             [["batch", "2"],
              ["NOUN", "NUM"],
              ["", ""]],
             [["batch", "3", "is", "the", "longest", "batch"],
              ["NOUN", "NUM", "VERB", "ART", "ADJ", "NOUN"],
              ["", "", "", "", "", ""]],
             [["fourth", "batch"],
              ["ORD", "NOUN"],
              ["", ""]],
             [["and", "another", "one"],
              ["CONJ", "?", "NUM"],
              ["", "", ""]]]
     for bs, max_len, dummy_input in [
             (1, 5, dummy_input_bs_1), (5, 6, dummy_input_bs_5)]:
         for init_case, params in itertools.product(
                 self.INIT_CASES, self.PARAMS):
             init_case = self.initialize_case(init_case, params)
             mf = TextMultiField(**init_case)
             fields = [init_case["base_field"]] \
                 + [f for _, f in init_case["feats_fields"]]
             nfields = len(fields)
             for i, f in enumerate(fields):
                 all_sents = [b[i] for b in dummy_input]
                 f.build_vocab(all_sents)
             inp_only_desired_fields = [b[:nfields] for b in dummy_input]
             data = mf.process(inp_only_desired_fields)
             if params["include_lengths"]:
                 data, lengths = data
                 self.assertEqual(lengths.shape, (bs,))
             expected_shape = (max_len, bs, nfields)
             self.assertEqual(data.shape, expected_shape)
예제 #13
0
 def test_base_field(self):
     for init_case, params in itertools.product(
             self.INIT_CASES, self.PARAMS):
         init_case = self.initialize_case(init_case, params)
         mf = TextMultiField(**init_case)
         self.assertIs(mf.base_field, init_case["base_field"])