def test_output_shape(self): indim = 20 innerdim = 50 batsize = 200 seqlen = 5 data = np.random.random((batsize, seqlen, indim)).astype("float32") gru = GRU(innerdim=innerdim, dim=indim) grupred = gru.predict(data) self.assertEqual(grupred.shape, (batsize, seqlen, innerdim))
def test_output_shape(self): indim = 20 innerdim = 50 batsize = 200 seqlen = 5 data = np.random.random((batsize, seqlen, indim)).astype("float32") gru = GRU(innerdim=innerdim, dim=indim) grupred = gru.predict(data) self.assertEqual(grupred.shape, (batsize, seqlen, innerdim))
def test_if_prediction_is_equivalent_to_manually_constructed_theano_graph(self): indim = 20 innerdim = 50 batsize = 200 seqlen = 5 data = np.random.random((batsize, seqlen, indim)).astype("float32") gru = GRU(innerdim=innerdim, dim=indim) grupred = gru.predict(data)[:, -1, :] tgru_in, tgru_out = self.build_theano_gru(innerdim, indim, batsize, gru) tgrupred = tgru_out.eval({tgru_in: data.astype("float32")}) print np.sum(np.abs(tgrupred-grupred)) self.assertTrue(np.allclose(grupred, tgrupred))
def test_gru_with_mask(self): indim = 2 innerdim = 5 batsize = 4 seqlen = 3 data = np.random.random((batsize, seqlen, indim)).astype("float32") mask = np.zeros((batsize, seqlen)).astype("float32") mask[:, 0] = 1. mask[0, :] = 1. gru = GRU(innerdim=innerdim, dim=indim) grupred = gru.predict(data, mask) print grupred self.assertEqual(grupred.shape, (batsize, seqlen, innerdim)) #self.assertTrue(np.allclose(grupred[1:, 1:, :], np.zeros_like(grupred[1:, 1:, :]))) self.assertTrue(np.all(abs(grupred[0, ...]) > 0)) self.assertTrue(np.all(abs(grupred[:, 0, :]) > 0))
def test_gru_with_mask(self): indim = 2 innerdim = 5 batsize = 4 seqlen = 3 data = np.random.random((batsize, seqlen, indim)).astype("float32") mask = np.zeros((batsize, seqlen)).astype("float32") mask[:, 0] = 1. mask[0, :] = 1. gru = GRU(innerdim=innerdim, dim=indim) grupred = gru.predict(data, mask) print grupred self.assertEqual(grupred.shape, (batsize, seqlen, innerdim)) #self.assertTrue(np.allclose(grupred[1:, 1:, :], np.zeros_like(grupred[1:, 1:, :]))) self.assertTrue(np.all(abs(grupred[0, ...]) > 0)) self.assertTrue(np.all(abs(grupred[:, 0, :]) > 0))