def forward(ctx, input, input_index, output, output_index): scatter_max(input, input_index, output, output_index, True) ctx.size = input.size() ctx.save_for_backward(output_index) # m1 = output_index.max() # m2 = output_index.min() # print("test grad") # print("input shape: ", input.shape) # print("max : ", m1) # print("min : ", m2) # print("points counts : ", ctx.size[0]) # if m2 < 0 or m1 >= ctx.size[0]: # print("do again") # scatter_max(input, input_index, output, output_index, False) # print("max : ", output_index.max()) # print("min : ", output_index.min()) # input.cpu().numpy().tofile("/root/input.bin") # input_index.cpu().numpy().tofile("/root/input_index.bin") # output.cpu().numpy().tofile("/root/output.bin") # output_index.cpu().numpy().tofile("/root/output_index.bin") return output
def scatterMax(input, input_index, voxel_nums, train): ''' only accept two dimension tensor, and do maxpooing in first dimension ''' output = input.new_full((voxel_nums, input.shape[1]), torch.finfo(input.dtype).min) output_index = input_index.new_empty((voxel_nums, input.shape[1])) if train: output = ScatterMaxCuda.apply(input, input_index, output, output_index) else: output = scatter_max(input, input_index, output, output_index, False) return output
def forward(ctx, input, input_index, output, output_index): scatter_max(input, input_index, output, output_index, True) ctx.size = input.size() ctx.save_for_backward(output_index) return output