Ejemplo n.º 1
0
    def forward(ctx,
                in_feat,
                mode=GlobalPoolingMode.AUTO,
                in_coords_key=None,
                glob_coords_key=None,
                coords_manager=None):
        assert isinstance(mode, GlobalPoolingMode), \
            f"Mode must be an instance of GlobalPoolingMode, {mode}"
        if glob_coords_key is None:
            glob_coords_key = CoordsKey(in_coords_key.D)

        gpool_forward = getattr(MEB,
                                'GlobalPoolingForward' + get_postfix(in_feat))
        broadcast_forward = getattr(MEB,
                                    'BroadcastForward' + get_postfix(in_feat))
        add = operation_type_to_int(OperationType.ADDITION)
        multiply = operation_type_to_int(OperationType.MULTIPLICATION)

        mean = in_feat.new()
        num_nonzero = in_feat.new()

        cpp_in_coords_key = in_coords_key.CPPCoordsKey
        cpp_glob_coords_key = glob_coords_key.CPPCoordsKey
        cpp_coords_manager = coords_manager.CPPCoordsManager

        mean, num_nonzero = gpool_forward(in_feat, cpp_in_coords_key,
                                          cpp_glob_coords_key,
                                          cpp_coords_manager, True, mode.value)
        # X - \mu
        centered_feat = broadcast_forward(in_feat, -mean, add,
                                          cpp_in_coords_key,
                                          cpp_glob_coords_key,
                                          cpp_coords_manager)

        # Variance = 1/N \sum (X - \mu) ** 2
        variance, num_nonzero = gpool_forward(centered_feat**2,
                                              cpp_in_coords_key,
                                              cpp_glob_coords_key,
                                              cpp_coords_manager, True,
                                              mode.value)

        # norm_feat = (X - \mu) / \sigma
        inv_std = 1 / (variance + 1e-8).sqrt()
        norm_feat = broadcast_forward(centered_feat, inv_std, multiply,
                                      cpp_in_coords_key, cpp_glob_coords_key,
                                      cpp_coords_manager)

        ctx.mode = mode
        ctx.in_coords_key, ctx.glob_coords_key = in_coords_key, glob_coords_key
        ctx.coords_manager = coords_manager
        # For GPU tensors, must use save_for_backward.
        ctx.save_for_backward(inv_std, norm_feat)
        return norm_feat
Ejemplo n.º 2
0
    def forward(ctx,
                in_feat,
                batch_size=0,
                in_coords_key=None,
                glob_coords_key=None,
                coords_manager=None):
        if glob_coords_key is None:
            glob_coords_key = CoordsKey(in_coords_key.D)

        gpool_forward = getattr(MEB,
                                'GlobalPoolingForward' + get_postfix(in_feat))
        broadcast_forward = getattr(MEB,
                                    'BroadcastForward' + get_postfix(in_feat))
        add = operation_type_to_int(OperationType.ADDITION)
        multiply = operation_type_to_int(OperationType.MULTIPLICATION)

        mean = in_feat.new()
        num_nonzero = in_feat.new()
        D = in_coords_key.D

        cpp_in_coords_key = in_coords_key.CPPCoordsKey
        cpp_glob_coords_key = glob_coords_key.CPPCoordsKey
        cpp_coords_manager = coords_manager.CPPCoordsManager

        gpool_forward(D, in_feat, mean, num_nonzero, cpp_in_coords_key,
                      cpp_glob_coords_key, cpp_coords_manager, batch_size,
                      True)
        # X - \mu
        centered_feat = in_feat.new()
        broadcast_forward(D, in_feat, -mean, centered_feat, add,
                          cpp_in_coords_key, cpp_glob_coords_key,
                          cpp_coords_manager)

        # Variance = 1/N \sum (X - \mu) ** 2
        variance = in_feat.new()
        gpool_forward(D, centered_feat**2, variance, num_nonzero,
                      cpp_in_coords_key, cpp_glob_coords_key,
                      cpp_coords_manager, batch_size, True)

        # norm_feat = (X - \mu) / \sigma
        inv_std = 1 / (variance + 1e-8).sqrt()
        norm_feat = in_feat.new()
        broadcast_forward(D, centered_feat, inv_std, norm_feat, multiply,
                          cpp_in_coords_key, cpp_glob_coords_key,
                          cpp_coords_manager)

        ctx.batch_size = batch_size
        ctx.in_coords_key, ctx.glob_coords_key = in_coords_key, glob_coords_key
        ctx.coords_manager = coords_manager
        # For GPU tensors, must use save_for_backward.
        ctx.save_for_backward(inv_std, norm_feat)
        return norm_feat
Ejemplo n.º 3
0
    def backward(ctx, out_grad):
        # https://kevinzakka.github.io/2016/09/14/batch_normalization/
        batch_size = ctx.batch_size
        in_coords_key, glob_coords_key = ctx.in_coords_key, ctx.glob_coords_key
        coords_manager = ctx.coords_manager

        # To prevent the memory leakage, compute the norm again
        inv_std, norm_feat = ctx.saved_variables
        D = in_coords_key.D

        gpool_forward = getattr(MEB,
                                'GlobalPoolingForward' + get_postfix(out_grad))
        broadcast_forward = getattr(MEB,
                                    'BroadcastForward' + get_postfix(out_grad))
        add = operation_type_to_int(OperationType.ADDITION)
        multiply = operation_type_to_int(OperationType.MULTIPLICATION)

        cpp_in_coords_key = in_coords_key.CPPCoordsKey
        cpp_glob_coords_key = glob_coords_key.CPPCoordsKey
        cpp_coords_manager = coords_manager.CPPCoordsManager

        # 1/N \sum dout
        num_nonzero = out_grad.new()
        mean_dout = out_grad.new()
        gpool_forward(D, out_grad, mean_dout, num_nonzero, cpp_in_coords_key,
                      cpp_glob_coords_key, cpp_coords_manager, batch_size,
                      True)

        # 1/N \sum (dout * out)
        mean_dout_feat = out_grad.new()
        gpool_forward(D, out_grad * norm_feat, mean_dout_feat, num_nonzero,
                      cpp_in_coords_key, cpp_glob_coords_key,
                      cpp_coords_manager, batch_size, True)

        # out * 1/N \sum (dout * out)
        feat_mean_dout_feat = out_grad.new()
        broadcast_forward(D, norm_feat, mean_dout_feat, feat_mean_dout_feat,
                          multiply, cpp_in_coords_key, cpp_glob_coords_key,
                          cpp_coords_manager)

        unnorm_din = out_grad.new()
        broadcast_forward(D, out_grad - feat_mean_dout_feat, -mean_dout,
                          unnorm_din, add, cpp_in_coords_key,
                          cpp_glob_coords_key, cpp_coords_manager)

        norm_din = out_grad.new()
        broadcast_forward(D, unnorm_din, inv_std, norm_din, multiply,
                          cpp_in_coords_key, cpp_glob_coords_key,
                          cpp_coords_manager)

        return norm_din, None, None, None, None
Ejemplo n.º 4
0
 def backward(ctx, grad_out_feat):
     grad_in_feat = grad_out_feat.new()
     bw_fn = getattr(MEB, 'PruningBackward' + get_postfix(grad_out_feat))
     bw_fn(ctx.in_coords_key.D, grad_in_feat, grad_out_feat,
           ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey,
           ctx.coords_manager.CPPCoordsManager)
     return grad_in_feat, None, None, None, None, None
Ejemplo n.º 5
0
    def forward(ctx,
                input_features,
                batch_size=0,
                average=True,
                in_coords_key=None,
                out_coords_key=None,
                coords_manager=None):
        if out_coords_key is None:
            out_coords_key = CoordsKey(in_coords_key.D)
        ctx.in_coords_key = in_coords_key
        ctx.out_coords_key = out_coords_key

        ctx.in_feat = input_features
        out_feat = input_features.new()
        ctx.average = average
        ctx.num_nonzero = input_features.new()
        ctx.coords_manager = coords_manager

        D = in_coords_key.D
        fw_fn = getattr(MEB,
                        'GlobalPoolingForward' + get_postfix(input_features))
        fw_fn(D, ctx.in_feat, out_feat, ctx.num_nonzero,
              ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey,
              ctx.coords_manager.CPPCoordsManager, batch_size, ctx.average)
        return out_feat
Ejemplo n.º 6
0
    def forward(ctx,
                input_features,
                average=True,
                mode=GlobalPoolingMode.AUTO,
                in_coords_key=None,
                out_coords_key=None,
                coords_manager=None):
        if out_coords_key is None:
            out_coords_key = CoordsKey(in_coords_key.D)
        assert isinstance(mode, GlobalPoolingMode), \
            f"Mode must be an instance of GlobalPoolingMode, {mode}"

        ctx.in_coords_key = in_coords_key
        ctx.out_coords_key = out_coords_key

        ctx.in_feat = input_features
        ctx.average = average
        ctx.coords_manager = coords_manager
        ctx.mode = mode.value

        fw_fn = getattr(MEB,
                        'GlobalPoolingForward' + get_postfix(input_features))
        out_feat, num_nonzero = fw_fn(ctx.in_feat,
                                      ctx.in_coords_key.CPPCoordsKey,
                                      ctx.out_coords_key.CPPCoordsKey,
                                      ctx.coords_manager.CPPCoordsManager,
                                      ctx.average, ctx.mode)

        ctx.num_nonzero = num_nonzero

        return out_feat
Ejemplo n.º 7
0
    def forward(ctx, input_features, input_features_global, operation_type,
                in_coords_key, glob_coords_key, coords_manager):
        assert input_features.shape[1] == input_features_global.shape[1]
        assert input_features.type() == input_features_global.type()
        assert isinstance(operation_type, OperationType)
        if not input_features.is_contiguous():
            input_features = input_features.contiguous()
        if not input_features_global.is_contiguous():
            input_features_global = input_features_global.contiguous()

        ctx.op = operation_type_to_int(operation_type)

        ctx.in_feat = input_features
        ctx.in_feat_glob = input_features_global
        ctx.in_coords_key = in_coords_key
        ctx.glob_coords_key = glob_coords_key
        ctx.coords_manager = coords_manager

        out_feat = input_features.new()

        fw_fn = getattr(MEB, 'BroadcastForward' + get_postfix(input_features))
        fw_fn(ctx.in_feat, ctx.in_feat_glob, out_feat, ctx.op,
              ctx.in_coords_key.CPPCoordsKey, ctx.glob_coords_key.CPPCoordsKey,
              ctx.coords_manager.CPPCoordsManager)
        return out_feat
Ejemplo n.º 8
0
 def backward(ctx, grad_out_feat):
     grad_in_feat = grad_out_feat.new()
     bw_fn = getattr(MEB,
                     'GlobalPoolingBackward' + get_postfix(grad_out_feat))
     bw_fn(ctx.in_feat, grad_in_feat, grad_out_feat, ctx.num_nonzero,
           ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey,
           ctx.coords_manager.CPPCoordsManager, ctx.average)
     return grad_in_feat, None, None, None, None, None
 def backward(ctx, grad_out_feat):
     grad_in_feat = grad_out_feat.new()
     grad_in_feat_glob = grad_out_feat.new()
     bw_fn = getattr(MEB, 'BroadcastBackward' + get_postfix(grad_out_feat))
     bw_fn(ctx.in_coords_key.D, ctx.in_feat, grad_in_feat, ctx.in_feat_glob,
           grad_in_feat_glob, grad_out_feat, ctx.op,
           ctx.in_coords_key.CPPCoordsKey, ctx.glob_coords_key.CPPCoordsKey,
           ctx.coords_manager.CPPCoordsManager)
     return grad_in_feat, grad_in_feat_glob, None, None, None, None
Ejemplo n.º 10
0
    def backward(ctx, grad_out_feat):
        if not grad_out_feat.is_contiguous():
            grad_out_feat = grad_out_feat.contiguous()

        bw_fn = getattr(MEB, 'UnionBackward' + get_postfix(grad_out_feat))
        grad_in_feats = bw_fn(grad_out_feat,
                              [key.CPPCoordsKey for key in ctx.in_coords_keys],
                              ctx.out_coords_key.CPPCoordsKey,
                              ctx.coords_manager.CPPCoordsManager)
        return (None, None, None, *grad_in_feats)
    def forward(ctx,
                input_features,
                kernel,
                tensor_stride=1,
                stride=1,
                kernel_size=-1,
                dilation=1,
                region_type=0,
                region_offset=None,
                generate_new_coords=False,
                in_coords_key=None,
                out_coords_key=None,
                coords_manager=None):
        """
        region_type=0 HyperCube
        """
        # Prep arguments
        # Kernel shape (n_spatial_kernels, in_nfeat, out_nfeat)
        assert input_features.shape[1] == kernel.shape[1], \
            "The input shape " + str(list(input_features.shape)) + \
            " does not match the kernel shape " + str(list(kernel.shape))
        if out_coords_key is None:
            out_coords_key = CoordsKey(in_coords_key.D)
        assert in_coords_key.D == out_coords_key.D
        assert input_features.type() == kernel.type(), \
            f"Type mismatch input: {input_features.type()} != kernel: {kernel.type()}"
        if not input_features.is_contiguous():
            input_features = input_features.contiguous()

        tensor_stride, stride, kernel_size, dilation, region_type = prep_args(
            tensor_stride, stride, kernel_size, dilation, region_type,
            in_coords_key.D)

        if region_offset is None:
            region_offset = torch.IntTensor()

        ctx.in_feat = input_features
        ctx.kernel = kernel
        ctx = save_ctx(ctx, tensor_stride, stride, kernel_size, dilation,
                       region_type, in_coords_key, out_coords_key,
                       coords_manager)

        D = in_coords_key.D
        out_feat = input_features.new()

        fw_fn = getattr(
            MEB, 'ConvolutionTransposeForward' + get_postfix(input_features))
        fw_fn(ctx.in_feat, out_feat, kernel,
              convert_to_int_list(ctx.tensor_stride, D),
              convert_to_int_list(ctx.stride, D),
              convert_to_int_list(ctx.kernel_size, D),
              convert_to_int_list(ctx.dilation, D), region_type, region_offset,
              ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey,
              ctx.coords_man.CPPCoordsManager, generate_new_coords)
        return out_feat
Ejemplo n.º 12
0
 def backward(ctx, grad_out_feat):
     grad_in_feat = grad_out_feat.new()
     D = ctx.in_coords_key.D
     bw_fn = getattr(MEB, 'MaxPoolingBackward' + get_postfix(grad_out_feat))
     bw_fn(D, ctx.in_feat, grad_in_feat, grad_out_feat, ctx.max_index,
           convert_to_int_list(ctx.tensor_stride, D),
           convert_to_int_list(ctx.stride, D),
           convert_to_int_list(ctx.kernel_size, D),
           convert_to_int_list(ctx.dilation, D), ctx.region_type,
           ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey,
           ctx.coords_man.CPPCoordsManager)
     return grad_in_feat, None, None, None, None, None, None, None, None, None
Ejemplo n.º 13
0
    def backward(ctx, grad_out_feat):
        if not grad_out_feat.is_contiguous():
            grad_out_feat = grad_out_feat.contiguous()

        grad_in_feat = grad_out_feat.new()
        D = ctx.in_coords_key.D
        bw_fn = getattr(MEB, 'AvgPoolingBackward' + get_postfix(grad_out_feat))
        bw_fn(ctx.in_feat, grad_in_feat, grad_out_feat, ctx.num_nonzero,
              convert_to_int_list(ctx.tensor_stride, D),
              convert_to_int_list(ctx.stride, D),
              convert_to_int_list(ctx.kernel_size, D),
              convert_to_int_list(ctx.dilation, D), ctx.region_type,
              ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey,
              ctx.coords_man.CPPCoordsManager, ctx.use_avg)
        return grad_in_feat, None, None, None, None, None, None, None, None, None, None
 def backward(ctx, grad_out_feat):
     assert grad_out_feat.type() == ctx.in_feat.type()
     grad_in_feat = grad_out_feat.new()
     grad_kernel = grad_out_feat.new()
     D = ctx.in_coords_key.D
     bw_fn = getattr(MEB,
                     'ConvolutionBackward' + get_postfix(grad_out_feat))
     bw_fn(D, ctx.in_feat, grad_in_feat, grad_out_feat, ctx.kernel,
           grad_kernel, convert_to_int_list(ctx.tensor_stride, D),
           convert_to_int_list(ctx.stride, D),
           convert_to_int_list(ctx.kernel_size, D),
           convert_to_int_list(ctx.dilation, D), ctx.region_type,
           ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey,
           ctx.coords_man.CPPCoordsManager)
     return grad_in_feat, grad_kernel, None, None, None, None, None, None, None, None, None
Ejemplo n.º 15
0
    def forward(ctx, in_feat, use_feat, in_coords_key, out_coords_key,
                coords_manager):
        assert in_feat.size(0) == use_feat.size(0)
        assert isinstance(use_feat, torch.ByteTensor)
        ctx.in_coords_key = in_coords_key
        ctx.out_coords_key = out_coords_key
        ctx.coords_manager = coords_manager

        out_feat = in_feat.new()

        fw_fn = getattr(MEB, 'PruningForward' + get_postfix(in_feat))
        fw_fn(ctx.in_coords_key.D, in_feat, out_feat, use_feat,
              ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey,
              ctx.coords_manager.CPPCoordsManager)
        return out_feat
    def forward(ctx,
                input_features,
                kernel,
                tensor_stride=1,
                stride=1,
                kernel_size=-1,
                dilation=1,
                region_type=0,
                region_offset=None,
                in_coords_key=None,
                out_coords_key=None,
                coords_manager=None):
        """
        region_type=0 HyperCube
        """
        # Prep arguments
        # Kernel shape (n_spatial_kernels, in_nfeat, out_nfeat)
        assert input_features.shape[1] == kernel.shape[1]
        if out_coords_key is None:
            out_coords_key = CoordsKey(in_coords_key.D)
        assert in_coords_key.D == out_coords_key.D
        assert input_features.type() == kernel.type()
        tensor_stride, stride, kernel_size, dilation, region_type = prep_args(
            tensor_stride, stride, kernel_size, dilation, region_type,
            in_coords_key.D)

        if region_offset is None:
            region_offset = torch.IntTensor()

        ctx.in_feat = input_features
        ctx.kernel = kernel
        ctx = save_ctx(ctx, tensor_stride, stride, kernel_size, dilation,
                       region_type, in_coords_key, out_coords_key,
                       coords_manager)

        D = in_coords_key.D
        out_feat = input_features.new()

        fw_fn = getattr(MEB,
                        'ConvolutionForward' + get_postfix(input_features))
        fw_fn(D, ctx.in_feat, out_feat, kernel,
              convert_to_int_list(ctx.tensor_stride, D),
              convert_to_int_list(ctx.stride, D),
              convert_to_int_list(ctx.kernel_size, D),
              convert_to_int_list(ctx.dilation, D), region_type, region_offset,
              ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey,
              ctx.coords_man.CPPCoordsManager)
        return out_feat
Ejemplo n.º 17
0
    def forward(ctx,
                input_features,
                tensor_stride=1,
                stride=1,
                kernel_size=-1,
                dilation=1,
                region_type=0,
                region_offset=None,
                in_coords_key=None,
                out_coords_key=None,
                coords_manager=None):
        assert isinstance(region_type, RegionType)
        if out_coords_key is None:
            out_coords_key = CoordsKey(in_coords_key.D)
        assert in_coords_key.D == out_coords_key.D
        if not input_features.is_contiguous():
            input_features = input_features.contiguous()

        tensor_stride, stride, kernel_size, dilation, region_type = prep_args(
            tensor_stride, stride, kernel_size, dilation, region_type,
            in_coords_key.D)

        if region_offset is None:
            region_offset = torch.IntTensor()

        ctx.in_feat = input_features
        ctx = save_ctx(ctx, tensor_stride, stride, kernel_size, dilation,
                       region_type, in_coords_key, out_coords_key,
                       coords_manager)

        D = in_coords_key.D
        out_feat = input_features.new()
        max_index = input_features.new().int()

        ctx.max_index = max_index

        fw_fn = getattr(MEB, 'MaxPoolingForward' + get_postfix(input_features))
        fw_fn(input_features, out_feat, max_index,
              convert_to_int_list(ctx.tensor_stride, D),
              convert_to_int_list(ctx.stride, D),
              convert_to_int_list(ctx.kernel_size, D),
              convert_to_int_list(ctx.dilation, D), region_type, region_offset,
              ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey,
              ctx.coords_man.CPPCoordsManager)
        return out_feat
Ejemplo n.º 18
0
    def forward(ctx, in_coords_keys, out_coords_key, coords_manager,
                *in_feats):
        assert isinstance(in_feats, list) or isinstance(in_feats, tuple), \
            "Input must be a list or a set of Tensors"
        assert len(in_feats) > 1, \
            "input must be a set with at least 2 Tensors"

        in_feats = [in_feat.contiguous() for in_feat in in_feats]

        ctx.in_coords_keys = in_coords_keys
        ctx.out_coords_key = out_coords_key
        ctx.coords_manager = coords_manager

        fw_fn = getattr(MEB, 'UnionForward' + get_postfix(in_feats[0]))
        return fw_fn(in_feats,
                     [key.CPPCoordsKey for key in ctx.in_coords_keys],
                     ctx.out_coords_key.CPPCoordsKey,
                     ctx.coords_manager.CPPCoordsManager)
Ejemplo n.º 19
0
    def forward(ctx,
                input_features,
                tensor_stride=1,
                stride=1,
                kernel_size=-1,
                dilation=1,
                region_type=0,
                region_offset=None,
                average=True,
                in_coords_key=None,
                out_coords_key=None,
                coords_manager=None):
        assert isinstance(region_type, RegionType)
        if out_coords_key is None:
            out_coords_key = CoordsKey(in_coords_key.D)
        assert in_coords_key.D == out_coords_key.D
        tensor_stride, stride, kernel_size, dilation, region_type = prep_args(
            tensor_stride, stride, kernel_size, dilation, region_type,
            in_coords_key.D)

        if region_offset is None:
            region_offset = torch.IntTensor()

        ctx.in_feat = input_features
        ctx = save_ctx(ctx, tensor_stride, stride, kernel_size, dilation,
                       region_type, in_coords_key, out_coords_key,
                       coords_manager)
        ctx.use_avg = average

        D = in_coords_key.D
        out_feat = input_features.new()
        ctx.num_nonzero = input_features.new()

        fw_fn = getattr(MEB, 'AvgPoolingForward' + get_postfix(input_features))
        fw_fn(D, ctx.in_feat, out_feat, ctx.num_nonzero,
              convert_to_int_list(ctx.tensor_stride, D),
              convert_to_int_list(ctx.stride, D),
              convert_to_int_list(ctx.kernel_size, D),
              convert_to_int_list(ctx.dilation, D), region_type, region_offset,
              ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey,
              ctx.coords_man.CPPCoordsManager, ctx.use_avg)
        return out_feat
Ejemplo n.º 20
0
    def forward(ctx, in_feat, mask, in_coords_key, out_coords_key,
                coords_manager):
        assert in_feat.size(0) == mask.size(0)
        assert isinstance(mask,
                          torch.BoolTensor), "Mask must be a cpu bool tensor."
        if not in_feat.is_contiguous():
            in_feat = in_feat.contiguous()
        if not mask.is_contiguous():
            mask = mask.contiguous()

        ctx.in_coords_key = in_coords_key
        ctx.out_coords_key = out_coords_key
        ctx.coords_manager = coords_manager

        out_feat = in_feat.new()

        fw_fn = getattr(MEB, 'PruningForward' + get_postfix(in_feat))
        fw_fn(in_feat, out_feat, mask, ctx.in_coords_key.CPPCoordsKey,
              ctx.out_coords_key.CPPCoordsKey,
              ctx.coords_manager.CPPCoordsManager)
        return out_feat
Ejemplo n.º 21
0
    def forward(ctx, in_feat, use_feat, in_coords_key, out_coords_key,
                coords_manager):
        assert in_feat.size(0) == use_feat.size(0)
        assert isinstance(use_feat, torch.ByteTensor) \
            or isinstance(use_feat, torch.BoolTensor), "use_feat must be a bool/byte tensor."
        if isinstance(use_feat, torch.BoolTensor):
            use_feat = use_feat.byte()
        if not in_feat.is_contiguous():
            in_feat = in_feat.contiguous()
        if not use_feat.is_contiguous():
            use_feat = use_feat.contiguous()

        ctx.in_coords_key = in_coords_key
        ctx.out_coords_key = out_coords_key
        ctx.coords_manager = coords_manager

        out_feat = in_feat.new()

        fw_fn = getattr(MEB, 'PruningForward' + get_postfix(in_feat))
        fw_fn(in_feat, out_feat, use_feat, ctx.in_coords_key.CPPCoordsKey,
              ctx.out_coords_key.CPPCoordsKey,
              ctx.coords_manager.CPPCoordsManager)
        return out_feat
Ejemplo n.º 22
0
    def forward(ctx,
                input_features,
                in_coords_key=None,
                out_coords_key=None,
                coords_manager=None):
        if out_coords_key is None:
            out_coords_key = CoordsKey(in_coords_key.D)
        ctx.in_coords_key = in_coords_key
        ctx.out_coords_key = out_coords_key

        ctx.in_feat = input_features
        out_feat = input_features.new()

        max_index = input_features.new().int()

        ctx.max_index = max_index
        ctx.coords_manager = coords_manager

        fw_fn = getattr(
            MEB, 'GlobalMaxPoolingForward' + get_postfix(input_features))
        fw_fn(ctx.in_feat, out_feat, ctx.max_index,
              ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey,
              ctx.coords_manager.CPPCoordsManager)
        return out_feat