def forward(ctx, pos, flat_netpin, netpin_start, pin2net_map, net_weights, net_mask, pin_mask, gamma, num_threads): """ @param pos pin location (x array, y array), not cell location @param flat_netpin flat netpin map, length of #pins @param netpin_start starting index in netpin map for each net, length of #nets+1, the last entry is #pins @param pin2net_map pin2net map @param net_weights weight of nets @param net_mask whether to compute wirelength, 1 means to compute, 0 means to ignore @param pin_mask whether compute gradient for a pin, 1 means to fill with zero, 0 means to compute @param gamma the smaller, the closer to HPWL """ if pos.is_cuda: output = weighted_average_wirelength_cuda.forward( pos.view(pos.numel()), flat_netpin, netpin_start, pin2net_map, net_weights, net_mask, gamma) else: output = weighted_average_wirelength_cpp.forward( pos.view(pos.numel()), flat_netpin, netpin_start, net_weights, net_mask, gamma, num_threads) ctx.flat_netpin = flat_netpin ctx.netpin_start = netpin_start ctx.pin2net_map = pin2net_map ctx.net_weights = net_weights ctx.net_mask = net_mask ctx.pin_mask = pin_mask ctx.gamma = gamma ctx.pos = pos ctx.num_threads = num_threads return output
def forward(ctx, pos, flat_netpin, netpin_start, pin2net_map, net_weights, net_mask, pin_mask, inv_gamma, num_threads): """ @param pos pin location (x array, y array), not cell location @param flat_netpin flat netpin map, length of #pins @param netpin_start starting index in netpin map for each net, length of #nets+1, the last entry is #pins @param pin2net_map pin2net map @param net_weights weight of nets @param net_mask whether to compute wirelength, 1 means to compute, 0 means to ignore @param pin_mask whether compute gradient for a pin, 1 means to fill with zero, 0 means to compute @param inv_gamma 1/gamma, the larger, the closer to HPWL """ tt = time.time() if pos.is_cuda: output = weighted_average_wirelength_cuda.forward( pos.view(pos.numel()), flat_netpin, netpin_start, pin2net_map, net_weights, net_mask, inv_gamma) else: output = weighted_average_wirelength_cpp.forward( pos.view(pos.numel()), flat_netpin, netpin_start, net_weights, net_mask, inv_gamma, num_threads) ctx.flat_netpin = flat_netpin ctx.netpin_start = netpin_start ctx.pin2net_map = pin2net_map ctx.net_weights = net_weights ctx.net_mask = net_mask ctx.pin_mask = pin_mask ctx.inv_gamma = inv_gamma ctx.pos = pos ctx.num_threads = num_threads ctx.exp_xy = output[1] ctx.exp_nxy = output[2] ctx.exp_xy_sum = output[3] ctx.exp_nxy_sum = output[4] ctx.xyexp_xy_sum = output[5] ctx.xyexp_nxy_sum = output[6] if pos.is_cuda: torch.cuda.synchronize() logger.debug("wirelength forward %.3f ms" % ((time.time() - tt) * 1000)) return output[0]