예제 #1
0
def prepare_task(
    config: PyTextConfig,
    dist_init_url: str = None,
    device_id: int = 0,
    rank: int = 0,
    world_size: int = 1,
    metric_channels: Optional[List[Channel]] = None,
    metadata: CommonMetadata = None,
) -> Task:

    if dist_init_url and world_size > 1:
        assert metadata is not None

    print("\nParameters: {}\n".format(config))
    _set_cuda(config.use_cuda_if_available, device_id, world_size)
    _set_fp16(config.use_fp16)
    if config.random_seed is not None:
        set_random_seeds(config.random_seed)

    if config.load_snapshot_path and os.path.isfile(config.load_snapshot_path):
        task = load(config.load_snapshot_path)
    else:
        task = create_task(config.task, metadata=metadata)

    for mc in metric_channels or []:
        task.metric_reporter.add_channel(mc)

    return task
예제 #2
0
def prepare_task(
    config: PyTextConfig,
    dist_init_url: str = None,
    device_id: int = 0,
    rank: int = 0,
    world_size: int = 1,
    metric_channels: Optional[List[Channel]] = None,
    metadata: CommonMetadata = None,
) -> Tuple[Task_Deprecated, TrainingState]:
    if world_size > 1 and config.random_seed is None:
        msg = (
            "Must set random seed when using world_size > 1, so that parameters have "
            "same initialization across workers."
        )
        raise ValueError(msg)

    if rank == 0:
        print("\nParameters: {}\n".format(config), flush=True)
    _set_cuda(config.use_cuda_if_available, device_id, world_size)
    _set_fp16(config.use_fp16, rank)
    _set_distributed(
        rank,
        world_size,
        dist_init_url,
        device_id,
        config.gpu_streams_for_distributed_training,
    )

    if config.random_seed is not None:
        set_random_seeds(config.random_seed, config.use_deterministic_cudnn)

    training_state = None

    if config.auto_resume_from_snapshot:
        # if there are existing checkpoints, resume from the latest one
        latest_snapshot_path = get_latest_checkpoint_path(
            os.path.dirname(config.save_snapshot_path)
        )
        if latest_snapshot_path:
            config.load_snapshot_path = latest_snapshot_path

    if config.load_snapshot_path:
        assert PathManager.isfile(config.load_snapshot_path)
        if config.use_config_from_snapshot:
            task, _, training_state = load(config.load_snapshot_path)
        else:
            task, _, training_state = load(
                config.load_snapshot_path, overwrite_config=config
            )
        if training_state:
            training_state.rank = rank
    else:
        task = create_task(
            config.task, metadata=metadata, rank=rank, world_size=world_size
        )

    for mc in metric_channels or []:
        task.metric_reporter.add_channel(mc)

    return task, training_state
예제 #3
0
def prepare_task(
    config: PyTextConfig,
    dist_init_url: str = None,
    device_id: int = 0,
    rank: int = 0,
    world_size: int = 1,
    metric_channels: Optional[List[Channel]] = None,
    metadata: CommonMetadata = None,
) -> Tuple[Task_Deprecated, TrainingState]:

    print("\nParameters: {}\n".format(config))
    _set_cuda(config.use_cuda_if_available, device_id, world_size)
    _set_fp16(config.use_fp16)
    _set_distributed(rank, world_size, dist_init_url, device_id)

    if config.random_seed is not None:
        set_random_seeds(config.random_seed, config.use_deterministic_cudnn)

    training_state = None
    if config.load_snapshot_path and os.path.isfile(config.load_snapshot_path):
        task, _config, training_state = load(config.load_snapshot_path)
        if training_state and training_state.model is None and task.model:
            training_state.model = task.model
    else:
        task = create_task(config.task,
                           metadata=metadata,
                           rank=rank,
                           world_size=world_size)

    for mc in metric_channels or []:
        task.metric_reporter.add_channel(mc)

    return task, training_state
예제 #4
0
    def test_set_random_seeds(self):
        set_random_seeds(456, False)

        self.assertEqual(random.randint(23, 57), 51)
        self.assertEqual(np.random.randint(93, 177), 120)
        self.assertTrue(
            bool(
                torch.eq(torch.randint(23, 57, (1, )),
                         torch.tensor([24], dtype=torch.long)).tolist()[0]))
예제 #5
0
 def __init__(self):
     # make result reproducible for testing purpose
     set_random_seeds(seed=0, use_deterministic_cudnn=True)
예제 #6
0
 def __init__(self):
     # make result reproducible for testing purpose
     set_random_seeds(seed=0, use_deterministic_cudnn=True)
     self.loss = torch.nn.CrossEntropyLoss()