Exemplo n.º 1
0
def test_rnn_fprop(sequence_length, input_size, hidden_size, batch_size,
                   return_sequence, weight_initializer, bias_initializer,
                   init_state, extra_axes, backward, transformer_factory):

    assert batch_size == 1, "the recurrent reference implementation only support batch size 1"

    # Get input placeholder and numpy array
    input_placeholder, input_value = make_placeholder(input_size,
                                                      sequence_length,
                                                      batch_size,
                                                      extra_axes=extra_axes)

    # Construct network weights and initial state, if desired
    W_in, W_rec, b, init_state, init_state_value = make_weights(
        input_placeholder, hidden_size, weight_initializer, bias_initializer,
        init_state)

    # Compute reference numpy RNN
    rnn_ref = RefRecurrent(input_size,
                           hidden_size,
                           return_sequence=return_sequence)
    rnn_ref.set_weights(W_in.reshape(rnn_ref.Wxh.shape), W_rec,
                        b.reshape(rnn_ref.bh.shape))

    # Compute reference numpy RNN
    input_shape = (input_size, sequence_length, batch_size)
    h_ref_list = rnn_ref.fprop_only(input_value.reshape(input_shape).transpose(
        [1, 0, 2]),
                                    init_states=init_state_value,
                                    backward=backward)

    # Generate ngraph RNN
    rnn_ng = Recurrent(hidden_size,
                       init=W_in,
                       init_inner=W_rec,
                       activation=Tanh(),
                       reset_cells=True,
                       return_sequence=return_sequence,
                       backward=backward)

    # fprop ngraph RNN
    out_ng = rnn_ng(input_placeholder, init_state=init_state)

    with ExecutorFactory() as ex:
        # Create computation and execute
        if init_state is not None:
            fprop_neon_fun = ex.executor(out_ng, input_placeholder, init_state)
            fprop_neon = fprop_neon_fun(input_value, init_state_value)

        else:
            fprop_neon_fun = ex.executor(out_ng, input_placeholder)
            fprop_neon = fprop_neon_fun(input_value)

        # Compare output with reference implementation
        if return_sequence is True:
            fprop_neon = fprop_neon[:, :, 0]
        ng.testing.assert_allclose(fprop_neon,
                                   h_ref_list,
                                   rtol=fprop_rtol,
                                   atol=fprop_atol)
Exemplo n.º 2
0
def test_rnn_deriv_ref(sequence_length, input_size, hidden_size, batch_size,
                       return_sequence, weight_initializer, bias_initializer,
                       transformer_factory):

    assert batch_size == 1, "the recurrent reference implementation only support batch size 1"
    assert return_sequence is True, "the reference rnn only supports sequences for deriv"

    # Get input placeholder and numpy array
    input_placeholder, input_value = make_placeholder(input_size, sequence_length, batch_size)

    # Construct network weights and initial state, if desired
    W_in, W_rec, b, init_state, init_state_value = make_weights(input_placeholder, hidden_size,
                                                                weight_initializer,
                                                                bias_initializer)

    # Compute reference numpy RNN
    rnn_ref = RefRecurrent(input_size, hidden_size, return_sequence=return_sequence)
    rnn_ref.set_weights(W_in, W_rec, b.reshape(rnn_ref.bh.shape))

    # Prepare deltas for gradient check
    output_shape = (hidden_size, sequence_length, batch_size)

    # generate random deltas tensor
    deltas = np.random.randn(*output_shape)

    # the reference code expects these shapes:
    # input_shape: (seq_len, input_size, batch_size)
    # output_shape: (seq_len, hidden_size, batch_size)
    dW_in, dW_rec, db = rnn_ref.lossFun(input_value.transpose([1, 0, 2]),
                                        deltas.copy().transpose([1, 0, 2]),
                                        init_states=init_state_value)[:3]

    # Generate ngraph RNN
    rnn_ng = Recurrent(hidden_size, init=W_in, init_inner=W_rec, activation=Tanh(),
                       reset_cells=True, return_sequence=return_sequence)

    # fprop ngraph RNN
    out_ng = rnn_ng.train_outputs(input_placeholder)

    deltas_constant = ng.constant(deltas, axes=out_ng.axes)
    params = [(rnn_ng.W_input, W_in),
              (rnn_ng.W_recur, W_rec),
              (rnn_ng.b, b)]

    with ExecutorFactory() as ex:
        # Create derivative computations and execute
        param_updates = list()
        for px, _ in params:
            update = ng.deriv(out_ng, px, error=deltas_constant)
            param_updates.append(ex.executor(update, input_placeholder))

        for update_fun, ref_val in zip(param_updates, [dW_in, dW_rec, db]):
            ng.testing.assert_allclose(update_fun(input_value),
                                       ref_val.squeeze(),
                                       rtol=bprop_rtol, atol=bprop_atol)