def test_birnn_deriv_numerical(sequence_length, input_size, hidden_size, batch_size, return_sequence, weight_initializer, bias_initializer, sum_out, concat_out): # Get input placeholder and numpy array input_placeholder, input_value = make_placeholder(input_size, sequence_length, batch_size) # Construct network weights and initial state, if desired W_in, W_rec, b, init_state, init_state_value = make_weights(input_placeholder, hidden_size, weight_initializer, bias_initializer) # Generate ngraph RNN rnn_ng = BiRNN(hidden_size, init=W_in, init_inner=W_rec, activation=Tanh(), reset_cells=True, return_sequence=return_sequence, sum_out=sum_out, concat_out=concat_out) # fprop ngraph RNN out_ng = rnn_ng.train_outputs(input_placeholder) w_in_f = rnn_ng.fwd_rnn.W_input w_rec_f = rnn_ng.fwd_rnn.W_recur b_f = rnn_ng.fwd_rnn.b w_in_b = rnn_ng.bwd_rnn.W_input w_rec_b = rnn_ng.bwd_rnn.W_recur b_b = rnn_ng.bwd_rnn.b params_f = [(w_in_f, W_in), (w_rec_f, W_rec), (b_f, b)] params_b = [(w_in_b, W_in), (w_rec_b, W_rec), (b_b, b)] if sum_out or concat_out: out_ng = [out_ng] params_birnn = [params_f + params_b] else: # in this case out_ng will be a list params_birnn = [params_f, params_b] with ExecutorFactory() as ex: # Create derivative computations and execute param_updates = list() dep_list = list() for output, dependents in zip(out_ng, params_birnn): for px, _ in dependents: update = (ex.derivative(output, px, input_placeholder), ex.numeric_derivative(output, px, delta, input_placeholder)) param_updates.append(update) dep_list += dependents for ii, ((deriv_s, deriv_n), (_, val)) in enumerate(zip(param_updates, dep_list)): ng.testing.assert_allclose(deriv_s(val, input_value), deriv_n(val, input_value), rtol=num_rtol, atol=num_atol)
def test_birnn_fprop(sequence_length, input_size, hidden_size, batch_size, return_sequence, weight_initializer, bias_initializer, init_state, sum_out, concat_out, transformer_factory): assert batch_size == 1, "the recurrent reference implementation only support batch size 1" # Get input placeholder and numpy array input_placeholder, input_value = make_placeholder(input_size, sequence_length, batch_size) # Construct network weights and initial state, if desired W_in, W_rec, b, init_state, init_state_value = make_weights(input_placeholder, hidden_size, weight_initializer, bias_initializer, init_state) # Compute reference numpy RNN rnn_ref = RefBidirectional(input_size, hidden_size, return_sequence=return_sequence, sum_out=sum_out, concat_out=concat_out) rnn_ref.set_weights(W_in, W_rec, b.reshape(rnn_ref.fwd_rnn.bh.shape)) h_ref_list = rnn_ref.fprop(input_value.transpose([1, 0, 2]), init_states=init_state_value) # Generate ngraph RNN rnn_ng = BiRNN(hidden_size, init=W_in, init_inner=W_rec, activation=Tanh(), reset_cells=True, return_sequence=return_sequence, sum_out=sum_out, concat_out=concat_out) # fprop ngraph RNN out_ng = rnn_ng.train_outputs(input_placeholder, init_state=init_state) with ExecutorFactory() as ex: # Create computation and execute if init_state is not None: fprop_neon_fun = ex.executor(out_ng, input_placeholder, init_state) fprop_neon = fprop_neon_fun(input_value, init_state_value) else: fprop_neon_fun = ex.executor(out_ng, input_placeholder) fprop_neon = fprop_neon_fun(input_value) # Compare output with reference implementation if not isinstance(fprop_neon, tuple): fprop_neon = [fprop_neon] h_ref_list = [h_ref_list] for ii, output in enumerate(fprop_neon): if return_sequence is True: output = output[:, :, 0] ng.testing.assert_allclose(output, h_ref_list[ii], rtol=fprop_rtol, atol=fprop_atol)