def test_featurewise_normalization(self): b, h, w, d = 5, 30, 40, 20 repr = (2, 3) inorm = InstanceNorm3d(repr, eps=0.) inp = torch.randn(2, b, sum(repr), h, w, d) feature_means = torch.randn(5).reshape(1, 1, -1, 1, 1, 1) batch_means = torch.randn(b).reshape(1, -1, 1, 1, 1, 1) feature_stds = torch.randn(5).reshape(1, 1, -1, 1, 1, 1) batch_stds = torch.randn(b).reshape(1, -1, 1, 1, 1, 1) inp *= feature_stds inp *= batch_stds inp += feature_means inp += batch_means out = inorm(inp) mags = magnitude(out) for batch_item in range(b): for feature in range(sum(repr)): mag = mags[batch_item, feature, ...] val = out[:, batch_item, feature, ...] diff_mean = torch.abs(val.mean()) self.assertLess(diff_mean.item(), 0.1) diff_std = torch.abs(mag.std() - 1.) self.assertLess(diff_std.item(), 0.1)
def forward( self, x: [2, 'b', 'f', 'h', 'w', ...]) -> [2, 'b', 'f', 'h', 'w', ...]: magnitudes = magnitude(x) g_x = self.conv1(magnitudes) means = x.reshape(*x.shape[:3], -1).mean(dim=3) mean_magnitudes = magnitude(means) g_m = self.mean_mat(mean_magnitudes) g = g_x + g_m.reshape(*g_m.shape, *([1] * self._dim)) g = torch.relu(g) g = self.conv2(g) g = torch.sigmoid(g) return x * g.unsqueeze(0)
def test_forward(self): nonl = CReLU((3, 6, 0, 1)) n, h, w = 3, 40, 40 inputs = torch.randn(2, n, 3 + 6 + 1, h, w) output = nonl(inputs) print(nonl.bias[0, 0, 0]) print(magnitude(inputs[:, 0, 0])) print(output[0, 0, 0])
def _forward( self, x: [2, 'b', 'f', 'h', 'w', ...]) -> [2, 'b', 'f', 'h', 'w', ...]: magn = magnitude(x).unsqueeze(0) normalized = x / (magn + self.eps) magn_biased = magn - self.bias zero = x.new_zeros((1, )) return torch.where(magn_biased > 0, normalized * magn_biased, zero)