示例#1
0
def test_config_file(plain_command_line, config_command_line, config_contents):
    config_file_argparse = arguments.ConfigArgumentParser()
    # Capital letter arguments are required
    config_file_argparse.add_argument("-a", type=int)
    config_file_argparse.add_argument("-b", type=int)
    config_file_argparse.add_argument("-C", type=int, required=True)
    config_file_argparse.add_argument("-D", type=int, required=True)
    config_file_argparse.add_argument("-e", type=int)

    # The option '--config <file>' will be added automaticall to config_command_line
    with tempfile.NamedTemporaryFile("w") as fp:
        arguments.save_args(argparse.Namespace(**config_contents), fp.name)
        fp.flush()

        # Parse args and cast to dicts directly
        args_command_line = vars(
            config_file_argparse.parse_args(args=plain_command_line.split()))
        args_config = vars(
            config_file_argparse.parse_args(
                args=(config_command_line +
                      (" --config %s" % fp.name)).split()))

        # Remove the config entry
        del args_command_line["config"]
        del args_config["config"]

        assert args_command_line == args_config
示例#2
0
def test_config_file_required(config_command_line, config_contents):
    config_file_argparse = arguments.ConfigArgumentParser()
    # Capital letter arguments are required
    config_file_argparse.add_argument("-a", type=int)
    config_file_argparse.add_argument("-b", type=int)
    config_file_argparse.add_argument("-C", type=int, required=True)
    config_file_argparse.add_argument("-D", type=int, required=True)
    config_file_argparse.add_argument("-e", type=int)

    # The option '--config <file>' will be added automaticall to config_command_line
    with pytest.raises(SystemExit): # argparse does not have finer regularity exceptions
        with tempfile.NamedTemporaryFile("w") as fp:
            arguments.save_args(argparse.Namespace(**config_contents), fp.name)
            fp.flush()

            # Parse args and cast to dicts directly
            config_file_argparse.parse_args(
                args=(config_command_line + (" --config %s" % fp.name)).split())
示例#3
0
def main():
    params = arguments.ConfigArgumentParser(description='Translate CLI')
    arguments.add_translate_cli_args(params)
    args = params.parse_args()
    run_translate(args)
    def initialize(self, context):
        super(SockeyeService, self).initialize(context)

        self.basedir = context.system_properties.get('model_dir')
        self.preprocessor = ChineseCharPreprocessor(
            os.path.join(self.basedir, 'bpe.codes.zh-en'),
            os.path.join(self.basedir, 'scripts'),
            os.path.join(self.basedir, 'scripts'))
        self.postprocessor = Detokenizer(
            os.path.join(self.basedir, 'scripts', 'detokenize.pl'))

        params = arguments.ConfigArgumentParser(description='Translate CLI')
        arguments.add_translate_cli_args(params)

        sockeye_args_path = os.path.join(self.basedir, 'sockeye-args.txt')
        sockeye_args = params.parse_args(read_sockeye_args(sockeye_args_path))
        # override models directory
        sockeye_args.models = [self.basedir]

        if 'gpu_id' in context.system_properties:
            self.device_ids.append(context.system_properties['gpu_id'])
        else:
            logging.warning('No gpu_id found in context')
            self.device_ids.append(0)

        if sockeye_args.checkpoints is not None:
            check_condition(
                len(sockeye_args.checkpoints) == len(sockeye_args.models),
                'must provide checkpoints for each model')

        if sockeye_args.skip_topk:
            check_condition(
                sockeye_args.beam_size == 1,
                '--skip-topk has no effect if beam size is larger than 1')
            check_condition(
                len(sockeye_args.models) == 1,
                '--skip-topk has no effect for decoding with more than 1 model'
            )

        if sockeye_args.nbest_size > 1:
            check_condition(
                sockeye_args.beam_size >= sockeye_args.nbest_size,
                'Size of nbest list (--nbest-size) must be smaller or equal to beam size (--beam-size).'
            )
            check_condition(
                sockeye_args.beam_search_drop == const.BEAM_SEARCH_STOP_ALL,
                '--nbest-size > 1 requires beam search to only stop after all hypotheses are finished '
                '(--beam-search-stop all)')
            if sockeye_args.output_type != const.OUTPUT_HANDLER_NBEST:
                logging.warning(
                    'For nbest translation, output handler must be "%s", overriding option --output-type.',
                    const.OUTPUT_HANDLER_NBEST)
                sockeye_args.output_type = const.OUTPUT_HANDLER_NBEST

        log_basic_info(sockeye_args)

        output_handler = get_output_handler(sockeye_args.output_type,
                                            sockeye_args.output,
                                            sockeye_args.sure_align_threshold)

        with ExitStack() as exit_stack:
            check_condition(
                len(self.device_ids) == 1,
                'translate only supports single device for now')
            translator_ctx = determine_context(
                device_ids=self.device_ids,
                use_cpu=sockeye_args.use_cpu,
                disable_device_locking=sockeye_args.disable_device_locking,
                lock_dir=sockeye_args.lock_dir,
                exit_stack=exit_stack)[0]
            logging.info('Translate Device: %s', translator_ctx)

            if sockeye_args.override_dtype == const.DTYPE_FP16:
                logging.warning(
                    'Experimental feature \'--override-dtype float16\' has been used. '
                    'This feature may be removed or change its behavior in the future. '
                    'DO NOT USE IT IN PRODUCTION')

            models, source_vocabs, target_vocab = inference.load_models(
                context=translator_ctx,
                max_input_len=sockeye_args.max_input_len,
                beam_size=sockeye_args.beam_size,
                batch_size=sockeye_args.batch_size,
                model_folders=sockeye_args.models,
                checkpoints=sockeye_args.checkpoints,
                softmax_temperature=sockeye_args.softmax_temperature,
                max_output_length_num_stds=sockeye_args.
                max_output_length_num_stds,
                decoder_return_logit_inputs=sockeye_args.restrict_lexicon
                is not None,
                cache_output_layer_w_b=sockeye_args.restrict_lexicon
                is not None,
                override_dtype=sockeye_args.override_dtype,
                output_scores=output_handler.reports_score())
            restrict_lexicon = None
            if sockeye_args.restrict_lexicon:
                restrict_lexicon = TopKLexicon(source_vocabs[0], target_vocab)
                restrict_lexicon.load(sockeye_args.restrict_lexicon,
                                      k=sockeye_args.restrict_lexicon_topk)
            store_beam = sockeye_args.output_type == const.OUTPUT_HANDLER_BEAM_STORE
            self.translator = inference.Translator(
                context=translator_ctx,
                ensemble_mode=sockeye_args.ensemble_mode,
                bucket_source_width=sockeye_args.bucket_width,
                length_penalty=inference.LengthPenalty(
                    sockeye_args.length_penalty_alpha,
                    sockeye_args.length_penalty_beta),
                beam_prune=sockeye_args.beam_prune,
                beam_search_stop=sockeye_args.beam_search_stop,
                nbest_size=sockeye_args.nbest_size,
                models=models,
                source_vocabs=source_vocabs,
                target_vocab=target_vocab,
                restrict_lexicon=restrict_lexicon,
                avoid_list=sockeye_args.avoid_list,
                store_beam=store_beam,
                strip_unknown_words=sockeye_args.strip_unknown_words,
                skip_topk=sockeye_args.skip_topk)
示例#5
0
    def get_translator(self, context):
        """
        Returns a translator for the given context
        :param context: model server context
        :return:
        """
        params = arguments.ConfigArgumentParser(description='Translate CLI')
        arguments.add_translate_cli_args(params)

        sockeye_args_path = os.path.join(self.basedir, 'sockeye-args.txt')
        sockeye_args = params.parse_args(read_sockeye_args(sockeye_args_path))
        # override models directory
        sockeye_args.models = [self.basedir]

        device_ids = []
        if 'gpu_id' in context.system_properties:
            device_ids.append(context.system_properties['gpu_id'])
        else:
            logging.warning('No gpu_id found in context')
            device_ids.append(0)

        log_basic_info(sockeye_args)

        if sockeye_args.nbest_size > 1:
            if sockeye_args.output_type != const.OUTPUT_HANDLER_JSON:
                logging.warning(
                    f'For n-best translation, you must specify --output-type {const.OUTPUT_HANDLER_JSON}'
                )
                sockeye_args.output_type = const.OUTPUT_HANDLER_JSON

        output_handler = get_output_handler(sockeye_args.output_type,
                                            sockeye_args.output,
                                            sockeye_args.sure_align_threshold)

        with ExitStack() as exit_stack:
            check_condition(
                len(device_ids) == 1,
                'translate only supports single device for now')
            translator_ctx = determine_context(
                device_ids=device_ids,
                use_cpu=sockeye_args.use_cpu,
                disable_device_locking=sockeye_args.disable_device_locking,
                lock_dir=sockeye_args.lock_dir,
                exit_stack=exit_stack)[0]
            logging.info(f'Translate Device: {translator_ctx}')

            models, source_vocabs, target_vocab = inference.load_models(
                context=translator_ctx,
                max_input_len=sockeye_args.max_input_len,
                beam_size=sockeye_args.beam_size,
                batch_size=sockeye_args.batch_size,
                model_folders=sockeye_args.models,
                checkpoints=sockeye_args.checkpoints,
                softmax_temperature=sockeye_args.softmax_temperature,
                max_output_length_num_stds=sockeye_args.
                max_output_length_num_stds,
                decoder_return_logit_inputs=sockeye_args.restrict_lexicon
                is not None,
                cache_output_layer_w_b=sockeye_args.restrict_lexicon
                is not None,
                override_dtype=sockeye_args.override_dtype,
                output_scores=output_handler.reports_score(),
                sampling=sockeye_args.sample)

            restrict_lexicon = None
            if sockeye_args.restrict_lexicon is not None:
                logging.info(str(sockeye_args.restrict_lexicon))
                if len(sockeye_args.restrict_lexicon) == 1:
                    # Single lexicon used for all inputs
                    restrict_lexicon = TopKLexicon(source_vocabs[0],
                                                   target_vocab)
                    # Handle a single arg of key:path or path (parsed as path:path)
                    restrict_lexicon.load(sockeye_args.restrict_lexicon[0][1],
                                          k=sockeye_args.restrict_lexicon_topk)
                else:
                    check_condition(
                        sockeye_args.json_input,
                        'JSON input is required when using multiple lexicons for vocabulary restriction'
                    )
                    # Multiple lexicons with specified names
                    restrict_lexicon = dict()
                    for key, path in sockeye_args.restrict_lexicon:
                        lexicon = TopKLexicon(source_vocabs[0], target_vocab)
                        lexicon.load(path,
                                     k=sockeye_args.restrict_lexicon_topk)
                        restrict_lexicon[key] = lexicon

            store_beam = sockeye_args.output_type == const.OUTPUT_HANDLER_BEAM_STORE

            brevity_penalty_weight = sockeye_args.brevity_penalty_weight
            if sockeye_args.brevity_penalty_type == const.BREVITY_PENALTY_CONSTANT:
                if sockeye_args.brevity_penalty_constant_length_ratio > 0.0:
                    constant_length_ratio = sockeye_args.brevity_penalty_constant_length_ratio
                else:
                    constant_length_ratio = sum(
                        model.length_ratio_mean
                        for model in models) / len(models)
                    logging.info(
                        f'Using average of constant length ratios saved in the model configs: {constant_length_ratio}'
                    )
            elif sockeye_args.brevity_penalty_type == const.BREVITY_PENALTY_LEARNED:
                constant_length_ratio = -1.0
            elif sockeye_args.brevity_penalty_type == const.BREVITY_PENALTY_NONE:
                brevity_penalty_weight = 0.0
                constant_length_ratio = -1.0
            else:
                raise ValueError(
                    f'Unknown brevity penalty type {sockeye_args.brevity_penalty_type}'
                )

            brevity_penalty = None
            if brevity_penalty_weight != 0.0:
                brevity_penalty = inference.BrevityPenalty(
                    brevity_penalty_weight)

            return inference.Translator(
                context=translator_ctx,
                ensemble_mode=sockeye_args.ensemble_mode,
                bucket_source_width=sockeye_args.bucket_width,
                length_penalty=inference.LengthPenalty(
                    sockeye_args.length_penalty_alpha,
                    sockeye_args.length_penalty_beta),
                beam_prune=sockeye_args.beam_prune,
                beam_search_stop=sockeye_args.beam_search_stop,
                nbest_size=sockeye_args.nbest_size,
                models=models,
                source_vocabs=source_vocabs,
                target_vocab=target_vocab,
                restrict_lexicon=restrict_lexicon,
                avoid_list=sockeye_args.avoid_list,
                store_beam=store_beam,
                strip_unknown_words=sockeye_args.strip_unknown_words,
                skip_topk=sockeye_args.skip_topk,
                sample=sockeye_args.sample,
                constant_length_ratio=constant_length_ratio,
                brevity_penalty=brevity_penalty)