Ejemplo n.º 1
0
def test_layer_mode_setting(input_placeholder):
    """
    Make sure that Layer.inference_mode sets the correct mode for the graph
    """
    layer = SimpleLayer()
    with Layer.inference_mode_on():
        layer(input_placeholder)

    assert "inference" in layer.modes
    # Make sure all ops in subgraph are in mode
    for op in layer:
        assert op in layer.modes["inference"].ops
Ejemplo n.º 2
0
 def make_network(scope1=None, scope2=None):
     # 2 layer network, each layer has its own scope
     x = ng.placeholder(axes)  # inputs
     t = ng.placeholder(ng.make_axes([ng.make_axis(length=1),
                                      N]))  # targets
     with Layer.variable_scope(scope1):
         layer1 = Affine(ConstantInit(val=Wlin1),
                         nout=nout1,
                         bias_init=ConstantInit(val=Wbias1),
                         activation=Rectlin(),
                         batch_norm=False)
     with Layer.variable_scope(scope2):
         layer2 = Affine(ConstantInit(val=Wlin2),
                         nout=1,
                         bias_init=ConstantInit(val=Wbias2),
                         activation=Logistic(),
                         batch_norm=False)
     seq = Sequential([layer1, layer2])
     p_t = seq(x)
     t_cast = ng.cast_axes(t, p_t.axes)  # TODO: how can this be avoided?
     loss = ng.cross_entropy_binary(p_t, t_cast)
     return seq, x, t, loss
Ejemplo n.º 3
0
def test_dropout_inference(nin, batch_size, transformer_factory):
    # set inputs
    N = ng.make_axis(batch_size, name='N')
    F = ng.make_axis(nin, name='F')

    inp = ng.placeholder([F, N])
    layer = Dropout(keep=0.5)
    with Layer.inference_mode_on():
        fprop = layer(inp)

    # create data
    x = np.random.uniform(size=(nin, batch_size))

    # evaluate
    with ExecutorFactory() as ex:
        comp = ex.executor(fprop, inp)
        out = comp(x)
        numpy_out = x * 0.5
        ng.testing.assert_allclose(out, numpy_out, atol=atol, rtol=rtol)
        out1 = out.copy()
        out2 = comp(x)
        ng.testing.assert_allclose(out1, out2, atol=atol, rtol=rtol)