def forward(self, x):
        if self.deterministic:
            assert self.training == False, "Flag deterministic is True. This should not be used in training."
            return F.conv3d(x, self.post_weight_mu, self.bias_mu)
        batch_size = x.size()[0]
        # apply local reparametrisation trick see [1] Eq. (6)
        # to the parametrisation given in [3] Eq. (6)
        mu_activations = F.conv3d(x, self.weight_mu, self.bias_mu, self.stride,
                                  self.padding, self.dilation, self.groups)

        var_weights = self.weight_logvar.exp()
        var_activations = F.conv3d(x.pow(2), var_weights, self.bias_logvar.exp(), self.stride,
                                   self.padding, self.dilation, self.groups)
        # compute z
        # note that we reparametrise according to [2] Eq. (11) (not [1])
        z = reparametrize(self.z_mu.repeat(batch_size, 1, 1, 1, 1), self.z_logvar.repeat(batch_size, 1, 1, 1, 1),
                          sampling=self.training, cuda=self.cuda)
        z = z[:, :, None, None, None]

        return reparametrize(mu_activations * z, (var_activations * z.pow(2)).log(), sampling=self.training,
                             cuda=self.cuda)
示例#2
0
    def reverse(self, output):
        weight = self.calc_weight()

        return F.conv3d(
            output,
            weight.squeeze().inverse().unsqueeze(2).unsqueeze(3).unsqueeze(4))
示例#3
0
    def forward(self,
                inputs,
                output_per_pixel=False,
                output_before_combine_slices=False,
                train_last_layers_only=False):
        res = []
        batch_size = inputs.shape[0]
        nb_input_slices = inputs.shape[1]

        x = inputs

        if not train_last_layers_only:
            x = x.view(batch_size * nb_input_slices, 1, inputs.shape[2],
                       inputs.shape[3])
            x = self.l1(x)
            x = self.bn1(x)
            # TODO: batch norm here may still help when cdf used
            x = torch.relu(x)
            x = self.maxpool(x)

            x = self.base_model.layer1(x)
            x = self.base_model.layer2(x)
            x = self.base_model.layer3(x)
            x = self.base_model.layer4(x)

            base_model_features = x.shape[1]
            x = x.view(batch_size, nb_input_slices, base_model_features,
                       x.shape[2], x.shape[3])  # BxSxCxHxW

        if output_before_combine_slices:
            return x

        x = x.permute((0, 2, 1, 3, 4))  # BxCxSxHxW
        # x = self.combine_conv(x)  # BxCx1xHxW
        slice_offset = (self.nb_input_slices - nb_input_slices) // 2
        x = F.conv3d(
            x, self.combine_conv.weight[:, :, slice_offset:slice_offset +
                                        nb_input_slices, :, :],
            self.combine_conv.bias)

        x = x.view(batch_size, self.combine_conv_features, x.shape[3],
                   x.shape[4])

        if output_per_pixel:
            res.append(
                F.conv2d(torch.cat([x, x], dim=1),
                         self.fc.weight[:, :, None, None], self.fc.bias))

        avg_pool = F.avg_pool2d(x, x.shape[2:])
        max_pool = F.max_pool2d(x, x.shape[2:])
        avg_max_pool = torch.cat((avg_pool, max_pool), 1)
        x = avg_max_pool.view(avg_max_pool.size(0), -1)

        if self.dropout > 0:
            x = F.dropout(x, self.dropout, self.training)

        out = self.fc(x)

        if res:
            res.append(out)
            return res
        else:
            return out
示例#4
0
文件: inference.py 项目: ywu40/NIC
def decode(bin_dir, rec_dir, model_dir, block_width, block_height):
    ############### retreive head info ###############
    T = time.time()
    file_object = open(bin_dir, 'rb')

    head_len = struct.calcsize('2HB')
    bits = file_object.read(head_len)
    [H, W, model_index] = struct.unpack('2HB', bits)
    # print("File Info:",Head)
    # Split Main & Hyper bins
    C = 3
    out_img = np.zeros([H, W, C])
    H_offset = 0
    W_offset = 0
    Block_Num_in_Width = int(np.ceil(W / block_width))
    Block_Num_in_Height = int(np.ceil(H / block_height))

    c_main = 192
    c_hyper = 128

    M, N2 = 192, 128
    if (model_index == 6) or (model_index == 7) or (model_index
                                                    == 14) or (model_index
                                                               == 15):
        M, N2 = 256, 192
    image_comp = model.Image_coding(3, M, N2, M, M // 2)
    context = Weighted_Gaussian(M)
    ######################### Load Model #########################
    image_comp.load_state_dict(
        torch.load(os.path.join(model_dir, models[model_index] + r'.pkl'),
                   map_location='cpu'))
    context.load_state_dict(
        torch.load(os.path.join(model_dir, models[model_index] + r'p.pkl'),
                   map_location='cpu'))
    if GPU:
        image_comp.cuda()
        context.cuda()

    for i in range(Block_Num_in_Height):
        for j in range(Block_Num_in_Width):

            Block_head_len = struct.calcsize('2H4h2I')
            bits = file_object.read(Block_head_len)
            [
                block_H, block_W, Min_Main, Max_Main, Min_V_HYPER, Max_V_HYPER,
                FileSizeMain, FileSizeHyper
            ] = struct.unpack('2H4h2I', bits)

            precise, tile = 16, 64.

            block_H_PAD = int(tile * np.ceil(block_H / tile))
            block_W_PAD = int(tile * np.ceil(block_W / tile))

            with open("main.bin", 'wb') as f:
                bits = file_object.read(FileSizeMain)
                f.write(bits)
            with open("hyper.bin", 'wb') as f:
                bits = file_object.read(FileSizeHyper)
                f.write(bits)

            ############### Hyper Decoder ###############
            # [Min_V - 0.5 , Max_V + 0.5]
            sample = np.arange(Min_V_HYPER, Max_V_HYPER + 1 + 1)
            sample = np.tile(sample, [c_hyper, 1, 1])
            lower = torch.sigmoid(
                image_comp.factorized_entropy_func._logits_cumulative(
                    torch.FloatTensor(sample) - 0.5, stop_gradient=False))
            cdf_h = lower.data.cpu().numpy() * (
                (1 << precise) -
                (Max_V_HYPER - Min_V_HYPER + 1))  # [N1, 1, Max - Min]
            cdf_h = cdf_h.astype(np.int) + sample.astype(np.int) - Min_V_HYPER
            T2 = time.time()
            AE.init_decoder("hyper.bin", Min_V_HYPER, Max_V_HYPER)
            Recons = []
            for i in range(c_hyper):
                for j in range(int(block_H_PAD * block_W_PAD / 64 / 64)):
                    # print(cdf_h[i,0,:])
                    Recons.append(AE.decode_cdf(cdf_h[i, 0, :].tolist()))
            # reshape Recons to y_hyper_q   [1, c_hyper, H_PAD/64, W_PAD/64]
            y_hyper_q = torch.reshape(
                torch.Tensor(Recons),
                [1, c_hyper,
                 int(block_H_PAD / 64),
                 int(block_W_PAD / 64)])

            ############### Main Decoder ###############
            hyper_dec = image_comp.p(image_comp.hyper_dec(y_hyper_q))
            h, w = int(block_H_PAD / 16), int(block_W_PAD / 16)
            sample = np.arange(Min_Main,
                               Max_Main + 1 + 1)  # [Min_V - 0.5 , Max_V + 0.5]

            sample = torch.FloatTensor(sample)

            p3d = (5, 5, 5, 5, 5, 5)
            y_main_q = torch.zeros(1, 1, c_main + 10, h + 10,
                                   w + 10)  # 8000x4000 -> 500*250
            AE.init_decoder("main.bin", Min_Main, Max_Main)
            hyper = torch.unsqueeze(context.conv3(hyper_dec), dim=1)

            #
            context.conv1.weight.data *= context.conv1.mask

            for i in range(c_main):
                T = time.time()
                for j in range(int(block_H_PAD / 16)):
                    for k in range(int(block_W_PAD / 16)):

                        x1 = F.conv3d(y_main_q[:, :, i:i + 12, j:j + 12,
                                               k:k + 12],
                                      weight=context.conv1.weight,
                                      bias=context.conv1.bias)  # [1,24,1,1,1]
                        params_prob = context.conv2(
                            torch.cat(
                                (x1, hyper[:, :, i:i + 2, j:j + 2, k:k + 2]),
                                dim=1))

                        # 3 gaussian
                        prob0, mean0, scale0, prob1, mean1, scale1, prob2, mean2, scale2 = params_prob[
                            0, :, 0, 0, 0]
                        # keep the weight  summation of prob == 1
                        probs = torch.stack([prob0, prob1, prob2], dim=-1)
                        probs = F.softmax(probs, dim=-1)

                        # process the scale value to positive non-zero
                        scale0 = torch.abs(scale0)
                        scale1 = torch.abs(scale1)
                        scale2 = torch.abs(scale2)
                        scale0[scale0 < 1e-6] = 1e-6
                        scale1[scale1 < 1e-6] = 1e-6
                        scale2[scale2 < 1e-6] = 1e-6
                        # 3 gaussian distributions
                        m0 = torch.distributions.normal.Normal(
                            mean0.view(1, 1).repeat(1,
                                                    Max_Main - Min_Main + 2),
                            scale0.view(1, 1).repeat(1,
                                                     Max_Main - Min_Main + 2))
                        m1 = torch.distributions.normal.Normal(
                            mean1.view(1, 1).repeat(1,
                                                    Max_Main - Min_Main + 2),
                            scale1.view(1, 1).repeat(1,
                                                     Max_Main - Min_Main + 2))
                        m2 = torch.distributions.normal.Normal(
                            mean2.view(1, 1).repeat(1,
                                                    Max_Main - Min_Main + 2),
                            scale2.view(1, 1).repeat(1,
                                                     Max_Main - Min_Main + 2))
                        lower0 = m0.cdf(sample - 0.5)
                        lower1 = m1.cdf(sample - 0.5)
                        lower2 = m2.cdf(sample - 0.5)  # [1,c,h,w,Max-Min+2]

                        lower = probs[0:1] * lower0 + probs[
                            1:2] * lower1 + probs[2:3] * lower2
                        cdf_m = lower.data.cpu().numpy() * (
                            (1 << precise) - (Max_Main - Min_Main + 1)
                        )  # [1, c, h, w ,Max-Min+1]
                        cdf_m = cdf_m.astype(np.int) + \
                            sample.numpy().astype(np.int) - Min_Main

                        pixs = AE.decode_cdf(cdf_m[0, :].tolist())
                        y_main_q[0, 0, i + 5, j + 5, k + 5] = pixs

                print("Decoding Channel (%d/192), Time (s): %0.4f" %
                      (i, time.time() - T))
            del hyper, hyper_dec
            y_main_q = y_main_q[0, :, 5:-5, 5:-5, 5:-5]
            rec = image_comp.decoder(y_main_q)

            output_ = torch.clamp(rec, min=0., max=1.0)
            out = output_.data[0].cpu().numpy()
            out = out.transpose(1, 2, 0)
            out_img[H_offset:H_offset + block_H, W_offset:W_offset +
                    block_W, :] = out[:block_H, :block_W, :]
            W_offset += block_W
            if W_offset >= W:
                W_offset = 0
                H_offset += block_H
    out_img = np.round(out_img * 255.0)
    out_img = out_img.astype('uint8')
    img = Image.fromarray(out_img[:H, :W, :])
    img.save(rec_dir)
示例#5
0
def lanczos3dfilter(f):
    return conv3d(f, lanczosk3d, padding=3)
示例#6
0
 def forward(self, x):
     w = self.weight
     v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
     w = (w - m) / torch.sqrt(v + 1e-5)
     return F.conv3d(x, w, self.bias, self.stride, self.padding,
                     self.dilation, self.groups)
示例#7
0
    def backward(ctx, grad_output):
        input, weights, bias = ctx.saved_tensors
        stride, padding, dilation, groups, alpha, beta, downsample = ctx.hparam

        pinput = torch.clamp(input, min=0)
        linput = torch.clamp(input, max=0)

        pweights = nweights = weights
        pweights = torch.clamp(weights, min=0)
        nweights = torch.clamp(weights, max=0)

        pout = F.conv3d(pinput, pweights, None, stride, padding, dilation,
                        groups)
        nout = F.conv3d(linput, nweights, None, stride, padding, dilation,
                        groups)
        sum_out = pout + nout
        sum_out[sum_out == 0] += epsilon

        norm_grad = grad_output / sum_out
        if not epsilon:
            norm_grad[sum_out == 0] = 0

        if downsample:
            norm_grad = F.interpolate(norm_grad,
                                      size=input.shape[-3:],
                                      mode='trilinear')

        agrad = torch.nn.grad.conv3d_input(input.shape,
                                           pweights,
                                           norm_grad,
                                           stride=1,
                                           padding=padding,
                                           groups=groups)
        agrad *= pinput
        bgrad = torch.nn.grad.conv3d_input(input.shape,
                                           nweights,
                                           norm_grad,
                                           stride=1,
                                           padding=padding,
                                           groups=groups)
        bgrad *= linput

        grad = agrad + bgrad

        if beta:
            cpout = F.conv3d(pinput, nweights, None, stride, padding, dilation,
                             groups)
            cnout = F.conv3d(linput, pweights, None, stride, padding, dilation,
                             groups)
            csum_out = cpout + cnout
            csum_out[csum_out == 0] += epsilon
            c_grad = grad_output / csum_out
            if not epsilon:
                c_grad[csum_out == 0] = 0

            c_agrad = torch.nn.grad.conv3d_input(input.shape,
                                                 nweights,
                                                 c_grad,
                                                 stride=stride,
                                                 padding=padding,
                                                 groups=groups)
            c_agrad *= pinput
            c_bgrad = torch.nn.grad.conv3d_input(input.shape,
                                                 pweights,
                                                 c_grad,
                                                 stride=stride,
                                                 padding=padding,
                                                 groups=groups)
            c_bgrad *= linput

            c_grad = c_agrad + c_bgrad

            grad = alpha * grad - beta * c_grad
        if all(ctx.needs_input_grad):
            weights.grad = torch.nn.grad.conv3d_weight(input,
                                                       weights.shape,
                                                       grad,
                                                       stride=stride,
                                                       padding=padding)
            if bias is not None:
                bias.grad = torch.nn.grad.conv3d_weight(input,
                                                        bias.shape,
                                                        grad,
                                                        stride=stride,
                                                        padding=padding)
                return grad, weights.grad, bias.grad, None, None, None, None, None, None, None
            return grad, weights.grad, None, None, None, None, None, None, None, None
        return grad, None, None, None, None, None, None, None, None, None
示例#8
0
 def forward(self, input):
     weight = self.activate_weight(self.weight, self.thresholds,
                                   self.quant_levels, self.stddev,
                                   self.training)
     return F.conv3d(input, weight, self.bias, self.stride, self.padding,
                     self.dilation, self.groups)
示例#9
0
 def forward(self, input):
     inShape = input.shape
     inPadded = self.pad(input.reshape((inShape[0], 1, 1, -1, inShape[-1])))
     output = F.conv3d(inPadded, self.weight) * self.Ts
     return output.reshape(inShape)
示例#10
0
def filter3D(input: torch.Tensor,
             kernel: torch.Tensor,
             border_type: str = 'replicate',
             normalized: bool = False) -> torch.Tensor:
    r"""Convolve a tensor with a 3d kernel.

    The function applies a given kernel to a tensor. The kernel is applied
    independently at each depth channel of the tensor. Before applying the
    kernel, the function applies padding according to the specified mode so
    that the output remains in the same shape.

    Args:
        input (torch.Tensor): the input tensor with shape of
          :math:`(B, C, D, H, W)`.
        kernel (torch.Tensor): the kernel to be convolved with the input
          tensor. The kernel shape must be :math:`(1, kD, kH, kW)`  or :math:`(B, kD, kH, kW)`.
        border_type (str): the padding mode to be applied before convolving.
          The expected modes are: ``'constant'``,
          ``'replicate'`` or ``'circular'``. Default: ``'replicate'``.
        normalized (bool): If True, kernel will be L1 normalized.

    Return:
        torch.Tensor: the convolved tensor of same size and numbers of channels
        as the input with shape :math:`(B, C, D, H, W)`.

    Example:
        >>> input = torch.tensor([[[
        ...    [[0., 0., 0., 0., 0.],
        ...     [0., 0., 0., 0., 0.],
        ...     [0., 0., 0., 0., 0.],
        ...     [0., 0., 0., 0., 0.],
        ...     [0., 0., 0., 0., 0.]],
        ...    [[0., 0., 0., 0., 0.],
        ...     [0., 0., 0., 0., 0.],
        ...     [0., 0., 5., 0., 0.],
        ...     [0., 0., 0., 0., 0.],
        ...     [0., 0., 0., 0., 0.]],
        ...    [[0., 0., 0., 0., 0.],
        ...     [0., 0., 0., 0., 0.],
        ...     [0., 0., 0., 0., 0.],
        ...     [0., 0., 0., 0., 0.],
        ...     [0., 0., 0., 0., 0.]]
        ... ]]])
        >>> kernel = torch.ones(1, 3, 3, 3)
        >>> filter3D(input, kernel)
        tensor([[[[[0., 0., 0., 0., 0.],
                   [0., 5., 5., 5., 0.],
                   [0., 5., 5., 5., 0.],
                   [0., 5., 5., 5., 0.],
                   [0., 0., 0., 0., 0.]],
        <BLANKLINE>
                  [[0., 0., 0., 0., 0.],
                   [0., 5., 5., 5., 0.],
                   [0., 5., 5., 5., 0.],
                   [0., 5., 5., 5., 0.],
                   [0., 0., 0., 0., 0.]],
        <BLANKLINE>
                  [[0., 0., 0., 0., 0.],
                   [0., 5., 5., 5., 0.],
                   [0., 5., 5., 5., 0.],
                   [0., 5., 5., 5., 0.],
                   [0., 0., 0., 0., 0.]]]]])
    """
    testing.check_is_tensor(input)
    testing.check_is_tensor(kernel)

    if not isinstance(border_type, str):
        raise TypeError("Input border_type is not string. Got {}".format(
            type(kernel)))

    if not len(input.shape) == 5:
        raise ValueError(
            "Invalid input shape, we expect BxCxDxHxW. Got: {}".format(
                input.shape))

    if not len(kernel.shape) == 4 and kernel.shape[0] != 1:
        raise ValueError(
            "Invalid kernel shape, we expect 1xDxHxW. Got: {}".format(
                kernel.shape))

    # prepare kernel
    b, c, d, h, w = input.shape
    tmp_kernel: torch.Tensor = kernel.unsqueeze(1).to(input)

    if normalized:
        bk, dk, hk, wk = kernel.shape
        tmp_kernel = normalize_kernel2d(tmp_kernel.view(
            bk, dk, hk * wk)).view_as(tmp_kernel)

    tmp_kernel = tmp_kernel.expand(-1, c, -1, -1, -1)

    # pad the input tensor
    depth, height, width = tmp_kernel.shape[-3:]
    padding_shape: List[int] = compute_padding([depth, height, width])
    input_pad: torch.Tensor = F.pad(input, padding_shape, mode=border_type)

    # kernel and input tensor reshape to align element-wise or batch-wise params
    tmp_kernel = tmp_kernel.reshape(-1, 1, depth, height, width)
    input_pad = input_pad.view(-1, tmp_kernel.size(0), input_pad.size(-3),
                               input_pad.size(-2), input_pad.size(-1))

    # convolve the tensor with the kernel.
    output = F.conv3d(input_pad,
                      tmp_kernel,
                      groups=tmp_kernel.size(0),
                      padding=0,
                      stride=1)
    return output.view(b, c, d, h, w)
示例#11
0
def cubify(voxels, thresh, device=None) -> Meshes:
    r"""
    Converts a voxel to a mesh by replacing each occupied voxel with a cube
    consisting of 12 faces and 8 vertices. Shared vertices are merged, and
    internal faces are removed.
    Args:
      voxels: A FloatTensor of shape (N, D, H, W) containing occupancy probabilities.
      thresh: A scalar threshold. If a voxel occupancy is larger than
          thresh, the voxel is considered occupied.
    Returns:
      meshes: A Meshes object of the corresponding meshes.
    """

    if device is None:
        device = voxels.device

    if len(voxels) == 0:
        return Meshes(verts=[], faces=[])

    N, D, H, W = voxels.size()
    # vertices corresponding to a unit cube: 8x3
    cube_verts = torch.tensor(
        [
            [0, 0, 0],
            [0, 0, 1],
            [0, 1, 0],
            [0, 1, 1],
            [1, 0, 0],
            [1, 0, 1],
            [1, 1, 0],
            [1, 1, 1],
        ],
        dtype=torch.int64,
        device=device,
    )

    # faces corresponding to a unit cube: 12x3
    cube_faces = torch.tensor(
        [
            [0, 1, 2],
            [1, 3, 2],  # left face: 0, 1
            [2, 3, 6],
            [3, 7, 6],  # bottom face: 2, 3
            [0, 2, 6],
            [0, 6, 4],  # front face: 4, 5
            [0, 5, 1],
            [0, 4, 5],  # up face: 6, 7
            [6, 7, 5],
            [6, 5, 4],  # right face: 8, 9
            [1, 7, 3],
            [1, 5, 7],  # back face: 10, 11
        ],
        dtype=torch.int64,
        device=device,
    )

    wx = torch.tensor([0.5, 0.5], device=device).view(1, 1, 1, 1, 2)
    wy = torch.tensor([0.5, 0.5], device=device).view(1, 1, 1, 2, 1)
    wz = torch.tensor([0.5, 0.5], device=device).view(1, 1, 2, 1, 1)

    voxelt = voxels.ge(thresh).float()
    # N x 1 x D x H x W
    voxelt = voxelt.view(N, 1, D, H, W)

    # N x 1 x (D-1) x (H-1) x (W-1)
    voxelt_x = F.conv3d(voxelt, wx).gt(0.5).float()
    voxelt_y = F.conv3d(voxelt, wy).gt(0.5).float()
    voxelt_z = F.conv3d(voxelt, wz).gt(0.5).float()

    # 12 x N x 1 x D x H x W
    faces_idx = torch.ones((cube_faces.size(0), N, 1, D, H, W), device=device)

    # add left face
    faces_idx[0, :, :, :, :, 1:] = 1 - voxelt_x
    faces_idx[1, :, :, :, :, 1:] = 1 - voxelt_x
    # add bottom face
    faces_idx[2, :, :, :, :-1, :] = 1 - voxelt_y
    faces_idx[3, :, :, :, :-1, :] = 1 - voxelt_y
    # add front face
    faces_idx[4, :, :, 1:, :, :] = 1 - voxelt_z
    faces_idx[5, :, :, 1:, :, :] = 1 - voxelt_z
    # add up face
    faces_idx[6, :, :, :, 1:, :] = 1 - voxelt_y
    faces_idx[7, :, :, :, 1:, :] = 1 - voxelt_y
    # add right face
    faces_idx[8, :, :, :, :, :-1] = 1 - voxelt_x
    faces_idx[9, :, :, :, :, :-1] = 1 - voxelt_x
    # add back face
    faces_idx[10, :, :, :-1, :, :] = 1 - voxelt_z
    faces_idx[11, :, :, :-1, :, :] = 1 - voxelt_z

    faces_idx *= voxelt

    # N x H x W x D x 12
    faces_idx = faces_idx.permute(1, 2, 4, 5, 3, 0).squeeze(1)
    # (NHWD) x 12
    faces_idx = faces_idx.contiguous()
    faces_idx = faces_idx.view(-1, cube_faces.size(0))

    # boolean to linear index
    # NF x 2
    linind = torch.nonzero(faces_idx)
    # NF x 4
    nyxz = unravel_index(linind[:, 0], (N, H, W, D))

    # NF x 3: faces
    faces = torch.index_select(cube_faces, 0, linind[:, 1])

    grid_faces = []
    for d in range(cube_faces.size(1)):
        # NF x 3
        xyz = torch.index_select(cube_verts, 0, faces[:, d])
        permute_idx = torch.tensor([1, 0, 2], device=device)
        yxz = torch.index_select(xyz, 1, permute_idx)
        yxz += nyxz[:, 1:]
        # NF x 1
        temp = ravel_index(yxz, (H + 1, W + 1, D + 1))
        grid_faces.append(temp)
    # NF x 3
    grid_faces = torch.stack(grid_faces, dim=1)

    y, x, z = torch.meshgrid(
        torch.arange(H + 1), torch.arange(W + 1), torch.arange(D + 1)
    )
    y = y.to(device=device, dtype=torch.float32)
    y = y * 2.0 / (H - 1.0) - 1.0
    x = x.to(device=device, dtype=torch.float32)
    x = x * 2.0 / (W - 1.0) - 1.0
    z = z.to(device=device, dtype=torch.float32)
    z = z * 2.0 / (D - 1.0) - 1.0
    # ((H+1)(W+1)(D+1)) x 3
    grid_verts = torch.stack((x, y, z), dim=3).view(-1, 3)

    if len(nyxz) == 0:
        verts_list = [torch.tensor([], dtype=torch.float32, device=device)] * N
        faces_list = [torch.tensor([], dtype=torch.int64, device=device)] * N
        return Meshes(verts=verts_list, faces=faces_list)

    num_verts = grid_verts.size(0)
    grid_faces += nyxz[:, 0].view(-1, 1) * num_verts
    idleverts = torch.ones(num_verts * N, dtype=torch.uint8, device=device)

    idleverts.scatter_(0, grid_faces.flatten(), 0)
    grid_faces -= nyxz[:, 0].view(-1, 1) * num_verts
    split_size = torch.bincount(nyxz[:, 0], minlength=N)
    faces_list = list(torch.split(grid_faces, split_size.tolist(), 0))

    idleverts = idleverts.view(N, num_verts)
    idlenum = idleverts.cumsum(1)

    verts_list = [
        grid_verts.index_select(0, (idleverts[n] == 0).nonzero()[:, 0])
        for n in range(N)
    ]
    faces_list = [nface - idlenum[n][nface] for n, nface in enumerate(faces_list)]

    return Meshes(verts=verts_list, faces=faces_list)
示例#12
0
def train(train_loader, encoder, decoder, refiner, optimizer, args):
    global i
    global num_rec
    global overall_loss
    global latent_loss
    global rec_loss
    global factor
    global avg
    global num_print
    j=0
    loss_curr = 0
    latent_loss_curr = 0
    rec_loss_curr = 0
    avg_curr = 0
    num_img = 20

    for image, shape in train_loader:
        if image.shape[0] < num_img:
            break    

        for iteration in range(1): 
            j += 1 

            shape = shape.cuda()
            image = image.cuda()

            optimizer.zero_grad()
            
            latent = encoder(image)
            result = decoder(latent)
            result = refiner(result)

            img_res = 256 

            images = torch.cuda.FloatTensor(num_img * 2, 1, img_res, img_res) ####
            show_images = torch.cuda.FloatTensor(num_img, 1, img_res, img_res)
            
            loss = 0 

            rand = torch.randint(0, 24, (num_img,))
            if j % 100 == 1:
              for i in range(num_img):
                  cam = rand[i]
                  images[i,0,:,:], _ = differentiable_rendering(result[i,0,:,:,:], result.shape[-1], img_res, camera_list[cam])
                  images[i+num_img,0,:,:], _ = differentiable_rendering(shape[i,0,:,:,:], shape.shape[-1], img_res, camera_list[cam])
                  if j % 100 == 1:  
                      if i % 2 == 0:
                          show_images[int(i/2),0,:,:] = images[i,0,:,:] 
                          show_images[int(i/2) + int(num_img / 2),0,:,:] = images[i+num_img,0,:,:]
              if j % 100 == 1: 
                  grid = make_grid(show_images, nrow=int(num_img/2))
                  torchvision.utils.save_image(grid, "../result/" + args.category + "/train_" + str(num_rec) + "_" + str(j) + ".png", nrow=6, padding=2, normalize=False, range=None, scale_each=False, pad_value=0)

            obj_loss = 0

            # narrow band
            mask = torch.abs(result[0,0]) < 0.1
            mask = mask.float()

            # sdf loss
            image_loss, sdf_loss = loss_fn(images[:num_img][0,0], images[num_img:][0,0], result[0,0] * mask, 4/64., 64, 64, 64, 256, 256)
            obj_loss += sdf_loss / (64**3) * 0.02
 
            # laplancian loss
            conv_input = (result[0,0] * mask).unsqueeze(0).unsqueeze(0)
            conv_filter = torch.cuda.FloatTensor([[[[[0, 0, 0], [0, 1, 0], [0, 0, 0]], [[0, 1, 0], [1, -6, 1], [0, 1, 0]], [[0, 0, 0], [0, 1, 0], [0, 0, 0]]]]])
            Lp_loss = torch.sum(F.conv3d(conv_input, conv_filter) ** 2) / (64**3)
            obj_loss += Lp_loss * 0.02

            # image loss
            obj_loss += F.mse_loss(images[:num_img], images[num_img:]) * 15 * (256 * 256 / img_res / img_res)

            # back probagate
            loss = obj_loss
            loss.backward()
            optimizer.step() 
示例#13
0
# Pad image for backwards difference in last slice
pad_slice = 2*pot[-1] - pot[-2]
pot = np.vstack((pot, pad_slice[np.newaxis,]))

# Pad image with zeroes on all other boundaries
pot = np.pad(pot, ((1,0),(1,1),(1,1)), 'constant')

# Create mask where conductive phase = 1
mask = (pot > 0).astype('float32')

# Kernel to shift image in +z direction
k_s = np.zeros((3,3,3), dtype='float32')
k_s[2,1,1] = 1

# Shift mask in +z direction
mask_s = conv3d(torch.as_tensor(np.expand_dims(mask, [0,1])), 
                torch.as_tensor(np.expand_dims(k_s, [0,1]))).numpy().squeeze()

# Forward difference kernel
k_fd = np.zeros_like(k_s)
k_fd[2,1,1] = -1   
k_fd[1,1,1] = 1

# Forward difference convolution
dpot = conv3d(torch.as_tensor(np.expand_dims(pot, [0,1])), 
              torch.as_tensor(np.expand_dims(k_fd, [0,1]))).numpy().squeeze()

# Current in Z calculation
mask = mask[1:-1,1:-1,1:-1]  # Remove padding on mask
currz_conv = dpot*mask*mask_s

#%% Calculating current using nested loops
示例#14
0
def computemotionenergy(img, spatempfilter, stride):
    '''
    calculate orientation energy of an image based on a set of gabor filters.

    <img>: Input image can be numpy array or tensor, 
        is a (batch_size, c, T, H, W), Note that input image is 'uint8' dtype and range (0, 256)
    <spatempfilter>: 
        the object returned by makemultiscalespatiotemporalfilters function
    <stride>: int, stride as number of pixels along the spatial domain


    Note:
        1. We typically set the stride as 1 in the temporal domain
    '''

    import torch
    from torch.nn.functional import conv3d
    from torch import Tensor, from_numpy  # to use conv3d function, it must be converted to tensor
    from numpy import array, transpose, newaxis, round, sqrt

    # convert img to tensor
    is_tensor = False
    if isinstance(img, Tensor):
        is_tensor = True
    else:  # convert it as tensor
        img = from_numpy(img).type(torch.float)

    img = img / 256  # convert to (0~1)
    imgSize, imgLen = img.shape[3], img.shape[2]  # assume square image

    if isinstance(stride, int):
        stride = [stride]

    gbr, sd, tf = spatempfilter['gabor'], spatempfilter['sd'], spatempfilter[
        'TF']

    filteredimg = []
    for igbr, vgbr in enumerate(gbr):  # loop each spatial and temporal scale
        # reformat the filter
        gbrtmp = vgbr
        gbrtmp = transpose(gbrtmp, [3, 2, 0, 1])  # (out_channels, kT, kH, kW)
        gbrtmp = gbrtmp[:, newaxis, :, :, :]  # (out_channels, 1, kT, kH, kW)
        gbrtmp = from_numpy(gbrtmp).type(torch.float)  # convert it to torch
        ksize = gbrtmp.shape[3]  # assume square kernel
        tsize = gbrtmp.shape[2]
        nKernel = gbrtmp.shape[0]
        # now gbrtmp is a [outchannel, group, T, H, W], as standard input for conv3d

        # calculate stride and padding
        stride_tmp = int(round(stride[igbr]))
        spadding = int(ksize / 2 - imgSize % stride_tmp / 2)
        tpadding = int((tsize - 1) / 2)

        # convolve
        # output a is a [batch_size, Cout, Tout, Hout, Wout]
        a = conv3d(img,
                   gbrtmp,
                   stride=(1, stride_tmp, stride_tmp),
                   padding=(tpadding, spadding, spadding))

        #a = torch.squeeze(a)
        # assume 0:Cout/2 and Cout/2+1: are different phase,

        # transform quadradic pair of simple cells to complex cells
        a = sqrt(a[:, :int(nKernel / 2), :, :, :]**2 +
                 a[:, int(nKernel / 2):, :, :, :]**2)

        if is_tensor:
            filteredimg.append(a.view(a.shape[0], a.shape[1], a.shape[2] - 1))
        else:
            a = array(a)  # convert it back to numpy array
            filteredimg.append(
                a.reshape(a.shape[0], a.shape[1], a.shape[2], -1))
    # flatten it as a vector
    return filteredimg
示例#15
0
def ssd(kpts_fixed,
        feat_fixed,
        feat_moving,
        orig_shape,
        disp_radius=16,
        disp_step=2,
        patch_radius=2,
        alpha=1.5,
        unroll_factor=50):
    _, N, _ = kpts_fixed.shape
    device = kpts_fixed.device
    D, H, W = orig_shape
    C = feat_fixed.shape[1]
    dtype = feat_fixed.dtype

    patch_step = disp_step  # same stride necessary for fast implementation
    patch = torch.stack(
        torch.meshgrid(
            torch.arange(0, 2 * patch_radius + 1, patch_step),
            torch.arange(0, 2 * patch_radius + 1, patch_step),
            torch.arange(0, 2 * patch_radius + 1, patch_step))).permute(
                1, 2, 3, 0).contiguous().view(1, 1, -1, 1,
                                              3).float() - patch_radius
    patch = (patch.flip(-1) * 2 /
             (torch.tensor([W, H, D]) - 1)).to(dtype).to(device)

    patch_width = round(patch.shape[2]**(1.0 / 3))

    if patch_width % 2 == 0:
        pad = [(patch_width - 1) // 2, (patch_width - 1) // 2 + 1]
    else:
        pad = [(patch_width - 1) // 2, (patch_width - 1) // 2]

    disp = torch.stack(
        torch.meshgrid(
            torch.arange(-disp_step * (disp_radius + ((pad[0] + pad[1]) / 2)),
                         (disp_step * (disp_radius +
                                       ((pad[0] + pad[1]) / 2))) + 1,
                         disp_step),
            torch.arange(-disp_step * (disp_radius + ((pad[0] + pad[1]) / 2)),
                         (disp_step * (disp_radius +
                                       ((pad[0] + pad[1]) / 2))) + 1,
                         disp_step),
            torch.arange(
                -disp_step * (disp_radius + ((pad[0] + pad[1]) / 2)),
                (disp_step * (disp_radius + ((pad[0] + pad[1]) / 2))) + 1,
                disp_step))).permute(1, 2, 3,
                                     0).contiguous().view(1, 1, -1, 1,
                                                          3).float()
    disp = (disp.flip(-1) * 2 /
            (torch.tensor([W, H, D]) - 1)).to(dtype).to(device)

    disp_width = disp_radius * 2 + 1
    ssd = torch.zeros(1, N, disp_width**3).to(dtype).to(device)
    split = np.array_split(np.arange(N), unroll_factor)
    for i in range(unroll_factor):
        feat_fixed_patch = F.grid_sample(
            feat_fixed,
            kpts_fixed[:, split[i], :].view(1, -1, 1, 1, 3).to(dtype) + patch,
            padding_mode='border',
            align_corners=True)
        feat_moving_disp = F.grid_sample(
            feat_moving,
            kpts_fixed[:, split[i], :].view(1, -1, 1, 1, 3).to(dtype) + disp,
            padding_mode='border',
            align_corners=True)
        corr = F.conv3d(
            feat_moving_disp.view(1, -1, disp_width + pad[0] + pad[1],
                                  disp_width + pad[0] + pad[1],
                                  disp_width + pad[0] + pad[1]),
            feat_fixed_patch.view(-1, 1, patch_width, patch_width,
                                  patch_width),
            groups=C * split[i].shape[0]).view(C, split[i].shape[0], -1)
        patch_sum = (feat_fixed_patch**2).squeeze(0).squeeze(3).sum(
            dim=2, keepdims=True)
        disp_sum = (patch_width**3) * F.avg_pool3d(
            (feat_moving_disp**2).view(C, -1, disp_width + pad[0] + pad[1],
                                       disp_width + pad[0] + pad[1],
                                       disp_width + pad[0] + pad[1]),
            patch_width,
            stride=1).view(C, split[i].shape[0], -1)
        ssd[0, split[i], :] = ((-2 * corr + patch_sum + disp_sum)).sum(0)

    ssd *= (alpha / (patch_width**3))

    return ssd
示例#16
0
def conv_soft_argmax3d(
    input: torch.Tensor,
    kernel_size: Tuple[int, int, int] = (3, 3, 3),
    stride: Tuple[int, int, int] = (1, 1, 1),
    padding: Tuple[int, int, int] = (1, 1, 1),
    temperature: Union[torch.Tensor, float] = torch.tensor(1.0),
    normalized_coordinates: bool = False,
    eps: float = 1e-8,
    output_value: bool = True,
    strict_maxima_bonus: float = 0.0
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    r"""Function that computes the convolutional spatial Soft-Argmax 3D over the windows
    of a given input heatmap. Function has two outputs: argmax coordinates and the softmaxpooled heatmap values
    themselves.
    On each window, the function computed is:

    .. math::
             ijk(X) = \frac{\sum{(i,j,k)} * exp(x / T)  \in X} {\sum{exp(x / T)  \in X}}

    .. math::
             val(X) = \frac{\sum{x * exp(x / T)  \in X}} {\sum{exp(x / T)  \in X}}

    where T is temperature.

    Args:
        kernel_size (Tuple[int,int,int]):  size of the window
        stride (Tuple[int,int,int]): stride of the window.
        padding (Tuple[int,int,int]): input zero padding
        temperature (torch.Tensor): factor to apply to input. Default is 1.
        normalized_coordinates (bool): whether to return the coordinates normalized in the range of [-1, 1]. Otherwise,
                                       it will return the coordinates in the range of the input shape. Default is False.
        eps (float): small value to avoid zero division. Default is 1e-8.
        output_value (bool): if True, val is output, if False, only ij
        strict_maxima_bonus (float): pixels, which are strict maxima will score (1 + strict_maxima_bonus) * value.
                                     This is needed for mimic behavior of strict NMS in classic local features
    Shape:
        - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})`
        - Output: :math:`(N, C, 3, D_{out}, H_{out}, W_{out})`, :math:`(N, C, D_{out}, H_{out}, W_{out})`, where

         .. math::
             D_{out} = \left\lfloor\frac{D_{in}  + 2 \times \text{padding}[0] -
             (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor

         .. math::
             H_{out} = \left\lfloor\frac{H_{in}  + 2 \times \text{padding}[1] -
             (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor

         .. math::
             W_{out} = \left\lfloor\frac{W_{in}  + 2 \times \text{padding}[2] -
             (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor

    Examples:
        >>> input = torch.randn(20, 16, 3, 50, 32)
        >>> nms_coords, nms_val = conv_soft_argmax3d(input, (3, 3, 3), (1, 2, 2), (0, 1, 1))
    """
    if not torch.is_tensor(input):
        raise TypeError("Input type is not a torch.Tensor. Got {}".format(
            type(input)))

    if not len(input.shape) == 5:
        raise ValueError(
            "Invalid input shape, we expect BxCxDxHxW. Got: {}".format(
                input.shape))

    if temperature <= 0:
        raise ValueError(
            "Temperature should be positive float or tensor. Got: {}".format(
                temperature))

    b, c, d, h, w = input.shape
    kx, ky, kz = kernel_size
    device: torch.device = input.device
    dtype: torch.dtype = input.dtype
    input = input.view(b * c, 1, d, h, w)

    center_kernel: torch.Tensor = _get_center_kernel3d(kx, ky, kz,
                                                       device).to(dtype)
    window_kernel: torch.Tensor = _get_window_grid_kernel3d(
        kx, ky, kz, device).to(dtype)

    # applies exponential normalization trick
    # https://timvieira.github.io/blog/post/2014/02/11/exp-normalize-trick/
    # https://github.com/pytorch/pytorch/blob/bcb0bb7e0e03b386ad837015faba6b4b16e3bfb9/aten/src/ATen/native/SoftMax.cpp#L44
    x_max = F.adaptive_max_pool3d(input, (1, 1, 1))

    # max is detached to prevent undesired backprop loops in the graph
    x_exp = ((input - x_max.detach()) / temperature).exp()

    pool_coef: float = float(kx * ky * kz)

    # softmax denominator
    den = pool_coef * F.avg_pool3d(
        x_exp.view_as(input), kernel_size, stride=stride,
        padding=padding) + eps

    # We need to output also coordinates
    # Pooled window center coordinates
    grid_global: torch.Tensor = create_meshgrid3d(
        d, h, w, False, device=device).to(dtype).permute(0, 4, 1, 2, 3)

    grid_global_pooled = F.conv3d(grid_global,
                                  center_kernel,
                                  stride=stride,
                                  padding=padding)

    # Coordinates of maxima residual to window center
    # prepare kernel
    coords_max: torch.Tensor = F.conv3d(x_exp,
                                        window_kernel,
                                        stride=stride,
                                        padding=padding)

    coords_max = coords_max / den.expand_as(coords_max)
    coords_max = coords_max + grid_global_pooled.expand_as(coords_max)
    # [:,:, 0, ...] is depth (scale)
    # [:,:, 1, ...] is x
    # [:,:, 2, ...] is y

    if normalized_coordinates:
        coords_max = normalize_pixel_coordinates3d(
            coords_max.permute(0, 2, 3, 4, 1), d, h, w)
        coords_max = coords_max.permute(0, 4, 1, 2, 3)

    # Back B*C -> (b, c)
    coords_max = coords_max.view(b, c, 3, coords_max.size(2),
                                 coords_max.size(3), coords_max.size(4))

    if not output_value:
        return coords_max
    x_softmaxpool = pool_coef * F.avg_pool3d(x_exp.view(input.size()) * input,
                                             kernel_size,
                                             stride=stride,
                                             padding=padding) / den
    if strict_maxima_bonus > 0:
        in_levels: int = input.size(2)
        out_levels: int = x_softmaxpool.size(2)
        skip_levels: int = (in_levels - out_levels) // 2
        strict_maxima: torch.Tensor = F.avg_pool3d(
            kornia.feature.nms3d(input, kernel_size), 1, stride, 0)
        strict_maxima = strict_maxima[:, :,
                                      skip_levels:out_levels - skip_levels]
        x_softmaxpool *= 1.0 + strict_maxima_bonus * strict_maxima
    x_softmaxpool = x_softmaxpool.view(b, c, x_softmaxpool.size(2),
                                       x_softmaxpool.size(3),
                                       x_softmaxpool.size(4))
    return coords_max, x_softmaxpool
示例#17
0
 def forward(self, input):
     N, C, H, W, Ns = input.shape
     inPadded = self.pad(input.reshape((N, 1, 1, -1, Ns)))
     output = F.conv3d(inPadded, self.weight) * self.Ts
     return output.reshape((N, -1, H, W, Ns))
示例#18
0
 def inverse(self, x):
     if self.learnable == True:
         return invertible_downsampling.apply(self.kernel_matrix, x)
     else:
         return F.conv3d(x, self.kernel, stride=2, groups=self.input_channels//8)
示例#19
0
 def forward(self, input):
     return F.conv3d(input, self.transform_weight(), self.bias, self.stride, self.padding, self.dilation)
示例#20
0
 def forward(self, x):
     return F.conv3d(x, self.filter, bias=None)
示例#21
0
 def apply_srm_kernel(self, input_spikes, srm):
     return F.conv3d(input_spikes, srm, padding=(0, 0, int(srm.shape[4]/2))) * self.net_params['t_s']
示例#22
0
def avgfilter(f):
    return conv3d(f, avgk, padding=1)
示例#23
0
 def apply_weights(activations, weights):
     applied = F.conv3d(activations, weights)
     return applied
示例#24
0
    def forward(self, x):
        """
        Convolution forward function
        Divide the kernel into three parts on output channels based on acs_kernel_split, 
        and conduct convolution on three directions seperately. Bias is added at last.
        """

        B, C_in, *input_shape = x.shape
        conv3D_output_shape = (self.conv3D_output_shape_f(0, input_shape),
                               self.conv3D_output_shape_f(1, input_shape),
                               self.conv3D_output_shape_f(2, input_shape))

        weight_a = self.weight[0:self.acs_kernel_split[0]].unsqueeze(2)
        weight_c = self.weight[self.acs_kernel_split[0]:(
            self.acs_kernel_split[0] + self.acs_kernel_split[1])].unsqueeze(3)
        weight_s = self.weight[(self.acs_kernel_split[0] +
                                self.acs_kernel_split[1]):].unsqueeze(4)
        f_out = []
        if self.acs_kernel_split[0] > 0:
            a = F.conv3d(
                x if conv3D_output_shape[0] == input_shape[0]
                or 2 * conv3D_output_shape[0] == input_shape[0] else F.pad(
                    x,
                    (0, 0, 0, 0, self.padding[0], self.padding[0]), 'constant',
                    0)[:, :,
                       self.kernel_size[0] // 2:self.kernel_size[0] // 2 +
                       (conv3D_output_shape[0] - 1) * self.stride[0] +
                       1, :, :],
                weight=weight_a,
                bias=None,
                stride=self.stride,
                padding=(0, self.padding[1], self.padding[2]),
                dilation=self.dilation,
                groups=self.groups)
            f_out.append(a)
        if self.acs_kernel_split[1] > 0:
            c = F.conv3d(
                x if conv3D_output_shape[1] == input_shape[1]
                or 2 * conv3D_output_shape[1] == input_shape[1] else F.pad(
                    x, (0, 0, self.padding[1], self.padding[1]), 'constant',
                    0)[:, :, :,
                       self.kernel_size[1] // 2:self.kernel_size[1] // 2 +
                       self.stride[1] * (conv3D_output_shape[1] - 1) + 1, :],
                weight=weight_c,
                bias=None,
                stride=self.stride,
                padding=(self.padding[0], 0, self.padding[2]),
                dilation=self.dilation,
                groups=self.groups)
            f_out.append(c)
        if self.acs_kernel_split[2] > 0:
            s = F.conv3d(
                x if conv3D_output_shape[2] == input_shape[2]
                or 2 * conv3D_output_shape[2] == input_shape[2] else F.pad(
                    x, (self.padding[2], self.padding[2]), 'constant',
                    0)[:, :, :, :,
                       self.kernel_size[2] // 2:self.kernel_size[2] // 2 +
                       self.stride[2] * (conv3D_output_shape[2] - 1) + 1],
                weight=weight_s,
                bias=None,
                stride=self.stride,
                padding=(self.padding[0], self.padding[1], 0),
                dilation=self.dilation,
                groups=self.groups)
            f_out.append(s)
        f = torch.cat(f_out, dim=1)
        if self.bias is not None:
            f += self.bias.view(1, self.out_channels, 1, 1, 1)
        return f
示例#25
0
 def forward(self, LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH):
     assert len(LLL.size()) == len(LLH.size()) == len(LHL.size()) == len(
         LHH.size()) == len(HLL.size()) == len(HLH.size()) == len(
             HHL.size()) == len(HHH.size()) == 5
     assert LLL.size()[0] == LLH.size()[0] == LHL.size()[0] == LHH.size(
     )[0] == HLL.size()[0] == HLH.size()[0] == HHL.size()[0] == HHH.size(
     )[0]
     assert LLL.size()[1] == LLH.size()[1] == LHL.size()[1] == LHH.size(
     )[1] == HLL.size()[1] == HLH.size()[1] == HHL.size()[1] == HHH.size(
     )[1] == self.in_channels
     LLL = F.pad(F.conv_transpose3d(LLL,
                                    self.up_filter,
                                    stride=self.stride,
                                    groups=self.in_channels),
                 pad=self.pad_sizes,
                 mode=self.pad_type)
     LLH = F.pad(F.conv_transpose3d(LLH,
                                    self.up_filter,
                                    stride=self.stride,
                                    groups=self.in_channels),
                 pad=self.pad_sizes,
                 mode=self.pad_type)
     LHL = F.pad(F.conv_transpose3d(LHL,
                                    self.up_filter,
                                    stride=self.stride,
                                    groups=self.in_channels),
                 pad=self.pad_sizes,
                 mode=self.pad_type)
     LHH = F.pad(F.conv_transpose3d(LHH,
                                    self.up_filter,
                                    stride=self.stride,
                                    groups=self.in_channels),
                 pad=self.pad_sizes,
                 mode=self.pad_type)
     HLL = F.pad(F.conv_transpose3d(HLL,
                                    self.up_filter,
                                    stride=self.stride,
                                    groups=self.in_channels),
                 pad=self.pad_sizes,
                 mode=self.pad_type)
     HLH = F.pad(F.conv_transpose3d(HLH,
                                    self.up_filter,
                                    stride=self.stride,
                                    groups=self.in_channels),
                 pad=self.pad_sizes,
                 mode=self.pad_type)
     HHL = F.pad(F.conv_transpose3d(HHL,
                                    self.up_filter,
                                    stride=self.stride,
                                    groups=self.in_channels),
                 pad=self.pad_sizes,
                 mode=self.pad_type)
     HHH = F.pad(F.conv_transpose3d(HHH,
                                    self.up_filter,
                                    stride=self.stride,
                                    groups=self.in_channels),
                 pad=self.pad_sizes,
                 mode=self.pad_type)
     return F.conv3d(LLL, self.filter_lll, stride = 1, groups = self.groups) + \
            F.conv3d(LLH, self.filter_llh, stride = 1, groups = self.groups) + \
            F.conv3d(LHL, self.filter_lhl, stride = 1, groups = self.groups) + \
            F.conv3d(LHH, self.filter_lhh, stride = 1, groups = self.groups) + \
            F.conv3d(HLL, self.filter_hll, stride = 1, groups = self.groups) + \
            F.conv3d(HLH, self.filter_hlh, stride = 1, groups = self.groups) + \
            F.conv3d(HHL, self.filter_hhl, stride = 1, groups = self.groups) + \
            F.conv3d(HHH, self.filter_hhh, stride = 1, groups = self.groups)
示例#26
0
 def forward(self, input):
     return F.conv3d(input, self.weight, self.bias, self.stride,
                     self.padding, self.dilation, self.groups)
示例#27
0
 def forward(self, x):
     return F.conv3d(x, self.W_(), self.bias, self.stride, self.padding,
                     self.dilation, self.groups)
示例#28
0
 def forward(self, x):
     # note: won't work when self.padding_mode != 'zeros'
     return F.conv3d(x, self._get_augmented_weight(), self.bias,
                     self.stride, self.padding, self.dilation, self.groups)
示例#29
0
 def reconstruct(H, W, Z):
     pad_size = (W.shape[2] - 1, W.shape[3] - 1, W.shape[4] - 1)
     out = F.conv3d(H,
                    W.flip((2, 3, 4)) * Z.view(-1, 1, 1, 1),
                    padding=pad_size)
     return out
示例#30
0
def acs_conv_f(x, weight, bias, kernel_size, dilation, padding, stride, groups,
               out_channels, acs_kernel_split):
    B, C_in, *input_shape = x.shape
    C_out = weight.shape[0]
    assert groups == 1 or groups == C_in == C_out, "only support standard or depthwise conv"

    conv3D_output_shape = (conv3D_output_shape_f(0, input_shape, kernel_size,
                                                 dilation, padding, stride),
                           conv3D_output_shape_f(1, input_shape, kernel_size,
                                                 dilation, padding, stride),
                           conv3D_output_shape_f(2, input_shape, kernel_size,
                                                 dilation, padding, stride))

    weight_a = weight[0:acs_kernel_split[0]].unsqueeze(2)
    weight_c = weight[acs_kernel_split[0]:(acs_kernel_split[0] +
                                           acs_kernel_split[1])].unsqueeze(3)
    weight_s = weight[(acs_kernel_split[0] +
                       acs_kernel_split[1]):].unsqueeze(4)
    if groups == C_in == C_out:
        # depth-wise
        x_a = x[:, 0:acs_kernel_split[0]]
        x_c = x[:, acs_kernel_split[0]:(acs_kernel_split[0] +
                                        acs_kernel_split[1])]
        x_s = x[:, (acs_kernel_split[0] + acs_kernel_split[1]):]
        group_a = acs_kernel_split[0]
        group_c = acs_kernel_split[1]
        group_s = acs_kernel_split[2]
    else:
        # groups=1
        x_a = x_c = x_s = x
        group_a = group_c = group_s = 1

    f_out = []
    if acs_kernel_split[0] > 0:
        a = F.conv3d(
            x_a if conv3D_output_shape[0] == input_shape[0]
            or 2 * conv3D_output_shape[0] == input_shape[0] else F.pad(
                x, (0, 0, 0, 0, padding[0], padding[0]), 'constant',
                0)[:, :, kernel_size[0] // 2:kernel_size[0] // 2 +
                   (conv3D_output_shape[0] - 1) * stride[0] + 1, :, :],
            weight=weight_a,
            bias=None,
            stride=stride,
            padding=(0, padding[1], padding[2]),
            dilation=dilation,
            groups=group_a)
        f_out.append(a)
    if acs_kernel_split[1] > 0:
        c = F.conv3d(x_c if conv3D_output_shape[1] == input_shape[1]
                     or 2 * conv3D_output_shape[1] == input_shape[1] else
                     F.pad(x, (0, 0, padding[1], padding[1]), 'constant',
                           0)[:, :, :,
                              kernel_size[1] // 2:kernel_size[1] // 2 +
                              stride[1] * (conv3D_output_shape[1] - 1) + 1, :],
                     weight=weight_c,
                     bias=None,
                     stride=stride,
                     padding=(padding[0], 0, padding[2]),
                     dilation=dilation,
                     groups=group_c)
        f_out.append(c)
    if acs_kernel_split[2] > 0:
        s = F.conv3d(x_s if conv3D_output_shape[2] == input_shape[2]
                     or 2 * conv3D_output_shape[2] == input_shape[2] else
                     F.pad(x, (padding[2], padding[2]), 'constant',
                           0)[:, :, :, :,
                              kernel_size[2] // 2:kernel_size[2] // 2 +
                              stride[2] * (conv3D_output_shape[2] - 1) + 1],
                     weight=weight_s,
                     bias=None,
                     stride=stride,
                     padding=(padding[0], padding[1], 0),
                     dilation=dilation,
                     groups=group_s)
        f_out.append(s)
    f = torch.cat(f_out, dim=1)

    if bias is not None:
        f += bias.view(1, out_channels, 1, 1, 1)

    return f
示例#31
0
 def conv3d(self, *args, **kargs):
     return F.conv3d(self, *args, **kargs)