def train_step_non_teacher(inp, tar): loss = 0. output = tf.expand_dims(tar[:, 0], axis=1) with tf.GradientTape() as tape: for t in range(1, tf.shape(tar)[1]): predictions, _, _ = speech_model.model(inp, output, True) loss += loss_function(tar[:, t], predictions[:, -1, :]) tar_weight = tf.cast( tf.logical_not(tf.math.equal(tar[:, t], 0)), tf.int32) train_acc(tar[:, t], predictions[:, -1, :], sample_weight=tar_weight) # select the last word from the seq_len dimension predictions = predictions[:, -1:, :] # (batch_size, 1, vocab_size) predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32) # concatentate the predicted_id to the output which is given to the decoder # as its input. output = tf.concat([output, predicted_id], axis=-1) batch_loss = (loss / tf.cast(tf.shape(tar)[1] - 1, dtype=tf.float32)) gradients = tape.gradient(batch_loss, speech_model.model.trainable_variables) opt.apply_gradients( zip(gradients, speech_model.model.trainable_variables)) tar_len = tf.reduce_sum(tf.cast( tf.logical_not(tf.math.equal(tar[:, 1:], 0)), tf.int32), axis=-1) ler = label_error_rate(tar[:, 1:], predictions, tar_len) train_loss(batch_loss) train_ler(ler)
def train_step(inp, tar): tar_inp = tar[:, :-1] tar_real = tar[:, 1:] enc_padding_mask, combined_mask, dec_padding_mask = create_masks( inp, tar_inp) with tf.GradientTape() as tape: predictions, _ = speech_model.model(inp, tar_inp, True, enc_padding_mask, combined_mask, dec_padding_mask) loss = loss_function(tar_real, predictions) gradients = tape.gradient(loss, speech_model.model.trainable_variables) opt.apply_gradients( zip(gradients, speech_model.model.trainable_variables)) tar_weight = tf.cast(tf.logical_not(tf.math.equal(tar_real, 0)), tf.int32) tar_len = tf.reduce_sum(tar_weight, axis=-1) ler = label_error_rate(tar_real, predictions, tar_len) train_loss(loss) train_ler(ler) train_acc(tar_real, predictions, sample_weight=tar_weight)
def valid_step_teacher_forcing(inp, tar, model, loss_tb, ler_tb, acc_tb): tar_inp = tar[:, :-1] tar_real = tar[:, 1:] predictions, _, attention_weights = model.model(inp, tar_inp, False) loss = loss_function(tar_real, predictions) tar_weight = tf.cast(tf.logical_not(tf.math.equal(tar_real, 0)), tf.int32) tar_len = tf.reduce_sum(tar_weight, axis=-1) ler = label_error_rate(tar[:, 1:], predictions, tar_len) loss_tb(loss) ler_tb(ler) acc_tb(tar_real, predictions, sample_weight=tar_weight) return predictions.numpy()[0], tf.stack(attention_weights, axis=-1).numpy()[0]
def valid_step_teacher_forcing(inp, tar, model, loss_tb, ler_tb, acc_tb): tar_inp = tar[:, :-1] tar_real = tar[:, 1:] enc_padding_mask, combined_mask, dec_padding_mask = create_masks( inp, tar_inp) predictions, attention_weights = model.model(inp, tar_inp, False, enc_padding_mask, combined_mask, dec_padding_mask) loss = loss_function(tar_real, predictions) tar_weight = tf.cast(tf.logical_not(tf.math.equal(tar_real, 0)), tf.int32) tar_len = tf.reduce_sum(tar_weight, axis=-1) ler = label_error_rate(tar[:, 1:], predictions, tar_len) loss_tb(loss) ler_tb(ler) acc_tb(tar_real, predictions, sample_weight=tar_weight) return predictions.numpy()[0], attention_weights[ 'decoder_layer4_block2'].numpy()[0, :, :tar_len[0], :]