Пример #1
0
class InputlessLstmBlock(object):
    """
    Inputless A long short-term memory (LSTM) block.

    Parameters
    ----------
    R
    b
    grad_clipping
    x
    mask
    prev_c
    prev_h
    device_id : int
        Defines the device's id on which the computation will take place


    Returns
    -------
    """
    def __init__(self,
                 R,
                 b,
                 grad_clipping,
                 mask,
                 prev_c,
                 prev_h,
                 device_id=None):
        self.f_context = Context(device_id)
        device_id = self.f_context.device_id
        if R.bpropagable:
            self.R, self.dL_dR = R.register_usage(device_id, device_id)
            self.R_b_context = Context(device_id)
        else:
            self.R = R.register_usage(device_id)
        if b.bpropagable:
            self.b, self.dL_db = b.register_usage(device_id, device_id)
            self.b_b_context = Context(device_id)
        else:
            self.b = b.register_usage(device_id)
        self.grad_clipping = grad_clipping
        if mask:
            self.mask = mask.register_usage(device_id)
        if prev_c.bpropagable:
            self.prev_c, self.dL_dprev_c = prev_c.register_usage(
                device_id, device_id)
        else:
            self.prev_c = prev_c.register_usage(device_id)
        if prev_h.bpropagable:
            self.prev_h, self.dL_dprev_h = prev_h.register_usage(
                device_id, device_id)
        else:
            self.prev_h = prev_h.register_usage(device_id)
        self.learning = R.bpropagable or prev_c.bpropagable or prev_h.bpropagable
        if self.learning:
            self.b_context = Context(device_id)

        dim = self.R.nrows
        batch_size = self.prev_c.nrows

        self.zifo = Matrix.empty(batch_size, 4 * dim, device_id=device_id)
        self.z = self.zifo[:, 0 * dim:1 * dim]
        self.i = self.zifo[:, 1 * dim:2 * dim]
        self.f = self.zifo[:, 2 * dim:3 * dim]
        self.o = self.zifo[:, 3 * dim:4 * dim]
        self.c = Matrix.empty_like(self.prev_c, device_id)
        self.c = Connector(self.c, device_id if self.learning else None)
        self.tanh_c = Matrix.empty_like(self.c, device_id)
        self.h = Matrix.empty_like(self.c, device_id)
        self.h = Connector(self.h, device_id if self.learning else None)

        if self.learning:
            self._dzifo_dpre_zifo = Matrix.empty_like(self.zifo)
            self.dz_dpre_z = self._dzifo_dpre_zifo[:, 0 * dim:1 * dim]
            self.di_dpre_i = self._dzifo_dpre_zifo[:, 1 * dim:2 * dim]
            self.df_dpre_f = self._dzifo_dpre_zifo[:, 2 * dim:3 * dim]
            self.do_dpre_o = self._dzifo_dpre_zifo[:, 3 * dim:4 * dim]
            self.dL_dpre_zifo = self._dzifo_dpre_zifo
            self.dL_dpre_z = self.dz_dpre_z
            self.dL_dpre_i = self.di_dpre_i
            self.dL_dpre_f = self.df_dpre_f
            self.dL_dpre_o = self.do_dpre_o
            self._dtanh_c_dc = Matrix.empty_like(self.c)

    @property
    def dzifo_dpre_zifo(self):
        if self.learning:
            return self._dzifo_dpre_zifo

    @property
    def dtanh_c_dc(self):
        if self.learning:
            return self._dtanh_c_dc

    def fprop(self):
        # zifo = tanh_sigm(h[t-1] * R + b)
        self.zifo.assign_dot(self.f_context, self.prev_h, self.R)
        self.zifo.add(self.f_context, self.b)
        self.zifo.tanh_sigm(self.f_context,
                            self.zifo,
                            self.dzifo_dpre_zifo,
                            axis=1)

        # c[t] = i[t] .* z[t] + f[t] .* c[t-1]
        # h[t] = o[t] .* tanh(c[t])
        self.c.assign_sum_hprod(self.f_context, self.i, self.z, self.f,
                                self.prev_c)
        self.c.tanh(self.f_context, self.tanh_c, self.dtanh_c_dc)
        self.h.assign_hprod(self.f_context, self.o, self.tanh_c)
        if hasattr(self, 'mask'):
            # s[t] = mask .* s[t] + (1 - mask) .* s[t-1]
            self.c.assign_masked_addition(self.f_context, self.mask, self.c,
                                          self.prev_c)
            self.h.assign_masked_addition(self.f_context, self.mask, self.h,
                                          self.prev_h)
        self.c.fprop()
        self.h.fprop()

    def bprop(self):
        if not self.learning:
            return
        dL_dc = self.c.backward_matrix
        dL_dh = self.h.backward_matrix
        if hasattr(self, 'mask'):
            # dL/ds[t-1] = (1 - mask) .* dL/ds[t]
            # dL/ds[t] = mask .* dL/ds[t]
            if hasattr(self, 'dL_dprev_c'):
                self.dL_dprev_c.add_hprod_one_minus_mask(
                    self.b_context, self.mask, dL_dc)
            dL_dc.hprod(self.b_context, self.mask)
            if hasattr(self, 'dL_dprev_h'):
                self.dL_dprev_h.add_hprod_one_minus_mask(
                    self.b_context, self.mask, dL_dh)
            dL_dh.hprod(self.b_context, self.mask)
        # dL/dc[t] = dL[t+1]/dc[t] + dL/dh[t] .* o[t] .* dtanh(c[t])/dc[t]
        dL_dc.add_hprod(self.b_context, dL_dh, self.o, self.dtanh_c_dc)

        # dL/dpre_o[t] = dL/dh[t] .* tanh(c[t]) .* do[t]/dpre_o[t]
        # dL/dpre_f[t] = dL/dc[t] .* c[t-1] .* df[t]/dpre_f[t]
        # dL/dpre_i[t] = dL/dc[t] .* z[t] .* di[t]/dpre_i[t]
        # dL/dpre_z[t] = dL/dc[t] .* i[t] .* dz[t]/dpre_z[t]
        self.dL_dpre_o.assign_hprod(self.b_context, dL_dh, self.tanh_c,
                                    self.do_dpre_o)
        self.dL_dpre_f.assign_hprod(self.b_context, dL_dc, self.prev_c,
                                    self.df_dpre_f)
        self.dL_dpre_i.assign_hprod(self.b_context, dL_dc, self.z,
                                    self.di_dpre_i)
        self.dL_dpre_z.assign_hprod(self.b_context, dL_dc, self.i,
                                    self.dz_dpre_z)
        self.dL_dpre_zifo.last_modif_context = self.b_context

        if self.grad_clipping:
            self.dL_dpre_zifo.clip(self.b_context, -self.grad_clipping,
                                   self.grad_clipping)

        if hasattr(self, 'dL_dR'):
            # dL_dR += h[t-1].T * dL/dpre_zifo[t]
            self.dL_dR.add_dot(self.R_b_context, self.prev_h,
                               self.dL_dpre_zifo, 'T')
        if hasattr(self, 'dL_db'):
            # dL_db += sum(dL/dpre_zifo[t], axis=0)
            self.dL_db.add_repeat_derivative(self.b_b_context,
                                             self.dL_dpre_zifo,
                                             self.dL_dpre_zifo.nrows,
                                             axis=0)
        if hasattr(self, 'dL_dprev_c'):
            # dL/dc[t-1] = f[t] .* dL/dc[t]
            self.dL_dprev_c.add_hprod(self.b_context, self.f, dL_dc)
        if hasattr(self, 'dL_dprev_h'):
            # dL/dh[t-1] = dL/dpre_zifo[t] * R.T
            self.dL_dprev_h.add_dot(self.b_context, self.dL_dpre_zifo, self.R,
                                    'N', 'T')
Пример #2
0
class LstmBlock(object):
    """
    A long short-term memory (LSTM) block.

    Parameters
    ----------
    W
    R
    b
    grad_clipping
    x
    mask
    prev_c
    prev_h
    device_id : int
        Defines the device's id on which the computation will take place


    Returns
    -------
    """
    def __init__(self, W, R, b, grad_clipping, x, mask, prev_c, prev_h, device_id=None):
        self.f_context = Context(device_id)
        device_id = self.f_context.device_id
        if W.bpropagable:
            self.W, self.dL_dW = W.register_usage(device_id, device_id)
            self.W_b_context = Context(device_id)
        else:
            self.W = W.register_usage(device_id)
        if R.bpropagable:
            self.R, self.dL_dR = R.register_usage(device_id, device_id)
            self.R_b_context = Context(device_id)
        else:
            self.R = R.register_usage(device_id)
        if b.bpropagable:
            self.b, self.dL_db = b.register_usage(device_id, device_id)
            self.b_b_context = Context(device_id)
        else:
            self.b = b.register_usage(device_id)
        self.grad_clipping = grad_clipping
        if x.bpropagable:
            self.x, self.dL_dx = x.register_usage(device_id, device_id)
            self.x_b_context = Context(device_id)
        else:
            self.x = x.register_usage(device_id)
        if mask:
            self.mask = mask.register_usage(device_id)
        if prev_c.bpropagable:
            self.prev_c, self.dL_dprev_c = prev_c.register_usage(device_id, device_id)
            self.prev_c_b_context = Context(device_id)
        else:
            self.prev_c = prev_c.register_usage(device_id)
        if prev_h.bpropagable:
            self.prev_h, self.dL_dprev_h = prev_h.register_usage(device_id, device_id)
            self.prev_h_b_context = Context(device_id)
        else:
            self.prev_h = prev_h.register_usage(device_id)
        self.learning = W.bpropagable or R.bpropagable or x.bpropagable or \
                        prev_c.bpropagable or prev_h.bpropagable
        if self.learning:
            self.b_context = Context(device_id)

        dim = self.R.nrows
        batch_size = self.x.nrows

        self.zifo = Matrix.empty(batch_size, 4 * dim, device_id=device_id)
        self.z = self.zifo[:, 0*dim:1*dim]
        self.i = self.zifo[:, 1*dim:2*dim]
        self.f = self.zifo[:, 2*dim:3*dim]
        self.o = self.zifo[:, 3*dim:4*dim]
        self.c = Matrix.empty_like(self.prev_c, device_id)
        self.c = Connector(self.c, device_id if self.learning else None)
        self.tanh_c = Matrix.empty_like(self.c, device_id)
        self.h = Matrix.empty_like(self.c, device_id)
        self.h = Connector(self.h, device_id if self.learning else None)

        if self.learning:
            self._dzifo_dpre_zifo = Matrix.empty_like(self.zifo)
            self.dz_dpre_z = self._dzifo_dpre_zifo[:, 0*dim:1*dim]
            self.di_dpre_i = self._dzifo_dpre_zifo[:, 1*dim:2*dim]
            self.df_dpre_f = self._dzifo_dpre_zifo[:, 2*dim:3*dim]
            self.do_dpre_o = self._dzifo_dpre_zifo[:, 3*dim:4*dim]
            self.dL_dpre_zifo = self._dzifo_dpre_zifo
            self.dL_dpre_z = self.dz_dpre_z
            self.dL_dpre_i = self.di_dpre_i
            self.dL_dpre_f = self.df_dpre_f
            self.dL_dpre_o = self.do_dpre_o
            self._dtanh_c_dc = Matrix.empty_like(self.c)

    @property
    def dzifo_dpre_zifo(self):
        if self.learning:
            return self._dzifo_dpre_zifo

    @property
    def dtanh_c_dc(self):
        if self.learning:
            return self._dtanh_c_dc

    def fprop(self):
        # zifo = tanh_sigm(x[t] * W + h[t-1] * R + b)
        self.zifo.assign_dot(self.f_context, self.x, self.W)
        self.zifo.add_dot(self.f_context, self.prev_h, self.R)
        self.zifo.add(self.f_context, self.b)
        self.zifo.tanh_sigm(self.f_context, self.zifo, self.dzifo_dpre_zifo, axis=1)

        # c[t] = i[t] .* z[t] + f[t] .* c[t-1]
        # h[t] = o[t] .* tanh(c[t])
        self.c.assign_sum_hprod(self.f_context, self.i, self.z, self.f, self.prev_c)
        self.c.tanh(self.f_context, self.tanh_c, self.dtanh_c_dc)
        self.h.assign_hprod(self.f_context, self.o, self.tanh_c)
        if hasattr(self, 'mask'):
            # s[t] = mask .* s[t] + (1 - mask) .* s[t-1]
            self.c.assign_masked_addition(self.f_context, self.mask, self.c, self.prev_c)
            self.h.assign_masked_addition(self.f_context, self.mask, self.h, self.prev_h)
        self.c.fprop()
        self.h.fprop()

    def bprop(self):
        dL_dc = self.c.backward_matrix
        dL_dh = self.h.backward_matrix
        if hasattr(self, 'mask'):
            # dL/ds[t-1] = (1 - mask) .* dL/ds[t]
            # dL/ds[t] = mask .* dL/ds[t]
            if hasattr(self, 'dL_dprev_c'):
                self.dL_dprev_c.add_hprod_one_minus_mask(self.prev_c_b_context, self.mask, dL_dc)
            dL_dc.hprod(self.prev_c_b_context, self.mask)
            if hasattr(self, 'dL_dprev_h'):
                self.dL_dprev_h.add_hprod_one_minus_mask(self.prev_h_b_context, self.mask, dL_dh)
            dL_dh.hprod(self.prev_h_b_context, self.mask)
        # dL/dc[t] = dL[t+1]/dc[t] + dL/dh[t] .* o[t] .* dtanh(c[t])/dc[t]
        dL_dc.add_hprod(self.b_context, dL_dh, self.o, self.dtanh_c_dc)

        # self.dzifo_dpre_zifo was calculated in self.f_context,
        # now we have to explicitly wait it in context self.b_context, because
        # self.dx_dpre_x does not have proper last_modif_context
        self.b_context.wait(self.f_context)
        # dL/dpre_o[t] = dL/dh[t] .* tanh(c[t]) .* do[t]/dpre_o[t]
        # dL/dpre_f[t] = dL/dc[t] .* c[t-1] .* df[t]/dpre_f[t]
        # dL/dpre_i[t] = dL/dc[t] .* z[t] .* di[t]/dpre_i[t]
        # dL/dpre_z[t] = dL/dc[t] .* i[t] .* dz[t]/dpre_z[t]
        self.dL_dpre_o.assign_hprod(self.b_context, dL_dh, self.tanh_c, self.do_dpre_o)
        self.dL_dpre_f.assign_hprod(self.b_context, dL_dc, self.prev_c, self.df_dpre_f)
        self.dL_dpre_i.assign_hprod(self.b_context, dL_dc, self.z, self.di_dpre_i)
        self.dL_dpre_z.assign_hprod(self.b_context, dL_dc, self.i, self.dz_dpre_z)
        if self.grad_clipping:
            self.dL_dpre_zifo.clip(self.b_context, -self.grad_clipping, self.grad_clipping)
        else:
            self.dL_dpre_zifo.last_modif_context = self.b_context
        if hasattr(self, 'dL_dW'):
            # dL_dW += x[t].T * dL/dpre_zifo[t]
            self.dL_dW.add_dot(self.W_b_context, self.x, self.dL_dpre_zifo, 'T')
        if hasattr(self, 'dL_dR'):
            # dL_dR += h[t-1].T * dL/dpre_zifo[t]
            self.dL_dR.add_dot(self.R_b_context, self.prev_h, self.dL_dpre_zifo, 'T')
        if hasattr(self, 'dL_db'):
            # dL_db += sum(dL/dpre_zifo[t], axis=0)
            self.dL_db.add_repeat_derivative(self.b_b_context, self.dL_dpre_zifo, self.dL_dpre_zifo.nrows, axis=0)
        if hasattr(self, 'dL_dx'):
            # dL/dx[t] = dL/dpre_zifo[t] * W.T
            self.dL_dx.add_dot(self.x_b_context, self.dL_dpre_zifo, self.W, 'N', 'T')
        if hasattr(self, 'dL_dprev_c'):
            # dL/dc[t-1] = f[t] .* dL/dc[t]
            self.dL_dprev_c.add_hprod(self.prev_c_b_context, self.f, dL_dc)
        if hasattr(self, 'dL_dprev_h'):
            # dL/dh[t-1] = dL/dpre_zifo[t] * R.T
            self.dL_dprev_h.add_dot(self.prev_h_b_context, self.dL_dpre_zifo, self.R, 'N', 'T')