def forward(ctx, pos, pin2net_map, net_weights, net_mask, pin_mask, gamma): """ @param pos pin location (x array, y array), not cell location @param pin2net_map pin2net map @param net_weights weight of nets @param net_mask whether to compute wirelength @param pin_mask whether compute gradient for a pin, 1 means to fill with zero, 0 means to compute @param gamma the smaller, the closer to HPWL """ #tt = time.time() if pos.is_cuda: output = weighted_average_wirelength_cuda_atomic.forward( pos.view(pos.numel()), pin2net_map, net_weights, net_mask, gamma) else: assert 0, "CPU version NOT IMPLEMENTED" ctx.pin2net_map = pin2net_map ctx.net_weights = net_weights ctx.net_mask = net_mask ctx.pin_mask = pin_mask ctx.gamma = gamma ctx.exp_xy = output[1] ctx.exp_nxy = output[2] ctx.exp_xy_sum = output[3] ctx.exp_nxy_sum = output[4] ctx.xyexp_xy_sum = output[5] ctx.xyexp_nxy_sum = output[6] ctx.pos = pos #if torch.isnan(ctx.exp_xy).any() or torch.isnan(ctx.exp_nxy).any() or torch.isnan(ctx.exp_xy_sum).any() or torch.isnan(ctx.exp_nxy_sum).any() or torch.isnan(output[0]).any(): # pdb.set_trace() torch.cuda.synchronize() #print("\t\twirelength forward kernel takes %.3f ms" % ((time.time()-tt)*1000)) return output[0]
def forward(ctx, pos, pin2net_map, flat_netpin, netpin_start, net_weights, net_mask, pin_mask, inv_gamma): """ @param pos pin location (x array, y array), not cell location @param pin2net_map pin2net map @param net_weights weight of nets @param net_mask whether to compute wirelength @param pin_mask whether compute gradient for a pin, 1 means to fill with zero, 0 means to compute @param inv_gamma 1/gamma, the larger, the closer to HPWL """ tt = time.time() if pos.is_cuda: output = weighted_average_wirelength_cuda_atomic.forward(pos.view(pos.numel()), pin2net_map, flat_netpin, netpin_start, net_weights, net_mask, inv_gamma) else: assert 0, "CPU version NOT IMPLEMENTED" ctx.pin2net_map = pin2net_map ctx.flat_netpin = flat_netpin ctx.netpin_start = netpin_start ctx.net_weights = net_weights ctx.net_mask = net_mask ctx.pin_mask = pin_mask ctx.inv_gamma = inv_gamma ctx.exp_xy = output[1] ctx.exp_nxy = output[2] ctx.exp_xy_sum = output[3] ctx.exp_nxy_sum = output[4] ctx.xyexp_xy_sum = output[5] ctx.xyexp_nxy_sum = output[6] ctx.pos = pos #if torch.isnan(ctx.exp_xy).any() or torch.isnan(ctx.exp_nxy).any() or torch.isnan(ctx.exp_xy_sum).any() or torch.isnan(ctx.exp_nxy_sum).any() or torch.isnan(output[0]).any(): # pdb.set_trace() if pos.is_cuda: torch.cuda.synchronize() logger.debug("wirelength forward %.3f ms" % ((time.time()-tt)*1000)) return output[0]