def forward(ctx, input_features, kernel, tensor_stride=1, stride=1, kernel_size=-1, dilation=1, region_type=0, region_offset=None, generate_new_coords=False, in_coords_key=None, out_coords_key=None, coords_manager=None): """ region_type=0 HyperCube """ # Prep arguments # Kernel shape (n_spatial_kernels, in_nfeat, out_nfeat) assert input_features.shape[1] == kernel.shape[1], \ "The input shape " + str(list(input_features.shape)) + \ " does not match the kernel shape " + str(list(kernel.shape)) if out_coords_key is None: out_coords_key = CoordsKey(in_coords_key.D) assert in_coords_key.D == out_coords_key.D assert input_features.type() == kernel.type(), \ f"Type mismatch input: {input_features.type()} != kernel: {kernel.type()}" if not input_features.is_contiguous(): input_features = input_features.contiguous() tensor_stride, stride, kernel_size, dilation, region_type = prep_args( tensor_stride, stride, kernel_size, dilation, region_type, in_coords_key.D) if region_offset is None: region_offset = torch.IntTensor() ctx.in_feat = input_features ctx.kernel = kernel ctx = save_ctx(ctx, tensor_stride, stride, kernel_size, dilation, region_type, in_coords_key, out_coords_key, coords_manager) D = in_coords_key.D out_feat = input_features.new() fw_fn = getattr( MEB, 'ConvolutionTransposeForward' + get_postfix(input_features)) fw_fn(ctx.in_feat, out_feat, kernel, convert_to_int_list(ctx.tensor_stride, D), convert_to_int_list(ctx.stride, D), convert_to_int_list(ctx.kernel_size, D), convert_to_int_list(ctx.dilation, D), region_type, region_offset, ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey, ctx.coords_man.CPPCoordsManager, generate_new_coords) return out_feat
def forward(ctx, input_features, kernel, tensor_stride=1, stride=1, kernel_size=-1, dilation=1, region_type=0, region_offset=None, in_coords_key=None, out_coords_key=None, coords_manager=None): """ region_type=0 HyperCube """ # Prep arguments # Kernel shape (n_spatial_kernels, in_nfeat, out_nfeat) assert input_features.shape[1] == kernel.shape[1] if out_coords_key is None: out_coords_key = CoordsKey(in_coords_key.D) assert in_coords_key.D == out_coords_key.D assert input_features.type() == kernel.type() tensor_stride, stride, kernel_size, dilation, region_type = prep_args( tensor_stride, stride, kernel_size, dilation, region_type, in_coords_key.D) if region_offset is None: region_offset = torch.IntTensor() ctx.in_feat = input_features ctx.kernel = kernel ctx = save_ctx(ctx, tensor_stride, stride, kernel_size, dilation, region_type, in_coords_key, out_coords_key, coords_manager) D = in_coords_key.D out_feat = input_features.new() fw_fn = getattr(MEB, 'ConvolutionForward' + get_postfix(input_features)) fw_fn(D, ctx.in_feat, out_feat, kernel, convert_to_int_list(ctx.tensor_stride, D), convert_to_int_list(ctx.stride, D), convert_to_int_list(ctx.kernel_size, D), convert_to_int_list(ctx.dilation, D), region_type, region_offset, ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey, ctx.coords_man.CPPCoordsManager) return out_feat
def forward(ctx, input_features, tensor_stride=1, stride=1, kernel_size=-1, dilation=1, region_type=0, region_offset=None, in_coords_key=None, out_coords_key=None, coords_manager=None): assert isinstance(region_type, RegionType) if out_coords_key is None: out_coords_key = CoordsKey(in_coords_key.D) assert in_coords_key.D == out_coords_key.D if not input_features.is_contiguous(): input_features = input_features.contiguous() tensor_stride, stride, kernel_size, dilation, region_type = prep_args( tensor_stride, stride, kernel_size, dilation, region_type, in_coords_key.D) if region_offset is None: region_offset = torch.IntTensor() ctx.in_feat = input_features ctx = save_ctx(ctx, tensor_stride, stride, kernel_size, dilation, region_type, in_coords_key, out_coords_key, coords_manager) D = in_coords_key.D out_feat = input_features.new() max_index = input_features.new().int() ctx.max_index = max_index fw_fn = getattr(MEB, 'MaxPoolingForward' + get_postfix(input_features)) fw_fn(input_features, out_feat, max_index, convert_to_int_list(ctx.tensor_stride, D), convert_to_int_list(ctx.stride, D), convert_to_int_list(ctx.kernel_size, D), convert_to_int_list(ctx.dilation, D), region_type, region_offset, ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey, ctx.coords_man.CPPCoordsManager) return out_feat
def forward(ctx, input_features, tensor_stride=1, stride=1, kernel_size=-1, dilation=1, region_type=0, region_offset=None, average=True, in_coords_key=None, out_coords_key=None, coords_manager=None): assert isinstance(region_type, RegionType) if out_coords_key is None: out_coords_key = CoordsKey(in_coords_key.D) assert in_coords_key.D == out_coords_key.D tensor_stride, stride, kernel_size, dilation, region_type = prep_args( tensor_stride, stride, kernel_size, dilation, region_type, in_coords_key.D) if region_offset is None: region_offset = torch.IntTensor() ctx.in_feat = input_features ctx = save_ctx(ctx, tensor_stride, stride, kernel_size, dilation, region_type, in_coords_key, out_coords_key, coords_manager) ctx.use_avg = average D = in_coords_key.D out_feat = input_features.new() ctx.num_nonzero = input_features.new() fw_fn = getattr(MEB, 'AvgPoolingForward' + get_postfix(input_features)) fw_fn(D, ctx.in_feat, out_feat, ctx.num_nonzero, convert_to_int_list(ctx.tensor_stride, D), convert_to_int_list(ctx.stride, D), convert_to_int_list(ctx.kernel_size, D), convert_to_int_list(ctx.dilation, D), region_type, region_offset, ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey, ctx.coords_man.CPPCoordsManager, ctx.use_avg) return out_feat