Beispiel #1
0
def train_model(
    params: Params,
    serialization_dir: Union[str, PathLike],
    recover: bool = False,
    force: bool = False,
    node_rank: int = 0,
    include_package: List[str] = None,
    dry_run: bool = False,
    file_friendly_logging: bool = False,
) -> Optional[Model]:
    """
    Trains the model specified in the given [`Params`](../common/params.md#params) object, using the data
    and training parameters also specified in that object, and saves the results in `serialization_dir`.

    # Parameters

    params : `Params`
        A parameter object specifying an AllenNLP Experiment.
    serialization_dir : `str`
        The directory in which to save results and logs.
    recover : `bool`, optional (default=`False`)
        If `True`, we will try to recover a training run from an existing serialization
        directory.  This is only intended for use when something actually crashed during the middle
        of a run.  For continuing training a model on new data, see `Model.from_archive`.
    force : `bool`, optional (default=`False`)
        If `True`, we will overwrite the serialization directory if it already exists.
    node_rank : `int`, optional
        Rank of the current node in distributed training
    include_package : `List[str]`, optional
        In distributed mode, extra packages mentioned will be imported in trainer workers.
    dry_run : `bool`, optional (default=`False`)
        Do not train a model, but create a vocabulary, show dataset statistics and other training
        information.
    file_friendly_logging : `bool`, optional (default=`False`)
        If `True`, we add newlines to tqdm output, even on an interactive terminal, and we slow
        down tqdm's output to only once every 10 seconds.

    # Returns

    best_model : `Optional[Model]`
        The model with the best epoch weights or `None` if in dry run.
    """
    common_logging.FILE_FRIENDLY_LOGGING = file_friendly_logging

    training_util.create_serialization_dir(params, serialization_dir, recover, force)
    params.to_file(os.path.join(serialization_dir, CONFIG_NAME))

    include_in_archive = params.pop("include_in_archive", None)
    verify_include_in_archive(include_in_archive)

    distributed_params = params.params.pop("distributed", None)
    # If distributed isn't in the config and the config contains strictly
    # one cuda device, we just run a single training process.
    if distributed_params is None:
        model = _train_worker(
            process_rank=0,
            params=params,
            serialization_dir=serialization_dir,
            include_package=include_package,
            dry_run=dry_run,
            file_friendly_logging=file_friendly_logging,
        )

        if not dry_run:
            archive_model(serialization_dir, include_in_archive=include_in_archive)
        return model

    # Otherwise, we are running multiple processes for training.
    else:
        common_logging.prepare_global_logging(
            serialization_dir,
            rank=0,
            world_size=1,
        )

        # We are careful here so that we can raise a good error if someone
        # passed the wrong thing - cuda_devices are required.
        device_ids = distributed_params.pop("cuda_devices", None)
        multi_device = isinstance(device_ids, list) and len(device_ids) > 1
        num_nodes = distributed_params.pop("num_nodes", 1)

        if not (multi_device or num_nodes > 1):
            raise ConfigurationError(
                "Multiple cuda devices/nodes need to be configured to run distributed training."
            )
        check_for_gpu(device_ids)

        master_addr = distributed_params.pop("master_address", "127.0.0.1")
        if master_addr in ("127.0.0.1", "0.0.0.0", "localhost"):
            # If running locally, we can automatically find an open port if one is not specified.
            master_port = (
                distributed_params.pop("master_port", None) or common_util.find_open_port()
            )
        else:
            # Otherwise we require that the port be specified.
            master_port = distributed_params.pop("master_port")

        num_procs = len(device_ids)
        world_size = num_nodes * num_procs

        # Creating `Vocabulary` objects from workers could be problematic since
        # the data loaders in each worker will yield only `rank` specific
        # instances. Hence it is safe to construct the vocabulary and write it
        # to disk before initializing the distributed context. The workers will
        # load the vocabulary from the path specified.
        vocab_dir = os.path.join(serialization_dir, "vocabulary")
        if recover:
            vocab = Vocabulary.from_files(vocab_dir)
        else:
            vocab = training_util.make_vocab_from_params(
                params.duplicate(), serialization_dir, print_statistics=dry_run
            )
        params["vocabulary"] = {
            "type": "from_files",
            "directory": vocab_dir,
            "padding_token": vocab._padding_token,
            "oov_token": vocab._oov_token,
        }

        logging.info(
            "Switching to distributed training mode since multiple GPUs are configured | "
            f"Master is at: {master_addr}:{master_port} | Rank of this node: {node_rank} | "
            f"Number of workers in this node: {num_procs} | Number of nodes: {num_nodes} | "
            f"World size: {world_size}"
        )

        mp.spawn(
            _train_worker,
            args=(
                params.duplicate(),
                serialization_dir,
                include_package,
                dry_run,
                node_rank,
                master_addr,
                master_port,
                world_size,
                device_ids,
                file_friendly_logging,
                include_in_archive,
            ),
            nprocs=num_procs,
        )
        if dry_run:
            return None
        else:
            archive_model(serialization_dir, include_in_archive=include_in_archive)
            model = Model.load(params, serialization_dir)
            return model
Beispiel #2
0
    def __init__(
        self,
        serialization_dir: str,
        summary_interval: int = 100,
        distribution_interval: Optional[int] = None,
        batch_size_interval: Optional[int] = None,
        should_log_parameter_statistics: bool = True,
        should_log_learning_rate: bool = False,
        project: Optional[str] = None,
        entity: Optional[str] = None,
        group: Optional[str] = None,
        name: Optional[str] = None,
        notes: Optional[str] = None,
        tags: Optional[Union[str, List[str]]] = None,
        watch_model: bool = True,
        files_to_save: List[str] = ["config.json", "out.log"],
        files_to_save_at_end: Optional[List[str]] = None,
        include_in_archive: List[str] = None,
        save_model_archive: bool = True,
        wandb_kwargs: Optional[Dict[str, Any]] = None,
        finish_on_end: bool = False,
        sub_callbacks: Optional[List[AllennlpWandbSubCallback]] = None,
    ) -> None:
        logger.debug("Wandb related varaibles")
        logger.debug(
            "%s |   %s  |   %s",
            "variable".ljust(15),
            "value from env".ljust(50),
            "value in constructor".ljust(50),
        )

        for e, a in [("PROJECT", project), ("ENTITY", entity)]:
            logger.debug(
                "%s |   %s  |   %s",
                str(e).lower()[:15].ljust(15),
                str(read_from_env("WANDB_" + e))[:50].ljust(50),
                str(a)[:50].ljust(50),
            )
        logger.debug("All wandb related envirnment varaibles")
        logger.debug("%s |   %s  ", "ENV VAR.".ljust(15), "VALUE".ljust(50))

        for k, v in os.environ.items():
            if "WANDB" in k or "ALLENNLP" in k:
                logger.debug(
                    "%s |   %s  ",
                    str(k)[:15].ljust(15),
                    str(v)[:50].ljust(50),
                )
        t = read_from_env("WANDB_TAGS") or tags

        if isinstance(t, str):
            tags = t.split(",")
        else:
            tags = t
        super().__init__(
            serialization_dir,
            summary_interval=summary_interval,
            distribution_interval=distribution_interval,
            batch_size_interval=batch_size_interval,
            should_log_parameter_statistics=should_log_parameter_statistics,
            should_log_learning_rate=should_log_learning_rate,
            # prefer env variables because
            project=read_from_env("WANDB_PROJECT") or project,
            entity=read_from_env("WANDB_ENTITY") or entity,
            group=read_from_env("WANDB_GROUP") or group,
            name=read_from_env("WANDB_NAME") or name,
            notes=read_from_env("WANDB_NOTES") or notes,
            tags=tags,
            watch_model=watch_model,
            files_to_save=tuple(files_to_save),
            wandb_kwargs=wandb_kwargs,
        )
        self.finish_on_end = finish_on_end
        self._files_to_save_at_end = files_to_save_at_end or []
        self.include_in_archive = include_in_archive
        verify_include_in_archive(include_in_archive)
        self.save_model_archive = save_model_archive
        self.priority = 100
        self.sub_callbacks = sorted(
            sub_callbacks or [], key=lambda x: x.priority, reverse=True
        )

        if save_model_archive:
            self._files_to_save_at_end.append("model.tar.gz")
        # do not set wandb dir to be inside the serialization directory.

        if "dir" in self._wandb_kwargs:
            self._wandb_kwargs["dir"] = None

        if "config" in self._wandb_kwargs:
            self._wandb_kwargs["config"] = flatten_dict(
                self._wandb_kwargs["config"]
            )