def forward(self, queries, keys, values, mask=None): """Apply multi-head attention. The input and output tensors are interpreted as sequences of vectors, where the first tensor dimension is the sequence dimension. Args: queries (lbann.Layer): Sequence of query vectors. keys (lbann.Layer): Sequence of key vectors. values (lbann.Layer): Sequence of value vectors. mask (lbann.Layer, optional): Additive attention mask. If the (i,j) entry is very negative (e.g. -1e9), then the ith query does not attend to the jth key/value pair. Returns: lbann.Layer: Sequence of output vectors. The sequence length is the same as `queries`. """ self.instance += 1 name = f'{self.name}_instance{self.instance}' # Apply fully-connected layers to input sequences queries_fc = lbann.ChannelwiseFullyConnected( queries, weights=self.query_weights, output_channel_dims=[self.embed_dim], name=f'{name}_queries_fc', ) keys_fc = lbann.ChannelwiseFullyConnected( keys, weights=self.key_weights, output_channel_dims=[self.embed_dim], name=f'{name}_keys_fc', ) values_fc = lbann.ChannelwiseFullyConnected( values, weights=self.value_weights, output_channel_dims=[self.embed_dim], name=f'{name}_values_fc', ) # Slice embedding vectors for each head slice_points = str_list(self.head_dim * i for i in range(self.num_heads + 1)) queries_slice = lbann.Slice( queries_fc, axis=1, slice_points=slice_points, name=f'{name}_queries_slice', ) keys_slice = lbann.Slice( keys_fc, axis=1, slice_points=slice_points, name=f'{name}_keys_slice', ) values_slice = lbann.Slice( values_fc, axis=1, slice_points=slice_points, name=f'{name}_values_slice', ) # Compute scaled dot-product attention for each head attentions = [] for head in range(self.num_heads): head_name = f'{name}_head{head}' # Attention inputs q = lbann.Identity(queries_slice) k = lbann.Identity(keys_slice) v = lbann.Identity(values_slice) # Multiply queries and keys # Note: num_queries x num_keys y = lbann.MatMul( q, k, transpose_b=True, name=f'{head_name}_matmul', ) y = lbann.WeightedSum( y, scaling_factors=str(1 / math.sqrt(self.head_dim)), name=f'{head_name}_scale', ) if mask: y = lbann.Add(y, mask, name=f'{head_name}_mask') y = lbann.ChannelwiseSoftmax(y, name=f'{head_name}_softmax') # Attention output # Note: num_queries x head_dim attentions.append(lbann.MatMul(y, v, name=head_name)) # Concatenate heads and apply fully-connected layer attentions = lbann.Concatenation(attentions, axis=1, name=f'{name}_heads_concat') outputs_fc = lbann.ChannelwiseFullyConnected( attentions, weights=self.output_weights, output_channel_dims=[self.embed_dim], name=f'{name}', ) return outputs_fc
def make_model( num_epochs, embed_dim, num_heads, label_smoothing, ): # Embedding weights var = 2 / (embed_dim + vocab_size) # Glorot initialization embedding_weights = lbann.Weights( name='embeddings', initializer=lbann.NormalInitializer(standard_deviation=math.sqrt(var)), ) # Input is two sequences of token IDs input_ = lbann.Input(data_field='samples') # Get sequences of embedding vectors # Note: Scale embeddings by sqrt(embed_dim). # Note: Decoder input is shifted right, so embedding for last # token isn't needed. embeddings_tokens = lbann.Identity( lbann.Slice( input_, axis=0, slice_points=str_list([0, 2 * sequence_length - 1]), )) embeddings = lbann.Embedding( embeddings_tokens, weights=embedding_weights, num_embeddings=vocab_size, embedding_dim=embed_dim, padding_idx=pad_index, ) embeddings = lbann.WeightedSum( embeddings, scaling_factors=str(math.sqrt(embed_dim)), ) embeddings_slice = lbann.Slice( embeddings, axis=0, slice_points=str_list([0, sequence_length, 2 * sequence_length - 1]), ) encoder_input = lbann.Identity(embeddings_slice) decoder_input = lbann.Identity(embeddings_slice) # Apply transformer model transformer = lbann.models.Transformer( hidden_size=embed_dim, num_heads=num_heads, name='transformer', ) result = transformer( encoder_input, sequence_length, decoder_input, sequence_length - 1, ) # Reconstruct decoder input preds = lbann.ChannelwiseFullyConnected( result, weights=embedding_weights, output_channel_dims=[vocab_size], bias=False, transpose=True, ) preds = lbann.ChannelwiseSoftmax(preds) preds = lbann.Slice(preds, axis=0, slice_points=str_list(range(sequence_length))) preds = [lbann.Identity(preds) for _ in range(sequence_length - 1)] # Count number of non-pad tokens label_tokens = lbann.Identity( lbann.Slice( input_, slice_points=str_list([sequence_length + 1, 2 * sequence_length]), )) pads = lbann.Constant(value=pad_index, num_neurons=str(sequence_length - 1)) is_not_pad = lbann.NotEqual(label_tokens, pads) num_not_pad = lbann.Reduction(is_not_pad, mode='sum') # Cross entropy loss with label smoothing label_tokens = lbann.Slice( label_tokens, slice_points=str_list(range(sequence_length)), ) label_tokens = [ lbann.Identity(label_tokens) for _ in range(sequence_length - 1) ] if label_smoothing > 0: uniform_label = lbann.Constant(value=1 / vocab_size, num_neurons=str_list([1, vocab_size])) loss = [] for i in range(sequence_length - 1): label = lbann.OneHot(label_tokens[i], size=vocab_size) label = lbann.Reshape(label, dims=str_list([1, vocab_size])) if label_smoothing > 0: label = lbann.WeightedSum( label, uniform_label, scaling_factors=str_list( [1 - label_smoothing, label_smoothing]), ) loss.append(lbann.CrossEntropy(preds[i], label)) loss = lbann.Concatenation(loss) # Average cross entropy over non-pad tokens loss_scales = lbann.Divide( is_not_pad, lbann.Tessellate(num_not_pad, hint_layer=is_not_pad), ) loss = lbann.Multiply(loss, loss_scales) loss = lbann.Reduction(loss, mode='sum') # Construct model metrics = [] callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()] return lbann.Model( num_epochs, layers=lbann.traverse_layer_graph(input_), objective_function=loss, metrics=metrics, callbacks=callbacks, )
def forward(self, queries, keys, values, mask=None): """Apply multi-head attention. The input and output tensors are interpreted as sequences of vectors, where the first tensor dimension is the sequence dimension. Args: queries (lbann.Layer): Sequence of query vectors. keys (lbann.Layer): Sequence of key vectors. values (lbann.Layer): Sequence of value vectors. mask (lbann.Layer, optional): Additive attention mask. If the (i,j) entry is very negative (e.g. -1e9), then the ith query does not attend to the jth key/value pair. Returns: lbann.Layer: Sequence of output vectors. The sequence length is the same as `queries`. """ ENABLE_SUBGRAPH = self.ENABLE_SUBGRAPH BRANCHES = self.BRANCHES if (ENABLE_SUBGRAPH): if (self.num_heads % BRANCHES != 0): raise ValueError('Num heads should be divisible by BRANCHES') self.instance += 1 name = f'{self.name}_instance{self.instance}' # Apply fully-connected layers to input sequences queries_fc = lbann.ChannelwiseFullyConnected( queries, weights=self.query_weights, output_channel_dims=[self.inner_dim], name=f'{name}_queries_fc', ) keys_fc = lbann.ChannelwiseFullyConnected( keys, weights=self.key_weights, output_channel_dims=[self.inner_dim], name=f'{name}_keys_fc', ) values_fc = lbann.ChannelwiseFullyConnected( values, weights=self.value_weights, output_channel_dims=[self.inner_dim], name=f'{name}_values_fc', ) # Slice embedding vectors for each head slice_points = str_list(self.head_dim * i for i in range(self.num_heads + 1)) queries_slice = lbann.Slice(queries_fc, axis=1, slice_points=slice_points, name=f'{name}_queries_slice', parallel_strategy={ 'sub_branch_tag': 0, 'enable_subgraph': ENABLE_SUBGRAPH }) keys_slice = lbann.Slice(keys_fc, axis=1, slice_points=slice_points, name=f'{name}_keys_slice', parallel_strategy={ 'sub_branch_tag': 0, 'enable_subgraph': ENABLE_SUBGRAPH }) values_slice = lbann.Slice(values_fc, axis=1, slice_points=slice_points, name=f'{name}_values_slice', parallel_strategy={ 'sub_branch_tag': 0, 'enable_subgraph': ENABLE_SUBGRAPH }) # Compute scaled dot-product attention for each head attentions = [] tag = 0 for head in range(self.num_heads): head_name = f'{name}_myattention_head{head}' # Attention inputs if (ENABLE_SUBGRAPH): if (head % int(self.num_heads / BRANCHES) == 0): tag += 1 q = lbann.Identity(queries_slice, parallel_strategy={ 'sub_branch_tag': tag, 'enable_subgraph': ENABLE_SUBGRAPH }) k = lbann.Identity(keys_slice, parallel_strategy={ 'sub_branch_tag': tag, 'enable_subgraph': ENABLE_SUBGRAPH }) v = lbann.Identity(values_slice, parallel_strategy={ 'sub_branch_tag': tag, 'enable_subgraph': ENABLE_SUBGRAPH }) else: q = lbann.Identity(queries_slice) k = lbann.Identity(keys_slice) v = lbann.Identity(values_slice) # Multiply queries and keys # Note: num_queries x num_keys y = lbann.MatMul( q, k, transpose_b=True, name=f'{head_name}_matmul', ) y = lbann.WeightedSum( y, scaling_factors=str(1 / math.sqrt(self.head_dim)), name=f'{head_name}_scale', ) if (ENABLE_SUBGRAPH): if mask != None: y = lbann.Sum([y, mask[tag]], name=f'{head_name}_mask') else: if mask: y = lbann.Sum([y, mask], name=f'{head_name}_mask') y = lbann.ChannelwiseSoftmax(y, name=f'{head_name}_softmax') # Attention output # Note: num_queries x head_dim attentions.append(lbann.MatMul(y, v, name=head_name)) #Strong scaling # Concatenate heads and apply fully-connected layer if (ENABLE_SUBGRAPH): attentions = lbann.Concatenation(attentions, axis=1, name=f'{name}_heads_concat', parallel_strategy={ 'sub_branch_tag': 0, 'enable_subgraph': ENABLE_SUBGRAPH }) else: attentions = lbann.Concatenation( attentions, axis=1, name=f'{name}_heads_concat', ) outputs_fc = lbann.ChannelwiseFullyConnected( attentions, weights=self.output_weights, output_channel_dims=[self.embed_dim], name=f'{name}', ) return outputs_fc
def forward( self, hidden_states, attention_mask=None, head_mask=None, ): mixed_query_layer, query_shape = lbann.modules.PytorchLinear( hidden_states, self.input_shape, self.all_head_size, weights=_load_pretrained_weights( ".".join((self.name, "query.weight")), ".".join((self.name, "query.bias")), load_weights=self.load_weights, ), name=".".join((self.name, "query")), return_dims=True, ) query_layer, query_shape = self.transpose_for_scores( mixed_query_layer, query_shape) key_layer, key_shape = lbann.modules.PytorchLinear( hidden_states, self.input_shape, self.all_head_size, weights=_load_pretrained_weights( ".".join((self.name, "key.weight")), ".".join((self.name, "key.bias")), load_weights=self.load_weights, ), name=".".join((self.name, "key")), return_dims=True, ) key_layer, key_shape = self.transpose_for_scores(key_layer, key_shape) value_layer, value_shape = lbann.modules.PytorchLinear( hidden_states, self.input_shape, self.all_head_size, weights=_load_pretrained_weights( ".".join((self.name, "value.weight")), ".".join((self.name, "value.bias")), load_weights=self.load_weights, ), name=".".join((self.name, "value")), return_dims=True, ) value_layer, value_shape = self.transpose_for_scores( value_layer, value_shape) # Take the dot product between "query" and "key" to get the raw attention scores. key_layer, key_shape = lbann.modules.Permute(key_layer, key_shape, axes=(0, 1, -1, -2), return_dims=True) attention_scores, attention_shape = lbann.modules.PytorchMatmul( query_layer, query_shape, key_layer, key_shape, return_dims=True, ) attention_scores = lbann.Scale(attention_scores, constant=1 / math.sqrt(self.attention_head_size)) if attention_mask is not None: # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function) attention_scores = lbann.Add(attention_scores, attention_mask) # Normalize the attention scores to probabilities. attention_scores = lbann.Reshape( attention_scores, dims=str_list([np.prod(attention_shape[:-1]), attention_shape[-1]]), ) attention_probs = lbann.ChannelwiseSoftmax(attention_scores) attention_probs = lbann.Reshape(attention_probs, dims=str_list(attention_shape)) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = lbann.Dropout( attention_probs, keep_prob=self.attention_probs_dropout_prob, ) # Mask heads if we want to if head_mask is not None: attention_probs = lbann.Multiply(attention_probs, head_mask) context_layer, context_shape = lbann.modules.PytorchMatmul( attention_probs, attention_shape, value_layer, value_shape, return_dims=True, ) context_layer, context_shape = lbann.modules.Permute( context_layer, context_shape, axes=(0, 2, 1, 3), return_dims=True, ) new_context_layer_shape = context_shape[:-2] + (self.all_head_size, ) context_layer = lbann.Reshape(context_layer, dims=str_list(self.input_shape)) return context_layer
def forward(self, queries, keys, values, mask=None): """Apply multi-head attention. The input and output tensors are interpreted as sequences of vectors, where the first tensor dimension is the sequence dimension. Args: queries (lbann.Layer): Sequence of query vectors. keys (lbann.Layer): Sequence of key vectors. values (lbann.Layer): Sequence of value vectors. mask (lbann.Layer, optional): Additive attention mask. If the (i,j) entry is very negative (e.g. -1e9), then the ith query does not attend to the jth key/value pair. Returns: lbann.Layer: Sequence of output vectors. The sequence length is the same as `queries`. """ ENABLE_SUBGRAPH = self.ENABLE_SUBGRAPH BRANCHES = self.BRANCHES if (ENABLE_SUBGRAPH): if (self.num_heads % BRANCHES != 0): raise ValueError('Num heads should be divisible by BRANCHES') self.instance += 1 name = f'{self.name}_instance{self.instance}' # Apply fully-connected layers to input sequences queries_fc = [] keys_fc = [] values_fc = [] # Slice embedding vectors for each head slice_points = str_list( self.head_dim * i for i in range(int(self.num_heads / self.BRANCHES) + 1)) #Queries strong scaling in CFC attentions = [] for count, query in enumerate(queries): temp = lbann.ChannelwiseFullyConnected( query, weights=self.query_weights[count], output_channel_dims=[self.inner_dim], name=f'{name}_subgrid{count}_queries_fc', ) attentions.append(temp) grid_sum_slice = lbann.Cross_Grid_Sum_Slice(attentions) attentions = [] for head in range(self.BRANCHES): attentions.append(lbann.Identity(grid_sum_slice)) for head in range(self.BRANCHES): temp = lbann.Slice( attentions[head], axis=1, slice_points=slice_points, name=f'{name}_subgrid{head}_queries_slice', ) queries_fc.append(temp) #keys strong scaling in CFC attentions = [] for count, key in enumerate(keys): temp = lbann.ChannelwiseFullyConnected( key, weights=self.key_weights[count], output_channel_dims=[self.inner_dim], name=f'{name}_subgrid{count}_keys_fc', ) attentions.append(temp) grid_sum_slice = lbann.Cross_Grid_Sum_Slice(attentions) attentions = [] for head in range(self.BRANCHES): attentions.append(lbann.Identity(grid_sum_slice)) for head in range(self.BRANCHES): temp = lbann.Slice( attentions[head], axis=1, slice_points=slice_points, name=f'{name}_subgrid{head}_keys_slice', ) keys_fc.append(temp) #Values strong scaling in CFC attentions = [] for count, value in enumerate(values): temp = lbann.ChannelwiseFullyConnected( value, weights=self.value_weights[count], output_channel_dims=[self.inner_dim], name=f'{name}_subgrid{count}_values_fc', ) attentions.append(temp) grid_sum_slice = lbann.Cross_Grid_Sum_Slice(attentions) attentions = [] for head in range(self.BRANCHES): attentions.append(lbann.Identity(grid_sum_slice)) for head in range(self.BRANCHES): temp = lbann.Slice( attentions[head], axis=1, slice_points=slice_points, name=f'{name}_subgrid{head}_values_slice', ) values_fc.append(temp) queries_slice = [] keys_slice = [] values_slice = [] for branch in range(self.BRANCHES): querie_slice = queries_fc[branch] key_slice = keys_fc[branch] value_slice = values_fc[branch] for head in range(int(self.num_heads / self.BRANCHES)): queries_slice.append(lbann.Identity(querie_slice)) keys_slice.append(lbann.Identity(key_slice)) values_slice.append(lbann.Identity(value_slice)) # Compute scaled dot-product attention for each head attentions = [] #variable to combine heads locally in sub-grids temp_attentions = [] tag = 0 for head in range(self.num_heads): head_name = f'{name}_myattention_head{head}' # Attention inputs if (head % int(self.num_heads / BRANCHES) == 0): temp_attentions.append([]) tag += 1 q = lbann.Identity(queries_slice[head]) k = lbann.Identity(keys_slice[head]) v = lbann.Identity(values_slice[head]) # Multiply queries and keys # Note: num_queries x num_keys y = lbann.MatMul( q, k, transpose_b=True, name=f'{head_name}_matmul', ) y = lbann.WeightedSum( y, scaling_factors=str(1 / math.sqrt(self.head_dim)), name=f'{head_name}_scale', ) if (ENABLE_SUBGRAPH): if mask != None: y = lbann.Sum([y, mask[tag]], name=f'{head_name}_mask') else: if mask: y = lbann.Sum([y, mask], name=f'{head_name}_mask') y = lbann.ChannelwiseSoftmax(y, name=f'{head_name}_softmax') # Attention output # Note: num_queries x head_dim y = lbann.MatMul(y, v, name=head_name) # attentions.append(lbann.MatMul(y, v, name=head_name)) temp_attentions[-1].append(y) for count, temp_attention in enumerate(temp_attentions): if (self.BRANCHES == self.num_heads): # No need to concat the heads at subgrid level # if number of subgrids is equal to number of heads attention_single_subgrid = temp_attentions[count][0] else: attention_single_subgrid = lbann.Concatenation( temp_attention, axis=1, name=f'{name}_subgrid_heads_concat{count}', parallel_strategy={ 'sub_branch_tag': 0, 'enable_subgraph': False }) attention_single_subgrid = lbann.ChannelwiseFullyConnected( attention_single_subgrid, weights=self.output_weights[count], output_channel_dims=[self.embed_dim], name=f'{name}_cfc_{count}', ) attentions.append(attention_single_subgrid) #Strong scaling grid_sum_slice = lbann.Cross_Grid_Sum_Slice(attentions) attentions = [] for head in range(self.BRANCHES): attentions.append(lbann.Identity(grid_sum_slice)) return attentions