コード例 #1
0
ファイル: model.py プロジェクト: sudheer-25/MirrorGAN
class NEXT_STAGE_G(nn.Module):
    def __init__(self, ngf, nef, ncf):
        super(NEXT_STAGE_G, self).__init__()
        self.gf_dim = ngf  # g net feature 32
        self.ef_dim = nef  # text embedding size 256
        self.cf_dim = ncf  # condition dim 100
        # print(ngf, nef, ncf)  (32, 256, 100)
        # (32, 256, 100)
        self.num_residual = cfg.GAN.R_NUM  # cfg.GAN.R_NUM = 2
        self.define_module()
        self.conv = conv1x1(ngf * 3, ngf * 2)

    def _make_layer(self, block, channel_num):
        # block ResBlock, channel_num 32 * 2
        layers = []
        for i in range(cfg.GAN.R_NUM):  # 2
            layers.append(block(channel_num))
        return nn.Sequential(*layers)

    def define_module(self):
        ngf = self.gf_dim
        self.att = ATT_NET(ngf, self.ef_dim)
        self.residual = self._make_layer(ResBlock, ngf * 2)
        self.upsample = upBlock(ngf * 2, ngf)

    def forward(self, h_code, c_code, word_embs, mask):
        """
            h_code1(query):  batch x idf x ih x iw (queryL=ihxiw)
            word_embs(context): batch x cdf x sourceL (sourceL=seq_len)
            c_code1: batch x idf x queryL
            att1: batch x sourceL x queryL
        """
        # print('========')
        # ((16, 32, 64, 64), (16, 100), (16, 256, 18), (16, 18))
        # print(h_code.size(), c_code.size(), word_embs.size(), mask.size())
        self.att.applyMask(mask)
        # here, a new c_code is generated by self.att() method.
        # weightedContext, weightedSentence, word_attn, sent_vs_att
        c_code, weightedSentence, att, sent_att = self.att(h_code, c_code, word_embs)
        # Then, image feature are concated with a new c_code, they become h_c_code,
        # so, here I can make some change, to concate more items together.
        # which means I need to get more output from line 369, self.att()
        # also, I need to feed more information to calculate the function, and let's see what the new idea will return.
        h_c_code = torch.cat((h_code, c_code), 1)
        # print('h_c_code.size:', h_c_code.size())  # ('h_c_code.size:', (16, 64, 64, 64))
        h_c_sent_code = torch.cat((h_c_code, weightedSentence), 1)
        # print('h_c_sent_code.size:', h_c_sent_code.size())
        # ('h_c_code.size:', (16, 64, 64, 64))
        # ('h_c_sent_code.size:', (16, 96, 64, 64))
        h_c_sent_code = self.conv(h_c_sent_code)  # (16, 96, 64, 64) -> (16, 64, 64, 64)
        out_code = self.residual(h_c_sent_code)  # (16, 64, 64, 64) -> 2 residual -> (16, 64, 64, 64)
        # print('out_code:', out_code.size())
        # state size ngf/2 x 2in_size x 2in_size
        out_code = self.upsample(
            out_code)  # (16, 64, 64, 64) -> (16, 32, 128, 128)(second G net) -> (16, 32, 256, 256)(third G net)
        return out_code, att
コード例 #2
0
ファイル: model.py プロジェクト: ammarnasr/mirrorGAN
 def define_module(self):
     ngf = self.gf_dim
     print ("Calling ATT_NET with idf =  ", ngf,"cdf = ", self.ef_dim)
     self.att = ATT_NET(ngf, self.ef_dim)
     print("back at NEXT_STAGE_G")
     self.residual = self._make_layer(ResBlock, ngf * 2)
     print("self.residual", self.residual)
     self.upsample = upBlock(ngf * 2, ngf)
     print("self.upsample", self.upsample)
コード例 #3
0
ファイル: model.py プロジェクト: ZuoyiLi/MirrorGAN-2
 def define_module(self):
     ngf = self.gf_dim
     self.att = ATT_NET(ngf, self.ef_dim)
     self.residual = self._make_layer(ResBlock, ngf * 2)
     self.upsample = upBlock(ngf * 2, ngf)
コード例 #4
0
ファイル: model.py プロジェクト: ammarnasr/mirrorGAN
class NEXT_STAGE_G(nn.Module):
    def __init__(self, ngf, nef, ncf):
        super(NEXT_STAGE_G, self).__init__()
        self.gf_dim = ngf
        self.ef_dim = nef
        self.cf_dim = ncf
        # (32, 256, 100)
        self.num_residual = cfg.GAN.R_NUM
        self.define_module()
        self.conv = conv1x1(ngf * 3, ngf * 2)
        print("ngf, nef, ncf, self.num_residual, self.conv", ngf, nef, ncf, self.num_residual, self.conv) # (32, 256, 100)


    def _make_layer(self, block, channel_num):
        layers = []
        for i in range(cfg.GAN.R_NUM): # 2
            layers.append(block(channel_num))
        return nn.Sequential(*layers)

    def define_module(self):
        ngf = self.gf_dim
        print ("Calling ATT_NET with idf =  ", ngf,"cdf = ", self.ef_dim)
        self.att = ATT_NET(ngf, self.ef_dim)
        print("back at NEXT_STAGE_G")
        self.residual = self._make_layer(ResBlock, ngf * 2)
        print("self.residual", self.residual)
        self.upsample = upBlock(ngf * 2, ngf)
        print("self.upsample", self.upsample)


    def forward(self, h_code, c_code, word_embs, mask):
        """
            h_code1(query):  batch x idf x ih x iw (queryL=ihxiw)
            word_embs(context): batch x cdf x sourceL (sourceL=seq_len)
            c_code1: batch x idf x queryL
            att1: batch x sourceL x queryL
        """
        print ('-------THE START OF NEXT_STAGE_G-------')
        print('1.h_code    : ', h_code.size())
        print('2.c_code    : ', c_code.size())
        print('3.word_embs : ', word_embs.size())
        print('4.mask      : ', mask.size())

        # print('========')
        # ((16, 32, 64, 64), (16, 100), (16, 256, 18), (16, 18))
        # print(h_code.size(), c_code.size(), word_embs.size(), mask.size())
        self.att.applyMask(mask)
        # here, a new c_code is generated by self.att() method.
        # weightedContext, weightedSentence, word_attn, sent_vs_att
        c_code, weightedSentence, att, sent_att = self.att(h_code, c_code, word_embs)
        # Then, image feature are concated with a new c_code, they become h_c_code,
        # so, here I can make some change, to concate more items together.
        # which means I need to get more output from line 369, self.att()
        # also, I need to feed more information to calculate the function, and let's see what the new idea will return.
        h_c_code = torch.cat((h_code, c_code), 1)
        # print('h_c_code.size:', h_c_code.size())  # ('h_c_code.size:', (16, 64, 64, 64))
        h_c_sent_code = torch.cat((h_c_code, weightedSentence), 1)
        # print('h_c_sent_code.size:', h_c_sent_code.size())
        # ('h_c_code.size:', (16, 64, 64, 64))
        # ('h_c_sent_code.size:', (16, 96, 64, 64))
        h_c_sent_code = self.conv(h_c_sent_code)
        out_code = self.residual(h_c_sent_code)
        # print('out_code:', out_code.size())
        # state size ngf/2 x 2in_size x 2in_size
        out_code = self.upsample(out_code)
        print ('-------THE END OF NEXT_STAGE_G-------')

        return out_code, att