def forward(ctx, pos, flat_netpin, netpin_start, pin2net_map, net_weights,
             net_mask, pin_mask, gamma, num_threads):
     """
     @param pos pin location (x array, y array), not cell location 
     @param flat_netpin flat netpin map, length of #pins 
     @param netpin_start starting index in netpin map for each net, length of #nets+1, the last entry is #pins  
     @param pin2net_map pin2net map 
     @param net_weights weight of nets 
     @param net_mask whether to compute wirelength, 1 means to compute, 0 means to ignore  
     @param pin_mask whether compute gradient for a pin, 1 means to fill with zero, 0 means to compute
     @param gamma the smaller, the closer to HPWL 
     """
     if pos.is_cuda:
         output = weighted_average_wirelength_cuda.forward(
             pos.view(pos.numel()), flat_netpin, netpin_start, pin2net_map,
             net_weights, net_mask, gamma)
     else:
         output = weighted_average_wirelength_cpp.forward(
             pos.view(pos.numel()), flat_netpin, netpin_start, net_weights,
             net_mask, gamma, num_threads)
     ctx.flat_netpin = flat_netpin
     ctx.netpin_start = netpin_start
     ctx.pin2net_map = pin2net_map
     ctx.net_weights = net_weights
     ctx.net_mask = net_mask
     ctx.pin_mask = pin_mask
     ctx.gamma = gamma
     ctx.pos = pos
     ctx.num_threads = num_threads
     return output
    def forward(ctx, pos, flat_netpin, netpin_start, pin2net_map, net_weights,
                net_mask, pin_mask, inv_gamma, num_threads):
        """
        @param pos pin location (x array, y array), not cell location
        @param flat_netpin flat netpin map, length of #pins
        @param netpin_start starting index in netpin map for each net, length of #nets+1, the last entry is #pins
        @param pin2net_map pin2net map
        @param net_weights weight of nets
        @param net_mask whether to compute wirelength, 1 means to compute, 0 means to ignore
        @param pin_mask whether compute gradient for a pin, 1 means to fill with zero, 0 means to compute
        @param inv_gamma 1/gamma, the larger, the closer to HPWL
        """
        tt = time.time()
        if pos.is_cuda:
            output = weighted_average_wirelength_cuda.forward(
                pos.view(pos.numel()), flat_netpin, netpin_start, pin2net_map,
                net_weights, net_mask, inv_gamma)
        else:
            output = weighted_average_wirelength_cpp.forward(
                pos.view(pos.numel()), flat_netpin, netpin_start, net_weights,
                net_mask, inv_gamma, num_threads)
        ctx.flat_netpin = flat_netpin
        ctx.netpin_start = netpin_start
        ctx.pin2net_map = pin2net_map
        ctx.net_weights = net_weights
        ctx.net_mask = net_mask
        ctx.pin_mask = pin_mask
        ctx.inv_gamma = inv_gamma
        ctx.pos = pos
        ctx.num_threads = num_threads
        ctx.exp_xy = output[1]
        ctx.exp_nxy = output[2]
        ctx.exp_xy_sum = output[3]
        ctx.exp_nxy_sum = output[4]
        ctx.xyexp_xy_sum = output[5]
        ctx.xyexp_nxy_sum = output[6]

        if pos.is_cuda:
            torch.cuda.synchronize()
        logger.debug("wirelength forward %.3f ms" %
                     ((time.time() - tt) * 1000))
        return output[0]