class TestEmbedding(unittest.TestCase): def setUp(self): W = np.arange(21).reshape(7, 3) self.embedding = Embedding(W) self.index = np.array([0, 2, 0, 4]) def test_params(self): params, = self.embedding.params assert_array_equal(np.array([ [ 0, 1, 2], [ 3, 4, 5], [ 6, 7, 8], [ 9, 10, 11], [12, 13, 14], [15, 16, 17], [18, 19, 20] ]), params) def test_grads(self): grads, = self.embedding.grads assert_array_equal(np.array([ [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0] ]), grads) def test_forward(self): out = self.embedding.forward(self.index) assert_array_equal(np.array([ [ 0, 1, 2], [ 6, 7, 8], [ 0, 1, 2], [12, 13, 14] ]), out) def test_backward(self): dout = self.embedding.forward(self.index) self.embedding.backward(dout) grads, = self.embedding.grads assert_array_equal(np.array([ [ 0, 2, 4], [ 0, 0, 0], [ 6, 7, 8], [ 0, 0, 0], [12, 13, 14], [ 0, 0, 0], [ 0, 0, 0] ]), grads)
class EmbeddingDot: def __init__(self, W): self.embedding = Embedding(W) self.params = self.embedding.params self.grads = self.embedding.grads self.cache = None def forward(self, h, index): target_W = self.embedding.forward(index) out = np.sum(target_W * h, axis=1) self.cache = (h, target_W) return out def backward(self, dout): h, target_W = self.cache # Reshape (3,) to (3, 1) dout = dout.reshape(dout.shape[0], 1) dtarget_W = dout * h self.cache = dtarget_W self.embedding.backward(dtarget_W) dh = dout * target_W return dh