Example #1
0
def main():
    params = argparse.ArgumentParser(
        description='CLI to build source and target vocab(s).')
    arguments.add_build_vocab_args(params)
    args = params.parse_args()

    num_words, num_words_other = args.num_words
    utils.check_condition(
        num_words == num_words_other,
        "Vocabulary CLI only allows a common value for --num-words")
    word_min_count, word_min_count_other = args.word_min_count
    utils.check_condition(
        word_min_count == word_min_count_other,
        "Vocabulary CLI only allows a common value for --word-min-count")

    global logger
    logger = log.setup_main_logger("build_vocab",
                                   file_logging=True,
                                   console=True,
                                   path="%s.%s" % (args.output, C.LOG_NAME))

    vocab = build_from_paths(args.inputs,
                             num_words=num_words,
                             min_count=word_min_count)
    logger.info("Vocabulary size: %d ", len(vocab))
    vocab_to_json(vocab, args.output + C.JSON_SUFFIX)
Example #2
0
def main():
    """
    Commandline interface for building top-k lexicons using during decoding.
    """

    params = argparse.ArgumentParser(description="Build a top-k lexicon for use during decoding.")
    arguments.add_lexicon_args(params)
    arguments.add_logging_args(params)
    args = params.parse_args()

    logger = setup_main_logger(__name__, console=not args.quiet, file_logging=False)
    log_sockeye_version(logger)

    logger.info("Reading source and target vocab from \"%s\"", args.model)
    vocab_source = vocab.vocab_from_json_or_pickle(os.path.join(args.model, C.VOCAB_SRC_NAME))
    vocab_target = vocab.vocab_from_json_or_pickle(os.path.join(args.model, C.VOCAB_TRG_NAME))

    logger.info("Creating top-k lexicon from \"%s\"", args.input)
    lexicon = TopKLexicon(vocab_source, vocab_target)
    lexicon.create(args.input, args.k)
    lexicon.save(args.output)
Example #3
0
Evaluation CLI. Prints corpus BLEU
"""
import argparse
import logging
import sys
from typing import Iterable, Optional

from contrib import sacrebleu
from log import setup_main_logger, log_sockeye_version
import arguments
import chrf
import constants as C
import data_io
import utils

logger = setup_main_logger(__name__, file_logging=False)


def raw_corpus_bleu(hypotheses: Iterable[str],
                    references: Iterable[str],
                    offset: Optional[float] = 0.01) -> float:
    """
    Simple wrapper around sacreBLEU's BLEU without tokenization and smoothing.

    :param hypotheses: Hypotheses stream.
    :param references: Reference stream.
    :param offset: Smoothing constant.
    :return: BLEU score as float between 0 and 1.
    """
    return sacrebleu.raw_corpus_bleu(hypotheses, [references],
                                     smooth_floor=offset).score / 100
Example #4
0
def main():
    params = argparse.ArgumentParser(
        description='CLI to train sockeye sequence-to-sequence models.')
    arguments.add_train_cli_args(params)
    args = params.parse_args()

    utils.seedRNGs(args)

    check_arg_compatibility(args)
    output_folder = os.path.abspath(args.output)
    resume_training, training_state_dir = check_resume(args, output_folder)

    global logger
    logger = setup_main_logger(__name__,
                               file_logging=True,
                               console=not args.quiet,
                               path=os.path.join(output_folder, C.LOG_NAME))
    utils.log_basic_info(args)
    with open(os.path.join(output_folder, C.ARGS_STATE_NAME), "w") as fp:
        json.dump(vars(args), fp)

    with ExitStack() as exit_stack:
        context = determine_context(args, exit_stack)
        vocab_source, vocab_target = load_or_create_vocabs(
            args, resume_training, output_folder)
        vocab_source_size = len(vocab_source)
        vocab_target_size = len(vocab_target)
        logger.info("Vocabulary sizes: source=%d target=%d", vocab_source_size,
                    vocab_target_size)
        train_iter, eval_iter, config_data = create_data_iters(
            args, vocab_source, vocab_target)
        lr_scheduler_instance = create_lr_scheduler(args, resume_training,
                                                    training_state_dir)

        model_config = create_model_config(args, vocab_source_size,
                                           vocab_target_size, config_data)
        model_config.freeze()

        training_model = create_training_model(model_config, args, context,
                                               train_iter,
                                               lr_scheduler_instance,
                                               resume_training,
                                               training_state_dir)

        weight_initializer = initializer.get_initializer(
            default_init_type=args.weight_init,
            default_init_scale=args.weight_init_scale,
            default_init_xavier_rand_type=args.weight_init_xavier_rand_type,
            default_init_xavier_factor_type=args.
            weight_init_xavier_factor_type,
            embed_init_type=args.embed_weight_init,
            embed_init_sigma=vocab_source_size**-0.5,  # TODO
            rnn_init_type=args.rnn_h2h_init)

        optimizer, optimizer_params, kvstore, gradient_clipping_type, gradient_clipping_threshold = define_optimizer(
            args, lr_scheduler_instance)

        # Handle options that override training settings
        max_updates = args.max_updates
        max_num_checkpoint_not_improved = args.max_num_checkpoint_not_improved
        min_num_epochs = args.min_num_epochs
        max_num_epochs = args.max_num_epochs
        if min_num_epochs is not None and max_num_epochs is not None:
            check_condition(
                min_num_epochs <= max_num_epochs,
                "Minimum number of epochs must be smaller than maximum number of epochs"
            )
        # Fixed training schedule always runs for a set number of updates
        if args.learning_rate_schedule:
            max_updates = sum(num_updates
                              for (_,
                                   num_updates) in args.learning_rate_schedule)
            max_num_checkpoint_not_improved = -1
            min_num_epochs = None
            max_num_epochs = None

        decode_and_evaluate, decode_and_evaluate_context = determine_decode_and_evaluate_context(
            args, exit_stack, context)

        training_model.fit(
            train_iter,
            eval_iter,
            output_folder=output_folder,
            max_params_files_to_keep=args.keep_last_params,
            metrics=args.metrics,
            initializer=weight_initializer,
            allow_missing_params=args.allow_missing_params,
            max_updates=max_updates,
            checkpoint_frequency=args.checkpoint_frequency,
            optimizer=optimizer,
            optimizer_params=optimizer_params,
            optimized_metric=args.optimized_metric,
            gradient_clipping_type=gradient_clipping_type,
            clip_gradient_threshold=gradient_clipping_threshold,
            kvstore=kvstore,
            max_num_not_improved=max_num_checkpoint_not_improved,
            min_num_epochs=min_num_epochs,
            max_num_epochs=max_num_epochs,
            decode_and_evaluate=decode_and_evaluate,
            decode_and_evaluate_context=decode_and_evaluate_context,
            use_tensorboard=args.use_tensorboard,
            mxmonitor_pattern=args.monitor_pattern,
            mxmonitor_stat_func=args.monitor_stat_func,
            lr_decay_param_reset=args.learning_rate_decay_param_reset,
            lr_decay_opt_states_reset=args.
            learning_rate_decay_optimizer_states_reset)
Example #5
0
import data_io
import decoder
import encoder
import initializer
import loss
import lr_scheduler
import model
import rnn
import convolution
import training
import transformer
import utils
import vocab

# Temporary logger, the real one (logging to a file probably, will be created in the main function)
logger = setup_main_logger(__name__, file_logging=False, console=True)


def none_if_negative(val):
    return None if val < 0 else val


def _build_or_load_vocab(existing_vocab_path: Optional[str],
                         data_paths: List[str], num_words: int,
                         word_min_count: int) -> Dict:
    if existing_vocab_path is None:
        vocabulary = vocab.build_from_paths(paths=data_paths,
                                            num_words=num_words,
                                            min_count=word_min_count)
    else:
        vocabulary = vocab.vocab_from_json(existing_vocab_path)
Example #6
0
        logger.info("Training...")
        writer = SummaryWriter()
        train_config = TrainConfig(args)
        print('Train config:\n', train_config)
        train_iter, test_iter = build_train_dataiters(args, vocab, edge_vocab)
        train(args, train_config, model, train_iter, test_iter, cuda_device,
              logger, writer)
        writer.close()
    elif args.mode == 'test':
        logger.info('Test...')
        inversed_vocab = reverse_vocab(vocab)
        test_iter = build_test_dataiters(args, vocab, edge_vocab)
        with torch.no_grad():
            test(args, model, test_iter, vocab, inversed_vocab, cuda_device)


if __name__ == "__main__":
    args = get_arguments()
    log_dir = os.path.join(args.log_dir, args.encoder_type, str(args.stadia))
    if not os.path.exists(log_dir):
        os.mkdir(log_dir)
    logger = setup_main_logger(__name__,
                               file_logging=True,
                               console=not args.quiet,
                               path=os.path.join(log_dir, C.LOG_NAME))
    if args.cuda_device is not None:
        cuda_device = torch.device(0)
    else:
        cuda_device = torch.device('cpu')
    main(args, logger, cuda_device)
Example #7
0
def main():
    params = argparse.ArgumentParser(description='Translate CLI')
    arguments.add_translate_cli_args(params)
    args = params.parse_args()

    if args.output is not None:
        global logger
        logger = setup_main_logger(__name__,
                                   console=not args.quiet,
                                   file_logging=True,
                                   path="%s.%s" % (args.output, C.LOG_NAME))

    if args.checkpoints is not None:
        check_condition(len(args.checkpoints) == len(args.models), "must provide checkpoints for each model")

    log_basic_info(args)

    out_handler = output_handler.get_output_handler(args.output_type,
                                                               args.output,
                                                               args.sure_align_threshold)

    with ExitStack() as exit_stack:
        context = _setup_context(args, exit_stack)
        with context:
            models, vocab_source, vocab_target = inference.load_models(
                context,
                args.max_input_len,
                args.beam_size,
                args.batch_size,
                args.models,
                args.checkpoints,
                args.softmax_temperature,
                args.max_output_length_num_stds,
                decoder_return_logit_inputs=args.restrict_lexicon is not None,
                cache_output_layer_w_b=args.restrict_lexicon is not None,
                vis_target_enc_attention_layer=args.vis_target_enc_attention_layer)
            restrict_lexicon = None # type: TopKLexicon
            if args.restrict_lexicon:
                restrict_lexicon = TopKLexicon(vocab_source, vocab_target)
                restrict_lexicon.load(args.restrict_lexicon)
            translator = inference.Translator(context,
                                                      args.ensemble_mode,
                                                      args.bucket_width,
                                                      inference.LengthPenalty(args.length_penalty_alpha,
                                                                                      args.length_penalty_beta),
                                                      models,
                                                      vocab_source,
                                                      vocab_target,
                                              restrict_lexicon,
                                              lex_weight=args.lex_weight,
                                              align_weight=args.align_weight,
                                              align_skip_threshold=args.align_threshold,
                                              align_k_best=args.align_beam_size
                                              )
            translator.dictionary = data_io.read_dictionary(args.dictionary) if args.dictionary else None
            translator.dictionary_override_with_max_attention = args.dictionary_override_with_max_attention
            if translator.dictionary_override_with_max_attention:
                utils.check_condition(args.batch_size==1, "batching not supported with dictionary override yet")
            translator.dictionary_ignore_se = args.dictionary_ignore_se
            if translator.dictionary_ignore_se:
                utils.check_condition(args.batch_size==1, "batching not supported with dictionary override yet")
            read_and_translate(translator, out_handler, args.chunk_size, args.input, args.reference)