コード例 #1
0
def test_get_weights(tf_graph, mocker):
    params = get_params()
    rnn = RNN(params)
    with pytest.raises(UserWarning) as excinfo:
        rnn.get_weights()
    assert 'No weights to return yet -- model has not yet been initialized.' in str(
        excinfo.value)
    mocker.patch.object(RNN, 'forward_pass')
    RNN.forward_pass.return_value = tf.fill(
        [params['N_batch'], params['N_steps'], params['N_out']],
        float('nan')), tf.fill(
            [params['N_batch'], params['N_steps'], params['N_rec']],
            float('nan'))
    rnn.build()
    with pytest.raises(UserWarning) as excinfo:
        rnn.get_weights()
    assert 'No weights to return yet -- model has not yet been initialized.' in str(
        excinfo.value)
    rdm1 = rd.RDM(dt=params['dt'],
                  tau=params['tau'],
                  T=2000,
                  N_batch=params['N_batch'])
    gen1 = rdm1.batch_generator()
    rnn.train(gen1)
    assert type(rnn.get_weights()) is dict
コード例 #2
0
ファイル: test_rnn.py プロジェクト: yulkang/PsychRNN
def test_get_weights(tf_graph,mocker):
	params = get_params()
	rnn = RNN(params)
	
	mocker.patch.object(RNN, 'forward_pass')
	RNN.forward_pass.return_value = tf.fill([params['N_batch'], params['N_steps'], params['N_out']], float('nan')), tf.fill([params['N_batch'], params['N_steps'], params['N_rec']], float('nan'))

	assert type(rnn.get_weights()) is dict