Exemplo n.º 1
0
    def forward(self, x, seq_len, max_num_frames):
        sorted_seq_len, sorted_idx = torch.sort(seq_len,
                                                dim=0,
                                                descending=True)
        _, original_idx = torch.sort(sorted_idx, dim=0, descending=False)
        if self.batch_first:
            sorted_x = x.index_select(0, sorted_idx)
        else:
            # print(sorted_idx)
            sorted_x = x.index_select(1, sorted_idx)

        packed_x = nn.utils.rnn.pack_padded_sequence(
            sorted_x,
            sorted_seq_len.cpu().data.numpy(),
            batch_first=self.batch_first)

        out, state = self.gru(packed_x)

        unpacked_x, unpacked_len = nn.utils.rnn.pad_packed_sequence(
            out, batch_first=self.batch_first)

        if self.batch_first:
            out = unpacked_x.index_select(0, original_idx)
            if out.shape[1] < max_num_frames:
                out = F.pad(out, [0, 0, 0, max_num_frames - out.shape[1]])
        else:
            out = unpacked_x.index_select(1, original_idx)
            if out.shape[0] < max_num_frames:
                out = F.pad(out,
                            [0, 0, 0, 0, 0, max_num_frames - out.shape[0]])

        # state = state.transpose(0, 1).contiguous().view(out.size(0), -1)
        return out
Exemplo n.º 2
0
    def forward(self, frame_embed, query_embed, seglens, wordmasks):
        """ encode query sequence
        Args:
            frame_embed [B, num_seg, vidim]
            seglens [B]
            query_embed (tensor[B, maxL, qidim])
            wordmasks (tensor[B, maxL])
        Returns:
            
        """
        B, num_seg, vhdim = frame_embed.size()
        maxL = query_embed.size(1)
        # Build Co-attention Mask
        b = torch.empty(B, num_seg, num_seg, dtype=torch.int32, device='cuda')
        for i in range(B):
            b[i] = torch.diag(F.pad(wordmasks[i], (0, num_seg-maxL)))
        b = b[:, :num_seg, :maxL]
        visual_query_mask = wordmasks.unsqueeze(1)|b
        query_visual_mask = wordmasks.unsqueeze(2).expand(B, maxL, num_seg)
        # CGA
        v_co_trm, q_co_trm = self.visual_CoTRM, self.textual_CoTRM
        query_embed1 = q_co_trm(query_embed, frame_embed, query_visual_mask)
        frame_embed1 = v_co_trm(frame_embed, query_embed1, visual_query_mask)
        frame_embed2 = v_co_trm(frame_embed, query_embed, visual_query_mask)
        query_embed2 = q_co_trm(query_embed, frame_embed2, query_visual_mask)
        # Fusion
        query_embed = torch.cat([self.textual_upsample1(query_embed1), self.textual_upsample2(query_embed2)], dim=-1)
        frame_embed = torch.cat([frame_embed1, frame_embed2], dim=-1)
        mm_embed = frame_embed + self.self_gate(query_embed)*query_embed
        mm_embed = self.rnn(self.norm_mm(mm_embed), seglens, num_seg)

        return {
            'frame_embed': mm_embed
        }
Exemplo n.º 3
0
def LFT(batch, window, y_stride: int = 1, x_stride: int = 1, padding: int = 1):
    windowBatch = extract_local_windows(batch,
                                        window.shape[0],
                                        y_stride=y_stride,
                                        x_stride=x_stride,
                                        padding=padding)
    windowPadded = F.pad(window, (padding, padding, padding, padding))
    localImageWindowsSmoothedPadded = windowBatch * windowPadded

    return custom_fft(localImageWindowsSmoothedPadded)
Exemplo n.º 4
0
    def forward(self, x):
        x = F.pad(x, (0, 1, 0, 1),
                  mode='replicate')  # cause densenet loose 1 pixel
        # somewhere in downscaling
        x = self.densenet(x)
        x = self.conv(x)
        context = self.oc_block(x)
        x = self.classifier(torch.cat([x, context], dim=1))

        x = F.interpolate(x, scale_factor=8, mode='bilinear')

        return x
    def forward(self, input):
        if self.dropout_fn is not None:
            dropped_w = self.dropout_fn.forward(self.weight, self.training)
        else:
            dropped_w = self.weight

        if self.padding_mode == 'circular':
            expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,
                                (self.padding[0] + 1) // 2, self.padding[0] // 2)
            return F.conv2d(F.pad(input, expanded_padding, mode='circular'),
                            dropped_w, self.bias, self.stride,
                            _pair(0), self.dilation, self.groups)

        return F.conv2d(input, dropped_w, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)
Exemplo n.º 6
0
def iLFT(stft2D_result,
         window,
         T,
         res_y: int,
         res_x: int,
         y_stride: int = 1,
         x_stride: int = 1,
         padding: int = 1,
         eps: float = 1e-8,
         is_inp_complex=True,
         move_window_according_T=True,
         channels=1):
    batchSize = stft2D_result.shape[0]
    cellSize = window.shape[0]

    ###this is the image padding, currently assumes x_stride == y_stride
    padSize = int(x_stride * ((cellSize - 1) // x_stride))

    cellSizeAdjust = cellSize + 2 * padding
    padSizeAdjust = int(x_stride * ((cellSizeAdjust - 1) // x_stride))

    ###the number of extracted cells along x and y axis
    ###this should be general enough to hold for different padSize
    num_windows_y = (res_y + 2 * padSizeAdjust -
                     cellSizeAdjust) // y_stride + 1
    num_windows_x = (res_x + 2 * padSizeAdjust -
                     cellSizeAdjust) // x_stride + 1
    num_windows_total = num_windows_y * num_windows_x

    if is_inp_complex:
        ifft_result = custom_ifft(stft2D_result)
    else:
        ifft_result = stft2D_result.clone()

    ifft_result = ifft_result.view(
        (batchSize, channels, num_windows_y, num_windows_x, cellSizeAdjust,
         cellSizeAdjust))

    window_big = F.pad(window, (padding, padding, padding, padding), value=0.0)
    window_big = window_big.expand(batchSize, channels, num_windows_total, -1,
                                   -1)

    if move_window_according_T:
        window_big_Complex = custom_fft(window_big)
        window_big_Complex = getPhaseAdd(window_big_Complex,
                                         T.view_as(window_big_Complex))
        window_big = custom_ifft(window_big_Complex)

    window_big = window_big.view(batchSize, channels, num_windows_y,
                                 num_windows_x, window_big.shape[3],
                                 window_big.shape[4])

    ifft_result *= window_big

    ifft_result = ifft_result.reshape(
        batchSize, -1,
        channels * ifft_result.shape[4] * ifft_result.shape[5]).permute(
            0, 2, 1)
    test = fold(ifft_result, \
                res_y=res_y, res_x=res_x, y_stride=y_stride, x_stride=x_stride, cell_size=cellSizeAdjust,
                pad_size=padSizeAdjust)

    window_big = (window_big**2).reshape(
        batchSize, -1,
        channels * window_big.shape[4] * window_big.shape[5]).permute(0, 2, 1)
    windowTracker = fold(window_big, \
                         res_y=res_y, res_x=res_x, y_stride=y_stride, x_stride=x_stride, cell_size=cellSizeAdjust,
                         pad_size=padSizeAdjust)

    windowTracker += eps
    weighted_result = test / windowTracker
    return weighted_result, windowTracker
Exemplo n.º 7
0
def unfold3d(
    tensor: torch.Tensor,
    kernel_size: Union[int, Tuple[int, int, int]],
    padding: Union[int, Tuple[int, int, int]] = 0,
    stride: Union[int, Tuple[int, int, int]] = 1,
):
    r"""
    Extracts sliding local blocks from an batched input tensor.

    :class:`torch.nn.Unfold` only supports 4D inputs (batched image-like tensors).
    This method implements the same action for 5D inputs

    Args:
        tensor: An input tensor of shape ``(B, C, D, H, W)``.
        kernel_size: the size of the sliding blocks
        padding: implicit zero padding to be added on both sides of input
        stride: the stride of the sliding blocks in the input spatial dimensions

    Example:
        >>> B, C, D, H, W = 3, 4, 5, 6, 7
        >>> tensor = torch.arange(1,B*C*D*H*W+1.).view(B,C,D,H,W)
        >>> unfold3d(tensor, kernel_size=2, padding=0, stride=1).shape
        torch.Size([3, 32, 120])

    Returns:
        A tensor of shape ``(B, C * np.product(kernel_size), L)``, where L - output spatial dimensions.
        See :class:`torch.nn.Unfold` for more details
    """

    if len(tensor.shape) != 5:
        raise ValueError(
            f"Input tensor must be of the shape [B, C, D, H, W]. Got{tensor.shape}"
        )

    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size, kernel_size)

    if isinstance(padding, int):
        padding = (padding, padding, padding)

    if isinstance(stride, int):
        stride = (stride, stride, stride)

    batch_size, channels, _, _, _ = tensor.shape

    # Input shape: (B, C, D, H, W)
    tensor = F.pad(tensor, (padding[2], padding[2], padding[1], padding[1],
                            padding[0], padding[0]))
    # Output shape: (B, C, D+2*padding[2], H+2*padding[1], W+2*padding[0])

    tensor = tensor.unfold(dimension=2, size=kernel_size[0], step=stride[0])
    tensor = tensor.unfold(dimension=3, size=kernel_size[1], step=stride[1])
    tensor = tensor.unfold(dimension=4, size=kernel_size[2], step=stride[2])
    # Output shape: (B, C, D_out, H_out, W_out, kernel_size[0], kernel_size[1], kernel_size[2])
    # For D_out, H_out, W_out definitions see :class:`torch.nn.Unfold`

    tensor = tensor.permute(0, 2, 3, 4, 1, 5, 6, 7)
    # Output shape: (B, D_out, H_out, W_out, C, kernel_size[0], kernel_size[1], kernel_size[2])

    tensor = tensor.reshape(batch_size, -1,
                            channels * np.prod(kernel_size)).transpose(1, 2)
    # Output shape: (B, D_out * H_out * W_out, C * kernel_size[0] * kernel_size[1] * kernel_size[2]

    return tensor
Exemplo n.º 8
0
def pad(inputs, position):
    """Apply pad function."""
    return F.pad(inputs, position)
Exemplo n.º 9
0
 def forward(self, x):
     x = x.transpose(0, 1).transpose(1, 2)
     x = F.pad(x, pad=self.pad, value=0)
     x = self.conv(x)
     x = x.transpose(1, 2).transpose(0, 1).contiguous()
     return x