def forward(self, x): if self.deterministic: assert self.training == False, "Flag deterministic is True. This should not be used in training." return F.conv3d(x, self.post_weight_mu, self.bias_mu) batch_size = x.size()[0] # apply local reparametrisation trick see [1] Eq. (6) # to the parametrisation given in [3] Eq. (6) mu_activations = F.conv3d(x, self.weight_mu, self.bias_mu, self.stride, self.padding, self.dilation, self.groups) var_weights = self.weight_logvar.exp() var_activations = F.conv3d(x.pow(2), var_weights, self.bias_logvar.exp(), self.stride, self.padding, self.dilation, self.groups) # compute z # note that we reparametrise according to [2] Eq. (11) (not [1]) z = reparametrize(self.z_mu.repeat(batch_size, 1, 1, 1, 1), self.z_logvar.repeat(batch_size, 1, 1, 1, 1), sampling=self.training, cuda=self.cuda) z = z[:, :, None, None, None] return reparametrize(mu_activations * z, (var_activations * z.pow(2)).log(), sampling=self.training, cuda=self.cuda)
def reverse(self, output): weight = self.calc_weight() return F.conv3d( output, weight.squeeze().inverse().unsqueeze(2).unsqueeze(3).unsqueeze(4))
def forward(self, inputs, output_per_pixel=False, output_before_combine_slices=False, train_last_layers_only=False): res = [] batch_size = inputs.shape[0] nb_input_slices = inputs.shape[1] x = inputs if not train_last_layers_only: x = x.view(batch_size * nb_input_slices, 1, inputs.shape[2], inputs.shape[3]) x = self.l1(x) x = self.bn1(x) # TODO: batch norm here may still help when cdf used x = torch.relu(x) x = self.maxpool(x) x = self.base_model.layer1(x) x = self.base_model.layer2(x) x = self.base_model.layer3(x) x = self.base_model.layer4(x) base_model_features = x.shape[1] x = x.view(batch_size, nb_input_slices, base_model_features, x.shape[2], x.shape[3]) # BxSxCxHxW if output_before_combine_slices: return x x = x.permute((0, 2, 1, 3, 4)) # BxCxSxHxW # x = self.combine_conv(x) # BxCx1xHxW slice_offset = (self.nb_input_slices - nb_input_slices) // 2 x = F.conv3d( x, self.combine_conv.weight[:, :, slice_offset:slice_offset + nb_input_slices, :, :], self.combine_conv.bias) x = x.view(batch_size, self.combine_conv_features, x.shape[3], x.shape[4]) if output_per_pixel: res.append( F.conv2d(torch.cat([x, x], dim=1), self.fc.weight[:, :, None, None], self.fc.bias)) avg_pool = F.avg_pool2d(x, x.shape[2:]) max_pool = F.max_pool2d(x, x.shape[2:]) avg_max_pool = torch.cat((avg_pool, max_pool), 1) x = avg_max_pool.view(avg_max_pool.size(0), -1) if self.dropout > 0: x = F.dropout(x, self.dropout, self.training) out = self.fc(x) if res: res.append(out) return res else: return out
def decode(bin_dir, rec_dir, model_dir, block_width, block_height): ############### retreive head info ############### T = time.time() file_object = open(bin_dir, 'rb') head_len = struct.calcsize('2HB') bits = file_object.read(head_len) [H, W, model_index] = struct.unpack('2HB', bits) # print("File Info:",Head) # Split Main & Hyper bins C = 3 out_img = np.zeros([H, W, C]) H_offset = 0 W_offset = 0 Block_Num_in_Width = int(np.ceil(W / block_width)) Block_Num_in_Height = int(np.ceil(H / block_height)) c_main = 192 c_hyper = 128 M, N2 = 192, 128 if (model_index == 6) or (model_index == 7) or (model_index == 14) or (model_index == 15): M, N2 = 256, 192 image_comp = model.Image_coding(3, M, N2, M, M // 2) context = Weighted_Gaussian(M) ######################### Load Model ######################### image_comp.load_state_dict( torch.load(os.path.join(model_dir, models[model_index] + r'.pkl'), map_location='cpu')) context.load_state_dict( torch.load(os.path.join(model_dir, models[model_index] + r'p.pkl'), map_location='cpu')) if GPU: image_comp.cuda() context.cuda() for i in range(Block_Num_in_Height): for j in range(Block_Num_in_Width): Block_head_len = struct.calcsize('2H4h2I') bits = file_object.read(Block_head_len) [ block_H, block_W, Min_Main, Max_Main, Min_V_HYPER, Max_V_HYPER, FileSizeMain, FileSizeHyper ] = struct.unpack('2H4h2I', bits) precise, tile = 16, 64. block_H_PAD = int(tile * np.ceil(block_H / tile)) block_W_PAD = int(tile * np.ceil(block_W / tile)) with open("main.bin", 'wb') as f: bits = file_object.read(FileSizeMain) f.write(bits) with open("hyper.bin", 'wb') as f: bits = file_object.read(FileSizeHyper) f.write(bits) ############### Hyper Decoder ############### # [Min_V - 0.5 , Max_V + 0.5] sample = np.arange(Min_V_HYPER, Max_V_HYPER + 1 + 1) sample = np.tile(sample, [c_hyper, 1, 1]) lower = torch.sigmoid( image_comp.factorized_entropy_func._logits_cumulative( torch.FloatTensor(sample) - 0.5, stop_gradient=False)) cdf_h = lower.data.cpu().numpy() * ( (1 << precise) - (Max_V_HYPER - Min_V_HYPER + 1)) # [N1, 1, Max - Min] cdf_h = cdf_h.astype(np.int) + sample.astype(np.int) - Min_V_HYPER T2 = time.time() AE.init_decoder("hyper.bin", Min_V_HYPER, Max_V_HYPER) Recons = [] for i in range(c_hyper): for j in range(int(block_H_PAD * block_W_PAD / 64 / 64)): # print(cdf_h[i,0,:]) Recons.append(AE.decode_cdf(cdf_h[i, 0, :].tolist())) # reshape Recons to y_hyper_q [1, c_hyper, H_PAD/64, W_PAD/64] y_hyper_q = torch.reshape( torch.Tensor(Recons), [1, c_hyper, int(block_H_PAD / 64), int(block_W_PAD / 64)]) ############### Main Decoder ############### hyper_dec = image_comp.p(image_comp.hyper_dec(y_hyper_q)) h, w = int(block_H_PAD / 16), int(block_W_PAD / 16) sample = np.arange(Min_Main, Max_Main + 1 + 1) # [Min_V - 0.5 , Max_V + 0.5] sample = torch.FloatTensor(sample) p3d = (5, 5, 5, 5, 5, 5) y_main_q = torch.zeros(1, 1, c_main + 10, h + 10, w + 10) # 8000x4000 -> 500*250 AE.init_decoder("main.bin", Min_Main, Max_Main) hyper = torch.unsqueeze(context.conv3(hyper_dec), dim=1) # context.conv1.weight.data *= context.conv1.mask for i in range(c_main): T = time.time() for j in range(int(block_H_PAD / 16)): for k in range(int(block_W_PAD / 16)): x1 = F.conv3d(y_main_q[:, :, i:i + 12, j:j + 12, k:k + 12], weight=context.conv1.weight, bias=context.conv1.bias) # [1,24,1,1,1] params_prob = context.conv2( torch.cat( (x1, hyper[:, :, i:i + 2, j:j + 2, k:k + 2]), dim=1)) # 3 gaussian prob0, mean0, scale0, prob1, mean1, scale1, prob2, mean2, scale2 = params_prob[ 0, :, 0, 0, 0] # keep the weight summation of prob == 1 probs = torch.stack([prob0, prob1, prob2], dim=-1) probs = F.softmax(probs, dim=-1) # process the scale value to positive non-zero scale0 = torch.abs(scale0) scale1 = torch.abs(scale1) scale2 = torch.abs(scale2) scale0[scale0 < 1e-6] = 1e-6 scale1[scale1 < 1e-6] = 1e-6 scale2[scale2 < 1e-6] = 1e-6 # 3 gaussian distributions m0 = torch.distributions.normal.Normal( mean0.view(1, 1).repeat(1, Max_Main - Min_Main + 2), scale0.view(1, 1).repeat(1, Max_Main - Min_Main + 2)) m1 = torch.distributions.normal.Normal( mean1.view(1, 1).repeat(1, Max_Main - Min_Main + 2), scale1.view(1, 1).repeat(1, Max_Main - Min_Main + 2)) m2 = torch.distributions.normal.Normal( mean2.view(1, 1).repeat(1, Max_Main - Min_Main + 2), scale2.view(1, 1).repeat(1, Max_Main - Min_Main + 2)) lower0 = m0.cdf(sample - 0.5) lower1 = m1.cdf(sample - 0.5) lower2 = m2.cdf(sample - 0.5) # [1,c,h,w,Max-Min+2] lower = probs[0:1] * lower0 + probs[ 1:2] * lower1 + probs[2:3] * lower2 cdf_m = lower.data.cpu().numpy() * ( (1 << precise) - (Max_Main - Min_Main + 1) ) # [1, c, h, w ,Max-Min+1] cdf_m = cdf_m.astype(np.int) + \ sample.numpy().astype(np.int) - Min_Main pixs = AE.decode_cdf(cdf_m[0, :].tolist()) y_main_q[0, 0, i + 5, j + 5, k + 5] = pixs print("Decoding Channel (%d/192), Time (s): %0.4f" % (i, time.time() - T)) del hyper, hyper_dec y_main_q = y_main_q[0, :, 5:-5, 5:-5, 5:-5] rec = image_comp.decoder(y_main_q) output_ = torch.clamp(rec, min=0., max=1.0) out = output_.data[0].cpu().numpy() out = out.transpose(1, 2, 0) out_img[H_offset:H_offset + block_H, W_offset:W_offset + block_W, :] = out[:block_H, :block_W, :] W_offset += block_W if W_offset >= W: W_offset = 0 H_offset += block_H out_img = np.round(out_img * 255.0) out_img = out_img.astype('uint8') img = Image.fromarray(out_img[:H, :W, :]) img.save(rec_dir)
def lanczos3dfilter(f): return conv3d(f, lanczosk3d, padding=3)
def forward(self, x): w = self.weight v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False) w = (w - m) / torch.sqrt(v + 1e-5) return F.conv3d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups)
def backward(ctx, grad_output): input, weights, bias = ctx.saved_tensors stride, padding, dilation, groups, alpha, beta, downsample = ctx.hparam pinput = torch.clamp(input, min=0) linput = torch.clamp(input, max=0) pweights = nweights = weights pweights = torch.clamp(weights, min=0) nweights = torch.clamp(weights, max=0) pout = F.conv3d(pinput, pweights, None, stride, padding, dilation, groups) nout = F.conv3d(linput, nweights, None, stride, padding, dilation, groups) sum_out = pout + nout sum_out[sum_out == 0] += epsilon norm_grad = grad_output / sum_out if not epsilon: norm_grad[sum_out == 0] = 0 if downsample: norm_grad = F.interpolate(norm_grad, size=input.shape[-3:], mode='trilinear') agrad = torch.nn.grad.conv3d_input(input.shape, pweights, norm_grad, stride=1, padding=padding, groups=groups) agrad *= pinput bgrad = torch.nn.grad.conv3d_input(input.shape, nweights, norm_grad, stride=1, padding=padding, groups=groups) bgrad *= linput grad = agrad + bgrad if beta: cpout = F.conv3d(pinput, nweights, None, stride, padding, dilation, groups) cnout = F.conv3d(linput, pweights, None, stride, padding, dilation, groups) csum_out = cpout + cnout csum_out[csum_out == 0] += epsilon c_grad = grad_output / csum_out if not epsilon: c_grad[csum_out == 0] = 0 c_agrad = torch.nn.grad.conv3d_input(input.shape, nweights, c_grad, stride=stride, padding=padding, groups=groups) c_agrad *= pinput c_bgrad = torch.nn.grad.conv3d_input(input.shape, pweights, c_grad, stride=stride, padding=padding, groups=groups) c_bgrad *= linput c_grad = c_agrad + c_bgrad grad = alpha * grad - beta * c_grad if all(ctx.needs_input_grad): weights.grad = torch.nn.grad.conv3d_weight(input, weights.shape, grad, stride=stride, padding=padding) if bias is not None: bias.grad = torch.nn.grad.conv3d_weight(input, bias.shape, grad, stride=stride, padding=padding) return grad, weights.grad, bias.grad, None, None, None, None, None, None, None return grad, weights.grad, None, None, None, None, None, None, None, None return grad, None, None, None, None, None, None, None, None, None
def forward(self, input): weight = self.activate_weight(self.weight, self.thresholds, self.quant_levels, self.stddev, self.training) return F.conv3d(input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
def forward(self, input): inShape = input.shape inPadded = self.pad(input.reshape((inShape[0], 1, 1, -1, inShape[-1]))) output = F.conv3d(inPadded, self.weight) * self.Ts return output.reshape(inShape)
def filter3D(input: torch.Tensor, kernel: torch.Tensor, border_type: str = 'replicate', normalized: bool = False) -> torch.Tensor: r"""Convolve a tensor with a 3d kernel. The function applies a given kernel to a tensor. The kernel is applied independently at each depth channel of the tensor. Before applying the kernel, the function applies padding according to the specified mode so that the output remains in the same shape. Args: input (torch.Tensor): the input tensor with shape of :math:`(B, C, D, H, W)`. kernel (torch.Tensor): the kernel to be convolved with the input tensor. The kernel shape must be :math:`(1, kD, kH, kW)` or :math:`(B, kD, kH, kW)`. border_type (str): the padding mode to be applied before convolving. The expected modes are: ``'constant'``, ``'replicate'`` or ``'circular'``. Default: ``'replicate'``. normalized (bool): If True, kernel will be L1 normalized. Return: torch.Tensor: the convolved tensor of same size and numbers of channels as the input with shape :math:`(B, C, D, H, W)`. Example: >>> input = torch.tensor([[[ ... [[0., 0., 0., 0., 0.], ... [0., 0., 0., 0., 0.], ... [0., 0., 0., 0., 0.], ... [0., 0., 0., 0., 0.], ... [0., 0., 0., 0., 0.]], ... [[0., 0., 0., 0., 0.], ... [0., 0., 0., 0., 0.], ... [0., 0., 5., 0., 0.], ... [0., 0., 0., 0., 0.], ... [0., 0., 0., 0., 0.]], ... [[0., 0., 0., 0., 0.], ... [0., 0., 0., 0., 0.], ... [0., 0., 0., 0., 0.], ... [0., 0., 0., 0., 0.], ... [0., 0., 0., 0., 0.]] ... ]]]) >>> kernel = torch.ones(1, 3, 3, 3) >>> filter3D(input, kernel) tensor([[[[[0., 0., 0., 0., 0.], [0., 5., 5., 5., 0.], [0., 5., 5., 5., 0.], [0., 5., 5., 5., 0.], [0., 0., 0., 0., 0.]], <BLANKLINE> [[0., 0., 0., 0., 0.], [0., 5., 5., 5., 0.], [0., 5., 5., 5., 0.], [0., 5., 5., 5., 0.], [0., 0., 0., 0., 0.]], <BLANKLINE> [[0., 0., 0., 0., 0.], [0., 5., 5., 5., 0.], [0., 5., 5., 5., 0.], [0., 5., 5., 5., 0.], [0., 0., 0., 0., 0.]]]]]) """ testing.check_is_tensor(input) testing.check_is_tensor(kernel) if not isinstance(border_type, str): raise TypeError("Input border_type is not string. Got {}".format( type(kernel))) if not len(input.shape) == 5: raise ValueError( "Invalid input shape, we expect BxCxDxHxW. Got: {}".format( input.shape)) if not len(kernel.shape) == 4 and kernel.shape[0] != 1: raise ValueError( "Invalid kernel shape, we expect 1xDxHxW. Got: {}".format( kernel.shape)) # prepare kernel b, c, d, h, w = input.shape tmp_kernel: torch.Tensor = kernel.unsqueeze(1).to(input) if normalized: bk, dk, hk, wk = kernel.shape tmp_kernel = normalize_kernel2d(tmp_kernel.view( bk, dk, hk * wk)).view_as(tmp_kernel) tmp_kernel = tmp_kernel.expand(-1, c, -1, -1, -1) # pad the input tensor depth, height, width = tmp_kernel.shape[-3:] padding_shape: List[int] = compute_padding([depth, height, width]) input_pad: torch.Tensor = F.pad(input, padding_shape, mode=border_type) # kernel and input tensor reshape to align element-wise or batch-wise params tmp_kernel = tmp_kernel.reshape(-1, 1, depth, height, width) input_pad = input_pad.view(-1, tmp_kernel.size(0), input_pad.size(-3), input_pad.size(-2), input_pad.size(-1)) # convolve the tensor with the kernel. output = F.conv3d(input_pad, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) return output.view(b, c, d, h, w)
def cubify(voxels, thresh, device=None) -> Meshes: r""" Converts a voxel to a mesh by replacing each occupied voxel with a cube consisting of 12 faces and 8 vertices. Shared vertices are merged, and internal faces are removed. Args: voxels: A FloatTensor of shape (N, D, H, W) containing occupancy probabilities. thresh: A scalar threshold. If a voxel occupancy is larger than thresh, the voxel is considered occupied. Returns: meshes: A Meshes object of the corresponding meshes. """ if device is None: device = voxels.device if len(voxels) == 0: return Meshes(verts=[], faces=[]) N, D, H, W = voxels.size() # vertices corresponding to a unit cube: 8x3 cube_verts = torch.tensor( [ [0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1], ], dtype=torch.int64, device=device, ) # faces corresponding to a unit cube: 12x3 cube_faces = torch.tensor( [ [0, 1, 2], [1, 3, 2], # left face: 0, 1 [2, 3, 6], [3, 7, 6], # bottom face: 2, 3 [0, 2, 6], [0, 6, 4], # front face: 4, 5 [0, 5, 1], [0, 4, 5], # up face: 6, 7 [6, 7, 5], [6, 5, 4], # right face: 8, 9 [1, 7, 3], [1, 5, 7], # back face: 10, 11 ], dtype=torch.int64, device=device, ) wx = torch.tensor([0.5, 0.5], device=device).view(1, 1, 1, 1, 2) wy = torch.tensor([0.5, 0.5], device=device).view(1, 1, 1, 2, 1) wz = torch.tensor([0.5, 0.5], device=device).view(1, 1, 2, 1, 1) voxelt = voxels.ge(thresh).float() # N x 1 x D x H x W voxelt = voxelt.view(N, 1, D, H, W) # N x 1 x (D-1) x (H-1) x (W-1) voxelt_x = F.conv3d(voxelt, wx).gt(0.5).float() voxelt_y = F.conv3d(voxelt, wy).gt(0.5).float() voxelt_z = F.conv3d(voxelt, wz).gt(0.5).float() # 12 x N x 1 x D x H x W faces_idx = torch.ones((cube_faces.size(0), N, 1, D, H, W), device=device) # add left face faces_idx[0, :, :, :, :, 1:] = 1 - voxelt_x faces_idx[1, :, :, :, :, 1:] = 1 - voxelt_x # add bottom face faces_idx[2, :, :, :, :-1, :] = 1 - voxelt_y faces_idx[3, :, :, :, :-1, :] = 1 - voxelt_y # add front face faces_idx[4, :, :, 1:, :, :] = 1 - voxelt_z faces_idx[5, :, :, 1:, :, :] = 1 - voxelt_z # add up face faces_idx[6, :, :, :, 1:, :] = 1 - voxelt_y faces_idx[7, :, :, :, 1:, :] = 1 - voxelt_y # add right face faces_idx[8, :, :, :, :, :-1] = 1 - voxelt_x faces_idx[9, :, :, :, :, :-1] = 1 - voxelt_x # add back face faces_idx[10, :, :, :-1, :, :] = 1 - voxelt_z faces_idx[11, :, :, :-1, :, :] = 1 - voxelt_z faces_idx *= voxelt # N x H x W x D x 12 faces_idx = faces_idx.permute(1, 2, 4, 5, 3, 0).squeeze(1) # (NHWD) x 12 faces_idx = faces_idx.contiguous() faces_idx = faces_idx.view(-1, cube_faces.size(0)) # boolean to linear index # NF x 2 linind = torch.nonzero(faces_idx) # NF x 4 nyxz = unravel_index(linind[:, 0], (N, H, W, D)) # NF x 3: faces faces = torch.index_select(cube_faces, 0, linind[:, 1]) grid_faces = [] for d in range(cube_faces.size(1)): # NF x 3 xyz = torch.index_select(cube_verts, 0, faces[:, d]) permute_idx = torch.tensor([1, 0, 2], device=device) yxz = torch.index_select(xyz, 1, permute_idx) yxz += nyxz[:, 1:] # NF x 1 temp = ravel_index(yxz, (H + 1, W + 1, D + 1)) grid_faces.append(temp) # NF x 3 grid_faces = torch.stack(grid_faces, dim=1) y, x, z = torch.meshgrid( torch.arange(H + 1), torch.arange(W + 1), torch.arange(D + 1) ) y = y.to(device=device, dtype=torch.float32) y = y * 2.0 / (H - 1.0) - 1.0 x = x.to(device=device, dtype=torch.float32) x = x * 2.0 / (W - 1.0) - 1.0 z = z.to(device=device, dtype=torch.float32) z = z * 2.0 / (D - 1.0) - 1.0 # ((H+1)(W+1)(D+1)) x 3 grid_verts = torch.stack((x, y, z), dim=3).view(-1, 3) if len(nyxz) == 0: verts_list = [torch.tensor([], dtype=torch.float32, device=device)] * N faces_list = [torch.tensor([], dtype=torch.int64, device=device)] * N return Meshes(verts=verts_list, faces=faces_list) num_verts = grid_verts.size(0) grid_faces += nyxz[:, 0].view(-1, 1) * num_verts idleverts = torch.ones(num_verts * N, dtype=torch.uint8, device=device) idleverts.scatter_(0, grid_faces.flatten(), 0) grid_faces -= nyxz[:, 0].view(-1, 1) * num_verts split_size = torch.bincount(nyxz[:, 0], minlength=N) faces_list = list(torch.split(grid_faces, split_size.tolist(), 0)) idleverts = idleverts.view(N, num_verts) idlenum = idleverts.cumsum(1) verts_list = [ grid_verts.index_select(0, (idleverts[n] == 0).nonzero()[:, 0]) for n in range(N) ] faces_list = [nface - idlenum[n][nface] for n, nface in enumerate(faces_list)] return Meshes(verts=verts_list, faces=faces_list)
def train(train_loader, encoder, decoder, refiner, optimizer, args): global i global num_rec global overall_loss global latent_loss global rec_loss global factor global avg global num_print j=0 loss_curr = 0 latent_loss_curr = 0 rec_loss_curr = 0 avg_curr = 0 num_img = 20 for image, shape in train_loader: if image.shape[0] < num_img: break for iteration in range(1): j += 1 shape = shape.cuda() image = image.cuda() optimizer.zero_grad() latent = encoder(image) result = decoder(latent) result = refiner(result) img_res = 256 images = torch.cuda.FloatTensor(num_img * 2, 1, img_res, img_res) #### show_images = torch.cuda.FloatTensor(num_img, 1, img_res, img_res) loss = 0 rand = torch.randint(0, 24, (num_img,)) if j % 100 == 1: for i in range(num_img): cam = rand[i] images[i,0,:,:], _ = differentiable_rendering(result[i,0,:,:,:], result.shape[-1], img_res, camera_list[cam]) images[i+num_img,0,:,:], _ = differentiable_rendering(shape[i,0,:,:,:], shape.shape[-1], img_res, camera_list[cam]) if j % 100 == 1: if i % 2 == 0: show_images[int(i/2),0,:,:] = images[i,0,:,:] show_images[int(i/2) + int(num_img / 2),0,:,:] = images[i+num_img,0,:,:] if j % 100 == 1: grid = make_grid(show_images, nrow=int(num_img/2)) torchvision.utils.save_image(grid, "../result/" + args.category + "/train_" + str(num_rec) + "_" + str(j) + ".png", nrow=6, padding=2, normalize=False, range=None, scale_each=False, pad_value=0) obj_loss = 0 # narrow band mask = torch.abs(result[0,0]) < 0.1 mask = mask.float() # sdf loss image_loss, sdf_loss = loss_fn(images[:num_img][0,0], images[num_img:][0,0], result[0,0] * mask, 4/64., 64, 64, 64, 256, 256) obj_loss += sdf_loss / (64**3) * 0.02 # laplancian loss conv_input = (result[0,0] * mask).unsqueeze(0).unsqueeze(0) conv_filter = torch.cuda.FloatTensor([[[[[0, 0, 0], [0, 1, 0], [0, 0, 0]], [[0, 1, 0], [1, -6, 1], [0, 1, 0]], [[0, 0, 0], [0, 1, 0], [0, 0, 0]]]]]) Lp_loss = torch.sum(F.conv3d(conv_input, conv_filter) ** 2) / (64**3) obj_loss += Lp_loss * 0.02 # image loss obj_loss += F.mse_loss(images[:num_img], images[num_img:]) * 15 * (256 * 256 / img_res / img_res) # back probagate loss = obj_loss loss.backward() optimizer.step()
# Pad image for backwards difference in last slice pad_slice = 2*pot[-1] - pot[-2] pot = np.vstack((pot, pad_slice[np.newaxis,])) # Pad image with zeroes on all other boundaries pot = np.pad(pot, ((1,0),(1,1),(1,1)), 'constant') # Create mask where conductive phase = 1 mask = (pot > 0).astype('float32') # Kernel to shift image in +z direction k_s = np.zeros((3,3,3), dtype='float32') k_s[2,1,1] = 1 # Shift mask in +z direction mask_s = conv3d(torch.as_tensor(np.expand_dims(mask, [0,1])), torch.as_tensor(np.expand_dims(k_s, [0,1]))).numpy().squeeze() # Forward difference kernel k_fd = np.zeros_like(k_s) k_fd[2,1,1] = -1 k_fd[1,1,1] = 1 # Forward difference convolution dpot = conv3d(torch.as_tensor(np.expand_dims(pot, [0,1])), torch.as_tensor(np.expand_dims(k_fd, [0,1]))).numpy().squeeze() # Current in Z calculation mask = mask[1:-1,1:-1,1:-1] # Remove padding on mask currz_conv = dpot*mask*mask_s #%% Calculating current using nested loops
def computemotionenergy(img, spatempfilter, stride): ''' calculate orientation energy of an image based on a set of gabor filters. <img>: Input image can be numpy array or tensor, is a (batch_size, c, T, H, W), Note that input image is 'uint8' dtype and range (0, 256) <spatempfilter>: the object returned by makemultiscalespatiotemporalfilters function <stride>: int, stride as number of pixels along the spatial domain Note: 1. We typically set the stride as 1 in the temporal domain ''' import torch from torch.nn.functional import conv3d from torch import Tensor, from_numpy # to use conv3d function, it must be converted to tensor from numpy import array, transpose, newaxis, round, sqrt # convert img to tensor is_tensor = False if isinstance(img, Tensor): is_tensor = True else: # convert it as tensor img = from_numpy(img).type(torch.float) img = img / 256 # convert to (0~1) imgSize, imgLen = img.shape[3], img.shape[2] # assume square image if isinstance(stride, int): stride = [stride] gbr, sd, tf = spatempfilter['gabor'], spatempfilter['sd'], spatempfilter[ 'TF'] filteredimg = [] for igbr, vgbr in enumerate(gbr): # loop each spatial and temporal scale # reformat the filter gbrtmp = vgbr gbrtmp = transpose(gbrtmp, [3, 2, 0, 1]) # (out_channels, kT, kH, kW) gbrtmp = gbrtmp[:, newaxis, :, :, :] # (out_channels, 1, kT, kH, kW) gbrtmp = from_numpy(gbrtmp).type(torch.float) # convert it to torch ksize = gbrtmp.shape[3] # assume square kernel tsize = gbrtmp.shape[2] nKernel = gbrtmp.shape[0] # now gbrtmp is a [outchannel, group, T, H, W], as standard input for conv3d # calculate stride and padding stride_tmp = int(round(stride[igbr])) spadding = int(ksize / 2 - imgSize % stride_tmp / 2) tpadding = int((tsize - 1) / 2) # convolve # output a is a [batch_size, Cout, Tout, Hout, Wout] a = conv3d(img, gbrtmp, stride=(1, stride_tmp, stride_tmp), padding=(tpadding, spadding, spadding)) #a = torch.squeeze(a) # assume 0:Cout/2 and Cout/2+1: are different phase, # transform quadradic pair of simple cells to complex cells a = sqrt(a[:, :int(nKernel / 2), :, :, :]**2 + a[:, int(nKernel / 2):, :, :, :]**2) if is_tensor: filteredimg.append(a.view(a.shape[0], a.shape[1], a.shape[2] - 1)) else: a = array(a) # convert it back to numpy array filteredimg.append( a.reshape(a.shape[0], a.shape[1], a.shape[2], -1)) # flatten it as a vector return filteredimg
def ssd(kpts_fixed, feat_fixed, feat_moving, orig_shape, disp_radius=16, disp_step=2, patch_radius=2, alpha=1.5, unroll_factor=50): _, N, _ = kpts_fixed.shape device = kpts_fixed.device D, H, W = orig_shape C = feat_fixed.shape[1] dtype = feat_fixed.dtype patch_step = disp_step # same stride necessary for fast implementation patch = torch.stack( torch.meshgrid( torch.arange(0, 2 * patch_radius + 1, patch_step), torch.arange(0, 2 * patch_radius + 1, patch_step), torch.arange(0, 2 * patch_radius + 1, patch_step))).permute( 1, 2, 3, 0).contiguous().view(1, 1, -1, 1, 3).float() - patch_radius patch = (patch.flip(-1) * 2 / (torch.tensor([W, H, D]) - 1)).to(dtype).to(device) patch_width = round(patch.shape[2]**(1.0 / 3)) if patch_width % 2 == 0: pad = [(patch_width - 1) // 2, (patch_width - 1) // 2 + 1] else: pad = [(patch_width - 1) // 2, (patch_width - 1) // 2] disp = torch.stack( torch.meshgrid( torch.arange(-disp_step * (disp_radius + ((pad[0] + pad[1]) / 2)), (disp_step * (disp_radius + ((pad[0] + pad[1]) / 2))) + 1, disp_step), torch.arange(-disp_step * (disp_radius + ((pad[0] + pad[1]) / 2)), (disp_step * (disp_radius + ((pad[0] + pad[1]) / 2))) + 1, disp_step), torch.arange( -disp_step * (disp_radius + ((pad[0] + pad[1]) / 2)), (disp_step * (disp_radius + ((pad[0] + pad[1]) / 2))) + 1, disp_step))).permute(1, 2, 3, 0).contiguous().view(1, 1, -1, 1, 3).float() disp = (disp.flip(-1) * 2 / (torch.tensor([W, H, D]) - 1)).to(dtype).to(device) disp_width = disp_radius * 2 + 1 ssd = torch.zeros(1, N, disp_width**3).to(dtype).to(device) split = np.array_split(np.arange(N), unroll_factor) for i in range(unroll_factor): feat_fixed_patch = F.grid_sample( feat_fixed, kpts_fixed[:, split[i], :].view(1, -1, 1, 1, 3).to(dtype) + patch, padding_mode='border', align_corners=True) feat_moving_disp = F.grid_sample( feat_moving, kpts_fixed[:, split[i], :].view(1, -1, 1, 1, 3).to(dtype) + disp, padding_mode='border', align_corners=True) corr = F.conv3d( feat_moving_disp.view(1, -1, disp_width + pad[0] + pad[1], disp_width + pad[0] + pad[1], disp_width + pad[0] + pad[1]), feat_fixed_patch.view(-1, 1, patch_width, patch_width, patch_width), groups=C * split[i].shape[0]).view(C, split[i].shape[0], -1) patch_sum = (feat_fixed_patch**2).squeeze(0).squeeze(3).sum( dim=2, keepdims=True) disp_sum = (patch_width**3) * F.avg_pool3d( (feat_moving_disp**2).view(C, -1, disp_width + pad[0] + pad[1], disp_width + pad[0] + pad[1], disp_width + pad[0] + pad[1]), patch_width, stride=1).view(C, split[i].shape[0], -1) ssd[0, split[i], :] = ((-2 * corr + patch_sum + disp_sum)).sum(0) ssd *= (alpha / (patch_width**3)) return ssd
def conv_soft_argmax3d( input: torch.Tensor, kernel_size: Tuple[int, int, int] = (3, 3, 3), stride: Tuple[int, int, int] = (1, 1, 1), padding: Tuple[int, int, int] = (1, 1, 1), temperature: Union[torch.Tensor, float] = torch.tensor(1.0), normalized_coordinates: bool = False, eps: float = 1e-8, output_value: bool = True, strict_maxima_bonus: float = 0.0 ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Function that computes the convolutional spatial Soft-Argmax 3D over the windows of a given input heatmap. Function has two outputs: argmax coordinates and the softmaxpooled heatmap values themselves. On each window, the function computed is: .. math:: ijk(X) = \frac{\sum{(i,j,k)} * exp(x / T) \in X} {\sum{exp(x / T) \in X}} .. math:: val(X) = \frac{\sum{x * exp(x / T) \in X}} {\sum{exp(x / T) \in X}} where T is temperature. Args: kernel_size (Tuple[int,int,int]): size of the window stride (Tuple[int,int,int]): stride of the window. padding (Tuple[int,int,int]): input zero padding temperature (torch.Tensor): factor to apply to input. Default is 1. normalized_coordinates (bool): whether to return the coordinates normalized in the range of [-1, 1]. Otherwise, it will return the coordinates in the range of the input shape. Default is False. eps (float): small value to avoid zero division. Default is 1e-8. output_value (bool): if True, val is output, if False, only ij strict_maxima_bonus (float): pixels, which are strict maxima will score (1 + strict_maxima_bonus) * value. This is needed for mimic behavior of strict NMS in classic local features Shape: - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` - Output: :math:`(N, C, 3, D_{out}, H_{out}, W_{out})`, :math:`(N, C, D_{out}, H_{out}, W_{out})`, where .. math:: D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor .. math:: H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor .. math:: W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor Examples: >>> input = torch.randn(20, 16, 3, 50, 32) >>> nms_coords, nms_val = conv_soft_argmax3d(input, (3, 3, 3), (1, 2, 2), (0, 1, 1)) """ if not torch.is_tensor(input): raise TypeError("Input type is not a torch.Tensor. Got {}".format( type(input))) if not len(input.shape) == 5: raise ValueError( "Invalid input shape, we expect BxCxDxHxW. Got: {}".format( input.shape)) if temperature <= 0: raise ValueError( "Temperature should be positive float or tensor. Got: {}".format( temperature)) b, c, d, h, w = input.shape kx, ky, kz = kernel_size device: torch.device = input.device dtype: torch.dtype = input.dtype input = input.view(b * c, 1, d, h, w) center_kernel: torch.Tensor = _get_center_kernel3d(kx, ky, kz, device).to(dtype) window_kernel: torch.Tensor = _get_window_grid_kernel3d( kx, ky, kz, device).to(dtype) # applies exponential normalization trick # https://timvieira.github.io/blog/post/2014/02/11/exp-normalize-trick/ # https://github.com/pytorch/pytorch/blob/bcb0bb7e0e03b386ad837015faba6b4b16e3bfb9/aten/src/ATen/native/SoftMax.cpp#L44 x_max = F.adaptive_max_pool3d(input, (1, 1, 1)) # max is detached to prevent undesired backprop loops in the graph x_exp = ((input - x_max.detach()) / temperature).exp() pool_coef: float = float(kx * ky * kz) # softmax denominator den = pool_coef * F.avg_pool3d( x_exp.view_as(input), kernel_size, stride=stride, padding=padding) + eps # We need to output also coordinates # Pooled window center coordinates grid_global: torch.Tensor = create_meshgrid3d( d, h, w, False, device=device).to(dtype).permute(0, 4, 1, 2, 3) grid_global_pooled = F.conv3d(grid_global, center_kernel, stride=stride, padding=padding) # Coordinates of maxima residual to window center # prepare kernel coords_max: torch.Tensor = F.conv3d(x_exp, window_kernel, stride=stride, padding=padding) coords_max = coords_max / den.expand_as(coords_max) coords_max = coords_max + grid_global_pooled.expand_as(coords_max) # [:,:, 0, ...] is depth (scale) # [:,:, 1, ...] is x # [:,:, 2, ...] is y if normalized_coordinates: coords_max = normalize_pixel_coordinates3d( coords_max.permute(0, 2, 3, 4, 1), d, h, w) coords_max = coords_max.permute(0, 4, 1, 2, 3) # Back B*C -> (b, c) coords_max = coords_max.view(b, c, 3, coords_max.size(2), coords_max.size(3), coords_max.size(4)) if not output_value: return coords_max x_softmaxpool = pool_coef * F.avg_pool3d(x_exp.view(input.size()) * input, kernel_size, stride=stride, padding=padding) / den if strict_maxima_bonus > 0: in_levels: int = input.size(2) out_levels: int = x_softmaxpool.size(2) skip_levels: int = (in_levels - out_levels) // 2 strict_maxima: torch.Tensor = F.avg_pool3d( kornia.feature.nms3d(input, kernel_size), 1, stride, 0) strict_maxima = strict_maxima[:, :, skip_levels:out_levels - skip_levels] x_softmaxpool *= 1.0 + strict_maxima_bonus * strict_maxima x_softmaxpool = x_softmaxpool.view(b, c, x_softmaxpool.size(2), x_softmaxpool.size(3), x_softmaxpool.size(4)) return coords_max, x_softmaxpool
def forward(self, input): N, C, H, W, Ns = input.shape inPadded = self.pad(input.reshape((N, 1, 1, -1, Ns))) output = F.conv3d(inPadded, self.weight) * self.Ts return output.reshape((N, -1, H, W, Ns))
def inverse(self, x): if self.learnable == True: return invertible_downsampling.apply(self.kernel_matrix, x) else: return F.conv3d(x, self.kernel, stride=2, groups=self.input_channels//8)
def forward(self, input): return F.conv3d(input, self.transform_weight(), self.bias, self.stride, self.padding, self.dilation)
def forward(self, x): return F.conv3d(x, self.filter, bias=None)
def apply_srm_kernel(self, input_spikes, srm): return F.conv3d(input_spikes, srm, padding=(0, 0, int(srm.shape[4]/2))) * self.net_params['t_s']
def avgfilter(f): return conv3d(f, avgk, padding=1)
def apply_weights(activations, weights): applied = F.conv3d(activations, weights) return applied
def forward(self, x): """ Convolution forward function Divide the kernel into three parts on output channels based on acs_kernel_split, and conduct convolution on three directions seperately. Bias is added at last. """ B, C_in, *input_shape = x.shape conv3D_output_shape = (self.conv3D_output_shape_f(0, input_shape), self.conv3D_output_shape_f(1, input_shape), self.conv3D_output_shape_f(2, input_shape)) weight_a = self.weight[0:self.acs_kernel_split[0]].unsqueeze(2) weight_c = self.weight[self.acs_kernel_split[0]:( self.acs_kernel_split[0] + self.acs_kernel_split[1])].unsqueeze(3) weight_s = self.weight[(self.acs_kernel_split[0] + self.acs_kernel_split[1]):].unsqueeze(4) f_out = [] if self.acs_kernel_split[0] > 0: a = F.conv3d( x if conv3D_output_shape[0] == input_shape[0] or 2 * conv3D_output_shape[0] == input_shape[0] else F.pad( x, (0, 0, 0, 0, self.padding[0], self.padding[0]), 'constant', 0)[:, :, self.kernel_size[0] // 2:self.kernel_size[0] // 2 + (conv3D_output_shape[0] - 1) * self.stride[0] + 1, :, :], weight=weight_a, bias=None, stride=self.stride, padding=(0, self.padding[1], self.padding[2]), dilation=self.dilation, groups=self.groups) f_out.append(a) if self.acs_kernel_split[1] > 0: c = F.conv3d( x if conv3D_output_shape[1] == input_shape[1] or 2 * conv3D_output_shape[1] == input_shape[1] else F.pad( x, (0, 0, self.padding[1], self.padding[1]), 'constant', 0)[:, :, :, self.kernel_size[1] // 2:self.kernel_size[1] // 2 + self.stride[1] * (conv3D_output_shape[1] - 1) + 1, :], weight=weight_c, bias=None, stride=self.stride, padding=(self.padding[0], 0, self.padding[2]), dilation=self.dilation, groups=self.groups) f_out.append(c) if self.acs_kernel_split[2] > 0: s = F.conv3d( x if conv3D_output_shape[2] == input_shape[2] or 2 * conv3D_output_shape[2] == input_shape[2] else F.pad( x, (self.padding[2], self.padding[2]), 'constant', 0)[:, :, :, :, self.kernel_size[2] // 2:self.kernel_size[2] // 2 + self.stride[2] * (conv3D_output_shape[2] - 1) + 1], weight=weight_s, bias=None, stride=self.stride, padding=(self.padding[0], self.padding[1], 0), dilation=self.dilation, groups=self.groups) f_out.append(s) f = torch.cat(f_out, dim=1) if self.bias is not None: f += self.bias.view(1, self.out_channels, 1, 1, 1) return f
def forward(self, LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH): assert len(LLL.size()) == len(LLH.size()) == len(LHL.size()) == len( LHH.size()) == len(HLL.size()) == len(HLH.size()) == len( HHL.size()) == len(HHH.size()) == 5 assert LLL.size()[0] == LLH.size()[0] == LHL.size()[0] == LHH.size( )[0] == HLL.size()[0] == HLH.size()[0] == HHL.size()[0] == HHH.size( )[0] assert LLL.size()[1] == LLH.size()[1] == LHL.size()[1] == LHH.size( )[1] == HLL.size()[1] == HLH.size()[1] == HHL.size()[1] == HHH.size( )[1] == self.in_channels LLL = F.pad(F.conv_transpose3d(LLL, self.up_filter, stride=self.stride, groups=self.in_channels), pad=self.pad_sizes, mode=self.pad_type) LLH = F.pad(F.conv_transpose3d(LLH, self.up_filter, stride=self.stride, groups=self.in_channels), pad=self.pad_sizes, mode=self.pad_type) LHL = F.pad(F.conv_transpose3d(LHL, self.up_filter, stride=self.stride, groups=self.in_channels), pad=self.pad_sizes, mode=self.pad_type) LHH = F.pad(F.conv_transpose3d(LHH, self.up_filter, stride=self.stride, groups=self.in_channels), pad=self.pad_sizes, mode=self.pad_type) HLL = F.pad(F.conv_transpose3d(HLL, self.up_filter, stride=self.stride, groups=self.in_channels), pad=self.pad_sizes, mode=self.pad_type) HLH = F.pad(F.conv_transpose3d(HLH, self.up_filter, stride=self.stride, groups=self.in_channels), pad=self.pad_sizes, mode=self.pad_type) HHL = F.pad(F.conv_transpose3d(HHL, self.up_filter, stride=self.stride, groups=self.in_channels), pad=self.pad_sizes, mode=self.pad_type) HHH = F.pad(F.conv_transpose3d(HHH, self.up_filter, stride=self.stride, groups=self.in_channels), pad=self.pad_sizes, mode=self.pad_type) return F.conv3d(LLL, self.filter_lll, stride = 1, groups = self.groups) + \ F.conv3d(LLH, self.filter_llh, stride = 1, groups = self.groups) + \ F.conv3d(LHL, self.filter_lhl, stride = 1, groups = self.groups) + \ F.conv3d(LHH, self.filter_lhh, stride = 1, groups = self.groups) + \ F.conv3d(HLL, self.filter_hll, stride = 1, groups = self.groups) + \ F.conv3d(HLH, self.filter_hlh, stride = 1, groups = self.groups) + \ F.conv3d(HHL, self.filter_hhl, stride = 1, groups = self.groups) + \ F.conv3d(HHH, self.filter_hhh, stride = 1, groups = self.groups)
def forward(self, input): return F.conv3d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
def forward(self, x): return F.conv3d(x, self.W_(), self.bias, self.stride, self.padding, self.dilation, self.groups)
def forward(self, x): # note: won't work when self.padding_mode != 'zeros' return F.conv3d(x, self._get_augmented_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
def reconstruct(H, W, Z): pad_size = (W.shape[2] - 1, W.shape[3] - 1, W.shape[4] - 1) out = F.conv3d(H, W.flip((2, 3, 4)) * Z.view(-1, 1, 1, 1), padding=pad_size) return out
def acs_conv_f(x, weight, bias, kernel_size, dilation, padding, stride, groups, out_channels, acs_kernel_split): B, C_in, *input_shape = x.shape C_out = weight.shape[0] assert groups == 1 or groups == C_in == C_out, "only support standard or depthwise conv" conv3D_output_shape = (conv3D_output_shape_f(0, input_shape, kernel_size, dilation, padding, stride), conv3D_output_shape_f(1, input_shape, kernel_size, dilation, padding, stride), conv3D_output_shape_f(2, input_shape, kernel_size, dilation, padding, stride)) weight_a = weight[0:acs_kernel_split[0]].unsqueeze(2) weight_c = weight[acs_kernel_split[0]:(acs_kernel_split[0] + acs_kernel_split[1])].unsqueeze(3) weight_s = weight[(acs_kernel_split[0] + acs_kernel_split[1]):].unsqueeze(4) if groups == C_in == C_out: # depth-wise x_a = x[:, 0:acs_kernel_split[0]] x_c = x[:, acs_kernel_split[0]:(acs_kernel_split[0] + acs_kernel_split[1])] x_s = x[:, (acs_kernel_split[0] + acs_kernel_split[1]):] group_a = acs_kernel_split[0] group_c = acs_kernel_split[1] group_s = acs_kernel_split[2] else: # groups=1 x_a = x_c = x_s = x group_a = group_c = group_s = 1 f_out = [] if acs_kernel_split[0] > 0: a = F.conv3d( x_a if conv3D_output_shape[0] == input_shape[0] or 2 * conv3D_output_shape[0] == input_shape[0] else F.pad( x, (0, 0, 0, 0, padding[0], padding[0]), 'constant', 0)[:, :, kernel_size[0] // 2:kernel_size[0] // 2 + (conv3D_output_shape[0] - 1) * stride[0] + 1, :, :], weight=weight_a, bias=None, stride=stride, padding=(0, padding[1], padding[2]), dilation=dilation, groups=group_a) f_out.append(a) if acs_kernel_split[1] > 0: c = F.conv3d(x_c if conv3D_output_shape[1] == input_shape[1] or 2 * conv3D_output_shape[1] == input_shape[1] else F.pad(x, (0, 0, padding[1], padding[1]), 'constant', 0)[:, :, :, kernel_size[1] // 2:kernel_size[1] // 2 + stride[1] * (conv3D_output_shape[1] - 1) + 1, :], weight=weight_c, bias=None, stride=stride, padding=(padding[0], 0, padding[2]), dilation=dilation, groups=group_c) f_out.append(c) if acs_kernel_split[2] > 0: s = F.conv3d(x_s if conv3D_output_shape[2] == input_shape[2] or 2 * conv3D_output_shape[2] == input_shape[2] else F.pad(x, (padding[2], padding[2]), 'constant', 0)[:, :, :, :, kernel_size[2] // 2:kernel_size[2] // 2 + stride[2] * (conv3D_output_shape[2] - 1) + 1], weight=weight_s, bias=None, stride=stride, padding=(padding[0], padding[1], 0), dilation=dilation, groups=group_s) f_out.append(s) f = torch.cat(f_out, dim=1) if bias is not None: f += bias.view(1, out_channels, 1, 1, 1) return f
def conv3d(self, *args, **kargs): return F.conv3d(self, *args, **kargs)