def backward(ctx, gradOutput): input, g0, g1, g2, g3, temp_out, mask = ctx.saved_tensors # print temp_out.size() # print mask.size() assert (gradOutput.is_contiguous() == True) with torch.cuda.device_of(gradOutput): num, channels, depth, height, width = input.size() # _, _, fsize, _, _ = g0.size() # print fsize gradInput = gradOutput.new().resize_(num, channels, depth, height, width).zero_() grad0 = gradOutput.new().resize_(g0.size()).zero_() grad1 = gradOutput.new().resize_(g1.size()).zero_() grad2 = gradOutput.new().resize_(g2.size()).zero_() grad3 = gradOutput.new().resize_(g3.size()).zero_() temp_grad = gradOutput.new().resize_(num, channels, depth, height, width).zero_() max_idx = gradOutput.new().resize_(num, channels, height, width).zero_() GANet.sga_cuda_backward(input, g0, g1, g2, g3, temp_out, mask, max_idx, gradOutput, temp_grad, gradInput, grad0, grad1, grad2, grad3) # GANet.lga_cuda_backward(input, filters, gradOutput, gradInput, gradFilters, radius) gradInput = gradInput.contiguous() grad0 = grad0.contiguous() grad1 = grad1.contiguous() grad2 = grad2.contiguous() grad3 = grad3.contiguous() return gradInput, grad0, grad1, grad2, grad3
def forward(ctx, input, filters): ctx.radius = radius assert (input.is_contiguous() == True and filters.is_contiguous() == True) with torch.cuda.device_of(input): num, channels, height, width = input.size() output = input.new().resize_(num, channels, height, width).zero_() GANet.lga_cuda_forward(input, filters, output, radius) output = output.contiguous() ctx.save_for_backward(input, filters) return output
def backward(ctx, gradOutput): input, filters = ctx.saved_tensors assert (gradOutput.is_contiguous() == True) with torch.cuda.device_of(gradOutput): num, channels, height, width = input.size() _, fsize, _, _ = filters.size() gradInput = gradOutput.new().resize_(num, channels, height, width).zero_() gradFilters = gradOutput.new().resize_(num, fsize, height, width).zero_() GANet.lga_cuda_backward(input, filters, gradOutput, gradInput, gradFilters, ctx.radius) gradInput = gradInput.contiguous() gradFilters = gradFilters.contiguous() return gradInput, gradFilters, None
def forward(ctx, input, g0, g1, g2, g3): assert (input.is_contiguous() == True and g0.is_contiguous() == True and g1.is_contiguous() == True and g2.is_contiguous() == True and g3.is_contiguous() == True) with torch.cuda.device_of(input): num, channels, depth, height, width = input.size() output = input.new().resize_(num, channels, depth, height, width).zero_() temp_out = input.new().resize_(num, channels, depth, height, width).zero_() mask = input.new().resize_(num, channels, depth, height, width).zero_() GANet.sga_cuda_forward(input, g0, g1, g2, g3, temp_out, output, mask) # GANet.sga_cuda_forward(input, filters, output, radius) output = output.contiguous() ctx.save_for_backward(input, g0, g1, g2, g3, temp_out, mask) return output