def backward(ctx, grad_pos):
     tt = time.time()
     if grad_pos.is_cuda:
         output = weighted_average_wirelength_cuda_atomic.backward(
             grad_pos, ctx.pos, ctx.exp_xy.view([-1]),
             ctx.exp_nxy.view([-1]), ctx.exp_xy_sum.view([-1]),
             ctx.exp_nxy_sum.view([-1]), ctx.xyexp_xy_sum.view([-1]),
             ctx.xyexp_nxy_sum.view([-1]), ctx.pin2net_map, ctx.flat_netpin,
             ctx.netpin_start, ctx.net_weights, ctx.net_mask, ctx.inv_gamma)
     else:
         assert 0, "CPU version NOT IMPLEMENTED"
     output[:int(output.numel() // 2)].masked_fill_(ctx.pin_mask, 0.0)
     output[int(output.numel() // 2):].masked_fill_(ctx.pin_mask, 0.0)
     if grad_pos.is_cuda:
         torch.cuda.synchronize()
     logger.debug("wirelength backward kernel %.3f ms" %
                  ((time.time() - tt) * 1000))
     return output, None, None, None, None, None, None, None
 def backward(ctx, grad_pos):
     #tt = time.time()
     if grad_pos.is_cuda:
         output = weighted_average_wirelength_cuda_atomic.backward(
             grad_pos, ctx.pos, ctx.exp_xy.view([-1]),
             ctx.exp_nxy.view([-1]), ctx.exp_xy_sum.view([-1]),
             ctx.exp_nxy_sum.view([-1]), ctx.xyexp_xy_sum.view([-1]),
             ctx.xyexp_nxy_sum.view([-1]), ctx.pin2net_map, ctx.net_weights,
             ctx.net_mask, ctx.gamma)
     else:
         assert 0, "CPU version NOT IMPLEMENTED"
     output[:int(output.numel() // 2)].masked_fill_(ctx.pin_mask, 0.0)
     output[int(output.numel() // 2):].masked_fill_(ctx.pin_mask, 0.0)
     #if torch.isnan(output).any():
     #    pdb.set_trace()
     torch.cuda.synchronize()
     #print("\t\twirelength backward kernel %.3f ms" % ((time.time()-tt)*1000))
     return output, None, None, None, None, None