def forward(self, x): out = self.conv1(x) rp = F.adaptive_max_pool2d(out, (self.s, 1)) cp = F.adaptive_max_pool2d(out, (1, self.s)) p = paddle.reshape(self.conv_p(rp), (x.shape[0], self.k, self.s, self.s)) q = paddle.reshape(self.conv_q(cp), (x.shape[0], self.k, self.s, self.s)) p = F.sigmoid(p) q = F.sigmoid(q) p = p / paddle.sum(p, axis=3, keepdim=True) q = q / paddle.sum(q, axis=2, keepdim=True) p = paddle.reshape(p, (x.shape[0], self.k, 1, self.s, self.s)) p = paddle.expand(p, (x.shape[0], self.k, x.shape[1] // self.k, self.s, self.s)) p = paddle.reshape(p, (x.shape[0], x.shape[1], self.s, self.s)) q = paddle.reshape(q, (x.shape[0], self.k, 1, self.s, self.s)) q = paddle.expand(q, (x.shape[0], self.k, x.shape[1] // self.k, self.s, self.s)) q = paddle.reshape(q, (x.shape[0], x.shape[1], self.s, self.s)) p = self.resize_mat(p, x.shape[2] // self.s) q = self.resize_mat(q, x.shape[2] // self.s) y = paddle.matmul(p, x) y = paddle.matmul(y, q) y = self.conv2(y) return y
def forward(self, input): x = self.DownBlock(input) gap = F.adaptive_avg_pool2d(x, 1) gap_logit = self.gap_fc(gap.reshape([x.shape[0], -1])) gap_weight = list(self.gap_fc.parameters())[0].transpose([1, 0]) gap = x * gap_weight.unsqueeze(2).unsqueeze(3) gmp = F.adaptive_max_pool2d(x, 1) gmp_logit = self.gmp_fc(gmp.reshape([x.shape[0], -1])) gmp_weight = list(self.gmp_fc.parameters())[0].transpose([1, 0]) gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3) cam_logit = paddle.concat([gap_logit, gmp_logit], 1) x = paddle.concat([gap, gmp], 1) x = self.relu(self.conv1x1(x)) heatmap = paddle.sum(x, axis=1, keepdim=True) if self.light: x_ = F.adaptive_avg_pool2d(x, 1) x_ = self.FC(x_.reshape([x_.shape[0], -1])) else: x_ = self.FC(x.reshape([x.shape[0], -1])) gamma, beta = self.gamma(x_), self.beta(x_) for i in range(self.n_blocks): x = getattr(self, 'UpBlock1_' + str(i + 1))(x, gamma, beta) out = self.UpBlock2(x) return out, cam_logit, heatmap
def forward(self, fstudent, fteacher): loss_all = 0.0 for fs, ft in zip(fstudent, fteacher): h = fs.shape[2] loss = F.mse_loss(fs, ft) cnt = 1.0 tot = 1.0 for l in [4, 2, 1]: if l >= h: continue if self.mode == "max": tmpfs = F.adaptive_max_pool2d(fs, (l, l)) tmpft = F.adaptive_max_pool2d(ft, (l, l)) else: tmpfs = F.adaptive_avg_pool2d(fs, (l, l)) tmpft = F.adaptive_avg_pool2d(ft, (l, l)) cnt /= 2.0 loss += F.mse_loss(tmpfs, tmpft) * cnt tot += cnt loss = loss / tot loss_all = loss_all + loss return loss_all
def forward(self, x): bs = x.shape[0] x = self.DownBlock(x) content_features = [] for i in range(self.n_blocks): x = getattr(self, 'EncodeBlock' + str(i + 1))(x) content_features.append( F.adaptive_avg_pool2d(x, 1).reshape([bs, -1])) gap = F.adaptive_avg_pool2d(x, 1) gap_logit = self.gap_fc(gap.reshape([bs, -1])) gap_weight = list(self.gap_fc.parameters())[0].transpose([1, 0]) gap = x * gap_weight.unsqueeze(2).unsqueeze(3) gmp = F.adaptive_max_pool2d(x, 1) gmp_logit = self.gmp_fc(gmp.reshape([bs, -1])) gmp_weight = list(self.gmp_fc.parameters())[0].transpose([1, 0]) gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3) cam_logit = paddle.concat([gap_logit, gmp_logit], 1) x = paddle.concat([gap, gmp], 1) x = self.relu(self.conv1x1(x)) heatmap = paddle.sum(x, axis=1, keepdim=True) if self.light: x_ = F.adaptive_avg_pool2d(x, 1) style_features = self.FC(x_.reshape([bs, -1])) else: style_features = self.FC(x.reshape([bs, -1])) for i in range(self.n_blocks): x = getattr(self, 'DecodeBlock' + str(i + 1))( x, content_features[4 - i - 1], style_features) out = self.UpBlock(x) return out, cam_logit, heatmap
def forward(self, x): x = self.model(x) gap = F.adaptive_avg_pool2d(x, 1) gap_logit = self.gap_fc(gap.reshape([x.shape[0], -1])) gap_weight = list(self.gap_fc.parameters())[0].transpose([1, 0]) gap = x * gap_weight.unsqueeze(2).unsqueeze(3) gmp = F.adaptive_max_pool2d(x, 1) gmp_logit = self.gmp_fc(gmp.reshape([x.shape[0], -1])) gmp_weight = list(self.gmp_fc.parameters())[0].transpose([1, 0]) gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3) cam_logit = paddle.concat([gap_logit, gmp_logit], 1) x = paddle.concat([gap, gmp], 1) x = self.leaky_relu(self.conv1x1(x)) heatmap = paddle.sum(x, axis=1, keepdim=True) x = self.pad(x) out = self.conv(x) return out, cam_logit, heatmap