def forward(self, x): """Perform LSTM step. State from previous steps is used to compute output. """ self.step += 1 name = '{0}_step{1}'.format(self.name, self.step) # Apply linearity input_concat = lbann.Concatenation([x, self.last_output], name=name + '_input', data_layout=self.data_layout) fc = self.fc(input_concat) # Get gates and cell update slice = lbann.Slice(fc, slice_points=_str_list([0, self.size, 4*self.size]), name=name + '_fc_slice', data_layout=self.data_layout) cell_update = lbann.Tanh(slice, name=name + '_cell_update', data_layout=self.data_layout) sigmoid = lbann.Sigmoid(slice, name=name + '_sigmoid', data_layout=self.data_layout) slice = lbann.Slice(sigmoid, slice_points=_str_list([0, self.size, 2*self.size, 3*self.size]), name=name + '_sigmoid_slice', data_layout=self.data_layout) f = lbann.Identity(slice, name=name + '_forget_gate', data_layout=self.data_layout) i = lbann.Identity(slice, name=name + '_input_gate', data_layout=self.data_layout) o = lbann.Identity(slice, name=name + '_output_gate', data_layout=self.data_layout) # Cell state cell_forget = lbann.Multiply([f, self.last_cell], name=name + '_cell_forget', data_layout=self.data_layout) cell_input = lbann.Multiply([i, cell_update], name=name + '_cell_input', data_layout=self.data_layout) cell = lbann.Add([cell_forget, cell_input], name=name + '_cell', data_layout=self.data_layout) # Output cell_act = lbann.Tanh(cell, name=name + '_cell_activation', data_layout=self.data_layout) output = lbann.Multiply([o, cell_act], name=name, data_layout=self.data_layout) # Update state and return output self.last_cell = cell self.last_output = output return output
def forward(self, x): """Do the VAE forward step :param x: list of tensors of longs, embed representation of input :return: float, kl term component of loss :return: float, recon component of loss """ x = lbann.Slice(x, slice_points=str_list([0, self.input_feature_dims])) x = lbann.Identity(x) x_emb = lbann.Embedding(x, num_embeddings=self.dictionary_size, embedding_dim=self.embedding_size, name='emb', weights=self.emb_weights) # Encoder: x -> z, kl_loss z, kl_loss = self.forward_encoder(x_emb) # Decoder: x, z -> recon_loss pred = self.forward_decoder(x_emb, z) recon_loss = self.compute_loss(x, pred) # Hack to remove blocking GPU allreduce in evaluation layer kl_loss = lbann.Identity(kl_loss, device='CPU') recon_loss = lbann.Identity(recon_loss, device='CPU') return kl_loss, recon_loss
def make_model( motif_size, walk_length, num_vertices, embed_dim, learn_rate, num_epochs, embeddings_dir, ): # Layer graph input_ = lbann.Slice( lbann.Input(data_field='samples'), slice_points=str_list([0, motif_size, motif_size+walk_length]), ) motif_indices = lbann.Identity(input_) walk_indices = lbann.Identity(input_) gan = model.gan.CommunityGAN( num_vertices, motif_size, embed_dim, learn_rate, ) loss, real_disc_prob, fake_disc_prob, gen_prob = gan( motif_indices, motif_size, walk_indices, walk_length, ) # Metrics metrics = [ lbann.Metric(real_disc_prob, name='D(real)'), lbann.Metric(fake_disc_prob, name='D(fake)'), lbann.Metric(gen_prob, name='G'), ] # Callbacks callbacks = [ lbann.CallbackPrint(), lbann.CallbackTimer(), lbann.CallbackDumpWeights(directory=embeddings_dir, epoch_interval=num_epochs), ] # Perform computation at double precision for l in lbann.traverse_layer_graph(input_): l.datatype = lbann.DataType.DOUBLE for w in l.weights: w.datatype = lbann.DataType.DOUBLE # Contruct model return lbann.Model( num_epochs, layers=lbann.traverse_layer_graph(input_), objective_function=loss, metrics=metrics, callbacks=callbacks, )
def forward_encoder(self, x_emb): """Encoder step, emulating z ~ E(x) = q_E(z|x) :param x_emb: (n_batch, len(x), d_z) of floats, embeddings for input sentence x :return: (n_batch, d_z) of floats, sample of latent vector z :return: float, kl term component of loss """ # _, h = self.encoder_rnn(x, None) h = self.encoder_rnn(x_emb, None) h = lbann.Slice( h, slice_points=str_list( [self.input_feature_dims - 1, self.input_feature_dims]), axis=0, ) h = lbann.Identity(h) mu, logvar = self.q_mu(h), self.q_logvar(h) # Set datatype of previous layers # Note: Depth-first search from mu and logvar to x_emb stack = [mu, logvar] in_stack = {l: True for l in stack} while stack: l = stack.pop() if type(l) not in (lbann.Slice, lbann.Reshape, lbann.Tessellate): l.datatype = self.datatype for parent in l.parents: if parent not in in_stack and parent is not x_emb: stack.append(parent) in_stack[parent] = True # eps = torch.randn_like(mu) eps = lbann.Gaussian(mean=0, stdev=1, hint_layer=mu) # z = mu + (logvar / 2).exp() * eps z = lbann.Add([ mu, (lbann.Multiply([ lbann.Exp(lbann.WeightedSum(logvar, scaling_factors='0.5')), eps ])) ]) # kl_loss = 0.5 * (logvar.exp() + mu ** 2 - 1 - logvar).sum(1).mean() kl_loss = lbann.Reduction( lbann.WeightedSum( lbann.Exp(logvar), lbann.Square(mu), self.constant(1, hint_layer=mu), logvar, scaling_factors='0.5 0.5 -0.5 -0.5', ), mode='sum', ) return z, kl_loss
def forward(self, _): w = lbann.WeightsLayer(weights=self.weights, dims='%d %d'.format(self.width, self.height)) slice = lbann.Slice(w, axis=0, slice_points=' '.join(range(self.width + 1))) cols = [] for _ in range(self.width): cols.append(lbann.Sqrt(lbann.L2Norm2(slice))) return lbann.Sum(cols)
def construct_model(): """Model description """ import lbann import lbann.modules fc = lbann.modules.FullyConnectedModule conv = lbann.modules.Convolution2dModule conv1 = conv(20, 3, stride=1, padding=1, name='conv1') conv2 = conv(20, 3, stride=1, padding=1, name='conv2') fc1 = fc(100, name='fc1') fc2 = fc(20, name='fc2') fc3 = fc(num_classes, name='fc3') # Layer graph input = lbann.Input(name='inp_tensor', target_mode='classification') inp_slice = lbann.Slice(input, axis=0, slice_points=str_list([0, dims - 1, dims]), name='inp_slice') xdata = lbann.Identity(inp_slice) ylabel = lbann.Identity(inp_slice, name='gt_y') #NHWC to NCHW x = lbann.Reshape(xdata, dims='14 13 13') x = conv2(conv1(x)) x = lbann.Reshape(x, dims='3380') x = lbann.Dropout(lbann.Relu(fc1(x)), keep_prob=0.5) x = lbann.Dropout(fc2(x), keep_prob=0.5) pred = lbann.Softmax(fc3(x)) gt_label = lbann.OneHot(ylabel, size=num_classes) loss = lbann.CrossEntropy([pred, gt_label], name='loss') acc = lbann.CategoricalAccuracy([pred, gt_label]) layers = list(lbann.traverse_layer_graph(input)) # Setup objective function weights = set() for l in layers: weights.update(l.weights) obj = lbann.ObjectiveFunction(loss) callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()] # Construct model num_epochs = 10 return lbann.Model(num_epochs, weights=weights, layers=layers, metrics=[lbann.Metric(acc, name='accuracy', unit='%')], objective_function=obj, callbacks=callbacks)
def forward(self, x, z): """Do the WAE forward step :param x: list of tensors of longs, embed representation of input :return: float, kl term component of loss :return: float, recon component of loss """ x = lbann.Slice(x, slice_points=str_list([0, self.input_feature_dims])) x = lbann.Identity(x) x_emb = lbann.Embedding(x, num_embeddings=self.dictionary_size, embedding_dim=self.embedding_size, name='emb', weights=self.emb_weights) # Encoder: x -> z, kl_loss z_sample = self.forward_encoder(x_emb) eps = lbann.Gaussian(mean=self.gmean, stdev=self.gstd, hint_layer=z_sample) z_sample = lbann.Add([z_sample, eps]) # Decoder: x, z -> recon_loss #pred = self.forward_decoder(x_emb, z_sample) pred, arg_max = self.forward_decoder(x_emb, z_sample) recon_loss = self.compute_loss(x, pred) # Hack to remove blocking GPU allreduce in evaluation layer #kl_loss = lbann.Identity(kl_loss, device='CPU') recon_loss = lbann.Identity(recon_loss, device='CPU') z_prior = lbann.Tessellate( lbann.Reshape(z, dims=str_list([1, self.zdim])), dims=str_list([self.input_feature_dims, self.zdim]), ) d_real = self.discriminator0( lbann.Concatenation([x_emb, z_prior], axis=1)) z_sample0 = lbann.Tessellate( lbann.Reshape(z_sample, dims=str_list([1, self.zdim])), dims=str_list([self.input_feature_dims, self.zdim]), ) y_z_sample = lbann.Concatenation([x_emb, z_sample0], axis=1) d_fake = self.discriminator0(lbann.StopGradient(y_z_sample)) d_adv = self.discriminator1(y_z_sample) #freeze return recon_loss, d_real, d_fake, d_adv, arg_max
def forward( self, motif_indices, motif_size, walk_indices, walk_length, ): # Apply generator fake_motif_indices, gen_prob, gen_log_prob = self.generator( walk_length, walk_indices, motif_size, ) # Get discriminator embeddings in log-space all_motif_indices = lbann.Concatenation(motif_indices, fake_motif_indices) all_motif_log_embeddings = self.discriminator.get_log_embeddings( all_motif_indices) all_motif_log_embeddings = lbann.Slice( all_motif_log_embeddings, slice_points=str_list([0, motif_size, 2 * motif_size]), ) real_motif_log_embeddings = lbann.Identity(all_motif_log_embeddings) fake_motif_log_embeddings = lbann.Identity(all_motif_log_embeddings) # Apply discriminator real_disc_prob, real_disc_log_not_prob \ = self.discriminator(motif_size, real_motif_log_embeddings) fake_disc_prob, fake_disc_log_not_prob \ = self.discriminator(motif_size, fake_motif_log_embeddings) # Loss function # L_disc = - log(D(real)) - log(1-D(fake)) # L_gen = - log(G) * stop_gradient(log(1-D(fake))) real_disc_log_prob \ = lbann.Log(lbann.Clamp(real_disc_prob, min=1e-37, max=1)) disc_loss = lbann.WeightedSum( real_disc_log_prob, fake_disc_log_not_prob, scaling_factors=str_list([-1, -1]), ) gen_loss = lbann.Multiply( gen_log_prob, lbann.StopGradient(fake_disc_log_not_prob), ) loss = lbann.Add(disc_loss, gen_loss) return loss, real_disc_prob, fake_disc_prob, gen_prob
def forward(self, x): x_slice = lbann.Slice(x, axis=0, slice_points="0 921 4750 8579", name='inp_slice') gene = self.geneT(lbann.Identity(x_slice)) drug1 = self.drug1T(lbann.Identity(x_slice)) drug2 = self.drug2T(lbann.Identity(x_slice)) concat = self.concatT( lbann.Concatenation([gene, drug1, drug2], name=self.name + 'concat')) response_fc = lbann.FullyConnected(concat, num_neurons=1, has_bias=True) return response_fc
def matrix_to_graph(cls, mat_layer, num_vertices, num_features): """Given a 2D matrix of shape (num_vertices, num_features), returns a GraphVertexData object with num_vertices number of nodes with num_features. """ slice_points = str_list([i for i in range(0,num_vertices * num_features + 1, num_features)]) flattened_layer = lbann.Reshape(mat_layer, dims = str(num_vertices * num_features)) sliced_mat_layer = lbann.Slice(flattened_layer, axis = 0, slice_points = slice_points) list_of_layers = [] for node in range(num_vertices): temp = lbann.Identity(sliced_mat_layer) list_of_layers.append(lbann.Reshape(temp, dims=str_list([1, num_features]))) return cls(list_of_layers, num_features)
def encoder_cnn(self, y): img_sca = lbann.Slice(y, axis=0, slice_points="0 16384 16399", name=self.name + '_y_slice') #assume C first, is data C first? img = lbann.Reshape(img_sca, dims='4 64 64', name=self.name + 'enc_reshape0') x = self.enc_conv[2](self.enc_conv[1](self.enc_conv[0](img))) x = lbann.Reshape(x, dims=str(16 * 8 * 8), name=self.name + 'enc_reshape1') h_stack = lbann.Concatenation([x, img_sca], axis=0) z = self.enc_out(h_stack) return z
def Graph_Data_Parser(_lbann_input_, num_nodes, node_feature_size, max_edges, num_classes=1): """ A parser for graph structured data with node features, source and target node indices (COO) format, and a target Args: _lbann_input_ (Layer): The input layer of the LBANN model num_nodes (int): The maximum number of nodes in the dataset node_features_size (int): The dimensionality of the node features matrix max_edges (int): The maximum number of edges in the dataset num_classes (int): The number of classes in the target or 1 for regression (default : 1) Returns: (dictionary) Returns a dictionary with the keys: node_features, source_indices, target_indices, and targets """ slice_points = [ 0, num_nodes * node_feature_size, max_edges, max_edges, num_classes ] shifted_slice_points = list(accumulate(slice_points)) sliced_input = lbann.Slice(_lbann_input_, slice_points=str_list(shifted_slice_points), name="Sliced_Graph_Input") node_features = lbann.Reshape(lbann.Identity(sliced_input), dims=str_list([num_nodes, node_feature_size]), name="Node_Feature_Matrix") source_indices = lbann.Identity(sliced_input) target_indices = lbann.Identity(sliced_input) targets = lbann.Identity(sliced_input) graph_data = { "node_features": node_features, "source_indices": source_indices, "target_indices": target_indices, "target": targets } return graph_data
def forward(self, hidden_states): # We "pool" the model by simply taking the hidden state corresponding # to the first token. first_token_tensor = lbann.Slice(hidden_states, axis=1, slice_points=str_list([0, 1])) pooled_output = lbann.modules.PytorchLinear( first_token_tensor, (self.input_shape[0], self.input_shape[-1]), self.hidden_size, weights=_load_pretrained_weights( ".".join((self.name, "dense.weight")), ".".join((self.name, "dense.bias")), load_weights=self.load_weights, ), name=".".join((self.name, "dense")), ) pooled_output = lbann.Tanh(pooled_output, name=".".join((self.name, "activation"))) return pooled_output
def forward_encoder(self, x_emb): """Encoder step, emulating z ~ E(x) = q_E(z|x) :param x_emb: (n_batch, len(x), d_z) of floats, embeddings for input sentence x :return: (n_batch, d_z) of floats, sample of latent vector z :return: float, kl term component of loss """ # _, h = self.encoder_rnn(x, None) h = self.encoder_rnn(x_emb, None) h = lbann.Slice( h, slice_points=str_list( [self.input_feature_dims - 1, self.input_feature_dims]), axis=0, ) h = lbann.Identity(h) z = self.q_mu(h) return z
def setup(num_patches=3, mini_batch_size=512, num_epochs=75, learning_rate=0.005, bn_statistics_group_size=2, fc_data_layout='model_parallel', warmup=True, checkpoint_interval=None): # Data dimensions patch_dims = patch_generator.patch_dims num_labels = patch_generator.num_labels(num_patches) # Extract tensors from data sample input = lbann.Input() slice_points = [0] for _ in range(num_patches): patch_size = functools.reduce(operator.mul, patch_dims) slice_points.append(slice_points[-1] + patch_size) slice_points.append(slice_points[-1] + num_labels) sample = lbann.Slice(input, slice_points=str_list(slice_points)) patches = [ lbann.Reshape(sample, dims=str_list(patch_dims)) for _ in range(num_patches) ] labels = lbann.Identity(sample) # Siamese network head_cnn = modules.ResNet( bn_statistics_group_size=bn_statistics_group_size) heads = [head_cnn(patch) for patch in patches] heads_concat = lbann.Concatenation(heads) # Classification network class_fc1 = modules.FcBnRelu( 4096, statistics_group_size=bn_statistics_group_size, name='siamese_class_fc1', data_layout=fc_data_layout) class_fc2 = modules.FcBnRelu( 4096, statistics_group_size=bn_statistics_group_size, name='siamese_class_fc2', data_layout=fc_data_layout) class_fc3 = lbann.modules.FullyConnectedModule(num_labels, activation=lbann.Softmax, name='siamese_class_fc3', data_layout=fc_data_layout) x = class_fc1(heads_concat) x = class_fc2(x) probs = class_fc3(x) # Setup objective function cross_entropy = lbann.CrossEntropy([probs, labels]) l2_reg_weights = set() for l in lbann.traverse_layer_graph(input): if type(l) == lbann.Convolution or type(l) == lbann.FullyConnected: l2_reg_weights.update(l.weights) l2_reg = lbann.L2WeightRegularization(weights=l2_reg_weights, scale=0.0002) obj = lbann.ObjectiveFunction([cross_entropy, l2_reg]) # Setup model metrics = [ lbann.Metric(lbann.CategoricalAccuracy([probs, labels]), name='accuracy', unit='%') ] callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()] if checkpoint_interval: callbacks.append( lbann.CallbackCheckpoint(checkpoint_dir='ckpt', checkpoint_epochs=5)) # Learning rate schedules if warmup: callbacks.append( lbann.CallbackLinearGrowthLearningRate(target=learning_rate * mini_batch_size / 128, num_epochs=5)) callbacks.append( lbann.CallbackDropFixedLearningRate(drop_epoch=list(range(0, 100, 15)), amt=0.25)) # Construct model model = lbann.Model(num_epochs, layers=lbann.traverse_layer_graph(input), objective_function=obj, metrics=metrics, callbacks=callbacks) # Setup optimizer opt = lbann.SGD(learn_rate=learning_rate, momentum=0.9) # opt = lbann.Adam(learn_rate=learning_rate, beta1=0.9, beta2=0.999, eps=1e-8) # Setup data reader data_reader = make_data_reader(num_patches) # Return experiment objects return model, data_reader, opt
def forward(self, queries, keys, values, mask=None): """Apply multi-head attention. The input and output tensors are interpreted as sequences of vectors, where the first tensor dimension is the sequence dimension. Args: queries (lbann.Layer): Sequence of query vectors. keys (lbann.Layer): Sequence of key vectors. values (lbann.Layer): Sequence of value vectors. mask (lbann.Layer, optional): Additive attention mask. If the (i,j) entry is very negative (e.g. -1e9), then the ith query does not attend to the jth key/value pair. Returns: lbann.Layer: Sequence of output vectors. The sequence length is the same as `queries`. """ self.instance += 1 name = f'{self.name}_instance{self.instance}' # Apply fully-connected layers to input sequences queries_fc = lbann.ChannelwiseFullyConnected( queries, weights=self.query_weights, output_channel_dims=[self.embed_dim], name=f'{name}_queries_fc', ) keys_fc = lbann.ChannelwiseFullyConnected( keys, weights=self.key_weights, output_channel_dims=[self.embed_dim], name=f'{name}_keys_fc', ) values_fc = lbann.ChannelwiseFullyConnected( values, weights=self.value_weights, output_channel_dims=[self.embed_dim], name=f'{name}_values_fc', ) # Slice embedding vectors for each head slice_points = str_list(self.head_dim * i for i in range(self.num_heads + 1)) queries_slice = lbann.Slice( queries_fc, axis=1, slice_points=slice_points, name=f'{name}_queries_slice', ) keys_slice = lbann.Slice( keys_fc, axis=1, slice_points=slice_points, name=f'{name}_keys_slice', ) values_slice = lbann.Slice( values_fc, axis=1, slice_points=slice_points, name=f'{name}_values_slice', ) # Compute scaled dot-product attention for each head attentions = [] for head in range(self.num_heads): head_name = f'{name}_head{head}' # Attention inputs q = lbann.Identity(queries_slice) k = lbann.Identity(keys_slice) v = lbann.Identity(values_slice) # Multiply queries and keys # Note: num_queries x num_keys y = lbann.MatMul( q, k, transpose_b=True, name=f'{head_name}_matmul', ) y = lbann.WeightedSum( y, scaling_factors=str(1 / math.sqrt(self.head_dim)), name=f'{head_name}_scale', ) if mask: y = lbann.Add(y, mask, name=f'{head_name}_mask') y = lbann.ChannelwiseSoftmax(y, name=f'{head_name}_softmax') # Attention output # Note: num_queries x head_dim attentions.append(lbann.MatMul(y, v, name=head_name)) # Concatenate heads and apply fully-connected layer attentions = lbann.Concatenation(attentions, axis=1, name=f'{name}_heads_concat') outputs_fc = lbann.ChannelwiseFullyConnected( attentions, weights=self.output_weights, output_channel_dims=[self.embed_dim], name=f'{name}', ) return outputs_fc
type=int, help='latent space dimensions (default: 128)', metavar='NUM') args = parser.parse_args() # ---------------------------------- # Construct layer graph # ---------------------------------- # Dataset properties vocab_size = dataset.corpus.vocab_size sequence_length = dataset.sample_dims()[0] # Input is a sequence of token IDs input_ = lbann.Identity(lbann.Input()) input_slice = lbann.Slice(input_, slice_points=str_list(range(sequence_length + 1))) tokens_list = [lbann.Identity(input_slice) for _ in range(sequence_length)] # Get sequence of embedding vectors embeddings = lbann.Embedding(input_, num_embeddings=vocab_size, embedding_dim=args.latent_dim) embeddings_slice = lbann.Slice(embeddings, axis=0, slice_points=str_list(range(sequence_length + 1))) embeddings_list = [ lbann.Reshape(embeddings_slice, dims='-1') for _ in range(sequence_length) ] # Layer modules
def construct_macc_surrogate_model(xdim, ydim, zdim, wae_mcf, surrogate_mcf, lambda_cyc, useCNN, dump_models, pretrained_dir, ltfb_batch_interval, num_epochs): """Construct MACC surrogate model. See https://arxiv.org/pdf/1912.08113.pdf model architecture and other details """ # Layer graph input = lbann.Input(data_field='samples', name='inp_data') # data is 64*64*4 images + 15 scalar + 5 param inp_slice = lbann.Slice(input, axis=0, slice_points=str_list([0, ydim, ydim + xdim]), name='inp_slice') gt_y = lbann.Identity(inp_slice, name='gt_y') gt_x = lbann.Identity(inp_slice, name='gt_x') #param not used zero = lbann.Constant(value=0.0, num_neurons='1', name='zero') one = lbann.Constant(value=1.0, num_neurons='1', name='one') z = lbann.Gaussian(mean=0.0, stdev=1.0, neuron_dims="20") wae = macc_network_architectures.MACCWAE( zdim, ydim, cf=wae_mcf, use_CNN=useCNN) #pretrained, freeze inv = macc_network_architectures.MACCInverse(xdim, cf=surrogate_mcf) fwd = macc_network_architectures.MACCForward(zdim, cf=surrogate_mcf) y_pred_fwd = wae.encoder(gt_y) param_pred_ = wae.encoder(gt_y) input_fake = inv(param_pred_) output_cyc = fwd(input_fake) y_image_re2 = wae.decoder(output_cyc) '''**** Train cycleGAN input params <--> latent space of (images, scalars) ****''' output_fake = fwd(gt_x) y_image_re = wae.decoder(output_fake) param_pred2_ = wae.encoder(y_image_re) input_cyc = inv(param_pred2_) L_l2_x = lbann.MeanSquaredError(input_fake, gt_x) L_cyc_x = lbann.MeanSquaredError(input_cyc, gt_x) L_l2_y = lbann.MeanSquaredError(output_fake, y_pred_fwd) L_cyc_y = lbann.MeanSquaredError(output_cyc, y_pred_fwd) #@todo slice here to separate scalar from image img_sca_loss = lbann.MeanSquaredError(y_image_re, gt_y) #L_cyc = L_cyc_y + L_cyc_x L_cyc = lbann.Add(L_cyc_y, L_cyc_x) #loss_gen0 = L_l2_y + lamda_cyc*L_cyc loss_gen0 = lbann.WeightedSum([L_l2_y, L_cyc], scaling_factors=f'1 {lambda_cyc}') loss_gen1 = lbann.WeightedSum([L_l2_x, L_cyc_y], scaling_factors=f'1 {lambda_cyc}') #loss_gen1 = L_l2_x + lamda_cyc*L_cyc_y layers = list(lbann.traverse_layer_graph(input)) weights = set() #Freeze appropriate (pretrained) weights pretrained_models = ["wae"] #add macc? for l in layers: for idx in range(len(pretrained_models)): if (l.weights and pretrained_models[idx] in l.name): for w in range(len(l.weights)): l.weights[w].optimizer = lbann.NoOptimizer() weights.update(l.weights) l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4) #d_adv_bce = lbann.LayerTerm(d_adv_bce,scale=0.01) # Setup objective function obj = lbann.ObjectiveFunction([loss_gen0, loss_gen1, l2_reg]) # Initialize check metric callback metrics = [ lbann.Metric(img_sca_loss, name='fw_loss'), lbann.Metric(L_l2_x, name='inverse loss'), lbann.Metric(L_cyc_y, name='output cycle loss'), lbann.Metric(L_cyc_x, name='param cycle loss') ] callbacks = [ lbann.CallbackPrint(), lbann.CallbackSaveModel(dir=dump_models), lbann.CallbackLoadModel(dirs=str(pretrained_dir)), lbann.CallbackTimer() ] if (ltfb_batch_interval > 0): callbacks.append( lbann.CallbackLTFB(batch_interval=ltfb_batch_interval, metric='fw_loss', low_score_wins=True, exchange_hyperparameters=True)) # Construct model return lbann.Model(num_epochs, weights=weights, layers=layers, metrics=metrics, objective_function=obj, callbacks=callbacks)
def construct_jag_wae_model(ydim, zdim, mcf, useCNN, dump_models, ltfb_batch_interval, num_epochs): """Construct LBANN model. JAG Wasserstein autoencoder model """ # Layer graph input = lbann.Input(data_field='samples', name='inp_data') # data is 64*64*4 images + 15 scalar + 5 param #inp_slice = lbann.Slice(input, axis=0, slice_points="0 16399 16404",name='inp_slice') inp_slice = lbann.Slice(input, axis=0, slice_points=str_list([0, ydim, ydim + 5]), name='inp_slice') gt_y = lbann.Identity(inp_slice, name='gt_y') gt_x = lbann.Identity(inp_slice, name='gt_x') #param not used zero = lbann.Constant(value=0.0, num_neurons='1', name='zero') one = lbann.Constant(value=1.0, num_neurons='1', name='one') z_dim = 20 #Latent space dim z = lbann.Gaussian(mean=0.0, stdev=1.0, neuron_dims="20") model = macc_network_architectures.MACCWAE(zdim, ydim, cf=mcf, use_CNN=useCNN) d1_real, d1_fake, d_adv, pred_y = model(z, gt_y) d1_real_bce = lbann.SigmoidBinaryCrossEntropy([d1_real, one], name='d1_real_bce') d1_fake_bce = lbann.SigmoidBinaryCrossEntropy([d1_fake, zero], name='d1_fake_bce') d_adv_bce = lbann.SigmoidBinaryCrossEntropy([d_adv, one], name='d_adv_bce') img_loss = lbann.MeanSquaredError([pred_y, gt_y]) rec_error = lbann.L2Norm2( lbann.WeightedSum([pred_y, gt_y], scaling_factors="1 -1")) layers = list(lbann.traverse_layer_graph(input)) # Setup objective function weights = set() src_layers = [] dst_layers = [] for l in layers: if (l.weights and "disc0" in l.name and "instance1" in l.name): src_layers.append(l.name) #freeze weights in disc2 if (l.weights and "disc1" in l.name): dst_layers.append(l.name) for idx in range(len(l.weights)): l.weights[idx].optimizer = lbann.NoOptimizer() weights.update(l.weights) l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4) d_adv_bce = lbann.LayerTerm(d_adv_bce, scale=0.01) obj = lbann.ObjectiveFunction( [d1_real_bce, d1_fake_bce, d_adv_bce, img_loss, rec_error, l2_reg]) # Initialize check metric callback metrics = [lbann.Metric(img_loss, name='recon_error')] #pred_y = macc_models.MACCWAE.pred_y_name callbacks = [ lbann.CallbackPrint(), lbann.CallbackTimer(), lbann.CallbackPrintModelDescription(), lbann.CallbackSaveModel(dir=dump_models), lbann.CallbackReplaceWeights(source_layers=list2str(src_layers), destination_layers=list2str(dst_layers), batch_interval=2) ] if (ltfb_batch_interval > 0): callbacks.append( lbann.CallbackLTFB(batch_interval=ltfb_batch_interval, metric='recon_error', low_score_wins=True, exchange_hyperparameters=True)) # Construct model return lbann.Model(num_epochs, weights=weights, layers=layers, metrics=metrics, objective_function=obj, callbacks=callbacks)
def make_model( num_epochs, embed_dim, num_heads, label_smoothing, ): # Embedding weights var = 2 / (embed_dim + vocab_size) # Glorot initialization embedding_weights = lbann.Weights( name='embeddings', initializer=lbann.NormalInitializer(standard_deviation=math.sqrt(var)), ) # Input is two sequences of token IDs input_ = lbann.Input(data_field='samples') # Get sequences of embedding vectors # Note: Scale embeddings by sqrt(embed_dim). # Note: Decoder input is shifted right, so embedding for last # token isn't needed. embeddings_tokens = lbann.Identity( lbann.Slice( input_, axis=0, slice_points=str_list([0, 2 * sequence_length - 1]), )) embeddings = lbann.Embedding( embeddings_tokens, weights=embedding_weights, num_embeddings=vocab_size, embedding_dim=embed_dim, padding_idx=pad_index, ) embeddings = lbann.WeightedSum( embeddings, scaling_factors=str(math.sqrt(embed_dim)), ) embeddings_slice = lbann.Slice( embeddings, axis=0, slice_points=str_list([0, sequence_length, 2 * sequence_length - 1]), ) encoder_input = lbann.Identity(embeddings_slice) decoder_input = lbann.Identity(embeddings_slice) # Apply transformer model transformer = lbann.models.Transformer( hidden_size=embed_dim, num_heads=num_heads, name='transformer', ) result = transformer( encoder_input, sequence_length, decoder_input, sequence_length - 1, ) # Reconstruct decoder input preds = lbann.ChannelwiseFullyConnected( result, weights=embedding_weights, output_channel_dims=[vocab_size], bias=False, transpose=True, ) preds = lbann.ChannelwiseSoftmax(preds) preds = lbann.Slice(preds, axis=0, slice_points=str_list(range(sequence_length))) preds = [lbann.Identity(preds) for _ in range(sequence_length - 1)] # Count number of non-pad tokens label_tokens = lbann.Identity( lbann.Slice( input_, slice_points=str_list([sequence_length + 1, 2 * sequence_length]), )) pads = lbann.Constant(value=pad_index, num_neurons=str(sequence_length - 1)) is_not_pad = lbann.NotEqual(label_tokens, pads) num_not_pad = lbann.Reduction(is_not_pad, mode='sum') # Cross entropy loss with label smoothing label_tokens = lbann.Slice( label_tokens, slice_points=str_list(range(sequence_length)), ) label_tokens = [ lbann.Identity(label_tokens) for _ in range(sequence_length - 1) ] if label_smoothing > 0: uniform_label = lbann.Constant(value=1 / vocab_size, num_neurons=str_list([1, vocab_size])) loss = [] for i in range(sequence_length - 1): label = lbann.OneHot(label_tokens[i], size=vocab_size) label = lbann.Reshape(label, dims=str_list([1, vocab_size])) if label_smoothing > 0: label = lbann.WeightedSum( label, uniform_label, scaling_factors=str_list( [1 - label_smoothing, label_smoothing]), ) loss.append(lbann.CrossEntropy(preds[i], label)) loss = lbann.Concatenation(loss) # Average cross entropy over non-pad tokens loss_scales = lbann.Divide( is_not_pad, lbann.Tessellate(num_not_pad, hint_layer=is_not_pad), ) loss = lbann.Multiply(loss, loss_scales) loss = lbann.Reduction(loss, mode='sum') # Construct model metrics = [] callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()] return lbann.Model( num_epochs, layers=lbann.traverse_layer_graph(input_), objective_function=loss, metrics=metrics, callbacks=callbacks, )
def construct_model(run_args): """Construct LBANN model. Initial model for ATOM molecular VAE """ import lbann print("Dump model dir ", run_args.dump_model_dir) assert run_args.dump_model_dir, "evaluate script asssumes a pretrained WAE model" pad_index = run_args.pad_index assert pad_index is not None sequence_length = run_args.sequence_length assert sequence_length is not None print("sequence length is {}".format(sequence_length)) data_layout = "data_parallel" # Layer graph input_ = lbann.Identity(lbann.Input(name='inp', target_mode="N/A"), name='inp1') wae_loss = [] input_feature_dims = sequence_length embedding_size = run_args.embedding_dim dictionary_size = run_args.num_embeddings assert embedding_size is not None assert dictionary_size is not None save_output = False print("save output? ", save_output, "out dir ", run_args.dump_outputs_dir) z = lbann.Gaussian(mean=0.0, stdev=1.0, neuron_dims="128") x = lbann.Slice(input_, slice_points=str_list([0, input_feature_dims])) x = lbann.Identity(x) waemodel = molwae.MolWAE(input_feature_dims, dictionary_size, embedding_size, pad_index, save_output) x_emb = lbann.Embedding(x, num_embeddings=waemodel.dictionary_size, embedding_dim=waemodel.embedding_size, name='emb', weights=waemodel.emb_weights) latentz = waemodel.forward_encoder(x_emb) fake_loss = lbann.MeanAbsoluteError(latentz, z) layers = list(lbann.traverse_layer_graph(input_)) # Setup objective function weights = set() for l in layers: weights.update(l.weights) obj = lbann.ObjectiveFunction(fake_loss) callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()] #Dump output (activation) for post processing conc_out = lbann.Concatenation([input_, latentz], name='conc_out') callbacks.append( lbann.CallbackDumpOutputs( batch_interval=run_args.dump_outputs_interval, execution_modes='test', directory=run_args.dump_outputs_dir, layers=f'{conc_out.name}')) # Construct model return lbann.Model(run_args.num_epochs, weights=weights, layers=layers, objective_function=obj, callbacks=callbacks)
def compute_loss(self, x, y): # y[:, :-1] y = lbann.Slice( y, axis=0, slice_points=str_list([0, self.input_feature_dims - 1]), ) y = lbann.Identity(y) # x[:, 1:] x = lbann.Slice( x, slice_points=str_list([1, self.input_feature_dims]), ) x = lbann.Identity(x) # Convert indices in x to one-hot representation # Note: Ignored indices result in zero vectors ignore_mask = lbann.Equal( x, self.constant(self.label_to_ignore, hint_layer=x), ) keep_mask = lbann.LogicalNot(ignore_mask) length = lbann.Reduction(keep_mask, mode='sum') length = lbann.Max(length, self.constant(1, [1])) x = lbann.Add( lbann.Multiply(keep_mask, x), lbann.Multiply(ignore_mask, self.constant(-1, hint_layer=x)), ) x = lbann.Slice(x, slice_points=str_list(range(self.input_feature_dims))) x = [lbann.Identity(x) for _ in range(self.input_feature_dims - 1)] x = [lbann.OneHot(xi, size=self.dictionary_size) for xi in x] x = [ lbann.Reshape(xi, dims=str_list([1, self.dictionary_size])) for xi in x ] x = lbann.Concatenation(x, axis=0) # recon_loss = F.cross_entropy( # y[:, :-1].contiguous().view(-1, y.size(-1)), # x[:, 1:].contiguous().view(-1), # ignore_index=self.pad # ) # Note: Ideally we'd shift y by y.max(-1) for numerical stability shifts = lbann.MatMul( lbann.Max(y, self.constant(0, hint_layer=y)), self.constant( 1 / math.sqrt(self.dictionary_size), [self.dictionary_size, self.dictionary_size], ), ) y = lbann.Subtract(y, shifts) z = lbann.MatMul( lbann.Exp(y), self.constant(1, [self.dictionary_size, 1]), ) z = lbann.Log(z) z = lbann.MatMul( lbann.Reshape(keep_mask, dims=str_list([1, -1])), z, ) recon_loss = lbann.MatMul( lbann.Reshape(y, dims=str_list([1, -1])), lbann.Reshape(x, dims=str_list([1, -1])), transpose_b=True, ) recon_loss = lbann.Subtract(z, recon_loss) recon_loss = lbann.Reshape(recon_loss, dims=str_list([1])) recon_loss = lbann.Divide(recon_loss, length) return recon_loss
def forward(self, x, prev_state): """ Apply GRU step channelwise Args: x (Layer): Input (shape: (num_channels, *)) prev_state (Layer): Sate from previous GRU step (shape: (num_channels, size)) Returns: (Layer, Layer): The output (out) and state (hn). The state can be passed directly into the next GRU step """ self.step += 1 name = f"{self.name}_step{self.step}" mat_size = self.num_channels * self.size prev_state = lbann.Reshape(prev_state, dims=str_list( [self.num_channels, self.size]), name=name + "_prev_state_reshape") fc1 = self.ih_fc(x) fc2 = self.hh_fc(prev_state) fc1_slice = lbann.Slice( fc1, axis=1, slice_points=str_list([0, self.size, 2 * self.size, 3 * self.size])) Wir_x = lbann.Reshape(lbann.Identity(fc1_slice), dims=str_list([self.num_channels, self.size]), name=name + '_Wir_x') Wiz_z = lbann.Reshape(lbann.Identity(fc1_slice), dims=str_list([self.num_channels, self.size]), name=name + '_Wiz_z') Win_x = lbann.Reshape(lbann.Identity(fc1_slice), dims=str_list([self.num_channels, self.size]), name=name + '_Win_x') fc2_slice = lbann.Slice( fc2, axis=1, slice_points=str_list([0, self.size, 2 * self.size, 3 * self.size])) Whr_x = lbann.Reshape(lbann.Identity(fc2_slice), dims=str_list([self.num_channels, self.size]), name=name + '_Whr_x') Whz_z = lbann.Reshape(lbann.Identity(fc2_slice), dims=str_list([self.num_channels, self.size]), name=name + '_Whz_z') Whn_x = lbann.Reshape(lbann.Identity(fc2_slice), dims=str_list([self.num_channels, self.size]), name=name + '_Whn_x') rt = \ lbann.Sigmoid( lbann.Add(Wir_x, Whr_x, data_layout=self.data_layout), name=name + '_reset_gate', data_layout=self.data_layout ) zt = \ lbann.Sigmoid( lbann.Add(Wiz_z, Whz_z, data_layout=self.data_layout), name=name + '_update_gate', data_layout=self.data_layout, ) nt = \ lbann.Tanh( lbann.Add( Win_x, lbann.Multiply(rt, Whn_x, data_layout=self.data_layout), data_layout=self.data_layout, ), name=name + '_new_gate', data_layout=self.data_layout, ) ht = \ lbann.Add( lbann.Multiply( lbann.WeightedSum( self.ones, zt, scaling_factors='1 -1', data_layout=self.data_layout ), nt, data_layout=self.data_layout ), lbann.Multiply(zt, prev_state, data_layout=self.data_layout), name=name+ '_output', data_layout=self.data_layout, ) ht = lbann.Reshape(ht, dims=str_list([self.num_channels, self.size])) return ht, ht
def construct_model(): """Construct LBANN model. JAG Wasserstein autoencoder model """ import lbann # Layer graph input = lbann.Input(target_mode='N/A',name='inp_data') # data is 64*64*4 images + 15 scalar + 5 param inp_slice = lbann.Slice(input, axis=0, slice_points="0 16399 16404",name='inp_slice') gt_y = lbann.Identity(inp_slice,name='gt_y') gt_x = lbann.Identity(inp_slice, name='gt_x') #param not used zero = lbann.Constant(value=0.0,num_neurons='1',name='zero') one = lbann.Constant(value=1.0,num_neurons='1',name='one') y_dim = 16399 #image+scalar shape z_dim = 20 #Latent space dim z = lbann.Gaussian(mean=0.0,stdev=1.0, neuron_dims="20") d1_real, d1_fake, d_adv, pred_y = jag_models.WAE(z_dim,y_dim)(z,gt_y) d1_real_bce = lbann.SigmoidBinaryCrossEntropy([d1_real,one],name='d1_real_bce') d1_fake_bce = lbann.SigmoidBinaryCrossEntropy([d1_fake,zero],name='d1_fake_bce') d_adv_bce = lbann.SigmoidBinaryCrossEntropy([d_adv,one],name='d_adv_bce') img_loss = lbann.MeanSquaredError([pred_y,gt_y]) rec_error = lbann.L2Norm2(lbann.WeightedSum([pred_y,gt_y], scaling_factors="1 -1")) layers = list(lbann.traverse_layer_graph(input)) # Setup objective function weights = set() src_layers = [] dst_layers = [] for l in layers: if(l.weights and "disc0" in l.name and "instance1" in l.name): src_layers.append(l.name) #freeze weights in disc2 if(l.weights and "disc1" in l.name): dst_layers.append(l.name) for idx in range(len(l.weights)): l.weights[idx].optimizer = lbann.NoOptimizer() weights.update(l.weights) l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4) d_adv_bce = lbann.LayerTerm(d_adv_bce,scale=0.01) obj = lbann.ObjectiveFunction([d1_real_bce,d1_fake_bce,d_adv_bce,img_loss,rec_error,l2_reg]) # Initialize check metric callback metrics = [lbann.Metric(img_loss, name='recon_error')] callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer(), lbann.CallbackReplaceWeights(source_layers=list2str(src_layers), destination_layers=list2str(dst_layers), batch_interval=2)] # Construct model num_epochs = 100 return lbann.Model(num_epochs, weights=weights, layers=layers, metrics=metrics, objective_function=obj, callbacks=callbacks)
# ---------------------------------- # Construct layer graph # ---------------------------------- # Properties of graph and random walk num_graph_nodes = dataset.max_graph_node_id() + 1 walk_length = dataset.walk_context_length num_negative_samples = dataset.num_negative_samples input_size = dataset.sample_dims()[0] # Embedding vectors, including negative sampling # Note: Input is sequence of graph node IDs input_ = lbann.Identity(lbann.Input()) input_slice = lbann.Slice( input_, slice_points=f'0 {num_negative_samples+1} {input_size}' ) decoder_embeddings = lbann.Embedding( input_slice, weights=decoder_embeddings_weights, num_embeddings=num_graph_nodes, embedding_dim=args.latent_dim, ) encoder_embeddings = lbann.Embedding( input_slice, weights=encoder_embeddings_weights, num_embeddings=num_graph_nodes, embedding_dim=args.latent_dim, ) # Skip-Gram with negative sampling
_reader.synth_dimensions = '1' _reader.percent_of_data_to_use = 1.0 add_data_reader('train') add_data_reader('test') input_ = lbann.Input() # Radial profile x = lbann.WeightsLayer( weights=lbann.Weights( lbann.ValueInitializer(values=str_list(image.flatten())), ), dims=str_list(image.shape), ) max_r = image.shape[-1] // 2 rprof = RadialProfile()(x, image.shape, max_r) rprof_slice = lbann.Slice(rprof, slice_points=str_list([0, 1, 2, 3])) red = lbann.Identity(rprof_slice, name='red') green = lbann.Identity(rprof_slice, name='green') blue = lbann.Identity(rprof_slice, name='blue') # Construct model callbacks = [ lbann.CallbackDumpOutputs(layers=str_list(['red', 'green', 'blue'])), ] model = lbann.Model( epochs=0, layers=lbann.traverse_layer_graph([input_, rprof]), callbacks=callbacks, ) # Run LBANN
def construct_model(run_args): """Construct LBANN model. Initial model for ATOM molecular VAE """ import lbann pad_index = run_args.pad_index assert pad_index is not None sequence_length = run_args.sequence_length assert sequence_length is not None, 'should be training seq len + bos + eos' print("sequence length is {}, which is training sequence len + bos + eos".format(sequence_length)) data_layout = "data_parallel" # Layer graph input_ = lbann.Input(data_field='samples',name='inp_data') #Note input assumes to come from encoder script concatenation of input smiles + z inp_slice = lbann.Slice(input_, axis=0, slice_points=str_list([0, sequence_length, sequence_length+run_args.z_dim]), name='inp_slice') inp_smile = lbann.Identity(inp_slice,name='inp_smile') z = lbann.Identity(inp_slice, name='z') wae_loss= [] input_feature_dims = sequence_length embedding_size = run_args.embedding_dim dictionary_size = run_args.num_embeddings assert embedding_size is not None assert dictionary_size is not None save_output = True if run_args.dump_outputs_dir else False print("save output? ", save_output, "out dir ", run_args.dump_outputs_dir) #uncomment below for random sampling #z = lbann.Gaussian(mean=0.0,stdev=1.0, neuron_dims=str(run_args.z_dim)) x = lbann.Slice(inp_smile, slice_points=str_list([0, input_feature_dims])) x = lbann.Identity(x) waemodel = molwae.MolWAE(input_feature_dims, dictionary_size, embedding_size, pad_index,run_args.z_dim,save_output=save_output) x_emb = lbann.Embedding( x, num_embeddings=waemodel.dictionary_size, embedding_dim=waemodel.embedding_size, name='emb', weights=waemodel.emb_weights ) pred, arg_max = waemodel.forward_decoder(x_emb,z) recon = waemodel.compute_loss(x, pred) wae_loss.append(recon) layers = list(lbann.traverse_layer_graph(input_)) # Setup objective function weights = set() for l in layers: weights.update(l.weights) #l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4) #wae_loss.append(l2_reg) print("LEN wae loss ", len(wae_loss)) obj = lbann.ObjectiveFunction(wae_loss) # Initialize check metric callback metrics = [lbann.Metric(recon, name='recon')] callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()] #Dump output (activation) for post processing pred_tensor = lbann.Concatenation(arg_max, name='pred_tensor') conc_out = lbann.Concatenation([input_,pred_tensor], name='conc_out') callbacks.append(lbann.CallbackDumpOutputs(batch_interval=run_args.dump_outputs_interval, execution_modes='test', directory=run_args.dump_outputs_dir, layers=f'{conc_out.name}')) # Construct model return lbann.Model(run_args.num_epochs, weights=weights, layers=layers, objective_function=obj, metrics=metrics, callbacks=callbacks)
def forward(self, x, prev_state): """Apply GRU step. Args: x (Layer): Input. prev_state: State from previous GRU step. Returns: (Layer, Layer): The output (out) and state (hn). The state can be passed directly into the next GRU step. """ self.step += 1 name = '{0}_step{1}'.format(self.name, self.step) fc1 = self.ih_fc(x) #input_fc fc2 = self.hh_fc(prev_state) #hidden_fc # Get gates and cell update fc1_slice = lbann.Slice(fc1, slice_points=str_list([0, self.size, 2*self.size, 3*self.size]), name=name + '_fc1_slice', data_layout=self.data_layout) Wir_x = lbann.Identity(fc1_slice, name=name + '_Wrx', data_layout=self.data_layout) Wiz_x = lbann.Identity(fc1_slice, name=name + '_Wzx', data_layout=self.data_layout) Win_x = lbann.Identity(fc1_slice, name=name + '_Wnx', data_layout=self.data_layout) fc2_slice = lbann.Slice(fc2, slice_points=str_list([0, self.size, 2*self.size, 3*self.size]), name=name + '_fc2_slice', data_layout=self.data_layout) Whr_prev = lbann.Identity(fc2_slice, name=name + '_Wrh', data_layout=self.data_layout) Whz_prev = lbann.Identity(fc2_slice, name=name + '_Wzh', data_layout=self.data_layout) Whn_prev = lbann.Identity(fc2_slice, name=name + '_Wnh', data_layout=self.data_layout) rt = \ lbann.Sigmoid( lbann.Add(Wir_x, Whr_prev, data_layout=self.data_layout), name=name + '_reset_gate', data_layout=self.data_layout ) zt = \ lbann.Sigmoid( lbann.Add(Wiz_x, Whz_prev, data_layout=self.data_layout), name=name + '_update_gate', data_layout=self.data_layout, ) nt = \ lbann.Tanh( lbann.Add( Win_x, lbann.Multiply(rt, Whn_prev, data_layout=self.data_layout), data_layout=self.data_layout, ), name=name + '_new_gate', data_layout=self.data_layout, ) ht = \ lbann.Add( lbann.Multiply( lbann.WeightedSum( self.ones, zt, scaling_factors='1 -1', data_layout=self.data_layout ), nt, data_layout=self.data_layout ), lbann.Multiply(zt, prev_state, data_layout=self.data_layout), name=name+ '_output', data_layout=self.data_layout, ) # Return output return ht, ht
standard_deviation=1 / args.latent_dim, ), name='embeddings', ) embeddings = lbann.Embedding( input_, weights=embeddings_weights, num_embeddings=num_vertices, embedding_dim=args.latent_dim, ) else: raise RuntimeError( f'unknown method to get embedding vectors ({args.embeddings})') embeddings_slice = lbann.Slice( embeddings, axis=0, slice_points=utils.str_list( [0, num_negative_samples, num_negative_samples + walk_length]), ) negative_samples_embeddings = lbann.Identity(embeddings_slice) walk_embeddings = lbann.Identity(embeddings_slice) # Skip-Gram objective function positive_loss = model.skip_gram.positive_samples_loss( walk_length, lbann.Identity(walk_embeddings), lbann.Identity(walk_embeddings), scale_decay=0.8, ) negative_loss = model.skip_gram.negative_samples_loss( walk_embeddings, negative_samples_embeddings,
def forward(self, x, prev_state): """Apply LSTM step. Args: x (Layer): Input. prev_state (tuple with two `Layer`s): State from previous LSTM step. Comprised of LSTM output and cell state. Returns: (Layer, (Layer, Layer)): The output and state (the output and cell state). The state can be passed directly into the next LSTM step. """ self.step += 1 name = '{0}_step{1}'.format(self.name, self.step) # Get output and cell state from previous step prev_output, prev_cell = prev_state # Apply linearity input_concat = lbann.Concatenation(x, prev_output, name=name + '_input', data_layout=self.data_layout) fc = self.fc(input_concat) # Get gates and cell update slice = lbann.Slice(fc, slice_points=str_list([0, self.size, 4*self.size]), name=name + '_fc_slice', data_layout=self.data_layout) cell_update = lbann.Tanh(slice, name=name + '_cell_update', data_layout=self.data_layout) sigmoid = lbann.Sigmoid(slice, name=name + '_sigmoid', data_layout=self.data_layout) slice = lbann.Slice(sigmoid, slice_points=str_list([0, self.size, 2*self.size, 3*self.size]), name=name + '_sigmoid_slice', data_layout=self.data_layout) f = lbann.Identity(slice, name=name + '_forget_gate', data_layout=self.data_layout) i = lbann.Identity(slice, name=name + '_input_gate', data_layout=self.data_layout) o = lbann.Identity(slice, name=name + '_output_gate', data_layout=self.data_layout) # Cell state cell_forget = lbann.Multiply(f, prev_cell, name=name + '_cell_forget', data_layout=self.data_layout) cell_input = lbann.Multiply(i, cell_update, name=name + '_cell_input', data_layout=self.data_layout) cell = lbann.Add(cell_forget, cell_input, name=name + '_cell', data_layout=self.data_layout) # Output cell_act = lbann.Tanh(cell, name=name + '_cell_activation', data_layout=self.data_layout) output = lbann.Multiply(o, cell_act, name=name, data_layout=self.data_layout) # Return output and state return output, (output, cell)