コード例 #1
0
ファイル: xl_wrapper.py プロジェクト: gorkemgoknar/tr-gpt3
    def from_pretrained(cls,
                        model_name_or_path=None,
                        seq_len=512,
                        weights_path=None,
                        deepspeed_config_path=None):
        init_method = 'tcp://' + os.getenv('MASTER_ADDR',
                                           'localhost') + ':' + os.getenv(
                                               'MASTER_PORT', '6000')
        try:
            torch.distributed.init_process_group(backend='nccl',
                                                 world_size=1,
                                                 rank=0,
                                                 init_method=init_method)
            mpu.initialize_model_parallel(1)
        except RuntimeError:
            logger.info("The default process group has already initialized...")

        seed = 1234
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        mpu.model_parallel_cuda_manual_seed(seed)
        tokenizer = GPT2Tokenizer.from_pretrained(model_name_or_path)
        logger.info("Check cached model files...")
        if weights_path is None:
            weights_path, deepspeed_config_path = download_model_files(
                model_name_or_path)
        model = setup_model(weights_path, deepspeed_config_path)
        model.cuda()
        model = model.eval()
        return cls(model,
                   tokenizer=tokenizer,
                   seq_len=seq_len,
                   model_path=model_name_or_path)
コード例 #2
0
def set_random_seed(seed):
    """Set random seed for reproducability."""

    if seed is not None and seed > 0:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        mpu.model_parallel_cuda_manual_seed(seed)