def __init__(self, num_features, eps=1e-5, momentum=0.99, last_gamma=False, D=-1): super(MinkowskiSwitchNorm, self).__init__() self.eps = eps self.momentum = momentum self.last_gamma = last_gamma self.weight = nn.Parameter(torch.ones(1, num_features)) self.bias = nn.Parameter(torch.zeros(1, num_features)) self.mean_weight = nn.Parameter(torch.ones(3)) self.var_weight = nn.Parameter(torch.ones(3)) self.register_buffer('running_mean', torch.zeros(1, num_features)) self.register_buffer('running_var', torch.zeros(1, num_features)) self.mean_in = MinkowskiGlobalPooling(dimension=D) self.glob_sum = MinkowskiBroadcastAddition(dimension=D) self.glob_sum2 = MinkowskiBroadcastAddition(dimension=D) self.glob_mean = MinkowskiGlobalPooling(dimension=D) self.glob_times = MinkowskiBroadcastMultiplication(dimension=D) self.D = D self.reset_parameters()
def __init__(self, num_features, eps=1e-5, D=-1): super(MinkowskiInstanceNorm, self).__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(1, num_features)) self.bias = nn.Parameter(torch.zeros(1, num_features)) self.mean_in = MinkowskiGlobalPooling(dimension=D) self.glob_sum = MinkowskiBroadcastAddition(dimension=D) self.glob_sum2 = MinkowskiBroadcastAddition(dimension=D) self.glob_mean = MinkowskiGlobalPooling(dimension=D) self.glob_times = MinkowskiBroadcastMultiplication(dimension=D) self.D = D self.reset_parameters()
def test_broadcast(self): in_channels, D = 2, 2 coords, feats, labels = data_loader(in_channels) coords, feats_glob, labels = data_loader(in_channels) feats = feats.double() feats_glob = feats_glob.double() input = SparseTensor(feats, coords=coords) pool = MinkowskiGlobalPooling(dimension=D) input_glob = pool(input) input_glob.F.requires_grad_() broadcast = MinkowskiBroadcastAddition(D) broadcast_mul = MinkowskiBroadcastMultiplication(D) output = broadcast(input, input_glob) print(output) output = broadcast_mul(input, input_glob) print(output) # Check backward fn = MinkowskiBroadcastFunction() self.assertTrue( gradcheck( fn, (input.F, input_glob.F, OperationType.ADDITION, input.coords_key, input_glob.coords_key, input.coords_man))) self.assertTrue( gradcheck( fn, (input.F, input_glob.F, OperationType.MULTIPLICATION, input.coords_key, input_glob.coords_key, input.coords_man)))
def test_broadcast_gpu(self): in_channels, D = 2, 2 coords, feats, labels = data_loader(in_channels) coords, feats_glob, labels = data_loader(in_channels) feats = feats.double() feats_glob = feats_glob.double() input = SparseTensor(feats, coords=coords) pool = MinkowskiGlobalPooling() input_glob = pool(input) input_glob.F.requires_grad_() broadcast_add = MinkowskiBroadcastAddition() broadcast_mul = MinkowskiBroadcastMultiplication() broadcast_cat = MinkowskiBroadcastConcatenation() cpu_add = broadcast_add(input, input_glob) cpu_mul = broadcast_mul(input, input_glob) cpu_cat = broadcast_cat(input, input_glob) # Check backward fn = MinkowskiBroadcastFunction() device = torch.device('cuda') input = input.to(device) input_glob = input_glob.to(device) gpu_add = broadcast_add(input, input_glob) gpu_mul = broadcast_mul(input, input_glob) gpu_cat = broadcast_cat(input, input_glob) self.assertTrue( torch.prod(gpu_add.F.cpu() - cpu_add.F < 1e-5).item() == 1) self.assertTrue( torch.prod(gpu_mul.F.cpu() - cpu_mul.F < 1e-5).item() == 1) self.assertTrue( torch.prod(gpu_cat.F.cpu() - cpu_cat.F < 1e-5).item() == 1) self.assertTrue( gradcheck( fn, (input.F, input_glob.F, OperationType.ADDITION, input.coords_key, input_glob.coords_key, input.coords_man))) self.assertTrue( gradcheck( fn, (input.F, input_glob.F, OperationType.MULTIPLICATION, input.coords_key, input_glob.coords_key, input.coords_man)))
def sparse_global_sum(self, x, y): return MinkowskiBroadcastAddition(dimension=self.D)(x, y)