Beispiel #1
0
def test_train_iters(tf_graph, mocker, capfd):
    params = get_params()

    pd1 = PerceptualDiscrimination(dt=params['dt'],
                                   tau=params['tau'],
                                   T=2000,
                                   N_batch=params['N_batch'])
    gen1 = pd1.batch_generator()

    for remainder in [0, 1]:
        train_params = {}
        train_params['training_iters'] = params[
            'N_batch'] + remainder  # number of iterations to train for Default: 10000
        train_params[
            'loss_epoch'] = 1  # Compute and record loss every 'loss_epoch' epochs. Default: 10
        train_params['verbosity'] = True

        rnn = RNN(params)
        mocker.patch.object(RNN, 'forward_pass')
        RNN.forward_pass.return_value = tf.fill(
            [params['N_batch'], params['N_steps'], params['N_out']],
            float('nan')), tf.fill(
                [params['N_batch'], params['N_steps'], params['N_rec']],
                float('nan'))

        rnn.train(gen1, train_params)
        rnn.destruct()

        out, _ = capfd.readouterr()
        N_epochs = int(
            ceil(train_params['training_iters'] / (params['N_batch'] * 1.0)))
        assert "Iter " + str(N_epochs * params['N_batch']) in out
        assert "Iter " + str((N_epochs + 1) * params['N_batch']) not in out
def test_custom_loss(tf_graph, mocker):

    params = get_params()
    params['loss_function'] = 'my_mean_squared_error'
    rnn = RNN(params)

    mocker.patch.object(RNN, 'forward_pass')
    RNN.forward_pass.return_value = tf.fill(
        [params['N_batch'], params['N_steps'], params['N_out']],
        float('nan')), tf.fill(
            [params['N_batch'], params['N_steps'], params['N_rec']],
            float('nan'))

    with pytest.raises(UserWarning) as excinfo:
        rnn.build()
    assert 'my_mean_squared_error' in str(excinfo.value)
    rnn.destruct()

    params['my_mean_squared_error'] = mean_squared_error
    rnn = RNN(params)

    mocker.patch.object(RNN, 'forward_pass')
    RNN.forward_pass.return_value = tf.fill(
        [params['N_batch'], params['N_steps'], params['N_out']],
        float('nan')), tf.fill(
            [params['N_batch'], params['N_steps'], params['N_rec']],
            float('nan'))

    rnn.build()
Beispiel #3
0
def test_destruct(tf_graph, mocker):
	params = get_params()
	rnn1 = RNN(params)
	rnn1.destruct()
	rnn2 = RNN(params)
	mocker.patch.object(RNN, 'forward_pass')
	RNN.forward_pass.return_value = tf.fill([params['N_batch'], params['N_steps'], params['N_out']], float('nan')), tf.fill([params['N_batch'], params['N_steps'], params['N_rec']], float('nan'))
	rnn2.build()
	rnn2.destruct()
	rnn3 = RNN(params)
Beispiel #4
0
def test_load_weights_path_rnn(tf_graph, mocker, tmpdir, capfd):
    params = get_params()

    pd1 = PerceptualDiscrimination(dt=params['dt'],
                                   tau=params['tau'],
                                   T=2000,
                                   N_batch=params['N_batch'])
    gen1 = pd1.batch_generator()

    mocker.patch.object(RNN, 'forward_pass')
    RNN.forward_pass.return_value = tf.fill(
        [params['N_batch'], params['N_steps'], params['N_out']],
        float('nan')), tf.fill(
            [params['N_batch'], params['N_steps'], params['N_rec']],
            float('nan'))

    rnn = RNN(params)
    rnn.build()

    train_params = {}
    train_params['save_weights_path'] = str(
        tmpdir.dirpath("save_weights.npz")
    )  # Where to save the model after training. Default: None
    train_params['verbosity'] = False

    ### save out some weights to test with and destroy the rnn that created them
    assert not tmpdir.dirpath("save_weights.npz").check(exists=1)
    rnn.train(gen1, train_params)

    assert rnn.is_initialized is True
    out, _ = capfd.readouterr()
    print(out)
    assert out == ""
    assert tmpdir.dirpath("save_weights.npz").check(exists=1)
    rnn.destruct()

    ### Make sure loading weights fails on nonexistent file
    params['load_weights_path'] = "nonexistent"
    with pytest.raises(EnvironmentError) as excinfo:
        rnn = RNN(params)
    assert "No such file" in str(excinfo.value)
    rnn.destruct()

    ### Ensure success when loading weights created previously
    params['load_weights_path'] = str(tmpdir.dirpath("save_weights.npz"))
    rnn = RNN(params)

    tmpdir.dirpath("save_weights.npz").remove()