def forward(self, x):
        out = self.conv1(x)
        rp = F.adaptive_max_pool2d(out, (self.s, 1))
        cp = F.adaptive_max_pool2d(out, (1, self.s))
        p = paddle.reshape(self.conv_p(rp), (x.shape[0], self.k, self.s, self.s))
        q = paddle.reshape(self.conv_q(cp), (x.shape[0], self.k, self.s, self.s))
        p = F.sigmoid(p)
        q = F.sigmoid(q)
        p = p / paddle.sum(p, axis=3, keepdim=True)
        q = q / paddle.sum(q, axis=2, keepdim=True)

        p = paddle.reshape(p, (x.shape[0], self.k, 1, self.s, self.s))
        p = paddle.expand(p, (x.shape[0], self.k, x.shape[1] // self.k, self.s, self.s))

        p = paddle.reshape(p, (x.shape[0], x.shape[1], self.s, self.s))

        q = paddle.reshape(q, (x.shape[0], self.k, 1, self.s, self.s))
        q = paddle.expand(q, (x.shape[0], self.k, x.shape[1] // self.k, self.s, self.s))

        q = paddle.reshape(q, (x.shape[0], x.shape[1], self.s, self.s))

        p = self.resize_mat(p, x.shape[2] // self.s)
        q = self.resize_mat(q, x.shape[2] // self.s)
        y = paddle.matmul(p, x)
        y = paddle.matmul(y, q)

        y = self.conv2(y)
        return y
    def forward(self, input):
        x = self.DownBlock(input)

        gap = F.adaptive_avg_pool2d(x, 1)
        gap_logit = self.gap_fc(gap.reshape([x.shape[0], -1]))
        gap_weight = list(self.gap_fc.parameters())[0].transpose([1, 0])
        gap = x * gap_weight.unsqueeze(2).unsqueeze(3)

        gmp = F.adaptive_max_pool2d(x, 1)
        gmp_logit = self.gmp_fc(gmp.reshape([x.shape[0], -1]))
        gmp_weight = list(self.gmp_fc.parameters())[0].transpose([1, 0])
        gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3)

        cam_logit = paddle.concat([gap_logit, gmp_logit], 1)
        x = paddle.concat([gap, gmp], 1)
        x = self.relu(self.conv1x1(x))

        heatmap = paddle.sum(x, axis=1, keepdim=True)

        if self.light:
            x_ = F.adaptive_avg_pool2d(x, 1)
            x_ = self.FC(x_.reshape([x_.shape[0], -1]))
        else:
            x_ = self.FC(x.reshape([x.shape[0], -1]))
        gamma, beta = self.gamma(x_), self.beta(x_)

        for i in range(self.n_blocks):
            x = getattr(self, 'UpBlock1_' + str(i + 1))(x, gamma, beta)
        out = self.UpBlock2(x)

        return out, cam_logit, heatmap
Beispiel #3
0
    def forward(self, fstudent, fteacher):
        loss_all = 0.0
        for fs, ft in zip(fstudent, fteacher):
            h = fs.shape[2]
            loss = F.mse_loss(fs, ft)
            cnt = 1.0
            tot = 1.0
            for l in [4, 2, 1]:
                if l >= h:
                    continue
                if self.mode == "max":
                    tmpfs = F.adaptive_max_pool2d(fs, (l, l))
                    tmpft = F.adaptive_max_pool2d(ft, (l, l))
                else:
                    tmpfs = F.adaptive_avg_pool2d(fs, (l, l))
                    tmpft = F.adaptive_avg_pool2d(ft, (l, l))

                cnt /= 2.0
                loss += F.mse_loss(tmpfs, tmpft) * cnt
                tot += cnt
            loss = loss / tot
            loss_all = loss_all + loss
        return loss_all
Beispiel #4
0
    def forward(self, x):
        bs = x.shape[0]

        x = self.DownBlock(x)

        content_features = []
        for i in range(self.n_blocks):
            x = getattr(self, 'EncodeBlock' + str(i + 1))(x)
            content_features.append(
                F.adaptive_avg_pool2d(x, 1).reshape([bs, -1]))

        gap = F.adaptive_avg_pool2d(x, 1)
        gap_logit = self.gap_fc(gap.reshape([bs, -1]))
        gap_weight = list(self.gap_fc.parameters())[0].transpose([1, 0])
        gap = x * gap_weight.unsqueeze(2).unsqueeze(3)

        gmp = F.adaptive_max_pool2d(x, 1)
        gmp_logit = self.gmp_fc(gmp.reshape([bs, -1]))
        gmp_weight = list(self.gmp_fc.parameters())[0].transpose([1, 0])
        gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3)

        cam_logit = paddle.concat([gap_logit, gmp_logit], 1)
        x = paddle.concat([gap, gmp], 1)
        x = self.relu(self.conv1x1(x))

        heatmap = paddle.sum(x, axis=1, keepdim=True)

        if self.light:
            x_ = F.adaptive_avg_pool2d(x, 1)
            style_features = self.FC(x_.reshape([bs, -1]))
        else:
            style_features = self.FC(x.reshape([bs, -1]))

        for i in range(self.n_blocks):
            x = getattr(self, 'DecodeBlock' + str(i + 1))(
                x, content_features[4 - i - 1], style_features)

        out = self.UpBlock(x)

        return out, cam_logit, heatmap
    def forward(self, x):
        x = self.model(x)

        gap = F.adaptive_avg_pool2d(x, 1)
        gap_logit = self.gap_fc(gap.reshape([x.shape[0], -1]))
        gap_weight = list(self.gap_fc.parameters())[0].transpose([1, 0])
        gap = x * gap_weight.unsqueeze(2).unsqueeze(3)

        gmp = F.adaptive_max_pool2d(x, 1)
        gmp_logit = self.gmp_fc(gmp.reshape([x.shape[0], -1]))
        gmp_weight = list(self.gmp_fc.parameters())[0].transpose([1, 0])
        gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3)

        cam_logit = paddle.concat([gap_logit, gmp_logit], 1)
        x = paddle.concat([gap, gmp], 1)
        x = self.leaky_relu(self.conv1x1(x))

        heatmap = paddle.sum(x, axis=1, keepdim=True)

        x = self.pad(x)
        out = self.conv(x)

        return out, cam_logit, heatmap