예제 #1
0
def _default_dataset_hparams(data_type=None):
    r"""Returns hyperparameters of a dataset with default values.

    See :meth:`texar.torch.data.MultiAlignedData.default_hparams` for details.
    """
    if data_type is None:
        data_type = _DataType.TEXT
    else:
        data_type = _DataType(data_type)
    if _is_text_data(data_type):
        hparams = _default_mono_text_dataset_hparams()
        hparams.update({
            "data_type": _DataType.TEXT,
            "vocab_share_with": None,
            "embedding_init_share_with": None,
            "processing_share_with": None,
        })
    elif _is_scalar_data(data_type):
        hparams = _default_scalar_dataset_hparams()
    elif _is_record_data(data_type):
        hparams = _default_record_dataset_hparams()
        hparams.update({
            "data_type": _DataType.RECORD,
        })
    else:
        raise ValueError(f"Invalid data type {data_type}")
    return hparams
def _default_paired_text_dataset_hparams():
    r"""Returns hyperparameters of a paired text dataset with default values.

    See :meth:`texar.torch.data.PairedTextData.default_hparams` for details.
    """
    source_hparams = _default_mono_text_dataset_hparams()
    source_hparams["bos_token"] = None
    source_hparams["data_name"] = "source"
    target_hparams = _default_mono_text_dataset_hparams()
    target_hparams.update({
        "vocab_share": False,
        "embedding_init_share": False,
        "processing_share": False,
        "data_name": "target"
    })
    return {"source_dataset": source_hparams, "target_dataset": target_hparams}