Example #1
0
def main():
    params = argparse.ArgumentParser(description='Translate CLI')
    arguments.add_translate_cli_args(params)
    args = params.parse_args()

    if args.output is not None:
        global logger
        logger = setup_main_logger(__name__,
                                   console=not args.quiet,
                                   file_logging=True,
                                   path="%s.%s" % (args.output, C.LOG_NAME))

    if args.checkpoints is not None:
        check_condition(
            len(args.checkpoints) == len(args.models),
            "must provide checkpoints for each model")

    log_basic_info(args)

    output_handler = get_output_handler(args.output_type, args.output,
                                        args.sure_align_threshold)

    with ExitStack() as exit_stack:
        context = _setup_context(args, exit_stack)

        models, source_vocabs, target_vocab = inference.load_models(
            context=context,
            max_input_len=args.max_input_len,
            beam_size=args.beam_size,
            batch_size=args.batch_size,
            model_folders=args.models,
            checkpoints=args.checkpoints,
            softmax_temperature=args.softmax_temperature,
            max_output_length_num_stds=args.max_output_length_num_stds,
            decoder_return_logit_inputs=args.restrict_lexicon is not None,
            cache_output_layer_w_b=args.restrict_lexicon is not None)
        restrict_lexicon = None  # type: Optional[TopKLexicon]
        if args.restrict_lexicon:
            restrict_lexicon = TopKLexicon(source_vocabs[0], target_vocab)
            restrict_lexicon.load(args.restrict_lexicon,
                                  k=args.restrict_lexicon_topk)
        store_beam = args.output_type == C.OUTPUT_HANDLER_BEAM_STORE
        translator = inference.Translator(
            context=context,
            ensemble_mode=args.ensemble_mode,
            bucket_source_width=args.bucket_width,
            length_penalty=inference.LengthPenalty(args.length_penalty_alpha,
                                                   args.length_penalty_beta),
            models=models,
            source_vocabs=source_vocabs,
            target_vocab=target_vocab,
            restrict_lexicon=restrict_lexicon,
            store_beam=store_beam,
            strip_unknown_words=args.strip_unknown_words)
        read_and_translate(translator=translator,
                           output_handler=output_handler,
                           chunk_size=args.chunk_size,
                           inp=args.input,
                           inp_factors=args.input_factors,
                           json_input=args.json_input)
Example #2
0
def main():
    params = argparse.ArgumentParser(description='Translate CLI')
    arguments.add_translate_cli_args(params)
    args = params.parse_args()

    if args.output is not None:
        global logger
        logger = setup_main_logger(__name__,
                                   console=not args.quiet,
                                   file_logging=True,
                                   path="%s.%s" % (args.output, C.LOG_NAME))

    if args.checkpoints is not None:
        check_condition(
            len(args.checkpoints) == len(args.models),
            "must provide checkpoints for each model")

    log_basic_info(args)

    output_handler = sockeye.output_handler.get_output_handler(
        args.output_type, args.output, args.sure_align_threshold)

    with ExitStack() as exit_stack:
        context = _setup_context(args, exit_stack)

        models, vocab_source, vocab_target = sockeye.inference.load_models(
            context,
            args.max_input_len,
            args.beam_size,
            args.batch_size,
            args.models,
            args.checkpoints,
            args.softmax_temperature,
            args.max_output_length_num_stds,
            decoder_return_logit_inputs=args.restrict_lexicon is not None,
            cache_output_layer_w_b=args.restrict_lexicon is not None,
            input_dim=args.input_dim)
        restrict_lexicon = None  # type: TopKLexicon
        if args.restrict_lexicon:
            restrict_lexicon = TopKLexicon(vocab_source, vocab_target)
            restrict_lexicon.load(args.restrict_lexicon)
        translator = sockeye.inference.Translator(
            context,
            args.ensemble_mode,
            args.bucket_width,
            sockeye.inference.LengthPenalty(args.length_penalty_alpha,
                                            args.length_penalty_beta),
            models,
            vocab_source,
            vocab_target,
            restrict_lexicon,
            input_dim=args.input_dim)
        read_and_translate(translator, output_handler, args.chunk_size,
                           args.input)
Example #3
0
def main():
    params = argparse.ArgumentParser(description='Translate CLI')
    arguments.add_translate_cli_args(params)
    args = params.parse_args()

    if args.output is not None:
        global logger
        logger = setup_main_logger(__name__,
                                   file_logging=True,
                                   path="%s.%s" % (args.output, C.LOG_NAME))

    if args.checkpoints is not None:
        check_condition(
            len(args.checkpoints) == len(args.models),
            "must provide checkpoints for each model")

    log_basic_info(args)

    output_handler = sockeye.output_handler.get_output_handler(
        args.output_type, args.output, args.sure_align_threshold)

    with ExitStack() as exit_stack:
        context = _setup_context(args, exit_stack)

        bucket_source_width, bucket_target_width = args.bucket_width
        translator = sockeye.inference.Translator(
            context, args.ensemble_mode, bucket_source_width,
            bucket_target_width,
            sockeye.inference.LengthPenalty(args.length_penalty_alpha,
                                            args.length_penalty_beta),
            *sockeye.inference.load_models(context, args.max_input_len,
                                           args.beam_size, args.batch_size,
                                           args.models, args.checkpoints,
                                           args.softmax_temperature,
                                           args.max_output_length_num_stds))

        logger.info("Using batches of size %d", args.batch_size)
        read_and_translate(translator, output_handler, args.chunk_size,
                           args.input)
Example #4
0
def run_translate(args: argparse.Namespace):
    # Seed randomly unless a seed has been passed
    seed_rngs(args.seed if args.seed is not None else int(time.time()))

    if args.output is not None:
        setup_main_logger(console=not args.quiet,
                          file_logging=not args.no_logfile,
                          path="%s.%s" % (args.output, C.LOG_NAME),
                          level=args.loglevel)
    else:
        setup_main_logger(file_logging=False, level=args.loglevel)

    log_basic_info(args)

    if args.nbest_size > 1:
        if args.output_type != C.OUTPUT_HANDLER_JSON:
            logger.warning(
                "For nbest translation, you must specify `--output-type '%s'; overriding your setting of '%s'.",
                C.OUTPUT_HANDLER_JSON, args.output_type)
            args.output_type = C.OUTPUT_HANDLER_JSON
    output_handler = get_output_handler(args.output_type, args.output)

    use_cpu = args.use_cpu
    if not pt.cuda.is_available():
        logger.info("CUDA not available, using cpu")
        use_cpu = True
    device = pt.device('cpu') if use_cpu else pt.device('cuda', args.device_id)
    logger.info(f"Translate Device: {device}")
    models, source_vocabs, target_vocabs = load_models(
        device=device,
        model_folders=args.models,
        checkpoints=args.checkpoints,
        dtype=args.dtype,
        inference_only=True)

    restrict_lexicon = None  # type: Optional[Union[RestrictLexicon, Dict[str, RestrictLexicon]]]
    if args.restrict_lexicon is not None:
        logger.info(str(args.restrict_lexicon))
        if len(args.restrict_lexicon) == 1:
            # Single lexicon used for all inputs.
            # Handle a single arg of key:path or path (parsed as path:path)
            restrict_lexicon = load_restrict_lexicon(
                args.restrict_lexicon[0][1],
                source_vocabs[0],
                target_vocabs[0],
                k=args.restrict_lexicon_topk)
            logger.info(
                f"Loaded a single lexicon ({args.restrict_lexicon[0][0]}) that will be applied to all inputs."
            )
        else:
            check_condition(
                args.json_input,
                "JSON input is required when using multiple lexicons for vocabulary restriction"
            )
            # Multiple lexicons with specified names
            restrict_lexicon = dict()
            for key, path in args.restrict_lexicon:
                lexicon = load_restrict_lexicon(path,
                                                source_vocabs[0],
                                                target_vocabs[0],
                                                k=args.restrict_lexicon_topk)
                restrict_lexicon[key] = lexicon

    brevity_penalty_weight = args.brevity_penalty_weight
    if args.brevity_penalty_type == C.BREVITY_PENALTY_CONSTANT:
        if args.brevity_penalty_constant_length_ratio > 0.0:
            constant_length_ratio = args.brevity_penalty_constant_length_ratio
        else:
            constant_length_ratio = sum(model.length_ratio_mean
                                        for model in models) / len(models)
            logger.info(
                "Using average of constant length ratios saved in the model configs: %f",
                constant_length_ratio)
    elif args.brevity_penalty_type == C.BREVITY_PENALTY_LEARNED:
        constant_length_ratio = -1.0
    elif args.brevity_penalty_type == C.BREVITY_PENALTY_NONE:
        brevity_penalty_weight = 0.0
        constant_length_ratio = -1.0
    else:
        raise ValueError("Unknown brevity penalty type %s" %
                         args.brevity_penalty_type)

    for model in models:
        model.eval()

    scorer = inference.CandidateScorer(
        length_penalty_alpha=args.length_penalty_alpha,
        length_penalty_beta=args.length_penalty_beta,
        brevity_penalty_weight=brevity_penalty_weight)
    scorer.to(models[0].dtype)

    translator = inference.Translator(
        device=device,
        ensemble_mode=args.ensemble_mode,
        scorer=scorer,
        batch_size=args.batch_size,
        beam_size=args.beam_size,
        beam_search_stop=args.beam_search_stop,
        nbest_size=args.nbest_size,
        models=models,
        source_vocabs=source_vocabs,
        target_vocabs=target_vocabs,
        restrict_lexicon=restrict_lexicon,
        strip_unknown_words=args.strip_unknown_words,
        sample=args.sample,
        output_scores=output_handler.reports_score(),
        constant_length_ratio=constant_length_ratio,
        max_output_length_num_stds=args.max_output_length_num_stds,
        max_input_length=args.max_input_length,
        max_output_length=args.max_output_length,
        prevent_unk=args.prevent_unk,
        greedy=args.greedy)

    read_and_translate(translator=translator,
                       output_handler=output_handler,
                       chunk_size=args.chunk_size,
                       input_file=args.input,
                       input_factors=args.input_factors,
                       input_is_json=args.json_input)
Example #5
0
def run_translate(args: argparse.Namespace):

    if args.output is not None:
        global logger
        logger = setup_main_logger(__name__,
                                   console=not args.quiet,
                                   file_logging=True,
                                   path="%s.%s" % (args.output, C.LOG_NAME))

    if args.checkpoints is not None:
        check_condition(
            len(args.checkpoints) == len(args.models),
            "must provide checkpoints for each model")

    if args.skip_topk:
        check_condition(
            args.beam_size == 1,
            "--skip-topk has no effect if beam size is larger than 1")
        check_condition(
            len(args.models) == 1,
            "--skip-topk has no effect for decoding with more than 1 model")

    if args.single_hyp_max:
        check_condition(args.single_hyp_max <= args.beam_size,
                        "--single-hyp-max should be at most the beam size")

    log_basic_info(args)

    output_handler = get_output_handler(args.output_type, args.output,
                                        args.sure_align_threshold)

    with ExitStack() as exit_stack:
        check_condition(
            len(args.device_ids) == 1,
            "translate only supports single device for now")
        context = determine_context(
            device_ids=args.device_ids,
            use_cpu=args.use_cpu,
            disable_device_locking=args.disable_device_locking,
            lock_dir=args.lock_dir,
            exit_stack=exit_stack)[0]
        logger.info("Translate Device: %s", context)

        if args.override_dtype == C.DTYPE_FP16:
            logger.warning(
                'Experimental feature \'--override-dtype float16\' has been used. '
                'This feature may be removed or change its behaviour in future. '
                'DO NOT USE IT IN PRODUCTION!')

        models, source_vocabs, target_vocab = inference.load_models(
            context=context,
            max_input_len=args.max_input_len,
            beam_size=args.beam_size,
            batch_size=args.batch_size,
            model_folders=args.models,
            checkpoints=args.checkpoints,
            softmax_temperature=args.softmax_temperature,
            max_output_length_num_stds=args.max_output_length_num_stds,
            decoder_return_logit_inputs=args.restrict_lexicon is not None,
            cache_output_layer_w_b=args.restrict_lexicon is not None,
            override_dtype=args.override_dtype)
        restrict_lexicon = None  # type: Optional[TopKLexicon]
        if args.restrict_lexicon:
            restrict_lexicon = TopKLexicon(source_vocabs[0], target_vocab)
            restrict_lexicon.load(args.restrict_lexicon,
                                  k=args.restrict_lexicon_topk)
        store_beam = args.output_type == C.OUTPUT_HANDLER_BEAM_STORE

        translator = inference.Translator(
            context=context,
            ensemble_mode=args.ensemble_mode,
            bucket_source_width=args.bucket_width,
            length_penalty=inference.LengthPenalty(args.length_penalty_alpha,
                                                   args.length_penalty_beta),
            beam_prune=args.beam_prune,
            beam_search_stop=args.beam_search_stop,
            models=models,
            source_vocabs=source_vocabs,
            target_vocab=target_vocab,
            restrict_lexicon=restrict_lexicon,
            avoid_list=args.avoid_list,
            store_beam=store_beam,
            strip_unknown_words=args.strip_unknown_words,
            skip_topk=args.skip_topk,
            beam_block_ngram=args.beam_block_ngram,
            single_hyp_max=args.single_hyp_max,
            beam_sibling_penalty=args.beam_sibling_penalty,
            stochastic_search=args.stochastic_search,
            stochastic_search_size=args.stochastic_search_size)
        read_and_translate(translator=translator,
                           output_handler=output_handler,
                           chunk_size=args.chunk_size,
                           input_file=args.input,
                           input_factors=args.input_factors,
                           input_is_json=args.json_input,
                           num_translations=args.num_translations)
Example #6
0
def run_translate(args: argparse.Namespace):

    # Seed randomly unless a seed has been passed
    utils.seed_rngs(args.seed if args.seed is not None else int(time.time()))

    if args.output is not None:
        setup_main_logger(console=not args.quiet,
                          file_logging=not args.no_logfile,
                          path="%s.%s" % (args.output, C.LOG_NAME),
                          level=args.loglevel)
    else:
        setup_main_logger(file_logging=False, level=args.loglevel)

    log_basic_info(args)

    if args.nbest_size > 1:
        if args.output_type != C.OUTPUT_HANDLER_JSON:
            logger.warning(
                "For nbest translation, you must specify `--output-type '%s'; overriding your setting of '%s'.",
                C.OUTPUT_HANDLER_JSON, args.output_type)
            args.output_type = C.OUTPUT_HANDLER_JSON
    output_handler = get_output_handler(args.output_type, args.output)
    hybridize = not args.no_hybridization

    with ExitStack() as exit_stack:
        check_condition(
            len(args.device_ids) == 1,
            "translate only supports single device for now")
        context = determine_context(
            device_ids=args.device_ids,
            use_cpu=args.use_cpu,
            disable_device_locking=args.disable_device_locking,
            lock_dir=args.lock_dir,
            exit_stack=exit_stack)[0]
        logger.info("Translate Device: %s", context)

        models, source_vocabs, target_vocabs = load_models(
            context=context,
            model_folders=args.models,
            checkpoints=args.checkpoints,
            dtype=args.dtype,
            hybridize=hybridize,
            inference_only=True,
            mc_dropout=args.mc_dropout)

        restrict_lexicon = None  # type: Optional[Union[TopKLexicon, Dict[str, TopKLexicon]]]
        if args.restrict_lexicon is not None:
            logger.info(str(args.restrict_lexicon))
            if len(args.restrict_lexicon) == 1:
                # Single lexicon used for all inputs.
                restrict_lexicon = TopKLexicon(source_vocabs[0],
                                               target_vocabs[0])
                # Handle a single arg of key:path or path (parsed as path:path)
                restrict_lexicon.load(args.restrict_lexicon[0][1],
                                      k=args.restrict_lexicon_topk)
            else:
                check_condition(
                    args.json_input,
                    "JSON input is required when using multiple lexicons for vocabulary restriction"
                )
                # Multiple lexicons with specified names
                restrict_lexicon = dict()
                for key, path in args.restrict_lexicon:
                    lexicon = TopKLexicon(source_vocabs[0], target_vocabs[0])
                    lexicon.load(path, k=args.restrict_lexicon_topk)
                    restrict_lexicon[key] = lexicon

        brevity_penalty_weight = args.brevity_penalty_weight
        if args.brevity_penalty_type == C.BREVITY_PENALTY_CONSTANT:
            if args.brevity_penalty_constant_length_ratio > 0.0:
                constant_length_ratio = args.brevity_penalty_constant_length_ratio
            else:
                constant_length_ratio = sum(model.length_ratio_mean
                                            for model in models) / len(models)
                logger.info(
                    "Using average of constant length ratios saved in the model configs: %f",
                    constant_length_ratio)
        elif args.brevity_penalty_type == C.BREVITY_PENALTY_LEARNED:
            constant_length_ratio = -1.0
        elif args.brevity_penalty_type == C.BREVITY_PENALTY_NONE:
            brevity_penalty_weight = 0.0
            constant_length_ratio = -1.0
        else:
            raise ValueError("Unknown brevity penalty type %s" %
                             args.brevity_penalty_type)

        scorer = inference.CandidateScorer(
            length_penalty_alpha=args.length_penalty_alpha,
            length_penalty_beta=args.length_penalty_beta,
            brevity_penalty_weight=brevity_penalty_weight,
            prefix='scorer_')

        translator = inference.Translator(
            context=context,
            ensemble_mode=args.ensemble_mode,
            scorer=scorer,
            batch_size=args.batch_size,
            beam_size=args.beam_size,
            beam_search_stop=args.beam_search_stop,
            nbest_size=args.nbest_size,
            models=models,
            source_vocabs=source_vocabs,
            target_vocabs=target_vocabs,
            restrict_lexicon=restrict_lexicon,
            avoid_list=args.avoid_list,
            strip_unknown_words=args.strip_unknown_words,
            sample=args.sample,
            output_scores=output_handler.reports_score(),
            constant_length_ratio=constant_length_ratio,
            max_output_length_num_stds=args.max_output_length_num_stds,
            max_input_length=args.max_input_length,
            max_output_length=args.max_output_length,
            hybridize=hybridize,
            softmax_temperature=args.softmax_temperature)
        read_and_translate(translator=translator,
                           output_handler=output_handler,
                           chunk_size=args.chunk_size,
                           input_file=args.input,
                           input_factors=args.input_factors,
                           input_is_json=args.json_input)
Example #7
0
def run_translate(args: argparse.Namespace):

    # Seed randomly unless a seed has been passed
    utils.seed_rngs(args.seed if args.seed is not None else int(time.time()))

    if args.output is not None:
        setup_main_logger(console=not args.quiet,
                          file_logging=True,
                          path="%s.%s" % (args.output, C.LOG_NAME),
                          level=args.loglevel)
    else:
        setup_main_logger(file_logging=False, level=args.loglevel)

    log_basic_info(args)

    if args.nbest_size > 1:
        if args.output_type != C.OUTPUT_HANDLER_JSON:
            logger.warning(
                "For nbest translation, you must specify `--output-type '%s'; overriding your setting of '%s'.",
                C.OUTPUT_HANDLER_JSON, args.output_type)
            args.output_type = C.OUTPUT_HANDLER_JSON
    output_handler = get_output_handler(args.output_type, args.output,
                                        args.sure_align_threshold)

    with ExitStack() as exit_stack:
        check_condition(
            len(args.device_ids) == 1,
            "translate only supports single device for now")
        context = determine_context(
            device_ids=args.device_ids,
            use_cpu=args.use_cpu,
            disable_device_locking=args.disable_device_locking,
            lock_dir=args.lock_dir,
            exit_stack=exit_stack)[0]
        logger.info("Translate Device: %s", context)

        models, source_vocabs, target_vocab = inference.load_models(
            context=context,
            max_input_len=args.max_input_len,
            beam_size=args.beam_size,
            batch_size=args.batch_size,
            model_folders=args.models,
            checkpoints=args.checkpoints,
            softmax_temperature=args.softmax_temperature,
            max_output_length_num_stds=args.max_output_length_num_stds,
            decoder_return_logit_inputs=args.restrict_lexicon is not None,
            cache_output_layer_w_b=args.restrict_lexicon is not None,
            override_dtype=args.override_dtype,
            output_scores=output_handler.reports_score(),
            sampling=args.sample)
        restrict_lexicon = None  # type: Optional[TopKLexicon]
        if args.restrict_lexicon:
            restrict_lexicon = TopKLexicon(source_vocabs[0], target_vocab)
            restrict_lexicon.load(args.restrict_lexicon,
                                  k=args.restrict_lexicon_topk)
        store_beam = args.output_type == C.OUTPUT_HANDLER_BEAM_STORE
        translator = inference.Translator(
            context=context,
            ensemble_mode=args.ensemble_mode,
            bucket_source_width=args.bucket_width,
            length_penalty=inference.LengthPenalty(args.length_penalty_alpha,
                                                   args.length_penalty_beta),
            beam_prune=args.beam_prune,
            beam_search_stop=args.beam_search_stop,
            nbest_size=args.nbest_size,
            models=models,
            source_vocabs=source_vocabs,
            target_vocab=target_vocab,
            restrict_lexicon=restrict_lexicon,
            avoid_list=args.avoid_list,
            store_beam=store_beam,
            strip_unknown_words=args.strip_unknown_words,
            skip_topk=args.skip_topk,
            sample=args.sample)
        read_and_translate(translator=translator,
                           output_handler=output_handler,
                           chunk_size=args.chunk_size,
                           input_file=args.input,
                           input_factors=args.input_factors,
                           input_is_json=args.json_input)
def run_translate(args: argparse.Namespace):

    if args.output is not None:
        global logger
        logger = setup_main_logger(__name__,
                                   console=not args.quiet,
                                   file_logging=True,
                                   path="%s.%s" % (args.output, C.LOG_NAME))

    if args.checkpoints is not None:
        check_condition(
            len(args.checkpoints) == len(args.models),
            "must provide checkpoints for each model")

    if args.beam_search_stop == C.BEAM_SEARCH_STOP_FIRST:
        check_condition(
            args.batch_size == 1,
            "Early stopping (--beam-search-stop %s) not supported with batching"
            % (C.BEAM_SEARCH_STOP_FIRST))

    log_basic_info(args)

    output_handler = get_output_handler(args.output_type, args.output,
                                        args.sure_align_threshold)

    with ExitStack() as exit_stack:
        context = _setup_context(args, exit_stack)

        if args.override_dtype == C.DTYPE_FP16:
            logger.warning(
                'Experimental feature \'--override-dtype float16\' has been used. '
                'This feature may be removed or change its behaviour in future. '
                'DO NOT USE IT IN PRODUCTION!')

        models, source_vocabs, target_vocab = inference.load_models(
            context=context,
            max_input_len=args.max_input_len,
            beam_size=args.beam_size,
            batch_size=args.batch_size,
            model_folders=args.models,
            checkpoints=args.checkpoints,
            softmax_temperature=args.softmax_temperature,
            max_output_length_num_stds=args.max_output_length_num_stds,
            decoder_return_logit_inputs=args.restrict_lexicon is not None,
            cache_output_layer_w_b=args.restrict_lexicon is not None,
            override_dtype=args.override_dtype)
        restrict_lexicon = None  # type: Optional[TopKLexicon]
        if args.restrict_lexicon:
            restrict_lexicon = TopKLexicon(source_vocabs[0], target_vocab)
            restrict_lexicon.load(args.restrict_lexicon,
                                  k=args.restrict_lexicon_topk)
        store_beam = args.output_type == C.OUTPUT_HANDLER_BEAM_STORE
        translator = inference.Translator(
            context=context,
            ensemble_mode=args.ensemble_mode,
            bucket_source_width=args.bucket_width,
            length_penalty=inference.LengthPenalty(args.length_penalty_alpha,
                                                   args.length_penalty_beta),
            beam_prune=args.beam_prune,
            beam_search_stop=args.beam_search_stop,
            models=models,
            source_vocabs=source_vocabs,
            target_vocab=target_vocab,
            restrict_lexicon=restrict_lexicon,
            store_beam=store_beam,
            strip_unknown_words=args.strip_unknown_words)
        read_and_translate(translator=translator,
                           output_handler=output_handler,
                           chunk_size=args.chunk_size,
                           input_file=args.input,
                           input_factors=args.input_factors,
                           input_is_json=args.json_input)
Example #9
0
def run_translate(args: argparse.Namespace):

    # Seed randomly unless a seed has been passed
    utils.seed_rngs(args.seed if args.seed is not None else int(time.time()))

    if args.output is not None:
        setup_main_logger(console=not args.quiet,
                          file_logging=True,
                          path="%s.%s" % (args.output, C.LOG_NAME),
                          level=args.loglevel)
    else:
        setup_main_logger(file_logging=False, level=args.loglevel)

    log_basic_info(args)

    if args.nbest_size > 1:
        if args.output_type != C.OUTPUT_HANDLER_JSON:
            logger.warning(
                "For nbest translation, you must specify `--output-type '%s'; overriding your setting of '%s'.",
                C.OUTPUT_HANDLER_JSON, args.output_type)
            args.output_type = C.OUTPUT_HANDLER_JSON
    output_handler = get_output_handler(args.output_type, args.output,
                                        args.sure_align_threshold)

    with ExitStack() as exit_stack:
        check_condition(
            len(args.device_ids) == 1,
            "translate only supports single device for now")
        context = determine_context(
            device_ids=args.device_ids,
            use_cpu=args.use_cpu,
            disable_device_locking=args.disable_device_locking,
            lock_dir=args.lock_dir,
            exit_stack=exit_stack)[0]
        logger.info("Translate Device: %s", context)

        models, source_vocabs, target_vocab = inference.load_models(
            context=context,
            max_input_len=args.max_input_len,
            max_ctx_input_len=args.max_ctx_input_len +
            1,  # + 1 for special context token
            beam_size=args.beam_size,
            batch_size=args.batch_size,
            model_folders=args.models,
            checkpoints=args.checkpoints,
            softmax_temperature=args.softmax_temperature,
            max_output_length_num_stds=args.max_output_length_num_stds,
            decoder_return_logit_inputs=args.restrict_lexicon is not None,
            cache_output_layer_w_b=args.restrict_lexicon is not None,
            override_dtype=args.override_dtype,
            output_scores=output_handler.reports_score(),
            sampling=args.sample)

        if any([model.config.num_pointers for model in models]):
            check_condition(
                args.restrict_lexicon is None,
                "The pointer mechanism does not currently work with vocabulary restriction."
            )

        restrict_lexicon = None  # type: Optional[Union[TopKLexicon, Dict[str, TopKLexicon]]]
        if args.restrict_lexicon is not None:
            logger.info(str(args.restrict_lexicon))
            if len(args.restrict_lexicon) == 1:
                # Single lexicon used for all inputs
                restrict_lexicon = TopKLexicon(source_vocabs[0], target_vocab)
                # Handle a single arg of key:path or path (parsed as path:path)
                restrict_lexicon.load(args.restrict_lexicon[0][1],
                                      k=args.restrict_lexicon_topk)
            else:
                check_condition(
                    args.json_input,
                    "JSON input is required when using multiple lexicons for vocabulary restriction"
                )
                # Multiple lexicons with specified names
                restrict_lexicon = dict()
                for key, path in args.restrict_lexicon:
                    lexicon = TopKLexicon(source_vocabs[0], target_vocab)
                    lexicon.load(path, k=args.restrict_lexicon_topk)
                    restrict_lexicon[key] = lexicon

        store_beam = args.output_type == C.OUTPUT_HANDLER_BEAM_STORE

        brevity_penalty_weight = args.brevity_penalty_weight
        if args.brevity_penalty_type == C.BREVITY_PENALTY_CONSTANT:
            if args.brevity_penalty_constant_length_ratio > 0.0:
                constant_length_ratio = args.brevity_penalty_constant_length_ratio
            else:
                constant_length_ratio = sum(model.length_ratio_mean
                                            for model in models) / len(models)
                logger.info(
                    "Using average of constant length ratios saved in the model configs: %f",
                    constant_length_ratio)
        elif args.brevity_penalty_type == C.BREVITY_PENALTY_LEARNED:
            constant_length_ratio = -1.0
        elif args.brevity_penalty_type == C.BREVITY_PENALTY_NONE:
            brevity_penalty_weight = 0.0
            constant_length_ratio = -1.0
        else:
            raise ValueError("Unknown brevity penalty type %s" %
                             args.brevity_penalty_type)

        brevity_penalty = None  # type: Optional[inference.BrevityPenalty]
        if brevity_penalty_weight != 0.0:
            brevity_penalty = inference.BrevityPenalty(brevity_penalty_weight)

        translator = inference.Translator(
            context=context,
            ensemble_mode=args.ensemble_mode,
            bucket_source_width=args.bucket_width,
            length_penalty=inference.LengthPenalty(args.length_penalty_alpha,
                                                   args.length_penalty_beta),
            beam_prune=args.beam_prune,
            beam_search_stop=args.beam_search_stop,
            nbest_size=args.nbest_size,
            models=models,
            source_vocabs=source_vocabs,
            target_vocab=target_vocab,
            restrict_lexicon=restrict_lexicon,
            avoid_list=args.avoid_list,
            store_beam=store_beam,
            strip_unknown_words=args.strip_unknown_words,
            skip_topk=args.skip_topk,
            sample=args.sample,
            constant_length_ratio=constant_length_ratio,
            brevity_penalty=brevity_penalty,
            ctx_step_size=args.ctx_step_size)

        if models[0].config.config_encoder.use_doc_pool:
            translator.use_doc_pool = models[
                0].config.config_encoder.use_doc_pool
            translator.pool_window = models[
                0].config.config_encoder.doc_pool_window
            translator.pool_stride = models[
                0].config.config_encoder.doc_pool_stride

        read_and_translate(translator=translator,
                           output_handler=output_handler,
                           chunk_size=args.chunk_size,
                           input_file=args.input,
                           input_factors=args.input_factors,
                           input_is_json=args.json_input)
    def initialize(self, context):
        super(SockeyeService, self).initialize(context)

        self.basedir = context.system_properties.get('model_dir')
        self.preprocessor = ChineseCharPreprocessor(
            os.path.join(self.basedir, 'bpe.codes.zh-en'),
            os.path.join(self.basedir, 'scripts'),
            os.path.join(self.basedir, 'scripts'))
        self.postprocessor = Detokenizer(
            os.path.join(self.basedir, 'scripts', 'detokenize.pl'))

        params = arguments.ConfigArgumentParser(description='Translate CLI')
        arguments.add_translate_cli_args(params)

        sockeye_args_path = os.path.join(self.basedir, 'sockeye-args.txt')
        sockeye_args = params.parse_args(read_sockeye_args(sockeye_args_path))
        # override models directory
        sockeye_args.models = [self.basedir]

        if 'gpu_id' in context.system_properties:
            self.device_ids.append(context.system_properties['gpu_id'])
        else:
            logging.warning('No gpu_id found in context')
            self.device_ids.append(0)

        if sockeye_args.checkpoints is not None:
            check_condition(
                len(sockeye_args.checkpoints) == len(sockeye_args.models),
                'must provide checkpoints for each model')

        if sockeye_args.skip_topk:
            check_condition(
                sockeye_args.beam_size == 1,
                '--skip-topk has no effect if beam size is larger than 1')
            check_condition(
                len(sockeye_args.models) == 1,
                '--skip-topk has no effect for decoding with more than 1 model'
            )

        if sockeye_args.nbest_size > 1:
            check_condition(
                sockeye_args.beam_size >= sockeye_args.nbest_size,
                'Size of nbest list (--nbest-size) must be smaller or equal to beam size (--beam-size).'
            )
            check_condition(
                sockeye_args.beam_search_drop == const.BEAM_SEARCH_STOP_ALL,
                '--nbest-size > 1 requires beam search to only stop after all hypotheses are finished '
                '(--beam-search-stop all)')
            if sockeye_args.output_type != const.OUTPUT_HANDLER_NBEST:
                logging.warning(
                    'For nbest translation, output handler must be "%s", overriding option --output-type.',
                    const.OUTPUT_HANDLER_NBEST)
                sockeye_args.output_type = const.OUTPUT_HANDLER_NBEST

        log_basic_info(sockeye_args)

        output_handler = get_output_handler(sockeye_args.output_type,
                                            sockeye_args.output,
                                            sockeye_args.sure_align_threshold)

        with ExitStack() as exit_stack:
            check_condition(
                len(self.device_ids) == 1,
                'translate only supports single device for now')
            translator_ctx = determine_context(
                device_ids=self.device_ids,
                use_cpu=sockeye_args.use_cpu,
                disable_device_locking=sockeye_args.disable_device_locking,
                lock_dir=sockeye_args.lock_dir,
                exit_stack=exit_stack)[0]
            logging.info('Translate Device: %s', translator_ctx)

            if sockeye_args.override_dtype == const.DTYPE_FP16:
                logging.warning(
                    'Experimental feature \'--override-dtype float16\' has been used. '
                    'This feature may be removed or change its behavior in the future. '
                    'DO NOT USE IT IN PRODUCTION')

            models, source_vocabs, target_vocab = inference.load_models(
                context=translator_ctx,
                max_input_len=sockeye_args.max_input_len,
                beam_size=sockeye_args.beam_size,
                batch_size=sockeye_args.batch_size,
                model_folders=sockeye_args.models,
                checkpoints=sockeye_args.checkpoints,
                softmax_temperature=sockeye_args.softmax_temperature,
                max_output_length_num_stds=sockeye_args.
                max_output_length_num_stds,
                decoder_return_logit_inputs=sockeye_args.restrict_lexicon
                is not None,
                cache_output_layer_w_b=sockeye_args.restrict_lexicon
                is not None,
                override_dtype=sockeye_args.override_dtype,
                output_scores=output_handler.reports_score())
            restrict_lexicon = None
            if sockeye_args.restrict_lexicon:
                restrict_lexicon = TopKLexicon(source_vocabs[0], target_vocab)
                restrict_lexicon.load(sockeye_args.restrict_lexicon,
                                      k=sockeye_args.restrict_lexicon_topk)
            store_beam = sockeye_args.output_type == const.OUTPUT_HANDLER_BEAM_STORE
            self.translator = inference.Translator(
                context=translator_ctx,
                ensemble_mode=sockeye_args.ensemble_mode,
                bucket_source_width=sockeye_args.bucket_width,
                length_penalty=inference.LengthPenalty(
                    sockeye_args.length_penalty_alpha,
                    sockeye_args.length_penalty_beta),
                beam_prune=sockeye_args.beam_prune,
                beam_search_stop=sockeye_args.beam_search_stop,
                nbest_size=sockeye_args.nbest_size,
                models=models,
                source_vocabs=source_vocabs,
                target_vocab=target_vocab,
                restrict_lexicon=restrict_lexicon,
                avoid_list=sockeye_args.avoid_list,
                store_beam=store_beam,
                strip_unknown_words=sockeye_args.strip_unknown_words,
                skip_topk=sockeye_args.skip_topk)
Example #11
0
def run_translate(args: argparse.Namespace):

    if args.output is not None:
        global logger
        logger = setup_main_logger(__name__,
                                   console=not args.quiet,
                                   file_logging=True,
                                   path="%s.%s" % (args.output, C.LOG_NAME))

    if args.checkpoints is not None:
        check_condition(len(args.checkpoints) == len(args.models), "must provide checkpoints for each model")

    log_basic_info(args)

    output_handler = get_output_handler(args.output_type,
                                        args.output,
                                        args.sure_align_threshold)

    with ExitStack() as exit_stack:
        check_condition(len(args.device_ids) == 1, "translate only supports single device for now")
        context = determine_context(device_ids=args.device_ids,
                                    use_cpu=args.use_cpu,
                                    disable_device_locking=args.disable_device_locking,
                                    lock_dir=args.lock_dir,
                                    exit_stack=exit_stack)[0]
        logger.info("Translate Device: %s", context)

        if args.override_dtype == C.DTYPE_FP16:
            logger.warning('Experimental feature \'--override-dtype float16\' has been used. '
                           'This feature may be removed or change its behaviour in future. '
                           'DO NOT USE IT IN PRODUCTION!')

        models, source_vocabs, target_vocab = inference.load_models(
            context=context,
            max_input_len=args.max_input_len,
            beam_size=args.beam_size,
            batch_size=args.batch_size,
            model_folders=args.models,
            checkpoints=args.checkpoints,
            softmax_temperature=args.softmax_temperature,
            max_output_length_num_stds=args.max_output_length_num_stds,
            decoder_return_logit_inputs=args.restrict_lexicon is not None,
            cache_output_layer_w_b=args.restrict_lexicon is not None,
            override_dtype=args.override_dtype)
        restrict_lexicon = None  # type: Optional[TopKLexicon]
        if args.restrict_lexicon:
            restrict_lexicon = TopKLexicon(source_vocabs[0], target_vocab)
            restrict_lexicon.load(args.restrict_lexicon, k=args.restrict_lexicon_topk)
        store_beam = args.output_type == C.OUTPUT_HANDLER_BEAM_STORE
        translator = inference.Translator(context=context,
                                          ensemble_mode=args.ensemble_mode,
                                          bucket_source_width=args.bucket_width,
                                          length_penalty=inference.LengthPenalty(args.length_penalty_alpha,
                                                                                 args.length_penalty_beta),
                                          beam_prune=args.beam_prune,
                                          beam_search_stop=args.beam_search_stop,
                                          models=models,
                                          source_vocabs=source_vocabs,
                                          target_vocab=target_vocab,
                                          restrict_lexicon=restrict_lexicon,
                                          avoid_list=args.avoid_list,
                                          store_beam=store_beam,
                                          strip_unknown_words=args.strip_unknown_words)
        read_and_translate(translator=translator,
                           output_handler=output_handler,
                           chunk_size=args.chunk_size,
                           input_file=args.input,
                           input_factors=args.input_factors,
                           input_is_json=args.json_input)
Example #12
0
def run_translate(args: argparse.Namespace):

    if args.output is not None:
        global logger
        logger = setup_main_logger(__name__,
                                   console=not args.quiet,
                                   file_logging=True,
                                   path="%s.%s" % (args.output, C.LOG_NAME))

    if args.checkpoints is not None:
        check_condition(len(args.checkpoints) == len(args.models), "must provide checkpoints for each model")

    if args.skip_topk:
        check_condition(args.beam_size == 1, "--skip-topk has no effect if beam size is larger than 1")
        check_condition(len(args.models) == 1, "--skip-topk has no effect for decoding with more than 1 model")

    log_basic_info(args)

    output_handler = get_output_handler(args.output_type,
                                        args.output,
                                        args.sure_align_threshold)

    with ExitStack() as exit_stack:
        check_condition(len(args.device_ids) == 1, "translate only supports single device for now")
        context = determine_context(device_ids=args.device_ids,
                                    use_cpu=args.use_cpu,
                                    disable_device_locking=args.disable_device_locking,
                                    lock_dir=args.lock_dir,
                                    exit_stack=exit_stack)[0]
        logger.info("Translate Device: %s", context)

        if args.override_dtype == C.DTYPE_FP16:
            logger.warning('Experimental feature \'--override-dtype float16\' has been used. '
                           'This feature may be removed or change its behaviour in future. '
                           'DO NOT USE IT IN PRODUCTION!')

        models, source_vocabs, target_vocab = inference.load_models(
            context=context,
            max_input_len=args.max_input_len,
            beam_size=args.beam_size,
            batch_size=args.batch_size,
            model_folders=args.models,
            checkpoints=args.checkpoints,
            softmax_temperature=args.softmax_temperature,
            max_output_length_num_stds=args.max_output_length_num_stds,
            decoder_return_logit_inputs=args.restrict_lexicon is not None,
            cache_output_layer_w_b=args.restrict_lexicon is not None,
            override_dtype=args.override_dtype)
        restrict_lexicon = None  # type: Optional[TopKLexicon]
        if args.restrict_lexicon:
            restrict_lexicon = TopKLexicon(source_vocabs[0], target_vocab)
            restrict_lexicon.load(args.restrict_lexicon, k=args.restrict_lexicon_topk)
        store_beam = args.output_type == C.OUTPUT_HANDLER_BEAM_STORE

        inference_adapt_model = None
        if args.inference_adapt:
            model = models[0] # for now, just use the first of the loaded models
            bucketing = True # for now, just set this here; modify decode CLI args later
            model_config = model.config
            default_bucket_key = (model_config.config_data.max_seq_len_source, model_config.config_data.max_seq_len_target)
            provide_data = [mx.io.DataDesc(name=C.SOURCE_NAME,
                           shape=(args.batch_size, default_bucket_key[0], model_config.config_data.num_source_factors),
                           layout=C.BATCH_MAJOR),
            mx.io.DataDesc(name=C.TARGET_NAME,
                           shape=(args.batch_size, default_bucket_key[1]),
                           layout=C.BATCH_MAJOR)]
            provide_label = [mx.io.DataDesc(name=C.TARGET_LABEL_NAME,
                           shape=(args.batch_size, default_bucket_key[1]),
                           layout=C.BATCH_MAJOR)]
            inference_adapt_model = inference_adapt_train.create_inference_adapt_model(config=model_config,
                                                                       context=context,
                                                                       provide_data=provide_data,
                                                                       provide_label=provide_label,
                                                                       default_bucket_key=default_bucket_key,
                                                                       bucketing=bucketing,
                                                                       args=args)
        translator = inference.Translator(context=context,
                                          ensemble_mode=args.ensemble_mode,
                                          bucket_source_width=args.bucket_width,
                                          length_penalty=inference.LengthPenalty(args.length_penalty_alpha,
                                                                                 args.length_penalty_beta),
                                          beam_prune=args.beam_prune,
                                          beam_search_stop=args.beam_search_stop,
                                          models=models,
                                          source_vocabs=source_vocabs,
                                          target_vocab=target_vocab,
                                          inference_adapt_model=inference_adapt_model,
                                          restrict_lexicon=restrict_lexicon,
                                          avoid_list=args.avoid_list,
                                          store_beam=store_beam,
                                          strip_unknown_words=args.strip_unknown_words,
                                          skip_topk=args.skip_topk,
                                          adapt_args=args)
        read_and_translate(translator=translator,
                           output_handler=output_handler,
                           chunk_size=args.chunk_size,
                           input_file=args.input,
                           input_factors=args.input_factors,
                           input_is_json=args.json_input)
Example #13
0
    def get_translator(self, context):
        """
        Returns a translator for the given context
        :param context: model server context
        :return:
        """
        params = arguments.ConfigArgumentParser(description='Translate CLI')
        arguments.add_translate_cli_args(params)

        sockeye_args_path = os.path.join(self.basedir, 'sockeye-args.txt')
        sockeye_args = params.parse_args(read_sockeye_args(sockeye_args_path))
        # override models directory
        sockeye_args.models = [self.basedir]

        device_ids = []
        if 'gpu_id' in context.system_properties:
            device_ids.append(context.system_properties['gpu_id'])
        else:
            logging.warning('No gpu_id found in context')
            device_ids.append(0)

        log_basic_info(sockeye_args)

        if sockeye_args.nbest_size > 1:
            if sockeye_args.output_type != const.OUTPUT_HANDLER_JSON:
                logging.warning(
                    f'For n-best translation, you must specify --output-type {const.OUTPUT_HANDLER_JSON}'
                )
                sockeye_args.output_type = const.OUTPUT_HANDLER_JSON

        output_handler = get_output_handler(sockeye_args.output_type,
                                            sockeye_args.output,
                                            sockeye_args.sure_align_threshold)

        with ExitStack() as exit_stack:
            check_condition(
                len(device_ids) == 1,
                'translate only supports single device for now')
            translator_ctx = determine_context(
                device_ids=device_ids,
                use_cpu=sockeye_args.use_cpu,
                disable_device_locking=sockeye_args.disable_device_locking,
                lock_dir=sockeye_args.lock_dir,
                exit_stack=exit_stack)[0]
            logging.info(f'Translate Device: {translator_ctx}')

            models, source_vocabs, target_vocab = inference.load_models(
                context=translator_ctx,
                max_input_len=sockeye_args.max_input_len,
                beam_size=sockeye_args.beam_size,
                batch_size=sockeye_args.batch_size,
                model_folders=sockeye_args.models,
                checkpoints=sockeye_args.checkpoints,
                softmax_temperature=sockeye_args.softmax_temperature,
                max_output_length_num_stds=sockeye_args.
                max_output_length_num_stds,
                decoder_return_logit_inputs=sockeye_args.restrict_lexicon
                is not None,
                cache_output_layer_w_b=sockeye_args.restrict_lexicon
                is not None,
                override_dtype=sockeye_args.override_dtype,
                output_scores=output_handler.reports_score(),
                sampling=sockeye_args.sample)

            restrict_lexicon = None
            if sockeye_args.restrict_lexicon is not None:
                logging.info(str(sockeye_args.restrict_lexicon))
                if len(sockeye_args.restrict_lexicon) == 1:
                    # Single lexicon used for all inputs
                    restrict_lexicon = TopKLexicon(source_vocabs[0],
                                                   target_vocab)
                    # Handle a single arg of key:path or path (parsed as path:path)
                    restrict_lexicon.load(sockeye_args.restrict_lexicon[0][1],
                                          k=sockeye_args.restrict_lexicon_topk)
                else:
                    check_condition(
                        sockeye_args.json_input,
                        'JSON input is required when using multiple lexicons for vocabulary restriction'
                    )
                    # Multiple lexicons with specified names
                    restrict_lexicon = dict()
                    for key, path in sockeye_args.restrict_lexicon:
                        lexicon = TopKLexicon(source_vocabs[0], target_vocab)
                        lexicon.load(path,
                                     k=sockeye_args.restrict_lexicon_topk)
                        restrict_lexicon[key] = lexicon

            store_beam = sockeye_args.output_type == const.OUTPUT_HANDLER_BEAM_STORE

            brevity_penalty_weight = sockeye_args.brevity_penalty_weight
            if sockeye_args.brevity_penalty_type == const.BREVITY_PENALTY_CONSTANT:
                if sockeye_args.brevity_penalty_constant_length_ratio > 0.0:
                    constant_length_ratio = sockeye_args.brevity_penalty_constant_length_ratio
                else:
                    constant_length_ratio = sum(
                        model.length_ratio_mean
                        for model in models) / len(models)
                    logging.info(
                        f'Using average of constant length ratios saved in the model configs: {constant_length_ratio}'
                    )
            elif sockeye_args.brevity_penalty_type == const.BREVITY_PENALTY_LEARNED:
                constant_length_ratio = -1.0
            elif sockeye_args.brevity_penalty_type == const.BREVITY_PENALTY_NONE:
                brevity_penalty_weight = 0.0
                constant_length_ratio = -1.0
            else:
                raise ValueError(
                    f'Unknown brevity penalty type {sockeye_args.brevity_penalty_type}'
                )

            brevity_penalty = None
            if brevity_penalty_weight != 0.0:
                brevity_penalty = inference.BrevityPenalty(
                    brevity_penalty_weight)

            return inference.Translator(
                context=translator_ctx,
                ensemble_mode=sockeye_args.ensemble_mode,
                bucket_source_width=sockeye_args.bucket_width,
                length_penalty=inference.LengthPenalty(
                    sockeye_args.length_penalty_alpha,
                    sockeye_args.length_penalty_beta),
                beam_prune=sockeye_args.beam_prune,
                beam_search_stop=sockeye_args.beam_search_stop,
                nbest_size=sockeye_args.nbest_size,
                models=models,
                source_vocabs=source_vocabs,
                target_vocab=target_vocab,
                restrict_lexicon=restrict_lexicon,
                avoid_list=sockeye_args.avoid_list,
                store_beam=store_beam,
                strip_unknown_words=sockeye_args.strip_unknown_words,
                skip_topk=sockeye_args.skip_topk,
                sample=sockeye_args.sample,
                constant_length_ratio=constant_length_ratio,
                brevity_penalty=brevity_penalty)