Beispiel #1
0
 def _get_config_dict(cls, path, **kw):
     local_files_only = kw.pop("local_files_only", False)
     from_pipeline = kw.pop("_from_pipeline", None)
     user_agent = {
         "file_type": "config",
         "from_auto_class": kw.pop("_from_auto", False)
     }
     if from_pipeline is not None:
         user_agent["using_pipeline"] = from_pipeline
     if is_offline_mode() and not local_files_only:
         log.info("Offline mode: forcing local_files_only=True")
         local_files_only = True
     path = str(path)
     if os.path.isfile(path) or is_remote_url(path):
         x = path
     else:
         f = kw.pop("_configuration_file", CONFIG_NAME)
         if os.path.isdir(path):
             x = os.path.join(path, f)
         else:
             x = hf_bucket_url(path,
                               filename=f,
                               revision=kw.pop("revision", None),
                               mirror=None)
     try:
         x2 = cached_path(
             x,
             cache_dir=kw.pop("cache_dir", None),
             force_download=kw.pop("force_download", False),
             proxies=kw.pop("proxies", None),
             resume_download=kw.pop("resume_download", False),
             local_files_only=local_files_only,
             use_auth_token=kw.pop("use_auth_token", None),
             user_agent=user_agent,
         )
     except RepositoryNotFoundError as e:
         raise OSError() from e
     except RevisionNotFoundError as e:
         raise OSError() from e
     except EntryNotFoundError as e:
         raise OSError() from e
     except HTTPError as e:
         raise OSError() from e
     except OSError as e:
         raise e
     try:
         y = cls._dict_from_json_file(x2)
     except (json.JSONDecodeError, UnicodeDecodeError) as e:
         raise OSError() from e
     if x2 == x:
         log.info(f"loading {x}")
     else:
         log.info(f"loading {x} from cache at {x2}")
     return y, kw
Beispiel #2
0
    Seq2SeqTrainingArguments,
    set_seed,
)
from transformers.file_utils import is_offline_mode
from transformers.trainer_utils import get_last_checkpoint, is_main_process
from transformers.utils import check_min_version

# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.6.0.dev0")

logger = logging.getLogger(__name__)

try:
    nltk.data.find("tokenizers/punkt")
except (LookupError, OSError):
    if is_offline_mode():
        raise LookupError(
            "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
        )
    with FileLock(".lock") as lock:
        nltk.download("punkt", quiet=True)


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
        metadata={