Ejemplo n.º 1
0
 def __sub__(self, other):
     if isinstance(other, HybridZonotope):
         return self.new(self.head - other.head
                         , h.msum(self.beta, other.beta, lambda a,b: a + b)
                         , h.msum(self.errors, None if other.errors is None else -other.errors, catNonNullErrors(lambda a,b: a + b)))
     else:
         # other has to be a standard variable or tensor
         return self.new(self.head - other, self.beta, self.errors)
Ejemplo n.º 2
0
    def merge(
        self,
        other,
        ref=None
    ):  # the vast majority of the time ref should be none here.  Not for parallel computation with powerset
        s_beta = self.getBeta()  # so that beta is never none

        sbox_u = self.head + s_beta
        sbox_l = self.head - s_beta
        o_u = other.ub()
        o_l = other.lb()
        o_in_s = (o_u <= sbox_u) & (o_l >= sbox_l)

        s_err_mx = self.errors.abs().sum(dim=0)

        if not isinstance(other, HybridZonotope):
            new_head = (self.head + other.center()) / 2
            new_beta = torch.max(sbox_u + s_err_mx, o_u) - new_head
            return self.new(torch.where(o_in_s, self.head, new_head),
                            torch.where(o_in_s, self.beta, new_beta),
                            o_in_s.float() * self.errors)

        # TODO: could be more efficient if one of these doesn't have beta or errors but thats okay for now.
        s_u = sbox_u + s_err_mx
        s_l = sbox_l - s_err_mx

        obox_u = o_u - other.head
        obox_l = o_l + other.head

        s_in_o = (s_u <= obox_u) & (s_l >= obox_l)

        # TODO: could theoretically still do something better when one is contained partially in the other
        new_head = (self.head + other.center()) / 2
        new_beta = torch.max(sbox_u + self.getErrors().abs().sum(dim=0),
                             o_u) - new_head

        return self.new(
            torch.where(o_in_s, self.head,
                        torch.where(s_in_o, other.head, new_head)),
            torch.where(o_in_s, s_beta,
                        torch.where(s_in_o, other.getBeta(), new_beta)),
            h.msum(
                o_in_s.float() * self.errors,
                s_in_o.float() * other.errors,
                catNonNullErrors(lambda a, b: a + b,
                                 ref_errs=ref.errors if ref is not None else
                                 ref)))  # these are both zero otherwise
Ejemplo n.º 3
0
 def cat(self,other, dim=0):
     return self.new(self.head.cat(other.head, dim = dim), 
                     h.msum(other.beta, self.beta, lambda a,b: a.cat(b, dim = dim)),
                     h.msum(self.errors, other.errors, catNonNullErrors(lambda a,b: a.cat(b, dim+1))))
Ejemplo n.º 4
0
 def addPar(self, a, b):
     return self.new(
         a.head + b.head, h.msum(a.beta, b.beta, lambda a, b: a + b),
         h.msum(a.errors, b.errors,
                catNonNullErrors(lambda a, b: a + b, self.errors)))