def test_linear_keep_batch_axis():
    feature_axis = ng.make_axis(1, name='A')
    batch_axis = ng.make_axis(2, name='N')

    x = ng.placeholder([batch_axis])
    linear = Linear(axes=feature_axis,
                    keep_axes=[batch_axis],
                    init=UniformInit(1.0, 1.0))(x)

    assert linear.axes == ng.make_axes([feature_axis, batch_axis])
def test_linear_axes_nout():
    feature_axis = ng.make_axis(1, name='A')
    batch_axis = ng.make_axis(2, name='N')

    x = ng.placeholder([feature_axis, batch_axis])
    linear = Linear(nout=3, init=UniformInit(1.0, 1.0))(x)

    assert feature_axis not in linear.axes
    assert batch_axis in linear.axes
    assert linear.axes.batch_axis().length == 2
    assert linear.axes.sample_axes().lengths == (3, )
def test_linear_zeros(input_placeholder, output_size):
    # basic sanity check with 0 weights random inputs
    x = np.random.random(input_placeholder.axes.lengths)
    layer = Linear(nout=output_size, init=UniformInit(0.0, 0.0))

    with ExecutorFactory() as ex:
        if ex.transformer.transformer_name == 'hetr':
            pytest.xfail("hetr fork-safe issue on mac")
        comp = ex.executor(layer(input_placeholder), input_placeholder)
        output_values = comp(x)

    assert np.min(output_values) == 0.0 and np.max(output_values) == 0.0
def test_linear_W_axes_nout():
    feature_axis = ng.make_axis(1, name='A')
    batch_axis = ng.make_axis(2, name='N')

    x = ng.placeholder([feature_axis, batch_axis])
    linear = Linear(nout=3, init=UniformInit(1.0, 1.0))
    linear(x)

    assert linear.W.axes.batch_axis() is None
    assert feature_axis in linear.W.axes
    assert len(linear.W.axes - feature_axis) == 1
    assert (linear.W.axes - feature_axis)[0].length == 3
def test_linear_ones(input_size, input_placeholder, output_size):
    # basic sanity check with all ones on the inputs and weights, check that
    # each row in output is the sum of the weights for that output this check
    # will confirm that the correct number of operations is being run
    x = np.ones(input_placeholder.axes.lengths)
    layer = Linear(nout=output_size, init=UniformInit(1.0, 1.0))

    with ExecutorFactory() as ex:
        if ex.transformer.transformer_name == 'hetr':
            pytest.xfail("hetr fork-safe issue on mac")
        out = layer(input_placeholder)
        comp = ex.executor([out, layer.W], input_placeholder)
        output_values, w = comp(x)

    ng.testing.assert_allclose(np.ones(out.axes.lengths) * input_size,
                               output_values,
                               atol=0.0,
                               rtol=0.0)
def test_linear_invalid_batch_axes():
    with pytest.raises(ValueError):
        Linear(axes=ng.make_axis(1, name='N'), init=UniformInit(1.0, 1.0))
def test_linear_invalid_shadow_axes():
    with pytest.raises(ValueError):
        Linear(axes=make_shadow_axis(ng.make_axis(1, name='A')),
               init=UniformInit(1.0, 1.0))
def test_linear_accepts_axes_axis():
    """ Ensure that Linear.__init__ accepts an Axis as axes """
    Linear(axes=ng.make_axis(1), init=UniformInit(1.0, 1.0))