def validate(a_l, b_l, c_l, a_u, b_u, c_u, x_minus, x_plus, y_minus, y_plus, verify_and_modify_all=False, max_iter=100, plot=False, eps=1e-5, print_info=True): # eps =1e-5 original_shape = c_l.shape a_l_new = a_l.view(-1) #.data.clone() b_l_new = b_l.view(-1) #.data.clone() c_l_new = c_l.view(-1) #.data.clone() a_u_new = a_u.view(-1) #.data.clone() b_u_new = b_u.view(-1) #.data.clone() c_u_new = c_u.view(-1) #.data.clone() x_minus_new = x_minus.view(-1) #.data.clone() x_plus_new = x_plus.view(-1) #.data.clone() y_minus_new = y_minus.view(-1) #.data.clone() y_plus_new = y_plus.view(-1) #.data.clone() N = a_l_new.size(0) if verify_and_modify_all: max_iter = N for i in range(max_iter): if verify_and_modify_all: n = i else: n = torch.randint(0, N, [1]) n = n.long() hl_fl, hu_fu = plot_2_surface(x_minus_new[n], x_plus_new[n], y_minus_new[n], y_plus_new[n], a_l_new[n], b_l_new[n], c_l_new[n], a_u_new[n], b_u_new[n], c_u_new[n], plot=plot) # print('hl-fl max', hl_fl.max()) # print('hu-fu min', hu_fu.min()) if print_info: print( 'tanh sigmoid iter: %d num: %d hl-f max %.6f mean %.6f hu-f min %.6f mean %.6f' % (i, n, hl_fl.max(), hl_fl.mean(), hu_fu.min(), hu_fu.mean())) if hl_fl.max() > eps: #we want hl_fl.max() < 0 print(x_minus_new[n], x_plus_new[n], y_minus_new[n], y_plus_new[n], a_l_new[n], b_l_new[n], c_l_new[n], a_u_new[n], b_u_new[n], c_u_new[n]) plot_surface(x_minus_new[n], x_plus_new[n], y_minus_new[n], y_plus_new[n], a_l_new[n], b_l_new[n], c_l_new[n]) print('hl-f max', hl_fl.max()) raise Exception('lower plane fail') break if hl_fl.max() > 0 and verify_and_modify_all: c_l_new[n] = c_l_new[n] - hl_fl.max() * 2 if hu_fu.min() < -eps: # we want hu_fu.min()>0 print(x_minus_new[n], x_plus_new[n], y_minus_new[n], y_plus_new[n], a_l_new[n], b_l_new[n], c_l_new[n], a_u_new[n], b_u_new[n], c_u_new[n]) plot_surface(x_minus_new[n], x_plus_new[n], y_minus_new[n], y_plus_new[n], a_u_new[n], b_u_new[n], c_u_new[n]) print('hu-f min', hu_fu.min()) raise Exception('upper plane fail') break if hu_fu.min() < 0 and verify_and_modify_all: c_u_new[n] = c_u_new[n] - hu_fu.min() * 2 c_l_new = c_l_new.view(original_shape) c_u_new = c_u_new.view(original_shape) return c_l_new, c_u_new
length = 3 x_minus = torch.Tensor([-length]) x_plus = torch.Tensor([length]) y_minus = torch.Tensor([-1]) y_plus = torch.Tensor([length]) num = [0] device = torch.device('cpu') # x_minus = ((torch.rand(num, device=device) - 0.5) * 10) # x_plus = (torch.rand(num, device=device)*5 + x_minus) # y_minus = ((torch.rand(num, device=device)-0.5) * 10) # y_plus = (torch.rand(num, device=device)*5 + y_minus) print_info = False start = time.time() a_l, b_l, c_l, a_u, b_u, c_u = bound_tanh_sigmoid(x_minus, x_plus, y_minus, y_plus, fine_tune_c=False, use_1D_line=False, use_constant=False, print_info=print_info) end = time.time() v1, v2 = plot_2_surface(x_minus[num], x_plus[num], y_minus[num], y_plus[num], a_l[num], b_l[num], c_l[num], a_u[num], b_u[num], c_u[num]) # validate(a_l,b_l,c_l,a_u,b_u,c_u,x_minus, x_plus, y_minus, y_plus, # max_iter=100,plot=False, eps=1e-4, print_info = print_info) print('time used:', end - start)
print_info=print_info) increase = raise_upper_plane(loss1, loss2, loss3, loss4, a_best, b_best, c_best, x_minus, x_plus, y_minus, y_plus) c_best = c_best + increase return a_best.detach(), b_best.detach(), c_best.detach() if __name__ == '__main__': x_minus = torch.Tensor([0.062]) x_plus = torch.Tensor([5]) y_minus = torch.Tensor([0.1032]) y_plus = torch.Tensor([5.3253]) print_info = False a, b, c = main_lower(x_minus, x_plus, y_minus, y_plus, print_info=print_info) a_best, b_best, c_best = main_upper(x_minus, x_plus, y_minus, y_plus, print_info=print_info) num = 0 v1, v2 = plot_2_surface(x_minus[num], x_plus[num], y_minus[num], y_plus[num], a[num], b[num], c[num], a_best[num], b_best[num], c_best[num])
lr=1e-2, max_iter=500, print_info=print_info) return a_upper.detach(), b_upper.detach(), c_upper.detach() if __name__ == '__main__': x_minus = torch.Tensor([-5.2]) x_plus = torch.Tensor([-0.1]) y_minus = torch.Tensor([0.1]) y_plus = torch.Tensor([5.2]) num = 0 print_info = False a_lower, b_lower, c_lower = main_lower(x_minus, x_plus, y_minus, y_plus, print_info=print_info) a_upper, b_upper, c_upper = main_upper(x_minus, x_plus, y_minus, y_plus, print_info=print_info) v1, v2 = plot_2_surface(x_minus[num], x_plus[num], y_minus[num], y_plus[num], a_lower[num], b_lower[num], c_lower[num], a_upper[num], b_upper[num], c_upper[num])
I_l = (X_l>=0).float() I_u = (X_u>=0).float() #k_l y + b_l <= sigmoid(y) <= k_u y + b_u #X_l*k_l y + X_l*b_l <= tanh(x)sigmoid(y), when X_l>=0 #X_l*k_u y + X_l*b_u <= tanh(x)sigmoid(y), when X_l<0 alpha_l = torch.zeros(x_minus.shape, device=x_minus.device) beta_l = I_l * X_l * kl + (1-I_l) * X_l * ku gamma_l = I_l * X_l * bl + (1-I_l) * X_l * bu #tanh(x)sigmoid(y) <= X_u*k_u y + X_u*b_u, when X_u>=0 #tanh(x)sigmoid(y) <= X_u*k_l y + X_u*b_l, when X_u<0 alpha_u = torch.zeros(x_plus.shape, device=x_minus.device) beta_u = I_u * X_u * ku + (1-I_u) * X_u * kl gamma_u = I_u * X_u * bu + (1-I_u) * X_u * bl idx= (0,0) # plot_surface(x_minus, x_plus,y_minus, y_plus, alpha_l,beta_l,gamma_l) plot_2_surface(x_minus[idx], x_plus[idx],y_minus[idx], y_plus[idx], alpha_l[idx],beta_l[idx],gamma_l[idx], alpha_u[idx],beta_u[idx],gamma_u[idx], plot=True, num_points=20)