def forward(self, x, label): """Compute cross-entropy loss. Args: x (lbann.Layer): Input vector. label (lbann.Layer): Label. Should have one entry, which will be cast to an integer. Returns: lbann.Layer: Loss function value. """ log_probs = self.fc(x) label_onehot = lbann.OneHot( label, size=self.num_classes, data_layout=self.data_layout, ) loss = lbann.Multiply( log_probs, label_onehot, data_layout=self.data_layout, ) loss = lbann.Reduction( loss, mode="sum", data_layout=self.data_layout, ) loss = lbann.Negative(loss, data_layout=self.data_layout) return loss
def construct_model(): """Model description """ import lbann import lbann.modules fc = lbann.modules.FullyConnectedModule conv = lbann.modules.Convolution2dModule conv1 = conv(20, 3, stride=1, padding=1, name='conv1') conv2 = conv(20, 3, stride=1, padding=1, name='conv2') fc1 = fc(100, name='fc1') fc2 = fc(20, name='fc2') fc3 = fc(num_classes, name='fc3') # Layer graph input = lbann.Input(name='inp_tensor', target_mode='classification') inp_slice = lbann.Slice(input, axis=0, slice_points=str_list([0, dims - 1, dims]), name='inp_slice') xdata = lbann.Identity(inp_slice) ylabel = lbann.Identity(inp_slice, name='gt_y') #NHWC to NCHW x = lbann.Reshape(xdata, dims='14 13 13') x = conv2(conv1(x)) x = lbann.Reshape(x, dims='3380') x = lbann.Dropout(lbann.Relu(fc1(x)), keep_prob=0.5) x = lbann.Dropout(fc2(x), keep_prob=0.5) pred = lbann.Softmax(fc3(x)) gt_label = lbann.OneHot(ylabel, size=num_classes) loss = lbann.CrossEntropy([pred, gt_label], name='loss') acc = lbann.CategoricalAccuracy([pred, gt_label]) layers = list(lbann.traverse_layer_graph(input)) # Setup objective function weights = set() for l in layers: weights.update(l.weights) obj = lbann.ObjectiveFunction(loss) callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()] # Construct model num_epochs = 10 return lbann.Model(num_epochs, weights=weights, layers=layers, metrics=[lbann.Metric(acc, name='accuracy', unit='%')], objective_function=obj, callbacks=callbacks)
def compute_loss(self, x, y): # y[:, :-1] y = lbann.Slice( y, axis=0, slice_points=str_list([0, self.input_feature_dims - 1]), ) y = lbann.Identity(y) # x[:, 1:] x = lbann.Slice( x, slice_points=str_list([1, self.input_feature_dims]), ) x = lbann.Identity(x) # Convert indices in x to one-hot representation # Note: Ignored indices result in zero vectors ignore_mask = lbann.Equal( x, self.constant(self.label_to_ignore, hint_layer=x), ) keep_mask = lbann.LogicalNot(ignore_mask) length = lbann.Reduction(keep_mask, mode='sum') length = lbann.Max(length, self.constant(1, [1])) x = lbann.Add( lbann.Multiply(keep_mask, x), lbann.Multiply(ignore_mask, self.constant(-1, hint_layer=x)), ) x = lbann.Slice(x, slice_points=str_list(range(self.input_feature_dims))) x = [lbann.Identity(x) for _ in range(self.input_feature_dims - 1)] x = [lbann.OneHot(xi, size=self.dictionary_size) for xi in x] x = [ lbann.Reshape(xi, dims=str_list([1, self.dictionary_size])) for xi in x ] x = lbann.Concatenation(x, axis=0) # recon_loss = F.cross_entropy( # y[:, :-1].contiguous().view(-1, y.size(-1)), # x[:, 1:].contiguous().view(-1), # ignore_index=self.pad # ) # Note: Ideally we'd shift y by y.max(-1) for numerical stability shifts = lbann.MatMul( lbann.Max(y, self.constant(0, hint_layer=y)), self.constant( 1 / math.sqrt(self.dictionary_size), [self.dictionary_size, self.dictionary_size], ), ) y = lbann.Subtract(y, shifts) z = lbann.MatMul( lbann.Exp(y), self.constant(1, [self.dictionary_size, 1]), ) z = lbann.Log(z) z = lbann.MatMul( lbann.Reshape(keep_mask, dims=str_list([1, -1])), z, ) recon_loss = lbann.MatMul( lbann.Reshape(y, dims=str_list([1, -1])), lbann.Reshape(x, dims=str_list([1, -1])), transpose_b=True, ) recon_loss = lbann.Subtract(z, recon_loss) recon_loss = lbann.Reshape(recon_loss, dims=str_list([1])) recon_loss = lbann.Divide(recon_loss, length) return recon_loss
] pred_fc = lbann.modules.FullyConnectedModule(vocab_size, data_layout='model_parallel') # Iterate through RNN steps loss = [] for step in range(sequence_length - 1): # Predict next token with RNN x = embeddings_list[step] x, lstm_state = lstm(x, lstm_state) x = pred_fc(x) pred = lbann.Softmax(x) # Evaluate prediction with cross entropy ground_truth = lbann.OneHot(tokens_list[step + 1], size=vocab_size) cross_entropy = lbann.CrossEntropy([pred, ground_truth]) loss.append(lbann.LayerTerm(cross_entropy, scale=1 / (sequence_length - 1))) # ---------------------------------- # Create data reader # ---------------------------------- reader = lbann.reader_pb2.DataReader() _reader = reader.reader.add() _reader.name = 'python' _reader.role = 'train' _reader.shuffle = True _reader.percent_of_data_to_use = 1.0 _reader.python.module = 'dataset'
def make_model( num_epochs, embed_dim, num_heads, label_smoothing, ): # Embedding weights var = 2 / (embed_dim + vocab_size) # Glorot initialization embedding_weights = lbann.Weights( name='embeddings', initializer=lbann.NormalInitializer(standard_deviation=math.sqrt(var)), ) # Input is two sequences of token IDs input_ = lbann.Input(data_field='samples') # Get sequences of embedding vectors # Note: Scale embeddings by sqrt(embed_dim). # Note: Decoder input is shifted right, so embedding for last # token isn't needed. embeddings_tokens = lbann.Identity( lbann.Slice( input_, axis=0, slice_points=str_list([0, 2 * sequence_length - 1]), )) embeddings = lbann.Embedding( embeddings_tokens, weights=embedding_weights, num_embeddings=vocab_size, embedding_dim=embed_dim, padding_idx=pad_index, ) embeddings = lbann.WeightedSum( embeddings, scaling_factors=str(math.sqrt(embed_dim)), ) embeddings_slice = lbann.Slice( embeddings, axis=0, slice_points=str_list([0, sequence_length, 2 * sequence_length - 1]), ) encoder_input = lbann.Identity(embeddings_slice) decoder_input = lbann.Identity(embeddings_slice) # Apply transformer model transformer = lbann.models.Transformer( hidden_size=embed_dim, num_heads=num_heads, name='transformer', ) result = transformer( encoder_input, sequence_length, decoder_input, sequence_length - 1, ) # Reconstruct decoder input preds = lbann.ChannelwiseFullyConnected( result, weights=embedding_weights, output_channel_dims=[vocab_size], bias=False, transpose=True, ) preds = lbann.ChannelwiseSoftmax(preds) preds = lbann.Slice(preds, axis=0, slice_points=str_list(range(sequence_length))) preds = [lbann.Identity(preds) for _ in range(sequence_length - 1)] # Count number of non-pad tokens label_tokens = lbann.Identity( lbann.Slice( input_, slice_points=str_list([sequence_length + 1, 2 * sequence_length]), )) pads = lbann.Constant(value=pad_index, num_neurons=str(sequence_length - 1)) is_not_pad = lbann.NotEqual(label_tokens, pads) num_not_pad = lbann.Reduction(is_not_pad, mode='sum') # Cross entropy loss with label smoothing label_tokens = lbann.Slice( label_tokens, slice_points=str_list(range(sequence_length)), ) label_tokens = [ lbann.Identity(label_tokens) for _ in range(sequence_length - 1) ] if label_smoothing > 0: uniform_label = lbann.Constant(value=1 / vocab_size, num_neurons=str_list([1, vocab_size])) loss = [] for i in range(sequence_length - 1): label = lbann.OneHot(label_tokens[i], size=vocab_size) label = lbann.Reshape(label, dims=str_list([1, vocab_size])) if label_smoothing > 0: label = lbann.WeightedSum( label, uniform_label, scaling_factors=str_list( [1 - label_smoothing, label_smoothing]), ) loss.append(lbann.CrossEntropy(preds[i], label)) loss = lbann.Concatenation(loss) # Average cross entropy over non-pad tokens loss_scales = lbann.Divide( is_not_pad, lbann.Tessellate(num_not_pad, hint_layer=is_not_pad), ) loss = lbann.Multiply(loss, loss_scales) loss = lbann.Reduction(loss, mode='sum') # Construct model metrics = [] callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()] return lbann.Model( num_epochs, layers=lbann.traverse_layer_graph(input_), objective_function=loss, metrics=metrics, callbacks=callbacks, )
def construct_model(run_args): """Construct LBANN model. Initial model for ATOM molecular SMILES generation Network architecture and training hyperparameters from https://github.com/samadejacobs/moses/tree/master/moses/char_rnn """ pad_index = run_args.pad_index assert pad_index is not None sequence_length = run_args.sequence_length assert sequence_length is not None print("sequence length is {}".format(sequence_length)) data_layout = "data_parallel" # Layer graph _input = lbann.Input(name="inp_tensor", data_field='samples') print(sequence_length) x_slice = lbann.Slice( _input, axis=0, slice_points=str_list(range(sequence_length + 1)), name="inp_slice", ) # embedding layer emb = [] embedding_dim = run_args.embedding_dim num_embeddings = run_args.num_embeddings assert embedding_dim is not None assert num_embeddings is not None emb_weights = lbann.Weights( initializer=lbann.NormalInitializer(mean=0, standard_deviation=1), name="emb_matrix", ) lstm1 = lbann.modules.GRU(size=run_args.hidden, data_layout=data_layout) fc = lbann.modules.FullyConnectedModule(size=num_embeddings, data_layout=data_layout) last_output = lbann.Constant( value=0.0, num_neurons="{}".format(run_args.hidden), data_layout=data_layout, name="lstm_init_output", ) lstm1_prev_state = [last_output] loss = [] idl = [] for i in range(sequence_length): idl.append( lbann.Identity(x_slice, name="slice_idl_" + str(i), device="CPU")) for i in range(sequence_length - 1): emb_l = lbann.Embedding( idl[i], name="emb_" + str(i), weights=emb_weights, embedding_dim=embedding_dim, num_embeddings=num_embeddings, ) x, lstm1_prev_state = lstm1(emb_l, lstm1_prev_state) fc_l = fc(x) y_soft = lbann.Softmax(fc_l, name="soft_" + str(i)) gt = lbann.OneHot(idl[i + 1], size=num_embeddings) ce = lbann.CrossEntropy([y_soft, gt], name="loss_" + str(i)) # mask padding in input pad_mask = lbann.NotEqual( [idl[i], lbann.Constant(value=pad_index, num_neurons="1")], ) ce_mask = lbann.Multiply([pad_mask, ce], name="loss_mask_" + str(i)) loss.append(lbann.LayerTerm(ce_mask, scale=1 / (sequence_length - 1))) layers = list(lbann.traverse_layer_graph(_input)) # Setup objective function weights = set() for l in layers: weights.update(l.weights) obj = lbann.ObjectiveFunction(loss) callbacks = [ lbann.CallbackPrint(), lbann.CallbackTimer(), lbann.CallbackStepLearningRate(step=run_args.step_size, amt=run_args.gamma), lbann.CallbackDumpWeights(directory=run_args.dump_weights_dir, epoch_interval=1), ] # Construct model return lbann.Model(run_args.num_epochs, layers=layers, weights=weights, objective_function=obj, callbacks=callbacks)