Example #1
0
    def forward(self, x):
        """Perform LSTM step.

        State from previous steps is used to compute output.

        """
        self.step += 1
        name = '{0}_step{1}'.format(self.name, self.step)

        # Apply linearity
        input_concat = lbann.Concatenation([x, self.last_output],
                                           name=name + '_input',
                                           data_layout=self.data_layout)
        fc = self.fc(input_concat)

        # Get gates and cell update
        slice = lbann.Slice(fc,
                            slice_points=_str_list([0, self.size, 4*self.size]),
                            name=name + '_fc_slice',
                            data_layout=self.data_layout)
        cell_update = lbann.Tanh(slice,
                                 name=name + '_cell_update',
                                 data_layout=self.data_layout)
        sigmoid = lbann.Sigmoid(slice,
                                name=name + '_sigmoid',
                                data_layout=self.data_layout)
        slice = lbann.Slice(sigmoid,
                            slice_points=_str_list([0, self.size, 2*self.size, 3*self.size]),
                            name=name + '_sigmoid_slice',
                            data_layout=self.data_layout)
        f = lbann.Identity(slice, name=name + '_forget_gate',
                           data_layout=self.data_layout)
        i = lbann.Identity(slice, name=name + '_input_gate',
                           data_layout=self.data_layout)
        o = lbann.Identity(slice, name=name + '_output_gate',
                           data_layout=self.data_layout)

        # Cell state
        cell_forget = lbann.Multiply([f, self.last_cell],
                                     name=name + '_cell_forget',
                                     data_layout=self.data_layout)
        cell_input = lbann.Multiply([i, cell_update],
                                    name=name + '_cell_input',
                                    data_layout=self.data_layout)
        cell = lbann.Add([cell_forget, cell_input], name=name + '_cell',
                         data_layout=self.data_layout)

        # Output
        cell_act = lbann.Tanh(cell, name=name + '_cell_activation',
                              data_layout=self.data_layout)
        output = lbann.Multiply([o, cell_act], name=name,
                                data_layout=self.data_layout)

        # Update state and return output
        self.last_cell = cell
        self.last_output = output
        return output
Example #2
0
 def forward(self, x):
     self.instance += 1
     y1 = self.branch1(x) if self.branch1 else x
     y2 = self.branch2c(self.branch2b(self.branch2a(x)))
     z = lbann.Add([y1, y2],
                   name='{0}_sum_instance{1}'.format(
                       self.name, self.instance))
     return lbann.Relu(z,
                       name='{0}_relu_instance{1}'.format(
                           self.name, self.instance))
Example #3
0
def f_invtransform(y, scale=4.0):  ### Transform to original space
    '''
    The inverse of the transformation function that scales the data before training
    '''
    inv_transform = lbann.WeightedSum(lbann.SafeDivide(
        lbann.Add(lbann.Constant(value=1.0, hint_layer=y), lbann.Identity(y)),
        lbann.Subtract(lbann.Constant(value=1.0, hint_layer=y),
                       lbann.Identity(y))),
                                      scaling_factors=str(scale))

    return inv_transform
Example #4
0
    def inv_transform(self, y):  ### Original transformation
        '''
        The inverse of the transformation function that scales the data before training
        '''
        inv_transform = lbann.WeightedSum(lbann.SafeDivide(
            lbann.Add(lbann.Constant(value=1.0, hint_layer=y),
                      lbann.Identity(y)),
            lbann.Subtract(lbann.Constant(value=1.0, hint_layer=y),
                           lbann.Identity(y))),
                                          scaling_factors=str(self.datascale))

        return inv_transform
Example #5
0
    def forward_encoder(self, x_emb):
        """Encoder step, emulating z ~ E(x) = q_E(z|x)

        :param x_emb: (n_batch, len(x), d_z) of floats, embeddings for input sentence x
        :return: (n_batch, d_z) of floats, sample of latent vector z
        :return: float, kl term component of loss
        """

        # _, h = self.encoder_rnn(x, None)
        h = self.encoder_rnn(x_emb, None)

        h = lbann.Slice(
            h,
            slice_points=str_list([self.input_feature_dims-1,
                                   self.input_feature_dims]),
            axis=0,
        )
        h = lbann.Identity(h)

        mu, logvar = self.q_mu(h), self.q_logvar(h)

        # Set datatype of previous layers
        # Note: Depth-first search from mu and logvar to x_emb
        stack = [mu, logvar]
        in_stack = {l : True for l in stack}
        while stack:
            l = stack.pop()
            if type(l) not in (lbann.Slice, lbann.Reshape, lbann.Tessellate):
                l.datatype = self.datatype
            for parent in l.parents:
                if parent not in in_stack and parent is not x_emb:
                    stack.append(parent)
                    in_stack[parent] = True

        # eps = torch.randn_like(mu)
        eps = lbann.Gaussian(mean=0, stdev=1,hint_layer=mu)

        # z = mu + (logvar / 2).exp() * eps
        z = lbann.Add([mu, (lbann.Multiply([lbann.Exp(lbann.WeightedSum(logvar,scaling_factors='0.5')),eps]))])

        # kl_loss = 0.5 * (logvar.exp() + mu ** 2 - 1 - logvar).sum(1).mean()
        kl_loss = lbann.Reduction(
            lbann.WeightedSum(
                lbann.Exp(logvar),
                lbann.Square(mu),
                self.constant(1, hint_layer=mu),
                logvar,
                scaling_factors='0.5 0.5 -0.5 -0.5',
            ),
            mode='sum',
        )

        return z, kl_loss
Example #6
0
 def forward(self, inputs):
     raise NotImplementedError  # Requires log-gamma function
     if len(inputs) != 2:
         raise ValueError('expected two inputs: predictions and labels')
     pred = inputs[0]
     label = inputs[1]
     ones = lbann.Constant(hint_layer=pred, value=1.0)
     term1 = pred
     term2 = lbann.Multiply([label, lbann.Log(pred)])
     term3 = lbann.LogGamma(lbann.Add([label, ones]))
     full = lbann.WeightedSum([term1, term2, term3],
                              scaling_factors='1.0 -1.0 1.0')
     return lbann.Reduction(full)
Example #7
0
def Gelu_approx(x):
    # This approximates gelu and may be more performant
    # return 0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x ** 3)))
    # Based on: https://paperswithcode.com/method/gelu
    sqrt_2_over_pi = math.sqrt(2 / math.pi)
    b_coef = 0.044715
    x_cubed = lbann.Multiply(lbann.Multiply(lbann.Identity(x), x), x)
    inner_tanh_x_comp = lbann.Add(x, lbann.Scale(x_cubed, constant=b_coef))
    tanh_x = lbann.Tanh(lbann.Scale(inner_tanh_x_comp,
                                    constant=sqrt_2_over_pi))
    return lbann.Scale(lbann.Multiply(x, lbann.AddConstant(tanh_x,
                                                           constant=1)),
                       constant=0.5)
Example #8
0
    def forward(self, x, z):
        """Do the WAE forward step

        :param x: list of tensors of longs, embed representation of input
        :return: float, kl term component of loss
        :return: float, recon component of loss
        """

        x = lbann.Slice(x, slice_points=str_list([0, self.input_feature_dims]))
        x = lbann.Identity(x)
        x_emb = lbann.Embedding(x,
                                num_embeddings=self.dictionary_size,
                                embedding_dim=self.embedding_size,
                                name='emb',
                                weights=self.emb_weights)

        # Encoder: x -> z, kl_loss
        z_sample = self.forward_encoder(x_emb)

        eps = lbann.Gaussian(mean=self.gmean,
                             stdev=self.gstd,
                             hint_layer=z_sample)
        z_sample = lbann.Add([z_sample, eps])

        # Decoder: x, z -> recon_loss
        #pred = self.forward_decoder(x_emb, z_sample)
        pred, arg_max = self.forward_decoder(x_emb, z_sample)
        recon_loss = self.compute_loss(x, pred)

        # Hack to remove blocking GPU allreduce in evaluation layer
        #kl_loss = lbann.Identity(kl_loss, device='CPU')
        recon_loss = lbann.Identity(recon_loss, device='CPU')

        z_prior = lbann.Tessellate(
            lbann.Reshape(z, dims=str_list([1, self.zdim])),
            dims=str_list([self.input_feature_dims, self.zdim]),
        )

        d_real = self.discriminator0(
            lbann.Concatenation([x_emb, z_prior], axis=1))

        z_sample0 = lbann.Tessellate(
            lbann.Reshape(z_sample, dims=str_list([1, self.zdim])),
            dims=str_list([self.input_feature_dims, self.zdim]),
        )
        y_z_sample = lbann.Concatenation([x_emb, z_sample0], axis=1)

        d_fake = self.discriminator0(lbann.StopGradient(y_z_sample))
        d_adv = self.discriminator1(y_z_sample)  #freeze

        return recon_loss, d_real, d_fake, d_adv, arg_max
Example #9
0
 def inv_transform(self, y):
     '''
     The inverse of the transformation function that scales the data before training
     '''
     inv_transform = lbann.WeightedSum(lbann.SafeDivide(
         lbann.Add(lbann.Constant(value=1.0, hint_layer=y),
                   lbann.Identity(y)),
         lbann.Subtract(lbann.Constant(value=1.0, hint_layer=y),
                        lbann.Identity(y))),
                                       scaling_factors=str(self.datascale))
     #linear_scale = 1/self.linear_scaler
     #CH2 = lbann.Tanh(lbann.WeightedSum(inv_transform,scaling_factors=str(linear_scale)))
     #return CH2
     return inv_transform
Example #10
0
    def forward(
        self,
        motif_indices,
        motif_size,
        walk_indices,
        walk_length,
    ):

        # Apply generator
        fake_motif_indices, gen_prob, gen_log_prob = self.generator(
            walk_length,
            walk_indices,
            motif_size,
        )

        # Get discriminator embeddings in log-space
        all_motif_indices = lbann.Concatenation(motif_indices,
                                                fake_motif_indices)
        all_motif_log_embeddings = self.discriminator.get_log_embeddings(
            all_motif_indices)
        all_motif_log_embeddings = lbann.Slice(
            all_motif_log_embeddings,
            slice_points=str_list([0, motif_size, 2 * motif_size]),
        )
        real_motif_log_embeddings = lbann.Identity(all_motif_log_embeddings)
        fake_motif_log_embeddings = lbann.Identity(all_motif_log_embeddings)

        # Apply discriminator
        real_disc_prob, real_disc_log_not_prob \
            = self.discriminator(motif_size, real_motif_log_embeddings)
        fake_disc_prob, fake_disc_log_not_prob \
            = self.discriminator(motif_size, fake_motif_log_embeddings)

        # Loss function
        # L_disc = - log(D(real)) - log(1-D(fake))
        # L_gen = - log(G) * stop_gradient(log(1-D(fake)))
        real_disc_log_prob \
            = lbann.Log(lbann.Clamp(real_disc_prob, min=1e-37, max=1))
        disc_loss = lbann.WeightedSum(
            real_disc_log_prob,
            fake_disc_log_not_prob,
            scaling_factors=str_list([-1, -1]),
        )
        gen_loss = lbann.Multiply(
            gen_log_prob,
            lbann.StopGradient(fake_disc_log_not_prob),
        )
        loss = lbann.Add(disc_loss, gen_loss)

        return loss, real_disc_prob, fake_disc_prob, gen_prob
Example #11
0
def create_position_ids_from_input_ids(input_ids,
                                       input_shape,
                                       padding_idx,
                                       past_key_values_length=0):
    padding_idx = lbann.Constant(value=padding_idx,
                                 num_neurons=str_list(input_shape))
    mask = lbann.NotEqual(input_ids, padding_idx)
    incremental_indices = lbann.Multiply(
        lbann.AddConstant(
            lbann.modules.Cumsum(mask, input_shape, axis=1),
            constant=past_key_values_length,
        ),
        mask,
    )
    incremental_indices = lbann.Add(incremental_indices, padding_idx)

    return incremental_indices
Example #12
0
 def forward(self, inputs):
     raise NotImplementedError  # Requires log-gamma function
     if len(inputs) != 2:
         raise ValueError('expected two inputs: predictions and labels')
     pred = inputs[0]
     label = inputs[1]
     count = lbann.Reduction(label)
     alpha_sum = lbann.Reduction(pred)
     lgamma_alpha_sum = lbann.Reduction(lbann.LogGamma(pred))
     lgamma_alpha_level_count_sum = lbann.Reduction(
         lbann.LogGamma(lbann.Add([pred, label])))
     return lbann.WeightedSum([
         lbann.LogGamma(alpha_sum),
         lbann.LogGamma(lbann.Sum([count, alpha_sum])),
         lgamma_alpha_level_count, lgamma_alpha_sum
     ],
                              scaling_factors='-1.0 1.0 -1.0 1.0')
Example #13
0
    def inv_transform(self, y):  ### Original transformation
        '''
        The inverse of the transformation function that scales the data before training
        '''
        inv_transform = lbann.WeightedSum(lbann.SafeDivide(
            lbann.Add(lbann.Constant(value=1.0, hint_layer=y),
                      lbann.Identity(y)),
            lbann.Subtract(lbann.Constant(value=1.0, hint_layer=y),
                           lbann.Identity(y))),
                                          scaling_factors=str(self.datascale))

        return inv_transform

#      def inv_transform(self, y):### New tranformation : log-linear


#         threshold = lbann.Constant(value=0.5, hint_layer=y)
#         is_above_threshold = lbann.Greater(y, threshold)
#         is_below_threshold = lbann.LogicalNot(is_above_threshold)

#         below = lbann.SafeDivide(
#             lbann.Subtract(y, lbann.Constant(value=1, hint_layer=y)),
#             lbann.Constant(value=0.03, hint_layer=y),
#         )
#         above = lbann.Exp(lbann.SafeDivide(
#             lbann.Subtract(
#                 y,
#                 lbann.Constant(value=0.5-0.5/math.log(300)*math.log(50), hint_layer=y)),
#             lbann.Constant(value=0.5/math.log(300), hint_layer=y),
#         ))
#         return lbann.Add(
#             lbann.Multiply(is_above_threshold, above),
#             lbann.Multiply(is_below_threshold, below),
#         )

# def f_invtransform_new(y):
#     if y<=0.5:
#         a=0.03;b=-1.0
#         return (y-b)/a
#     elif y>0.5:
#         a=0.5/np.log(300)
#         b=0.5-a*np.log(50)
#         return np.exp((y-b)/a)
Example #14
0
def construct_model():
    """Construct LBANN model.

    Pilot1 Combo model

    """
    import lbann

    # Layer graph
    data = lbann.Input(data_field='samples')
    responses = lbann.Input(data_field='responses')

    pred = combo.Combo()(data)
    mse = lbann.MeanSquaredError([responses, pred])

    SS_res = lbann.Reduction(lbann.Square(lbann.Subtract(responses, pred)),
                             mode='sum')

    #SS_tot = var(x) = mean((x-mean(x))^2)
    mini_batch_size = lbann.MiniBatchSize()
    mean = lbann.Divide(lbann.BatchwiseReduceSum(responses), mini_batch_size)
    SS_tot = lbann.Divide(
        lbann.BatchwiseReduceSum(lbann.Square(lbann.Subtract(responses,
                                                             mean))),
        mini_batch_size)
    eps = lbann.Constant(value=1e-07, hint_layer=SS_tot)
    r2 = lbann.Subtract(lbann.Constant(value=1, num_neurons='1'),
                        lbann.Divide(SS_res, lbann.Add(SS_tot, eps)))

    metrics = [lbann.Metric(mse, name='mse')]
    metrics.append(lbann.Metric(r2, name='r2'))

    callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()]

    # Construct model
    num_epochs = 100
    layers = list(lbann.traverse_layer_graph([data, responses]))
    return lbann.Model(num_epochs,
                       layers=layers,
                       metrics=metrics,
                       objective_function=mse,
                       callbacks=callbacks)
Example #15
0
    def compute_loss(self, x, y):

        # y[:, :-1]
        y = lbann.Slice(
            y,
            axis=0,
            slice_points=str_list([0, self.input_feature_dims - 1]),
        )
        y = lbann.Identity(y)

        # x[:, 1:]
        x = lbann.Slice(
            x,
            slice_points=str_list([1, self.input_feature_dims]),
        )
        x = lbann.Identity(x)

        # Convert indices in x to one-hot representation
        # Note: Ignored indices result in zero vectors
        ignore_mask = lbann.Equal(
            x,
            self.constant(self.label_to_ignore, hint_layer=x),
        )
        keep_mask = lbann.LogicalNot(ignore_mask)
        length = lbann.Reduction(keep_mask, mode='sum')
        length = lbann.Max(length, self.constant(1, [1]))
        x = lbann.Add(
            lbann.Multiply(keep_mask, x),
            lbann.Multiply(ignore_mask, self.constant(-1, hint_layer=x)),
        )
        x = lbann.Slice(x,
                        slice_points=str_list(range(self.input_feature_dims)))
        x = [lbann.Identity(x) for _ in range(self.input_feature_dims - 1)]
        x = [lbann.OneHot(xi, size=self.dictionary_size) for xi in x]
        x = [
            lbann.Reshape(xi, dims=str_list([1, self.dictionary_size]))
            for xi in x
        ]
        x = lbann.Concatenation(x, axis=0)

        # recon_loss = F.cross_entropy(
        #     y[:, :-1].contiguous().view(-1, y.size(-1)),
        #     x[:, 1:].contiguous().view(-1),
        #     ignore_index=self.pad
        # )
        # Note: Ideally we'd shift y by y.max(-1) for numerical stability
        shifts = lbann.MatMul(
            lbann.Max(y, self.constant(0, hint_layer=y)),
            self.constant(
                1 / math.sqrt(self.dictionary_size),
                [self.dictionary_size, self.dictionary_size],
            ),
        )
        y = lbann.Subtract(y, shifts)
        z = lbann.MatMul(
            lbann.Exp(y),
            self.constant(1, [self.dictionary_size, 1]),
        )
        z = lbann.Log(z)
        z = lbann.MatMul(
            lbann.Reshape(keep_mask, dims=str_list([1, -1])),
            z,
        )
        recon_loss = lbann.MatMul(
            lbann.Reshape(y, dims=str_list([1, -1])),
            lbann.Reshape(x, dims=str_list([1, -1])),
            transpose_b=True,
        )
        recon_loss = lbann.Subtract(z, recon_loss)
        recon_loss = lbann.Reshape(recon_loss, dims=str_list([1]))
        recon_loss = lbann.Divide(recon_loss, length)

        return recon_loss
Example #16
0
meansq = lbann.Square(mu, name="meansq")

kldiv_plus_half = lbann.WeightedSum([meansq, var, logsd],
                                    name="kldiv_plus_half",
                                    scaling_factors='0.5 0.5 -1')

kldiv_full = lbann.Rsqrt(kldiv_plus_half, name="kldiv_full")

kldiv = lbann.Reduction(kldiv_full, name="kldiv", mode="sum")

# Generate sample
noise = lbann.Gaussian(name="noise", mean=0, stdev=1, hint_layer=mu)

sdnoise = lbann.Hadamard([noise, sd], name="sdnoise")

sample = lbann.Add([mu, sdnoise], name="sample")

# Decoder
decode4 = lbann.FullyConnected(sample,
                               name="decode4",
                               has_bias=True,
                               hint_layer=encode3)

decode4neuron = lbann.Relu(decode4, name="decode4neuron")

decode3 = lbann.FullyConnected(decode4neuron,
                               name="decode3",
                               has_bias=True,
                               hint_layer=encode2)

decode3neuron = lbann.Relu(decode3, name="decode3neuron")
Example #17
0
    def forward(self, queries, keys, values, mask=None):
        """Apply multi-head attention.

        The input and output tensors are interpreted as sequences of
        vectors, where the first tensor dimension is the sequence
        dimension.

        Args:
            queries (lbann.Layer): Sequence of query vectors.
            keys (lbann.Layer): Sequence of key vectors.
            values (lbann.Layer): Sequence of value vectors.
            mask (lbann.Layer, optional): Additive attention mask. If
                the (i,j) entry is very negative (e.g. -1e9), then the
                ith query does not attend to the jth key/value pair.

        Returns:
            lbann.Layer: Sequence of output vectors. The sequence
                length is the same as `queries`.

        """
        self.instance += 1
        name = f'{self.name}_instance{self.instance}'

        # Apply fully-connected layers to input sequences
        queries_fc = lbann.ChannelwiseFullyConnected(
            queries,
            weights=self.query_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}_queries_fc',
        )
        keys_fc = lbann.ChannelwiseFullyConnected(
            keys,
            weights=self.key_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}_keys_fc',
        )
        values_fc = lbann.ChannelwiseFullyConnected(
            values,
            weights=self.value_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}_values_fc',
        )

        # Slice embedding vectors for each head
        slice_points = str_list(self.head_dim * i
                                for i in range(self.num_heads + 1))
        queries_slice = lbann.Slice(
            queries_fc,
            axis=1,
            slice_points=slice_points,
            name=f'{name}_queries_slice',
        )
        keys_slice = lbann.Slice(
            keys_fc,
            axis=1,
            slice_points=slice_points,
            name=f'{name}_keys_slice',
        )
        values_slice = lbann.Slice(
            values_fc,
            axis=1,
            slice_points=slice_points,
            name=f'{name}_values_slice',
        )

        # Compute scaled dot-product attention for each head
        attentions = []
        for head in range(self.num_heads):
            head_name = f'{name}_head{head}'

            # Attention inputs
            q = lbann.Identity(queries_slice)
            k = lbann.Identity(keys_slice)
            v = lbann.Identity(values_slice)

            # Multiply queries and keys
            # Note: num_queries x num_keys
            y = lbann.MatMul(
                q,
                k,
                transpose_b=True,
                name=f'{head_name}_matmul',
            )
            y = lbann.WeightedSum(
                y,
                scaling_factors=str(1 / math.sqrt(self.head_dim)),
                name=f'{head_name}_scale',
            )
            if mask:
                y = lbann.Add(y, mask, name=f'{head_name}_mask')
            y = lbann.ChannelwiseSoftmax(y, name=f'{head_name}_softmax')

            # Attention output
            # Note: num_queries x head_dim
            attentions.append(lbann.MatMul(y, v, name=head_name))

        # Concatenate heads and apply fully-connected layer
        attentions = lbann.Concatenation(attentions,
                                         axis=1,
                                         name=f'{name}_heads_concat')
        outputs_fc = lbann.ChannelwiseFullyConnected(
            attentions,
            weights=self.output_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}',
        )
        return outputs_fc
Example #18
0
def construct_macc_surrogate_model(xdim, ydim, zdim, wae_mcf, surrogate_mcf,
                                   lambda_cyc, useCNN, dump_models,
                                   pretrained_dir, ltfb_batch_interval,
                                   num_epochs):
    """Construct MACC surrogate model.

    See https://arxiv.org/pdf/1912.08113.pdf model architecture and other details

    """
    # Layer graph
    input = lbann.Input(data_field='samples', name='inp_data')
    # data is 64*64*4 images + 15 scalar + 5 param
    inp_slice = lbann.Slice(input,
                            axis=0,
                            slice_points=str_list([0, ydim, ydim + xdim]),
                            name='inp_slice')
    gt_y = lbann.Identity(inp_slice, name='gt_y')
    gt_x = lbann.Identity(inp_slice, name='gt_x')  #param not used

    zero = lbann.Constant(value=0.0, num_neurons='1', name='zero')
    one = lbann.Constant(value=1.0, num_neurons='1', name='one')

    z = lbann.Gaussian(mean=0.0, stdev=1.0, neuron_dims="20")
    wae = macc_network_architectures.MACCWAE(
        zdim, ydim, cf=wae_mcf, use_CNN=useCNN)  #pretrained, freeze
    inv = macc_network_architectures.MACCInverse(xdim, cf=surrogate_mcf)
    fwd = macc_network_architectures.MACCForward(zdim, cf=surrogate_mcf)

    y_pred_fwd = wae.encoder(gt_y)

    param_pred_ = wae.encoder(gt_y)
    input_fake = inv(param_pred_)

    output_cyc = fwd(input_fake)
    y_image_re2 = wae.decoder(output_cyc)
    '''**** Train cycleGAN input params <--> latent space of (images, scalars) ****'''
    output_fake = fwd(gt_x)
    y_image_re = wae.decoder(output_fake)

    param_pred2_ = wae.encoder(y_image_re)
    input_cyc = inv(param_pred2_)

    L_l2_x = lbann.MeanSquaredError(input_fake, gt_x)
    L_cyc_x = lbann.MeanSquaredError(input_cyc, gt_x)

    L_l2_y = lbann.MeanSquaredError(output_fake, y_pred_fwd)
    L_cyc_y = lbann.MeanSquaredError(output_cyc, y_pred_fwd)

    #@todo slice here to separate scalar from image
    img_sca_loss = lbann.MeanSquaredError(y_image_re, gt_y)
    #L_cyc = L_cyc_y + L_cyc_x
    L_cyc = lbann.Add(L_cyc_y, L_cyc_x)

    #loss_gen0  = L_l2_y + lamda_cyc*L_cyc
    loss_gen0 = lbann.WeightedSum([L_l2_y, L_cyc],
                                  scaling_factors=f'1 {lambda_cyc}')
    loss_gen1 = lbann.WeightedSum([L_l2_x, L_cyc_y],
                                  scaling_factors=f'1 {lambda_cyc}')
    #loss_gen1  =  L_l2_x + lamda_cyc*L_cyc_y

    layers = list(lbann.traverse_layer_graph(input))
    weights = set()
    #Freeze appropriate (pretrained) weights
    pretrained_models = ["wae"]  #add macc?
    for l in layers:
        for idx in range(len(pretrained_models)):
            if (l.weights and pretrained_models[idx] in l.name):
                for w in range(len(l.weights)):
                    l.weights[w].optimizer = lbann.NoOptimizer()
        weights.update(l.weights)

    l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4)
    #d_adv_bce = lbann.LayerTerm(d_adv_bce,scale=0.01)
    # Setup objective function
    obj = lbann.ObjectiveFunction([loss_gen0, loss_gen1, l2_reg])
    # Initialize check metric callback
    metrics = [
        lbann.Metric(img_sca_loss, name='fw_loss'),
        lbann.Metric(L_l2_x, name='inverse loss'),
        lbann.Metric(L_cyc_y, name='output cycle loss'),
        lbann.Metric(L_cyc_x, name='param cycle loss')
    ]

    callbacks = [
        lbann.CallbackPrint(),
        lbann.CallbackSaveModel(dir=dump_models),
        lbann.CallbackLoadModel(dirs=str(pretrained_dir)),
        lbann.CallbackTimer()
    ]

    if (ltfb_batch_interval > 0):
        callbacks.append(
            lbann.CallbackLTFB(batch_interval=ltfb_batch_interval,
                               metric='fw_loss',
                               low_score_wins=True,
                               exchange_hyperparameters=True))
    # Construct model
    return lbann.Model(num_epochs,
                       weights=weights,
                       layers=layers,
                       metrics=metrics,
                       objective_function=obj,
                       callbacks=callbacks)
Example #19
0
    def forward(self, x, prev_state):
        """Apply GRU step.

        Args:
            x (Layer): Input.
            prev_state: State from previous GRU step.

        Returns:
            (Layer, Layer): The output (out)  and state (hn).
                          The state can be passed directly into
                           the next GRU step.

        """
        self.step += 1
        name = '{0}_step{1}'.format(self.name, self.step)


        fc1 = self.ih_fc(x)   #input_fc
        fc2 = self.hh_fc(prev_state)  #hidden_fc


        # Get gates and cell update
        fc1_slice = lbann.Slice(fc1,
                            slice_points=str_list([0, self.size, 2*self.size, 3*self.size]),
                            name=name + '_fc1_slice',
                            data_layout=self.data_layout)
        Wir_x = lbann.Identity(fc1_slice, name=name + '_Wrx',
                           data_layout=self.data_layout)
        Wiz_x = lbann.Identity(fc1_slice, name=name + '_Wzx',
                           data_layout=self.data_layout)
        Win_x = lbann.Identity(fc1_slice, name=name + '_Wnx',
                           data_layout=self.data_layout)

        fc2_slice = lbann.Slice(fc2,
                            slice_points=str_list([0, self.size, 2*self.size, 3*self.size]),
                            name=name + '_fc2_slice',
                            data_layout=self.data_layout)
        Whr_prev = lbann.Identity(fc2_slice, name=name + '_Wrh',
                           data_layout=self.data_layout)
        Whz_prev = lbann.Identity(fc2_slice, name=name + '_Wzh',
                           data_layout=self.data_layout)
        Whn_prev = lbann.Identity(fc2_slice, name=name + '_Wnh',
                           data_layout=self.data_layout)

        rt = \
            lbann.Sigmoid(
                lbann.Add(Wir_x, Whr_prev, data_layout=self.data_layout),
                name=name + '_reset_gate',
                data_layout=self.data_layout
            )

        zt = \
            lbann.Sigmoid(
                lbann.Add(Wiz_x, Whz_prev, data_layout=self.data_layout),
                name=name + '_update_gate',
                data_layout=self.data_layout,
            )

        nt = \
            lbann.Tanh(
                lbann.Add(
                    Win_x,
                    lbann.Multiply(rt, Whn_prev, data_layout=self.data_layout),
                    data_layout=self.data_layout,
                ),
                name=name + '_new_gate', data_layout=self.data_layout,
            )

        ht = \
            lbann.Add(
                lbann.Multiply(
                    lbann.WeightedSum(
                        self.ones,
                        zt,
                        scaling_factors='1 -1', data_layout=self.data_layout
                    ),
                    nt,
                    data_layout=self.data_layout
                ),
                lbann.Multiply(zt, prev_state, data_layout=self.data_layout),
                name=name+ '_output', data_layout=self.data_layout,
            )

        # Return output
        return ht, ht
Example #20
0
def construct_model():
    """Construct MACC surrogate model.

    See https://arxiv.org/pdf/1912.08113.pdf model architecture and other details

    """
    import lbann

    # Layer graph
    input = lbann.Input(data_field='samples', name='inp_data')
    # data is 64*64*4 images + 15 scalar + 5 param
    inp_slice = lbann.Slice(input,
                            axis=0,
                            slice_points=str_list(
                                [0, args.ydim, args.ydim + args.xdim]),
                            name='inp_slice')
    gt_y = lbann.Identity(inp_slice, name='gt_y')
    gt_x = lbann.Identity(inp_slice, name='gt_x')  #param not used

    zero = lbann.Constant(value=0.0, num_neurons='1', name='zero')
    one = lbann.Constant(value=1.0, num_neurons='1', name='one')

    z = lbann.Gaussian(mean=0.0, stdev=1.0, neuron_dims="20")
    wae = macc_models.MACCWAE(args.zdim,
                              args.ydim,
                              cf=args.wae_mcf,
                              use_CNN=args.useCNN)  #pretrained, freeze
    inv = macc_models.MACCInverse(args.xdim, cf=args.surrogate_mcf)
    fwd = macc_models.MACCForward(args.zdim, cf=args.surrogate_mcf)

    y_pred_fwd = wae.encoder(gt_y)

    param_pred_ = wae.encoder(gt_y)
    input_fake = inv(param_pred_)

    output_cyc = fwd(input_fake)
    y_image_re2 = wae.decoder(output_cyc)
    '''**** Train cycleGAN input params <--> latent space of (images, scalars) ****'''
    output_fake = fwd(gt_x)
    y_image_re = wae.decoder(output_fake)

    y_out = wae.decoder(y_pred_fwd)

    param_pred2_ = wae.encoder(y_image_re)
    input_cyc = inv(param_pred2_)

    L_l2_x = lbann.MeanSquaredError(
        input_fake, gt_x)  #(x,inv(enc(y)), (encoder+)inverse loss
    L_cyc_x = lbann.MeanSquaredError(
        input_cyc, gt_x)  #param, x cycle loss, from latent space

    L_l2_y = lbann.MeanSquaredError(
        output_fake, y_pred_fwd)  #pred error into latent space (enc(y),fw(x))
    L_cyc_y = lbann.MeanSquaredError(
        output_cyc,
        y_pred_fwd)  # pred error into latent space (enc(y), fw(inv(enc(y))))

    #@todo slice here to separate scalar from image
    img_sca_loss = lbann.MeanSquaredError(
        y_image_re,
        gt_y)  # (y,dec(fw(x))) #forward model to decoder, no latent space
    dec_fw_inv_enc_y = lbann.MeanSquaredError(
        y_image_re2, gt_y)  #(y, dec(fw(inv(enc(y))))) y->enc_z->x'->fw_z->y'
    wae_loss = lbann.MeanSquaredError(y_out, gt_y)  #(y, dec(enc(y)) '
    #L_cyc = L_cyc_y + L_cyc_x
    L_cyc = lbann.Add(L_cyc_y, L_cyc_x)

    #loss_gen0  = L_l2_y + lamda_cyc*L_cyc
    loss_gen0 = lbann.WeightedSum([L_l2_y, L_cyc],
                                  scaling_factors=f'1 {args.lamda_cyc}')
    loss_gen1 = lbann.WeightedSum([L_l2_x, L_cyc_y],
                                  scaling_factors=f'1 {args.lamda_cyc}')
    #loss_gen1  =  L_l2_x + lamda_cyc*L_cyc_y

    conc_out = lbann.Concatenation(
        [gt_x, wae_loss, img_sca_loss, dec_fw_inv_enc_y, L_l2_x],
        name='x_errors')
    layers = list(lbann.traverse_layer_graph(input))
    weights = set()
    for l in layers:
        weights.update(l.weights)

    # Setup objective function
    obj = lbann.ObjectiveFunction([loss_gen0, loss_gen1])
    # Initialize check metric callback
    metrics = [
        lbann.Metric(img_sca_loss, name='img_re1'),
        lbann.Metric(dec_fw_inv_enc_y, name='img_re2'),
        lbann.Metric(wae_loss, name='wae_loss'),
        lbann.Metric(L_l2_x, name='inverse loss'),
        lbann.Metric(L_cyc_y, name='output cycle loss'),
        lbann.Metric(L_cyc_x, name='param cycle loss')
    ]

    callbacks = [
        lbann.CallbackPrint(),
        lbann.CallbackDumpOutputs(layers=f'{conc_out.name}',
                                  execution_modes='test',
                                  directory=args.dump_outputs,
                                  batch_interval=1,
                                  format='npy'),
        lbann.CallbackTimer()
    ]

    # Construct model
    num_epochs = 1
    return lbann.Model(num_epochs,
                       weights=weights,
                       layers=layers,
                       serialize_io=True,
                       metrics=metrics,
                       objective_function=obj,
                       callbacks=callbacks)
Example #21
0
    def forward(
        self,
        input_ids=None,
        token_type_ids=None,
        position_ids=None,
        inputs_embeds=None,
    ):

        if position_ids is None:
            if input_ids is not None:
                position_ids = create_position_ids_from_input_ids(
                    input_ids,
                    self.input_shape,
                    self.padding_idx,
                )
            else:
                position_ids = self.create_position_ids_from_inputs_embeds(
                    inputs_embeds)

        if token_type_ids is None:
            token_type_ids = lbann.Constant(value=0,
                                            num_neurons=str_list(
                                                self.input_shape))

        if inputs_embeds is None:
            inputs_embeds = lbann.Embedding(
                input_ids,
                num_embeddings=self.vocab_size,
                embedding_dim=self.hidden_size,
                padding_idx=self.pad_token_id,
                weights=_load_pretrained_weights(
                    ".".join((self.name, "word_embeddings.weight")),
                    load_weights=self.load_weights,
                ),
                name=".".join((self.name, "word_embeddings")),
            )
        token_type_embeddings = lbann.Embedding(
            token_type_ids,
            num_embeddings=self.type_vocab_size,
            embedding_dim=self.hidden_size,
            weights=_load_pretrained_weights(
                ".".join((self.name, "token_type_embeddings.weight")),
                load_weights=self.load_weights,
            ),
            name=".".join((self.name, "token_type_embeddings")),
        )

        embeddings = lbann.Add(inputs_embeds, token_type_embeddings)
        if self.position_embedding_type == "absolute":
            position_embeddings = lbann.Embedding(
                position_ids,
                num_embeddings=self.max_position_embeddings,
                embedding_dim=self.hidden_size,
                padding_idx=self.pad_token_id,
                weights=_load_pretrained_weights(
                    ".".join((self.name, "position_embeddings.weight")),
                    load_weights=self.load_weights,
                ),
                name=".".join((self.name, "position_embeddings")),
            )
            embeddings = lbann.Add(embeddings, position_embeddings)

        embeddings = lbann.modules.PytorchLayerNorm(
            embeddings,
            self.layer_norm_eps,
            self.input_shape + (self.hidden_size, ),
            weights=_load_pretrained_weights(
                ".".join((self.name, "layernorm.weightbias")),
                load_weights=self.load_weights,
            ),
            name=".".join((self.name, "LayerNorm")),
        )
        embeddings = lbann.Dropout(embeddings,
                                   keep_prob=self.hidden_dropout_prob)
        return embeddings
Example #22
0
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
    ):
        mixed_query_layer, query_shape = lbann.modules.PytorchLinear(
            hidden_states,
            self.input_shape,
            self.all_head_size,
            weights=_load_pretrained_weights(
                ".".join((self.name, "query.weight")),
                ".".join((self.name, "query.bias")),
                load_weights=self.load_weights,
            ),
            name=".".join((self.name, "query")),
            return_dims=True,
        )
        query_layer, query_shape = self.transpose_for_scores(
            mixed_query_layer, query_shape)

        key_layer, key_shape = lbann.modules.PytorchLinear(
            hidden_states,
            self.input_shape,
            self.all_head_size,
            weights=_load_pretrained_weights(
                ".".join((self.name, "key.weight")),
                ".".join((self.name, "key.bias")),
                load_weights=self.load_weights,
            ),
            name=".".join((self.name, "key")),
            return_dims=True,
        )
        key_layer, key_shape = self.transpose_for_scores(key_layer, key_shape)

        value_layer, value_shape = lbann.modules.PytorchLinear(
            hidden_states,
            self.input_shape,
            self.all_head_size,
            weights=_load_pretrained_weights(
                ".".join((self.name, "value.weight")),
                ".".join((self.name, "value.bias")),
                load_weights=self.load_weights,
            ),
            name=".".join((self.name, "value")),
            return_dims=True,
        )
        value_layer, value_shape = self.transpose_for_scores(
            value_layer, value_shape)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        key_layer, key_shape = lbann.modules.Permute(key_layer,
                                                     key_shape,
                                                     axes=(0, 1, -1, -2),
                                                     return_dims=True)
        attention_scores, attention_shape = lbann.modules.PytorchMatmul(
            query_layer,
            query_shape,
            key_layer,
            key_shape,
            return_dims=True,
        )

        attention_scores = lbann.Scale(attention_scores,
                                       constant=1 /
                                       math.sqrt(self.attention_head_size))

        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
            attention_scores = lbann.Add(attention_scores, attention_mask)

        # Normalize the attention scores to probabilities.
        attention_scores = lbann.Reshape(
            attention_scores,
            dims=str_list([np.prod(attention_shape[:-1]),
                           attention_shape[-1]]),
        )
        attention_probs = lbann.ChannelwiseSoftmax(attention_scores)
        attention_probs = lbann.Reshape(attention_probs,
                                        dims=str_list(attention_shape))

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = lbann.Dropout(
            attention_probs,
            keep_prob=self.attention_probs_dropout_prob,
        )

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = lbann.Multiply(attention_probs, head_mask)

        context_layer, context_shape = lbann.modules.PytorchMatmul(
            attention_probs,
            attention_shape,
            value_layer,
            value_shape,
            return_dims=True,
        )
        context_layer, context_shape = lbann.modules.Permute(
            context_layer,
            context_shape,
            axes=(0, 2, 1, 3),
            return_dims=True,
        )
        new_context_layer_shape = context_shape[:-2] + (self.all_head_size, )
        context_layer = lbann.Reshape(context_layer,
                                      dims=str_list(self.input_shape))

        return context_layer
Example #23
0
    def forward(self, x, prev_state):
        """Apply LSTM step.

        Args:
            x (Layer): Input.
            prev_state (tuple with two `Layer`s): State from previous
                LSTM step. Comprised of LSTM output and cell state.

        Returns:
            (Layer, (Layer, Layer)): The output and state (the output
                and cell state). The state can be passed directly into
                the next LSTM step.

        """
        self.step += 1
        name = '{0}_step{1}'.format(self.name, self.step)

        # Get output and cell state from previous step
        prev_output, prev_cell = prev_state

        # Apply linearity
        input_concat = lbann.Concatenation(x, prev_output,
                                           name=name + '_input',
                                           data_layout=self.data_layout)
        fc = self.fc(input_concat)

        # Get gates and cell update
        slice = lbann.Slice(fc,
                            slice_points=str_list([0, self.size, 4*self.size]),
                            name=name + '_fc_slice',
                            data_layout=self.data_layout)
        cell_update = lbann.Tanh(slice,
                                 name=name + '_cell_update',
                                 data_layout=self.data_layout)
        sigmoid = lbann.Sigmoid(slice,
                                name=name + '_sigmoid',
                                data_layout=self.data_layout)
        slice = lbann.Slice(sigmoid,
                            slice_points=str_list([0, self.size, 2*self.size, 3*self.size]),
                            name=name + '_sigmoid_slice',
                            data_layout=self.data_layout)
        f = lbann.Identity(slice, name=name + '_forget_gate',
                           data_layout=self.data_layout)
        i = lbann.Identity(slice, name=name + '_input_gate',
                           data_layout=self.data_layout)
        o = lbann.Identity(slice, name=name + '_output_gate',
                           data_layout=self.data_layout)

        # Cell state
        cell_forget = lbann.Multiply(f, prev_cell,
                                     name=name + '_cell_forget',
                                     data_layout=self.data_layout)
        cell_input = lbann.Multiply(i, cell_update,
                                    name=name + '_cell_input',
                                    data_layout=self.data_layout)
        cell = lbann.Add(cell_forget, cell_input, name=name + '_cell',
                         data_layout=self.data_layout)

        # Output
        cell_act = lbann.Tanh(cell, name=name + '_cell_activation',
                              data_layout=self.data_layout)
        output = lbann.Multiply(o, cell_act, name=name,
                                data_layout=self.data_layout)

        # Return output and state
        return output, (output, cell)
Example #24
0
    def compute_loss(self, x, y):

        # y[:, :-1]
        y = lbann.Slice(
            y,
            axis=0,
            slice_points=str_list([0, self.input_feature_dims-1]),
        )
        y = lbann.Identity(y)

        # x[:, 1:]
        x = lbann.Slice(
            x,
            slice_points=str_list([1, self.input_feature_dims]),
        )
        x = lbann.Identity(x)

        # Figure out entries in x to ignore
        ignore_mask = lbann.Equal(
            x,
            self.constant(self.label_to_ignore, hint_layer=x),
        )
        keep_mask = lbann.LogicalNot(ignore_mask)
        length = lbann.Reduction(keep_mask, mode='sum')
        length = lbann.Max(length, self.constant(1, [1]))

        # Convert entries in x to indices in y
        # Note: Ignored entries correspond to an index of -1.
        offsets = [
            row*self.dictionary_size
            for row in range(self.input_feature_dims-1)
        ]
        offsets = lbann.Weights(
            initializer=lbann.ValueInitializer(values=str_list(offsets)),
            optimizer=lbann.NoOptimizer(),
        )
        offsets = lbann.WeightsLayer(
            dims=str_list([self.input_feature_dims-1]),
            weights=offsets,
        )
        y_inds = lbann.Add(x, offsets)
        y_inds = lbann.Add(
            lbann.Multiply(keep_mask, y_inds),
            lbann.Multiply(
                ignore_mask,
                self.constant(-1, hint_layer=y_inds),
            ),
        )

        # recon_loss = F.cross_entropy(
        #     y[:, :-1].contiguous().view(-1, y.size(-1)),
        #     x[:, 1:].contiguous().view(-1),
        #     ignore_index=self.pad
        # )

        # Shift y for numerical stability
        # Note: We'd prefer to shift by y.max(-1)
        shifts = lbann.MatMul(
            lbann.Max(y, self.constant(0, hint_layer=y)),
            self.constant(
                1 / math.sqrt(self.dictionary_size),
                [self.dictionary_size, self.dictionary_size],
            ),
        )
        y = lbann.Subtract(y, shifts)

        # Compute log of softmax denominator and sum
        z = lbann.MatMul(
            lbann.Exp(y),
            self.constant(1, [self.dictionary_size, 1]),
        )
        z = lbann.Log(z)
        z = lbann.MatMul(
            lbann.Reshape(keep_mask, dims=str_list([1, -1])),
            z,
        )
        z = lbann.Reshape(z, dims=str_list([1]))

        # Compute cross entropy
        recon_loss = lbann.Gather(
            lbann.Reshape(y, dims=str_list([-1])),
            y_inds,
        )
        recon_loss = lbann.Reduction(recon_loss, mode='sum')
        recon_loss = lbann.Subtract(z, recon_loss)
        recon_loss = lbann.Divide(recon_loss, length)

        return recon_loss
Example #25
0
# Pearson Correlation
# rho(x,y) = covariance(x,y) / sqrt( variance(x) * variance(y) )
pearson_r_cov = lbann.Covariance([pred, responses], name="pearson_r_cov")

pearson_r_var1 = lbann.Variance(responses, name="pearson_r_var1")

pearson_r_var2 = lbann.Variance(pred, name="pearson_r_var2")

pearson_r_mult = lbann.Multiply([pearson_r_var1, pearson_r_var2],
                                name="pearson_r_mult")

pearson_r_sqrt = lbann.Sqrt(pearson_r_mult, name="pearson_r_sqrt")

eps = lbann.Constant(value=1e-07, hint_layer=pearson_r_sqrt)
pearson_r = lbann.Divide(
    [pearson_r_cov, lbann.Add(pearson_r_sqrt, eps)], name="pearson_r")

metrics = [lbann.Metric(mse, name='mse')]
metrics.append(lbann.Metric(pearson_r, name='pearson_r'))

callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()]

layers = list(lbann.traverse_layer_graph([images, responses]))
model = lbann.Model(args.num_epochs,
                    layers=layers,
                    metrics=metrics,
                    objective_function=mse,
                    callbacks=callbacks)

# Load data reader from prototext
data_reader_proto = lbann.lbann_pb2.LbannPB()
Example #26
0
    def forward(self, x, prev_state):
        """ Apply GRU step channelwise 
        Args: 
            x (Layer): Input (shape: (num_channels, *))
            prev_state (Layer): Sate from previous GRU step  (shape: (num_channels, size))
        Returns:
            (Layer, Layer): The output (out) and state (hn). The state can be passed directly into the next GRU step
        """

        self.step += 1

        name = f"{self.name}_step{self.step}"

        mat_size = self.num_channels * self.size

        prev_state = lbann.Reshape(prev_state,
                                   dims=str_list(
                                       [self.num_channels, self.size]),
                                   name=name + "_prev_state_reshape")

        fc1 = self.ih_fc(x)
        fc2 = self.hh_fc(prev_state)

        fc1_slice = lbann.Slice(
            fc1,
            axis=1,
            slice_points=str_list([0, self.size, 2 * self.size,
                                   3 * self.size]))

        Wir_x = lbann.Reshape(lbann.Identity(fc1_slice),
                              dims=str_list([self.num_channels, self.size]),
                              name=name + '_Wir_x')
        Wiz_z = lbann.Reshape(lbann.Identity(fc1_slice),
                              dims=str_list([self.num_channels, self.size]),
                              name=name + '_Wiz_z')
        Win_x = lbann.Reshape(lbann.Identity(fc1_slice),
                              dims=str_list([self.num_channels, self.size]),
                              name=name + '_Win_x')
        fc2_slice = lbann.Slice(
            fc2,
            axis=1,
            slice_points=str_list([0, self.size, 2 * self.size,
                                   3 * self.size]))

        Whr_x = lbann.Reshape(lbann.Identity(fc2_slice),
                              dims=str_list([self.num_channels, self.size]),
                              name=name + '_Whr_x')
        Whz_z = lbann.Reshape(lbann.Identity(fc2_slice),
                              dims=str_list([self.num_channels, self.size]),
                              name=name + '_Whz_z')
        Whn_x = lbann.Reshape(lbann.Identity(fc2_slice),
                              dims=str_list([self.num_channels, self.size]),
                              name=name + '_Whn_x')

        rt = \
            lbann.Sigmoid(
                lbann.Add(Wir_x, Whr_x, data_layout=self.data_layout),
                name=name + '_reset_gate',
                data_layout=self.data_layout
            )

        zt = \
            lbann.Sigmoid(
                lbann.Add(Wiz_z, Whz_z, data_layout=self.data_layout),
                name=name + '_update_gate',
                data_layout=self.data_layout,
            )

        nt = \
            lbann.Tanh(
                lbann.Add(
                    Win_x,
                    lbann.Multiply(rt, Whn_x, data_layout=self.data_layout),
                    data_layout=self.data_layout,
                ),
                name=name + '_new_gate', data_layout=self.data_layout,
            )

        ht = \
            lbann.Add(
                lbann.Multiply(
                    lbann.WeightedSum(
                        self.ones,
                        zt,
                        scaling_factors='1 -1', data_layout=self.data_layout
                    ),
                    nt,
                    data_layout=self.data_layout
                ),
                lbann.Multiply(zt, prev_state, data_layout=self.data_layout),
                name=name+ '_output', data_layout=self.data_layout,
            )

        ht = lbann.Reshape(ht, dims=str_list([self.num_channels, self.size]))

        return ht, ht