def get_text_field():
    text_field = TextFeatureField(DatasetFieldName.TEXT_FIELD)
    text_field.build_vocab([])
    vocab_size = len(text_field.vocab.itos)
    text_field.vocab.itos.extend(["good", "boy"])
    text_field.vocab.stoi.update({"good": vocab_size, "boy": vocab_size + 1})
    print(text_field.vocab.itos)
    return text_field
Exemple #2
0
    def from_config(cls, config: Config, feature_config: FeatureConfig, *args,
                    **kwargs):
        """
        Factory method to construct an instance of `LanguageModelDataHandler`
        from the module's config object and feature config object.

        Args:
            config (LanguageModelDataHandler.Config): Configuration object
                specifying all the parameters of `LanguageModelDataHandler`.
            feature_config (FeatureConfig): Configuration object specifying all
                the parameters of all input features.

        Returns:
            type: An instance of `LanguageModelDataHandler`.
        """
        # For language modeling the only input is a collection of utterances.
        # The input and the labels are created by the LangaugeModelDataHandler.
        # The input at time step t+1 becomes a label for the input at time step t.
        word_feat_config = feature_config.word_feat
        features: Dict[str, Field] = {
            DatasetFieldName.TEXT_FIELD:
            TextFeatureField(
                eos_token=VocabMeta.EOS_TOKEN if config.append_eos else None,
                init_token=VocabMeta.INIT_TOKEN if config.append_bos else None,
                pretrained_embeddings_path=word_feat_config.
                pretrained_embeddings_path,
                embed_dim=word_feat_config.embed_dim,
                embedding_init_strategy=word_feat_config.
                embedding_init_strategy,
                vocab_file=word_feat_config.vocab_file,
                vocab_size=word_feat_config.vocab_size,
                vocab_from_train_data=word_feat_config.vocab_from_train_data,
            )
        }
        labels: Dict[str, Field] = {}
        extra_fields: Dict[str, Field] = {
            DatasetFieldName.UTTERANCE_FIELD: RawField()
        }
        return cls(raw_columns=config.columns_to_read,
                   features=features,
                   labels=labels,
                   extra_fields=extra_fields,
                   train_path=config.train_path,
                   eval_path=config.eval_path,
                   test_path=config.test_path,
                   train_batch_size=config.train_batch_size,
                   eval_batch_size=config.eval_batch_size,
                   test_batch_size=config.test_batch_size,
                   **kwargs)
    def from_config(
        cls,
        config: Config,
        feature_config: FeatureConfig,
        label_config: WordLabelConfig,
        **kwargs
    ):
        """
        Factory method to construct an instance of `BPTTLanguageModelDataHandler`
        from the module's config object and feature config object.

        Args:
            config (LanguageModelDataHandler.Config): Configuration object
                specifying all the parameters of `BPTTLanguageModelDataHandler`.
            feature_config (FeatureConfig): Configuration object specifying all
                the parameters of all input features.

        Returns:
            type: An instance of `BPTTLanguageModelDataHandler`.
        """
        # For language modeling the only input is a collection of utterances.
        # The input and the labels are created by the LangaugeModelDataHandler.
        # The input at time step t+1 becomes a label for the input at time step t.
        columns = config.columns_to_read
        bptt_len = config.bptt_len
        if bptt_len <= 0:
            raise TypeError("BPTT Sequence length cannot be 0 or less.")
        features = {
            # the name must be text because it's hardcoded in torchtext BPTT iterator
            "text": TextFeatureField(
                eos_token=VocabMeta.EOS_TOKEN, include_lengths=False
            )
        }
        return cls(
            bptt_len=bptt_len,
            raw_columns=columns,
            features=features,
            labels={},
            extra_fields={},
            train_path=config.train_path,
            eval_path=config.eval_path,
            test_path=config.test_path,
            train_batch_size=config.train_batch_size,
            eval_batch_size=config.eval_batch_size,
            test_batch_size=config.test_batch_size,
            pass_index=False,
            **kwargs
        )
Exemple #4
0
    def create_language_model_data_handler(cls) -> LanguageModelDataHandler:
        # TODO: Refactor this after Shicong refactors PyText config and removes
        # Thrift. After that directly use Data Handler's from config method
        # with synthetic configs
        columns = [DFColumn.UTTERANCE]
        features: Dict[str, Field] = {
            DatasetFieldName.TEXT_FIELD: TextFeatureField(
                eos_token=VocabMeta.EOS_TOKEN, init_token=VocabMeta.INIT_TOKEN
            )
        }

        return LanguageModelDataHandler(
            raw_columns=columns,
            features=features,
            labels={},
            featurizer=create_featurizer(SimpleFeaturizer.Config(), FeatureConfig()),
        )
Exemple #5
0
    def from_config(
        cls,
        config: Config,
        feature_config: ModelInputConfig,
        target_config: None,
        **kwargs,
    ):
        features: Dict[str, Field] = {
            ModelInput.POS_RESPONSE:
            TextFeatureField.from_config(feature_config.pos_response)
        }
        # we want vocab to be built once across all fields
        # so we just make all features point to POS_RESPONSE
        # TODO: if we're reading pretrained embeddings, they
        # will be read multiple times
        features[ModelInput.NEG_RESPONSE] = features[ModelInput.POS_RESPONSE]
        features[ModelInput.QUERY] = features[ModelInput.POS_RESPONSE]
        assert len(features) == 3, "Expected three text features"

        kwargs.update(config.items())
        return cls(raw_columns=config.columns_to_read,
                   labels=None,
                   features=features,
                   **kwargs)