Ejemplo n.º 1
0
        open input vocabularies with specified encoding (default: utf-8)
"""

import argparse
import sys
from typing import Dict

import numpy as np
import mxnet as mx

from sockeye.log import setup_main_logger, log_sockeye_version
from . import arguments
from . import utils
from . import vocab

logger = setup_main_logger(__name__, console=True, file_logging=False)


def init_weight(weight: np.ndarray,
                vocab_in: Dict[str, int],
                vocab_out: Dict[str, int],
                initializer: mx.initializer.Initializer=mx.init.Constant(value=0.0)) -> mx.nd.NDArray:
    """
    Initialize vocabulary-sized weight by existing values given input and output vocabularies.

    :param weight: Input weight.
    :param vocab_in: Input vocabulary.
    :param vocab_out: Output vocabulary.
    :param initializer: MXNet initializer.
    :return: Initialized output weight.
    """
Ejemplo n.º 2
0
import sys
import time
from contextlib import ExitStack
from math import ceil
from typing import Generator, Optional, List

from sockeye.lexicon import TopKLexicon
from sockeye.log import setup_main_logger
from sockeye.output_handler import get_output_handler, OutputHandler
from sockeye.utils import determine_context, log_basic_info, check_condition, grouper
from . import arguments
from . import constants as C
from . import data_io
from . import inference

logger = setup_main_logger(__name__, file_logging=False)


def main():
    params = arguments.ConfigArgumentParser(description='Translate CLI')
    arguments.add_translate_cli_args(params)
    args = params.parse_args()
    run_translate(args)


def run_translate(args: argparse.Namespace):

    if args.output is not None:
        global logger
        logger = setup_main_logger(__name__,
                                   console=not args.quiet,
Ejemplo n.º 3
0
def run_translate(args: argparse.Namespace):

    if args.output is not None:
        global logger
        logger = setup_main_logger(__name__,
                                   console=not args.quiet,
                                   file_logging=True,
                                   path="%s.%s" % (args.output, C.LOG_NAME))

    if args.checkpoints is not None:
        check_condition(len(args.checkpoints) == len(args.models), "must provide checkpoints for each model")

    log_basic_info(args)

    output_handler = get_output_handler(args.output_type,
                                        args.output,
                                        args.sure_align_threshold)

    with ExitStack() as exit_stack:
        check_condition(len(args.device_ids) == 1, "translate only supports single device for now")
        context = determine_context(device_ids=args.device_ids,
                                    use_cpu=args.use_cpu,
                                    disable_device_locking=args.disable_device_locking,
                                    lock_dir=args.lock_dir,
                                    exit_stack=exit_stack)[0]
        logger.info("Translate Device: %s", context)

        if args.override_dtype == C.DTYPE_FP16:
            logger.warning('Experimental feature \'--override-dtype float16\' has been used. '
                           'This feature may be removed or change its behaviour in future. '
                           'DO NOT USE IT IN PRODUCTION!')

        models, source_vocabs, target_vocab = inference.load_models(
            context=context,
            max_input_len=args.max_input_len,
            beam_size=args.beam_size,
            batch_size=args.batch_size,
            model_folders=args.models,
            checkpoints=args.checkpoints,
            softmax_temperature=args.softmax_temperature,
            max_output_length_num_stds=args.max_output_length_num_stds,
            decoder_return_logit_inputs=args.restrict_lexicon is not None,
            cache_output_layer_w_b=args.restrict_lexicon is not None,
            override_dtype=args.override_dtype)
        restrict_lexicon = None  # type: Optional[TopKLexicon]
        if args.restrict_lexicon:
            restrict_lexicon = TopKLexicon(source_vocabs[0], target_vocab)
            restrict_lexicon.load(args.restrict_lexicon, k=args.restrict_lexicon_topk)
        store_beam = args.output_type == C.OUTPUT_HANDLER_BEAM_STORE
        translator = inference.Translator(context=context,
                                          ensemble_mode=args.ensemble_mode,
                                          bucket_source_width=args.bucket_width,
                                          length_penalty=inference.LengthPenalty(args.length_penalty_alpha,
                                                                                 args.length_penalty_beta),
                                          beam_prune=args.beam_prune,
                                          beam_search_stop=args.beam_search_stop,
                                          models=models,
                                          source_vocabs=source_vocabs,
                                          target_vocab=target_vocab,
                                          restrict_lexicon=restrict_lexicon,
                                          avoid_list=args.avoid_list,
                                          store_beam=store_beam,
                                          strip_unknown_words=args.strip_unknown_words)
        read_and_translate(translator=translator,
                           output_handler=output_handler,
                           chunk_size=args.chunk_size,
                           input_file=args.input,
                           input_factors=args.input_factors,
                           input_is_json=args.json_input)
Ejemplo n.º 4
0
def run_translate(args: argparse.Namespace):

    # Seed randomly unless a seed has been passed
    utils.seed_rngs(args.seed if args.seed is not None else int(time.time()))

    if args.output is not None:
        setup_main_logger(console=not args.quiet,
                          file_logging=not args.no_logfile,
                          path="%s.%s" % (args.output, C.LOG_NAME),
                          level=args.loglevel)
    else:
        setup_main_logger(file_logging=False, level=args.loglevel)

    log_basic_info(args)

    if args.nbest_size > 1:
        if args.output_type != C.OUTPUT_HANDLER_JSON:
            logger.warning(
                "For nbest translation, you must specify `--output-type '%s'; overriding your setting of '%s'.",
                C.OUTPUT_HANDLER_JSON, args.output_type)
            args.output_type = C.OUTPUT_HANDLER_JSON
    output_handler = get_output_handler(args.output_type, args.output,
                                        args.sure_align_threshold)

    with ExitStack() as exit_stack:
        check_condition(
            len(args.device_ids) == 1,
            "translate only supports single device for now")
        context = determine_context(
            device_ids=args.device_ids,
            use_cpu=args.use_cpu,
            disable_device_locking=args.disable_device_locking,
            lock_dir=args.lock_dir,
            exit_stack=exit_stack)[0]
        logger.info("Translate Device: %s", context)

        models, source_vocabs, target_vocab = inference.load_models(
            context=context,
            max_input_len=args.max_input_len,
            beam_size=args.beam_size,
            batch_size=args.batch_size,
            model_folders=args.models,
            checkpoints=args.checkpoints,
            softmax_temperature=args.softmax_temperature,
            max_output_length_num_stds=args.max_output_length_num_stds,
            decoder_return_logit_inputs=args.restrict_lexicon is not None,
            cache_output_layer_w_b=args.restrict_lexicon is not None,
            override_dtype=args.override_dtype,
            output_scores=output_handler.reports_score(),
            sampling=args.sample)

        if any([model.config.num_pointers for model in models]):
            check_condition(
                args.restrict_lexicon is None,
                "The pointer mechanism does not currently work with vocabulary restriction."
            )

        restrict_lexicon = None  # type: Optional[Union[TopKLexicon, Dict[str, TopKLexicon]]]
        if args.restrict_lexicon is not None:
            logger.info(str(args.restrict_lexicon))
            if len(args.restrict_lexicon) == 1:
                # Single lexicon used for all inputs
                restrict_lexicon = TopKLexicon(source_vocabs[0], target_vocab)
                # Handle a single arg of key:path or path (parsed as path:path)
                restrict_lexicon.load(args.restrict_lexicon[0][1],
                                      k=args.restrict_lexicon_topk)
            else:
                check_condition(
                    args.json_input,
                    "JSON input is required when using multiple lexicons for vocabulary restriction"
                )
                # Multiple lexicons with specified names
                restrict_lexicon = dict()
                for key, path in args.restrict_lexicon:
                    lexicon = TopKLexicon(source_vocabs[0], target_vocab)
                    lexicon.load(path, k=args.restrict_lexicon_topk)
                    restrict_lexicon[key] = lexicon

        store_beam = args.output_type == C.OUTPUT_HANDLER_BEAM_STORE

        brevity_penalty_weight = args.brevity_penalty_weight
        if args.brevity_penalty_type == C.BREVITY_PENALTY_CONSTANT:
            if args.brevity_penalty_constant_length_ratio > 0.0:
                constant_length_ratio = args.brevity_penalty_constant_length_ratio
            else:
                constant_length_ratio = sum(model.length_ratio_mean
                                            for model in models) / len(models)
                logger.info(
                    "Using average of constant length ratios saved in the model configs: %f",
                    constant_length_ratio)
        elif args.brevity_penalty_type == C.BREVITY_PENALTY_LEARNED:
            constant_length_ratio = -1.0
        elif args.brevity_penalty_type == C.BREVITY_PENALTY_NONE:
            brevity_penalty_weight = 0.0
            constant_length_ratio = -1.0
        else:
            raise ValueError("Unknown brevity penalty type %s" %
                             args.brevity_penalty_type)

        brevity_penalty = None  # type: Optional[inference.BrevityPenalty]
        if brevity_penalty_weight != 0.0:
            brevity_penalty = inference.BrevityPenalty(brevity_penalty_weight)

        translator = inference.Translator(
            context=context,
            ensemble_mode=args.ensemble_mode,
            bucket_source_width=args.bucket_width,
            length_penalty=inference.LengthPenalty(args.length_penalty_alpha,
                                                   args.length_penalty_beta),
            beam_prune=args.beam_prune,
            beam_search_stop=args.beam_search_stop,
            nbest_size=args.nbest_size,
            models=models,
            source_vocabs=source_vocabs,
            target_vocab=target_vocab,
            restrict_lexicon=restrict_lexicon,
            avoid_list=args.avoid_list,
            store_beam=store_beam,
            strip_unknown_words=args.strip_unknown_words,
            skip_topk=args.skip_topk,
            sample=args.sample,
            constant_length_ratio=constant_length_ratio,
            brevity_penalty=brevity_penalty)
        read_and_translate(translator=translator,
                           output_handler=output_handler,
                           chunk_size=args.chunk_size,
                           input_file=args.input,
                           input_factors=args.input_factors,
                           input_is_json=args.json_input)
Ejemplo n.º 5
0
from math import ceil
from typing import Generator, Optional, List

import mxnet as mx

from sockeye.lexicon import TopKLexicon
from sockeye.log import setup_main_logger
from sockeye.utils import determine_context, log_basic_info, check_condition, grouper
from sockeye.output_handler import get_output_handler, OutputHandler
from . import arguments
from . import constants as C
from . import data_io
from . import inference
from . import inference_adapt_train

logger = setup_main_logger(__name__, file_logging=False)


def main():
    params = arguments.ConfigArgumentParser(description='Translate CLI')
    arguments.add_translate_cli_args(params)
    arguments.add_inference_adapt_args(params)
    args = params.parse_args()
    run_translate(args)


def run_translate(args: argparse.Namespace):

    if args.output is not None:
        global logger
        logger = setup_main_logger(__name__,
Ejemplo n.º 6
0
from . import data_io
from . import decoder
from . import encoder
from . import initializer
from . import loss
from . import lr_scheduler
from . import model
from . import rnn
from . import rnn_attention
from . import training
from . import transformer
from . import utils
from . import vocab

# Temporary logger, the real one (logging to a file probably, will be created in the main function)
logger = setup_main_logger(__name__, file_logging=False, console=True)


def none_if_negative(val):
    return None if val < 0 else val


def _list_to_tuple(v):
    """Convert v to a tuple if it is a list."""
    if isinstance(v, list):
        return tuple(v)
    return v


def _dict_difference(dict1: Dict, dict2: Dict):
    diffs = set()
Ejemplo n.º 7
0
def main():
    params = argparse.ArgumentParser(description='CLI to train sockeye sequence-to-sequence models.')
    arguments.add_train_cli_args(params)
    args = params.parse_args()

    utils.seedRNGs(args.seed)

    check_arg_compatibility(args)
    output_folder = os.path.abspath(args.output)
    resume_training, training_state_dir = check_resume(args, output_folder)

    global logger
    logger = setup_main_logger(__name__,
                               file_logging=True,
                               console=not args.quiet, path=os.path.join(output_folder, C.LOG_NAME))
    utils.log_basic_info(args)
    with open(os.path.join(output_folder, C.ARGS_STATE_NAME), "w") as fp:
        json.dump(vars(args), fp)


    with ExitStack() as exit_stack:
        context = determine_context(args, exit_stack)

        shared_vocab = use_shared_vocab(args)

        train_iter, eval_iter, config_data, vocab_source, vocab_target = create_data_iters_and_vocab(
            args=args,
            shared_vocab=shared_vocab,
            resume_training=resume_training,
            output_folder=output_folder)

        if not resume_training:
            vocab.vocab_to_json(vocab_source, os.path.join(output_folder, C.VOCAB_SRC_NAME) + C.JSON_SUFFIX)
            vocab.vocab_to_json(vocab_target, os.path.join(output_folder, C.VOCAB_TRG_NAME) + C.JSON_SUFFIX)

        vocab_source_size = len(vocab_source)
        vocab_target_size = len(vocab_target)
        logger.info("Vocabulary sizes: source=%d target=%d", vocab_source_size, vocab_target_size)
        lr_scheduler_instance = create_lr_scheduler(args, resume_training, training_state_dir)

        model_config = create_model_config(args, vocab_source_size, vocab_target_size, config_data)
        model_config.freeze()

        training_model = create_training_model(model_config, args,
                                               context, train_iter, lr_scheduler_instance,
                                               resume_training, training_state_dir)


        weight_initializer = initializer.get_initializer(
            default_init_type=args.weight_init,
            default_init_scale=args.weight_init_scale,
            default_init_xavier_rand_type=args.weight_init_xavier_rand_type,
            default_init_xavier_factor_type=args.weight_init_xavier_factor_type,
            embed_init_type=args.embed_weight_init,
            embed_init_sigma=vocab_source_size ** -0.5,  # TODO
            rnn_init_type=args.rnn_h2h_init)

        optimizer, optimizer_params, kvstore, gradient_clipping_type, gradient_clipping_threshold = define_optimizer(args, lr_scheduler_instance)

        # Handle options that override training settings
        max_updates = args.max_updates
        max_num_checkpoint_not_improved = args.max_num_checkpoint_not_improved
        min_num_epochs = args.min_num_epochs
        max_num_epochs = args.max_num_epochs
        if min_num_epochs is not None and max_num_epochs is not None:
            check_condition(min_num_epochs <= max_num_epochs,
                            "Minimum number of epochs must be smaller than maximum number of epochs")
        # Fixed training schedule always runs for a set number of updates
        if args.learning_rate_schedule:
            max_updates = sum(num_updates for (_, num_updates) in args.learning_rate_schedule)
            max_num_checkpoint_not_improved = -1
            min_num_epochs = None
            max_num_epochs = None

        decode_and_evaluate, decode_and_evaluate_context = determine_decode_and_evaluate_context(args,
                                                                                                 exit_stack,
                                                                                                 context)
        training_model.fit(train_iter, eval_iter,
                           output_folder=output_folder,
                           max_params_files_to_keep=args.keep_last_params,
                           metrics=args.metrics,
                           initializer=weight_initializer,
                           allow_missing_params=args.allow_missing_params,
                           max_updates=max_updates,
                           checkpoint_frequency=args.checkpoint_frequency,
                           optimizer=optimizer, optimizer_params=optimizer_params,
                           optimized_metric=args.optimized_metric,
                           gradient_clipping_type=gradient_clipping_type,
                           clip_gradient_threshold=gradient_clipping_threshold,
                           kvstore=kvstore,
                           max_num_not_improved=max_num_checkpoint_not_improved,
                           min_num_epochs=min_num_epochs,
                           max_num_epochs=max_num_epochs,
                           decode_and_evaluate=decode_and_evaluate,
                           decode_and_evaluate_fname_source=args.validation_source,
                           decode_and_evaluate_fname_target=args.validation_target,
                           decode_and_evaluate_context=decode_and_evaluate_context,
                           use_tensorboard=args.use_tensorboard,
                           mxmonitor_pattern=args.monitor_pattern,
                           mxmonitor_stat_func=args.monitor_stat_func,
                           lr_decay_param_reset=args.learning_rate_decay_param_reset,
                           lr_decay_opt_states_reset=args.learning_rate_decay_optimizer_states_reset)
Ejemplo n.º 8
0
def main():
    params = argparse.ArgumentParser(description='CLI to train sockeye sequence-to-sequence models.')
    arguments.add_train_cli_args(params)
    args = params.parse_args()

    utils.seedRNGs(args)

    check_arg_compatibility(args)
    output_folder = os.path.abspath(args.output)
    resume_training, training_state_dir = check_resume(args, output_folder)

    global logger
    logger = setup_main_logger(__name__,
                               file_logging=True,
                               console=not args.quiet, path=os.path.join(output_folder, C.LOG_NAME))
    utils.log_basic_info(args)
    with open(os.path.join(output_folder, C.ARGS_STATE_NAME), "w") as fp:
        json.dump(vars(args), fp)

    with ExitStack() as exit_stack:
        context = determine_context(args, exit_stack)
        vocab_source, vocab_target = load_or_create_vocabs(args, resume_training, output_folder)
        vocab_source_size = len(vocab_source)
        vocab_target_size = len(vocab_target)
        logger.info("Vocabulary sizes: source=%d target=%d", vocab_source_size, vocab_target_size)
        train_iter, eval_iter, config_data = create_data_iters(args, vocab_source, vocab_target)
        lr_scheduler_instance = create_lr_scheduler(args, resume_training, training_state_dir)

        model_config = create_model_config(args, vocab_source_size, vocab_target_size, config_data)
        model_config.freeze()

        training_model = create_training_model(model_config, args,
                                               context, train_iter, lr_scheduler_instance,
                                               resume_training, training_state_dir)

        lexicon_array = lexicon.initialize_lexicon(args.lexical_bias,
                                                   vocab_source, vocab_target) if args.lexical_bias else None

        weight_initializer = initializer.get_initializer(
            default_init_type=args.weight_init,
            default_init_scale=args.weight_init_scale,
            default_init_xavier_factor_type=args.weight_init_xavier_factor_type,
            embed_init_type=args.embed_weight_init,
            embed_init_sigma=vocab_source_size ** -0.5,  # TODO
            rnn_init_type=args.rnn_h2h_init,
            lexicon=lexicon_array)

        optimizer, optimizer_params, kvstore = define_optimizer(args, lr_scheduler_instance)

        # Handle options that override training settings
        max_updates = args.max_updates
        max_num_checkpoint_not_improved = args.max_num_checkpoint_not_improved
        min_num_epochs = args.min_num_epochs
        max_num_epochs = args.max_num_epochs
        if min_num_epochs is not None and max_num_epochs is not None:
            check_condition(min_num_epochs <= max_num_epochs,
                            "Minimum number of epochs must be smaller than maximum number of epochs")
        # Fixed training schedule always runs for a set number of updates
        if args.learning_rate_schedule:
            max_updates = sum(num_updates for (_, num_updates) in args.learning_rate_schedule)
            max_num_checkpoint_not_improved = -1
            min_num_epochs = None
            max_num_epochs = None

        monitor_bleu = args.monitor_bleu
        # Turn on BLEU monitoring when the optimized metric is BLEU and it hasn't been enabled yet
        if args.optimized_metric == C.BLEU and monitor_bleu == 0:
            logger.info("You chose BLEU as the optimized metric, will turn on BLEU monitoring during training. "
                        "To control how many validation sentences are used for calculating bleu use "
                        "the --monitor-bleu argument.")
            monitor_bleu = -1

        training_model.fit(train_iter, eval_iter,
                           output_folder=output_folder,
                           max_params_files_to_keep=args.keep_last_params,
                           metrics=args.metrics,
                           initializer=weight_initializer,
                           max_updates=max_updates,
                           checkpoint_frequency=args.checkpoint_frequency,
                           optimizer=optimizer, optimizer_params=optimizer_params,
                           optimized_metric=args.optimized_metric,
                           kvstore=kvstore,
                           max_num_not_improved=max_num_checkpoint_not_improved,
                           min_num_epochs=min_num_epochs,
                           max_num_epochs=max_num_epochs,
                           monitor_bleu=monitor_bleu,
                           use_tensorboard=args.use_tensorboard,
                           mxmonitor_pattern=args.monitor_pattern,
                           mxmonitor_stat_func=args.monitor_stat_func,
                           lr_decay_param_reset=args.learning_rate_decay_param_reset,
                           lr_decay_opt_states_reset=args.learning_rate_decay_optimizer_states_reset)
Ejemplo n.º 9
0
def run_translate(args: argparse.Namespace):

    if args.output is not None:
        global logger
        logger = setup_main_logger(__name__,
                                   console=not args.quiet,
                                   file_logging=True,
                                   path="%s.%s" % (args.output, C.LOG_NAME))

    if args.checkpoints is not None:
        check_condition(
            len(args.checkpoints) == len(args.models),
            "must provide checkpoints for each model")

    if args.skip_topk:
        check_condition(
            args.beam_size == 1,
            "--skip-topk has no effect if beam size is larger than 1")
        check_condition(
            len(args.models) == 1,
            "--skip-topk has no effect for decoding with more than 1 model")

    log_basic_info(args)

    output_handler = get_output_handler(args.output_type, args.output,
                                        args.sure_align_threshold)

    with ExitStack() as exit_stack:
        check_condition(
            len(args.device_ids) == 1,
            "translate only supports single device for now")
        edge_vocab = vocab.vocab_from_json(args.edge_vocab)
        context = determine_context(
            device_ids=args.device_ids,
            use_cpu=args.use_cpu,
            disable_device_locking=args.disable_device_locking,
            lock_dir=args.lock_dir,
            exit_stack=exit_stack)[0]
        logger.info("Translate Device: %s", context)

        if args.override_dtype == C.DTYPE_FP16:
            logger.warning(
                'Experimental feature \'--override-dtype float16\' has been used. '
                'This feature may be removed or change its behaviour in future. '
                'DO NOT USE IT IN PRODUCTION!')

        models, source_vocabs, target_vocab, edge_vocab = inference.load_models(
            context=context,
            max_input_len=args.max_input_len,
            beam_size=args.beam_size,
            batch_size=args.batch_size,
            edge_vocab=edge_vocab,
            model_folders=args.models,
            checkpoints=args.checkpoints,
            softmax_temperature=args.softmax_temperature,
            max_output_length_num_stds=args.max_output_length_num_stds,
            decoder_return_logit_inputs=args.restrict_lexicon is not None,
            cache_output_layer_w_b=args.restrict_lexicon is not None,
            override_dtype=args.override_dtype)
        restrict_lexicon = None  # type: Optional[TopKLexicon]
        if args.restrict_lexicon:
            restrict_lexicon = TopKLexicon(source_vocabs[0], target_vocab)
            restrict_lexicon.load(args.restrict_lexicon,
                                  k=args.restrict_lexicon_topk)
        store_beam = args.output_type == C.OUTPUT_HANDLER_BEAM_STORE
        translator = inference.Translator(
            context=context,
            ensemble_mode=args.ensemble_mode,
            bucket_source_width=args.bucket_width,
            length_penalty=inference.LengthPenalty(args.length_penalty_alpha,
                                                   args.length_penalty_beta),
            beam_prune=args.beam_prune,
            beam_search_stop=args.beam_search_stop,
            models=models,
            source_vocabs=source_vocabs,
            target_vocab=target_vocab,
            edge_vocab=edge_vocab,
            restrict_lexicon=restrict_lexicon,
            avoid_list=args.avoid_list,
            store_beam=store_beam,
            strip_unknown_words=args.strip_unknown_words,
            skip_topk=args.skip_topk)
        read_and_translate(translator=translator,
                           output_handler=output_handler,
                           chunk_size=args.chunk_size,
                           input_file=args.input,
                           input_factors=args.input_factors,
                           input_is_json=args.json_input)
Ejemplo n.º 10
0
def run_translate(args: argparse.Namespace):

    # Seed randomly unless a seed has been passed
    utils.seed_rngs(args.seed if args.seed is not None else int(time.time()))

    if args.output is not None:
        setup_main_logger(console=not args.quiet,
                          file_logging=True,
                          path="%s.%s" % (args.output, C.LOG_NAME))
    else:
        setup_main_logger(file_logging=False)

    log_basic_info(args)

    if args.nbest_size > 1:
        if args.output_type != C.OUTPUT_HANDLER_JSON:
            logger.warning(
                "For nbest translation, you must specify `--output-type '%s'; overriding your setting of '%s'.",
                C.OUTPUT_HANDLER_JSON, args.output_type)
            args.output_type = C.OUTPUT_HANDLER_JSON
    output_handler = get_output_handler(args.output_type, args.output,
                                        args.sure_align_threshold)

    with ExitStack() as exit_stack:
        check_condition(
            len(args.device_ids) == 1,
            "translate only supports single device for now")
        context = determine_context(
            device_ids=args.device_ids,
            use_cpu=args.use_cpu,
            disable_device_locking=args.disable_device_locking,
            lock_dir=args.lock_dir,
            exit_stack=exit_stack)[0]
        logger.info("Translate Device: %s", context)

        models, source_vocabs, target_vocab = inference.load_models(
            context=context,
            max_input_len=args.max_input_len,
            beam_size=args.beam_size,
            batch_size=args.batch_size,
            model_folders=args.models,
            checkpoints=args.checkpoints,
            softmax_temperature=args.softmax_temperature,
            max_output_length_num_stds=args.max_output_length_num_stds,
            decoder_return_logit_inputs=args.restrict_lexicon is not None,
            cache_output_layer_w_b=args.restrict_lexicon is not None,
            override_dtype=args.override_dtype,
            output_scores=output_handler.reports_score(),
            sampling=args.sample)
        restrict_lexicon = None  # type: Optional[TopKLexicon]
        if args.restrict_lexicon:
            restrict_lexicon = TopKLexicon(source_vocabs[0], target_vocab)
            restrict_lexicon.load(args.restrict_lexicon,
                                  k=args.restrict_lexicon_topk)
        store_beam = args.output_type == C.OUTPUT_HANDLER_BEAM_STORE
        translator = inference.Translator(
            context=context,
            ensemble_mode=args.ensemble_mode,
            bucket_source_width=args.bucket_width,
            length_penalty=inference.LengthPenalty(args.length_penalty_alpha,
                                                   args.length_penalty_beta),
            beam_prune=args.beam_prune,
            beam_search_stop=args.beam_search_stop,
            nbest_size=args.nbest_size,
            models=models,
            source_vocabs=source_vocabs,
            target_vocab=target_vocab,
            restrict_lexicon=restrict_lexicon,
            avoid_list=args.avoid_list,
            store_beam=store_beam,
            strip_unknown_words=args.strip_unknown_words,
            skip_topk=args.skip_topk,
            sample=args.sample)
        read_and_translate(translator=translator,
                           output_handler=output_handler,
                           chunk_size=args.chunk_size,
                           input_file=args.input,
                           input_factors=args.input_factors,
                           input_is_json=args.json_input)
Ejemplo n.º 11
0
def main():
    params = argparse.ArgumentParser(description='CLI to train sockeye sequence-to-sequence models.')
    arguments.add_io_args(params)
    arguments.add_model_parameters(params)
    arguments.add_training_args(params)
    arguments.add_device_args(params)
    args = params.parse_args()

    # seed the RNGs
    np.random.seed(args.seed)
    random.seed(args.seed)
    mx.random.seed(args.seed)

    if args.use_fused_rnn:
        check_condition(not args.use_cpu, "GPU required for FusedRNN cells")

    if args.rnn_residual_connections:
        check_condition(args.rnn_num_layers > 2, "Residual connections require at least 3 RNN layers")

    check_condition(args.optimized_metric == C.BLEU or args.optimized_metric in args.metrics,
                    "Must optimize either BLEU or one of tracked metrics (--metrics)")

    # Checking status of output folder, resumption, etc.
    # Create temporary logger to console only
    logger = setup_main_logger(__name__, file_logging=False, console=not args.quiet)
    output_folder = os.path.abspath(args.output)
    resume_training = False
    training_state_dir = os.path.join(output_folder, C.TRAINING_STATE_DIRNAME)
    if os.path.exists(output_folder):
        if args.overwrite_output:
            logger.info("Removing existing output folder %s.", output_folder)
            shutil.rmtree(output_folder)
            os.makedirs(output_folder)
        elif os.path.exists(training_state_dir):
            with open(os.path.join(output_folder, C.ARGS_STATE_NAME), "r") as fp:
                old_args = json.load(fp)
            arg_diffs = _dict_difference(vars(args), old_args) | _dict_difference(old_args, vars(args))
            # Remove args that may differ without affecting the training.
            arg_diffs -= set(C.ARGS_MAY_DIFFER)
            # allow different device-ids provided their total count is the same
            if 'device_ids' in arg_diffs and len(old_args['device_ids']) == len(vars(args)['device_ids']):
                arg_diffs.discard('device_ids')
            if not arg_diffs:
                resume_training = True
            else:
                # We do not have the logger yet
                logger.error("Mismatch in arguments for training continuation.")
                logger.error("Differing arguments: %s.", ", ".join(arg_diffs))
                sys.exit(1)
        else:
            logger.error("Refusing to overwrite existing output folder %s.", output_folder)
            sys.exit(1)
    else:
        os.makedirs(output_folder)

    logger = setup_main_logger(__name__,
                               file_logging=True,
                               console=not args.quiet, path=os.path.join(output_folder, C.LOG_NAME))
    log_sockeye_version(logger)
    logger.info("Command: %s", " ".join(sys.argv))
    logger.info("Arguments: %s", args)
    with open(os.path.join(output_folder, C.ARGS_STATE_NAME), "w") as fp:
        json.dump(vars(args), fp)

    with ExitStack() as exit_stack:
        # context
        if args.use_cpu:
            logger.info("Device: CPU")
            context = [mx.cpu()]
        else:
            num_gpus = get_num_gpus()
            check_condition(num_gpus >= 1,
                            "No GPUs found, consider running on the CPU with --use-cpu "
                            "(note: check depends on nvidia-smi and this could also mean that the nvidia-smi "
                            "binary isn't on the path).")
            if args.disable_device_locking:
                context = expand_requested_device_ids(args.device_ids)
            else:
                context = exit_stack.enter_context(acquire_gpus(args.device_ids, lock_dir=args.lock_dir))
            logger.info("Device(s): GPU %s", context)
            context = [mx.gpu(gpu_id) for gpu_id in context]

        # load existing or create vocabs
        if resume_training:
            vocab_source = sockeye.vocab.vocab_from_json_or_pickle(os.path.join(output_folder, C.VOCAB_SRC_NAME))
            vocab_target = sockeye.vocab.vocab_from_json_or_pickle(os.path.join(output_folder, C.VOCAB_TRG_NAME))
        else:
            vocab_source = _build_or_load_vocab(args.source_vocab, args.source, args.num_words, args.word_min_count)
            sockeye.vocab.vocab_to_json(vocab_source, os.path.join(output_folder, C.VOCAB_SRC_NAME) + C.JSON_SUFFIX)

            vocab_target = _build_or_load_vocab(args.target_vocab, args.target, args.num_words, args.word_min_count)
            sockeye.vocab.vocab_to_json(vocab_target, os.path.join(output_folder, C.VOCAB_TRG_NAME) + C.JSON_SUFFIX)

        vocab_source_size = len(vocab_source)
        vocab_target_size = len(vocab_target)
        logger.info("Vocabulary sizes: source=%d target=%d", vocab_source_size, vocab_target_size)

        data_info = sockeye.data_io.DataInfo(os.path.abspath(args.source),
                                             os.path.abspath(args.target),
                                             os.path.abspath(args.validation_source),
                                             os.path.abspath(args.validation_target),
                                             args.source_vocab,
                                             args.target_vocab)

        # create data iterators
        max_seq_len_source = args.max_seq_len if args.max_seq_len_source is None else args.max_seq_len_source
        max_seq_len_target = args.max_seq_len if args.max_seq_len_target is None else args.max_seq_len_target
        train_iter, eval_iter = sockeye.data_io.get_training_data_iters(source=data_info.source,
                                                                        target=data_info.target,
                                                                        validation_source=data_info.validation_source,
                                                                        validation_target=data_info.validation_target,
                                                                        vocab_source=vocab_source,
                                                                        vocab_target=vocab_target,
                                                                        batch_size=args.batch_size,
                                                                        fill_up=args.fill_up,
                                                                        max_seq_len_source=max_seq_len_source,
                                                                        max_seq_len_target=max_seq_len_target,
                                                                        bucketing=not args.no_bucketing,
                                                                        bucket_width=args.bucket_width)

        # learning rate scheduling
        learning_rate_half_life = none_if_negative(args.learning_rate_half_life)
        # TODO: The loading for continuation of the scheduler is done separately from the other parts
        if not resume_training:
            lr_scheduler = sockeye.lr_scheduler.get_lr_scheduler(args.learning_rate_scheduler_type,
                                                                 args.checkpoint_frequency,
                                                                 learning_rate_half_life,
                                                                 args.learning_rate_reduce_factor,
                                                                 args.learning_rate_reduce_num_not_improved)
        else:
            with open(os.path.join(training_state_dir, C.SCHEDULER_STATE_NAME), "rb") as fp:
                lr_scheduler = pickle.load(fp)

        # model configuration
        num_embed_source = args.num_embed if args.num_embed_source is None else args.num_embed_source
        num_embed_target = args.num_embed if args.num_embed_target is None else args.num_embed_target
        attention_num_hidden = args.rnn_num_hidden if not args.attention_num_hidden else args.attention_num_hidden
        model_config = sockeye.model.ModelConfig(max_seq_len=max_seq_len_source,
                                                 vocab_source_size=vocab_source_size,
                                                 vocab_target_size=vocab_target_size,
                                                 num_embed_source=num_embed_source,
                                                 num_embed_target=num_embed_target,
                                                 attention_type=args.attention_type,
                                                 attention_num_hidden=attention_num_hidden,
                                                 attention_coverage_type=args.attention_coverage_type,
                                                 attention_coverage_num_hidden=args.attention_coverage_num_hidden,
                                                 attention_use_prev_word=args.attention_use_prev_word,
                                                 dropout=args.dropout,
                                                 rnn_cell_type=args.rnn_cell_type,
                                                 rnn_num_layers=args.rnn_num_layers,
                                                 rnn_num_hidden=args.rnn_num_hidden,
                                                 rnn_residual_connections=args.rnn_residual_connections,
                                                 weight_tying=args.weight_tying,
                                                 context_gating=args.context_gating,
                                                 lexical_bias=args.lexical_bias,
                                                 learn_lexical_bias=args.learn_lexical_bias,
                                                 data_info=data_info,
                                                 loss=args.loss,
                                                 normalize_loss=args.normalize_loss,
                                                 smoothed_cross_entropy_alpha=args.smoothed_cross_entropy_alpha,
                                                 layer_normalization=args.layer_normalization)

        # create training model
        model = sockeye.training.TrainingModel(model_config=model_config,
                                               context=context,
                                               train_iter=train_iter,
                                               fused=args.use_fused_rnn,
                                               bucketing=not args.no_bucketing,
                                               lr_scheduler=lr_scheduler,
                                               rnn_forget_bias=args.rnn_forget_bias)

        # We may consider loading the params in TrainingModule, for consistency
        # with the training state saving
        if resume_training:
            logger.info("Found partial training in directory %s. Resuming from saved state.", training_state_dir)
            model.load_params_from_file(os.path.join(training_state_dir, C.TRAINING_STATE_PARAMS_NAME))
        elif args.params:
            logger.info("Training will initialize from parameters loaded from '%s'", args.params)
            model.load_params_from_file(args.params)

        lexicon = sockeye.lexicon.initialize_lexicon(args.lexical_bias,
                                                     vocab_source, vocab_target) if args.lexical_bias else None

        initializer = sockeye.initializer.get_initializer(args.rnn_h2h_init, lexicon=lexicon)

        optimizer = args.optimizer
        optimizer_params = {'wd': args.weight_decay,
                            "learning_rate": args.initial_learning_rate}
        if lr_scheduler is not None:
            optimizer_params["lr_scheduler"] = lr_scheduler
        clip_gradient = none_if_negative(args.clip_gradient)
        if clip_gradient is not None:
            optimizer_params["clip_gradient"] = clip_gradient
        if args.momentum is not None:
            optimizer_params["momentum"] = args.momentum
        if args.normalize_loss:
            # When normalize_loss is turned on we normalize by the number of non-PAD symbols in a batch which implicitly
            # already contains the number of sentences and therefore we need to disable rescale_grad.
            optimizer_params["rescale_grad"] = 1.0
        else:
            # Making MXNet module API's default scaling factor explicit
            optimizer_params["rescale_grad"] = 1.0 / args.batch_size
        logger.info("Optimizer: %s", optimizer)
        logger.info("Optimizer Parameters: %s", optimizer_params)

        model.fit(train_iter, eval_iter,
                  output_folder=output_folder,
                  metrics=args.metrics,
                  initializer=initializer,
                  max_updates=args.max_updates,
                  checkpoint_frequency=args.checkpoint_frequency,
                  optimizer=optimizer, optimizer_params=optimizer_params,
                  optimized_metric=args.optimized_metric,
                  max_num_not_improved=args.max_num_checkpoint_not_improved,
                  min_num_epochs=args.min_num_epochs,
                  monitor_bleu=args.monitor_bleu,
                  use_tensorboard=args.use_tensorboard)
Ejemplo n.º 12
0
def main():
    params = argparse.ArgumentParser(
        description='Translate from STDIN to STDOUT')
    params = arguments.add_inference_args(params)
    params = arguments.add_device_args(params)
    args = params.parse_args()

    logger = setup_main_logger(__name__, file_logging=False)

    assert args.beam_size > 0, "Beam size must be 1 or greater."
    if args.checkpoints is not None:
        assert len(args.checkpoints) == len(
            args.models), "must provide checkpoints for each model"

    logger.info("Command: %s", " ".join(sys.argv))
    logger.info("Arguments: %s", args)

    output_stream = sys.stdout
    output_handler = sockeye.output_handler.get_output_handler(
        args.output_type, output_stream, args.align_plot_prefix,
        args.sure_align_threshold)

    with ExitStack() as exit_stack:
        if args.use_cpu:
            context = mx.cpu()
        else:
            num_gpus = get_num_gpus()
            assert num_gpus > 0, "No GPUs found, consider running on the CPU with --use-cpu " \
                                 "(note: check depends on nvidia-smi and this could also mean that the nvidia-smi " \
                                 "binary isn't on the path)."
            assert len(
                args.device_ids) == 1, "cannot run on multiple devices for now"
            gpu_id = args.device_ids[0]
            if gpu_id < 0:
                # get a gpu id automatically:
                gpu_id = exit_stack.enter_context(acquire_gpu())
            context = mx.gpu(gpu_id)

        translator = sockeye.inference.Translator(
            context, args.ensemble_mode,
            *sockeye.inference.load_models(context, args.max_input_len,
                                           args.beam_size, args.models,
                                           args.checkpoints,
                                           args.softmax_temperature))
        total_time = 0
        i = 0
        for i, line in enumerate(sys.stdin, 1):
            trans_input = translator.make_input(i, line)
            logger.debug(" IN: %s", trans_input)

            tic = time.time()
            trans_output = translator.translate(trans_input)
            trans_wall_time = time.time() - tic
            total_time += trans_wall_time

            logger.debug("OUT: %s", trans_output)
            logger.debug("OUT: time=%.2f", trans_wall_time)

            output_handler.handle(trans_input, trans_output)

        logger.info(
            "Processed %d lines. Total time: %.4f sec/sent: %.4f sent/sec: %.4f",
            i, total_time, total_time / i, i / total_time)
Ejemplo n.º 13
0
def main():
    params = argparse.ArgumentParser(
        description='CLI to train sockeye sequence-to-sequence models.')
    arguments.add_io_args(params)
    arguments.add_model_parameters(params)
    arguments.add_training_args(params)
    arguments.add_device_args(params)
    args = params.parse_args()

    # seed the RNGs
    np.random.seed(args.seed)
    random.seed(args.seed)
    mx.random.seed(args.seed)

    if args.use_fused_rnn:
        check_condition(not args.use_cpu, "GPU required for FusedRNN cells")

    check_condition(
        args.optimized_metric == C.BLEU
        or args.optimized_metric in args.metrics,
        "Must optimize either BLEU or one of tracked metrics (--metrics)")

    # Checking status of output folder, resumption, etc.
    # Create temporary logger to console only
    logger = setup_main_logger(__name__,
                               file_logging=False,
                               console=not args.quiet)
    output_folder = os.path.abspath(args.output)
    resume_training = False
    training_state_dir = os.path.join(output_folder, C.TRAINING_STATE_DIRNAME)
    if os.path.exists(output_folder):
        if args.overwrite_output:
            logger.info("Removing existing output folder %s.", output_folder)
            shutil.rmtree(output_folder)
            os.makedirs(output_folder)
        elif os.path.exists(training_state_dir):
            with open(os.path.join(output_folder, C.ARGS_STATE_NAME),
                      "r") as fp:
                old_args = json.load(fp)
            arg_diffs = _dict_difference(
                vars(args), old_args) | _dict_difference(old_args, vars(args))
            # Remove args that may differ without affecting the training.
            arg_diffs -= set(C.ARGS_MAY_DIFFER)
            # allow different device-ids provided their total count is the same
            if 'device_ids' in arg_diffs and len(
                    old_args['device_ids']) == len(vars(args)['device_ids']):
                arg_diffs.discard('device_ids')
            if not arg_diffs:
                resume_training = True
            else:
                # We do not have the logger yet
                logger.error(
                    "Mismatch in arguments for training continuation.")
                logger.error("Differing arguments: %s.", ", ".join(arg_diffs))
                sys.exit(1)
        else:
            logger.error("Refusing to overwrite existing output folder %s.",
                         output_folder)
            sys.exit(1)
    else:
        os.makedirs(output_folder)

    logger = setup_main_logger(__name__,
                               file_logging=True,
                               console=not args.quiet,
                               path=os.path.join(output_folder, C.LOG_NAME))
    log_sockeye_version(logger)
    logger.info("Command: %s", " ".join(sys.argv))
    logger.info("Arguments: %s", args)
    with open(os.path.join(output_folder, C.ARGS_STATE_NAME), "w") as fp:
        json.dump(vars(args), fp)

    with ExitStack() as exit_stack:
        # context
        if args.use_cpu:
            logger.info("Device: CPU")
            context = [mx.cpu()]
        else:
            num_gpus = get_num_gpus()
            check_condition(
                num_gpus >= 1,
                "No GPUs found, consider running on the CPU with --use-cpu "
                "(note: check depends on nvidia-smi and this could also mean that the nvidia-smi "
                "binary isn't on the path).")
            if args.disable_device_locking:
                context = expand_requested_device_ids(args.device_ids)
            else:
                context = exit_stack.enter_context(
                    acquire_gpus(args.device_ids, lock_dir=args.lock_dir))
            logger.info("Device(s): GPU %s", context)
            context = [mx.gpu(gpu_id) for gpu_id in context]

        # load existing or create vocabs
        if resume_training:
            vocab_source = vocab.vocab_from_json_or_pickle(
                os.path.join(output_folder, C.VOCAB_SRC_NAME))
            vocab_target = vocab.vocab_from_json_or_pickle(
                os.path.join(output_folder, C.VOCAB_TRG_NAME))
        else:
            num_words_source, num_words_target = args.num_words
            word_min_count_source, word_min_count_target = args.word_min_count

            # if the source and target embeddings are tied we build a joint vocabulary:
            if args.weight_tying and C.WEIGHT_TYING_SRC in args.weight_tying_type \
                    and C.WEIGHT_TYING_TRG in args.weight_tying_type:
                vocab_source = vocab_target = _build_or_load_vocab(
                    args.source_vocab, [args.source, args.target],
                    num_words_source, word_min_count_source)
            else:
                vocab_source = _build_or_load_vocab(args.source_vocab,
                                                    [args.source],
                                                    num_words_source,
                                                    word_min_count_source)
                vocab_target = _build_or_load_vocab(args.target_vocab,
                                                    [args.target],
                                                    num_words_target,
                                                    word_min_count_target)

            # write vocabularies
            vocab.vocab_to_json(
                vocab_source,
                os.path.join(output_folder, C.VOCAB_SRC_NAME) + C.JSON_SUFFIX)
            vocab.vocab_to_json(
                vocab_target,
                os.path.join(output_folder, C.VOCAB_TRG_NAME) + C.JSON_SUFFIX)

        vocab_source_size = len(vocab_source)
        vocab_target_size = len(vocab_target)
        logger.info("Vocabulary sizes: source=%d target=%d", vocab_source_size,
                    vocab_target_size)

        config_data = data_io.DataConfig(
            os.path.abspath(args.source), os.path.abspath(args.target),
            os.path.abspath(args.validation_source),
            os.path.abspath(args.validation_target), args.source_vocab,
            args.target_vocab)

        # create data iterators
        max_seq_len_source, max_seq_len_target = args.max_seq_len
        train_iter, eval_iter = data_io.get_training_data_iters(
            source=config_data.source,
            target=config_data.target,
            validation_source=config_data.validation_source,
            validation_target=config_data.validation_target,
            vocab_source=vocab_source,
            vocab_target=vocab_target,
            batch_size=args.batch_size,
            fill_up=args.fill_up,
            max_seq_len_source=max_seq_len_source,
            max_seq_len_target=max_seq_len_target,
            bucketing=not args.no_bucketing,
            bucket_width=args.bucket_width)

        # learning rate scheduling
        learning_rate_half_life = none_if_negative(
            args.learning_rate_half_life)
        # TODO: The loading for continuation of the scheduler is done separately from the other parts
        if not resume_training:
            lr_scheduler_instance = lr_scheduler.get_lr_scheduler(
                args.learning_rate_scheduler_type, args.checkpoint_frequency,
                learning_rate_half_life, args.learning_rate_reduce_factor,
                args.learning_rate_reduce_num_not_improved,
                args.learning_rate_schedule, args.learning_rate_warmup)
        else:
            with open(os.path.join(training_state_dir, C.SCHEDULER_STATE_NAME),
                      "rb") as fp:
                lr_scheduler_instance = pickle.load(fp)

        # model configuration
        num_embed_source, num_embed_target = args.num_embed
        encoder_num_layers, decoder_num_layers = args.num_layers

        encoder_embed_dropout, decoder_embed_dropout = args.embed_dropout
        encoder_rnn_dropout, decoder_rnn_dropout = args.rnn_dropout
        if encoder_embed_dropout > 0 and encoder_rnn_dropout > 0:
            logger.warning(
                "Setting encoder RNN AND source embedding dropout > 0 leads to "
                "two dropout layers on top of each other.")
        if decoder_embed_dropout > 0 and decoder_rnn_dropout > 0:
            logger.warning(
                "Setting encoder RNN AND source embedding dropout > 0 leads to "
                "two dropout layers on top of each other.")

        config_conv = None
        if args.encoder == C.RNN_WITH_CONV_EMBED_NAME:
            config_conv = encoder.ConvolutionalEmbeddingConfig(
                num_embed=num_embed_source,
                max_filter_width=args.conv_embed_max_filter_width,
                num_filters=args.conv_embed_num_filters,
                pool_stride=args.conv_embed_pool_stride,
                num_highway_layers=args.conv_embed_num_highway_layers,
                dropout=args.conv_embed_dropout)

        if args.encoder in (C.TRANSFORMER_TYPE,
                            C.TRANSFORMER_WITH_CONV_EMBED_TYPE):
            config_encoder = transformer.TransformerConfig(
                model_size=args.transformer_model_size,
                attention_heads=args.transformer_attention_heads,
                feed_forward_num_hidden=args.
                transformer_feed_forward_num_hidden,
                num_layers=encoder_num_layers,
                vocab_size=vocab_source_size,
                dropout_attention=args.transformer_dropout_attention,
                dropout_relu=args.transformer_dropout_relu,
                dropout_residual=args.transformer_dropout_residual,
                layer_normalization=args.layer_normalization,
                weight_tying=args.weight_tying,
                positional_encodings=not args.
                transformer_no_positional_encodings,
                conv_config=config_conv)
        else:
            config_encoder = encoder.RecurrentEncoderConfig(
                vocab_size=vocab_source_size,
                num_embed=num_embed_source,
                embed_dropout=encoder_embed_dropout,
                rnn_config=rnn.RNNConfig(
                    cell_type=args.rnn_cell_type,
                    num_hidden=args.rnn_num_hidden,
                    num_layers=encoder_num_layers,
                    dropout=encoder_rnn_dropout,
                    residual=args.rnn_residual_connections,
                    first_residual_layer=args.rnn_first_residual_layer,
                    forget_bias=args.rnn_forget_bias),
                conv_config=config_conv,
                reverse_input=args.rnn_encoder_reverse_input)

        if args.decoder == C.TRANSFORMER_TYPE:
            config_decoder = transformer.TransformerConfig(
                model_size=args.transformer_model_size,
                attention_heads=args.transformer_attention_heads,
                feed_forward_num_hidden=args.
                transformer_feed_forward_num_hidden,
                num_layers=decoder_num_layers,
                vocab_size=vocab_target_size,
                dropout_attention=args.transformer_dropout_attention,
                dropout_relu=args.transformer_dropout_relu,
                dropout_residual=args.transformer_dropout_residual,
                layer_normalization=args.layer_normalization,
                weight_tying=args.weight_tying,
                positional_encodings=not args.
                transformer_no_positional_encodings)

        else:
            attention_num_hidden = args.rnn_num_hidden if not args.attention_num_hidden else args.attention_num_hidden
            config_coverage = None
            if args.attention_type == "coverage":
                config_coverage = coverage.CoverageConfig(
                    type=args.attention_coverage_type,
                    num_hidden=args.attention_coverage_num_hidden,
                    layer_normalization=args.layer_normalization)
            config_attention = attention.AttentionConfig(
                type=args.attention_type,
                num_hidden=attention_num_hidden,
                input_previous_word=args.attention_use_prev_word,
                rnn_num_hidden=args.rnn_num_hidden,
                layer_normalization=args.layer_normalization,
                config_coverage=config_coverage,
                num_heads=args.attention_mhdot_heads)
            decoder_weight_tying = args.weight_tying and C.WEIGHT_TYING_TRG in args.weight_tying_type \
                                   and C.WEIGHT_TYING_SOFTMAX in args.weight_tying_type
            config_decoder = decoder.RecurrentDecoderConfig(
                vocab_size=vocab_target_size,
                max_seq_len_source=max_seq_len_source,
                num_embed=num_embed_target,
                rnn_config=rnn.RNNConfig(
                    cell_type=args.rnn_cell_type,
                    num_hidden=args.rnn_num_hidden,
                    num_layers=decoder_num_layers,
                    dropout=decoder_rnn_dropout,
                    residual=args.rnn_residual_connections,
                    first_residual_layer=args.rnn_first_residual_layer,
                    forget_bias=args.rnn_forget_bias),
                attention_config=config_attention,
                embed_dropout=decoder_embed_dropout,
                hidden_dropout=args.rnn_decoder_hidden_dropout,
                weight_tying=decoder_weight_tying,
                zero_state_init=args.rnn_decoder_zero_init,
                context_gating=args.rnn_context_gating,
                layer_normalization=args.layer_normalization)

        config_loss = loss.LossConfig(
            type=args.loss,
            vocab_size=vocab_target_size,
            normalize=args.normalize_loss,
            smoothed_cross_entropy_alpha=args.smoothed_cross_entropy_alpha)

        model_config = model.ModelConfig(
            config_data=config_data,
            max_seq_len_source=max_seq_len_source,
            max_seq_len_target=max_seq_len_target,
            vocab_source_size=vocab_source_size,
            vocab_target_size=vocab_target_size,
            config_encoder=config_encoder,
            config_decoder=config_decoder,
            config_loss=config_loss,
            lexical_bias=args.lexical_bias,
            learn_lexical_bias=args.learn_lexical_bias,
            weight_tying=args.weight_tying,
            weight_tying_type=args.weight_tying_type
            if args.weight_tying else None)
        model_config.freeze()

        # create training model
        training_model = training.TrainingModel(
            config=model_config,
            context=context,
            train_iter=train_iter,
            fused=args.use_fused_rnn,
            bucketing=not args.no_bucketing,
            lr_scheduler=lr_scheduler_instance)

        # We may consider loading the params in TrainingModule, for consistency
        # with the training state saving
        if resume_training:
            logger.info(
                "Found partial training in directory %s. Resuming from saved state.",
                training_state_dir)
            training_model.load_params_from_file(
                os.path.join(training_state_dir, C.TRAINING_STATE_PARAMS_NAME))
        elif args.params:
            logger.info(
                "Training will initialize from parameters loaded from '%s'",
                args.params)
            training_model.load_params_from_file(args.params)

        lexicon_array = lexicon.initialize_lexicon(
            args.lexical_bias, vocab_source,
            vocab_target) if args.lexical_bias else None

        weight_initializer = initializer.get_initializer(
            args.weight_init,
            args.weight_init_scale,
            args.rnn_h2h_init,
            lexicon=lexicon_array)

        optimizer = args.optimizer
        optimizer_params = {
            'wd': args.weight_decay,
            "learning_rate": args.initial_learning_rate
        }
        if lr_scheduler_instance is not None:
            optimizer_params["lr_scheduler"] = lr_scheduler_instance
        clip_gradient = none_if_negative(args.clip_gradient)
        if clip_gradient is not None:
            optimizer_params["clip_gradient"] = clip_gradient
        if args.momentum is not None:
            optimizer_params["momentum"] = args.momentum
        if args.normalize_loss:
            # When normalize_loss is turned on we normalize by the number of non-PAD symbols in a batch which implicitly
            # already contains the number of sentences and therefore we need to disable rescale_grad.
            optimizer_params["rescale_grad"] = 1.0
        else:
            # Making MXNet module API's default scaling factor explicit
            optimizer_params["rescale_grad"] = 1.0 / args.batch_size
        logger.info("Optimizer: %s", optimizer)
        logger.info("Optimizer Parameters: %s", optimizer_params)

        # Handle options that override training settings
        max_updates = args.max_updates
        max_num_checkpoint_not_improved = args.max_num_checkpoint_not_improved
        min_num_epochs = args.min_num_epochs
        # Fixed training schedule always runs for a set number of updates
        if args.learning_rate_schedule:
            max_updates = sum(num_updates
                              for (_,
                                   num_updates) in args.learning_rate_schedule)
            max_num_checkpoint_not_improved = -1
            min_num_epochs = 0

        training_model.fit(
            train_iter,
            eval_iter,
            output_folder=output_folder,
            max_params_files_to_keep=args.keep_last_params,
            metrics=args.metrics,
            initializer=weight_initializer,
            max_updates=max_updates,
            checkpoint_frequency=args.checkpoint_frequency,
            optimizer=optimizer,
            optimizer_params=optimizer_params,
            optimized_metric=args.optimized_metric,
            max_num_not_improved=max_num_checkpoint_not_improved,
            min_num_epochs=min_num_epochs,
            monitor_bleu=args.monitor_bleu,
            use_tensorboard=args.use_tensorboard,
            mxmonitor_pattern=args.monitor_pattern,
            mxmonitor_stat_func=args.monitor_stat_func)
Ejemplo n.º 14
0
def run_translate(args: argparse.Namespace):

    if args.output is not None:
        global logger
        logger = setup_main_logger(__name__,
                                   console=not args.quiet,
                                   file_logging=True,
                                   path="%s.%s" % (args.output, C.LOG_NAME))

    if args.checkpoints is not None:
        check_condition(
            len(args.checkpoints) == len(args.models),
            "must provide checkpoints for each model")

    if args.beam_search_stop == C.BEAM_SEARCH_STOP_FIRST:
        check_condition(
            args.batch_size == 1,
            "Early stopping (--beam-search-stop %s) not supported with batching"
            % (C.BEAM_SEARCH_STOP_FIRST))

    log_basic_info(args)

    output_handler = get_output_handler(args.output_type, args.output,
                                        args.sure_align_threshold)

    with ExitStack() as exit_stack:
        context = _setup_context(args, exit_stack)

        if args.override_dtype == C.DTYPE_FP16:
            logger.warning(
                'Experimental feature \'--override-dtype float16\' has been used. '
                'This feature may be removed or change its behaviour in future. '
                'DO NOT USE IT IN PRODUCTION!')

        models, source_vocabs, target_vocab = inference.load_models(
            context=context,
            max_input_len=args.max_input_len,
            beam_size=args.beam_size,
            batch_size=args.batch_size,
            model_folders=args.models,
            checkpoints=args.checkpoints,
            softmax_temperature=args.softmax_temperature,
            max_output_length_num_stds=args.max_output_length_num_stds,
            decoder_return_logit_inputs=args.restrict_lexicon is not None,
            cache_output_layer_w_b=args.restrict_lexicon is not None,
            override_dtype=args.override_dtype)
        restrict_lexicon = None  # type: Optional[TopKLexicon]
        if args.restrict_lexicon:
            restrict_lexicon = TopKLexicon(source_vocabs[0], target_vocab)
            restrict_lexicon.load(args.restrict_lexicon,
                                  k=args.restrict_lexicon_topk)
        store_beam = args.output_type == C.OUTPUT_HANDLER_BEAM_STORE
        translator = inference.Translator(
            context=context,
            ensemble_mode=args.ensemble_mode,
            bucket_source_width=args.bucket_width,
            length_penalty=inference.LengthPenalty(args.length_penalty_alpha,
                                                   args.length_penalty_beta),
            beam_prune=args.beam_prune,
            beam_search_stop=args.beam_search_stop,
            models=models,
            source_vocabs=source_vocabs,
            target_vocab=target_vocab,
            restrict_lexicon=restrict_lexicon,
            avoid_list=args.avoid_list,
            store_beam=store_beam,
            strip_unknown_words=args.strip_unknown_words)
        read_and_translate(translator=translator,
                           output_handler=output_handler,
                           chunk_size=args.chunk_size,
                           input_file=args.input,
                           input_factors=args.input_factors,
                           input_is_json=args.json_input)