def forward(ctx, pos, pin2net_map, net_weights, net_mask, pin_mask, gamma):
     """
     @param pos pin location (x array, y array), not cell location 
     @param pin2net_map pin2net map 
     @param net_weights weight of nets 
     @param net_mask whether to compute wirelength 
     @param pin_mask whether compute gradient for a pin, 1 means to fill with zero, 0 means to compute
     @param gamma the smaller, the closer to HPWL 
     """
     #tt = time.time()
     if pos.is_cuda:
         output = weighted_average_wirelength_cuda_atomic.forward(
             pos.view(pos.numel()), pin2net_map, net_weights, net_mask,
             gamma)
     else:
         assert 0, "CPU version NOT IMPLEMENTED"
     ctx.pin2net_map = pin2net_map
     ctx.net_weights = net_weights
     ctx.net_mask = net_mask
     ctx.pin_mask = pin_mask
     ctx.gamma = gamma
     ctx.exp_xy = output[1]
     ctx.exp_nxy = output[2]
     ctx.exp_xy_sum = output[3]
     ctx.exp_nxy_sum = output[4]
     ctx.xyexp_xy_sum = output[5]
     ctx.xyexp_nxy_sum = output[6]
     ctx.pos = pos
     #if torch.isnan(ctx.exp_xy).any() or torch.isnan(ctx.exp_nxy).any() or torch.isnan(ctx.exp_xy_sum).any() or torch.isnan(ctx.exp_nxy_sum).any() or torch.isnan(output[0]).any():
     #    pdb.set_trace()
     torch.cuda.synchronize()
     #print("\t\twirelength forward kernel takes %.3f ms" % ((time.time()-tt)*1000))
     return output[0]
 def forward(ctx, pos, pin2net_map, flat_netpin, netpin_start, net_weights, net_mask, pin_mask, inv_gamma):
     """
     @param pos pin location (x array, y array), not cell location
     @param pin2net_map pin2net map
     @param net_weights weight of nets
     @param net_mask whether to compute wirelength
     @param pin_mask whether compute gradient for a pin, 1 means to fill with zero, 0 means to compute
     @param inv_gamma 1/gamma, the larger, the closer to HPWL
     """
     tt = time.time()
     if pos.is_cuda:
         output = weighted_average_wirelength_cuda_atomic.forward(pos.view(pos.numel()), pin2net_map, flat_netpin, netpin_start, net_weights, net_mask, inv_gamma)
     else:
         assert 0, "CPU version NOT IMPLEMENTED"
     ctx.pin2net_map = pin2net_map
     ctx.flat_netpin = flat_netpin
     ctx.netpin_start = netpin_start
     ctx.net_weights = net_weights
     ctx.net_mask = net_mask
     ctx.pin_mask = pin_mask
     ctx.inv_gamma = inv_gamma
     ctx.exp_xy = output[1]
     ctx.exp_nxy = output[2]
     ctx.exp_xy_sum = output[3]
     ctx.exp_nxy_sum = output[4]
     ctx.xyexp_xy_sum = output[5]
     ctx.xyexp_nxy_sum = output[6]
     ctx.pos = pos
     #if torch.isnan(ctx.exp_xy).any() or torch.isnan(ctx.exp_nxy).any() or torch.isnan(ctx.exp_xy_sum).any() or torch.isnan(ctx.exp_nxy_sum).any() or torch.isnan(output[0]).any():
     #    pdb.set_trace()
     if pos.is_cuda:
         torch.cuda.synchronize()
     logger.debug("wirelength forward %.3f ms" % ((time.time()-tt)*1000))
     return output[0]