Ejemplo n.º 1
0
def load_checkpoint_model(model, args):
    """Load a model checkpoint."""

    iteration, release, success = get_checkpoint_iteration(args)

    if not success:
        return 0

    # Checkpoint.
    checkpoint_name = get_checkpoint_name(args.load, iteration, release)

    if mpu.get_data_parallel_rank() == 0:
        print('global rank {} is loading checkpoint {}'.format(
            torch.distributed.get_rank(), checkpoint_name))

    # Load the checkpoint.
    sd = torch.load(checkpoint_name, map_location='cpu')

    if isinstance(model, torchDDP):
        model = model.module

    # Model.
    try:
        model.load_state_dict(sd['module'])
    except KeyError:
        print_rank_0('A metadata file exists but unable to load model '
                     'from checkpoint {}, exiting'.format(checkpoint_name))
        exit()

    torch.distributed.barrier()
    if mpu.get_data_parallel_rank() == 0:
        print('  successfully loaded {}'.format(checkpoint_name))

    return iteration
Ejemplo n.º 2
0
def load_data(args, data_type, tokenizer, ratio=1):
    data_path = args.data_dir
    # Data parallel arguments.
    world_size = mpu.get_data_parallel_world_size()
    rank = mpu.get_data_parallel_rank()
    global_batch_size = args.batch_size * world_size
    num_workers = args.num_workers

    # Dataset
    filename = os.path.join(data_path, data_type + '.json')
    dataset = CHIDDataset(args, filename, data_type, tokenizer, ratio=ratio)

    # Use a random sampler with distributed batch sampler.
    if data_type == 'train':
        sampler = RandomSampler(dataset)
    else:
        sampler = torch.utils.data.SequentialSampler(dataset)
    batch_sampler = DistributedBatchSampler(sampler=sampler,
                                            batch_size=global_batch_size,
                                            drop_last=True,
                                            rank=rank,
                                            world_size=world_size)

    # Torch dataloader.
    return torch.utils.data.DataLoader(dataset,
                                       batch_sampler=batch_sampler,
                                       num_workers=num_workers,
                                       pin_memory=True,
                                       collate_fn=dataset.collate), dataset
def get_model(args, version=None):
    """Build the model."""
    
    print_rank_0('building Bert model ...')
    if version is None:
        model = BertMixtureModel(num_layers=args.num_layers,
                      vocab_size=args.vocab_size,
                      hidden_size=args.hidden_size,
                      num_attention_heads=args.num_attention_heads,
                      embedding_dropout_prob=args.hidden_dropout,
                      attention_dropout_prob=args.attention_dropout,
                      output_dropout_prob=args.hidden_dropout,
                      layernorm_epsilon=args.layernorm_epsilon,
                      max_sequence_length=args.max_position_embeddings,
                      checkpoint_activations=args.checkpoint_activations,
                      checkpoint_num_layers=args.checkpoint_num_layers,
                      parallel_output=True,
                      num_experts=args.num_experts,
                      type_vocab_size=2)
    elif version == "v0":
        model = BertMixtureModel_v0(num_layers=args.num_layers,
                      vocab_size=args.vocab_size,
                      hidden_size=args.hidden_size,
                      num_attention_heads=args.num_attention_heads,
                      embedding_dropout_prob=args.hidden_dropout,
                      attention_dropout_prob=args.attention_dropout,
                      output_dropout_prob=args.hidden_dropout,
                      layernorm_epsilon=args.layernorm_epsilon,
                      max_sequence_length=args.max_position_embeddings,
                      checkpoint_activations=args.checkpoint_activations,
                      checkpoint_num_layers=args.checkpoint_num_layers,
                      parallel_output=True,
                      num_experts=args.num_experts,
                      type_vocab_size=2)
    
    if mpu.get_data_parallel_rank() == 0:
        print(' > number of parameters on model parallel rank {}: {}'.format(
            mpu.get_model_parallel_rank(),
            sum([p.nelement() for p in model.parameters()])), flush=True)

    #To prevent OOM for model sizes that cannot fit in GPU memory in full precision
    if args.deepspeed and args.fp16:
        model.half()

    # GPU allocation.
    model.cuda(torch.cuda.current_device())

    # Fp16 conversion.
    if args.fp16:
        model = FP16_Module(model)

    # Wrap model for distributed training.
    if USE_TORCH_DDP:
        i = torch.cuda.current_device()
        model = DDP(model, device_ids=[i], output_device=i,
                    process_group=mpu.get_data_parallel_group())
    else:
        model = DDP(model)

    return model
def get_model(args):
    """Build the model."""

    print_rank_0('building GPT2 model ...')
    model = GPT2Model(num_layers=args.num_layers,
                      vocab_size=args.vocab_size,
                      hidden_size=args.hidden_size,
                      num_attention_heads=args.num_attention_heads,
                      embedding_dropout_prob=args.hidden_dropout,
                      attention_dropout_prob=args.attention_dropout,
                      output_dropout_prob=args.hidden_dropout,
                      max_sequence_length=args.max_position_embeddings,
                      checkpoint_activations=args.checkpoint_activations,
                      checkpoint_num_layers=args.checkpoint_num_layers,
                      parallel_output=False)

    if mpu.get_data_parallel_rank() == 0:
        print(' > number of parameters on model parallel rank {}: {}'.format(
            mpu.get_model_parallel_rank(),
            sum([p.nelement() for p in model.parameters()])),
              flush=True)

    # GPU allocation.
    model.cuda(torch.cuda.current_device())

    # Fp16 conversion.
    if args.fp16:
        model = FP16_Module(model)

    # Wrap model for distributed training.
    model = DDP(model)

    return model
Ejemplo n.º 5
0
def build_data_loader(dataset,
                      batch_size,
                      num_workers,
                      drop_last,
                      shuffle=True,
                      only_rank0=False):
    """Data loader. Note that batch-size is the local (per GPU) batch-size."""

    # Sampler.
    if only_rank0:
        rank, world_size = 0, 1
    else:
        world_size = mpu.get_data_parallel_world_size()
        rank = mpu.get_data_parallel_rank()
    sampler = torch.utils.data.distributed.DistributedSampler(
        dataset, num_replicas=world_size, rank=rank, shuffle=shuffle)

    # Data loader. Note that batch size is the per GPU batch size.
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=batch_size,
                                              sampler=sampler,
                                              shuffle=False,
                                              num_workers=num_workers,
                                              drop_last=drop_last,
                                              pin_memory=True,
                                              collate_fn=my_collate)

    return data_loader
Ejemplo n.º 6
0
def load_pretrained(model, checkpoint_path, args, task_tokens=None):
    load_dir, tag, release, success = get_checkpoint_iteration(checkpoint_path)
    checkpoint_name = get_checkpoint_name(load_dir, tag, release)
    if mpu.get_data_parallel_rank() == 0:
        print('global rank {} is loading pretrained model {}'.format(
            torch.distributed.get_rank(), checkpoint_name))
    # Load the checkpoint.
    sd = torch.load(checkpoint_name, map_location='cpu')
    if args.deepspeed:
        model = model.module
    if isinstance(model, TorchDDP):
        model = model.module
    if isinstance(model, FP16_Module):
        model = model.module
    if hasattr(model, "model"):
        model = model.model

    # Model.
    def extend_embedding_weights(state_weights, model_weights):
        original_length = state_weights.shape[0]
        assert original_length <= args.max_position_embeddings + 1
        new_weights = model_weights.clone()
        new_weights[:original_length] = state_weights
        return new_weights

    if args.block_lm:
        if "transformer.block_position_embeddings.weight" in sd["module"]:
            position_weights = sd['module'][
                "transformer.position_embeddings.weight"]
            if args.max_position_embeddings + 1 > position_weights.shape[0]:
                sd['module'][
                    "transformer.position_embeddings.weight"] = extend_embedding_weights(
                        position_weights,
                        model.state_dict()
                        ["transformer.position_embeddings.weight"].data)
                print_rank_0(
                    f"Extend position embedding to {args.max_position_embeddings + 1}"
                )
        if "transformer.block_position_embeddings.weight" in sd["module"]:
            block_position_weights = sd['module'][
                "transformer.block_position_embeddings.weight"]
            if args.max_position_embeddings + 1 > block_position_weights.shape[
                    0]:
                sd['module'][
                    "transformer.block_position_embeddings.weight"] = extend_embedding_weights(
                        block_position_weights,
                        model.state_dict()
                        ["transformer.block_position_embeddings.weight"].data)
                print_rank_0(
                    f"Extend block position embedding to {args.max_position_embeddings + 1}"
                )
    missing_keys, unexpected_keys = model.load_state_dict(sd['module'],
                                                          strict=False)
    if missing_keys or unexpected_keys:
        print_rank_0(
            f"Missing keys {missing_keys}, unexpected keys {unexpected_keys}")
    if args.continuous_prompt and args.prompt_init:
        model.prompt_spell.init_embedding(model.word_embeddings.weight.data,
                                          task_tokens)
Ejemplo n.º 7
0
def save_checkpoint(iteration,
                    model,
                    optimizer,
                    lr_scheduler,
                    args,
                    tag=None,
                    barrier=True):
    """Save a model checkpoint."""
    if tag is None:
        tag = str(iteration)
    if args.deepspeed:
        save_ds_checkpoint(iteration, model, lr_scheduler, args, tag=tag)
    else:
        # Only rank zer0 of the data parallel writes to the disk.
        if isinstance(model, torchDDP):
            model = model.module

        if mpu.get_data_parallel_rank() == 0:
            checkpoint_name = get_checkpoint_name(args.save, tag)
            print(
                'global rank {} is saving checkpoint at iteration {:7d} to {}'.
                format(torch.distributed.get_rank(), iteration,
                       checkpoint_name))

            sd = {}
            sd['iteration'] = iteration
            sd['module'] = model.state_dict()

            # Optimizer stuff.
            if not args.no_save_optim:
                if optimizer is not None:
                    sd['optimizer'] = optimizer.state_dict()
                if lr_scheduler is not None:
                    sd['lr_scheduler'] = lr_scheduler.state_dict()

            # rng states.
            if not args.no_save_rng:
                sd['random_rng_state'] = random.getstate()
                sd['np_rng_state'] = np.random.get_state()
                sd['torch_rng_state'] = torch.get_rng_state()
                sd['cuda_rng_state'] = torch.cuda.get_rng_state()
                sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker(
                ).get_states()

            ensure_directory_exists(checkpoint_name)
            torch.save(sd, checkpoint_name)
            print('  successfully saved {}'.format(checkpoint_name))

    # Wait so everyone is done (necessary)
    if barrier:
        torch.distributed.barrier()
    # And update the latest iteration
    if torch.distributed.get_rank() == 0:
        tracker_filename = get_checkpoint_tracker_filename(args.save)
        with open(tracker_filename, 'w') as f:
            f.write(tag)
Ejemplo n.º 8
0
def get_checkpoint_name(checkpoints_path, iteration, release=False, zero=False):
    if release:
        d = 'release'
    else:
        d = 'iter_{:07d}'.format(iteration)
    if zero:
        dp_rank = mpu.get_data_parallel_rank()
        d += '_zero_dp_rank_{}'.format(dp_rank)
    return os.path.join(checkpoints_path, d,
                        'mp_rank_{:02d}'.format(mpu.get_model_parallel_rank()),
                        'model_optim_rng.pt')
Ejemplo n.º 9
0
def test_boradcast_data(model_parallel_size):

    if torch.distributed.get_rank() == 0:
        print(
            '> testing boradcast_data with model parallel size {} ...'.format(
                model_parallel_size))

    mpu.initialize_model_parallel(model_parallel_size)
    torch.manual_seed(1234 + mpu.get_data_parallel_rank())
    model_parallel_size = mpu.get_model_parallel_world_size()

    key_size_t = {
        'key1': [7, 11],
        'key2': [8, 2, 1],
        'key3': [13],
        'key4': [5, 1, 2],
        'key5': [5, 12]
    }
    keys = list(key_size_t.keys())

    data = {}
    data_t = {}
    for key in key_size_t:
        data[key] = torch.LongTensor(size=key_size_t[key]).random_(0, 1000)
        data_t[key] = data[key].clone()
    data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000)
    data_t['keyX'] = data['keyX'].clone()
    if mpu.get_model_parallel_rank() != 0:
        data = None

    data_utils._check_data_types(keys, data_t, torch.int64)
    key_size, key_numel, \
        total_numel = data_utils._build_key_size_numel_dictionaries(keys, data)
    for key in keys:
        assert key_size[key] == key_size_t[key]
    total_numel_t = 0
    for key in keys:
        target_size = functools.reduce(operator.mul, key_size_t[key], 1)
        assert key_numel[key] == target_size
        total_numel_t += target_size
    assert total_numel == total_numel_t

    data_b = data_utils.broadcast_data(keys, data, torch.int64)
    for key in keys:
        tensor = data_t[key].cuda()
        assert data_b[key].sub(tensor).abs().max() == 0

    # Reset groups
    mpu.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('>> passed the test :-)')
Ejemplo n.º 10
0
def make_gpt2_dataloaders(args):

    # Input parameters.
    input_data_sizes_file = args.input_data_sizes_file
    seq_length = args.seq_length
    initial_seed = args.seed

    # Data parallel arguments.
    world_size = mpu.get_data_parallel_world_size()
    rank = mpu.get_data_parallel_rank()
    global_batch_size = args.batch_size * world_size
    num_workers = args.num_workers

    def make_data_loader_(data_path):
        # Build the dataset.
        dataset = GPT2Dataset(data_path, input_data_sizes_file, seq_length,
                              initial_seed)
        # Use a simple sampler with distributed batch sampler.
        sampler = torch.utils.data.SequentialSampler(dataset)
        batch_sampler = DistributedBatchSampler(sampler=sampler,
                                                batch_size=global_batch_size,
                                                drop_last=True,
                                                rank=rank,
                                                world_size=world_size)
        # Torch dataloader.
        return torch.utils.data.DataLoader(dataset,
                                           batch_sampler=batch_sampler,
                                           num_workers=num_workers,
                                           pin_memory=True)

    train = make_data_loader_(args.train_data_path)
    valid = make_data_loader_(args.val_data_path)
    test = make_data_loader_(args.test_data_path)

    args.do_train = False
    args.do_valid = False
    args.do_test = False

    if train is not None:
        args.do_train = True
    if valid is not None:
        args.do_valid = True
    if test is not None:
        args.do_test = True

    # Tokenizer.
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2', cache_dir=args.cache_dir)
    eod_token = tokenizer.encoder['<|endoftext|>']
    num_tokens = eod_token + 1

    return (train, valid, test), num_tokens, eod_token
Ejemplo n.º 11
0
def get_model(args):
    """Build the model."""

    print_rank_0('building BERT model ...')
    model = BertModel(args)

    if mpu.get_data_parallel_rank() == 0:
        print(' > number of parameters on model parallel rank {}: {}'.format(
            mpu.get_model_parallel_rank(),
            sum([p.nelement() for p in model.parameters()])), flush=True)

    # GPU allocation.
    model.cuda(torch.cuda.current_device())

    # Fp16 conversion.
    if args.fp16:
        model = FP16_Module(model)
        if args.fp32_embedding:
            model.module.model.bert.embeddings.word_embeddings.float()
            if args.ds_type=='BERT':
                model.module.model.bert.embeddings.position_embeddings.float()
            else:
                model.module.model.bert.embeddings.token_position_embeddings.float()
                model.module.model.bert.embeddings.para_position_embeddings.float()
                model.module.model.bert.embeddings.sent_position_embeddings.float()
            model.module.model.bert.embeddings.token_type_embeddings.float()
        if args.fp32_tokentypes:
            model.module.model.bert.embeddings.token_type_embeddings.float()
        if args.fp32_layernorm:
            for name, _module in model.named_modules():
                if 'LayerNorm' in name:
                    _module.float()

    # Wrap model for distributed training.
    if args.DDP_impl == 'torch':
        i = torch.cuda.current_device()
        args.DDP_type = torch.nn.parallel.distributed.DistributedDataParallel
        model = args.DDP_type(model, device_ids=[i], output_device=i,
                              process_group=mpu.get_data_parallel_group())
    elif args.DDP_impl == 'local':
        args.DDP_type = LocalDDP
        model = args.DDP_type(model)
    else:
        print_rank_0('Unknown DDP implementation specified: {}. '
                     'Exiting.'.format(args.DDP_impl))
        exit()

    return model
def get_model(args):
    """Build the model."""

    print_rank_0('building GPT2 model ...')
    model = GPT2Model(num_layers=args.num_layers,
                      vocab_size=args.vocab_size,
                      hidden_size=args.hidden_size,
                      num_attention_heads=args.num_attention_heads,
                      embedding_dropout_prob=args.hidden_dropout,
                      attention_dropout_prob=args.attention_dropout,
                      output_dropout_prob=args.hidden_dropout,
                      max_sequence_length=args.max_position_embeddings,
                      max_memory_length=args.mem_length,
                      checkpoint_activations=args.checkpoint_activations,
                      checkpoint_num_layers=args.checkpoint_num_layers,
                      parallel_output=True,
                      relative_encoding=args.transformer_xl)

    if mpu.get_data_parallel_rank() == 0:
        print(' > number of parameters on model parallel rank {}: {}'.format(
            mpu.get_model_parallel_rank(),
            sum([p.nelement() for p in model.parameters()])),
              flush=True)

    # To prevent OOM for model sizes that cannot fit in GPU memory in full precision
    if hasattr(args, "deepspeed") and args.deepspeed and args.fp16:
        model.half()

    # GPU allocation.
    model.cuda(torch.cuda.current_device())

    # Fp16 conversion.
    if args.fp16:
        model = FP16_Module(model)

    # Wrap model for distributed training.
    if not args.deepspeed:
        if USE_TORCH_DDP:
            i = torch.cuda.current_device()
            model = DDP(model,
                        device_ids=[i],
                        output_device=i,
                        process_group=mpu.get_data_parallel_group())
        else:
            model = DDP(model)

    return model
Ejemplo n.º 13
0
def get_model(args):
    """Build the model."""

    print_rank_0('building GPT2 model ...')
    model = GPT2Model(num_layers=args.num_layers,
                      vocab_size=args.vocab_size,
                      hidden_size=args.hidden_size,
                      num_attention_heads=args.num_attention_heads,
                      embedding_dropout_prob=args.hidden_dropout,
                      attention_dropout_prob=args.attention_dropout,
                      output_dropout_prob=args.hidden_dropout,
                      max_sequence_length=args.max_position_embeddings,
                      checkpoint_activations=args.checkpoint_activations,
                      checkpoint_num_layers=args.checkpoint_num_layers,
                      parallel_output=True)

    if mpu.get_data_parallel_rank() == 0:
        print(' > number of parameters on model parallel rank {}: {}'.format(
            mpu.get_model_parallel_rank(),
            sum([p.nelement() for p in model.parameters()])),
              flush=True)

    # GPU allocation.
    model.cuda(torch.cuda.current_device())

    # Fp16 conversion.
    if args.fp16:
        model = FP16_Module(model)

    # Wrap model for distributed training.
    if args.DDP_impl == 'torch':
        i = torch.cuda.current_device()
        args.DDP_type = torch.nn.parallel.distributed.DistributedDataParallel
        model = args.DDP_type(model,
                              device_ids=[i],
                              output_device=i,
                              process_group=mpu.get_data_parallel_group())
    elif args.DDP_impl == 'local':
        args.DDP_type = LocalDDP
        model = args.DDP_type(model)
    else:
        print_rank_0('Unknown DDP implementation specified: {}. '
                     'Exiting.'.format(args.DDP_impl))
        exit()

    return model
Ejemplo n.º 14
0
 def __init__(self, args, tokenizer, max_seq_length, bert_prob=1.0, gap_sentence_prob=0.0, gpt_infill_prob=0.5,
              gpt_min_ratio=0.5, bert_ratio=0.15, gap_sentence_ratio=0.15, average_block_length=3,
              max_block_length=40, block_mask_prob=0.0, context_mask_ratio=0.0, context_mask_range=3,
              short_seq_prob=0.0, single_span_prob=0.0, block_position_encoding=True, encoder_decoder=False,
              shuffle_blocks=True, sentinel_token=False, task_mask=False, random_position=False, masked_lm=False):
     self.eod_token = args.eod_token
     self.tokenizer = tokenizer
     self.count = 0
     self.max_seq_length = max_seq_length
     self.rank = mpu.get_data_parallel_rank()
     self.world_size = mpu.get_data_parallel_world_size()
     # self.rank = 0
     # self.world_size = 1
     assert 0.0 <= bert_prob <= 1.0
     self.bert_prob = bert_prob
     self.gap_sentence_prob = gap_sentence_prob
     self.gpt_prob = 1 - bert_prob - gap_sentence_prob
     assert self.gpt_prob >= -1e-10
     self.infill_prob = gpt_infill_prob
     self.gpt_min_ratio = gpt_min_ratio
     self.bert_ratio = bert_ratio
     self.gap_sentence_ratio = gap_sentence_ratio
     self.block_length_distribution = [poisson.pmf(i, average_block_length) for i in range(1, max_block_length)]
     self.block_mask_prob = block_mask_prob
     self.context_mask_ratio = context_mask_ratio
     self.context_mask_range = context_mask_range
     self.short_seq_prob = short_seq_prob
     self.single_span_prob = single_span_prob
     self.block_position_encoding = block_position_encoding
     self.encoder_decoder = encoder_decoder
     self.shuffle_blocks = shuffle_blocks
     self.sentinel_token = sentinel_token
     self.generation_mask = 'gMASK' if task_mask else 'MASK'
     self.generation_mask = self.tokenizer.get_command(self.generation_mask).Id
     self.gap_sentence_mask = 'sMASK' if task_mask else 'MASK'
     self.gap_sentence_mask = self.tokenizer.get_command(self.gap_sentence_mask).Id
     self.random_position = random_position
     self.masked_lm = masked_lm
     print_rank_0(
         f"BERT prob {self.bert_prob}, gap sent prob {self.gap_sentence_prob}, GPT prob {self.gpt_prob}, infill prob {self.infill_prob}")
     print_rank_0(
         f"generation min ratio {self.gpt_min_ratio}, block ratio {self.bert_ratio}, gap sent ratio {self.gap_sentence_ratio}")
     print_rank_0(f"block length distribution {self.block_length_distribution}")
     print_rank_0(f"block mask prob {self.block_mask_prob}, context mask ratio {self.context_mask_ratio}")
Ejemplo n.º 15
0
def get_model(args):
    """Build the model."""

    print_rank_0('building BERT model ...')
    model = BertModel(args)

    if mpu.get_data_parallel_rank() == 0:
        print(' > number of parameters on model parallel rank {}: {}'.format(
            mpu.get_model_parallel_rank(),
            sum([p.nelement() for p in model.parameters()])),
              flush=True)

    # GPU allocation.
    model.cuda(torch.cuda.current_device())

    # Fp16 conversion.
    if args.fp16:
        model = FP16_Module(model)
        if args.fp32_embedding:
            model.module.model.bert.embeddings.word_embeddings.float()
            model.module.model.bert.embeddings.position_embeddings.float()
            model.module.model.bert.embeddings.token_type_embeddings.float()
        if args.fp32_tokentypes:
            model.module.model.bert.embeddings.token_type_embeddings.float()
        if args.fp32_layernorm:
            for name, _module in model.named_modules():
                if 'LayerNorm' in name:
                    _module.float()

    # Wrap model for distributed training.
    if USE_TORCH_DDP:
        i = torch.cuda.current_device()
        model = DDP(model,
                    device_ids=[i],
                    output_device=i,
                    process_group=mpu.get_data_parallel_group())
    else:
        model = DDP(model)

    return model
Ejemplo n.º 16
0
def test_initialize_model_parallel(model_parallel_size):

    if torch.distributed.get_rank() == 0:
        print('> testing initialize_model_parallel with size {} ...'.format(
            model_parallel_size))
    model_parallel_size_ = min(model_parallel_size,
                               torch.distributed.get_world_size())
    assert not mpu.model_parallel_is_initialized()
    mpu.initialize_model_parallel(model_parallel_size_)
    assert mpu.model_parallel_is_initialized()

    # Checks.
    def check(group, world_size, rank):
        assert world_size == torch.distributed.get_world_size(group=group)
        assert rank == torch.distributed.get_rank(group=group)

    # Model parallel.
    world_size = model_parallel_size_
    rank = torch.distributed.get_rank() % model_parallel_size_
    assert world_size == mpu.get_model_parallel_world_size()
    assert rank == mpu.get_model_parallel_rank()
    check(mpu.get_model_parallel_group(), world_size, rank)


    # Data parallel.
    world_size = torch.distributed.get_world_size() // model_parallel_size_
    rank = torch.distributed.get_rank() // model_parallel_size
    assert world_size == mpu.get_data_parallel_world_size()
    assert rank == mpu.get_data_parallel_rank()
    check(mpu.get_data_parallel_group(), world_size, rank)

    # Reset groups
    mpu.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('>> passed the test :-)')
Ejemplo n.º 17
0
def get_model(args, config, do_fp16=False):
    """Build the model."""

    print_rank_0('building GPT2 model ...')
    model = GPT2Model(**config,
                      checkpoint_activations=args.checkpoint_activations,
                      checkpoint_num_layers=args.checkpoint_num_layers,
                      parallel_output=True)

    if mpu.get_data_parallel_rank() == 0:
        print(' > number of parameters on model parallel rank {}: {}'.format(
            mpu.get_model_parallel_rank(),
            sum([p.nelement() for p in model.parameters()])),
              flush=True)

    # To prevent OOM for model sizes that cannot fit in GPU memory in full precision
    if args.deepspeed and do_fp16:
        model.half()

    # GPU allocation.
    model.cuda(torch.cuda.current_device())

    # Fp16 conversion.
    if do_fp16:
        model = FP16_Module(model)

    # Wrap model for distributed training.
    if USE_TORCH_DDP:
        i = torch.cuda.current_device()
        model = DDP(model,
                    device_ids=[i],
                    output_device=i,
                    process_group=mpu.get_data_parallel_group())
    else:
        model = DDP(model)

    return model
Ejemplo n.º 18
0
def make_data_loader(dataset):
    """Buld dataloader given an input dataset."""
    if dataset is None:
        return None
    args = get_args()

    # Data parallel arguments.
    world_size = mpu.get_data_parallel_world_size()
    rank = mpu.get_data_parallel_rank()
    global_batch_size = args.batch_size * world_size
    num_workers = args.num_workers

    # Use a simple sampler with distributed batch sampler.
    sampler = torch.utils.data.SequentialSampler(dataset)
    batch_sampler = DistributedBatchSampler(sampler=sampler,
                                            batch_size=global_batch_size,
                                            drop_last=True,
                                            rank=rank,
                                            world_size=world_size)
    # Torch dataloader.
    return torch.utils.data.DataLoader(dataset,
                                       batch_sampler=batch_sampler,
                                       num_workers=num_workers,
                                       pin_memory=True)
Ejemplo n.º 19
0
def load_checkpoint(load_path, model, optimizer, lr_scheduler, args):
    """Load a model checkpoint."""

    iteration, release, success = get_checkpoint_iteration(load_path)

    if not success:
        return 0

    if args.deepspeed:

        checkpoint_name, sd = model.load_checkpoint(
            load_path,
            iteration,
            load_module_strict=False,
            load_optimizer_states=False,
            load_lr_scheduler_states=False)

        if checkpoint_name is None:
            if mpu.get_data_parallel_rank() == 0:
                print("Unable to load checkpoint.")
            return iteration

    else:

        # Checkpoint.
        checkpoint_name = get_checkpoint_name(load_path, iteration, release)

        if mpu.get_data_parallel_rank() == 0:
            print('global rank {} is loading checkpoint {}'.format(
                torch.distributed.get_rank(), checkpoint_name))

        # Load the checkpoint.
        sd = torch.load(checkpoint_name, map_location='cpu')

        if isinstance(model, torchDDP):
            model = model.module

        # Model.
        try:
            model.load_state_dict(sd['model'])
        except KeyError:
            print_rank_0('A metadata file exists but unable to load model '
                         'from checkpoint {}, exiting'.format(checkpoint_name))
            exit()

        # Optimizer.
        if not release and not args.finetune and not args.no_load_optim:
            try:
                if optimizer is not None:
                    optimizer.load_state_dict(sd['optimizer'])
                if lr_scheduler is not None:
                    lr_scheduler.load_state_dict(sd['lr_scheduler'])
            except KeyError:
                print_rank_0(
                    'Unable to load optimizer from checkpoint {}, exiting. '
                    'Specify --no-load-optim or --finetune to prevent '
                    'attempting to load the optimizer '
                    'state.'.format(checkpoint_name))
                exit()

    # Iterations.
    if args.finetune or release:
        iteration = 0
    else:
        try:
            iteration = sd['iteration']
        except KeyError:
            try:  # Backward compatible with older checkpoints
                iteration = sd['total_iters']
            except KeyError:
                print_rank_0(
                    'A metadata file exists but Unable to load iteration '
                    ' from checkpoint {}, exiting'.format(checkpoint_name))
                exit()

    # rng states.
    if not release and not args.finetune and not args.no_load_rng:
        try:
            random.setstate(sd['random_rng_state'])
            np.random.set_state(sd['np_rng_state'])
            torch.set_rng_state(sd['torch_rng_state'])
            torch.cuda.set_rng_state(sd['cuda_rng_state'])
            mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states'])
        except KeyError:
            print_rank_0(
                'Unable to load optimizer from checkpoint {}, exiting. '
                'Specify --no-load-optim or --finetune to prevent '
                'attempting to load the optimizer '
                'state.'.format(checkpoint_name))
            exit()

    torch.distributed.barrier()
    if mpu.get_data_parallel_rank() == 0:
        print('  successfully loaded {}'.format(checkpoint_name))

    return iteration
Ejemplo n.º 20
0
def load_tnews_data(data_path, data_type, tokenizer, few_shot=False):
    args = get_args()

    filename = os.path.join(data_path, data_type+'.json')
    objs = []
    with open(filename) as fin:
        for line in fin:
            objs.append(json.loads(line.strip()))

    pad_id = tokenizer.encoder['<pad>']
    args.eod_token = tokenizer.encoder['<eod>']

    labels = []
    label_map = {}
    label_reverse = {}
    with open(os.path.join(data_path, 'labels.json')) as fin:
        for i, line in enumerate(fin):
            obj = json.loads(line.strip())
            labels.append(obj['label_desc'])
            label_map[obj['label_desc']] = i
            label_reverse[obj['label']] = obj['label_desc']

    all_tokens = []
    all_masks = []
    all_labels = []
    for _, obj in enumerate(objs):
        sentence = obj['sentence']
        tokenized_sentence = tokenizer.encode(sentence)[:args.seq_length-20]
        obj['label_desc'] = label_reverse[obj['label']]

        if few_shot:
            cur_labels = random.sample(labels, 3)
            while obj['label_desc'] in cur_labels:
                cur_labels = random.sample(labels, 3)
            cur_labels.append(obj['label_desc'])
            cur_label = cur_labels.index(obj['label_desc'])
            assert cur_label != -1
        else:
            cur_labels = labels
            cur_label = label_map[obj['label_desc']]

        all_labels.append(cur_label)

        for _, label in enumerate(cur_labels):
            prompt = "这是关于{}的文章:".format(label)
            prompt_tokens = tokenizer.encode(prompt)
            prompt_len = len(prompt_tokens)
            tokens = prompt_tokens + tokenized_sentence
            second_mask = [0] * (args.seq_length-1)
            for idx in range(prompt_len-1, len(tokens)-1):
                second_mask[idx] = 1
            all_masks.append(second_mask)
            token_length = len(tokens)
            assert token_length < args.seq_length
            tokens.extend([pad_id] * (args.seq_length - token_length))
            all_tokens.append(tokens)
    
    all_tokens = torch.tensor(all_tokens, dtype=torch.long)
    all_masks = torch.tensor(all_masks, dtype=torch.float)
    dataset = TensorDataset(all_tokens, all_masks)

    # Data parallel arguments.
    world_size = mpu.get_data_parallel_world_size()
    rank = mpu.get_data_parallel_rank()
    global_batch_size = args.batch_size * world_size
    num_workers = args.num_workers

    sampler = torch.utils.data.SequentialSampler(dataset)
    batch_sampler = DistributedBatchSampler(sampler=sampler,
                                            batch_size=global_batch_size,
                                            drop_last=True,
                                            rank=rank,
                                            world_size=world_size)
    
    # Torch dataloader.
    return torch.utils.data.DataLoader(dataset,
                                       batch_sampler=batch_sampler,
                                       num_workers=num_workers,
                                       pin_memory=True), all_labels
Ejemplo n.º 21
0
def get_model(args, model_type=None, multi_token=True, num_labels=None):
    """Build the model."""
    print_rank_0('building GLM model ...')

    output_predict, parallel_output = True, True
    if (model_type == "multiple_choice" or model_type == "classification") and not args.cloze_eval:
        output_predict = False
    if model_type is not None:
        parallel_output = False

    model = GLMModel(num_layers=args.num_layers,
                     vocab_size=args.vocab_size,
                     hidden_size=args.hidden_size,
                     num_attention_heads=args.num_attention_heads,
                     embedding_dropout_prob=args.hidden_dropout,
                     attention_dropout_prob=args.attention_dropout,
                     output_dropout_prob=args.hidden_dropout,
                     max_sequence_length=args.max_position_embeddings,
                     max_memory_length=args.mem_length,
                     checkpoint_activations=args.checkpoint_activations,
                     checkpoint_num_layers=args.checkpoint_num_layers,
                     parallel_output=parallel_output,
                     relative_encoding=args.transformer_xl,
                     block_position_encoding=args.block_lm and not args.masked_lm,
                     output_predict=output_predict)

    if model_type is not None:
        if model_type == 'cloze':
            if multi_token:
                if args.fast_decode:
                    model = GLMForMultiTokenClozeFast(model, length_penalty=args.length_penalty)
                else:
                    model = GLMForMultiTokenCloze(model, length_penalty=args.length_penalty)
            else:
                model = GLMForSingleTokenCloze(model)
        elif model_type == 'classification':
            model = GLMForSequenceClassification(model, args.hidden_size, args.output_dropout, args.pool_token,
                                                 num_class=num_labels)
        elif model_type == 'generation':
            pass
        else:
            raise NotImplementedError(model_type)

    if mpu.get_data_parallel_rank() == 0:
        print(' > number of parameters on model parallel rank {}: {}'.format(
            mpu.get_model_parallel_rank(),
            sum([p.nelement() for p in model.parameters()])), flush=True)

    # To prevent OOM for model sizes that cannot fit in GPU memory in full precision
    if hasattr(args, "deepspeed") and args.deepspeed and args.fp16:
        model.half()

    # GPU allocation.
    model.cuda(torch.cuda.current_device())

    # Fp16 conversion.
    if args.fp16:
        model = FP16_Module(model)

    # Wrap model for distributed training.
    if not args.deepspeed:
        if args.DDP_impl == 'torch':
            i = torch.cuda.current_device()
            model = TorchDDP(model, device_ids=[i], output_device=i,
                             process_group=mpu.get_data_parallel_group())
        else:
            model = LocalDDP(model)

    return model
Ejemplo n.º 22
0
def make_loaders(args, tokenizer):
    """makes training/val/test"""

    if args.use_tfrecords:
        return make_tfrecord_loaders(args)
    world_size = torch.distributed.get_world_size(
        group=mpu.get_data_parallel_group())
    if args.loader_scatter is not None:
        assert world_size % args.loader_scatter == 0
    batch_size = args.batch_size * world_size
    eval_batch_size = batch_size
    if args.eval_batch_size is not None:
        eval_batch_size = args.eval_batch_size * world_size
    seq_length = args.seq_length
    if seq_length < 0:
        seq_length = seq_length * world_size
    eval_seq_length = args.eval_seq_length
    if eval_seq_length is not None and eval_seq_length < 0:
        eval_seq_length = eval_seq_length * world_size
    split = get_split(args)
    data_set_args = {
        'path': args.train_data,
        'seq_length': seq_length,
        'mem_length': args.mem_length,
        'delim': args.delim,
        'text_key': args.text_key,
        'label_key': 'label',
        'ds_type': args.data_set_type,
        'split': split,
        'loose': args.loose_json,
        'max_preds_per_seq': args.max_preds_per_seq,
        'presplit_sentences': args.presplit_sentences,
        'sample_one_document': args.sample_one_document,
        'filter_english': args.filter_english,
        'pre_tokenize': not args.no_pre_tokenize,
        'tokenizer': tokenizer,
        'save_splits': args.save_splits,
        'load_splits': args.load_splits,
        'save_test_data': args.save_test_data,
        'no_lazy_loader': args.no_lazy_loader,
        'loader_scatter': args.loader_scatter,
        'data_parallel_rank': mpu.get_data_parallel_rank(),
        "non_sentence_start": args.non_sentence_start,
        "half_lazy_loader": args.half_lazy_loader
    }

    eval_set_args = copy.copy(data_set_args)
    eval_set_args['split'] = [1.]
    # if optional eval args were set then replace their
    # equivalent values in the arg dict
    if eval_seq_length:
        eval_set_args['seq_length'] = eval_seq_length
    if args.eval_max_preds_per_seq:
        eval_set_args['max_preds_per_seq'] = args.eval_max_preds_per_seq
    if args.eval_text_key is not None:
        eval_set_args['text_key'] = args.eval_text_key

    # make datasets splits and tokenizer
    train, valid, test = None, None, None

    if args.train_data is not None:
        train = data_utils.make_dataset(**data_set_args)
        if data_utils.should_split(split):
            train, valid, test = train
        eval_set_args['tokenizer'] = tokenizer

    # make training and val dataset if necessary
    if valid is None and args.valid_data is not None:
        eval_set_args['path'] = args.valid_data
        valid = data_utils.make_dataset(**eval_set_args)
        eval_set_args['tokenizer'] = tokenizer
    if test is None and args.test_data is not None:
        eval_set_args['path'] = args.test_data
        test = data_utils.make_dataset(**eval_set_args)

    # wrap datasets with data loader
    use_block = args.block_lm or args.encoder_decoder

    if train is not None and args.batch_size > 0:
        train = make_data_loader(train,
                                 tokenizer,
                                 batch_size,
                                 args.train_iters,
                                 args,
                                 shuffle=args.shuffle,
                                 block_collate=use_block)
        args.do_train = True
    else:
        args.do_train = False
    eval_batch_size = eval_batch_size if eval_batch_size != 0 else batch_size
    if valid is not None:
        valid = make_data_loader(valid,
                                 tokenizer,
                                 eval_batch_size,
                                 args.train_iters,
                                 args,
                                 shuffle=args.shuffle,
                                 block_collate=use_block)
        args.do_valid = True
    else:
        args.do_valid = False
    if test is not None:
        test = make_data_loader(test,
                                tokenizer,
                                eval_batch_size,
                                len(test) // eval_batch_size + 1,
                                args,
                                shuffle=args.shuffle,
                                block_collate=use_block)
        args.do_test = True
    else:
        args.do_test = False

    return train, valid, test
Ejemplo n.º 23
0
def load_ocnli_data(data_path, data_type, tokenizer):
    args = get_args()

    filename = os.path.join(data_path, data_type+'.json')
    objs = []
    with open(filename) as fin:
        for line in fin:
            objs.append(json.loads(line.strip()))

    pad_id = tokenizer.encoder['<pad>']
    args.eod_token = tokenizer.encoder['<eod>']

    all_tokens_1 = []
    all_masks_1 = []
    all_tokens_2 = []
    all_masks_2 = []    
    all_tokens_3 = []
    all_masks_3 = [] 
    all_labels = []
    for obj in objs:

        if obj['label'] == '-':
            continue

        prompt = "{}?对,".format(obj['sentence1'])
        prompt_tokens = tokenizer.encode(prompt)
        prompt_len = len(prompt_tokens)
        tokens = prompt_tokens + tokenizer.encode(obj['sentence2'])
        second_mask = [0] * (args.seq_length-1)
        for idx in range(prompt_len-1, len(tokens)-1):
            second_mask[idx] = 1
        all_masks_1.append(second_mask)
        token_length = len(tokens)
        assert token_length < args.seq_length
        tokens.extend([pad_id] * (args.seq_length - token_length))
        all_tokens_1.append(tokens)

        prompt = "{}?错,".format(obj['sentence1'])
        prompt_tokens = tokenizer.encode(prompt)
        prompt_len = len(prompt_tokens)
        tokens = prompt_tokens + tokenizer.encode(obj['sentence2'])
        second_mask = [0] * (args.seq_length-1)
        for idx in range(prompt_len-1, len(tokens)-1):
            second_mask[idx] = 1
        all_masks_2.append(second_mask)
        token_length = len(tokens)
        assert token_length < args.seq_length
        tokens.extend([pad_id] * (args.seq_length - token_length))
        all_tokens_2.append(tokens)

        prompt = "{}?也许,".format(obj['sentence1'])
        prompt_tokens = tokenizer.encode(prompt)
        prompt_len = len(prompt_tokens)
        tokens = prompt_tokens + tokenizer.encode(obj['sentence2'])
        second_mask = [0] * (args.seq_length-1)
        for idx in range(prompt_len-1, len(tokens)-1):
            second_mask[idx] = 1
        all_masks_3.append(second_mask)
        token_length = len(tokens)
        assert token_length < args.seq_length
        tokens.extend([pad_id] * (args.seq_length - token_length))
        all_tokens_3.append(tokens)

        if obj['label'] == 'entailment':
            all_labels.append([0])
        elif obj['label'] == 'contradiction':
            all_labels.append([1])
        else:
            all_labels.append([2])

    all_tokens_1 = torch.tensor(all_tokens_1, dtype=torch.long)
    all_masks_1 = torch.tensor(all_masks_1, dtype=torch.float)
    all_tokens_2 = torch.tensor(all_tokens_2, dtype=torch.long)
    all_masks_2 = torch.tensor(all_masks_2, dtype=torch.float)
    all_tokens_3 = torch.tensor(all_tokens_3, dtype=torch.long)
    all_masks_3 = torch.tensor(all_masks_3, dtype=torch.float)
    all_labels = torch.tensor(all_labels, dtype=torch.long)
    dataset = TensorDataset(all_tokens_1, all_masks_1, all_tokens_2, all_masks_2, all_tokens_3, all_masks_3, all_labels)

    # Data parallel arguments.
    world_size = mpu.get_data_parallel_world_size()
    rank = mpu.get_data_parallel_rank()
    global_batch_size = args.batch_size * world_size
    num_workers = args.num_workers

    # Use a random sampler with distributed batch sampler.
    if data_type == 'train':
        sampler = RandomSampler(dataset)
    else:
        sampler = torch.utils.data.SequentialSampler(dataset)
    batch_sampler = DistributedBatchSampler(sampler=sampler,
                                            batch_size=global_batch_size,
                                            drop_last=True,
                                            rank=rank,
                                            world_size=world_size)
    
    # Torch dataloader.
    return torch.utils.data.DataLoader(dataset,
                                       batch_sampler=batch_sampler,
                                       num_workers=num_workers,
                                       pin_memory=True)
Ejemplo n.º 24
0
def forward_step(data_iterator, model, args, timers, mems):
    """Forward step."""

    # Get the batch.
    timers('batch generator').start()
    timers('data loader').start()
    rand = random.Random(args.iteration * mpu.get_data_parallel_world_size() +
                         mpu.get_data_parallel_rank())
    if data_iterator[1] and rand.random() < args.multi_task_ratio:
        data = next(data_iterator[1]) if data_iterator[1] else None
        data["mode"] = "multi-task"
    else:
        data = next(data_iterator[0]) if data_iterator[0] else None
    # print_rank_0("data iterator")
    timers('data loader').stop()
    tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
        data, args)
    timers('batch generator').stop()

    # print_rank_0("get batch")

    def print_masked_text(batch_id):
        block_position_ids = position_ids[:, 1]
        position_ids_ = position_ids[:, 0]
        sep = attention_mask.item() if torch.numel(
            attention_mask) == 1 else attention_mask[batch_id].item()
        text, last_segment = "", []
        for i, token_id in enumerate(tokens[batch_id, :sep].tolist()):
            token = tokenizer.IdToToken(token_id)
            if token.startswith('[MASK') or token.endswith('MASK]'):
                if last_segment:
                    text += tokenizer.DecodeIds(last_segment)
                    last_segment = []
                text += f" [{position_ids_[batch_id, i].item()}, {token}]"
            else:
                last_segment.append(token_id)
        if last_segment:
            text += tokenizer.DecodeIds(last_segment)
        print(text.encode('utf-8'))
        last_index = None
        for i in range(sep, tokens.size(1)):
            if tokenizer.IdToToken(
                    tokens[batch_id, i].item()).startswith("<|startofpiece"):
                if last_index is not None:
                    print(
                        tokenizer.DecodeIds(
                            tokens[batch_id,
                                   last_index:i].tolist()).encode('utf-8'),
                        "|",
                        tokenizer.DecodeIds(
                            labels[batch_id,
                                   last_index:i].tolist()).encode('utf-8'),
                        position_ids_[batch_id, last_index:i].tolist(),
                        block_position_ids[batch_id, last_index:i].tolist())
                last_index = i
        if last_index is not None:
            print(
                tokenizer.DecodeIds(
                    tokens[batch_id,
                           last_index:].tolist()).encode('utf-8'), "|",
                tokenizer.DecodeIds(
                    labels[batch_id, last_index:].tolist()).encode('utf-8'),
                position_ids_[batch_id, last_index:].tolist(),
                block_position_ids[batch_id, last_index:].tolist())

    if data is not None and "mode" in data:
        mode = data['mode']
    else:
        mode = 'bert'

    logits, *mems = model(tokens, position_ids, attention_mask, *mems)
    losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(),
                                              labels)
    loss_mask = loss_mask.view(-1)
    loss = torch.sum(losses.view(-1) * loss_mask)
    if loss_mask.sum().item() > 0:
        loss = loss / loss_mask.sum()

    return loss, mems, mode
Ejemplo n.º 25
0
def load_checkpoint(model, optimizer, lr_scheduler, args):
    """Load a model checkpoint."""
    if isinstance(model, torchDDP):
        model = model.module
    # Read the tracker file and set the iteration.
    tracker_filename = get_checkpoint_tracker_filename(args.load)
    if not os.path.isfile(tracker_filename):
        print_rank_0('WARNING: could not find the metadata file {} '.format(
            tracker_filename))
        print_rank_0('    will not load any checkpoints and will start from '
                     'random')
        return 0
    iteration = 0
    release = False
    with open(tracker_filename, 'r') as f:
        metastring = f.read().strip()
        try:
            iteration = int(metastring)
        except ValueError:
            release = metastring == 'release'
            if not release:
                print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format(
                    tracker_filename))
                exit()

    assert iteration > 0 or release, 'error parsing metadata file {}'.format(
        tracker_filename)

    # Checkpoint.
    checkpoint_name = get_checkpoint_name(args.load, iteration, release)
    if mpu.get_data_parallel_rank() == 0:
        print('global rank {} is loading checkpoint {}'.format(
            torch.distributed.get_rank(), checkpoint_name))

    # Load the checkpoint.
    sd = torch.load(checkpoint_name, map_location='cpu')

    # Iterations.
    if args.finetune or release:
        iteration = 0
    else:
        try:
            iteration = sd['iteration']
        except KeyError:
            try: # Backward compatible with older checkpoints
                iteration = sd['total_iters']
            except KeyError:
                print_rank_0('A metadata file exists but Unable to load iteration '
                             ' from checkpoint {}, exiting'.format(checkpoint_name))
                exit()

    # Model.
    try:
        model.load_state_dict(sd['model'])
    except KeyError:
        print_rank_0('A metadata file exists but unable to load model '
                     'from checkpoint {}, exiting'.format(checkpoint_name))
        exit()

    # Optimizer.
    if not release and not args.finetune and not args.no_load_optim:
        try:
            if optimizer is not None:
                optimizer.load_state_dict(sd['optimizer'])
            if lr_scheduler is not None:
                lr_scheduler.load_state_dict(sd['lr_scheduler'])
        except KeyError:
            print_rank_0('Unable to load optimizer from checkpoint {}, exiting. '
                         'Specify --no-load-optim or --finetune to prevent '
                         'attempting to load the optimizer '
                         'state.'.format(checkpoint_name))
            exit()

    # rng states.
    if not release and not args.finetune and not args.no_load_rng:
        try:
            random.setstate(sd['random_rng_state'])
            np.random.set_state(sd['np_rng_state'])
            torch.set_rng_state(sd['torch_rng_state'])
            torch.cuda.set_rng_state(sd['cuda_rng_state'])
            mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states'])
        except KeyError:
            print_rank_0('Unable to load optimizer from checkpoint {}, exiting. '
                         'Specify --no-load-optim or --finetune to prevent '
                         'attempting to load the optimizer '
                         'state.'.format(checkpoint_name))
            exit()

    #torch.distributed.barrier()
    if mpu.get_data_parallel_rank() == 0:
        print('  successfully loaded {}'.format(checkpoint_name))

    return iteration
Ejemplo n.º 26
0
def load_checkpoint(model,
                    optimizer,
                    lr_scheduler,
                    args,
                    no_deepspeed=False,
                    no_load_optim=False):
    """Load a model checkpoint."""

    load_dir, tag, release, success = get_checkpoint_iteration(args.load)

    if not success:
        return 0

    if args.deepspeed and not no_deepspeed:

        checkpoint_name, sd = model.load_checkpoint(
            load_dir,
            tag,
            load_optimizer_states=not args.no_load_optim and not no_load_optim,
            load_lr_scheduler_states=not args.no_load_lr_scheduler)
        if not args.no_load_lr_scheduler and "client_lr_scheduler" in sd:
            lr_scheduler.load_state_dict(sd["client_lr_scheduler"])
            print_rank_0("Load lr scheduler state")
        if checkpoint_name is None:
            if mpu.get_data_parallel_rank() == 0:
                print("Unable to load checkpoint.")
            return tag

    else:

        # Checkpoint.
        checkpoint_name = get_checkpoint_name(load_dir, tag, release)

        if mpu.get_data_parallel_rank() == 0:
            print('global rank {} is loading checkpoint {}'.format(
                torch.distributed.get_rank(), checkpoint_name))

        # Load the checkpoint.
        sd = torch.load(checkpoint_name, map_location='cpu')

        # Model.
        if args.deepspeed:
            model = model.module
        missing_keys, unexpected_keys = model.load_state_dict(sd['module'],
                                                              strict=False)
        if missing_keys or unexpected_keys:
            print_rank_0(
                f"Missing keys {missing_keys}, unexpected keys {unexpected_keys}"
            )

        # Optimizer.
        if not release and not args.finetune and not args.no_load_optim and not no_load_optim:
            try:
                if optimizer is not None:
                    optimizer.load_state_dict(sd['optimizer'])
                if lr_scheduler is not None:
                    lr_scheduler.load_state_dict(sd['lr_scheduler'])
            except KeyError:
                print_rank_0(
                    'Unable to load optimizer from checkpoint {}, exiting. '
                    'Specify --no-load-optim or --finetune to prevent '
                    'attempting to load the optimizer '
                    'state.'.format(checkpoint_name))

    # Iterations.
    if args.finetune or release:
        iteration = 0
    else:
        try:
            iteration = sd['iteration']
        except KeyError:
            try:  # Backward compatible with older checkpoints
                iteration = sd['total_iters']
            except KeyError:
                print_rank_0(
                    'A metadata file exists but Unable to load iteration '
                    ' from checkpoint {}, starting from 0 iteration'.format(
                        checkpoint_name))
                iteration = 0

    # rng states.
    if not release and not args.finetune and not args.no_load_rng:
        try:
            random.setstate(sd['random_rng_state'])
            np.random.set_state(sd['np_rng_state'])
            torch.set_rng_state(sd['torch_rng_state'])
            torch.cuda.set_rng_state(sd['cuda_rng_state'])
            mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states'])
        except KeyError:
            print_rank_0(
                'Unable to load random state from checkpoint {}, exiting. '
                'Specify --no-load-rng or --finetune to prevent '
                'attempting to load the random '
                'state.'.format(checkpoint_name))

    if mpu.get_data_parallel_rank() == 0:
        print('  successfully loaded {}'.format(checkpoint_name))

    return iteration
Ejemplo n.º 27
0
def load_checkpoint(model, optimizer, lr_scheduler, args):
    """Load a model checkpoint."""

    load_dir, tag, release, success = get_checkpoint_iteration(args)

    if not success:
        return 0

    if args.deepspeed:

        checkpoint_name, sd = model.load_checkpoint(
            load_dir,
            tag,
            load_optimizer_states=not args.no_load_optim,
            load_lr_scheduler_states=not args.no_load_optim)
        if "client_lr_scheduler" in sd:
            lr_scheduler.load_state_dict(sd["client_lr_scheduler"])
            print_rank_0("Load lr scheduler state")
        if checkpoint_name is None:
            if mpu.get_data_parallel_rank() == 0:
                print("Unable to load checkpoint.")
            return tag

    else:

        # Checkpoint.
        checkpoint_name = get_checkpoint_name(load_dir, tag, release)

        if mpu.get_data_parallel_rank() == 0:
            print('global rank {} is loading checkpoint {}'.format(
                torch.distributed.get_rank(), checkpoint_name))

        # Load the checkpoint.
        sd = torch.load(checkpoint_name, map_location='cpu')

        if isinstance(model, torchDDP):
            model = model.module

        # Model.
        try:

            def extend_embedding_weights(state_weights, model_weights):
                original_length = state_weights.shape[0]
                assert original_length <= args.max_position_embeddings + 1
                new_weights = model_weights.clone()
                new_weights[:original_length] = state_weights
                return new_weights

            if args.block_lm:
                if "transformer.block_position_embeddings.weight" in sd[
                        "module"]:
                    position_weights = sd['module'][
                        "transformer.position_embeddings.weight"]
                    if args.max_position_embeddings + 1 > position_weights.shape[
                            0]:
                        sd['module'][
                            "transformer.position_embeddings.weight"] = extend_embedding_weights(
                                position_weights,
                                model.state_dict()
                                ["transformer.position_embeddings.weight"].data
                            )
                        print_rank_0(
                            f"Extend position embedding to {args.max_position_embeddings + 1}"
                        )
                if "transformer.block_position_embeddings.weight" in sd[
                        "module"]:
                    block_position_weights = sd['module'][
                        "transformer.block_position_embeddings.weight"]
                    if args.max_position_embeddings + 1 > block_position_weights.shape[
                            0]:
                        sd['module'][
                            "transformer.block_position_embeddings.weight"] = extend_embedding_weights(
                                block_position_weights,
                                model.state_dict()
                                ["transformer.block_position_embeddings.weight"]
                                .data)
                        print_rank_0(
                            f"Extend block position embedding to {args.max_position_embeddings + 1}"
                        )

            model.load_state_dict(sd['module'], strict=False)
        except KeyError:
            print_rank_0('A metadata file exists but unable to load model '
                         'from checkpoint {}, exiting'.format(checkpoint_name))
            exit()

        # Optimizer.
        if not release and not args.finetune and not args.no_load_optim:
            try:
                if optimizer is not None:
                    optimizer.load_state_dict(sd['optimizer'])
                if lr_scheduler is not None:
                    lr_scheduler.load_state_dict(sd['lr_scheduler'])
            except KeyError:
                print_rank_0(
                    'Unable to load optimizer from checkpoint {}, exiting. '
                    'Specify --no-load-optim or --finetune to prevent '
                    'attempting to load the optimizer '
                    'state.'.format(checkpoint_name))
                exit()

    # Iterations.
    if args.finetune or release:
        iteration = 0
    else:
        try:
            iteration = sd['iteration']
        except KeyError:
            try:  # Backward compatible with older checkpoints
                iteration = sd['total_iters']
            except KeyError:
                print_rank_0(
                    'A metadata file exists but Unable to load iteration '
                    ' from checkpoint {}, exiting'.format(checkpoint_name))
                exit()

    # rng states.
    if not release and not args.finetune and not args.no_load_rng:
        try:
            random.setstate(sd['random_rng_state'])
            np.random.set_state(sd['np_rng_state'])
            torch.set_rng_state(sd['torch_rng_state'])
            torch.cuda.set_rng_state(sd['cuda_rng_state'])
            mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states'])
        except KeyError:
            print_rank_0(
                'Unable to load optimizer from checkpoint {}, exiting. '
                'Specify --no-load-rng or --finetune to prevent '
                'attempting to load the random '
                'state.'.format(checkpoint_name))
            exit()

    if mpu.get_data_parallel_rank() == 0:
        print('  successfully loaded {}'.format(checkpoint_name))

    return iteration
Ejemplo n.º 28
0
def get_model(args,
              model_type=None,
              multi_token=True,
              num_labels=None,
              spell_length=None):
    """Build the model."""
    print_rank_0('building GPT2 model ...')
    if args.pretrained_bert:
        if model_type == "multiple_choice":
            model = BertForMultipleChoice.from_pretrained(
                args.tokenizer_model_type,
                cache_dir=args.cache_dir,
                fp32_layernorm=args.fp32_layernorm,
                fp32_embedding=args.fp32_embedding,
                layernorm_epsilon=args.layernorm_epsilon)
        elif model_type == "classification":
            model = BertForSequenceClassification.from_pretrained(
                args.tokenizer_model_type,
                cache_dir=args.cache_dir,
                fp32_layernorm=args.fp32_layernorm,
                fp32_embedding=args.fp32_embedding,
                layernorm_epsilon=args.layernorm_epsilon,
                num_labels=num_labels)
        else:
            raise NotImplementedError
    else:
        output_predict, paralle_output = True, True
        if (model_type == "multiple_choice"
                or model_type == "classification") and not args.cloze_eval:
            output_predict = False
        if model_type is not None:
            paralle_output = False
        if spell_length is not None:
            print_rank_0(f"Continuous spell length {spell_length}")
        model = GLMModel(num_layers=args.num_layers,
                         vocab_size=args.vocab_size,
                         hidden_size=args.hidden_size,
                         num_attention_heads=args.num_attention_heads,
                         embedding_dropout_prob=args.hidden_dropout,
                         attention_dropout_prob=args.attention_dropout,
                         output_dropout_prob=args.hidden_dropout,
                         max_sequence_length=args.max_position_embeddings,
                         max_memory_length=args.mem_length,
                         checkpoint_activations=args.checkpoint_activations,
                         checkpoint_num_layers=args.checkpoint_num_layers,
                         parallel_output=paralle_output,
                         relative_encoding=args.transformer_xl,
                         block_position_encoding=args.block_lm
                         and not args.masked_lm,
                         output_predict=output_predict,
                         spell_length=spell_length,
                         spell_func=args.prompt_func,
                         attention_scale=args.attention_scale)
        if args.freeze_transformer:
            model.freeze_transformer(
                tune_prefix_layers=args.tune_prefix_layers)
        if model_type is not None:
            if model_type == 'multiple_choice':
                if args.cloze_eval:
                    if multi_token:
                        if args.fast_decode:
                            model = GLMForMultiTokenClozeFast(
                                model, length_penalty=args.length_penalty)
                        else:
                            model = GLMForMultiTokenCloze(
                                model, length_penalty=args.length_penalty)
                    else:
                        model = GLMForSingleTokenCloze(
                            model, take_softmax=args.adapet)
                else:
                    model = GLMForSequenceClassification(model,
                                                         args.hidden_size,
                                                         args.output_dropout,
                                                         args.pool_token,
                                                         num_class=num_labels)
            elif model_type == 'classification':
                model = GLMForSequenceClassification(model,
                                                     args.hidden_size,
                                                     args.output_dropout,
                                                     args.pool_token,
                                                     num_class=num_labels)
            elif model_type == 'generation':
                pass
            else:
                raise NotImplementedError(model_type)

    if mpu.get_data_parallel_rank() == 0:
        print(' > number of parameters on model parallel rank {}: {}'.format(
            mpu.get_model_parallel_rank(),
            sum([p.nelement() for p in model.parameters()])),
              flush=True)

    # To prevent OOM for model sizes that cannot fit in GPU memory in full precision
    if args.fp16:
        model.half()

    # GPU allocation.
    model.cuda(torch.cuda.current_device())

    # Fp16 conversion.
    if args.fp16:
        model = FP16_Module(model)

    # Wrap model for distributed training.
    if not args.deepspeed and (args.train_iters or args.epochs):
        if args.DDP_impl == 'torch':
            i = torch.cuda.current_device()
            model = TorchDDP(model,
                             device_ids=[i],
                             output_device=i,
                             process_group=mpu.get_data_parallel_group())
        elif args.DDP_impl == 'local':
            model = LocalDDP(model)
        else:
            print_rank_0("Skip DDP model")
    return model
Ejemplo n.º 29
0
def load_checkpoint(model, optimizer, lr_scheduler, args):
    """Load a model checkpoint."""

    iteration, release, success = get_checkpoint_iteration(args)

    if not success:
        return 0

    if args.deepspeed:
        raise NotImplemented("No installed deep speed")

    else:

        if args.load_openai:
            from utils import move_weights
            from model import DistributedDataParallel as DDP
            from fp16 import FP16_Module
            model_path = args.load
            from transformers import GPT2LMHeadModel
            print('global rank {} is loading openai weights {}'.format(
                torch.distributed.get_rank(), model_path))
            model.cpu()
            gpt2model = GPT2LMHeadModel.from_pretrained(
                model_path, cache_dir='gpt2_weights')
            model2fill = model
            while isinstance(model2fill, (DDP, FP16_Module)):
                model2fill = model2fill.module
            move_weights(model2fill, gpt2model)
            model.cuda(torch.cuda.current_device())
            sd = {}
        else:
            # Checkpoint.
            checkpoint_name = get_checkpoint_name(args.load, iteration,
                                                  release)

            if mpu.get_data_parallel_rank() == 0:
                print('global rank {} is loading checkpoint {}'.format(
                    torch.distributed.get_rank(), checkpoint_name))
            sd = torch.load(checkpoint_name, map_location='cpu')

            if isinstance(model, torchDDP):
                model = model.module

            # Model.
            try:
                model.load_state_dict(sd['model'])
            except KeyError:
                print_rank_0(
                    'A metadata file exists but unable to load model '
                    'from checkpoint {}, exiting'.format(checkpoint_name))
                exit()

            # Optimizer.
            if not release and not args.finetune and not args.no_load_optim:
                try:
                    if optimizer is not None:
                        optimizer.load_state_dict(sd['optimizer'])
                    if lr_scheduler is not None:
                        lr_scheduler.load_state_dict(sd['lr_scheduler'])
                except KeyError:
                    print_rank_0(
                        'Unable to load optimizer from checkpoint {}, exiting. '
                        'Specify --no-load-optim or --finetune to prevent '
                        'attempting to load the optimizer '
                        'state.'.format(checkpoint_name))
                    exit()

    # Iterations.
    if args.finetune or release:
        iteration = 0
    else:
        try:
            iteration = sd['iteration']
        except KeyError:
            try:  # Backward compatible with older checkpoints
                iteration = sd['total_iters']
            except KeyError:
                print_rank_0(
                    'A metadata file exists but Unable to load iteration '
                    ' from checkpoint {}, exiting'.format(checkpoint_name))
                exit()

    # rng states.
    if not release and not args.finetune and not args.no_load_rng:
        try:
            random.setstate(sd['random_rng_state'])
            np.random.set_state(sd['np_rng_state'])
            torch.set_rng_state(sd['torch_rng_state'])
            torch.cuda.set_rng_state(sd['cuda_rng_state'])
            mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states'])
        except KeyError:
            print_rank_0(
                'Unable to load optimizer from checkpoint {}, exiting. '
                'Specify --no-load-optim or --finetune to prevent '
                'attempting to load the optimizer '
                'state.'.format(checkpoint_name))
            exit()

    torch.distributed.barrier()
    if mpu.get_data_parallel_rank() == 0:
        print('  successfully loaded {}'.format(checkpoint_name))

    return iteration
Ejemplo n.º 30
0
def save_checkpoint(iteration,
                    model,
                    optimizer,
                    lr_scheduler,
                    args,
                    tag=None,
                    barrier=True,
                    only_changed_parameters=False,
                    no_deepspeed=False,
                    no_save_optim=False):
    """Save a model checkpoint."""
    if tag is None:
        tag = str(iteration)
    if args.deepspeed and not no_deepspeed:
        save_ds_checkpoint(iteration, model, lr_scheduler, args, tag=tag)
    else:
        # Only rank zer0 of the data parallel writes to the disk.

        if mpu.get_data_parallel_rank() == 0:
            checkpoint_name = get_checkpoint_name(args.save, tag)
            print(
                'global rank {} is saving checkpoint at iteration {:7d} to {}'.
                format(torch.distributed.get_rank(), iteration,
                       checkpoint_name))
            sd = {'iteration': iteration}
            if args.deepspeed:
                model = model.module
            state_dict = model.state_dict()
            if only_changed_parameters:
                requires_grad_dict = {}
                for name, parameter in model.named_parameters():
                    requires_grad_dict[name] = parameter.requires_grad
                state_dict = {
                    key: value
                    for key, value in state_dict.items()
                    if requires_grad_dict[key]
                }
            sd['module'] = state_dict

            # Optimizer stuff.
            if not args.no_save_optim and not no_save_optim:
                if optimizer is not None:
                    sd['optimizer'] = optimizer.state_dict()
                if lr_scheduler is not None:
                    sd['lr_scheduler'] = lr_scheduler.state_dict()

            # rng states.
            if not args.no_save_rng:
                sd['random_rng_state'] = random.getstate()
                sd['np_rng_state'] = np.random.get_state()
                sd['torch_rng_state'] = torch.get_rng_state()
                sd['cuda_rng_state'] = torch.cuda.get_rng_state()
                sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker(
                ).get_states()

            ensure_directory_exists(checkpoint_name)
            torch.save(sd, checkpoint_name)
            print('  successfully saved {}'.format(checkpoint_name))

    # Wait so everyone is done (necessary)
    if barrier:
        torch.distributed.barrier()
    # And update the latest iteration
    if torch.distributed.get_rank() == 0:
        tracker_filename = get_checkpoint_tracker_filename(args.save)
        with open(tracker_filename, 'w') as f:
            f.write(tag)