Exemple #1
0
  def test_fb_consistency_with_occlusion(self):
    batch_size = 4
    height = 64
    width = 64
    # flows points right and up by 4
    flow_01 = np.ones((batch_size, height, width, 2)) * 4.
    # flow points left and down by 2
    imperfect_flow_10 = -flow_01 * .5
    flow_01 = torch.tensor(flow_01.astype(np.float32))
    resize_transform = transforms.Compose([transforms.Resize((int(height / 2), int(width / 2)))])
    flow_01 = torch.moveaxis(flow_01, -1, 1)
    flow_01_level1 = resize_transform(flow_01) / 2.
    flow_01_level1 = torch.moveaxis(flow_01_level1, 1, -1)

    imperfect_flow_10 = torch.tensor(imperfect_flow_10.astype(np.float32))
    imperfect_flow_10_level1 = -flow_01_level1 * .5
    flows = {}
    flows[(0, 1, 0)] = [flow_01, flow_01_level1]
    flows[(1, 0, 0)] = [imperfect_flow_10, imperfect_flow_10_level1]
    _, _, _, not_occluded_masks, _, _ = \
        uflow_utils.compute_warps_and_occlusion(
            flows, occlusion_estimation='brox')
    # assert that everything is occluded
    is_zeros_01 = np.equal(
        np.zeros((batch_size, height - 8, width - 8, 1)),
        not_occluded_masks[(0, 1, 0)][0][:, 4:-4, 4:-4, :]).all()
    is_zeros_10 = np.equal(
        np.zeros((batch_size, height - 8, width - 8, 1)),
        not_occluded_masks[(1, 0, 0)][0][:, 4:-4, 4:-4, :]).all()
    self.assertTrue(is_zeros_01)
    self.assertTrue(is_zeros_10)
def gather_nd(params, indices):
    params = torch.moveaxis(params, (0, 1, 2, 3), (0, 3, 1, 2))
    indices = torch.moveaxis(indices, (0, 1, 2, 3), (0, 3, 1, 2))
    indices = indices.type(torch.int64)
    gathered = params[list(indices.T)]
    gathered = torch.moveaxis(gathered, (0, 1, 2, 3), (3, 2, 0, 1))

    return gathered
Exemple #3
0
 def tensor_indexing_ops(self):
     x = torch.randn(2, 4)
     y = torch.randn(2, 4, 2)
     t = torch.tensor([[0, 0], [1, 0]])
     mask = x.ge(0.5)
     i = [0, 1]
     return (
         torch.cat((x, x, x), 0),
         torch.concat((x, x, x), 0),
         torch.conj(x),
         torch.chunk(x, 2),
         torch.dsplit(y, i),
         torch.column_stack((x, x)),
         torch.dstack((x, x)),
         torch.gather(x, 0, t),
         torch.hsplit(x, i),
         torch.hstack((x, x)),
         torch.index_select(x, 0, torch.tensor([0, 1])),
         torch.masked_select(x, mask),
         torch.movedim(x, 1, 0),
         torch.moveaxis(x, 1, 0),
         torch.narrow(x, 0, 0, 2),
         torch.nonzero(x),
         torch.permute(x, (0, 1)),
         torch.reshape(x, (-1, )),
     )
    def __call__(self, input_img, target_category=None):
        # replace ReLU with GuidedBackpropReLU
        self.recursive_replace_relu_with_guidedrelu(self.model)

        if self.cuda:
            input_img = input_img.cuda()

        input_img = input_img.requires_grad_(True)

        output = self.forward(input_img)

        if target_category is None:
            print('warning: using CPU')
            target_category = np.argmax(output.cpu().data.numpy())

        try:
            loss = output[0, target_category]
        except:
            loss = output.logits[0, target_category]

        # loss.backward(retain_graph=True)
        output = torch.autograd.grad(loss,input_img,create_graph=True)
        # print(input_img.grad.cpu().data.shape)
        # output = input_img.grad.cpu().data
        output = output[0][0, :, :, :]
        output = torch.moveaxis(output, 0, 2)

        # replace GuidedBackpropReLU back with ReLU
        self.recursive_replace_guidedrelu_with_relu(self.model)

        return output
Exemple #5
0
    def __call__(self, x, index=None):
        stdev = self.stdev_spread * (torch.max(x) - torch.min(x))
        total_gradients = torch.zeros_like(x)
        for i in range(self.n_samples):
            noise = torch.normal(0, stdev)
            x_plus_noise = x + noise
            output = self.model(x_plus_noise)

            if index is None:
                index = torch.argmax(output)

            one_hot = torch.zeros((1, output.size()[-1]), dtype=torch.float32)
            one_hot[0][index] = 1
            one_hot = torch.sum(one_hot * output)

            if x_plus_noise.grad is not None:
                x_plus_noise.grad.data.zero_()
            # one_hot.backward(retain_variables=True)
            # grad = x_plus_noise.grad
            grad = torch.autograd.grad(
                one_hot, x_plus_noise, create_graph=True)

            # print(grad[0][0, :, :, :])
            if self.magnitude:
                total_gradients += (grad[0] * grad[0])
            else:
                total_gradients += grad
            # if self.visdom:

        avg_gradients = total_gradients[0, :, :, :] / self.n_samples
        avg_gradients = torch.moveaxis(avg_gradients, 0, 2)

        return avg_gradients
Exemple #6
0
 def tensor_indexing_ops(self):
     x = torch.randn(2, 4)
     y = torch.randn(4, 4)
     t = torch.tensor([[0, 0], [1, 0]])
     mask = x.ge(0.5)
     i = [0, 1]
     return len(
         torch.cat((x, x, x), 0),
         torch.concat((x, x, x), 0),
         torch.conj(x),
         torch.chunk(x, 2),
         torch.dsplit(torch.randn(2, 2, 4), i),
         torch.column_stack((x, x)),
         torch.dstack((x, x)),
         torch.gather(x, 0, t),
         torch.hsplit(x, i),
         torch.hstack((x, x)),
         torch.index_select(x, 0, torch.tensor([0, 1])),
         x.index(t),
         torch.masked_select(x, mask),
         torch.movedim(x, 1, 0),
         torch.moveaxis(x, 1, 0),
         torch.narrow(x, 0, 0, 2),
         torch.nonzero(x),
         torch.permute(x, (0, 1)),
         torch.reshape(x, (-1, )),
         torch.row_stack((x, x)),
         torch.select(x, 0, 0),
         torch.scatter(x, 0, t, x),
         x.scatter(0, t, x.clone()),
         torch.diagonal_scatter(y, torch.ones(4)),
         torch.select_scatter(y, torch.ones(4), 0, 0),
         torch.slice_scatter(x, x),
         torch.scatter_add(x, 0, t, x),
         x.scatter_(0, t, y),
         x.scatter_add_(0, t, y),
         # torch.scatter_reduce(x, 0, t, reduce="sum"),
         torch.split(x, 1),
         torch.squeeze(x, 0),
         torch.stack([x, x]),
         torch.swapaxes(x, 0, 1),
         torch.swapdims(x, 0, 1),
         torch.t(x),
         torch.take(x, t),
         torch.take_along_dim(x, torch.argmax(x)),
         torch.tensor_split(x, 1),
         torch.tensor_split(x, [0, 1]),
         torch.tile(x, (2, 2)),
         torch.transpose(x, 0, 1),
         torch.unbind(x),
         torch.unsqueeze(x, -1),
         torch.vsplit(x, i),
         torch.vstack((x, x)),
         torch.where(x),
         torch.where(t > 0, t, 0),
         torch.where(t > 0, t, t),
     )
Exemple #7
0
def moveaxis(x: NdarrayOrTensor, src: int, dst: int) -> NdarrayOrTensor:
    """`moveaxis` for pytorch and numpy, using `permute` for pytorch ver < 1.8"""
    if isinstance(x, torch.Tensor):
        if hasattr(torch, "moveaxis"):
            return torch.moveaxis(x, src, dst)
        return _moveaxis_with_permute(x, src, dst)  # type: ignore
    if isinstance(x, np.ndarray):
        return np.moveaxis(x, src, dst)
    raise RuntimeError()
def img_preprocessing(img):
    """
    Takes input image from coin segmentation output vector and preprocesses for CNN model.
    :param img: single 64x64x3 coin image
    :return: CNN ready input image
    """
    img = img[..., ::-1] / 255
    input_img = torch.from_numpy(img.copy())
    input_img = torch.moveaxis(input_img, -1, 0).unsqueeze(0).float()
    return input_img
    def log_pre_update(self):
        """
        Initialize the info dictionary to be logged in wandb and collect base metrics
        Returns info dictionary.
        """

        # Initialize and update the info dict for logging
        info = dict()
        info["ppo/advantage_mean"] = self.buf_advantages.mean()
        info["ppo/advantage_std"] = self.buf_advantages.std()
        info["ppo/return_mean"] = self.buf_returns.mean()
        info["ppo/return_std"] = self.buf_returns.std()
        info["ppo/value_est_mean"] = self.rollout.buf_vpreds.mean()
        info["ppo/value_est_std"] = self.rollout.buf_vpreds.std()
        info["ppo/explained_variance"] = explained_variance(
            self.rollout.buf_vpreds.flatten(),  # TODO: switch to ravel if pytorch>=1.9
            self.buf_returns.flatten()  # TODO: switch to ravel if pytorch >= 1.9
        )
        info["ppo/reward_mean"] = torch.mean(self.rollout.buf_rewards)

        if self.rollout.best_ext_return is not None:
            info["performance/best_ext_return"] = self.rollout.best_ext_return
        # TODO: maybe add extra flag for detailed logging so runs are not slowed down
        if not self.debugging:
            feature_stats, stacked_act_feat = self.get_activation_stats(
                self.rollout.buf_acts_features, "activations_features/"
            )
            hidden_stats, stacked_act_pi = self.get_activation_stats(
                self.rollout.buf_acts_pi, "activations_hidden/"
            )
            info.update(feature_stats)
            info.update(hidden_stats)

            info["activations_features/raw_act_distribution"] = wandb.Histogram(
                to_numpy(stacked_act_feat)
            )
            info["activations_hidden/raw_act_distribution"] = wandb.Histogram(
                to_numpy(stacked_act_pi)
            )

            info["ppo/action_distribution"] = wandb.Histogram(
                to_numpy(self.rollout.buf_acs).flatten()
            )

            if self.vlog_freq >= 0 and self.n_updates % self.vlog_freq == 0:
                print(str(self.n_updates) + " updates - logging video.")
                # Reshape images such that they have shape [time,channels,width,height]
                sample_video = torch.moveaxis(self.rollout.buf_obs[0], 3, 1)
                # Log buffer video from first env
                info["observations"] = wandb.Video(
                    to_numpy(sample_video), fps=12, format="gif"
                )

        return info
Exemple #10
0
def show_test_result(result_path):
    masks = torch.load(result_path)
    print(f'masks: ', masks)
    assert (masks.shape == (24, 256, 256))
    assert ((torch.where(masks == 1, 10, 0).sum() + torch.where(masks == 0, 10, 0).sum()).item() == 24 * 256 * 256 * 10)
    masks = torch.moveaxis(masks, 0, 2)
    rand_idx = np.random.randint(0, 23)
    rand_mask = masks[:,:,rand_idx].cpu().detach().numpy()
    plt.imshow(rand_mask)
    plt.title(str(rand_idx))
    plt.show()
Exemple #11
0
	def __init__(self, 	quaternion:torch.Tensor, translation:torch.Tensor, rotation:torch.Tensor=None, 
						normalize:bool=True, unstack_inputs:bool=False) -> None:
		super().__init__()
		if not(quaternion is None):
			assert quaternion.shape[-1] == 4
			if normalize:
				quaternion = quaternion / torch.linalg.norm(quaternion, dim=-1, keepdim=True)
		if unstack_inputs:
			if not(rotation is None):
				rotation = [torch.moveaxis(x, -1, 0) for x in torch.moveaxis(rotation, -2, 0)]
			translation = torch.moveaxis(translation, -1, 0)
		if rotation is None:
			rotation = quat_to_rot(quaternion)

		self.quaternion = quaternion
		self.rotation = [list(row) for row in rotation]
		self.translation = list(translation)
		# print(self.quaternion.dtype, self.rotation[0][0].dtype,self.translation[0].dtype)

		assert all(len(row) == 3 for row in self.rotation)
		assert len(self.translation) == 3
Exemple #12
0
    def __call__(self, x, index=None):
        output = self.pretrained_model(x)

        if index is None:
            index = torch.argmax(output)

        one_hot = torch.zeros((1, output.size()[-1]), dtype=torch.float32)
        one_hot[0][index] = 1
        one_hot = torch.sum(one_hot * output)
        grad = torch.autograd.grad(one_hot, x, create_graph=True)
        grad = grad[0][0, :, :, :]
        grad = torch.moveaxis(grad, 0, 2)
        return grad
    def apply_transforms(self, image, labels):
        #inputs = np.asarray(image, dtype=np.float32)
        inputs = image

        inputs = torch.tensor(inputs, dtype=torch.float, requires_grad=False)
        labels = torch.tensor(labels, dtype=torch.long, requires_grad=False)
        """ Expected input is:   (C x W x H x D) """
        inputs = inputs.unsqueeze(0)
        inputs = torch.moveaxis(inputs, 1, -1)

        labels = labels.unsqueeze(0)
        labels = torch.moveaxis(labels, 1, -1)

        subject_a = tio.Subject(
            one_image=tio.ScalarImage(tensor=inputs),  # *** must be tensors!!!
            a_segmentation=tio.LabelMap(tensor=labels))

        subjects_list = [subject_a]

        subjects_dataset = tio.SubjectsDataset(subjects_list,
                                               transform=self.transforms)
        subject_sample = subjects_dataset[0]

        X = subject_sample['one_image']['data'].numpy()
        Y = subject_sample['a_segmentation']['data'].numpy()
        """ Re-arrange channels for Pytorch into (D, H, W) """
        X = X[0]
        X = np.moveaxis(X, -1, 0)

        Y = Y[0]
        Y = np.moveaxis(Y, -1, 0)
        """ DEBUG """
        #plot_max(X)
        #plot_max(Y)

        return X, Y
Exemple #14
0
def display_image(dataloader, batch_index=0, verbose=False):
    print('display an image')
    train_features, train_labels = next(iter(dataloader))
    if verbose:
        print(f"Feature batch shape: {train_features.size()}")
        print(f"Labels batch shape: {train_labels.size()}")
    img = train_features[batch_index].squeeze().cpu()
    # need to flip it b/c opencv reads in images as BGR, matplot reads in as RGB
    img = torch.flip(img, dims=(0, ))
    label = train_labels[batch_index]
    plot_img = torch.moveaxis(img, (0, 1, 2), (-1, 0, 1))
    if verbose:
        print('plot image shape', plot_img.shape)
    plt.imshow(plot_img)
    plt.show()
    if verbose:
        print(f"Label: {label}")
    def bilinear_splatting(self, frame1: torch.Tensor, mask1: Optional[torch.Tensor], depth1: torch.Tensor,
                           flow12: torch.Tensor, flow12_mask: Optional[torch.Tensor], is_image: bool = False) -> \
            Tuple[torch.Tensor, torch.Tensor]:
        """
        Bilinear splatting
        :param frame1: (b,c,h,w)
        :param mask1: (b,1,h,w): 1 for known, 0 for unknown. Optional
        :param depth1: (b,1,h,w)
        :param flow12: (b,2,h,w)
        :param flow12_mask: (b,1,h,w): 1 for valid flow, 0 for invalid flow. Optional
        :param is_image: if true, output will be clipped to (-1,1) range
        :return: warped_frame2: (b,c,h,w)
                 mask2: (b,1,h,w): 1 for known and 0 for unknown
        """
        if self.resolution is not None:
            assert frame1.shape[2:4] == self.resolution
        b, c, h, w = frame1.shape
        if mask1 is None:
            mask1 = torch.ones(size=(b, 1, h, w)).to(frame1)
        if flow12_mask is None:
            flow12_mask = torch.ones(size=(b, 1, h, w)).to(flow12)
        grid = self.create_grid(b, h, w).to(frame1)
        trans_pos = flow12 + grid

        trans_pos_offset = trans_pos + 1
        trans_pos_floor = torch.floor(trans_pos_offset).long()
        trans_pos_ceil = torch.ceil(trans_pos_offset).long()
        trans_pos_offset = torch.stack([
            torch.clamp(trans_pos_offset[:, 0], min=0, max=w + 1),
            torch.clamp(trans_pos_offset[:, 1], min=0, max=h + 1)
        ],
                                       dim=1)
        trans_pos_floor = torch.stack([
            torch.clamp(trans_pos_floor[:, 0], min=0, max=w + 1),
            torch.clamp(trans_pos_floor[:, 1], min=0, max=h + 1)
        ],
                                      dim=1)
        trans_pos_ceil = torch.stack([
            torch.clamp(trans_pos_ceil[:, 0], min=0, max=w + 1),
            torch.clamp(trans_pos_ceil[:, 1], min=0, max=h + 1)
        ],
                                     dim=1)

        prox_weight_nw = (1 - (trans_pos_offset[:, 1:2] - trans_pos_floor[:, 1:2])) * \
                         (1 - (trans_pos_offset[:, 0:1] - trans_pos_floor[:, 0:1]))
        prox_weight_sw = (1 - (trans_pos_ceil[:, 1:2] - trans_pos_offset[:, 1:2])) * \
                         (1 - (trans_pos_offset[:, 0:1] - trans_pos_floor[:, 0:1]))
        prox_weight_ne = (1 - (trans_pos_offset[:, 1:2] - trans_pos_floor[:, 1:2])) * \
                         (1 - (trans_pos_ceil[:, 0:1] - trans_pos_offset[:, 0:1]))
        prox_weight_se = (1 - (trans_pos_ceil[:, 1:2] - trans_pos_offset[:, 1:2])) * \
                         (1 - (trans_pos_ceil[:, 0:1] - trans_pos_offset[:, 0:1]))

        sat_depth1 = torch.clamp(depth1, min=0, max=1000)
        log_depth1 = torch.log(1 + sat_depth1)
        depth_weights = torch.exp(log_depth1 / log_depth1.max() * 50)

        weight_nw = torch.moveaxis(
            prox_weight_nw * mask1 * flow12_mask / depth_weights, [0, 1, 2, 3],
            [0, 3, 1, 2])
        weight_sw = torch.moveaxis(
            prox_weight_sw * mask1 * flow12_mask / depth_weights, [0, 1, 2, 3],
            [0, 3, 1, 2])
        weight_ne = torch.moveaxis(
            prox_weight_ne * mask1 * flow12_mask / depth_weights, [0, 1, 2, 3],
            [0, 3, 1, 2])
        weight_se = torch.moveaxis(
            prox_weight_se * mask1 * flow12_mask / depth_weights, [0, 1, 2, 3],
            [0, 3, 1, 2])

        warped_frame = torch.zeros(size=(b, h + 2, w + 2, c),
                                   dtype=torch.float32).to(frame1)
        warped_weights = torch.zeros(size=(b, h + 2, w + 2, 1),
                                     dtype=torch.float32).to(frame1)

        frame1_cl = torch.moveaxis(frame1, [0, 1, 2, 3], [0, 3, 1, 2])
        batch_indices = torch.arange(b)[:, None, None].to(frame1.device)
        warped_frame.index_put_(
            (batch_indices, trans_pos_floor[:, 1], trans_pos_floor[:, 0]),
            frame1_cl * weight_nw,
            accumulate=True)
        warped_frame.index_put_(
            (batch_indices, trans_pos_ceil[:, 1], trans_pos_floor[:, 0]),
            frame1_cl * weight_sw,
            accumulate=True)
        warped_frame.index_put_(
            (batch_indices, trans_pos_floor[:, 1], trans_pos_ceil[:, 0]),
            frame1_cl * weight_ne,
            accumulate=True)
        warped_frame.index_put_(
            (batch_indices, trans_pos_ceil[:, 1], trans_pos_ceil[:, 0]),
            frame1_cl * weight_se,
            accumulate=True)

        warped_weights.index_put_(
            (batch_indices, trans_pos_floor[:, 1], trans_pos_floor[:, 0]),
            weight_nw,
            accumulate=True)
        warped_weights.index_put_(
            (batch_indices, trans_pos_ceil[:, 1], trans_pos_floor[:, 0]),
            weight_sw,
            accumulate=True)
        warped_weights.index_put_(
            (batch_indices, trans_pos_floor[:, 1], trans_pos_ceil[:, 0]),
            weight_ne,
            accumulate=True)
        warped_weights.index_put_(
            (batch_indices, trans_pos_ceil[:, 1], trans_pos_ceil[:, 0]),
            weight_se,
            accumulate=True)

        warped_frame_cf = torch.moveaxis(warped_frame, [0, 1, 2, 3],
                                         [0, 2, 3, 1])
        warped_weights_cf = torch.moveaxis(warped_weights, [0, 1, 2, 3],
                                           [0, 2, 3, 1])
        cropped_warped_frame = warped_frame_cf[:, :, 1:-1, 1:-1]
        cropped_weights = warped_weights_cf[:, :, 1:-1, 1:-1]

        mask = cropped_weights > 0
        zero_value = -1 if is_image else 0
        zero_tensor = torch.tensor(zero_value,
                                   dtype=frame1.dtype,
                                   device=frame1.device)
        warped_frame2 = torch.where(mask,
                                    cropped_warped_frame / cropped_weights,
                                    zero_tensor)
        mask2 = mask.to(frame1)

        if is_image:
            assert warped_frame2.min() >= -1.1  # Allow for rounding errors
            assert warped_frame2.max() <= 1.1
            warped_frame2 = torch.clamp(warped_frame2, min=-1, max=1)
        return warped_frame2, mask2
Exemple #16
0
    def forward(self, input_batch, target_list_of_tokens=None):
        batch_size, frames_num, win_len, color, height, width = input_batch.shape
        output = torch.moveaxis(
            input_batch, (0, 1, 2),
            (2, 1,
             0))  # [win_len, frames_num, batch_size, color, height, width]

        output = output.reshape(win_len * frames_num * batch_size, color,
                                height, width)  # [batch, C, H, W] for resnet
        output = self.resnet18(output)  # [batch, feature=512]

        output = output.reshape(win_len, frames_num * batch_size,
                                512)  # [time, batch, feature=512] for LSTM
        output, (h_n, c_n) = self.LSTM(
            output)  # output shape is [time, batch, feature=1024]
        output = output[
            -1, :, :]  # last state output in LSTM seq [batch, feature=1024]

        output = output.reshape(
            frames_num, batch_size,
            1024)  # [frames_num, batch, feature=1024] for Transformer

        # если в nn передали list_of_tokens, то loss будет высчитываться, сравнивая с этим истинным значением
        # если в nn не передали list_of_tokens, то loss высчитываться не будет
        if target_list_of_tokens is not None:
            target_batch = self.list_of_tokens_to_tensor_of_tokens_idx(
                target_list_of_tokens)  # [batch, seq_len]
            target_batch = target_batch.to(
                self.embedding.weight.device
            )  # костыль, чтобы совпадало расположение тензоров на device

        if self.training:
            target_output = self.embedding(
                target_batch)  # [batch, seq_len, emb_size=512]
            target_output = torch.moveaxis(
                target_output, (0, 1),
                (1, 0))  # [seq_len, batch, emb_size=512] for Transformer
        else:
            dummy_batch = self.tokens_to_id['_BOS_'] * torch.ones(
                (batch_size, 1), dtype=torch.long)  # [batch, seq_len=1]
            dummy_batch = dummy_batch.to(
                self.embedding.weight.device
            )  # костыль, чтобы совпадало расположение тензоров на device

            target_output = self.embedding(
                dummy_batch)  # [batch, seq_len=1, emb_size=512]
            target_output = torch.moveaxis(
                target_output, (0, 1),
                (1, 0))  # [seq_len=1, batch, emb_size=512] for Transformer

            # generate predict_output untill all seq hasn't _EOS_
            i = 1
            max_iter = target_batch.shape[
                1] if target_list_of_tokens is not None else 30
            while i < max_iter:
                i += 1
                predict_output = self.transformer(
                    output, target_output)  # [seq_len=i, batch, emb_size=512]
                # добавляем к target_output последний слайс из predict_output
                # до тех пор, пока не получим длину max_iter
                last_slice_of_predict_output = predict_output[
                    -1:, :, :]  # [seq_len=1, batch, emb_size=512]
                target_output = torch.cat(
                    (target_output, last_slice_of_predict_output),
                    dim=0)  # [seq_len=i+1, batch, emb_size=512]

                # # проверка на наличие токена _EOS_ в ответе
                # predict_output = torch.moveaxis(predict_output, (0,1), (1,0)) # [batch, seq_len=i, emb_size=512] for Linear
                # predict_output = self.linear(predict_output) # [batch, seq=i, classes_num]
                # probs = self.softmax(predict_output) # [batch, seq_len=i, classes_num]
                # predict_tensor_of_tokens_idx = torch.argmax(probs, dim=-1) # [batch, seq_len=i]

                # # сравниваем каждый токен в строке с токеном _EOS_ и берем логическое ИЛИ вдоль размерности seq
                # val, ind = torch.max(predict_tensor_of_tokens_idx==self.tokens_to_id['_EOS_'], dim=1)
                # # берем логическое И вдоль размерности batch, чтобы убедиться, что каждый пример из батча имеет токен _EOS_
                # val, ind = torch.min(val, dim=0)
            # else:
            # dummy_batch = torch.cat((dummy_batch, predict_tensor_of_tokens_idx), dim=1) # [batch, seq_len=max_iter]

        # generate predict_output one time
        predict_output = self.transformer(
            output, target_output)  # [seq_len, batch, emb_size=512]
        predict_output = torch.moveaxis(
            predict_output, (0, 1),
            (1, 0))  # [batch, seq_len, emb_size=512] for Linear
        predict_output = self.linear(
            predict_output)  # [batch, seq_len, classes_num]

        # for return
        probs = self.softmax(predict_output)  # [batch, seq_len, classes_num]
        predict_tensor_of_tokens_idx = torch.argmax(probs,
                                                    dim=-1)  # [batch, seq_len]
        predict_list_of_tokens = self.tensor_of_tokens_idx_to_list_of_tokens(
            predict_tensor_of_tokens_idx)

        if target_list_of_tokens is not None:
            # for compute loss
            predict_batch = torch.moveaxis(
                predict_output, (1, 2),
                (2, 1))  # [batch, classes_num, seq_len] for loss
            self.loss = self.compute_loss(predict_batch, target_batch)
        else:
            self.loss = None
            # # for compute loss
            # predict_batch = torch.moveaxis(predict_output, (1,2), (2,1)) # [batch, classes_num, seq_len] for loss
            # self.loss = self.compute_loss(predict_batch, dummy_batch)

        return predict_list_of_tokens
    def bilinear_interpolation(self, frame2: torch.Tensor, mask2: Optional[torch.Tensor], flow12: torch.Tensor,
                               flow12_mask: Optional[torch.Tensor], is_image: bool = False) -> \
            Tuple[torch.Tensor, torch.Tensor]:
        """
        Bilinear interpolation
        :param frame2: (b, c, h, w)
        :param mask2: (b, 1, h, w): 1 for known, 0 for unknown. Optional
        :param flow12: (b, 2, h, w)
        :param flow12_mask: (b, 1, h, w): 1 for valid flow, 0 for invalid flow. Optional
        :param is_image: if true, output will be clipped to (-1,1) range
        :return: warped_frame1: (b, c, h, w)
                 mask1: (b, 1, h, w): 1 for known and 0 for unknown
        """
        if self.resolution is not None:
            assert frame2.shape[2:4] == self.resolution
        b, c, h, w = frame2.shape
        if mask2 is None:
            mask2 = torch.ones(size=(b, 1, h, w)).to(frame2)
        if flow12_mask is None:
            flow12_mask = torch.ones(size=(b, 1, h, w)).to(flow12)
        grid = self.create_grid(b, h, w).to(frame2)
        trans_pos = flow12 + grid

        trans_pos_offset = trans_pos + 1
        trans_pos_floor = torch.floor(trans_pos_offset).long()
        trans_pos_ceil = torch.ceil(trans_pos_offset).long()
        trans_pos_offset = torch.stack([
            torch.clamp(trans_pos_offset[:, 0], min=0, max=w + 1),
            torch.clamp(trans_pos_offset[:, 1], min=0, max=h + 1)
        ],
                                       dim=1)
        trans_pos_floor = torch.stack([
            torch.clamp(trans_pos_floor[:, 0], min=0, max=w + 1),
            torch.clamp(trans_pos_floor[:, 1], min=0, max=h + 1)
        ],
                                      dim=1)
        trans_pos_ceil = torch.stack([
            torch.clamp(trans_pos_ceil[:, 0], min=0, max=w + 1),
            torch.clamp(trans_pos_ceil[:, 1], min=0, max=h + 1)
        ],
                                     dim=1)

        prox_weight_nw = (1 - (trans_pos_offset[:, 1:2] - trans_pos_floor[:, 1:2])) * \
                         (1 - (trans_pos_offset[:, 0:1] - trans_pos_floor[:, 0:1]))
        prox_weight_sw = (1 - (trans_pos_ceil[:, 1:2] - trans_pos_offset[:, 1:2])) * \
                         (1 - (trans_pos_offset[:, 0:1] - trans_pos_floor[:, 0:1]))
        prox_weight_ne = (1 - (trans_pos_offset[:, 1:2] - trans_pos_floor[:, 1:2])) * \
                         (1 - (trans_pos_ceil[:, 0:1] - trans_pos_offset[:, 0:1]))
        prox_weight_se = (1 - (trans_pos_ceil[:, 1:2] - trans_pos_offset[:, 1:2])) * \
                         (1 - (trans_pos_ceil[:, 0:1] - trans_pos_offset[:, 0:1]))

        weight_nw = torch.moveaxis(prox_weight_nw * flow12_mask, [0, 1, 2, 3],
                                   [0, 3, 1, 2])
        weight_sw = torch.moveaxis(prox_weight_sw * flow12_mask, [0, 1, 2, 3],
                                   [0, 3, 1, 2])
        weight_ne = torch.moveaxis(prox_weight_ne * flow12_mask, [0, 1, 2, 3],
                                   [0, 3, 1, 2])
        weight_se = torch.moveaxis(prox_weight_se * flow12_mask, [0, 1, 2, 3],
                                   [0, 3, 1, 2])

        frame2_offset = F.pad(frame2, [1, 1, 1, 1])
        mask2_offset = F.pad(mask2, [1, 1, 1, 1])
        bi = torch.arange(b)[:, None, None]

        f2_nw = frame2_offset[bi, :, trans_pos_floor[:, 1], trans_pos_floor[:,
                                                                            0]]
        f2_sw = frame2_offset[bi, :, trans_pos_ceil[:, 1], trans_pos_floor[:,
                                                                           0]]
        f2_ne = frame2_offset[bi, :, trans_pos_floor[:, 1], trans_pos_ceil[:,
                                                                           0]]
        f2_se = frame2_offset[bi, :, trans_pos_ceil[:, 1], trans_pos_ceil[:,
                                                                          0]]

        m2_nw = mask2_offset[bi, :, trans_pos_floor[:, 1], trans_pos_floor[:,
                                                                           0]]
        m2_sw = mask2_offset[bi, :, trans_pos_ceil[:, 1], trans_pos_floor[:,
                                                                          0]]
        m2_ne = mask2_offset[bi, :, trans_pos_floor[:, 1], trans_pos_ceil[:,
                                                                          0]]
        m2_se = mask2_offset[bi, :, trans_pos_ceil[:, 1], trans_pos_ceil[:, 0]]

        nr = weight_nw * f2_nw * m2_nw + weight_sw * f2_sw * m2_sw + \
             weight_ne * f2_ne * m2_ne + weight_se * f2_se * m2_se
        dr = weight_nw * m2_nw + weight_sw * m2_sw + weight_ne * m2_ne + weight_se * m2_se

        zero_value = -1 if is_image else 0
        zero_tensor = torch.tensor(zero_value,
                                   dtype=nr.dtype,
                                   device=nr.device)
        warped_frame1 = torch.where(dr > 0, nr / dr, zero_tensor)
        mask1 = (dr > 0).to(frame2)

        # Convert to channel first
        warped_frame1 = torch.moveaxis(warped_frame1, [0, 1, 2, 3],
                                       [0, 2, 3, 1])
        mask1 = torch.moveaxis(mask1, [0, 1, 2, 3], [0, 2, 3, 1])

        if is_image:
            assert warped_frame1.min() >= -1.1  # Allow for rounding errors
            assert warped_frame1.max() <= 1.1
            warped_frame1 = torch.clamp(warped_frame1, min=-1, max=1)
        return warped_frame1, mask1