コード例 #1
0
    def test_four(self):
        """Equal duty cycle, boost factor=0, k=3, batch size=2."""
        x = self.x2

        expected = torch.zeros_like(x)
        expected[0, 0, 1, 1] = x[0, 0, 1, 1]
        expected[0, 1, 0, 1] = x[0, 1, 0, 1]
        expected[0, 2, 1, 0] = x[0, 2, 1, 0]
        expected[1, 1, 0, 0] = x[1, 1, 0, 0]
        expected[1, 1, 0, 1] = x[1, 1, 0, 1]
        expected[1, 2, 1, 1] = x[1, 2, 1, 1]

        for break_ties in [True, False]:
            with self.subTest(break_ties=break_ties):
                result = F.kwinners2d(x,
                                      self.duty_cycle,
                                      k=3,
                                      boost_strength=0.0,
                                      local=False,
                                      break_ties=break_ties)

                self.assertEqual(result.shape, expected.shape)

                num_correct = (result == expected).sum()
                self.assertEqual(num_correct, result.reshape(-1).size()[0])
コード例 #2
0
    def forward(self, x):

        if self.n == 0:
            self.n = np.prod(x.shape[1:])
            if not self.local:
                self.k = int(round(self.n * self.percent_on))
                self.k_inference = int(round(self.n * self.percent_on_inference))

        if self.training:
            x = F.kwinners2d(x, self.duty_cycle, self.k,
                             self._cached_boost_strength, self.local,
                             self.break_ties, self.relu, self.inplace)
            self.update_duty_cycle(x)
        else:
            x = F.kwinners2d(x, self.duty_cycle, self.k_inference,
                             self._cached_boost_strength, self.local,
                             self.break_ties, self.relu, self.inplace)

        return x