def forward(self, x): """Perform LSTM step. State from previous steps is used to compute output. """ self.step += 1 name = '{0}_step{1}'.format(self.name, self.step) # Apply linearity input_concat = lbann.Concatenation([x, self.last_output], name=name + '_input', data_layout=self.data_layout) fc = self.fc(input_concat) # Get gates and cell update slice = lbann.Slice(fc, slice_points=_str_list([0, self.size, 4*self.size]), name=name + '_fc_slice', data_layout=self.data_layout) cell_update = lbann.Tanh(slice, name=name + '_cell_update', data_layout=self.data_layout) sigmoid = lbann.Sigmoid(slice, name=name + '_sigmoid', data_layout=self.data_layout) slice = lbann.Slice(sigmoid, slice_points=_str_list([0, self.size, 2*self.size, 3*self.size]), name=name + '_sigmoid_slice', data_layout=self.data_layout) f = lbann.Identity(slice, name=name + '_forget_gate', data_layout=self.data_layout) i = lbann.Identity(slice, name=name + '_input_gate', data_layout=self.data_layout) o = lbann.Identity(slice, name=name + '_output_gate', data_layout=self.data_layout) # Cell state cell_forget = lbann.Multiply([f, self.last_cell], name=name + '_cell_forget', data_layout=self.data_layout) cell_input = lbann.Multiply([i, cell_update], name=name + '_cell_input', data_layout=self.data_layout) cell = lbann.Add([cell_forget, cell_input], name=name + '_cell', data_layout=self.data_layout) # Output cell_act = lbann.Tanh(cell, name=name + '_cell_activation', data_layout=self.data_layout) output = lbann.Multiply([o, cell_act], name=name, data_layout=self.data_layout) # Update state and return output self.last_cell = cell self.last_output = output return output
def Gelu_approx(x): # This approximates gelu and may be more performant # return 0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x ** 3))) # Based on: https://paperswithcode.com/method/gelu sqrt_2_over_pi = math.sqrt(2 / math.pi) b_coef = 0.044715 x_cubed = lbann.Multiply(lbann.Multiply(lbann.Identity(x), x), x) inner_tanh_x_comp = lbann.Add(x, lbann.Scale(x_cubed, constant=b_coef)) tanh_x = lbann.Tanh(lbann.Scale(inner_tanh_x_comp, constant=sqrt_2_over_pi)) return lbann.Scale(lbann.Multiply(x, lbann.AddConstant(tanh_x, constant=1)), constant=0.5)
def forward(self, x, label): """Compute cross-entropy loss. Args: x (lbann.Layer): Input vector. label (lbann.Layer): Label. Should have one entry, which will be cast to an integer. Returns: lbann.Layer: Loss function value. """ log_probs = self.fc(x) label_onehot = lbann.OneHot( label, size=self.num_classes, data_layout=self.data_layout, ) loss = lbann.Multiply( log_probs, label_onehot, data_layout=self.data_layout, ) loss = lbann.Reduction( loss, mode="sum", data_layout=self.data_layout, ) loss = lbann.Negative(loss, data_layout=self.data_layout) return loss
def forward_encoder(self, x_emb): """Encoder step, emulating z ~ E(x) = q_E(z|x) :param x_emb: (n_batch, len(x), d_z) of floats, embeddings for input sentence x :return: (n_batch, d_z) of floats, sample of latent vector z :return: float, kl term component of loss """ # _, h = self.encoder_rnn(x, None) h = self.encoder_rnn(x_emb, None) h = lbann.Slice( h, slice_points=str_list( [self.input_feature_dims - 1, self.input_feature_dims]), axis=0, ) h = lbann.Identity(h) mu, logvar = self.q_mu(h), self.q_logvar(h) # Set datatype of previous layers # Note: Depth-first search from mu and logvar to x_emb stack = [mu, logvar] in_stack = {l: True for l in stack} while stack: l = stack.pop() if type(l) not in (lbann.Slice, lbann.Reshape, lbann.Tessellate): l.datatype = self.datatype for parent in l.parents: if parent not in in_stack and parent is not x_emb: stack.append(parent) in_stack[parent] = True # eps = torch.randn_like(mu) eps = lbann.Gaussian(mean=0, stdev=1, hint_layer=mu) # z = mu + (logvar / 2).exp() * eps z = lbann.Add([ mu, (lbann.Multiply([ lbann.Exp(lbann.WeightedSum(logvar, scaling_factors='0.5')), eps ])) ]) # kl_loss = 0.5 * (logvar.exp() + mu ** 2 - 1 - logvar).sum(1).mean() kl_loss = lbann.Reduction( lbann.WeightedSum( lbann.Exp(logvar), lbann.Square(mu), self.constant(1, hint_layer=mu), logvar, scaling_factors='0.5 0.5 -0.5 -0.5', ), mode='sum', ) return z, kl_loss
def forward(self, inputs): if len(inputs) != 2: raise ValueError('expected two inputs: predictions and labels') pred = inputs[0] label = inputs[1] # Assumed to be Boolean masked_pred = lbann.Multiply([pred, label]) pred_sum = lbann.Reduction(masked_pred) return lbann.Negative(lbann.Log(pred_sum))
def forward(self, inputs): if len(inputs) != 2: raise ValueError('expected two inputs: predictions and labels') pred = inputs[0] label = inputs[1] ones = p.Constant(hint_layer=pred, value=1.0) term1 = lbann.Multiply( [label, lbann.Log(lbann.Subtract([ones, pred]))]) term2 = lbann.Log(pred) full = lbann.WeightedSum([term1, term2], scaling_factors='-1.0 -1.0') return lbann.Reduction(full)
def forward(self, inputs): raise NotImplementedError # Requires log-gamma function if len(inputs) != 2: raise ValueError('expected two inputs: predictions and labels') pred = inputs[0] label = inputs[1] ones = lbann.Constant(hint_layer=pred, value=1.0) term1 = pred term2 = lbann.Multiply([label, lbann.Log(pred)]) term3 = lbann.LogGamma(lbann.Add([label, ones])) full = lbann.WeightedSum([term1, term2, term3], scaling_factors='1.0 -1.0 1.0') return lbann.Reduction(full)
def forward( self, motif_indices, motif_size, walk_indices, walk_length, ): # Apply generator fake_motif_indices, gen_prob, gen_log_prob = self.generator( walk_length, walk_indices, motif_size, ) # Get discriminator embeddings in log-space all_motif_indices = lbann.Concatenation(motif_indices, fake_motif_indices) all_motif_log_embeddings = self.discriminator.get_log_embeddings( all_motif_indices) all_motif_log_embeddings = lbann.Slice( all_motif_log_embeddings, slice_points=str_list([0, motif_size, 2 * motif_size]), ) real_motif_log_embeddings = lbann.Identity(all_motif_log_embeddings) fake_motif_log_embeddings = lbann.Identity(all_motif_log_embeddings) # Apply discriminator real_disc_prob, real_disc_log_not_prob \ = self.discriminator(motif_size, real_motif_log_embeddings) fake_disc_prob, fake_disc_log_not_prob \ = self.discriminator(motif_size, fake_motif_log_embeddings) # Loss function # L_disc = - log(D(real)) - log(1-D(fake)) # L_gen = - log(G) * stop_gradient(log(1-D(fake))) real_disc_log_prob \ = lbann.Log(lbann.Clamp(real_disc_prob, min=1e-37, max=1)) disc_loss = lbann.WeightedSum( real_disc_log_prob, fake_disc_log_not_prob, scaling_factors=str_list([-1, -1]), ) gen_loss = lbann.Multiply( gen_log_prob, lbann.StopGradient(fake_disc_log_not_prob), ) loss = lbann.Add(disc_loss, gen_loss) return loss, real_disc_prob, fake_disc_prob, gen_prob
def forward(self, image, dims, max_r): """Compute radial profile. Args: image (lbann.Layer): Image dims (tuple of int): Image dimensions (dim 0 corresponds to channel) max_r (int): Maximum radial distance. Pixels outside this distance are ignored. Returns: Layer: num_channels x max_r radial profile """ # Bin spatial positions r, r_counts = self._find_radial_bins(dims[1:], max_r) # Reciprocal of bin counts # Note: If a count is 0, its reciprocal is 0. r_counts_recip = [0 if c == 0 else 1 / c for c in r_counts] # Get scatter indices and scaling factors # Note: Independent binning for each channel (dim 0) tile_dims = [dims[0]] + [1] * r.ndim inds_vals = np.tile(r, tile_dims) inds_vals += np.arange(0, dims[0] * max_r, max_r).reshape(tile_dims) inds_vals[:, r >= max_r] = -1 inds_vals = inds_vals.flatten() scales_vals = r_counts_recip * dims[0] # Construct LBANN layer graph image = lbann.Reshape(image, dims=str_list([np.prod(dims)])) inds = lbann.WeightsLayer( weights=lbann.Weights( lbann.ValueInitializer(values=str_list(inds_vals)), optimizer=lbann.NoOptimizer(), ), dims=str_list([len(inds_vals)]), ) r_sums = lbann.Scatter(image, inds, dims=str_list([dims[0] * max_r])) scales = lbann.WeightsLayer( weights=lbann.Weights( lbann.ValueInitializer(values=str_list(scales_vals)), optimizer=lbann.NoOptimizer(), ), dims=str_list([len(scales_vals)]), ) r_means = lbann.Multiply(scales, r_sums) return lbann.Reshape(r_means, dims=str_list([dims[0], max_r]))
def create_position_ids_from_input_ids(input_ids, input_shape, padding_idx, past_key_values_length=0): padding_idx = lbann.Constant(value=padding_idx, num_neurons=str_list(input_shape)) mask = lbann.NotEqual(input_ids, padding_idx) incremental_indices = lbann.Multiply( lbann.AddConstant( lbann.modules.Cumsum(mask, input_shape, axis=1), constant=past_key_values_length, ), mask, ) incremental_indices = lbann.Add(incremental_indices, padding_idx) return incremental_indices
def forward(self, node_feature_mat, source_indices, target_indices, activation=lbann.Relu): """Apply GIN Layer. Args: node_feature_mat (Layer): Node feature matrix with the shape of (num_nodes,input_channels) source_indices (Layer): Source node indices of the edges with shape (num_nodes) target_indices (Layer): Target node indices of the edges with shape (num_nodes activation (Layer): Activation layer for the node features. If None, then no activation is applied. (default: lbann.Relu) Returns: (Layer) : The output after kernel ops. The output can passed into another Graph Conv layer directly """ eps = lbann.Constant(value=(1 + self.eps), num_neurons=str_list( [self.num_nodes, self.input_channel_size])) eps_node_features = lbann.Multiply(node_feature_mat, eps, name=self.name + "_epl_mult") node_feature_mat = lbann.Sum(eps_node_features, node_feature_mat) # Transform with the sequence of linear layers for layer in self.nn: node_feature_mat = layer(node_feature_mat) neighborhoods = GraphExpand(node_feature_mat, target_indices) neighborhoods = lbann.Reshape( neighborhoods, dims=str_list([self.num_edges, self.output_channel_size])) aggregated_node_features = GraphReduce( neighborhoods, source_indices, [self.num_nodes, self.output_channel_size]) ## Apply activation if activation: aggregated_node_features = activation(aggregated_node_features) return aggregated_node_features
def forward(self, X, A, activation = lbann.Relu): """Apply GIN Layer. Args: X (GraphVertexData): LBANN Data object, which is a collection of Layers. Each Layer is of the shape (1,input_channels) A (Layer): Adjacency matrix input with shape (num_nodes, num_nodes) activation (Layer): Activation layer for the node features. If None, then no activation is applied. (default: lbann.Relu) Returns: (GraphVertexData): The output after GCN. The output can passed into another Graph Conv layer directly """ in_channel = X.shape[1] # Accumulate Messages from Neighboring Nodes out = X.get_mat() out = lbann.MatMul(A,out, name = self.name+"_GIN_MATMUL") message = GraphVertexData.matrix_to_graph(out, X.shape[0], in_channel) # Aggregate Messages into node features eps = lbann.Constant(value=(1+self.eps),num_neurons = str_list([1, in_channel])) for node_feature in range(X.shape[0]): eps_val = lbann.Multiply(eps, X[node_feature]) X[node_feature] = lbann.Sum(message[node_feature], eps_val) # Transform with the sequence of linear layers for layer in self.nn: for node_feature in range(X.shape[0]): X[node_feature] = layer(X[node_feature]) ## Apply activation if activation: for node_feature in range(X.shape[0]): X[node_feature] = activation(X[node_feature]) X.update_num_features(self.output_channels) return X
def Silu(x): return lbann.Multiply(x, lbann.Sigmoid(x))
def compute_loss(self, x, y): # y[:, :-1] y = lbann.Slice( y, axis=0, slice_points=str_list([0, self.input_feature_dims - 1]), ) y = lbann.Identity(y) # x[:, 1:] x = lbann.Slice( x, slice_points=str_list([1, self.input_feature_dims]), ) x = lbann.Identity(x) # Convert indices in x to one-hot representation # Note: Ignored indices result in zero vectors ignore_mask = lbann.Equal( x, self.constant(self.label_to_ignore, hint_layer=x), ) keep_mask = lbann.LogicalNot(ignore_mask) length = lbann.Reduction(keep_mask, mode='sum') length = lbann.Max(length, self.constant(1, [1])) x = lbann.Add( lbann.Multiply(keep_mask, x), lbann.Multiply(ignore_mask, self.constant(-1, hint_layer=x)), ) x = lbann.Slice(x, slice_points=str_list(range(self.input_feature_dims))) x = [lbann.Identity(x) for _ in range(self.input_feature_dims - 1)] x = [lbann.OneHot(xi, size=self.dictionary_size) for xi in x] x = [ lbann.Reshape(xi, dims=str_list([1, self.dictionary_size])) for xi in x ] x = lbann.Concatenation(x, axis=0) # recon_loss = F.cross_entropy( # y[:, :-1].contiguous().view(-1, y.size(-1)), # x[:, 1:].contiguous().view(-1), # ignore_index=self.pad # ) # Note: Ideally we'd shift y by y.max(-1) for numerical stability shifts = lbann.MatMul( lbann.Max(y, self.constant(0, hint_layer=y)), self.constant( 1 / math.sqrt(self.dictionary_size), [self.dictionary_size, self.dictionary_size], ), ) y = lbann.Subtract(y, shifts) z = lbann.MatMul( lbann.Exp(y), self.constant(1, [self.dictionary_size, 1]), ) z = lbann.Log(z) z = lbann.MatMul( lbann.Reshape(keep_mask, dims=str_list([1, -1])), z, ) recon_loss = lbann.MatMul( lbann.Reshape(y, dims=str_list([1, -1])), lbann.Reshape(x, dims=str_list([1, -1])), transpose_b=True, ) recon_loss = lbann.Subtract(z, recon_loss) recon_loss = lbann.Reshape(recon_loss, dims=str_list([1])) recon_loss = lbann.Divide(recon_loss, length) return recon_loss
def compute_loss(self, x, y): # y[:, :-1] y = lbann.Slice( y, axis=0, slice_points=str_list([0, self.input_feature_dims-1]), ) y = lbann.Identity(y) # x[:, 1:] x = lbann.Slice( x, slice_points=str_list([1, self.input_feature_dims]), ) x = lbann.Identity(x) # Figure out entries in x to ignore ignore_mask = lbann.Equal( x, self.constant(self.label_to_ignore, hint_layer=x), ) keep_mask = lbann.LogicalNot(ignore_mask) length = lbann.Reduction(keep_mask, mode='sum') length = lbann.Max(length, self.constant(1, [1])) # Convert entries in x to indices in y # Note: Ignored entries correspond to an index of -1. offsets = [ row*self.dictionary_size for row in range(self.input_feature_dims-1) ] offsets = lbann.Weights( initializer=lbann.ValueInitializer(values=str_list(offsets)), optimizer=lbann.NoOptimizer(), ) offsets = lbann.WeightsLayer( dims=str_list([self.input_feature_dims-1]), weights=offsets, ) y_inds = lbann.Add(x, offsets) y_inds = lbann.Add( lbann.Multiply(keep_mask, y_inds), lbann.Multiply( ignore_mask, self.constant(-1, hint_layer=y_inds), ), ) # recon_loss = F.cross_entropy( # y[:, :-1].contiguous().view(-1, y.size(-1)), # x[:, 1:].contiguous().view(-1), # ignore_index=self.pad # ) # Shift y for numerical stability # Note: We'd prefer to shift by y.max(-1) shifts = lbann.MatMul( lbann.Max(y, self.constant(0, hint_layer=y)), self.constant( 1 / math.sqrt(self.dictionary_size), [self.dictionary_size, self.dictionary_size], ), ) y = lbann.Subtract(y, shifts) # Compute log of softmax denominator and sum z = lbann.MatMul( lbann.Exp(y), self.constant(1, [self.dictionary_size, 1]), ) z = lbann.Log(z) z = lbann.MatMul( lbann.Reshape(keep_mask, dims=str_list([1, -1])), z, ) z = lbann.Reshape(z, dims=str_list([1])) # Compute cross entropy recon_loss = lbann.Gather( lbann.Reshape(y, dims=str_list([-1])), y_inds, ) recon_loss = lbann.Reduction(recon_loss, mode='sum') recon_loss = lbann.Subtract(z, recon_loss) recon_loss = lbann.Divide(recon_loss, length) return recon_loss
# rho(x,y) = covariance(x,y) / sqrt( variance(x) * variance(y) ) pearson_r_cov = lbann.Covariance([reconstruction, data], name="pearson_r_cov", data_layout="model_parallel") pearson_r_var1 = lbann.Variance(data, name="pearson_r_var1", data_layout="model_parallel") pearson_r_var2 = lbann.Variance(reconstruction, name="pearson_r_var1", data_layout="model_parallel") pearson_r_mult = lbann.Multiply([pearson_r_var1, pearson_r_var2], name="pearson_r_mult", data_layout="model_parallel") pearson_r_sqrt = lbann.Sqrt(pearson_r_mult, name="pearson_r_sqrt", data_layout="model_parallel") pearson_r = lbann.Divide([pearson_r_cov, pearson_r_sqrt], name="pearson_r", data_layout="model_parallel") layer_list = list(lbann.traverse_layer_graph(input_)) # Set up objective function layer_term = lbann.LayerTerm(mean_squared_error) obj = lbann.ObjectiveFunction(layer_term)
def construct_model(run_args): """Construct LBANN model. Initial model for ATOM molecular SMILES generation Network architecture and training hyperparameters from https://github.com/samadejacobs/moses/tree/master/moses/char_rnn """ pad_index = run_args.pad_index assert pad_index is not None sequence_length = run_args.sequence_length assert sequence_length is not None print("sequence length is {}".format(sequence_length)) data_layout = "data_parallel" # Layer graph _input = lbann.Input(name="inp_tensor", data_field='samples') print(sequence_length) x_slice = lbann.Slice( _input, axis=0, slice_points=str_list(range(sequence_length + 1)), name="inp_slice", ) # embedding layer emb = [] embedding_dim = run_args.embedding_dim num_embeddings = run_args.num_embeddings assert embedding_dim is not None assert num_embeddings is not None emb_weights = lbann.Weights( initializer=lbann.NormalInitializer(mean=0, standard_deviation=1), name="emb_matrix", ) lstm1 = lbann.modules.GRU(size=run_args.hidden, data_layout=data_layout) fc = lbann.modules.FullyConnectedModule(size=num_embeddings, data_layout=data_layout) last_output = lbann.Constant( value=0.0, num_neurons="{}".format(run_args.hidden), data_layout=data_layout, name="lstm_init_output", ) lstm1_prev_state = [last_output] loss = [] idl = [] for i in range(sequence_length): idl.append( lbann.Identity(x_slice, name="slice_idl_" + str(i), device="CPU")) for i in range(sequence_length - 1): emb_l = lbann.Embedding( idl[i], name="emb_" + str(i), weights=emb_weights, embedding_dim=embedding_dim, num_embeddings=num_embeddings, ) x, lstm1_prev_state = lstm1(emb_l, lstm1_prev_state) fc_l = fc(x) y_soft = lbann.Softmax(fc_l, name="soft_" + str(i)) gt = lbann.OneHot(idl[i + 1], size=num_embeddings) ce = lbann.CrossEntropy([y_soft, gt], name="loss_" + str(i)) # mask padding in input pad_mask = lbann.NotEqual( [idl[i], lbann.Constant(value=pad_index, num_neurons="1")], ) ce_mask = lbann.Multiply([pad_mask, ce], name="loss_mask_" + str(i)) loss.append(lbann.LayerTerm(ce_mask, scale=1 / (sequence_length - 1))) layers = list(lbann.traverse_layer_graph(_input)) # Setup objective function weights = set() for l in layers: weights.update(l.weights) obj = lbann.ObjectiveFunction(loss) callbacks = [ lbann.CallbackPrint(), lbann.CallbackTimer(), lbann.CallbackStepLearningRate(step=run_args.step_size, amt=run_args.gamma), lbann.CallbackDumpWeights(directory=run_args.dump_weights_dir, epoch_interval=1), ] # Construct model return lbann.Model(run_args.num_epochs, layers=layers, weights=weights, objective_function=obj, callbacks=callbacks)
def forward( self, hidden_states, attention_mask=None, head_mask=None, ): mixed_query_layer, query_shape = lbann.modules.PytorchLinear( hidden_states, self.input_shape, self.all_head_size, weights=_load_pretrained_weights( ".".join((self.name, "query.weight")), ".".join((self.name, "query.bias")), load_weights=self.load_weights, ), name=".".join((self.name, "query")), return_dims=True, ) query_layer, query_shape = self.transpose_for_scores( mixed_query_layer, query_shape) key_layer, key_shape = lbann.modules.PytorchLinear( hidden_states, self.input_shape, self.all_head_size, weights=_load_pretrained_weights( ".".join((self.name, "key.weight")), ".".join((self.name, "key.bias")), load_weights=self.load_weights, ), name=".".join((self.name, "key")), return_dims=True, ) key_layer, key_shape = self.transpose_for_scores(key_layer, key_shape) value_layer, value_shape = lbann.modules.PytorchLinear( hidden_states, self.input_shape, self.all_head_size, weights=_load_pretrained_weights( ".".join((self.name, "value.weight")), ".".join((self.name, "value.bias")), load_weights=self.load_weights, ), name=".".join((self.name, "value")), return_dims=True, ) value_layer, value_shape = self.transpose_for_scores( value_layer, value_shape) # Take the dot product between "query" and "key" to get the raw attention scores. key_layer, key_shape = lbann.modules.Permute(key_layer, key_shape, axes=(0, 1, -1, -2), return_dims=True) attention_scores, attention_shape = lbann.modules.PytorchMatmul( query_layer, query_shape, key_layer, key_shape, return_dims=True, ) attention_scores = lbann.Scale(attention_scores, constant=1 / math.sqrt(self.attention_head_size)) if attention_mask is not None: # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function) attention_scores = lbann.Add(attention_scores, attention_mask) # Normalize the attention scores to probabilities. attention_scores = lbann.Reshape( attention_scores, dims=str_list([np.prod(attention_shape[:-1]), attention_shape[-1]]), ) attention_probs = lbann.ChannelwiseSoftmax(attention_scores) attention_probs = lbann.Reshape(attention_probs, dims=str_list(attention_shape)) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = lbann.Dropout( attention_probs, keep_prob=self.attention_probs_dropout_prob, ) # Mask heads if we want to if head_mask is not None: attention_probs = lbann.Multiply(attention_probs, head_mask) context_layer, context_shape = lbann.modules.PytorchMatmul( attention_probs, attention_shape, value_layer, value_shape, return_dims=True, ) context_layer, context_shape = lbann.modules.Permute( context_layer, context_shape, axes=(0, 2, 1, 3), return_dims=True, ) new_context_layer_shape = context_shape[:-2] + (self.all_head_size, ) context_layer = lbann.Reshape(context_layer, dims=str_list(self.input_shape)) return context_layer
def make_model( num_epochs, embed_dim, num_heads, label_smoothing, ): # Embedding weights var = 2 / (embed_dim + vocab_size) # Glorot initialization embedding_weights = lbann.Weights( name='embeddings', initializer=lbann.NormalInitializer(standard_deviation=math.sqrt(var)), ) # Input is two sequences of token IDs input_ = lbann.Input(data_field='samples') # Get sequences of embedding vectors # Note: Scale embeddings by sqrt(embed_dim). # Note: Decoder input is shifted right, so embedding for last # token isn't needed. embeddings_tokens = lbann.Identity( lbann.Slice( input_, axis=0, slice_points=str_list([0, 2 * sequence_length - 1]), )) embeddings = lbann.Embedding( embeddings_tokens, weights=embedding_weights, num_embeddings=vocab_size, embedding_dim=embed_dim, padding_idx=pad_index, ) embeddings = lbann.WeightedSum( embeddings, scaling_factors=str(math.sqrt(embed_dim)), ) embeddings_slice = lbann.Slice( embeddings, axis=0, slice_points=str_list([0, sequence_length, 2 * sequence_length - 1]), ) encoder_input = lbann.Identity(embeddings_slice) decoder_input = lbann.Identity(embeddings_slice) # Apply transformer model transformer = lbann.models.Transformer( hidden_size=embed_dim, num_heads=num_heads, name='transformer', ) result = transformer( encoder_input, sequence_length, decoder_input, sequence_length - 1, ) # Reconstruct decoder input preds = lbann.ChannelwiseFullyConnected( result, weights=embedding_weights, output_channel_dims=[vocab_size], bias=False, transpose=True, ) preds = lbann.ChannelwiseSoftmax(preds) preds = lbann.Slice(preds, axis=0, slice_points=str_list(range(sequence_length))) preds = [lbann.Identity(preds) for _ in range(sequence_length - 1)] # Count number of non-pad tokens label_tokens = lbann.Identity( lbann.Slice( input_, slice_points=str_list([sequence_length + 1, 2 * sequence_length]), )) pads = lbann.Constant(value=pad_index, num_neurons=str(sequence_length - 1)) is_not_pad = lbann.NotEqual(label_tokens, pads) num_not_pad = lbann.Reduction(is_not_pad, mode='sum') # Cross entropy loss with label smoothing label_tokens = lbann.Slice( label_tokens, slice_points=str_list(range(sequence_length)), ) label_tokens = [ lbann.Identity(label_tokens) for _ in range(sequence_length - 1) ] if label_smoothing > 0: uniform_label = lbann.Constant(value=1 / vocab_size, num_neurons=str_list([1, vocab_size])) loss = [] for i in range(sequence_length - 1): label = lbann.OneHot(label_tokens[i], size=vocab_size) label = lbann.Reshape(label, dims=str_list([1, vocab_size])) if label_smoothing > 0: label = lbann.WeightedSum( label, uniform_label, scaling_factors=str_list( [1 - label_smoothing, label_smoothing]), ) loss.append(lbann.CrossEntropy(preds[i], label)) loss = lbann.Concatenation(loss) # Average cross entropy over non-pad tokens loss_scales = lbann.Divide( is_not_pad, lbann.Tessellate(num_not_pad, hint_layer=is_not_pad), ) loss = lbann.Multiply(loss, loss_scales) loss = lbann.Reduction(loss, mode='sum') # Construct model metrics = [] callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()] return lbann.Model( num_epochs, layers=lbann.traverse_layer_graph(input_), objective_function=loss, metrics=metrics, callbacks=callbacks, )
def forward(self, x, prev_state): """ Apply GRU step channelwise Args: x (Layer): Input (shape: (num_channels, *)) prev_state (Layer): Sate from previous GRU step (shape: (num_channels, size)) Returns: (Layer, Layer): The output (out) and state (hn). The state can be passed directly into the next GRU step """ self.step += 1 name = f"{self.name}_step{self.step}" mat_size = self.num_channels * self.size prev_state = lbann.Reshape(prev_state, dims=str_list( [self.num_channels, self.size]), name=name + "_prev_state_reshape") fc1 = self.ih_fc(x) fc2 = self.hh_fc(prev_state) fc1_slice = lbann.Slice( fc1, axis=1, slice_points=str_list([0, self.size, 2 * self.size, 3 * self.size])) Wir_x = lbann.Reshape(lbann.Identity(fc1_slice), dims=str_list([self.num_channels, self.size]), name=name + '_Wir_x') Wiz_z = lbann.Reshape(lbann.Identity(fc1_slice), dims=str_list([self.num_channels, self.size]), name=name + '_Wiz_z') Win_x = lbann.Reshape(lbann.Identity(fc1_slice), dims=str_list([self.num_channels, self.size]), name=name + '_Win_x') fc2_slice = lbann.Slice( fc2, axis=1, slice_points=str_list([0, self.size, 2 * self.size, 3 * self.size])) Whr_x = lbann.Reshape(lbann.Identity(fc2_slice), dims=str_list([self.num_channels, self.size]), name=name + '_Whr_x') Whz_z = lbann.Reshape(lbann.Identity(fc2_slice), dims=str_list([self.num_channels, self.size]), name=name + '_Whz_z') Whn_x = lbann.Reshape(lbann.Identity(fc2_slice), dims=str_list([self.num_channels, self.size]), name=name + '_Whn_x') rt = \ lbann.Sigmoid( lbann.Add(Wir_x, Whr_x, data_layout=self.data_layout), name=name + '_reset_gate', data_layout=self.data_layout ) zt = \ lbann.Sigmoid( lbann.Add(Wiz_z, Whz_z, data_layout=self.data_layout), name=name + '_update_gate', data_layout=self.data_layout, ) nt = \ lbann.Tanh( lbann.Add( Win_x, lbann.Multiply(rt, Whn_x, data_layout=self.data_layout), data_layout=self.data_layout, ), name=name + '_new_gate', data_layout=self.data_layout, ) ht = \ lbann.Add( lbann.Multiply( lbann.WeightedSum( self.ones, zt, scaling_factors='1 -1', data_layout=self.data_layout ), nt, data_layout=self.data_layout ), lbann.Multiply(zt, prev_state, data_layout=self.data_layout), name=name+ '_output', data_layout=self.data_layout, ) ht = lbann.Reshape(ht, dims=str_list([self.num_channels, self.size])) return ht, ht
def forward(self, x, prev_state): """Apply LSTM step. Args: x (Layer): Input. prev_state (tuple with two `Layer`s): State from previous LSTM step. Comprised of LSTM output and cell state. Returns: (Layer, (Layer, Layer)): The output and state (the output and cell state). The state can be passed directly into the next LSTM step. """ self.step += 1 name = '{0}_step{1}'.format(self.name, self.step) # Get output and cell state from previous step prev_output, prev_cell = prev_state # Apply linearity input_concat = lbann.Concatenation(x, prev_output, name=name + '_input', data_layout=self.data_layout) fc = self.fc(input_concat) # Get gates and cell update slice = lbann.Slice(fc, slice_points=str_list([0, self.size, 4*self.size]), name=name + '_fc_slice', data_layout=self.data_layout) cell_update = lbann.Tanh(slice, name=name + '_cell_update', data_layout=self.data_layout) sigmoid = lbann.Sigmoid(slice, name=name + '_sigmoid', data_layout=self.data_layout) slice = lbann.Slice(sigmoid, slice_points=str_list([0, self.size, 2*self.size, 3*self.size]), name=name + '_sigmoid_slice', data_layout=self.data_layout) f = lbann.Identity(slice, name=name + '_forget_gate', data_layout=self.data_layout) i = lbann.Identity(slice, name=name + '_input_gate', data_layout=self.data_layout) o = lbann.Identity(slice, name=name + '_output_gate', data_layout=self.data_layout) # Cell state cell_forget = lbann.Multiply(f, prev_cell, name=name + '_cell_forget', data_layout=self.data_layout) cell_input = lbann.Multiply(i, cell_update, name=name + '_cell_input', data_layout=self.data_layout) cell = lbann.Add(cell_forget, cell_input, name=name + '_cell', data_layout=self.data_layout) # Output cell_act = lbann.Tanh(cell, name=name + '_cell_activation', data_layout=self.data_layout) output = lbann.Multiply(o, cell_act, name=name, data_layout=self.data_layout) # Return output and state return output, (output, cell)
def Gelu(x): x_erf = lbann.Erf(lbann.Scale(x, constant=(1 / math.sqrt(2)))) return lbann.Multiply( x, lbann.Scale(lbann.AddConstant(x_erf, constant=1), constant=0.5))
images = lbann.Reshape(images, dims='1 300 300') pred = model.PROBIESNet(num_labels)(images) mse = lbann.MeanSquaredError([responses, pred]) # Pearson Correlation # rho(x,y) = covariance(x,y) / sqrt( variance(x) * variance(y) ) pearson_r_cov = lbann.Covariance([pred, responses], name="pearson_r_cov") pearson_r_var1 = lbann.Variance(responses, name="pearson_r_var1") pearson_r_var2 = lbann.Variance(pred, name="pearson_r_var2") pearson_r_mult = lbann.Multiply([pearson_r_var1, pearson_r_var2], name="pearson_r_mult") pearson_r_sqrt = lbann.Sqrt(pearson_r_mult, name="pearson_r_sqrt") eps = lbann.Constant(value=1e-07, hint_layer=pearson_r_sqrt) pearson_r = lbann.Divide( [pearson_r_cov, lbann.Add(pearson_r_sqrt, eps)], name="pearson_r") metrics = [lbann.Metric(mse, name='mse')] metrics.append(lbann.Metric(pearson_r, name='pearson_r')) callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()] layers = list(lbann.traverse_layer_graph([images, responses])) model = lbann.Model(args.num_epochs, layers=layers,
def forward(self, x, prev_state): """Apply GRU step. Args: x (Layer): Input. prev_state: State from previous GRU step. Returns: (Layer, Layer): The output (out) and state (hn). The state can be passed directly into the next GRU step. """ self.step += 1 name = '{0}_step{1}'.format(self.name, self.step) fc1 = self.ih_fc(x) #input_fc fc2 = self.hh_fc(prev_state) #hidden_fc # Get gates and cell update fc1_slice = lbann.Slice(fc1, slice_points=str_list([0, self.size, 2*self.size, 3*self.size]), name=name + '_fc1_slice', data_layout=self.data_layout) Wir_x = lbann.Identity(fc1_slice, name=name + '_Wrx', data_layout=self.data_layout) Wiz_x = lbann.Identity(fc1_slice, name=name + '_Wzx', data_layout=self.data_layout) Win_x = lbann.Identity(fc1_slice, name=name + '_Wnx', data_layout=self.data_layout) fc2_slice = lbann.Slice(fc2, slice_points=str_list([0, self.size, 2*self.size, 3*self.size]), name=name + '_fc2_slice', data_layout=self.data_layout) Whr_prev = lbann.Identity(fc2_slice, name=name + '_Wrh', data_layout=self.data_layout) Whz_prev = lbann.Identity(fc2_slice, name=name + '_Wzh', data_layout=self.data_layout) Whn_prev = lbann.Identity(fc2_slice, name=name + '_Wnh', data_layout=self.data_layout) rt = \ lbann.Sigmoid( lbann.Add(Wir_x, Whr_prev, data_layout=self.data_layout), name=name + '_reset_gate', data_layout=self.data_layout ) zt = \ lbann.Sigmoid( lbann.Add(Wiz_x, Whz_prev, data_layout=self.data_layout), name=name + '_update_gate', data_layout=self.data_layout, ) nt = \ lbann.Tanh( lbann.Add( Win_x, lbann.Multiply(rt, Whn_prev, data_layout=self.data_layout), data_layout=self.data_layout, ), name=name + '_new_gate', data_layout=self.data_layout, ) ht = \ lbann.Add( lbann.Multiply( lbann.WeightedSum( self.ones, zt, scaling_factors='1 -1', data_layout=self.data_layout ), nt, data_layout=self.data_layout ), lbann.Multiply(zt, prev_state, data_layout=self.data_layout), name=name+ '_output', data_layout=self.data_layout, ) # Return output return ht, ht