Beispiel #1
0
def get_fields_with_postag(src_data_type,
                           n_src_feats,
                           n_src_pos_feats,
                           n_tgt_feats,
                           pad='<blank>',
                           bos='<s>',
                           eos='</s>',
                           dynamic_dict=False,
                           with_align=False,
                           src_truncate=None,
                           tgt_truncate=None):
    fields = get_fields(src_data_type,
                        n_src_feats,
                        n_tgt_feats,
                        pad=pad,
                        bos=bos,
                        eos=eos,
                        dynamic_dict=dynamic_dict,
                        with_align=with_align,
                        src_truncate=src_truncate,
                        tgt_truncate=tgt_truncate)

    src_pos_field_kwargs = {
        "n_feats": n_src_pos_feats,
        "include_lengths": True,
        "pad": pad,
        "bos": None,
        "eos": None,
        "truncate": src_truncate,
        "base_name": "src_pos"
    }
    fields["src_pos"] = text_fields(**src_pos_field_kwargs)
    return fields
    def get_fields(
        cls,
        src_types: List[str],
        n_src_feats: int,
        n_tgt_feats: int,
        pad: str = '<blank>',
        bos: str = '<s>',
        eos: str = '</s>',
        dynamic_dict: bool = False,
        src_truncate: Optional[int] = None,
        tgt_truncate: Optional[int] = None,
    ) -> Dict[str, Union[Field, TextMultiField]]:
        """
        Args:
            src_data_type: type of the source input. Options are [text|img|audio].
            n_src_feats (int): the number of source features (not counting tokens)
                to create a :class:`torchtext.data.Field` for. (If
                ``src_data_type=="text"``, these fields are stored together
                as a ``TextMultiField``).
            n_tgt_feats (int): See above.
            pad (str): Special pad symbol. Used on src and tgt side.
            bos (str): Special beginning of sequence symbol. Only relevant
                for tgt.
            eos (str): Special end of sequence symbol. Only relevant
                for tgt.
            dynamic_dict (bool): Whether or not to include source map and
                alignment fields.
            src_truncate: Cut off src sequences beyond this (passed to
                ``src_data_type``'s data reader - see there for more details).
            tgt_truncate: Cut off tgt sequences beyond this (passed to
                :class:`TextDataReader` - see there for more details).

        Returns:
            A dict mapping names to fields. These names need to match
            the dataset example attributes.
        """
        # PN: here I removed data types other than "text", to make things easier
        # assert src_data_type == 'text', "Only text is supported in multi-source"
        # assert not dynamic_dict or src_data_type == 'text', 'it is not possible to use dynamic_dict with non-text input'
        fields: Dict = {}

        for src_type in src_types:
            src_field_kwargs = {
                "n_feats": n_src_feats,
                "include_lengths": True,
                "pad": pad,
                "bos": None,
                "eos": None,
                "truncate": src_truncate,
                "base_name": "src"
            }
            fields[f"src.{src_type}"] = text_fields(**src_field_kwargs)
        # end for

        tgt_field_kwargs = {
            "n_feats": n_tgt_feats,
            "include_lengths": False,
            "pad": pad,
            "bos": bos,
            "eos": eos,
            "truncate": tgt_truncate,
            "base_name": "tgt"
        }
        fields["tgt"] = text_fields(**tgt_field_kwargs)

        indices = Field(use_vocab=False, dtype=torch.long, sequential=False)
        fields["indices"] = indices

        if dynamic_dict:
            for src_type in src_types:
                src_map = Field(use_vocab=False,
                                dtype=torch.float,
                                postprocessing=cls.make_src,
                                sequential=False)
                fields[f"src_map.{src_type}"] = src_map
            # end for

            src_ex_vocab = RawField()
            fields["src_ex_vocab"] = src_ex_vocab

            align = Field(use_vocab=False,
                          dtype=torch.long,
                          postprocessing=cls.make_tgt,
                          sequential=False)
            fields["alignment"] = align
        # end if

        return fields
Beispiel #3
0
def get_fields(
    src_data_type,
    n_src_feats,
    n_tgt_feats,
    pad='<blank>',
    bos='<s>',
    eos='</s>',
    dynamic_dict=False,
    src_truncate=None,
    tgt_truncate=None
):
    """
    Args:
        src_data_type: type of the source input. Options are [text|img|audio].
        n_src_feats (int): the number of source features (not counting tokens)
            to create a :class:`torchtext.data.Field` for. (If
            ``src_data_type=="text"``, these fields are stored together
            as a ``TextMultiField``).
        n_tgt_feats (int): See above.
        pad (str): Special pad symbol. Used on src and tgt side.
        bos (str): Special beginning of sequence symbol. Only relevant
            for tgt.
        eos (str): Special end of sequence symbol. Only relevant
            for tgt.
        dynamic_dict (bool): Whether or not to include source map and
            alignment fields.
        src_truncate: Cut off src sequences beyond this (passed to
            ``src_data_type``'s data reader - see there for more details).
        tgt_truncate: Cut off tgt sequences beyond this (passed to
            :class:`TextDataReader` - see there for more details).

    Returns:
        A dict mapping names to fields. These names need to match
        the dataset example attributes.
    """

    assert src_data_type in ['text', 'img', 'audio', 'vec', 'keyphrase'], \
        "Data type not implemented"
    assert not dynamic_dict or src_data_type == 'text' or src_data_type == 'keyphrase', \
        'it is not possible to use dynamic_dict with non-text input'
    fields = {}

    fields_getters = {"text": text_fields,
                      "img": image_fields,
                      "audio": audio_fields,
                      "vec": vec_fields,
                      "keyphrase": text_fields}

    src_field_kwargs = {"n_feats": n_src_feats,
                        "include_lengths": True,
                        "pad": pad, "bos": None, "eos": None,
                        "truncate": src_truncate,
                        "base_name": "src"}
    fields["src"] = fields_getters[src_data_type](**src_field_kwargs)

    tgt_field_kwargs = {"n_feats": n_tgt_feats,
                        "include_lengths": False,
                        "pad": pad, "bos": bos, "eos": eos, "sep": keyphrase_dataset.SEP_token,
                        "truncate": tgt_truncate,
                        "base_name": "tgt"}

    if src_data_type == "keyphrase":
        fields['tgt'] = keyphrase_fields(**tgt_field_kwargs)
    else:
        fields['tgt'] = text_fields(**tgt_field_kwargs)

    indices = Field(use_vocab=False, dtype=torch.long, sequential=False)
    fields["indices"] = indices

    if dynamic_dict:
        src_map = Field(
            use_vocab=False, dtype=torch.float,
            postprocessing=make_src, sequential=False)
        fields["src_map"] = src_map

        src_ex_vocab = RawField()
        fields["src_ex_vocab"] = src_ex_vocab

        align = Field(
            use_vocab=False, dtype=torch.long,
            postprocessing=make_tgt, sequential=False)
        fields["alignment"] = align

    if src_data_type == 'keyphrase':
        id = Field(use_vocab=False, dtype=torch.long, sequential=False)
        fields["id"] = id

        # for Orthogonal Regularization and Semantic Coverage
        sep_indices = Field(
            use_vocab=False, dtype=torch.long,
            postprocessing=make_tgt, sequential=False)
        fields["sep_indices"] = sep_indices

    return fields