def slice_state_dict(config, loaded_state_dict): sliced_state_dict = OrderedDict() start_layer_id = ( config.n_total_layers // mpu.get_pipeline_parallel_world_size() * mpu.get_pipeline_parallel_group_rank() + min(mpu.get_pipeline_parallel_group_rank(), config.n_total_layers % mpu.get_pipeline_parallel_world_size())) end_layer_id = start_layer_id + config.n_layers for key, value in loaded_state_dict.items(): keys = key.split('.') global_layer_id = int(keys[2]) if start_layer_id <= global_layer_id < end_layer_id: local_layer_id = global_layer_id - start_layer_id new_key = '.'.join(keys[:2] + [str(local_layer_id)] + keys[3:]) if keys[3] == 'attn' and keys[4] == 'in_proj': in_size = mpu.divide(value.size(0), mpu.get_model_parallel_world_size()) if keys[5] in ('weight', 'bias'): new_value = value[mpu.get_model_parallel_rank() * in_size:(mpu.get_model_parallel_rank() + 1) * in_size] else: raise NotImplementedError(f"Unknown key {key}") elif keys[3] == 'attn' and keys[4] == 'out_proj': if keys[5] == 'weight': out_size = mpu.divide(value.size(1), mpu.get_model_parallel_world_size()) new_value = value[:, mpu.get_model_parallel_rank() * out_size:(mpu.get_model_parallel_rank() + 1) * out_size] elif keys[5] == 'bias': new_value = value else: raise NotImplementedError(f"Unknown key {key}") elif keys[3] == 'fc1': in_size = mpu.divide(value.size(0), mpu.get_model_parallel_world_size()) if keys[4] in ('weight', 'bias'): new_value = value[mpu.get_model_parallel_rank() * in_size:(mpu.get_model_parallel_rank() + 1) * in_size] else: raise NotImplementedError(f"Unknown key {key}") elif keys[3] == 'fc2': if keys[4] == 'weight': out_size = mpu.divide(value.size(1), mpu.get_model_parallel_world_size()) new_value = value[:, mpu.get_model_parallel_rank() * out_size:(mpu.get_model_parallel_rank() + 1) * out_size] elif keys[4] == 'bias': new_value = value else: raise NotImplementedError(f"Unknown key {key}") else: new_value = value sliced_state_dict[new_key] = new_value return sliced_state_dict
def parallel_self_attention(model_parallel_size, num_att_heads_per_partition, hidden_size_per_att_head, dropout_prob, batch_size, sequence_length): mpu.initialize_model_parallel(model_parallel_size) model_parallel_size = mpu.get_model_parallel_world_size() seed = 12345 set_random_seed(seed) num_att_heads = num_att_heads_per_partition * \ torch.distributed.get_world_size() hidden_size = hidden_size_per_att_head * num_att_heads # Network identity_layer = IdentityLayer3D(batch_size, sequence_length, hidden_size).cuda() attention_layer = mpu.BertParallelSelfAttention(hidden_size, num_att_heads, dropout_prob).cuda() loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda() attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda() # Forward input_ = identity_layer() output = attention_layer(input_, attention_mask) loss = torch.mul(output, loss_weight).sum() # Backward loss.backward() rank = mpu.get_model_parallel_rank() mpu.destroy_model_parallel() return rank, hidden_size, model_parallel_size, loss, \ attention_layer, identity_layer
def main_worker(gpu, ngpus_per_node, args): global best_acc1 args.gpu = gpu if args.gpu is not None: print("Use GPU: {} for training".format(args.gpu)) if args.distributed: if args.dist_url == "env://" and args.rank == -1: args.rank = int(os.environ["RANK"]) if args.multiprocessing_distributed: # For multiprocessing distributed training, rank needs to be the # global rank among all the processes args.rank = args.rank * ngpus_per_node + gpu dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) mpu.initialize_model_parallel(args.model_parallel_size) # create model conf_dict = EasyDict(yaml.load(open(args.cfg, "r"), Loader=yaml.Loader)) conf_dict.world_size = mpu.get_model_parallel_world_size() conf_dict.gpu = args.gpu conf_dict.device = torch.device( "cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu") solver = Solver(conf_dict) solver.train()
def prepare_tokenizer(args): tokenizer_args = { 'tokenizer_type': args.tokenizer_type, 'corpus': None, 'model_path': args.tokenizer_path, 'vocab_size': args.vocab_size, 'model_type': args.tokenizer_model_type, 'cache_dir': args.cache_dir} tokenizer = make_tokenizer(**tokenizer_args) num_tokens = tokenizer.num_tokens before = num_tokens after = before multiple = args.make_vocab_size_divisible_by * \ mpu.get_model_parallel_world_size() while (after % multiple) != 0: after += 1 print_rank_0('> padded vocab (size: {}) with {} dummy ' 'tokens (new size: {})'.format( before, after - before, after)) args.tokenizer_num_tokens = after args.tokenizer_num_type_tokens = tokenizer.num_type_tokens args.eod_token = tokenizer.get_command('eos').Id # after = tokenizer.num_tokens # while after % mpu.get_model_parallel_world_size() != 0: # after += 1 args.vocab_size = after print("prepare tokenizer done", flush=True) return tokenizer
def prepare_tokenizer(args): tokenizer_args = { 'tokenizer_type': args.tokenizer_type, 'corpus': None, 'model_path': args.tokenizer_path, 'vocab_size': args.vocab_size, 'model_type': args.tokenizer_model_type, 'cache_dir': args.cache_dir } tokenizer = make_tokenizer(**tokenizer_args) args.tokenizer_num_tokens = tokenizer.num_tokens args.tokenizer_num_type_tokens = tokenizer.num_type_tokens args.eod_token = tokenizer.get_command('eos').Id after = tokenizer.num_tokens multiple = args.make_vocab_size_divisible_by * \ mpu.get_model_parallel_world_size() if multiple != 0: while (after % multiple) != 0: after += 1 args.vocab_size = after print("prepare tokenizer done", flush=True) return tokenizer
def test_boradcast_data(model_parallel_size): if torch.distributed.get_rank() == 0: print( '> testing boradcast_data with model parallel size {} ...'.format( model_parallel_size)) mpu.initialize_model_parallel(model_parallel_size) torch.manual_seed(1234 + mpu.get_data_parallel_rank()) model_parallel_size = mpu.get_model_parallel_world_size() key_size_t = { 'key1': [7, 11], 'key2': [8, 2, 1], 'key3': [13], 'key4': [5, 1, 2], 'key5': [5, 12] } keys = list(key_size_t.keys()) data = {} data_t = {} for key in key_size_t: data[key] = torch.LongTensor(size=key_size_t[key]).random_(0, 1000) data_t[key] = data[key].clone() data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000) data_t['keyX'] = data['keyX'].clone() if mpu.get_model_parallel_rank() != 0: data = None data_utils._check_data_types(keys, data_t, torch.int64) key_size, key_numel, \ total_numel = data_utils._build_key_size_numel_dictionaries(keys, data) for key in keys: assert key_size[key] == key_size_t[key] total_numel_t = 0 for key in keys: target_size = functools.reduce(operator.mul, key_size_t[key], 1) assert key_numel[key] == target_size total_numel_t += target_size assert total_numel == total_numel_t data_b = data_utils.broadcast_data(keys, data, torch.int64) for key in keys: tensor = data_t[key].cuda() assert data_b[key].sub(tensor).abs().max() == 0 # Reset groups mpu.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print('>> passed the test :-)')
def __init__(self, embedding_dim, ffn_embedding_dim, num_attention_heads, device='cpu', checkpoint_gradients=False): nn.Module.__init__(self) self.model_parallel_size = mpu.get_model_parallel_world_size() self.checkpoint_gradients = checkpoint_gradients assert ffn_embedding_dim % self.model_parallel_size == 0 # TODO: write a custom inplace LayerNorm layer self.attn_ln = nn.LayerNorm(embedding_dim).to(device) self.attn = ModelParallelMultiheadLMAttentionWithCache(embedding_dim, num_attention_heads, device=device) self.fc_ln = nn.LayerNorm(embedding_dim).to(device) self.fc1 = mpu.ColumnParallelLinear(embedding_dim, ffn_embedding_dim, gather_output=False, device=device) self.fc2 = mpu.RowParallelLinear(ffn_embedding_dim, embedding_dim, input_is_parallel=True, device=device)
def get_train_val_test_data(args): """Load the data on rank zero and boradcast number of tokens to all GPUS.""" (train_data, val_data, test_data) = (None, None, None) # Data loader only on rank 0 of each model parallel group. if mpu.get_model_parallel_rank() == 0: if args.use_npy_data_loader: (train_data, val_data, test_data), num_tokens, \ eod_token = make_gpt2_dataloaders(args) else: data_config = configure_data() data_config.set_defaults(data_set_type='GPT2', transpose=False) (train_data, val_data, test_data), tokenizer = data_config.apply(args) num_tokens = tokenizer.num_tokens eod_token = tokenizer.get_command('eos').Id assert eod_token == tokenizer.get_command('pad').Id before = num_tokens after = before multiple = args.make_vocab_size_divisible_by * \ mpu.get_model_parallel_world_size() while (after % multiple) != 0: after += 1 print_rank_0('> padded vocab (size: {}) with {} dummy ' 'tokens (new size: {})'.format(before, after - before, after)) print_rank_0('> found end-of-document token: {}'.format(eod_token)) token_counts = torch.cuda.LongTensor([ after, eod_token, int(args.do_train), int(args.do_valid), int(args.do_test) ]) else: token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0]) # Broadcast num tokens. torch.distributed.broadcast(token_counts, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) num_tokens = token_counts[0].item() eod_token = token_counts[1].item() args.do_train = token_counts[2].item() args.do_valid = token_counts[3].item() args.do_test = token_counts[4].item() return train_data, val_data, test_data, num_tokens, eod_token
def evaluate_tnews(args, model, dataloader, device, mode="dev"): model.eval() all_truth, all_preds = [], [] with torch.no_grad(): for batch, no_model_batch in tqdm(dataloader, desc="Evaluating {}".format(mode), disable=(torch.distributed.get_rank() != 0)): for k in batch: batch[k] = batch[k].to(device) for k in no_model_batch: no_model_batch[k] = no_model_batch[k].to(device) output = model(**batch) output = torch.sum(output * no_model_batch["loss_mask"].unsqueeze(-1), 1) / torch.sum( no_model_batch["loss_mask"], -1).unsqueeze(-1) # gather the output logits from other gpus tensor_list = [torch.zeros_like(output) for _ in range(mpu.get_data_parallel_world_size())] torch.distributed.all_gather(tensor_list, output, mpu.get_data_parallel_group()) # gather the truth labels from other gpus tensor_list_truth = [torch.zeros_like(no_model_batch["truth"], dtype=torch.long) for _ in range(mpu.get_data_parallel_world_size())] torch.distributed.all_gather(tensor_list_truth, no_model_batch["truth"], mpu.get_data_parallel_group()) if args.model_parallel_size == 1: scores = torch.stack(tensor_list, 0).view(-1, 30000) else: assert args.model_parallel_size == 2, "Now, we only support model parallel <= 2" # for convience implementation. Note that the truth labels only appears in the first 15000 part of the logits, e.g. on rank 0, 2, 4, ... scores = torch.stack(tensor_list, 0).view(-1, 15000) truth = torch.stack(tensor_list_truth, 0) truth = truth.view(-1) # scores = scores[:, cand_ids] preds = torch.argmax(scores, dim=-1) all_truth.extend(truth.detach().cpu().tolist()) all_preds.extend(preds.detach().cpu().tolist()) acc = sum([int(p == l) for p, l in zip(all_preds, all_truth)]) / len(all_truth) acc = torch.tensor(acc).to(device) acc_list = [torch.zeros_like(acc) for _ in range(mpu.get_model_parallel_world_size())] torch.distributed.all_gather(acc_list, acc, mpu.get_model_parallel_group()) return acc_list[0].item(), all_truth, all_preds
def get_train_val_test_data(args): """Load the data on rank zero and boradcast number of tokens to all GPUS.""" (train_data, val_data, test_data) = (None, None, None) # Data loader only on rank 0 of each model parallel group. if mpu.get_model_parallel_rank() == 0: data_config = configure_data() ds_type = 'BERT' data_config.set_defaults(data_set_type=ds_type, transpose=False) (train_data, val_data, test_data), tokenizer = data_config.apply(args) before = tokenizer.num_tokens after = before multiple = args.make_vocab_size_divisible_by * \ mpu.get_model_parallel_world_size() while (after % multiple) != 0: after += 1 print_rank_0('> padded vocab (size: {}) with {} dummy ' 'tokens (new size: {})'.format(before, after - before, after)) # Need to broadcast num_tokens and num_type_tokens. token_counts = torch.cuda.LongTensor([ after, tokenizer.num_type_tokens, int(args.do_train), int(args.do_valid), int(args.do_test) ]) else: token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0]) # Broadcast num tokens. torch.distributed.broadcast(token_counts, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) num_tokens = token_counts[0].item() num_type_tokens = token_counts[1].item() args.do_train = token_counts[2].item() args.do_valid = token_counts[3].item() args.do_test = token_counts[4].item() return train_data, val_data, test_data, num_tokens, num_type_tokens
def prepare_tokenizer(args): add_sentinel_token = 0 if args.sentinel_token: add_sentinel_token = args.max_position_embeddings tokenizer = make_tokenizer(args.tokenizer_type, None, args.tokenizer_path, args.vocab_size, args.tokenizer_model_type, add_block_symbols=args.block_lm, cache_dir=args.cache_dir, add_sentinel_token=add_sentinel_token, add_task_mask=args.task_mask, add_decoder_mask=args.block_mask_prob > 0.0 or args.context_mask_ratio > 0.0) if mpu.get_model_parallel_rank() == 0: num_tokens = tokenizer.num_tokens eod_token = tokenizer.get_command('eos').Id assert eod_token == tokenizer.get_command('pad').Id before = num_tokens after = before multiple = args.make_vocab_size_divisible_by * \ mpu.get_model_parallel_world_size() while (after % multiple) != 0: after += 1 print_rank_0('> padded vocab (size: {}) with {} dummy ' 'tokens (new size: {})'.format(before, after - before, after)) print_rank_0('> found end-of-document token: {}'.format(eod_token)) token_counts = torch.cuda.LongTensor([after, eod_token]) else: token_counts = torch.cuda.LongTensor([0, 0]) # Broadcast num tokens. torch.distributed.broadcast(token_counts, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) num_tokens = token_counts[0].item() eod_token = token_counts[1].item() args.vocab_size, args.eod_token = num_tokens, eod_token return tokenizer
def test_cross_entropy(model_parallel_size): if torch.distributed.get_rank() == 0: print('> testing cross entropy with model parallel size {} ...'.format( model_parallel_size)) mpu.initialize_model_parallel(model_parallel_size) model_parallel_size = mpu.get_model_parallel_world_size() batch_size = 13 seq_length = 17 vocab_size_per_partition = 11 logits_scale = 1000.0 vocab_size = vocab_size_per_partition * model_parallel_size seed = 1234 loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed) loss_mpu, grad_mpu = mpu_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed) error = loss_torch.sub_(loss_mpu).abs().max() print(' max error in loss on global rank {}: {}'.format( torch.distributed.get_rank(), error)) assert error < 1.0e-6 error = grad_torch.sub_(grad_mpu).abs().max() print(' max error in grad on global rank {}: {}'.format( torch.distributed.get_rank(), error)) assert error < 1.0e-6 # Reset groups mpu.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print('>> passed the test :-)')
def __init__(self, embed_dim, num_heads, bias=True, device='cpu'): nn.Module.__init__(self) self.embed_dim = embed_dim self.in_proj = mpu.ColumnParallelLinear(embed_dim, 3 * embed_dim, bias=bias, gather_output=False, device=device) self.out_proj = mpu.RowParallelLinear(embed_dim, embed_dim, bias=bias, input_is_parallel=True, device=device) self.model_parallel_size = mpu.get_model_parallel_world_size() self.num_total_heads = num_heads self.num_heads = self.num_total_heads // self.model_parallel_size assert ( self.num_heads * self.model_parallel_size == num_heads ), "Number of heads must be divisble by model parallel size" self.head_dim = embed_dim // num_heads assert ( self.head_dim * num_heads == self.embed_dim ), "embed_dim must be divisible by num_heads" self.scaling = self.head_dim ** -0.5
def test_initialize_model_parallel(model_parallel_size): if torch.distributed.get_rank() == 0: print('> testing initialize_model_parallel with size {} ...'.format( model_parallel_size)) model_parallel_size_ = min(model_parallel_size, torch.distributed.get_world_size()) assert not mpu.model_parallel_is_initialized() mpu.initialize_model_parallel(model_parallel_size_) assert mpu.model_parallel_is_initialized() # Checks. def check(group, world_size, rank): assert world_size == torch.distributed.get_world_size(group=group) assert rank == torch.distributed.get_rank(group=group) # Model parallel. world_size = model_parallel_size_ rank = torch.distributed.get_rank() % model_parallel_size_ assert world_size == mpu.get_model_parallel_world_size() assert rank == mpu.get_model_parallel_rank() check(mpu.get_model_parallel_group(), world_size, rank) # Data parallel. world_size = torch.distributed.get_world_size() // model_parallel_size_ rank = torch.distributed.get_rank() // model_parallel_size assert world_size == mpu.get_data_parallel_world_size() assert rank == mpu.get_data_parallel_rank() check(mpu.get_data_parallel_group(), world_size, rank) # Reset groups mpu.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print('>> passed the test :-)')
def test_model_parallel_cuda_manual_seed(model_parallel_size): if torch.distributed.get_rank() == 0: print('> testing model parallel cuda manual seed with size {} ...'. format(model_parallel_size)) mpu.initialize_model_parallel(model_parallel_size) model_parallel_size = mpu.get_model_parallel_world_size() mpu.model_parallel_cuda_manual_seed(12345) assert torch.cuda.initial_seed() == 12345 with mpu.get_cuda_rng_tracker().fork(): assert torch.cuda.initial_seed() == (12345 + 2718 + mpu.get_model_parallel_rank()) # Reset the tracker mpu.get_cuda_rng_tracker().reset() # Reset groups mpu.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print('>> passed the test :-)')
def parallel_transformer(model_parallel_size, num_att_heads_per_partition, hidden_size_per_att_head, batch_size, sequence_length): mpu.initialize_model_parallel(model_parallel_size) model_parallel_size = mpu.get_model_parallel_world_size() seed = 12345 set_random_seed(seed) num_att_heads = num_att_heads_per_partition * \ torch.distributed.get_world_size() hidden_size = hidden_size_per_att_head * num_att_heads intermediate_size = 4 * hidden_size # Network identity_layer = IdentityLayer3D(batch_size, sequence_length, hidden_size).cuda() transformer_layer = mpu.BertParallelTransformerLayer( hidden_size, intermediate_size, num_att_heads, 0.0, 0.0, torch.nn.functional.relu, 1.0e-5).cuda() loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda() attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda() # Forward input_ = identity_layer() output = transformer_layer(input_, attention_mask) loss = torch.mul(output, loss_weight).sum() # Backward loss.backward() rank = mpu.get_model_parallel_rank() mpu.destroy_model_parallel() return rank, hidden_size, model_parallel_size, loss, \ transformer_layer, identity_layer
def test_initialize_affine_weight(model_parallel_size): mpu.initialize_model_parallel(model_parallel_size) if torch.distributed.get_rank() == 0: print('> testing initialize_affine_weight with model parallel ' 'size: {}'.format(model_parallel_size)) model_parallel_size = mpu.get_model_parallel_world_size() seed = 12345 input_size_coeff = 13 input_size = input_size_coeff * model_parallel_size output_size_coeff = 17 output_size = output_size_coeff * model_parallel_size # --------------- # Column parallel # --------------- weight = torch.empty(output_size_coeff, input_size) set_random_seed(seed) layers._initialize_affine_weight(weight, output_size, input_size, output_size_coeff, 0, torch.nn.init.normal_) # Target. set_random_seed(seed) master_weight = torch.empty(output_size, input_size) torch.nn.init.normal_(master_weight) rank = mpu.get_model_parallel_rank() my_weight = torch.split(master_weight, output_size_coeff, dim=0)[rank].contiguous().clone() # Compare. error = weight.sub(my_weight).abs().max() torch.distributed.barrier() print(' column parallel max error (should be zero) on global rank ' '{}: {}'.format(torch.distributed.get_rank(), error)) assert error < 1.0e-6 # ------------ # Row parallel # ------------ weight = torch.empty(output_size, input_size_coeff) set_random_seed(seed) mpu.layers._initialize_affine_weight(weight, output_size, input_size, input_size_coeff, 1, torch.nn.init.normal_) # Target. set_random_seed(seed) master_weight = torch.empty(output_size, input_size) torch.nn.init.normal_(master_weight) rank = mpu.get_model_parallel_rank() my_weight = torch.split(master_weight, input_size_coeff, dim=1)[rank].contiguous().clone() # Compare. error = weight.sub(my_weight).abs().max() torch.distributed.barrier() print(' row parallel max error (should be zero) on global rank ' '{}: {}'.format(torch.distributed.get_rank(), error)) assert error < 1.0e-6 # Reset groups mpu.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(' >> passed the test :-)')
def test_row_parallel_linear(model_parallel_size): mpu.initialize_model_parallel(model_parallel_size) if torch.distributed.get_rank() == 0: print('> testing RowParallelLinear with model parallel ' 'size: {}'.format(model_parallel_size)) model_parallel_size = mpu.get_model_parallel_world_size() seed = 12345 set_random_seed(seed) input_size_coeff = 13 input_size = input_size_coeff * model_parallel_size output_size_coeff = 17 output_size = output_size_coeff * model_parallel_size batch_size = 7 # Network identity_layer = IdentityLayer2D(batch_size, input_size).cuda() linear_layer = mpu.RowParallelLinear( input_size, output_size, keep_master_weight_for_test=True).cuda() loss_weight = torch.randn([batch_size, output_size]).cuda() # Forward input_ = identity_layer() output = linear_layer(input_) loss = torch.mul(output, loss_weight).sum() # Backward loss.backward() # Values. dLdY = loss_weight X = identity_layer.weight A = linear_layer.master_weight.cuda() dLdA = torch.matmul(dLdY.t(), X) dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1) dLdX = torch.matmul(dLdY, A) rank = mpu.get_model_parallel_rank() my_dLdA = torch.split(dLdA, input_size_coeff, dim=1)[rank].contiguous().clone() error = my_dLdA.sub(linear_layer.weight.grad).abs().max() torch.distributed.barrier() print(' error in dLdA on global rank {}: {}'.format( torch.distributed.get_rank(), error)) assert error < 1.0e-6 error = dLdb.sub(linear_layer.bias.grad).abs().max() torch.distributed.barrier() print(' error in dLdb on global rank {}: {}'.format( torch.distributed.get_rank(), error)) assert error < 1.0e-6 error = dLdX.sub(identity_layer.weight.grad).abs().max() torch.distributed.barrier() print(' error in dLdX on global rank {}: {}'.format( torch.distributed.get_rank(), error)) assert error < 1.0e-6 # Reset groups mpu.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(' >> passed the test :-)')
def test_parallel_embedding(model_parallel_size): if torch.distributed.get_rank() == 0: print('> testing parallel embedding with model parallel size {} ...'. format(model_parallel_size)) mpu.initialize_model_parallel(model_parallel_size) model_parallel_size = mpu.get_model_parallel_world_size() batch_size = 17 seq_length = 23 vocab_size = 48 hidden_size = 16 seed = 1236 set_random_seed(123) input_data = torch.LongTensor(size=(batch_size, seq_length)).random_( 0, vocab_size).cuda() loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda() set_random_seed(seed) embedding_original = torch.nn.Embedding(vocab_size, hidden_size).cuda() output = embedding_original(input_data) loss_original = torch.mul(output, loss_weight).sum() loss_original.backward() set_random_seed(seed) embedding_parallel = layers.ParallelEmbedding( vocab_size, hidden_size, init_method=init.normal_).cuda() output = embedding_parallel(input_data) loss_parallel = torch.mul(output, loss_weight).sum() loss_parallel.backward() set_random_seed(seed) embedding_vocab_parallel = layers.VocabParallelEmbedding( vocab_size, hidden_size, init_method=init.normal_).cuda() output = embedding_vocab_parallel(input_data) loss_vocab_parallel = torch.mul(output, loss_weight).sum() loss_vocab_parallel.backward() torch.distributed.barrier() error = loss_parallel.sub(loss_original).abs() print(' error in loss (parallel) on global rank {}: {}'.format( torch.distributed.get_rank(), error)) assert error < 1.0e-12, 'error: {}'.format(error) torch.distributed.barrier() error = loss_vocab_parallel.sub(loss_original).abs() print(' error in loss (vocab parallel) on global rank {}: {}'.format( torch.distributed.get_rank(), error)) assert error < 1.0e-12, 'error: {}'.format(error) weight_grad_orig = torch.split(embedding_original.weight.grad, hidden_size // model_parallel_size, 1)[mpu.get_model_parallel_rank()] error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max() print(' error in grad (parallel) on global rank {}: {}'.format( torch.distributed.get_rank(), error)) assert error < 1.0e-12, 'error: {}'.format(error) weight_grad_orig = torch.split(embedding_original.weight.grad, vocab_size // model_parallel_size, 0)[mpu.get_model_parallel_rank()] error = embedding_vocab_parallel.weight.grad.sub( weight_grad_orig).abs().max() print(' error in grad (vocab parallel) on global rank {}: {}'.format( torch.distributed.get_rank(), error)) assert error < 1.0e-12, 'error: {}'.format(error) # Reset groups mpu.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print('>> passed the test :-)')
def test_cuda_rng_tracker(model_parallel_size): if torch.distributed.get_rank() == 0: print('> testing cuda rng tracker with size {} ...'.format( model_parallel_size)) mpu.initialize_model_parallel(model_parallel_size) model_parallel_size = mpu.get_model_parallel_world_size() seed_1 = 1234 seed_2 = 4321 size = [12, 21] tensor = torch.cuda.FloatTensor(size) # Set to seed_1 and generate two tensors. torch.cuda.manual_seed(seed_1) torch.randn(size, out=tensor) target_11 = tensor.clone() torch.randn(size, out=tensor) target_12 = tensor.clone() # Set to seed_2 and generate two tensors. torch.cuda.manual_seed(seed_2) torch.randn(size, out=tensor) target_21 = tensor.clone() torch.randn(size, out=tensor) target_22 = tensor.clone() # Now if we interleave seed_1 and seed_2, # we should still get the same tensors torch.cuda.manual_seed(seed_1) mpu.get_cuda_rng_tracker().add('test', seed_2) torch.randn(size, out=tensor) result_11 = tensor.clone() with mpu.get_cuda_rng_tracker().fork('test'): torch.randn(size, out=tensor) result_21 = tensor.clone() torch.randn(size, out=tensor) result_12 = tensor.clone() with mpu.get_cuda_rng_tracker().fork('test'): torch.randn(size, out=tensor) result_22 = tensor.clone() diff = result_11.sub(result_21).abs().max() diff = min(diff, result_12.sub(result_22).abs().max()) print(' max diff in generated tensors (should be non-zero) on ' 'global rank {}: {}'.format(torch.distributed.get_rank(), diff)) assert diff > 1.0e-6 error = max( result_11.sub(target_11).abs().max(), result_12.sub(target_12).abs().max()) error = max(error, result_21.sub(target_21).abs().max()) error = max(error, result_22.sub(target_22).abs().max()) print(' max error in generated tensors (should be zero) on ' 'global rank {}: {}'.format(torch.distributed.get_rank(), error)) assert error < 1.0e-6 # Reset the tracker mpu.get_cuda_rng_tracker().reset() # Reset groups mpu.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print('>> passed the test :-)')
def test_set_cuda_rng_state(model_parallel_size): if torch.distributed.get_rank() == 0: print('> testing set_rng_state with size {} ...'.format( model_parallel_size)) mpu.initialize_model_parallel(model_parallel_size) model_parallel_size = mpu.get_model_parallel_world_size() size = 123 seed = 1234 torch.cuda.manual_seed(1234) tensor = torch.cuda.FloatTensor(size) # Get the state rng_state = torch.cuda.get_rng_state() rng_state_copy = rng_state.clone() # Do some stuff. for _ in range(5): torch.randn(size, out=tensor) result_1 = tensor.clone() assert rng_state.sub(rng_state_copy).max() == 0 assert torch.cuda.get_rng_state().sub(rng_state_copy).max() > 0 # State should be different. new_rng_state = torch.cuda.get_rng_state() max_diff = new_rng_state.sub(rng_state).max() print( ' max diff in rng state (should be non-zero) on global rank {}: {}'. format(torch.distributed.get_rank(), max_diff)) assert max_diff > 0 # Reset the rng state and do the same stuff. mpu.random._set_cuda_rng_state(rng_state) for _ in range(5): torch.randn(size, out=tensor) mpu.random._set_cuda_rng_state(rng_state) for _ in range(5): torch.randn(size, out=tensor) result_2 = tensor.clone() # Results should be the same error = result_2.sub(result_1).abs().max() print(' max error in generated tensors (should be zero) on ' 'global rank {}: {}'.format(torch.distributed.get_rank(), error)) assert error < 1.0e-6 # Input state should have remained intact. error = rng_state.sub(rng_state_copy).max() print(' max error in rng state (should be zero) on global rank {}: {}'. format(torch.distributed.get_rank(), error)) assert error == 0 # Reset groups mpu.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print('>> passed the test :-)')
def verify_one_step(args): if args.verify == "save": assert dist.get_world_size() == 1 assert mpu.get_pipeline_parallel_world_size() == 1 assert mpu.get_model_parallel_world_size() == 1 assert args.n_input_slices == 1 assert args.n_batch_slices == 1 os.makedirs(args.verify_path, exist_ok=True) config, layers, pipelined_layers = initialize_model(args) if mpu.get_pipeline_parallel_group_rank() == 0: x = layers.create_inputs(config.batch_size, config.seq_len, random=True) torch.save(x, os.path.join(args.verify_path, 'input.pt')) else: x = None try: y = pipelined_layers(x) if mpu.get_pipeline_parallel_group_rank( ) == mpu.get_pipeline_parallel_world_size() - 1: loss = loss_func(y) loss.backward() else: y.backward() except: print(f"rank={args.rank}", traceback.format_exc()) raise torch.save(pipelined_layers.state_dict(), os.path.join(args.verify_path, 'model.ckpt')) grad_dic = OrderedDict( (x[0], x[1].grad) for x in pipelined_layers.named_parameters()) torch.save(grad_dic, os.path.join(args.verify_path, 'model.grad.ckpt')) else: assert args.verify == "load" config, layers, pipelined_layers = initialize_model(args) with FileLock(os.path.join(args.verify_path, 'model.ckpt.lock')): loaded_state_dict = torch.load(os.path.join( args.verify_path, 'model.ckpt'), map_location=torch.device('cuda')) sliced_state_dict = slice_state_dict(config, loaded_state_dict) pipelined_layers.load_state_dict(sliced_state_dict) if mpu.get_pipeline_parallel_group_rank() == 0: with FileLock(os.path.join(args.verify_path, 'input.pt.lock')): x = torch.load(os.path.join(args.verify_path, 'input.pt'), map_location=torch.device('cuda')) else: x = None try: y = pipelined_layers(x) if mpu.get_pipeline_parallel_group_rank( ) == mpu.get_pipeline_parallel_world_size() - 1: loss = loss_func(y) loss.backward() else: y.backward() except: print(f"rank={args.rank}", traceback.format_exc()) raise grad_dic = OrderedDict( (x[0], x[1].grad) for x in pipelined_layers.named_parameters()) with FileLock(os.path.join(args.verify_path, 'model.grad.ckpt.lock')): loaded_grad_dic = torch.load(os.path.join(args.verify_path, 'model.grad.ckpt'), map_location=torch.device('cuda')) sliced_grad_dic = slice_state_dict(config, loaded_grad_dic) assert grad_dic.keys() == sliced_grad_dic.keys() for k in grad_dic.keys(): assert torch.allclose(grad_dic[k], sliced_grad_dic[k])
def get_eval_data(args): val_dataloader = None if mpu.get_model_parallel_rank() == 0: eval_batch_size = args.eval_batch_size eval_batch_size = args.batch_size if eval_batch_size is None else eval_batch_size seq_len = args.seq_length valid_data = args.valid_data valid_data = valid_data[0] if isinstance(valid_data, list) else valid_data tokenizer = get_tokenizer(args) if not args.cloze_eval: with open(valid_data, "rb") as reader: entire_data = reader.read().decode('utf-8') num_original_tokens = len(entire_data.strip().split(" ")) entire_data = get_detokenizer(valid_data)(entire_data) tokenized_data = tokenizer.EncodeAsIds(entire_data).tokenization num_tokenized_tokens = len(tokenized_data) string = 'Original Tokens: %d, Detokenized tokens: %d' % (num_tokenized_tokens, num_original_tokens) print_rank_0(string) eod_token = tokenizer.get_command('pad').Id val_dataset = LM_Eval_Dataset(tokenized_data, seq_len, eod_token, args.overlapping_eval) else: val_dataset = Lambada_Eval_Dataset(valid_data, tokenizer, seq_len) num_tokenized_tokens = 0 num_original_tokens = 0 val_dataloader = torch.utils.data.DataLoader( val_dataset, batch_size=eval_batch_size, drop_last=False) before = tokenizer.num_tokens after = before while after % mpu.get_model_parallel_world_size() != 0: after += 1 print_rank_0('> padded vocab (size: {}) with {} dummy tokens (new size: {})'. format(before, after - before, after)) eod_token = tokenizer.get_command('pad').Id num_examples = len(val_dataset) token_counts = torch.cuda.LongTensor([after, eod_token, num_examples, num_original_tokens, num_tokenized_tokens]) else: token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0]) torch.distributed.broadcast(token_counts, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) args.vocab_size = token_counts[0].item() args.eod_token = token_counts[1].item() args.num_examples = token_counts[2].item() args.num_original_tokens = token_counts[3].item() args.num_tokenized_tokens = token_counts[4].item() print('global rank: {} | vocab size: {} | eod token: {} | ' 'num_examples: {} | num_original_tokens: {} | ' 'num_tokenized_tokens: {}'.format( torch.distributed.get_rank(), args.vocab_size, args.eod_token, args.num_examples, args.num_original_tokens, args.num_tokenized_tokens )) return val_dataloader