예제 #1
0
def test_config_serialization():
    c1 = ConfigTest(param=1, config=ConfigTest(param=2))
    expected_serialization = """!ConfigTest
config: !ConfigTest
  config: null
  param: 2
param: 1
"""
    with tempfile.TemporaryDirectory() as tmp_dir:
        fname = os.path.join(tmp_dir, "config")
        c1.save(fname)
        assert os.path.exists(fname)
        with open(fname) as f:
            assert f.read() == expected_serialization

        c2 = Config.load(fname)
        assert c2.param == c1.param
        assert c2.config.param == c1.config.param
예제 #2
0
def create_data_iters_and_vocab(args: argparse.Namespace,
                                shared_vocab: bool,
                                resume_training: bool,
                                output_folder: str) -> Tuple['data_io.BaseParallelSampleIter',
                                                             'data_io.BaseParallelSampleIter',
                                                             'data_io.DataConfig', Dict, Dict]:
    """
    Create the data iterators and the vocabularies.

    :param args: Arguments as returned by argparse.
    :param shared_vocab: Whether to create a shared vocabulary.
    :param resume_training: Whether to resume training.
    :param output_folder: Output folder.
    :return: The data iterators (train, validation, config_data) as well as the source and target vocabularies.
    """

    max_seq_len_source, max_seq_len_target = args.max_seq_len
    num_words_source, num_words_target = args.num_words
    word_min_count_source, word_min_count_target = args.word_min_count
    batch_num_devices = 1 if args.use_cpu else sum(-di if di < 0 else 1 for di in args.device_ids)
    batch_by_words = args.batch_type == C.BATCH_TYPE_WORD

    either_raw_or_prepared_error_msg = "Either specify a raw training corpus with %s and %s or a preprocessed corpus " \
                                       "with %s." % (C.TRAINING_ARG_SOURCE,
                                                     C.TRAINING_ARG_TARGET,
                                                     C.TRAINING_ARG_PREPARED_DATA)
    if args.prepared_data is not None:
        utils.check_condition(args.source is None and args.target is None, either_raw_or_prepared_error_msg)
        if not resume_training:
            utils.check_condition(args.source_vocab is None and args.target_vocab is None,
                                  "You are using a prepared data folder, which is tied to a vocabulary. "
                                  "To change it you need to rerun data preparation with a different vocabulary.")
        train_iter, validation_iter, data_config, vocab_source, vocab_target = data_io.get_prepared_data_iters(
            prepared_data_dir=args.prepared_data,
            validation_source=os.path.abspath(args.validation_source),
            validation_target=os.path.abspath(args.validation_target),
            shared_vocab=shared_vocab,
            batch_size=args.batch_size,
            batch_by_words=batch_by_words,
            batch_num_devices=batch_num_devices,
            fill_up=args.fill_up)
        if resume_training:
            # resuming training. Making sure the vocabs in the model and in the prepared data match up
            model_vocab_source = vocab.vocab_from_json(os.path.join(output_folder, C.VOCAB_SRC_NAME + C.JSON_SUFFIX))
            model_vocab_target = vocab.vocab_from_json(os.path.join(output_folder, C.VOCAB_TRG_NAME + C.JSON_SUFFIX))
            utils.check_condition(vocab.are_identical(vocab_source, model_vocab_source),
                                  "Prepared data and resumed model source vocabs do not match.")
            utils.check_condition(vocab.are_identical(vocab_target, model_vocab_target),
                                  "Prepared data and resumed model target vocabs do not match.")

        return train_iter, validation_iter, data_config, vocab_source, vocab_target
    else:
        utils.check_condition(args.prepared_data is None and args.source is not None and args.target is not None,
                              either_raw_or_prepared_error_msg)

        if resume_training:
            # Load the existing vocab created when starting the training run.
            vocab_source = vocab.vocab_from_json(os.path.join(output_folder, C.VOCAB_SRC_NAME + C.JSON_SUFFIX))
            vocab_target = vocab.vocab_from_json(os.path.join(output_folder, C.VOCAB_TRG_NAME + C.JSON_SUFFIX))

            # Recover the vocabulary path from the existing config file:
            orig_config = cast(model.ModelConfig, Config.load(os.path.join(output_folder, C.CONFIG_NAME)))
            vocab_source_path = orig_config.config_data.vocab_source
            vocab_target_path = orig_config.config_data.vocab_target
        else:
            # Load vocab:
            vocab_source_path = args.source_vocab
            vocab_target_path = args.target_vocab
            vocab_source, vocab_target = vocab.load_or_create_vocabs(source=args.source, target=args.target,
                                                                     source_vocab_path=vocab_source_path,
                                                                     target_vocab_path=vocab_target_path,
                                                                     shared_vocab=shared_vocab,
                                                                     num_words_source=num_words_source,
                                                                     num_words_target=num_words_target,
                                                                     word_min_count_source=word_min_count_source,
                                                                     word_min_count_target=word_min_count_target)

        train_iter, validation_iter, config_data = data_io.get_training_data_iters(
            source=os.path.abspath(args.source),
            target=os.path.abspath(args.target),
            validation_source=os.path.abspath(args.validation_source),
            validation_target=os.path.abspath(args.validation_target),
            vocab_source=vocab_source,
            vocab_target=vocab_target,
            vocab_source_path=vocab_source_path,
            vocab_target_path=vocab_target_path,
            shared_vocab=shared_vocab,
            batch_size=args.batch_size,
            batch_by_words=batch_by_words,
            batch_num_devices=batch_num_devices,
            fill_up=args.fill_up,
            max_seq_len_source=max_seq_len_source,
            max_seq_len_target=max_seq_len_target,
            bucketing=not args.no_bucketing,
            bucket_width=args.bucket_width)
        return train_iter, validation_iter, config_data, vocab_source, vocab_target
예제 #3
0
def test_config_missing_attributes_filled_with_default():
    # when we load a configuration object that does not contain all attributes as the current version of the
    # configuration object we expect the missing attributes to be filled with the default values of the dataclass

    config_obj = Config.load("test/data/config_with_missing_attributes.yaml")
    assert config_obj.new_attribute == "new_attribute"
예제 #4
0
파일: train.py 프로젝트: kyocen/sockeye
def create_data_iters_and_vocabs(args: argparse.Namespace,
                                 shared_vocab: bool,
                                 resume_training: bool,
                                 output_folder: str) -> Tuple['data_io.BaseParallelSampleIter',
                                                              'data_io.BaseParallelSampleIter',
                                                              'data_io.DataConfig',
                                                              List[vocab.Vocab], vocab.Vocab]:
    """
    Create the data iterators and the vocabularies.

    :param args: Arguments as returned by argparse.
    :param shared_vocab: Whether to create a shared vocabulary.
    :param resume_training: Whether to resume training.
    :param output_folder: Output folder.
    :return: The data iterators (train, validation, config_data) as well as the source and target vocabularies.
    """
    max_seq_len_source, max_seq_len_target = args.max_seq_len
    num_words_source, num_words_target = args.num_words
    word_min_count_source, word_min_count_target = args.word_min_count
    batch_num_devices = 1 if args.use_cpu else sum(-di if di < 0 else 1 for di in args.device_ids)
    batch_by_words = args.batch_type == C.BATCH_TYPE_WORD

    validation_sources = [args.validation_source] + args.validation_source_factors
    validation_sources = [str(os.path.abspath(source)) for source in validation_sources]

    either_raw_or_prepared_error_msg = "Either specify a raw training corpus with %s and %s or a preprocessed corpus " \
                                       "with %s." % (C.TRAINING_ARG_SOURCE,
                                                     C.TRAINING_ARG_TARGET,
                                                     C.TRAINING_ARG_PREPARED_DATA)
    if args.prepared_data is not None:
        utils.check_condition(args.source is None and args.target is None, either_raw_or_prepared_error_msg)
        if not resume_training:
            utils.check_condition(args.source_vocab is None and args.target_vocab is None,
                                  "You are using a prepared data folder, which is tied to a vocabulary. "
                                  "To change it you need to rerun data preparation with a different vocabulary.")
        train_iter, validation_iter, data_config, source_vocabs, target_vocab = data_io.get_prepared_data_iters(
            prepared_data_dir=args.prepared_data,
            validation_sources=validation_sources,
            validation_target=str(os.path.abspath(args.validation_target)),
            shared_vocab=shared_vocab,
            batch_size=args.batch_size,
            batch_by_words=batch_by_words,
            batch_num_devices=batch_num_devices,
            fill_up=args.fill_up)

        check_condition(len(source_vocabs) == len(args.source_factors_num_embed) + 1,
                        "Data was prepared with %d source factors, but only provided %d source factor dimensions." % (
                            len(source_vocabs), len(args.source_factors_num_embed) + 1))

        if resume_training:
            # resuming training. Making sure the vocabs in the model and in the prepared data match up
            model_source_vocabs = vocab.load_source_vocabs(output_folder)
            for i, (v, mv) in enumerate(zip(source_vocabs, model_source_vocabs)):
                utils.check_condition(vocab.are_identical(v, mv),
                                      "Prepared data and resumed model source vocab %d do not match." % i)
            model_target_vocab = vocab.vocab_from_json(os.path.join(output_folder, C.VOCAB_TRG_NAME))
            utils.check_condition(vocab.are_identical(target_vocab, model_target_vocab),
                                  "Prepared data and resumed model target vocabs do not match.")

            check_condition(len(args.source_factors) == len(args.validation_source_factors),
                            'Training and validation data must have the same number of factors: %d vs. %d.' % (
                                len(args.source_factors), len(args.validation_source_factors)))

        return train_iter, validation_iter, data_config, source_vocabs, target_vocab

    else:
        utils.check_condition(args.prepared_data is None and args.source is not None and args.target is not None,
                              either_raw_or_prepared_error_msg)

        if resume_training:
            # Load the existing vocabs created when starting the training run.
            source_vocabs = vocab.load_source_vocabs(output_folder)
            target_vocab = vocab.vocab_from_json(os.path.join(output_folder, C.VOCAB_TRG_NAME))

            # Recover the vocabulary path from the data info file:
            data_info = cast(data_io.DataInfo, Config.load(os.path.join(output_folder, C.DATA_INFO)))
            source_vocab_paths = data_info.source_vocabs
            target_vocab_path = data_info.target_vocab

        else:
            # Load or create vocabs
            source_vocab_paths = [args.source_vocab] + [None] * len(args.source_factors)
            target_vocab_path = args.target_vocab
            source_vocabs, target_vocab = vocab.load_or_create_vocabs(
                source_paths=[args.source] + args.source_factors,
                target_path=args.target,
                source_vocab_paths=source_vocab_paths,
                target_vocab_path=target_vocab_path,
                shared_vocab=shared_vocab,
                num_words_source=num_words_source,
                num_words_target=num_words_target,
                word_min_count_source=word_min_count_source,
                word_min_count_target=word_min_count_target)

        check_condition(len(args.source_factors) == len(args.source_factors_num_embed),
                        "Number of source factor data (%d) differs from provided source factor dimensions (%d)" % (
                            len(args.source_factors), len(args.source_factors_num_embed)))

        sources = [args.source] + args.source_factors
        sources = [str(os.path.abspath(source)) for source in sources]

        train_iter, validation_iter, config_data, data_info = data_io.get_training_data_iters(
            sources=sources,
            target=os.path.abspath(args.target),
            validation_sources=validation_sources,
            validation_target=os.path.abspath(args.validation_target),
            source_vocabs=source_vocabs,
            target_vocab=target_vocab,
            source_vocab_paths=source_vocab_paths,
            target_vocab_path=target_vocab_path,
            shared_vocab=shared_vocab,
            batch_size=args.batch_size,
            batch_by_words=batch_by_words,
            batch_num_devices=batch_num_devices,
            fill_up=args.fill_up,
            max_seq_len_source=max_seq_len_source,
            max_seq_len_target=max_seq_len_target,
            bucketing=not args.no_bucketing,
            bucket_width=args.bucket_width)

        data_info_fname = os.path.join(output_folder, C.DATA_INFO)
        logger.info("Writing data config to '%s'", data_info_fname)
        data_info.save(data_info_fname)

        return train_iter, validation_iter, config_data, source_vocabs, target_vocab