def main(): params = argparse.ArgumentParser(description='Evaluate translations by calculating 4-BLEU ' 'score with respect to a reference set.') arguments.add_evaluate_args(params) args = params.parse_args() if args.quiet: logger.setLevel(logging.ERROR) utils.check_condition(args.offset >= 0, "Offset should be non-negative.") log_sockeye_version(logger) logger.info("Command: %s", " ".join(sys.argv)) logger.info("Arguments: %s", args) references = [' '.join(e) for e in data_io.read_content(args.references)] hypotheses = [h.strip() for h in args.hypotheses] logger.info("%d hypotheses | %d references", len(hypotheses), len(references)) if not args.not_strict: utils.check_condition(len(hypotheses) == len(references), "Number of hypotheses (%d) and references (%d) does not match." % (len(hypotheses), len(references))) if not args.sentence: bleu = raw_corpus_bleu(hypotheses, references, args.offset) print(bleu, file=sys.stdout) else: for h, r in zip(hypotheses, references): bleu = raw_corpus_bleu(h, r, args.offset) print(bleu, file=sys.stdout)
def init_embeddings(args: argparse.Namespace): log_sockeye_version(logger) if len(args.weight_files) != len(args.vocabularies_in) or \ len(args.weight_files) != len(args.vocabularies_out) or \ len(args.weight_files) != len(args.names): logger.error("Exactly the same number of 'input weight files', 'input vocabularies', " "'output vocabularies' and 'Sockeye parameter names' should be provided.") sys.exit(1) params = {} # type: Dict[str, mx.nd.NDArray] weight_file_cache = {} # type: Dict[str, np.ndarray] for weight_file, vocab_in_file, vocab_out_file, name in zip(args.weight_files, args.vocabularies_in, args.vocabularies_out, args.names): weight = load_weight(weight_file, name, weight_file_cache) logger.info('Loading input/output vocabularies: %s %s', vocab_in_file, vocab_out_file) logger.info('%s',args.encoding) #vocab_in = vocab.vocab_from_json(vocab_in_file, encoding=args.encoding) vocab_in = vocab.vocab_from_json(vocab_in_file) vocab_out = vocab.vocab_from_json(vocab_out_file) logger.info('Initializing parameter: %s', name) initializer = mx.init.Normal(sigma=np.std(weight)) params[name] = init_weight(weight, vocab_in, vocab_out, initializer) logger.info('Saving initialized parameters to %s', args.file) utils.save_params(params, args.file)
def main(): params = argparse.ArgumentParser(description='Translate CLI') arguments.add_inference_args(params) arguments.add_device_args(params) args = params.parse_args() if args.output is not None: global logger logger = setup_main_logger(__name__, file_logging=True, path="%s.%s" % (args.output, C.LOG_NAME)) if args.checkpoints is not None: check_condition( len(args.checkpoints) == len(args.models), "must provide checkpoints for each model") log_sockeye_version(logger) logger.info("Command: %s", " ".join(sys.argv)) logger.info("Arguments: %s", args) output_handler = sockeye.output_handler.get_output_handler( args.output_type, args.output, args.sure_align_threshold) with ExitStack() as exit_stack: context = _setup_context(args, exit_stack) translator = sockeye.inference.Translator( context, args.ensemble_mode, *sockeye.inference.load_models(context, args.max_input_len, args.beam_size, args.models, args.checkpoints, args.softmax_temperature)) read_and_translate(translator, output_handler, args.input)
def main(): params = argparse.ArgumentParser(description='Evaluate translations by calculating metrics with ' 'respect to a reference set. If multiple hypotheses files are given' 'the mean and standard deviation of the metrics are reported.') arguments.add_evaluate_args(params) arguments.add_logging_args(params) args = params.parse_args() if args.quiet: logger.setLevel(logging.ERROR) utils.check_condition(args.offset >= 0, "Offset should be non-negative.") log_sockeye_version(logger) logger.info("Command: %s", " ".join(sys.argv)) logger.info("Arguments: %s", args) references = [' '.join(e) for e in data_io.read_content(args.references)] all_hypotheses = [[h.strip() for h in hypotheses] for hypotheses in args.hypotheses] if not args.not_strict: for hypotheses in all_hypotheses: utils.check_condition(len(hypotheses) == len(references), "Number of hypotheses (%d) and references (%d) does not match." % (len(hypotheses), len(references))) logger.info("%d hypothesis set(s) | %d hypotheses | %d references", len(all_hypotheses), len(all_hypotheses[0]), len(references)) metric_info = ["%s\t(s_opt)" % name for name in args.metrics] logger.info("\t".join(metric_info)) metrics = [] # type: List[Tuple[str, Callable]] for name in args.metrics: if name == C.BLEU: func = partial(raw_corpus_bleu, offset=args.offset) elif name == C.CHRF: func = raw_corpus_chrf elif name == C.ROUGE1: func = raw_corpus_rouge1 elif name == C.ROUGE2: func = raw_corpus_rouge2 elif name == C.ROUGEL: func = raw_corpus_rougel else: raise ValueError("Unknown metric %s." % name) metrics.append((name, func)) if not args.sentence: scores = defaultdict(list) # type: Dict[str, List[float]] for hypotheses in all_hypotheses: for name, metric in metrics: scores[name].append(metric(hypotheses, references)) _print_mean_std_score(metrics, scores) else: for hypotheses in all_hypotheses: for h, r in zip(hypotheses, references): scores = defaultdict(list) # type: Dict[str, List[float]] for name, metric in metrics: scores[name].append(metric([h], [r])) _print_mean_std_score(metrics, scores)
def main(): params = argparse.ArgumentParser(description='Evaluate translations by calculating metrics with ' 'respect to a reference set. If multiple hypotheses files are given' 'the mean and standard deviation of the metrics are reported.') arguments.add_evaluate_args(params) arguments.add_logging_args(params) args = params.parse_args() if args.quiet: logger.setLevel(logging.ERROR) utils.check_condition(args.offset >= 0, "Offset should be non-negative.") log_sockeye_version(logger) logger.info("Command: %s", " ".join(sys.argv)) logger.info("Arguments: %s", args) references = [' '.join(e) for e in data_io.read_content(args.references)] all_hypotheses = [[h.strip() for h in hypotheses] for hypotheses in args.hypotheses] metrics = args.metrics logger.info("%d hypotheses | %d references", len(all_hypotheses), len(references)) if not args.not_strict: for hypotheses in all_hypotheses: utils.check_condition(len(hypotheses) == len(references), "Number of hypotheses (%d) and references (%d) does not match." % (len(hypotheses), len(references))) metric_info = [] for metric in metrics: metric_info.append("%s\t(s_opt)" % metric) logger.info("\t".join(metric_info)) if not args.sentence: scores = defaultdict(list) for hypotheses in all_hypotheses: for metric in metrics: if metric == C.BLEU: score = raw_corpus_bleu(hypotheses, references, args.offset) elif metric == C.CHRF: score = raw_corpus_chrf(hypotheses, references) else: raise ValueError("Unknown metric %s." % metric) scores[metric].append(score) _print_mean_std_score(metrics, scores) else: for hypotheses in all_hypotheses: for h, r in zip(hypotheses, references): scores = defaultdict(list) for metric in metrics: if metric == C.BLEU: score = raw_corpus_bleu([h], [r], args.offset) elif metric == C.CHRF: score = raw_corpus_chrf(h, r) else: raise ValueError("Unknown metric %s." % metric) scores[metric].append(score) _print_mean_std_score(metrics, scores)
def log_basic_info(args) -> None: """ Log basic information like version number, arguments, etc. :param args: Arguments as returned by argparse. """ log_sockeye_version(logger) log_mxnet_version(logger) logger.info("Command: %s", " ".join(sys.argv)) logger.info("Arguments: %s", args)
def extract_parameters(args: argparse.Namespace): log_sockeye_version(logger) if os.path.isdir(args.input): param_path = os.path.join(args.input, C.PARAMS_BEST_NAME) else: param_path = args.input ext_params = extract(param_path, args.names, args.list_all) if len(ext_params) > 0: utils.check_condition(args.output is not None, "An output filename must be specified. (Use --output)") logger.info("Writting extracted parameters to '%s'", args.output) np.savez_compressed(args.output, **ext_params)
def main(): params = argparse.ArgumentParser( description='Evaluate translations by calculating metrics with ' 'respect to a reference set.') arguments.add_evaluate_args(params) arguments.add_logging_args(params) args = params.parse_args() if args.quiet: logger.setLevel(logging.ERROR) utils.check_condition(args.offset >= 0, "Offset should be non-negative.") log_sockeye_version(logger) logger.info("Command: %s", " ".join(sys.argv)) logger.info("Arguments: %s", args) references = [' '.join(e) for e in data_io.read_content(args.references)] hypotheses = [h.strip() for h in args.hypotheses] logger.info("%d hypotheses | %d references", len(hypotheses), len(references)) if not args.not_strict: utils.check_condition( len(hypotheses) == len(references), "Number of hypotheses (%d) and references (%d) does not match." % (len(hypotheses), len(references))) if not args.sentence: scores = [] for metric in args.metrics: if metric == C.BLEU: bleu_score = raw_corpus_bleu(hypotheses, references, args.offset) scores.append("%.6f" % bleu_score) elif metric == C.CHRF: chrf_score = raw_corpus_chrf(hypotheses, references) scores.append("%.6f" % chrf_score) print("\t".join(scores), file=sys.stdout) else: for h, r in zip(hypotheses, references): scores = [] for metric in args.metrics: if metric == C.BLEU: bleu = raw_corpus_bleu([h], [r], args.offset) scores.append("%.6f" % bleu) elif metric == C.CHRF: chrf_score = raw_corpus_chrf(h, r) scores.append("%.6f" % chrf_score) print("\t".join(scores), file=sys.stdout)
def average_parameters(args: argparse.Namespace): log_sockeye_version(logger) if len(args.inputs) > 1: avg_params = average(args.inputs) else: param_paths = find_checkpoints(model_path=args.inputs[0], size=args.n, strategy=args.strategy, metric=args.metric) avg_params = average(param_paths) mx.nd.save(args.output, avg_params) logger.info("Averaged parameters written to '%s'", args.output)
def main(): """ Commandline interface to average parameters. """ log_sockeye_version(logger) params = argparse.ArgumentParser(description="Averages parameters from multiple models.") sockeye.arguments.add_average_args(params) args = params.parse_args() if len(args.inputs) > 1: avg_params = average(args.inputs) else: param_paths = find_checkpoints(args.inputs[0], args.n, args.strategy, args.max, args.metric) avg_params = average(param_paths) mx.nd.save(args.output, avg_params) logger.info("Averaged parameters written to '%s'", args.output)
def main(): """ Commandline interface to extract parameters. """ log_sockeye_version(logger) params = argparse.ArgumentParser(description="Extract specific parameters.") arguments.add_extract_args(params) args = params.parse_args() if os.path.isdir(args.input): param_path = os.path.join(args.input, C.PARAMS_BEST_NAME) else: param_path = args.input ext_params = extract(param_path, args.names, args.list_all) if len(ext_params) > 0: utils.check_condition(args.output is not None, "An output filename must be specified. (Use --output)") logger.info("Writting extracted parameters to '%s'", args.output) np.savez_compressed(args.output, **ext_params)
def main(): """ Commandline interface to initialize Sockeye embedding weights with pretrained word representations. """ log_sockeye_version(logger) params = argparse.ArgumentParser( description='Quick usage: python3 -m sockeye.init_embedding ' '-w embed-in-src.npy embed-in-tgt.npy ' '-i vocab-in-src.json vocab-in-tgt.json ' '-o vocab-out-src.json vocab-out-tgt.json ' '-n source_embed_weight target_embed_weight ' '-f params.init') arguments.add_init_embedding_args(params) args = params.parse_args() if len(args.weight_files) != len(args.vocabularies_in) or \ len(args.weight_files) != len(args.vocabularies_out) or \ len(args.weight_files) != len(args.names): logger.error( "Exactly the same number of 'input weight files', 'input vocabularies', " "'output vocabularies' and 'Sockeye parameter names' should be provided." ) sys.exit(1) params = {} # type: Dict[str, mx.nd.NDArray] weight_file_cache = {} # type: Dict[str, np.ndarray] for weight_file, vocab_in_file, vocab_out_file, name in zip( args.weight_files, args.vocabularies_in, args.vocabularies_out, args.names): weight = load_weight(weight_file, name, weight_file_cache) logger.info('Loading input/output vocabularies: %s %s', vocab_in_file, vocab_out_file) vocab_in = vocab.vocab_from_json(vocab_in_file, encoding=args.encoding) vocab_out = vocab.vocab_from_json(vocab_out_file) logger.info('Initializing parameter: %s', name) initializer = mx.init.Normal(sigma=np.std(weight)) params[name] = init_weight(weight, vocab_in, vocab_out, initializer) logger.info('Saving initialized parameters to %s', args.file) utils.save_params(params, args.file)
def main(): params = argparse.ArgumentParser(description='Evaluate translations by calculating 4-BLEU ' 'score with respect to a reference set') params.add_argument('--references', '-r', required=True, type=str, help="File with references") params.add_argument('--hypotheses', '-i', required=True, type=str, help="File with references") params.add_argument('--quiet', '-q', action="store_true", help="Do not print logging information") params.add_argument('--sentence', '-s', action="store_true", help="Show sentence-BLEU") params.add_argument('--offset', type=float, default=0.01, help="Numerical value of the offset of zero n-gram counts") args = params.parse_args() check_condition(args.offset >= 0, "Offset should be non-negative.") logger = setup_main_logger(__name__, file_logging=False) log_sockeye_version(logger) if args.quiet: logger.setLevel(logging.ERROR) logger.info("Command: %s", " ".join(sys.argv)) logger.info("Arguments: %s", args) hypotheses = [' '.join(e) for e in read_content(args.hypotheses)] references = [' '.join(e) for e in read_content(args.references)] logger.info("Loaded %d hypotheses", len(hypotheses)) logger.info("Loaded %d references", len(references)) check_condition(len(hypotheses) == len(references), "Hypotheses and references have different number of lines.") if not args.sentence: bleu = corpus_bleu(hypotheses, references, args.offset) print(bleu, file=sys.stdout) else: for h, r in zip(hypotheses, references): bleu = bleu_from_counts(bleu_counts(h, r), args.offset) print(bleu, file=sys.stdout)
def init_embeddings(args: argparse.Namespace): log_sockeye_version(logger) if len(args.weight_files) != len(args.vocabularies_in) or \ len(args.weight_files) != len(args.vocabularies_out) or \ len(args.weight_files) != len(args.names): logger.error("Exactly the same number of 'input weight files', 'input vocabularies', " "'output vocabularies' and 'Sockeye parameter names' should be provided.") sys.exit(1) params = {} # type: Dict[str, mx.nd.NDArray] weight_file_cache = {} # type: Dict[str, np.ndarray] for weight_file, vocab_in_file, vocab_out_file, name in zip(args.weight_files, args.vocabularies_in, args.vocabularies_out, args.names): weight = load_weight(weight_file, name, weight_file_cache) logger.info('Loading input/output vocabularies: %s %s', vocab_in_file, vocab_out_file) vocab_in = vocab.vocab_from_json(vocab_in_file, encoding=args.encoding) vocab_out = vocab.vocab_from_json(vocab_out_file) logger.info('Initializing parameter: %s', name) initializer = mx.init.Normal(sigma=np.std(weight)) params[name] = init_weight(weight, vocab_in, vocab_out, initializer) logger.info('Saving initialized parameters to %s', args.file) utils.save_params(params, args.file)
def annotate_model_params(model_dir: str): log_sockeye_version(logger) params_best = os.path.join(model_dir, C.PARAMS_BEST_NAME) params_best_float32 = os.path.join(model_dir, C.PARAMS_BEST_NAME_FLOAT32) config = os.path.join(model_dir, C.CONFIG_NAME) config_float32 = os.path.join(model_dir, C.CONFIG_NAME_FLOAT32) for fname in params_best_float32, config_float32: check_condition( not os.path.exists(fname), 'File "%s" exists, indicating this model has already been quantized.' % fname) # Load model and compute scaling factors model, _, __ = sockeye.model.load_model(model_dir, for_disk_saving='float32', dtype='int8') # Move original params and config files os.rename(params_best, params_best_float32) os.rename(config, config_float32) # Write new params and config files with annotated scaling factors model.save_parameters(params_best) model.save_config(model_dir)
def main(): params = argparse.ArgumentParser(description='CLI to train sockeye sequence-to-sequence models.') arguments.add_train_cli_args(params) args = params.parse_args() # seed the RNGs np.random.seed(args.seed) random.seed(args.seed) mx.random.seed(args.seed) if args.use_fused_rnn: check_condition(not args.use_cpu, "GPU required for FusedRNN cells") check_condition(args.optimized_metric == C.BLEU or args.optimized_metric in args.metrics, "Must optimize either BLEU or one of tracked metrics (--metrics)") # Checking status of output folder, resumption, etc. # Create temporary logger to console only logger = setup_main_logger(__name__, file_logging=False, console=not args.quiet) output_folder = os.path.abspath(args.output) resume_training = False training_state_dir = os.path.join(output_folder, C.TRAINING_STATE_DIRNAME) if os.path.exists(output_folder): if args.overwrite_output: logger.info("Removing existing output folder %s.", output_folder) shutil.rmtree(output_folder) os.makedirs(output_folder) elif os.path.exists(training_state_dir): with open(os.path.join(output_folder, C.ARGS_STATE_NAME), "r") as fp: old_args = json.load(fp) arg_diffs = _dict_difference(vars(args), old_args) | _dict_difference(old_args, vars(args)) # Remove args that may differ without affecting the training. arg_diffs -= set(C.ARGS_MAY_DIFFER) # allow different device-ids provided their total count is the same if 'device_ids' in arg_diffs and len(old_args['device_ids']) == len(vars(args)['device_ids']): arg_diffs.discard('device_ids') if not arg_diffs: resume_training = True else: # We do not have the logger yet logger.error("Mismatch in arguments for training continuation.") logger.error("Differing arguments: %s.", ", ".join(arg_diffs)) sys.exit(1) elif os.path.exists(os.path.join(output_folder, C.PARAMS_BEST_NAME)): logger.error("Refusing to overwrite model folder %s as it seems to contain a trained model.", output_folder) sys.exit(1) else: logger.info("The output folder %s already exists, but no training state or parameter file was found. " "Will start training from scratch.", output_folder) else: os.makedirs(output_folder) logger = setup_main_logger(__name__, file_logging=True, console=not args.quiet, path=os.path.join(output_folder, C.LOG_NAME)) log_sockeye_version(logger) log_mxnet_version(logger) logger.info("Command: %s", " ".join(sys.argv)) logger.info("Arguments: %s", args) with open(os.path.join(output_folder, C.ARGS_STATE_NAME), "w") as fp: json.dump(vars(args), fp) with ExitStack() as exit_stack: # context if args.use_cpu: logger.info("Device: CPU") context = [mx.cpu()] else: num_gpus = get_num_gpus() check_condition(num_gpus >= 1, "No GPUs found, consider running on the CPU with --use-cpu " "(note: check depends on nvidia-smi and this could also mean that the nvidia-smi " "binary isn't on the path).") if args.disable_device_locking: context = expand_requested_device_ids(args.device_ids) else: context = exit_stack.enter_context(acquire_gpus(args.device_ids, lock_dir=args.lock_dir)) logger.info("Device(s): GPU %s", context) context = [mx.gpu(gpu_id) for gpu_id in context] # load existing or create vocabs if resume_training: vocab_source = vocab.vocab_from_json_or_pickle(os.path.join(output_folder, C.VOCAB_SRC_NAME)) vocab_target = vocab.vocab_from_json_or_pickle(os.path.join(output_folder, C.VOCAB_TRG_NAME)) else: num_words_source, num_words_target = args.num_words word_min_count_source, word_min_count_target = args.word_min_count # if the source and target embeddings are tied we build a joint vocabulary: if args.weight_tying and C.WEIGHT_TYING_SRC in args.weight_tying_type \ and C.WEIGHT_TYING_TRG in args.weight_tying_type: vocab_source = vocab_target = _build_or_load_vocab(args.source_vocab, [args.source, args.target], num_words_source, word_min_count_source) else: vocab_source = _build_or_load_vocab(args.source_vocab, [args.source], num_words_source, word_min_count_source) vocab_target = _build_or_load_vocab(args.target_vocab, [args.target], num_words_target, word_min_count_target) # write vocabularies vocab.vocab_to_json(vocab_source, os.path.join(output_folder, C.VOCAB_SRC_NAME) + C.JSON_SUFFIX) vocab.vocab_to_json(vocab_target, os.path.join(output_folder, C.VOCAB_TRG_NAME) + C.JSON_SUFFIX) vocab_source_size = len(vocab_source) vocab_target_size = len(vocab_target) logger.info("Vocabulary sizes: source=%d target=%d", vocab_source_size, vocab_target_size) # create data iterators max_seq_len_source, max_seq_len_target = args.max_seq_len batch_num_devices = 1 if args.use_cpu else sum(-di if di < 0 else 1 for di in args.device_ids) train_iter, eval_iter, config_data = data_io.get_training_data_iters(source=os.path.abspath(args.source), target=os.path.abspath(args.target), validation_source=os.path.abspath( args.validation_source), validation_target=os.path.abspath( args.validation_target), vocab_source=vocab_source, vocab_target=vocab_target, vocab_source_path=args.source_vocab, vocab_target_path=args.target_vocab, batch_size=args.batch_size, batch_by_words=args.batch_type == C.BATCH_TYPE_WORD, batch_num_devices=batch_num_devices, fill_up=args.fill_up, max_seq_len_source=max_seq_len_source, max_seq_len_target=max_seq_len_target, bucketing=not args.no_bucketing, bucket_width=args.bucket_width) # learning rate scheduling learning_rate_half_life = none_if_negative(args.learning_rate_half_life) # TODO: The loading for continuation of the scheduler is done separately from the other parts if not resume_training: lr_scheduler_instance = lr_scheduler.get_lr_scheduler(args.learning_rate_scheduler_type, args.checkpoint_frequency, learning_rate_half_life, args.learning_rate_reduce_factor, args.learning_rate_reduce_num_not_improved, args.learning_rate_schedule, args.learning_rate_warmup) else: with open(os.path.join(training_state_dir, C.SCHEDULER_STATE_NAME), "rb") as fp: lr_scheduler_instance = pickle.load(fp) # model configuration num_embed_source, num_embed_target = args.num_embed encoder_num_layers, decoder_num_layers = args.num_layers encoder_embed_dropout, decoder_embed_dropout = args.embed_dropout encoder_rnn_dropout_inputs, decoder_rnn_dropout_inputs = args.rnn_dropout_inputs encoder_rnn_dropout_states, decoder_rnn_dropout_states = args.rnn_dropout_states if encoder_embed_dropout > 0 and encoder_rnn_dropout_inputs > 0: logger.warning("Setting encoder RNN AND source embedding dropout > 0 leads to " "two dropout layers on top of each other.") if decoder_embed_dropout > 0 and decoder_rnn_dropout_inputs > 0: logger.warning("Setting encoder RNN AND source embedding dropout > 0 leads to " "two dropout layers on top of each other.") encoder_rnn_dropout_recurrent, decoder_rnn_dropout_recurrent = args.rnn_dropout_recurrent if encoder_rnn_dropout_recurrent > 0 or decoder_rnn_dropout_recurrent > 0: check_condition(args.rnn_cell_type == C.LSTM_TYPE, "Recurrent dropout without memory loss only supported for LSTMs right now.") encoder_transformer_preprocess, decoder_transformer_preprocess = args.transformer_preprocess encoder_transformer_postprocess, decoder_transformer_postprocess = args.transformer_postprocess config_conv = None if args.encoder == C.RNN_WITH_CONV_EMBED_NAME: config_conv = encoder.ConvolutionalEmbeddingConfig(num_embed=num_embed_source, max_filter_width=args.conv_embed_max_filter_width, num_filters=args.conv_embed_num_filters, pool_stride=args.conv_embed_pool_stride, num_highway_layers=args.conv_embed_num_highway_layers, dropout=args.conv_embed_dropout) if args.encoder in (C.TRANSFORMER_TYPE, C.TRANSFORMER_WITH_CONV_EMBED_TYPE): config_encoder = transformer.TransformerConfig( model_size=args.transformer_model_size, attention_heads=args.transformer_attention_heads, feed_forward_num_hidden=args.transformer_feed_forward_num_hidden, num_layers=encoder_num_layers, vocab_size=vocab_source_size, dropout_attention=args.transformer_dropout_attention, dropout_relu=args.transformer_dropout_relu, dropout_prepost=args.transformer_dropout_prepost, weight_tying=args.weight_tying and C.WEIGHT_TYING_SRC in args.weight_tying_type, positional_encodings=not args.transformer_no_positional_encodings, preprocess_sequence=encoder_transformer_preprocess, postprocess_sequence=encoder_transformer_postprocess, conv_config=config_conv) else: config_encoder = encoder.RecurrentEncoderConfig( vocab_size=vocab_source_size, num_embed=num_embed_source, embed_dropout=encoder_embed_dropout, rnn_config=rnn.RNNConfig(cell_type=args.rnn_cell_type, num_hidden=args.rnn_num_hidden, num_layers=encoder_num_layers, dropout_inputs=encoder_rnn_dropout_inputs, dropout_states=encoder_rnn_dropout_states, dropout_recurrent=encoder_rnn_dropout_recurrent, residual=args.rnn_residual_connections, first_residual_layer=args.rnn_first_residual_layer, forget_bias=args.rnn_forget_bias), conv_config=config_conv, reverse_input=args.rnn_encoder_reverse_input) decoder_weight_tying = args.weight_tying and C.WEIGHT_TYING_TRG in args.weight_tying_type \ and C.WEIGHT_TYING_SOFTMAX in args.weight_tying_type if args.decoder == C.TRANSFORMER_TYPE: config_decoder = transformer.TransformerConfig( model_size=args.transformer_model_size, attention_heads=args.transformer_attention_heads, feed_forward_num_hidden=args.transformer_feed_forward_num_hidden, num_layers=decoder_num_layers, vocab_size=vocab_target_size, dropout_attention=args.transformer_dropout_attention, dropout_relu=args.transformer_dropout_relu, dropout_prepost=args.transformer_dropout_prepost, weight_tying=decoder_weight_tying, positional_encodings=not args.transformer_no_positional_encodings, preprocess_sequence=decoder_transformer_preprocess, postprocess_sequence=decoder_transformer_postprocess, conv_config=None) else: attention_num_hidden = args.rnn_num_hidden if not args.attention_num_hidden else args.attention_num_hidden config_coverage = None if args.attention_type == C.ATT_COV: config_coverage = coverage.CoverageConfig(type=args.attention_coverage_type, num_hidden=args.attention_coverage_num_hidden, layer_normalization=args.layer_normalization) config_attention = attention.AttentionConfig(type=args.attention_type, num_hidden=attention_num_hidden, input_previous_word=args.attention_use_prev_word, rnn_num_hidden=args.rnn_num_hidden, layer_normalization=args.layer_normalization, config_coverage=config_coverage, num_heads=args.attention_mhdot_heads) config_decoder = decoder.RecurrentDecoderConfig( vocab_size=vocab_target_size, max_seq_len_source=max_seq_len_source, num_embed=num_embed_target, rnn_config=rnn.RNNConfig(cell_type=args.rnn_cell_type, num_hidden=args.rnn_num_hidden, num_layers=decoder_num_layers, dropout_inputs=decoder_rnn_dropout_inputs, dropout_states=decoder_rnn_dropout_states, dropout_recurrent=decoder_rnn_dropout_recurrent, residual=args.rnn_residual_connections, first_residual_layer=args.rnn_first_residual_layer, forget_bias=args.rnn_forget_bias), attention_config=config_attention, embed_dropout=decoder_embed_dropout, hidden_dropout=args.rnn_decoder_hidden_dropout, weight_tying=decoder_weight_tying, state_init=args.rnn_decoder_state_init, context_gating=args.rnn_context_gating, layer_normalization=args.layer_normalization, attention_in_upper_layers=args.attention_in_upper_layers) config_loss = loss.LossConfig(type=args.loss, vocab_size=vocab_target_size, normalize=args.normalize_loss, smoothed_cross_entropy_alpha=args.smoothed_cross_entropy_alpha) model_config = model.ModelConfig(config_data=config_data, max_seq_len_source=max_seq_len_source, max_seq_len_target=max_seq_len_target, vocab_source_size=vocab_source_size, vocab_target_size=vocab_target_size, config_encoder=config_encoder, config_decoder=config_decoder, config_loss=config_loss, lexical_bias=args.lexical_bias, learn_lexical_bias=args.learn_lexical_bias, weight_tying=args.weight_tying, weight_tying_type=args.weight_tying_type if args.weight_tying else None) model_config.freeze() # create training model training_model = training.TrainingModel(config=model_config, context=context, train_iter=train_iter, fused=args.use_fused_rnn, bucketing=not args.no_bucketing, lr_scheduler=lr_scheduler_instance) # We may consider loading the params in TrainingModule, for consistency # with the training state saving if resume_training: logger.info("Found partial training in directory %s. Resuming from saved state.", training_state_dir) training_model.load_params_from_file(os.path.join(training_state_dir, C.TRAINING_STATE_PARAMS_NAME)) elif args.params: logger.info("Training will initialize from parameters loaded from '%s'", args.params) training_model.load_params_from_file(args.params) lexicon_array = lexicon.initialize_lexicon(args.lexical_bias, vocab_source, vocab_target) if args.lexical_bias else None weight_initializer = initializer.get_initializer(args.weight_init, args.weight_init_scale, args.rnn_h2h_init, lexicon=lexicon_array) optimizer = args.optimizer optimizer_params = {'wd': args.weight_decay, "learning_rate": args.initial_learning_rate} if lr_scheduler_instance is not None: optimizer_params["lr_scheduler"] = lr_scheduler_instance clip_gradient = none_if_negative(args.clip_gradient) if clip_gradient is not None: optimizer_params["clip_gradient"] = clip_gradient if args.momentum is not None: optimizer_params["momentum"] = args.momentum if args.normalize_loss: # When normalize_loss is turned on we normalize by the number of non-PAD symbols in a batch which implicitly # already contains the number of sentences and therefore we need to disable rescale_grad. optimizer_params["rescale_grad"] = 1.0 else: # Making MXNet module API's default scaling factor explicit optimizer_params["rescale_grad"] = 1.0 / args.batch_size logger.info("Optimizer: %s", optimizer) logger.info("Optimizer Parameters: %s", optimizer_params) # Handle options that override training settings max_updates = args.max_updates max_num_checkpoint_not_improved = args.max_num_checkpoint_not_improved min_num_epochs = args.min_num_epochs # Fixed training schedule always runs for a set number of updates if args.learning_rate_schedule: max_updates = sum(num_updates for (_, num_updates) in args.learning_rate_schedule) max_num_checkpoint_not_improved = -1 min_num_epochs = 0 monitor_bleu = args.monitor_bleu # Turn on BLEU monitoring when the optimized metric is BLEU and it hasn't been enabled yet if args.optimized_metric == C.BLEU and monitor_bleu == 0: logger.info("You chose BLEU as the optimized metric, will turn on BLEU monitoring during training. " "To control how many validation sentences are used for calculating bleu use " "the --monitor-bleu argument.") monitor_bleu = -1 training_model.fit(train_iter, eval_iter, output_folder=output_folder, max_params_files_to_keep=args.keep_last_params, metrics=args.metrics, initializer=weight_initializer, max_updates=max_updates, checkpoint_frequency=args.checkpoint_frequency, optimizer=optimizer, optimizer_params=optimizer_params, optimized_metric=args.optimized_metric, max_num_not_improved=max_num_checkpoint_not_improved, min_num_epochs=min_num_epochs, monitor_bleu=monitor_bleu, use_tensorboard=args.use_tensorboard, mxmonitor_pattern=args.monitor_pattern, mxmonitor_stat_func=args.monitor_stat_func)
def main(): params = argparse.ArgumentParser( description='CLI to train sockeye sequence-to-sequence models.') arguments.add_io_args(params) arguments.add_model_parameters(params) arguments.add_training_args(params) arguments.add_device_args(params) args = params.parse_args() # seed the RNGs np.random.seed(args.seed) random.seed(args.seed) mx.random.seed(args.seed) if args.use_fused_rnn: check_condition(not args.use_cpu, "GPU required for FusedRNN cells") if args.rnn_residual_connections: check_condition(args.rnn_num_layers > 2, "Residual connections require at least 3 RNN layers") check_condition( args.optimized_metric == C.BLEU or args.optimized_metric in args.metrics, "Must optimize either BLEU or one of tracked metrics (--metrics)") # Checking status of output folder, resumption, etc. # Create temporary logger to console only logger = setup_main_logger(__name__, file_logging=False, console=not args.quiet) output_folder = os.path.abspath(args.output) resume_training = False training_state_dir = os.path.join(output_folder, C.TRAINING_STATE_DIRNAME) if os.path.exists(output_folder): if args.overwrite_output: logger.info("Removing existing output folder %s.", output_folder) shutil.rmtree(output_folder) os.makedirs(output_folder) elif os.path.exists(training_state_dir): with open(os.path.join(output_folder, C.ARGS_STATE_NAME), "r") as fp: old_args = json.load(fp) arg_diffs = _dict_difference( vars(args), old_args) | _dict_difference(old_args, vars(args)) # Remove args that may differ without affecting the training. arg_diffs -= set(C.ARGS_MAY_DIFFER) # allow different device-ids provided their total count is the same if 'device_ids' in arg_diffs and len( old_args['device_ids']) == len(vars(args)['device_ids']): arg_diffs.discard('device_ids') if not arg_diffs: resume_training = True else: # We do not have the logger yet logger.error( "Mismatch in arguments for training continuation.") logger.error("Differing arguments: %s.", ", ".join(arg_diffs)) sys.exit(1) else: logger.error("Refusing to overwrite existing output folder %s.", output_folder) sys.exit(1) else: os.makedirs(output_folder) logger = setup_main_logger(__name__, file_logging=True, console=not args.quiet, path=os.path.join(output_folder, C.LOG_NAME)) log_sockeye_version(logger) logger.info("Command: %s", " ".join(sys.argv)) logger.info("Arguments: %s", args) with open(os.path.join(output_folder, C.ARGS_STATE_NAME), "w") as fp: json.dump(vars(args), fp) with ExitStack() as exit_stack: # context if args.use_cpu: logger.info("Device: CPU") context = [mx.cpu()] else: num_gpus = get_num_gpus() check_condition( num_gpus >= 1, "No GPUs found, consider running on the CPU with --use-cpu " "(note: check depends on nvidia-smi and this could also mean that the nvidia-smi " "binary isn't on the path).") if args.disable_device_locking: context = expand_requested_device_ids(args.device_ids) else: context = exit_stack.enter_context( acquire_gpus(args.device_ids, lock_dir=args.lock_dir)) logger.info("Device(s): GPU %s", context) context = [mx.gpu(gpu_id) for gpu_id in context] # load existing or create vocabs if resume_training: vocab_source = vocab.vocab_from_json_or_pickle( os.path.join(output_folder, C.VOCAB_SRC_NAME)) vocab_target = vocab.vocab_from_json_or_pickle( os.path.join(output_folder, C.VOCAB_TRG_NAME)) else: num_words_source = args.num_words if args.num_words_source is None else args.num_words_source vocab_source = _build_or_load_vocab(args.source_vocab, args.source, num_words_source, args.word_min_count) vocab.vocab_to_json( vocab_source, os.path.join(output_folder, C.VOCAB_SRC_NAME) + C.JSON_SUFFIX) num_words_target = args.num_words if args.num_words_target is None else args.num_words_target vocab_target = _build_or_load_vocab(args.target_vocab, args.target, num_words_target, args.word_min_count) vocab.vocab_to_json( vocab_target, os.path.join(output_folder, C.VOCAB_TRG_NAME) + C.JSON_SUFFIX) vocab_source_size = len(vocab_source) vocab_target_size = len(vocab_target) logger.info("Vocabulary sizes: source=%d target=%d", vocab_source_size, vocab_target_size) config_data = data_io.DataConfig( os.path.abspath(args.source), os.path.abspath(args.target), os.path.abspath(args.validation_source), os.path.abspath(args.validation_target), args.source_vocab, args.target_vocab) # create data iterators max_seq_len_source = args.max_seq_len if args.max_seq_len_source is None else args.max_seq_len_source max_seq_len_target = args.max_seq_len if args.max_seq_len_target is None else args.max_seq_len_target train_iter, eval_iter = data_io.get_training_data_iters( source=config_data.source, target=config_data.target, validation_source=config_data.validation_source, validation_target=config_data.validation_target, vocab_source=vocab_source, vocab_target=vocab_target, batch_size=args.batch_size, fill_up=args.fill_up, max_seq_len_source=max_seq_len_source, max_seq_len_target=max_seq_len_target, bucketing=not args.no_bucketing, bucket_width=args.bucket_width) # learning rate scheduling learning_rate_half_life = none_if_negative( args.learning_rate_half_life) # TODO: The loading for continuation of the scheduler is done separately from the other parts if not resume_training: lr_scheduler_instance = lr_scheduler.get_lr_scheduler( args.learning_rate_scheduler_type, args.checkpoint_frequency, learning_rate_half_life, args.learning_rate_reduce_factor, args.learning_rate_reduce_num_not_improved) else: with open(os.path.join(training_state_dir, C.SCHEDULER_STATE_NAME), "rb") as fp: lr_scheduler_instance = pickle.load(fp) # model configuration num_embed_source = args.num_embed if args.num_embed_source is None else args.num_embed_source num_embed_target = args.num_embed if args.num_embed_target is None else args.num_embed_target config_rnn = rnn.RNNConfig(cell_type=args.rnn_cell_type, num_hidden=args.rnn_num_hidden, num_layers=args.rnn_num_layers, dropout=args.dropout, residual=args.rnn_residual_connections, forget_bias=args.rnn_forget_bias) config_conv = None if args.encoder == C.RNN_WITH_CONV_EMBED_NAME: config_conv = encoder.ConvolutionalEmbeddingConfig( num_embed=num_embed_source, max_filter_width=args.conv_embed_max_filter_width, num_filters=args.conv_embed_num_filters, pool_stride=args.conv_embed_pool_stride, num_highway_layers=args.conv_embed_num_highway_layers, dropout=args.dropout) config_encoder = encoder.RecurrentEncoderConfig( vocab_size=vocab_source_size, num_embed=num_embed_source, rnn_config=config_rnn, conv_config=config_conv) config_decoder = decoder.RecurrentDecoderConfig( vocab_size=vocab_target_size, num_embed=num_embed_target, rnn_config=config_rnn, dropout=args.dropout, weight_tying=args.weight_tying, context_gating=args.context_gating, layer_normalization=args.layer_normalization) attention_num_hidden = args.rnn_num_hidden if not args.attention_num_hidden else args.attention_num_hidden config_coverage = None if args.attention_type == "coverage": config_coverage = coverage.CoverageConfig( type=args.attention_coverage_type, num_hidden=args.attention_coverage_num_hidden, layer_normalization=args.layer_normalization) config_attention = attention.AttentionConfig( type=args.attention_type, num_hidden=attention_num_hidden, input_previous_word=args.attention_use_prev_word, rnn_num_hidden=config_rnn.num_hidden, layer_normalization=args.layer_normalization, config_coverage=config_coverage) config_loss = loss.LossConfig( type=args.loss, vocab_size=vocab_target_size, normalize=args.normalize_loss, smoothed_cross_entropy_alpha=args.smoothed_cross_entropy_alpha) model_config = model.ModelConfig( config_data=config_data, max_seq_len=max_seq_len_source, vocab_source_size=vocab_source_size, vocab_target_size=vocab_target_size, config_encoder=config_encoder, config_decoder=config_decoder, config_attention=config_attention, config_loss=config_loss, lexical_bias=args.lexical_bias, learn_lexical_bias=args.learn_lexical_bias) model_config.freeze() # create training model training_model = training.TrainingModel( config=model_config, context=context, train_iter=train_iter, fused=args.use_fused_rnn, bucketing=not args.no_bucketing, lr_scheduler=lr_scheduler_instance) # We may consider loading the params in TrainingModule, for consistency # with the training state saving if resume_training: logger.info( "Found partial training in directory %s. Resuming from saved state.", training_state_dir) training_model.load_params_from_file( os.path.join(training_state_dir, C.TRAINING_STATE_PARAMS_NAME)) elif args.params: logger.info( "Training will initialize from parameters loaded from '%s'", args.params) training_model.load_params_from_file(args.params) lexicon_array = lexicon.initialize_lexicon( args.lexical_bias, vocab_source, vocab_target) if args.lexical_bias else None weight_initializer = initializer.get_initializer(args.rnn_h2h_init, lexicon=lexicon_array) optimizer = args.optimizer optimizer_params = { 'wd': args.weight_decay, "learning_rate": args.initial_learning_rate } if lr_scheduler_instance is not None: optimizer_params["lr_scheduler"] = lr_scheduler_instance clip_gradient = none_if_negative(args.clip_gradient) if clip_gradient is not None: optimizer_params["clip_gradient"] = clip_gradient if args.momentum is not None: optimizer_params["momentum"] = args.momentum if args.normalize_loss: # When normalize_loss is turned on we normalize by the number of non-PAD symbols in a batch which implicitly # already contains the number of sentences and therefore we need to disable rescale_grad. optimizer_params["rescale_grad"] = 1.0 else: # Making MXNet module API's default scaling factor explicit optimizer_params["rescale_grad"] = 1.0 / args.batch_size logger.info("Optimizer: %s", optimizer) logger.info("Optimizer Parameters: %s", optimizer_params) training_model.fit( train_iter, eval_iter, output_folder=output_folder, max_params_files_to_keep=args.keep_last_params, metrics=args.metrics, initializer=weight_initializer, max_updates=args.max_updates, checkpoint_frequency=args.checkpoint_frequency, optimizer=optimizer, optimizer_params=optimizer_params, optimized_metric=args.optimized_metric, max_num_not_improved=args.max_num_checkpoint_not_improved, min_num_epochs=args.min_num_epochs, monitor_bleu=args.monitor_bleu, use_tensorboard=args.use_tensorboard)