예제 #1
0
    def fprop(params):
      """
      Forward pass of the NTM.
      """

      W = params # aliasing for brevity

      xs, hs, ys, ps, ts, os = {}, {}, {}, {}, {}, {}

      def l():
        """
        Silly utility function that should be called in init.
        """
        return [{} for _ in xrange(self.heads)]

      rs = l()
      k_rs, beta_rs, g_rs, s_rs, gamma_rs = l(),l(),l(),l(),l()
      k_ws, beta_ws, g_ws, s_ws, gamma_ws = l(),l(),l(),l(),l()
      adds, erases = l(),l()
      w_ws, w_rs = l(),l() # read weights and write weights
      for idx in range(self.heads):
        rs[idx][-1] = self.W['rsInit' + str(idx)] # stores values read from memory
        w_ws[idx][-1] = softmax(self.W['w_wsInit' + str(idx)])
        w_rs[idx][-1] = softmax(self.W['w_rsInit' + str(idx)])

      mems = {} # the state of the memory at every timestep
      mems[-1] = self.W['memsInit']
      loss = 0

      for t in xrange(len(inputs)):

        xs[t] = np.reshape(np.array(inputs[t]),inputs[t].shape[::-1])

        rsum = 0
        for idx in range(self.heads):
          rsum = rsum + np.dot(W['rh' + str(idx)], np.reshape(rs[idx][t-1],(self.M,1)))
        hs[t] = np.tanh(np.dot(W['xh'], xs[t]) + rsum + W['bh'])

        os[t] = np.tanh(np.dot(W['ho'], hs[t]) + W['bo'])


        for idx in range(self.heads):
          # parameters to the read head
          k_rs[idx][t] = np.tanh(np.dot(W['ok_r' + str(idx)],os[t]) + W['bk_r' + str(idx)])
          beta_rs[idx][t] = softplus(np.dot(W['obeta_r' + str(idx)],os[t])
                                     + W['bbeta_r' + str(idx)])
          g_rs[idx][t] = sigmoid(np.dot(W['og_r' + str(idx)],os[t]) + W['bg_r' + str(idx)])
          s_rs[idx][t] = softmax(np.dot(W['os_r' + str(idx)],os[t]) + W['bs_r' + str(idx)])
          gamma_rs[idx][t] = 1 + sigmoid(np.dot(W['ogamma_r' + str(idx)], os[t])
                                         + W['bgamma_r' + str(idx)])

          # parameters to the write head
          k_ws[idx][t] = np.tanh(np.dot(W['ok_w' + str(idx)],os[t]) + W['bk_w' + str(idx)])
          beta_ws[idx][t] = softplus(np.dot(W['obeta_w' + str(idx)], os[t])
                                     + W['bbeta_w' + str(idx)])
          g_ws[idx][t] = sigmoid(np.dot(W['og_w' + str(idx)],os[t]) + W['bg_w' + str(idx)])
          s_ws[idx][t] = softmax(np.dot(W['os_w' + str(idx)],os[t]) + W['bs_w' + str(idx)])
          gamma_ws[idx][t] = 1 + sigmoid(np.dot(W['ogamma_w' + str(idx)], os[t])
                                         + W['bgamma_w' + str(idx)])

          # the erase and add vectors
          # these are also parameters to the write head
          # but they describe "what" is to be written rather than "where"
          adds[idx][t] = np.tanh(np.dot(W['oadds' + str(idx)], os[t]) + W['badds' + str(idx)])
          erases[idx][t] = sigmoid(np.dot(W['oerases' + str(idx)], os[t]) + W['erases' + str(idx)])

          w_ws[idx][t] = addressing.create_weights(   k_ws[idx][t]
                                                    , beta_ws[idx][t]
                                                    , g_ws[idx][t]
                                                    , s_ws[idx][t]
                                                    , gamma_ws[idx][t]
                                                    , w_ws[idx][t-1]
                                                    , mems[t-1])

          w_rs[idx][t] = addressing.create_weights(   k_rs[idx][t]
                                                    , beta_rs[idx][t]
                                                    , g_rs[idx][t]
                                                    , s_rs[idx][t]
                                                    , gamma_rs[idx][t]
                                                    , w_rs[idx][t-1]
                                                    , mems[t-1])

        ys[t] = np.dot(W['oy'], os[t]) + W['by']
        ps[t] = sigmoid(ys[t])

        one = np.ones(ps[t].shape)
        ts[t] = np.reshape(np.array(targets[t]),(self.out_size,1))

        epsilon = 2**-23 # to prevent log(0)
        a = np.multiply(ts[t] , np.log2(ps[t] + epsilon))
        b = np.multiply(one - ts[t], np.log2(one-ps[t] + epsilon))
        loss = loss - (a + b)

        for idx in range(self.heads):
          # read from the memory
          rs[idx][t] = memory.read(mems[t-1],w_rs[idx][t])

          # write into the memory
          mems[t] = memory.write(mems[t-1],w_ws[idx][t],erases[idx][t],adds[idx][t])

      self.stats = [loss, ps, w_rs, w_ws, adds, erases]
      return np.sum(loss)
예제 #2
0
파일: ntm.py 프로젝트: vyraun/diffmem
        def manual_grads(params):
            """
      Compute the gradient of the loss WRT the parameters
      Ordering of the operations is reverse of that in fprop()
      """
            deltas = {}
            for key, val in params.iteritems():
                deltas[key] = np.zeros_like(val)

            [
                loss, mems, ps, ys, os, zos, hs, zhs, xs, rs, w_rs, w_ws, adds,
                erases, k_rs, k_ws, g_rs, g_ws, wc_rs, wc_ws, zbeta_rs,
                zbeta_ws, zs_rs, zs_ws, wg_rs, wg_ws
            ] = self.stats
            dd = {}
            drs = {}
            dzh = {}
            dmem = {}  # might not need this, since we have dmemtilde
            dmemtilde = {}
            du_r = {}
            du_w = {}
            dwg_r = {}
            dwg_w = {}
            for t in reversed(xrange(len(targets))):

                dy = np.copy(ps[t])
                dy -= targets[t].T  # backprop into y

                deltas['oy'] += np.dot(dy, os[t].T)
                deltas['by'] += dy

                if t < len(targets) - 1:
                    # r[t] affects cost through zh[t+1] via Wrh
                    drs[t] = np.dot(self.W['rh'].T, dzh[t + 1])

                    # right now, mems[t] influences cost through rs[t+1], via w_rs[t+1]
                    dmem[t] = np.dot(w_rs[t + 1], drs[t + 1].reshape(
                        (self.M, 1)).T)
                    # and also through mems at next step
                    W = np.reshape(w_ws[t + 1], (w_ws[t + 1].shape[0], 1))
                    E = np.reshape(erases[t + 1], (erases[t + 1].shape[0], 1))
                    WTE = np.dot(W, E.T)
                    KEEP = np.ones(mems[0].shape) - WTE
                    dmem[t] += np.multiply(dmemtilde[t + 1], KEEP)
                    # and also through its influence on the content weighting next step
                    dmem[t] += du_r[t + 1] + du_w[t + 1]

                    dmemtilde[t] = dmem[t]

                    # erases[t] affects cost through mems[t], via w_ws[t]
                    derase = np.dot(
                        np.multiply(dmemtilde[t], -mems[t - 1]).T, w_ws[t])

                    # zerase affects just erases through a sigmoid
                    dzerase = derase * (erases[t] * (1 - erases[t]))

                    # adds[t] affects costs through mems[t], via w_ws
                    dadd = np.dot(dmem[t].T, w_ws[t])

                    # zadds affects just adds through a tanh
                    dzadd = dadd * (1 - adds[t] * adds[t])

                    # dbadds is just dzadds
                    deltas['badds'] += dzadd

                    deltas['oadds'] += np.dot(dzadd, os[t].T)

                    deltas['berases'] += dzerase

                    deltas['oerases'] += np.dot(dzerase, os[t].T)

                    # # read weights affect what is read, via what's in mems[t-1]
                    # dwc_r = np.dot(mems[t-1], drs[t])

                    # # write weights affect mem[t] through adding
                    # dwc_w = np.dot(dmem[t], adds[t])
                    # # they also affect memtilde[t] through erasing
                    # dwc_w += np.dot(np.multiply(dmemtilde[t], -mems[t-1]), erases[t])

                    dw_r = np.dot(mems[t - 1], drs[t])
                    dw_r += dwg_r[t + 1] * (1 - g_rs[t + 1])

                    # write weights affect mem[t] through adding
                    dw_w = np.dot(dmem[t], adds[t])
                    # they also affect memtilde[t] through erasing
                    dw_w += np.dot(np.multiply(dmemtilde[t], -mems[t - 1]),
                                   erases[t])
                    dw_w += dwg_w[t + 1] * (1 - g_ws[t + 1])

                    sgwr = np.zeros((self.N, self.N))
                    sgww = np.zeros((self.N, self.N))
                    for i in range(self.N):
                        sgwr[i, i] = softmax(zs_rs[t])[0]
                        sgwr[i, (i + 1) % self.N] = softmax(zs_rs[t])[2]
                        sgwr[i, (i - 1) % self.N] = softmax(zs_rs[t])[1]

                        sgww[i, i] = softmax(zs_ws[t])[0]
                        sgww[i, (i + 1) % self.N] = softmax(zs_ws[t])[2]
                        sgww[i, (i - 1) % self.N] = softmax(zs_ws[t])[1]

                    # right now, shifted weights are final weight
                    dws_r = dw_r
                    dws_w = dw_w

                    dwg_r[t] = np.dot(sgwr.T, dws_r)
                    dwg_w[t] = np.dot(sgww.T, dws_w)

                    dwc_r = dwg_r[t] * g_rs[t]
                    dwc_w = dwg_w[t] * g_ws[t]
                    """
          We need dw/dK
          now w has N elts and K has N elts
          and we want, for every elt of W, the grad of that elt w.r.t. each
          of the N elts of K. that gives us N * N things
          """
                    # first, we must build up the K values (should be taken from fprop)
                    K_rs = []
                    K_ws = []
                    for i in range(self.N):
                        K_rs.append(cosine_sim(mems[t - 1][i, :], k_rs[t]))
                        K_ws.append(cosine_sim(mems[t - 1][i, :], k_ws[t]))

                    # then, we populate the grads
                    dwdK_r = np.zeros((self.N, self.N))
                    dwdK_w = np.zeros((self.N, self.N))
                    # for every row in the memory
                    for i in range(self.N):
                        # for every element in the weighting
                        for j in range(self.N):
                            dwdK_r[i,
                                   j] += softmax_grads(K_rs,
                                                       softplus(zbeta_rs[t]),
                                                       i, j)
                            dwdK_w[i,
                                   j] += softmax_grads(K_ws,
                                                       softplus(zbeta_ws[t]),
                                                       i, j)

                    # compute dK for all i in N
                    # K is the evaluated cosine similarity for the i-th row of mem matrix
                    dK_r = np.zeros_like(w_rs[0])
                    dK_w = np.zeros_like(w_ws[0])

                    # for all i in N (for every row that we've simmed)
                    for i in range(self.N):
                        # for every j in N (for every elt of the weighting)
                        for j in range(self.N):
                            # specifically, dwdK_r will change, and for write as well
                            dK_r[i] += dwc_r[j] * dwdK_r[i, j]
                            dK_w[i] += dwc_w[j] * dwdK_w[i, j]
                    """
          dK_r_dk_rs is a list of N things
          each elt of the list corresponds to grads of K_idx
          w.r.t. the key k_t
          so it should be a length N list of M by 1 vectors
          """

                    dK_r_dk_rs = []
                    dK_r_dmem = []
                    for i in range(self.N):
                        # let k_rs be u, Mem[i] be v
                        u = np.reshape(k_rs[t], (self.M, ))
                        v = mems[t - 1][i, :]
                        dK_r_dk_rs.append(dKdu(u, v))
                        dK_r_dmem.append(dKdu(v, u))

                    dK_w_dk_ws = []
                    dK_w_dmem = []
                    for i in range(self.N):
                        # let k_ws be u, Mem[i] be v
                        u = np.reshape(k_ws[t], (self.M, ))
                        v = mems[t - 1][i, :]
                        dK_w_dk_ws.append(dKdu(u, v))
                        dK_w_dmem.append(dKdu(v, u))

                    # compute delta for keys
                    dk_r = np.zeros_like(k_rs[0])
                    dk_w = np.zeros_like(k_ws[0])
                    # for every one of M elt of dk_r
                    for i in range(self.M):
                        # for every one of the N Ks
                        for j in range(self.N):
                            # add delta K_r[j] * dK_r[j] / dk_r[i]
                            # add influence on through K_r[j]
                            dk_r[i] += dK_r[j] * dK_r_dk_rs[j][i]
                            dk_w[i] += dK_w[j] * dK_w_dk_ws[j][i]

                    # these represent influence of mem on next K
                    """
          Let's let du_r[t] represent the
          influence of mems[t-1] on the cost through the K values
          this is analogous to dk_w, but, k only every affects that
          whereas mems[t-1] will also affect what is read at time t+1
          and through memtilde at time t+1
          """
                    du_r[t] = np.zeros_like(mems[0])
                    du_w[t] = np.zeros_like(mems[0])
                    # for every row in mems[t-1]
                    for i in range(self.N):
                        # for every elt of this row (one of M)
                        for j in range(self.M):
                            du_r[t][i, j] = dK_r[i] * dK_r_dmem[i][j]
                            du_w[t][i, j] = dK_w[i] * dK_w_dmem[i][j]

                    # key values are activated as tanh
                    dzk_r = dk_r * (1 - k_rs[t] * k_rs[t])
                    dzk_w = dk_w * (1 - k_ws[t] * k_ws[t])

                    deltas['ok_r'] += np.dot(dzk_r, os[t].T)
                    deltas['ok_w'] += np.dot(dzk_w, os[t].T)

                    deltas['bk_r'] += dzk_r
                    deltas['bk_w'] += dzk_w

                    dg_r = np.dot(dwg_r[t].T, (wc_rs[t] - w_rs[t - 1]))
                    dg_w = np.dot(dwg_w[t].T, (wc_ws[t] - w_ws[t - 1]))

                    # compute dzg_r, dzg_w
                    dzg_r = dg_r * (g_rs[t] * (1 - g_rs[t]))
                    dzg_w = dg_w * (g_ws[t] * (1 - g_ws[t]))

                    deltas['og_r'] += np.dot(dzg_r, os[t].T)
                    deltas['og_w'] += np.dot(dzg_w, os[t].T)

                    deltas['bg_r'] += dzg_r
                    deltas['bg_w'] += dzg_w

                    # compute dbeta, which affects w_content through interaction with Ks

                    dwcdbeta_r = np.zeros_like(w_rs[0])
                    dwcdbeta_w = np.zeros_like(w_ws[0])
                    for i in range(self.N):
                        dwcdbeta_r[i] = beta_grads(K_rs, softplus(zbeta_rs[t]),
                                                   i)
                        dwcdbeta_w[i] = beta_grads(K_ws, softplus(zbeta_ws[t]),
                                                   i)

                    dbeta_r = np.zeros_like(zbeta_rs[0])
                    dbeta_w = np.zeros_like(zbeta_ws[0])
                    for i in range(self.N):
                        dbeta_r[0] += dwc_r[i] * dwcdbeta_r[i]
                        dbeta_w[0] += dwc_w[i] * dwcdbeta_w[i]

                    # beta is activated from zbeta by softplus, grad of which is sigmoid
                    dzbeta_r = dbeta_r * sigmoid(zbeta_rs[t])
                    dzbeta_w = dbeta_w * sigmoid(zbeta_ws[t])

                    deltas['obeta_r'] += np.dot(dzbeta_r, os[t].T)
                    deltas['obeta_w'] += np.dot(dzbeta_w, os[t].T)

                    deltas['bbeta_r'] += dzbeta_r
                    deltas['bbeta_w'] += dzbeta_w

                    sgsr = np.zeros((self.N, 3))
                    sgsw = np.zeros((self.N, 3))
                    for i in range(self.N):
                        sgsr[i, 1] = wg_rs[t][(i - 1) % self.N]
                        sgsr[i, 0] = wg_rs[t][i]
                        sgsr[i, 2] = wg_rs[t][(i + 1) % self.N]

                        sgsw[i, 1] = wg_ws[t][(i - 1) % self.N]
                        sgsw[i, 0] = wg_ws[t][i]
                        sgsw[i, 2] = wg_ws[t][(i + 1) % self.N]

                    ds_r = np.dot(sgsr.T, dws_r)
                    ds_w = np.dot(sgsw.T, dws_w)

                    shift_act_jac_r = np.zeros((3, 3))
                    shift_act_jac_w = np.zeros((3, 3))
                    bf = np.array([[1.0]])
                    for i in range(3):
                        for j in range(3):
                            shift_act_jac_r[i, j] = softmax_grads(
                                zs_rs[t], bf, i, j)
                            shift_act_jac_w[i, j] = softmax_grads(
                                zs_ws[t], bf, i, j)

                    dzs_r = np.dot(shift_act_jac_r.T, ds_r)
                    dzs_w = np.dot(shift_act_jac_w.T, ds_w)

                    deltas['os_r'] += np.dot(dzs_r, os[t].T)
                    deltas['os_w'] += np.dot(dzs_w, os[t].T)

                    deltas['bs_r'] += dzs_r
                    deltas['bs_w'] += dzs_w

                else:
                    drs[t] = np.zeros_like(rs[0])
                    dmemtilde[t] = np.zeros_like(mems[0])
                    du_r[t] = np.zeros_like(mems[0])
                    du_w[t] = np.zeros_like(mems[0])
                    dwg_r[t] = np.zeros_like(w_rs[0])
                    dwg_w[t] = np.zeros_like(w_ws[0])

                # o affects y through Woy
                do = np.dot(params['oy'].T, dy)
                if t < len(targets) - 1:
                    # and also zadd through Woadds
                    do += np.dot(params['oadds'].T, dzadd)
                    do += np.dot(params['oerases'].T, dzerase)
                    # and also through the keys
                    do += np.dot(params['ok_r'].T, dzk_r)
                    do += np.dot(params['ok_w'].T, dzk_w)
                    # and also through the interpolators
                    do += np.dot(params['og_r'].T, dzg_r)
                    do += np.dot(params['og_w'].T, dzg_w)
                    # and also through beta
                    do += np.dot(params['obeta_r'].T, dzbeta_r)
                    do += np.dot(params['obeta_w'].T, dzbeta_w)
                    # and also through the shift values
                    do += np.dot(params['os_r'].T, dzs_r)
                    do += np.dot(params['os_w'].T, dzs_w)

                # compute deriv w.r.t. pre-activation of o
                dzo = do * (1 - os[t] * os[t])

                deltas['ho'] += np.dot(dzo, hs[t].T)
                deltas['bo'] += dzo

                # compute hidden dh
                dh = np.dot(params['ho'].T, dzo)

                # compute deriv w.r.t. pre-activation of h
                dzh[t] = dh * (1 - hs[t] * hs[t])

                deltas['xh'] += np.dot(dzh[t], xs[t].T)
                deltas['bh'] += dzh[t]

                # Wrh affects zh via rs[t-1]
                deltas['rh'] += np.dot(dzh[t], rs[t - 1].reshape(
                    (self.M, 1)).T)

            return deltas
예제 #3
0
    def fprop(params):
      """
      Forward pass of the NTM.
      """

      W = params # aliasing for brevity

      xs, zhs, hs, ys, ps, ts, zos, os = {}, {}, {}, {}, {}, {}, {}, {}

      def l():
        """
        Silly utility function that should be called in init.
        """
        return [{} for _ in xrange(self.heads)]

      rs = l()
      zk_rs = l()
      k_rs, beta_rs, g_rs, s_rs, gamma_rs = l(),l(),l(),l(),l()
      k_ws, beta_ws, g_ws, s_ws, gamma_ws = l(),l(),l(),l(),l()
      adds, erases = l(),l()
      w_ws, w_rs = l(),l() # read weights and write weights
      for idx in range(self.heads):
        rs[idx][-1] = self.W['rsInit' + str(idx)] # stores values read from memory
        w_ws[idx][-1] = softmax(self.W['w_wsInit' + str(idx)])
        w_rs[idx][-1] = softmax(self.W['w_rsInit' + str(idx)])

      mems = {} # the state of the memory at every timestep
      mems[-1] = self.W['memsInit']
      loss = 0

      for t in xrange(len(inputs)):

        xs[t] = np.reshape(np.array(inputs[t]),inputs[t].shape[::-1])

        rsum = 0
        for idx in range(self.heads):
          rsum = rsum + np.dot(W['rh' + str(idx)], np.reshape(rs[idx][t-1],(self.M,1)))
        zhs[t] = np.dot(W['xh'], xs[t]) + rsum + W['bh']
        hs[t] = np.tanh(zhs[t])

        zos[t] = np.dot(W['ho'], hs[t]) + W['bo']
        os[t] = np.tanh(zos[t])

        for idx in range(self.heads):
          # parameters to the read head
          zk_rs[idx][t] =np.dot(W['ok_r' + str(idx)],os[t]) + W['bk_r' + str(idx)]
          k_rs[idx][t] = np.tanh(zk_rs[idx][t])
          beta_rs[idx][t] = softplus(np.dot(W['obeta_r' + str(idx)],os[t])
                                     + W['bbeta_r' + str(idx)])
          g_rs[idx][t] = sigmoid(np.dot(W['og_r' + str(idx)],os[t]) + W['bg_r' + str(idx)])
          s_rs[idx][t] = softmax(np.dot(W['os_r' + str(idx)],os[t]) + W['bs_r' + str(idx)])
          gamma_rs[idx][t] = 1 + sigmoid(np.dot(W['ogamma_r' + str(idx)], os[t])
                                         + W['bgamma_r' + str(idx)])

          # parameters to the write head
          k_ws[idx][t] = np.tanh(np.dot(W['ok_w' + str(idx)],os[t]) + W['bk_w' + str(idx)])
          beta_ws[idx][t] = softplus(np.dot(W['obeta_w' + str(idx)], os[t])
                                     + W['bbeta_w' + str(idx)])
          g_ws[idx][t] = sigmoid(np.dot(W['og_w' + str(idx)],os[t]) + W['bg_w' + str(idx)])
          s_ws[idx][t] = softmax(np.dot(W['os_w' + str(idx)],os[t]) + W['bs_w' + str(idx)])
          gamma_ws[idx][t] = 1 + sigmoid(np.dot(W['ogamma_w' + str(idx)], os[t])
                                         + W['bgamma_w' + str(idx)])

          # the erase and add vectors
          # these are also parameters to the write head
          # but they describe "what" is to be written rather than "where"
          adds[idx][t] = np.tanh(np.dot(W['oadds' + str(idx)], os[t]) + W['badds' + str(idx)])
          erases[idx][t] = sigmoid(np.dot(W['oerases' + str(idx)], os[t]) + W['erases' + str(idx)])

          w_ws[idx][t] = addressing.create_weights(   k_ws[idx][t]
                                                    , beta_ws[idx][t]
                                                    , g_ws[idx][t]
                                                    , s_ws[idx][t]
                                                    , gamma_ws[idx][t]
                                                    , w_ws[idx][t-1]
                                                    , mems[t-1])

          w_rs[idx][t] = addressing.create_weights(   k_rs[idx][t]
                                                    , beta_rs[idx][t]
                                                    , g_rs[idx][t]
                                                    , s_rs[idx][t]
                                                    , gamma_rs[idx][t]
                                                    , w_rs[idx][t-1]
                                                    , mems[t-1])

        ys[t] = np.dot(W['oy'], os[t]) + W['by']
        ps[t] = sigmoid(ys[t])

        one = np.ones(ps[t].shape)
        ts[t] = np.reshape(np.array(targets[t]),(self.out_size,1))

        epsilon = 2**-23 # to prevent log(0)
        a = np.multiply(ts[t] , np.log2(ps[t] + epsilon))
        b = np.multiply(one - ts[t], np.log2(one-ps[t] + epsilon))
        loss = loss - (a + b)

        for idx in range(self.heads):
          # read from the memory
          rs[idx][t] = memory.read(mems[t-1],w_rs[idx][t])

          # write into the memory
          mems[t] = memory.write(mems[t-1],w_ws[idx][t],erases[idx][t],adds[idx][t])

      self.stats = [loss, mems, ps, ys, os, zos, hs, zhs, xs, rs, w_rs, w_ws, adds, erases]
      return np.sum(loss)
예제 #4
0
파일: ntm.py 프로젝트: DoctorTeeth/diffmem
    def manual_grads(params):
      """
      Compute the gradient of the loss WRT the parameters
      Ordering of the operations is reverse of that in fprop()
      """
      deltas = {}
      for key, val in params.iteritems():
        deltas[key] = np.zeros_like(val)

      [loss, mems, ps, ys, os, zos, hs, zhs, xs, rs, w_rs,
       w_ws, adds, erases, k_rs, k_ws, g_rs, g_ws, wc_rs, wc_ws,
       zbeta_rs, zbeta_ws, zs_rs, zs_ws, wg_rs, wg_ws] = self.stats
      dd = {}
      drs = {}
      dzh = {}
      dmem = {} # might not need this, since we have dmemtilde
      dmemtilde = {}
      du_r = {}
      du_w = {}
      dwg_r = {}
      dwg_w = {}
      for t in reversed(xrange(len(targets))):

        dy = np.copy(ps[t])
        dy -= targets[t].T # backprop into y

        deltas['oy'] += np.dot(dy, os[t].T)
        deltas['by'] += dy

        if t < len(targets) - 1:
          # r[t] affects cost through zh[t+1] via Wrh
          drs[t] = np.dot(self.W['rh'].T, dzh[t + 1])

          # right now, mems[t] influences cost through rs[t+1], via w_rs[t+1]
          dmem[t] = np.dot( w_rs[t + 1], drs[t + 1].reshape((self.M,1)).T )
          # and also through mems at next step
          W = np.reshape(w_ws[t+1], (w_ws[t+1].shape[0], 1))
          E = np.reshape(erases[t+1], (erases[t+1].shape[0], 1))
          WTE = np.dot(W, E.T)
          KEEP = np.ones(mems[0].shape) - WTE
          dmem[t] += np.multiply(dmemtilde[t+1], KEEP)
          # and also through its influence on the content weighting next step
          dmem[t] += du_r[t+1] + du_w[t+1]

          dmemtilde[t] = dmem[t]

          # erases[t] affects cost through mems[t], via w_ws[t]
          derase = np.dot(np.multiply(dmemtilde[t], -mems[t-1]).T, w_ws[t])

          # zerase affects just erases through a sigmoid
          dzerase = derase * (erases[t] * (1 - erases[t]))

          # adds[t] affects costs through mems[t], via w_ws
          dadd = np.dot(dmem[t].T, w_ws[t])

          # zadds affects just adds through a tanh
          dzadd = dadd * (1 - adds[t] * adds[t])

          # dbadds is just dzadds
          deltas['badds'] += dzadd

          deltas['oadds'] += np.dot(dzadd, os[t].T)

          deltas['berases'] += dzerase

          deltas['oerases'] += np.dot(dzerase, os[t].T)

          # # read weights affect what is read, via what's in mems[t-1]
          # dwc_r = np.dot(mems[t-1], drs[t])

          # # write weights affect mem[t] through adding
          # dwc_w = np.dot(dmem[t], adds[t])
          # # they also affect memtilde[t] through erasing
          # dwc_w += np.dot(np.multiply(dmemtilde[t], -mems[t-1]), erases[t])

          dw_r = np.dot(mems[t-1], drs[t])
          dw_r += dwg_r[t+1] * (1 - g_rs[t+1])

          # write weights affect mem[t] through adding
          dw_w = np.dot(dmem[t], adds[t])
          # they also affect memtilde[t] through erasing
          dw_w += np.dot(np.multiply(dmemtilde[t], -mems[t-1]), erases[t])
          dw_w += dwg_w[t+1] * (1 - g_ws[t+1])

          sgwr = np.zeros((self.N, self.N))
          sgww = np.zeros((self.N, self.N))
          for i in range(self.N):
            sgwr[i,i] = softmax(zs_rs[t])[0]
            sgwr[i,(i+1) % self.N] = softmax(zs_rs[t])[2]
            sgwr[i,(i-1) % self.N] = softmax(zs_rs[t])[1]

            sgww[i,i] = softmax(zs_ws[t])[0]
            sgww[i,(i+1) % self.N] = softmax(zs_ws[t])[2]
            sgww[i,(i-1) % self.N] = softmax(zs_ws[t])[1]

          # right now, shifted weights are final weight
          dws_r = dw_r
          dws_w = dw_w

          dwg_r[t] = np.dot(sgwr.T, dws_r)
          dwg_w[t] = np.dot(sgww.T, dws_w)

          dwc_r = dwg_r[t] * g_rs[t]
          dwc_w = dwg_w[t] * g_ws[t]


          """
          We need dw/dK
          now w has N elts and K has N elts
          and we want, for every elt of W, the grad of that elt w.r.t. each
          of the N elts of K. that gives us N * N things
          """
          # first, we must build up the K values (should be taken from fprop)
          K_rs = []
          K_ws = []
          for i in range(self.N):
            K_rs.append(cosine_sim(mems[t-1][i, :], k_rs[t]))
            K_ws.append(cosine_sim(mems[t-1][i, :], k_ws[t]))

          # then, we populate the grads
          dwdK_r = np.zeros((self.N, self.N))
          dwdK_w = np.zeros((self.N, self.N))
          # for every row in the memory
          for i in range(self.N):
            # for every element in the weighting
            for j in range(self.N):
              dwdK_r[i,j] += softmax_grads(K_rs, softplus(zbeta_rs[t]), i, j)
              dwdK_w[i,j] += softmax_grads(K_ws, softplus(zbeta_ws[t]), i, j)

          # compute dK for all i in N
          # K is the evaluated cosine similarity for the i-th row of mem matrix
          dK_r = np.zeros_like(w_rs[0])
          dK_w = np.zeros_like(w_ws[0])

          # for all i in N (for every row that we've simmed)
          for i in range(self.N):
            # for every j in N (for every elt of the weighting)
            for j in range(self.N):
              # specifically, dwdK_r will change, and for write as well
              dK_r[i] += dwc_r[j] * dwdK_r[i,j] 
              dK_w[i] += dwc_w[j] * dwdK_w[i,j]

          """
          dK_r_dk_rs is a list of N things
          each elt of the list corresponds to grads of K_idx
          w.r.t. the key k_t
          so it should be a length N list of M by 1 vectors
          """

          dK_r_dk_rs = []
          dK_r_dmem = []
          for i in range(self.N):
            # let k_rs be u, Mem[i] be v
            u = np.reshape(k_rs[t], (self.M,))
            v = mems[t-1][i, :]
            dK_r_dk_rs.append( dKdu(u,v) )
            dK_r_dmem.append( dKdu(v,u))

          dK_w_dk_ws = []
          dK_w_dmem = []
          for i in range(self.N):
            # let k_ws be u, Mem[i] be v
            u = np.reshape(k_ws[t], (self.M,))
            v = mems[t-1][i, :]
            dK_w_dk_ws.append( dKdu(u,v) )
            dK_w_dmem.append( dKdu(v,u))

          # compute delta for keys
          dk_r = np.zeros_like(k_rs[0])
          dk_w = np.zeros_like(k_ws[0])
          # for every one of M elt of dk_r
          for i in range(self.M):
            # for every one of the N Ks
            for j in range(self.N):
              # add delta K_r[j] * dK_r[j] / dk_r[i]
              # add influence on through K_r[j]
              dk_r[i] += dK_r[j] * dK_r_dk_rs[j][i]
              dk_w[i] += dK_w[j] * dK_w_dk_ws[j][i]

          # these represent influence of mem on next K
          """
          Let's let du_r[t] represent the
          influence of mems[t-1] on the cost through the K values
          this is analogous to dk_w, but, k only every affects that
          whereas mems[t-1] will also affect what is read at time t+1
          and through memtilde at time t+1
          """
          du_r[t] = np.zeros_like(mems[0])
          du_w[t] = np.zeros_like(mems[0])
          # for every row in mems[t-1]
          for i in range(self.N):
            # for every elt of this row (one of M)
            for j in range(self.M):
              du_r[t][i,j] = dK_r[i] * dK_r_dmem[i][j]
              du_w[t][i,j] = dK_w[i] * dK_w_dmem[i][j]

          # key values are activated as tanh
          dzk_r = dk_r * (1 - k_rs[t] * k_rs[t])
          dzk_w = dk_w * (1 - k_ws[t] * k_ws[t])

          deltas['ok_r'] += np.dot(dzk_r, os[t].T)
          deltas['ok_w'] += np.dot(dzk_w, os[t].T)

          deltas['bk_r'] += dzk_r
          deltas['bk_w'] += dzk_w

          dg_r = np.dot(dwg_r[t].T, (wc_rs[t] - w_rs[t-1]) )
          dg_w = np.dot(dwg_w[t].T, (wc_ws[t] - w_ws[t-1]) )

          # compute dzg_r, dzg_w
          dzg_r = dg_r * (g_rs[t] * (1 - g_rs[t]))
          dzg_w = dg_w * (g_ws[t] * (1 - g_ws[t]))

          deltas['og_r'] += np.dot(dzg_r, os[t].T)
          deltas['og_w'] += np.dot(dzg_w, os[t].T)

          deltas['bg_r'] += dzg_r
          deltas['bg_w'] += dzg_w

          # compute dbeta, which affects w_content through interaction with Ks

          dwcdbeta_r = np.zeros_like(w_rs[0])
          dwcdbeta_w = np.zeros_like(w_ws[0])
          for i in range(self.N):
            dwcdbeta_r[i] = beta_grads(K_rs, softplus(zbeta_rs[t]), i)
            dwcdbeta_w[i] = beta_grads(K_ws, softplus(zbeta_ws[t]), i)

          dbeta_r = np.zeros_like(zbeta_rs[0])
          dbeta_w = np.zeros_like(zbeta_ws[0])
          for i in range(self.N):
            dbeta_r[0] += dwc_r[i] * dwcdbeta_r[i]
            dbeta_w[0] += dwc_w[i] * dwcdbeta_w[i]

          # beta is activated from zbeta by softplus, grad of which is sigmoid
          dzbeta_r = dbeta_r * sigmoid(zbeta_rs[t])
          dzbeta_w = dbeta_w * sigmoid(zbeta_ws[t])

          deltas['obeta_r'] += np.dot(dzbeta_r, os[t].T)
          deltas['obeta_w'] += np.dot(dzbeta_w, os[t].T)

          deltas['bbeta_r'] += dzbeta_r
          deltas['bbeta_w'] += dzbeta_w

          sgsr = np.zeros((self.N, 3))
          sgsw = np.zeros((self.N, 3))
          for i in range(self.N):
            sgsr[i,1] = wg_rs[t][(i - 1) % self.N]
            sgsr[i,0] = wg_rs[t][i]
            sgsr[i,2] = wg_rs[t][(i + 1) % self.N]

            sgsw[i,1] = wg_ws[t][(i - 1) % self.N]
            sgsw[i,0] = wg_ws[t][i]
            sgsw[i,2] = wg_ws[t][(i + 1) % self.N]

          ds_r = np.dot(sgsr.T, dws_r)
          ds_w = np.dot(sgsw.T, dws_w)

          shift_act_jac_r = np.zeros((3,3))
          shift_act_jac_w = np.zeros((3,3))
          bf = np.array([[1.0]])
          for i in range(3):
            for j in range(3):
              shift_act_jac_r[i,j] = softmax_grads(zs_rs[t], bf, i, j)
              shift_act_jac_w[i,j] = softmax_grads(zs_ws[t], bf, i, j)

          dzs_r = np.dot(shift_act_jac_r.T, ds_r)
          dzs_w = np.dot(shift_act_jac_w.T, ds_w)

          deltas['os_r'] += np.dot(dzs_r, os[t].T)
          deltas['os_w'] += np.dot(dzs_w, os[t].T)

          deltas['bs_r'] += dzs_r
          deltas['bs_w'] += dzs_w

        else:
          drs[t] = np.zeros_like(rs[0])
          dmemtilde[t] = np.zeros_like(mems[0])
          du_r[t] = np.zeros_like(mems[0])
          du_w[t] = np.zeros_like(mems[0])
          dwg_r[t] = np.zeros_like(w_rs[0])
          dwg_w[t] = np.zeros_like(w_ws[0])

        # o affects y through Woy
        do = np.dot(params['oy'].T, dy)
        if t < len(targets) - 1:
          # and also zadd through Woadds
          do += np.dot(params['oadds'].T, dzadd)
          do += np.dot(params['oerases'].T, dzerase)
          # and also through the keys
          do += np.dot(params['ok_r'].T, dzk_r)
          do += np.dot(params['ok_w'].T, dzk_w)
          # and also through the interpolators
          do += np.dot(params['og_r'].T, dzg_r)
          do += np.dot(params['og_w'].T, dzg_w)
          # and also through beta
          do += np.dot(params['obeta_r'].T, dzbeta_r)
          do += np.dot(params['obeta_w'].T, dzbeta_w)
          # and also through the shift values
          do += np.dot(params['os_r'].T, dzs_r)
          do += np.dot(params['os_w'].T, dzs_w)


        # compute deriv w.r.t. pre-activation of o
        dzo = do * (1 - os[t] * os[t])

        deltas['ho'] += np.dot(dzo, hs[t].T)
        deltas['bo'] += dzo

        # compute hidden dh
        dh = np.dot(params['ho'].T, dzo)

        # compute deriv w.r.t. pre-activation of h
        dzh[t] = dh * (1 - hs[t] * hs[t])

        deltas['xh'] += np.dot(dzh[t], xs[t].T)
        deltas['bh'] += dzh[t]

        # Wrh affects zh via rs[t-1]
        deltas['rh'] += np.dot(dzh[t], rs[t-1].reshape((self.M, 1)).T)

      return deltas