Exemplo n.º 1
0
    def forward(self, x, mask=None):
        """Apply Transformer encoder layer.

        Args:
            x (lbann.Layer): Sequence of input vectors.
            mask (lbann.Layer, optional): Attention mask.

        Returns:
            lbann.Layer: Sequence of output vectors.

        """
        self.instance += 1
        name = f'{self.name}_instance{self.instance}'

        # Self-attention with residual connection
        y = self.attention(x, x, x, mask=mask)
        if self.dropout_prob > 0:
            y = lbann.Dropout(
                y,
                keep_prob=1 - self.dropout_prob,
                name=f'{name}_drop1',
            )
        z = lbann.Sum(x, y, name=f'{name}_sum1')
        z = lbann.InstanceNorm(z, name=f'{name}_norm1')
        x = z

        # Feedforward network with residual connection
        y = lbann.ChannelwiseFullyConnected(
            x,
            weights=self.fc1_weights,
            output_channel_dims=[self.feedforward_dim],
            name=f'{name}_fc1',
        )
        y = lbann.Relu(y, name=f'{name}_relu1')
        if self.dropout_prob > 0:
            y = lbann.Dropout(
                y,
                keep_prob=1 - self.dropout_prob,
                name=f'{name}_drop2',
            )
        y = lbann.ChannelwiseFullyConnected(
            y,
            weights=self.fc2_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}_fc2',
        )
        if self.dropout_prob > 0:
            y = lbann.Dropout(
                y,
                keep_prob=1 - self.dropout_prob,
                name=f'{name}_drop3',
            )
        z = lbann.Sum(x, y, name=f'{name}_sum2')
        z = lbann.InstanceNorm(z, name=f'{name}_norm2')
        return z
Exemplo n.º 2
0
def PytorchLinear(x,
                  input_shape,
                  hidden_size,
                  weights=[],
                  name="",
                  return_dims=False):
    need_reshape = len(input_shape) > 2
    if need_reshape:
        new_in_shape = (np.prod(input_shape[:-1]), input_shape[-1])
        x = lbann.Reshape(x, dims=str_list(new_in_shape))

    if len(input_shape) == 1:
        y = lbann.FullyConnected(x,
                                 num_neurons=hidden_size,
                                 weights=weights,
                                 name=name)
    else:
        y = lbann.ChannelwiseFullyConnected(x,
                                            output_channel_dims=[hidden_size],
                                            weights=weights,
                                            name=name)

    if need_reshape:
        new_out_shape = input_shape[:-1] + (hidden_size, )
        y = lbann.Reshape(y, dims=str_list(new_out_shape))
    else:
        new_out_shape = (input_shape[0], hidden_size)

    if return_dims:
        return y, new_out_shape
    return y
Exemplo n.º 3
0
 def encode(self, x):
     x = lbann.Reshape(x, dims=utils.str_list([-1, self.input_dim]))
     for i, dim in enumerate(self.hidden_dims):
         x = lbann.ChannelwiseFullyConnected(
             x,
             weights=self.weights[i],
             output_channel_dims=dim,
             bias=False,
         )
         x = lbann.Relu(x)
     x = lbann.ChannelwiseFullyConnected(
         x,
         weights=self.weights[-1],
         output_channel_dims=self.output_dim,
         bias=False,
     )
     return x
Exemplo n.º 4
0
 def decode(self, x):
     x = lbann.Reshape(x, dims=utils.str_list([-1, self.output_dim]))
     for i in range(len(self.hidden_dims)):
         x = lbann.ChannelwiseFullyConnected(
             x,
             weights=self.weights[-i - 1],
             output_channel_dims=self.hidden_dims[-i - 1],
             transpose=True,
             bias=False,
         )
         x = lbann.Relu(x)
     x = lbann.ChannelwiseFullyConnected(
         x,
         weights=self.weights[0],
         output_channel_dims=self.input_dim,
         transpose=True,
         bias=False,
     )
     return x
Exemplo n.º 5
0
Arquivo: vae.py Projeto: oyamay/lbann
    def forward_decoder(self, x_emb, z):
        """Decoder step, emulating x ~ G(z)

        :param x_emb: (n_batch, len(x), d_z) of floats, embeddings for input sentence x
        :param z: (n_batch, d_z) of floats, latent vector z
        :return: float, recon component of loss
        :return: list of ints, reconstructed sentence
        """

        # z_0 = z.unsqueeze(1).repeat(1, x_emb.size(1), 1)
        # x_input = torch.cat([x_emb, z_0], dim=-1)
        z_0 = lbann.Tessellate(
            lbann.Reshape(z, dims=str_list([1, 128])),
            dims=str_list([self.input_feature_dims, 128]),
        )
        x_input = lbann.Concatenation(x_emb, z_0, axis=1)

        h_0 = self.decoder_lat(z)
        # h_0 = h_0.unsqueeze(0).repeat(self.decoder_rnn.num_layers, 1, 1)
        h_0 = lbann.Reshape(h_0, dims=str_list([1, 512]))
        h_0 = lbann.Tessellate(h_0, dims=str_list((3, 512)))

        # output, _ = self.decoder_rnn(x_input, h_0)
        output = self.decoder_rnn(x_input, h_0)

        # y = self.decoder_fc(output)
        y = lbann.ChannelwiseFullyConnected(
            output,
            output_channel_dims=self.dictionary_size,
            bias=True,
            name=f'{self.decoder_fc.name}',
            weights=self.decoder_fc.weights,
        )

        # Set datatype of layers
        # Note: Depth-first search from y to x_emb and z
        stack = [y]
        in_stack = {l: True for l in stack}
        while stack:
            l = stack.pop()
            if type(l) not in (lbann.Slice, lbann.Reshape, lbann.Tessellate):
                l.datatype = self.datatype
            for parent in l.parents:
                if parent not in in_stack and parent not in (x_emb, z):
                    stack.append(parent)
                    in_stack[parent] = True

        return y
Exemplo n.º 6
0
 def forward(self, x):
     self.instance += 1
     name = '{0}_instance{1}'.format(self.name, self.instance)
     y = lbann.ChannelwiseFullyConnected(
         x,
         weights=self.weights,
         name=(name + '_fc' if self.activation else name),
         data_layout=self.data_layout,
         output_channel_dims=self.size,
         bias=self.bias,
         transpose=self.transpose,
         parallel_strategy=self.parallel_strategy)
     if self.activation:
         return self.activation(y,
                                name=name + '_activation',
                                data_layout=self.data_layout,
                                parallel_strategy=self.parallel_strategy)
     else:
         return y
Exemplo n.º 7
0
    def forward(self, queries, keys, values, mask=None):
        """Apply multi-head attention.

        The input and output tensors are interpreted as sequences of
        vectors, where the first tensor dimension is the sequence
        dimension.

        Args:
            queries (lbann.Layer): Sequence of query vectors.
            keys (lbann.Layer): Sequence of key vectors.
            values (lbann.Layer): Sequence of value vectors.
            mask (lbann.Layer, optional): Additive attention mask. If
                the (i,j) entry is very negative (e.g. -1e9), then the
                ith query does not attend to the jth key/value pair.

        Returns:
            lbann.Layer: Sequence of output vectors. The sequence
                length is the same as `queries`.

        """
        self.instance += 1
        name = f'{self.name}_instance{self.instance}'

        # Apply fully-connected layers to input sequences
        queries_fc = lbann.ChannelwiseFullyConnected(
            queries,
            weights=self.query_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}_queries_fc',
        )
        keys_fc = lbann.ChannelwiseFullyConnected(
            keys,
            weights=self.key_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}_keys_fc',
        )
        values_fc = lbann.ChannelwiseFullyConnected(
            values,
            weights=self.value_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}_values_fc',
        )

        # Slice embedding vectors for each head
        slice_points = str_list(self.head_dim * i
                                for i in range(self.num_heads + 1))
        queries_slice = lbann.Slice(
            queries_fc,
            axis=1,
            slice_points=slice_points,
            name=f'{name}_queries_slice',
        )
        keys_slice = lbann.Slice(
            keys_fc,
            axis=1,
            slice_points=slice_points,
            name=f'{name}_keys_slice',
        )
        values_slice = lbann.Slice(
            values_fc,
            axis=1,
            slice_points=slice_points,
            name=f'{name}_values_slice',
        )

        # Compute scaled dot-product attention for each head
        attentions = []
        for head in range(self.num_heads):
            head_name = f'{name}_head{head}'

            # Attention inputs
            q = lbann.Identity(queries_slice)
            k = lbann.Identity(keys_slice)
            v = lbann.Identity(values_slice)

            # Multiply queries and keys
            # Note: num_queries x num_keys
            y = lbann.MatMul(
                q,
                k,
                transpose_b=True,
                name=f'{head_name}_matmul',
            )
            y = lbann.WeightedSum(
                y,
                scaling_factors=str(1 / math.sqrt(self.head_dim)),
                name=f'{head_name}_scale',
            )
            if mask:
                y = lbann.Add(y, mask, name=f'{head_name}_mask')
            y = lbann.ChannelwiseSoftmax(y, name=f'{head_name}_softmax')

            # Attention output
            # Note: num_queries x head_dim
            attentions.append(lbann.MatMul(y, v, name=head_name))

        # Concatenate heads and apply fully-connected layer
        attentions = lbann.Concatenation(attentions,
                                         axis=1,
                                         name=f'{name}_heads_concat')
        outputs_fc = lbann.ChannelwiseFullyConnected(
            attentions,
            weights=self.output_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}',
        )
        return outputs_fc
Exemplo n.º 8
0
def make_model(
    num_epochs,
    embed_dim,
    num_heads,
    label_smoothing,
):

    # Embedding weights
    var = 2 / (embed_dim + vocab_size)  # Glorot initialization
    embedding_weights = lbann.Weights(
        name='embeddings',
        initializer=lbann.NormalInitializer(standard_deviation=math.sqrt(var)),
    )

    # Input is two sequences of token IDs
    input_ = lbann.Input(data_field='samples')

    # Get sequences of embedding vectors
    # Note: Scale embeddings by sqrt(embed_dim).
    # Note: Decoder input is shifted right, so embedding for last
    # token isn't needed.
    embeddings_tokens = lbann.Identity(
        lbann.Slice(
            input_,
            axis=0,
            slice_points=str_list([0, 2 * sequence_length - 1]),
        ))
    embeddings = lbann.Embedding(
        embeddings_tokens,
        weights=embedding_weights,
        num_embeddings=vocab_size,
        embedding_dim=embed_dim,
        padding_idx=pad_index,
    )
    embeddings = lbann.WeightedSum(
        embeddings,
        scaling_factors=str(math.sqrt(embed_dim)),
    )
    embeddings_slice = lbann.Slice(
        embeddings,
        axis=0,
        slice_points=str_list([0, sequence_length, 2 * sequence_length - 1]),
    )
    encoder_input = lbann.Identity(embeddings_slice)
    decoder_input = lbann.Identity(embeddings_slice)

    # Apply transformer model
    transformer = lbann.models.Transformer(
        hidden_size=embed_dim,
        num_heads=num_heads,
        name='transformer',
    )
    result = transformer(
        encoder_input,
        sequence_length,
        decoder_input,
        sequence_length - 1,
    )

    # Reconstruct decoder input
    preds = lbann.ChannelwiseFullyConnected(
        result,
        weights=embedding_weights,
        output_channel_dims=[vocab_size],
        bias=False,
        transpose=True,
    )
    preds = lbann.ChannelwiseSoftmax(preds)
    preds = lbann.Slice(preds,
                        axis=0,
                        slice_points=str_list(range(sequence_length)))
    preds = [lbann.Identity(preds) for _ in range(sequence_length - 1)]

    # Count number of non-pad tokens
    label_tokens = lbann.Identity(
        lbann.Slice(
            input_,
            slice_points=str_list([sequence_length + 1, 2 * sequence_length]),
        ))
    pads = lbann.Constant(value=pad_index,
                          num_neurons=str(sequence_length - 1))
    is_not_pad = lbann.NotEqual(label_tokens, pads)
    num_not_pad = lbann.Reduction(is_not_pad, mode='sum')

    # Cross entropy loss with label smoothing
    label_tokens = lbann.Slice(
        label_tokens,
        slice_points=str_list(range(sequence_length)),
    )
    label_tokens = [
        lbann.Identity(label_tokens) for _ in range(sequence_length - 1)
    ]
    if label_smoothing > 0:
        uniform_label = lbann.Constant(value=1 / vocab_size,
                                       num_neurons=str_list([1, vocab_size]))
    loss = []
    for i in range(sequence_length - 1):
        label = lbann.OneHot(label_tokens[i], size=vocab_size)
        label = lbann.Reshape(label, dims=str_list([1, vocab_size]))
        if label_smoothing > 0:
            label = lbann.WeightedSum(
                label,
                uniform_label,
                scaling_factors=str_list(
                    [1 - label_smoothing, label_smoothing]),
            )
        loss.append(lbann.CrossEntropy(preds[i], label))
    loss = lbann.Concatenation(loss)

    # Average cross entropy over non-pad tokens
    loss_scales = lbann.Divide(
        is_not_pad,
        lbann.Tessellate(num_not_pad, hint_layer=is_not_pad),
    )
    loss = lbann.Multiply(loss, loss_scales)
    loss = lbann.Reduction(loss, mode='sum')

    # Construct model
    metrics = []
    callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()]
    return lbann.Model(
        num_epochs,
        layers=lbann.traverse_layer_graph(input_),
        objective_function=loss,
        metrics=metrics,
        callbacks=callbacks,
    )
Exemplo n.º 9
0
# ----------------------------------------------
with open("./config.json") as f:
    config = json.load(f, object_hook=lambda d: SimpleNamespace(**d))
config.input_shape = (16, 32)
config.load_weights = os.path.exists('./pretrained_weights')

# Construct the model
input_ = lbann.Slice(
    lbann.Input(data_field="samples"),
    slice_points=str_list([0, 1, 1 + np.prod(config.input_shape)]),
)
labels = lbann.Identity(input_)
sample = lbann.Reshape(input_, dims=str_list(config.input_shape))
roberta = RobertaModel(config, load_weights=config.load_weights)
out = roberta(sample)
out = lbann.ChannelwiseFullyConnected(out, output_channel_dims=[1000])
loss = CrossEntropyLoss(10, data_layout="model_parallel")
obj = loss(out, labels)
metrics = [lbann.Metric(obj, name="loss")]

model = lbann.Model(
    lbann_params.epochs,
    layers=lbann.traverse_layer_graph(input_),
    objective_function=obj,
    metrics=metrics,
    callbacks=[
        lbann.CallbackPrint(),
        lbann.CallbackTimer(),
    ],
)
Exemplo n.º 10
0
    def forward(self, queries, keys, values, mask=None):
        """Apply multi-head attention.

        The input and output tensors are interpreted as sequences of
        vectors, where the first tensor dimension is the sequence
        dimension.

        Args:
            queries (lbann.Layer): Sequence of query vectors.
            keys (lbann.Layer): Sequence of key vectors.
            values (lbann.Layer): Sequence of value vectors.
            mask (lbann.Layer, optional): Additive attention mask. If
                the (i,j) entry is very negative (e.g. -1e9), then the
                ith query does not attend to the jth key/value pair.

        Returns:
            lbann.Layer: Sequence of output vectors. The sequence
                length is the same as `queries`.

        """
        ENABLE_SUBGRAPH = self.ENABLE_SUBGRAPH
        BRANCHES = self.BRANCHES
        if (ENABLE_SUBGRAPH):
            if (self.num_heads % BRANCHES != 0):
                raise ValueError('Num heads should be divisible by BRANCHES')
        self.instance += 1
        name = f'{self.name}_instance{self.instance}'

        # Apply fully-connected layers to input sequences
        queries_fc = lbann.ChannelwiseFullyConnected(
            queries,
            weights=self.query_weights,
            output_channel_dims=[self.inner_dim],
            name=f'{name}_queries_fc',
        )
        keys_fc = lbann.ChannelwiseFullyConnected(
            keys,
            weights=self.key_weights,
            output_channel_dims=[self.inner_dim],
            name=f'{name}_keys_fc',
        )
        values_fc = lbann.ChannelwiseFullyConnected(
            values,
            weights=self.value_weights,
            output_channel_dims=[self.inner_dim],
            name=f'{name}_values_fc',
        )

        # Slice embedding vectors for each head
        slice_points = str_list(self.head_dim * i
                                for i in range(self.num_heads + 1))
        queries_slice = lbann.Slice(queries_fc,
                                    axis=1,
                                    slice_points=slice_points,
                                    name=f'{name}_queries_slice',
                                    parallel_strategy={
                                        'sub_branch_tag': 0,
                                        'enable_subgraph': ENABLE_SUBGRAPH
                                    })
        keys_slice = lbann.Slice(keys_fc,
                                 axis=1,
                                 slice_points=slice_points,
                                 name=f'{name}_keys_slice',
                                 parallel_strategy={
                                     'sub_branch_tag': 0,
                                     'enable_subgraph': ENABLE_SUBGRAPH
                                 })
        values_slice = lbann.Slice(values_fc,
                                   axis=1,
                                   slice_points=slice_points,
                                   name=f'{name}_values_slice',
                                   parallel_strategy={
                                       'sub_branch_tag': 0,
                                       'enable_subgraph': ENABLE_SUBGRAPH
                                   })

        # Compute scaled dot-product attention for each head
        attentions = []
        tag = 0
        for head in range(self.num_heads):
            head_name = f'{name}_myattention_head{head}'

            # Attention inputs

            if (ENABLE_SUBGRAPH):
                if (head % int(self.num_heads / BRANCHES) == 0):
                    tag += 1

                q = lbann.Identity(queries_slice,
                                   parallel_strategy={
                                       'sub_branch_tag': tag,
                                       'enable_subgraph': ENABLE_SUBGRAPH
                                   })
                k = lbann.Identity(keys_slice,
                                   parallel_strategy={
                                       'sub_branch_tag': tag,
                                       'enable_subgraph': ENABLE_SUBGRAPH
                                   })
                v = lbann.Identity(values_slice,
                                   parallel_strategy={
                                       'sub_branch_tag': tag,
                                       'enable_subgraph': ENABLE_SUBGRAPH
                                   })
            else:
                q = lbann.Identity(queries_slice)
                k = lbann.Identity(keys_slice)
                v = lbann.Identity(values_slice)

            # Multiply queries and keys
            # Note: num_queries x num_keys
            y = lbann.MatMul(
                q,
                k,
                transpose_b=True,
                name=f'{head_name}_matmul',
            )
            y = lbann.WeightedSum(
                y,
                scaling_factors=str(1 / math.sqrt(self.head_dim)),
                name=f'{head_name}_scale',
            )

            if (ENABLE_SUBGRAPH):
                if mask != None:
                    y = lbann.Sum([y, mask[tag]], name=f'{head_name}_mask')
            else:
                if mask:
                    y = lbann.Sum([y, mask], name=f'{head_name}_mask')
            y = lbann.ChannelwiseSoftmax(y, name=f'{head_name}_softmax')

            # Attention output
            # Note: num_queries x head_dim

            attentions.append(lbann.MatMul(y, v, name=head_name))

            #Strong scaling

        # Concatenate heads and apply fully-connected layer
        if (ENABLE_SUBGRAPH):
            attentions = lbann.Concatenation(attentions,
                                             axis=1,
                                             name=f'{name}_heads_concat',
                                             parallel_strategy={
                                                 'sub_branch_tag': 0,
                                                 'enable_subgraph':
                                                 ENABLE_SUBGRAPH
                                             })
        else:
            attentions = lbann.Concatenation(
                attentions,
                axis=1,
                name=f'{name}_heads_concat',
            )

        outputs_fc = lbann.ChannelwiseFullyConnected(
            attentions,
            weights=self.output_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}',
        )
        return outputs_fc
Exemplo n.º 11
0
    def forward(self, queries, keys, values, mask=None):
        """Apply multi-head attention.

        The input and output tensors are interpreted as sequences of
        vectors, where the first tensor dimension is the sequence
        dimension.

        Args:
            queries (lbann.Layer): Sequence of query vectors.
            keys (lbann.Layer): Sequence of key vectors.
            values (lbann.Layer): Sequence of value vectors.
            mask (lbann.Layer, optional): Additive attention mask. If
                the (i,j) entry is very negative (e.g. -1e9), then the
                ith query does not attend to the jth key/value pair.

        Returns:
            lbann.Layer: Sequence of output vectors. The sequence
                length is the same as `queries`.

        """
        ENABLE_SUBGRAPH = self.ENABLE_SUBGRAPH
        BRANCHES = self.BRANCHES
        if (ENABLE_SUBGRAPH):
            if (self.num_heads % BRANCHES != 0):
                raise ValueError('Num heads should be divisible by BRANCHES')
        self.instance += 1
        name = f'{self.name}_instance{self.instance}'

        # Apply fully-connected layers to input sequences
        queries_fc = []
        keys_fc = []
        values_fc = []

        # Slice embedding vectors for each head
        slice_points = str_list(
            self.head_dim * i
            for i in range(int(self.num_heads / self.BRANCHES) + 1))

        #Queries strong scaling in CFC
        attentions = []
        for count, query in enumerate(queries):
            temp = lbann.ChannelwiseFullyConnected(
                query,
                weights=self.query_weights[count],
                output_channel_dims=[self.inner_dim],
                name=f'{name}_subgrid{count}_queries_fc',
            )
            attentions.append(temp)

        grid_sum_slice = lbann.Cross_Grid_Sum_Slice(attentions)

        attentions = []

        for head in range(self.BRANCHES):
            attentions.append(lbann.Identity(grid_sum_slice))

        for head in range(self.BRANCHES):
            temp = lbann.Slice(
                attentions[head],
                axis=1,
                slice_points=slice_points,
                name=f'{name}_subgrid{head}_queries_slice',
            )

            queries_fc.append(temp)

        #keys strong scaling in CFC

        attentions = []
        for count, key in enumerate(keys):
            temp = lbann.ChannelwiseFullyConnected(
                key,
                weights=self.key_weights[count],
                output_channel_dims=[self.inner_dim],
                name=f'{name}_subgrid{count}_keys_fc',
            )

            attentions.append(temp)

        grid_sum_slice = lbann.Cross_Grid_Sum_Slice(attentions)

        attentions = []

        for head in range(self.BRANCHES):
            attentions.append(lbann.Identity(grid_sum_slice))

        for head in range(self.BRANCHES):

            temp = lbann.Slice(
                attentions[head],
                axis=1,
                slice_points=slice_points,
                name=f'{name}_subgrid{head}_keys_slice',
            )

            keys_fc.append(temp)

        #Values strong scaling in CFC
        attentions = []

        for count, value in enumerate(values):
            temp = lbann.ChannelwiseFullyConnected(
                value,
                weights=self.value_weights[count],
                output_channel_dims=[self.inner_dim],
                name=f'{name}_subgrid{count}_values_fc',
            )
            attentions.append(temp)

        grid_sum_slice = lbann.Cross_Grid_Sum_Slice(attentions)

        attentions = []

        for head in range(self.BRANCHES):
            attentions.append(lbann.Identity(grid_sum_slice))

        for head in range(self.BRANCHES):
            temp = lbann.Slice(
                attentions[head],
                axis=1,
                slice_points=slice_points,
                name=f'{name}_subgrid{head}_values_slice',
            )
            values_fc.append(temp)

        queries_slice = []
        keys_slice = []
        values_slice = []

        for branch in range(self.BRANCHES):
            querie_slice = queries_fc[branch]
            key_slice = keys_fc[branch]
            value_slice = values_fc[branch]

            for head in range(int(self.num_heads / self.BRANCHES)):
                queries_slice.append(lbann.Identity(querie_slice))
                keys_slice.append(lbann.Identity(key_slice))
                values_slice.append(lbann.Identity(value_slice))

        # Compute scaled dot-product attention for each head
        attentions = []

        #variable to combine heads locally in sub-grids
        temp_attentions = []
        tag = 0
        for head in range(self.num_heads):
            head_name = f'{name}_myattention_head{head}'

            # Attention inputs
            if (head % int(self.num_heads / BRANCHES) == 0):
                temp_attentions.append([])
                tag += 1

            q = lbann.Identity(queries_slice[head])
            k = lbann.Identity(keys_slice[head])
            v = lbann.Identity(values_slice[head])

            # Multiply queries and keys
            # Note: num_queries x num_keys
            y = lbann.MatMul(
                q,
                k,
                transpose_b=True,
                name=f'{head_name}_matmul',
            )
            y = lbann.WeightedSum(
                y,
                scaling_factors=str(1 / math.sqrt(self.head_dim)),
                name=f'{head_name}_scale',
            )

            if (ENABLE_SUBGRAPH):
                if mask != None:
                    y = lbann.Sum([y, mask[tag]], name=f'{head_name}_mask')
            else:
                if mask:
                    y = lbann.Sum([y, mask], name=f'{head_name}_mask')
            y = lbann.ChannelwiseSoftmax(y, name=f'{head_name}_softmax')

            # Attention output
            # Note: num_queries x head_dim
            y = lbann.MatMul(y, v, name=head_name)
            # attentions.append(lbann.MatMul(y, v, name=head_name))

            temp_attentions[-1].append(y)

        for count, temp_attention in enumerate(temp_attentions):

            if (self.BRANCHES == self.num_heads):
                # No need to concat the heads at subgrid level
                # if number of subgrids is equal to number of heads
                attention_single_subgrid = temp_attentions[count][0]
            else:
                attention_single_subgrid = lbann.Concatenation(
                    temp_attention,
                    axis=1,
                    name=f'{name}_subgrid_heads_concat{count}',
                    parallel_strategy={
                        'sub_branch_tag': 0,
                        'enable_subgraph': False
                    })

            attention_single_subgrid = lbann.ChannelwiseFullyConnected(
                attention_single_subgrid,
                weights=self.output_weights[count],
                output_channel_dims=[self.embed_dim],
                name=f'{name}_cfc_{count}',
            )

            attentions.append(attention_single_subgrid)

        #Strong scaling

        grid_sum_slice = lbann.Cross_Grid_Sum_Slice(attentions)

        attentions = []

        for head in range(self.BRANCHES):
            attentions.append(lbann.Identity(grid_sum_slice))

        return attentions
Exemplo n.º 12
0
    def forward(self, x, memory, src_mask=None, tgt_mask=None):
        """Apply Transformer decoder layer.

        Args:
            x (lbann.Layer): Sequence of input vectors.
            memory (lbann.Layer): Sequence of vectors produced by
                Transformer encoder stack.
            src_mask (lbann.Layer, optional): Attention mask for
                second attention module (attends to both `x` and
                `memory`).
            tgt_mask (lbann.Layer, optional): Attention mask for first
                attention module (attends only to `x`).

        Returns:
            lbann.Layer: Sequence of output vectors.

        """
        self.instance += 1
        name = f'{self.name}_instance{self.instance}'

        # Self-attention with residual connection
        y = self.attention1(x, x, x, mask=tgt_mask)
        if self.dropout_prob > 0:
            y = lbann.Dropout(
                y,
                keep_prob=1 - self.dropout_prob,
                name=f'{name}_drop1',
            )
        z = lbann.Sum(x, y, name=f'{name}_sum1')
        z = lbann.InstanceNorm(z, name=f'{name}_norm1')
        x = z

        # Attention on encoder output with residual connection
        y = self.attention2(x, memory, memory, mask=src_mask)
        if self.dropout_prob > 0:
            y = lbann.Dropout(
                y,
                keep_prob=1 - self.dropout_prob,
                name=f'{name}_drop2',
            )
        z = lbann.Sum(x, y, name=f'{name}_sum2')
        z = lbann.InstanceNorm(z, name=f'{name}_norm2')
        x = z

        # Feedforward network with residual connection
        y = lbann.ChannelwiseFullyConnected(
            x,
            weights=self.fc1_weights,
            output_channel_dims=[self.feedforward_dim],
            name=f'{name}_fc1',
        )
        y = lbann.Relu(y, name=f'{name}_relu1')
        if self.dropout_prob > 0:
            y = lbann.Dropout(
                y,
                keep_prob=1 - self.dropout_prob,
                name=f'{name}_drop3',
            )
        y = lbann.ChannelwiseFullyConnected(
            y,
            weights=self.fc2_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}_fc2',
        )
        if self.dropout_prob > 0:
            y = lbann.Dropout(
                y,
                keep_prob=1 - self.dropout_prob,
                name=f'{name}_drop4',
            )
        z = lbann.Sum(x, y, name=f'{name}_sum3')
        z = lbann.InstanceNorm(z, name=f'{name}_norm3')
        return z