class TestEmbedding(unittest.TestCase):
    def setUp(self):
        W = np.arange(21).reshape(7, 3)
        self.embedding = Embedding(W)
        self.index = np.array([0, 2, 0, 4])

    def test_params(self):
        params, = self.embedding.params
        assert_array_equal(np.array([
            [ 0,  1,  2],
            [ 3,  4,  5],
            [ 6,  7,  8],
            [ 9, 10, 11],
            [12, 13, 14],
            [15, 16, 17],
            [18, 19, 20]
       ]), params)

    def test_grads(self):
        grads, = self.embedding.grads
        assert_array_equal(np.array([
            [0, 0, 0],
            [0, 0, 0],
            [0, 0, 0],
            [0, 0, 0],
            [0, 0, 0],
            [0, 0, 0],
            [0, 0, 0]
       ]), grads)

    def test_forward(self):
        out = self.embedding.forward(self.index)
        assert_array_equal(np.array([
            [ 0,  1,  2],
            [ 6,  7,  8],
            [ 0,  1,  2],
            [12, 13, 14]
       ]), out)

    def test_backward(self):
        dout = self.embedding.forward(self.index)
        self.embedding.backward(dout)
        grads, = self.embedding.grads
        assert_array_equal(np.array([
            [ 0,  2,  4],
            [ 0,  0,  0],
            [ 6,  7,  8],
            [ 0,  0,  0],
            [12, 13, 14],
            [ 0,  0,  0],
            [ 0,  0,  0]
        ]), grads)
class EmbeddingDot:
    def __init__(self, W):
        self.embedding = Embedding(W)
        self.params = self.embedding.params
        self.grads = self.embedding.grads
        self.cache = None

    def forward(self, h, index):
        target_W = self.embedding.forward(index)
        out = np.sum(target_W * h, axis=1)
        self.cache = (h, target_W)
        return out

    def backward(self, dout):
        h, target_W = self.cache
        # Reshape (3,) to (3, 1)
        dout = dout.reshape(dout.shape[0], 1)
        dtarget_W = dout * h
        self.cache = dtarget_W
        self.embedding.backward(dtarget_W)
        dh = dout * target_W
        return dh