示例#1
0
    def rec(self, x_t, m_tm1, mem_tm1, h_tm1):
        """
        :param x_t:     current input vector: (batsize, inp_dim)
        :param h_tm1:   previous state vector: (batsize, state_dim)
        :param m_tm1:   previous memory content vector: (batsize, state_dim)
        :param mem_tm1: previous memory state: (batsize, mem_size, state_dim)
        :return:    (y_t, h_t, m_t, mem_t)
        """
        # read memory
        memory_addr_gate1 =     self.gateactivation(T.dot(h_tm1, self.uma) + T.dot(x_t, self.wma) + T.dot(m_tm1, self.mma) + self.bma)
        memory_addr_gate2 =     self.gateactivation(T.dot(h_tm1, self.uma2) + T.dot(x_t, self.wma2) + T.dot(m_tm1, self.mma2) + self.bma2)
        memaddrcan         =    memory_addr_gate1 * h_tm1 +      (1 - memory_addr_gate1) * m_tm1
        memaddr =               memory_addr_gate2 * memaddrcan + (1 - memory_addr_gate2) * x_t      # TODO: ERROR HERE: x_t shape incompatible with internal shapes
        memsel = self.attgen(memaddr, mem_tm1)
        m_t = self.attcon(mem_tm1, memsel)

        # update inner stuff
        state_filter_gate =     self.gateactivation(T.dot(h_tm1, self.usf) + T.dot(x_t, self.wsf) + T.dot(m_t, self.msf) + self.bsf)
        memory_filter_gate =    self.gateactivation(T.dot(h_tm1, self.umf) + T.dot(x_t, self.wmf) + T.dot(m_t, self.mmf) + self.bmf)
        input_filter_gate =     self.gateactivation(T.dot(h_tm1, self.uif) + T.dot(x_t, self.wif) + T.dot(m_t, self.mif) + self.bif)
        update_gate     =       self.gateactivation(T.dot(h_tm1, self.uug) + T.dot(x_t, self.wug) + T.dot(m_t, self.mug) + self.bug)

        # compute new state
        h_tm1_filtered = T.dot(state_filter_gate * h_tm1, self.u)
        x_t_filtered =   T.dot(input_filter_gate * x_t, self.w)
        m_t_filtered = T.dot(memory_filter_gate * m_t, self.m)
        h_t_can = self.outpactivation(h_tm1_filtered + x_t_filtered + m_t_filtered + self.b)
        h_t = update_gate * h_tm1 + (1 - update_gate) * h_t_can

        # write memory
        memory_write_filter=    self.gateactivation(T.dot(h_tm1, self.uwf) + T.dot(x_t, self.wwf) + T.dot(m_t, self.mwf) + self.bwf)    # (batsize, state_dim)
        if self.discrete:       # memsel: (batsize, mem_size)
            memseln = T.zeros_like(memsel)
            memsel = T.argmax(memsel, axis=1)
            memseln[T.arange(memsel.shape[0]), memsel] = 1.0        # TODO: doesn't work
            memsel = memseln

        memwritesel = T.batched_tensordot(memsel, memory_write_filter, axes=0)  # (batsize, mem_size, state_dim)
        h_t_rep = h_t.reshape((h_t.shape[0], 1, h_t.shape[1])).repeat(mem_tm1.shape[1], axis=1)
        mem_t = memwritesel * mem_tm1 + (1 - memwritesel) * h_t_rep
        return [h_t, m_t, mem_t, h_t]
示例#2
0
    def rec(self, x_t, m_tm1, mem_tm1, h_tm1):
        """
        :param x_t:     current input vector: (batsize, inp_dim)
        :param h_tm1:   previous state vector: (batsize, state_dim)
        :param m_tm1:   previous memory content vector: (batsize, state_dim)
        :param mem_tm1: previous memory state: (batsize, mem_size, state_dim)
        :return:    (y_t, h_t, m_t, mem_t)
        """
        # read memory
        memory_addr_gate1 = self.gateactivation(
            T.dot(h_tm1, self.uma) + T.dot(x_t, self.wma) +
            T.dot(m_tm1, self.mma) + self.bma)
        memory_addr_gate2 = self.gateactivation(
            T.dot(h_tm1, self.uma2) + T.dot(x_t, self.wma2) +
            T.dot(m_tm1, self.mma2) + self.bma2)
        memaddrcan = memory_addr_gate1 * h_tm1 + (1 -
                                                  memory_addr_gate1) * m_tm1
        memaddr = memory_addr_gate2 * memaddrcan + (
            1 - memory_addr_gate2
        ) * x_t  # TODO: ERROR HERE: x_t shape incompatible with internal shapes
        memsel = self.attgen(memaddr, mem_tm1)
        m_t = self.attcon(mem_tm1, memsel)

        # update inner stuff
        state_filter_gate = self.gateactivation(
            T.dot(h_tm1, self.usf) + T.dot(x_t, self.wsf) +
            T.dot(m_t, self.msf) + self.bsf)
        memory_filter_gate = self.gateactivation(
            T.dot(h_tm1, self.umf) + T.dot(x_t, self.wmf) +
            T.dot(m_t, self.mmf) + self.bmf)
        input_filter_gate = self.gateactivation(
            T.dot(h_tm1, self.uif) + T.dot(x_t, self.wif) +
            T.dot(m_t, self.mif) + self.bif)
        update_gate = self.gateactivation(
            T.dot(h_tm1, self.uug) + T.dot(x_t, self.wug) +
            T.dot(m_t, self.mug) + self.bug)

        # compute new state
        h_tm1_filtered = T.dot(state_filter_gate * h_tm1, self.u)
        x_t_filtered = T.dot(input_filter_gate * x_t, self.w)
        m_t_filtered = T.dot(memory_filter_gate * m_t, self.m)
        h_t_can = self.outpactivation(h_tm1_filtered + x_t_filtered +
                                      m_t_filtered + self.b)
        h_t = update_gate * h_tm1 + (1 - update_gate) * h_t_can

        # write memory
        memory_write_filter = self.gateactivation(
            T.dot(h_tm1, self.uwf) + T.dot(x_t, self.wwf) +
            T.dot(m_t, self.mwf) + self.bwf)  # (batsize, state_dim)
        if self.discrete:  # memsel: (batsize, mem_size)
            memseln = T.zeros_like(memsel)
            memsel = T.argmax(memsel, axis=1)
            memseln[T.arange(memsel.shape[0]),
                    memsel] = 1.0  # TODO: doesn't work
            memsel = memseln

        memwritesel = T.batched_tensordot(
            memsel, memory_write_filter,
            axes=0)  # (batsize, mem_size, state_dim)
        h_t_rep = h_t.reshape(
            (h_t.shape[0], 1, h_t.shape[1])).repeat(mem_tm1.shape[1], axis=1)
        mem_t = memwritesel * mem_tm1 + (1 - memwritesel) * h_t_rep
        return [h_t, m_t, mem_t, h_t]
示例#3
0
 def apply(self, data, weights):
     bestidx = T.argmax(weights, axis=1)
     return data[T.arange(bestidx.shape[0]), bestidx, :]
示例#4
0
 def apply(self, data, weights):
     bestidx = T.argmax(weights, axis=1)
     return data[T.arange(bestidx.shape[0]), bestidx, :]