示例#1
0
    def forward(self, x):
        """Perform LSTM step.

        State from previous steps is used to compute output.

        """
        self.step += 1
        name = '{0}_step{1}'.format(self.name, self.step)

        # Apply linearity
        input_concat = lbann.Concatenation([x, self.last_output],
                                           name=name + '_input',
                                           data_layout=self.data_layout)
        fc = self.fc(input_concat)

        # Get gates and cell update
        slice = lbann.Slice(fc,
                            slice_points=_str_list([0, self.size, 4*self.size]),
                            name=name + '_fc_slice',
                            data_layout=self.data_layout)
        cell_update = lbann.Tanh(slice,
                                 name=name + '_cell_update',
                                 data_layout=self.data_layout)
        sigmoid = lbann.Sigmoid(slice,
                                name=name + '_sigmoid',
                                data_layout=self.data_layout)
        slice = lbann.Slice(sigmoid,
                            slice_points=_str_list([0, self.size, 2*self.size, 3*self.size]),
                            name=name + '_sigmoid_slice',
                            data_layout=self.data_layout)
        f = lbann.Identity(slice, name=name + '_forget_gate',
                           data_layout=self.data_layout)
        i = lbann.Identity(slice, name=name + '_input_gate',
                           data_layout=self.data_layout)
        o = lbann.Identity(slice, name=name + '_output_gate',
                           data_layout=self.data_layout)

        # Cell state
        cell_forget = lbann.Multiply([f, self.last_cell],
                                     name=name + '_cell_forget',
                                     data_layout=self.data_layout)
        cell_input = lbann.Multiply([i, cell_update],
                                    name=name + '_cell_input',
                                    data_layout=self.data_layout)
        cell = lbann.Add([cell_forget, cell_input], name=name + '_cell',
                         data_layout=self.data_layout)

        # Output
        cell_act = lbann.Tanh(cell, name=name + '_cell_activation',
                              data_layout=self.data_layout)
        output = lbann.Multiply([o, cell_act], name=name,
                                data_layout=self.data_layout)

        # Update state and return output
        self.last_cell = cell
        self.last_output = output
        return output
示例#2
0
def Gelu_approx(x):
    # This approximates gelu and may be more performant
    # return 0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x ** 3)))
    # Based on: https://paperswithcode.com/method/gelu
    sqrt_2_over_pi = math.sqrt(2 / math.pi)
    b_coef = 0.044715
    x_cubed = lbann.Multiply(lbann.Multiply(lbann.Identity(x), x), x)
    inner_tanh_x_comp = lbann.Add(x, lbann.Scale(x_cubed, constant=b_coef))
    tanh_x = lbann.Tanh(lbann.Scale(inner_tanh_x_comp,
                                    constant=sqrt_2_over_pi))
    return lbann.Scale(lbann.Multiply(x, lbann.AddConstant(tanh_x,
                                                           constant=1)),
                       constant=0.5)
示例#3
0
    def forward(self, x, label):
        """Compute cross-entropy loss.

        Args:
          x (lbann.Layer): Input vector.
          label (lbann.Layer): Label. Should have one entry, which
            will be cast to an integer.

        Returns:
          lbann.Layer: Loss function value.

        """
        log_probs = self.fc(x)
        label_onehot = lbann.OneHot(
            label,
            size=self.num_classes,
            data_layout=self.data_layout,
        )
        loss = lbann.Multiply(
            log_probs,
            label_onehot,
            data_layout=self.data_layout,
        )
        loss = lbann.Reduction(
            loss,
            mode="sum",
            data_layout=self.data_layout,
        )
        loss = lbann.Negative(loss, data_layout=self.data_layout)
        return loss
示例#4
0
文件: vae.py 项目: szaman19/lbann
    def forward_encoder(self, x_emb):
        """Encoder step, emulating z ~ E(x) = q_E(z|x)

        :param x_emb: (n_batch, len(x), d_z) of floats, embeddings for input sentence x
        :return: (n_batch, d_z) of floats, sample of latent vector z
        :return: float, kl term component of loss
        """

        # _, h = self.encoder_rnn(x, None)
        h = self.encoder_rnn(x_emb, None)

        h = lbann.Slice(
            h,
            slice_points=str_list(
                [self.input_feature_dims - 1, self.input_feature_dims]),
            axis=0,
        )
        h = lbann.Identity(h)

        mu, logvar = self.q_mu(h), self.q_logvar(h)

        # Set datatype of previous layers
        # Note: Depth-first search from mu and logvar to x_emb
        stack = [mu, logvar]
        in_stack = {l: True for l in stack}
        while stack:
            l = stack.pop()
            if type(l) not in (lbann.Slice, lbann.Reshape, lbann.Tessellate):
                l.datatype = self.datatype
            for parent in l.parents:
                if parent not in in_stack and parent is not x_emb:
                    stack.append(parent)
                    in_stack[parent] = True

        # eps = torch.randn_like(mu)
        eps = lbann.Gaussian(mean=0, stdev=1, hint_layer=mu)

        # z = mu + (logvar / 2).exp() * eps
        z = lbann.Add([
            mu,
            (lbann.Multiply([
                lbann.Exp(lbann.WeightedSum(logvar, scaling_factors='0.5')),
                eps
            ]))
        ])

        # kl_loss = 0.5 * (logvar.exp() + mu ** 2 - 1 - logvar).sum(1).mean()
        kl_loss = lbann.Reduction(
            lbann.WeightedSum(
                lbann.Exp(logvar),
                lbann.Square(mu),
                self.constant(1, hint_layer=mu),
                logvar,
                scaling_factors='0.5 0.5 -0.5 -0.5',
            ),
            mode='sum',
        )

        return z, kl_loss
示例#5
0
 def forward(self, inputs):
     if len(inputs) != 2:
         raise ValueError('expected two inputs: predictions and labels')
     pred = inputs[0]
     label = inputs[1]  # Assumed to be Boolean
     masked_pred = lbann.Multiply([pred, label])
     pred_sum = lbann.Reduction(masked_pred)
     return lbann.Negative(lbann.Log(pred_sum))
示例#6
0
 def forward(self, inputs):
     if len(inputs) != 2:
         raise ValueError('expected two inputs: predictions and labels')
     pred = inputs[0]
     label = inputs[1]
     ones = p.Constant(hint_layer=pred, value=1.0)
     term1 = lbann.Multiply(
         [label, lbann.Log(lbann.Subtract([ones, pred]))])
     term2 = lbann.Log(pred)
     full = lbann.WeightedSum([term1, term2], scaling_factors='-1.0 -1.0')
     return lbann.Reduction(full)
示例#7
0
 def forward(self, inputs):
     raise NotImplementedError  # Requires log-gamma function
     if len(inputs) != 2:
         raise ValueError('expected two inputs: predictions and labels')
     pred = inputs[0]
     label = inputs[1]
     ones = lbann.Constant(hint_layer=pred, value=1.0)
     term1 = pred
     term2 = lbann.Multiply([label, lbann.Log(pred)])
     term3 = lbann.LogGamma(lbann.Add([label, ones]))
     full = lbann.WeightedSum([term1, term2, term3],
                              scaling_factors='1.0 -1.0 1.0')
     return lbann.Reduction(full)
示例#8
0
    def forward(
        self,
        motif_indices,
        motif_size,
        walk_indices,
        walk_length,
    ):

        # Apply generator
        fake_motif_indices, gen_prob, gen_log_prob = self.generator(
            walk_length,
            walk_indices,
            motif_size,
        )

        # Get discriminator embeddings in log-space
        all_motif_indices = lbann.Concatenation(motif_indices,
                                                fake_motif_indices)
        all_motif_log_embeddings = self.discriminator.get_log_embeddings(
            all_motif_indices)
        all_motif_log_embeddings = lbann.Slice(
            all_motif_log_embeddings,
            slice_points=str_list([0, motif_size, 2 * motif_size]),
        )
        real_motif_log_embeddings = lbann.Identity(all_motif_log_embeddings)
        fake_motif_log_embeddings = lbann.Identity(all_motif_log_embeddings)

        # Apply discriminator
        real_disc_prob, real_disc_log_not_prob \
            = self.discriminator(motif_size, real_motif_log_embeddings)
        fake_disc_prob, fake_disc_log_not_prob \
            = self.discriminator(motif_size, fake_motif_log_embeddings)

        # Loss function
        # L_disc = - log(D(real)) - log(1-D(fake))
        # L_gen = - log(G) * stop_gradient(log(1-D(fake)))
        real_disc_log_prob \
            = lbann.Log(lbann.Clamp(real_disc_prob, min=1e-37, max=1))
        disc_loss = lbann.WeightedSum(
            real_disc_log_prob,
            fake_disc_log_not_prob,
            scaling_factors=str_list([-1, -1]),
        )
        gen_loss = lbann.Multiply(
            gen_log_prob,
            lbann.StopGradient(fake_disc_log_not_prob),
        )
        loss = lbann.Add(disc_loss, gen_loss)

        return loss, real_disc_prob, fake_disc_prob, gen_prob
示例#9
0
    def forward(self, image, dims, max_r):
        """Compute radial profile.

        Args:
            image (lbann.Layer): Image
            dims (tuple of int): Image dimensions (dim 0 corresponds
                to channel)
            max_r (int): Maximum radial distance. Pixels outside this
                distance are ignored.

        Returns:
            Layer: num_channels x max_r radial profile

        """

        # Bin spatial positions
        r, r_counts = self._find_radial_bins(dims[1:], max_r)

        # Reciprocal of bin counts
        # Note: If a count is 0, its reciprocal is 0.
        r_counts_recip = [0 if c == 0 else 1 / c for c in r_counts]

        # Get scatter indices and scaling factors
        # Note: Independent binning for each channel (dim 0)
        tile_dims = [dims[0]] + [1] * r.ndim
        inds_vals = np.tile(r, tile_dims)
        inds_vals += np.arange(0, dims[0] * max_r, max_r).reshape(tile_dims)
        inds_vals[:, r >= max_r] = -1
        inds_vals = inds_vals.flatten()
        scales_vals = r_counts_recip * dims[0]

        # Construct LBANN layer graph
        image = lbann.Reshape(image, dims=str_list([np.prod(dims)]))
        inds = lbann.WeightsLayer(
            weights=lbann.Weights(
                lbann.ValueInitializer(values=str_list(inds_vals)),
                optimizer=lbann.NoOptimizer(),
            ),
            dims=str_list([len(inds_vals)]),
        )
        r_sums = lbann.Scatter(image, inds, dims=str_list([dims[0] * max_r]))
        scales = lbann.WeightsLayer(
            weights=lbann.Weights(
                lbann.ValueInitializer(values=str_list(scales_vals)),
                optimizer=lbann.NoOptimizer(),
            ),
            dims=str_list([len(scales_vals)]),
        )
        r_means = lbann.Multiply(scales, r_sums)
        return lbann.Reshape(r_means, dims=str_list([dims[0], max_r]))
示例#10
0
def create_position_ids_from_input_ids(input_ids,
                                       input_shape,
                                       padding_idx,
                                       past_key_values_length=0):
    padding_idx = lbann.Constant(value=padding_idx,
                                 num_neurons=str_list(input_shape))
    mask = lbann.NotEqual(input_ids, padding_idx)
    incremental_indices = lbann.Multiply(
        lbann.AddConstant(
            lbann.modules.Cumsum(mask, input_shape, axis=1),
            constant=past_key_values_length,
        ),
        mask,
    )
    incremental_indices = lbann.Add(incremental_indices, padding_idx)

    return incremental_indices
示例#11
0
文件: GINConv.py 项目: benson31/lbann
    def forward(self,
                node_feature_mat,
                source_indices,
                target_indices,
                activation=lbann.Relu):
        """Apply GIN  Layer. 
        
        Args:
            node_feature_mat (Layer): Node feature matrix with the shape of (num_nodes,input_channels) 
            source_indices (Layer): Source node indices of the edges with shape (num_nodes)
            target_indices (Layer): Target node indices of the edges with shape (num_nodes
            activation (Layer): Activation layer for the node features. If None, then no activation is 
                                applied. (default: lbann.Relu) 
        Returns: 
            (Layer) : The output after kernel ops. The output can passed into another Graph Conv layer
                          directly
        """
        eps = lbann.Constant(value=(1 + self.eps),
                             num_neurons=str_list(
                                 [self.num_nodes, self.input_channel_size]))

        eps_node_features = lbann.Multiply(node_feature_mat,
                                           eps,
                                           name=self.name + "_epl_mult")

        node_feature_mat = lbann.Sum(eps_node_features, node_feature_mat)

        # Transform with the sequence of linear layers
        for layer in self.nn:
            node_feature_mat = layer(node_feature_mat)

        neighborhoods = GraphExpand(node_feature_mat, target_indices)

        neighborhoods = lbann.Reshape(
            neighborhoods,
            dims=str_list([self.num_edges, self.output_channel_size]))

        aggregated_node_features = GraphReduce(
            neighborhoods, source_indices,
            [self.num_nodes, self.output_channel_size])
        ## Apply activation
        if activation:
            aggregated_node_features = activation(aggregated_node_features)

        return aggregated_node_features
示例#12
0
文件: GINConv.py 项目: oyamay/lbann
    def forward(self, X, A, activation = lbann.Relu):
        """Apply GIN  Layer. 
        
        Args:
            X (GraphVertexData): LBANN Data object, which is a collection of Layers. Each Layer is of
                                 the shape (1,input_channels) 

            A (Layer): Adjacency matrix input with shape (num_nodes, num_nodes)

            activation (Layer): Activation layer for the node features. If None, then no activation is 
                                applied. (default: lbann.Relu) 
        Returns: 
            
            (GraphVertexData): The output after GCN. The output can passed into another Graph Conv layer
                          directly
        """
        in_channel = X.shape[1]

        # Accumulate Messages from Neighboring Nodes
        out = X.get_mat()
        out = lbann.MatMul(A,out, name = self.name+"_GIN_MATMUL")
        message = GraphVertexData.matrix_to_graph(out, X.shape[0], in_channel)

        # Aggregate Messages into node features  
        eps = lbann.Constant(value=(1+self.eps),num_neurons = str_list([1, in_channel]))
        for node_feature in range(X.shape[0]):
            eps_val = lbann.Multiply(eps, X[node_feature])
            X[node_feature] = lbann.Sum(message[node_feature], eps_val)
        
        # Transform with the sequence of linear layers
        for layer in self.nn:
            for node_feature in range(X.shape[0]):
                X[node_feature] = layer(X[node_feature])
        
        ## Apply activation 
        if activation:
            for node_feature in range(X.shape[0]):
                X[node_feature] = activation(X[node_feature])
        X.update_num_features(self.output_channels) 
        return X
示例#13
0
def Silu(x):
    return lbann.Multiply(x, lbann.Sigmoid(x))
示例#14
0
文件: vae.py 项目: szaman19/lbann
    def compute_loss(self, x, y):

        # y[:, :-1]
        y = lbann.Slice(
            y,
            axis=0,
            slice_points=str_list([0, self.input_feature_dims - 1]),
        )
        y = lbann.Identity(y)

        # x[:, 1:]
        x = lbann.Slice(
            x,
            slice_points=str_list([1, self.input_feature_dims]),
        )
        x = lbann.Identity(x)

        # Convert indices in x to one-hot representation
        # Note: Ignored indices result in zero vectors
        ignore_mask = lbann.Equal(
            x,
            self.constant(self.label_to_ignore, hint_layer=x),
        )
        keep_mask = lbann.LogicalNot(ignore_mask)
        length = lbann.Reduction(keep_mask, mode='sum')
        length = lbann.Max(length, self.constant(1, [1]))
        x = lbann.Add(
            lbann.Multiply(keep_mask, x),
            lbann.Multiply(ignore_mask, self.constant(-1, hint_layer=x)),
        )
        x = lbann.Slice(x,
                        slice_points=str_list(range(self.input_feature_dims)))
        x = [lbann.Identity(x) for _ in range(self.input_feature_dims - 1)]
        x = [lbann.OneHot(xi, size=self.dictionary_size) for xi in x]
        x = [
            lbann.Reshape(xi, dims=str_list([1, self.dictionary_size]))
            for xi in x
        ]
        x = lbann.Concatenation(x, axis=0)

        # recon_loss = F.cross_entropy(
        #     y[:, :-1].contiguous().view(-1, y.size(-1)),
        #     x[:, 1:].contiguous().view(-1),
        #     ignore_index=self.pad
        # )
        # Note: Ideally we'd shift y by y.max(-1) for numerical stability
        shifts = lbann.MatMul(
            lbann.Max(y, self.constant(0, hint_layer=y)),
            self.constant(
                1 / math.sqrt(self.dictionary_size),
                [self.dictionary_size, self.dictionary_size],
            ),
        )
        y = lbann.Subtract(y, shifts)
        z = lbann.MatMul(
            lbann.Exp(y),
            self.constant(1, [self.dictionary_size, 1]),
        )
        z = lbann.Log(z)
        z = lbann.MatMul(
            lbann.Reshape(keep_mask, dims=str_list([1, -1])),
            z,
        )
        recon_loss = lbann.MatMul(
            lbann.Reshape(y, dims=str_list([1, -1])),
            lbann.Reshape(x, dims=str_list([1, -1])),
            transpose_b=True,
        )
        recon_loss = lbann.Subtract(z, recon_loss)
        recon_loss = lbann.Reshape(recon_loss, dims=str_list([1]))
        recon_loss = lbann.Divide(recon_loss, length)

        return recon_loss
示例#15
0
文件: vae.py 项目: benson31/lbann
    def compute_loss(self, x, y):

        # y[:, :-1]
        y = lbann.Slice(
            y,
            axis=0,
            slice_points=str_list([0, self.input_feature_dims-1]),
        )
        y = lbann.Identity(y)

        # x[:, 1:]
        x = lbann.Slice(
            x,
            slice_points=str_list([1, self.input_feature_dims]),
        )
        x = lbann.Identity(x)

        # Figure out entries in x to ignore
        ignore_mask = lbann.Equal(
            x,
            self.constant(self.label_to_ignore, hint_layer=x),
        )
        keep_mask = lbann.LogicalNot(ignore_mask)
        length = lbann.Reduction(keep_mask, mode='sum')
        length = lbann.Max(length, self.constant(1, [1]))

        # Convert entries in x to indices in y
        # Note: Ignored entries correspond to an index of -1.
        offsets = [
            row*self.dictionary_size
            for row in range(self.input_feature_dims-1)
        ]
        offsets = lbann.Weights(
            initializer=lbann.ValueInitializer(values=str_list(offsets)),
            optimizer=lbann.NoOptimizer(),
        )
        offsets = lbann.WeightsLayer(
            dims=str_list([self.input_feature_dims-1]),
            weights=offsets,
        )
        y_inds = lbann.Add(x, offsets)
        y_inds = lbann.Add(
            lbann.Multiply(keep_mask, y_inds),
            lbann.Multiply(
                ignore_mask,
                self.constant(-1, hint_layer=y_inds),
            ),
        )

        # recon_loss = F.cross_entropy(
        #     y[:, :-1].contiguous().view(-1, y.size(-1)),
        #     x[:, 1:].contiguous().view(-1),
        #     ignore_index=self.pad
        # )

        # Shift y for numerical stability
        # Note: We'd prefer to shift by y.max(-1)
        shifts = lbann.MatMul(
            lbann.Max(y, self.constant(0, hint_layer=y)),
            self.constant(
                1 / math.sqrt(self.dictionary_size),
                [self.dictionary_size, self.dictionary_size],
            ),
        )
        y = lbann.Subtract(y, shifts)

        # Compute log of softmax denominator and sum
        z = lbann.MatMul(
            lbann.Exp(y),
            self.constant(1, [self.dictionary_size, 1]),
        )
        z = lbann.Log(z)
        z = lbann.MatMul(
            lbann.Reshape(keep_mask, dims=str_list([1, -1])),
            z,
        )
        z = lbann.Reshape(z, dims=str_list([1]))

        # Compute cross entropy
        recon_loss = lbann.Gather(
            lbann.Reshape(y, dims=str_list([-1])),
            y_inds,
        )
        recon_loss = lbann.Reduction(recon_loss, mode='sum')
        recon_loss = lbann.Subtract(z, recon_loss)
        recon_loss = lbann.Divide(recon_loss, length)

        return recon_loss
# rho(x,y) = covariance(x,y) / sqrt( variance(x) * variance(y) )

pearson_r_cov = lbann.Covariance([reconstruction, data],
                                 name="pearson_r_cov",
                                 data_layout="model_parallel")

pearson_r_var1 = lbann.Variance(data,
                                name="pearson_r_var1",
                                data_layout="model_parallel")

pearson_r_var2 = lbann.Variance(reconstruction,
                                name="pearson_r_var1",
                                data_layout="model_parallel")

pearson_r_mult = lbann.Multiply([pearson_r_var1, pearson_r_var2],
                                name="pearson_r_mult",
                                data_layout="model_parallel")

pearson_r_sqrt = lbann.Sqrt(pearson_r_mult,
                            name="pearson_r_sqrt",
                            data_layout="model_parallel")

pearson_r = lbann.Divide([pearson_r_cov, pearson_r_sqrt],
                         name="pearson_r",
                         data_layout="model_parallel")

layer_list = list(lbann.traverse_layer_graph(input_))

# Set up objective function
layer_term = lbann.LayerTerm(mean_squared_error)
obj = lbann.ObjectiveFunction(layer_term)
示例#17
0
def construct_model(run_args):
    """Construct LBANN model.

    Initial model for ATOM molecular SMILES generation
    Network architecture and training hyperparameters from
    https://github.com/samadejacobs/moses/tree/master/moses/char_rnn

    """

    pad_index = run_args.pad_index
    assert pad_index is not None

    sequence_length = run_args.sequence_length
    assert sequence_length is not None

    print("sequence length is {}".format(sequence_length))
    data_layout = "data_parallel"

    # Layer graph
    _input = lbann.Input(name="inp_tensor", data_field='samples')
    print(sequence_length)
    x_slice = lbann.Slice(
        _input,
        axis=0,
        slice_points=str_list(range(sequence_length + 1)),
        name="inp_slice",
    )

    # embedding layer
    emb = []
    embedding_dim = run_args.embedding_dim
    num_embeddings = run_args.num_embeddings
    assert embedding_dim is not None
    assert num_embeddings is not None

    emb_weights = lbann.Weights(
        initializer=lbann.NormalInitializer(mean=0, standard_deviation=1),
        name="emb_matrix",
    )

    lstm1 = lbann.modules.GRU(size=run_args.hidden, data_layout=data_layout)
    fc = lbann.modules.FullyConnectedModule(size=num_embeddings,
                                            data_layout=data_layout)

    last_output = lbann.Constant(
        value=0.0,
        num_neurons="{}".format(run_args.hidden),
        data_layout=data_layout,
        name="lstm_init_output",
    )

    lstm1_prev_state = [last_output]

    loss = []
    idl = []
    for i in range(sequence_length):
        idl.append(
            lbann.Identity(x_slice, name="slice_idl_" + str(i), device="CPU"))

    for i in range(sequence_length - 1):

        emb_l = lbann.Embedding(
            idl[i],
            name="emb_" + str(i),
            weights=emb_weights,
            embedding_dim=embedding_dim,
            num_embeddings=num_embeddings,
        )

        x, lstm1_prev_state = lstm1(emb_l, lstm1_prev_state)
        fc_l = fc(x)
        y_soft = lbann.Softmax(fc_l, name="soft_" + str(i))
        gt = lbann.OneHot(idl[i + 1], size=num_embeddings)
        ce = lbann.CrossEntropy([y_soft, gt], name="loss_" + str(i))
        # mask padding in input
        pad_mask = lbann.NotEqual(
            [idl[i], lbann.Constant(value=pad_index, num_neurons="1")], )
        ce_mask = lbann.Multiply([pad_mask, ce], name="loss_mask_" + str(i))
        loss.append(lbann.LayerTerm(ce_mask, scale=1 / (sequence_length - 1)))

    layers = list(lbann.traverse_layer_graph(_input))
    # Setup objective function
    weights = set()
    for l in layers:
        weights.update(l.weights)
    obj = lbann.ObjectiveFunction(loss)

    callbacks = [
        lbann.CallbackPrint(),
        lbann.CallbackTimer(),
        lbann.CallbackStepLearningRate(step=run_args.step_size,
                                       amt=run_args.gamma),
        lbann.CallbackDumpWeights(directory=run_args.dump_weights_dir,
                                  epoch_interval=1),
    ]

    # Construct model
    return lbann.Model(run_args.num_epochs,
                       layers=layers,
                       weights=weights,
                       objective_function=obj,
                       callbacks=callbacks)
示例#18
0
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
    ):
        mixed_query_layer, query_shape = lbann.modules.PytorchLinear(
            hidden_states,
            self.input_shape,
            self.all_head_size,
            weights=_load_pretrained_weights(
                ".".join((self.name, "query.weight")),
                ".".join((self.name, "query.bias")),
                load_weights=self.load_weights,
            ),
            name=".".join((self.name, "query")),
            return_dims=True,
        )
        query_layer, query_shape = self.transpose_for_scores(
            mixed_query_layer, query_shape)

        key_layer, key_shape = lbann.modules.PytorchLinear(
            hidden_states,
            self.input_shape,
            self.all_head_size,
            weights=_load_pretrained_weights(
                ".".join((self.name, "key.weight")),
                ".".join((self.name, "key.bias")),
                load_weights=self.load_weights,
            ),
            name=".".join((self.name, "key")),
            return_dims=True,
        )
        key_layer, key_shape = self.transpose_for_scores(key_layer, key_shape)

        value_layer, value_shape = lbann.modules.PytorchLinear(
            hidden_states,
            self.input_shape,
            self.all_head_size,
            weights=_load_pretrained_weights(
                ".".join((self.name, "value.weight")),
                ".".join((self.name, "value.bias")),
                load_weights=self.load_weights,
            ),
            name=".".join((self.name, "value")),
            return_dims=True,
        )
        value_layer, value_shape = self.transpose_for_scores(
            value_layer, value_shape)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        key_layer, key_shape = lbann.modules.Permute(key_layer,
                                                     key_shape,
                                                     axes=(0, 1, -1, -2),
                                                     return_dims=True)
        attention_scores, attention_shape = lbann.modules.PytorchMatmul(
            query_layer,
            query_shape,
            key_layer,
            key_shape,
            return_dims=True,
        )

        attention_scores = lbann.Scale(attention_scores,
                                       constant=1 /
                                       math.sqrt(self.attention_head_size))

        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
            attention_scores = lbann.Add(attention_scores, attention_mask)

        # Normalize the attention scores to probabilities.
        attention_scores = lbann.Reshape(
            attention_scores,
            dims=str_list([np.prod(attention_shape[:-1]),
                           attention_shape[-1]]),
        )
        attention_probs = lbann.ChannelwiseSoftmax(attention_scores)
        attention_probs = lbann.Reshape(attention_probs,
                                        dims=str_list(attention_shape))

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = lbann.Dropout(
            attention_probs,
            keep_prob=self.attention_probs_dropout_prob,
        )

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = lbann.Multiply(attention_probs, head_mask)

        context_layer, context_shape = lbann.modules.PytorchMatmul(
            attention_probs,
            attention_shape,
            value_layer,
            value_shape,
            return_dims=True,
        )
        context_layer, context_shape = lbann.modules.Permute(
            context_layer,
            context_shape,
            axes=(0, 2, 1, 3),
            return_dims=True,
        )
        new_context_layer_shape = context_shape[:-2] + (self.all_head_size, )
        context_layer = lbann.Reshape(context_layer,
                                      dims=str_list(self.input_shape))

        return context_layer
示例#19
0
def make_model(
    num_epochs,
    embed_dim,
    num_heads,
    label_smoothing,
):

    # Embedding weights
    var = 2 / (embed_dim + vocab_size)  # Glorot initialization
    embedding_weights = lbann.Weights(
        name='embeddings',
        initializer=lbann.NormalInitializer(standard_deviation=math.sqrt(var)),
    )

    # Input is two sequences of token IDs
    input_ = lbann.Input(data_field='samples')

    # Get sequences of embedding vectors
    # Note: Scale embeddings by sqrt(embed_dim).
    # Note: Decoder input is shifted right, so embedding for last
    # token isn't needed.
    embeddings_tokens = lbann.Identity(
        lbann.Slice(
            input_,
            axis=0,
            slice_points=str_list([0, 2 * sequence_length - 1]),
        ))
    embeddings = lbann.Embedding(
        embeddings_tokens,
        weights=embedding_weights,
        num_embeddings=vocab_size,
        embedding_dim=embed_dim,
        padding_idx=pad_index,
    )
    embeddings = lbann.WeightedSum(
        embeddings,
        scaling_factors=str(math.sqrt(embed_dim)),
    )
    embeddings_slice = lbann.Slice(
        embeddings,
        axis=0,
        slice_points=str_list([0, sequence_length, 2 * sequence_length - 1]),
    )
    encoder_input = lbann.Identity(embeddings_slice)
    decoder_input = lbann.Identity(embeddings_slice)

    # Apply transformer model
    transformer = lbann.models.Transformer(
        hidden_size=embed_dim,
        num_heads=num_heads,
        name='transformer',
    )
    result = transformer(
        encoder_input,
        sequence_length,
        decoder_input,
        sequence_length - 1,
    )

    # Reconstruct decoder input
    preds = lbann.ChannelwiseFullyConnected(
        result,
        weights=embedding_weights,
        output_channel_dims=[vocab_size],
        bias=False,
        transpose=True,
    )
    preds = lbann.ChannelwiseSoftmax(preds)
    preds = lbann.Slice(preds,
                        axis=0,
                        slice_points=str_list(range(sequence_length)))
    preds = [lbann.Identity(preds) for _ in range(sequence_length - 1)]

    # Count number of non-pad tokens
    label_tokens = lbann.Identity(
        lbann.Slice(
            input_,
            slice_points=str_list([sequence_length + 1, 2 * sequence_length]),
        ))
    pads = lbann.Constant(value=pad_index,
                          num_neurons=str(sequence_length - 1))
    is_not_pad = lbann.NotEqual(label_tokens, pads)
    num_not_pad = lbann.Reduction(is_not_pad, mode='sum')

    # Cross entropy loss with label smoothing
    label_tokens = lbann.Slice(
        label_tokens,
        slice_points=str_list(range(sequence_length)),
    )
    label_tokens = [
        lbann.Identity(label_tokens) for _ in range(sequence_length - 1)
    ]
    if label_smoothing > 0:
        uniform_label = lbann.Constant(value=1 / vocab_size,
                                       num_neurons=str_list([1, vocab_size]))
    loss = []
    for i in range(sequence_length - 1):
        label = lbann.OneHot(label_tokens[i], size=vocab_size)
        label = lbann.Reshape(label, dims=str_list([1, vocab_size]))
        if label_smoothing > 0:
            label = lbann.WeightedSum(
                label,
                uniform_label,
                scaling_factors=str_list(
                    [1 - label_smoothing, label_smoothing]),
            )
        loss.append(lbann.CrossEntropy(preds[i], label))
    loss = lbann.Concatenation(loss)

    # Average cross entropy over non-pad tokens
    loss_scales = lbann.Divide(
        is_not_pad,
        lbann.Tessellate(num_not_pad, hint_layer=is_not_pad),
    )
    loss = lbann.Multiply(loss, loss_scales)
    loss = lbann.Reduction(loss, mode='sum')

    # Construct model
    metrics = []
    callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()]
    return lbann.Model(
        num_epochs,
        layers=lbann.traverse_layer_graph(input_),
        objective_function=loss,
        metrics=metrics,
        callbacks=callbacks,
    )
示例#20
0
    def forward(self, x, prev_state):
        """ Apply GRU step channelwise 
        Args: 
            x (Layer): Input (shape: (num_channels, *))
            prev_state (Layer): Sate from previous GRU step  (shape: (num_channels, size))
        Returns:
            (Layer, Layer): The output (out) and state (hn). The state can be passed directly into the next GRU step
        """

        self.step += 1

        name = f"{self.name}_step{self.step}"

        mat_size = self.num_channels * self.size

        prev_state = lbann.Reshape(prev_state,
                                   dims=str_list(
                                       [self.num_channels, self.size]),
                                   name=name + "_prev_state_reshape")

        fc1 = self.ih_fc(x)
        fc2 = self.hh_fc(prev_state)

        fc1_slice = lbann.Slice(
            fc1,
            axis=1,
            slice_points=str_list([0, self.size, 2 * self.size,
                                   3 * self.size]))

        Wir_x = lbann.Reshape(lbann.Identity(fc1_slice),
                              dims=str_list([self.num_channels, self.size]),
                              name=name + '_Wir_x')
        Wiz_z = lbann.Reshape(lbann.Identity(fc1_slice),
                              dims=str_list([self.num_channels, self.size]),
                              name=name + '_Wiz_z')
        Win_x = lbann.Reshape(lbann.Identity(fc1_slice),
                              dims=str_list([self.num_channels, self.size]),
                              name=name + '_Win_x')
        fc2_slice = lbann.Slice(
            fc2,
            axis=1,
            slice_points=str_list([0, self.size, 2 * self.size,
                                   3 * self.size]))

        Whr_x = lbann.Reshape(lbann.Identity(fc2_slice),
                              dims=str_list([self.num_channels, self.size]),
                              name=name + '_Whr_x')
        Whz_z = lbann.Reshape(lbann.Identity(fc2_slice),
                              dims=str_list([self.num_channels, self.size]),
                              name=name + '_Whz_z')
        Whn_x = lbann.Reshape(lbann.Identity(fc2_slice),
                              dims=str_list([self.num_channels, self.size]),
                              name=name + '_Whn_x')

        rt = \
            lbann.Sigmoid(
                lbann.Add(Wir_x, Whr_x, data_layout=self.data_layout),
                name=name + '_reset_gate',
                data_layout=self.data_layout
            )

        zt = \
            lbann.Sigmoid(
                lbann.Add(Wiz_z, Whz_z, data_layout=self.data_layout),
                name=name + '_update_gate',
                data_layout=self.data_layout,
            )

        nt = \
            lbann.Tanh(
                lbann.Add(
                    Win_x,
                    lbann.Multiply(rt, Whn_x, data_layout=self.data_layout),
                    data_layout=self.data_layout,
                ),
                name=name + '_new_gate', data_layout=self.data_layout,
            )

        ht = \
            lbann.Add(
                lbann.Multiply(
                    lbann.WeightedSum(
                        self.ones,
                        zt,
                        scaling_factors='1 -1', data_layout=self.data_layout
                    ),
                    nt,
                    data_layout=self.data_layout
                ),
                lbann.Multiply(zt, prev_state, data_layout=self.data_layout),
                name=name+ '_output', data_layout=self.data_layout,
            )

        ht = lbann.Reshape(ht, dims=str_list([self.num_channels, self.size]))

        return ht, ht
示例#21
0
    def forward(self, x, prev_state):
        """Apply LSTM step.

        Args:
            x (Layer): Input.
            prev_state (tuple with two `Layer`s): State from previous
                LSTM step. Comprised of LSTM output and cell state.

        Returns:
            (Layer, (Layer, Layer)): The output and state (the output
                and cell state). The state can be passed directly into
                the next LSTM step.

        """
        self.step += 1
        name = '{0}_step{1}'.format(self.name, self.step)

        # Get output and cell state from previous step
        prev_output, prev_cell = prev_state

        # Apply linearity
        input_concat = lbann.Concatenation(x, prev_output,
                                           name=name + '_input',
                                           data_layout=self.data_layout)
        fc = self.fc(input_concat)

        # Get gates and cell update
        slice = lbann.Slice(fc,
                            slice_points=str_list([0, self.size, 4*self.size]),
                            name=name + '_fc_slice',
                            data_layout=self.data_layout)
        cell_update = lbann.Tanh(slice,
                                 name=name + '_cell_update',
                                 data_layout=self.data_layout)
        sigmoid = lbann.Sigmoid(slice,
                                name=name + '_sigmoid',
                                data_layout=self.data_layout)
        slice = lbann.Slice(sigmoid,
                            slice_points=str_list([0, self.size, 2*self.size, 3*self.size]),
                            name=name + '_sigmoid_slice',
                            data_layout=self.data_layout)
        f = lbann.Identity(slice, name=name + '_forget_gate',
                           data_layout=self.data_layout)
        i = lbann.Identity(slice, name=name + '_input_gate',
                           data_layout=self.data_layout)
        o = lbann.Identity(slice, name=name + '_output_gate',
                           data_layout=self.data_layout)

        # Cell state
        cell_forget = lbann.Multiply(f, prev_cell,
                                     name=name + '_cell_forget',
                                     data_layout=self.data_layout)
        cell_input = lbann.Multiply(i, cell_update,
                                    name=name + '_cell_input',
                                    data_layout=self.data_layout)
        cell = lbann.Add(cell_forget, cell_input, name=name + '_cell',
                         data_layout=self.data_layout)

        # Output
        cell_act = lbann.Tanh(cell, name=name + '_cell_activation',
                              data_layout=self.data_layout)
        output = lbann.Multiply(o, cell_act, name=name,
                                data_layout=self.data_layout)

        # Return output and state
        return output, (output, cell)
示例#22
0
def Gelu(x):
    x_erf = lbann.Erf(lbann.Scale(x, constant=(1 / math.sqrt(2))))
    return lbann.Multiply(
        x, lbann.Scale(lbann.AddConstant(x_erf, constant=1), constant=0.5))
示例#23
0
images = lbann.Reshape(images, dims='1 300 300')

pred = model.PROBIESNet(num_labels)(images)

mse = lbann.MeanSquaredError([responses, pred])

# Pearson Correlation
# rho(x,y) = covariance(x,y) / sqrt( variance(x) * variance(y) )
pearson_r_cov = lbann.Covariance([pred, responses], name="pearson_r_cov")

pearson_r_var1 = lbann.Variance(responses, name="pearson_r_var1")

pearson_r_var2 = lbann.Variance(pred, name="pearson_r_var2")

pearson_r_mult = lbann.Multiply([pearson_r_var1, pearson_r_var2],
                                name="pearson_r_mult")

pearson_r_sqrt = lbann.Sqrt(pearson_r_mult, name="pearson_r_sqrt")

eps = lbann.Constant(value=1e-07, hint_layer=pearson_r_sqrt)
pearson_r = lbann.Divide(
    [pearson_r_cov, lbann.Add(pearson_r_sqrt, eps)], name="pearson_r")

metrics = [lbann.Metric(mse, name='mse')]
metrics.append(lbann.Metric(pearson_r, name='pearson_r'))

callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()]

layers = list(lbann.traverse_layer_graph([images, responses]))
model = lbann.Model(args.num_epochs,
                    layers=layers,
示例#24
0
    def forward(self, x, prev_state):
        """Apply GRU step.

        Args:
            x (Layer): Input.
            prev_state: State from previous GRU step.

        Returns:
            (Layer, Layer): The output (out)  and state (hn).
                          The state can be passed directly into
                           the next GRU step.

        """
        self.step += 1
        name = '{0}_step{1}'.format(self.name, self.step)


        fc1 = self.ih_fc(x)   #input_fc
        fc2 = self.hh_fc(prev_state)  #hidden_fc


        # Get gates and cell update
        fc1_slice = lbann.Slice(fc1,
                            slice_points=str_list([0, self.size, 2*self.size, 3*self.size]),
                            name=name + '_fc1_slice',
                            data_layout=self.data_layout)
        Wir_x = lbann.Identity(fc1_slice, name=name + '_Wrx',
                           data_layout=self.data_layout)
        Wiz_x = lbann.Identity(fc1_slice, name=name + '_Wzx',
                           data_layout=self.data_layout)
        Win_x = lbann.Identity(fc1_slice, name=name + '_Wnx',
                           data_layout=self.data_layout)

        fc2_slice = lbann.Slice(fc2,
                            slice_points=str_list([0, self.size, 2*self.size, 3*self.size]),
                            name=name + '_fc2_slice',
                            data_layout=self.data_layout)
        Whr_prev = lbann.Identity(fc2_slice, name=name + '_Wrh',
                           data_layout=self.data_layout)
        Whz_prev = lbann.Identity(fc2_slice, name=name + '_Wzh',
                           data_layout=self.data_layout)
        Whn_prev = lbann.Identity(fc2_slice, name=name + '_Wnh',
                           data_layout=self.data_layout)

        rt = \
            lbann.Sigmoid(
                lbann.Add(Wir_x, Whr_prev, data_layout=self.data_layout),
                name=name + '_reset_gate',
                data_layout=self.data_layout
            )

        zt = \
            lbann.Sigmoid(
                lbann.Add(Wiz_x, Whz_prev, data_layout=self.data_layout),
                name=name + '_update_gate',
                data_layout=self.data_layout,
            )

        nt = \
            lbann.Tanh(
                lbann.Add(
                    Win_x,
                    lbann.Multiply(rt, Whn_prev, data_layout=self.data_layout),
                    data_layout=self.data_layout,
                ),
                name=name + '_new_gate', data_layout=self.data_layout,
            )

        ht = \
            lbann.Add(
                lbann.Multiply(
                    lbann.WeightedSum(
                        self.ones,
                        zt,
                        scaling_factors='1 -1', data_layout=self.data_layout
                    ),
                    nt,
                    data_layout=self.data_layout
                ),
                lbann.Multiply(zt, prev_state, data_layout=self.data_layout),
                name=name+ '_output', data_layout=self.data_layout,
            )

        # Return output
        return ht, ht