def train_step(inp, tar): tar_inp = tar[:, :-1] tar_real = tar[:, 1:] enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp) with tf.GradientTape() as tape: # predictions.shape == (batch_size, seq_len, vocab_size) predictions, _ = transformer(inp, tar_inp, True, enc_padding_mask, combined_mask, dec_padding_mask) loss = loss_function(tar_real, predictions) gradients = tape.gradient(loss, transformer.trainable_variables) optimizer.apply_gradients(zip(gradients, transformer.trainable_variables)) train_loss(loss) train_acc(tar_real, predictions)
def evaluate(inp_sentence): start_token = [tokenizer_txt.vocab_size] end_token = [tokenizer_txt.vocab_size + 1] # The input is a MWP, hence adding the start and end token inp_sentence = start_token + \ tokenizer_txt.encode(inp_sentence) + end_token encoder_input = tf.expand_dims(inp_sentence, 0) # The target is an equation, the first word to the transformer should be the # equation start token. decoder_input = [tokenizer_eq.vocab_size] output = tf.expand_dims(decoder_input, 0) for i in range(MAX_LENGTH): enc_padding_mask, combined_mask, dec_padding_mask = create_masks(encoder_input, output) # predictions.shape == (batch_size, seq_len, vocab_size) predictions, attention_weights = transformer(encoder_input, output, False, enc_padding_mask, combined_mask, dec_padding_mask) # select the last word from the seq_len dimension predictions = predictions[:, -1:, :] # (batch_size, 1, vocab_size) predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32) # return the result if the predicted_id is equal to the end token if tf.equal(predicted_id, tokenizer_eq.vocab_size + 1): return tf.squeeze(output, axis=0), attention_weights # concatentate the predicted_id to the output which is given to the decoder # as its input. output = tf.concat([output, predicted_id], axis=-1) return tf.squeeze(output, axis=0), attention_weights