def test_process_shape(self): dummy_input_bs_1 = [[["this", "is", "for", "the", "unittest"], ["NOUN", "VERB", "PREP", "ART", "NOUN"], ["", "", "", "", "MODULE"]]] dummy_input_bs_5 = [[["this", "is", "for", "the", "unittest"], ["NOUN", "VERB", "PREP", "ART", "NOUN"], ["", "", "", "", "MODULE"]], [["batch", "2"], ["NOUN", "NUM"], ["", ""]], [["batch", "3", "is", "the", "longest", "batch"], ["NOUN", "NUM", "VERB", "ART", "ADJ", "NOUN"], ["", "", "", "", "", ""]], [["fourth", "batch"], ["ORD", "NOUN"], ["", ""]], [["and", "another", "one"], ["CONJ", "?", "NUM"], ["", "", ""]]] for bs, max_len, dummy_input in [(1, 5, dummy_input_bs_1), (5, 6, dummy_input_bs_5)]: for init_case, params in itertools.product(self.INIT_CASES, self.PARAMS): init_case = self.initialize_case(init_case, params) mf = TextMultiField(**init_case) fields = [init_case["base_field"]] \ + [f for _, f in init_case["feats_fields"]] nfields = len(fields) for i, f in enumerate(fields): all_sents = [b[i] for b in dummy_input] f.build_vocab(all_sents) inp_only_desired_fields = [b[:nfields] for b in dummy_input] data = mf.process(inp_only_desired_fields) if params["include_lengths"]: data, lengths = data self.assertEqual(lengths.shape, (bs, )) expected_shape = (max_len, bs, nfields) self.assertEqual(data.shape, expected_shape)
def test_preprocess_shape(self): for init_case, params in itertools.product( self.INIT_CASES, self.PARAMS): init_case = self.initialize_case(init_case, params) mf = TextMultiField(**init_case) sample_str = "dummy input here ." proc = mf.preprocess(sample_str) self.assertEqual(len(proc), len(init_case["feats_fields"]) + 1)
def test_preprocess_shape(self): for init_case, params in itertools.product(self.INIT_CASES, self.PARAMS): init_case = self.initialize_case(init_case, params) mf = TextMultiField(**init_case) sample_str = "dummy input here ." proc = mf.preprocess(sample_str) self.assertEqual(len(proc), len(init_case["feats_fields"]) + 1)
def test_preprocess_shape(self): for init_case, params in itertools.product( self.INIT_CASES, self.PARAMS): init_case = self.initialize_case(init_case, params) mf = TextMultiField(**init_case) sample_str = { "base_field": "dummy input here .", "a": "A A B D", "r": "C C C C", "b": "D F E D", "zbase_field": "another dummy input ." } proc = mf.preprocess(sample_str) self.assertEqual(len(proc), len(init_case["feats_fields"]) + 1)
def test_correct_n_fields(self): for init_case, params in itertools.product( self.INIT_CASES, self.PARAMS): init_case = self.initialize_case(init_case, params) mf = TextMultiField(**init_case) self.assertEqual(len(mf.fields), len(init_case["feats_fields"]) + 1)
def test_getitem_0_returns_correct_field(self): for init_case, params in itertools.product( self.INIT_CASES, self.PARAMS): init_case = self.initialize_case(init_case, params) mf = TextMultiField(**init_case) self.assertEqual(mf[0][0], init_case["base_name"]) self.assertIs(mf[0][1], init_case["base_field"])
def load_old_vocab(vocab, data_type="text", dynamic_dict=False): """ vocab: a list of (field name, torchtext.vocab.Vocab) pairs. This is the format formerly saved in *.vocab.pt files. data_type: text, img, or audio returns: a dictionary whose keys are the field names and whose values are lists of (name, Field) pairs """ if _old_style_field_list(vocab): # upgrade to multifield fields = vocab for base_name, vals in fields.items(): if ((base_name == 'src' and data_type == 'text') or base_name == 'tgt'): assert not isinstance(vals[0][1], TextMultiField) fields[base_name] = [(base_name, TextMultiField(vals[0][0], vals[0][1], vals[1:]))] return fields vocab = dict(vocab) n_src_features = sum('src_feat_' in k for k in vocab) n_tgt_features = sum('tgt_feat_' in k for k in vocab) fields = get_fields(data_type, n_src_features, n_tgt_features, dynamic_dict=dynamic_dict) for k, vals in fields.items(): for n, f in vals: try: f_iter = iter(f) except TypeError: f_iter = [(n, f)] for sub_n, sub_f in f_iter: if sub_n in vocab: sub_f.vocab = vocab[sub_n] return fields
def test_getitem_has_correct_number_of_indexes(self): for init_case, params in itertools.product( self.INIT_CASES, self.PARAMS): init_case = self.initialize_case(init_case, params) mf = TextMultiField(**init_case) nfields = len(init_case["feats_fields"]) + 1 with self.assertRaises(IndexError): mf[nfields]
def test_fields_order_correct(self): for init_case, params in itertools.product( self.INIT_CASES, self.PARAMS): init_case = self.initialize_case(init_case, params) mf = TextMultiField(**init_case) fnames = [name for name, _ in init_case["feats_fields"]] correct_order = [init_case["base_name"]] + list(sorted(fnames)) self.assertEqual([name for name, _ in mf.fields], correct_order)
def load_old_vocab(vocab, data_type="text", dynamic_dict=False): """Update a legacy vocab/field format. Args: vocab: a list of (field name, torchtext.vocab.Vocab) pairs. This is the format formerly saved in *.vocab.pt files. Or, text data not using a :class:`TextMultiField`. data_type (str): text, img, or audio dynamic_dict (bool): Used for copy attention. Returns: a dictionary whose keys are the field names and whose values Fields. """ if _old_style_vocab(vocab): # List[Tuple[str, Vocab]] -> List[Tuple[str, Field]] # -> dict[str, Field] vocab = dict(vocab) n_src_features = sum('src_feat_' in k for k in vocab) n_tgt_features = sum('tgt_feat_' in k for k in vocab) fields = get_fields( data_type, n_src_features, n_tgt_features, dynamic_dict=dynamic_dict) for n, f in fields.items(): try: f_iter = iter(f) except TypeError: f_iter = [(n, f)] for sub_n, sub_f in f_iter: if sub_n in vocab: sub_f.vocab = vocab[sub_n] return fields if _old_style_field_list(vocab): # upgrade to multifield # Dict[str, List[Tuple[str, Field]]] # doesn't change structure - don't return early. fields = vocab for base_name, vals in fields.items(): if ((base_name == 'src' and (data_type == 'text' or data_type == 'keyphrase')) or base_name == 'tgt'): # assert not isinstance(vals[0][1], TextMultiField) # changed by @memray, to solve the problem of cannot find vocab while loading dataset if isinstance(vals[0][1], TextMultiField): fields[base_name] = [(base_name, TextMultiField( vals[0][0], vals[0][1].base_field, vals[1:]))] elif isinstance(vals[0][1], KeyphraseField): fields[base_name] = [(base_name, KeyphraseField( vals[0][0], vals[0][1].base_field))] if _old_style_nesting(vocab): # Dict[str, List[Tuple[str, Field]]] -> List[Tuple[str, Field]] # -> dict[str, Field] fields = dict(list(chain.from_iterable(vocab.values()))) return fields
def test_getitem_nonzero_returns_correct_field(self): for init_case, params in itertools.product( self.INIT_CASES, self.PARAMS): init_case = self.initialize_case(init_case, params) mf = TextMultiField(**init_case) fnames = [name for name, _ in init_case["feats_fields"]] if len(fnames) > 0: ordered_names = list(sorted(fnames)) name2field = dict(init_case["feats_fields"]) for i, name in enumerate(ordered_names, 1): expected_field = name2field[name] self.assertIs(mf[i][1], expected_field)
def test_process_shape(self): dummy_input_bs_1 = [[ ["this", "is", "for", "the", "unittest"], ["NOUN", "VERB", "PREP", "ART", "NOUN"], ["", "", "", "", "MODULE"]]] dummy_input_bs_5 = [ [["this", "is", "for", "the", "unittest"], ["NOUN", "VERB", "PREP", "ART", "NOUN"], ["", "", "", "", "MODULE"]], [["batch", "2"], ["NOUN", "NUM"], ["", ""]], [["batch", "3", "is", "the", "longest", "batch"], ["NOUN", "NUM", "VERB", "ART", "ADJ", "NOUN"], ["", "", "", "", "", ""]], [["fourth", "batch"], ["ORD", "NOUN"], ["", ""]], [["and", "another", "one"], ["CONJ", "?", "NUM"], ["", "", ""]]] for bs, max_len, dummy_input in [ (1, 5, dummy_input_bs_1), (5, 6, dummy_input_bs_5)]: for init_case, params in itertools.product( self.INIT_CASES, self.PARAMS): init_case = self.initialize_case(init_case, params) mf = TextMultiField(**init_case) fields = [init_case["base_field"]] \ + [f for _, f in init_case["feats_fields"]] nfields = len(fields) for i, f in enumerate(fields): all_sents = [b[i] for b in dummy_input] f.build_vocab(all_sents) inp_only_desired_fields = [b[:nfields] for b in dummy_input] data = mf.process(inp_only_desired_fields) if params["include_lengths"]: data, lengths = data self.assertEqual(lengths.shape, (bs,)) expected_shape = (max_len, bs, nfields) self.assertEqual(data.shape, expected_shape)
def test_base_field(self): for init_case, params in itertools.product( self.INIT_CASES, self.PARAMS): init_case = self.initialize_case(init_case, params) mf = TextMultiField(**init_case) self.assertIs(mf.base_field, init_case["base_field"])