コード例 #1
0
ファイル: evaluate.py プロジェクト: martinpopel/sockeye
def main():
    params = argparse.ArgumentParser(description='Evaluate translations by calculating 4-BLEU '
                                                 'score with respect to a reference set.')
    arguments.add_evaluate_args(params)
    args = params.parse_args()

    if args.quiet:
        logger.setLevel(logging.ERROR)

    utils.check_condition(args.offset >= 0, "Offset should be non-negative.")
    log_sockeye_version(logger)

    logger.info("Command: %s", " ".join(sys.argv))
    logger.info("Arguments: %s", args)

    references = [' '.join(e) for e in data_io.read_content(args.references)]
    hypotheses = [h.strip() for h in args.hypotheses]
    logger.info("%d hypotheses | %d references", len(hypotheses), len(references))

    if not args.not_strict:
        utils.check_condition(len(hypotheses) == len(references),
                              "Number of hypotheses (%d) and references (%d) does not match." % (len(hypotheses),
                                                                                                 len(references)))

    if not args.sentence:
        bleu = raw_corpus_bleu(hypotheses, references, args.offset)
        print(bleu, file=sys.stdout)
    else:
        for h, r in zip(hypotheses, references):
            bleu = raw_corpus_bleu(h, r, args.offset)
            print(bleu, file=sys.stdout)
コード例 #2
0
def init_embeddings(args: argparse.Namespace):
    log_sockeye_version(logger)

    if len(args.weight_files) != len(args.vocabularies_in) or \
            len(args.weight_files) != len(args.vocabularies_out) or \
            len(args.weight_files) != len(args.names):
        logger.error("Exactly the same number of 'input weight files', 'input vocabularies', "
                     "'output vocabularies' and 'Sockeye parameter names' should be provided.")
        sys.exit(1)

    params = {}  # type: Dict[str, mx.nd.NDArray]
    weight_file_cache = {}  # type: Dict[str, np.ndarray]
    for weight_file, vocab_in_file, vocab_out_file, name in zip(args.weight_files, args.vocabularies_in,
                                                                args.vocabularies_out, args.names):
        weight = load_weight(weight_file, name, weight_file_cache)
        logger.info('Loading input/output vocabularies: %s %s', vocab_in_file, vocab_out_file)
        logger.info('%s',args.encoding)
        #vocab_in = vocab.vocab_from_json(vocab_in_file, encoding=args.encoding)
        vocab_in = vocab.vocab_from_json(vocab_in_file)
        vocab_out = vocab.vocab_from_json(vocab_out_file)
        logger.info('Initializing parameter: %s', name)
        initializer = mx.init.Normal(sigma=np.std(weight))
        params[name] = init_weight(weight, vocab_in, vocab_out, initializer)

    logger.info('Saving initialized parameters to %s', args.file)
    utils.save_params(params, args.file)
コード例 #3
0
def main():
    params = argparse.ArgumentParser(description='Translate CLI')
    arguments.add_inference_args(params)
    arguments.add_device_args(params)
    args = params.parse_args()

    if args.output is not None:
        global logger
        logger = setup_main_logger(__name__,
                                   file_logging=True,
                                   path="%s.%s" % (args.output, C.LOG_NAME))

    if args.checkpoints is not None:
        check_condition(
            len(args.checkpoints) == len(args.models),
            "must provide checkpoints for each model")

    log_sockeye_version(logger)
    logger.info("Command: %s", " ".join(sys.argv))
    logger.info("Arguments: %s", args)

    output_handler = sockeye.output_handler.get_output_handler(
        args.output_type, args.output, args.sure_align_threshold)

    with ExitStack() as exit_stack:
        context = _setup_context(args, exit_stack)

        translator = sockeye.inference.Translator(
            context, args.ensemble_mode,
            *sockeye.inference.load_models(context, args.max_input_len,
                                           args.beam_size, args.models,
                                           args.checkpoints,
                                           args.softmax_temperature))
        read_and_translate(translator, output_handler, args.input)
コード例 #4
0
ファイル: evaluate.py プロジェクト: zhongxia96/DCGCN
def main():
    params = argparse.ArgumentParser(description='Evaluate translations by calculating metrics with '
                                                 'respect to a reference set. If multiple hypotheses files are given'
                                                 'the mean and standard deviation of the metrics are reported.')
    arguments.add_evaluate_args(params)
    arguments.add_logging_args(params)
    args = params.parse_args()

    if args.quiet:
        logger.setLevel(logging.ERROR)

    utils.check_condition(args.offset >= 0, "Offset should be non-negative.")
    log_sockeye_version(logger)

    logger.info("Command: %s", " ".join(sys.argv))
    logger.info("Arguments: %s", args)

    references = [' '.join(e) for e in data_io.read_content(args.references)]
    all_hypotheses = [[h.strip() for h in hypotheses] for hypotheses in args.hypotheses]
    if not args.not_strict:
        for hypotheses in all_hypotheses:
            utils.check_condition(len(hypotheses) == len(references),
                                  "Number of hypotheses (%d) and references (%d) does not match." % (len(hypotheses),
                                                                                                     len(references)))
    logger.info("%d hypothesis set(s) | %d hypotheses | %d references",
                len(all_hypotheses), len(all_hypotheses[0]), len(references))

    metric_info = ["%s\t(s_opt)" % name for name in args.metrics]
    logger.info("\t".join(metric_info))

    metrics = []  # type: List[Tuple[str, Callable]]
    for name in args.metrics:
        if name == C.BLEU:
            func = partial(raw_corpus_bleu, offset=args.offset)
        elif name == C.CHRF:
            func = raw_corpus_chrf
        elif name == C.ROUGE1:
            func = raw_corpus_rouge1
        elif name == C.ROUGE2:
            func = raw_corpus_rouge2
        elif name == C.ROUGEL:
            func = raw_corpus_rougel
        else:
            raise ValueError("Unknown metric %s." % name)
        metrics.append((name, func))

    if not args.sentence:
        scores = defaultdict(list)  # type: Dict[str, List[float]]
        for hypotheses in all_hypotheses:
            for name, metric in metrics:
                scores[name].append(metric(hypotheses, references))
        _print_mean_std_score(metrics, scores)
    else:
        for hypotheses in all_hypotheses:
            for h, r in zip(hypotheses, references):
                scores = defaultdict(list)  # type: Dict[str, List[float]]
                for name, metric in metrics:
                    scores[name].append(metric([h], [r]))
                _print_mean_std_score(metrics, scores)
コード例 #5
0
def main():
    params = argparse.ArgumentParser(description='Evaluate translations by calculating metrics with '
                                                 'respect to a reference set. If multiple hypotheses files are given'
                                                 'the mean and standard deviation of the metrics are reported.')
    arguments.add_evaluate_args(params)
    arguments.add_logging_args(params)
    args = params.parse_args()

    if args.quiet:
        logger.setLevel(logging.ERROR)

    utils.check_condition(args.offset >= 0, "Offset should be non-negative.")
    log_sockeye_version(logger)

    logger.info("Command: %s", " ".join(sys.argv))
    logger.info("Arguments: %s", args)

    references = [' '.join(e) for e in data_io.read_content(args.references)]
    all_hypotheses = [[h.strip() for h in hypotheses] for hypotheses in args.hypotheses]
    metrics = args.metrics
    logger.info("%d hypotheses | %d references", len(all_hypotheses), len(references))

    if not args.not_strict:
        for hypotheses in all_hypotheses:
            utils.check_condition(len(hypotheses) == len(references),
                                  "Number of hypotheses (%d) and references (%d) does not match." % (len(hypotheses),
                                                                                                     len(references)))
    metric_info = []
    for metric in metrics:
        metric_info.append("%s\t(s_opt)" % metric)
    logger.info("\t".join(metric_info))

    if not args.sentence:
        scores = defaultdict(list)
        for hypotheses in all_hypotheses:
            for metric in metrics:
                if metric == C.BLEU:
                    score = raw_corpus_bleu(hypotheses, references, args.offset)
                elif metric == C.CHRF:
                    score = raw_corpus_chrf(hypotheses, references)
                else:
                    raise ValueError("Unknown metric %s." % metric)
                scores[metric].append(score)
        _print_mean_std_score(metrics, scores)
    else:
        for hypotheses in all_hypotheses:
            for h, r in zip(hypotheses, references):
                scores = defaultdict(list)
                for metric in metrics:
                    if metric == C.BLEU:
                        score = raw_corpus_bleu([h], [r], args.offset)
                    elif metric == C.CHRF:
                        score = raw_corpus_chrf(h, r)
                    else:
                        raise ValueError("Unknown metric %s." % metric)
                    scores[metric].append(score)
                _print_mean_std_score(metrics, scores)
コード例 #6
0
ファイル: utils.py プロジェクト: lagka/sockeye
def log_basic_info(args) -> None:
    """
    Log basic information like version number, arguments, etc.

    :param args: Arguments as returned by argparse.
    """
    log_sockeye_version(logger)
    log_mxnet_version(logger)
    logger.info("Command: %s", " ".join(sys.argv))
    logger.info("Arguments: %s", args)
コード例 #7
0
ファイル: utils.py プロジェクト: msobrevillac/sockeye
def log_basic_info(args) -> None:
    """
    Log basic information like version number, arguments, etc.

    :param args: Arguments as returned by argparse.
    """
    log_sockeye_version(logger)
    log_mxnet_version(logger)
    logger.info("Command: %s", " ".join(sys.argv))
    logger.info("Arguments: %s", args)
コード例 #8
0
def extract_parameters(args: argparse.Namespace):
    log_sockeye_version(logger)

    if os.path.isdir(args.input):
        param_path = os.path.join(args.input, C.PARAMS_BEST_NAME)
    else:
        param_path = args.input
    ext_params = extract(param_path, args.names, args.list_all)
    
    if len(ext_params) > 0:
        utils.check_condition(args.output is not None, "An output filename must be specified. (Use --output)")
        logger.info("Writting extracted parameters to '%s'", args.output)
        np.savez_compressed(args.output, **ext_params)
コード例 #9
0
def main():
    params = argparse.ArgumentParser(
        description='Evaluate translations by calculating metrics with '
        'respect to a reference set.')
    arguments.add_evaluate_args(params)
    arguments.add_logging_args(params)
    args = params.parse_args()

    if args.quiet:
        logger.setLevel(logging.ERROR)

    utils.check_condition(args.offset >= 0, "Offset should be non-negative.")
    log_sockeye_version(logger)

    logger.info("Command: %s", " ".join(sys.argv))
    logger.info("Arguments: %s", args)

    references = [' '.join(e) for e in data_io.read_content(args.references)]
    hypotheses = [h.strip() for h in args.hypotheses]
    logger.info("%d hypotheses | %d references", len(hypotheses),
                len(references))

    if not args.not_strict:
        utils.check_condition(
            len(hypotheses) == len(references),
            "Number of hypotheses (%d) and references (%d) does not match." %
            (len(hypotheses), len(references)))

    if not args.sentence:
        scores = []
        for metric in args.metrics:
            if metric == C.BLEU:
                bleu_score = raw_corpus_bleu(hypotheses, references,
                                             args.offset)
                scores.append("%.6f" % bleu_score)
            elif metric == C.CHRF:
                chrf_score = raw_corpus_chrf(hypotheses, references)
                scores.append("%.6f" % chrf_score)
        print("\t".join(scores), file=sys.stdout)
    else:
        for h, r in zip(hypotheses, references):
            scores = []
            for metric in args.metrics:
                if metric == C.BLEU:
                    bleu = raw_corpus_bleu([h], [r], args.offset)
                    scores.append("%.6f" % bleu)
                elif metric == C.CHRF:
                    chrf_score = raw_corpus_chrf(h, r)
                    scores.append("%.6f" % chrf_score)
            print("\t".join(scores), file=sys.stdout)
コード例 #10
0
def average_parameters(args: argparse.Namespace):
    log_sockeye_version(logger)

    if len(args.inputs) > 1:
        avg_params = average(args.inputs)
    else:
        param_paths = find_checkpoints(model_path=args.inputs[0],
                                       size=args.n,
                                       strategy=args.strategy,
                                       metric=args.metric)
        avg_params = average(param_paths)

    mx.nd.save(args.output, avg_params)
    logger.info("Averaged parameters written to '%s'", args.output)
コード例 #11
0
ファイル: average.py プロジェクト: lagka/sockeye
def average_parameters(args: argparse.Namespace):
    log_sockeye_version(logger)

    if len(args.inputs) > 1:
        avg_params = average(args.inputs)
    else:
        param_paths = find_checkpoints(model_path=args.inputs[0],
                                       size=args.n,
                                       strategy=args.strategy,
                                       metric=args.metric)
        avg_params = average(param_paths)

    mx.nd.save(args.output, avg_params)
    logger.info("Averaged parameters written to '%s'", args.output)
コード例 #12
0
def main():
    """
    Commandline interface to average parameters.
    """
    log_sockeye_version(logger)
    params = argparse.ArgumentParser(description="Averages parameters from multiple models.")
    sockeye.arguments.add_average_args(params)
    args = params.parse_args()

    if len(args.inputs) > 1:
        avg_params = average(args.inputs)
    else:
        param_paths = find_checkpoints(args.inputs[0], args.n, args.strategy,
                                       args.max, args.metric)
        avg_params = average(param_paths)

    mx.nd.save(args.output, avg_params)
    logger.info("Averaged parameters written to '%s'", args.output)
コード例 #13
0
def main():
    """
    Commandline interface to extract parameters.
    """
    log_sockeye_version(logger)
    params = argparse.ArgumentParser(description="Extract specific parameters.")
    arguments.add_extract_args(params)
    args = params.parse_args()

    if os.path.isdir(args.input):
        param_path = os.path.join(args.input, C.PARAMS_BEST_NAME)
    else:
        param_path = args.input
    ext_params = extract(param_path, args.names, args.list_all)
    
    if len(ext_params) > 0:
        utils.check_condition(args.output is not None, "An output filename must be specified. (Use --output)")
        logger.info("Writting extracted parameters to '%s'", args.output)
        np.savez_compressed(args.output, **ext_params)
コード例 #14
0
def main():
    """
    Commandline interface to initialize Sockeye embedding weights with pretrained word representations.
    """
    log_sockeye_version(logger)
    params = argparse.ArgumentParser(
        description='Quick usage: python3 -m sockeye.init_embedding '
        '-w embed-in-src.npy embed-in-tgt.npy '
        '-i vocab-in-src.json vocab-in-tgt.json '
        '-o vocab-out-src.json vocab-out-tgt.json '
        '-n source_embed_weight target_embed_weight '
        '-f params.init')
    arguments.add_init_embedding_args(params)
    args = params.parse_args()

    if len(args.weight_files) != len(args.vocabularies_in) or \
       len(args.weight_files) != len(args.vocabularies_out) or \
       len(args.weight_files) != len(args.names):
        logger.error(
            "Exactly the same number of 'input weight files', 'input vocabularies', "
            "'output vocabularies' and 'Sockeye parameter names' should be provided."
        )
        sys.exit(1)

    params = {}  # type: Dict[str, mx.nd.NDArray]
    weight_file_cache = {}  # type: Dict[str, np.ndarray]
    for weight_file, vocab_in_file, vocab_out_file, name in zip(
            args.weight_files, args.vocabularies_in, args.vocabularies_out,
            args.names):
        weight = load_weight(weight_file, name, weight_file_cache)
        logger.info('Loading input/output vocabularies: %s %s', vocab_in_file,
                    vocab_out_file)
        vocab_in = vocab.vocab_from_json(vocab_in_file, encoding=args.encoding)
        vocab_out = vocab.vocab_from_json(vocab_out_file)
        logger.info('Initializing parameter: %s', name)
        initializer = mx.init.Normal(sigma=np.std(weight))
        params[name] = init_weight(weight, vocab_in, vocab_out, initializer)

    logger.info('Saving initialized parameters to %s', args.file)
    utils.save_params(params, args.file)
コード例 #15
0
ファイル: evaluate.py プロジェクト: rah9eu/p3
def main():
    params = argparse.ArgumentParser(description='Evaluate translations by calculating 4-BLEU '
                                                 'score with respect to a reference set')
    params.add_argument('--references', '-r', required=True, type=str, help="File with references")
    params.add_argument('--hypotheses', '-i', required=True, type=str, help="File with references")
    params.add_argument('--quiet', '-q', action="store_true", help="Do not print logging information")
    params.add_argument('--sentence', '-s', action="store_true", help="Show sentence-BLEU")
    params.add_argument('--offset', type=float, default=0.01,
                        help="Numerical value of the offset of zero n-gram counts")
    args = params.parse_args()

    check_condition(args.offset >= 0, "Offset should be non-negative.")

    logger = setup_main_logger(__name__, file_logging=False)
    log_sockeye_version(logger)

    if args.quiet:
        logger.setLevel(logging.ERROR)

    logger.info("Command: %s", " ".join(sys.argv))
    logger.info("Arguments: %s", args)

    hypotheses = [' '.join(e) for e in read_content(args.hypotheses)]
    references = [' '.join(e) for e in read_content(args.references)]

    logger.info("Loaded %d hypotheses", len(hypotheses))
    logger.info("Loaded %d references", len(references))

    check_condition(len(hypotheses) == len(references), "Hypotheses and references have different number of lines.")

    if not args.sentence:
        bleu = corpus_bleu(hypotheses, references, args.offset)
        print(bleu, file=sys.stdout)
    else:
        for h, r in zip(hypotheses, references):
            bleu = bleu_from_counts(bleu_counts(h, r), args.offset)
            print(bleu, file=sys.stdout)
コード例 #16
0
ファイル: init_embedding.py プロジェクト: lagka/sockeye
def init_embeddings(args: argparse.Namespace):
    log_sockeye_version(logger)

    if len(args.weight_files) != len(args.vocabularies_in) or \
            len(args.weight_files) != len(args.vocabularies_out) or \
            len(args.weight_files) != len(args.names):
        logger.error("Exactly the same number of 'input weight files', 'input vocabularies', "
                     "'output vocabularies' and 'Sockeye parameter names' should be provided.")
        sys.exit(1)

    params = {}  # type: Dict[str, mx.nd.NDArray]
    weight_file_cache = {}  # type: Dict[str, np.ndarray]
    for weight_file, vocab_in_file, vocab_out_file, name in zip(args.weight_files, args.vocabularies_in,
                                                                args.vocabularies_out, args.names):
        weight = load_weight(weight_file, name, weight_file_cache)
        logger.info('Loading input/output vocabularies: %s %s', vocab_in_file, vocab_out_file)
        vocab_in = vocab.vocab_from_json(vocab_in_file, encoding=args.encoding)
        vocab_out = vocab.vocab_from_json(vocab_out_file)
        logger.info('Initializing parameter: %s', name)
        initializer = mx.init.Normal(sigma=np.std(weight))
        params[name] = init_weight(weight, vocab_in, vocab_out, initializer)

    logger.info('Saving initialized parameters to %s', args.file)
    utils.save_params(params, args.file)
コード例 #17
0
ファイル: quantize.py プロジェクト: xingniu/sockeye
def annotate_model_params(model_dir: str):
    log_sockeye_version(logger)

    params_best = os.path.join(model_dir, C.PARAMS_BEST_NAME)
    params_best_float32 = os.path.join(model_dir, C.PARAMS_BEST_NAME_FLOAT32)
    config = os.path.join(model_dir, C.CONFIG_NAME)
    config_float32 = os.path.join(model_dir, C.CONFIG_NAME_FLOAT32)

    for fname in params_best_float32, config_float32:
        check_condition(
            not os.path.exists(fname),
            'File "%s" exists, indicating this model has already been quantized.'
            % fname)

    # Load model and compute scaling factors
    model, _, __ = sockeye.model.load_model(model_dir,
                                            for_disk_saving='float32',
                                            dtype='int8')
    # Move original params and config files
    os.rename(params_best, params_best_float32)
    os.rename(config, config_float32)
    # Write new params and config files with annotated scaling factors
    model.save_parameters(params_best)
    model.save_config(model_dir)
コード例 #18
0
def main():
    params = argparse.ArgumentParser(description='CLI to train sockeye sequence-to-sequence models.')
    arguments.add_train_cli_args(params)
    args = params.parse_args()

    # seed the RNGs
    np.random.seed(args.seed)
    random.seed(args.seed)
    mx.random.seed(args.seed)

    if args.use_fused_rnn:
        check_condition(not args.use_cpu, "GPU required for FusedRNN cells")

    check_condition(args.optimized_metric == C.BLEU or args.optimized_metric in args.metrics,
                    "Must optimize either BLEU or one of tracked metrics (--metrics)")

    # Checking status of output folder, resumption, etc.
    # Create temporary logger to console only
    logger = setup_main_logger(__name__, file_logging=False, console=not args.quiet)

    output_folder = os.path.abspath(args.output)
    resume_training = False
    training_state_dir = os.path.join(output_folder, C.TRAINING_STATE_DIRNAME)
    if os.path.exists(output_folder):
        if args.overwrite_output:
            logger.info("Removing existing output folder %s.", output_folder)
            shutil.rmtree(output_folder)
            os.makedirs(output_folder)
        elif os.path.exists(training_state_dir):
            with open(os.path.join(output_folder, C.ARGS_STATE_NAME), "r") as fp:
                old_args = json.load(fp)
            arg_diffs = _dict_difference(vars(args), old_args) | _dict_difference(old_args, vars(args))
            # Remove args that may differ without affecting the training.
            arg_diffs -= set(C.ARGS_MAY_DIFFER)
            # allow different device-ids provided their total count is the same
            if 'device_ids' in arg_diffs and len(old_args['device_ids']) == len(vars(args)['device_ids']):
                arg_diffs.discard('device_ids')
            if not arg_diffs:
                resume_training = True
            else:
                # We do not have the logger yet
                logger.error("Mismatch in arguments for training continuation.")
                logger.error("Differing arguments: %s.", ", ".join(arg_diffs))
                sys.exit(1)
        elif os.path.exists(os.path.join(output_folder, C.PARAMS_BEST_NAME)):
            logger.error("Refusing to overwrite model folder %s as it seems to contain a trained model.", output_folder)
            sys.exit(1)
        else:
            logger.info("The output folder %s already exists, but no training state or parameter file was found. "
                        "Will start training from scratch.", output_folder)
    else:
        os.makedirs(output_folder)

    logger = setup_main_logger(__name__,
                               file_logging=True,
                               console=not args.quiet, path=os.path.join(output_folder, C.LOG_NAME))
    log_sockeye_version(logger)
    log_mxnet_version(logger)
    logger.info("Command: %s", " ".join(sys.argv))
    logger.info("Arguments: %s", args)
    with open(os.path.join(output_folder, C.ARGS_STATE_NAME), "w") as fp:
        json.dump(vars(args), fp)

    with ExitStack() as exit_stack:
        # context
        if args.use_cpu:
            logger.info("Device: CPU")
            context = [mx.cpu()]
        else:
            num_gpus = get_num_gpus()
            check_condition(num_gpus >= 1,
                            "No GPUs found, consider running on the CPU with --use-cpu "
                            "(note: check depends on nvidia-smi and this could also mean that the nvidia-smi "
                            "binary isn't on the path).")
            if args.disable_device_locking:
                context = expand_requested_device_ids(args.device_ids)
            else:
                context = exit_stack.enter_context(acquire_gpus(args.device_ids, lock_dir=args.lock_dir))
            logger.info("Device(s): GPU %s", context)
            context = [mx.gpu(gpu_id) for gpu_id in context]

        # load existing or create vocabs
        if resume_training:
            vocab_source = vocab.vocab_from_json_or_pickle(os.path.join(output_folder, C.VOCAB_SRC_NAME))
            vocab_target = vocab.vocab_from_json_or_pickle(os.path.join(output_folder, C.VOCAB_TRG_NAME))
        else:
            num_words_source, num_words_target = args.num_words
            word_min_count_source, word_min_count_target = args.word_min_count

            # if the source and target embeddings are tied we build a joint vocabulary:
            if args.weight_tying and C.WEIGHT_TYING_SRC in args.weight_tying_type \
                    and C.WEIGHT_TYING_TRG in args.weight_tying_type:
                vocab_source = vocab_target = _build_or_load_vocab(args.source_vocab,
                                                                   [args.source, args.target],
                                                                   num_words_source,
                                                                   word_min_count_source)
            else:
                vocab_source = _build_or_load_vocab(args.source_vocab, [args.source],
                                                    num_words_source, word_min_count_source)
                vocab_target = _build_or_load_vocab(args.target_vocab, [args.target],
                                                    num_words_target, word_min_count_target)

            # write vocabularies
            vocab.vocab_to_json(vocab_source, os.path.join(output_folder, C.VOCAB_SRC_NAME) + C.JSON_SUFFIX)
            vocab.vocab_to_json(vocab_target, os.path.join(output_folder, C.VOCAB_TRG_NAME) + C.JSON_SUFFIX)

        vocab_source_size = len(vocab_source)
        vocab_target_size = len(vocab_target)
        logger.info("Vocabulary sizes: source=%d target=%d", vocab_source_size, vocab_target_size)

        # create data iterators
        max_seq_len_source, max_seq_len_target = args.max_seq_len
        batch_num_devices = 1 if args.use_cpu else sum(-di if di < 0 else 1 for di in args.device_ids)
        train_iter, eval_iter, config_data = data_io.get_training_data_iters(source=os.path.abspath(args.source),
                                                                             target=os.path.abspath(args.target),
                                                                             validation_source=os.path.abspath(
                                                                                 args.validation_source),
                                                                             validation_target=os.path.abspath(
                                                                                 args.validation_target),
                                                                             vocab_source=vocab_source,
                                                                             vocab_target=vocab_target,
                                                                             vocab_source_path=args.source_vocab,
                                                                             vocab_target_path=args.target_vocab,
                                                                             batch_size=args.batch_size,
                                                                             batch_by_words=args.batch_type == C.BATCH_TYPE_WORD,
                                                                             batch_num_devices=batch_num_devices,
                                                                             fill_up=args.fill_up,
                                                                             max_seq_len_source=max_seq_len_source,
                                                                             max_seq_len_target=max_seq_len_target,
                                                                             bucketing=not args.no_bucketing,
                                                                             bucket_width=args.bucket_width)

        # learning rate scheduling
        learning_rate_half_life = none_if_negative(args.learning_rate_half_life)
        # TODO: The loading for continuation of the scheduler is done separately from the other parts
        if not resume_training:
            lr_scheduler_instance = lr_scheduler.get_lr_scheduler(args.learning_rate_scheduler_type,
                                                                  args.checkpoint_frequency,
                                                                  learning_rate_half_life,
                                                                  args.learning_rate_reduce_factor,
                                                                  args.learning_rate_reduce_num_not_improved,
                                                                  args.learning_rate_schedule,
                                                                  args.learning_rate_warmup)
        else:
            with open(os.path.join(training_state_dir, C.SCHEDULER_STATE_NAME), "rb") as fp:
                lr_scheduler_instance = pickle.load(fp)

        # model configuration
        num_embed_source, num_embed_target = args.num_embed
        encoder_num_layers, decoder_num_layers = args.num_layers

        encoder_embed_dropout, decoder_embed_dropout = args.embed_dropout
        encoder_rnn_dropout_inputs, decoder_rnn_dropout_inputs = args.rnn_dropout_inputs
        encoder_rnn_dropout_states, decoder_rnn_dropout_states = args.rnn_dropout_states
        if encoder_embed_dropout > 0 and encoder_rnn_dropout_inputs > 0:
            logger.warning("Setting encoder RNN AND source embedding dropout > 0 leads to "
                           "two dropout layers on top of each other.")
        if decoder_embed_dropout > 0 and decoder_rnn_dropout_inputs > 0:
            logger.warning("Setting encoder RNN AND source embedding dropout > 0 leads to "
                           "two dropout layers on top of each other.")
        encoder_rnn_dropout_recurrent, decoder_rnn_dropout_recurrent = args.rnn_dropout_recurrent
        if encoder_rnn_dropout_recurrent > 0 or decoder_rnn_dropout_recurrent > 0:
            check_condition(args.rnn_cell_type == C.LSTM_TYPE,
                            "Recurrent dropout without memory loss only supported for LSTMs right now.")

        encoder_transformer_preprocess, decoder_transformer_preprocess = args.transformer_preprocess
        encoder_transformer_postprocess, decoder_transformer_postprocess = args.transformer_postprocess

        config_conv = None
        if args.encoder == C.RNN_WITH_CONV_EMBED_NAME:
            config_conv = encoder.ConvolutionalEmbeddingConfig(num_embed=num_embed_source,
                                                               max_filter_width=args.conv_embed_max_filter_width,
                                                               num_filters=args.conv_embed_num_filters,
                                                               pool_stride=args.conv_embed_pool_stride,
                                                               num_highway_layers=args.conv_embed_num_highway_layers,
                                                               dropout=args.conv_embed_dropout)

        if args.encoder in (C.TRANSFORMER_TYPE, C.TRANSFORMER_WITH_CONV_EMBED_TYPE):
            config_encoder = transformer.TransformerConfig(
                model_size=args.transformer_model_size,
                attention_heads=args.transformer_attention_heads,
                feed_forward_num_hidden=args.transformer_feed_forward_num_hidden,
                num_layers=encoder_num_layers,
                vocab_size=vocab_source_size,
                dropout_attention=args.transformer_dropout_attention,
                dropout_relu=args.transformer_dropout_relu,
                dropout_prepost=args.transformer_dropout_prepost,
                weight_tying=args.weight_tying and C.WEIGHT_TYING_SRC in args.weight_tying_type,
                positional_encodings=not args.transformer_no_positional_encodings,
                preprocess_sequence=encoder_transformer_preprocess,
                postprocess_sequence=encoder_transformer_postprocess,
                conv_config=config_conv)
        else:
            config_encoder = encoder.RecurrentEncoderConfig(
                vocab_size=vocab_source_size,
                num_embed=num_embed_source,
                embed_dropout=encoder_embed_dropout,
                rnn_config=rnn.RNNConfig(cell_type=args.rnn_cell_type,
                                         num_hidden=args.rnn_num_hidden,
                                         num_layers=encoder_num_layers,
                                         dropout_inputs=encoder_rnn_dropout_inputs,
                                         dropout_states=encoder_rnn_dropout_states,
                                         dropout_recurrent=encoder_rnn_dropout_recurrent,
                                         residual=args.rnn_residual_connections,
                                         first_residual_layer=args.rnn_first_residual_layer,
                                         forget_bias=args.rnn_forget_bias),
                conv_config=config_conv,
                reverse_input=args.rnn_encoder_reverse_input)

        decoder_weight_tying = args.weight_tying and C.WEIGHT_TYING_TRG in args.weight_tying_type \
                               and C.WEIGHT_TYING_SOFTMAX in args.weight_tying_type

        if args.decoder == C.TRANSFORMER_TYPE:
            config_decoder = transformer.TransformerConfig(
                model_size=args.transformer_model_size,
                attention_heads=args.transformer_attention_heads,
                feed_forward_num_hidden=args.transformer_feed_forward_num_hidden,
                num_layers=decoder_num_layers,
                vocab_size=vocab_target_size,
                dropout_attention=args.transformer_dropout_attention,
                dropout_relu=args.transformer_dropout_relu,
                dropout_prepost=args.transformer_dropout_prepost,
                weight_tying=decoder_weight_tying,
                positional_encodings=not args.transformer_no_positional_encodings,
                preprocess_sequence=decoder_transformer_preprocess,
                postprocess_sequence=decoder_transformer_postprocess,
                conv_config=None)

        else:
            attention_num_hidden = args.rnn_num_hidden if not args.attention_num_hidden else args.attention_num_hidden
            config_coverage = None
            if args.attention_type == C.ATT_COV:
                config_coverage = coverage.CoverageConfig(type=args.attention_coverage_type,
                                                          num_hidden=args.attention_coverage_num_hidden,
                                                          layer_normalization=args.layer_normalization)
            config_attention = attention.AttentionConfig(type=args.attention_type,
                                                         num_hidden=attention_num_hidden,
                                                         input_previous_word=args.attention_use_prev_word,
                                                         rnn_num_hidden=args.rnn_num_hidden,
                                                         layer_normalization=args.layer_normalization,
                                                         config_coverage=config_coverage,
                                                         num_heads=args.attention_mhdot_heads)
            config_decoder = decoder.RecurrentDecoderConfig(
                vocab_size=vocab_target_size,
                max_seq_len_source=max_seq_len_source,
                num_embed=num_embed_target,
                rnn_config=rnn.RNNConfig(cell_type=args.rnn_cell_type,
                                         num_hidden=args.rnn_num_hidden,
                                         num_layers=decoder_num_layers,
                                         dropout_inputs=decoder_rnn_dropout_inputs,
                                         dropout_states=decoder_rnn_dropout_states,
                                         dropout_recurrent=decoder_rnn_dropout_recurrent,
                                         residual=args.rnn_residual_connections,
                                         first_residual_layer=args.rnn_first_residual_layer,
                                         forget_bias=args.rnn_forget_bias),
                attention_config=config_attention,
                embed_dropout=decoder_embed_dropout,
                hidden_dropout=args.rnn_decoder_hidden_dropout,
                weight_tying=decoder_weight_tying,
                state_init=args.rnn_decoder_state_init,
                context_gating=args.rnn_context_gating,
                layer_normalization=args.layer_normalization,
                attention_in_upper_layers=args.attention_in_upper_layers)

        config_loss = loss.LossConfig(type=args.loss,
                                      vocab_size=vocab_target_size,
                                      normalize=args.normalize_loss,
                                      smoothed_cross_entropy_alpha=args.smoothed_cross_entropy_alpha)

        model_config = model.ModelConfig(config_data=config_data,
                                         max_seq_len_source=max_seq_len_source,
                                         max_seq_len_target=max_seq_len_target,
                                         vocab_source_size=vocab_source_size,
                                         vocab_target_size=vocab_target_size,
                                         config_encoder=config_encoder,
                                         config_decoder=config_decoder,
                                         config_loss=config_loss,
                                         lexical_bias=args.lexical_bias,
                                         learn_lexical_bias=args.learn_lexical_bias,
                                         weight_tying=args.weight_tying,
                                         weight_tying_type=args.weight_tying_type if args.weight_tying else None)
        model_config.freeze()

        # create training model
        training_model = training.TrainingModel(config=model_config,
                                                context=context,
                                                train_iter=train_iter,
                                                fused=args.use_fused_rnn,
                                                bucketing=not args.no_bucketing,
                                                lr_scheduler=lr_scheduler_instance)

        # We may consider loading the params in TrainingModule, for consistency
        # with the training state saving
        if resume_training:
            logger.info("Found partial training in directory %s. Resuming from saved state.", training_state_dir)
            training_model.load_params_from_file(os.path.join(training_state_dir, C.TRAINING_STATE_PARAMS_NAME))
        elif args.params:
            logger.info("Training will initialize from parameters loaded from '%s'", args.params)
            training_model.load_params_from_file(args.params)

        lexicon_array = lexicon.initialize_lexicon(args.lexical_bias,
                                                   vocab_source, vocab_target) if args.lexical_bias else None

        weight_initializer = initializer.get_initializer(args.weight_init, args.weight_init_scale,
                                                         args.rnn_h2h_init, lexicon=lexicon_array)

        optimizer = args.optimizer
        optimizer_params = {'wd': args.weight_decay,
                            "learning_rate": args.initial_learning_rate}
        if lr_scheduler_instance is not None:
            optimizer_params["lr_scheduler"] = lr_scheduler_instance
        clip_gradient = none_if_negative(args.clip_gradient)
        if clip_gradient is not None:
            optimizer_params["clip_gradient"] = clip_gradient
        if args.momentum is not None:
            optimizer_params["momentum"] = args.momentum
        if args.normalize_loss:
            # When normalize_loss is turned on we normalize by the number of non-PAD symbols in a batch which implicitly
            # already contains the number of sentences and therefore we need to disable rescale_grad.
            optimizer_params["rescale_grad"] = 1.0
        else:
            # Making MXNet module API's default scaling factor explicit
            optimizer_params["rescale_grad"] = 1.0 / args.batch_size
        logger.info("Optimizer: %s", optimizer)
        logger.info("Optimizer Parameters: %s", optimizer_params)

        # Handle options that override training settings
        max_updates = args.max_updates
        max_num_checkpoint_not_improved = args.max_num_checkpoint_not_improved
        min_num_epochs = args.min_num_epochs
        # Fixed training schedule always runs for a set number of updates
        if args.learning_rate_schedule:
            max_updates = sum(num_updates for (_, num_updates) in args.learning_rate_schedule)
            max_num_checkpoint_not_improved = -1
            min_num_epochs = 0

        monitor_bleu = args.monitor_bleu
        # Turn on BLEU monitoring when the optimized metric is BLEU and it hasn't been enabled yet
        if args.optimized_metric == C.BLEU and monitor_bleu == 0:
            logger.info("You chose BLEU as the optimized metric, will turn on BLEU monitoring during training. "
                        "To control how many validation sentences are used for calculating bleu use "
                        "the --monitor-bleu argument.")
            monitor_bleu = -1

        training_model.fit(train_iter, eval_iter,
                           output_folder=output_folder,
                           max_params_files_to_keep=args.keep_last_params,
                           metrics=args.metrics,
                           initializer=weight_initializer,
                           max_updates=max_updates,
                           checkpoint_frequency=args.checkpoint_frequency,
                           optimizer=optimizer, optimizer_params=optimizer_params,
                           optimized_metric=args.optimized_metric,
                           max_num_not_improved=max_num_checkpoint_not_improved,
                           min_num_epochs=min_num_epochs,
                           monitor_bleu=monitor_bleu,
                           use_tensorboard=args.use_tensorboard,
                           mxmonitor_pattern=args.monitor_pattern,
                           mxmonitor_stat_func=args.monitor_stat_func)
コード例 #19
0
ファイル: train.py プロジェクト: mengjiexu/sockeye
def main():
    params = argparse.ArgumentParser(
        description='CLI to train sockeye sequence-to-sequence models.')
    arguments.add_io_args(params)
    arguments.add_model_parameters(params)
    arguments.add_training_args(params)
    arguments.add_device_args(params)
    args = params.parse_args()

    # seed the RNGs
    np.random.seed(args.seed)
    random.seed(args.seed)
    mx.random.seed(args.seed)

    if args.use_fused_rnn:
        check_condition(not args.use_cpu, "GPU required for FusedRNN cells")

    if args.rnn_residual_connections:
        check_condition(args.rnn_num_layers > 2,
                        "Residual connections require at least 3 RNN layers")

    check_condition(
        args.optimized_metric == C.BLEU
        or args.optimized_metric in args.metrics,
        "Must optimize either BLEU or one of tracked metrics (--metrics)")

    # Checking status of output folder, resumption, etc.
    # Create temporary logger to console only
    logger = setup_main_logger(__name__,
                               file_logging=False,
                               console=not args.quiet)
    output_folder = os.path.abspath(args.output)
    resume_training = False
    training_state_dir = os.path.join(output_folder, C.TRAINING_STATE_DIRNAME)
    if os.path.exists(output_folder):
        if args.overwrite_output:
            logger.info("Removing existing output folder %s.", output_folder)
            shutil.rmtree(output_folder)
            os.makedirs(output_folder)
        elif os.path.exists(training_state_dir):
            with open(os.path.join(output_folder, C.ARGS_STATE_NAME),
                      "r") as fp:
                old_args = json.load(fp)
            arg_diffs = _dict_difference(
                vars(args), old_args) | _dict_difference(old_args, vars(args))
            # Remove args that may differ without affecting the training.
            arg_diffs -= set(C.ARGS_MAY_DIFFER)
            # allow different device-ids provided their total count is the same
            if 'device_ids' in arg_diffs and len(
                    old_args['device_ids']) == len(vars(args)['device_ids']):
                arg_diffs.discard('device_ids')
            if not arg_diffs:
                resume_training = True
            else:
                # We do not have the logger yet
                logger.error(
                    "Mismatch in arguments for training continuation.")
                logger.error("Differing arguments: %s.", ", ".join(arg_diffs))
                sys.exit(1)
        else:
            logger.error("Refusing to overwrite existing output folder %s.",
                         output_folder)
            sys.exit(1)
    else:
        os.makedirs(output_folder)

    logger = setup_main_logger(__name__,
                               file_logging=True,
                               console=not args.quiet,
                               path=os.path.join(output_folder, C.LOG_NAME))
    log_sockeye_version(logger)
    logger.info("Command: %s", " ".join(sys.argv))
    logger.info("Arguments: %s", args)
    with open(os.path.join(output_folder, C.ARGS_STATE_NAME), "w") as fp:
        json.dump(vars(args), fp)

    with ExitStack() as exit_stack:
        # context
        if args.use_cpu:
            logger.info("Device: CPU")
            context = [mx.cpu()]
        else:
            num_gpus = get_num_gpus()
            check_condition(
                num_gpus >= 1,
                "No GPUs found, consider running on the CPU with --use-cpu "
                "(note: check depends on nvidia-smi and this could also mean that the nvidia-smi "
                "binary isn't on the path).")
            if args.disable_device_locking:
                context = expand_requested_device_ids(args.device_ids)
            else:
                context = exit_stack.enter_context(
                    acquire_gpus(args.device_ids, lock_dir=args.lock_dir))
            logger.info("Device(s): GPU %s", context)
            context = [mx.gpu(gpu_id) for gpu_id in context]

        # load existing or create vocabs
        if resume_training:
            vocab_source = vocab.vocab_from_json_or_pickle(
                os.path.join(output_folder, C.VOCAB_SRC_NAME))
            vocab_target = vocab.vocab_from_json_or_pickle(
                os.path.join(output_folder, C.VOCAB_TRG_NAME))
        else:
            num_words_source = args.num_words if args.num_words_source is None else args.num_words_source
            vocab_source = _build_or_load_vocab(args.source_vocab, args.source,
                                                num_words_source,
                                                args.word_min_count)
            vocab.vocab_to_json(
                vocab_source,
                os.path.join(output_folder, C.VOCAB_SRC_NAME) + C.JSON_SUFFIX)

            num_words_target = args.num_words if args.num_words_target is None else args.num_words_target
            vocab_target = _build_or_load_vocab(args.target_vocab, args.target,
                                                num_words_target,
                                                args.word_min_count)
            vocab.vocab_to_json(
                vocab_target,
                os.path.join(output_folder, C.VOCAB_TRG_NAME) + C.JSON_SUFFIX)

        vocab_source_size = len(vocab_source)
        vocab_target_size = len(vocab_target)
        logger.info("Vocabulary sizes: source=%d target=%d", vocab_source_size,
                    vocab_target_size)

        config_data = data_io.DataConfig(
            os.path.abspath(args.source), os.path.abspath(args.target),
            os.path.abspath(args.validation_source),
            os.path.abspath(args.validation_target), args.source_vocab,
            args.target_vocab)

        # create data iterators
        max_seq_len_source = args.max_seq_len if args.max_seq_len_source is None else args.max_seq_len_source
        max_seq_len_target = args.max_seq_len if args.max_seq_len_target is None else args.max_seq_len_target
        train_iter, eval_iter = data_io.get_training_data_iters(
            source=config_data.source,
            target=config_data.target,
            validation_source=config_data.validation_source,
            validation_target=config_data.validation_target,
            vocab_source=vocab_source,
            vocab_target=vocab_target,
            batch_size=args.batch_size,
            fill_up=args.fill_up,
            max_seq_len_source=max_seq_len_source,
            max_seq_len_target=max_seq_len_target,
            bucketing=not args.no_bucketing,
            bucket_width=args.bucket_width)

        # learning rate scheduling
        learning_rate_half_life = none_if_negative(
            args.learning_rate_half_life)
        # TODO: The loading for continuation of the scheduler is done separately from the other parts
        if not resume_training:
            lr_scheduler_instance = lr_scheduler.get_lr_scheduler(
                args.learning_rate_scheduler_type, args.checkpoint_frequency,
                learning_rate_half_life, args.learning_rate_reduce_factor,
                args.learning_rate_reduce_num_not_improved)
        else:
            with open(os.path.join(training_state_dir, C.SCHEDULER_STATE_NAME),
                      "rb") as fp:
                lr_scheduler_instance = pickle.load(fp)

        # model configuration
        num_embed_source = args.num_embed if args.num_embed_source is None else args.num_embed_source
        num_embed_target = args.num_embed if args.num_embed_target is None else args.num_embed_target

        config_rnn = rnn.RNNConfig(cell_type=args.rnn_cell_type,
                                   num_hidden=args.rnn_num_hidden,
                                   num_layers=args.rnn_num_layers,
                                   dropout=args.dropout,
                                   residual=args.rnn_residual_connections,
                                   forget_bias=args.rnn_forget_bias)

        config_conv = None
        if args.encoder == C.RNN_WITH_CONV_EMBED_NAME:
            config_conv = encoder.ConvolutionalEmbeddingConfig(
                num_embed=num_embed_source,
                max_filter_width=args.conv_embed_max_filter_width,
                num_filters=args.conv_embed_num_filters,
                pool_stride=args.conv_embed_pool_stride,
                num_highway_layers=args.conv_embed_num_highway_layers,
                dropout=args.dropout)

        config_encoder = encoder.RecurrentEncoderConfig(
            vocab_size=vocab_source_size,
            num_embed=num_embed_source,
            rnn_config=config_rnn,
            conv_config=config_conv)

        config_decoder = decoder.RecurrentDecoderConfig(
            vocab_size=vocab_target_size,
            num_embed=num_embed_target,
            rnn_config=config_rnn,
            dropout=args.dropout,
            weight_tying=args.weight_tying,
            context_gating=args.context_gating,
            layer_normalization=args.layer_normalization)

        attention_num_hidden = args.rnn_num_hidden if not args.attention_num_hidden else args.attention_num_hidden
        config_coverage = None
        if args.attention_type == "coverage":
            config_coverage = coverage.CoverageConfig(
                type=args.attention_coverage_type,
                num_hidden=args.attention_coverage_num_hidden,
                layer_normalization=args.layer_normalization)
        config_attention = attention.AttentionConfig(
            type=args.attention_type,
            num_hidden=attention_num_hidden,
            input_previous_word=args.attention_use_prev_word,
            rnn_num_hidden=config_rnn.num_hidden,
            layer_normalization=args.layer_normalization,
            config_coverage=config_coverage)

        config_loss = loss.LossConfig(
            type=args.loss,
            vocab_size=vocab_target_size,
            normalize=args.normalize_loss,
            smoothed_cross_entropy_alpha=args.smoothed_cross_entropy_alpha)

        model_config = model.ModelConfig(
            config_data=config_data,
            max_seq_len=max_seq_len_source,
            vocab_source_size=vocab_source_size,
            vocab_target_size=vocab_target_size,
            config_encoder=config_encoder,
            config_decoder=config_decoder,
            config_attention=config_attention,
            config_loss=config_loss,
            lexical_bias=args.lexical_bias,
            learn_lexical_bias=args.learn_lexical_bias)
        model_config.freeze()

        # create training model
        training_model = training.TrainingModel(
            config=model_config,
            context=context,
            train_iter=train_iter,
            fused=args.use_fused_rnn,
            bucketing=not args.no_bucketing,
            lr_scheduler=lr_scheduler_instance)

        # We may consider loading the params in TrainingModule, for consistency
        # with the training state saving
        if resume_training:
            logger.info(
                "Found partial training in directory %s. Resuming from saved state.",
                training_state_dir)
            training_model.load_params_from_file(
                os.path.join(training_state_dir, C.TRAINING_STATE_PARAMS_NAME))
        elif args.params:
            logger.info(
                "Training will initialize from parameters loaded from '%s'",
                args.params)
            training_model.load_params_from_file(args.params)

        lexicon_array = lexicon.initialize_lexicon(
            args.lexical_bias, vocab_source,
            vocab_target) if args.lexical_bias else None

        weight_initializer = initializer.get_initializer(args.rnn_h2h_init,
                                                         lexicon=lexicon_array)

        optimizer = args.optimizer
        optimizer_params = {
            'wd': args.weight_decay,
            "learning_rate": args.initial_learning_rate
        }
        if lr_scheduler_instance is not None:
            optimizer_params["lr_scheduler"] = lr_scheduler_instance
        clip_gradient = none_if_negative(args.clip_gradient)
        if clip_gradient is not None:
            optimizer_params["clip_gradient"] = clip_gradient
        if args.momentum is not None:
            optimizer_params["momentum"] = args.momentum
        if args.normalize_loss:
            # When normalize_loss is turned on we normalize by the number of non-PAD symbols in a batch which implicitly
            # already contains the number of sentences and therefore we need to disable rescale_grad.
            optimizer_params["rescale_grad"] = 1.0
        else:
            # Making MXNet module API's default scaling factor explicit
            optimizer_params["rescale_grad"] = 1.0 / args.batch_size
        logger.info("Optimizer: %s", optimizer)
        logger.info("Optimizer Parameters: %s", optimizer_params)

        training_model.fit(
            train_iter,
            eval_iter,
            output_folder=output_folder,
            max_params_files_to_keep=args.keep_last_params,
            metrics=args.metrics,
            initializer=weight_initializer,
            max_updates=args.max_updates,
            checkpoint_frequency=args.checkpoint_frequency,
            optimizer=optimizer,
            optimizer_params=optimizer_params,
            optimized_metric=args.optimized_metric,
            max_num_not_improved=args.max_num_checkpoint_not_improved,
            min_num_epochs=args.min_num_epochs,
            monitor_bleu=args.monitor_bleu,
            use_tensorboard=args.use_tensorboard)