示例#1
0
def test_birnn_deriv_numerical(sequence_length, input_size, hidden_size, batch_size,
                               return_sequence, weight_initializer, bias_initializer,
                               sum_out, concat_out):

    # Get input placeholder and numpy array
    input_placeholder, input_value = make_placeholder(input_size, sequence_length, batch_size)

    # Construct network weights and initial state, if desired
    W_in, W_rec, b, init_state, init_state_value = make_weights(input_placeholder, hidden_size,
                                                                weight_initializer,
                                                                bias_initializer)

    # Generate ngraph RNN
    rnn_ng = BiRNN(hidden_size, init=W_in, init_inner=W_rec, activation=Tanh(),
                   reset_cells=True, return_sequence=return_sequence,
                   sum_out=sum_out, concat_out=concat_out)

    # fprop ngraph RNN
    out_ng = rnn_ng.train_outputs(input_placeholder)

    w_in_f = rnn_ng.fwd_rnn.W_input
    w_rec_f = rnn_ng.fwd_rnn.W_recur
    b_f = rnn_ng.fwd_rnn.b
    w_in_b = rnn_ng.bwd_rnn.W_input
    w_rec_b = rnn_ng.bwd_rnn.W_recur
    b_b = rnn_ng.bwd_rnn.b

    params_f = [(w_in_f, W_in),
                (w_rec_f, W_rec),
                (b_f, b)]

    params_b = [(w_in_b, W_in),
                (w_rec_b, W_rec),
                (b_b, b)]

    if sum_out or concat_out:
        out_ng = [out_ng]
        params_birnn = [params_f + params_b]
    else:
        # in this case out_ng will be a list
        params_birnn = [params_f, params_b]

    with ExecutorFactory() as ex:
        # Create derivative computations and execute
        param_updates = list()
        dep_list = list()
        for output, dependents in zip(out_ng, params_birnn):
            for px, _ in dependents:
                update = (ex.derivative(output, px, input_placeholder),
                          ex.numeric_derivative(output, px, delta, input_placeholder))
                param_updates.append(update)
            dep_list += dependents

        for ii, ((deriv_s, deriv_n), (_, val)) in enumerate(zip(param_updates, dep_list)):
            ng.testing.assert_allclose(deriv_s(val, input_value),
                                       deriv_n(val, input_value),
                                       rtol=num_rtol,
                                       atol=num_atol)
示例#2
0
def test_birnn_fprop(sequence_length, input_size, hidden_size, batch_size,
                     return_sequence, weight_initializer, bias_initializer,
                     init_state, sum_out, concat_out, transformer_factory):

    assert batch_size == 1, "the recurrent reference implementation only support batch size 1"

    # Get input placeholder and numpy array
    input_placeholder, input_value = make_placeholder(input_size, sequence_length, batch_size)

    # Construct network weights and initial state, if desired
    W_in, W_rec, b, init_state, init_state_value = make_weights(input_placeholder, hidden_size,
                                                                weight_initializer,
                                                                bias_initializer,
                                                                init_state)

    # Compute reference numpy RNN
    rnn_ref = RefBidirectional(input_size, hidden_size, return_sequence=return_sequence,
                               sum_out=sum_out, concat_out=concat_out)
    rnn_ref.set_weights(W_in, W_rec, b.reshape(rnn_ref.fwd_rnn.bh.shape))
    h_ref_list = rnn_ref.fprop(input_value.transpose([1, 0, 2]),
                               init_states=init_state_value)

    # Generate ngraph RNN
    rnn_ng = BiRNN(hidden_size, init=W_in, init_inner=W_rec, activation=Tanh(),
                   reset_cells=True, return_sequence=return_sequence,
                   sum_out=sum_out, concat_out=concat_out)

    # fprop ngraph RNN
    out_ng = rnn_ng.train_outputs(input_placeholder, init_state=init_state)

    with ExecutorFactory() as ex:
        # Create computation and execute
        if init_state is not None:
            fprop_neon_fun = ex.executor(out_ng, input_placeholder, init_state)
            fprop_neon = fprop_neon_fun(input_value, init_state_value)

        else:
            fprop_neon_fun = ex.executor(out_ng, input_placeholder)
            fprop_neon = fprop_neon_fun(input_value)

        # Compare output with reference implementation
        if not isinstance(fprop_neon, tuple):
            fprop_neon = [fprop_neon]
            h_ref_list = [h_ref_list]
        for ii, output in enumerate(fprop_neon):
            if return_sequence is True:
                output = output[:, :, 0]
            ng.testing.assert_allclose(output, h_ref_list[ii], rtol=fprop_rtol, atol=fprop_atol)