Exemplo n.º 1
0
def test_seq2seq_deriv_ref(batch_size, sequence_length_enc,
                           sequence_length_dec, input_size, hidden_size,
                           weight_initializer, bias_initializer,
                           transformer_factory):

    # TODO: are these assumptions true?
    assert batch_size == 1, "the seq2seq reference implementation only support batch size 1"

    # Get input placeholders and numpy arrays
    input_placeholder_enc, input_value_enc, = \
        make_placeholder(input_size, sequence_length_enc, batch_size)
    input_placeholder_dec, input_value_dec, = \
        make_placeholder(input_size, sequence_length_dec, batch_size)

    # Construct encoder weights
    W_in_enc, W_rec_enc, b_enc, _, _ = make_weights(input_placeholder_enc,
                                                    hidden_size,
                                                    weight_initializer,
                                                    bias_initializer,
                                                    init_state=False)

    # Construct decoder weights
    W_in_dec, W_rec_dec, b_dec, _, _ = make_weights(input_placeholder_dec,
                                                    hidden_size,
                                                    weight_initializer,
                                                    bias_initializer,
                                                    init_state=False)

    # Reference numpy seq2seq
    seq2seq_ref = RefSeq2Seq(input_size,
                             hidden_size,
                             decoder_return_sequence=True)
    seq2seq_ref.set_weights(W_in_enc, W_rec_enc,
                            b_enc.reshape(seq2seq_ref.bh_enc.shape), W_in_dec,
                            W_rec_dec, b_dec.reshape(seq2seq_ref.bh_dec.shape))

    # Prepare deltas for gradient check
    output_shape = (hidden_size, sequence_length_dec, batch_size)

    # generate random deltas tensor
    deltas = np.random.randn(*output_shape)

    # the reference code expects these shapes:
    # input_shape: (seq_len, input_size, batch_size)
    # output_shape: (seq_len, hidden_size, batch_size)
    dW_in_enc, dW_rec_enc, db_enc, dW_in_dec, dW_rec_dec, db_dec, encoding_ref, hs_return_dec = \
        seq2seq_ref.lossFun(input_value_enc.transpose([1, 0, 2]),
                            input_value_dec.transpose([1, 0, 2]),
                            deltas.copy().transpose([1, 0, 2]))

    # Generate ngraph Seq2Seq
    rnn_enc_ng = Recurrent(hidden_size,
                           init=W_in_enc,
                           init_inner=W_rec_enc,
                           activation=Tanh(),
                           reset_cells=True,
                           return_sequence=False)
    rnn_dec_ng = Recurrent(hidden_size,
                           init=W_in_dec,
                           init_inner=W_rec_dec,
                           activation=Tanh(),
                           reset_cells=True,
                           return_sequence=True)

    # ngraph fprop graph
    encoding_ng = rnn_enc_ng(input_placeholder_enc, init_state=None)
    output_ng = rnn_dec_ng(input_placeholder_dec, init_state=encoding_ng)

    deltas_constant = ng.constant(deltas, axes=output_ng.axes)
    params = [(rnn_dec_ng.b, db_dec), (rnn_dec_ng.W_input, dW_in_dec),
              (rnn_dec_ng.W_recur, dW_rec_dec), (rnn_enc_ng.b, db_enc),
              (rnn_enc_ng.W_input, dW_in_enc),
              (rnn_enc_ng.W_recur, dW_rec_enc)]

    with ExecutorFactory() as ex:

        # fprop computations
        fprop_fun = ex.executor([encoding_ng, output_ng],
                                input_placeholder_enc, input_placeholder_dec)

        # gradient computations
        update_funs = []
        for px, _ in params:
            update = ng.deriv(output_ng, px, error=deltas_constant)
            update_funs.append(
                ex.executor(update, input_placeholder_enc,
                            input_placeholder_dec))

        # check forward pass
        encoding, output = fprop_fun(input_value_enc, input_value_dec)
        ng.testing.assert_allclose(encoding, encoding_ref)
        ng.testing.assert_allclose(np.squeeze(output),
                                   np.squeeze(hs_return_dec))

        # check gradient computations
        for update_fun, (_, deriv_ref_val) in zip(update_funs, params):
            grad_neon = update_fun(input_value_enc, input_value_dec)
            ng.testing.assert_allclose(grad_neon,
                                       deriv_ref_val.squeeze(),
                                       rtol=bprop_rtol,
                                       atol=1e-4)
Exemplo n.º 2
0
def test_rnn_deriv_ref(sequence_length, input_size, hidden_size, batch_size,
                       return_sequence, weight_initializer, bias_initializer,
                       init_state, transformer_factory):

    assert batch_size == 1, "the recurrent reference implementation only support batch size 1"
    assert return_sequence is True, "the reference rnn only supports sequences for deriv"

    # Get input placeholder and numpy array
    input_placeholder, input_value = make_placeholder(input_size, sequence_length, batch_size)

    # Construct network weights and initial state, if desired
    W_in, W_rec, b, init_state, init_state_value = make_weights(input_placeholder, hidden_size,
                                                                weight_initializer,
                                                                bias_initializer,
                                                                init_state)

    # Compute reference numpy RNN
    rnn_ref = RefRecurrent(input_size, hidden_size, return_sequence=return_sequence)
    rnn_ref.set_weights(W_in, W_rec, b.reshape(rnn_ref.bh.shape))

    # Prepare deltas for gradient check
    output_shape = (hidden_size, sequence_length, batch_size)

    # generate random deltas tensor
    deltas = np.random.randn(*output_shape)

    # the reference code expects these shapes:
    # input_shape: (seq_len, input_size, batch_size)
    # output_shape: (seq_len, hidden_size, batch_size)
    dW_in, dW_rec, db = rnn_ref.lossFun(input_value.transpose([1, 0, 2]),
                                        deltas.copy().transpose([1, 0, 2]),
                                        init_states=init_state_value)[:3]

    # Generate ngraph RNN
    rnn_ng = RNNCell(hidden_size, init=W_in, init_h2h=W_rec, activation=Tanh(),
                     reset_cells=True)

    # fprop ngraph RNN
    num_steps = input_placeholder.axes.recurrent_axis().length
    init_states = {'h': init_state} if init_state is not None else init_state
    out_ng = unroll(rnn_ng, num_steps, input_placeholder, init_states=init_states,
                    return_sequence=return_sequence)
    deltas_constant = ng.constant(deltas, axes=out_ng.axes)
    params = [(rnn_ng.i2h.linear.W, W_in),
              (rnn_ng.h2h.W, W_rec),
              (rnn_ng.i2h.bias.W, b)]

    with ExecutorFactory() as ex:
        # Create derivative computations and execute
        param_updates = list()
        for px, _ in params:
            update = ng.deriv(out_ng, px, error=deltas_constant)
            if init_state is not None:
                param_updates.append(ex.executor(update, input_placeholder, init_state))
            else:
                param_updates.append(ex.executor(update, input_placeholder))

        for update_fun, ref_val in zip(param_updates, [dW_in, dW_rec, db]):
            if init_state is not None:
                grad_neon = update_fun(input_value, init_state_value)
            else:
                grad_neon = update_fun(input_value)
            ng.testing.assert_allclose(grad_neon,
                                       ref_val.squeeze(),
                                       rtol=bprop_rtol, atol=bprop_atol)
Exemplo n.º 3
0
def test_recurrent_batchnorm_bprop(RNN, recurrent_input, output_size,
                                   bn_params, transformer_factory):
    """Compare bprop gated RNN with batch norm to numpy batch norm followed by rnn without"""

    helper = RNNHelper(recurrent_input, output_size, RNN, bn_params)

    # Get rnn + batch norm bprop graph
    fprop = helper.rnn(recurrent_input)
    bprop_vars = [recurrent_input, helper.gamma, helper.beta]

    # Get bprop graph
    delta_placeholder = ng.placeholder(fprop.axes)
    bprops = [ng.deriv(fprop, var, delta_placeholder) for var in bprop_vars]

    # Get reference graphs
    reference_fprop = helper.reference_rnn(helper.reference_input)

    # Handle the case where we have gates in the RNN object
    bprop_vars = [helper.reference_input]
    if helper.has_gates:
        bprop_vars.append(helper.get_ancestor_op(reference_fprop))

    reference_delta_placeholder = ng.placeholder(reference_fprop.axes)
    reference_bprop = [
        ng.deriv(reference_fprop, var, reference_delta_placeholder)
        for var in bprop_vars
    ]

    # Begin execution
    with ExecutorFactory() as ex:
        bprop_function = ex.executor(bprops, recurrent_input,
                                     delta_placeholder)
        reference_function = ex.executor(reference_bprop,
                                         helper.reference_input,
                                         reference_delta_placeholder)

        # Create data
        input_value = rng.uniform(0, 1, recurrent_input.axes)
        delta = rng.uniform(-.1, .1, fprop.axes)

        # Compute reference weighted input
        weighted_input = np.dot(helper.W_in, input_value.swapaxes(0, 1))

        # Set the reduction axes used for reference
        bn_params['axis'] = (1, 2)

        # Get reference batch normed input
        batch_norm_reference = BatchNormReference(weighted_input, **bn_params)
        normed_input = batch_norm_reference.fprop[0]

        # Reference backprop through RNN
        reference_result = reference_function(normed_input, delta)
        # This is because of a HETR bug where return collections aren't handled properly
        if isinstance(reference_result, tuple):
            rnn_delta = reference_result[0]
        else:
            rnn_delta = reference_result

        # Reference backprop through BN
        dx_ref, dgamma_ref, dbeta_ref = batch_norm_reference.bprop(rnn_delta)

        # Backprop through reference batch norm for a single gate
        if helper.has_gates:
            rnn_gate_delta = reference_result[1]
            _, dgamma_ref, dbeta_ref = batch_norm_reference.bprop(
                rnn_gate_delta)

        # Backprop through weighted input
        dx_ref = np.dot(helper.W_in.T, dx_ref.swapaxes(0, 1))

        # Compute ngraph bprop
        dx, dgamma, dbeta = bprop_function(input_value, delta)

        ng.testing.assert_allclose(dx, dx_ref, rtol=rtol, atol=recurrent_atol)
        ng.testing.assert_allclose(dgamma,
                                   dgamma_ref,
                                   rtol=rtol,
                                   atol=recurrent_atol)
        ng.testing.assert_allclose(dbeta,
                                   dbeta_ref,
                                   rtol=rtol,
                                   atol=recurrent_atol)