def test_get_weights(tf_graph, mocker): params = get_params() rnn = RNN(params) with pytest.raises(UserWarning) as excinfo: rnn.get_weights() assert 'No weights to return yet -- model has not yet been initialized.' in str( excinfo.value) mocker.patch.object(RNN, 'forward_pass') RNN.forward_pass.return_value = tf.fill( [params['N_batch'], params['N_steps'], params['N_out']], float('nan')), tf.fill( [params['N_batch'], params['N_steps'], params['N_rec']], float('nan')) rnn.build() with pytest.raises(UserWarning) as excinfo: rnn.get_weights() assert 'No weights to return yet -- model has not yet been initialized.' in str( excinfo.value) rdm1 = rd.RDM(dt=params['dt'], tau=params['tau'], T=2000, N_batch=params['N_batch']) gen1 = rdm1.batch_generator() rnn.train(gen1) assert type(rnn.get_weights()) is dict
def test_get_weights(tf_graph,mocker): params = get_params() rnn = RNN(params) mocker.patch.object(RNN, 'forward_pass') RNN.forward_pass.return_value = tf.fill([params['N_batch'], params['N_steps'], params['N_out']], float('nan')), tf.fill([params['N_batch'], params['N_steps'], params['N_rec']], float('nan')) assert type(rnn.get_weights()) is dict