def main_lower(x_minus, x_plus, y_minus, y_plus, plot=False, num=0):
    # x1 = (x_minus + x_plus) / 2
    # x1 = x_minus
    x1 = binary_search_x1(x_minus, x_plus, y_minus, y_plus)
    # y1 = binary_search_y1(x_minus, x_plus, y_minus, y_plus)
    x2, z = find_x2(x_minus, x_plus, y_minus, y_plus, x1)
    a, b, c = get_abc_lower(x1, x2, z, x_minus, x_plus, y_minus, y_plus)
    volume = utils.get_volume(a, b, c, x_minus, x_plus, y_minus, y_plus)

    if plot:
        utils.plot_surface(x_minus[num], x_plus[num], y_minus[num],
                           y_plus[num], a[num], b[num], c[num])

    # x = torch.linspace(x_minus.item(), x_plus.item(),100)
    # v = torch.zeros(x.shape)
    # g = torch.zeros(x.shape)
    # for i in range(len(x)):
    #     x2,z = find_x2(x_minus, x_plus, y_minus, y_plus, torch.Tensor([x[i]]))
    #     a,b,c = get_abc_lower(x[i],x2,z,x_minus, x_plus, y_minus, y_plus)
    #     v[i] = utils.get_volume(a,b,c,x_minus, x_plus, y_minus, y_plus)
    #     g[i] = estimate_gradient_lower(torch.Tensor([x[i]]), 1e-3, x_minus, x_plus, y_minus, y_plus)
    # # v =
    # plt.figure()
    # plt.plot(x.numpy(),v.numpy())
    # plt.figure()
    # plt.plot(x.numpy(),g.numpy())
    return a, b, c, volume, x1, x2
def main_upper(x_minus, x_plus, y_minus, y_plus, plot=False, num=0):

    y1 = binary_search_y1(x_minus, x_plus, y_minus, y_plus)
    y2, z = find_y2(x_minus, x_plus, y_minus, y_plus, y1)
    a, b, c = get_abc_upper(y1, y2, z, x_minus, x_plus, y_minus, y_plus)
    volume = utils.get_volume(a, b, c, x_minus, x_plus, y_minus, y_plus)

    if plot:
        utils.plot_surface(x_minus[num], x_plus[num], y_minus[num],
                           y_plus[num], a[num], b[num], c[num])
    return a, b, c, volume, y1, y2
def main_lower(x_minus,
               x_plus,
               y_minus,
               y_plus,
               plot=False,
               num=0,
               print_info=True):
    if print_info:
        print('4th orthant lower: using third.main_upper function')
    x_minus_new = -x_plus
    x_plus_new = -x_minus

    a, b, c = third.main_upper(x_minus_new,
                               x_plus_new,
                               y_minus,
                               y_plus,
                               print_info=print_info)
    b = -b
    c = -c
    if plot:
        utils.plot_surface(x_minus[num], x_plus[num], y_minus[num],
                           y_plus[num], a[num], b[num], c[num])
    return a.detach(), b.detach(), c.detach()
Beispiel #4
0
def validate(a_l,
             b_l,
             c_l,
             a_u,
             b_u,
             c_u,
             x_minus,
             x_plus,
             y_minus,
             y_plus,
             verify_and_modify_all=False,
             max_iter=100,
             plot=False,
             eps=1e-5,
             print_info=True):
    # eps =1e-5
    original_shape = c_l.shape

    a_l_new = a_l.view(-1)  #.data.clone()
    b_l_new = b_l.view(-1)  #.data.clone()
    c_l_new = c_l.view(-1)  #.data.clone()

    a_u_new = a_u.view(-1)  #.data.clone()
    b_u_new = b_u.view(-1)  #.data.clone()
    c_u_new = c_u.view(-1)  #.data.clone()

    x_minus_new = x_minus.view(-1)  #.data.clone()
    x_plus_new = x_plus.view(-1)  #.data.clone()
    y_minus_new = y_minus.view(-1)  #.data.clone()
    y_plus_new = y_plus.view(-1)  #.data.clone()

    N = a_l_new.size(0)

    if verify_and_modify_all:
        max_iter = N

    for i in range(max_iter):

        if verify_and_modify_all:
            n = i
        else:
            n = torch.randint(0, N, [1])
            n = n.long()

        hl_fl, hu_fu = plot_2_surface(x_minus_new[n],
                                      x_plus_new[n],
                                      y_minus_new[n],
                                      y_plus_new[n],
                                      a_l_new[n],
                                      b_l_new[n],
                                      c_l_new[n],
                                      a_u_new[n],
                                      b_u_new[n],
                                      c_u_new[n],
                                      plot=plot)

        # print('hl-fl max', hl_fl.max())
        # print('hu-fu min', hu_fu.min())
        if print_info:
            print(
                'tanh sigmoid iter: %d num: %d hl-f max %.6f mean %.6f hu-f min %.6f mean %.6f'
                % (i, n, hl_fl.max(), hl_fl.mean(), hu_fu.min(), hu_fu.mean()))
        if hl_fl.max() > eps:  #we want hl_fl.max() < 0
            print(x_minus_new[n], x_plus_new[n], y_minus_new[n], y_plus_new[n],
                  a_l_new[n], b_l_new[n], c_l_new[n], a_u_new[n], b_u_new[n],
                  c_u_new[n])
            plot_surface(x_minus_new[n], x_plus_new[n], y_minus_new[n],
                         y_plus_new[n], a_l_new[n], b_l_new[n], c_l_new[n])
            print('hl-f max', hl_fl.max())
            raise Exception('lower plane fail')
            break

        if hl_fl.max() > 0 and verify_and_modify_all:
            c_l_new[n] = c_l_new[n] - hl_fl.max() * 2

        if hu_fu.min() < -eps:  # we want hu_fu.min()>0
            print(x_minus_new[n], x_plus_new[n], y_minus_new[n], y_plus_new[n],
                  a_l_new[n], b_l_new[n], c_l_new[n], a_u_new[n], b_u_new[n],
                  c_u_new[n])
            plot_surface(x_minus_new[n], x_plus_new[n], y_minus_new[n],
                         y_plus_new[n], a_u_new[n], b_u_new[n], c_u_new[n])
            print('hu-f min', hu_fu.min())
            raise Exception('upper plane fail')
            break
        if hu_fu.min() < 0 and verify_and_modify_all:
            c_u_new[n] = c_u_new[n] - hu_fu.min() * 2
    c_l_new = c_l_new.view(original_shape)
    c_u_new = c_u_new.view(original_shape)
    return c_l_new, c_u_new
if __name__ == '__main__':
    # x_minus = torch.Tensor([0,0.3])
    # x_plus = torch.Tensor([0.1,2])
    # y_minus = torch.Tensor([0,0.1])
    # y_plus = torch.Tensor([0.1,1])
    x_minus = torch.Tensor([0.9062])
    x_plus = torch.Tensor([0.9295])
    y_minus = torch.Tensor([0.1032])
    y_plus = torch.Tensor([5.3253])
    # a0 = torch.zeros(x_minus.shape, device = x_minus.device)
    # b0 = torch.zeros(x_minus.shape, device = x_minus.device)
    z10 = torch.tanh(x_minus) * torch.sigmoid(y_minus)
    z20 = torch.tanh(x_minus) * torch.sigmoid(y_minus)
    z30 = torch.tanh(x_minus) * torch.sigmoid(y_minus)
    a, b, c = train_lower(z10,
                          z20,
                          z30,
                          x_minus,
                          x_plus,
                          y_minus,
                          y_plus,
                          first.qualification_loss_lower_standard,
                          '1l',
                          lr=1e-2,
                          max_iter=500)

    num = 0
    plot_surface(x_minus[num], x_plus[num], y_minus[num], y_plus[num], a[num],
                 b[num], c[num])
    loss5 = torch.clamp(loss5, min=confidence)
    
    loss = loss1 + loss2 + loss4 + loss5
    return loss, valid


import train_activation_plane
def main_upper(x_minus, x_plus, y_minus, y_plus, print_info = True):
    z10 = torch.tanh(x_plus) * torch.sigmoid(y_minus)
    z20 = torch.tanh(x_plus) * torch.sigmoid(y_minus)
    z30 = torch.tanh(x_plus) * torch.sigmoid(y_minus)
    a,b,c = train_activation_plane.train_upper(z10,z20,z30,
                x_minus, x_plus, y_minus, y_plus, qualification_loss, 
                '23u', lr=1e-2,
                max_iter = 500, print_info = print_info)
    return a.detach(),b.detach(),c.detach()



if __name__ == '__main__':
    
    x_minus = torch.Tensor([-2])
    x_plus = torch.Tensor([-0.1])
    y_minus = torch.Tensor([-2])
    y_plus = torch.Tensor([2])
    
    num = 0
    a_u, b_u, c_u = main_upper(x_minus, x_plus, y_minus, y_plus, print_info = False)
    
    plot_surface(x_minus[num], x_plus[num],y_minus[num], y_plus[num],
                                a_u[num],b_u[num],c_u[num])