def test_seq2seq_deriv_ref(batch_size, sequence_length_enc, sequence_length_dec, input_size, hidden_size, weight_initializer, bias_initializer, transformer_factory): # TODO: are these assumptions true? assert batch_size == 1, "the seq2seq reference implementation only support batch size 1" # Get input placeholders and numpy arrays input_placeholder_enc, input_value_enc, = \ make_placeholder(input_size, sequence_length_enc, batch_size) input_placeholder_dec, input_value_dec, = \ make_placeholder(input_size, sequence_length_dec, batch_size) # Construct encoder weights W_in_enc, W_rec_enc, b_enc, _, _ = make_weights(input_placeholder_enc, hidden_size, weight_initializer, bias_initializer, init_state=False) # Construct decoder weights W_in_dec, W_rec_dec, b_dec, _, _ = make_weights(input_placeholder_dec, hidden_size, weight_initializer, bias_initializer, init_state=False) # Reference numpy seq2seq seq2seq_ref = RefSeq2Seq(input_size, hidden_size, decoder_return_sequence=True) seq2seq_ref.set_weights(W_in_enc, W_rec_enc, b_enc.reshape(seq2seq_ref.bh_enc.shape), W_in_dec, W_rec_dec, b_dec.reshape(seq2seq_ref.bh_dec.shape)) # Prepare deltas for gradient check output_shape = (hidden_size, sequence_length_dec, batch_size) # generate random deltas tensor deltas = np.random.randn(*output_shape) # the reference code expects these shapes: # input_shape: (seq_len, input_size, batch_size) # output_shape: (seq_len, hidden_size, batch_size) dW_in_enc, dW_rec_enc, db_enc, dW_in_dec, dW_rec_dec, db_dec, encoding_ref, hs_return_dec = \ seq2seq_ref.lossFun(input_value_enc.transpose([1, 0, 2]), input_value_dec.transpose([1, 0, 2]), deltas.copy().transpose([1, 0, 2])) # Generate ngraph Seq2Seq rnn_enc_ng = Recurrent(hidden_size, init=W_in_enc, init_inner=W_rec_enc, activation=Tanh(), reset_cells=True, return_sequence=False) rnn_dec_ng = Recurrent(hidden_size, init=W_in_dec, init_inner=W_rec_dec, activation=Tanh(), reset_cells=True, return_sequence=True) # ngraph fprop graph encoding_ng = rnn_enc_ng(input_placeholder_enc, init_state=None) output_ng = rnn_dec_ng(input_placeholder_dec, init_state=encoding_ng) deltas_constant = ng.constant(deltas, axes=output_ng.axes) params = [(rnn_dec_ng.b, db_dec), (rnn_dec_ng.W_input, dW_in_dec), (rnn_dec_ng.W_recur, dW_rec_dec), (rnn_enc_ng.b, db_enc), (rnn_enc_ng.W_input, dW_in_enc), (rnn_enc_ng.W_recur, dW_rec_enc)] with ExecutorFactory() as ex: # fprop computations fprop_fun = ex.executor([encoding_ng, output_ng], input_placeholder_enc, input_placeholder_dec) # gradient computations update_funs = [] for px, _ in params: update = ng.deriv(output_ng, px, error=deltas_constant) update_funs.append( ex.executor(update, input_placeholder_enc, input_placeholder_dec)) # check forward pass encoding, output = fprop_fun(input_value_enc, input_value_dec) ng.testing.assert_allclose(encoding, encoding_ref) ng.testing.assert_allclose(np.squeeze(output), np.squeeze(hs_return_dec)) # check gradient computations for update_fun, (_, deriv_ref_val) in zip(update_funs, params): grad_neon = update_fun(input_value_enc, input_value_dec) ng.testing.assert_allclose(grad_neon, deriv_ref_val.squeeze(), rtol=bprop_rtol, atol=1e-4)
def test_rnn_deriv_ref(sequence_length, input_size, hidden_size, batch_size, return_sequence, weight_initializer, bias_initializer, init_state, transformer_factory): assert batch_size == 1, "the recurrent reference implementation only support batch size 1" assert return_sequence is True, "the reference rnn only supports sequences for deriv" # Get input placeholder and numpy array input_placeholder, input_value = make_placeholder(input_size, sequence_length, batch_size) # Construct network weights and initial state, if desired W_in, W_rec, b, init_state, init_state_value = make_weights(input_placeholder, hidden_size, weight_initializer, bias_initializer, init_state) # Compute reference numpy RNN rnn_ref = RefRecurrent(input_size, hidden_size, return_sequence=return_sequence) rnn_ref.set_weights(W_in, W_rec, b.reshape(rnn_ref.bh.shape)) # Prepare deltas for gradient check output_shape = (hidden_size, sequence_length, batch_size) # generate random deltas tensor deltas = np.random.randn(*output_shape) # the reference code expects these shapes: # input_shape: (seq_len, input_size, batch_size) # output_shape: (seq_len, hidden_size, batch_size) dW_in, dW_rec, db = rnn_ref.lossFun(input_value.transpose([1, 0, 2]), deltas.copy().transpose([1, 0, 2]), init_states=init_state_value)[:3] # Generate ngraph RNN rnn_ng = RNNCell(hidden_size, init=W_in, init_h2h=W_rec, activation=Tanh(), reset_cells=True) # fprop ngraph RNN num_steps = input_placeholder.axes.recurrent_axis().length init_states = {'h': init_state} if init_state is not None else init_state out_ng = unroll(rnn_ng, num_steps, input_placeholder, init_states=init_states, return_sequence=return_sequence) deltas_constant = ng.constant(deltas, axes=out_ng.axes) params = [(rnn_ng.i2h.linear.W, W_in), (rnn_ng.h2h.W, W_rec), (rnn_ng.i2h.bias.W, b)] with ExecutorFactory() as ex: # Create derivative computations and execute param_updates = list() for px, _ in params: update = ng.deriv(out_ng, px, error=deltas_constant) if init_state is not None: param_updates.append(ex.executor(update, input_placeholder, init_state)) else: param_updates.append(ex.executor(update, input_placeholder)) for update_fun, ref_val in zip(param_updates, [dW_in, dW_rec, db]): if init_state is not None: grad_neon = update_fun(input_value, init_state_value) else: grad_neon = update_fun(input_value) ng.testing.assert_allclose(grad_neon, ref_val.squeeze(), rtol=bprop_rtol, atol=bprop_atol)
def test_recurrent_batchnorm_bprop(RNN, recurrent_input, output_size, bn_params, transformer_factory): """Compare bprop gated RNN with batch norm to numpy batch norm followed by rnn without""" helper = RNNHelper(recurrent_input, output_size, RNN, bn_params) # Get rnn + batch norm bprop graph fprop = helper.rnn(recurrent_input) bprop_vars = [recurrent_input, helper.gamma, helper.beta] # Get bprop graph delta_placeholder = ng.placeholder(fprop.axes) bprops = [ng.deriv(fprop, var, delta_placeholder) for var in bprop_vars] # Get reference graphs reference_fprop = helper.reference_rnn(helper.reference_input) # Handle the case where we have gates in the RNN object bprop_vars = [helper.reference_input] if helper.has_gates: bprop_vars.append(helper.get_ancestor_op(reference_fprop)) reference_delta_placeholder = ng.placeholder(reference_fprop.axes) reference_bprop = [ ng.deriv(reference_fprop, var, reference_delta_placeholder) for var in bprop_vars ] # Begin execution with ExecutorFactory() as ex: bprop_function = ex.executor(bprops, recurrent_input, delta_placeholder) reference_function = ex.executor(reference_bprop, helper.reference_input, reference_delta_placeholder) # Create data input_value = rng.uniform(0, 1, recurrent_input.axes) delta = rng.uniform(-.1, .1, fprop.axes) # Compute reference weighted input weighted_input = np.dot(helper.W_in, input_value.swapaxes(0, 1)) # Set the reduction axes used for reference bn_params['axis'] = (1, 2) # Get reference batch normed input batch_norm_reference = BatchNormReference(weighted_input, **bn_params) normed_input = batch_norm_reference.fprop[0] # Reference backprop through RNN reference_result = reference_function(normed_input, delta) # This is because of a HETR bug where return collections aren't handled properly if isinstance(reference_result, tuple): rnn_delta = reference_result[0] else: rnn_delta = reference_result # Reference backprop through BN dx_ref, dgamma_ref, dbeta_ref = batch_norm_reference.bprop(rnn_delta) # Backprop through reference batch norm for a single gate if helper.has_gates: rnn_gate_delta = reference_result[1] _, dgamma_ref, dbeta_ref = batch_norm_reference.bprop( rnn_gate_delta) # Backprop through weighted input dx_ref = np.dot(helper.W_in.T, dx_ref.swapaxes(0, 1)) # Compute ngraph bprop dx, dgamma, dbeta = bprop_function(input_value, delta) ng.testing.assert_allclose(dx, dx_ref, rtol=rtol, atol=recurrent_atol) ng.testing.assert_allclose(dgamma, dgamma_ref, rtol=rtol, atol=recurrent_atol) ng.testing.assert_allclose(dbeta, dbeta_ref, rtol=rtol, atol=recurrent_atol)