def forward(self, x): """Perform LSTM step. State from previous steps is used to compute output. """ self.step += 1 name = '{0}_step{1}'.format(self.name, self.step) # Apply linearity input_concat = lbann.Concatenation([x, self.last_output], name=name + '_input', data_layout=self.data_layout) fc = self.fc(input_concat) # Get gates and cell update slice = lbann.Slice(fc, slice_points=_str_list([0, self.size, 4*self.size]), name=name + '_fc_slice', data_layout=self.data_layout) cell_update = lbann.Tanh(slice, name=name + '_cell_update', data_layout=self.data_layout) sigmoid = lbann.Sigmoid(slice, name=name + '_sigmoid', data_layout=self.data_layout) slice = lbann.Slice(sigmoid, slice_points=_str_list([0, self.size, 2*self.size, 3*self.size]), name=name + '_sigmoid_slice', data_layout=self.data_layout) f = lbann.Identity(slice, name=name + '_forget_gate', data_layout=self.data_layout) i = lbann.Identity(slice, name=name + '_input_gate', data_layout=self.data_layout) o = lbann.Identity(slice, name=name + '_output_gate', data_layout=self.data_layout) # Cell state cell_forget = lbann.Multiply([f, self.last_cell], name=name + '_cell_forget', data_layout=self.data_layout) cell_input = lbann.Multiply([i, cell_update], name=name + '_cell_input', data_layout=self.data_layout) cell = lbann.Add([cell_forget, cell_input], name=name + '_cell', data_layout=self.data_layout) # Output cell_act = lbann.Tanh(cell, name=name + '_cell_activation', data_layout=self.data_layout) output = lbann.Multiply([o, cell_act], name=name, data_layout=self.data_layout) # Update state and return output self.last_cell = cell self.last_output = output return output
def forward(self, x): self.instance += 1 y1 = self.branch1(x) if self.branch1 else x y2 = self.branch2c(self.branch2b(self.branch2a(x))) z = lbann.Add([y1, y2], name='{0}_sum_instance{1}'.format( self.name, self.instance)) return lbann.Relu(z, name='{0}_relu_instance{1}'.format( self.name, self.instance))
def f_invtransform(y, scale=4.0): ### Transform to original space ''' The inverse of the transformation function that scales the data before training ''' inv_transform = lbann.WeightedSum(lbann.SafeDivide( lbann.Add(lbann.Constant(value=1.0, hint_layer=y), lbann.Identity(y)), lbann.Subtract(lbann.Constant(value=1.0, hint_layer=y), lbann.Identity(y))), scaling_factors=str(scale)) return inv_transform
def inv_transform(self, y): ### Original transformation ''' The inverse of the transformation function that scales the data before training ''' inv_transform = lbann.WeightedSum(lbann.SafeDivide( lbann.Add(lbann.Constant(value=1.0, hint_layer=y), lbann.Identity(y)), lbann.Subtract(lbann.Constant(value=1.0, hint_layer=y), lbann.Identity(y))), scaling_factors=str(self.datascale)) return inv_transform
def forward_encoder(self, x_emb): """Encoder step, emulating z ~ E(x) = q_E(z|x) :param x_emb: (n_batch, len(x), d_z) of floats, embeddings for input sentence x :return: (n_batch, d_z) of floats, sample of latent vector z :return: float, kl term component of loss """ # _, h = self.encoder_rnn(x, None) h = self.encoder_rnn(x_emb, None) h = lbann.Slice( h, slice_points=str_list([self.input_feature_dims-1, self.input_feature_dims]), axis=0, ) h = lbann.Identity(h) mu, logvar = self.q_mu(h), self.q_logvar(h) # Set datatype of previous layers # Note: Depth-first search from mu and logvar to x_emb stack = [mu, logvar] in_stack = {l : True for l in stack} while stack: l = stack.pop() if type(l) not in (lbann.Slice, lbann.Reshape, lbann.Tessellate): l.datatype = self.datatype for parent in l.parents: if parent not in in_stack and parent is not x_emb: stack.append(parent) in_stack[parent] = True # eps = torch.randn_like(mu) eps = lbann.Gaussian(mean=0, stdev=1,hint_layer=mu) # z = mu + (logvar / 2).exp() * eps z = lbann.Add([mu, (lbann.Multiply([lbann.Exp(lbann.WeightedSum(logvar,scaling_factors='0.5')),eps]))]) # kl_loss = 0.5 * (logvar.exp() + mu ** 2 - 1 - logvar).sum(1).mean() kl_loss = lbann.Reduction( lbann.WeightedSum( lbann.Exp(logvar), lbann.Square(mu), self.constant(1, hint_layer=mu), logvar, scaling_factors='0.5 0.5 -0.5 -0.5', ), mode='sum', ) return z, kl_loss
def forward(self, inputs): raise NotImplementedError # Requires log-gamma function if len(inputs) != 2: raise ValueError('expected two inputs: predictions and labels') pred = inputs[0] label = inputs[1] ones = lbann.Constant(hint_layer=pred, value=1.0) term1 = pred term2 = lbann.Multiply([label, lbann.Log(pred)]) term3 = lbann.LogGamma(lbann.Add([label, ones])) full = lbann.WeightedSum([term1, term2, term3], scaling_factors='1.0 -1.0 1.0') return lbann.Reduction(full)
def Gelu_approx(x): # This approximates gelu and may be more performant # return 0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x ** 3))) # Based on: https://paperswithcode.com/method/gelu sqrt_2_over_pi = math.sqrt(2 / math.pi) b_coef = 0.044715 x_cubed = lbann.Multiply(lbann.Multiply(lbann.Identity(x), x), x) inner_tanh_x_comp = lbann.Add(x, lbann.Scale(x_cubed, constant=b_coef)) tanh_x = lbann.Tanh(lbann.Scale(inner_tanh_x_comp, constant=sqrt_2_over_pi)) return lbann.Scale(lbann.Multiply(x, lbann.AddConstant(tanh_x, constant=1)), constant=0.5)
def forward(self, x, z): """Do the WAE forward step :param x: list of tensors of longs, embed representation of input :return: float, kl term component of loss :return: float, recon component of loss """ x = lbann.Slice(x, slice_points=str_list([0, self.input_feature_dims])) x = lbann.Identity(x) x_emb = lbann.Embedding(x, num_embeddings=self.dictionary_size, embedding_dim=self.embedding_size, name='emb', weights=self.emb_weights) # Encoder: x -> z, kl_loss z_sample = self.forward_encoder(x_emb) eps = lbann.Gaussian(mean=self.gmean, stdev=self.gstd, hint_layer=z_sample) z_sample = lbann.Add([z_sample, eps]) # Decoder: x, z -> recon_loss #pred = self.forward_decoder(x_emb, z_sample) pred, arg_max = self.forward_decoder(x_emb, z_sample) recon_loss = self.compute_loss(x, pred) # Hack to remove blocking GPU allreduce in evaluation layer #kl_loss = lbann.Identity(kl_loss, device='CPU') recon_loss = lbann.Identity(recon_loss, device='CPU') z_prior = lbann.Tessellate( lbann.Reshape(z, dims=str_list([1, self.zdim])), dims=str_list([self.input_feature_dims, self.zdim]), ) d_real = self.discriminator0( lbann.Concatenation([x_emb, z_prior], axis=1)) z_sample0 = lbann.Tessellate( lbann.Reshape(z_sample, dims=str_list([1, self.zdim])), dims=str_list([self.input_feature_dims, self.zdim]), ) y_z_sample = lbann.Concatenation([x_emb, z_sample0], axis=1) d_fake = self.discriminator0(lbann.StopGradient(y_z_sample)) d_adv = self.discriminator1(y_z_sample) #freeze return recon_loss, d_real, d_fake, d_adv, arg_max
def inv_transform(self, y): ''' The inverse of the transformation function that scales the data before training ''' inv_transform = lbann.WeightedSum(lbann.SafeDivide( lbann.Add(lbann.Constant(value=1.0, hint_layer=y), lbann.Identity(y)), lbann.Subtract(lbann.Constant(value=1.0, hint_layer=y), lbann.Identity(y))), scaling_factors=str(self.datascale)) #linear_scale = 1/self.linear_scaler #CH2 = lbann.Tanh(lbann.WeightedSum(inv_transform,scaling_factors=str(linear_scale))) #return CH2 return inv_transform
def forward( self, motif_indices, motif_size, walk_indices, walk_length, ): # Apply generator fake_motif_indices, gen_prob, gen_log_prob = self.generator( walk_length, walk_indices, motif_size, ) # Get discriminator embeddings in log-space all_motif_indices = lbann.Concatenation(motif_indices, fake_motif_indices) all_motif_log_embeddings = self.discriminator.get_log_embeddings( all_motif_indices) all_motif_log_embeddings = lbann.Slice( all_motif_log_embeddings, slice_points=str_list([0, motif_size, 2 * motif_size]), ) real_motif_log_embeddings = lbann.Identity(all_motif_log_embeddings) fake_motif_log_embeddings = lbann.Identity(all_motif_log_embeddings) # Apply discriminator real_disc_prob, real_disc_log_not_prob \ = self.discriminator(motif_size, real_motif_log_embeddings) fake_disc_prob, fake_disc_log_not_prob \ = self.discriminator(motif_size, fake_motif_log_embeddings) # Loss function # L_disc = - log(D(real)) - log(1-D(fake)) # L_gen = - log(G) * stop_gradient(log(1-D(fake))) real_disc_log_prob \ = lbann.Log(lbann.Clamp(real_disc_prob, min=1e-37, max=1)) disc_loss = lbann.WeightedSum( real_disc_log_prob, fake_disc_log_not_prob, scaling_factors=str_list([-1, -1]), ) gen_loss = lbann.Multiply( gen_log_prob, lbann.StopGradient(fake_disc_log_not_prob), ) loss = lbann.Add(disc_loss, gen_loss) return loss, real_disc_prob, fake_disc_prob, gen_prob
def create_position_ids_from_input_ids(input_ids, input_shape, padding_idx, past_key_values_length=0): padding_idx = lbann.Constant(value=padding_idx, num_neurons=str_list(input_shape)) mask = lbann.NotEqual(input_ids, padding_idx) incremental_indices = lbann.Multiply( lbann.AddConstant( lbann.modules.Cumsum(mask, input_shape, axis=1), constant=past_key_values_length, ), mask, ) incremental_indices = lbann.Add(incremental_indices, padding_idx) return incremental_indices
def forward(self, inputs): raise NotImplementedError # Requires log-gamma function if len(inputs) != 2: raise ValueError('expected two inputs: predictions and labels') pred = inputs[0] label = inputs[1] count = lbann.Reduction(label) alpha_sum = lbann.Reduction(pred) lgamma_alpha_sum = lbann.Reduction(lbann.LogGamma(pred)) lgamma_alpha_level_count_sum = lbann.Reduction( lbann.LogGamma(lbann.Add([pred, label]))) return lbann.WeightedSum([ lbann.LogGamma(alpha_sum), lbann.LogGamma(lbann.Sum([count, alpha_sum])), lgamma_alpha_level_count, lgamma_alpha_sum ], scaling_factors='-1.0 1.0 -1.0 1.0')
def inv_transform(self, y): ### Original transformation ''' The inverse of the transformation function that scales the data before training ''' inv_transform = lbann.WeightedSum(lbann.SafeDivide( lbann.Add(lbann.Constant(value=1.0, hint_layer=y), lbann.Identity(y)), lbann.Subtract(lbann.Constant(value=1.0, hint_layer=y), lbann.Identity(y))), scaling_factors=str(self.datascale)) return inv_transform # def inv_transform(self, y):### New tranformation : log-linear # threshold = lbann.Constant(value=0.5, hint_layer=y) # is_above_threshold = lbann.Greater(y, threshold) # is_below_threshold = lbann.LogicalNot(is_above_threshold) # below = lbann.SafeDivide( # lbann.Subtract(y, lbann.Constant(value=1, hint_layer=y)), # lbann.Constant(value=0.03, hint_layer=y), # ) # above = lbann.Exp(lbann.SafeDivide( # lbann.Subtract( # y, # lbann.Constant(value=0.5-0.5/math.log(300)*math.log(50), hint_layer=y)), # lbann.Constant(value=0.5/math.log(300), hint_layer=y), # )) # return lbann.Add( # lbann.Multiply(is_above_threshold, above), # lbann.Multiply(is_below_threshold, below), # ) # def f_invtransform_new(y): # if y<=0.5: # a=0.03;b=-1.0 # return (y-b)/a # elif y>0.5: # a=0.5/np.log(300) # b=0.5-a*np.log(50) # return np.exp((y-b)/a)
def construct_model(): """Construct LBANN model. Pilot1 Combo model """ import lbann # Layer graph data = lbann.Input(data_field='samples') responses = lbann.Input(data_field='responses') pred = combo.Combo()(data) mse = lbann.MeanSquaredError([responses, pred]) SS_res = lbann.Reduction(lbann.Square(lbann.Subtract(responses, pred)), mode='sum') #SS_tot = var(x) = mean((x-mean(x))^2) mini_batch_size = lbann.MiniBatchSize() mean = lbann.Divide(lbann.BatchwiseReduceSum(responses), mini_batch_size) SS_tot = lbann.Divide( lbann.BatchwiseReduceSum(lbann.Square(lbann.Subtract(responses, mean))), mini_batch_size) eps = lbann.Constant(value=1e-07, hint_layer=SS_tot) r2 = lbann.Subtract(lbann.Constant(value=1, num_neurons='1'), lbann.Divide(SS_res, lbann.Add(SS_tot, eps))) metrics = [lbann.Metric(mse, name='mse')] metrics.append(lbann.Metric(r2, name='r2')) callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()] # Construct model num_epochs = 100 layers = list(lbann.traverse_layer_graph([data, responses])) return lbann.Model(num_epochs, layers=layers, metrics=metrics, objective_function=mse, callbacks=callbacks)
def compute_loss(self, x, y): # y[:, :-1] y = lbann.Slice( y, axis=0, slice_points=str_list([0, self.input_feature_dims - 1]), ) y = lbann.Identity(y) # x[:, 1:] x = lbann.Slice( x, slice_points=str_list([1, self.input_feature_dims]), ) x = lbann.Identity(x) # Convert indices in x to one-hot representation # Note: Ignored indices result in zero vectors ignore_mask = lbann.Equal( x, self.constant(self.label_to_ignore, hint_layer=x), ) keep_mask = lbann.LogicalNot(ignore_mask) length = lbann.Reduction(keep_mask, mode='sum') length = lbann.Max(length, self.constant(1, [1])) x = lbann.Add( lbann.Multiply(keep_mask, x), lbann.Multiply(ignore_mask, self.constant(-1, hint_layer=x)), ) x = lbann.Slice(x, slice_points=str_list(range(self.input_feature_dims))) x = [lbann.Identity(x) for _ in range(self.input_feature_dims - 1)] x = [lbann.OneHot(xi, size=self.dictionary_size) for xi in x] x = [ lbann.Reshape(xi, dims=str_list([1, self.dictionary_size])) for xi in x ] x = lbann.Concatenation(x, axis=0) # recon_loss = F.cross_entropy( # y[:, :-1].contiguous().view(-1, y.size(-1)), # x[:, 1:].contiguous().view(-1), # ignore_index=self.pad # ) # Note: Ideally we'd shift y by y.max(-1) for numerical stability shifts = lbann.MatMul( lbann.Max(y, self.constant(0, hint_layer=y)), self.constant( 1 / math.sqrt(self.dictionary_size), [self.dictionary_size, self.dictionary_size], ), ) y = lbann.Subtract(y, shifts) z = lbann.MatMul( lbann.Exp(y), self.constant(1, [self.dictionary_size, 1]), ) z = lbann.Log(z) z = lbann.MatMul( lbann.Reshape(keep_mask, dims=str_list([1, -1])), z, ) recon_loss = lbann.MatMul( lbann.Reshape(y, dims=str_list([1, -1])), lbann.Reshape(x, dims=str_list([1, -1])), transpose_b=True, ) recon_loss = lbann.Subtract(z, recon_loss) recon_loss = lbann.Reshape(recon_loss, dims=str_list([1])) recon_loss = lbann.Divide(recon_loss, length) return recon_loss
meansq = lbann.Square(mu, name="meansq") kldiv_plus_half = lbann.WeightedSum([meansq, var, logsd], name="kldiv_plus_half", scaling_factors='0.5 0.5 -1') kldiv_full = lbann.Rsqrt(kldiv_plus_half, name="kldiv_full") kldiv = lbann.Reduction(kldiv_full, name="kldiv", mode="sum") # Generate sample noise = lbann.Gaussian(name="noise", mean=0, stdev=1, hint_layer=mu) sdnoise = lbann.Hadamard([noise, sd], name="sdnoise") sample = lbann.Add([mu, sdnoise], name="sample") # Decoder decode4 = lbann.FullyConnected(sample, name="decode4", has_bias=True, hint_layer=encode3) decode4neuron = lbann.Relu(decode4, name="decode4neuron") decode3 = lbann.FullyConnected(decode4neuron, name="decode3", has_bias=True, hint_layer=encode2) decode3neuron = lbann.Relu(decode3, name="decode3neuron")
def forward(self, queries, keys, values, mask=None): """Apply multi-head attention. The input and output tensors are interpreted as sequences of vectors, where the first tensor dimension is the sequence dimension. Args: queries (lbann.Layer): Sequence of query vectors. keys (lbann.Layer): Sequence of key vectors. values (lbann.Layer): Sequence of value vectors. mask (lbann.Layer, optional): Additive attention mask. If the (i,j) entry is very negative (e.g. -1e9), then the ith query does not attend to the jth key/value pair. Returns: lbann.Layer: Sequence of output vectors. The sequence length is the same as `queries`. """ self.instance += 1 name = f'{self.name}_instance{self.instance}' # Apply fully-connected layers to input sequences queries_fc = lbann.ChannelwiseFullyConnected( queries, weights=self.query_weights, output_channel_dims=[self.embed_dim], name=f'{name}_queries_fc', ) keys_fc = lbann.ChannelwiseFullyConnected( keys, weights=self.key_weights, output_channel_dims=[self.embed_dim], name=f'{name}_keys_fc', ) values_fc = lbann.ChannelwiseFullyConnected( values, weights=self.value_weights, output_channel_dims=[self.embed_dim], name=f'{name}_values_fc', ) # Slice embedding vectors for each head slice_points = str_list(self.head_dim * i for i in range(self.num_heads + 1)) queries_slice = lbann.Slice( queries_fc, axis=1, slice_points=slice_points, name=f'{name}_queries_slice', ) keys_slice = lbann.Slice( keys_fc, axis=1, slice_points=slice_points, name=f'{name}_keys_slice', ) values_slice = lbann.Slice( values_fc, axis=1, slice_points=slice_points, name=f'{name}_values_slice', ) # Compute scaled dot-product attention for each head attentions = [] for head in range(self.num_heads): head_name = f'{name}_head{head}' # Attention inputs q = lbann.Identity(queries_slice) k = lbann.Identity(keys_slice) v = lbann.Identity(values_slice) # Multiply queries and keys # Note: num_queries x num_keys y = lbann.MatMul( q, k, transpose_b=True, name=f'{head_name}_matmul', ) y = lbann.WeightedSum( y, scaling_factors=str(1 / math.sqrt(self.head_dim)), name=f'{head_name}_scale', ) if mask: y = lbann.Add(y, mask, name=f'{head_name}_mask') y = lbann.ChannelwiseSoftmax(y, name=f'{head_name}_softmax') # Attention output # Note: num_queries x head_dim attentions.append(lbann.MatMul(y, v, name=head_name)) # Concatenate heads and apply fully-connected layer attentions = lbann.Concatenation(attentions, axis=1, name=f'{name}_heads_concat') outputs_fc = lbann.ChannelwiseFullyConnected( attentions, weights=self.output_weights, output_channel_dims=[self.embed_dim], name=f'{name}', ) return outputs_fc
def construct_macc_surrogate_model(xdim, ydim, zdim, wae_mcf, surrogate_mcf, lambda_cyc, useCNN, dump_models, pretrained_dir, ltfb_batch_interval, num_epochs): """Construct MACC surrogate model. See https://arxiv.org/pdf/1912.08113.pdf model architecture and other details """ # Layer graph input = lbann.Input(data_field='samples', name='inp_data') # data is 64*64*4 images + 15 scalar + 5 param inp_slice = lbann.Slice(input, axis=0, slice_points=str_list([0, ydim, ydim + xdim]), name='inp_slice') gt_y = lbann.Identity(inp_slice, name='gt_y') gt_x = lbann.Identity(inp_slice, name='gt_x') #param not used zero = lbann.Constant(value=0.0, num_neurons='1', name='zero') one = lbann.Constant(value=1.0, num_neurons='1', name='one') z = lbann.Gaussian(mean=0.0, stdev=1.0, neuron_dims="20") wae = macc_network_architectures.MACCWAE( zdim, ydim, cf=wae_mcf, use_CNN=useCNN) #pretrained, freeze inv = macc_network_architectures.MACCInverse(xdim, cf=surrogate_mcf) fwd = macc_network_architectures.MACCForward(zdim, cf=surrogate_mcf) y_pred_fwd = wae.encoder(gt_y) param_pred_ = wae.encoder(gt_y) input_fake = inv(param_pred_) output_cyc = fwd(input_fake) y_image_re2 = wae.decoder(output_cyc) '''**** Train cycleGAN input params <--> latent space of (images, scalars) ****''' output_fake = fwd(gt_x) y_image_re = wae.decoder(output_fake) param_pred2_ = wae.encoder(y_image_re) input_cyc = inv(param_pred2_) L_l2_x = lbann.MeanSquaredError(input_fake, gt_x) L_cyc_x = lbann.MeanSquaredError(input_cyc, gt_x) L_l2_y = lbann.MeanSquaredError(output_fake, y_pred_fwd) L_cyc_y = lbann.MeanSquaredError(output_cyc, y_pred_fwd) #@todo slice here to separate scalar from image img_sca_loss = lbann.MeanSquaredError(y_image_re, gt_y) #L_cyc = L_cyc_y + L_cyc_x L_cyc = lbann.Add(L_cyc_y, L_cyc_x) #loss_gen0 = L_l2_y + lamda_cyc*L_cyc loss_gen0 = lbann.WeightedSum([L_l2_y, L_cyc], scaling_factors=f'1 {lambda_cyc}') loss_gen1 = lbann.WeightedSum([L_l2_x, L_cyc_y], scaling_factors=f'1 {lambda_cyc}') #loss_gen1 = L_l2_x + lamda_cyc*L_cyc_y layers = list(lbann.traverse_layer_graph(input)) weights = set() #Freeze appropriate (pretrained) weights pretrained_models = ["wae"] #add macc? for l in layers: for idx in range(len(pretrained_models)): if (l.weights and pretrained_models[idx] in l.name): for w in range(len(l.weights)): l.weights[w].optimizer = lbann.NoOptimizer() weights.update(l.weights) l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4) #d_adv_bce = lbann.LayerTerm(d_adv_bce,scale=0.01) # Setup objective function obj = lbann.ObjectiveFunction([loss_gen0, loss_gen1, l2_reg]) # Initialize check metric callback metrics = [ lbann.Metric(img_sca_loss, name='fw_loss'), lbann.Metric(L_l2_x, name='inverse loss'), lbann.Metric(L_cyc_y, name='output cycle loss'), lbann.Metric(L_cyc_x, name='param cycle loss') ] callbacks = [ lbann.CallbackPrint(), lbann.CallbackSaveModel(dir=dump_models), lbann.CallbackLoadModel(dirs=str(pretrained_dir)), lbann.CallbackTimer() ] if (ltfb_batch_interval > 0): callbacks.append( lbann.CallbackLTFB(batch_interval=ltfb_batch_interval, metric='fw_loss', low_score_wins=True, exchange_hyperparameters=True)) # Construct model return lbann.Model(num_epochs, weights=weights, layers=layers, metrics=metrics, objective_function=obj, callbacks=callbacks)
def forward(self, x, prev_state): """Apply GRU step. Args: x (Layer): Input. prev_state: State from previous GRU step. Returns: (Layer, Layer): The output (out) and state (hn). The state can be passed directly into the next GRU step. """ self.step += 1 name = '{0}_step{1}'.format(self.name, self.step) fc1 = self.ih_fc(x) #input_fc fc2 = self.hh_fc(prev_state) #hidden_fc # Get gates and cell update fc1_slice = lbann.Slice(fc1, slice_points=str_list([0, self.size, 2*self.size, 3*self.size]), name=name + '_fc1_slice', data_layout=self.data_layout) Wir_x = lbann.Identity(fc1_slice, name=name + '_Wrx', data_layout=self.data_layout) Wiz_x = lbann.Identity(fc1_slice, name=name + '_Wzx', data_layout=self.data_layout) Win_x = lbann.Identity(fc1_slice, name=name + '_Wnx', data_layout=self.data_layout) fc2_slice = lbann.Slice(fc2, slice_points=str_list([0, self.size, 2*self.size, 3*self.size]), name=name + '_fc2_slice', data_layout=self.data_layout) Whr_prev = lbann.Identity(fc2_slice, name=name + '_Wrh', data_layout=self.data_layout) Whz_prev = lbann.Identity(fc2_slice, name=name + '_Wzh', data_layout=self.data_layout) Whn_prev = lbann.Identity(fc2_slice, name=name + '_Wnh', data_layout=self.data_layout) rt = \ lbann.Sigmoid( lbann.Add(Wir_x, Whr_prev, data_layout=self.data_layout), name=name + '_reset_gate', data_layout=self.data_layout ) zt = \ lbann.Sigmoid( lbann.Add(Wiz_x, Whz_prev, data_layout=self.data_layout), name=name + '_update_gate', data_layout=self.data_layout, ) nt = \ lbann.Tanh( lbann.Add( Win_x, lbann.Multiply(rt, Whn_prev, data_layout=self.data_layout), data_layout=self.data_layout, ), name=name + '_new_gate', data_layout=self.data_layout, ) ht = \ lbann.Add( lbann.Multiply( lbann.WeightedSum( self.ones, zt, scaling_factors='1 -1', data_layout=self.data_layout ), nt, data_layout=self.data_layout ), lbann.Multiply(zt, prev_state, data_layout=self.data_layout), name=name+ '_output', data_layout=self.data_layout, ) # Return output return ht, ht
def construct_model(): """Construct MACC surrogate model. See https://arxiv.org/pdf/1912.08113.pdf model architecture and other details """ import lbann # Layer graph input = lbann.Input(data_field='samples', name='inp_data') # data is 64*64*4 images + 15 scalar + 5 param inp_slice = lbann.Slice(input, axis=0, slice_points=str_list( [0, args.ydim, args.ydim + args.xdim]), name='inp_slice') gt_y = lbann.Identity(inp_slice, name='gt_y') gt_x = lbann.Identity(inp_slice, name='gt_x') #param not used zero = lbann.Constant(value=0.0, num_neurons='1', name='zero') one = lbann.Constant(value=1.0, num_neurons='1', name='one') z = lbann.Gaussian(mean=0.0, stdev=1.0, neuron_dims="20") wae = macc_models.MACCWAE(args.zdim, args.ydim, cf=args.wae_mcf, use_CNN=args.useCNN) #pretrained, freeze inv = macc_models.MACCInverse(args.xdim, cf=args.surrogate_mcf) fwd = macc_models.MACCForward(args.zdim, cf=args.surrogate_mcf) y_pred_fwd = wae.encoder(gt_y) param_pred_ = wae.encoder(gt_y) input_fake = inv(param_pred_) output_cyc = fwd(input_fake) y_image_re2 = wae.decoder(output_cyc) '''**** Train cycleGAN input params <--> latent space of (images, scalars) ****''' output_fake = fwd(gt_x) y_image_re = wae.decoder(output_fake) y_out = wae.decoder(y_pred_fwd) param_pred2_ = wae.encoder(y_image_re) input_cyc = inv(param_pred2_) L_l2_x = lbann.MeanSquaredError( input_fake, gt_x) #(x,inv(enc(y)), (encoder+)inverse loss L_cyc_x = lbann.MeanSquaredError( input_cyc, gt_x) #param, x cycle loss, from latent space L_l2_y = lbann.MeanSquaredError( output_fake, y_pred_fwd) #pred error into latent space (enc(y),fw(x)) L_cyc_y = lbann.MeanSquaredError( output_cyc, y_pred_fwd) # pred error into latent space (enc(y), fw(inv(enc(y)))) #@todo slice here to separate scalar from image img_sca_loss = lbann.MeanSquaredError( y_image_re, gt_y) # (y,dec(fw(x))) #forward model to decoder, no latent space dec_fw_inv_enc_y = lbann.MeanSquaredError( y_image_re2, gt_y) #(y, dec(fw(inv(enc(y))))) y->enc_z->x'->fw_z->y' wae_loss = lbann.MeanSquaredError(y_out, gt_y) #(y, dec(enc(y)) ' #L_cyc = L_cyc_y + L_cyc_x L_cyc = lbann.Add(L_cyc_y, L_cyc_x) #loss_gen0 = L_l2_y + lamda_cyc*L_cyc loss_gen0 = lbann.WeightedSum([L_l2_y, L_cyc], scaling_factors=f'1 {args.lamda_cyc}') loss_gen1 = lbann.WeightedSum([L_l2_x, L_cyc_y], scaling_factors=f'1 {args.lamda_cyc}') #loss_gen1 = L_l2_x + lamda_cyc*L_cyc_y conc_out = lbann.Concatenation( [gt_x, wae_loss, img_sca_loss, dec_fw_inv_enc_y, L_l2_x], name='x_errors') layers = list(lbann.traverse_layer_graph(input)) weights = set() for l in layers: weights.update(l.weights) # Setup objective function obj = lbann.ObjectiveFunction([loss_gen0, loss_gen1]) # Initialize check metric callback metrics = [ lbann.Metric(img_sca_loss, name='img_re1'), lbann.Metric(dec_fw_inv_enc_y, name='img_re2'), lbann.Metric(wae_loss, name='wae_loss'), lbann.Metric(L_l2_x, name='inverse loss'), lbann.Metric(L_cyc_y, name='output cycle loss'), lbann.Metric(L_cyc_x, name='param cycle loss') ] callbacks = [ lbann.CallbackPrint(), lbann.CallbackDumpOutputs(layers=f'{conc_out.name}', execution_modes='test', directory=args.dump_outputs, batch_interval=1, format='npy'), lbann.CallbackTimer() ] # Construct model num_epochs = 1 return lbann.Model(num_epochs, weights=weights, layers=layers, serialize_io=True, metrics=metrics, objective_function=obj, callbacks=callbacks)
def forward( self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, ): if position_ids is None: if input_ids is not None: position_ids = create_position_ids_from_input_ids( input_ids, self.input_shape, self.padding_idx, ) else: position_ids = self.create_position_ids_from_inputs_embeds( inputs_embeds) if token_type_ids is None: token_type_ids = lbann.Constant(value=0, num_neurons=str_list( self.input_shape)) if inputs_embeds is None: inputs_embeds = lbann.Embedding( input_ids, num_embeddings=self.vocab_size, embedding_dim=self.hidden_size, padding_idx=self.pad_token_id, weights=_load_pretrained_weights( ".".join((self.name, "word_embeddings.weight")), load_weights=self.load_weights, ), name=".".join((self.name, "word_embeddings")), ) token_type_embeddings = lbann.Embedding( token_type_ids, num_embeddings=self.type_vocab_size, embedding_dim=self.hidden_size, weights=_load_pretrained_weights( ".".join((self.name, "token_type_embeddings.weight")), load_weights=self.load_weights, ), name=".".join((self.name, "token_type_embeddings")), ) embeddings = lbann.Add(inputs_embeds, token_type_embeddings) if self.position_embedding_type == "absolute": position_embeddings = lbann.Embedding( position_ids, num_embeddings=self.max_position_embeddings, embedding_dim=self.hidden_size, padding_idx=self.pad_token_id, weights=_load_pretrained_weights( ".".join((self.name, "position_embeddings.weight")), load_weights=self.load_weights, ), name=".".join((self.name, "position_embeddings")), ) embeddings = lbann.Add(embeddings, position_embeddings) embeddings = lbann.modules.PytorchLayerNorm( embeddings, self.layer_norm_eps, self.input_shape + (self.hidden_size, ), weights=_load_pretrained_weights( ".".join((self.name, "layernorm.weightbias")), load_weights=self.load_weights, ), name=".".join((self.name, "LayerNorm")), ) embeddings = lbann.Dropout(embeddings, keep_prob=self.hidden_dropout_prob) return embeddings
def forward( self, hidden_states, attention_mask=None, head_mask=None, ): mixed_query_layer, query_shape = lbann.modules.PytorchLinear( hidden_states, self.input_shape, self.all_head_size, weights=_load_pretrained_weights( ".".join((self.name, "query.weight")), ".".join((self.name, "query.bias")), load_weights=self.load_weights, ), name=".".join((self.name, "query")), return_dims=True, ) query_layer, query_shape = self.transpose_for_scores( mixed_query_layer, query_shape) key_layer, key_shape = lbann.modules.PytorchLinear( hidden_states, self.input_shape, self.all_head_size, weights=_load_pretrained_weights( ".".join((self.name, "key.weight")), ".".join((self.name, "key.bias")), load_weights=self.load_weights, ), name=".".join((self.name, "key")), return_dims=True, ) key_layer, key_shape = self.transpose_for_scores(key_layer, key_shape) value_layer, value_shape = lbann.modules.PytorchLinear( hidden_states, self.input_shape, self.all_head_size, weights=_load_pretrained_weights( ".".join((self.name, "value.weight")), ".".join((self.name, "value.bias")), load_weights=self.load_weights, ), name=".".join((self.name, "value")), return_dims=True, ) value_layer, value_shape = self.transpose_for_scores( value_layer, value_shape) # Take the dot product between "query" and "key" to get the raw attention scores. key_layer, key_shape = lbann.modules.Permute(key_layer, key_shape, axes=(0, 1, -1, -2), return_dims=True) attention_scores, attention_shape = lbann.modules.PytorchMatmul( query_layer, query_shape, key_layer, key_shape, return_dims=True, ) attention_scores = lbann.Scale(attention_scores, constant=1 / math.sqrt(self.attention_head_size)) if attention_mask is not None: # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function) attention_scores = lbann.Add(attention_scores, attention_mask) # Normalize the attention scores to probabilities. attention_scores = lbann.Reshape( attention_scores, dims=str_list([np.prod(attention_shape[:-1]), attention_shape[-1]]), ) attention_probs = lbann.ChannelwiseSoftmax(attention_scores) attention_probs = lbann.Reshape(attention_probs, dims=str_list(attention_shape)) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = lbann.Dropout( attention_probs, keep_prob=self.attention_probs_dropout_prob, ) # Mask heads if we want to if head_mask is not None: attention_probs = lbann.Multiply(attention_probs, head_mask) context_layer, context_shape = lbann.modules.PytorchMatmul( attention_probs, attention_shape, value_layer, value_shape, return_dims=True, ) context_layer, context_shape = lbann.modules.Permute( context_layer, context_shape, axes=(0, 2, 1, 3), return_dims=True, ) new_context_layer_shape = context_shape[:-2] + (self.all_head_size, ) context_layer = lbann.Reshape(context_layer, dims=str_list(self.input_shape)) return context_layer
def forward(self, x, prev_state): """Apply LSTM step. Args: x (Layer): Input. prev_state (tuple with two `Layer`s): State from previous LSTM step. Comprised of LSTM output and cell state. Returns: (Layer, (Layer, Layer)): The output and state (the output and cell state). The state can be passed directly into the next LSTM step. """ self.step += 1 name = '{0}_step{1}'.format(self.name, self.step) # Get output and cell state from previous step prev_output, prev_cell = prev_state # Apply linearity input_concat = lbann.Concatenation(x, prev_output, name=name + '_input', data_layout=self.data_layout) fc = self.fc(input_concat) # Get gates and cell update slice = lbann.Slice(fc, slice_points=str_list([0, self.size, 4*self.size]), name=name + '_fc_slice', data_layout=self.data_layout) cell_update = lbann.Tanh(slice, name=name + '_cell_update', data_layout=self.data_layout) sigmoid = lbann.Sigmoid(slice, name=name + '_sigmoid', data_layout=self.data_layout) slice = lbann.Slice(sigmoid, slice_points=str_list([0, self.size, 2*self.size, 3*self.size]), name=name + '_sigmoid_slice', data_layout=self.data_layout) f = lbann.Identity(slice, name=name + '_forget_gate', data_layout=self.data_layout) i = lbann.Identity(slice, name=name + '_input_gate', data_layout=self.data_layout) o = lbann.Identity(slice, name=name + '_output_gate', data_layout=self.data_layout) # Cell state cell_forget = lbann.Multiply(f, prev_cell, name=name + '_cell_forget', data_layout=self.data_layout) cell_input = lbann.Multiply(i, cell_update, name=name + '_cell_input', data_layout=self.data_layout) cell = lbann.Add(cell_forget, cell_input, name=name + '_cell', data_layout=self.data_layout) # Output cell_act = lbann.Tanh(cell, name=name + '_cell_activation', data_layout=self.data_layout) output = lbann.Multiply(o, cell_act, name=name, data_layout=self.data_layout) # Return output and state return output, (output, cell)
def compute_loss(self, x, y): # y[:, :-1] y = lbann.Slice( y, axis=0, slice_points=str_list([0, self.input_feature_dims-1]), ) y = lbann.Identity(y) # x[:, 1:] x = lbann.Slice( x, slice_points=str_list([1, self.input_feature_dims]), ) x = lbann.Identity(x) # Figure out entries in x to ignore ignore_mask = lbann.Equal( x, self.constant(self.label_to_ignore, hint_layer=x), ) keep_mask = lbann.LogicalNot(ignore_mask) length = lbann.Reduction(keep_mask, mode='sum') length = lbann.Max(length, self.constant(1, [1])) # Convert entries in x to indices in y # Note: Ignored entries correspond to an index of -1. offsets = [ row*self.dictionary_size for row in range(self.input_feature_dims-1) ] offsets = lbann.Weights( initializer=lbann.ValueInitializer(values=str_list(offsets)), optimizer=lbann.NoOptimizer(), ) offsets = lbann.WeightsLayer( dims=str_list([self.input_feature_dims-1]), weights=offsets, ) y_inds = lbann.Add(x, offsets) y_inds = lbann.Add( lbann.Multiply(keep_mask, y_inds), lbann.Multiply( ignore_mask, self.constant(-1, hint_layer=y_inds), ), ) # recon_loss = F.cross_entropy( # y[:, :-1].contiguous().view(-1, y.size(-1)), # x[:, 1:].contiguous().view(-1), # ignore_index=self.pad # ) # Shift y for numerical stability # Note: We'd prefer to shift by y.max(-1) shifts = lbann.MatMul( lbann.Max(y, self.constant(0, hint_layer=y)), self.constant( 1 / math.sqrt(self.dictionary_size), [self.dictionary_size, self.dictionary_size], ), ) y = lbann.Subtract(y, shifts) # Compute log of softmax denominator and sum z = lbann.MatMul( lbann.Exp(y), self.constant(1, [self.dictionary_size, 1]), ) z = lbann.Log(z) z = lbann.MatMul( lbann.Reshape(keep_mask, dims=str_list([1, -1])), z, ) z = lbann.Reshape(z, dims=str_list([1])) # Compute cross entropy recon_loss = lbann.Gather( lbann.Reshape(y, dims=str_list([-1])), y_inds, ) recon_loss = lbann.Reduction(recon_loss, mode='sum') recon_loss = lbann.Subtract(z, recon_loss) recon_loss = lbann.Divide(recon_loss, length) return recon_loss
# Pearson Correlation # rho(x,y) = covariance(x,y) / sqrt( variance(x) * variance(y) ) pearson_r_cov = lbann.Covariance([pred, responses], name="pearson_r_cov") pearson_r_var1 = lbann.Variance(responses, name="pearson_r_var1") pearson_r_var2 = lbann.Variance(pred, name="pearson_r_var2") pearson_r_mult = lbann.Multiply([pearson_r_var1, pearson_r_var2], name="pearson_r_mult") pearson_r_sqrt = lbann.Sqrt(pearson_r_mult, name="pearson_r_sqrt") eps = lbann.Constant(value=1e-07, hint_layer=pearson_r_sqrt) pearson_r = lbann.Divide( [pearson_r_cov, lbann.Add(pearson_r_sqrt, eps)], name="pearson_r") metrics = [lbann.Metric(mse, name='mse')] metrics.append(lbann.Metric(pearson_r, name='pearson_r')) callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()] layers = list(lbann.traverse_layer_graph([images, responses])) model = lbann.Model(args.num_epochs, layers=layers, metrics=metrics, objective_function=mse, callbacks=callbacks) # Load data reader from prototext data_reader_proto = lbann.lbann_pb2.LbannPB()
def forward(self, x, prev_state): """ Apply GRU step channelwise Args: x (Layer): Input (shape: (num_channels, *)) prev_state (Layer): Sate from previous GRU step (shape: (num_channels, size)) Returns: (Layer, Layer): The output (out) and state (hn). The state can be passed directly into the next GRU step """ self.step += 1 name = f"{self.name}_step{self.step}" mat_size = self.num_channels * self.size prev_state = lbann.Reshape(prev_state, dims=str_list( [self.num_channels, self.size]), name=name + "_prev_state_reshape") fc1 = self.ih_fc(x) fc2 = self.hh_fc(prev_state) fc1_slice = lbann.Slice( fc1, axis=1, slice_points=str_list([0, self.size, 2 * self.size, 3 * self.size])) Wir_x = lbann.Reshape(lbann.Identity(fc1_slice), dims=str_list([self.num_channels, self.size]), name=name + '_Wir_x') Wiz_z = lbann.Reshape(lbann.Identity(fc1_slice), dims=str_list([self.num_channels, self.size]), name=name + '_Wiz_z') Win_x = lbann.Reshape(lbann.Identity(fc1_slice), dims=str_list([self.num_channels, self.size]), name=name + '_Win_x') fc2_slice = lbann.Slice( fc2, axis=1, slice_points=str_list([0, self.size, 2 * self.size, 3 * self.size])) Whr_x = lbann.Reshape(lbann.Identity(fc2_slice), dims=str_list([self.num_channels, self.size]), name=name + '_Whr_x') Whz_z = lbann.Reshape(lbann.Identity(fc2_slice), dims=str_list([self.num_channels, self.size]), name=name + '_Whz_z') Whn_x = lbann.Reshape(lbann.Identity(fc2_slice), dims=str_list([self.num_channels, self.size]), name=name + '_Whn_x') rt = \ lbann.Sigmoid( lbann.Add(Wir_x, Whr_x, data_layout=self.data_layout), name=name + '_reset_gate', data_layout=self.data_layout ) zt = \ lbann.Sigmoid( lbann.Add(Wiz_z, Whz_z, data_layout=self.data_layout), name=name + '_update_gate', data_layout=self.data_layout, ) nt = \ lbann.Tanh( lbann.Add( Win_x, lbann.Multiply(rt, Whn_x, data_layout=self.data_layout), data_layout=self.data_layout, ), name=name + '_new_gate', data_layout=self.data_layout, ) ht = \ lbann.Add( lbann.Multiply( lbann.WeightedSum( self.ones, zt, scaling_factors='1 -1', data_layout=self.data_layout ), nt, data_layout=self.data_layout ), lbann.Multiply(zt, prev_state, data_layout=self.data_layout), name=name+ '_output', data_layout=self.data_layout, ) ht = lbann.Reshape(ht, dims=str_list([self.num_channels, self.size])) return ht, ht