Ejemplo n.º 1
0
    "Disables adding carry-over value to all categorical slots for nemotracker.",
    dest="add_carry_value",
)

parser.add_argument(
    "--no_carry_status",
    action="store_false",
    help="Disables adding carry-over status to the slots for nemotracker.",
    dest="add_carry_status",
)

args = parser.parse_args()
logging.info(args)

if args.debug_mode:
    logging.setLevel("DEBUG")

if args.task_name == "multiwoz":
    schema_config = {
        "MAX_NUM_CAT_SLOT": 9,
        "MAX_NUM_NONCAT_SLOT": 4,
        "MAX_NUM_VALUE_PER_CAT_SLOT": 47,
        "MAX_NUM_INTENT": 1,
    }
else:
    schema_config = {
        "MAX_NUM_CAT_SLOT": 6,
        "MAX_NUM_NONCAT_SLOT": 12,
        "MAX_NUM_VALUE_PER_CAT_SLOT": 12,
        "MAX_NUM_INTENT": 4,
    }
Ejemplo n.º 2
0
    choices=["", "Attention"],
    help="transformation to use for computing head. Default uses linear projection.",
)
parser.add_argument(
    "--debug_mode", action="store_true", help="Enables debug mode with more info on data preprocessing and evaluation",
)

parser.add_argument(
    "--checkpoints_to_keep", default=1, type=int, help="The number of last checkpoints to keep",
)

args = parser.parse_args()
logging.info(args)

if args.debug_mode:
    logging.setLevel(10)

if args.task_name == "multiwoz":
    schema_config = {
        "MAX_NUM_CAT_SLOT": 9,
        "MAX_NUM_NONCAT_SLOT": 4,
        "MAX_NUM_VALUE_PER_CAT_SLOT": 47,
        "MAX_NUM_INTENT": 1,
    }
else:
    schema_config = {
        "MAX_NUM_CAT_SLOT": 6,
        "MAX_NUM_NONCAT_SLOT": 12,
        "MAX_NUM_VALUE_PER_CAT_SLOT": 12,
        "MAX_NUM_INTENT": 4,
    }
Ejemplo n.º 3
0
    parser.add_argument(
        "--show_all_output", action="store_true", help="Set to True to show output of all dialogue modules"
    )
    parser.add_argument("--work_dir", default='outputs', type=str, help='Path to where to store logs')

    args = parser.parse_args()

    # Get the absolute path.
    abs_data_dir = expanduser(args.data_dir)

    # Check if data dir exists
    if not exists(abs_data_dir):
        raise ValueError(f"Data folder `{abs_data_dir}` not found")

    if args.show_all_output:
        logging.setLevel('DEBUG')

    # Initialize NF.
    nf = NeuralModuleFactory(placement=DeviceType.CPU, local_rank=None, log_dir=args.work_dir, checkpoint_dir=None)

    # Initialize the modules.

    # List of the domains to be considered.
    domains = {"attraction": 0, "restaurant": 1, "train": 2, "hotel": 3, "taxi": 5}

    # Create DataDescriptor that contains information about domains, slots, and associated vocabulary
    data_desc = MultiWOZDataDesc(abs_data_dir, domains)
    vocab_size = len(data_desc.vocab)

    # Encoder changing the "user utterance" into format accepted by TRADE encoderRNN.
    user_utterance_encoder = UserUtteranceEncoder(data_desc=data_desc)
Ejemplo n.º 4
0
from nemo_text_processing.text_normalization.en.taggers.word import WordFst
from nemo_text_processing.text_normalization.en.verbalizers.date import DateFst as vDateFst
from nemo_text_processing.text_normalization.en.verbalizers.ordinal import OrdinalFst as vOrdinalFst
from nemo_text_processing.text_normalization.en.verbalizers.time import TimeFst as vTimeFst

from nemo.utils import logging

try:
    import pynini
    from pynini.lib import pynutil

    PYNINI_AVAILABLE = True
except (ModuleNotFoundError, ImportError):
    PYNINI_AVAILABLE = False

logging.setLevel("INFO")


class ClassifyFst(GraphFst):
    """
    Final class that composes all other classification grammars. This class can process an entire sentence including punctuation.
    For deployment, this grammar will be compiled and exported to OpenFst Finate State Archiv (FAR) File. 
    More details to deployment at NeMo/tools/text_processing_deployment.
    
    Args:
        input_case: accepting either "lower_cased" or "cased" input.
        deterministic: if True will provide a single transduction option,
            for False multiple options (used for audio-based normalization)
        cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache.
        overwrite_cache: set to True to overwrite .far files
        whitelist: path to a file with whitelist replacements