Exemple #1
0
def prepare_task(
    config: PyTextConfig,
    dist_init_url: str = None,
    device_id: int = 0,
    rank: int = 0,
    world_size: int = 1,
    metric_channels: Optional[List[Channel]] = None,
    metadata: CommonMetadata = None,
) -> Task:

    if dist_init_url and world_size > 1:
        assert metadata is not None
        dist_init(rank, world_size, dist_init_url)

    print("\nParameters: {}\n".format(config))
    _set_cuda(config.use_cuda_if_available, device_id, world_size)
    _set_fp16(config.use_fp16 and world_size == 1)
    if config.random_seed is not None:
        set_random_seeds(config.random_seed)

    if config.load_snapshot_path and os.path.isfile(config.load_snapshot_path):
        task = load(config.load_snapshot_path)
    else:
        task = create_task(config.task, metadata=metadata)

    for mc in metric_channels or []:
        task.metric_reporter.add_channel(mc)

    return task
Exemple #2
0
def _set_distributed(
    rank: int, world_size: int, dist_init_url: str, device_id: int, gpu_streams: int = 1
) -> None:
    if dist_init_url and world_size > 1:
        distributed.dist_init(
            rank, world_size, dist_init_url, device_id, gpu_streams=gpu_streams
        )
Exemple #3
0
def _set_distributed(
    rank: int,
    world_size: int,
    dist_init_url: str,
    device_id: int,
) -> None:
    if dist_init_url and world_size > 1:
        distributed.dist_init(rank, world_size, dist_init_url, device_id)
Exemple #4
0
def _set_distributed(
    rank: int,
    world_size: int,
    dist_init_url: str,
    device_id: int,
    metadata: CommonMetadata,
) -> None:
    if dist_init_url and world_size > 1:
        assert metadata is not None
        distributed.dist_init(rank, world_size, dist_init_url, device_id)
Exemple #5
0
    def train(
        self, config: PyTextConfig, rank: int = 0, world_size: int = 1, dist_init_url=""
    ):
        # TODO: move dist_init back to prepare_task in pytext/workflow.py
        # when processing time between dist_init and first loss.backward() is short
        if dist_init_url and world_size > 1:
            distributed.dist_init(rank, world_size, dist_init_url)

        return self.trainer.train(
            self.data.batches(Stage.TRAIN),
            self.data.batches(Stage.EVAL),
            self.model,
            self.metric_reporter,
            config,
            rank=rank,
        )
Exemple #6
0
    def train_single_model(
        self, train_config, model_id, rank=0, world_size=1, dist_init_url=""
    ):
        assert model_id >= 0 and model_id < len(self.model.models)

        train_iter = self.data_handler.get_train_iter(rank, world_size)
        eval_iter = self.data_handler.get_eval_iter()
        if dist_init_url and world_size > 1:
            distributed.dist_init(rank, world_size, dist_init_url)

        return self.trainer.train_single_model(
            train_iter,
            eval_iter,
            self.model.models[model_id],
            self.metric_reporter,
            train_config,
        )
Exemple #7
0
    def train(self, train_config, rank=0, world_size=1, dist_init_url=""):
        """
        Wrapper method to train the model using :class:`~Trainer` object.

        Args:
            train_config (PyTextConfig): config for training
            rank (int): for distributed training only, rank of the gpu, default is 0
            world_size (int): for distributed training only, total gpu to use, default
                is 1
        """
        train_iter = self.data_handler.get_train_iter(rank, world_size)
        eval_iter = self.data_handler.get_eval_iter()
        if dist_init_url and world_size > 1:
            distributed.dist_init(rank, world_size, dist_init_url)

        result = self.trainer.train(
            train_iter,
            eval_iter,
            self.model,
            self.metric_reporter,
            train_config,
            rank=rank,
        )
        return result