예제 #1
0
def test_save(tf_graph, mocker, tmpdir):
    save_weights_path = str(tmpdir.dirpath("save_weights.npz"))
    params = get_params()
    rnn = RNN(params)
    with pytest.raises(UserWarning) as excinfo:
        rnn.save(save_weights_path)
    assert "No weights to return" in str(excinfo.value)
    mocker.patch.object(RNN, 'forward_pass')
    RNN.forward_pass.return_value = tf.fill(
        [params['N_batch'], params['N_steps'], params['N_out']],
        float('nan')), tf.fill(
            [params['N_batch'], params['N_steps'], params['N_rec']],
            float('nan'))
    rnn.build()
    rdm1 = rd.RDM(dt=params['dt'],
                  tau=params['tau'],
                  T=2000,
                  N_batch=params['N_batch'])
    gen1 = rdm1.batch_generator()
    rnn.train(gen1)

    assert not tmpdir.dirpath("save_weights.npz").check(exists=1)
    rnn.save(save_weights_path)
    assert tmpdir.dirpath("save_weights.npz").check(exists=1)

    tmpdir.dirpath("save_weights.npz").remove()
예제 #2
0
def test_save(tf_graph, mocker, tmpdir):
	save_weights_path = str(tmpdir.dirpath("save_weights.npz"))
	params = get_params()
	rnn = RNN(params)

	mocker.patch.object(RNN, 'forward_pass')
	RNN.forward_pass.return_value = tf.fill([params['N_batch'], params['N_steps'], params['N_out']], float('nan')), tf.fill([params['N_batch'], params['N_steps'], params['N_rec']], float('nan'))
	
	pd1 = PerceptualDiscrimination(dt = params['dt'], tau = params['tau'], T = 2000, N_batch = params['N_batch'])  
	gen1 = pd1.batch_generator()
	rnn.train(gen1)

	assert not tmpdir.dirpath("save_weights.npz").check(exists=1)
	rnn.save(save_weights_path)
	assert tmpdir.dirpath("save_weights.npz").check(exists=1)

	tmpdir.dirpath("save_weights.npz").remove()