def get_reduced(x, all_possible_sub): num_e = h.product(x.size()) view_num = all_possible_sub * h.product(self.in_shape) if num_e >= view_num and num_e % view_num == 0: # convert to Box (HybirdZonotope) lower = x.min(1)[0] upper = x.max(1)[0] return ai.HybridZonotope((lower + upper) / 2, (upper - lower) / 2, None) else: # if it is a Point() assert False
def LabeledDomain(x, **kwargs): def get_swaped(d): t = np.zeros(d) while len(np.unique(t)) != d: t = np.random.randint(0, len(self.swaps), d) ys = [] for i in range(d): y1 = y.clone() for (p, q) in self.swaps[t[i]]: y1[:, :, p, :], y1[:, :, q, :] = y1[:, :, q, :], y1[:, :, p, :] ys.append(y1) return ys xc = x.center() y = self.embed(xc.long()).view(-1, 1, self.in_shape[0], self.dim) # y = kargs["y"] if x.label == "Box": y_ls_1 = y.clone() y_ls_1[:, :, :-1, :] = y[:, :, 1:, :] y_rs_1 = y.clone() y_rs_1[:, :, 1:, :] = y[:, :, :-1, :] lower = torch.min(torch.min(y, y_ls_1), y_rs_1) upper = torch.max(torch.max(y, y_ls_1), y_rs_1) return ai.HybridZonotope((upper + lower) / 2, (upper - lower) / 2, None) elif x.label[-6:] == "Points": d = int(x.label[:-6]) if d == 1: return y ys = get_swaped(d) mid = ys[0] err = None for i in range(1, d): err = torch.unsqueeze( (mid - ys[i]) / 2, 0) if err is None else torch.cat( [err, torch.unsqueeze((mid - ys[i]) / 2, 0)], 0) mid = (mid + ys[i]) / 2 return ai.TaggedDomain(ai.HybridZonotope(mid, None, err), g.HBox(0)) elif x.label[-len("Zonotope_Dataaug" ):] == "Zonotope_Dataaug" or x.label[-len( "Interval_Dataaug"):] == "Interval_Dataaug": d = int(x.label[:-len("Zonotope_Dataaug")]) if x.label[-len("Zonotope_Dataaug"):] == "Zonotope_Dataaug": tag = "Points" else: tag = "Points_Interval" ret = [] i = 0 while i < d: if d - i in [6, 7]: t = 3 else: t = min(5, d - i) x.label = str(t) + tag ret.append( ai.TaggedDomain(LabeledDomain(x), g.DList.MLoss(1.0 * t / d))) # ret.append(LabeledDomain(x)) i += t return ai.ListDomain(ret) # return ai.ListDisjDomain(ret) elif x.label[-len("Points_Interval"):] == "Points_Interval": d = int(x.label[:-len("Points_Interval")]) ys = get_swaped(d) lower = ys[0].clone() upper = ys[0].clone() for i in range(1, d): lower = torch.min(lower, ys[i]) upper = torch.max(upper, ys[i]) return ai.TaggedDomain( ai.HybridZonotope((upper + lower) / 2, (upper - lower) / 2, None), g.HBox(0)) elif x.label[-len("Points_Dataaug"):] == "Points_Dataaug": d = int(x.label[:-len("Points_Dataaug")]) ys = get_swaped(d) return ai.ListDomain( [ai.TaggedDomain(y, g.DList.MLoss(1.0 / d)) for y in ys]) elif x.label[-len("Convex_Dataaug"):] == "Convex_Dataaug": d = int(x.label[:-len("Convex_Dataaug")]) ys = torch.cat(get_swaped(d), 1) return ai.TaggedDomain(ys.view(-1, 1, self.in_shape[0], self.dim), tag="magic" + str(d)) elif x.label[-len("Convex_Box_Groups"):] == "Convex_Box_Groups": try: groups_consider = int(x.label[:-len("Convex_Box_Groups")]) except: groups_consider = None if groups_consider is None: groups_consider = len(self.groups) else: random.shuffle(self.groups) x = xc.repeat((1, groups_consider + 1)) for i in x: for j in range(1, groups_consider + 1): for p, q in self.groups[j - 1]: i[j * self.in_shape[0] + p] = i[q] y = self.embed(x.long()).view(-1, 1, self.in_shape[0], self.dim) if self.delta != 1: for id in range(len(y)): item_group_id = id % (groups_consider + 1) item_id = id - item_group_id if item_group_id == 0: continue y[id] = y[id] * self.delta + (1 - self.delta) * y[item_id] return ai.TaggedDomain(y, tag="magic" + str(groups_consider + 1)) else: raise NotImplementedError()
def Domain(self, *args, **kargs): return ai.HybridZonotope(*args, customRelu = ai.creluNIPS, **kargs)