Пример #1
0
 def recwmask(self, x_t, m_t, *states):   # m_t: (batsize, ), x_t: (batsize, dim), states: (batsize, **somedim**)
     recout = self.rec(x_t, *states)
     y_t = recout[0]
     newstates = recout[1:]
     y_tm1 = T.zeros_like(y_t)
     y_tm1 = states[0]               # TODO: beware with multiple layers (here will be the bottom first)
     y_t_out = (y_t.T * m_t + y_tm1.T * (1 - m_t)).T
     states_out = [(a.T * m_t + b.T * (1 - m_t)).T for a, b in zip(newstates, states)]
     return [y_t_out] + states_out
Пример #2
0
 def recwmask(
     self, x_t, m_t, *states
 ):  # m_t: (batsize, ), x_t: (batsize, dim), states: (batsize, **somedim**)
     recout = self.rec(x_t, *states)
     y_t = recout[0]
     newstates = recout[1:]
     y_tm1 = T.zeros_like(y_t)
     y_tm1 = states[
         0]  # TODO: beware with multiple layers (here will be the bottom first)
     y_t_out = (y_t.T * m_t + y_tm1.T * (1 - m_t)).T
     states_out = [(a.T * m_t + b.T * (1 - m_t)).T
                   for a, b in zip(newstates, states)]
     return [y_t_out] + states_out
Пример #3
0
    def rec(self, x_t, m_tm1, mem_tm1, h_tm1):
        """
        :param x_t:     current input vector: (batsize, inp_dim)
        :param h_tm1:   previous state vector: (batsize, state_dim)
        :param m_tm1:   previous memory content vector: (batsize, state_dim)
        :param mem_tm1: previous memory state: (batsize, mem_size, state_dim)
        :return:    (y_t, h_t, m_t, mem_t)
        """
        # read memory
        memory_addr_gate1 =     self.gateactivation(T.dot(h_tm1, self.uma) + T.dot(x_t, self.wma) + T.dot(m_tm1, self.mma) + self.bma)
        memory_addr_gate2 =     self.gateactivation(T.dot(h_tm1, self.uma2) + T.dot(x_t, self.wma2) + T.dot(m_tm1, self.mma2) + self.bma2)
        memaddrcan         =    memory_addr_gate1 * h_tm1 +      (1 - memory_addr_gate1) * m_tm1
        memaddr =               memory_addr_gate2 * memaddrcan + (1 - memory_addr_gate2) * x_t      # TODO: ERROR HERE: x_t shape incompatible with internal shapes
        memsel = self.attgen(memaddr, mem_tm1)
        m_t = self.attcon(mem_tm1, memsel)

        # update inner stuff
        state_filter_gate =     self.gateactivation(T.dot(h_tm1, self.usf) + T.dot(x_t, self.wsf) + T.dot(m_t, self.msf) + self.bsf)
        memory_filter_gate =    self.gateactivation(T.dot(h_tm1, self.umf) + T.dot(x_t, self.wmf) + T.dot(m_t, self.mmf) + self.bmf)
        input_filter_gate =     self.gateactivation(T.dot(h_tm1, self.uif) + T.dot(x_t, self.wif) + T.dot(m_t, self.mif) + self.bif)
        update_gate     =       self.gateactivation(T.dot(h_tm1, self.uug) + T.dot(x_t, self.wug) + T.dot(m_t, self.mug) + self.bug)

        # compute new state
        h_tm1_filtered = T.dot(state_filter_gate * h_tm1, self.u)
        x_t_filtered =   T.dot(input_filter_gate * x_t, self.w)
        m_t_filtered = T.dot(memory_filter_gate * m_t, self.m)
        h_t_can = self.outpactivation(h_tm1_filtered + x_t_filtered + m_t_filtered + self.b)
        h_t = update_gate * h_tm1 + (1 - update_gate) * h_t_can

        # write memory
        memory_write_filter=    self.gateactivation(T.dot(h_tm1, self.uwf) + T.dot(x_t, self.wwf) + T.dot(m_t, self.mwf) + self.bwf)    # (batsize, state_dim)
        if self.discrete:       # memsel: (batsize, mem_size)
            memseln = T.zeros_like(memsel)
            memsel = T.argmax(memsel, axis=1)
            memseln[T.arange(memsel.shape[0]), memsel] = 1.0        # TODO: doesn't work
            memsel = memseln

        memwritesel = T.batched_tensordot(memsel, memory_write_filter, axes=0)  # (batsize, mem_size, state_dim)
        h_t_rep = h_t.reshape((h_t.shape[0], 1, h_t.shape[1])).repeat(mem_tm1.shape[1], axis=1)
        mem_t = memwritesel * mem_tm1 + (1 - memwritesel) * h_t_rep
        return [h_t, m_t, mem_t, h_t]
Пример #4
0
 def apply(self, x, mask=None):  # (batsize, seqlen, dim)
     mask = x.mask if mask is None else mask
     if mask is not None:
         assert (mask.ndim == x.ndim - 1)
         realm = T.tensordot(mask, T.ones((x.shape[-1], )), 0)
         if self.mode == "max":
             x = T.switch(realm, x, np.infty * (realm - 1))
         else:
             x = x * realm
     if self.mode == "max":
         ret = T.max(x, axis=-2)
     elif self.mode == "sum":
         ret = T.sum(x, axis=-2)
     elif self.mode == "avg":
         ret = T.sum(x, axis=-2) / x.shape[-2]
     else:
         raise Exception("unknown pooling mode: {:3s}".format(self.mode))
     # ret: (batsize, dim)
     if mask is not None:
         mask = 1 * (T.sum(mask, axis=-1) > 0)
         ret = T.switch(T.tensordot(mask, T.ones((x.shape[-1], )), 0), ret,
                        T.zeros_like(ret))
         ret.mask = mask
     return ret
Пример #5
0
    def rec(self, x_t, m_tm1, mem_tm1, h_tm1):
        """
        :param x_t:     current input vector: (batsize, inp_dim)
        :param h_tm1:   previous state vector: (batsize, state_dim)
        :param m_tm1:   previous memory content vector: (batsize, state_dim)
        :param mem_tm1: previous memory state: (batsize, mem_size, state_dim)
        :return:    (y_t, h_t, m_t, mem_t)
        """
        # read memory
        memory_addr_gate1 = self.gateactivation(
            T.dot(h_tm1, self.uma) + T.dot(x_t, self.wma) +
            T.dot(m_tm1, self.mma) + self.bma)
        memory_addr_gate2 = self.gateactivation(
            T.dot(h_tm1, self.uma2) + T.dot(x_t, self.wma2) +
            T.dot(m_tm1, self.mma2) + self.bma2)
        memaddrcan = memory_addr_gate1 * h_tm1 + (1 -
                                                  memory_addr_gate1) * m_tm1
        memaddr = memory_addr_gate2 * memaddrcan + (
            1 - memory_addr_gate2
        ) * x_t  # TODO: ERROR HERE: x_t shape incompatible with internal shapes
        memsel = self.attgen(memaddr, mem_tm1)
        m_t = self.attcon(mem_tm1, memsel)

        # update inner stuff
        state_filter_gate = self.gateactivation(
            T.dot(h_tm1, self.usf) + T.dot(x_t, self.wsf) +
            T.dot(m_t, self.msf) + self.bsf)
        memory_filter_gate = self.gateactivation(
            T.dot(h_tm1, self.umf) + T.dot(x_t, self.wmf) +
            T.dot(m_t, self.mmf) + self.bmf)
        input_filter_gate = self.gateactivation(
            T.dot(h_tm1, self.uif) + T.dot(x_t, self.wif) +
            T.dot(m_t, self.mif) + self.bif)
        update_gate = self.gateactivation(
            T.dot(h_tm1, self.uug) + T.dot(x_t, self.wug) +
            T.dot(m_t, self.mug) + self.bug)

        # compute new state
        h_tm1_filtered = T.dot(state_filter_gate * h_tm1, self.u)
        x_t_filtered = T.dot(input_filter_gate * x_t, self.w)
        m_t_filtered = T.dot(memory_filter_gate * m_t, self.m)
        h_t_can = self.outpactivation(h_tm1_filtered + x_t_filtered +
                                      m_t_filtered + self.b)
        h_t = update_gate * h_tm1 + (1 - update_gate) * h_t_can

        # write memory
        memory_write_filter = self.gateactivation(
            T.dot(h_tm1, self.uwf) + T.dot(x_t, self.wwf) +
            T.dot(m_t, self.mwf) + self.bwf)  # (batsize, state_dim)
        if self.discrete:  # memsel: (batsize, mem_size)
            memseln = T.zeros_like(memsel)
            memsel = T.argmax(memsel, axis=1)
            memseln[T.arange(memsel.shape[0]),
                    memsel] = 1.0  # TODO: doesn't work
            memsel = memseln

        memwritesel = T.batched_tensordot(
            memsel, memory_write_filter,
            axes=0)  # (batsize, mem_size, state_dim)
        h_t_rep = h_t.reshape(
            (h_t.shape[0], 1, h_t.shape[1])).repeat(mem_tm1.shape[1], axis=1)
        mem_t = memwritesel * mem_tm1 + (1 - memwritesel) * h_t_rep
        return [h_t, m_t, mem_t, h_t]