def BA_logits(psp, W_out, BA_out):
    logits = einsum_bij_jk_to_bik(psp, W_out)

    def grad(dy):
        dloss_dw_out = tf.einsum('btj,btk->jk', psp, dy)
        dloss_dba_out = tf.einsum('btj,btk->jk', psp, dy) if FLAGS.eprop == 'adaptive' else tf.zeros_like(BA_out)
        dloss_dpsp = tf.einsum('bik,jk->bij', dy, BA_out)
        return [dloss_dpsp, dloss_dw_out, dloss_dba_out]

    return logits, grad
                rd.randn(n_outputs, N_output_classes_with_blank) /
                np.sqrt(n_outputs),
                dtype=tf.float32)
            BA_out = tf.Variable(init_w_out,
                                 dtype=tf.float32,
                                 name='BroadcastWeights')
        else:
            init_w_out = tf.constant(rd.randn(n_outputs,
                                              N_output_classes_with_blank),
                                     dtype=tf.float32)
            BA_out = tf.constant(init_w_out,
                                 dtype=tf.float32,
                                 name='BroadcastWeights')
        phn_logits = BA_logits(outputs, w_out, BA_out) + b_out
    else:
        phn_logits = einsum_bij_jk_to_bik(outputs, w_out) + b_out

if FLAGS.eprop == 'adaptive':
    weight_decay = tf.constant(FLAGS.readout_decay, dtype=tf.float32)
    w_out_decay = tf.assign(w_out, w_out - weight_decay * w_out)
    BA_decay = tf.assign(BA_out, BA_out - weight_decay * BA_out)
    KolenPollackDecay = [BA_decay, w_out_decay]

# Firing rate regularization
with tf.name_scope('RegularizationLoss'):
    av = tf.reduce_mean(tf.concat(output_list, axis=2), axis=(0, 1)) / FLAGS.dt
    loss_reg = tf.reduce_sum(tf.square(av - FLAGS.reg_rate / 1000) * FLAGS.reg)

# Define the graph for the loss function and the definition of the error
with tf.name_scope('Loss'):
    phn_logits_time_major = tf.transpose(phn_logits, [1, 0, 2])
    if FLAGS.eprop in ['adaptive', 'random']:
        if FLAGS.BAglobal:
            BA_out = tf.constant(np.ones((n_outputs, N_output_classes_with_blank)) / np.sqrt(n_outputs),
                                 dtype=tf.float32, name='BroadcastWeights')
        else:
            if FLAGS.eprop == 'adaptive':
                init_w_out = rd.randn(n_outputs, N_output_classes_with_blank) / np.sqrt(n_outputs)
                BA_out = tf.Variable(init_w_out, dtype=tf.float32, name='BroadcastWeights')
            else:
                init_w_out = rd.randn(n_outputs, N_output_classes_with_blank)
                BA_out = tf.constant(init_w_out, dtype=tf.float32, name='BroadcastWeights')

        phn_logits = BA_logits(lsnn_out, w_out, BA_out)
    else:
        print("Broadcast alignment disabled!")
        phn_logits = einsum_bij_jk_to_bik(lsnn_out, w_out)

    if FLAGS.readout_bias:
        b_out = tf.Variable(np.zeros(N_output_classes_with_blank), dtype=tf.float32, name="OutBias")
        phn_logits += b_out

if FLAGS.eprop == 'adaptive':
    weight_decay = tf.constant(FLAGS.readout_decay, dtype=tf.float32)
    w_out_decay = tf.assign(w_out, w_out - weight_decay * w_out)
    BA_decay = tf.assign(BA_out, BA_out - weight_decay * BA_out)
    KolenPollackDecay = [BA_decay, w_out_decay]

# Firing rate regularization
with tf.name_scope('RegularizationLoss'):
    av = tf.reduce_mean(outputs, axis=(0, 1)) / FLAGS.dt
    regularization_coeff = tf.Variable(np.ones(n_outputs) * FLAGS.reg,