def forward(self, x, z): """Do the WAE forward step :param x: list of tensors of longs, embed representation of input :return: float, kl term component of loss :return: float, recon component of loss """ x = lbann.Slice(x, slice_points=str_list([0, self.input_feature_dims])) x = lbann.Identity(x) x_emb = lbann.Embedding(x, num_embeddings=self.dictionary_size, embedding_dim=self.embedding_size, name='emb', weights=self.emb_weights) # Encoder: x -> z, kl_loss z_sample = self.forward_encoder(x_emb) eps = lbann.Gaussian(mean=self.gmean, stdev=self.gstd, hint_layer=z_sample) z_sample = lbann.Add([z_sample, eps]) # Decoder: x, z -> recon_loss #pred = self.forward_decoder(x_emb, z_sample) pred, arg_max = self.forward_decoder(x_emb, z_sample) recon_loss = self.compute_loss(x, pred) # Hack to remove blocking GPU allreduce in evaluation layer #kl_loss = lbann.Identity(kl_loss, device='CPU') recon_loss = lbann.Identity(recon_loss, device='CPU') z_prior = lbann.Tessellate( lbann.Reshape(z, dims=str_list([1, self.zdim])), dims=str_list([self.input_feature_dims, self.zdim]), ) d_real = self.discriminator0( lbann.Concatenation([x_emb, z_prior], axis=1)) z_sample0 = lbann.Tessellate( lbann.Reshape(z_sample, dims=str_list([1, self.zdim])), dims=str_list([self.input_feature_dims, self.zdim]), ) y_z_sample = lbann.Concatenation([x_emb, z_sample0], axis=1) d_fake = self.discriminator0(lbann.StopGradient(y_z_sample)) d_adv = self.discriminator1(y_z_sample) #freeze return recon_loss, d_real, d_fake, d_adv, arg_max
def forward(self, z, y): z_sample = self.encoder(y) y_recon = self.decoder(z_sample) #d real/fake share weights, shared weights is copied to d_adv #(through replace weight callback) and freeze d_real = self.discriminator0(lbann.Concatenation([y, z], axis=0)) y_z_sample = lbann.Concatenation([y, z_sample], axis=0) d_fake = self.discriminator0(lbann.StopGradient(y_z_sample)) d_adv = self.discriminator1(y_z_sample) #freeze return d_real, d_fake, d_adv, y_recon
def dense_layer(statistics_group_size, current_block_num, current_layer_num, cumulative_layer_num, parent_nodes, batch_norm_size, growth_rate): concatenation_node = lbann.Concatenation(parent_nodes) cumulative_layer_num += 1 log('dense_block={b} dense_layer={l} Concatenation. cumulative_layer_num={n}' .format(b=current_block_num, l=current_layer_num, n=cumulative_layer_num)) conv_block_1_node, cumulative_layer_num = conv_block( statistics_group_size, current_block_num, current_layer_num, cumulative_layer_num, concatenation_node, conv_dims_i=1, conv_pads_i=0, num_output_channels=batch_norm_size * growth_rate) conv_block_2_node, cumulative_layer_num = conv_block( statistics_group_size, current_block_num, current_layer_num, cumulative_layer_num, conv_block_1_node, conv_dims_i=3, conv_pads_i=1, num_output_channels=growth_rate) return conv_block_2_node, cumulative_layer_num
def forward(self, img, z, mcr): ''' Steps: - Modify image if using mcr - D1 + imgs -> d1_real - G + noise -> gen_imgs - D1 + gen_imgs -> d1_fake - Adv (D2) + gen_imgs Return D outputs and gen_imgs ''' print('MCR in forward', mcr) if mcr: ### Multi-channel rescaling. Add extra channel for real images. Generated images are rescaled inside generator linear_scale = 1 / self.linear_scaler ch2 = lbann.Tanh( lbann.WeightedSum(self.inv_transform(lbann.Identity(img)), scaling_factors=str(linear_scale))) y = lbann.Concatenation(lbann.Identity(img), ch2, axis=0) img = lbann.Reshape(y, dims='2 128 128') else: img = lbann.Reshape(img, dims='1 128 128') d1_real = self.forward_discriminator1(img) #instance1 gen_img = self.forward_generator(z, mcr=mcr) d1_fake = self.forward_discriminator1( lbann.StopGradient(gen_img)) #instance2 d_adv = self.forward_discriminator2( gen_img) #instance 3 //need to freeze #d1s share weights, d1_w is copied to d_adv (through replace weight callback) and freeze return d1_real, d1_fake, d_adv, gen_img, img
def forward(self, node_feature_mat, source_indices, target_indices): """Call the GatedGraphConv Args: node_feature_mat (Layer): Node feature matrix with the shape of (num_nodes,input_channels) source_indices (Layer): Source node indices of the edges with shape (num_nodes) target_indices (Layer): Target node indices of the edges with shape (num_nodes) Returns: (Layer) : The output after kernel ops. The output can passed into another Graph Conv layer directly """ if (self.input_channel_size < self.output_channel_size): num_zeros = self.output_channel_size - self.input_channel_size print(num_zeros) zeros = lbann.Constant(value = 0, num_neurons = str_list([self.num_nodes,num_zeros]), name = self.name+'_padded') node_feature_mat = lbann.Concatenation(node_feature_mat, zeros, axis = 1) elif (input_features > self.output_channel_size): ValueError('The feature size of the nodes {} cannot be greater than the output dimension {}'. format(input_features, self.output_channel_size)) for layer in range(self.num_layers): messages = self.nns[layer](node_feature_mat) neighborhoods = GraphExpand(messages, target_indices) aggregate = GraphReduce(neighborhoods,source_indices, [self.num_nodes, self.output_channel_size]) node_feature_mat = self.rnn(aggregate, node_feature_mat) return node_feature_mat
def forward_generator(self, z, mcr): ''' Build the Generator ''' x = lbann.Relu( lbann.BatchNormalization(self.g_fc1(z), decay=0.9, scale_init=1.0, epsilon=1e-5)) dims = '512 8 8' x = lbann.Reshape(x, dims=dims) #channel first for count, lyr in enumerate(self.g_convT): x = lbann.Relu( lbann.BatchNormalization(lyr(x), decay=0.9, scale_init=1.0, epsilon=1e-5)) img = self.g_convT3(x) if mcr: ### For multi-channel rescaling, add extra channel to output image linear_scale = 1 / self.linear_scaler # linear_scale=lbann.Constant(value=0.001) ch2 = lbann.Tanh( lbann.WeightedSum(self.inv_transform(img), scaling_factors=str(linear_scale))) y = lbann.Concatenation(img, ch2, axis=0) img = lbann.Reshape(y, dims='2 128 128') else: img = lbann.Reshape(img, dims='1 128 128') return img
def densenet(version, cumulative_layer_num, images_node ): if version == 121: growth_rate = 32 # k in the paper layers_per_block = (6, 12, 24, 16) num_initial_features = 64 elif version == 161: growth_rate = 48 # k in the paper layers_per_block = (96, 48, 36, 24) num_initial_features = 96 else: raise Exception('Invalid version={v}.'.format(v=version)) batch_norm_size = 4 parent_node, cumulative_layer_num = initial_layer( cumulative_layer_num, images_node, num_initial_features) num_features = num_initial_features # Start counting dense blocks at 1. for current_block_num, num_layers in enumerate(layers_per_block, 1): parent_nodes, cumulative_layer_num = dense_block( cumulative_layer_num, parent_node, batch_norm_size=batch_norm_size, current_block_num=current_block_num, growth_rate=growth_rate, num_layers=num_layers, num_initial_channels=num_initial_features ) # num_features += num_layers * growth_rate for node in parent_nodes[1:]: num_features += node.num_output_channels parent_node = lbann.Concatenation(parent_nodes) cumulative_layer_num += 1 log('densenet Concatenation. cumulative_layer_num={n}'.format( b=current_block_num, n=cumulative_layer_num)) if current_block_num != len(layers_per_block): parent_node, cumulative_layer_num = transition_layer( current_block_num, cumulative_layer_num, parent_node, # In Python 3, this is integer division. num_output_channels=num_features//2, ) num_features //= 2 batch_normalization_node = standard_batchnorm(parent_node) cumulative_layer_num += 1 log('densenet BatchNormalization. cumulative_layer_num={n}'.format( b=current_block_num, n=cumulative_layer_num)) probs = classification_layer( cumulative_layer_num, batch_normalization_node ) return probs
def forward(self, x): """Perform LSTM step. State from previous steps is used to compute output. """ self.step += 1 name = '{0}_step{1}'.format(self.name, self.step) # Apply linearity input_concat = lbann.Concatenation([x, self.last_output], name=name + '_input', data_layout=self.data_layout) fc = self.fc(input_concat) # Get gates and cell update slice = lbann.Slice(fc, slice_points=_str_list([0, self.size, 4*self.size]), name=name + '_fc_slice', data_layout=self.data_layout) cell_update = lbann.Tanh(slice, name=name + '_cell_update', data_layout=self.data_layout) sigmoid = lbann.Sigmoid(slice, name=name + '_sigmoid', data_layout=self.data_layout) slice = lbann.Slice(sigmoid, slice_points=_str_list([0, self.size, 2*self.size, 3*self.size]), name=name + '_sigmoid_slice', data_layout=self.data_layout) f = lbann.Identity(slice, name=name + '_forget_gate', data_layout=self.data_layout) i = lbann.Identity(slice, name=name + '_input_gate', data_layout=self.data_layout) o = lbann.Identity(slice, name=name + '_output_gate', data_layout=self.data_layout) # Cell state cell_forget = lbann.Multiply([f, self.last_cell], name=name + '_cell_forget', data_layout=self.data_layout) cell_input = lbann.Multiply([i, cell_update], name=name + '_cell_input', data_layout=self.data_layout) cell = lbann.Add([cell_forget, cell_input], name=name + '_cell', data_layout=self.data_layout) # Output cell_act = lbann.Tanh(cell, name=name + '_cell_activation', data_layout=self.data_layout) output = lbann.Multiply([o, cell_act], name=name, data_layout=self.data_layout) # Update state and return output self.last_cell = cell self.last_output = output return output
def forward_discriminator2(self,y): ch2 = self.inv_transform(lbann.Identity(y)) y = lbann.Concatenation(lbann.Identity(y),ch2,axis=0) img = lbann.Reshape(y, dims='2 128 128') x = lbann.LeakyRelu(self.d2_conv[0](img), negative_slope=0.2) x = lbann.LeakyRelu(self.d2_conv[1](x), negative_slope=0.2) x = lbann.LeakyRelu(self.d2_conv[2](x), negative_slope=0.2) x = lbann.LeakyRelu(self.d2_conv[3](x), negative_slope=0.2) return self.d2_fc(lbann.Reshape(x,dims='32768'))
def forward(self, x, x_concat=None): self.instance += 1 if x_concat is not None: x = lbann.Concatenation([x, x_concat], axis=0) for c in self.convs: x = c(x) return x
def forward( self, motif_indices, motif_size, walk_indices, walk_length, ): # Apply generator fake_motif_indices, gen_prob, gen_log_prob = self.generator( walk_length, walk_indices, motif_size, ) # Get discriminator embeddings in log-space all_motif_indices = lbann.Concatenation(motif_indices, fake_motif_indices) all_motif_log_embeddings = self.discriminator.get_log_embeddings( all_motif_indices) all_motif_log_embeddings = lbann.Slice( all_motif_log_embeddings, slice_points=str_list([0, motif_size, 2 * motif_size]), ) real_motif_log_embeddings = lbann.Identity(all_motif_log_embeddings) fake_motif_log_embeddings = lbann.Identity(all_motif_log_embeddings) # Apply discriminator real_disc_prob, real_disc_log_not_prob \ = self.discriminator(motif_size, real_motif_log_embeddings) fake_disc_prob, fake_disc_log_not_prob \ = self.discriminator(motif_size, fake_motif_log_embeddings) # Loss function # L_disc = - log(D(real)) - log(1-D(fake)) # L_gen = - log(G) * stop_gradient(log(1-D(fake))) real_disc_log_prob \ = lbann.Log(lbann.Clamp(real_disc_prob, min=1e-37, max=1)) disc_loss = lbann.WeightedSum( real_disc_log_prob, fake_disc_log_not_prob, scaling_factors=str_list([-1, -1]), ) gen_loss = lbann.Multiply( gen_log_prob, lbann.StopGradient(fake_disc_log_not_prob), ) loss = lbann.Add(disc_loss, gen_loss) return loss, real_disc_prob, fake_disc_prob, gen_prob
def get_mat(self, cols = None): """Generates a matrix representation of the graph data. args: cols (int) """ mat = lbann.Concatenation(self.layers) if (cols): mat = lbann.Reshape(mat, dims=str_list([self.shape[0], cols])) else: mat = lbann.Reshape(mat, dims=str_list([self.shape[0], self.shape[1]])) return mat
def decoder_cnn(self, z): x = self.dec_cnn_fc(z) sca = self.dec_fc_sca(lbann.Identity(x)) img = lbann.Reshape(lbann.Identity(x), dims="16 8 8", name=self.name + 'dec_reshape0') img = self.dec_convT[2](lbann.Relu(self.dec_convT[1](lbann.Relu( self.dec_convT[0](img))))) #concat for common interface, slice in output img = lbann.Reshape(img, dims=str(64 * 64 * 4), name=self.name + 'dec_reshape1') #?? check tensor shape #todo check that concat size == dec_out_dim return lbann.Concatenation([img, sca], axis=0)
def forward(self, x): x_slice = lbann.Slice(x, axis=0, slice_points="0 921 4750 8579", name='inp_slice') gene = self.geneT(lbann.Identity(x_slice)) drug1 = self.drug1T(lbann.Identity(x_slice)) drug2 = self.drug2T(lbann.Identity(x_slice)) concat = self.concatT( lbann.Concatenation([gene, drug1, drug2], name=self.name + 'concat')) response_fc = lbann.FullyConnected(concat, num_neurons=1, has_bias=True) return response_fc
def forward_decoder(self, x_emb, z): """Decoder step, emulating x ~ G(z) :param x_emb: (n_batch, len(x), d_z) of floats, embeddings for input sentence x :param z: (n_batch, d_z) of floats, latent vector z :return: float, recon component of loss :return: list of ints, reconstructed sentence """ # z_0 = z.unsqueeze(1).repeat(1, x_emb.size(1), 1) # x_input = torch.cat([x_emb, z_0], dim=-1) z_0 = lbann.Tessellate( lbann.Reshape(z, dims=str_list([1, 128])), dims=str_list([self.input_feature_dims, 128]), ) x_input = lbann.Concatenation(x_emb, z_0, axis=1) h_0 = self.decoder_lat(z) # h_0 = h_0.unsqueeze(0).repeat(self.decoder_rnn.num_layers, 1, 1) h_0 = lbann.Reshape(h_0, dims=str_list([1, 512])) h_0 = lbann.Tessellate(h_0, dims=str_list((3, 512))) # output, _ = self.decoder_rnn(x_input, h_0) output = self.decoder_rnn(x_input, h_0) # y = self.decoder_fc(output) y = lbann.ChannelwiseFullyConnected( output, output_channel_dims=self.dictionary_size, bias=True, name=f'{self.decoder_fc.name}', weights=self.decoder_fc.weights, ) # Set datatype of layers # Note: Depth-first search from y to x_emb and z stack = [y] in_stack = {l: True for l in stack} while stack: l = stack.pop() if type(l) not in (lbann.Slice, lbann.Reshape, lbann.Tessellate): l.datatype = self.datatype for parent in l.parents: if parent not in in_stack and parent not in (x_emb, z): stack.append(parent) in_stack[parent] = True return y
def encoder_cnn(self, y): img_sca = lbann.Slice(y, axis=0, slice_points="0 16384 16399", name=self.name + '_y_slice') #assume C first, is data C first? img = lbann.Reshape(img_sca, dims='4 64 64', name=self.name + 'enc_reshape0') x = self.enc_conv[2](self.enc_conv[1](self.enc_conv[0](img))) x = lbann.Reshape(x, dims=str(16 * 8 * 8), name=self.name + 'enc_reshape1') h_stack = lbann.Concatenation([x, img_sca], axis=0) z = self.enc_out(h_stack) return z
def forward(self, X, A): """Call the GatedGraphConv Args: X (GraphVertexData): LBANN Data object, which is a collection of Layers. Each Layer is of the shape (1,input_channels) A (Layer): Adjacency matrix input with shape (num_nodes, num_nodes) Returns: LBANN_Data_Mat: The output after Gated Graph Kernel. The output can passed into another Graph Conv layer directly """ input_features = X.size(1) num_nodes = X.size(0) if (input_features < self.output_channels): for i in range(num_nodes): num_zeros = self.output_channels - input_features zeros = lbann.Constant(value=0, num_neurons=str_list([1, num_zeros]), name=self.name + '_zero_' + str(i)) X[i] = lbann.Concatenation(X[i], zeros, axis=1) elif (input_features > self.output_channels): ValueError( 'The feature size of the nodes {} cannot be greater than the output dimension {}' .format(input_features, self.output_channels)) X.update_num_features(self.output_channels) for layer in range(self.num_layers): ## X_mat = X.get_mat() messages = lbann.MatMul(X_mat, self.weights[layer]) aggregate = lbann.MatMul(A, messages) M = GraphVertexData.matrix_to_graph(aggregate, num_nodes, self.output_channels) for i in range(num_nodes): X[i] = lbann.Reshape(X[i], dims=str(self.output_channels)) X[i] = lbann.Reshape(self.rnn(M[i], X[i])[1], dims=str_list([1, self.output_channels])) return X
def forward_generator(self, z, mcr): ''' Build the Generator ''' x = lbann.Relu( lbann.BatchNormalization(self.g_fc1(z), decay=0.9, scale_init=1.0, epsilon=1e-5)) dims = '512 8 8' print("dims", dims) x = lbann.Reshape(x, dims=dims) #channel first x = lbann.Relu( lbann.BatchNormalization(self.g_convT[0](x), decay=0.9, scale_init=1.0, epsilon=1e-5)) x = lbann.Relu( lbann.BatchNormalization(self.g_convT[1](x), decay=0.9, scale_init=1.0, epsilon=1e-5)) x = lbann.Relu( lbann.BatchNormalization(self.g_convT[2](x), decay=0.9, scale_init=1.0, epsilon=1e-5)) img = self.g_convT3(x) if mcr: ### For multi-channel rescaling, add extra channel to output image linear_scale = 1 / self.linear_scaler #ch2 = lbann.Tanh(self.inv_transform(img)/linear_scalar) ch2 = lbann.Tanh( lbann.WeightedSum(self.inv_transform(img), scaling_factors=str(linear_scale))) y = lbann.Concatenation(img, ch2, axis=0) img = lbann.Reshape(y, dims='2 128 128') else: img = lbann.Reshape(img, dims='1 128 128') print('Gen Img in GAN', img.__dict__) return img
def forward(self, x, prev_state): """Apply LSTM step. Args: x (Layer): Input. prev_state (tuple with two `Layer`s): State from previous LSTM step. Comprised of LSTM output and cell state. Returns: (Layer, (Layer, Layer)): The output and state (the output and cell state). The state can be passed directly into the next LSTM step. """ self.step += 1 name = '{0}_step{1}'.format(self.name, self.step) # Get output and cell state from previous step prev_output, prev_cell = prev_state # Apply linearity input_concat = lbann.Concatenation(x, prev_output, name=name + '_input', data_layout=self.data_layout) fc = self.fc(input_concat) # Get gates and cell update slice = lbann.Slice(fc, slice_points=str_list([0, self.size, 4*self.size]), name=name + '_fc_slice', data_layout=self.data_layout) cell_update = lbann.Tanh(slice, name=name + '_cell_update', data_layout=self.data_layout) sigmoid = lbann.Sigmoid(slice, name=name + '_sigmoid', data_layout=self.data_layout) slice = lbann.Slice(sigmoid, slice_points=str_list([0, self.size, 2*self.size, 3*self.size]), name=name + '_sigmoid_slice', data_layout=self.data_layout) f = lbann.Identity(slice, name=name + '_forget_gate', data_layout=self.data_layout) i = lbann.Identity(slice, name=name + '_input_gate', data_layout=self.data_layout) o = lbann.Identity(slice, name=name + '_output_gate', data_layout=self.data_layout) # Cell state cell_forget = lbann.Multiply(f, prev_cell, name=name + '_cell_forget', data_layout=self.data_layout) cell_input = lbann.Multiply(i, cell_update, name=name + '_cell_input', data_layout=self.data_layout) cell = lbann.Add(cell_forget, cell_input, name=name + '_cell', data_layout=self.data_layout) # Output cell_act = lbann.Tanh(cell, name=name + '_cell_activation', data_layout=self.data_layout) output = lbann.Multiply(o, cell_act, name=name, data_layout=self.data_layout) # Return output and state return output, (output, cell)
def construct_model(run_args): """Construct LBANN model. Initial model for ATOM molecular VAE """ import lbann print("Dump model dir ", run_args.dump_model_dir) assert run_args.dump_model_dir, "evaluate script asssumes a pretrained WAE model" pad_index = run_args.pad_index assert pad_index is not None sequence_length = run_args.sequence_length assert sequence_length is not None print("sequence length is {}".format(sequence_length)) data_layout = "data_parallel" # Layer graph input_ = lbann.Identity(lbann.Input(name='inp', data_field='samples'), name='inp1') wae_loss = [] input_feature_dims = sequence_length embedding_size = run_args.embedding_dim dictionary_size = run_args.num_embeddings assert embedding_size is not None assert dictionary_size is not None save_output = True if run_args.dump_outputs_dir else False print("save output? ", save_output, "out dir ", run_args.dump_outputs_dir) z = lbann.Gaussian(mean=0.0, stdev=1.0, neuron_dims=run_args.z_dim) waemodel = molwae.MolWAE(input_feature_dims, dictionary_size, embedding_size, pad_index, run_args.z_dim, save_output) recon, d1_real, d1_fake, d_adv, arg_max = waemodel(input_, z) zero = lbann.Constant(value=0.0, num_neurons='1', name='zero') one = lbann.Constant(value=1.0, num_neurons='1', name='one') d1_real_bce = lbann.SigmoidBinaryCrossEntropy([d1_real, one], name='d1_real_bce') d1_fake_bce = lbann.SigmoidBinaryCrossEntropy([d1_fake, zero], name='d1_fake_bce') d_adv_bce = lbann.SigmoidBinaryCrossEntropy([d_adv, one], name='d_adv_bce') wae_loss.append(recon) layers = list(lbann.traverse_layer_graph(input_)) # Setup objective function weights = set() src_layers = [] dst_layers = [] for l in layers: if (l.weights and "disc0" in l.name and "instance1" in l.name): src_layers.append(l.name) #freeze weights in disc2 if (l.weights and "disc1" in l.name): dst_layers.append(l.name) for idx in range(len(l.weights)): l.weights[idx].optimizer = lbann.NoOptimizer() weights.update(l.weights) l2_weights = [ w for w in weights if not isinstance(w.optimizer, lbann.NoOptimizer) ] l2_reg = lbann.L2WeightRegularization(weights=l2_weights, scale=1e-4) wae_loss.append(d1_real_bce) wae_loss.append(d_adv_bce) wae_loss.append(d1_fake_bce) wae_loss.append(l2_reg) print("LEN wae loss ", len(wae_loss)) obj = lbann.ObjectiveFunction(wae_loss) # Initialize check metric callback metrics = [ lbann.Metric(d_adv_bce, name='adv_loss'), lbann.Metric(recon, name='recon') ] callbacks = [ lbann.CallbackPrint(), #lbann.CallbackStepLearningRate(step=10, amt=0.5), lbann.CallbackTimer() ] callbacks.append( lbann.CallbackReplaceWeights(source_layers=list2str(src_layers), destination_layers=list2str(dst_layers), batch_interval=2)) #Dump output (activation) for post processing if (run_args.dump_outputs_dir): pred_tensor = lbann.Concatenation(arg_max, name='pred_tensor') callbacks.append( lbann.CallbackDumpOutputs( batch_interval=run_args.dump_outputs_interval, execution_modes='test', directory=run_args.dump_outputs_dir, layers=f'inp pred_tensor {waemodel.q_mu.name}')) # Construct model return lbann.Model(run_args.num_epochs, weights=weights, layers=layers, objective_function=obj, metrics=metrics, callbacks=callbacks)
def construct_model(run_args): """Construct LBANN model. Initial model for ATOM molecular VAE """ import lbann pad_index = run_args.pad_index assert pad_index is not None sequence_length = run_args.sequence_length assert sequence_length is not None, 'should be training seq len + bos + eos' print("sequence length is {}, which is training sequence len + bos + eos".format(sequence_length)) data_layout = "data_parallel" # Layer graph input_ = lbann.Input(data_field='samples',name='inp_data') #Note input assumes to come from encoder script concatenation of input smiles + z inp_slice = lbann.Slice(input_, axis=0, slice_points=str_list([0, sequence_length, sequence_length+run_args.z_dim]), name='inp_slice') inp_smile = lbann.Identity(inp_slice,name='inp_smile') z = lbann.Identity(inp_slice, name='z') wae_loss= [] input_feature_dims = sequence_length embedding_size = run_args.embedding_dim dictionary_size = run_args.num_embeddings assert embedding_size is not None assert dictionary_size is not None save_output = True if run_args.dump_outputs_dir else False print("save output? ", save_output, "out dir ", run_args.dump_outputs_dir) #uncomment below for random sampling #z = lbann.Gaussian(mean=0.0,stdev=1.0, neuron_dims=str(run_args.z_dim)) x = lbann.Slice(inp_smile, slice_points=str_list([0, input_feature_dims])) x = lbann.Identity(x) waemodel = molwae.MolWAE(input_feature_dims, dictionary_size, embedding_size, pad_index,run_args.z_dim,save_output=save_output) x_emb = lbann.Embedding( x, num_embeddings=waemodel.dictionary_size, embedding_dim=waemodel.embedding_size, name='emb', weights=waemodel.emb_weights ) pred, arg_max = waemodel.forward_decoder(x_emb,z) recon = waemodel.compute_loss(x, pred) wae_loss.append(recon) layers = list(lbann.traverse_layer_graph(input_)) # Setup objective function weights = set() for l in layers: weights.update(l.weights) #l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4) #wae_loss.append(l2_reg) print("LEN wae loss ", len(wae_loss)) obj = lbann.ObjectiveFunction(wae_loss) # Initialize check metric callback metrics = [lbann.Metric(recon, name='recon')] callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()] #Dump output (activation) for post processing pred_tensor = lbann.Concatenation(arg_max, name='pred_tensor') conc_out = lbann.Concatenation([input_,pred_tensor], name='conc_out') callbacks.append(lbann.CallbackDumpOutputs(batch_interval=run_args.dump_outputs_interval, execution_modes='test', directory=run_args.dump_outputs_dir, layers=f'{conc_out.name}')) # Construct model return lbann.Model(run_args.num_epochs, weights=weights, layers=layers, objective_function=obj, metrics=metrics, callbacks=callbacks)
def forward(self, queries, keys, values, mask=None): """Apply multi-head attention. The input and output tensors are interpreted as sequences of vectors, where the first tensor dimension is the sequence dimension. Args: queries (lbann.Layer): Sequence of query vectors. keys (lbann.Layer): Sequence of key vectors. values (lbann.Layer): Sequence of value vectors. mask (lbann.Layer, optional): Additive attention mask. If the (i,j) entry is very negative (e.g. -1e9), then the ith query does not attend to the jth key/value pair. Returns: lbann.Layer: Sequence of output vectors. The sequence length is the same as `queries`. """ ENABLE_SUBGRAPH = self.ENABLE_SUBGRAPH BRANCHES = self.BRANCHES if (ENABLE_SUBGRAPH): if (self.num_heads % BRANCHES != 0): raise ValueError('Num heads should be divisible by BRANCHES') self.instance += 1 name = f'{self.name}_instance{self.instance}' # Apply fully-connected layers to input sequences queries_fc = [] keys_fc = [] values_fc = [] # Slice embedding vectors for each head slice_points = str_list( self.head_dim * i for i in range(int(self.num_heads / self.BRANCHES) + 1)) #Queries strong scaling in CFC attentions = [] for count, query in enumerate(queries): temp = lbann.ChannelwiseFullyConnected( query, weights=self.query_weights[count], output_channel_dims=[self.inner_dim], name=f'{name}_subgrid{count}_queries_fc', ) attentions.append(temp) grid_sum_slice = lbann.Cross_Grid_Sum_Slice(attentions) attentions = [] for head in range(self.BRANCHES): attentions.append(lbann.Identity(grid_sum_slice)) for head in range(self.BRANCHES): temp = lbann.Slice( attentions[head], axis=1, slice_points=slice_points, name=f'{name}_subgrid{head}_queries_slice', ) queries_fc.append(temp) #keys strong scaling in CFC attentions = [] for count, key in enumerate(keys): temp = lbann.ChannelwiseFullyConnected( key, weights=self.key_weights[count], output_channel_dims=[self.inner_dim], name=f'{name}_subgrid{count}_keys_fc', ) attentions.append(temp) grid_sum_slice = lbann.Cross_Grid_Sum_Slice(attentions) attentions = [] for head in range(self.BRANCHES): attentions.append(lbann.Identity(grid_sum_slice)) for head in range(self.BRANCHES): temp = lbann.Slice( attentions[head], axis=1, slice_points=slice_points, name=f'{name}_subgrid{head}_keys_slice', ) keys_fc.append(temp) #Values strong scaling in CFC attentions = [] for count, value in enumerate(values): temp = lbann.ChannelwiseFullyConnected( value, weights=self.value_weights[count], output_channel_dims=[self.inner_dim], name=f'{name}_subgrid{count}_values_fc', ) attentions.append(temp) grid_sum_slice = lbann.Cross_Grid_Sum_Slice(attentions) attentions = [] for head in range(self.BRANCHES): attentions.append(lbann.Identity(grid_sum_slice)) for head in range(self.BRANCHES): temp = lbann.Slice( attentions[head], axis=1, slice_points=slice_points, name=f'{name}_subgrid{head}_values_slice', ) values_fc.append(temp) queries_slice = [] keys_slice = [] values_slice = [] for branch in range(self.BRANCHES): querie_slice = queries_fc[branch] key_slice = keys_fc[branch] value_slice = values_fc[branch] for head in range(int(self.num_heads / self.BRANCHES)): queries_slice.append(lbann.Identity(querie_slice)) keys_slice.append(lbann.Identity(key_slice)) values_slice.append(lbann.Identity(value_slice)) # Compute scaled dot-product attention for each head attentions = [] #variable to combine heads locally in sub-grids temp_attentions = [] tag = 0 for head in range(self.num_heads): head_name = f'{name}_myattention_head{head}' # Attention inputs if (head % int(self.num_heads / BRANCHES) == 0): temp_attentions.append([]) tag += 1 q = lbann.Identity(queries_slice[head]) k = lbann.Identity(keys_slice[head]) v = lbann.Identity(values_slice[head]) # Multiply queries and keys # Note: num_queries x num_keys y = lbann.MatMul( q, k, transpose_b=True, name=f'{head_name}_matmul', ) y = lbann.WeightedSum( y, scaling_factors=str(1 / math.sqrt(self.head_dim)), name=f'{head_name}_scale', ) if (ENABLE_SUBGRAPH): if mask != None: y = lbann.Sum([y, mask[tag]], name=f'{head_name}_mask') else: if mask: y = lbann.Sum([y, mask], name=f'{head_name}_mask') y = lbann.ChannelwiseSoftmax(y, name=f'{head_name}_softmax') # Attention output # Note: num_queries x head_dim y = lbann.MatMul(y, v, name=head_name) # attentions.append(lbann.MatMul(y, v, name=head_name)) temp_attentions[-1].append(y) for count, temp_attention in enumerate(temp_attentions): if (self.BRANCHES == self.num_heads): # No need to concat the heads at subgrid level # if number of subgrids is equal to number of heads attention_single_subgrid = temp_attentions[count][0] else: attention_single_subgrid = lbann.Concatenation( temp_attention, axis=1, name=f'{name}_subgrid_heads_concat{count}', parallel_strategy={ 'sub_branch_tag': 0, 'enable_subgraph': False }) attention_single_subgrid = lbann.ChannelwiseFullyConnected( attention_single_subgrid, weights=self.output_weights[count], output_channel_dims=[self.embed_dim], name=f'{name}_cfc_{count}', ) attentions.append(attention_single_subgrid) #Strong scaling grid_sum_slice = lbann.Cross_Grid_Sum_Slice(attentions) attentions = [] for head in range(self.BRANCHES): attentions.append(lbann.Identity(grid_sum_slice)) return attentions
def construct_model(run_args): """Construct LBANN model. Initial model for ATOM molecular VAE """ import lbann print("Dump model dir ", run_args.dump_model_dir) assert run_args.dump_model_dir, "evaluate script asssumes a pretrained WAE model" pad_index = run_args.pad_index assert pad_index is not None sequence_length = run_args.sequence_length assert sequence_length is not None print("sequence length is {}".format(sequence_length)) data_layout = "data_parallel" # Layer graph input_ = lbann.Identity(lbann.Input(name='inp', target_mode="N/A"), name='inp1') wae_loss = [] input_feature_dims = sequence_length embedding_size = run_args.embedding_dim dictionary_size = run_args.num_embeddings assert embedding_size is not None assert dictionary_size is not None save_output = False print("save output? ", save_output, "out dir ", run_args.dump_outputs_dir) z = lbann.Gaussian(mean=0.0, stdev=1.0, neuron_dims="128") x = lbann.Slice(input_, slice_points=str_list([0, input_feature_dims])) x = lbann.Identity(x) waemodel = molwae.MolWAE(input_feature_dims, dictionary_size, embedding_size, pad_index, save_output) x_emb = lbann.Embedding(x, num_embeddings=waemodel.dictionary_size, embedding_dim=waemodel.embedding_size, name='emb', weights=waemodel.emb_weights) latentz = waemodel.forward_encoder(x_emb) fake_loss = lbann.MeanAbsoluteError(latentz, z) layers = list(lbann.traverse_layer_graph(input_)) # Setup objective function weights = set() for l in layers: weights.update(l.weights) obj = lbann.ObjectiveFunction(fake_loss) callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()] #Dump output (activation) for post processing conc_out = lbann.Concatenation([input_, latentz], name='conc_out') callbacks.append( lbann.CallbackDumpOutputs( batch_interval=run_args.dump_outputs_interval, execution_modes='test', directory=run_args.dump_outputs_dir, layers=f'{conc_out.name}')) # Construct model return lbann.Model(run_args.num_epochs, weights=weights, layers=layers, objective_function=obj, callbacks=callbacks)
def compute_loss(self, x, y): # y[:, :-1] y = lbann.Slice( y, axis=0, slice_points=str_list([0, self.input_feature_dims - 1]), ) y = lbann.Identity(y) # x[:, 1:] x = lbann.Slice( x, slice_points=str_list([1, self.input_feature_dims]), ) x = lbann.Identity(x) # Convert indices in x to one-hot representation # Note: Ignored indices result in zero vectors ignore_mask = lbann.Equal( x, self.constant(self.label_to_ignore, hint_layer=x), ) keep_mask = lbann.LogicalNot(ignore_mask) length = lbann.Reduction(keep_mask, mode='sum') length = lbann.Max(length, self.constant(1, [1])) x = lbann.Add( lbann.Multiply(keep_mask, x), lbann.Multiply(ignore_mask, self.constant(-1, hint_layer=x)), ) x = lbann.Slice(x, slice_points=str_list(range(self.input_feature_dims))) x = [lbann.Identity(x) for _ in range(self.input_feature_dims - 1)] x = [lbann.OneHot(xi, size=self.dictionary_size) for xi in x] x = [ lbann.Reshape(xi, dims=str_list([1, self.dictionary_size])) for xi in x ] x = lbann.Concatenation(x, axis=0) # recon_loss = F.cross_entropy( # y[:, :-1].contiguous().view(-1, y.size(-1)), # x[:, 1:].contiguous().view(-1), # ignore_index=self.pad # ) # Note: Ideally we'd shift y by y.max(-1) for numerical stability shifts = lbann.MatMul( lbann.Max(y, self.constant(0, hint_layer=y)), self.constant( 1 / math.sqrt(self.dictionary_size), [self.dictionary_size, self.dictionary_size], ), ) y = lbann.Subtract(y, shifts) z = lbann.MatMul( lbann.Exp(y), self.constant(1, [self.dictionary_size, 1]), ) z = lbann.Log(z) z = lbann.MatMul( lbann.Reshape(keep_mask, dims=str_list([1, -1])), z, ) recon_loss = lbann.MatMul( lbann.Reshape(y, dims=str_list([1, -1])), lbann.Reshape(x, dims=str_list([1, -1])), transpose_b=True, ) recon_loss = lbann.Subtract(z, recon_loss) recon_loss = lbann.Reshape(recon_loss, dims=str_list([1])) recon_loss = lbann.Divide(recon_loss, length) return recon_loss
def forward(self, queries, keys, values, mask=None): """Apply multi-head attention. The input and output tensors are interpreted as sequences of vectors, where the first tensor dimension is the sequence dimension. Args: queries (lbann.Layer): Sequence of query vectors. keys (lbann.Layer): Sequence of key vectors. values (lbann.Layer): Sequence of value vectors. mask (lbann.Layer, optional): Additive attention mask. If the (i,j) entry is very negative (e.g. -1e9), then the ith query does not attend to the jth key/value pair. Returns: lbann.Layer: Sequence of output vectors. The sequence length is the same as `queries`. """ ENABLE_SUBGRAPH = self.ENABLE_SUBGRAPH BRANCHES = self.BRANCHES if (ENABLE_SUBGRAPH): if (self.num_heads % BRANCHES != 0): raise ValueError('Num heads should be divisible by BRANCHES') self.instance += 1 name = f'{self.name}_instance{self.instance}' # Apply fully-connected layers to input sequences queries_fc = lbann.ChannelwiseFullyConnected( queries, weights=self.query_weights, output_channel_dims=[self.inner_dim], name=f'{name}_queries_fc', ) keys_fc = lbann.ChannelwiseFullyConnected( keys, weights=self.key_weights, output_channel_dims=[self.inner_dim], name=f'{name}_keys_fc', ) values_fc = lbann.ChannelwiseFullyConnected( values, weights=self.value_weights, output_channel_dims=[self.inner_dim], name=f'{name}_values_fc', ) # Slice embedding vectors for each head slice_points = str_list(self.head_dim * i for i in range(self.num_heads + 1)) queries_slice = lbann.Slice(queries_fc, axis=1, slice_points=slice_points, name=f'{name}_queries_slice', parallel_strategy={ 'sub_branch_tag': 0, 'enable_subgraph': ENABLE_SUBGRAPH }) keys_slice = lbann.Slice(keys_fc, axis=1, slice_points=slice_points, name=f'{name}_keys_slice', parallel_strategy={ 'sub_branch_tag': 0, 'enable_subgraph': ENABLE_SUBGRAPH }) values_slice = lbann.Slice(values_fc, axis=1, slice_points=slice_points, name=f'{name}_values_slice', parallel_strategy={ 'sub_branch_tag': 0, 'enable_subgraph': ENABLE_SUBGRAPH }) # Compute scaled dot-product attention for each head attentions = [] tag = 0 for head in range(self.num_heads): head_name = f'{name}_myattention_head{head}' # Attention inputs if (ENABLE_SUBGRAPH): if (head % int(self.num_heads / BRANCHES) == 0): tag += 1 q = lbann.Identity(queries_slice, parallel_strategy={ 'sub_branch_tag': tag, 'enable_subgraph': ENABLE_SUBGRAPH }) k = lbann.Identity(keys_slice, parallel_strategy={ 'sub_branch_tag': tag, 'enable_subgraph': ENABLE_SUBGRAPH }) v = lbann.Identity(values_slice, parallel_strategy={ 'sub_branch_tag': tag, 'enable_subgraph': ENABLE_SUBGRAPH }) else: q = lbann.Identity(queries_slice) k = lbann.Identity(keys_slice) v = lbann.Identity(values_slice) # Multiply queries and keys # Note: num_queries x num_keys y = lbann.MatMul( q, k, transpose_b=True, name=f'{head_name}_matmul', ) y = lbann.WeightedSum( y, scaling_factors=str(1 / math.sqrt(self.head_dim)), name=f'{head_name}_scale', ) if (ENABLE_SUBGRAPH): if mask != None: y = lbann.Sum([y, mask[tag]], name=f'{head_name}_mask') else: if mask: y = lbann.Sum([y, mask], name=f'{head_name}_mask') y = lbann.ChannelwiseSoftmax(y, name=f'{head_name}_softmax') # Attention output # Note: num_queries x head_dim attentions.append(lbann.MatMul(y, v, name=head_name)) #Strong scaling # Concatenate heads and apply fully-connected layer if (ENABLE_SUBGRAPH): attentions = lbann.Concatenation(attentions, axis=1, name=f'{name}_heads_concat', parallel_strategy={ 'sub_branch_tag': 0, 'enable_subgraph': ENABLE_SUBGRAPH }) else: attentions = lbann.Concatenation( attentions, axis=1, name=f'{name}_heads_concat', ) outputs_fc = lbann.ChannelwiseFullyConnected( attentions, weights=self.output_weights, output_channel_dims=[self.embed_dim], name=f'{name}', ) return outputs_fc
def forward(self, queries, keys, values, mask=None): """Apply multi-head attention. The input and output tensors are interpreted as sequences of vectors, where the first tensor dimension is the sequence dimension. Args: queries (lbann.Layer): Sequence of query vectors. keys (lbann.Layer): Sequence of key vectors. values (lbann.Layer): Sequence of value vectors. mask (lbann.Layer, optional): Additive attention mask. If the (i,j) entry is very negative (e.g. -1e9), then the ith query does not attend to the jth key/value pair. Returns: lbann.Layer: Sequence of output vectors. The sequence length is the same as `queries`. """ self.instance += 1 name = f'{self.name}_instance{self.instance}' # Apply fully-connected layers to input sequences queries_fc = lbann.ChannelwiseFullyConnected( queries, weights=self.query_weights, output_channel_dims=[self.embed_dim], name=f'{name}_queries_fc', ) keys_fc = lbann.ChannelwiseFullyConnected( keys, weights=self.key_weights, output_channel_dims=[self.embed_dim], name=f'{name}_keys_fc', ) values_fc = lbann.ChannelwiseFullyConnected( values, weights=self.value_weights, output_channel_dims=[self.embed_dim], name=f'{name}_values_fc', ) # Slice embedding vectors for each head slice_points = str_list(self.head_dim * i for i in range(self.num_heads + 1)) queries_slice = lbann.Slice( queries_fc, axis=1, slice_points=slice_points, name=f'{name}_queries_slice', ) keys_slice = lbann.Slice( keys_fc, axis=1, slice_points=slice_points, name=f'{name}_keys_slice', ) values_slice = lbann.Slice( values_fc, axis=1, slice_points=slice_points, name=f'{name}_values_slice', ) # Compute scaled dot-product attention for each head attentions = [] for head in range(self.num_heads): head_name = f'{name}_head{head}' # Attention inputs q = lbann.Identity(queries_slice) k = lbann.Identity(keys_slice) v = lbann.Identity(values_slice) # Multiply queries and keys # Note: num_queries x num_keys y = lbann.MatMul( q, k, transpose_b=True, name=f'{head_name}_matmul', ) y = lbann.WeightedSum( y, scaling_factors=str(1 / math.sqrt(self.head_dim)), name=f'{head_name}_scale', ) if mask: y = lbann.Add(y, mask, name=f'{head_name}_mask') y = lbann.ChannelwiseSoftmax(y, name=f'{head_name}_softmax') # Attention output # Note: num_queries x head_dim attentions.append(lbann.MatMul(y, v, name=head_name)) # Concatenate heads and apply fully-connected layer attentions = lbann.Concatenation(attentions, axis=1, name=f'{name}_heads_concat') outputs_fc = lbann.ChannelwiseFullyConnected( attentions, weights=self.output_weights, output_channel_dims=[self.embed_dim], name=f'{name}', ) return outputs_fc
def construct_model(): """Construct MACC surrogate model. See https://arxiv.org/pdf/1912.08113.pdf model architecture and other details """ import lbann # Layer graph input = lbann.Input(data_field='samples', name='inp_data') # data is 64*64*4 images + 15 scalar + 5 param inp_slice = lbann.Slice(input, axis=0, slice_points=str_list( [0, args.ydim, args.ydim + args.xdim]), name='inp_slice') gt_y = lbann.Identity(inp_slice, name='gt_y') gt_x = lbann.Identity(inp_slice, name='gt_x') #param not used zero = lbann.Constant(value=0.0, num_neurons='1', name='zero') one = lbann.Constant(value=1.0, num_neurons='1', name='one') z = lbann.Gaussian(mean=0.0, stdev=1.0, neuron_dims="20") wae = macc_models.MACCWAE(args.zdim, args.ydim, cf=args.wae_mcf, use_CNN=args.useCNN) #pretrained, freeze inv = macc_models.MACCInverse(args.xdim, cf=args.surrogate_mcf) fwd = macc_models.MACCForward(args.zdim, cf=args.surrogate_mcf) y_pred_fwd = wae.encoder(gt_y) param_pred_ = wae.encoder(gt_y) input_fake = inv(param_pred_) output_cyc = fwd(input_fake) y_image_re2 = wae.decoder(output_cyc) '''**** Train cycleGAN input params <--> latent space of (images, scalars) ****''' output_fake = fwd(gt_x) y_image_re = wae.decoder(output_fake) y_out = wae.decoder(y_pred_fwd) param_pred2_ = wae.encoder(y_image_re) input_cyc = inv(param_pred2_) L_l2_x = lbann.MeanSquaredError( input_fake, gt_x) #(x,inv(enc(y)), (encoder+)inverse loss L_cyc_x = lbann.MeanSquaredError( input_cyc, gt_x) #param, x cycle loss, from latent space L_l2_y = lbann.MeanSquaredError( output_fake, y_pred_fwd) #pred error into latent space (enc(y),fw(x)) L_cyc_y = lbann.MeanSquaredError( output_cyc, y_pred_fwd) # pred error into latent space (enc(y), fw(inv(enc(y)))) #@todo slice here to separate scalar from image img_sca_loss = lbann.MeanSquaredError( y_image_re, gt_y) # (y,dec(fw(x))) #forward model to decoder, no latent space dec_fw_inv_enc_y = lbann.MeanSquaredError( y_image_re2, gt_y) #(y, dec(fw(inv(enc(y))))) y->enc_z->x'->fw_z->y' wae_loss = lbann.MeanSquaredError(y_out, gt_y) #(y, dec(enc(y)) ' #L_cyc = L_cyc_y + L_cyc_x L_cyc = lbann.Add(L_cyc_y, L_cyc_x) #loss_gen0 = L_l2_y + lamda_cyc*L_cyc loss_gen0 = lbann.WeightedSum([L_l2_y, L_cyc], scaling_factors=f'1 {args.lamda_cyc}') loss_gen1 = lbann.WeightedSum([L_l2_x, L_cyc_y], scaling_factors=f'1 {args.lamda_cyc}') #loss_gen1 = L_l2_x + lamda_cyc*L_cyc_y conc_out = lbann.Concatenation( [gt_x, wae_loss, img_sca_loss, dec_fw_inv_enc_y, L_l2_x], name='x_errors') layers = list(lbann.traverse_layer_graph(input)) weights = set() for l in layers: weights.update(l.weights) # Setup objective function obj = lbann.ObjectiveFunction([loss_gen0, loss_gen1]) # Initialize check metric callback metrics = [ lbann.Metric(img_sca_loss, name='img_re1'), lbann.Metric(dec_fw_inv_enc_y, name='img_re2'), lbann.Metric(wae_loss, name='wae_loss'), lbann.Metric(L_l2_x, name='inverse loss'), lbann.Metric(L_cyc_y, name='output cycle loss'), lbann.Metric(L_cyc_x, name='param cycle loss') ] callbacks = [ lbann.CallbackPrint(), lbann.CallbackDumpOutputs(layers=f'{conc_out.name}', execution_modes='test', directory=args.dump_outputs, batch_interval=1, format='npy'), lbann.CallbackTimer() ] # Construct model num_epochs = 1 return lbann.Model(num_epochs, weights=weights, layers=layers, serialize_io=True, metrics=metrics, objective_function=obj, callbacks=callbacks)
def make_model( num_epochs, embed_dim, num_heads, label_smoothing, ): # Embedding weights var = 2 / (embed_dim + vocab_size) # Glorot initialization embedding_weights = lbann.Weights( name='embeddings', initializer=lbann.NormalInitializer(standard_deviation=math.sqrt(var)), ) # Input is two sequences of token IDs input_ = lbann.Input(data_field='samples') # Get sequences of embedding vectors # Note: Scale embeddings by sqrt(embed_dim). # Note: Decoder input is shifted right, so embedding for last # token isn't needed. embeddings_tokens = lbann.Identity( lbann.Slice( input_, axis=0, slice_points=str_list([0, 2 * sequence_length - 1]), )) embeddings = lbann.Embedding( embeddings_tokens, weights=embedding_weights, num_embeddings=vocab_size, embedding_dim=embed_dim, padding_idx=pad_index, ) embeddings = lbann.WeightedSum( embeddings, scaling_factors=str(math.sqrt(embed_dim)), ) embeddings_slice = lbann.Slice( embeddings, axis=0, slice_points=str_list([0, sequence_length, 2 * sequence_length - 1]), ) encoder_input = lbann.Identity(embeddings_slice) decoder_input = lbann.Identity(embeddings_slice) # Apply transformer model transformer = lbann.models.Transformer( hidden_size=embed_dim, num_heads=num_heads, name='transformer', ) result = transformer( encoder_input, sequence_length, decoder_input, sequence_length - 1, ) # Reconstruct decoder input preds = lbann.ChannelwiseFullyConnected( result, weights=embedding_weights, output_channel_dims=[vocab_size], bias=False, transpose=True, ) preds = lbann.ChannelwiseSoftmax(preds) preds = lbann.Slice(preds, axis=0, slice_points=str_list(range(sequence_length))) preds = [lbann.Identity(preds) for _ in range(sequence_length - 1)] # Count number of non-pad tokens label_tokens = lbann.Identity( lbann.Slice( input_, slice_points=str_list([sequence_length + 1, 2 * sequence_length]), )) pads = lbann.Constant(value=pad_index, num_neurons=str(sequence_length - 1)) is_not_pad = lbann.NotEqual(label_tokens, pads) num_not_pad = lbann.Reduction(is_not_pad, mode='sum') # Cross entropy loss with label smoothing label_tokens = lbann.Slice( label_tokens, slice_points=str_list(range(sequence_length)), ) label_tokens = [ lbann.Identity(label_tokens) for _ in range(sequence_length - 1) ] if label_smoothing > 0: uniform_label = lbann.Constant(value=1 / vocab_size, num_neurons=str_list([1, vocab_size])) loss = [] for i in range(sequence_length - 1): label = lbann.OneHot(label_tokens[i], size=vocab_size) label = lbann.Reshape(label, dims=str_list([1, vocab_size])) if label_smoothing > 0: label = lbann.WeightedSum( label, uniform_label, scaling_factors=str_list( [1 - label_smoothing, label_smoothing]), ) loss.append(lbann.CrossEntropy(preds[i], label)) loss = lbann.Concatenation(loss) # Average cross entropy over non-pad tokens loss_scales = lbann.Divide( is_not_pad, lbann.Tessellate(num_not_pad, hint_layer=is_not_pad), ) loss = lbann.Multiply(loss, loss_scales) loss = lbann.Reduction(loss, mode='sum') # Construct model metrics = [] callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()] return lbann.Model( num_epochs, layers=lbann.traverse_layer_graph(input_), objective_function=loss, metrics=metrics, callbacks=callbacks, )
def construct_model(run_args): """Construct LBANN model. Initial model for ATOM molecular VAE """ import lbann pad_index = run_args.pad_index assert pad_index is not None #sequence_length = run_args.sequence_length sequence_length = 102 assert sequence_length is not None print("sequence length is {}".format(sequence_length)) data_layout = "data_parallel" # Layer graph input_ = lbann.Input(target_mode='N/A',name='inp_data') inp_slice = lbann.Slice(input_, axis=0, slice_points="0 102 230",name='inp_slice') inp_smile = lbann.Identity(inp_slice,name='inp_smile') z = lbann.Identity(inp_slice, name='z') #param not used #input_ = lbann.Identity(lbann.Input(name='inp',target_mode="N/A"), name='inp1') vae_loss= [] input_feature_dims = sequence_length embedding_size = run_args.embedding_dim dictionary_size = run_args.num_embeddings assert embedding_size is not None assert dictionary_size is not None save_output = True if run_args.dump_outputs_dir else False #print("Inp smile len ", len(inp_smile), "z len ", len(z)) print("save output? ", save_output, "out dir ", run_args.dump_outputs_dir) #z = lbann.Gaussian(mean=0.0,stdev=1.0, neuron_dims="128") x = lbann.Slice(inp_smile, slice_points=str_list([0, input_feature_dims])) x = lbann.Identity(x) waemodel = molwae.MolWAE(input_feature_dims, dictionary_size, embedding_size, pad_index,save_output) x_emb = lbann.Embedding( x, num_embeddings=waemodel.dictionary_size, embedding_dim=waemodel.embedding_size, name='emb', weights=waemodel.emb_weights ) pred, arg_max = waemodel.forward_decoder(x_emb,z) recon = waemodel.compute_loss(x, pred) vae_loss.append(recon) layers = list(lbann.traverse_layer_graph(input_)) # Setup objective function weights = set() for l in layers: weights.update(l.weights) #l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4) #vae_loss.append(l2_reg) print("LEN vae loss ", len(vae_loss)) obj = lbann.ObjectiveFunction(vae_loss) # Initialize check metric callback metrics = [lbann.Metric(recon, name='recon')] callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()] #Dump output (activation) for post processing pred_tensor = lbann.Concatenation(arg_max, name='pred_tensor') conc_out = lbann.Concatenation([input_,pred_tensor], name='conc_out') callbacks.append(lbann.CallbackDumpOutputs(batch_interval=run_args.dump_outputs_interval, execution_modes='test', directory=run_args.dump_outputs_dir, layers=f'{conc_out.name}')) # Construct model return lbann.Model(run_args.num_epochs, weights=weights, layers=layers, objective_function=obj, metrics=metrics, callbacks=callbacks)
def setup(num_patches=3, mini_batch_size=512, num_epochs=75, learning_rate=0.005, bn_statistics_group_size=2, fc_data_layout='model_parallel', warmup=True, checkpoint_interval=None): # Data dimensions patch_dims = patch_generator.patch_dims num_labels = patch_generator.num_labels(num_patches) # Extract tensors from data sample input = lbann.Input() slice_points = [0] for _ in range(num_patches): patch_size = functools.reduce(operator.mul, patch_dims) slice_points.append(slice_points[-1] + patch_size) slice_points.append(slice_points[-1] + num_labels) sample = lbann.Slice(input, slice_points=str_list(slice_points)) patches = [ lbann.Reshape(sample, dims=str_list(patch_dims)) for _ in range(num_patches) ] labels = lbann.Identity(sample) # Siamese network head_cnn = modules.ResNet( bn_statistics_group_size=bn_statistics_group_size) heads = [head_cnn(patch) for patch in patches] heads_concat = lbann.Concatenation(heads) # Classification network class_fc1 = modules.FcBnRelu( 4096, statistics_group_size=bn_statistics_group_size, name='siamese_class_fc1', data_layout=fc_data_layout) class_fc2 = modules.FcBnRelu( 4096, statistics_group_size=bn_statistics_group_size, name='siamese_class_fc2', data_layout=fc_data_layout) class_fc3 = lbann.modules.FullyConnectedModule(num_labels, activation=lbann.Softmax, name='siamese_class_fc3', data_layout=fc_data_layout) x = class_fc1(heads_concat) x = class_fc2(x) probs = class_fc3(x) # Setup objective function cross_entropy = lbann.CrossEntropy([probs, labels]) l2_reg_weights = set() for l in lbann.traverse_layer_graph(input): if type(l) == lbann.Convolution or type(l) == lbann.FullyConnected: l2_reg_weights.update(l.weights) l2_reg = lbann.L2WeightRegularization(weights=l2_reg_weights, scale=0.0002) obj = lbann.ObjectiveFunction([cross_entropy, l2_reg]) # Setup model metrics = [ lbann.Metric(lbann.CategoricalAccuracy([probs, labels]), name='accuracy', unit='%') ] callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()] if checkpoint_interval: callbacks.append( lbann.CallbackCheckpoint(checkpoint_dir='ckpt', checkpoint_epochs=5)) # Learning rate schedules if warmup: callbacks.append( lbann.CallbackLinearGrowthLearningRate(target=learning_rate * mini_batch_size / 128, num_epochs=5)) callbacks.append( lbann.CallbackDropFixedLearningRate(drop_epoch=list(range(0, 100, 15)), amt=0.25)) # Construct model model = lbann.Model(num_epochs, layers=lbann.traverse_layer_graph(input), objective_function=obj, metrics=metrics, callbacks=callbacks) # Setup optimizer opt = lbann.SGD(learn_rate=learning_rate, momentum=0.9) # opt = lbann.Adam(learn_rate=learning_rate, beta1=0.9, beta2=0.999, eps=1e-8) # Setup data reader data_reader = make_data_reader(num_patches) # Return experiment objects return model, data_reader, opt