def load_checkpoint_model(model, args): """Load a model checkpoint.""" iteration, release, success = get_checkpoint_iteration(args) if not success: return 0 # Checkpoint. checkpoint_name = get_checkpoint_name(args.load, iteration, release) if mpu.get_data_parallel_rank() == 0: print('global rank {} is loading checkpoint {}'.format( torch.distributed.get_rank(), checkpoint_name)) # Load the checkpoint. sd = torch.load(checkpoint_name, map_location='cpu') if isinstance(model, torchDDP): model = model.module # Model. try: model.load_state_dict(sd['module']) except KeyError: print_rank_0('A metadata file exists but unable to load model ' 'from checkpoint {}, exiting'.format(checkpoint_name)) exit() torch.distributed.barrier() if mpu.get_data_parallel_rank() == 0: print(' successfully loaded {}'.format(checkpoint_name)) return iteration
def load_data(args, data_type, tokenizer, ratio=1): data_path = args.data_dir # Data parallel arguments. world_size = mpu.get_data_parallel_world_size() rank = mpu.get_data_parallel_rank() global_batch_size = args.batch_size * world_size num_workers = args.num_workers # Dataset filename = os.path.join(data_path, data_type + '.json') dataset = CHIDDataset(args, filename, data_type, tokenizer, ratio=ratio) # Use a random sampler with distributed batch sampler. if data_type == 'train': sampler = RandomSampler(dataset) else: sampler = torch.utils.data.SequentialSampler(dataset) batch_sampler = DistributedBatchSampler(sampler=sampler, batch_size=global_batch_size, drop_last=True, rank=rank, world_size=world_size) # Torch dataloader. return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True, collate_fn=dataset.collate), dataset
def get_model(args, version=None): """Build the model.""" print_rank_0('building Bert model ...') if version is None: model = BertMixtureModel(num_layers=args.num_layers, vocab_size=args.vocab_size, hidden_size=args.hidden_size, num_attention_heads=args.num_attention_heads, embedding_dropout_prob=args.hidden_dropout, attention_dropout_prob=args.attention_dropout, output_dropout_prob=args.hidden_dropout, layernorm_epsilon=args.layernorm_epsilon, max_sequence_length=args.max_position_embeddings, checkpoint_activations=args.checkpoint_activations, checkpoint_num_layers=args.checkpoint_num_layers, parallel_output=True, num_experts=args.num_experts, type_vocab_size=2) elif version == "v0": model = BertMixtureModel_v0(num_layers=args.num_layers, vocab_size=args.vocab_size, hidden_size=args.hidden_size, num_attention_heads=args.num_attention_heads, embedding_dropout_prob=args.hidden_dropout, attention_dropout_prob=args.attention_dropout, output_dropout_prob=args.hidden_dropout, layernorm_epsilon=args.layernorm_epsilon, max_sequence_length=args.max_position_embeddings, checkpoint_activations=args.checkpoint_activations, checkpoint_num_layers=args.checkpoint_num_layers, parallel_output=True, num_experts=args.num_experts, type_vocab_size=2) if mpu.get_data_parallel_rank() == 0: print(' > number of parameters on model parallel rank {}: {}'.format( mpu.get_model_parallel_rank(), sum([p.nelement() for p in model.parameters()])), flush=True) #To prevent OOM for model sizes that cannot fit in GPU memory in full precision if args.deepspeed and args.fp16: model.half() # GPU allocation. model.cuda(torch.cuda.current_device()) # Fp16 conversion. if args.fp16: model = FP16_Module(model) # Wrap model for distributed training. if USE_TORCH_DDP: i = torch.cuda.current_device() model = DDP(model, device_ids=[i], output_device=i, process_group=mpu.get_data_parallel_group()) else: model = DDP(model) return model
def get_model(args): """Build the model.""" print_rank_0('building GPT2 model ...') model = GPT2Model(num_layers=args.num_layers, vocab_size=args.vocab_size, hidden_size=args.hidden_size, num_attention_heads=args.num_attention_heads, embedding_dropout_prob=args.hidden_dropout, attention_dropout_prob=args.attention_dropout, output_dropout_prob=args.hidden_dropout, max_sequence_length=args.max_position_embeddings, checkpoint_activations=args.checkpoint_activations, checkpoint_num_layers=args.checkpoint_num_layers, parallel_output=False) if mpu.get_data_parallel_rank() == 0: print(' > number of parameters on model parallel rank {}: {}'.format( mpu.get_model_parallel_rank(), sum([p.nelement() for p in model.parameters()])), flush=True) # GPU allocation. model.cuda(torch.cuda.current_device()) # Fp16 conversion. if args.fp16: model = FP16_Module(model) # Wrap model for distributed training. model = DDP(model) return model
def build_data_loader(dataset, batch_size, num_workers, drop_last, shuffle=True, only_rank0=False): """Data loader. Note that batch-size is the local (per GPU) batch-size.""" # Sampler. if only_rank0: rank, world_size = 0, 1 else: world_size = mpu.get_data_parallel_world_size() rank = mpu.get_data_parallel_rank() sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=world_size, rank=rank, shuffle=shuffle) # Data loader. Note that batch size is the per GPU batch size. data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=sampler, shuffle=False, num_workers=num_workers, drop_last=drop_last, pin_memory=True, collate_fn=my_collate) return data_loader
def load_pretrained(model, checkpoint_path, args, task_tokens=None): load_dir, tag, release, success = get_checkpoint_iteration(checkpoint_path) checkpoint_name = get_checkpoint_name(load_dir, tag, release) if mpu.get_data_parallel_rank() == 0: print('global rank {} is loading pretrained model {}'.format( torch.distributed.get_rank(), checkpoint_name)) # Load the checkpoint. sd = torch.load(checkpoint_name, map_location='cpu') if args.deepspeed: model = model.module if isinstance(model, TorchDDP): model = model.module if isinstance(model, FP16_Module): model = model.module if hasattr(model, "model"): model = model.model # Model. def extend_embedding_weights(state_weights, model_weights): original_length = state_weights.shape[0] assert original_length <= args.max_position_embeddings + 1 new_weights = model_weights.clone() new_weights[:original_length] = state_weights return new_weights if args.block_lm: if "transformer.block_position_embeddings.weight" in sd["module"]: position_weights = sd['module'][ "transformer.position_embeddings.weight"] if args.max_position_embeddings + 1 > position_weights.shape[0]: sd['module'][ "transformer.position_embeddings.weight"] = extend_embedding_weights( position_weights, model.state_dict() ["transformer.position_embeddings.weight"].data) print_rank_0( f"Extend position embedding to {args.max_position_embeddings + 1}" ) if "transformer.block_position_embeddings.weight" in sd["module"]: block_position_weights = sd['module'][ "transformer.block_position_embeddings.weight"] if args.max_position_embeddings + 1 > block_position_weights.shape[ 0]: sd['module'][ "transformer.block_position_embeddings.weight"] = extend_embedding_weights( block_position_weights, model.state_dict() ["transformer.block_position_embeddings.weight"].data) print_rank_0( f"Extend block position embedding to {args.max_position_embeddings + 1}" ) missing_keys, unexpected_keys = model.load_state_dict(sd['module'], strict=False) if missing_keys or unexpected_keys: print_rank_0( f"Missing keys {missing_keys}, unexpected keys {unexpected_keys}") if args.continuous_prompt and args.prompt_init: model.prompt_spell.init_embedding(model.word_embeddings.weight.data, task_tokens)
def save_checkpoint(iteration, model, optimizer, lr_scheduler, args, tag=None, barrier=True): """Save a model checkpoint.""" if tag is None: tag = str(iteration) if args.deepspeed: save_ds_checkpoint(iteration, model, lr_scheduler, args, tag=tag) else: # Only rank zer0 of the data parallel writes to the disk. if isinstance(model, torchDDP): model = model.module if mpu.get_data_parallel_rank() == 0: checkpoint_name = get_checkpoint_name(args.save, tag) print( 'global rank {} is saving checkpoint at iteration {:7d} to {}'. format(torch.distributed.get_rank(), iteration, checkpoint_name)) sd = {} sd['iteration'] = iteration sd['module'] = model.state_dict() # Optimizer stuff. if not args.no_save_optim: if optimizer is not None: sd['optimizer'] = optimizer.state_dict() if lr_scheduler is not None: sd['lr_scheduler'] = lr_scheduler.state_dict() # rng states. if not args.no_save_rng: sd['random_rng_state'] = random.getstate() sd['np_rng_state'] = np.random.get_state() sd['torch_rng_state'] = torch.get_rng_state() sd['cuda_rng_state'] = torch.cuda.get_rng_state() sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker( ).get_states() ensure_directory_exists(checkpoint_name) torch.save(sd, checkpoint_name) print(' successfully saved {}'.format(checkpoint_name)) # Wait so everyone is done (necessary) if barrier: torch.distributed.barrier() # And update the latest iteration if torch.distributed.get_rank() == 0: tracker_filename = get_checkpoint_tracker_filename(args.save) with open(tracker_filename, 'w') as f: f.write(tag)
def get_checkpoint_name(checkpoints_path, iteration, release=False, zero=False): if release: d = 'release' else: d = 'iter_{:07d}'.format(iteration) if zero: dp_rank = mpu.get_data_parallel_rank() d += '_zero_dp_rank_{}'.format(dp_rank) return os.path.join(checkpoints_path, d, 'mp_rank_{:02d}'.format(mpu.get_model_parallel_rank()), 'model_optim_rng.pt')
def test_boradcast_data(model_parallel_size): if torch.distributed.get_rank() == 0: print( '> testing boradcast_data with model parallel size {} ...'.format( model_parallel_size)) mpu.initialize_model_parallel(model_parallel_size) torch.manual_seed(1234 + mpu.get_data_parallel_rank()) model_parallel_size = mpu.get_model_parallel_world_size() key_size_t = { 'key1': [7, 11], 'key2': [8, 2, 1], 'key3': [13], 'key4': [5, 1, 2], 'key5': [5, 12] } keys = list(key_size_t.keys()) data = {} data_t = {} for key in key_size_t: data[key] = torch.LongTensor(size=key_size_t[key]).random_(0, 1000) data_t[key] = data[key].clone() data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000) data_t['keyX'] = data['keyX'].clone() if mpu.get_model_parallel_rank() != 0: data = None data_utils._check_data_types(keys, data_t, torch.int64) key_size, key_numel, \ total_numel = data_utils._build_key_size_numel_dictionaries(keys, data) for key in keys: assert key_size[key] == key_size_t[key] total_numel_t = 0 for key in keys: target_size = functools.reduce(operator.mul, key_size_t[key], 1) assert key_numel[key] == target_size total_numel_t += target_size assert total_numel == total_numel_t data_b = data_utils.broadcast_data(keys, data, torch.int64) for key in keys: tensor = data_t[key].cuda() assert data_b[key].sub(tensor).abs().max() == 0 # Reset groups mpu.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print('>> passed the test :-)')
def make_gpt2_dataloaders(args): # Input parameters. input_data_sizes_file = args.input_data_sizes_file seq_length = args.seq_length initial_seed = args.seed # Data parallel arguments. world_size = mpu.get_data_parallel_world_size() rank = mpu.get_data_parallel_rank() global_batch_size = args.batch_size * world_size num_workers = args.num_workers def make_data_loader_(data_path): # Build the dataset. dataset = GPT2Dataset(data_path, input_data_sizes_file, seq_length, initial_seed) # Use a simple sampler with distributed batch sampler. sampler = torch.utils.data.SequentialSampler(dataset) batch_sampler = DistributedBatchSampler(sampler=sampler, batch_size=global_batch_size, drop_last=True, rank=rank, world_size=world_size) # Torch dataloader. return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True) train = make_data_loader_(args.train_data_path) valid = make_data_loader_(args.val_data_path) test = make_data_loader_(args.test_data_path) args.do_train = False args.do_valid = False args.do_test = False if train is not None: args.do_train = True if valid is not None: args.do_valid = True if test is not None: args.do_test = True # Tokenizer. tokenizer = GPT2Tokenizer.from_pretrained('gpt2', cache_dir=args.cache_dir) eod_token = tokenizer.encoder['<|endoftext|>'] num_tokens = eod_token + 1 return (train, valid, test), num_tokens, eod_token
def get_model(args): """Build the model.""" print_rank_0('building BERT model ...') model = BertModel(args) if mpu.get_data_parallel_rank() == 0: print(' > number of parameters on model parallel rank {}: {}'.format( mpu.get_model_parallel_rank(), sum([p.nelement() for p in model.parameters()])), flush=True) # GPU allocation. model.cuda(torch.cuda.current_device()) # Fp16 conversion. if args.fp16: model = FP16_Module(model) if args.fp32_embedding: model.module.model.bert.embeddings.word_embeddings.float() if args.ds_type=='BERT': model.module.model.bert.embeddings.position_embeddings.float() else: model.module.model.bert.embeddings.token_position_embeddings.float() model.module.model.bert.embeddings.para_position_embeddings.float() model.module.model.bert.embeddings.sent_position_embeddings.float() model.module.model.bert.embeddings.token_type_embeddings.float() if args.fp32_tokentypes: model.module.model.bert.embeddings.token_type_embeddings.float() if args.fp32_layernorm: for name, _module in model.named_modules(): if 'LayerNorm' in name: _module.float() # Wrap model for distributed training. if args.DDP_impl == 'torch': i = torch.cuda.current_device() args.DDP_type = torch.nn.parallel.distributed.DistributedDataParallel model = args.DDP_type(model, device_ids=[i], output_device=i, process_group=mpu.get_data_parallel_group()) elif args.DDP_impl == 'local': args.DDP_type = LocalDDP model = args.DDP_type(model) else: print_rank_0('Unknown DDP implementation specified: {}. ' 'Exiting.'.format(args.DDP_impl)) exit() return model
def get_model(args): """Build the model.""" print_rank_0('building GPT2 model ...') model = GPT2Model(num_layers=args.num_layers, vocab_size=args.vocab_size, hidden_size=args.hidden_size, num_attention_heads=args.num_attention_heads, embedding_dropout_prob=args.hidden_dropout, attention_dropout_prob=args.attention_dropout, output_dropout_prob=args.hidden_dropout, max_sequence_length=args.max_position_embeddings, max_memory_length=args.mem_length, checkpoint_activations=args.checkpoint_activations, checkpoint_num_layers=args.checkpoint_num_layers, parallel_output=True, relative_encoding=args.transformer_xl) if mpu.get_data_parallel_rank() == 0: print(' > number of parameters on model parallel rank {}: {}'.format( mpu.get_model_parallel_rank(), sum([p.nelement() for p in model.parameters()])), flush=True) # To prevent OOM for model sizes that cannot fit in GPU memory in full precision if hasattr(args, "deepspeed") and args.deepspeed and args.fp16: model.half() # GPU allocation. model.cuda(torch.cuda.current_device()) # Fp16 conversion. if args.fp16: model = FP16_Module(model) # Wrap model for distributed training. if not args.deepspeed: if USE_TORCH_DDP: i = torch.cuda.current_device() model = DDP(model, device_ids=[i], output_device=i, process_group=mpu.get_data_parallel_group()) else: model = DDP(model) return model
def get_model(args): """Build the model.""" print_rank_0('building GPT2 model ...') model = GPT2Model(num_layers=args.num_layers, vocab_size=args.vocab_size, hidden_size=args.hidden_size, num_attention_heads=args.num_attention_heads, embedding_dropout_prob=args.hidden_dropout, attention_dropout_prob=args.attention_dropout, output_dropout_prob=args.hidden_dropout, max_sequence_length=args.max_position_embeddings, checkpoint_activations=args.checkpoint_activations, checkpoint_num_layers=args.checkpoint_num_layers, parallel_output=True) if mpu.get_data_parallel_rank() == 0: print(' > number of parameters on model parallel rank {}: {}'.format( mpu.get_model_parallel_rank(), sum([p.nelement() for p in model.parameters()])), flush=True) # GPU allocation. model.cuda(torch.cuda.current_device()) # Fp16 conversion. if args.fp16: model = FP16_Module(model) # Wrap model for distributed training. if args.DDP_impl == 'torch': i = torch.cuda.current_device() args.DDP_type = torch.nn.parallel.distributed.DistributedDataParallel model = args.DDP_type(model, device_ids=[i], output_device=i, process_group=mpu.get_data_parallel_group()) elif args.DDP_impl == 'local': args.DDP_type = LocalDDP model = args.DDP_type(model) else: print_rank_0('Unknown DDP implementation specified: {}. ' 'Exiting.'.format(args.DDP_impl)) exit() return model
def __init__(self, args, tokenizer, max_seq_length, bert_prob=1.0, gap_sentence_prob=0.0, gpt_infill_prob=0.5, gpt_min_ratio=0.5, bert_ratio=0.15, gap_sentence_ratio=0.15, average_block_length=3, max_block_length=40, block_mask_prob=0.0, context_mask_ratio=0.0, context_mask_range=3, short_seq_prob=0.0, single_span_prob=0.0, block_position_encoding=True, encoder_decoder=False, shuffle_blocks=True, sentinel_token=False, task_mask=False, random_position=False, masked_lm=False): self.eod_token = args.eod_token self.tokenizer = tokenizer self.count = 0 self.max_seq_length = max_seq_length self.rank = mpu.get_data_parallel_rank() self.world_size = mpu.get_data_parallel_world_size() # self.rank = 0 # self.world_size = 1 assert 0.0 <= bert_prob <= 1.0 self.bert_prob = bert_prob self.gap_sentence_prob = gap_sentence_prob self.gpt_prob = 1 - bert_prob - gap_sentence_prob assert self.gpt_prob >= -1e-10 self.infill_prob = gpt_infill_prob self.gpt_min_ratio = gpt_min_ratio self.bert_ratio = bert_ratio self.gap_sentence_ratio = gap_sentence_ratio self.block_length_distribution = [poisson.pmf(i, average_block_length) for i in range(1, max_block_length)] self.block_mask_prob = block_mask_prob self.context_mask_ratio = context_mask_ratio self.context_mask_range = context_mask_range self.short_seq_prob = short_seq_prob self.single_span_prob = single_span_prob self.block_position_encoding = block_position_encoding self.encoder_decoder = encoder_decoder self.shuffle_blocks = shuffle_blocks self.sentinel_token = sentinel_token self.generation_mask = 'gMASK' if task_mask else 'MASK' self.generation_mask = self.tokenizer.get_command(self.generation_mask).Id self.gap_sentence_mask = 'sMASK' if task_mask else 'MASK' self.gap_sentence_mask = self.tokenizer.get_command(self.gap_sentence_mask).Id self.random_position = random_position self.masked_lm = masked_lm print_rank_0( f"BERT prob {self.bert_prob}, gap sent prob {self.gap_sentence_prob}, GPT prob {self.gpt_prob}, infill prob {self.infill_prob}") print_rank_0( f"generation min ratio {self.gpt_min_ratio}, block ratio {self.bert_ratio}, gap sent ratio {self.gap_sentence_ratio}") print_rank_0(f"block length distribution {self.block_length_distribution}") print_rank_0(f"block mask prob {self.block_mask_prob}, context mask ratio {self.context_mask_ratio}")
def get_model(args): """Build the model.""" print_rank_0('building BERT model ...') model = BertModel(args) if mpu.get_data_parallel_rank() == 0: print(' > number of parameters on model parallel rank {}: {}'.format( mpu.get_model_parallel_rank(), sum([p.nelement() for p in model.parameters()])), flush=True) # GPU allocation. model.cuda(torch.cuda.current_device()) # Fp16 conversion. if args.fp16: model = FP16_Module(model) if args.fp32_embedding: model.module.model.bert.embeddings.word_embeddings.float() model.module.model.bert.embeddings.position_embeddings.float() model.module.model.bert.embeddings.token_type_embeddings.float() if args.fp32_tokentypes: model.module.model.bert.embeddings.token_type_embeddings.float() if args.fp32_layernorm: for name, _module in model.named_modules(): if 'LayerNorm' in name: _module.float() # Wrap model for distributed training. if USE_TORCH_DDP: i = torch.cuda.current_device() model = DDP(model, device_ids=[i], output_device=i, process_group=mpu.get_data_parallel_group()) else: model = DDP(model) return model
def test_initialize_model_parallel(model_parallel_size): if torch.distributed.get_rank() == 0: print('> testing initialize_model_parallel with size {} ...'.format( model_parallel_size)) model_parallel_size_ = min(model_parallel_size, torch.distributed.get_world_size()) assert not mpu.model_parallel_is_initialized() mpu.initialize_model_parallel(model_parallel_size_) assert mpu.model_parallel_is_initialized() # Checks. def check(group, world_size, rank): assert world_size == torch.distributed.get_world_size(group=group) assert rank == torch.distributed.get_rank(group=group) # Model parallel. world_size = model_parallel_size_ rank = torch.distributed.get_rank() % model_parallel_size_ assert world_size == mpu.get_model_parallel_world_size() assert rank == mpu.get_model_parallel_rank() check(mpu.get_model_parallel_group(), world_size, rank) # Data parallel. world_size = torch.distributed.get_world_size() // model_parallel_size_ rank = torch.distributed.get_rank() // model_parallel_size assert world_size == mpu.get_data_parallel_world_size() assert rank == mpu.get_data_parallel_rank() check(mpu.get_data_parallel_group(), world_size, rank) # Reset groups mpu.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print('>> passed the test :-)')
def get_model(args, config, do_fp16=False): """Build the model.""" print_rank_0('building GPT2 model ...') model = GPT2Model(**config, checkpoint_activations=args.checkpoint_activations, checkpoint_num_layers=args.checkpoint_num_layers, parallel_output=True) if mpu.get_data_parallel_rank() == 0: print(' > number of parameters on model parallel rank {}: {}'.format( mpu.get_model_parallel_rank(), sum([p.nelement() for p in model.parameters()])), flush=True) # To prevent OOM for model sizes that cannot fit in GPU memory in full precision if args.deepspeed and do_fp16: model.half() # GPU allocation. model.cuda(torch.cuda.current_device()) # Fp16 conversion. if do_fp16: model = FP16_Module(model) # Wrap model for distributed training. if USE_TORCH_DDP: i = torch.cuda.current_device() model = DDP(model, device_ids=[i], output_device=i, process_group=mpu.get_data_parallel_group()) else: model = DDP(model) return model
def make_data_loader(dataset): """Buld dataloader given an input dataset.""" if dataset is None: return None args = get_args() # Data parallel arguments. world_size = mpu.get_data_parallel_world_size() rank = mpu.get_data_parallel_rank() global_batch_size = args.batch_size * world_size num_workers = args.num_workers # Use a simple sampler with distributed batch sampler. sampler = torch.utils.data.SequentialSampler(dataset) batch_sampler = DistributedBatchSampler(sampler=sampler, batch_size=global_batch_size, drop_last=True, rank=rank, world_size=world_size) # Torch dataloader. return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True)
def load_checkpoint(load_path, model, optimizer, lr_scheduler, args): """Load a model checkpoint.""" iteration, release, success = get_checkpoint_iteration(load_path) if not success: return 0 if args.deepspeed: checkpoint_name, sd = model.load_checkpoint( load_path, iteration, load_module_strict=False, load_optimizer_states=False, load_lr_scheduler_states=False) if checkpoint_name is None: if mpu.get_data_parallel_rank() == 0: print("Unable to load checkpoint.") return iteration else: # Checkpoint. checkpoint_name = get_checkpoint_name(load_path, iteration, release) if mpu.get_data_parallel_rank() == 0: print('global rank {} is loading checkpoint {}'.format( torch.distributed.get_rank(), checkpoint_name)) # Load the checkpoint. sd = torch.load(checkpoint_name, map_location='cpu') if isinstance(model, torchDDP): model = model.module # Model. try: model.load_state_dict(sd['model']) except KeyError: print_rank_0('A metadata file exists but unable to load model ' 'from checkpoint {}, exiting'.format(checkpoint_name)) exit() # Optimizer. if not release and not args.finetune and not args.no_load_optim: try: if optimizer is not None: optimizer.load_state_dict(sd['optimizer']) if lr_scheduler is not None: lr_scheduler.load_state_dict(sd['lr_scheduler']) except KeyError: print_rank_0( 'Unable to load optimizer from checkpoint {}, exiting. ' 'Specify --no-load-optim or --finetune to prevent ' 'attempting to load the optimizer ' 'state.'.format(checkpoint_name)) exit() # Iterations. if args.finetune or release: iteration = 0 else: try: iteration = sd['iteration'] except KeyError: try: # Backward compatible with older checkpoints iteration = sd['total_iters'] except KeyError: print_rank_0( 'A metadata file exists but Unable to load iteration ' ' from checkpoint {}, exiting'.format(checkpoint_name)) exit() # rng states. if not release and not args.finetune and not args.no_load_rng: try: random.setstate(sd['random_rng_state']) np.random.set_state(sd['np_rng_state']) torch.set_rng_state(sd['torch_rng_state']) torch.cuda.set_rng_state(sd['cuda_rng_state']) mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states']) except KeyError: print_rank_0( 'Unable to load optimizer from checkpoint {}, exiting. ' 'Specify --no-load-optim or --finetune to prevent ' 'attempting to load the optimizer ' 'state.'.format(checkpoint_name)) exit() torch.distributed.barrier() if mpu.get_data_parallel_rank() == 0: print(' successfully loaded {}'.format(checkpoint_name)) return iteration
def load_tnews_data(data_path, data_type, tokenizer, few_shot=False): args = get_args() filename = os.path.join(data_path, data_type+'.json') objs = [] with open(filename) as fin: for line in fin: objs.append(json.loads(line.strip())) pad_id = tokenizer.encoder['<pad>'] args.eod_token = tokenizer.encoder['<eod>'] labels = [] label_map = {} label_reverse = {} with open(os.path.join(data_path, 'labels.json')) as fin: for i, line in enumerate(fin): obj = json.loads(line.strip()) labels.append(obj['label_desc']) label_map[obj['label_desc']] = i label_reverse[obj['label']] = obj['label_desc'] all_tokens = [] all_masks = [] all_labels = [] for _, obj in enumerate(objs): sentence = obj['sentence'] tokenized_sentence = tokenizer.encode(sentence)[:args.seq_length-20] obj['label_desc'] = label_reverse[obj['label']] if few_shot: cur_labels = random.sample(labels, 3) while obj['label_desc'] in cur_labels: cur_labels = random.sample(labels, 3) cur_labels.append(obj['label_desc']) cur_label = cur_labels.index(obj['label_desc']) assert cur_label != -1 else: cur_labels = labels cur_label = label_map[obj['label_desc']] all_labels.append(cur_label) for _, label in enumerate(cur_labels): prompt = "这是关于{}的文章:".format(label) prompt_tokens = tokenizer.encode(prompt) prompt_len = len(prompt_tokens) tokens = prompt_tokens + tokenized_sentence second_mask = [0] * (args.seq_length-1) for idx in range(prompt_len-1, len(tokens)-1): second_mask[idx] = 1 all_masks.append(second_mask) token_length = len(tokens) assert token_length < args.seq_length tokens.extend([pad_id] * (args.seq_length - token_length)) all_tokens.append(tokens) all_tokens = torch.tensor(all_tokens, dtype=torch.long) all_masks = torch.tensor(all_masks, dtype=torch.float) dataset = TensorDataset(all_tokens, all_masks) # Data parallel arguments. world_size = mpu.get_data_parallel_world_size() rank = mpu.get_data_parallel_rank() global_batch_size = args.batch_size * world_size num_workers = args.num_workers sampler = torch.utils.data.SequentialSampler(dataset) batch_sampler = DistributedBatchSampler(sampler=sampler, batch_size=global_batch_size, drop_last=True, rank=rank, world_size=world_size) # Torch dataloader. return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True), all_labels
def get_model(args, model_type=None, multi_token=True, num_labels=None): """Build the model.""" print_rank_0('building GLM model ...') output_predict, parallel_output = True, True if (model_type == "multiple_choice" or model_type == "classification") and not args.cloze_eval: output_predict = False if model_type is not None: parallel_output = False model = GLMModel(num_layers=args.num_layers, vocab_size=args.vocab_size, hidden_size=args.hidden_size, num_attention_heads=args.num_attention_heads, embedding_dropout_prob=args.hidden_dropout, attention_dropout_prob=args.attention_dropout, output_dropout_prob=args.hidden_dropout, max_sequence_length=args.max_position_embeddings, max_memory_length=args.mem_length, checkpoint_activations=args.checkpoint_activations, checkpoint_num_layers=args.checkpoint_num_layers, parallel_output=parallel_output, relative_encoding=args.transformer_xl, block_position_encoding=args.block_lm and not args.masked_lm, output_predict=output_predict) if model_type is not None: if model_type == 'cloze': if multi_token: if args.fast_decode: model = GLMForMultiTokenClozeFast(model, length_penalty=args.length_penalty) else: model = GLMForMultiTokenCloze(model, length_penalty=args.length_penalty) else: model = GLMForSingleTokenCloze(model) elif model_type == 'classification': model = GLMForSequenceClassification(model, args.hidden_size, args.output_dropout, args.pool_token, num_class=num_labels) elif model_type == 'generation': pass else: raise NotImplementedError(model_type) if mpu.get_data_parallel_rank() == 0: print(' > number of parameters on model parallel rank {}: {}'.format( mpu.get_model_parallel_rank(), sum([p.nelement() for p in model.parameters()])), flush=True) # To prevent OOM for model sizes that cannot fit in GPU memory in full precision if hasattr(args, "deepspeed") and args.deepspeed and args.fp16: model.half() # GPU allocation. model.cuda(torch.cuda.current_device()) # Fp16 conversion. if args.fp16: model = FP16_Module(model) # Wrap model for distributed training. if not args.deepspeed: if args.DDP_impl == 'torch': i = torch.cuda.current_device() model = TorchDDP(model, device_ids=[i], output_device=i, process_group=mpu.get_data_parallel_group()) else: model = LocalDDP(model) return model
def make_loaders(args, tokenizer): """makes training/val/test""" if args.use_tfrecords: return make_tfrecord_loaders(args) world_size = torch.distributed.get_world_size( group=mpu.get_data_parallel_group()) if args.loader_scatter is not None: assert world_size % args.loader_scatter == 0 batch_size = args.batch_size * world_size eval_batch_size = batch_size if args.eval_batch_size is not None: eval_batch_size = args.eval_batch_size * world_size seq_length = args.seq_length if seq_length < 0: seq_length = seq_length * world_size eval_seq_length = args.eval_seq_length if eval_seq_length is not None and eval_seq_length < 0: eval_seq_length = eval_seq_length * world_size split = get_split(args) data_set_args = { 'path': args.train_data, 'seq_length': seq_length, 'mem_length': args.mem_length, 'delim': args.delim, 'text_key': args.text_key, 'label_key': 'label', 'ds_type': args.data_set_type, 'split': split, 'loose': args.loose_json, 'max_preds_per_seq': args.max_preds_per_seq, 'presplit_sentences': args.presplit_sentences, 'sample_one_document': args.sample_one_document, 'filter_english': args.filter_english, 'pre_tokenize': not args.no_pre_tokenize, 'tokenizer': tokenizer, 'save_splits': args.save_splits, 'load_splits': args.load_splits, 'save_test_data': args.save_test_data, 'no_lazy_loader': args.no_lazy_loader, 'loader_scatter': args.loader_scatter, 'data_parallel_rank': mpu.get_data_parallel_rank(), "non_sentence_start": args.non_sentence_start, "half_lazy_loader": args.half_lazy_loader } eval_set_args = copy.copy(data_set_args) eval_set_args['split'] = [1.] # if optional eval args were set then replace their # equivalent values in the arg dict if eval_seq_length: eval_set_args['seq_length'] = eval_seq_length if args.eval_max_preds_per_seq: eval_set_args['max_preds_per_seq'] = args.eval_max_preds_per_seq if args.eval_text_key is not None: eval_set_args['text_key'] = args.eval_text_key # make datasets splits and tokenizer train, valid, test = None, None, None if args.train_data is not None: train = data_utils.make_dataset(**data_set_args) if data_utils.should_split(split): train, valid, test = train eval_set_args['tokenizer'] = tokenizer # make training and val dataset if necessary if valid is None and args.valid_data is not None: eval_set_args['path'] = args.valid_data valid = data_utils.make_dataset(**eval_set_args) eval_set_args['tokenizer'] = tokenizer if test is None and args.test_data is not None: eval_set_args['path'] = args.test_data test = data_utils.make_dataset(**eval_set_args) # wrap datasets with data loader use_block = args.block_lm or args.encoder_decoder if train is not None and args.batch_size > 0: train = make_data_loader(train, tokenizer, batch_size, args.train_iters, args, shuffle=args.shuffle, block_collate=use_block) args.do_train = True else: args.do_train = False eval_batch_size = eval_batch_size if eval_batch_size != 0 else batch_size if valid is not None: valid = make_data_loader(valid, tokenizer, eval_batch_size, args.train_iters, args, shuffle=args.shuffle, block_collate=use_block) args.do_valid = True else: args.do_valid = False if test is not None: test = make_data_loader(test, tokenizer, eval_batch_size, len(test) // eval_batch_size + 1, args, shuffle=args.shuffle, block_collate=use_block) args.do_test = True else: args.do_test = False return train, valid, test
def load_ocnli_data(data_path, data_type, tokenizer): args = get_args() filename = os.path.join(data_path, data_type+'.json') objs = [] with open(filename) as fin: for line in fin: objs.append(json.loads(line.strip())) pad_id = tokenizer.encoder['<pad>'] args.eod_token = tokenizer.encoder['<eod>'] all_tokens_1 = [] all_masks_1 = [] all_tokens_2 = [] all_masks_2 = [] all_tokens_3 = [] all_masks_3 = [] all_labels = [] for obj in objs: if obj['label'] == '-': continue prompt = "{}?对,".format(obj['sentence1']) prompt_tokens = tokenizer.encode(prompt) prompt_len = len(prompt_tokens) tokens = prompt_tokens + tokenizer.encode(obj['sentence2']) second_mask = [0] * (args.seq_length-1) for idx in range(prompt_len-1, len(tokens)-1): second_mask[idx] = 1 all_masks_1.append(second_mask) token_length = len(tokens) assert token_length < args.seq_length tokens.extend([pad_id] * (args.seq_length - token_length)) all_tokens_1.append(tokens) prompt = "{}?错,".format(obj['sentence1']) prompt_tokens = tokenizer.encode(prompt) prompt_len = len(prompt_tokens) tokens = prompt_tokens + tokenizer.encode(obj['sentence2']) second_mask = [0] * (args.seq_length-1) for idx in range(prompt_len-1, len(tokens)-1): second_mask[idx] = 1 all_masks_2.append(second_mask) token_length = len(tokens) assert token_length < args.seq_length tokens.extend([pad_id] * (args.seq_length - token_length)) all_tokens_2.append(tokens) prompt = "{}?也许,".format(obj['sentence1']) prompt_tokens = tokenizer.encode(prompt) prompt_len = len(prompt_tokens) tokens = prompt_tokens + tokenizer.encode(obj['sentence2']) second_mask = [0] * (args.seq_length-1) for idx in range(prompt_len-1, len(tokens)-1): second_mask[idx] = 1 all_masks_3.append(second_mask) token_length = len(tokens) assert token_length < args.seq_length tokens.extend([pad_id] * (args.seq_length - token_length)) all_tokens_3.append(tokens) if obj['label'] == 'entailment': all_labels.append([0]) elif obj['label'] == 'contradiction': all_labels.append([1]) else: all_labels.append([2]) all_tokens_1 = torch.tensor(all_tokens_1, dtype=torch.long) all_masks_1 = torch.tensor(all_masks_1, dtype=torch.float) all_tokens_2 = torch.tensor(all_tokens_2, dtype=torch.long) all_masks_2 = torch.tensor(all_masks_2, dtype=torch.float) all_tokens_3 = torch.tensor(all_tokens_3, dtype=torch.long) all_masks_3 = torch.tensor(all_masks_3, dtype=torch.float) all_labels = torch.tensor(all_labels, dtype=torch.long) dataset = TensorDataset(all_tokens_1, all_masks_1, all_tokens_2, all_masks_2, all_tokens_3, all_masks_3, all_labels) # Data parallel arguments. world_size = mpu.get_data_parallel_world_size() rank = mpu.get_data_parallel_rank() global_batch_size = args.batch_size * world_size num_workers = args.num_workers # Use a random sampler with distributed batch sampler. if data_type == 'train': sampler = RandomSampler(dataset) else: sampler = torch.utils.data.SequentialSampler(dataset) batch_sampler = DistributedBatchSampler(sampler=sampler, batch_size=global_batch_size, drop_last=True, rank=rank, world_size=world_size) # Torch dataloader. return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True)
def forward_step(data_iterator, model, args, timers, mems): """Forward step.""" # Get the batch. timers('batch generator').start() timers('data loader').start() rand = random.Random(args.iteration * mpu.get_data_parallel_world_size() + mpu.get_data_parallel_rank()) if data_iterator[1] and rand.random() < args.multi_task_ratio: data = next(data_iterator[1]) if data_iterator[1] else None data["mode"] = "multi-task" else: data = next(data_iterator[0]) if data_iterator[0] else None # print_rank_0("data iterator") timers('data loader').stop() tokens, labels, loss_mask, attention_mask, position_ids = get_batch( data, args) timers('batch generator').stop() # print_rank_0("get batch") def print_masked_text(batch_id): block_position_ids = position_ids[:, 1] position_ids_ = position_ids[:, 0] sep = attention_mask.item() if torch.numel( attention_mask) == 1 else attention_mask[batch_id].item() text, last_segment = "", [] for i, token_id in enumerate(tokens[batch_id, :sep].tolist()): token = tokenizer.IdToToken(token_id) if token.startswith('[MASK') or token.endswith('MASK]'): if last_segment: text += tokenizer.DecodeIds(last_segment) last_segment = [] text += f" [{position_ids_[batch_id, i].item()}, {token}]" else: last_segment.append(token_id) if last_segment: text += tokenizer.DecodeIds(last_segment) print(text.encode('utf-8')) last_index = None for i in range(sep, tokens.size(1)): if tokenizer.IdToToken( tokens[batch_id, i].item()).startswith("<|startofpiece"): if last_index is not None: print( tokenizer.DecodeIds( tokens[batch_id, last_index:i].tolist()).encode('utf-8'), "|", tokenizer.DecodeIds( labels[batch_id, last_index:i].tolist()).encode('utf-8'), position_ids_[batch_id, last_index:i].tolist(), block_position_ids[batch_id, last_index:i].tolist()) last_index = i if last_index is not None: print( tokenizer.DecodeIds( tokens[batch_id, last_index:].tolist()).encode('utf-8'), "|", tokenizer.DecodeIds( labels[batch_id, last_index:].tolist()).encode('utf-8'), position_ids_[batch_id, last_index:].tolist(), block_position_ids[batch_id, last_index:].tolist()) if data is not None and "mode" in data: mode = data['mode'] else: mode = 'bert' logits, *mems = model(tokens, position_ids, attention_mask, *mems) losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(), labels) loss_mask = loss_mask.view(-1) loss = torch.sum(losses.view(-1) * loss_mask) if loss_mask.sum().item() > 0: loss = loss / loss_mask.sum() return loss, mems, mode
def load_checkpoint(model, optimizer, lr_scheduler, args): """Load a model checkpoint.""" if isinstance(model, torchDDP): model = model.module # Read the tracker file and set the iteration. tracker_filename = get_checkpoint_tracker_filename(args.load) if not os.path.isfile(tracker_filename): print_rank_0('WARNING: could not find the metadata file {} '.format( tracker_filename)) print_rank_0(' will not load any checkpoints and will start from ' 'random') return 0 iteration = 0 release = False with open(tracker_filename, 'r') as f: metastring = f.read().strip() try: iteration = int(metastring) except ValueError: release = metastring == 'release' if not release: print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format( tracker_filename)) exit() assert iteration > 0 or release, 'error parsing metadata file {}'.format( tracker_filename) # Checkpoint. checkpoint_name = get_checkpoint_name(args.load, iteration, release) if mpu.get_data_parallel_rank() == 0: print('global rank {} is loading checkpoint {}'.format( torch.distributed.get_rank(), checkpoint_name)) # Load the checkpoint. sd = torch.load(checkpoint_name, map_location='cpu') # Iterations. if args.finetune or release: iteration = 0 else: try: iteration = sd['iteration'] except KeyError: try: # Backward compatible with older checkpoints iteration = sd['total_iters'] except KeyError: print_rank_0('A metadata file exists but Unable to load iteration ' ' from checkpoint {}, exiting'.format(checkpoint_name)) exit() # Model. try: model.load_state_dict(sd['model']) except KeyError: print_rank_0('A metadata file exists but unable to load model ' 'from checkpoint {}, exiting'.format(checkpoint_name)) exit() # Optimizer. if not release and not args.finetune and not args.no_load_optim: try: if optimizer is not None: optimizer.load_state_dict(sd['optimizer']) if lr_scheduler is not None: lr_scheduler.load_state_dict(sd['lr_scheduler']) except KeyError: print_rank_0('Unable to load optimizer from checkpoint {}, exiting. ' 'Specify --no-load-optim or --finetune to prevent ' 'attempting to load the optimizer ' 'state.'.format(checkpoint_name)) exit() # rng states. if not release and not args.finetune and not args.no_load_rng: try: random.setstate(sd['random_rng_state']) np.random.set_state(sd['np_rng_state']) torch.set_rng_state(sd['torch_rng_state']) torch.cuda.set_rng_state(sd['cuda_rng_state']) mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states']) except KeyError: print_rank_0('Unable to load optimizer from checkpoint {}, exiting. ' 'Specify --no-load-optim or --finetune to prevent ' 'attempting to load the optimizer ' 'state.'.format(checkpoint_name)) exit() #torch.distributed.barrier() if mpu.get_data_parallel_rank() == 0: print(' successfully loaded {}'.format(checkpoint_name)) return iteration
def load_checkpoint(model, optimizer, lr_scheduler, args, no_deepspeed=False, no_load_optim=False): """Load a model checkpoint.""" load_dir, tag, release, success = get_checkpoint_iteration(args.load) if not success: return 0 if args.deepspeed and not no_deepspeed: checkpoint_name, sd = model.load_checkpoint( load_dir, tag, load_optimizer_states=not args.no_load_optim and not no_load_optim, load_lr_scheduler_states=not args.no_load_lr_scheduler) if not args.no_load_lr_scheduler and "client_lr_scheduler" in sd: lr_scheduler.load_state_dict(sd["client_lr_scheduler"]) print_rank_0("Load lr scheduler state") if checkpoint_name is None: if mpu.get_data_parallel_rank() == 0: print("Unable to load checkpoint.") return tag else: # Checkpoint. checkpoint_name = get_checkpoint_name(load_dir, tag, release) if mpu.get_data_parallel_rank() == 0: print('global rank {} is loading checkpoint {}'.format( torch.distributed.get_rank(), checkpoint_name)) # Load the checkpoint. sd = torch.load(checkpoint_name, map_location='cpu') # Model. if args.deepspeed: model = model.module missing_keys, unexpected_keys = model.load_state_dict(sd['module'], strict=False) if missing_keys or unexpected_keys: print_rank_0( f"Missing keys {missing_keys}, unexpected keys {unexpected_keys}" ) # Optimizer. if not release and not args.finetune and not args.no_load_optim and not no_load_optim: try: if optimizer is not None: optimizer.load_state_dict(sd['optimizer']) if lr_scheduler is not None: lr_scheduler.load_state_dict(sd['lr_scheduler']) except KeyError: print_rank_0( 'Unable to load optimizer from checkpoint {}, exiting. ' 'Specify --no-load-optim or --finetune to prevent ' 'attempting to load the optimizer ' 'state.'.format(checkpoint_name)) # Iterations. if args.finetune or release: iteration = 0 else: try: iteration = sd['iteration'] except KeyError: try: # Backward compatible with older checkpoints iteration = sd['total_iters'] except KeyError: print_rank_0( 'A metadata file exists but Unable to load iteration ' ' from checkpoint {}, starting from 0 iteration'.format( checkpoint_name)) iteration = 0 # rng states. if not release and not args.finetune and not args.no_load_rng: try: random.setstate(sd['random_rng_state']) np.random.set_state(sd['np_rng_state']) torch.set_rng_state(sd['torch_rng_state']) torch.cuda.set_rng_state(sd['cuda_rng_state']) mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states']) except KeyError: print_rank_0( 'Unable to load random state from checkpoint {}, exiting. ' 'Specify --no-load-rng or --finetune to prevent ' 'attempting to load the random ' 'state.'.format(checkpoint_name)) if mpu.get_data_parallel_rank() == 0: print(' successfully loaded {}'.format(checkpoint_name)) return iteration
def load_checkpoint(model, optimizer, lr_scheduler, args): """Load a model checkpoint.""" load_dir, tag, release, success = get_checkpoint_iteration(args) if not success: return 0 if args.deepspeed: checkpoint_name, sd = model.load_checkpoint( load_dir, tag, load_optimizer_states=not args.no_load_optim, load_lr_scheduler_states=not args.no_load_optim) if "client_lr_scheduler" in sd: lr_scheduler.load_state_dict(sd["client_lr_scheduler"]) print_rank_0("Load lr scheduler state") if checkpoint_name is None: if mpu.get_data_parallel_rank() == 0: print("Unable to load checkpoint.") return tag else: # Checkpoint. checkpoint_name = get_checkpoint_name(load_dir, tag, release) if mpu.get_data_parallel_rank() == 0: print('global rank {} is loading checkpoint {}'.format( torch.distributed.get_rank(), checkpoint_name)) # Load the checkpoint. sd = torch.load(checkpoint_name, map_location='cpu') if isinstance(model, torchDDP): model = model.module # Model. try: def extend_embedding_weights(state_weights, model_weights): original_length = state_weights.shape[0] assert original_length <= args.max_position_embeddings + 1 new_weights = model_weights.clone() new_weights[:original_length] = state_weights return new_weights if args.block_lm: if "transformer.block_position_embeddings.weight" in sd[ "module"]: position_weights = sd['module'][ "transformer.position_embeddings.weight"] if args.max_position_embeddings + 1 > position_weights.shape[ 0]: sd['module'][ "transformer.position_embeddings.weight"] = extend_embedding_weights( position_weights, model.state_dict() ["transformer.position_embeddings.weight"].data ) print_rank_0( f"Extend position embedding to {args.max_position_embeddings + 1}" ) if "transformer.block_position_embeddings.weight" in sd[ "module"]: block_position_weights = sd['module'][ "transformer.block_position_embeddings.weight"] if args.max_position_embeddings + 1 > block_position_weights.shape[ 0]: sd['module'][ "transformer.block_position_embeddings.weight"] = extend_embedding_weights( block_position_weights, model.state_dict() ["transformer.block_position_embeddings.weight"] .data) print_rank_0( f"Extend block position embedding to {args.max_position_embeddings + 1}" ) model.load_state_dict(sd['module'], strict=False) except KeyError: print_rank_0('A metadata file exists but unable to load model ' 'from checkpoint {}, exiting'.format(checkpoint_name)) exit() # Optimizer. if not release and not args.finetune and not args.no_load_optim: try: if optimizer is not None: optimizer.load_state_dict(sd['optimizer']) if lr_scheduler is not None: lr_scheduler.load_state_dict(sd['lr_scheduler']) except KeyError: print_rank_0( 'Unable to load optimizer from checkpoint {}, exiting. ' 'Specify --no-load-optim or --finetune to prevent ' 'attempting to load the optimizer ' 'state.'.format(checkpoint_name)) exit() # Iterations. if args.finetune or release: iteration = 0 else: try: iteration = sd['iteration'] except KeyError: try: # Backward compatible with older checkpoints iteration = sd['total_iters'] except KeyError: print_rank_0( 'A metadata file exists but Unable to load iteration ' ' from checkpoint {}, exiting'.format(checkpoint_name)) exit() # rng states. if not release and not args.finetune and not args.no_load_rng: try: random.setstate(sd['random_rng_state']) np.random.set_state(sd['np_rng_state']) torch.set_rng_state(sd['torch_rng_state']) torch.cuda.set_rng_state(sd['cuda_rng_state']) mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states']) except KeyError: print_rank_0( 'Unable to load optimizer from checkpoint {}, exiting. ' 'Specify --no-load-rng or --finetune to prevent ' 'attempting to load the random ' 'state.'.format(checkpoint_name)) exit() if mpu.get_data_parallel_rank() == 0: print(' successfully loaded {}'.format(checkpoint_name)) return iteration
def get_model(args, model_type=None, multi_token=True, num_labels=None, spell_length=None): """Build the model.""" print_rank_0('building GPT2 model ...') if args.pretrained_bert: if model_type == "multiple_choice": model = BertForMultipleChoice.from_pretrained( args.tokenizer_model_type, cache_dir=args.cache_dir, fp32_layernorm=args.fp32_layernorm, fp32_embedding=args.fp32_embedding, layernorm_epsilon=args.layernorm_epsilon) elif model_type == "classification": model = BertForSequenceClassification.from_pretrained( args.tokenizer_model_type, cache_dir=args.cache_dir, fp32_layernorm=args.fp32_layernorm, fp32_embedding=args.fp32_embedding, layernorm_epsilon=args.layernorm_epsilon, num_labels=num_labels) else: raise NotImplementedError else: output_predict, paralle_output = True, True if (model_type == "multiple_choice" or model_type == "classification") and not args.cloze_eval: output_predict = False if model_type is not None: paralle_output = False if spell_length is not None: print_rank_0(f"Continuous spell length {spell_length}") model = GLMModel(num_layers=args.num_layers, vocab_size=args.vocab_size, hidden_size=args.hidden_size, num_attention_heads=args.num_attention_heads, embedding_dropout_prob=args.hidden_dropout, attention_dropout_prob=args.attention_dropout, output_dropout_prob=args.hidden_dropout, max_sequence_length=args.max_position_embeddings, max_memory_length=args.mem_length, checkpoint_activations=args.checkpoint_activations, checkpoint_num_layers=args.checkpoint_num_layers, parallel_output=paralle_output, relative_encoding=args.transformer_xl, block_position_encoding=args.block_lm and not args.masked_lm, output_predict=output_predict, spell_length=spell_length, spell_func=args.prompt_func, attention_scale=args.attention_scale) if args.freeze_transformer: model.freeze_transformer( tune_prefix_layers=args.tune_prefix_layers) if model_type is not None: if model_type == 'multiple_choice': if args.cloze_eval: if multi_token: if args.fast_decode: model = GLMForMultiTokenClozeFast( model, length_penalty=args.length_penalty) else: model = GLMForMultiTokenCloze( model, length_penalty=args.length_penalty) else: model = GLMForSingleTokenCloze( model, take_softmax=args.adapet) else: model = GLMForSequenceClassification(model, args.hidden_size, args.output_dropout, args.pool_token, num_class=num_labels) elif model_type == 'classification': model = GLMForSequenceClassification(model, args.hidden_size, args.output_dropout, args.pool_token, num_class=num_labels) elif model_type == 'generation': pass else: raise NotImplementedError(model_type) if mpu.get_data_parallel_rank() == 0: print(' > number of parameters on model parallel rank {}: {}'.format( mpu.get_model_parallel_rank(), sum([p.nelement() for p in model.parameters()])), flush=True) # To prevent OOM for model sizes that cannot fit in GPU memory in full precision if args.fp16: model.half() # GPU allocation. model.cuda(torch.cuda.current_device()) # Fp16 conversion. if args.fp16: model = FP16_Module(model) # Wrap model for distributed training. if not args.deepspeed and (args.train_iters or args.epochs): if args.DDP_impl == 'torch': i = torch.cuda.current_device() model = TorchDDP(model, device_ids=[i], output_device=i, process_group=mpu.get_data_parallel_group()) elif args.DDP_impl == 'local': model = LocalDDP(model) else: print_rank_0("Skip DDP model") return model
def load_checkpoint(model, optimizer, lr_scheduler, args): """Load a model checkpoint.""" iteration, release, success = get_checkpoint_iteration(args) if not success: return 0 if args.deepspeed: raise NotImplemented("No installed deep speed") else: if args.load_openai: from utils import move_weights from model import DistributedDataParallel as DDP from fp16 import FP16_Module model_path = args.load from transformers import GPT2LMHeadModel print('global rank {} is loading openai weights {}'.format( torch.distributed.get_rank(), model_path)) model.cpu() gpt2model = GPT2LMHeadModel.from_pretrained( model_path, cache_dir='gpt2_weights') model2fill = model while isinstance(model2fill, (DDP, FP16_Module)): model2fill = model2fill.module move_weights(model2fill, gpt2model) model.cuda(torch.cuda.current_device()) sd = {} else: # Checkpoint. checkpoint_name = get_checkpoint_name(args.load, iteration, release) if mpu.get_data_parallel_rank() == 0: print('global rank {} is loading checkpoint {}'.format( torch.distributed.get_rank(), checkpoint_name)) sd = torch.load(checkpoint_name, map_location='cpu') if isinstance(model, torchDDP): model = model.module # Model. try: model.load_state_dict(sd['model']) except KeyError: print_rank_0( 'A metadata file exists but unable to load model ' 'from checkpoint {}, exiting'.format(checkpoint_name)) exit() # Optimizer. if not release and not args.finetune and not args.no_load_optim: try: if optimizer is not None: optimizer.load_state_dict(sd['optimizer']) if lr_scheduler is not None: lr_scheduler.load_state_dict(sd['lr_scheduler']) except KeyError: print_rank_0( 'Unable to load optimizer from checkpoint {}, exiting. ' 'Specify --no-load-optim or --finetune to prevent ' 'attempting to load the optimizer ' 'state.'.format(checkpoint_name)) exit() # Iterations. if args.finetune or release: iteration = 0 else: try: iteration = sd['iteration'] except KeyError: try: # Backward compatible with older checkpoints iteration = sd['total_iters'] except KeyError: print_rank_0( 'A metadata file exists but Unable to load iteration ' ' from checkpoint {}, exiting'.format(checkpoint_name)) exit() # rng states. if not release and not args.finetune and not args.no_load_rng: try: random.setstate(sd['random_rng_state']) np.random.set_state(sd['np_rng_state']) torch.set_rng_state(sd['torch_rng_state']) torch.cuda.set_rng_state(sd['cuda_rng_state']) mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states']) except KeyError: print_rank_0( 'Unable to load optimizer from checkpoint {}, exiting. ' 'Specify --no-load-optim or --finetune to prevent ' 'attempting to load the optimizer ' 'state.'.format(checkpoint_name)) exit() torch.distributed.barrier() if mpu.get_data_parallel_rank() == 0: print(' successfully loaded {}'.format(checkpoint_name)) return iteration
def save_checkpoint(iteration, model, optimizer, lr_scheduler, args, tag=None, barrier=True, only_changed_parameters=False, no_deepspeed=False, no_save_optim=False): """Save a model checkpoint.""" if tag is None: tag = str(iteration) if args.deepspeed and not no_deepspeed: save_ds_checkpoint(iteration, model, lr_scheduler, args, tag=tag) else: # Only rank zer0 of the data parallel writes to the disk. if mpu.get_data_parallel_rank() == 0: checkpoint_name = get_checkpoint_name(args.save, tag) print( 'global rank {} is saving checkpoint at iteration {:7d} to {}'. format(torch.distributed.get_rank(), iteration, checkpoint_name)) sd = {'iteration': iteration} if args.deepspeed: model = model.module state_dict = model.state_dict() if only_changed_parameters: requires_grad_dict = {} for name, parameter in model.named_parameters(): requires_grad_dict[name] = parameter.requires_grad state_dict = { key: value for key, value in state_dict.items() if requires_grad_dict[key] } sd['module'] = state_dict # Optimizer stuff. if not args.no_save_optim and not no_save_optim: if optimizer is not None: sd['optimizer'] = optimizer.state_dict() if lr_scheduler is not None: sd['lr_scheduler'] = lr_scheduler.state_dict() # rng states. if not args.no_save_rng: sd['random_rng_state'] = random.getstate() sd['np_rng_state'] = np.random.get_state() sd['torch_rng_state'] = torch.get_rng_state() sd['cuda_rng_state'] = torch.cuda.get_rng_state() sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker( ).get_states() ensure_directory_exists(checkpoint_name) torch.save(sd, checkpoint_name) print(' successfully saved {}'.format(checkpoint_name)) # Wait so everyone is done (necessary) if barrier: torch.distributed.barrier() # And update the latest iteration if torch.distributed.get_rank() == 0: tracker_filename = get_checkpoint_tracker_filename(args.save) with open(tracker_filename, 'w') as f: f.write(tag)