コード例 #1
0
ファイル: mac.py プロジェクト: ririw/mac
    def process_qn(self, questions, qn_lens, batch_size):
        if isinstance(qn_lens, list):
            qn_lens = torch.from_numpy(np.array(qn_lens, dtype=np.int64))
        debug_helpers.check_shape(questions, (batch_size, None))
        debug_helpers.check_shape(qn_lens, (batch_size,))

        qn_tensors = self.embedding(questions)
        debug_helpers.check_shape(qn_tensors, (batch_size, None, self.ctrl_dim))
        packed_embedded = rnnutils.pack_padded_sequence(
            qn_tensors, qn_lens, batch_first=True
        )

        h0_c0_size = (2, batch_size, self.ctrl_dim)
        h0 = self.lstm_h0.expand(h0_c0_size).contiguous()
        c0 = self.lstm_c0.expand(h0_c0_size).contiguous()

        lstm_out, (hn, _) = self.lstm_processor(packed_embedded, (h0, c0))
        padded_lstm, _ = rnnutils.pad_packed_sequence(lstm_out, batch_first=True)
        proj_lstm = self.lstm_proj(padded_lstm)

        hn_concat = torch.cat([hn[0], hn[1]], -1)
        debug_helpers.check_shape(proj_lstm, (batch_size, None, self.ctrl_dim))
        debug_helpers.check_shape(hn_concat, (batch_size, self.ctrl_dim * 2))
        return proj_lstm, hn_concat
コード例 #2
0
ファイル: mac.py プロジェクト: ririw/mac
    def forward(self, question_words, image_vec, context_words):
        batch_size = context_words.shape[0]

        debug_helpers.check_shape(context_words, (batch_size, None, self.ctrl_dim))
        debug_helpers.check_shape(question_words, (batch_size, self.ctrl_dim * 2))
        debug_helpers.check_shape(image_vec, (batch_size, self.ctrl_dim, 14, 14))

        ctrl = self.initial_control.expand(batch_size, self.ctrl_dim)
        mem = self.initial_mem.expand(batch_size, self.ctrl_dim)

        for i in range(self.recurrence_length):
            cu_cell = self.cu_cells[i]
            ru_cell = self.ru_cells[i]
            wu_cell = self.wu_cells[i]
            ctrl = cu_cell(i, ctrl, context_words, question_words)
            ri = ru_cell(mem, image_vec, ctrl)
            mem = wu_cell(mem, ri, ctrl)
            debug_helpers.check_shape(mem, (batch_size, self.ctrl_dim))
            debug_helpers.check_shape(ri, (batch_size, self.ctrl_dim))
            debug_helpers.check_shape(ctrl, (batch_size, self.ctrl_dim))

        output = self.output_cell(question_words, mem)
        debug_helpers.check_shape(output, (batch_size, 28))
        return output
コード例 #3
0
ファイル: mac.py プロジェクト: ririw/mac
    def forward(self, h, mem):
        batch_size = mem.shape[0]
        check_shape(h, (batch_size, self.ctrl_dim * 2))
        check_shape(mem, (batch_size, self.ctrl_dim))

        return self.layers(torch.cat([h, mem], 1))
コード例 #4
0
ファイル: mac.py プロジェクト: ririw/mac
    def forward(self, mem, ri, control, prev_control=None):
        batch_size, ctrl_dim = mem.shape
        assert ctrl_dim == self.ctrl_dim
        assert self.use_prev_control == (prev_control is not None)
        if prev_control is not None:
            check_shape(prev_control, (batch_size, None, ctrl_dim))
        check_shape(ri, (batch_size, ctrl_dim))
        check_shape(control, (batch_size, ctrl_dim))

        m_info = self.mem_read_int(torch.cat([ri, mem], 1))
        check_shape(m_info, (batch_size, ctrl_dim))

        if prev_control is not None:
            control_similarity = torch.einsum("bsd,bd->bsd", prev_control, control)
            control_expweight = self.mem_select(control_similarity)
            check_shape(control_expweight, (batch_size, None, 1))
            sa = torch.nn.functional.softmax(control_similarity.squeeze(2), dim=1)
            m_other = torch.einsum("bs,bsd->bd", sa, prev_control)
            m_info = self.mem_merge_other(m_other) + self.mem_merge_info(m_info)
        check_shape(m_info, (batch_size, ctrl_dim))

        if self.gate_mem:
            mem_ctrl = self.mem_gate(control).squeeze(1)
            ci = torch.sigmoid(mem_ctrl)
            check_shape(m_info, (batch_size, self.ctrl_dim))
            m_next = torch.einsum("bd,b->bd", mem, ci) + torch.einsum(
                "bd,b->bd", m_info, 1 - ci
            )
        else:
            m_next = m_info

        save_all_locals()
        return m_next
コード例 #5
0
ファイル: mac.py プロジェクト: ririw/mac
    def forward(self, mem, kb, control):
        batch_size, ctrl_dim = mem.shape
        assert ctrl_dim == self.ctrl_dim

        kb_shape = (batch_size, ctrl_dim, 14, 14)
        check_shape(kb, kb_shape)
        check_shape(control, (batch_size, ctrl_dim))
        kb = kb.permute(0, 2, 3, 1)
        check_shape(kb, (batch_size, 14, 14, ctrl_dim))

        mem_trfed = self.mem_trf(mem)
        check_shape(mem_trfed, (batch_size, ctrl_dim))

        mem_kb_inter = torch.einsum("bc,bwhc->bwhc", mem_trfed, kb)
        mem_kb_inter_cat = torch.cat([mem_kb_inter, kb], -1)
        check_shape(mem_kb_inter_cat, (batch_size, 14, 14, ctrl_dim * 2))
        mem_kb_inter_cat_trf = self.ctrl_lin(mem_kb_inter_cat)
        check_shape(mem_kb_inter_cat_trf, (batch_size, 14, 14, ctrl_dim))

        ctrled = torch.einsum("bwhc,bc->bwhc", mem_kb_inter_cat_trf, control)
        attended_flat = self.attn(ctrled).view(batch_size, -1)
        check_shape(attended_flat, (batch_size, 14 * 14))
        attended = func.softmax(attended_flat, dim=-1).view(batch_size, 14, 14)
        check_shape(attended, (batch_size, 14, 14))

        retrieved = torch.einsum("bwhc,bwh->bc", kb, attended)
        check_shape(retrieved, (batch_size, ctrl_dim))
        save_all_locals()
        return retrieved
コード例 #6
0
ファイル: mac.py プロジェクト: ririw/mac
    def forward(self, step, prev_ctrl, context_words, question_words):
        batch_size, seq_len, _ = context_words.shape

        check_shape(prev_ctrl, (batch_size, self.ctrl_dim))
        check_shape(question_words, (batch_size, self.ctrl_dim * 2))
        check_shape(context_words, (batch_size, seq_len, self.ctrl_dim))

        question_words_localized = self.step_trf[step](question_words)

        c_concat = torch.cat([prev_ctrl, question_words_localized], 1)
        check_shape(c_concat, (batch_size, self.ctrl_dim * 2))
        cq = self.cq_lin(c_concat)
        check_shape(cq, (batch_size, self.ctrl_dim))

        cw_weighted = torch.einsum("bd,bsd->bsd", cq, context_words)
        ca = self.ca_lin(cw_weighted).squeeze(2)
        check_shape(ca, (batch_size, seq_len))

        cv = torch.nn.Softmax(dim=1)(ca)
        check_shape(cv, (batch_size, seq_len))

        next_ctrl = torch.einsum("bs,bsd->bd", cv, context_words)
        check_shape(next_ctrl, (batch_size, self.ctrl_dim))

        save_all_locals()
        return next_ctrl
コード例 #7
0
ファイル: mac.py プロジェクト: ririw/mac
 def process_img(self, kb, batch_size):
     debug_helpers.check_shape(kb, (batch_size, 1024, 14, 14))
     kb_reduced = self.kb_mapper(kb)
     debug_helpers.check_shape(kb_reduced, (batch_size, self.ctrl_dim, 14, 14))
     return kb_reduced