示例#1
0
 def _finalize_config(self, config, training=True):
     config = utility.resolve_environment_variables(config,
                                                    training=training)
     config = self._upgrade_data_config(config, training=training)
     config = utility.resolve_remote_files(config, self._shared_dir,
                                           self._storage)
     return config
示例#2
0
def test_resolve_remote_files(tmpdir):
    tmpdir.join("remote").join("dir").join("a.txt").write("toto", ensure=True)
    tmpdir.join("local").ensure_dir()
    storage_config = {
        "tmp": {
            "type": "local",
            "basedir": str(tmpdir)
        },
        "tmp2": {
            "type": "local",
            "basedir": str(tmpdir.join("remote"))
        }
    }
    client = StorageClient(config=storage_config)
    config = {
        "a": "/home/ubuntu/a.txt",
        "b": "non_storage:b.txt",
        "c": "tmp:remote/dir/a.txt",
        "d": "tmp2:/dir/a.txt",
        "e": True,
        "f": "tmp:",
    }
    config = utility.resolve_remote_files(config, str(tmpdir.join("local")),
                                          client)
    c_path = tmpdir.join("local").join("tmp/remote/dir/a.txt")
    d_path = tmpdir.join("local").join("tmp2/dir/a.txt")
    f_path = tmpdir.join("local").join("tmp")
    assert config["a"] == "/home/ubuntu/a.txt"
    assert config["b"] == "non_storage:b.txt"
    assert config["c"] == str(c_path)
    assert config["d"] == str(d_path)
    assert c_path.check(file=1)
    assert d_path.check(file=1)
    assert f_path.check(dir=1)
    def _get_vocabs_info(self,
                         config,
                         local_config,
                         model_config=None,
                         tokens_to_add=None,
                         keep_previous=False):
        if tokens_to_add is None:
            tokens_to_add = {}
        vocab_config = config.get('vocabulary', {})
        vocab_local_config = local_config.get('vocabulary', {})
        # For compatibility with old configurations
        tok_config = config.get('tokenization', {})
        tok_local_config = local_config.get('tokenization', {})
        joint_vocab = is_joint_vocab(vocab_local_config)
        parent_dependencies = {}
        if model_config:
            model_config = config_util.old_to_new_config(model_config)
            model_vocab_config = model_config.get('vocabulary', {})
            model_vocab_local_config = utility.resolve_remote_files(
                utility.resolve_environment_variables(model_vocab_config),
                self._shared_dir, self._storage)
            model_joint_vocab = is_joint_vocab(model_vocab_local_config)
            if joint_vocab != model_joint_vocab:
                raise ValueError(
                    "Changing joint vocabularies to split vocabularies "
                    "(or vice-versa) is currently not supported.")
            if keep_previous:
                bundle_dependencies(parent_dependencies,
                                    copy.deepcopy(model_vocab_config),
                                    copy.deepcopy(model_vocab_local_config))
        else:
            model_vocab_config = None
            model_vocab_local_config = None
        source_tokens_to_add = tokens_to_add.get('source') or []
        target_tokens_to_add = tokens_to_add.get('target') or []
        if joint_vocab:
            source_tokens_to_add = set(
                list(source_tokens_to_add) + list(target_tokens_to_add))
            target_tokens_to_add = source_tokens_to_add
        src_info = self._get_vocab_info(
            'source',
            vocab_config,
            vocab_local_config,
            tok_config,
            tok_local_config,
            model_config=model_vocab_config,
            model_local_config=model_vocab_local_config,
            tokens_to_add=source_tokens_to_add,
            keep_previous=keep_previous,
            joint_vocab=joint_vocab)
        tgt_info = self._get_vocab_info(
            'target',
            vocab_config,
            vocab_local_config,
            tok_config,
            tok_local_config,
            model_config=model_vocab_config,
            model_local_config=model_vocab_local_config,
            tokens_to_add=target_tokens_to_add,
            keep_previous=keep_previous,
            joint_vocab=joint_vocab)

        if vocab_config:
            config.pop('tokenization', None)
            local_config.pop('tokenization', None)

        return src_info, tgt_info, parent_dependencies