def test_graph_parameters(self): np.random.seed(0) net1 = Vector(10) >> Linear(10) net2 = net1 >> Linear(10) self.assertEqual(net1.get_state(as_list=True), net2.left.get_state(as_list=True))
def __init__(self, input_size, embedding_size, q_network, likelihood_model): self.input_size = input_size self.embedding_size = embedding_size self.q_network = q_network >> (Linear(embedding_size), Linear(embedding_size)) self.likelihood_model = likelihood_model
def test_shape_elementwise(self): part1 = Linear() part2 = Linear(20) self.assertEqual(part1.get_shape(), (None, None)) self.assertEqual(part2.get_shape(), (None, [Shape(20)])) part2.chain(part1) self.assertEqual(part1.get_shape(), ([Shape(20)], [Shape(20)])) self.assertEqual(part2.get_shape(), (None, [Shape(20)]))
def test_shape_inference2(self): part1 = Linear(10, 100) part2 = Linear(100, 20) self.assertEqual(part1.get_shape(), ([Shape(10)], [Shape(100)])) self.assertEqual(part2.get_shape(), ([Shape(100)], [Shape(20)])) part1.chain(part2) self.assertEqual(part1.get_shape(), ([Shape(10)], [Shape(100)])) self.assertEqual(part2.get_shape(), ([Shape(100)], [Shape(20)]))
def test_graph_basic(self): np.random.seed(0) v1 = Vector(10) v2 = Linear(10) self.assertEqual((v1 >> v2).left, v1.chain(v2).left) self.assertEqual((v1 >> v2).right, v1.chain(v2).right)
def test_freeze_parameters(self): np.random.seed(0) net1 = Vector(10) >> Linear(10) net1.initialize() self.assertEqual(net1.freeze().get_state(as_list=True), net1.get_state(as_list=True))
def test_shape_elementwise2(self): part1 = Linear(20) part2 = Linear() part3 = Linear(30, 40) self.assertEqual(part1.get_shape(), (None, [Shape(20)])) self.assertEqual(part2.get_shape(), (None, None)) self.assertEqual(part3.get_shape(), ([Shape(30)], [Shape(40)])) part4 = part1.chain(part2) self.assertEqual(part1.get_shape(), (None, [Shape(20)])) self.assertEqual(part2.get_shape(), ([Shape(20)], [Shape(20)])) self.assertEqual(part4.get_shape(), (None, [Shape(20)])) with self.assertRaises(ShapeInError): part4.chain(part3)
def test_freeze_parameters2(self): np.random.seed(0) gan = (Vector(10) >> Linear(20)) >> (Linear(10) >> Linear(2)) gan.initialize() self.assertEqual(gan.left.freeze().get_state(as_list=True), gan.left.get_state(as_list=True)) self.assertEqual(gan.right.freeze().get_state(as_list=True), gan.right.get_state(as_list=True)) self.assertEqual(gan.right.freeze().get_graph_parameters(), []) self.assertNotEqual(gan.right.get_graph_parameters(), []) self.assertEqual( (Vector(20) >> gan.right).freeze().get_graph_parameters(), []) self.assertNotEqual((Vector(20) >> gan.right).get_graph_parameters(), []) self.assertEqual( (Vector(20) >> gan.right).freeze().get_state(as_list=True), (Vector(20) >> gan.right).get_state(as_list=True)) self.assertEqual(gan.left.freeze().get_graph_parameters(), []) self.assertNotEqual(gan.left.get_graph_parameters(), [])
T.set_default_device('/cpu:0') c = T.scalar(name='c') segments = T.matrix(dtype='int32', name='segments') a_idx = segments[:, 0] b_idx = segments[:, 1] leaf_segment = segments[:, 2] m = segments[:, 3] log_fac = segments[:, 4] x = T.matrix(name='x') e = T.matrix(name='e') q_network = Vector(X.shape[1], placeholder=x, is_input=False) >> Repeat(Tanh(200), 2) q_mu_network = q_network >> Linear(D) q_mu = q_mu_network.get_outputs()[0].get_placeholder() q_sigma_network = q_network >> Linear(D) q_sigma = tf.sqrt(tf.exp(q_sigma_network.get_outputs()[0].get_placeholder())) z = q_mu + e * q_sigma values, times = T.variable(values), T.variable(times) values = tf.concat(0, [z, values]) harmonic = T.variable(create_harmonic(M)) a_batch_values = T.gather(values, a_idx) a_batch_times = T.gather(times, a_idx) b_batch_values = T.gather(values, b_idx) b_batch_times = T.gather(times, b_idx) harmonic_m = T.gather(harmonic, m - 1)
def test_bad_shape2(self): part1 = Linear(100) part2 = Linear(20) self.assertEqual(part1.get_shape(), (None, [Shape(100)])) self.assertEqual(part2.get_shape(), (None, [Shape(20)])) part1.chain(part2) self.assertEqual(part1.get_shape(), (None, [Shape(100)])) self.assertEqual(part2.get_shape(), ([Shape(100)], [Shape(20)])) part3 = Linear(100) part2.chain(part3) part4 = Linear(90) with self.assertRaises(ShapeInError): part4.chain(part3)
def test_bad_shape(self): part1 = Linear(10, 100) part2 = Linear(90, 20) with self.assertRaises(ShapeInError): part1.chain(part2)
def test_freeze(self): net1 = Vector(10) >> Linear(10) self.assertEqual(Freeze(net1).get_graph_parameters(), [])