示例#1
0
def run_train_translate(train_params: str,
                        translate_params: str,
                        translate_params_equiv: Optional[str],
                        train_source_path: str,
                        train_target_path: str,
                        dev_source_path: str,
                        dev_target_path: str,
                        test_source_path: str,
                        test_target_path: str,
                        use_prepared_data: bool = False,
                        max_seq_len: int = 10,
                        restrict_lexicon: bool = False,
                        work_dir: Optional[str] = None,
                        seed: int = 13,
                        quiet: bool = False) -> Tuple[float, float, float, float]:
    """
    Train a model and translate a dev set.  Report validation perplexity and BLEU.

    :param train_params: Command line args for model training.
    :param translate_params: First command line args for translation.
    :param translate_params_equiv: Second command line args for translation. Should produce the same outputs
    :param train_source_path: Path to the source file.
    :param train_target_path: Path to the target file.
    :param dev_source_path: Path to the development source file.
    :param dev_target_path: Path to the development target file.
    :param test_source_path: Path to the test source file.
    :param test_target_path: Path to the test target file.
    :param use_prepared_data: Whether to use the prepared data functionality.
    :param max_seq_len: The maximum sequence length.
    :param restrict_lexicon: Additional translation run with top-k lexicon-based vocabulary restriction.
    :param work_dir: The directory to store the model and other outputs in.
    :param seed: The seed used for training.
    :param quiet: Suppress the console output of training and decoding.
    :return: A tuple containing perplexity, bleu scores for standard and reduced vocab decoding, chrf score.
    """
    if quiet:
        quiet_arg = "--quiet"
    else:
        quiet_arg = ""
    with TemporaryDirectory(dir=work_dir, prefix="test_train_translate.") as work_dir:
        # Optionally create prepared data directory
        if use_prepared_data:
            prepared_data_path = os.path.join(work_dir, "prepared_data")
            params = "{} {}".format(sockeye.prepare_data.__file__,
                                    _PREPARE_DATA_COMMON.format(train_source=train_source_path,
                                                                train_target=train_target_path,
                                                                output=prepared_data_path,
                                                                max_len=max_seq_len,
                                                                quiet=quiet_arg))
            logger.info("Creating prepared data folder.")
            with patch.object(sys, "argv", params.split()):
                sockeye.prepare_data.main()
            # Train model
            model_path = os.path.join(work_dir, "model")
            params = "{} {} {}".format(sockeye.train.__file__,
                                       _TRAIN_PARAMS_PREPARED_DATA_COMMON.format(prepared_data=prepared_data_path,
                                                                                 dev_source=dev_source_path,
                                                                                 dev_target=dev_target_path,
                                                                                 model=model_path,
                                                                                 max_len=max_seq_len,
                                                                                 quiet=quiet_arg),
                                       train_params)
            logger.info("Starting training with parameters %s.", train_params)
            with patch.object(sys, "argv", params.split()):
                sockeye.train.main()
        else:
            # Train model
            model_path = os.path.join(work_dir, "model")
            params = "{} {} {}".format(sockeye.train.__file__,
                                       _TRAIN_PARAMS_COMMON.format(train_source=train_source_path,
                                                                   train_target=train_target_path,
                                                                   dev_source=dev_source_path,
                                                                   dev_target=dev_target_path,
                                                                   model=model_path,
                                                                   max_len=max_seq_len,
                                                                   seed=seed,
                                                                   quiet=quiet_arg),
                                       train_params)
            logger.info("Starting training with parameters %s.", train_params)
            with patch.object(sys, "argv", params.split()):
                sockeye.train.main()

        logger.info("Translating with parameters %s.", translate_params)
        # Translate corpus with the 1st params
        out_path = os.path.join(work_dir, "out.txt")
        params = "{} {} {}".format(sockeye.translate.__file__,
                                   _TRANSLATE_PARAMS_COMMON.format(model=model_path,
                                                                   input=test_source_path,
                                                                   output=out_path,
                                                                   quiet=quiet_arg),
                                   translate_params)
        with patch.object(sys, "argv", params.split()):
            sockeye.translate.main()

        # Translate corpus with the 2nd params
        if translate_params_equiv is not None:
            out_path_equiv = os.path.join(work_dir, "out_equiv.txt")
            params = "{} {} {}".format(sockeye.translate.__file__,
                                       _TRANSLATE_PARAMS_COMMON.format(model=model_path,
                                                                       input=test_source_path,
                                                                       output=out_path_equiv,
                                                                       quiet=quiet_arg),
                                       translate_params_equiv)
            with patch.object(sys, "argv", params.split()):
                sockeye.translate.main()

            # read-in both outputs, ensure they are the same
            with open(out_path, 'rt') as f:
                lines = f.readlines()
            with open(out_path_equiv, 'rt') as f:
                lines_equiv = f.readlines()
            assert all(a == b for a, b in zip(lines, lines_equiv))

        # Test restrict-lexicon
        out_restrict_path = os.path.join(work_dir, "out-restrict.txt")
        if restrict_lexicon:
            # fast_align lex table
            lex_path = os.path.join(work_dir, "lex")
            generate_fast_align_lex(lex_path)
            # Top-K JSON
            json_path = os.path.join(work_dir, "json")
            params = "{} {}".format(sockeye.lexicon.__file__,
                                    _LEXICON_PARAMS_COMMON.format(input=lex_path,
                                                                  model=model_path,
                                                                  json=json_path,
                                                                  quiet=quiet_arg))
            with patch.object(sys, "argv", params.split()):
                sockeye.lexicon.main()
            # Translate corpus with restrict-lexicon
            params = "{} {} {} {}".format(sockeye.translate.__file__,
                                          _TRANSLATE_PARAMS_COMMON.format(model=model_path,
                                                                          input=test_source_path,
                                                                          output=out_restrict_path,
                                                                          quiet=quiet_arg),
                                          translate_params,
                                          _TRANSLATE_PARAMS_RESTRICT.format(json=json_path))
            with patch.object(sys, "argv", params.split()):
                sockeye.translate.main()

        # test averaging
        points = sockeye.average.find_checkpoints(model_path=model_path,
                                                  size=1,
                                                  strategy='best',
                                                  metric=C.PERPLEXITY)
        assert len(points) > 0
        averaged_params = sockeye.average.average(points)
        assert averaged_params

        # get best validation perplexity
        metrics = sockeye.utils.read_metrics_file(path=os.path.join(model_path, C.METRICS_NAME))
        perplexity = min(m[C.PERPLEXITY + '-val'] for m in metrics)

        hypotheses = open(out_path, "r").readlines()
        references = open(test_target_path, "r").readlines()
        assert len(hypotheses) == len(references)

        # compute metrics
        bleu = raw_corpus_bleu(hypotheses=hypotheses, references=references, offset=0.01)
        chrf = raw_corpus_chrf(hypotheses=hypotheses, references=references)

        bleu_restrict = None
        if restrict_lexicon:
            bleu_restrict = raw_corpus_bleu(hypotheses=hypotheses, references=references, offset=0.01)

        # Run BLEU cli
        eval_params = "{} {} ".format(sockeye.evaluate.__file__,
                                      _EVAL_PARAMS_COMMON.format(hypotheses=out_path,
                                                                 references=test_target_path,
                                                                 metrics="bleu chrf",
                                                                 quiet=quiet_arg), )
        with patch.object(sys, "argv", eval_params.split()):
            sockeye.evaluate.main()

        return perplexity, bleu, bleu_restrict, chrf
示例#2
0
def run_train_translate(
        train_params: str,
        translate_params: str,
        translate_params_equiv: Optional[str],
        train_source_path: str,
        train_target_path: str,
        dev_source_path: str,
        dev_target_path: str,
        test_source_path: str,
        test_target_path: str,
        train_source_factor_paths: Optional[List[str]] = None,
        dev_source_factor_paths: Optional[List[str]] = None,
        test_source_factor_paths: Optional[List[str]] = None,
        use_prepared_data: bool = False,
        max_seq_len: int = 10,
        restrict_lexicon: bool = False,
        work_dir: Optional[str] = None,
        seed: int = 13,
        quiet: bool = False) -> Tuple[float, float, float, float]:
    """
    Train a model and translate a dev set.  Report validation perplexity and BLEU.

    :param train_params: Command line args for model training.
    :param translate_params: First command line args for translation.
    :param translate_params_equiv: Second command line args for translation. Should produce the same outputs
    :param train_source_path: Path to the source file.
    :param train_target_path: Path to the target file.
    :param dev_source_path: Path to the development source file.
    :param dev_target_path: Path to the development target file.
    :param test_source_path: Path to the test source file.
    :param test_target_path: Path to the test target file.
    :param train_source_factor_paths: Optional list of paths to training source factor files.
    :param dev_source_factor_paths: Optional list of paths to dev source factor files.
    :param test_source_factor_paths: Optional list of paths to test source factor files.
    :param use_prepared_data: Whether to use the prepared data functionality.
    :param max_seq_len: The maximum sequence length.
    :param restrict_lexicon: Additional translation run with top-k lexicon-based vocabulary restriction.
    :param work_dir: The directory to store the model and other outputs in.
    :param seed: The seed used for training.
    :param quiet: Suppress the console output of training and decoding.
    :return: A tuple containing perplexity, bleu scores for standard and reduced vocab decoding, chrf score.
    """
    if quiet:
        quiet_arg = "--quiet"
    else:
        quiet_arg = ""
    with TemporaryDirectory(dir=work_dir,
                            prefix="test_train_translate.") as work_dir:
        # Optionally create prepared data directory
        if use_prepared_data:
            prepared_data_path = os.path.join(work_dir, "prepared_data")
            params = "{} {}".format(
                sockeye.prepare_data.__file__,
                _PREPARE_DATA_COMMON.format(train_source=train_source_path,
                                            train_target=train_target_path,
                                            output=prepared_data_path,
                                            max_len=max_seq_len,
                                            quiet=quiet_arg))
            if train_source_factor_paths is not None:
                params += _TRAIN_WITH_FACTORS_COMMON.format(
                    source_factors=" ".join(train_source_factor_paths))

            logger.info("Creating prepared data folder.")
            with patch.object(sys, "argv", params.split()):
                sockeye.prepare_data.main()
            # Train model
            model_path = os.path.join(work_dir, "model")
            params = "{} {} {}".format(
                sockeye.train.__file__,
                _TRAIN_PARAMS_PREPARED_DATA_COMMON.format(
                    prepared_data=prepared_data_path,
                    dev_source=dev_source_path,
                    dev_target=dev_target_path,
                    model=model_path,
                    max_len=max_seq_len,
                    quiet=quiet_arg), train_params)

            if dev_source_factor_paths is not None:
                params += _DEV_WITH_FACTORS_COMMON.format(
                    dev_source_factors=" ".join(dev_source_factor_paths))

            logger.info("Starting training with parameters %s.", train_params)
            with patch.object(sys, "argv", params.split()):
                sockeye.train.main()
        else:
            # Train model
            model_path = os.path.join(work_dir, "model")
            params = "{} {} {}".format(
                sockeye.train.__file__,
                _TRAIN_PARAMS_COMMON.format(train_source=train_source_path,
                                            train_target=train_target_path,
                                            dev_source=dev_source_path,
                                            dev_target=dev_target_path,
                                            model=model_path,
                                            max_len=max_seq_len,
                                            seed=seed,
                                            quiet=quiet_arg), train_params)

            if train_source_factor_paths is not None:
                params += _TRAIN_WITH_FACTORS_COMMON.format(
                    source_factors=" ".join(train_source_factor_paths))
            if dev_source_factor_paths is not None:
                params += _DEV_WITH_FACTORS_COMMON.format(
                    dev_source_factors=" ".join(dev_source_factor_paths))

            logger.info("Starting training with parameters %s.", train_params)
            with patch.object(sys, "argv", params.split()):
                sockeye.train.main()

        # run checkpoint decoder on 1% of dev data
        with open(dev_source_path) as dev_fd:
            num_dev_sent = sum(1 for _ in dev_fd)
        sample_size = min(1, int(num_dev_sent * 0.01))
        cp_decoder = sockeye.checkpoint_decoder.CheckpointDecoder(
            context=mx.cpu(),
            inputs=[dev_source_path],
            references=dev_target_path,
            model=model_path,
            sample_size=sample_size,
            batch_size=2,
            beam_size=2)
        cp_metrics = cp_decoder.decode_and_evaluate()
        logger.info("Checkpoint decoder metrics: %s", cp_metrics)

        logger.info("Translating with parameters %s.", translate_params)
        # Translate corpus with the 1st params
        out_path = os.path.join(work_dir, "out.txt")
        params = "{} {} {}".format(
            sockeye.translate.__file__,
            _TRANSLATE_PARAMS_COMMON.format(model=model_path,
                                            input=test_source_path,
                                            output=out_path,
                                            quiet=quiet_arg), translate_params)

        if test_source_factor_paths is not None:
            params += _TRANSLATE_WITH_FACTORS_COMMON.format(
                input_factors=" ".join(test_source_factor_paths))

        with patch.object(sys, "argv", params.split()):
            sockeye.translate.main()

        # Translate corpus with the 2nd params
        if translate_params_equiv is not None:
            out_path_equiv = os.path.join(work_dir, "out_equiv.txt")
            params = "{} {} {}".format(
                sockeye.translate.__file__,
                _TRANSLATE_PARAMS_COMMON.format(model=model_path,
                                                input=test_source_path,
                                                output=out_path_equiv,
                                                quiet=quiet_arg),
                translate_params_equiv)

            if test_source_factor_paths is not None:
                params += _TRANSLATE_WITH_FACTORS_COMMON.format(
                    input_factors=" ".join(test_source_factor_paths))

            with patch.object(sys, "argv", params.split()):
                sockeye.translate.main()

            # read-in both outputs, ensure they are the same
            with open(out_path, 'rt') as f:
                lines = f.readlines()
            with open(out_path_equiv, 'rt') as f:
                lines_equiv = f.readlines()
            assert all(a == b for a, b in zip(lines, lines_equiv))

        # Test restrict-lexicon
        out_restrict_path = os.path.join(work_dir, "out-restrict.txt")
        if restrict_lexicon:
            # fast_align lex table
            ttable_path = os.path.join(work_dir, "ttable")
            generate_fast_align_lex(ttable_path)
            # Top-K lexicon
            lexicon_path = os.path.join(work_dir, "lexicon")
            params = "{} {}".format(
                sockeye.lexicon.__file__,
                _LEXICON_CREATE_PARAMS_COMMON.format(input=ttable_path,
                                                     model=model_path,
                                                     topk=20,
                                                     lexicon=lexicon_path,
                                                     quiet=quiet_arg))
            with patch.object(sys, "argv", params.split()):
                sockeye.lexicon.main()
            # Translate corpus with restrict-lexicon
            params = "{} {} {} {}".format(
                sockeye.translate.__file__,
                _TRANSLATE_PARAMS_COMMON.format(model=model_path,
                                                input=test_source_path,
                                                output=out_restrict_path,
                                                quiet=quiet_arg),
                translate_params,
                _TRANSLATE_PARAMS_RESTRICT.format(lexicon=lexicon_path,
                                                  topk=1))

            if test_source_factor_paths is not None:
                params += _TRANSLATE_WITH_FACTORS_COMMON.format(
                    input_factors=" ".join(test_source_factor_paths))

            with patch.object(sys, "argv", params.split()):
                sockeye.translate.main()

        # test averaging
        points = sockeye.average.find_checkpoints(model_path=model_path,
                                                  size=1,
                                                  strategy='best',
                                                  metric=C.PERPLEXITY)
        assert len(points) > 0
        averaged_params = sockeye.average.average(points)
        assert averaged_params

        # test parameter extraction
        extract_params = _EXTRACT_PARAMS.format(output=os.path.join(
            model_path, "params.extracted"),
                                                input=model_path)
        with patch.object(sys, "argv", extract_params.split()):
            sockeye.extract_parameters.main()
        with np.load(os.path.join(model_path, "params.extracted.npz")) as data:
            assert "target_output_bias" in data

        # get best validation perplexity
        metrics = sockeye.utils.read_metrics_file(
            path=os.path.join(model_path, C.METRICS_NAME))
        perplexity = min(m[C.PERPLEXITY + '-val'] for m in metrics)

        with open(out_path, "r") as out:
            hypotheses = out.readlines()
        with open(test_target_path, "r") as ref:
            references = ref.readlines()
        assert len(hypotheses) == len(references)

        # compute metrics
        bleu = raw_corpus_bleu(hypotheses=hypotheses,
                               references=references,
                               offset=0.01)
        chrf = raw_corpus_chrf(hypotheses=hypotheses, references=references)

        bleu_restrict = None
        if restrict_lexicon:
            bleu_restrict = raw_corpus_bleu(hypotheses=hypotheses,
                                            references=references,
                                            offset=0.01)

        # Run evaluate cli
        eval_params = "{} {} ".format(
            sockeye.evaluate.__file__,
            _EVAL_PARAMS_COMMON.format(hypotheses=out_path,
                                       references=test_target_path,
                                       metrics="bleu chrf rouge1",
                                       quiet=quiet_arg),
        )
        with patch.object(sys, "argv", eval_params.split()):
            sockeye.evaluate.main()

        return perplexity, bleu, bleu_restrict, chrf
示例#3
0
def run_train_translate(
        train_params: str,
        translate_params: str,
        translate_params_equiv: Optional[str],
        train_source_path: str,
        train_target_path: str,
        dev_source_path: str,
        dev_target_path: str,
        test_source_path: str,
        test_target_path: str,
        train_source_factor_paths: Optional[List[str]] = None,
        dev_source_factor_paths: Optional[List[str]] = None,
        test_source_factor_paths: Optional[List[str]] = None,
        use_prepared_data: bool = False,
        use_target_constraints: bool = False,
        max_seq_len: int = 10,
        restrict_lexicon: bool = False,
        work_dir: Optional[str] = None,
        seed: int = 13,
        quiet: bool = False) -> Tuple[float, float, float, float]:
    """
    Train a model and translate a dev set.  Report validation perplexity and BLEU.

    :param train_params: Command line args for model training.
    :param translate_params: First command line args for translation.
    :param translate_params_equiv: Second command line args for translation. Should produce the same outputs
    :param train_source_path: Path to the source file.
    :param train_target_path: Path to the target file.
    :param dev_source_path: Path to the development source file.
    :param dev_target_path: Path to the development target file.
    :param test_source_path: Path to the test source file.
    :param test_target_path: Path to the test target file.
    :param train_source_factor_paths: Optional list of paths to training source factor files.
    :param dev_source_factor_paths: Optional list of paths to dev source factor files.
    :param test_source_factor_paths: Optional list of paths to test source factor files.
    :param use_prepared_data: Whether to use the prepared data functionality.
    :param max_seq_len: The maximum sequence length.
    :param restrict_lexicon: Additional translation run with top-k lexicon-based vocabulary restriction.
    :param work_dir: The directory to store the model and other outputs in.
    :param seed: The seed used for training.
    :param quiet: Suppress the console output of training and decoding.
    :return: A tuple containing perplexity, bleu scores for standard and reduced vocab decoding, chrf score.
    """
    if quiet:
        quiet_arg = "--quiet"
    else:
        quiet_arg = ""
    with TemporaryDirectory(dir=work_dir,
                            prefix="test_train_translate.") as work_dir:
        # Optionally create prepared data directory
        if use_prepared_data:
            prepared_data_path = os.path.join(work_dir, "prepared_data")
            params = "{} {}".format(
                sockeye.prepare_data.__file__,
                _PREPARE_DATA_COMMON.format(train_source=train_source_path,
                                            train_target=train_target_path,
                                            output=prepared_data_path,
                                            max_len=max_seq_len,
                                            quiet=quiet_arg))
            if train_source_factor_paths is not None:
                params += _TRAIN_WITH_FACTORS_COMMON.format(
                    source_factors=" ".join(train_source_factor_paths))

            logger.info("Creating prepared data folder.")
            with patch.object(sys, "argv", params.split()):
                sockeye.prepare_data.main()
            # Train model
            model_path = os.path.join(work_dir, "model")
            params = "{} {} {}".format(
                sockeye.train.__file__,
                _TRAIN_PARAMS_PREPARED_DATA_COMMON.format(
                    prepared_data=prepared_data_path,
                    dev_source=dev_source_path,
                    dev_target=dev_target_path,
                    model=model_path,
                    max_len=max_seq_len,
                    quiet=quiet_arg), train_params)

            if dev_source_factor_paths is not None:
                params += _DEV_WITH_FACTORS_COMMON.format(
                    dev_source_factors=" ".join(dev_source_factor_paths))

            logger.info("Starting training with parameters %s.", train_params)
            with patch.object(sys, "argv", params.split()):
                sockeye.train.main()
        else:
            # Train model
            model_path = os.path.join(work_dir, "model")
            params = "{} {} {}".format(
                sockeye.train.__file__,
                _TRAIN_PARAMS_COMMON.format(train_source=train_source_path,
                                            train_target=train_target_path,
                                            dev_source=dev_source_path,
                                            dev_target=dev_target_path,
                                            model=model_path,
                                            max_len=max_seq_len,
                                            seed=seed,
                                            quiet=quiet_arg), train_params)

            if train_source_factor_paths is not None:
                params += _TRAIN_WITH_FACTORS_COMMON.format(
                    source_factors=" ".join(train_source_factor_paths))
            if dev_source_factor_paths is not None:
                params += _DEV_WITH_FACTORS_COMMON.format(
                    dev_source_factors=" ".join(dev_source_factor_paths))

            logger.info("Starting training with parameters %s.", train_params)
            with patch.object(sys, "argv", params.split()):
                sockeye.train.main()

        # run checkpoint decoder on 1% of dev data
        with open(dev_source_path) as dev_fd:
            num_dev_sent = sum(1 for _ in dev_fd)
        sample_size = min(1, int(num_dev_sent * 0.01))
        cp_decoder = sockeye.checkpoint_decoder.CheckpointDecoder(
            context=mx.cpu(),
            inputs=[dev_source_path],
            references=dev_target_path,
            model=model_path,
            sample_size=sample_size,
            batch_size=2,
            beam_size=2)
        cp_metrics = cp_decoder.decode_and_evaluate()
        logger.info("Checkpoint decoder metrics: %s", cp_metrics)

        logger.info("Translating with parameters %s.", translate_params)
        # Translate corpus with the 1st params
        out_path = os.path.join(work_dir, "out.txt")
        translate_score_path = os.path.join(work_dir, "out.scores.txt")
        params = "{} {} {} --output-type translation_with_score".format(
            sockeye.translate.__file__,
            _TRANSLATE_PARAMS_COMMON.format(model=model_path,
                                            input=test_source_path,
                                            output=out_path,
                                            quiet=quiet_arg), translate_params)

        if test_source_factor_paths is not None:
            params += _TRANSLATE_WITH_FACTORS_COMMON.format(
                input_factors=" ".join(test_source_factor_paths))

        with patch.object(sys, "argv", params.split()):
            sockeye.translate.main()

        # Break out translation and score
        with open(out_path) as out_fh:
            outputs = out_fh.readlines()
        with open(out_path,
                  'w') as out_translate, open(translate_score_path,
                                              'w') as out_scores:
            for output in outputs:
                output = output.strip()
                # blank lines on test input will have only one field output (-inf for the score)
                try:
                    score, translation = output.split('\t')
                except ValueError:
                    score = output
                    translation = ""
                print(translation, file=out_translate)
                print(score, file=out_scores)

        # Test target constraints
        if use_target_constraints:
            """
            Read in the unconstrained system output from the first pass and use it to generate positive
            and negative constraints. It is important to generate a mix of positive, negative, and no
            constraints per batch, to test these production-realistic interactions as well.
            """
            # 'constraint' = positive constraints (must appear), 'avoid' = negative constraints (must not appear)
            for constraint_type in ["constraints", "avoid"]:
                constrained_sources = []
                with open(test_source_path) as source_inp, open(
                        out_path) as system_out:
                    for sentno, (source, target) in enumerate(
                            zip(source_inp, system_out)):
                        target_words = target.rstrip().split()
                        target_len = len(target_words)
                        new_source = {'text': source.rstrip()}
                        # From the odd-numbered sentences that are not too long, create constraints. We do
                        # only odds to ensure we get batches with mixed constraints / lack of constraints.
                        if target_len > 0 and sentno % 2 == 0:
                            start_pos = 0
                            end_pos = min(target_len, 3)
                            constraint = ' '.join(
                                target_words[start_pos:end_pos])
                            new_source[constraint_type] = [constraint]
                        constrained_sources.append(json.dumps(new_source))

                new_test_source_path = os.path.join(work_dir,
                                                    "test_constrained.txt")
                with open(new_test_source_path, 'w') as out:
                    for json_line in constrained_sources:
                        print(json_line, file=out)

                out_path_constrained = os.path.join(work_dir,
                                                    "out_constrained.txt")
                params = "{} {} {} --json-input".format(
                    sockeye.translate.__file__,
                    _TRANSLATE_PARAMS_COMMON.format(
                        model=model_path,
                        input=new_test_source_path,
                        output=out_path_constrained,
                        quiet=quiet_arg), translate_params)

                with patch.object(sys, "argv", params.split()):
                    sockeye.translate.main()

                for json_input, constrained_out, unconstrained_out in zip(
                        open(new_test_source_path), open(out_path_constrained),
                        open(out_path)):
                    jobj = json.loads(json_input)
                    if jobj.get(constraint_type) is None:
                        # if there were no constraints, make sure the output is the same as the unconstrained output
                        assert constrained_out == unconstrained_out
                    else:
                        restriction = jobj[constraint_type][0]
                        if constraint_type == 'constraints':
                            # for positive constraints, ensure the constraint is in the constrained output
                            assert restriction in constrained_out
                        else:
                            # for negative constraints, ensure the constraints is *not* in the constrained output
                            assert restriction not in constrained_out

        # Test scoring by ensuring that the sockeye.scoring module produces the same scores when scoring the output
        # of sockeye.translate. However, since this training is on very small datasets, the output of sockeye.translate
        # is often pure garbage or empty and cannot be scored. So we only try to score if we have some valid output
        # to work with.

        # Skip if there are invalid tokens in the output, or if no valid outputs were found
        translate_output_is_valid = True
        with open(out_path) as out_fh:
            sentences = list(map(lambda x: x.rstrip(), out_fh.readlines()))
            # At least one output must be non-empty
            found_valid_output = any(sentences)

            # There must be no bad tokens
            found_bad_tokens = any([
                bad_token in ' '.join(sentences)
                for bad_token in C.VOCAB_SYMBOLS
            ])

            translate_output_is_valid = found_valid_output and not found_bad_tokens

        # Only run scoring under these conditions. Why?
        # - scoring isn't compatible with prepared data because that loses the source ordering
        # - translate splits up too-long sentences and translates them in sequence, invalidating the score, so skip that
        # - scoring requires valid translation output to compare against
        if not use_prepared_data \
           and '--max-input-len' not in translate_params \
           and translate_output_is_valid:

            # Score
            # We use the translation parameters, but have to remove irrelevant arguments from it.
            # Currently, the only relevant flag passed is the --softmax-temperature flag.
            score_params = ''
            if 'softmax-temperature' in translate_params:
                params = translate_params.split(C.TOKEN_SEPARATOR)
                for i, param in enumerate(params):
                    if param == '--softmax-temperature':
                        score_params = '--softmax-temperature {}'.format(
                            params[i + 1])
                        break

            scores_output_file = out_path + '.score'
            params = "{} {} {}".format(
                sockeye.score.__file__,
                _SCORE_PARAMS_COMMON.format(model=model_path,
                                            source=test_source_path,
                                            target=out_path,
                                            output=scores_output_file),
                score_params)

            if test_source_factor_paths is not None:
                params += _SCORE_WITH_FACTORS_COMMON.format(
                    source_factors=" ".join(test_source_factor_paths))

            with patch.object(sys, "argv", params.split()):
                sockeye.score.main()

            # Compare scored output to original translation output. There are a few tricks: for blank source sentences,
            # inference will report a score of -inf, so skip these. Second, we don't know if the scores include the
            # generation of </s> and have had length normalization applied. So, skip all sentences that are as long
            # as the maximum length, in order to safely exclude them.
            with open(translate_score_path) as in_translate, open(
                    out_path) as in_words, open(
                        scores_output_file) as in_score:
                model_config = sockeye.model.SockeyeModel.load_config(
                    os.path.join(model_path, C.CONFIG_NAME))
                max_len = model_config.config_data.max_seq_len_target

                # Filter out sockeye.translate sentences that had -inf or were too long (which sockeye.score will have skipped)
                translate_scores = []
                translate_lens = []
                score_scores = in_score.readlines()
                for score, sent in zip(in_translate.readlines(),
                                       in_words.readlines()):
                    if score != '-inf\n' and len(sent.split()) < max_len:
                        translate_scores.append(score)
                        translate_lens.append(len(sent.split()))

                assert len(translate_scores) == len(score_scores)

                # Compare scores (using 0.002 which covers common noise comparing e.g., 1.234 and 1.235)
                for translate_score, translate_len, score_score in zip(
                        translate_scores, translate_lens, score_scores):
                    # Skip sentences that are close to the maximum length to avoid confusion about whether
                    # the length penalty was applied
                    if translate_len >= max_len - 2:
                        continue

                    translate_score = float(translate_score)
                    score_score = float(score_score)

                    assert abs(translate_score - score_score) < 0.002

        # Translate corpus with the 2nd params
        if translate_params_equiv is not None:
            out_path_equiv = os.path.join(work_dir, "out_equiv.txt")
            params = "{} {} {}".format(
                sockeye.translate.__file__,
                _TRANSLATE_PARAMS_COMMON.format(model=model_path,
                                                input=test_source_path,
                                                output=out_path_equiv,
                                                quiet=quiet_arg),
                translate_params_equiv)

            if test_source_factor_paths is not None:
                params += _TRANSLATE_WITH_FACTORS_COMMON.format(
                    input_factors=" ".join(test_source_factor_paths))

            with patch.object(sys, "argv", params.split()):
                sockeye.translate.main()

            # read-in both outputs, ensure they are the same
            with open(out_path, 'rt') as f:
                lines = f.readlines()
            with open(out_path_equiv, 'rt') as f:
                lines_equiv = f.readlines()
            assert all(a == b for a, b in zip(lines, lines_equiv))

        # Test restrict-lexicon
        out_restrict_path = os.path.join(work_dir, "out-restrict.txt")
        if restrict_lexicon:
            # fast_align lex table
            ttable_path = os.path.join(work_dir, "ttable")
            generate_fast_align_lex(ttable_path)
            # Top-K lexicon
            lexicon_path = os.path.join(work_dir, "lexicon")
            params = "{} {}".format(
                sockeye.lexicon.__file__,
                _LEXICON_CREATE_PARAMS_COMMON.format(input=ttable_path,
                                                     model=model_path,
                                                     topk=20,
                                                     lexicon=lexicon_path,
                                                     quiet=quiet_arg))
            with patch.object(sys, "argv", params.split()):
                sockeye.lexicon.main()
            # Translate corpus with restrict-lexicon
            params = "{} {} {} {}".format(
                sockeye.translate.__file__,
                _TRANSLATE_PARAMS_COMMON.format(model=model_path,
                                                input=test_source_path,
                                                output=out_restrict_path,
                                                quiet=quiet_arg),
                translate_params,
                _TRANSLATE_PARAMS_RESTRICT.format(lexicon=lexicon_path,
                                                  topk=1))

            if test_source_factor_paths is not None:
                params += _TRANSLATE_WITH_FACTORS_COMMON.format(
                    input_factors=" ".join(test_source_factor_paths))

            with patch.object(sys, "argv", params.split()):
                sockeye.translate.main()

        # test averaging
        points = sockeye.average.find_checkpoints(model_path=model_path,
                                                  size=1,
                                                  strategy='best',
                                                  metric=C.PERPLEXITY)
        assert len(points) > 0
        averaged_params = sockeye.average.average(points)
        assert averaged_params

        # test parameter extraction
        extract_params = _EXTRACT_PARAMS.format(output=os.path.join(
            model_path, "params.extracted"),
                                                input=model_path)
        with patch.object(sys, "argv", extract_params.split()):
            sockeye.extract_parameters.main()
        with np.load(os.path.join(model_path, "params.extracted.npz")) as data:
            assert "target_output_bias" in data

        # get best validation perplexity
        metrics = sockeye.utils.read_metrics_file(
            path=os.path.join(model_path, C.METRICS_NAME))
        perplexity = min(m[C.PERPLEXITY + '-val'] for m in metrics)

        with open(out_path, "r") as out:
            hypotheses = out.readlines()
        with open(test_target_path, "r") as ref:
            references = ref.readlines()
        assert len(hypotheses) == len(references)

        # compute metrics
        bleu = raw_corpus_bleu(hypotheses=hypotheses,
                               references=references,
                               offset=0.01)
        chrf = raw_corpus_chrf(hypotheses=hypotheses, references=references)

        bleu_restrict = None
        if restrict_lexicon:
            bleu_restrict = raw_corpus_bleu(hypotheses=hypotheses,
                                            references=references,
                                            offset=0.01)

        # Run evaluate cli
        eval_params = "{} {} ".format(
            sockeye.evaluate.__file__,
            _EVAL_PARAMS_COMMON.format(hypotheses=out_path,
                                       references=test_target_path,
                                       metrics="bleu chrf rouge1",
                                       quiet=quiet_arg),
        )
        with patch.object(sys, "argv", eval_params.split()):
            sockeye.evaluate.main()

        return perplexity, bleu, bleu_restrict, chrf
示例#4
0
def run_train_translate(train_params: str,
                        translate_params: str,
                        translate_params_equiv: Optional[str],
                        train_source_path: str,
                        train_target_path: str,
                        dev_source_path: str,
                        dev_target_path: str,
                        test_source_path: str,
                        test_target_path: str,
                        train_source_factor_paths: Optional[List[str]] = None,
                        dev_source_factor_paths: Optional[List[str]] = None,
                        test_source_factor_paths: Optional[List[str]] = None,
                        use_prepared_data: bool = False,
                        use_target_constraints: bool = False,
                        max_seq_len: int = 10,
                        restrict_lexicon: bool = False,
                        work_dir: Optional[str] = None,
                        seed: int = 13,
                        quiet: bool = False) -> Tuple[float, float, float, float]:
    """
    Train a model and translate a dev set.  Report validation perplexity and BLEU.

    :param train_params: Command line args for model training.
    :param translate_params: First command line args for translation.
    :param translate_params_equiv: Second command line args for translation. Should produce the same outputs
    :param train_source_path: Path to the source file.
    :param train_target_path: Path to the target file.
    :param dev_source_path: Path to the development source file.
    :param dev_target_path: Path to the development target file.
    :param test_source_path: Path to the test source file.
    :param test_target_path: Path to the test target file.
    :param train_source_factor_paths: Optional list of paths to training source factor files.
    :param dev_source_factor_paths: Optional list of paths to dev source factor files.
    :param test_source_factor_paths: Optional list of paths to test source factor files.
    :param use_prepared_data: Whether to use the prepared data functionality.
    :param max_seq_len: The maximum sequence length.
    :param restrict_lexicon: Additional translation run with top-k lexicon-based vocabulary restriction.
    :param work_dir: The directory to store the model and other outputs in.
    :param seed: The seed used for training.
    :param quiet: Suppress the console output of training and decoding.
    :return: A tuple containing perplexity, bleu scores for standard and reduced vocab decoding, chrf score.
    """
    if quiet:
        quiet_arg = "--quiet"
    else:
        quiet_arg = ""
    with TemporaryDirectory(dir=work_dir, prefix="test_train_translate.") as work_dir:
        # Optionally create prepared data directory
        if use_prepared_data:
            prepared_data_path = os.path.join(work_dir, "prepared_data")
            params = "{} {}".format(sockeye.prepare_data.__file__,
                                    _PREPARE_DATA_COMMON.format(train_source=train_source_path,
                                                                train_target=train_target_path,
                                                                output=prepared_data_path,
                                                                max_len=max_seq_len,
                                                                quiet=quiet_arg))
            if train_source_factor_paths is not None:
                params += _TRAIN_WITH_FACTORS_COMMON.format(source_factors=" ".join(train_source_factor_paths))

            logger.info("Creating prepared data folder.")
            with patch.object(sys, "argv", params.split()):
                sockeye.prepare_data.main()
            # Train model
            model_path = os.path.join(work_dir, "model")
            params = "{} {} {}".format(sockeye.train.__file__,
                                       _TRAIN_PARAMS_PREPARED_DATA_COMMON.format(prepared_data=prepared_data_path,
                                                                                 dev_source=dev_source_path,
                                                                                 dev_target=dev_target_path,
                                                                                 model=model_path,
                                                                                 max_len=max_seq_len,
                                                                                 quiet=quiet_arg),
                                       train_params)

            if dev_source_factor_paths is not None:
                params += _DEV_WITH_FACTORS_COMMON.format(dev_source_factors=" ".join(dev_source_factor_paths))

            logger.info("Starting training with parameters %s.", train_params)
            with patch.object(sys, "argv", params.split()):
                sockeye.train.main()
        else:
            # Train model
            model_path = os.path.join(work_dir, "model")
            params = "{} {} {}".format(sockeye.train.__file__,
                                       _TRAIN_PARAMS_COMMON.format(train_source=train_source_path,
                                                                   train_target=train_target_path,
                                                                   dev_source=dev_source_path,
                                                                   dev_target=dev_target_path,
                                                                   model=model_path,
                                                                   max_len=max_seq_len,
                                                                   seed=seed,
                                                                   quiet=quiet_arg),
                                       train_params)

            if train_source_factor_paths is not None:
                params += _TRAIN_WITH_FACTORS_COMMON.format(source_factors=" ".join(train_source_factor_paths))
            if dev_source_factor_paths is not None:
                params += _DEV_WITH_FACTORS_COMMON.format(dev_source_factors=" ".join(dev_source_factor_paths))

            logger.info("Starting training with parameters %s.", train_params)
            with patch.object(sys, "argv", params.split()):
                sockeye.train.main()

        # run checkpoint decoder on 1% of dev data
        with open(dev_source_path) as dev_fd:
            num_dev_sent = sum(1 for _ in dev_fd)
        sample_size = min(1, int(num_dev_sent * 0.01))
        cp_decoder = sockeye.checkpoint_decoder.CheckpointDecoder(context=mx.cpu(),
                                                                  inputs=[dev_source_path],
                                                                  references=dev_target_path,
                                                                  model=model_path,
                                                                  sample_size=sample_size,
                                                                  batch_size=2,
                                                                  beam_size=2)
        cp_metrics = cp_decoder.decode_and_evaluate()
        logger.info("Checkpoint decoder metrics: %s", cp_metrics)

        logger.info("Translating with parameters %s.", translate_params)
        # Translate corpus with the 1st params
        out_path = os.path.join(work_dir, "out.txt")
        params = "{} {} {}".format(sockeye.translate.__file__,
                                   _TRANSLATE_PARAMS_COMMON.format(model=model_path,
                                                                   input=test_source_path,
                                                                   output=out_path,
                                                                   quiet=quiet_arg),
                                   translate_params)

        if test_source_factor_paths is not None:
            params += _TRANSLATE_WITH_FACTORS_COMMON.format(input_factors=" ".join(test_source_factor_paths))

        with patch.object(sys, "argv", params.split()):
            sockeye.translate.main()

        # Test target constraints
        if use_target_constraints:
            """
            Read in the unconstrained system output from the first pass and use it to generate positive
            and negative constraints. It is important to generate a mix of positive, negative, and no
            constraints per batch, to test these production-realistic interactions as well.
            """
            # 'constraint' = positive constraints (must appear), 'avoid' = negative constraints (must not appear)
            for constraint_type in ["constraints", "avoid"]:
                constrained_sources = []
                with open(test_source_path) as source_inp, open(out_path) as system_out:
                    for sentno, (source, target) in enumerate(zip(source_inp, system_out)):
                        target_words = target.rstrip().split()
                        target_len = len(target_words)
                        new_source = {'text': source.rstrip()}
                        # From the odd-numbered sentences that are not too long, create constraints. We do
                        # only odds to ensure we get batches with mixed constraints / lack of constraints.
                        if target_len > 0 and sentno % 2 == 0:
                            start_pos = 0
                            end_pos = min(target_len, 3)
                            constraint = ' '.join(target_words[start_pos:end_pos])
                            new_source[constraint_type] = [constraint]
                        constrained_sources.append(json.dumps(new_source))

                new_test_source_path = os.path.join(work_dir, "test_constrained.txt")
                with open(new_test_source_path, 'w') as out:
                    for json_line in constrained_sources:
                        print(json_line, file=out)

                out_path_constrained = os.path.join(work_dir, "out_constrained.txt")
                params = "{} {} {} --json-input".format(sockeye.translate.__file__,
                                                        _TRANSLATE_PARAMS_COMMON.format(model=model_path,
                                                                                        input=new_test_source_path,
                                                                                        output=out_path_constrained,
                                                                                        quiet=quiet_arg),
                                                        translate_params)

                with patch.object(sys, "argv", params.split()):
                    sockeye.translate.main()

                for json_input, constrained_out, unconstrained_out in zip(open(new_test_source_path),
                                                                          open(out_path_constrained),
                                                                          open(out_path)):
                    jobj = json.loads(json_input)
                    if jobj.get(constraint_type, None) == None:
                        # if there were no constraints, make sure the output is the same as the unconstrained output
                        assert constrained_out == unconstrained_out
                    else:
                        restriction = jobj[constraint_type][0]
                        if constraint_type == 'constraints':
                            # for positive constraints, ensure the constraint is in the constrained output
                            assert restriction in constrained_out
                        else:
                            # for negative constraints, ensure the constraints is *not* in the constrained output
                            assert restriction not in constrained_out

        # Translate corpus with the 2nd params
        if translate_params_equiv is not None:
            out_path_equiv = os.path.join(work_dir, "out_equiv.txt")
            params = "{} {} {}".format(sockeye.translate.__file__,
                                       _TRANSLATE_PARAMS_COMMON.format(model=model_path,
                                                                       input=test_source_path,
                                                                       output=out_path_equiv,
                                                                       quiet=quiet_arg),
                                       translate_params_equiv)

            if test_source_factor_paths is not None:
                params += _TRANSLATE_WITH_FACTORS_COMMON.format(input_factors=" ".join(test_source_factor_paths))

            with patch.object(sys, "argv", params.split()):
                sockeye.translate.main()

            # read-in both outputs, ensure they are the same
            with open(out_path, 'rt') as f:
                lines = f.readlines()
            with open(out_path_equiv, 'rt') as f:
                lines_equiv = f.readlines()
            assert all(a == b for a, b in zip(lines, lines_equiv))

        # Test restrict-lexicon
        out_restrict_path = os.path.join(work_dir, "out-restrict.txt")
        if restrict_lexicon:
            # fast_align lex table
            ttable_path = os.path.join(work_dir, "ttable")
            generate_fast_align_lex(ttable_path)
            # Top-K lexicon
            lexicon_path = os.path.join(work_dir, "lexicon")
            params = "{} {}".format(sockeye.lexicon.__file__,
                                    _LEXICON_CREATE_PARAMS_COMMON.format(input=ttable_path,
                                                                         model=model_path,
                                                                         topk=20,
                                                                         lexicon=lexicon_path,
                                                                         quiet=quiet_arg))
            with patch.object(sys, "argv", params.split()):
                sockeye.lexicon.main()
            # Translate corpus with restrict-lexicon
            params = "{} {} {} {}".format(sockeye.translate.__file__,
                                          _TRANSLATE_PARAMS_COMMON.format(model=model_path,
                                                                          input=test_source_path,
                                                                          output=out_restrict_path,
                                                                          quiet=quiet_arg),
                                          translate_params,
                                          _TRANSLATE_PARAMS_RESTRICT.format(lexicon=lexicon_path, topk=1))

            if test_source_factor_paths is not None:
                params += _TRANSLATE_WITH_FACTORS_COMMON.format(input_factors=" ".join(test_source_factor_paths))

            with patch.object(sys, "argv", params.split()):
                sockeye.translate.main()

        # test averaging
        points = sockeye.average.find_checkpoints(model_path=model_path,
                                                  size=1,
                                                  strategy='best',
                                                  metric=C.PERPLEXITY)
        assert len(points) > 0
        averaged_params = sockeye.average.average(points)
        assert averaged_params

        # test parameter extraction
        extract_params = _EXTRACT_PARAMS.format(output=os.path.join(model_path, "params.extracted"),
                                                input=model_path)
        with patch.object(sys, "argv", extract_params.split()):
            sockeye.extract_parameters.main()
        with np.load(os.path.join(model_path, "params.extracted.npz")) as data:
            assert "target_output_bias" in data

        # get best validation perplexity
        metrics = sockeye.utils.read_metrics_file(path=os.path.join(model_path, C.METRICS_NAME))
        perplexity = min(m[C.PERPLEXITY + '-val'] for m in metrics)

        with open(out_path, "r") as out:
            hypotheses = out.readlines()
        with open(test_target_path, "r") as ref:
            references = ref.readlines()
        assert len(hypotheses) == len(references)

        # compute metrics
        bleu = raw_corpus_bleu(hypotheses=hypotheses, references=references, offset=0.01)
        chrf = raw_corpus_chrf(hypotheses=hypotheses, references=references)

        bleu_restrict = None
        if restrict_lexicon:
            bleu_restrict = raw_corpus_bleu(hypotheses=hypotheses, references=references, offset=0.01)

        # Run evaluate cli
        eval_params = "{} {} ".format(sockeye.evaluate.__file__,
                                      _EVAL_PARAMS_COMMON.format(hypotheses=out_path,
                                                                 references=test_target_path,
                                                                 metrics="bleu chrf rouge1",
                                                                 quiet=quiet_arg), )
        with patch.object(sys, "argv", eval_params.split()):
            sockeye.evaluate.main()

        return perplexity, bleu, bleu_restrict, chrf
示例#5
0
文件: common.py 项目: lagka/sockeye
def run_train_translate(train_params: str,
                        translate_params: str,
                        translate_params_equiv: Optional[str],
                        train_source_path: str,
                        train_target_path: str,
                        dev_source_path: str,
                        dev_target_path: str,
                        test_source_path: str,
                        test_target_path: str,
                        train_source_factor_paths: Optional[List[str]] = None,
                        dev_source_factor_paths: Optional[List[str]] = None,
                        test_source_factor_paths: Optional[List[str]] = None,
                        use_prepared_data: bool = False,
                        max_seq_len: int = 10,
                        restrict_lexicon: bool = False,
                        work_dir: Optional[str] = None,
                        seed: int = 13,
                        quiet: bool = False) -> Tuple[float, float, float, float]:
    """
    Train a model and translate a dev set.  Report validation perplexity and BLEU.

    :param train_params: Command line args for model training.
    :param translate_params: First command line args for translation.
    :param translate_params_equiv: Second command line args for translation. Should produce the same outputs
    :param train_source_path: Path to the source file.
    :param train_target_path: Path to the target file.
    :param dev_source_path: Path to the development source file.
    :param dev_target_path: Path to the development target file.
    :param test_source_path: Path to the test source file.
    :param test_target_path: Path to the test target file.
    :param train_source_factor_paths: Optional list of paths to training source factor files.
    :param dev_source_factor_paths: Optional list of paths to dev source factor files.
    :param test_source_factor_paths: Optional list of paths to test source factor files.
    :param use_prepared_data: Whether to use the prepared data functionality.
    :param max_seq_len: The maximum sequence length.
    :param restrict_lexicon: Additional translation run with top-k lexicon-based vocabulary restriction.
    :param work_dir: The directory to store the model and other outputs in.
    :param seed: The seed used for training.
    :param quiet: Suppress the console output of training and decoding.
    :return: A tuple containing perplexity, bleu scores for standard and reduced vocab decoding, chrf score.
    """
    if quiet:
        quiet_arg = "--quiet"
    else:
        quiet_arg = ""
    with TemporaryDirectory(dir=work_dir, prefix="test_train_translate.") as work_dir:
        # Optionally create prepared data directory
        if use_prepared_data:
            prepared_data_path = os.path.join(work_dir, "prepared_data")
            params = "{} {}".format(sockeye.prepare_data.__file__,
                                    _PREPARE_DATA_COMMON.format(train_source=train_source_path,
                                                                train_target=train_target_path,
                                                                output=prepared_data_path,
                                                                max_len=max_seq_len,
                                                                quiet=quiet_arg))
            if train_source_factor_paths is not None:
                params += _TRAIN_WITH_FACTORS_COMMON.format(source_factors=" ".join(train_source_factor_paths))

            logger.info("Creating prepared data folder.")
            with patch.object(sys, "argv", params.split()):
                sockeye.prepare_data.main()
            # Train model
            model_path = os.path.join(work_dir, "model")
            params = "{} {} {}".format(sockeye.train.__file__,
                                       _TRAIN_PARAMS_PREPARED_DATA_COMMON.format(prepared_data=prepared_data_path,
                                                                                 dev_source=dev_source_path,
                                                                                 dev_target=dev_target_path,
                                                                                 model=model_path,
                                                                                 max_len=max_seq_len,
                                                                                 quiet=quiet_arg),
                                       train_params)

            if dev_source_factor_paths is not None:
                params += _DEV_WITH_FACTORS_COMMON.format(dev_source_factors=" ".join(dev_source_factor_paths))

            logger.info("Starting training with parameters %s.", train_params)
            with patch.object(sys, "argv", params.split()):
                sockeye.train.main()
        else:
            # Train model
            model_path = os.path.join(work_dir, "model")
            params = "{} {} {}".format(sockeye.train.__file__,
                                       _TRAIN_PARAMS_COMMON.format(train_source=train_source_path,
                                                                   train_target=train_target_path,
                                                                   dev_source=dev_source_path,
                                                                   dev_target=dev_target_path,
                                                                   model=model_path,
                                                                   max_len=max_seq_len,
                                                                   seed=seed,
                                                                   quiet=quiet_arg),
                                       train_params)

            if train_source_factor_paths is not None:
                params += _TRAIN_WITH_FACTORS_COMMON.format(source_factors=" ".join(train_source_factor_paths))
            if dev_source_factor_paths is not None:
                params += _DEV_WITH_FACTORS_COMMON.format(dev_source_factors=" ".join(dev_source_factor_paths))

            logger.info("Starting training with parameters %s.", train_params)
            with patch.object(sys, "argv", params.split()):
                sockeye.train.main()

        # run checkpoint decoder on 1% of dev data
        with open(dev_source_path) as dev_fd:
            num_dev_sent = sum(1 for _ in dev_fd)
        sample_size = min(1, int(num_dev_sent * 0.01))
        cp_decoder = sockeye.checkpoint_decoder.CheckpointDecoder(context=mx.cpu(),
                                                                  inputs=[dev_source_path],
                                                                  references=dev_target_path,
                                                                  model=model_path,
                                                                  sample_size=sample_size,
                                                                  batch_size=2,
                                                                  beam_size=2)
        cp_metrics = cp_decoder.decode_and_evaluate()
        logger.info("Checkpoint decoder metrics: %s", cp_metrics)

        logger.info("Translating with parameters %s.", translate_params)
        # Translate corpus with the 1st params
        out_path = os.path.join(work_dir, "out.txt")
        params = "{} {} {}".format(sockeye.translate.__file__,
                                   _TRANSLATE_PARAMS_COMMON.format(model=model_path,
                                                                   input=test_source_path,
                                                                   output=out_path,
                                                                   quiet=quiet_arg),
                                   translate_params)

        if test_source_factor_paths is not None:
            params += _TRANSLATE_WITH_FACTORS_COMMON.format(input_factors=" ".join(test_source_factor_paths))

        with patch.object(sys, "argv", params.split()):
            sockeye.translate.main()

        # Translate corpus with the 2nd params
        if translate_params_equiv is not None:
            out_path_equiv = os.path.join(work_dir, "out_equiv.txt")
            params = "{} {} {}".format(sockeye.translate.__file__,
                                       _TRANSLATE_PARAMS_COMMON.format(model=model_path,
                                                                       input=test_source_path,
                                                                       output=out_path_equiv,
                                                                       quiet=quiet_arg),
                                       translate_params_equiv)

            if test_source_factor_paths is not None:
                params += _TRANSLATE_WITH_FACTORS_COMMON.format(input_factors=" ".join(test_source_factor_paths))

            with patch.object(sys, "argv", params.split()):
                sockeye.translate.main()

            # read-in both outputs, ensure they are the same
            with open(out_path, 'rt') as f:
                lines = f.readlines()
            with open(out_path_equiv, 'rt') as f:
                lines_equiv = f.readlines()
            assert all(a == b for a, b in zip(lines, lines_equiv))

        # Test restrict-lexicon
        out_restrict_path = os.path.join(work_dir, "out-restrict.txt")
        if restrict_lexicon:
            # fast_align lex table
            ttable_path = os.path.join(work_dir, "ttable")
            generate_fast_align_lex(ttable_path)
            # Top-K lexicon
            lexicon_path = os.path.join(work_dir, "lexicon")
            params = "{} {}".format(sockeye.lexicon.__file__,
                                    _LEXICON_CREATE_PARAMS_COMMON.format(input=ttable_path,
                                                                         model=model_path,
                                                                         topk=20,
                                                                         lexicon=lexicon_path,
                                                                         quiet=quiet_arg))
            with patch.object(sys, "argv", params.split()):
                sockeye.lexicon.main()
            # Translate corpus with restrict-lexicon
            params = "{} {} {} {}".format(sockeye.translate.__file__,
                                          _TRANSLATE_PARAMS_COMMON.format(model=model_path,
                                                                          input=test_source_path,
                                                                          output=out_restrict_path,
                                                                          quiet=quiet_arg),
                                          translate_params,
                                          _TRANSLATE_PARAMS_RESTRICT.format(lexicon=lexicon_path, topk=1))

            if test_source_factor_paths is not None:
                params += _TRANSLATE_WITH_FACTORS_COMMON.format(input_factors=" ".join(test_source_factor_paths))

            with patch.object(sys, "argv", params.split()):
                sockeye.translate.main()

        # test averaging
        points = sockeye.average.find_checkpoints(model_path=model_path,
                                                  size=1,
                                                  strategy='best',
                                                  metric=C.PERPLEXITY)
        assert len(points) > 0
        averaged_params = sockeye.average.average(points)
        assert averaged_params

        # test parameter extraction
        extract_params = _EXTRACT_PARAMS.format(output=os.path.join(model_path, "params.extracted"),
                                                input=model_path)
        with patch.object(sys, "argv", extract_params.split()):
            sockeye.extract_parameters.main()
        with np.load(os.path.join(model_path, "params.extracted.npz")) as data:
            assert "target_output_bias" in data

        # get best validation perplexity
        metrics = sockeye.utils.read_metrics_file(path=os.path.join(model_path, C.METRICS_NAME))
        perplexity = min(m[C.PERPLEXITY + '-val'] for m in metrics)

        with open(out_path, "r") as out:
            hypotheses = out.readlines()
        with open(test_target_path, "r") as ref:
            references = ref.readlines()
        assert len(hypotheses) == len(references)

        # compute metrics
        bleu = raw_corpus_bleu(hypotheses=hypotheses, references=references, offset=0.01)
        chrf = raw_corpus_chrf(hypotheses=hypotheses, references=references)

        bleu_restrict = None
        if restrict_lexicon:
            bleu_restrict = raw_corpus_bleu(hypotheses=hypotheses, references=references, offset=0.01)

        # Run evaluate cli
        eval_params = "{} {} ".format(sockeye.evaluate.__file__,
                                      _EVAL_PARAMS_COMMON.format(hypotheses=out_path,
                                                                 references=test_target_path,
                                                                 metrics="bleu chrf rouge1",
                                                                 quiet=quiet_arg), )
        with patch.object(sys, "argv", eval_params.split()):
            sockeye.evaluate.main()

        return perplexity, bleu, bleu_restrict, chrf
示例#6
0
def run_train_captioning(train_params: str,
                        translate_params: str,
                        translate_params_equiv: Optional[str],
                        train_source_path: str,
                        train_target_path: str,
                        dev_source_path: str,
                        dev_target_path: str,
                        test_source_path: str,
                        test_target_path: str,
                        max_seq_len: int = 10,
                        work_dir: Optional[str] = None,
                        seed: int = 13,
                        quiet: bool = False) -> Tuple[float, float, float, float]:
    """
    Train a model and caption a dev set.  Report validation perplexity and BLEU.

    :param train_params: Command line args for model training.
    :param translate_params: First command line args for translation.
    :param translate_params_equiv: Second command line args for captuoning. Should produce the same outputs
    :param train_source_path: Path to the source file.
    :param train_target_path: Path to the target file.
    :param dev_source_path: Path to the development source file.
    :param dev_target_path: Path to the development target file.
    :param test_source_path: Path to the test source file.
    :param test_target_path: Path to the test target file.
    :param max_seq_len: The maximum sequence length.
    :param work_dir: The directory to store the model and other outputs in.
    :param seed: The seed used for training.
    :param quiet: Suppress the console output of training and decoding.
    :return: A tuple containing perplexity, bleu scores for standard and reduced vocab decoding, chrf score.
    """
    source_root = work_dir
    if quiet:
        quiet_arg = "--quiet"
    else:
        quiet_arg = ""
    with TemporaryDirectory(dir=work_dir, prefix="test_train_translate.") as work_dir:
        # Train model
        model_path = os.path.join(work_dir, "model")
        params = "{} {} {}".format(sockeye.image_captioning.train.__file__,
                                   _CAPTION_TRAIN_PARAMS_COMMON.format(
                                       source_root=source_root,
                                       train_source=train_source_path,
                                       train_target=train_target_path,
                                       dev_root=source_root,
                                       dev_source=dev_source_path,
                                       dev_target=dev_target_path,
                                       model=model_path,
                                       max_len=max_seq_len,
                                       seed=seed,
                                       quiet=quiet_arg),
                                   train_params)

        logger.info("Starting training with parameters %s.", train_params)
        with patch.object(sys, "argv", params.split()):
            sockeye.image_captioning.train.main()

        logger.info("Translating with parameters %s.", translate_params)
        # Translate corpus with the 1st params
        out_path = os.path.join(work_dir, "out.txt")
        params = "{} {} {}".format(sockeye.image_captioning.captioner.__file__,
                                   _CAPTIONER_PARAMS_COMMON.format(model=model_path,
                                                                   source_root=source_root,
                                                                   input=test_source_path,
                                                                   output=out_path,
                                                                   quiet=quiet_arg),
                                   translate_params)
        with patch.object(sys, "argv", params.split()):
            sockeye.image_captioning.captioner.main()

        # Translate corpus with the 2nd params
        if translate_params_equiv is not None:
            out_path_equiv = os.path.join(work_dir, "out_equiv.txt")
            params = "{} {} {}".format(sockeye.image_captioning.captioner.__file__,
                                   _CAPTIONER_PARAMS_COMMON.format(model=model_path,
                                                                   source_root=source_root,
                                                                   input=test_source_path,
                                                                   output=out_path_equiv,
                                                                   quiet=quiet_arg),
                                    translate_params_equiv)
            with patch.object(sys, "argv", params.split()):
                sockeye.image_captioning.captioner.main()
            # read-in both outputs, ensure they are the same
            with open(out_path, 'rt') as f:
                lines = f.readlines()
            with open(out_path_equiv, 'rt') as f:
                lines_equiv = f.readlines()
            assert all(a == b for a, b in zip(lines, lines_equiv))

        # test averaging
        points = sockeye.average.find_checkpoints(model_path=model_path,
                                                  size=1,
                                                  strategy='best',
                                                  metric=C.PERPLEXITY)
        assert len(points) > 0
        averaged_params = sockeye.average.average(points)
        assert averaged_params

        # get best validation perplexity
        metrics = sockeye.utils.read_metrics_file(path=os.path.join(model_path, C.METRICS_NAME))
        perplexity = min(m[C.PERPLEXITY + '-val'] for m in metrics)
        hypotheses = open(out_path, "r").readlines()
        references = open(test_target_path, "r").readlines()
        assert len(hypotheses) == len(references)
        # compute metrics
        bleu = raw_corpus_bleu(hypotheses=hypotheses, references=references, offset=0.01)
        chrf = raw_corpus_chrf(hypotheses=hypotheses, references=references)

        return perplexity, bleu, chrf