def define_module(self): ngf = self.gf_dim self.att = SPATIAL_ATT(ngf, self.ef_dim) # spatial attention self.channel_att = CHANNEL_ATT(ngf, self.ef_dim) # channel-wise attention self.residual = self._make_layer(ResBlock, ngf * 3) self.upsample = upBlock(ngf * 3, ngf) self.SAIN = ACM(ngf * 3)
class NEXT_STAGE_G(nn.Module): def __init__(self, ngf, nef, ncf): super(NEXT_STAGE_G, self).__init__() self.gf_dim = ngf self.ef_dim = nef self.cf_dim = ncf self.num_residual = cfg.GAN.R_NUM self.define_module() def _make_layer(self, block, channel_num): layers = [] for i in range(cfg.GAN.R_NUM): layers.append(block(channel_num)) return nn.Sequential(*layers) def define_module(self): ngf = self.gf_dim self.att = SPATIAL_ATT(ngf, self.ef_dim) # spatial attention self.channel_att = CHANNEL_ATT(ngf, self.ef_dim) # channel-wise attention self.residual = self._make_layer(ResBlock, ngf * 3) self.upsample = upBlock(ngf * 3, ngf) self.SAIN = ACM(ngf * 3) def forward(self, h_code, c_code, word_embs, mask, img): """ h_code1(query): batch x idf x ih x iw (queryL=ihxiw) word_embs(context): batch x cdf x sourceL (sourceL=seq_len) c_code1: batch x idf x queryL att1: batch x sourceL x queryL """ self.att.applyMask(mask) c_code, att = self.att(h_code, word_embs) c_code_channel, att_channel = self.channel_att(c_code, word_embs, h_code.size(2), h_code.size(3)) c_code = c_code.view(word_embs.size(0), -1, h_code.size(2), h_code.size(3)) h_c_code = torch.cat((h_code, c_code), 1) h_c_c_code = torch.cat((h_c_code, c_code_channel), 1) h_c_c_img_code = self.SAIN(h_c_c_code, img) out_code = self.residual(h_c_c_img_code) out_code = self.upsample(out_code) return out_code, att
class DCM_NEXT_STAGE(nn.Module): def __init__(self, ngf, nef, ncf): super(DCM_NEXT_STAGE, self).__init__() self.gf_dim = ngf self.ef_dim = nef self.cf_dim = ncf self.num_residual = cfg.GAN.R_NUM self.define_module() def _make_layer(self, block, channel_num): layers = [] for i in range(cfg.GAN.R_NUM): layers.append(block(channel_num)) return nn.Sequential(*layers) def define_module(self): ngf = self.gf_dim self.att = SPATIAL_ATT(ngf, self.ef_dim) self.color_channel_att = DCM_CHANNEL_ATT(ngf, self.ef_dim) self.residual = self._make_layer(ResBlock, ngf * 3) self.block = nn.Sequential( conv3x3(ngf * 3, ngf * 2), nn.InstanceNorm2d(ngf * 2), GLU()) self.SAIN = ACM(ngf * 3) def forward(self, h_code, c_code, word_embs, mask, img): self.att.applyMask(mask) print(h_code.shape) c_code, att = self.att(h_code, word_embs) c_code_channel, att_channel = self.color_channel_att(c_code, word_embs, h_code.size(2), h_code.size(3)) c_code = c_code.view(word_embs.size(0), -1, h_code.size(2), h_code.size(3)) h_c_code = torch.cat((h_code, c_code), 1) h_c_c_code = torch.cat((h_c_code, c_code_channel), 1) h_c_c_img_code = self.SAIN(h_c_c_code, img) out_code = self.residual(h_c_c_img_code) out_code = self.block(out_code) return out_code
def define_module(self): ngf = self.gf_dim self.att = SPATIAL_ATT(ngf, self.ef_dim) self.color_channel_att = DCM_CHANNEL_ATT(ngf, self.ef_dim) self.residual = self._make_layer(ResBlock, ngf * 3) self.block = nn.Sequential(conv3x3(ngf * 3, ngf * 2), nn.InstanceNorm2d(ngf * 2), GLU()) self.SAIN = ACM(ngf * 3)