def line_bounding_2D_activation(x_minus, x_plus,y_minus, y_plus, tanh=True):
    #if tanh is True, bound tanh(x) sigmoid(y)
    #else, bound x sigmoid(y)
    kl, bl, ku, bu = getConvenientGeneralActivationBound(
                        y_minus, y_plus, 'sigmoid')
    
    if tanh:
        X_l = torch.tanh(x_minus)
        X_u = torch.tanh(x_plus)
    else:
        X_l = x_minus
        X_u = x_plus
    
    I_l = (X_l>=0).float()
    I_u = (X_u>=0).float()
    
    #k_l y + b_l <= sigmoid(y) <= k_u y + b_u
    #X_l*k_l y + X_l*b_l <= tanh(x)sigmoid(y), when X_l>=0
    #X_l*k_u y + X_l*b_u <= tanh(x)sigmoid(y), when X_l<0
    
    alpha_l = torch.zeros(x_minus.shape, device=x_minus.device)
    beta_l = I_l * X_l * kl + (1-I_l) * X_l * ku
    gamma_l = I_l * X_l * bl + (1-I_l) * X_l * bu
    
    #tanh(x)sigmoid(y) <= X_u*k_u y + X_u*b_u, when X_u>=0
    #tanh(x)sigmoid(y) <= X_u*k_l y + X_u*b_l, when X_u<0
    
    alpha_u = torch.zeros(x_plus.shape, device=x_minus.device)
    beta_u = I_u * X_u * ku + (1-I_u) * X_u * kl
    gamma_u = I_u * X_u * bu + (1-I_u) * X_u * bl
    return alpha_l,beta_l,gamma_l, alpha_u,beta_u,gamma_u
Exemplo n.º 2
0
    def get_hfc(self, m):
        #compute hfc of the m time step
        #m could range from 1 to seq_len
        #bound c[:,m-1,:]*sigmoid(yf[:,m,:])
        if m > 1:
            b_l, a_l, c_l, b_u, a_u, c_u = x_sigmoid.main(
                self.c_l[m - 1 - 1],
                self.c_u[m - 1 - 1],
                self.yf_l[m - 1],
                self.yf_u[m - 1],
                use_1D_line=self.use_1D_line,
                use_constant=self.use_constant,
                print_info=self.print_info)
            self.alpha_l_fc[m - 1] = a_l.detach()
            self.alpha_u_fc[m - 1] = a_u.detach()
            self.beta_l_fc[m - 1] = b_l.detach()
            self.beta_u_fc[m - 1] = b_u.detach()
            self.gamma_l_fc[m - 1] = c_l.detach()
            self.gamma_u_fc[m - 1] = c_u.detach()
            return a_l, b_l, c_l, a_u, b_u, c_u
        if m == 1:
            # bound c0 * sigmoid(yf1)
            zeros = torch.zeros(self.yf_l[m - 1].shape, device=self.device)
            if self.c0 is None:
                self.alpha_l_fc[m - 1] = zeros.data.clone()
                self.alpha_u_fc[m - 1] = zeros.data.clone()
                self.beta_l_fc[m - 1] = zeros.data.clone()
                self.beta_u_fc[m - 1] = zeros.data.clone()
                self.gamma_l_fc[m - 1] = zeros.data.clone()
                self.gamma_u_fc[m - 1] = zeros.data.clone()
            else:
                # bound c0 * sigmoid(yf1)
                #c0 is constant, we only need to bound 1d sigmoid
                #alpha * yf + beta * c + gamma
                #kl * yf1 + bl <= sigmoid(yf1) <= ku * yf1 + bu
                #if c0_i >= 0
                #c0_i (kl_i * yf1_i + bl_i) <= c0_i sigmoid(yf1)_i <= c0_i (ku_i * yf1_i + bu_i)
                ##if c0_i < 0
                #co_i (kl_u * yf1_i + bu_i) <= c0_i sigmoid(yf1)_i <= c0_i (kl_i * yf1_i + bl_i)
                kl, bl, ku, bu = getConvenientGeneralActivationBound(
                    self.yf_l[m - 1], self.yf_u[m - 1], 'sigmoid')
                I = (self.c0 >= 0).float()
                KU = I * ku + (1 - I) * kl
                BU = I * bu + (1 - I) * bl
                KL = (1 - I) * ku + I * kl
                BL = (1 - I) * bu + I * bl

                self.alpha_u_fc[m - 1] = self.c0 * KU
                self.gamma_u_fc[m - 1] = self.c0 * BU
                self.beta_u_fc[m - 1] = zeros.data.clone()

                self.alpha_l_fc[m - 1] = self.c0 * KL
                self.gamma_l_fc[m - 1] = self.c0 * BL
                self.beta_l_fc[m - 1] = zeros.data.clone()
        return 0
Exemplo n.º 3
0
    def getLastLayerBound(self,
                          eps,
                          p,
                          x=None,
                          clearIntermediateVariables=False):
        #eps could be a real number, or a tensor of size N
        if self.x is None and x is None:
            raise Exception(
                'You must first attach data to the net or feed data to this function'
            )
        if self.W is None or self.b is None:
            self.extractWeight()

        if x is None:
            x = self.x

        for k in range(1, self.num_layers + 1):
            # k from 1 to self.num_layers
            yL, yU = self.compute2sideBound(eps, p, k, x=x)
            self.l[k] = yL.detach()
            self.u[k] = yU.detach()
            print('yU-yL', (yU - yL).mean())
            #in this loop, self.u, l
            if k < self.num_layers:
                kl, bl, ku, bu = get_bound.getConvenientGeneralActivationBound(
                    self.l[k], self.u[k], self.activation, use_constant=False)
                self.kl[k] = kl.detach()
                self.ku[k] = ku.detach()
                self.bl[k] = bl.detach()
                self.bu[k] = bu.detach()
        if clearIntermediateVariables:
            self.clear_intermediate_variables()
            # self.l[k] = None
            # self.u[k] = None
            #clear l[k] and u[k] to release memory
        return yL, yU
Exemplo n.º 4
0
    def optimize_kl_neuronwise(self, v, eps, p, x, num_neurons, max_iter = 100, print_loss=True,
                                patience=5, acc=1e-2, init='ori'):
        # optimize the lower bounding line slopes of h1_l/U,..,h(v-1)_L/U for yv
        # compute a tighter bound of the v-th layer, v range from 1 to num_layers
        # x should be a tensor of size (batch, input_dimesnion)
        # if this is the last layer and gx0 trick is applied, 
        # we should instead use optimize_kl_neuronwise_for_last_layer_with_gx0_trick_multi_sample
        
        if v==1:
            yL_opti,yU_opti= self.compute2sideBound(eps, p, v, x=x)
        else:
            kl_ori = []
            idx = self.kl_idx[1:v]
            for k in range(1,v):
                kl_ori.append(self.kl[k].clone().detach())
                kl_ori[k-1].requires_grad = False

            num_neuron = self.b[v].shape[0]
            batch = x.shape[0]
            yL_opti = torch.zeros(batch, num_neuron, device=x.device)
            W_v = self.W[v].detach().clone()
            b_v = self.b[v].detach().clone()
            
            for j in range(num_neuron): # for every neuron in this layer, we optimize over it for all images in this batch at once 
                self.W[v] = W_v[j:j+1,:]
                self.b[v] = b_v[j:j+1]

                # init optimization variables
                kl = []
                for k in range(1,v):
                    if init == 'ori':
                        kl.append(kl_ori[k-1].clone().detach())
                    elif init == 'rand':
                        kl.append(torch.rand_like(kl_ori[k-1]))
                    else:
                        raise Exception('%s initialization not supported' % init)
                    kl[k-1].requires_grad = True

                optimizer = optim.Adam(kl, lr = 1e-1)
                stopper = EarlyStop(patience, acc=acc)

                for i in range(max_iter):
                    for k in range(1,v):
                        # self.kl[k] = torch.clamp(kl[k-1], min=0, max=1) * idx[k-1] + kl_ori[k-1] * (1-idx[k-1])
                        kl[k-1].data.clamp_(min=0, max=1)
                        self.kl[k] = kl[k-1] * idx[k-1] + kl_ori[k-1] * (1-idx[k-1])
                    
                    yL,_ = self.compute2sideBound(eps, p, v, x=x) # shape (batch,1)
                    
                    loss1 = -yL
                    loss1 = loss1.mean()
                    if stopper.should_stop(loss1.detach().cpu().item()):
                        break
                    optimizer.zero_grad()
                    # pdb.set_trace()
                    loss1.backward()
                    optimizer.step()
                    if print_loss:
                        print('neuron %d/%d step %d yL mean %.5f' % (j+1, num_neuron, i+1, -loss1))
                
                yL_opti[:,j] = yL.squeeze(1)
            

            yU_opti = torch.zeros(batch, num_neuron, device=x.device)
            for j in range(num_neuron):
                self.W[v] = W_v[j:j+1,:]
                self.b[v] = b_v[j:j+1]
                
                # init optimization variables
                kl = []
                for k in range(1,v):
                    if init == 'ori':
                        kl.append(kl_ori[k-1].clone().detach())
                    elif init == 'rand':
                        kl.append(torch.rand_like(kl_ori[k-1]))
                    else:
                        raise Exception('%s initialization not supported' % init)
                    kl[k-1].requires_grad = True

                optimizer = optim.Adam(kl, lr = 1e-1)
                stopper = EarlyStop(patience, acc=acc)

                for i in range(max_iter):
                    for k in range(1,v):
                        # self.kl[k] = torch.clamp(kl[k-1], min=0, max=1) * idx[k-1] + kl_ori[k-1] * (1-idx[k-1])
                        kl[k-1].data.clamp_(min=0, max=1)
                        self.kl[k] = kl[k-1] * idx[k-1] + kl_ori[k-1] * (1-idx[k-1])
                    
                    _,yU = self.compute2sideBound(eps, p, v, x=x) # shape (batch,1)
                    
                    loss2 = yU
                    # print('loss2 neuron %d/%d step %d \n' % (j+1, num_neuron, i+1), loss2)
                    loss2 = loss2.mean()
                    if stopper.should_stop(loss2.detach().cpu().item()):
                        break
                    optimizer.zero_grad()
                    loss2.backward()
                    optimizer.step()
                    if print_loss:
                        print('neuron %d/%d step %d yU mean %.5f' % (j+1, num_neuron, i+1, loss2))

                yU_opti[:,j] = yU.squeeze(1)
            
            
            self.W[v] = W_v.detach()
            self.b[v] = b_v.detach()
        
        if v == 1: 
            print('Layer %d: yU-yL mean: %.3f' % (v,(yU_opti-yL_opti).mean()))
        else:
            print('Layer %d: yU-yL mean: %.3f' % (v,(yU_opti-yL_opti).mean()), 
                'optimized lines portion:',[round(index.mean().item()*100,1) for index in idx])
            # print([index.mean().item() for index in idx])

        self.l[v] = yL_opti.detach()
        self.u[v] = yU_opti.detach()
        self.kl_idx[v] = ((yL_opti<0) * (yU_opti>0)).float().detach()
        kl, bl, ku, bu = get_bound.getConvenientGeneralActivationBound(
                                self.l[v], self.u[v], self.activation, use_constant=False)
        self.kl[v] = kl.detach()
        self.ku[v] = ku.detach()
        self.bl[v] = bl.detach()
        self.bu[v] = bu.detach()

        for k in range(1,v):
            self.kl[k] = kl_ori[k-1].clone().detach()
        return yL_opti, yU_opti
    def optimize_k_neuronwise(self,
                              v,
                              eps,
                              p,
                              x,
                              num_neurons,
                              max_iter=100,
                              print_loss=True,
                              patience=5,
                              acc=1e-2,
                              lr=1e-1,
                              init='middle'):
        # optimize the lower/upper bounding lines of h1_l/U,..,h(v-1)_L/U for yv
        # compute a tighter bound of the v-th layer, v range from 1 to num_layers
        # x should be a tensor of size (batch, input_dimesnion)
        # if this is the last layer and gx0 trick is applied,
        # we should instead use optimize_k_neuronwise_for_last_layer_with_gx0_trick_multi_sample

        if v == 1:
            yL_opti, yU_opti = self.compute2sideBound(eps, p, v, x=x)
        else:
            kl_ori = []
            bl_ori = []
            ku_ori = []
            bu_ori = []
            for k in range(1, v):
                kl_ori.append(self.kl[k].clone().detach())
                bl_ori.append(self.bl[k].clone().detach())
                ku_ori.append(self.ku[k].clone().detach())
                bu_ori.append(self.bu[k].clone().detach())
                kl_ori[k - 1].requires_grad = False
                bl_ori[k - 1].requires_grad = False
                ku_ori[k - 1].requires_grad = False
                bu_ori[k - 1].requires_grad = False

            num_neuron = self.b[v].shape[0]
            batch = x.shape[0]
            yL_opti = torch.zeros(batch, num_neuron, device=x.device)
            W_v = self.W[v].detach().clone()
            b_v = self.b[v].detach().clone()

            for j in range(
                    num_neuron
            ):  # for every neuron in this layer, we optimize over it for all images in this batch at once
                self.W[v] = W_v[j:j + 1, :]
                self.b[v] = b_v[j:j + 1]

                # init optimization variables
                dl = []
                du = []
                for k in range(1, v):
                    if init == 'ori':
                        dl.append(self.sl[k].clone().detach())
                        dl[k - 1].requires_grad = True
                        du.append(self.su[k].clone().detach())
                        du[k - 1].requires_grad = True
                    elif init == 'middle':
                        dl.append(((self.dl_lower[k] + self.dl_upper[k]) /
                                   2).detach())
                        dl[k - 1].requires_grad = True
                        du.append(((self.du_lower[k] + self.du_upper[k]) /
                                   2).detach())
                        du[k - 1].requires_grad = True
                    elif init == 'rand':
                        dl.append(((self.dl_upper[k] - self.dl_lower[k]) *
                                   torch.rand_like(self.dl_lower[k]) +
                                   self.dl_lower[k]).detach())
                        dl[k - 1].requires_grad = True
                        du.append(((self.du_upper[k] - self.du_lower[k]) *
                                   torch.rand_like(self.du_lower[k]) +
                                   self.du_lower[k]).detach())
                        du[k - 1].requires_grad = True
                    else:
                        raise Exception('%s initialization not supported' %
                                        init)

                optimizer = optim.Adam(dl + du, lr=lr)
                stopper = EarlyStop(patience, acc=acc)

                for i in range(max_iter):
                    for k in range(1, v):
                        dl[k - 1].data = torch.max(
                            self.dl_lower[k],
                            torch.min(dl[k - 1], self.dl_upper[k]))
                        kl_temp, bl_temp = get_tangent_line_short(
                            dl[k - 1], self.activation)
                        self.kl[k] = self.valid_l[k] * kl_temp + (
                            1 - self.valid_l[k]) * kl_ori[k - 1]
                        self.bl[k] = self.valid_l[k] * bl_temp + (
                            1 - self.valid_l[k]) * bl_ori[k - 1]

                        du[k - 1].data = torch.max(
                            self.du_lower[k],
                            torch.min(du[k - 1], self.du_upper[k]))
                        ku_temp, bu_temp = get_tangent_line_short(
                            du[k - 1], self.activation)
                        self.ku[k] = self.valid_u[k] * ku_temp + (
                            1 - self.valid_u[k]) * ku_ori[k - 1]
                        self.bu[k] = self.valid_u[k] * bu_temp + (
                            1 - self.valid_u[k]) * bu_ori[k - 1]

                    yL, _ = self.compute2sideBound(eps, p, v,
                                                   x=x)  # shape (batch,1)

                    loss1 = -yL
                    loss1 = loss1.mean()
                    if stopper.should_stop(loss1.detach().cpu().item()):
                        break
                    optimizer.zero_grad()
                    # pdb.set_trace()
                    loss1.backward()
                    optimizer.step()
                    if print_loss:
                        print('neuron %d/%d step %d yL mean %.5f' %
                              (j + 1, num_neuron, i + 1, -loss1))

                yL_opti[:, j] = yL.squeeze(1)

            yU_opti = torch.zeros(batch, num_neuron, device=x.device)
            for j in range(num_neuron):
                self.W[v] = W_v[j:j + 1, :]
                self.b[v] = b_v[j:j + 1]

                # init optimization variables
                dl = []
                du = []
                for k in range(1, v):
                    if init == 'ori':
                        dl.append(self.sl[k].clone().detach())
                        dl[k - 1].requires_grad = True
                        du.append(self.su[k].clone().detach())
                        du[k - 1].requires_grad = True
                    elif init == 'middle':
                        dl.append(((self.dl_lower[k] + self.dl_upper[k]) /
                                   2).detach())
                        dl[k - 1].requires_grad = True
                        du.append(((self.du_lower[k] + self.du_upper[k]) /
                                   2).detach())
                        du[k - 1].requires_grad = True
                    elif init == 'rand':
                        dl.append(((self.dl_upper[k] - self.dl_lower[k]) *
                                   torch.rand_like(self.dl_lower[k]) +
                                   self.dl_lower[k]).detach())
                        dl[k - 1].requires_grad = True
                        du.append(((self.du_upper[k] - self.du_lower[k]) *
                                   torch.rand_like(self.du_lower[k]) +
                                   self.du_lower[k]).detach())
                        du[k - 1].requires_grad = True
                    else:
                        raise Exception('%s initialization not supported' %
                                        init)

                optimizer = optim.Adam(dl + du, lr=lr)
                stopper = EarlyStop(patience, acc=acc)

                for i in range(max_iter):
                    for k in range(1, v):
                        dl[k - 1].data = torch.max(
                            self.dl_lower[k],
                            torch.min(dl[k - 1], self.dl_upper[k]))
                        kl_temp, bl_temp = get_tangent_line_short(
                            dl[k - 1], self.activation)
                        self.kl[k] = self.valid_l[k] * kl_temp + (
                            1 - self.valid_l[k]) * kl_ori[k - 1]
                        self.bl[k] = self.valid_l[k] * bl_temp + (
                            1 - self.valid_l[k]) * bl_ori[k - 1]

                        du[k - 1].data = torch.max(
                            self.du_lower[k],
                            torch.min(du[k - 1], self.du_upper[k]))
                        ku_temp, bu_temp = get_tangent_line_short(
                            du[k - 1], self.activation)
                        self.ku[k] = self.valid_u[k] * ku_temp + (
                            1 - self.valid_u[k]) * ku_ori[k - 1]
                        self.bu[k] = self.valid_u[k] * bu_temp + (
                            1 - self.valid_u[k]) * bu_ori[k - 1]

                    _, yU = self.compute2sideBound(eps, p, v,
                                                   x=x)  # shape (batch,1)

                    loss2 = yU
                    # print('loss2 neuron %d/%d step %d \n' % (j+1, num_neuron, i+1), loss2)
                    loss2 = loss2.mean()
                    if stopper.should_stop(loss2.detach().cpu().item()):
                        break
                    optimizer.zero_grad()
                    loss2.backward()
                    optimizer.step()
                    if print_loss:
                        print('neuron %d/%d step %d yU mean %.5f' %
                              (j + 1, num_neuron, i + 1, loss2))

                yU_opti[:, j] = yU.squeeze(1)

            self.W[v] = W_v.detach()
            self.b[v] = b_v.detach()

        if v == 1:
            print('Layer %d: yU-yL mean: %.3f' % (v,
                                                  (yU_opti - yL_opti).mean()))
        else:
            print('Layer %d: yU-yL mean: %.3f' % (v,
                                                  (yU_opti - yL_opti).mean()))
            print('optimized lower lines portion:', [
                round(index.mean().item() * 100, 1)
                for index in self.valid_l[1:v]
            ])
            print('optimized upper lines portion:', [
                round(index.mean().item() * 100, 1)
                for index in self.valid_u[1:v]
            ])
            # print([index.mean().item() for index in idx])

        self.l[v] = yL_opti.detach()
        self.u[v] = yU_opti.detach()
        with torch.no_grad():
            kl, bl, ku, bu, sl, sl_valid, su, su_valid = getConvenientGeneralActivationBound(
                self.l[v],
                self.u[v],
                self.activation,
                use_constant=False,
                remain_tangent_line_info=True)
        self.kl[v] = kl.detach()
        self.ku[v] = ku.detach()
        self.bl[v] = bl.detach()
        self.bu[v] = bu.detach()
        self.sl[v] = sl.detach()
        self.valid_l[v] = sl_valid.detach()
        self.su[v] = su.detach()
        self.valid_u[v] = su_valid.detach()

        idx = ((self.l[v] < 0) * (self.u[v] > 0)).detach().float()

        self.dl_lower[v] = self.l[v].detach().clone()
        self.dl_upper[v] = (idx * self.sl[v] + (1 - idx) * self.u[v]).detach()

        self.du_lower[v] = (idx * self.su[v] + (1 - idx) * self.l[v]).detach()
        self.du_upper[v] = self.u[v].detach().clone()

        for k in range(1, v):
            self.kl[k] = kl_ori[k - 1].detach()
            self.bl[k] = bl_ori[k - 1].detach()
            self.ku[k] = ku_ori[k - 1].detach()
            self.bu[k] = bu_ori[k - 1].detach()
        return yL_opti, yU_opti
Exemplo n.º 6
0
    return alpha_l,beta_l,gamma_l, alpha_u,beta_u,gamma_u

if __name__ == '__main__':
    #bound tanh(x) sigmoid(y)
    x_minus = torch.rand(2,3) - 0.5
    x_plus = x_minus + 0.5 
    
    y_minus = torch.rand(2,3) - 0.5
    y_plus = y_minus + 0.5 
    # x_minus = torch.Tensor([0.5])
    # x_plus = torch.Tensor([1])
    
    # y_minus = torch.Tensor([-1])
    # y_plus = torch.Tensor([1])
    
    kl, bl, ku, bu = getConvenientGeneralActivationBound(y_minus,
                                                            y_plus, 'sigmoid')
    X_l = torch.tanh(x_minus)
    X_u = torch.tanh(x_plus)
    
    I_l = (X_l>=0).float()
    I_u = (X_u>=0).float()
    
    #k_l y + b_l <= sigmoid(y) <= k_u y + b_u
    #X_l*k_l y + X_l*b_l <= tanh(x)sigmoid(y), when X_l>=0
    #X_l*k_u y + X_l*b_u <= tanh(x)sigmoid(y), when X_l<0
    
    alpha_l = torch.zeros(x_minus.shape, device=x_minus.device)
    beta_l = I_l * X_l * kl + (1-I_l) * X_l * ku
    gamma_l = I_l * X_l * bl + (1-I_l) * X_l * bu
    
    #tanh(x)sigmoid(y) <= X_u*k_u y + X_u*b_u, when X_u>=0
Exemplo n.º 7
0
    def computeLast2sideBound(self, eps, p, v, X=None, Eps_idx=None):
        with torch.no_grad():
            n = self.W_ax.shape[1]  # input_size
            s = self.W_ax.shape[0]  # hidden_size
            t = self.W_fa.shape[0]  # output_size
            idx_eps = torch.zeros(self.time_step, device=X.device)
            idx_eps[Eps_idx - 1] = 1
            if X is None:
                X = self.X
            N = X.shape[0]  # number of images, batch size
            if self.a_0 is None:
                a_0 = torch.zeros(N, s, device=X.device)
            else:
                a_0 = self.a_0
            if type(eps) == torch.Tensor:
                eps = eps.to(X.device)
            if p == 1:
                q = float('inf')
            elif p == 'inf':
                q = 1
            else:
                q = p / (p - 1)

            yU = torch.zeros(N, t, device=X.device)  # [N,s]
            yL = torch.zeros(N, t, device=X.device)  # [N,s]

            W_ax = self.W_ax.unsqueeze(0).expand(N, -1, -1)  # [N, s, n]
            W_aa = self.W_aa.unsqueeze(0).expand(N, -1, -1)  # [N, s, s]
            W_fa = self.W_fa.unsqueeze(0).expand(N, -1, -1)  # [N, t, s]
            b_ax = self.b_ax.unsqueeze(0).expand(N, -1)  # [N, s]
            b_aa = self.b_aa.unsqueeze(0).expand(N, -1)  # [N, s]
            b_f = self.b_f.unsqueeze(0).expand(N, -1)  # [N, t]

            # k from time_step+1 to 1 terms
            for k in range(v - 1, 0, -1):
                if k == v - 1:
                    ## compute A^{<v-1>}, Ou^{<v-1>}, Delta^{<v-1>} and Theta^{<v-1>}
                    ### 1. compute slopes alpha and intercepts beta
                    kl, bl, ku, bu = get_bound.getConvenientGeneralActivationBound(
                        self.l[k], self.u[k], self.activation)

                    bl = bl / kl
                    bu = bu / ku

                    self.kl[k] = kl  # [N, s]
                    self.ku[k] = ku  # [N, s]
                    self.bl[k] = bl  # [N, s]
                    self.bu[k] = bu  # [N, s]
                    alpha_l = kl.unsqueeze(1).expand(-1, t, -1)  # [N, t, s]
                    alpha_u = ku.unsqueeze(1).expand(-1, t, -1)  # [N, t, s]
                    beta_l = bl.unsqueeze(1).expand(-1, t, -1)  # [N, t, s]
                    beta_u = bu.unsqueeze(1).expand(-1, t, -1)  # [N, t, s]
                    ### 2. compute lambda^{<v-1>}, omega^{<v-1>}, Delta^{<v-1>} and Theta^{<v-1>}
                    I = (W_fa >= 0).float()  # [N, t, s]
                    lamida = I * alpha_u + (1 - I) * alpha_l
                    omiga = I * alpha_l + (1 - I) * alpha_u
                    Delta = I * beta_u + (
                        1 - I
                    ) * beta_l  # [N, t, s], this is the transpose of the delta defined in the paper
                    Theta = I * beta_l + (1 - I) * beta_u  # [N, t, s]
                    ### 3. clear l[k] and u[k] to release memory
                    self.l[k] = None
                    self.u[k] = None
                    ### 4. compute A^{<v-1>} and Ou^{<v-1>}
                    A = W_fa * lamida  # [N, t, s]
                    Ou = W_fa * omiga  # [N, t, s]
                else:
                    ## compute A^{<k>}, Ou^{<k>}, Delta^{<k>} and Theta^{<k>}
                    ### 1. compute slopes alpha and intercepts beta
                    alpha_l = self.kl[k].unsqueeze(1).expand(-1, t, -1)
                    alpha_u = self.ku[k].unsqueeze(1).expand(-1, t, -1)
                    beta_l = self.bl[k].unsqueeze(1).expand(-1, t,
                                                            -1)  # [N, t, s]
                    beta_u = self.bu[k].unsqueeze(1).expand(-1, t,
                                                            -1)  # [N, t, s]
                    ### 2. compute lambda^{<k>}, omega^{<k>}, Delta^{<k>} and Theta^{<k>}
                    I = (torch.matmul(A, W_aa) >= 0).float()  # [N, t, s]
                    lamida = I * alpha_u + (1 - I) * alpha_l
                    Delta = I * beta_u + (
                        1 - I
                    ) * beta_l  # [N, s, s], this is the transpose of the delta defined in the paper
                    I = (torch.matmul(Ou, W_aa) >= 0).float()  # [N, t, s]
                    omiga = I * alpha_l + (1 - I) * alpha_u
                    Theta = I * beta_l + (1 - I) * beta_u  # [N, s, s]
                    ### 3. compute A^{<k>} and Ou^{<k>}
                    A = torch.matmul(A, W_aa) * lamida  # [N, s, s]
                    Ou = torch.matmul(Ou, W_aa) * omiga  # [N, s, s]
                ## first term
                if type(eps) == torch.Tensor:
                    #eps is a tensor of size N
                    yU = yU + idx_eps[k - 1] * eps.unsqueeze(1).expand(
                        -1, t) * torch.norm(torch.matmul(A, W_ax), p=q,
                                            dim=2)  # eps ||A^ {<k>} W_ax||q
                    yL = yL - idx_eps[k - 1] * eps.unsqueeze(1).expand(
                        -1, t) * torch.norm(torch.matmul(Ou, W_ax), p=q,
                                            dim=2)  # eps ||Ou^ {<k>} W_ax||q
                else:
                    yU = yU + idx_eps[k - 1] * eps * torch.norm(
                        torch.matmul(
                            A, W_ax), p=q, dim=2)  # eps ||A^ {<k>} W_ax||q
                    yL = yL - idx_eps[k - 1] * eps * torch.norm(
                        torch.matmul(
                            Ou, W_ax), p=q, dim=2)  # eps ||Ou^ {<k>} W_ax||q
                ## second term
                yU = yU + torch.matmul(
                    A, torch.matmul(W_ax, X[:, k - 1, :].view(
                        N, n, 1))).squeeze(2)  # A^ {<k>} W_ax x^{<k>}
                yL = yL + torch.matmul(
                    Ou, torch.matmul(W_ax, X[:, k - 1, :].view(
                        N, n, 1))).squeeze(2)  # Ou^ {<k>} W_ax x^{<k>}
                ## third term
                yU = yU + torch.matmul(A, (b_aa + b_ax).view(
                    N, s, 1)).squeeze(2) + (A * Delta).sum(
                        2)  # A^ {<k>} (b_a + Delta^{<k>})
                yL = yL + torch.matmul(Ou, (b_aa + b_ax).view(
                    N, s, 1)).squeeze(2) + (Ou * Theta).sum(
                        2)  # Ou^ {<k>} (b_a + Theta^{<k>})
            # compute A^{<0>}
            A = torch.matmul(A, W_aa)  # (A^ {<1>} W_aa) * lambda^{<0>}
            Ou = torch.matmul(Ou, W_aa)  # (Ou^ {<1>} W_aa) * omega^{<0>}
            yU = yU + torch.matmul(A, a_0.view(N, s, 1)).squeeze(
                2)  # A^ {<0>} * a_0
            yL = yL + torch.matmul(Ou, a_0.view(N, s, 1)).squeeze(
                2)  # Ou^ {<0>} * a_0
            yU = yU + b_f
            yL = yL + b_f
        return yL, yU