Exemplo n.º 1
0
    def from_pretrained(cls,
                        model_name_or_path=None,
                        seq_len=512,
                        weights_path=None,
                        deepspeed_config_path=None):
        init_method = 'tcp://' + os.getenv('MASTER_ADDR',
                                           'localhost') + ':' + os.getenv(
                                               'MASTER_PORT', '6000')
        try:
            torch.distributed.init_process_group(backend='nccl',
                                                 world_size=1,
                                                 rank=0,
                                                 init_method=init_method)
            mpu.initialize_model_parallel(1)
        except RuntimeError:
            logger.info("The default process group has already initialized...")

        seed = 1234
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        mpu.model_parallel_cuda_manual_seed(seed)
        tokenizer = GPT2Tokenizer.from_pretrained(model_name_or_path)
        logger.info("Check cached model files...")
        if weights_path is None:
            weights_path, deepspeed_config_path = download_model_files(
                model_name_or_path)
        model = setup_model(weights_path, deepspeed_config_path)
        model.cuda()
        model = model.eval()
        return cls(model,
                   tokenizer=tokenizer,
                   seq_len=seq_len,
                   model_path=model_name_or_path)
Exemplo n.º 2
0
def initialize_distributed(args):
    """Initialize torch.distributed."""

    # Manually set the device ids.
    device = args.rank % torch.cuda.device_count()
    if args.local_rank is not None:
        device = args.local_rank
    torch.cuda.set_device(device)
    # Call the init process
    init_method = 'tcp://'
    master_ip = os.getenv('MASTER_ADDR', 'localhost')
    master_port = os.getenv('MASTER_PORT', str(args.master_port))
    init_method += master_ip + ':' + master_port
    torch.distributed.init_process_group(
        backend=args.distributed_backend,
        world_size=args.world_size, rank=args.rank,
        init_method=init_method)

    # Set the model-parallel / data-parallel communicators.
    mpu.initialize_model_parallel(args.model_parallel_size)

    # Optional DeepSpeed Activation Checkpointing Features
    #
    if DEEPSPEED_WRAP and args.deepspeed and args.deepspeed_activation_checkpointing:
        set_deepspeed_activation_checkpointing(args)