def __init__(self, embedding_dim, ffn_embedding_dim, num_attention_heads, device='cpu', checkpoint_gradients=False): nn.Module.__init__(self) self.model_parallel_size = mpu.get_model_parallel_world_size() self.checkpoint_gradients = checkpoint_gradients assert ffn_embedding_dim % self.model_parallel_size == 0 # TODO: write a custom inplace LayerNorm layer self.attn_ln = nn.LayerNorm(embedding_dim).to(device) self.attn = ModelParallelMultiheadLMAttentionWithCache(embedding_dim, num_attention_heads, device=device) self.fc_ln = nn.LayerNorm(embedding_dim).to(device) self.fc1 = mpu.ColumnParallelLinear(embedding_dim, ffn_embedding_dim, gather_output=False, device=device) self.fc2 = mpu.RowParallelLinear(ffn_embedding_dim, embedding_dim, input_is_parallel=True, device=device)
def __init__(self, embed_dim, num_heads, bias=True, device='cpu'): nn.Module.__init__(self) self.embed_dim = embed_dim self.in_proj = mpu.ColumnParallelLinear(embed_dim, 3 * embed_dim, bias=bias, gather_output=False, device=device) self.out_proj = mpu.RowParallelLinear(embed_dim, embed_dim, bias=bias, input_is_parallel=True, device=device) self.model_parallel_size = mpu.get_model_parallel_world_size() self.num_total_heads = num_heads self.num_heads = self.num_total_heads // self.model_parallel_size assert ( self.num_heads * self.model_parallel_size == num_heads ), "Number of heads must be divisble by model parallel size" self.head_dim = embed_dim // num_heads assert ( self.head_dim * num_heads == self.embed_dim ), "embed_dim must be divisible by num_heads" self.scaling = self.head_dim ** -0.5
def test_row_parallel_linear(model_parallel_size): mpu.initialize_model_parallel(model_parallel_size) if torch.distributed.get_rank() == 0: print('> testing RowParallelLinear with model parallel ' 'size: {}'.format(model_parallel_size)) model_parallel_size = mpu.get_model_parallel_world_size() seed = 12345 set_random_seed(seed) input_size_coeff = 13 input_size = input_size_coeff * model_parallel_size output_size_coeff = 17 output_size = output_size_coeff * model_parallel_size batch_size = 7 # Network identity_layer = IdentityLayer2D(batch_size, input_size).cuda() linear_layer = mpu.RowParallelLinear( input_size, output_size, keep_master_weight_for_test=True).cuda() loss_weight = torch.randn([batch_size, output_size]).cuda() # Forward input_ = identity_layer() output = linear_layer(input_) loss = torch.mul(output, loss_weight).sum() # Backward loss.backward() # Values. dLdY = loss_weight X = identity_layer.weight A = linear_layer.master_weight.cuda() dLdA = torch.matmul(dLdY.t(), X) dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1) dLdX = torch.matmul(dLdY, A) rank = mpu.get_model_parallel_rank() my_dLdA = torch.split(dLdA, input_size_coeff, dim=1)[rank].contiguous().clone() error = my_dLdA.sub(linear_layer.weight.grad).abs().max() torch.distributed.barrier() print(' error in dLdA on global rank {}: {}'.format( torch.distributed.get_rank(), error)) assert error < 1.0e-6 error = dLdb.sub(linear_layer.bias.grad).abs().max() torch.distributed.barrier() print(' error in dLdb on global rank {}: {}'.format( torch.distributed.get_rank(), error)) assert error < 1.0e-6 error = dLdX.sub(identity_layer.weight.grad).abs().max() torch.distributed.barrier() print(' error in dLdX on global rank {}: {}'.format( torch.distributed.get_rank(), error)) assert error < 1.0e-6 # Reset groups mpu.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(' >> passed the test :-)')