Esempio n. 1
0
def BondEncoder(edge_feature_columns, EDGE_EMBEDDING_DIM):
    """Embeds the edge features into a vector
	Args:
		edge_feature_columns (list(Layers)): A list of layers with edge feaures with shape (NUM_EDGES)
		EDGE_EMBEDDING_DIM (int): The embedding dimensionality of the edge feature vector
	Returns:
		(Layer): A layer containing the embedded edge feature matrix of shape (NUM_EDGES, EDGE_EMBEDDING_DIM)
		"""
    # Courtesy of OGB
    bond_feature_dims = [5, 6, 2]
    _fan_in = bond_feature_dims[0]
    _fan_out = EDGE_EMBEDDING_DIM
    _embedding_weights = lbann.Weights(
        initializer=_xavier_uniform_init(_fan_in, _fan_out),
        name="bond_encoder_weights_{}".format(0))

    temp = lbann.Embedding(edge_feature_columns[0],
                           num_embeddings=bond_feature_dims[0],
                           embedding_dim=EDGE_EMBEDDING_DIM,
                           weights=_embedding_weights,
                           name="Bond_Embedding_0")

    for i in range(1, 3):
        _fan_in = bond_feature_dims[i]
        _fan_out = EDGE_EMBEDDING_DIM
        _embedding_weights = lbann.Weights(
            initializer=_xavier_uniform_init(_fan_in, _fan_out),
            name="bond_encoder_weights_{}".format(i))
        _temp2 = lbann.Embedding(edge_feature_columns[i],
                                 num_embeddings=bond_feature_dims[i],
                                 embedding_dim=EDGE_EMBEDDING_DIM,
                                 weights=_embedding_weights,
                                 name="Bond_Embedding_{}".format(i))
        temp = lbann.Sum(temp, _temp2)
    return temp
Esempio n. 2
0
    def forward(self, x):
        """Do the VAE forward step

        :param x: list of tensors of longs, embed representation of input
        :return: float, kl term component of loss
        :return: float, recon component of loss
        """

        x = lbann.Slice(x, slice_points=str_list([0, self.input_feature_dims]))
        x = lbann.Identity(x)
        x_emb = lbann.Embedding(x,
                                num_embeddings=self.dictionary_size,
                                embedding_dim=self.embedding_size,
                                name='emb',
                                weights=self.emb_weights)

        # Encoder: x -> z, kl_loss
        z, kl_loss = self.forward_encoder(x_emb)

        # Decoder: x, z -> recon_loss
        pred = self.forward_decoder(x_emb, z)
        recon_loss = self.compute_loss(x, pred)

        # Hack to remove blocking GPU allreduce in evaluation layer
        kl_loss = lbann.Identity(kl_loss, device='CPU')
        recon_loss = lbann.Identity(recon_loss, device='CPU')

        return kl_loss, recon_loss
Esempio n. 3
0
    def forward(self, x, z):
        """Do the WAE forward step

        :param x: list of tensors of longs, embed representation of input
        :return: float, kl term component of loss
        :return: float, recon component of loss
        """

        x = lbann.Slice(x, slice_points=str_list([0, self.input_feature_dims]))
        x = lbann.Identity(x)
        x_emb = lbann.Embedding(x,
                                num_embeddings=self.dictionary_size,
                                embedding_dim=self.embedding_size,
                                name='emb',
                                weights=self.emb_weights)

        # Encoder: x -> z, kl_loss
        z_sample = self.forward_encoder(x_emb)

        eps = lbann.Gaussian(mean=self.gmean,
                             stdev=self.gstd,
                             hint_layer=z_sample)
        z_sample = lbann.Add([z_sample, eps])

        # Decoder: x, z -> recon_loss
        #pred = self.forward_decoder(x_emb, z_sample)
        pred, arg_max = self.forward_decoder(x_emb, z_sample)
        recon_loss = self.compute_loss(x, pred)

        # Hack to remove blocking GPU allreduce in evaluation layer
        #kl_loss = lbann.Identity(kl_loss, device='CPU')
        recon_loss = lbann.Identity(recon_loss, device='CPU')

        z_prior = lbann.Tessellate(
            lbann.Reshape(z, dims=str_list([1, self.zdim])),
            dims=str_list([self.input_feature_dims, self.zdim]),
        )

        d_real = self.discriminator0(
            lbann.Concatenation([x_emb, z_prior], axis=1))

        z_sample0 = lbann.Tessellate(
            lbann.Reshape(z_sample, dims=str_list([1, self.zdim])),
            dims=str_list([self.input_feature_dims, self.zdim]),
        )
        y_z_sample = lbann.Concatenation([x_emb, z_sample0], axis=1)

        d_fake = self.discriminator0(lbann.StopGradient(y_z_sample))
        d_adv = self.discriminator1(y_z_sample)  #freeze

        return recon_loss, d_real, d_fake, d_adv, arg_max
Esempio n. 4
0
def AtomEncoder(node_feature_columns, EMBEDDING_DIM):
    """Embeds the node features into a vector

	Args:
		edge_feature_columns (list(Layers)): A list of layers with node feaures with shape (NUM_NODES)
		EMBEDDING_DIM (int): The embedding dimensionality of the node feature vector
	Returns:
		(Layer): A layer containing the embedded node feature matrix of shape (NUM_NODES, EMBEDDING_DIM)
		"""
    # Courtesy of OGB
    atom_feature_dims = [119, 4, 12, 12, 10, 6, 6, 2, 2]

    _fan_in = atom_feature_dims[0]
    _fan_out = EMBEDDING_DIM

    _embedding_weights = lbann.Weights(
        initializer=_xavier_uniform_init(_fan_in, _fan_out),
        name="atom_encoder_weights_{}".format(0))

    temp = lbann.Embedding(node_feature_columns[0],
                           num_embeddings=atom_feature_dims[0],
                           embedding_dim=EMBEDDING_DIM,
                           weights=_embedding_weights,
                           name="Atom_Embedding_0")
    for i in range(1, 9):
        _fan_in = atom_feature_dims[i]
        _fan_out = EMBEDDING_DIM
        _embedding_weights = lbann.Weights(
            initializer=_xavier_uniform_init(_fan_in, _fan_out),
            name="atom_encoder_weights_{}".format(i))
        _temp2 = lbann.Embedding(node_feature_columns[i],
                                 num_embeddings=atom_feature_dims[i],
                                 embedding_dim=EMBEDDING_DIM,
                                 weights=_embedding_weights,
                                 name="Atom_Embedding_{}".format(i))
        temp = lbann.Sum(temp, _temp2)
    return temp
Esempio n. 5
0
        num_embeddings=num_vertices,
        embedding_dim=args.latent_dim,
        sparse_sgd=True,
        learning_rate=args.learning_rate,
    )
elif args.embeddings == 'replicated':
    embeddings_weights = lbann.Weights(
        initializer=lbann.NormalInitializer(
            mean=0,
            standard_deviation=1 / args.latent_dim,
        ),
        name='embeddings',
    )
    embeddings = lbann.Embedding(
        input_,
        weights=embeddings_weights,
        num_embeddings=num_vertices,
        embedding_dim=args.latent_dim,
    )
else:
    raise RuntimeError(
        f'unknown method to get embedding vectors ({args.embeddings})')
embeddings_slice = lbann.Slice(
    embeddings,
    axis=0,
    slice_points=utils.str_list(
        [0, num_negative_samples, num_negative_samples + walk_length]),
)
negative_samples_embeddings = lbann.Identity(embeddings_slice)
walk_embeddings = lbann.Identity(embeddings_slice)

# Skip-Gram objective function
Esempio n. 6
0
def construct_model(run_args):
    """Construct LBANN model.

    Initial model for ATOM molecular VAE

    """
    import lbann

    pad_index = run_args.pad_index
    assert pad_index is not None

    sequence_length = run_args.sequence_length
    assert sequence_length is not None, 'should be training seq len + bos + eos'

    print("sequence length is {}, which is training sequence len + bos + eos".format(sequence_length))
    data_layout = "data_parallel"
    # Layer graph
    input_ = lbann.Input(data_field='samples',name='inp_data')
    #Note input assumes to come from encoder script concatenation of input smiles + z
    inp_slice = lbann.Slice(input_, axis=0,
                             slice_points=str_list([0, sequence_length, sequence_length+run_args.z_dim]),
                             name='inp_slice')
    inp_smile = lbann.Identity(inp_slice,name='inp_smile')
    z  = lbann.Identity(inp_slice, name='z')
    wae_loss= []
    input_feature_dims = sequence_length

    embedding_size = run_args.embedding_dim
    dictionary_size = run_args.num_embeddings
    assert embedding_size is not None
    assert dictionary_size is not None

    save_output = True if run_args.dump_outputs_dir else False

    print("save output? ", save_output, "out dir ",  run_args.dump_outputs_dir)
    #uncomment below for random sampling
    #z = lbann.Gaussian(mean=0.0,stdev=1.0, neuron_dims=str(run_args.z_dim))
    x = lbann.Slice(inp_smile, slice_points=str_list([0, input_feature_dims]))
    x = lbann.Identity(x)
    waemodel = molwae.MolWAE(input_feature_dims,
                           dictionary_size,
                           embedding_size,
                           pad_index,run_args.z_dim,save_output=save_output)
    x_emb = lbann.Embedding(
            x,
            num_embeddings=waemodel.dictionary_size,
            embedding_dim=waemodel.embedding_size,
            name='emb',
            weights=waemodel.emb_weights
    )


    pred, arg_max = waemodel.forward_decoder(x_emb,z)

    recon = waemodel.compute_loss(x, pred)



    wae_loss.append(recon)

    layers = list(lbann.traverse_layer_graph(input_))
    # Setup objective function
    weights = set()
    for l in layers:
      weights.update(l.weights)
    #l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4)

    #wae_loss.append(l2_reg)
    print("LEN wae loss ", len(wae_loss))

    obj = lbann.ObjectiveFunction(wae_loss)

    # Initialize check metric callback
    metrics = [lbann.Metric(recon, name='recon')]

    callbacks = [lbann.CallbackPrint(),
                 lbann.CallbackTimer()]


    #Dump output (activation) for post processing
    pred_tensor = lbann.Concatenation(arg_max, name='pred_tensor')
    conc_out = lbann.Concatenation([input_,pred_tensor], name='conc_out')
    callbacks.append(lbann.CallbackDumpOutputs(batch_interval=run_args.dump_outputs_interval,
                       execution_modes='test',
                       directory=run_args.dump_outputs_dir,
                       layers=f'{conc_out.name}'))
    # Construct model
    return lbann.Model(run_args.num_epochs,
                       weights=weights,
                       layers=layers,
                       objective_function=obj,
                       metrics=metrics,
                       callbacks=callbacks)
Esempio n. 7
0
def construct_model(run_args):
    """Construct LBANN model.

    Initial model for ATOM molecular VAE

    """
    import lbann

    print("Dump model dir ", run_args.dump_model_dir)
    assert run_args.dump_model_dir, "evaluate script asssumes a pretrained WAE model"
    pad_index = run_args.pad_index
    assert pad_index is not None

    sequence_length = run_args.sequence_length
    assert sequence_length is not None

    print("sequence length is {}".format(sequence_length))
    data_layout = "data_parallel"
    # Layer graph
    input_ = lbann.Identity(lbann.Input(name='inp', target_mode="N/A"),
                            name='inp1')
    wae_loss = []
    input_feature_dims = sequence_length

    embedding_size = run_args.embedding_dim
    dictionary_size = run_args.num_embeddings
    assert embedding_size is not None
    assert dictionary_size is not None

    save_output = False

    print("save output? ", save_output, "out dir ", run_args.dump_outputs_dir)
    z = lbann.Gaussian(mean=0.0, stdev=1.0, neuron_dims="128")

    x = lbann.Slice(input_, slice_points=str_list([0, input_feature_dims]))
    x = lbann.Identity(x)
    waemodel = molwae.MolWAE(input_feature_dims, dictionary_size,
                             embedding_size, pad_index, save_output)
    x_emb = lbann.Embedding(x,
                            num_embeddings=waemodel.dictionary_size,
                            embedding_dim=waemodel.embedding_size,
                            name='emb',
                            weights=waemodel.emb_weights)

    latentz = waemodel.forward_encoder(x_emb)

    fake_loss = lbann.MeanAbsoluteError(latentz, z)

    layers = list(lbann.traverse_layer_graph(input_))
    # Setup objective function
    weights = set()
    for l in layers:
        weights.update(l.weights)

    obj = lbann.ObjectiveFunction(fake_loss)

    callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()]

    #Dump output (activation) for post processing
    conc_out = lbann.Concatenation([input_, latentz], name='conc_out')
    callbacks.append(
        lbann.CallbackDumpOutputs(
            batch_interval=run_args.dump_outputs_interval,
            execution_modes='test',
            directory=run_args.dump_outputs_dir,
            layers=f'{conc_out.name}'))
    # Construct model
    return lbann.Model(run_args.num_epochs,
                       weights=weights,
                       layers=layers,
                       objective_function=obj,
                       callbacks=callbacks)
Esempio n. 8
0
# Properties of graph and random walk
num_graph_nodes = dataset.max_graph_node_id() + 1
walk_length = dataset.walk_context_length
num_negative_samples = dataset.num_negative_samples
input_size = dataset.sample_dims()[0]

# Embedding vectors, including negative sampling
# Note: Input is sequence of graph node IDs
input_ = lbann.Identity(lbann.Input())
input_slice = lbann.Slice(
    input_,
    slice_points=f'0 {num_negative_samples+1} {input_size}'
)
decoder_embeddings = lbann.Embedding(
    input_slice,
    weights=decoder_embeddings_weights,
    num_embeddings=num_graph_nodes,
    embedding_dim=args.latent_dim,
)
encoder_embeddings = lbann.Embedding(
    input_slice,
    weights=encoder_embeddings_weights,
    num_embeddings=num_graph_nodes,
    embedding_dim=args.latent_dim,
)

# Skip-Gram with negative sampling
preds = lbann.MatMul(decoder_embeddings, encoder_embeddings, transpose_b=True)
preds_slice = lbann.Slice(
    preds,
    axis=0,
    slice_points=f'0 {num_negative_samples} {num_negative_samples+1}')
Esempio n. 9
0
File: main.py Progetto: oyamay/lbann
# Construct layer graph
# ----------------------------------

# Dataset properties
vocab_size = dataset.corpus.vocab_size
sequence_length = dataset.sample_dims()[0]

# Input is a sequence of token IDs
input_ = lbann.Identity(lbann.Input())
input_slice = lbann.Slice(input_,
                          slice_points=str_list(range(sequence_length + 1)))
tokens_list = [lbann.Identity(input_slice) for _ in range(sequence_length)]

# Get sequence of embedding vectors
embeddings = lbann.Embedding(input_,
                             num_embeddings=vocab_size,
                             embedding_dim=args.latent_dim)
embeddings_slice = lbann.Slice(embeddings,
                               axis=0,
                               slice_points=str_list(range(sequence_length +
                                                           1)))
embeddings_list = [
    lbann.Reshape(embeddings_slice, dims='-1') for _ in range(sequence_length)
]

# Layer modules
lstm = lbann.modules.LSTMCell(args.latent_dim)
lstm_state = [
    lbann.Constant(value=0, num_neurons=str_list(args.latent_dim)),
    lbann.Constant(value=0, num_neurons=str_list(args.latent_dim))
]
Esempio n. 10
0
def make_model(
    num_epochs,
    embed_dim,
    num_heads,
    label_smoothing,
):

    # Embedding weights
    var = 2 / (embed_dim + vocab_size)  # Glorot initialization
    embedding_weights = lbann.Weights(
        name='embeddings',
        initializer=lbann.NormalInitializer(standard_deviation=math.sqrt(var)),
    )

    # Input is two sequences of token IDs
    input_ = lbann.Input(data_field='samples')

    # Get sequences of embedding vectors
    # Note: Scale embeddings by sqrt(embed_dim).
    # Note: Decoder input is shifted right, so embedding for last
    # token isn't needed.
    embeddings_tokens = lbann.Identity(
        lbann.Slice(
            input_,
            axis=0,
            slice_points=str_list([0, 2 * sequence_length - 1]),
        ))
    embeddings = lbann.Embedding(
        embeddings_tokens,
        weights=embedding_weights,
        num_embeddings=vocab_size,
        embedding_dim=embed_dim,
        padding_idx=pad_index,
    )
    embeddings = lbann.WeightedSum(
        embeddings,
        scaling_factors=str(math.sqrt(embed_dim)),
    )
    embeddings_slice = lbann.Slice(
        embeddings,
        axis=0,
        slice_points=str_list([0, sequence_length, 2 * sequence_length - 1]),
    )
    encoder_input = lbann.Identity(embeddings_slice)
    decoder_input = lbann.Identity(embeddings_slice)

    # Apply transformer model
    transformer = lbann.models.Transformer(
        hidden_size=embed_dim,
        num_heads=num_heads,
        name='transformer',
    )
    result = transformer(
        encoder_input,
        sequence_length,
        decoder_input,
        sequence_length - 1,
    )

    # Reconstruct decoder input
    preds = lbann.ChannelwiseFullyConnected(
        result,
        weights=embedding_weights,
        output_channel_dims=[vocab_size],
        bias=False,
        transpose=True,
    )
    preds = lbann.ChannelwiseSoftmax(preds)
    preds = lbann.Slice(preds,
                        axis=0,
                        slice_points=str_list(range(sequence_length)))
    preds = [lbann.Identity(preds) for _ in range(sequence_length - 1)]

    # Count number of non-pad tokens
    label_tokens = lbann.Identity(
        lbann.Slice(
            input_,
            slice_points=str_list([sequence_length + 1, 2 * sequence_length]),
        ))
    pads = lbann.Constant(value=pad_index,
                          num_neurons=str(sequence_length - 1))
    is_not_pad = lbann.NotEqual(label_tokens, pads)
    num_not_pad = lbann.Reduction(is_not_pad, mode='sum')

    # Cross entropy loss with label smoothing
    label_tokens = lbann.Slice(
        label_tokens,
        slice_points=str_list(range(sequence_length)),
    )
    label_tokens = [
        lbann.Identity(label_tokens) for _ in range(sequence_length - 1)
    ]
    if label_smoothing > 0:
        uniform_label = lbann.Constant(value=1 / vocab_size,
                                       num_neurons=str_list([1, vocab_size]))
    loss = []
    for i in range(sequence_length - 1):
        label = lbann.OneHot(label_tokens[i], size=vocab_size)
        label = lbann.Reshape(label, dims=str_list([1, vocab_size]))
        if label_smoothing > 0:
            label = lbann.WeightedSum(
                label,
                uniform_label,
                scaling_factors=str_list(
                    [1 - label_smoothing, label_smoothing]),
            )
        loss.append(lbann.CrossEntropy(preds[i], label))
    loss = lbann.Concatenation(loss)

    # Average cross entropy over non-pad tokens
    loss_scales = lbann.Divide(
        is_not_pad,
        lbann.Tessellate(num_not_pad, hint_layer=is_not_pad),
    )
    loss = lbann.Multiply(loss, loss_scales)
    loss = lbann.Reduction(loss, mode='sum')

    # Construct model
    metrics = []
    callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()]
    return lbann.Model(
        num_epochs,
        layers=lbann.traverse_layer_graph(input_),
        objective_function=loss,
        metrics=metrics,
        callbacks=callbacks,
    )
Esempio n. 11
0
def construct_model(run_args):
    """Construct LBANN model.

    Initial model for ATOM molecular VAE

    """
    import lbann

    pad_index = run_args.pad_index
    assert pad_index is not None

    #sequence_length = run_args.sequence_length
    sequence_length = 102
    assert sequence_length is not None

    print("sequence length is {}".format(sequence_length))
    data_layout = "data_parallel"
    # Layer graph
    input_ = lbann.Input(target_mode='N/A',name='inp_data')
    inp_slice = lbann.Slice(input_, axis=0, slice_points="0 102 230",name='inp_slice')
    inp_smile = lbann.Identity(inp_slice,name='inp_smile')
    z = lbann.Identity(inp_slice, name='z') #param not used
    #input_ = lbann.Identity(lbann.Input(name='inp',target_mode="N/A"), name='inp1')
    vae_loss= []
    input_feature_dims = sequence_length

    embedding_size = run_args.embedding_dim
    dictionary_size = run_args.num_embeddings
    assert embedding_size is not None
    assert dictionary_size is not None

    save_output = True if run_args.dump_outputs_dir else False

    #print("Inp smile len ", len(inp_smile), "z len ",  len(z))
    print("save output? ", save_output, "out dir ",  run_args.dump_outputs_dir)
    #z = lbann.Gaussian(mean=0.0,stdev=1.0, neuron_dims="128")
    x = lbann.Slice(inp_smile, slice_points=str_list([0, input_feature_dims]))
    x = lbann.Identity(x)
    waemodel = molwae.MolWAE(input_feature_dims,
                           dictionary_size,
                           embedding_size,
                           pad_index,save_output)
    x_emb = lbann.Embedding(
            x,
            num_embeddings=waemodel.dictionary_size,
            embedding_dim=waemodel.embedding_size,
            name='emb',
            weights=waemodel.emb_weights
    )

    
    pred, arg_max = waemodel.forward_decoder(x_emb,z)

    recon = waemodel.compute_loss(x, pred)



    vae_loss.append(recon)

    layers = list(lbann.traverse_layer_graph(input_))
    # Setup objective function
    weights = set()
    for l in layers:
      weights.update(l.weights)
    #l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4)

    #vae_loss.append(l2_reg)
    print("LEN vae loss ", len(vae_loss))

    obj = lbann.ObjectiveFunction(vae_loss)

    # Initialize check metric callback
    metrics = [lbann.Metric(recon, name='recon')]

    callbacks = [lbann.CallbackPrint(),
                 lbann.CallbackTimer()]


    #Dump output (activation) for post processing
    pred_tensor = lbann.Concatenation(arg_max, name='pred_tensor')
    conc_out = lbann.Concatenation([input_,pred_tensor], name='conc_out')
    callbacks.append(lbann.CallbackDumpOutputs(batch_interval=run_args.dump_outputs_interval, 
                       execution_modes='test', 
                       directory=run_args.dump_outputs_dir,
                       layers=f'{conc_out.name}'))
    # Construct model
    return lbann.Model(run_args.num_epochs,
                       weights=weights,
                       layers=layers,
                       objective_function=obj,
                       metrics=metrics,
                       callbacks=callbacks)
Esempio n. 12
0
    def forward(
        self,
        input_ids=None,
        token_type_ids=None,
        position_ids=None,
        inputs_embeds=None,
    ):

        if position_ids is None:
            if input_ids is not None:
                position_ids = create_position_ids_from_input_ids(
                    input_ids,
                    self.input_shape,
                    self.padding_idx,
                )
            else:
                position_ids = self.create_position_ids_from_inputs_embeds(
                    inputs_embeds)

        if token_type_ids is None:
            token_type_ids = lbann.Constant(value=0,
                                            num_neurons=str_list(
                                                self.input_shape))

        if inputs_embeds is None:
            inputs_embeds = lbann.Embedding(
                input_ids,
                num_embeddings=self.vocab_size,
                embedding_dim=self.hidden_size,
                padding_idx=self.pad_token_id,
                weights=_load_pretrained_weights(
                    ".".join((self.name, "word_embeddings.weight")),
                    load_weights=self.load_weights,
                ),
                name=".".join((self.name, "word_embeddings")),
            )
        token_type_embeddings = lbann.Embedding(
            token_type_ids,
            num_embeddings=self.type_vocab_size,
            embedding_dim=self.hidden_size,
            weights=_load_pretrained_weights(
                ".".join((self.name, "token_type_embeddings.weight")),
                load_weights=self.load_weights,
            ),
            name=".".join((self.name, "token_type_embeddings")),
        )

        embeddings = lbann.Add(inputs_embeds, token_type_embeddings)
        if self.position_embedding_type == "absolute":
            position_embeddings = lbann.Embedding(
                position_ids,
                num_embeddings=self.max_position_embeddings,
                embedding_dim=self.hidden_size,
                padding_idx=self.pad_token_id,
                weights=_load_pretrained_weights(
                    ".".join((self.name, "position_embeddings.weight")),
                    load_weights=self.load_weights,
                ),
                name=".".join((self.name, "position_embeddings")),
            )
            embeddings = lbann.Add(embeddings, position_embeddings)

        embeddings = lbann.modules.PytorchLayerNorm(
            embeddings,
            self.layer_norm_eps,
            self.input_shape + (self.hidden_size, ),
            weights=_load_pretrained_weights(
                ".".join((self.name, "layernorm.weightbias")),
                load_weights=self.load_weights,
            ),
            name=".".join((self.name, "LayerNorm")),
        )
        embeddings = lbann.Dropout(embeddings,
                                   keep_prob=self.hidden_dropout_prob)
        return embeddings
Esempio n. 13
0
def construct_model(run_args):
    """Construct LBANN model.

    Initial model for ATOM molecular SMILES generation
    Network architecture and training hyperparameters from
    https://github.com/samadejacobs/moses/tree/master/moses/char_rnn

    """

    pad_index = run_args.pad_index
    assert pad_index is not None

    sequence_length = run_args.sequence_length
    assert sequence_length is not None

    print("sequence length is {}".format(sequence_length))
    data_layout = "data_parallel"

    # Layer graph
    _input = lbann.Input(name="inp_tensor", data_field='samples')
    print(sequence_length)
    x_slice = lbann.Slice(
        _input,
        axis=0,
        slice_points=str_list(range(sequence_length + 1)),
        name="inp_slice",
    )

    # embedding layer
    emb = []
    embedding_dim = run_args.embedding_dim
    num_embeddings = run_args.num_embeddings
    assert embedding_dim is not None
    assert num_embeddings is not None

    emb_weights = lbann.Weights(
        initializer=lbann.NormalInitializer(mean=0, standard_deviation=1),
        name="emb_matrix",
    )

    lstm1 = lbann.modules.GRU(size=run_args.hidden, data_layout=data_layout)
    fc = lbann.modules.FullyConnectedModule(size=num_embeddings,
                                            data_layout=data_layout)

    last_output = lbann.Constant(
        value=0.0,
        num_neurons="{}".format(run_args.hidden),
        data_layout=data_layout,
        name="lstm_init_output",
    )

    lstm1_prev_state = [last_output]

    loss = []
    idl = []
    for i in range(sequence_length):
        idl.append(
            lbann.Identity(x_slice, name="slice_idl_" + str(i), device="CPU"))

    for i in range(sequence_length - 1):

        emb_l = lbann.Embedding(
            idl[i],
            name="emb_" + str(i),
            weights=emb_weights,
            embedding_dim=embedding_dim,
            num_embeddings=num_embeddings,
        )

        x, lstm1_prev_state = lstm1(emb_l, lstm1_prev_state)
        fc_l = fc(x)
        y_soft = lbann.Softmax(fc_l, name="soft_" + str(i))
        gt = lbann.OneHot(idl[i + 1], size=num_embeddings)
        ce = lbann.CrossEntropy([y_soft, gt], name="loss_" + str(i))
        # mask padding in input
        pad_mask = lbann.NotEqual(
            [idl[i], lbann.Constant(value=pad_index, num_neurons="1")], )
        ce_mask = lbann.Multiply([pad_mask, ce], name="loss_mask_" + str(i))
        loss.append(lbann.LayerTerm(ce_mask, scale=1 / (sequence_length - 1)))

    layers = list(lbann.traverse_layer_graph(_input))
    # Setup objective function
    weights = set()
    for l in layers:
        weights.update(l.weights)
    obj = lbann.ObjectiveFunction(loss)

    callbacks = [
        lbann.CallbackPrint(),
        lbann.CallbackTimer(),
        lbann.CallbackStepLearningRate(step=run_args.step_size,
                                       amt=run_args.gamma),
        lbann.CallbackDumpWeights(directory=run_args.dump_weights_dir,
                                  epoch_interval=1),
    ]

    # Construct model
    return lbann.Model(run_args.num_epochs,
                       layers=layers,
                       weights=weights,
                       objective_function=obj,
                       callbacks=callbacks)