Esempio n. 1
0
    def test_featurewise_normalization(self):
        b, h, w, d = 5, 30, 40, 20

        repr = (2, 3)
        inorm = InstanceNorm3d(repr, eps=0.)

        inp = torch.randn(2, b, sum(repr), h, w, d)
        feature_means = torch.randn(5).reshape(1, 1, -1, 1, 1, 1)
        batch_means = torch.randn(b).reshape(1, -1, 1, 1, 1, 1)
        feature_stds = torch.randn(5).reshape(1, 1, -1, 1, 1, 1)
        batch_stds = torch.randn(b).reshape(1, -1, 1, 1, 1, 1)
        inp *= feature_stds
        inp *= batch_stds
        inp += feature_means
        inp += batch_means

        out = inorm(inp)
        mags = magnitude(out)

        for batch_item in range(b):
            for feature in range(sum(repr)):
                mag = mags[batch_item, feature, ...]
                val = out[:, batch_item, feature, ...]

                diff_mean = torch.abs(val.mean())
                self.assertLess(diff_mean.item(), 0.1)

                diff_std = torch.abs(mag.std() - 1.)
                self.assertLess(diff_std.item(), 0.1)
Esempio n. 2
0
    def forward(
            self, x: [2, 'b', 'f', 'h', 'w',
                      ...]) -> [2, 'b', 'f', 'h', 'w', ...]:
        magnitudes = magnitude(x)
        g_x = self.conv1(magnitudes)

        means = x.reshape(*x.shape[:3], -1).mean(dim=3)
        mean_magnitudes = magnitude(means)
        g_m = self.mean_mat(mean_magnitudes)

        g = g_x + g_m.reshape(*g_m.shape, *([1] * self._dim))
        g = torch.relu(g)
        g = self.conv2(g)
        g = torch.sigmoid(g)

        return x * g.unsqueeze(0)
Esempio n. 3
0
        def test_forward(self):
            nonl = CReLU((3, 6, 0, 1))
            n, h, w = 3, 40, 40
            inputs = torch.randn(2, n, 3 + 6 + 1, h, w)
            output = nonl(inputs)

            print(nonl.bias[0, 0, 0])
            print(magnitude(inputs[:, 0, 0]))
            print(output[0, 0, 0])
Esempio n. 4
0
    def _forward(
            self, x: [2, 'b', 'f', 'h', 'w',
                      ...]) -> [2, 'b', 'f', 'h', 'w', ...]:
        magn = magnitude(x).unsqueeze(0)
        normalized = x / (magn + self.eps)
        magn_biased = magn - self.bias

        zero = x.new_zeros((1, ))
        return torch.where(magn_biased > 0, normalized * magn_biased, zero)