def test_linear_keep_batch_axis(): feature_axis = ng.make_axis(1, name='A') batch_axis = ng.make_axis(2, name='N') x = ng.placeholder([batch_axis]) linear = Linear(axes=feature_axis, keep_axes=[batch_axis], init=UniformInit(1.0, 1.0))(x) assert linear.axes == ng.make_axes([feature_axis, batch_axis])
def test_linear_axes_nout(): feature_axis = ng.make_axis(1, name='A') batch_axis = ng.make_axis(2, name='N') x = ng.placeholder([feature_axis, batch_axis]) linear = Linear(nout=3, init=UniformInit(1.0, 1.0))(x) assert feature_axis not in linear.axes assert batch_axis in linear.axes assert linear.axes.batch_axis().length == 2 assert linear.axes.sample_axes().lengths == (3, )
def test_linear_zeros(input_placeholder, output_size): # basic sanity check with 0 weights random inputs x = np.random.random(input_placeholder.axes.lengths) layer = Linear(nout=output_size, init=UniformInit(0.0, 0.0)) with ExecutorFactory() as ex: if ex.transformer.transformer_name == 'hetr': pytest.xfail("hetr fork-safe issue on mac") comp = ex.executor(layer(input_placeholder), input_placeholder) output_values = comp(x) assert np.min(output_values) == 0.0 and np.max(output_values) == 0.0
def test_linear_W_axes_nout(): feature_axis = ng.make_axis(1, name='A') batch_axis = ng.make_axis(2, name='N') x = ng.placeholder([feature_axis, batch_axis]) linear = Linear(nout=3, init=UniformInit(1.0, 1.0)) linear(x) assert linear.W.axes.batch_axis() is None assert feature_axis in linear.W.axes assert len(linear.W.axes - feature_axis) == 1 assert (linear.W.axes - feature_axis)[0].length == 3
def test_linear_ones(input_size, input_placeholder, output_size): # basic sanity check with all ones on the inputs and weights, check that # each row in output is the sum of the weights for that output this check # will confirm that the correct number of operations is being run x = np.ones(input_placeholder.axes.lengths) layer = Linear(nout=output_size, init=UniformInit(1.0, 1.0)) with ExecutorFactory() as ex: if ex.transformer.transformer_name == 'hetr': pytest.xfail("hetr fork-safe issue on mac") out = layer(input_placeholder) comp = ex.executor([out, layer.W], input_placeholder) output_values, w = comp(x) ng.testing.assert_allclose(np.ones(out.axes.lengths) * input_size, output_values, atol=0.0, rtol=0.0)
def test_linear_invalid_batch_axes(): with pytest.raises(ValueError): Linear(axes=ng.make_axis(1, name='N'), init=UniformInit(1.0, 1.0))
def test_linear_invalid_shadow_axes(): with pytest.raises(ValueError): Linear(axes=make_shadow_axis(ng.make_axis(1, name='A')), init=UniformInit(1.0, 1.0))
def test_linear_accepts_axes_axis(): """ Ensure that Linear.__init__ accepts an Axis as axes """ Linear(axes=ng.make_axis(1), init=UniformInit(1.0, 1.0))