def test_fb_consistency_with_occlusion(self): batch_size = 4 height = 64 width = 64 # flows points right and up by 4 flow_01 = np.ones((batch_size, height, width, 2)) * 4. # flow points left and down by 2 imperfect_flow_10 = -flow_01 * .5 flow_01 = torch.tensor(flow_01.astype(np.float32)) resize_transform = transforms.Compose([transforms.Resize((int(height / 2), int(width / 2)))]) flow_01 = torch.moveaxis(flow_01, -1, 1) flow_01_level1 = resize_transform(flow_01) / 2. flow_01_level1 = torch.moveaxis(flow_01_level1, 1, -1) imperfect_flow_10 = torch.tensor(imperfect_flow_10.astype(np.float32)) imperfect_flow_10_level1 = -flow_01_level1 * .5 flows = {} flows[(0, 1, 0)] = [flow_01, flow_01_level1] flows[(1, 0, 0)] = [imperfect_flow_10, imperfect_flow_10_level1] _, _, _, not_occluded_masks, _, _ = \ uflow_utils.compute_warps_and_occlusion( flows, occlusion_estimation='brox') # assert that everything is occluded is_zeros_01 = np.equal( np.zeros((batch_size, height - 8, width - 8, 1)), not_occluded_masks[(0, 1, 0)][0][:, 4:-4, 4:-4, :]).all() is_zeros_10 = np.equal( np.zeros((batch_size, height - 8, width - 8, 1)), not_occluded_masks[(1, 0, 0)][0][:, 4:-4, 4:-4, :]).all() self.assertTrue(is_zeros_01) self.assertTrue(is_zeros_10)
def gather_nd(params, indices): params = torch.moveaxis(params, (0, 1, 2, 3), (0, 3, 1, 2)) indices = torch.moveaxis(indices, (0, 1, 2, 3), (0, 3, 1, 2)) indices = indices.type(torch.int64) gathered = params[list(indices.T)] gathered = torch.moveaxis(gathered, (0, 1, 2, 3), (3, 2, 0, 1)) return gathered
def tensor_indexing_ops(self): x = torch.randn(2, 4) y = torch.randn(2, 4, 2) t = torch.tensor([[0, 0], [1, 0]]) mask = x.ge(0.5) i = [0, 1] return ( torch.cat((x, x, x), 0), torch.concat((x, x, x), 0), torch.conj(x), torch.chunk(x, 2), torch.dsplit(y, i), torch.column_stack((x, x)), torch.dstack((x, x)), torch.gather(x, 0, t), torch.hsplit(x, i), torch.hstack((x, x)), torch.index_select(x, 0, torch.tensor([0, 1])), torch.masked_select(x, mask), torch.movedim(x, 1, 0), torch.moveaxis(x, 1, 0), torch.narrow(x, 0, 0, 2), torch.nonzero(x), torch.permute(x, (0, 1)), torch.reshape(x, (-1, )), )
def __call__(self, input_img, target_category=None): # replace ReLU with GuidedBackpropReLU self.recursive_replace_relu_with_guidedrelu(self.model) if self.cuda: input_img = input_img.cuda() input_img = input_img.requires_grad_(True) output = self.forward(input_img) if target_category is None: print('warning: using CPU') target_category = np.argmax(output.cpu().data.numpy()) try: loss = output[0, target_category] except: loss = output.logits[0, target_category] # loss.backward(retain_graph=True) output = torch.autograd.grad(loss,input_img,create_graph=True) # print(input_img.grad.cpu().data.shape) # output = input_img.grad.cpu().data output = output[0][0, :, :, :] output = torch.moveaxis(output, 0, 2) # replace GuidedBackpropReLU back with ReLU self.recursive_replace_guidedrelu_with_relu(self.model) return output
def __call__(self, x, index=None): stdev = self.stdev_spread * (torch.max(x) - torch.min(x)) total_gradients = torch.zeros_like(x) for i in range(self.n_samples): noise = torch.normal(0, stdev) x_plus_noise = x + noise output = self.model(x_plus_noise) if index is None: index = torch.argmax(output) one_hot = torch.zeros((1, output.size()[-1]), dtype=torch.float32) one_hot[0][index] = 1 one_hot = torch.sum(one_hot * output) if x_plus_noise.grad is not None: x_plus_noise.grad.data.zero_() # one_hot.backward(retain_variables=True) # grad = x_plus_noise.grad grad = torch.autograd.grad( one_hot, x_plus_noise, create_graph=True) # print(grad[0][0, :, :, :]) if self.magnitude: total_gradients += (grad[0] * grad[0]) else: total_gradients += grad # if self.visdom: avg_gradients = total_gradients[0, :, :, :] / self.n_samples avg_gradients = torch.moveaxis(avg_gradients, 0, 2) return avg_gradients
def tensor_indexing_ops(self): x = torch.randn(2, 4) y = torch.randn(4, 4) t = torch.tensor([[0, 0], [1, 0]]) mask = x.ge(0.5) i = [0, 1] return len( torch.cat((x, x, x), 0), torch.concat((x, x, x), 0), torch.conj(x), torch.chunk(x, 2), torch.dsplit(torch.randn(2, 2, 4), i), torch.column_stack((x, x)), torch.dstack((x, x)), torch.gather(x, 0, t), torch.hsplit(x, i), torch.hstack((x, x)), torch.index_select(x, 0, torch.tensor([0, 1])), x.index(t), torch.masked_select(x, mask), torch.movedim(x, 1, 0), torch.moveaxis(x, 1, 0), torch.narrow(x, 0, 0, 2), torch.nonzero(x), torch.permute(x, (0, 1)), torch.reshape(x, (-1, )), torch.row_stack((x, x)), torch.select(x, 0, 0), torch.scatter(x, 0, t, x), x.scatter(0, t, x.clone()), torch.diagonal_scatter(y, torch.ones(4)), torch.select_scatter(y, torch.ones(4), 0, 0), torch.slice_scatter(x, x), torch.scatter_add(x, 0, t, x), x.scatter_(0, t, y), x.scatter_add_(0, t, y), # torch.scatter_reduce(x, 0, t, reduce="sum"), torch.split(x, 1), torch.squeeze(x, 0), torch.stack([x, x]), torch.swapaxes(x, 0, 1), torch.swapdims(x, 0, 1), torch.t(x), torch.take(x, t), torch.take_along_dim(x, torch.argmax(x)), torch.tensor_split(x, 1), torch.tensor_split(x, [0, 1]), torch.tile(x, (2, 2)), torch.transpose(x, 0, 1), torch.unbind(x), torch.unsqueeze(x, -1), torch.vsplit(x, i), torch.vstack((x, x)), torch.where(x), torch.where(t > 0, t, 0), torch.where(t > 0, t, t), )
def moveaxis(x: NdarrayOrTensor, src: int, dst: int) -> NdarrayOrTensor: """`moveaxis` for pytorch and numpy, using `permute` for pytorch ver < 1.8""" if isinstance(x, torch.Tensor): if hasattr(torch, "moveaxis"): return torch.moveaxis(x, src, dst) return _moveaxis_with_permute(x, src, dst) # type: ignore if isinstance(x, np.ndarray): return np.moveaxis(x, src, dst) raise RuntimeError()
def img_preprocessing(img): """ Takes input image from coin segmentation output vector and preprocesses for CNN model. :param img: single 64x64x3 coin image :return: CNN ready input image """ img = img[..., ::-1] / 255 input_img = torch.from_numpy(img.copy()) input_img = torch.moveaxis(input_img, -1, 0).unsqueeze(0).float() return input_img
def log_pre_update(self): """ Initialize the info dictionary to be logged in wandb and collect base metrics Returns info dictionary. """ # Initialize and update the info dict for logging info = dict() info["ppo/advantage_mean"] = self.buf_advantages.mean() info["ppo/advantage_std"] = self.buf_advantages.std() info["ppo/return_mean"] = self.buf_returns.mean() info["ppo/return_std"] = self.buf_returns.std() info["ppo/value_est_mean"] = self.rollout.buf_vpreds.mean() info["ppo/value_est_std"] = self.rollout.buf_vpreds.std() info["ppo/explained_variance"] = explained_variance( self.rollout.buf_vpreds.flatten(), # TODO: switch to ravel if pytorch>=1.9 self.buf_returns.flatten() # TODO: switch to ravel if pytorch >= 1.9 ) info["ppo/reward_mean"] = torch.mean(self.rollout.buf_rewards) if self.rollout.best_ext_return is not None: info["performance/best_ext_return"] = self.rollout.best_ext_return # TODO: maybe add extra flag for detailed logging so runs are not slowed down if not self.debugging: feature_stats, stacked_act_feat = self.get_activation_stats( self.rollout.buf_acts_features, "activations_features/" ) hidden_stats, stacked_act_pi = self.get_activation_stats( self.rollout.buf_acts_pi, "activations_hidden/" ) info.update(feature_stats) info.update(hidden_stats) info["activations_features/raw_act_distribution"] = wandb.Histogram( to_numpy(stacked_act_feat) ) info["activations_hidden/raw_act_distribution"] = wandb.Histogram( to_numpy(stacked_act_pi) ) info["ppo/action_distribution"] = wandb.Histogram( to_numpy(self.rollout.buf_acs).flatten() ) if self.vlog_freq >= 0 and self.n_updates % self.vlog_freq == 0: print(str(self.n_updates) + " updates - logging video.") # Reshape images such that they have shape [time,channels,width,height] sample_video = torch.moveaxis(self.rollout.buf_obs[0], 3, 1) # Log buffer video from first env info["observations"] = wandb.Video( to_numpy(sample_video), fps=12, format="gif" ) return info
def show_test_result(result_path): masks = torch.load(result_path) print(f'masks: ', masks) assert (masks.shape == (24, 256, 256)) assert ((torch.where(masks == 1, 10, 0).sum() + torch.where(masks == 0, 10, 0).sum()).item() == 24 * 256 * 256 * 10) masks = torch.moveaxis(masks, 0, 2) rand_idx = np.random.randint(0, 23) rand_mask = masks[:,:,rand_idx].cpu().detach().numpy() plt.imshow(rand_mask) plt.title(str(rand_idx)) plt.show()
def __init__(self, quaternion:torch.Tensor, translation:torch.Tensor, rotation:torch.Tensor=None, normalize:bool=True, unstack_inputs:bool=False) -> None: super().__init__() if not(quaternion is None): assert quaternion.shape[-1] == 4 if normalize: quaternion = quaternion / torch.linalg.norm(quaternion, dim=-1, keepdim=True) if unstack_inputs: if not(rotation is None): rotation = [torch.moveaxis(x, -1, 0) for x in torch.moveaxis(rotation, -2, 0)] translation = torch.moveaxis(translation, -1, 0) if rotation is None: rotation = quat_to_rot(quaternion) self.quaternion = quaternion self.rotation = [list(row) for row in rotation] self.translation = list(translation) # print(self.quaternion.dtype, self.rotation[0][0].dtype,self.translation[0].dtype) assert all(len(row) == 3 for row in self.rotation) assert len(self.translation) == 3
def __call__(self, x, index=None): output = self.pretrained_model(x) if index is None: index = torch.argmax(output) one_hot = torch.zeros((1, output.size()[-1]), dtype=torch.float32) one_hot[0][index] = 1 one_hot = torch.sum(one_hot * output) grad = torch.autograd.grad(one_hot, x, create_graph=True) grad = grad[0][0, :, :, :] grad = torch.moveaxis(grad, 0, 2) return grad
def apply_transforms(self, image, labels): #inputs = np.asarray(image, dtype=np.float32) inputs = image inputs = torch.tensor(inputs, dtype=torch.float, requires_grad=False) labels = torch.tensor(labels, dtype=torch.long, requires_grad=False) """ Expected input is: (C x W x H x D) """ inputs = inputs.unsqueeze(0) inputs = torch.moveaxis(inputs, 1, -1) labels = labels.unsqueeze(0) labels = torch.moveaxis(labels, 1, -1) subject_a = tio.Subject( one_image=tio.ScalarImage(tensor=inputs), # *** must be tensors!!! a_segmentation=tio.LabelMap(tensor=labels)) subjects_list = [subject_a] subjects_dataset = tio.SubjectsDataset(subjects_list, transform=self.transforms) subject_sample = subjects_dataset[0] X = subject_sample['one_image']['data'].numpy() Y = subject_sample['a_segmentation']['data'].numpy() """ Re-arrange channels for Pytorch into (D, H, W) """ X = X[0] X = np.moveaxis(X, -1, 0) Y = Y[0] Y = np.moveaxis(Y, -1, 0) """ DEBUG """ #plot_max(X) #plot_max(Y) return X, Y
def display_image(dataloader, batch_index=0, verbose=False): print('display an image') train_features, train_labels = next(iter(dataloader)) if verbose: print(f"Feature batch shape: {train_features.size()}") print(f"Labels batch shape: {train_labels.size()}") img = train_features[batch_index].squeeze().cpu() # need to flip it b/c opencv reads in images as BGR, matplot reads in as RGB img = torch.flip(img, dims=(0, )) label = train_labels[batch_index] plot_img = torch.moveaxis(img, (0, 1, 2), (-1, 0, 1)) if verbose: print('plot image shape', plot_img.shape) plt.imshow(plot_img) plt.show() if verbose: print(f"Label: {label}")
def bilinear_splatting(self, frame1: torch.Tensor, mask1: Optional[torch.Tensor], depth1: torch.Tensor, flow12: torch.Tensor, flow12_mask: Optional[torch.Tensor], is_image: bool = False) -> \ Tuple[torch.Tensor, torch.Tensor]: """ Bilinear splatting :param frame1: (b,c,h,w) :param mask1: (b,1,h,w): 1 for known, 0 for unknown. Optional :param depth1: (b,1,h,w) :param flow12: (b,2,h,w) :param flow12_mask: (b,1,h,w): 1 for valid flow, 0 for invalid flow. Optional :param is_image: if true, output will be clipped to (-1,1) range :return: warped_frame2: (b,c,h,w) mask2: (b,1,h,w): 1 for known and 0 for unknown """ if self.resolution is not None: assert frame1.shape[2:4] == self.resolution b, c, h, w = frame1.shape if mask1 is None: mask1 = torch.ones(size=(b, 1, h, w)).to(frame1) if flow12_mask is None: flow12_mask = torch.ones(size=(b, 1, h, w)).to(flow12) grid = self.create_grid(b, h, w).to(frame1) trans_pos = flow12 + grid trans_pos_offset = trans_pos + 1 trans_pos_floor = torch.floor(trans_pos_offset).long() trans_pos_ceil = torch.ceil(trans_pos_offset).long() trans_pos_offset = torch.stack([ torch.clamp(trans_pos_offset[:, 0], min=0, max=w + 1), torch.clamp(trans_pos_offset[:, 1], min=0, max=h + 1) ], dim=1) trans_pos_floor = torch.stack([ torch.clamp(trans_pos_floor[:, 0], min=0, max=w + 1), torch.clamp(trans_pos_floor[:, 1], min=0, max=h + 1) ], dim=1) trans_pos_ceil = torch.stack([ torch.clamp(trans_pos_ceil[:, 0], min=0, max=w + 1), torch.clamp(trans_pos_ceil[:, 1], min=0, max=h + 1) ], dim=1) prox_weight_nw = (1 - (trans_pos_offset[:, 1:2] - trans_pos_floor[:, 1:2])) * \ (1 - (trans_pos_offset[:, 0:1] - trans_pos_floor[:, 0:1])) prox_weight_sw = (1 - (trans_pos_ceil[:, 1:2] - trans_pos_offset[:, 1:2])) * \ (1 - (trans_pos_offset[:, 0:1] - trans_pos_floor[:, 0:1])) prox_weight_ne = (1 - (trans_pos_offset[:, 1:2] - trans_pos_floor[:, 1:2])) * \ (1 - (trans_pos_ceil[:, 0:1] - trans_pos_offset[:, 0:1])) prox_weight_se = (1 - (trans_pos_ceil[:, 1:2] - trans_pos_offset[:, 1:2])) * \ (1 - (trans_pos_ceil[:, 0:1] - trans_pos_offset[:, 0:1])) sat_depth1 = torch.clamp(depth1, min=0, max=1000) log_depth1 = torch.log(1 + sat_depth1) depth_weights = torch.exp(log_depth1 / log_depth1.max() * 50) weight_nw = torch.moveaxis( prox_weight_nw * mask1 * flow12_mask / depth_weights, [0, 1, 2, 3], [0, 3, 1, 2]) weight_sw = torch.moveaxis( prox_weight_sw * mask1 * flow12_mask / depth_weights, [0, 1, 2, 3], [0, 3, 1, 2]) weight_ne = torch.moveaxis( prox_weight_ne * mask1 * flow12_mask / depth_weights, [0, 1, 2, 3], [0, 3, 1, 2]) weight_se = torch.moveaxis( prox_weight_se * mask1 * flow12_mask / depth_weights, [0, 1, 2, 3], [0, 3, 1, 2]) warped_frame = torch.zeros(size=(b, h + 2, w + 2, c), dtype=torch.float32).to(frame1) warped_weights = torch.zeros(size=(b, h + 2, w + 2, 1), dtype=torch.float32).to(frame1) frame1_cl = torch.moveaxis(frame1, [0, 1, 2, 3], [0, 3, 1, 2]) batch_indices = torch.arange(b)[:, None, None].to(frame1.device) warped_frame.index_put_( (batch_indices, trans_pos_floor[:, 1], trans_pos_floor[:, 0]), frame1_cl * weight_nw, accumulate=True) warped_frame.index_put_( (batch_indices, trans_pos_ceil[:, 1], trans_pos_floor[:, 0]), frame1_cl * weight_sw, accumulate=True) warped_frame.index_put_( (batch_indices, trans_pos_floor[:, 1], trans_pos_ceil[:, 0]), frame1_cl * weight_ne, accumulate=True) warped_frame.index_put_( (batch_indices, trans_pos_ceil[:, 1], trans_pos_ceil[:, 0]), frame1_cl * weight_se, accumulate=True) warped_weights.index_put_( (batch_indices, trans_pos_floor[:, 1], trans_pos_floor[:, 0]), weight_nw, accumulate=True) warped_weights.index_put_( (batch_indices, trans_pos_ceil[:, 1], trans_pos_floor[:, 0]), weight_sw, accumulate=True) warped_weights.index_put_( (batch_indices, trans_pos_floor[:, 1], trans_pos_ceil[:, 0]), weight_ne, accumulate=True) warped_weights.index_put_( (batch_indices, trans_pos_ceil[:, 1], trans_pos_ceil[:, 0]), weight_se, accumulate=True) warped_frame_cf = torch.moveaxis(warped_frame, [0, 1, 2, 3], [0, 2, 3, 1]) warped_weights_cf = torch.moveaxis(warped_weights, [0, 1, 2, 3], [0, 2, 3, 1]) cropped_warped_frame = warped_frame_cf[:, :, 1:-1, 1:-1] cropped_weights = warped_weights_cf[:, :, 1:-1, 1:-1] mask = cropped_weights > 0 zero_value = -1 if is_image else 0 zero_tensor = torch.tensor(zero_value, dtype=frame1.dtype, device=frame1.device) warped_frame2 = torch.where(mask, cropped_warped_frame / cropped_weights, zero_tensor) mask2 = mask.to(frame1) if is_image: assert warped_frame2.min() >= -1.1 # Allow for rounding errors assert warped_frame2.max() <= 1.1 warped_frame2 = torch.clamp(warped_frame2, min=-1, max=1) return warped_frame2, mask2
def forward(self, input_batch, target_list_of_tokens=None): batch_size, frames_num, win_len, color, height, width = input_batch.shape output = torch.moveaxis( input_batch, (0, 1, 2), (2, 1, 0)) # [win_len, frames_num, batch_size, color, height, width] output = output.reshape(win_len * frames_num * batch_size, color, height, width) # [batch, C, H, W] for resnet output = self.resnet18(output) # [batch, feature=512] output = output.reshape(win_len, frames_num * batch_size, 512) # [time, batch, feature=512] for LSTM output, (h_n, c_n) = self.LSTM( output) # output shape is [time, batch, feature=1024] output = output[ -1, :, :] # last state output in LSTM seq [batch, feature=1024] output = output.reshape( frames_num, batch_size, 1024) # [frames_num, batch, feature=1024] for Transformer # если в nn передали list_of_tokens, то loss будет высчитываться, сравнивая с этим истинным значением # если в nn не передали list_of_tokens, то loss высчитываться не будет if target_list_of_tokens is not None: target_batch = self.list_of_tokens_to_tensor_of_tokens_idx( target_list_of_tokens) # [batch, seq_len] target_batch = target_batch.to( self.embedding.weight.device ) # костыль, чтобы совпадало расположение тензоров на device if self.training: target_output = self.embedding( target_batch) # [batch, seq_len, emb_size=512] target_output = torch.moveaxis( target_output, (0, 1), (1, 0)) # [seq_len, batch, emb_size=512] for Transformer else: dummy_batch = self.tokens_to_id['_BOS_'] * torch.ones( (batch_size, 1), dtype=torch.long) # [batch, seq_len=1] dummy_batch = dummy_batch.to( self.embedding.weight.device ) # костыль, чтобы совпадало расположение тензоров на device target_output = self.embedding( dummy_batch) # [batch, seq_len=1, emb_size=512] target_output = torch.moveaxis( target_output, (0, 1), (1, 0)) # [seq_len=1, batch, emb_size=512] for Transformer # generate predict_output untill all seq hasn't _EOS_ i = 1 max_iter = target_batch.shape[ 1] if target_list_of_tokens is not None else 30 while i < max_iter: i += 1 predict_output = self.transformer( output, target_output) # [seq_len=i, batch, emb_size=512] # добавляем к target_output последний слайс из predict_output # до тех пор, пока не получим длину max_iter last_slice_of_predict_output = predict_output[ -1:, :, :] # [seq_len=1, batch, emb_size=512] target_output = torch.cat( (target_output, last_slice_of_predict_output), dim=0) # [seq_len=i+1, batch, emb_size=512] # # проверка на наличие токена _EOS_ в ответе # predict_output = torch.moveaxis(predict_output, (0,1), (1,0)) # [batch, seq_len=i, emb_size=512] for Linear # predict_output = self.linear(predict_output) # [batch, seq=i, classes_num] # probs = self.softmax(predict_output) # [batch, seq_len=i, classes_num] # predict_tensor_of_tokens_idx = torch.argmax(probs, dim=-1) # [batch, seq_len=i] # # сравниваем каждый токен в строке с токеном _EOS_ и берем логическое ИЛИ вдоль размерности seq # val, ind = torch.max(predict_tensor_of_tokens_idx==self.tokens_to_id['_EOS_'], dim=1) # # берем логическое И вдоль размерности batch, чтобы убедиться, что каждый пример из батча имеет токен _EOS_ # val, ind = torch.min(val, dim=0) # else: # dummy_batch = torch.cat((dummy_batch, predict_tensor_of_tokens_idx), dim=1) # [batch, seq_len=max_iter] # generate predict_output one time predict_output = self.transformer( output, target_output) # [seq_len, batch, emb_size=512] predict_output = torch.moveaxis( predict_output, (0, 1), (1, 0)) # [batch, seq_len, emb_size=512] for Linear predict_output = self.linear( predict_output) # [batch, seq_len, classes_num] # for return probs = self.softmax(predict_output) # [batch, seq_len, classes_num] predict_tensor_of_tokens_idx = torch.argmax(probs, dim=-1) # [batch, seq_len] predict_list_of_tokens = self.tensor_of_tokens_idx_to_list_of_tokens( predict_tensor_of_tokens_idx) if target_list_of_tokens is not None: # for compute loss predict_batch = torch.moveaxis( predict_output, (1, 2), (2, 1)) # [batch, classes_num, seq_len] for loss self.loss = self.compute_loss(predict_batch, target_batch) else: self.loss = None # # for compute loss # predict_batch = torch.moveaxis(predict_output, (1,2), (2,1)) # [batch, classes_num, seq_len] for loss # self.loss = self.compute_loss(predict_batch, dummy_batch) return predict_list_of_tokens
def bilinear_interpolation(self, frame2: torch.Tensor, mask2: Optional[torch.Tensor], flow12: torch.Tensor, flow12_mask: Optional[torch.Tensor], is_image: bool = False) -> \ Tuple[torch.Tensor, torch.Tensor]: """ Bilinear interpolation :param frame2: (b, c, h, w) :param mask2: (b, 1, h, w): 1 for known, 0 for unknown. Optional :param flow12: (b, 2, h, w) :param flow12_mask: (b, 1, h, w): 1 for valid flow, 0 for invalid flow. Optional :param is_image: if true, output will be clipped to (-1,1) range :return: warped_frame1: (b, c, h, w) mask1: (b, 1, h, w): 1 for known and 0 for unknown """ if self.resolution is not None: assert frame2.shape[2:4] == self.resolution b, c, h, w = frame2.shape if mask2 is None: mask2 = torch.ones(size=(b, 1, h, w)).to(frame2) if flow12_mask is None: flow12_mask = torch.ones(size=(b, 1, h, w)).to(flow12) grid = self.create_grid(b, h, w).to(frame2) trans_pos = flow12 + grid trans_pos_offset = trans_pos + 1 trans_pos_floor = torch.floor(trans_pos_offset).long() trans_pos_ceil = torch.ceil(trans_pos_offset).long() trans_pos_offset = torch.stack([ torch.clamp(trans_pos_offset[:, 0], min=0, max=w + 1), torch.clamp(trans_pos_offset[:, 1], min=0, max=h + 1) ], dim=1) trans_pos_floor = torch.stack([ torch.clamp(trans_pos_floor[:, 0], min=0, max=w + 1), torch.clamp(trans_pos_floor[:, 1], min=0, max=h + 1) ], dim=1) trans_pos_ceil = torch.stack([ torch.clamp(trans_pos_ceil[:, 0], min=0, max=w + 1), torch.clamp(trans_pos_ceil[:, 1], min=0, max=h + 1) ], dim=1) prox_weight_nw = (1 - (trans_pos_offset[:, 1:2] - trans_pos_floor[:, 1:2])) * \ (1 - (trans_pos_offset[:, 0:1] - trans_pos_floor[:, 0:1])) prox_weight_sw = (1 - (trans_pos_ceil[:, 1:2] - trans_pos_offset[:, 1:2])) * \ (1 - (trans_pos_offset[:, 0:1] - trans_pos_floor[:, 0:1])) prox_weight_ne = (1 - (trans_pos_offset[:, 1:2] - trans_pos_floor[:, 1:2])) * \ (1 - (trans_pos_ceil[:, 0:1] - trans_pos_offset[:, 0:1])) prox_weight_se = (1 - (trans_pos_ceil[:, 1:2] - trans_pos_offset[:, 1:2])) * \ (1 - (trans_pos_ceil[:, 0:1] - trans_pos_offset[:, 0:1])) weight_nw = torch.moveaxis(prox_weight_nw * flow12_mask, [0, 1, 2, 3], [0, 3, 1, 2]) weight_sw = torch.moveaxis(prox_weight_sw * flow12_mask, [0, 1, 2, 3], [0, 3, 1, 2]) weight_ne = torch.moveaxis(prox_weight_ne * flow12_mask, [0, 1, 2, 3], [0, 3, 1, 2]) weight_se = torch.moveaxis(prox_weight_se * flow12_mask, [0, 1, 2, 3], [0, 3, 1, 2]) frame2_offset = F.pad(frame2, [1, 1, 1, 1]) mask2_offset = F.pad(mask2, [1, 1, 1, 1]) bi = torch.arange(b)[:, None, None] f2_nw = frame2_offset[bi, :, trans_pos_floor[:, 1], trans_pos_floor[:, 0]] f2_sw = frame2_offset[bi, :, trans_pos_ceil[:, 1], trans_pos_floor[:, 0]] f2_ne = frame2_offset[bi, :, trans_pos_floor[:, 1], trans_pos_ceil[:, 0]] f2_se = frame2_offset[bi, :, trans_pos_ceil[:, 1], trans_pos_ceil[:, 0]] m2_nw = mask2_offset[bi, :, trans_pos_floor[:, 1], trans_pos_floor[:, 0]] m2_sw = mask2_offset[bi, :, trans_pos_ceil[:, 1], trans_pos_floor[:, 0]] m2_ne = mask2_offset[bi, :, trans_pos_floor[:, 1], trans_pos_ceil[:, 0]] m2_se = mask2_offset[bi, :, trans_pos_ceil[:, 1], trans_pos_ceil[:, 0]] nr = weight_nw * f2_nw * m2_nw + weight_sw * f2_sw * m2_sw + \ weight_ne * f2_ne * m2_ne + weight_se * f2_se * m2_se dr = weight_nw * m2_nw + weight_sw * m2_sw + weight_ne * m2_ne + weight_se * m2_se zero_value = -1 if is_image else 0 zero_tensor = torch.tensor(zero_value, dtype=nr.dtype, device=nr.device) warped_frame1 = torch.where(dr > 0, nr / dr, zero_tensor) mask1 = (dr > 0).to(frame2) # Convert to channel first warped_frame1 = torch.moveaxis(warped_frame1, [0, 1, 2, 3], [0, 2, 3, 1]) mask1 = torch.moveaxis(mask1, [0, 1, 2, 3], [0, 2, 3, 1]) if is_image: assert warped_frame1.min() >= -1.1 # Allow for rounding errors assert warped_frame1.max() <= 1.1 warped_frame1 = torch.clamp(warped_frame1, min=-1, max=1) return warped_frame1, mask1