def test_four(self): """Equal duty cycle, boost factor=0, k=3, batch size=2.""" x = self.x2 expected = torch.zeros_like(x) expected[0, 0, 1, 1] = x[0, 0, 1, 1] expected[0, 1, 0, 1] = x[0, 1, 0, 1] expected[0, 2, 1, 0] = x[0, 2, 1, 0] expected[1, 1, 0, 0] = x[1, 1, 0, 0] expected[1, 1, 0, 1] = x[1, 1, 0, 1] expected[1, 2, 1, 1] = x[1, 2, 1, 1] for break_ties in [True, False]: with self.subTest(break_ties=break_ties): result = F.kwinners2d(x, self.duty_cycle, k=3, boost_strength=0.0, local=False, break_ties=break_ties) self.assertEqual(result.shape, expected.shape) num_correct = (result == expected).sum() self.assertEqual(num_correct, result.reshape(-1).size()[0])
def forward(self, x): if self.n == 0: self.n = np.prod(x.shape[1:]) if not self.local: self.k = int(round(self.n * self.percent_on)) self.k_inference = int(round(self.n * self.percent_on_inference)) if self.training: x = F.kwinners2d(x, self.duty_cycle, self.k, self._cached_boost_strength, self.local, self.break_ties, self.relu, self.inplace) self.update_duty_cycle(x) else: x = F.kwinners2d(x, self.duty_cycle, self.k_inference, self._cached_boost_strength, self.local, self.break_ties, self.relu, self.inplace) return x