def test_config_file(plain_command_line, config_command_line, config_contents): config_file_argparse = arguments.ConfigArgumentParser() # Capital letter arguments are required config_file_argparse.add_argument("-a", type=int) config_file_argparse.add_argument("-b", type=int) config_file_argparse.add_argument("-C", type=int, required=True) config_file_argparse.add_argument("-D", type=int, required=True) config_file_argparse.add_argument("-e", type=int) # The option '--config <file>' will be added automaticall to config_command_line with tempfile.NamedTemporaryFile("w") as fp: arguments.save_args(argparse.Namespace(**config_contents), fp.name) fp.flush() # Parse args and cast to dicts directly args_command_line = vars( config_file_argparse.parse_args(args=plain_command_line.split())) args_config = vars( config_file_argparse.parse_args( args=(config_command_line + (" --config %s" % fp.name)).split())) # Remove the config entry del args_command_line["config"] del args_config["config"] assert args_command_line == args_config
def test_config_file_required(config_command_line, config_contents): config_file_argparse = arguments.ConfigArgumentParser() # Capital letter arguments are required config_file_argparse.add_argument("-a", type=int) config_file_argparse.add_argument("-b", type=int) config_file_argparse.add_argument("-C", type=int, required=True) config_file_argparse.add_argument("-D", type=int, required=True) config_file_argparse.add_argument("-e", type=int) # The option '--config <file>' will be added automaticall to config_command_line with pytest.raises(SystemExit): # argparse does not have finer regularity exceptions with tempfile.NamedTemporaryFile("w") as fp: arguments.save_args(argparse.Namespace(**config_contents), fp.name) fp.flush() # Parse args and cast to dicts directly config_file_argparse.parse_args( args=(config_command_line + (" --config %s" % fp.name)).split())
def main(): params = arguments.ConfigArgumentParser(description='Translate CLI') arguments.add_translate_cli_args(params) args = params.parse_args() run_translate(args)
def initialize(self, context): super(SockeyeService, self).initialize(context) self.basedir = context.system_properties.get('model_dir') self.preprocessor = ChineseCharPreprocessor( os.path.join(self.basedir, 'bpe.codes.zh-en'), os.path.join(self.basedir, 'scripts'), os.path.join(self.basedir, 'scripts')) self.postprocessor = Detokenizer( os.path.join(self.basedir, 'scripts', 'detokenize.pl')) params = arguments.ConfigArgumentParser(description='Translate CLI') arguments.add_translate_cli_args(params) sockeye_args_path = os.path.join(self.basedir, 'sockeye-args.txt') sockeye_args = params.parse_args(read_sockeye_args(sockeye_args_path)) # override models directory sockeye_args.models = [self.basedir] if 'gpu_id' in context.system_properties: self.device_ids.append(context.system_properties['gpu_id']) else: logging.warning('No gpu_id found in context') self.device_ids.append(0) if sockeye_args.checkpoints is not None: check_condition( len(sockeye_args.checkpoints) == len(sockeye_args.models), 'must provide checkpoints for each model') if sockeye_args.skip_topk: check_condition( sockeye_args.beam_size == 1, '--skip-topk has no effect if beam size is larger than 1') check_condition( len(sockeye_args.models) == 1, '--skip-topk has no effect for decoding with more than 1 model' ) if sockeye_args.nbest_size > 1: check_condition( sockeye_args.beam_size >= sockeye_args.nbest_size, 'Size of nbest list (--nbest-size) must be smaller or equal to beam size (--beam-size).' ) check_condition( sockeye_args.beam_search_drop == const.BEAM_SEARCH_STOP_ALL, '--nbest-size > 1 requires beam search to only stop after all hypotheses are finished ' '(--beam-search-stop all)') if sockeye_args.output_type != const.OUTPUT_HANDLER_NBEST: logging.warning( 'For nbest translation, output handler must be "%s", overriding option --output-type.', const.OUTPUT_HANDLER_NBEST) sockeye_args.output_type = const.OUTPUT_HANDLER_NBEST log_basic_info(sockeye_args) output_handler = get_output_handler(sockeye_args.output_type, sockeye_args.output, sockeye_args.sure_align_threshold) with ExitStack() as exit_stack: check_condition( len(self.device_ids) == 1, 'translate only supports single device for now') translator_ctx = determine_context( device_ids=self.device_ids, use_cpu=sockeye_args.use_cpu, disable_device_locking=sockeye_args.disable_device_locking, lock_dir=sockeye_args.lock_dir, exit_stack=exit_stack)[0] logging.info('Translate Device: %s', translator_ctx) if sockeye_args.override_dtype == const.DTYPE_FP16: logging.warning( 'Experimental feature \'--override-dtype float16\' has been used. ' 'This feature may be removed or change its behavior in the future. ' 'DO NOT USE IT IN PRODUCTION') models, source_vocabs, target_vocab = inference.load_models( context=translator_ctx, max_input_len=sockeye_args.max_input_len, beam_size=sockeye_args.beam_size, batch_size=sockeye_args.batch_size, model_folders=sockeye_args.models, checkpoints=sockeye_args.checkpoints, softmax_temperature=sockeye_args.softmax_temperature, max_output_length_num_stds=sockeye_args. max_output_length_num_stds, decoder_return_logit_inputs=sockeye_args.restrict_lexicon is not None, cache_output_layer_w_b=sockeye_args.restrict_lexicon is not None, override_dtype=sockeye_args.override_dtype, output_scores=output_handler.reports_score()) restrict_lexicon = None if sockeye_args.restrict_lexicon: restrict_lexicon = TopKLexicon(source_vocabs[0], target_vocab) restrict_lexicon.load(sockeye_args.restrict_lexicon, k=sockeye_args.restrict_lexicon_topk) store_beam = sockeye_args.output_type == const.OUTPUT_HANDLER_BEAM_STORE self.translator = inference.Translator( context=translator_ctx, ensemble_mode=sockeye_args.ensemble_mode, bucket_source_width=sockeye_args.bucket_width, length_penalty=inference.LengthPenalty( sockeye_args.length_penalty_alpha, sockeye_args.length_penalty_beta), beam_prune=sockeye_args.beam_prune, beam_search_stop=sockeye_args.beam_search_stop, nbest_size=sockeye_args.nbest_size, models=models, source_vocabs=source_vocabs, target_vocab=target_vocab, restrict_lexicon=restrict_lexicon, avoid_list=sockeye_args.avoid_list, store_beam=store_beam, strip_unknown_words=sockeye_args.strip_unknown_words, skip_topk=sockeye_args.skip_topk)
def get_translator(self, context): """ Returns a translator for the given context :param context: model server context :return: """ params = arguments.ConfigArgumentParser(description='Translate CLI') arguments.add_translate_cli_args(params) sockeye_args_path = os.path.join(self.basedir, 'sockeye-args.txt') sockeye_args = params.parse_args(read_sockeye_args(sockeye_args_path)) # override models directory sockeye_args.models = [self.basedir] device_ids = [] if 'gpu_id' in context.system_properties: device_ids.append(context.system_properties['gpu_id']) else: logging.warning('No gpu_id found in context') device_ids.append(0) log_basic_info(sockeye_args) if sockeye_args.nbest_size > 1: if sockeye_args.output_type != const.OUTPUT_HANDLER_JSON: logging.warning( f'For n-best translation, you must specify --output-type {const.OUTPUT_HANDLER_JSON}' ) sockeye_args.output_type = const.OUTPUT_HANDLER_JSON output_handler = get_output_handler(sockeye_args.output_type, sockeye_args.output, sockeye_args.sure_align_threshold) with ExitStack() as exit_stack: check_condition( len(device_ids) == 1, 'translate only supports single device for now') translator_ctx = determine_context( device_ids=device_ids, use_cpu=sockeye_args.use_cpu, disable_device_locking=sockeye_args.disable_device_locking, lock_dir=sockeye_args.lock_dir, exit_stack=exit_stack)[0] logging.info(f'Translate Device: {translator_ctx}') models, source_vocabs, target_vocab = inference.load_models( context=translator_ctx, max_input_len=sockeye_args.max_input_len, beam_size=sockeye_args.beam_size, batch_size=sockeye_args.batch_size, model_folders=sockeye_args.models, checkpoints=sockeye_args.checkpoints, softmax_temperature=sockeye_args.softmax_temperature, max_output_length_num_stds=sockeye_args. max_output_length_num_stds, decoder_return_logit_inputs=sockeye_args.restrict_lexicon is not None, cache_output_layer_w_b=sockeye_args.restrict_lexicon is not None, override_dtype=sockeye_args.override_dtype, output_scores=output_handler.reports_score(), sampling=sockeye_args.sample) restrict_lexicon = None if sockeye_args.restrict_lexicon is not None: logging.info(str(sockeye_args.restrict_lexicon)) if len(sockeye_args.restrict_lexicon) == 1: # Single lexicon used for all inputs restrict_lexicon = TopKLexicon(source_vocabs[0], target_vocab) # Handle a single arg of key:path or path (parsed as path:path) restrict_lexicon.load(sockeye_args.restrict_lexicon[0][1], k=sockeye_args.restrict_lexicon_topk) else: check_condition( sockeye_args.json_input, 'JSON input is required when using multiple lexicons for vocabulary restriction' ) # Multiple lexicons with specified names restrict_lexicon = dict() for key, path in sockeye_args.restrict_lexicon: lexicon = TopKLexicon(source_vocabs[0], target_vocab) lexicon.load(path, k=sockeye_args.restrict_lexicon_topk) restrict_lexicon[key] = lexicon store_beam = sockeye_args.output_type == const.OUTPUT_HANDLER_BEAM_STORE brevity_penalty_weight = sockeye_args.brevity_penalty_weight if sockeye_args.brevity_penalty_type == const.BREVITY_PENALTY_CONSTANT: if sockeye_args.brevity_penalty_constant_length_ratio > 0.0: constant_length_ratio = sockeye_args.brevity_penalty_constant_length_ratio else: constant_length_ratio = sum( model.length_ratio_mean for model in models) / len(models) logging.info( f'Using average of constant length ratios saved in the model configs: {constant_length_ratio}' ) elif sockeye_args.brevity_penalty_type == const.BREVITY_PENALTY_LEARNED: constant_length_ratio = -1.0 elif sockeye_args.brevity_penalty_type == const.BREVITY_PENALTY_NONE: brevity_penalty_weight = 0.0 constant_length_ratio = -1.0 else: raise ValueError( f'Unknown brevity penalty type {sockeye_args.brevity_penalty_type}' ) brevity_penalty = None if brevity_penalty_weight != 0.0: brevity_penalty = inference.BrevityPenalty( brevity_penalty_weight) return inference.Translator( context=translator_ctx, ensemble_mode=sockeye_args.ensemble_mode, bucket_source_width=sockeye_args.bucket_width, length_penalty=inference.LengthPenalty( sockeye_args.length_penalty_alpha, sockeye_args.length_penalty_beta), beam_prune=sockeye_args.beam_prune, beam_search_stop=sockeye_args.beam_search_stop, nbest_size=sockeye_args.nbest_size, models=models, source_vocabs=source_vocabs, target_vocab=target_vocab, restrict_lexicon=restrict_lexicon, avoid_list=sockeye_args.avoid_list, store_beam=store_beam, strip_unknown_words=sockeye_args.strip_unknown_words, skip_topk=sockeye_args.skip_topk, sample=sockeye_args.sample, constant_length_ratio=constant_length_ratio, brevity_penalty=brevity_penalty)