Ejemplo n.º 1
0
    def __init__(self, input_feature_dims,dictionary_size, embedding_size, ignore_label, name=None):
        """Initialize Molecular VAE.

        Args:
            input_feature_dims (int): analogous to sequence length.
            dictionary_size (int): vocabulary size
            embedding_size (int): embedding size
            ignore_label (int): padding index
            name (str, optional): Module name
                (default: 'molvae_module<index>').

        """
        MolVAE.global_count += 1
        self.instance = 0
        self.name = (name if name
                     else 'molvae_module{0}'.format(MolVAE.global_count))

        self.input_feature_dims = input_feature_dims
        self.embedding_size = embedding_size
        self.dictionary_size = dictionary_size
        self.label_to_ignore = ignore_label
        self.datatype = lbann.DataType.FLOAT
        self.weights_datatype = lbann.DataType.FLOAT

        fc = lbann.modules.FullyConnectedModule
        gru = GRUModule

        #Encoder
        self.encoder_rnn = gru(
            hidden_size=256,
            name=self.name+'_encoder_rnn',
            datatype=self.datatype,
            weights_datatype=self.weights_datatype,
        )
        self.q_mu = fc(128,name=self.name+'_encoder_qmu')
        self.q_logvar = fc(128,name=self.name+'_encoder_qlogvar')
        for w in self.q_mu.weights + self.q_logvar.weights:
            w.datatype = self.weights_datatype

        #Decoder
        self.decoder_rnn = gru(
            hidden_size=512,
            num_layers=3,
            name=self.name+'_decoder_rnn',
            datatype=self.datatype,
            weights_datatype=self.weights_datatype,
        )
        self.decoder_lat = fc(512, name=self.name+'_decoder_lat')
        self.decoder_fc = fc(self.dictionary_size, name=self.name+'_decoder_fc')
        for w in self.decoder_lat.weights + self.decoder_fc.weights:
            w.datatype = self.weights_datatype
        self.decoder_fc.weights[0].initializer = lbann.NormalInitializer(
            mean=0, standard_deviation=1/math.sqrt(512))

        #shared encoder/decoder weights
        self.emb_weights = lbann.Weights(
            initializer=lbann.NormalInitializer(mean=0, standard_deviation=1),
            name='emb_matrix',
            datatype=self.weights_datatype,
        )
Ejemplo n.º 2
0
    def __init__(self, mcr, name=None):
        
        self.instance = 0
        self.name = (name if name else 'ExaGAN{0}'.format(CosmoGAN.global_count))
        
        ## Gathering the CNN modules into variables
        convbnrelu = lbann.models.resnet.ConvBNRelu
        fc = lbann.modules.FullyConnectedModule
        conv = lbann.modules.Convolution2dModule
        
        #bn_stats_grp_sz = 0 #0 global, 1 local
        bn_stats_grp_sz = -1 #0 global, 1 local
        self.datascale = 4.0
        self.linear_scaler=1000.0
        self.inits = {'dense': lbann.NormalInitializer(mean=0,standard_deviation=0.02),
                      'conv': lbann.NormalInitializer(mean=0,standard_deviation=0.02), #should be truncated Normal
                      'convT':lbann.NormalInitializer(mean=0,standard_deviation=0.02)}
        
        #########################
        ##### Discriminator
        d_neurons = [64,128,256,512]
        d_kernel_size,d_stride,d_padding=5,2,2
        
        ### Implementing convolution, bnorm using convbrelu
        ##self, out_channels, kernel_size, stride, padding, bn_zero_init, bn_statistics_group_size, relu, name
        self.d1_conv = [convbnrelu(layer, kernel_size=d_kernel_size, stride=d_stride, padding=d_padding, bn_zero_init=False, bn_statistics_group_size=bn_stats_grp_sz, relu=False, name=self.name+'_disc1_conv'+str(i)) for i,layer in enumerate(d_neurons)]
        
        ## Trying without convbrelu
#         self.d1_conv = [conv(layer,d_kernel_size, stride=d_stride, padding=d_padding, transpose=False, bias= False, weights=[lbann.Weights(initializer=self.inits['conv'])], name=self.name+'_disc1_conv'+str(i)) for i,layer in enumerate(d_neurons)]
        
        ### Fully connected layer
        ##self,size,bias=True,transpose=False,weights=[],activation=None,name=None,data_layout='data_parallel',parallel_strategy={}): 
        self.d1_fc = fc(1,name=self.name+'_disc1_fc', weights=[lbann.Weights(initializer=self.inits['dense'])])
        
        #stacked_discriminator, this will be frozen, no optimizer, 
        #layer has to be named for callback
        self.d2_conv = [convbnrelu(layer, d_kernel_size, d_stride, d_padding, False, bn_stats_grp_sz, False,name=self.name+'_disc2_conv'+str(i)) for i,layer in enumerate(d_neurons)] 
        
#         self.d2_conv = [conv(layer,d_kernel_size, stride=d_stride, padding=d_padding, transpose=False, bias=False, weights=[lbann.Weights(initializer=self.inits['conv'])], name=self.name+'_disc2_conv'+str(i)) for i,layer in enumerate(d_neurons)]

        self.d2_fc = fc(1,name=self.name+'_disc2_fc', weights=[lbann.Weights(initializer=self.inits['dense'])])
        
        #########################
        ##### Generator
        g_neurons = [256,128,64]
        g_kernel_size,g_stride,g_padding=5,2,2

        ### Transpose convolution
        ##(self, num_dims,out_channels,kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True,weights=[],activation=None,name=None,transpose=False,parallel_strategy={})
        self.g_convT = [conv(layer, g_kernel_size, stride=g_stride, padding=g_padding, transpose=True, weights=[lbann.Weights(initializer=self.inits['convT'])]) for i,layer in enumerate(g_neurons)] 
        
        ### Fully connected
        fc_size=524288 ### (8 * 8 * 2 * 256)
        self.g_fc1 = fc(fc_size,name=self.name+'_gen_fc1', weights=[lbann.Weights(initializer=self.inits['dense'])])
        
        ### Final conv transpose
        self.g_convT3 = conv(1, g_kernel_size, stride=g_stride, padding=g_padding, activation=lbann.Tanh,name='gen_img',transpose=True,
                       weights=[lbann.Weights(initializer=self.inits['convT'])])
Ejemplo n.º 3
0
    def __init__(self, mcr, name=None):

        self.instance = 0
        self.name = (name
                     if name else 'ExaGAN{0}'.format(CosmoGAN.global_count))

        ## Gathering the CNN modules into variables
        convbnrelu = lbann.models.resnet.ConvBNRelu
        fc = lbann.modules.FullyConnectedModule
        conv = lbann.modules.Convolution2dModule

        #bn_stats_grp_sz = 0 #0 global, 1 local
        bn_stats_grp_sz = -1  #0 global, 1 local
        self.datascale = 4.0
        self.linear_scaler = 1000.0
        self.inits = {
            'dense': lbann.NormalInitializer(mean=0, standard_deviation=0.02),
            'conv': lbann.NormalInitializer(
                mean=0, standard_deviation=0.02),  #should be truncated Normal
            'convT': lbann.NormalInitializer(mean=0, standard_deviation=0.02)
        }

        #########################
        ##### Generator
        g_neurons = [256, 128, 64]
        g_kernel_size, g_stride, g_padding = 5, 2, 2

        ### Transpose convolution
        ##(self, num_dims,out_channels,kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True,weights=[],activation=None,name=None,transpose=False,parallel_strategy={})
        self.g_convT = [
            conv(layer,
                 g_kernel_size,
                 stride=g_stride,
                 padding=g_padding,
                 transpose=True,
                 weights=[lbann.Weights(initializer=self.inits['convT'])])
            for i, layer in enumerate(g_neurons)
        ]

        ### Fully connected
        fc_size = 32768  ### (8 * 8 * 2 * 256)
        self.g_fc1 = fc(
            fc_size,
            name=self.name + '_gen_fc1',
            weights=[lbann.Weights(initializer=self.inits['dense'])])

        ### Final conv transpose
        self.g_convT3 = conv(
            1,
            g_kernel_size,
            stride=g_stride,
            padding=g_padding,
            activation=lbann.Tanh,
            name='gen_img',
            transpose=True,
            weights=[lbann.Weights(initializer=self.inits['convT'])])
Ejemplo n.º 4
0
    def __init__(self, name=None):
       self.instance = 0
       self.name = (name if name
                     else 'ExaGAN{0}'.format(CosmoGAN.global_count))

       convbnrelu = lbann.models.resnet.ConvBNRelu
       fc = lbann.modules.FullyConnectedModule
       conv = lbann.modules.Convolution2dModule
       #bn_stats_grp_sz = 0 #0 global, 1 local
       bn_stats_grp_sz = -1 #0 global, 1 local

       ##MCR properties #@todo: make multichannel optional
       self.datascale = 4 
       self.linear_scaler=1000.

       self.inits = {'dense': lbann.NormalInitializer(mean=0,standard_deviation=0.02),
                      'conv': lbann.NormalInitializer(mean=0,standard_deviation=0.02), #should be truncated Normal
                      'convT':lbann.NormalInitializer(mean=0,standard_deviation=0.02)}
       
       d_neurons = [64,128,256,512]
       self.d1_conv = [convbnrelu(d_neurons[i], 4, 2, 1, False, bn_stats_grp_sz, False,name=self.name+'_disc1_conv'+str(i))
                   for i in range(len(d_neurons))] 
       self.d1_fc = fc(1,name=self.name+'_disc1_fc',
                       weights=[lbann.Weights(initializer=self.inits['dense'])])

       #stacked_discriminator, this will be frozen, no optimizer, 
       #layer has to be named for callback
       self.d2_conv = [convbnrelu(d_neurons[i], 4, 2, 1, False, bn_stats_grp_sz, False,name=self.name+'_disc2_conv'+str(i))
                   for i in range(len(d_neurons))] 
       self.d2_fc = fc(1,name=self.name+'_disc2_fc',
                       weights=[lbann.Weights(initializer=self.inits['dense'])])
       #generator
       g_neurons = [256,128,64]
      
       self.g_convT = [conv(g_neurons[i], 5, stride=2, padding=2, transpose=True,
                       weights=[lbann.Weights(initializer=self.inits['convT'])])
                       for i in range(len(g_neurons))] 
 
       self.g_fc1 = fc(32768,name=self.name+'_gen_fc1',
                       weights=[lbann.Weights(initializer=self.inits['dense'])])
       self.g_convT3 = conv(1, 5, stride=2, padding=2, activation=lbann.Tanh,name='gen_img',transpose=True,
                       weights=[lbann.Weights(initializer=self.inits['convT'])])
Ejemplo n.º 5
0
        num_negative_samples=num_negative_samples,
    )

# ----------------------------------
# Construct layer graph
# ----------------------------------
obj = []
metrics = []

# Embedding vectors, including negative sampling
# Note: Input is sequence of vertex IDs
input_ = lbann.Input(data_field='samples')
if args.embeddings == 'distributed':
    embeddings_weights = lbann.Weights(
        initializer=lbann.NormalInitializer(
            mean=0,
            standard_deviation=1 / args.latent_dim,
        ),
        name='embeddings',
    )
    embeddings = lbann.DistEmbedding(
        input_,
        weights=embeddings_weights,
        num_embeddings=num_vertices,
        embedding_dim=args.latent_dim,
        sparse_sgd=True,
        learning_rate=args.learning_rate,
    )
elif args.embeddings == 'replicated':
    embeddings_weights = lbann.Weights(
        initializer=lbann.NormalInitializer(
            mean=0,
Ejemplo n.º 6
0
    def __init__(self,
                 input_width=64,
                 input_channel=1,
                 gen_device='GPU',
                 disc_ps=None,
                 gen_ps=None,
                 use_bn=False,
                 name=None):

        self.instance = 0
        self.name = (name
                     if name else 'Exa3DGAN{0}'.format(Exa3DGAN.global_count))

        convbnrelu = ConvBNRelu
        fc = lbann.modules.FullyConnectedModule
        conv = lbann.modules.Convolution3dModule
        bn_stats_grp_sz = -1  #0 global, 1 local
        self.input_width = input_width
        self.input_channel = input_channel

        self.g_device = gen_device
        #Set parallel strategy
        self.d_ps = disc_ps
        self.g_ps = gen_ps
        self.use_bn = use_bn

        assert self.input_width in [64, 128, 256, 512]

        w = [int(self.input_width / 16)
             ] * 3  #filter size in last disc conv and first gen conv
        w.insert(0, 512)  ##num filters in last disc conv and first gen conv
        self.outc_dims = w

        self.inits = {
            'dense': lbann.NormalInitializer(mean=0, standard_deviation=0.02),
            'conv': lbann.NormalInitializer(mean=0, standard_deviation=0.02),
            'convT': lbann.NormalInitializer(mean=0, standard_deviation=0.02)
        }

        #Discriminator
        d_channels = [64, 128, 256, 512]
        kernel_size = 5
        padding = 2
        stride = 2
        self.d1_conv = [
            convbnrelu(
                d_channels[i],
                kernel_size,
                stride,
                padding,
                self.use_bn,
                bn_stats_grp_sz,
                False,
                name=self.name + '_disc1_conv' + str(i),
                activation=lbann.Relu,
                parallel_strategy=self.d_ps,
                conv_weights=[lbann.Weights(initializer=self.inits['conv'])])
            for i in range(len(d_channels))
        ]
        self.d1_fc = fc(
            1,
            name=self.name + '_disc1_fc',
            weights=[lbann.Weights(initializer=self.inits['dense'])])

        #stacked_discriminator, this will be frozen, no optimizer,
        #layer has to be named for callback
        self.d2_conv = [
            convbnrelu(
                d_channels[i],
                kernel_size,
                stride,
                padding,
                self.use_bn,
                bn_stats_grp_sz,
                False,
                name=self.name + '_disc2_conv' + str(i),
                activation=lbann.Relu,
                parallel_strategy=self.d_ps,
                conv_weights=[lbann.Weights(initializer=self.inits['conv'])])
            for i in range(len(d_channels))
        ]

        self.d2_fc = fc(
            1,
            name=self.name + '_disc2_fc',
            weights=[lbann.Weights(initializer=self.inits['dense'])])

        g_channels = [256, 128, 64]
        kernel_size = 2
        padding = 0
        self.g_convT = [
            conv(g_channels[i],
                 kernel_size,
                 stride,
                 padding,
                 transpose=True,
                 name=self.name + '_gen' + str(i),
                 parallel_strategy=self.g_ps,
                 weights=[lbann.Weights(initializer=self.inits['convT'])])
            for i in range(len(g_channels))
        ]

        self.g_convT3 = conv(
            input_channel,
            kernel_size,
            stride,
            padding,
            activation=lbann.Tanh,
            parallel_strategy=self.g_ps,
            name='gen_img',
            transpose=True,
            weights=[lbann.Weights(initializer=self.inits['convT'])])
Ejemplo n.º 7
0
def make_model(
    num_epochs,
    embed_dim,
    num_heads,
    label_smoothing,
):

    # Embedding weights
    var = 2 / (embed_dim + vocab_size)  # Glorot initialization
    embedding_weights = lbann.Weights(
        name='embeddings',
        initializer=lbann.NormalInitializer(standard_deviation=math.sqrt(var)),
    )

    # Input is two sequences of token IDs
    input_ = lbann.Input(data_field='samples')

    # Get sequences of embedding vectors
    # Note: Scale embeddings by sqrt(embed_dim).
    # Note: Decoder input is shifted right, so embedding for last
    # token isn't needed.
    embeddings_tokens = lbann.Identity(
        lbann.Slice(
            input_,
            axis=0,
            slice_points=str_list([0, 2 * sequence_length - 1]),
        ))
    embeddings = lbann.Embedding(
        embeddings_tokens,
        weights=embedding_weights,
        num_embeddings=vocab_size,
        embedding_dim=embed_dim,
        padding_idx=pad_index,
    )
    embeddings = lbann.WeightedSum(
        embeddings,
        scaling_factors=str(math.sqrt(embed_dim)),
    )
    embeddings_slice = lbann.Slice(
        embeddings,
        axis=0,
        slice_points=str_list([0, sequence_length, 2 * sequence_length - 1]),
    )
    encoder_input = lbann.Identity(embeddings_slice)
    decoder_input = lbann.Identity(embeddings_slice)

    # Apply transformer model
    transformer = lbann.models.Transformer(
        hidden_size=embed_dim,
        num_heads=num_heads,
        name='transformer',
    )
    result = transformer(
        encoder_input,
        sequence_length,
        decoder_input,
        sequence_length - 1,
    )

    # Reconstruct decoder input
    preds = lbann.ChannelwiseFullyConnected(
        result,
        weights=embedding_weights,
        output_channel_dims=[vocab_size],
        bias=False,
        transpose=True,
    )
    preds = lbann.ChannelwiseSoftmax(preds)
    preds = lbann.Slice(preds,
                        axis=0,
                        slice_points=str_list(range(sequence_length)))
    preds = [lbann.Identity(preds) for _ in range(sequence_length - 1)]

    # Count number of non-pad tokens
    label_tokens = lbann.Identity(
        lbann.Slice(
            input_,
            slice_points=str_list([sequence_length + 1, 2 * sequence_length]),
        ))
    pads = lbann.Constant(value=pad_index,
                          num_neurons=str(sequence_length - 1))
    is_not_pad = lbann.NotEqual(label_tokens, pads)
    num_not_pad = lbann.Reduction(is_not_pad, mode='sum')

    # Cross entropy loss with label smoothing
    label_tokens = lbann.Slice(
        label_tokens,
        slice_points=str_list(range(sequence_length)),
    )
    label_tokens = [
        lbann.Identity(label_tokens) for _ in range(sequence_length - 1)
    ]
    if label_smoothing > 0:
        uniform_label = lbann.Constant(value=1 / vocab_size,
                                       num_neurons=str_list([1, vocab_size]))
    loss = []
    for i in range(sequence_length - 1):
        label = lbann.OneHot(label_tokens[i], size=vocab_size)
        label = lbann.Reshape(label, dims=str_list([1, vocab_size]))
        if label_smoothing > 0:
            label = lbann.WeightedSum(
                label,
                uniform_label,
                scaling_factors=str_list(
                    [1 - label_smoothing, label_smoothing]),
            )
        loss.append(lbann.CrossEntropy(preds[i], label))
    loss = lbann.Concatenation(loss)

    # Average cross entropy over non-pad tokens
    loss_scales = lbann.Divide(
        is_not_pad,
        lbann.Tessellate(num_not_pad, hint_layer=is_not_pad),
    )
    loss = lbann.Multiply(loss, loss_scales)
    loss = lbann.Reduction(loss, mode='sum')

    # Construct model
    metrics = []
    callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()]
    return lbann.Model(
        num_epochs,
        layers=lbann.traverse_layer_graph(input_),
        objective_function=loss,
        metrics=metrics,
        callbacks=callbacks,
    )
Ejemplo n.º 8
0
    def __init__(self,
                 input_feature_dims,
                 dictionary_size,
                 embedding_size,
                 ignore_label,
                 num_decoder_layers=3,
                 save_output=False,
                 name=None):
        """Initialize Molecular WAE.

        Args:
            input_feature_dims (int): analogous to sequence length.
            dictionary_size (int): vocabulary size
            embedding_size (int): embedding size
            ignore_label (int): padding index
            num_decoder_layers (int, optional) : Number of decoder layers
                (default: 3)
            save_output (bool, optional): save or not save predictions
                (default: False).
            name (str, optional): Module name
                (default: 'molvae_module<index>').

        """
        MolWAE.global_count += 1
        self.instance = 0
        self.name = (name if name else 'molvae_module{0}'.format(
            MolWAE.global_count))

        self.input_feature_dims = input_feature_dims
        self.embedding_size = embedding_size
        self.dictionary_size = dictionary_size
        self.label_to_ignore = ignore_label
        self.num_decoder_layers = num_decoder_layers
        self.save_output = save_output
        self.datatype = lbann.DataType.FLOAT
        self.weights_datatype = lbann.DataType.FLOAT

        fc = lbann.modules.FullyConnectedModule
        gru = GRUModule

        disc_neurons = [128, 64, 1]
        #Encoder
        self.encoder_rnn = gru(
            hidden_size=256,
            name=self.name + '_encoder_rnn',
            datatype=self.datatype,
            weights_datatype=self.weights_datatype,
        )
        self.q_mu = fc(128, name='encoder_qmu')
        self.q_logvar = fc(128, name='encoder_qlogvar')
        for w in self.q_mu.weights + self.q_logvar.weights:
            w.datatype = self.weights_datatype

        #Decoder
        self.decoder_rnn = gru(
            hidden_size=512,
            num_layers=self.num_decoder_layers,
            name=self.name + '_decoder_rnn',
            datatype=self.datatype,
            weights_datatype=self.weights_datatype,
        )
        self.decoder_lat = fc(512, name=self.name + '_decoder_lat')
        self.decoder_fc = fc(self.dictionary_size,
                             name=self.name + '_decoder_fc')
        for w in self.decoder_lat.weights + self.decoder_fc.weights:
            w.datatype = self.weights_datatype
        self.decoder_fc.weights[0].initializer = lbann.NormalInitializer(
            mean=0, standard_deviation=1 / math.sqrt(512))

        #shared encoder/decoder weights
        self.emb_weights = lbann.Weights(
            initializer=lbann.NormalInitializer(mean=0, standard_deviation=1),
            name='emb_matrix',
            datatype=self.weights_datatype,
        )

        #Discriminator1
        self.d0_fc0 = fc(disc_neurons[0],
                         activation=lbann.Relu,
                         name=self.name + '_disc0_fc0')
        self.d0_fc1 = fc(disc_neurons[1],
                         activation=lbann.Relu,
                         name=self.name + '_disc0_fc1')
        self.d0_fc2 = fc(disc_neurons[2], name=self.name + '_disc0_fc2')

        #Discriminator2
        #stacked_discriminator, this will be frozen, no optimizer,
        #layer has to be named for replace layer callback
        self.d1_fc0 = fc(disc_neurons[0],
                         activation=lbann.Relu,
                         name=self.name + '_disc1_fc0')
        self.d1_fc1 = fc(disc_neurons[1],
                         activation=lbann.Relu,
                         name=self.name + '_disc1_fc1')
        self.d1_fc2 = fc(disc_neurons[2], name=self.name + '_disc1_fc2')
Ejemplo n.º 9
0
def construct_model(run_args):
    """Construct LBANN model.

    Initial model for ATOM molecular SMILES generation
    Network architecture and training hyperparameters from
    https://github.com/samadejacobs/moses/tree/master/moses/char_rnn

    """

    pad_index = run_args.pad_index
    assert pad_index is not None

    sequence_length = run_args.sequence_length
    assert sequence_length is not None

    print("sequence length is {}".format(sequence_length))
    data_layout = "data_parallel"

    # Layer graph
    _input = lbann.Input(name="inp_tensor", data_field='samples')
    print(sequence_length)
    x_slice = lbann.Slice(
        _input,
        axis=0,
        slice_points=str_list(range(sequence_length + 1)),
        name="inp_slice",
    )

    # embedding layer
    emb = []
    embedding_dim = run_args.embedding_dim
    num_embeddings = run_args.num_embeddings
    assert embedding_dim is not None
    assert num_embeddings is not None

    emb_weights = lbann.Weights(
        initializer=lbann.NormalInitializer(mean=0, standard_deviation=1),
        name="emb_matrix",
    )

    lstm1 = lbann.modules.GRU(size=run_args.hidden, data_layout=data_layout)
    fc = lbann.modules.FullyConnectedModule(size=num_embeddings,
                                            data_layout=data_layout)

    last_output = lbann.Constant(
        value=0.0,
        num_neurons="{}".format(run_args.hidden),
        data_layout=data_layout,
        name="lstm_init_output",
    )

    lstm1_prev_state = [last_output]

    loss = []
    idl = []
    for i in range(sequence_length):
        idl.append(
            lbann.Identity(x_slice, name="slice_idl_" + str(i), device="CPU"))

    for i in range(sequence_length - 1):

        emb_l = lbann.Embedding(
            idl[i],
            name="emb_" + str(i),
            weights=emb_weights,
            embedding_dim=embedding_dim,
            num_embeddings=num_embeddings,
        )

        x, lstm1_prev_state = lstm1(emb_l, lstm1_prev_state)
        fc_l = fc(x)
        y_soft = lbann.Softmax(fc_l, name="soft_" + str(i))
        gt = lbann.OneHot(idl[i + 1], size=num_embeddings)
        ce = lbann.CrossEntropy([y_soft, gt], name="loss_" + str(i))
        # mask padding in input
        pad_mask = lbann.NotEqual(
            [idl[i], lbann.Constant(value=pad_index, num_neurons="1")], )
        ce_mask = lbann.Multiply([pad_mask, ce], name="loss_mask_" + str(i))
        loss.append(lbann.LayerTerm(ce_mask, scale=1 / (sequence_length - 1)))

    layers = list(lbann.traverse_layer_graph(_input))
    # Setup objective function
    weights = set()
    for l in layers:
        weights.update(l.weights)
    obj = lbann.ObjectiveFunction(loss)

    callbacks = [
        lbann.CallbackPrint(),
        lbann.CallbackTimer(),
        lbann.CallbackStepLearningRate(step=run_args.step_size,
                                       amt=run_args.gamma),
        lbann.CallbackDumpWeights(directory=run_args.dump_weights_dir,
                                  epoch_interval=1),
    ]

    # Construct model
    return lbann.Model(run_args.num_epochs,
                       layers=layers,
                       weights=weights,
                       objective_function=obj,
                       callbacks=callbacks)