def main_lower(x_minus, x_plus, y_minus, y_plus, plot=False, num=0): # x1 = (x_minus + x_plus) / 2 # x1 = x_minus x1 = binary_search_x1(x_minus, x_plus, y_minus, y_plus) # y1 = binary_search_y1(x_minus, x_plus, y_minus, y_plus) x2, z = find_x2(x_minus, x_plus, y_minus, y_plus, x1) a, b, c = get_abc_lower(x1, x2, z, x_minus, x_plus, y_minus, y_plus) volume = utils.get_volume(a, b, c, x_minus, x_plus, y_minus, y_plus) if plot: utils.plot_surface(x_minus[num], x_plus[num], y_minus[num], y_plus[num], a[num], b[num], c[num]) # x = torch.linspace(x_minus.item(), x_plus.item(),100) # v = torch.zeros(x.shape) # g = torch.zeros(x.shape) # for i in range(len(x)): # x2,z = find_x2(x_minus, x_plus, y_minus, y_plus, torch.Tensor([x[i]])) # a,b,c = get_abc_lower(x[i],x2,z,x_minus, x_plus, y_minus, y_plus) # v[i] = utils.get_volume(a,b,c,x_minus, x_plus, y_minus, y_plus) # g[i] = estimate_gradient_lower(torch.Tensor([x[i]]), 1e-3, x_minus, x_plus, y_minus, y_plus) # # v = # plt.figure() # plt.plot(x.numpy(),v.numpy()) # plt.figure() # plt.plot(x.numpy(),g.numpy()) return a, b, c, volume, x1, x2
def main_upper(x_minus, x_plus, y_minus, y_plus, plot=False, num=0): y1 = binary_search_y1(x_minus, x_plus, y_minus, y_plus) y2, z = find_y2(x_minus, x_plus, y_minus, y_plus, y1) a, b, c = get_abc_upper(y1, y2, z, x_minus, x_plus, y_minus, y_plus) volume = utils.get_volume(a, b, c, x_minus, x_plus, y_minus, y_plus) if plot: utils.plot_surface(x_minus[num], x_plus[num], y_minus[num], y_plus[num], a[num], b[num], c[num]) return a, b, c, volume, y1, y2
def main_lower(x_minus, x_plus, y_minus, y_plus, plot=False, num=0, print_info=True): if print_info: print('4th orthant lower: using third.main_upper function') x_minus_new = -x_plus x_plus_new = -x_minus a, b, c = third.main_upper(x_minus_new, x_plus_new, y_minus, y_plus, print_info=print_info) b = -b c = -c if plot: utils.plot_surface(x_minus[num], x_plus[num], y_minus[num], y_plus[num], a[num], b[num], c[num]) return a.detach(), b.detach(), c.detach()
def validate(a_l, b_l, c_l, a_u, b_u, c_u, x_minus, x_plus, y_minus, y_plus, verify_and_modify_all=False, max_iter=100, plot=False, eps=1e-5, print_info=True): # eps =1e-5 original_shape = c_l.shape a_l_new = a_l.view(-1) #.data.clone() b_l_new = b_l.view(-1) #.data.clone() c_l_new = c_l.view(-1) #.data.clone() a_u_new = a_u.view(-1) #.data.clone() b_u_new = b_u.view(-1) #.data.clone() c_u_new = c_u.view(-1) #.data.clone() x_minus_new = x_minus.view(-1) #.data.clone() x_plus_new = x_plus.view(-1) #.data.clone() y_minus_new = y_minus.view(-1) #.data.clone() y_plus_new = y_plus.view(-1) #.data.clone() N = a_l_new.size(0) if verify_and_modify_all: max_iter = N for i in range(max_iter): if verify_and_modify_all: n = i else: n = torch.randint(0, N, [1]) n = n.long() hl_fl, hu_fu = plot_2_surface(x_minus_new[n], x_plus_new[n], y_minus_new[n], y_plus_new[n], a_l_new[n], b_l_new[n], c_l_new[n], a_u_new[n], b_u_new[n], c_u_new[n], plot=plot) # print('hl-fl max', hl_fl.max()) # print('hu-fu min', hu_fu.min()) if print_info: print( 'tanh sigmoid iter: %d num: %d hl-f max %.6f mean %.6f hu-f min %.6f mean %.6f' % (i, n, hl_fl.max(), hl_fl.mean(), hu_fu.min(), hu_fu.mean())) if hl_fl.max() > eps: #we want hl_fl.max() < 0 print(x_minus_new[n], x_plus_new[n], y_minus_new[n], y_plus_new[n], a_l_new[n], b_l_new[n], c_l_new[n], a_u_new[n], b_u_new[n], c_u_new[n]) plot_surface(x_minus_new[n], x_plus_new[n], y_minus_new[n], y_plus_new[n], a_l_new[n], b_l_new[n], c_l_new[n]) print('hl-f max', hl_fl.max()) raise Exception('lower plane fail') break if hl_fl.max() > 0 and verify_and_modify_all: c_l_new[n] = c_l_new[n] - hl_fl.max() * 2 if hu_fu.min() < -eps: # we want hu_fu.min()>0 print(x_minus_new[n], x_plus_new[n], y_minus_new[n], y_plus_new[n], a_l_new[n], b_l_new[n], c_l_new[n], a_u_new[n], b_u_new[n], c_u_new[n]) plot_surface(x_minus_new[n], x_plus_new[n], y_minus_new[n], y_plus_new[n], a_u_new[n], b_u_new[n], c_u_new[n]) print('hu-f min', hu_fu.min()) raise Exception('upper plane fail') break if hu_fu.min() < 0 and verify_and_modify_all: c_u_new[n] = c_u_new[n] - hu_fu.min() * 2 c_l_new = c_l_new.view(original_shape) c_u_new = c_u_new.view(original_shape) return c_l_new, c_u_new
if __name__ == '__main__': # x_minus = torch.Tensor([0,0.3]) # x_plus = torch.Tensor([0.1,2]) # y_minus = torch.Tensor([0,0.1]) # y_plus = torch.Tensor([0.1,1]) x_minus = torch.Tensor([0.9062]) x_plus = torch.Tensor([0.9295]) y_minus = torch.Tensor([0.1032]) y_plus = torch.Tensor([5.3253]) # a0 = torch.zeros(x_minus.shape, device = x_minus.device) # b0 = torch.zeros(x_minus.shape, device = x_minus.device) z10 = torch.tanh(x_minus) * torch.sigmoid(y_minus) z20 = torch.tanh(x_minus) * torch.sigmoid(y_minus) z30 = torch.tanh(x_minus) * torch.sigmoid(y_minus) a, b, c = train_lower(z10, z20, z30, x_minus, x_plus, y_minus, y_plus, first.qualification_loss_lower_standard, '1l', lr=1e-2, max_iter=500) num = 0 plot_surface(x_minus[num], x_plus[num], y_minus[num], y_plus[num], a[num], b[num], c[num])
loss5 = torch.clamp(loss5, min=confidence) loss = loss1 + loss2 + loss4 + loss5 return loss, valid import train_activation_plane def main_upper(x_minus, x_plus, y_minus, y_plus, print_info = True): z10 = torch.tanh(x_plus) * torch.sigmoid(y_minus) z20 = torch.tanh(x_plus) * torch.sigmoid(y_minus) z30 = torch.tanh(x_plus) * torch.sigmoid(y_minus) a,b,c = train_activation_plane.train_upper(z10,z20,z30, x_minus, x_plus, y_minus, y_plus, qualification_loss, '23u', lr=1e-2, max_iter = 500, print_info = print_info) return a.detach(),b.detach(),c.detach() if __name__ == '__main__': x_minus = torch.Tensor([-2]) x_plus = torch.Tensor([-0.1]) y_minus = torch.Tensor([-2]) y_plus = torch.Tensor([2]) num = 0 a_u, b_u, c_u = main_upper(x_minus, x_plus, y_minus, y_plus, print_info = False) plot_surface(x_minus[num], x_plus[num],y_minus[num], y_plus[num], a_u[num],b_u[num],c_u[num])