def __sub__(self, other): if isinstance(other, HybridZonotope): return self.new(self.head - other.head , h.msum(self.beta, other.beta, lambda a,b: a + b) , h.msum(self.errors, None if other.errors is None else -other.errors, catNonNullErrors(lambda a,b: a + b))) else: # other has to be a standard variable or tensor return self.new(self.head - other, self.beta, self.errors)
def merge( self, other, ref=None ): # the vast majority of the time ref should be none here. Not for parallel computation with powerset s_beta = self.getBeta() # so that beta is never none sbox_u = self.head + s_beta sbox_l = self.head - s_beta o_u = other.ub() o_l = other.lb() o_in_s = (o_u <= sbox_u) & (o_l >= sbox_l) s_err_mx = self.errors.abs().sum(dim=0) if not isinstance(other, HybridZonotope): new_head = (self.head + other.center()) / 2 new_beta = torch.max(sbox_u + s_err_mx, o_u) - new_head return self.new(torch.where(o_in_s, self.head, new_head), torch.where(o_in_s, self.beta, new_beta), o_in_s.float() * self.errors) # TODO: could be more efficient if one of these doesn't have beta or errors but thats okay for now. s_u = sbox_u + s_err_mx s_l = sbox_l - s_err_mx obox_u = o_u - other.head obox_l = o_l + other.head s_in_o = (s_u <= obox_u) & (s_l >= obox_l) # TODO: could theoretically still do something better when one is contained partially in the other new_head = (self.head + other.center()) / 2 new_beta = torch.max(sbox_u + self.getErrors().abs().sum(dim=0), o_u) - new_head return self.new( torch.where(o_in_s, self.head, torch.where(s_in_o, other.head, new_head)), torch.where(o_in_s, s_beta, torch.where(s_in_o, other.getBeta(), new_beta)), h.msum( o_in_s.float() * self.errors, s_in_o.float() * other.errors, catNonNullErrors(lambda a, b: a + b, ref_errs=ref.errors if ref is not None else ref))) # these are both zero otherwise
def cat(self,other, dim=0): return self.new(self.head.cat(other.head, dim = dim), h.msum(other.beta, self.beta, lambda a,b: a.cat(b, dim = dim)), h.msum(self.errors, other.errors, catNonNullErrors(lambda a,b: a.cat(b, dim+1))))
def addPar(self, a, b): return self.new( a.head + b.head, h.msum(a.beta, b.beta, lambda a, b: a + b), h.msum(a.errors, b.errors, catNonNullErrors(lambda a, b: a + b, self.errors)))