Ejemplo n.º 1
0
class NEXT_STAGE_G(nn.Module):
    def __init__(self, ngf, nef, ncf):
        super(NEXT_STAGE_G, self).__init__()
        self.gf_dim = ngf
        self.ef_dim = nef
        self.cf_dim = ncf
        self.num_residual = cfg.GAN.R_NUM
        self.define_module()

    def _make_layer(self, block, channel_num):
        layers = []
        for i in range(cfg.GAN.R_NUM):
            layers.append(block(channel_num))
        return nn.Sequential(*layers)

    def define_module(self):
        ngf = self.gf_dim
        self.att = ATT_NET(ngf, self.ef_dim)
        self.residual = self._make_layer(ResBlock, ngf * 2)
        self.upsample = upBlock(ngf * 2, ngf)

    def forward(self, h_code, c_code, word_embs, mask):
        """
            h_code1(query):  batch x idf x ih x iw (queryL=ihxiw)
            word_embs(context): batch x cdf x sourceL (sourceL=seq_len)
            c_code1: batch x idf x queryL
            att1: batch x sourceL x queryL
        """
        self.att.applyMask(mask)
        c_code, att = self.att(h_code, word_embs)
        h_c_code = torch.cat((h_code, c_code), 1)
        out_code = self.residual(h_c_code)

        # state size ngf/2 x 2in_size x 2in_size
        out_code = self.upsample(out_code)

        return out_code, att
class NEXT_STAGE_G(nn.Module):
    def __init__(self):
        super(NEXT_STAGE_G, self).__init__()
        self.gen_feat_dim = cfg.GAN.GEN_FEAT_DIM
        self.text_emb_dim = cfg.TEXT.EMBEDDING_DIM
        self.label_dim = cfg.GAN.NEXT_LABEL_DIM
        self.num_residual = cfg.GAN.RESIDUAL_NUM
        self.define_module()

    def _make_layer(self, block, channel_num):
        layers = []
        for i in range(cfg.GAN.RESIDUAL_NUM):
            layers.append(block(channel_num))
        return nn.Sequential(*layers)

    def define_module(self):
        self.att = ATT_NET(self.gen_feat_dim, self.text_emb_dim)
        self.residual = self._make_layer(ResBlock, self.gen_feat_dim * 2)
        self.upsample = upBlock(self.gen_feat_dim * 3, self.gen_feat_dim)

        # local pathway
        label_input_dim = cfg.GAN.TEXT_CONDITION_DIM + cfg.TEXT.CLASSES_NUM  # no noise anymore
        self.label = nn.Sequential(
            nn.Linear(label_input_dim, self.label_dim, bias=False),
            nn.BatchNorm1d(self.label_dim), nn.ReLU(True))

        self.local1 = upBlock(self.label_dim + self.gen_feat_dim,
                              self.gen_feat_dim * 2)
        self.local2 = upBlock(self.gen_feat_dim * 2, self.gen_feat_dim)

    def forward(self,
                h_code,
                c_code,
                word_embs,
                mask,
                transf_matrices,
                transf_matrices_inv,
                label_one_hot,
                max_objects,
                op=True):
        """
            h_code1(query):  batch x idf x ih x iw (queryL=ihxiw)
            word_embs(context): batch x cdf x sourceL (sourceL=seq_len)
            c_code1: batch x idf x queryL
            att1: batch x sourceL x queryL
        """
        _hw = h_code.shape[2]
        self.att.applyMask(mask)
        c_code_att, att = self.att(h_code, word_embs)
        h_c_code = torch.cat((h_code, c_code_att), 1)
        out_code = self.residual(h_c_code)

        # object pathways
        h_code_locals = h_code.new_zeros(h_code.shape[0], self.gen_feat_dim,
                                         _hw, _hw)
        if op:
            for idx in range(max_objects):
                current_label = self.label(
                    torch.cat((c_code, label_one_hot[:, idx]), 1))
                current_label = current_label.view(h_code.shape[0],
                                                   self.label_dim, 1, 1)
                current_label = current_label.repeat(1, 1, _hw // 4, _hw // 4)
                current_patch = stn(
                    h_code, transf_matrices[:, idx],
                    (h_code.shape[0], h_code.shape[1], _hw // 4, _hw // 4))
                # logger.info(current_label.shape)
                # logger.info(current_patch.shape)
                current_input = torch.cat((current_patch, current_label), 1)
                # logger.info(current_input.shape)
                h_code_local = self.local1(current_input)
                h_code_local = self.local2(h_code_local)
                h_code_local = stn(h_code_local, transf_matrices_inv[:, idx],
                                   h_code_locals.shape)
                h_code_locals = merge_tensors(h_code_locals, h_code_local, idx)

        out_code = torch.cat((out_code, h_code_locals), 1)

        # state size ngf/2 x 2in_size x 2in_size
        out_code = self.upsample(out_code)

        return out_code, att
Ejemplo n.º 3
0
class NEXT_STAGE_G(nn.Module):
    def __init__(self, ngf, nef, ncf):
        super(NEXT_STAGE_G, self).__init__()
        self.gf_dim = ngf
        self.ef_dim = nef
        self.cf_dim = ncf
        self.num_residual = cfg.GAN.R_NUM
        self.define_module()

    def _make_layer(self, block, channel_num):
        layers = []
        for i in range(cfg.GAN.R_NUM):
            layers.append(block(channel_num))
        return nn.Sequential(*layers)

    def define_module(self):
        ngf = self.gf_dim
        self.att = ATT_NET(ngf, self.ef_dim)
        self.residual = self._make_layer(ResBlock, ngf * 2)
        self.upsample = upBlock(ngf * 3, ngf)

        # local pathway
        linput = cfg.GAN.Z_DIM + 81
        self.label = nn.Sequential(
            nn.Linear(linput, self.ef_dim // 2, bias=False),
            nn.BatchNorm1d(self.ef_dim // 2), nn.ReLU(True))

        self.local1 = upBlock(self.ef_dim // 2 + ngf, ngf * 2)
        self.local2 = upBlock(ngf * 2, ngf)

    def forward(self,
                h_code,
                c_code,
                word_embs,
                mask,
                transf_matrices,
                transf_matrices_inv,
                label_one_hot,
                max_objects,
                op=True):
        """
            h_code1(query):  batch x idf x ih x iw (queryL=ihxiw)
            word_embs(context): batch x cdf x sourceL (sourceL=seq_len)
            c_code1: batch x idf x queryL
            att1: batch x sourceL x queryL
        """
        _hw = h_code.shape[2]
        self.att.applyMask(mask)
        c_code_att, att = self.att(h_code, word_embs)
        h_c_code = torch.cat((h_code, c_code_att), 1)
        out_code = self.residual(h_c_code)

        # object pathways
        h_code_locals = torch.cuda.FloatTensor(h_code.shape[0], self.gf_dim,
                                               _hw, _hw).fill_(0)
        if op:
            for idx in range(max_objects):
                current_label = self.label(
                    torch.cat((c_code, label_one_hot[:, idx]), 1))
                current_label = current_label.view(h_code.shape[0],
                                                   self.ef_dim // 2, 1, 1)
                current_label = current_label.repeat(1, 1, _hw // 4, _hw // 4)
                current_patch = stn(
                    h_code, transf_matrices[:, idx],
                    (h_code.shape[0], h_code.shape[1], _hw // 4, _hw // 4))
                # print(current_label.shape)
                # print(current_patch.shape)
                current_input = torch.cat((current_patch, current_label), 1)
                # print(current_input.shape)
                h_code_local = self.local1(current_input)
                h_code_local = self.local2(h_code_local)
                h_code_local = stn(h_code_local, transf_matrices_inv[:, idx],
                                   h_code_locals.shape)
                h_code_locals = merge_tensors(h_code_locals, h_code_local, idx)

        out_code = torch.cat((out_code, h_code_locals), 1)

        # state size ngf/2 x 2in_size x 2in_size
        out_code = self.upsample(out_code)

        return out_code, att
Ejemplo n.º 4
0
class NEXT_STAGE_G_MAIN(nn.Module):
    def __init__(self, ngf, nef, nef2):
        super(NEXT_STAGE_G_MAIN, self).__init__()
        self.gf_dim = ngf
        self.ef_dim = nef
        self.ef_dim2 = nef2
        self.define_module()

    def _make_layer(self, block, channel_num):
        layers = []
        for i in range(cfg.GAN.LOCAL_R_NUM):
            layers.append(block(channel_num))
        return nn.Sequential(*layers)

    def define_module(self):
        ngf = self.gf_dim
        nef = self.ef_dim
        nef2 = self.ef_dim2

        self.att = ATT_NET(ngf, nef)
        self.bt_att = BT_ATT_NET(ngf, nef)
        self.residual = self._make_layer(HmapResBlock, ngf*3+nef2)
        self.upsample = upBlock(ngf*3+nef2, ngf)

    def forward(self, h_code, h_code_hmap, c_code, word_embs,
        glove_word_embs, slabels_feat, mask, rois, num_rois, bt_mask,
        glb_max_num_roi):
        """
            h_code1(query):  batch x idf x ih x iw (queryL=ihxiw)
            word_embs(context): batch x cdf x sourceL (sourceL=seq_len)
            glove_word_embs: batch x cdf2 x sourceL (sourceL=seq_len)
            slabels_feat: batch x cdf2 x max_num_roi x 1
            c_code1: batch x idf x queryL
            att1: batch x sourceL x queryL
        """
        idf, ih, iw = h_code.size(1), h_code.size(2), h_code.size(3)

        ### compute grid attn and context vectors
        self.att.applyMask(mask)
        c_code, att = self.att(h_code, word_embs)

        num_rois = num_rois.data.cpu().numpy().tolist()
        max_num_roi = np.amax(num_rois)
        slabels_feat = slabels_feat[:, :, :max_num_roi]

        raw_bt_c_code = Variable(torch.Tensor(c_code.size(0), idf, 
            glb_max_num_roi, 1).zero_())
        if cfg.CUDA:
            raw_bt_c_code = raw_bt_c_code.cuda()

        if max_num_roi > 0:
            ### compute bottom-up attn and context vectors
            self.bt_att.applyMask(mask)
            # bt_c_code (variable): batch_size x idf x max_num_rois x 1
            # bt_att (variable): batch_size x cap_len x max_num_rois x 1
            bt_c_code, bt_att = self.bt_att(slabels_feat, glove_word_embs, word_embs)
            raw_bt_c_code[:,:,:max_num_roi] = bt_c_code
            # bt_mask: batch x max_num_rois x ih x iw -> batch x max_num_rois x num x ih x iw
            bt_mask = bt_mask[:,:max_num_roi]

            bt_code_mask = bt_mask.unsqueeze(2).repeat(1, 1, bt_c_code.size(1), 1, 1)
            bt_c_code = pprocess_bt_attns(bt_c_code, ih, iw, bt_code_mask)

            bt_att_mask = bt_mask.unsqueeze(2).repeat(1, 1, bt_att.size(1), 1, 1)
            bt_att = pprocess_bt_attns(bt_att, ih, iw, bt_att_mask)

            bt_slabels_mask = bt_mask.unsqueeze(2).repeat(1, 1, self.ef_dim2, 1, 1)
            bt_slabels_code = pprocess_bt_attns(slabels_feat, ih, iw, bt_slabels_mask)
        else:
            bt_c_code = Variable(torch.Tensor(c_code.size()).zero_())
            bt_att = Variable(torch.Tensor(att.size()).zero_())
            bt_slabels_code = Variable(torch.Tensor(c_code.size(0), self.ef_dim2, 
                c_code.size(2), c_code.size(3)).zero_())
            if cfg.CUDA:
                bt_c_code = bt_c_code.cuda()
                bt_att = bt_att.cuda()
                bt_slabels_code = bt_slabels_code.cuda()
        
        h_c_code = torch.cat((h_code+h_code_hmap, c_code, bt_c_code, bt_slabels_code), 1)
        out_code = self.residual(h_c_code)

        # state size ngf/2 x 2in_size x 2in_size
        out_code = self.upsample(out_code)
        raw_bt_c_code = raw_bt_c_code.transpose(1,2).squeeze(-1)

        return out_code, raw_bt_c_code, att, bt_att