예제 #1
0
def test_apply_node():
    network = tn.SequentialNode("s", [
        tn.InputNode("in", shape=(3, 4, 5)),
        tn.ApplyNode("a", fn=T.sum, shape_fn=lambda x: ()),
    ]).network()
    fn = network.function(["in"], ["s"])
    x = np.random.randn(3, 4, 5).astype(fX)
    np.testing.assert_allclose(fn(x)[0], x.sum(), rtol=1e-5)
예제 #2
0
 def tmp(include_batch_pad):
     network = tn.SequentialNode(
         "seq",
         [tn.InputNode("i", shape=(None, 2)),
          tn.ApplyNode("a",
                       fn=(lambda x: x.shape[0].astype(fX) + x),
                       shape_fn=(lambda s: s))]
     ).network()
     handlers = [canopy.handlers.chunk_variables(3, ["i"])]
     if include_batch_pad:
         handlers.insert(0, canopy.handlers.batch_pad(3, ["x"]))
     fn = canopy.handlers.handled_fn(network,
                                     handlers,
                                     {"x": "i"},
                                     {"out": "seq"})
     return fn({"x": np.zeros((16, 2), dtype=fX)})
예제 #3
0
def test_chunk_variables():
    network = tn.SequentialNode(
        "seq",
        [tn.InputNode("i", shape=(None, 2)),
         tn.ApplyNode("a",
                      fn=(lambda x: x.shape[0].astype(fX) + x),
                      shape_fn=(lambda s: s))]
    ).network()

    fn1 = canopy.handlers.handled_fn(network,
                                     [],
                                     {"x": "i"},
                                     {"out": "seq"})
    np.testing.assert_equal(fn1({"x": np.zeros((18, 2), dtype=fX)})["out"],
                            np.ones((18, 2), dtype=fX) * 18)

    fn2 = canopy.handlers.handled_fn(
        network,
        [canopy.handlers.chunk_variables(3, ["i"])],
        {"x": "i"},
        {"out": "seq"})
    np.testing.assert_equal(fn2({"x": np.zeros((18, 2), dtype=fX)})["out"],
                            np.ones((18, 2), dtype=fX) * 3)
예제 #4
0
            shape=shape,
            tags={"parameter"},
            default_inits=[],
        )


def reward_fn(x):
    return -T.sqr(x - 3.5).sum(axis=1) + 100


graph = tn.GraphNode("graph", [[
    tn.ConstantNode("state", value=T.zeros((1, 1))),
    ConstantStateNode("mu", shape=(1, 1)),
    tn.ConstantNode("sigma", value=1.),
    REINFORCE.NormalSampleNode("sampled"),
    tn.ApplyNode("reward", fn=reward_fn, shape_fn=lambda x: x[:1]),
    REINFORCE.NormalREINFORCECostNode("REINFORCE")
],
                               [{
                                   "from": "mu",
                                   "to": "sampled",
                                   "to_key": "mu"
                               }, {
                                   "from": "sigma",
                                   "to": "sampled",
                                   "to_key": "sigma"
                               }, {
                                   "from": "sampled",
                                   "to": "reward"
                               }, {
                                   "from": "state",
예제 #5
0
def test_apply_node_serialization():
    tn.check_serialization(tn.ApplyNode("a"))