Beispiel #1
0
class MTBottleneckConfig(NemoConfig):
    name: Optional[str] = 'MTBottleneck'
    do_training: bool = True
    do_testing: bool = False
    model: MTBottleneckModelConfig = MTBottleneckModelConfig()
    trainer: Optional[TrainerConfig] = TrainerConfig()
    exp_manager: Optional[ExpManagerConfig] = ExpManagerConfig(name='MTBottleneck', files_to_copy=[])
Beispiel #2
0
class MTEncDecConfig(NemoConfig):
    model: AAYNBaseConfig = AAYNBaseConfig()
    trainer: Optional[TrainerConfig] = TrainerConfig()
    exp_manager: Optional[ExpManagerConfig] = ExpManagerConfig(
        name='MTEncDec', files_to_copy=[])
Beispiel #3
0
class MTEncDecConfig(NemoConfig):
    name: Optional[str] = 'MTEncDec'
    do_training: bool = True
    model: MTEncDecModelConfig = MTEncDecModelConfig()
    trainer: Optional[TrainerConfig] = TrainerConfig()
    exp_manager: Optional[ExpManagerConfig] = ExpManagerConfig(name='MTEncDec', files_to_copy=[])
Beispiel #4
0
def nemo_export(argv):
    args = get_args(argv)
    loglevel = logging.INFO
    # assuming loglevel is bound to the string value obtained from the
    # command line argument. Convert to upper case to allow the user to
    # specify --log=DEBUG or --log=debug
    if args.verbose is not None:
        numeric_level = getattr(logging, args.verbose.upper(), None)
        if not isinstance(numeric_level, int):
            raise ValueError('Invalid log level: %s' % numeric_level)
        loglevel = numeric_level

    logger = logging.getLogger(__name__)
    if logger.handlers:
        for handler in logger.handlers:
            logger.removeHandler(handler)
    logging.basicConfig(level=loglevel, format='%(asctime)s [%(levelname)s] %(message)s')
    logging.info("Logging level set to {}".format(loglevel))

    """Convert a .nemo saved model into .riva Riva input format."""
    nemo_in = args.source
    out = args.out

    # Create a PL trainer object which is required for restoring Megatron models
    cfg_trainer = TrainerConfig(
        gpus=1,
        accelerator="ddp",
        num_nodes=1,
        # Need to set the following two to False as ExpManager will take care of them differently.
        logger=False,
        checkpoint_callback=False,
    )
    trainer = Trainer(cfg_trainer)

    logging.info("Restoring NeMo model from '{}'".format(nemo_in))
    try:
        with torch.inference_mode():
            # Restore instance from .nemo file using generic model restore_from
            model = ModelPT.restore_from(restore_path=nemo_in, trainer=trainer)
    except Exception as e:
        logging.error(
            "Failed to restore model from NeMo file : {}. Please make sure you have the latest NeMo package installed with [all] dependencies.".format(
                nemo_in
            )
        )
        raise e

    logging.info("Model {} restored from '{}'".format(model.cfg.target, nemo_in))

    if not isinstance(model, Exportable):
        logging.error("Your NeMo model class ({}) is not Exportable.".format(model.cfg.target))
        sys.exit(1)
    typecheck.set_typecheck_enabled(enabled=False)

    try:
        #
        #  Add custom export parameters here
        #
        in_args = {}
        if args.max_batch is not None:
            in_args["max_batch"] = args.max_batch
        if args.max_dim is not None:
            in_args["max_dim"] = args.max_dim

        autocast = nullcontext
        model = model.to(device=args.device)
        model.eval()
        with torch.inference_mode():
            input_example = model.input_module.input_example(**in_args)
        if args.autocast:
            autocast = torch.cuda.amp.autocast
        with autocast(), torch.inference_mode():
            logging.info(f"Getting output example")
            input_list, input_dict = parse_input_example(input_example)
            output_example = forward_method(model)(*input_list, **input_dict)
            logging.info(f"Exporting model with autocast={args.autocast}")
            input_names = model.input_names
            output_names = model.output_names

            _, descriptions = model.export(
                out,
                check_trace=False,
                input_example=input_example,
                onnx_opset_version=args.onnx_opset,
                verbose=args.verbose,
            )

    except Exception as e:
        logging.error(
            "Export failed. Please make sure your NeMo model class ({}) has working export() and that you have the latest NeMo package installed with [all] dependencies.".format(
                model.cfg.target
            )
        )
        raise e

    logging.info("Successfully exported to {}".format(out))

    del model

    if args.runtime_check:
        verify_runtime(out, input_list, input_dict, input_names, output_names, output_example)