コード例 #1
0
ファイル: test_graph.py プロジェクト: zhangmarvin/deepx
    def test_graph_parameters(self):
        np.random.seed(0)
        net1 = Vector(10) >> Linear(10)

        net2 = net1 >> Linear(10)

        self.assertEqual(net1.get_state(as_list=True),
                         net2.left.get_state(as_list=True))
コード例 #2
0
ファイル: vae.py プロジェクト: sharadmv/deep-trees
    def __init__(self, input_size, embedding_size, q_network,
                 likelihood_model):
        self.input_size = input_size
        self.embedding_size = embedding_size

        self.q_network = q_network >> (Linear(embedding_size),
                                       Linear(embedding_size))
        self.likelihood_model = likelihood_model
コード例 #3
0
ファイル: test_graph.py プロジェクト: zhangmarvin/deepx
    def test_shape_elementwise(self):
        part1 = Linear()
        part2 = Linear(20)

        self.assertEqual(part1.get_shape(), (None, None))
        self.assertEqual(part2.get_shape(), (None, [Shape(20)]))

        part2.chain(part1)

        self.assertEqual(part1.get_shape(), ([Shape(20)], [Shape(20)]))
        self.assertEqual(part2.get_shape(), (None, [Shape(20)]))
コード例 #4
0
ファイル: test_graph.py プロジェクト: zhangmarvin/deepx
    def test_shape_inference2(self):
        part1 = Linear(10, 100)
        part2 = Linear(100, 20)

        self.assertEqual(part1.get_shape(), ([Shape(10)], [Shape(100)]))
        self.assertEqual(part2.get_shape(), ([Shape(100)], [Shape(20)]))

        part1.chain(part2)

        self.assertEqual(part1.get_shape(), ([Shape(10)], [Shape(100)]))
        self.assertEqual(part2.get_shape(), ([Shape(100)], [Shape(20)]))
コード例 #5
0
ファイル: test_graph.py プロジェクト: zhangmarvin/deepx
    def test_graph_basic(self):
        np.random.seed(0)
        v1 = Vector(10)
        v2 = Linear(10)

        self.assertEqual((v1 >> v2).left, v1.chain(v2).left)
        self.assertEqual((v1 >> v2).right, v1.chain(v2).right)
コード例 #6
0
ファイル: test_graph.py プロジェクト: zhangmarvin/deepx
    def test_freeze_parameters(self):
        np.random.seed(0)
        net1 = Vector(10) >> Linear(10)
        net1.initialize()

        self.assertEqual(net1.freeze().get_state(as_list=True),
                         net1.get_state(as_list=True))
コード例 #7
0
ファイル: test_graph.py プロジェクト: sharadmv/deepx
    def test_shape_elementwise(self):
        part1 = Linear()
        part2 = Linear(20)

        self.assertEqual(part1.get_shape(), (None, None))
        self.assertEqual(part2.get_shape(), (None, [Shape(20)]))

        part2.chain(part1)

        self.assertEqual(part1.get_shape(), ([Shape(20)], [Shape(20)]))
        self.assertEqual(part2.get_shape(), (None, [Shape(20)]))
コード例 #8
0
ファイル: test_graph.py プロジェクト: sharadmv/deepx
    def test_shape_inference2(self):
        part1 = Linear(10, 100)
        part2 = Linear(100, 20)

        self.assertEqual(part1.get_shape(), ([Shape(10)], [Shape(100)]))
        self.assertEqual(part2.get_shape(), ([Shape(100)], [Shape(20)]))

        part1.chain(part2)

        self.assertEqual(part1.get_shape(), ([Shape(10)], [Shape(100)]))
        self.assertEqual(part2.get_shape(), ([Shape(100)], [Shape(20)]))
コード例 #9
0
ファイル: test_graph.py プロジェクト: sharadmv/deepx
    def test_shape_elementwise2(self):
        part1 = Linear(20)
        part2 = Linear()
        part3 = Linear(30, 40)

        self.assertEqual(part1.get_shape(), (None, [Shape(20)]))
        self.assertEqual(part2.get_shape(), (None, None))
        self.assertEqual(part3.get_shape(), ([Shape(30)], [Shape(40)]))

        part4 = part1.chain(part2)

        self.assertEqual(part1.get_shape(), (None, [Shape(20)]))
        self.assertEqual(part2.get_shape(), ([Shape(20)], [Shape(20)]))
        self.assertEqual(part4.get_shape(), (None, [Shape(20)]))

        with self.assertRaises(ShapeInError):
            part4.chain(part3)
コード例 #10
0
ファイル: test_graph.py プロジェクト: zhangmarvin/deepx
    def test_freeze_parameters2(self):
        np.random.seed(0)
        gan = (Vector(10) >> Linear(20)) >> (Linear(10) >> Linear(2))
        gan.initialize()

        self.assertEqual(gan.left.freeze().get_state(as_list=True),
                         gan.left.get_state(as_list=True))
        self.assertEqual(gan.right.freeze().get_state(as_list=True),
                         gan.right.get_state(as_list=True))

        self.assertEqual(gan.right.freeze().get_graph_parameters(), [])
        self.assertNotEqual(gan.right.get_graph_parameters(), [])

        self.assertEqual(
            (Vector(20) >> gan.right).freeze().get_graph_parameters(), [])
        self.assertNotEqual((Vector(20) >> gan.right).get_graph_parameters(),
                            [])

        self.assertEqual(
            (Vector(20) >> gan.right).freeze().get_state(as_list=True),
            (Vector(20) >> gan.right).get_state(as_list=True))

        self.assertEqual(gan.left.freeze().get_graph_parameters(), [])
        self.assertNotEqual(gan.left.get_graph_parameters(), [])
コード例 #11
0
ファイル: model.py プロジェクト: sharadmv/deep-trees
    T.set_default_device('/cpu:0')

    c = T.scalar(name='c')
    segments = T.matrix(dtype='int32', name='segments')

    a_idx = segments[:, 0]
    b_idx = segments[:, 1]
    leaf_segment = segments[:, 2]
    m = segments[:, 3]
    log_fac = segments[:, 4]

    x = T.matrix(name='x')
    e = T.matrix(name='e')
    q_network = Vector(X.shape[1], placeholder=x, is_input=False) >> Repeat(Tanh(200), 2)
    q_mu_network = q_network >> Linear(D)
    q_mu = q_mu_network.get_outputs()[0].get_placeholder()
    q_sigma_network = q_network >> Linear(D)
    q_sigma = tf.sqrt(tf.exp(q_sigma_network.get_outputs()[0].get_placeholder()))
    z = q_mu + e * q_sigma

    values, times = T.variable(values), T.variable(times)
    values = tf.concat(0, [z, values])
    harmonic = T.variable(create_harmonic(M))

    a_batch_values = T.gather(values, a_idx)
    a_batch_times = T.gather(times, a_idx)
    b_batch_values = T.gather(values, b_idx)
    b_batch_times = T.gather(times, b_idx)
    harmonic_m = T.gather(harmonic, m - 1)
コード例 #12
0
ファイル: test_graph.py プロジェクト: zhangmarvin/deepx
    def test_bad_shape2(self):
        part1 = Linear(100)
        part2 = Linear(20)

        self.assertEqual(part1.get_shape(), (None, [Shape(100)]))
        self.assertEqual(part2.get_shape(), (None, [Shape(20)]))

        part1.chain(part2)

        self.assertEqual(part1.get_shape(), (None, [Shape(100)]))
        self.assertEqual(part2.get_shape(), ([Shape(100)], [Shape(20)]))

        part3 = Linear(100)
        part2.chain(part3)

        part4 = Linear(90)
        with self.assertRaises(ShapeInError):
            part4.chain(part3)
コード例 #13
0
ファイル: test_graph.py プロジェクト: zhangmarvin/deepx
    def test_bad_shape(self):
        part1 = Linear(10, 100)
        part2 = Linear(90, 20)

        with self.assertRaises(ShapeInError):
            part1.chain(part2)
コード例 #14
0
ファイル: test_graph.py プロジェクト: zhangmarvin/deepx
 def test_freeze(self):
     net1 = Vector(10) >> Linear(10)
     self.assertEqual(Freeze(net1).get_graph_parameters(), [])
コード例 #15
0
ファイル: test_graph.py プロジェクト: zhangmarvin/deepx
    def test_shape_elementwise2(self):
        part1 = Linear(20)
        part2 = Linear()
        part3 = Linear(30, 40)

        self.assertEqual(part1.get_shape(), (None, [Shape(20)]))
        self.assertEqual(part2.get_shape(), (None, None))
        self.assertEqual(part3.get_shape(), ([Shape(30)], [Shape(40)]))

        part4 = part1.chain(part2)

        self.assertEqual(part1.get_shape(), (None, [Shape(20)]))
        self.assertEqual(part2.get_shape(), ([Shape(20)], [Shape(20)]))
        self.assertEqual(part4.get_shape(), (None, [Shape(20)]))

        with self.assertRaises(ShapeInError):
            part4.chain(part3)
コード例 #16
0
ファイル: test_graph.py プロジェクト: sharadmv/deepx
    def test_bad_shape2(self):
        part1 = Linear(100)
        part2 = Linear(20)

        self.assertEqual(part1.get_shape(), (None, [Shape(100)]))
        self.assertEqual(part2.get_shape(), (None, [Shape(20)]))

        part1.chain(part2)

        self.assertEqual(part1.get_shape(), (None, [Shape(100)]))
        self.assertEqual(part2.get_shape(), ([Shape(100)], [Shape(20)]))

        part3 = Linear(100)
        part2.chain(part3)

        part4 = Linear(90)
        with self.assertRaises(ShapeInError):
            part4.chain(part3)
コード例 #17
0
ファイル: test_graph.py プロジェクト: sharadmv/deepx
    def test_bad_shape(self):
        part1 = Linear(10, 100)
        part2 = Linear(90, 20)

        with self.assertRaises(ShapeInError):
            part1.chain(part2)