def load_ict_checkpoint(model, only_query_model=False, only_block_model=False, from_realm_chkpt=False): """selectively load ICT models for indexing/retrieving from ICT or REALM checkpoints""" args = get_args() if isinstance(model, torchDDP): model = model.module load_path = args.load if from_realm_chkpt else args.ict_load tracker_filename = get_checkpoint_tracker_filename(load_path) with open(tracker_filename, 'r') as f: iteration = int(f.read().strip()) # assert iteration > 0 checkpoint_name = get_checkpoint_name(load_path, iteration, False) if mpu.get_data_parallel_rank() == 0: print('global rank {} is loading checkpoint {}'.format( torch.distributed.get_rank(), checkpoint_name)) state_dict = torch.load(checkpoint_name, map_location='cpu') ict_state_dict = state_dict['model'] if from_realm_chkpt and mpu.get_data_parallel_rank() == 0: print(" loading ICT state dict from REALM", flush=True) ict_state_dict = ict_state_dict['retriever']['ict_model'] if only_query_model: ict_state_dict.pop('context_model') if only_block_model: ict_state_dict.pop('question_model') model.load_state_dict(ict_state_dict) torch.distributed.barrier() if mpu.get_data_parallel_rank() == 0: print(' successfully loaded {}'.format(checkpoint_name)) return model
def setup_model_and_optimizer(model_provider_func): """Setup model and optimizer.""" args = get_args() model = get_model(model_provider_func) unwrapped_model = model while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16Module)): unwrapped_model = unwrapped_model.module optimizer = get_megatron_optimizer(unwrapped_model) lr_scheduler = get_learning_rate_scheduler(optimizer) if args.load is not None: timers = get_timers() # Extra barrier is added to make sure all ranks report the # max time. torch.distributed.barrier() timers('load checkpoint').start() args.iteration = load_checkpoint(model, optimizer, lr_scheduler) torch.distributed.barrier() timers('load checkpoint').stop() timers.log(['load checkpoint']) else: args.iteration = 0 # We only support local DDP with multiple micro-batches. if get_num_microbatches() > 1: assert args.DDP_impl == 'local' # get model without FP16 and/or TorchDDP wrappers unwrapped_model = model while hasattr(unwrapped_model, 'module'): unwrapped_model = unwrapped_model.module if args.iteration == 0 and hasattr(unwrapped_model, 'init_state_dict_from_bert'): print("Initializing ICT from pretrained BERT model", flush=True) unwrapped_model.init_state_dict_from_bert() return model, optimizer, lr_scheduler
def get_learning_rate_scheduler(optimizer): """Build the learning rate scheduler.""" args = get_args() # Iteration-based training. if args.train_iters: if args.lr_decay_iters is None: args.lr_decay_iters = args.train_iters decay_steps = args.lr_decay_iters * args.global_batch_size if args.lr_warmup_fraction is not None: warmup_steps = args.lr_warmup_fraction * decay_steps else: warmup_steps = args.lr_warmup_iters * args.global_batch_size # Sample-based training. elif args.train_samples: # We need to set training iters for later use. Technically # we need to adjust the training samples too (due to last # batch being incomplete) but we leave it as is for now. update_train_iters(args) if args.lr_decay_samples is None: args.lr_decay_samples = args.train_samples decay_steps = args.lr_decay_samples if args.lr_warmup_fraction is not None: warmup_steps = args.lr_warmup_fraction * decay_steps else: warmup_steps = args.lr_warmup_samples else: raise Exception( 'either train-iters or train-samples should be provided.') lr_scheduler = AnnealingLR( optimizer, max_lr=args.lr, min_lr=args.min_lr, warmup_steps=warmup_steps, decay_steps=decay_steps, decay_style=args.lr_decay_style, use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler, override_lr_scheduler=args.override_lr_scheduler) return lr_scheduler
def _initialize_affine_weight_cpu(weight, output_size, input_size, per_partition_size, partition_dim, init_method, stride=1, return_master_weight=False): """Initialize affine weight for model parallel. Build the master weight on all processes and scatter the relevant chunk.""" set_tensor_model_parallel_attributes(tensor=weight, is_parallel=True, dim=partition_dim, stride=stride) # Initialize master weight master_weight = torch.empty(output_size, input_size, dtype=torch.float, requires_grad=False) init_method(master_weight) args = get_args() master_weight = master_weight.to(dtype=args.params_dtype) # Split and copy per_partition_per_stride_size = divide(per_partition_size, stride) weight_list = torch.split(master_weight, per_partition_per_stride_size, dim=partition_dim) rank = get_tensor_model_parallel_rank() world_size = get_tensor_model_parallel_world_size() my_weight_list = weight_list[rank::world_size] with torch.no_grad(): torch.cat(my_weight_list, dim=partition_dim, out=weight) if return_master_weight: return master_weight return None
def main(): """Main program.""" initialize_megatron(extra_args_provider=add_text_generate_args, args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}) # Set up model and load checkpoint. model = get_model(model_provider) args = get_args() if args.load is not None: _ = load_checkpoint(model, None, None) # Generate samples. if args.num_samples == 0: args.batch_size = 1 if args.sample_input_file != "": generate_samples_input_from_file(model) else: generate_samples_interactive(model) else: generate_and_write_samples_unconditional(model)
def forward_step(data_iterator, model): """Forward step.""" args = get_args() timers = get_timers() # Get the batch. timers('batch generator').start() tokens_enc, tokens_dec, loss_mask, lm_labels, enc_mask, dec_mask, enc_dec_mask \ = get_batch(data_iterator) timers('batch generator').stop() # Forward model lm_labels output_tensor = model(tokens_enc, tokens_dec, enc_mask, dec_mask, enc_dec_mask, tokentype_ids=None, lm_labels=lm_labels) return output_tensor, partial(loss_func, loss_mask)
def model_provider(): """Build the model.""" args = get_args() print_rank_0('building classification model for {} ...'.format( args.task)) if mpu.get_pipeline_model_parallel_world_size() > 1: # Determine model based on position of stage in pipeline. if mpu.is_pipeline_first_stage(): model = ClassificationFirstStage(num_classes=num_classes, num_tokentypes=2) elif mpu.is_pipeline_last_stage(): model = ClassificationLastStage(num_classes=num_classes, num_tokentypes=2) else: model = ClassificationIntermediateStage( num_classes=num_classes, num_tokentypes=2) else: model = Classification(num_classes=num_classes, num_tokentypes=2) return model
def backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad): """Backward step.""" args = get_args() timers = get_timers() # Retain the grad on the input_tensor. if input_tensor is not None: input_tensor.retain_grad() # Backward pass. if output_tensor_grad is None: output_tensor = optimizer.scale_loss(output_tensor) torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad) # Collect the grad of the input_tensor. input_tensor_grad = None if input_tensor is not None: input_tensor_grad = input_tensor.grad return input_tensor_grad
def __init__(self, mpu_vocab_size, hidden_size, init_method, layernorm_epsilon, parallel_output): super(BertLMHead, self).__init__() args = get_args() self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) self.bias.tensor_model_parallel = True self.bias.partition_dim = 0 self.bias.stride = 1 self.parallel_output = parallel_output self.dense = get_linear_layer(hidden_size, hidden_size, init_method) LayerNorm = import_layernorm(args.fp32_residual_connection) self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) self.gelu = torch.nn.functional.gelu if args.openai_gelu: self.gelu = openai_gelu elif args.onnx_safe: self.gelu = erf_gelu
def __init__(self, mlp_activation_func, init_method, output_layer_init_method): super(ParallelMLP, self).__init__() args = get_args() # Project to 4h. self.dense_h_to_4h = mpu.ColumnParallelLinear(args.hidden_size, 4 * args.hidden_size, gather_output=False, init_method=init_method) self.activation_func = mlp_activation_func # Project back to h. self.dense_4h_to_h = mpu.RowParallelLinear( 4 * args.hidden_size, args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method) self.dropout = torch.nn.Dropout(args.hidden_dropout)
def train_valid_test_datasets_provider(train_val_test_num_samples): """Build train, valid and test datasets.""" args = get_args() print_rank_0('> building train, validation, and test datasets ' 'for BERT ICT...') train_ds, valid_ds, test_ds = build_train_valid_test_datasets( data_prefix=args.data_path, data_impl=args.data_impl, splits_string=args.split, train_valid_test_num_samples=train_val_test_num_samples, max_seq_length=args.seq_length, masked_lm_prob=args.mask_prob, short_seq_prob=args.short_seq_prob, seed=args.seed, skip_warmup=(not args.mmap_warmup), binary_head=False, dataset_type='ict') print_rank_0("> finished creating BERT ICT datasets ...") return train_ds, valid_ds, test_ds
def __init__(self, num_classes, num_tokentypes=2): super(Classification, self).__init__() args = get_args() self.num_classes = num_classes init_method = init_method_normal(args.init_method_std) self.language_model, self._language_model_key = get_language_model( attention_mask_func=bert_attention_mask_func, num_tokentypes=num_tokentypes, add_pooler=True, init_method=init_method, scaled_init_method=scaled_init_method_normal( args.init_method_std, args.num_layers)) # Multi-choice head. self.classification_dropout = torch.nn.Dropout(args.hidden_dropout) self.classification_head = get_linear_layer(args.hidden_size, self.num_classes, init_method) self._classification_head_key = 'classification_head'
def __init__(self, task_name, dataset_name, datapath, tokenizer, max_seq_length): # Store inputs. self.task_name = task_name self.dataset_name = dataset_name self.tokenizer = tokenizer self.max_seq_length = max_seq_length print_rank_0(' > building {} dataset for {}:'.format( self.task_name, self.dataset_name)) # Process the files. print_rank_0(datapath) self.samples, self.id2text = self.process_samples_from_single_path( datapath) args = get_args() if args.sample_rate < 1: # subsample k = int(len(self.samples) * args.sample_rate) self.samples = random.sample(self.samples, k) print_rank_0(' >> total number of samples: {}'.format( len(self.samples)))
def __init__(self, attention_mask_func, init_method, output_layer_init_method, layer_number): args = get_args() super(ParallelTransformerLayerPart1, self).__init__() self.layer_number = layer_number self.apply_residual_connection_post_layernorm \ = args.apply_residual_connection_post_layernorm # Layernorm on the input data. self.input_layernorm = LayerNorm(args.hidden_size, eps=args.layernorm_epsilon) # Self attention. self.attention = ParallelSelfAttention(attention_mask_func, init_method, output_layer_init_method, layer_number) self.hidden_dropout = args.hidden_dropout self.bias_dropout_fusion = args.bias_dropout_fusion
def _build_wikitext103_dataset(): """""" args = get_args() tokenizer = get_tokenizer() assert len(args.valid_data) == 1 with open(args.valid_data[0], "rb") as reader: entire_data = reader.read().decode('utf-8') num_original_tokens = len(entire_data.strip().split(" ")) entire_data = get_detokenizer(args.valid_data[0])(entire_data) tokenized_data = tokenizer.tokenize(entire_data) num_tokenized_tokens = len(tokenized_data) val_dataset = _LMDataset(tokenized_data, args.seq_length, tokenizer.eod, num_original_tokens, num_tokenized_tokens, args.overlapping_eval) print_rank_0(' > number of original tokens: {}, number of detokenized ' 'tokens: {}'.format(num_original_tokens, num_tokenized_tokens)) return val_dataset
def forward_step(data_iterator, model): """Forward step.""" args = get_args() timers = get_timers() # Get the batch. timers('batch-generator').start() tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch( data_iterator) timers('batch-generator').stop() if not args.bert_binary_head: types = None # Forward pass through the model. output_tensor = model(tokens, padding_mask, tokentype_ids=types, lm_labels=lm_labels) return output_tensor, partial(loss_func, loss_mask, sentence_order)
def forward_step(data_iterator, model): """Forward step.""" args = get_args() timers = get_timers() # Get the batch. timers('batch generator').start() tokens, labels, loss_mask, attention_mask, position_ids = get_batch( data_iterator) timers('batch generator').stop() # Forward model. losses = model(tokens, position_ids, attention_mask, labels=labels) if args.curriculum_learning and args.curriculum_seqlen < args.seq_length: loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous() loss_mask = loss_mask.view(-1) loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() # Reduce loss for logging. reduced_loss = reduce_losses([loss]) return loss, {'lm loss': reduced_loss[0]}
def get_ict_dataset(use_titles=True, query_in_block_prob=1): """Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block()) rather than for training, since it is only built with a single epoch sample mapping. """ args = get_args() block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True) titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True) kwargs = dict(name='full', block_dataset=block_dataset, title_dataset=titles_dataset, data_prefix=args.data_path, num_epochs=1, max_num_samples=None, max_seq_length=args.seq_length, seed=1, query_in_block_prob=query_in_block_prob, use_titles=use_titles, use_one_sent_docs=args.use_one_sent_docs) dataset = ICTDataset(**kwargs) return dataset
def check_checkpoint_args(checkpoint_args): """Ensure fixed arguments for a model are the same for the input arguments and the one retreived frm checkpoint.""" args = get_args() def _compare(arg_name): checkpoint_value = getattr(checkpoint_args, arg_name) args_value = getattr(args, arg_name) error_message = '{} value from checkpoint ({}) is not equal to the ' \ 'input argument value ({}).'.format( arg_name, checkpoint_value, args_value) assert checkpoint_value == args_value, error_message _compare('num_layers') _compare('hidden_size') _compare('num_attention_heads') _compare('max_position_embeddings') _compare('make_vocab_size_divisible_by') _compare('padded_vocab_size') _compare('tokenizer_type') _compare('model_parallel_size')
def __init__(self, attention_mask_func, mlp_activation_func, init_method, output_layer_init_method): super(ParallelTransformer, self).__init__() args = get_args() # Store activation checkpoiting flag. self.checkpoint_activations = args.checkpoint_activations self.checkpoint_num_layers = args.checkpoint_num_layers # Number of layers: self.num_layers = args.num_layers self.num_unique_layers = args.num_unique_layers if self.num_unique_layers is None: self.num_unique_layers = self.num_layers assert self.num_layers % self.num_unique_layers == 0, \ 'number of layers should be divisible by number of unique layers' self.param_sharing_style = args.param_sharing_style # Transformer layers. def build_layer(layer_number): return ParallelTransformerLayer(attention_mask_func, mlp_activation_func, init_method, output_layer_init_method, layer_number) self.layers = torch.nn.ModuleList( [build_layer(i + 1) for i in range(self.num_unique_layers)]) # Print layer ordering. if self.num_layers != self.num_unique_layers: if torch.distributed.get_rank() == 0: print('> will be using the following layer ordering:') for i in range(self.num_layers): print(' layer id: {:3d} --> unique layer id: ' '{:3d}'.format(i, self._get_layer_index(i)), flush=True) # Final layer norm before output. self.final_layernorm = LayerNorm(args.hidden_size, eps=args.layernorm_epsilon)
def get_model(model_provider_func): """Build the model.""" args = get_args() # Build model on cpu. model = model_provider_func() # Set tensor model parallel attributes if not set. # Only parameters that are already tensor model parallel have these # attributes set for them. We should make sure the default attributes # are set for all params so the optimizer can use them. for param in model.parameters(): mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param) # Print number of parameters. if mpu.get_data_parallel_rank() == 0: print(' > number of parameters on (tensor, pipeline) ' 'model parallel rank ({}, {}): {}'.format( mpu.get_tensor_model_parallel_rank(), mpu.get_pipeline_model_parallel_rank(), sum([p.nelement() for p in model.parameters()])), flush=True) # GPU allocation. model.cuda(torch.cuda.current_device()) # Fp16 conversion. if args.fp16: model = FP16Module(model) if args.DDP_impl == 'torch': i = torch.cuda.current_device() model = torchDDP(model, device_ids=[i], output_device=i, process_group=mpu.get_data_parallel_group()) return model if args.DDP_impl == 'local': model = LocalDDP(model) return model raise NotImplementedError('Unknown DDP implementation specified: {}. ' 'Exiting.'.format(args.DDP_impl))
def forward_step(batch, model, eval_metric): """Forward step.""" # Get the batch. tokens, labels, attention_mask, position_ids, loss_mask = process_batch( batch) # Tell the model what our actual batch size will be args = get_args() args.micro_batch_size = len(labels) input_tensor = recv_forward() # Forward pass through the model. unwrapped_model = unwrap_model(model, (torchDDP, LocalDDP, Float16Module)) unwrapped_model.set_input_tensor(input_tensor) output = model(tokens, position_ids, attention_mask) send_forward(output) if mpu.is_pipeline_last_stage(): # For loss, return the unreduced loss. if eval_metric == 'loss': losses = mpu.vocab_parallel_cross_entropy( output.contiguous().float(), labels.contiguous()) loss = torch.sum( losses.view(-1) * loss_mask.contiguous().view(-1).float()) return loss # For accuracy, return the number of correctly predicted samples. if eval_metric == 'accuracy': outputs = torch.argmax(output, -1) correct = (outputs == labels).float() correct[(1 - loss_mask).bool()] = 1 correct = correct.prod(-1) return correct.sum() raise NotImplementedError('forward method for evaluation metric {} ' 'is not implemented.'.format(eval_metric)) return None
def load_biencoder_checkpoint(model, only_query_model=False, only_context_model=False, custom_load_path=None): """ selectively load retrieval models for indexing/retrieving from saved checkpoints """ args = get_args() model = utils.unwrap_model(model) load_path = custom_load_path if custom_load_path is not None else args.load tracker_filename = get_checkpoint_tracker_filename(load_path) with open(tracker_filename, 'r') as f: iteration = int(f.read().strip()) checkpoint_name = get_checkpoint_name(load_path, iteration, False) if mpu.get_data_parallel_rank() == 0: print('global rank {} is loading checkpoint {}'.format( torch.distributed.get_rank(), checkpoint_name)) state_dict = torch.load(checkpoint_name, map_location='cpu') ret_state_dict = state_dict['model'] if only_query_model: ret_state_dict.pop('context_model') if only_context_model: ret_state_dict.pop('query_model') assert len(model) == 1 model[0].load_state_dict(ret_state_dict) torch.distributed.barrier() if mpu.get_data_parallel_rank() == 0: print(' successfully loaded {}'.format(checkpoint_name)) return model
def get_model(model_provider_func): """Build the model.""" args = get_args() # Build model on cpu. model = model_provider_func() # Print number of parameters. if mpu.get_data_parallel_rank() == 0: print(' > number of parameters on model parallel rank {}: {}'.format( mpu.get_model_parallel_rank(), sum([p.nelement() for p in model.parameters()])), flush=True) if args.deepspeed: # DeepSpeed handles CUDA, FP16, and DDP components. return model # GPU allocation. model.cuda(torch.cuda.current_device()) # Fp16 conversion. if args.fp16: model = FP16_Module(model) # Wrap model for distributed training.""" if args.DDP_impl == 'torch': i = torch.cuda.current_device() model = torchDDP(model, device_ids=[i], output_device=i, process_group=mpu.get_data_parallel_group()) return model if args.DDP_impl == 'local': model = LocalDDP(model) return model raise NotImplementedError('Unknown DDP implementation specified: {}. ' 'Exiting.'.format(args.DDP_impl))
def get_language_model(attention_mask_func, num_tokentypes, add_pooler, init_method, scaled_init_method): """Build language model and return along with the key to save.""" args = get_args() # Use torch gelu unless otherwise forced. gelu = F.gelu if args.openai_gelu: gelu = openai_gelu # Language model. language_model = TransformerLanguageModel( attention_mask_func=attention_mask_func, mlp_activation_func=gelu, init_method=init_method, output_layer_init_method=scaled_init_method, num_tokentypes=num_tokentypes, add_pooler=add_pooler) # key used for checkpoints. language_model_key = 'language_model' return language_model, language_model_key
def model_provider(): """Build the model.""" print_rank_0('building BERT model ...') args = get_args() if mpu.get_pipeline_model_parallel_world_size() > 1: # Determine model based on position of stage in pipeline. if mpu.is_pipeline_first_stage(): model = BertModelFirstStage(num_tokentypes=2) elif mpu.is_pipeline_last_stage(): model = BertModelLastStage(num_tokentypes=2, add_binary_head=True, parallel_output=True) else: model = BertModelIntermediateStage(num_tokentypes=2) else: model = BertModel(num_tokentypes=2, add_binary_head=True, parallel_output=True) return model
def get_one_epoch_nq_dataloader(dataset, micro_batch_size=None): """Data loader. Note that batch-size is the local (per GPU) batch-size. NOTE: This dataloader is not distributed !!! """ args = get_args() if micro_batch_size is None: micro_batch_size = args.micro_batch_size num_workers = args.num_workers sampler = torch.utils.data.SequentialSampler(dataset) # importantly, drop_last must be False to get all the data. batch_sampler = BatchSampler(sampler, batch_size=micro_batch_size, drop_last=False) # Data loader. Note that batch size is the per GPU batch size. data_loader = CustomDataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True) return data_loader
def __init__(self, attention_mask_func, mlp_activation_func, init_method, output_layer_init_method): super(ParallelTransformer, self).__init__() args = get_args() # Store activation checkpoiting flag. self.checkpoint_activations = args.checkpoint_activations self.checkpoint_num_layers = args.checkpoint_num_layers def get_layer(layer_number): return ParallelTransformerLayer(attention_mask_func, mlp_activation_func, init_method, output_layer_init_method, layer_number) # Transformer layers. self.layers = torch.nn.ModuleList( [get_layer(i + 1) for i in range(args.num_layers)]) # Final layer norm before output. self.final_layernorm = LayerNorm(args.hidden_size, eps=args.layernorm_epsilon)
def model_provider(): """Build the model.""" print_rank_0('building GPT2 model ...') see_memory_usage(f"Before Building Model", force=True) args = get_args() with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(), remote_device=None if args.remote_device == 'none' else args.remote_device, config=args.deepspeed_config, enabled=args.zero_stage == 3): model = GPT2Model(num_tokentypes=0, parallel_output=True) see_memory_usage(f"After Building Model", force=True) if mpu.get_data_parallel_rank() == 0: billion_params = get_parameters_in_billions(model) print( f' > number of parameters on model parallel rank {mpu.get_model_parallel_rank()}\ {round(billion_params, 3)} Billion', flush=True) return model
def __init__(self, num_tokentypes=0, parallel_output=True): super(T5Model, self).__init__() args = get_args() self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.parallel_output = parallel_output init_method = init_method_normal(args.init_method_std) scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers) self.language_model, self._language_model_key = get_language_model( num_tokentypes=num_tokentypes, add_pooler=False, add_decoder=True, encoder_attn_mask_type=AttnMaskType.padding, init_method=init_method, scaled_init_method=scaled_init_method) self.lm_head = T5LMHead( self.language_model.embedding.word_embeddings.weight.size(0), parallel_output) self._lm_head_key = 'lm_head'