def construct_model(): """Construct LBANN model. Pilot1 Combo model """ import lbann # Layer graph data = lbann.Input(data_field='samples') responses = lbann.Input(data_field='responses') pred = combo.Combo()(data) mse = lbann.MeanSquaredError([responses, pred]) SS_res = lbann.Reduction(lbann.Square(lbann.Subtract(responses, pred)), mode='sum') #SS_tot = var(x) = mean((x-mean(x))^2) mini_batch_size = lbann.MiniBatchSize() mean = lbann.Divide(lbann.BatchwiseReduceSum(responses), mini_batch_size) SS_tot = lbann.Divide( lbann.BatchwiseReduceSum(lbann.Square(lbann.Subtract(responses, mean))), mini_batch_size) eps = lbann.Constant(value=1e-07, hint_layer=SS_tot) r2 = lbann.Subtract(lbann.Constant(value=1, num_neurons='1'), lbann.Divide(SS_res, lbann.Add(SS_tot, eps))) metrics = [lbann.Metric(mse, name='mse')] metrics.append(lbann.Metric(r2, name='r2')) callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()] # Construct model num_epochs = 100 layers = list(lbann.traverse_layer_graph([data, responses])) return lbann.Model(num_epochs, layers=layers, metrics=metrics, objective_function=mse, callbacks=callbacks)
# Pearson Correlation # rho(x,y) = covariance(x,y) / sqrt( variance(x) * variance(y) ) pearson_r_cov = lbann.Covariance([pred, responses], name="pearson_r_cov") pearson_r_var1 = lbann.Variance(responses, name="pearson_r_var1") pearson_r_var2 = lbann.Variance(pred, name="pearson_r_var2") pearson_r_mult = lbann.Multiply([pearson_r_var1, pearson_r_var2], name="pearson_r_mult") pearson_r_sqrt = lbann.Sqrt(pearson_r_mult, name="pearson_r_sqrt") eps = lbann.Constant(value=1e-07, hint_layer=pearson_r_sqrt) pearson_r = lbann.Divide( [pearson_r_cov, lbann.Add(pearson_r_sqrt, eps)], name="pearson_r") metrics = [lbann.Metric(mse, name='mse')] metrics.append(lbann.Metric(pearson_r, name='pearson_r')) callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()] layers = list(lbann.traverse_layer_graph([images, responses])) model = lbann.Model(args.num_epochs, layers=layers, metrics=metrics, objective_function=mse, callbacks=callbacks) # Load data reader from prototext data_reader_proto = lbann.lbann_pb2.LbannPB()
def compute_loss(self, x, y): # y[:, :-1] y = lbann.Slice( y, axis=0, slice_points=str_list([0, self.input_feature_dims - 1]), ) y = lbann.Identity(y) # x[:, 1:] x = lbann.Slice( x, slice_points=str_list([1, self.input_feature_dims]), ) x = lbann.Identity(x) # Convert indices in x to one-hot representation # Note: Ignored indices result in zero vectors ignore_mask = lbann.Equal( x, self.constant(self.label_to_ignore, hint_layer=x), ) keep_mask = lbann.LogicalNot(ignore_mask) length = lbann.Reduction(keep_mask, mode='sum') length = lbann.Max(length, self.constant(1, [1])) x = lbann.Add( lbann.Multiply(keep_mask, x), lbann.Multiply(ignore_mask, self.constant(-1, hint_layer=x)), ) x = lbann.Slice(x, slice_points=str_list(range(self.input_feature_dims))) x = [lbann.Identity(x) for _ in range(self.input_feature_dims - 1)] x = [lbann.OneHot(xi, size=self.dictionary_size) for xi in x] x = [ lbann.Reshape(xi, dims=str_list([1, self.dictionary_size])) for xi in x ] x = lbann.Concatenation(x, axis=0) # recon_loss = F.cross_entropy( # y[:, :-1].contiguous().view(-1, y.size(-1)), # x[:, 1:].contiguous().view(-1), # ignore_index=self.pad # ) # Note: Ideally we'd shift y by y.max(-1) for numerical stability shifts = lbann.MatMul( lbann.Max(y, self.constant(0, hint_layer=y)), self.constant( 1 / math.sqrt(self.dictionary_size), [self.dictionary_size, self.dictionary_size], ), ) y = lbann.Subtract(y, shifts) z = lbann.MatMul( lbann.Exp(y), self.constant(1, [self.dictionary_size, 1]), ) z = lbann.Log(z) z = lbann.MatMul( lbann.Reshape(keep_mask, dims=str_list([1, -1])), z, ) recon_loss = lbann.MatMul( lbann.Reshape(y, dims=str_list([1, -1])), lbann.Reshape(x, dims=str_list([1, -1])), transpose_b=True, ) recon_loss = lbann.Subtract(z, recon_loss) recon_loss = lbann.Reshape(recon_loss, dims=str_list([1])) recon_loss = lbann.Divide(recon_loss, length) return recon_loss
data_layout="model_parallel") pearson_r_var2 = lbann.Variance(reconstruction, name="pearson_r_var1", data_layout="model_parallel") pearson_r_mult = lbann.Multiply([pearson_r_var1, pearson_r_var2], name="pearson_r_mult", data_layout="model_parallel") pearson_r_sqrt = lbann.Sqrt(pearson_r_mult, name="pearson_r_sqrt", data_layout="model_parallel") pearson_r = lbann.Divide([pearson_r_cov, pearson_r_sqrt], name="pearson_r", data_layout="model_parallel") layer_list = list(lbann.traverse_layer_graph(input_)) # Set up objective function layer_term = lbann.LayerTerm(mean_squared_error) obj = lbann.ObjectiveFunction(layer_term) # Metrics metrics = [lbann.Metric(pearson_r, name="Pearson correlation")] # Callbacks callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()] # Setup Model
def make_model( num_epochs, embed_dim, num_heads, label_smoothing, ): # Embedding weights var = 2 / (embed_dim + vocab_size) # Glorot initialization embedding_weights = lbann.Weights( name='embeddings', initializer=lbann.NormalInitializer(standard_deviation=math.sqrt(var)), ) # Input is two sequences of token IDs input_ = lbann.Input(data_field='samples') # Get sequences of embedding vectors # Note: Scale embeddings by sqrt(embed_dim). # Note: Decoder input is shifted right, so embedding for last # token isn't needed. embeddings_tokens = lbann.Identity( lbann.Slice( input_, axis=0, slice_points=str_list([0, 2 * sequence_length - 1]), )) embeddings = lbann.Embedding( embeddings_tokens, weights=embedding_weights, num_embeddings=vocab_size, embedding_dim=embed_dim, padding_idx=pad_index, ) embeddings = lbann.WeightedSum( embeddings, scaling_factors=str(math.sqrt(embed_dim)), ) embeddings_slice = lbann.Slice( embeddings, axis=0, slice_points=str_list([0, sequence_length, 2 * sequence_length - 1]), ) encoder_input = lbann.Identity(embeddings_slice) decoder_input = lbann.Identity(embeddings_slice) # Apply transformer model transformer = lbann.models.Transformer( hidden_size=embed_dim, num_heads=num_heads, name='transformer', ) result = transformer( encoder_input, sequence_length, decoder_input, sequence_length - 1, ) # Reconstruct decoder input preds = lbann.ChannelwiseFullyConnected( result, weights=embedding_weights, output_channel_dims=[vocab_size], bias=False, transpose=True, ) preds = lbann.ChannelwiseSoftmax(preds) preds = lbann.Slice(preds, axis=0, slice_points=str_list(range(sequence_length))) preds = [lbann.Identity(preds) for _ in range(sequence_length - 1)] # Count number of non-pad tokens label_tokens = lbann.Identity( lbann.Slice( input_, slice_points=str_list([sequence_length + 1, 2 * sequence_length]), )) pads = lbann.Constant(value=pad_index, num_neurons=str(sequence_length - 1)) is_not_pad = lbann.NotEqual(label_tokens, pads) num_not_pad = lbann.Reduction(is_not_pad, mode='sum') # Cross entropy loss with label smoothing label_tokens = lbann.Slice( label_tokens, slice_points=str_list(range(sequence_length)), ) label_tokens = [ lbann.Identity(label_tokens) for _ in range(sequence_length - 1) ] if label_smoothing > 0: uniform_label = lbann.Constant(value=1 / vocab_size, num_neurons=str_list([1, vocab_size])) loss = [] for i in range(sequence_length - 1): label = lbann.OneHot(label_tokens[i], size=vocab_size) label = lbann.Reshape(label, dims=str_list([1, vocab_size])) if label_smoothing > 0: label = lbann.WeightedSum( label, uniform_label, scaling_factors=str_list( [1 - label_smoothing, label_smoothing]), ) loss.append(lbann.CrossEntropy(preds[i], label)) loss = lbann.Concatenation(loss) # Average cross entropy over non-pad tokens loss_scales = lbann.Divide( is_not_pad, lbann.Tessellate(num_not_pad, hint_layer=is_not_pad), ) loss = lbann.Multiply(loss, loss_scales) loss = lbann.Reduction(loss, mode='sum') # Construct model metrics = [] callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()] return lbann.Model( num_epochs, layers=lbann.traverse_layer_graph(input_), objective_function=loss, metrics=metrics, callbacks=callbacks, )
def compute_loss(self, x, y): # y[:, :-1] y = lbann.Slice( y, axis=0, slice_points=str_list([0, self.input_feature_dims-1]), ) y = lbann.Identity(y) # x[:, 1:] x = lbann.Slice( x, slice_points=str_list([1, self.input_feature_dims]), ) x = lbann.Identity(x) # Figure out entries in x to ignore ignore_mask = lbann.Equal( x, self.constant(self.label_to_ignore, hint_layer=x), ) keep_mask = lbann.LogicalNot(ignore_mask) length = lbann.Reduction(keep_mask, mode='sum') length = lbann.Max(length, self.constant(1, [1])) # Convert entries in x to indices in y # Note: Ignored entries correspond to an index of -1. offsets = [ row*self.dictionary_size for row in range(self.input_feature_dims-1) ] offsets = lbann.Weights( initializer=lbann.ValueInitializer(values=str_list(offsets)), optimizer=lbann.NoOptimizer(), ) offsets = lbann.WeightsLayer( dims=str_list([self.input_feature_dims-1]), weights=offsets, ) y_inds = lbann.Add(x, offsets) y_inds = lbann.Add( lbann.Multiply(keep_mask, y_inds), lbann.Multiply( ignore_mask, self.constant(-1, hint_layer=y_inds), ), ) # recon_loss = F.cross_entropy( # y[:, :-1].contiguous().view(-1, y.size(-1)), # x[:, 1:].contiguous().view(-1), # ignore_index=self.pad # ) # Shift y for numerical stability # Note: We'd prefer to shift by y.max(-1) shifts = lbann.MatMul( lbann.Max(y, self.constant(0, hint_layer=y)), self.constant( 1 / math.sqrt(self.dictionary_size), [self.dictionary_size, self.dictionary_size], ), ) y = lbann.Subtract(y, shifts) # Compute log of softmax denominator and sum z = lbann.MatMul( lbann.Exp(y), self.constant(1, [self.dictionary_size, 1]), ) z = lbann.Log(z) z = lbann.MatMul( lbann.Reshape(keep_mask, dims=str_list([1, -1])), z, ) z = lbann.Reshape(z, dims=str_list([1])) # Compute cross entropy recon_loss = lbann.Gather( lbann.Reshape(y, dims=str_list([-1])), y_inds, ) recon_loss = lbann.Reduction(recon_loss, mode='sum') recon_loss = lbann.Subtract(z, recon_loss) recon_loss = lbann.Divide(recon_loss, length) return recon_loss