def optimize(grammar, productions, stack_shape, d_S, is_phi, output, steps):
    train_step = train_step_wrapped()

    for i in range(steps):
        loss, output_, stack_ = train_step(grammar, productions, stack_shape,
                                           d_S, is_phi, output)
        if i % 10 == 0:
            p_output = tokens_pretty_print(output_)
            p_stack = tokens_pretty_print(stack_)

            tf.print(loss, p_output, p_stack, tf.argmax(productions, axis=-1))
def dump_step_info(grammar, production, stack, output):
    gs, go = grammar
    top = stack_peek(stack)
    tf.print('p\t', tf.argmax(production))
    i = tf.argmax(production)
    j = tf.argmax(top)
    tf.print('G_s\t', tokens_pretty_print(gs[i][j]), (i, j))
    tf.print('G_o\t', tokens_pretty_print(go[i][j]), (i, j))
    tf.print('S_i+1\t', tokens_pretty_print(stack[0]), tf.argmax(stack[1]))
    tf.print('O_i+1\t', tokens_pretty_print(output[0]), tf.argmax(output[1]))
    tf.print('-' * 80)
def test_train_step():
    train_step = train_step_wrapped()
    # MAX_PRODUCTIONS = 5
    # productions = tf.Variable(tf.one_hot([0] * MAX_PRODUCTIONS, PRODUCTION_DIM), dtype=tf.float32)
    productions = tf.Variable(tf.one_hot([2, 3, 0, 0, 0], PRODUCTION_DIM, dtype=tf.float32))
    # productions = tf.Variable(tf.one_hot([2, 0, 0, 0, 0], PRODUCTION_DIM), dtype=tf.float32)
    stack_shape = (STACK_SIZE, TOKEN_DIM)
    d_S = tf.constant(S, dtype=tf.float32)
    output = encode_to_tokens('x + x', TOKEN_DIM, STACK_SIZE)

    tf.config.experimental_run_functions_eagerly(True)
    loss, output_, stack_ = train_step(grammar, productions, stack_shape, d_S, is_phi, output, True)
    tf.config.experimental_run_functions_eagerly(False)
    tf.print(loss)
    tf.print(tokens_pretty_print(output_), tokens_pretty_print(output))
    tf.print(tokens_pretty_print(stack_))
def test_production_step():
    stack = new_stack(((STACK_SIZE, TOKEN_DIM)))
    output = new_stack(((STACK_SIZE, TOKEN_DIM)))

    stack = safe_push(stack, tf.constant(S, dtype=tf.float32), is_phi)
    production = tf.one_hot(2, PRODUCTION_DIM)
    phi = tf.one_hot(0, TOKEN_DIM, dtype=tf.float32)

    with tf.GradientTape(persistent=True) as tape:
        tape.watch(grammar)
        tape.watch(production)
        tape.watch(stack)
        tape.watch(output)

        new_s, new_o = production_step(grammar, production, stack, output, phi,
                                       is_phi)

    tf.print(tokens_pretty_print(new_s[0]))
    tf.print(tape.gradient(new_o, output))
    tf.print(tape.gradient(new_s, stack))
    tf.print(tape.gradient(new_s[0], grammar[0]).shape)
    tf.print(tape.gradient(new_s[1], grammar[0]).shape)
    tf.print(tape.gradient(new_o[0], grammar[1]).shape)
    tf.print(tape.gradient(new_o[1], grammar[1]).shape)
    tf.print(tape.gradient(new_s, production))
def test_generate():
    tf.config.experimental_run_functions_eagerly(True)

    productions = tf.one_hot([2, 3, 0, 1, 0], PRODUCTION_DIM)

    stack_shape = (STACK_SIZE, TOKEN_DIM)
    d_S = tf.constant(S, dtype=tf.float32)
    d_phi = tf.constant(tf.one_hot([0], TOKEN_DIM))

    with tf.GradientTape(persistent=True) as tape:
        tape.watch(productions)
        output, final_stack = generate(grammar, productions, stack_shape, d_S,
                                       d_phi, is_phi, True)

    tf.config.experimental_run_functions_eagerly(False)
    tf.print('Final result:')
    tf.print(tokens_pretty_print(output))
    tf.print('-' * 80)
    tf.print('Final stack:')
    tf.print(tokens_pretty_print(final_stack))
    tf.print('-' * 80)
    tf.print(tape.gradient(output, productions))
def test_tokens_pretty_print():
    TOKEN_DIM = 6
    tokens = tf.transpose(
        tf.one_hot([0, 1, 2, 3, 4, 5], TOKEN_DIM, dtype=tf.float32))
    tokens_pretty_print(tokens)