def BondEncoder(edge_feature_columns, EDGE_EMBEDDING_DIM): """Embeds the edge features into a vector Args: edge_feature_columns (list(Layers)): A list of layers with edge feaures with shape (NUM_EDGES) EDGE_EMBEDDING_DIM (int): The embedding dimensionality of the edge feature vector Returns: (Layer): A layer containing the embedded edge feature matrix of shape (NUM_EDGES, EDGE_EMBEDDING_DIM) """ # Courtesy of OGB bond_feature_dims = [5, 6, 2] _fan_in = bond_feature_dims[0] _fan_out = EDGE_EMBEDDING_DIM _embedding_weights = lbann.Weights( initializer=_xavier_uniform_init(_fan_in, _fan_out), name="bond_encoder_weights_{}".format(0)) temp = lbann.Embedding(edge_feature_columns[0], num_embeddings=bond_feature_dims[0], embedding_dim=EDGE_EMBEDDING_DIM, weights=_embedding_weights, name="Bond_Embedding_0") for i in range(1, 3): _fan_in = bond_feature_dims[i] _fan_out = EDGE_EMBEDDING_DIM _embedding_weights = lbann.Weights( initializer=_xavier_uniform_init(_fan_in, _fan_out), name="bond_encoder_weights_{}".format(i)) _temp2 = lbann.Embedding(edge_feature_columns[i], num_embeddings=bond_feature_dims[i], embedding_dim=EDGE_EMBEDDING_DIM, weights=_embedding_weights, name="Bond_Embedding_{}".format(i)) temp = lbann.Sum(temp, _temp2) return temp
def forward(self, x): """Do the VAE forward step :param x: list of tensors of longs, embed representation of input :return: float, kl term component of loss :return: float, recon component of loss """ x = lbann.Slice(x, slice_points=str_list([0, self.input_feature_dims])) x = lbann.Identity(x) x_emb = lbann.Embedding(x, num_embeddings=self.dictionary_size, embedding_dim=self.embedding_size, name='emb', weights=self.emb_weights) # Encoder: x -> z, kl_loss z, kl_loss = self.forward_encoder(x_emb) # Decoder: x, z -> recon_loss pred = self.forward_decoder(x_emb, z) recon_loss = self.compute_loss(x, pred) # Hack to remove blocking GPU allreduce in evaluation layer kl_loss = lbann.Identity(kl_loss, device='CPU') recon_loss = lbann.Identity(recon_loss, device='CPU') return kl_loss, recon_loss
def forward(self, x, z): """Do the WAE forward step :param x: list of tensors of longs, embed representation of input :return: float, kl term component of loss :return: float, recon component of loss """ x = lbann.Slice(x, slice_points=str_list([0, self.input_feature_dims])) x = lbann.Identity(x) x_emb = lbann.Embedding(x, num_embeddings=self.dictionary_size, embedding_dim=self.embedding_size, name='emb', weights=self.emb_weights) # Encoder: x -> z, kl_loss z_sample = self.forward_encoder(x_emb) eps = lbann.Gaussian(mean=self.gmean, stdev=self.gstd, hint_layer=z_sample) z_sample = lbann.Add([z_sample, eps]) # Decoder: x, z -> recon_loss #pred = self.forward_decoder(x_emb, z_sample) pred, arg_max = self.forward_decoder(x_emb, z_sample) recon_loss = self.compute_loss(x, pred) # Hack to remove blocking GPU allreduce in evaluation layer #kl_loss = lbann.Identity(kl_loss, device='CPU') recon_loss = lbann.Identity(recon_loss, device='CPU') z_prior = lbann.Tessellate( lbann.Reshape(z, dims=str_list([1, self.zdim])), dims=str_list([self.input_feature_dims, self.zdim]), ) d_real = self.discriminator0( lbann.Concatenation([x_emb, z_prior], axis=1)) z_sample0 = lbann.Tessellate( lbann.Reshape(z_sample, dims=str_list([1, self.zdim])), dims=str_list([self.input_feature_dims, self.zdim]), ) y_z_sample = lbann.Concatenation([x_emb, z_sample0], axis=1) d_fake = self.discriminator0(lbann.StopGradient(y_z_sample)) d_adv = self.discriminator1(y_z_sample) #freeze return recon_loss, d_real, d_fake, d_adv, arg_max
def AtomEncoder(node_feature_columns, EMBEDDING_DIM): """Embeds the node features into a vector Args: edge_feature_columns (list(Layers)): A list of layers with node feaures with shape (NUM_NODES) EMBEDDING_DIM (int): The embedding dimensionality of the node feature vector Returns: (Layer): A layer containing the embedded node feature matrix of shape (NUM_NODES, EMBEDDING_DIM) """ # Courtesy of OGB atom_feature_dims = [119, 4, 12, 12, 10, 6, 6, 2, 2] _fan_in = atom_feature_dims[0] _fan_out = EMBEDDING_DIM _embedding_weights = lbann.Weights( initializer=_xavier_uniform_init(_fan_in, _fan_out), name="atom_encoder_weights_{}".format(0)) temp = lbann.Embedding(node_feature_columns[0], num_embeddings=atom_feature_dims[0], embedding_dim=EMBEDDING_DIM, weights=_embedding_weights, name="Atom_Embedding_0") for i in range(1, 9): _fan_in = atom_feature_dims[i] _fan_out = EMBEDDING_DIM _embedding_weights = lbann.Weights( initializer=_xavier_uniform_init(_fan_in, _fan_out), name="atom_encoder_weights_{}".format(i)) _temp2 = lbann.Embedding(node_feature_columns[i], num_embeddings=atom_feature_dims[i], embedding_dim=EMBEDDING_DIM, weights=_embedding_weights, name="Atom_Embedding_{}".format(i)) temp = lbann.Sum(temp, _temp2) return temp
num_embeddings=num_vertices, embedding_dim=args.latent_dim, sparse_sgd=True, learning_rate=args.learning_rate, ) elif args.embeddings == 'replicated': embeddings_weights = lbann.Weights( initializer=lbann.NormalInitializer( mean=0, standard_deviation=1 / args.latent_dim, ), name='embeddings', ) embeddings = lbann.Embedding( input_, weights=embeddings_weights, num_embeddings=num_vertices, embedding_dim=args.latent_dim, ) else: raise RuntimeError( f'unknown method to get embedding vectors ({args.embeddings})') embeddings_slice = lbann.Slice( embeddings, axis=0, slice_points=utils.str_list( [0, num_negative_samples, num_negative_samples + walk_length]), ) negative_samples_embeddings = lbann.Identity(embeddings_slice) walk_embeddings = lbann.Identity(embeddings_slice) # Skip-Gram objective function
def construct_model(run_args): """Construct LBANN model. Initial model for ATOM molecular VAE """ import lbann pad_index = run_args.pad_index assert pad_index is not None sequence_length = run_args.sequence_length assert sequence_length is not None, 'should be training seq len + bos + eos' print("sequence length is {}, which is training sequence len + bos + eos".format(sequence_length)) data_layout = "data_parallel" # Layer graph input_ = lbann.Input(data_field='samples',name='inp_data') #Note input assumes to come from encoder script concatenation of input smiles + z inp_slice = lbann.Slice(input_, axis=0, slice_points=str_list([0, sequence_length, sequence_length+run_args.z_dim]), name='inp_slice') inp_smile = lbann.Identity(inp_slice,name='inp_smile') z = lbann.Identity(inp_slice, name='z') wae_loss= [] input_feature_dims = sequence_length embedding_size = run_args.embedding_dim dictionary_size = run_args.num_embeddings assert embedding_size is not None assert dictionary_size is not None save_output = True if run_args.dump_outputs_dir else False print("save output? ", save_output, "out dir ", run_args.dump_outputs_dir) #uncomment below for random sampling #z = lbann.Gaussian(mean=0.0,stdev=1.0, neuron_dims=str(run_args.z_dim)) x = lbann.Slice(inp_smile, slice_points=str_list([0, input_feature_dims])) x = lbann.Identity(x) waemodel = molwae.MolWAE(input_feature_dims, dictionary_size, embedding_size, pad_index,run_args.z_dim,save_output=save_output) x_emb = lbann.Embedding( x, num_embeddings=waemodel.dictionary_size, embedding_dim=waemodel.embedding_size, name='emb', weights=waemodel.emb_weights ) pred, arg_max = waemodel.forward_decoder(x_emb,z) recon = waemodel.compute_loss(x, pred) wae_loss.append(recon) layers = list(lbann.traverse_layer_graph(input_)) # Setup objective function weights = set() for l in layers: weights.update(l.weights) #l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4) #wae_loss.append(l2_reg) print("LEN wae loss ", len(wae_loss)) obj = lbann.ObjectiveFunction(wae_loss) # Initialize check metric callback metrics = [lbann.Metric(recon, name='recon')] callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()] #Dump output (activation) for post processing pred_tensor = lbann.Concatenation(arg_max, name='pred_tensor') conc_out = lbann.Concatenation([input_,pred_tensor], name='conc_out') callbacks.append(lbann.CallbackDumpOutputs(batch_interval=run_args.dump_outputs_interval, execution_modes='test', directory=run_args.dump_outputs_dir, layers=f'{conc_out.name}')) # Construct model return lbann.Model(run_args.num_epochs, weights=weights, layers=layers, objective_function=obj, metrics=metrics, callbacks=callbacks)
def construct_model(run_args): """Construct LBANN model. Initial model for ATOM molecular VAE """ import lbann print("Dump model dir ", run_args.dump_model_dir) assert run_args.dump_model_dir, "evaluate script asssumes a pretrained WAE model" pad_index = run_args.pad_index assert pad_index is not None sequence_length = run_args.sequence_length assert sequence_length is not None print("sequence length is {}".format(sequence_length)) data_layout = "data_parallel" # Layer graph input_ = lbann.Identity(lbann.Input(name='inp', target_mode="N/A"), name='inp1') wae_loss = [] input_feature_dims = sequence_length embedding_size = run_args.embedding_dim dictionary_size = run_args.num_embeddings assert embedding_size is not None assert dictionary_size is not None save_output = False print("save output? ", save_output, "out dir ", run_args.dump_outputs_dir) z = lbann.Gaussian(mean=0.0, stdev=1.0, neuron_dims="128") x = lbann.Slice(input_, slice_points=str_list([0, input_feature_dims])) x = lbann.Identity(x) waemodel = molwae.MolWAE(input_feature_dims, dictionary_size, embedding_size, pad_index, save_output) x_emb = lbann.Embedding(x, num_embeddings=waemodel.dictionary_size, embedding_dim=waemodel.embedding_size, name='emb', weights=waemodel.emb_weights) latentz = waemodel.forward_encoder(x_emb) fake_loss = lbann.MeanAbsoluteError(latentz, z) layers = list(lbann.traverse_layer_graph(input_)) # Setup objective function weights = set() for l in layers: weights.update(l.weights) obj = lbann.ObjectiveFunction(fake_loss) callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()] #Dump output (activation) for post processing conc_out = lbann.Concatenation([input_, latentz], name='conc_out') callbacks.append( lbann.CallbackDumpOutputs( batch_interval=run_args.dump_outputs_interval, execution_modes='test', directory=run_args.dump_outputs_dir, layers=f'{conc_out.name}')) # Construct model return lbann.Model(run_args.num_epochs, weights=weights, layers=layers, objective_function=obj, callbacks=callbacks)
# Properties of graph and random walk num_graph_nodes = dataset.max_graph_node_id() + 1 walk_length = dataset.walk_context_length num_negative_samples = dataset.num_negative_samples input_size = dataset.sample_dims()[0] # Embedding vectors, including negative sampling # Note: Input is sequence of graph node IDs input_ = lbann.Identity(lbann.Input()) input_slice = lbann.Slice( input_, slice_points=f'0 {num_negative_samples+1} {input_size}' ) decoder_embeddings = lbann.Embedding( input_slice, weights=decoder_embeddings_weights, num_embeddings=num_graph_nodes, embedding_dim=args.latent_dim, ) encoder_embeddings = lbann.Embedding( input_slice, weights=encoder_embeddings_weights, num_embeddings=num_graph_nodes, embedding_dim=args.latent_dim, ) # Skip-Gram with negative sampling preds = lbann.MatMul(decoder_embeddings, encoder_embeddings, transpose_b=True) preds_slice = lbann.Slice( preds, axis=0, slice_points=f'0 {num_negative_samples} {num_negative_samples+1}')
# Construct layer graph # ---------------------------------- # Dataset properties vocab_size = dataset.corpus.vocab_size sequence_length = dataset.sample_dims()[0] # Input is a sequence of token IDs input_ = lbann.Identity(lbann.Input()) input_slice = lbann.Slice(input_, slice_points=str_list(range(sequence_length + 1))) tokens_list = [lbann.Identity(input_slice) for _ in range(sequence_length)] # Get sequence of embedding vectors embeddings = lbann.Embedding(input_, num_embeddings=vocab_size, embedding_dim=args.latent_dim) embeddings_slice = lbann.Slice(embeddings, axis=0, slice_points=str_list(range(sequence_length + 1))) embeddings_list = [ lbann.Reshape(embeddings_slice, dims='-1') for _ in range(sequence_length) ] # Layer modules lstm = lbann.modules.LSTMCell(args.latent_dim) lstm_state = [ lbann.Constant(value=0, num_neurons=str_list(args.latent_dim)), lbann.Constant(value=0, num_neurons=str_list(args.latent_dim)) ]
def make_model( num_epochs, embed_dim, num_heads, label_smoothing, ): # Embedding weights var = 2 / (embed_dim + vocab_size) # Glorot initialization embedding_weights = lbann.Weights( name='embeddings', initializer=lbann.NormalInitializer(standard_deviation=math.sqrt(var)), ) # Input is two sequences of token IDs input_ = lbann.Input(data_field='samples') # Get sequences of embedding vectors # Note: Scale embeddings by sqrt(embed_dim). # Note: Decoder input is shifted right, so embedding for last # token isn't needed. embeddings_tokens = lbann.Identity( lbann.Slice( input_, axis=0, slice_points=str_list([0, 2 * sequence_length - 1]), )) embeddings = lbann.Embedding( embeddings_tokens, weights=embedding_weights, num_embeddings=vocab_size, embedding_dim=embed_dim, padding_idx=pad_index, ) embeddings = lbann.WeightedSum( embeddings, scaling_factors=str(math.sqrt(embed_dim)), ) embeddings_slice = lbann.Slice( embeddings, axis=0, slice_points=str_list([0, sequence_length, 2 * sequence_length - 1]), ) encoder_input = lbann.Identity(embeddings_slice) decoder_input = lbann.Identity(embeddings_slice) # Apply transformer model transformer = lbann.models.Transformer( hidden_size=embed_dim, num_heads=num_heads, name='transformer', ) result = transformer( encoder_input, sequence_length, decoder_input, sequence_length - 1, ) # Reconstruct decoder input preds = lbann.ChannelwiseFullyConnected( result, weights=embedding_weights, output_channel_dims=[vocab_size], bias=False, transpose=True, ) preds = lbann.ChannelwiseSoftmax(preds) preds = lbann.Slice(preds, axis=0, slice_points=str_list(range(sequence_length))) preds = [lbann.Identity(preds) for _ in range(sequence_length - 1)] # Count number of non-pad tokens label_tokens = lbann.Identity( lbann.Slice( input_, slice_points=str_list([sequence_length + 1, 2 * sequence_length]), )) pads = lbann.Constant(value=pad_index, num_neurons=str(sequence_length - 1)) is_not_pad = lbann.NotEqual(label_tokens, pads) num_not_pad = lbann.Reduction(is_not_pad, mode='sum') # Cross entropy loss with label smoothing label_tokens = lbann.Slice( label_tokens, slice_points=str_list(range(sequence_length)), ) label_tokens = [ lbann.Identity(label_tokens) for _ in range(sequence_length - 1) ] if label_smoothing > 0: uniform_label = lbann.Constant(value=1 / vocab_size, num_neurons=str_list([1, vocab_size])) loss = [] for i in range(sequence_length - 1): label = lbann.OneHot(label_tokens[i], size=vocab_size) label = lbann.Reshape(label, dims=str_list([1, vocab_size])) if label_smoothing > 0: label = lbann.WeightedSum( label, uniform_label, scaling_factors=str_list( [1 - label_smoothing, label_smoothing]), ) loss.append(lbann.CrossEntropy(preds[i], label)) loss = lbann.Concatenation(loss) # Average cross entropy over non-pad tokens loss_scales = lbann.Divide( is_not_pad, lbann.Tessellate(num_not_pad, hint_layer=is_not_pad), ) loss = lbann.Multiply(loss, loss_scales) loss = lbann.Reduction(loss, mode='sum') # Construct model metrics = [] callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()] return lbann.Model( num_epochs, layers=lbann.traverse_layer_graph(input_), objective_function=loss, metrics=metrics, callbacks=callbacks, )
def construct_model(run_args): """Construct LBANN model. Initial model for ATOM molecular VAE """ import lbann pad_index = run_args.pad_index assert pad_index is not None #sequence_length = run_args.sequence_length sequence_length = 102 assert sequence_length is not None print("sequence length is {}".format(sequence_length)) data_layout = "data_parallel" # Layer graph input_ = lbann.Input(target_mode='N/A',name='inp_data') inp_slice = lbann.Slice(input_, axis=0, slice_points="0 102 230",name='inp_slice') inp_smile = lbann.Identity(inp_slice,name='inp_smile') z = lbann.Identity(inp_slice, name='z') #param not used #input_ = lbann.Identity(lbann.Input(name='inp',target_mode="N/A"), name='inp1') vae_loss= [] input_feature_dims = sequence_length embedding_size = run_args.embedding_dim dictionary_size = run_args.num_embeddings assert embedding_size is not None assert dictionary_size is not None save_output = True if run_args.dump_outputs_dir else False #print("Inp smile len ", len(inp_smile), "z len ", len(z)) print("save output? ", save_output, "out dir ", run_args.dump_outputs_dir) #z = lbann.Gaussian(mean=0.0,stdev=1.0, neuron_dims="128") x = lbann.Slice(inp_smile, slice_points=str_list([0, input_feature_dims])) x = lbann.Identity(x) waemodel = molwae.MolWAE(input_feature_dims, dictionary_size, embedding_size, pad_index,save_output) x_emb = lbann.Embedding( x, num_embeddings=waemodel.dictionary_size, embedding_dim=waemodel.embedding_size, name='emb', weights=waemodel.emb_weights ) pred, arg_max = waemodel.forward_decoder(x_emb,z) recon = waemodel.compute_loss(x, pred) vae_loss.append(recon) layers = list(lbann.traverse_layer_graph(input_)) # Setup objective function weights = set() for l in layers: weights.update(l.weights) #l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4) #vae_loss.append(l2_reg) print("LEN vae loss ", len(vae_loss)) obj = lbann.ObjectiveFunction(vae_loss) # Initialize check metric callback metrics = [lbann.Metric(recon, name='recon')] callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()] #Dump output (activation) for post processing pred_tensor = lbann.Concatenation(arg_max, name='pred_tensor') conc_out = lbann.Concatenation([input_,pred_tensor], name='conc_out') callbacks.append(lbann.CallbackDumpOutputs(batch_interval=run_args.dump_outputs_interval, execution_modes='test', directory=run_args.dump_outputs_dir, layers=f'{conc_out.name}')) # Construct model return lbann.Model(run_args.num_epochs, weights=weights, layers=layers, objective_function=obj, metrics=metrics, callbacks=callbacks)
def forward( self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, ): if position_ids is None: if input_ids is not None: position_ids = create_position_ids_from_input_ids( input_ids, self.input_shape, self.padding_idx, ) else: position_ids = self.create_position_ids_from_inputs_embeds( inputs_embeds) if token_type_ids is None: token_type_ids = lbann.Constant(value=0, num_neurons=str_list( self.input_shape)) if inputs_embeds is None: inputs_embeds = lbann.Embedding( input_ids, num_embeddings=self.vocab_size, embedding_dim=self.hidden_size, padding_idx=self.pad_token_id, weights=_load_pretrained_weights( ".".join((self.name, "word_embeddings.weight")), load_weights=self.load_weights, ), name=".".join((self.name, "word_embeddings")), ) token_type_embeddings = lbann.Embedding( token_type_ids, num_embeddings=self.type_vocab_size, embedding_dim=self.hidden_size, weights=_load_pretrained_weights( ".".join((self.name, "token_type_embeddings.weight")), load_weights=self.load_weights, ), name=".".join((self.name, "token_type_embeddings")), ) embeddings = lbann.Add(inputs_embeds, token_type_embeddings) if self.position_embedding_type == "absolute": position_embeddings = lbann.Embedding( position_ids, num_embeddings=self.max_position_embeddings, embedding_dim=self.hidden_size, padding_idx=self.pad_token_id, weights=_load_pretrained_weights( ".".join((self.name, "position_embeddings.weight")), load_weights=self.load_weights, ), name=".".join((self.name, "position_embeddings")), ) embeddings = lbann.Add(embeddings, position_embeddings) embeddings = lbann.modules.PytorchLayerNorm( embeddings, self.layer_norm_eps, self.input_shape + (self.hidden_size, ), weights=_load_pretrained_weights( ".".join((self.name, "layernorm.weightbias")), load_weights=self.load_weights, ), name=".".join((self.name, "LayerNorm")), ) embeddings = lbann.Dropout(embeddings, keep_prob=self.hidden_dropout_prob) return embeddings
def construct_model(run_args): """Construct LBANN model. Initial model for ATOM molecular SMILES generation Network architecture and training hyperparameters from https://github.com/samadejacobs/moses/tree/master/moses/char_rnn """ pad_index = run_args.pad_index assert pad_index is not None sequence_length = run_args.sequence_length assert sequence_length is not None print("sequence length is {}".format(sequence_length)) data_layout = "data_parallel" # Layer graph _input = lbann.Input(name="inp_tensor", data_field='samples') print(sequence_length) x_slice = lbann.Slice( _input, axis=0, slice_points=str_list(range(sequence_length + 1)), name="inp_slice", ) # embedding layer emb = [] embedding_dim = run_args.embedding_dim num_embeddings = run_args.num_embeddings assert embedding_dim is not None assert num_embeddings is not None emb_weights = lbann.Weights( initializer=lbann.NormalInitializer(mean=0, standard_deviation=1), name="emb_matrix", ) lstm1 = lbann.modules.GRU(size=run_args.hidden, data_layout=data_layout) fc = lbann.modules.FullyConnectedModule(size=num_embeddings, data_layout=data_layout) last_output = lbann.Constant( value=0.0, num_neurons="{}".format(run_args.hidden), data_layout=data_layout, name="lstm_init_output", ) lstm1_prev_state = [last_output] loss = [] idl = [] for i in range(sequence_length): idl.append( lbann.Identity(x_slice, name="slice_idl_" + str(i), device="CPU")) for i in range(sequence_length - 1): emb_l = lbann.Embedding( idl[i], name="emb_" + str(i), weights=emb_weights, embedding_dim=embedding_dim, num_embeddings=num_embeddings, ) x, lstm1_prev_state = lstm1(emb_l, lstm1_prev_state) fc_l = fc(x) y_soft = lbann.Softmax(fc_l, name="soft_" + str(i)) gt = lbann.OneHot(idl[i + 1], size=num_embeddings) ce = lbann.CrossEntropy([y_soft, gt], name="loss_" + str(i)) # mask padding in input pad_mask = lbann.NotEqual( [idl[i], lbann.Constant(value=pad_index, num_neurons="1")], ) ce_mask = lbann.Multiply([pad_mask, ce], name="loss_mask_" + str(i)) loss.append(lbann.LayerTerm(ce_mask, scale=1 / (sequence_length - 1))) layers = list(lbann.traverse_layer_graph(_input)) # Setup objective function weights = set() for l in layers: weights.update(l.weights) obj = lbann.ObjectiveFunction(loss) callbacks = [ lbann.CallbackPrint(), lbann.CallbackTimer(), lbann.CallbackStepLearningRate(step=run_args.step_size, amt=run_args.gamma), lbann.CallbackDumpWeights(directory=run_args.dump_weights_dir, epoch_interval=1), ] # Construct model return lbann.Model(run_args.num_epochs, layers=layers, weights=weights, objective_function=obj, callbacks=callbacks)