예제 #1
0
 def get_reduced(x, all_possible_sub):
     num_e = h.product(x.size())
     view_num = all_possible_sub * h.product(self.in_shape)
     if num_e >= view_num and num_e % view_num == 0:  # convert to Box (HybirdZonotope)
         lower = x.min(1)[0]
         upper = x.max(1)[0]
         return ai.HybridZonotope((lower + upper) / 2,
                                  (upper - lower) / 2, None)
     else:  # if it is a Point()
         assert False
예제 #2
0
        def LabeledDomain(x, **kwargs):
            def get_swaped(d):
                t = np.zeros(d)
                while len(np.unique(t)) != d:
                    t = np.random.randint(0, len(self.swaps), d)
                ys = []
                for i in range(d):
                    y1 = y.clone()
                    for (p, q) in self.swaps[t[i]]:
                        y1[:, :, p, :], y1[:, :, q, :] = y1[:, :,
                                                            q, :], y1[:, :,
                                                                      p, :]
                    ys.append(y1)
                return ys

            xc = x.center()
            y = self.embed(xc.long()).view(-1, 1, self.in_shape[0], self.dim)
            # y = kargs["y"]
            if x.label == "Box":
                y_ls_1 = y.clone()
                y_ls_1[:, :, :-1, :] = y[:, :, 1:, :]
                y_rs_1 = y.clone()
                y_rs_1[:, :, 1:, :] = y[:, :, :-1, :]
                lower = torch.min(torch.min(y, y_ls_1), y_rs_1)
                upper = torch.max(torch.max(y, y_ls_1), y_rs_1)
                return ai.HybridZonotope((upper + lower) / 2,
                                         (upper - lower) / 2, None)
            elif x.label[-6:] == "Points":
                d = int(x.label[:-6])
                if d == 1:
                    return y
                ys = get_swaped(d)
                mid = ys[0]
                err = None
                for i in range(1, d):
                    err = torch.unsqueeze(
                        (mid - ys[i]) / 2, 0) if err is None else torch.cat(
                            [err, torch.unsqueeze((mid - ys[i]) / 2, 0)], 0)
                    mid = (mid + ys[i]) / 2
                return ai.TaggedDomain(ai.HybridZonotope(mid, None, err),
                                       g.HBox(0))
            elif x.label[-len("Zonotope_Dataaug"
                              ):] == "Zonotope_Dataaug" or x.label[-len(
                                  "Interval_Dataaug"):] == "Interval_Dataaug":
                d = int(x.label[:-len("Zonotope_Dataaug")])
                if x.label[-len("Zonotope_Dataaug"):] == "Zonotope_Dataaug":
                    tag = "Points"
                else:
                    tag = "Points_Interval"
                ret = []
                i = 0
                while i < d:
                    if d - i in [6, 7]:
                        t = 3
                    else:
                        t = min(5, d - i)
                    x.label = str(t) + tag
                    ret.append(
                        ai.TaggedDomain(LabeledDomain(x),
                                        g.DList.MLoss(1.0 * t / d)))
                    # ret.append(LabeledDomain(x))
                    i += t

                return ai.ListDomain(ret)
                # return ai.ListDisjDomain(ret)
            elif x.label[-len("Points_Interval"):] == "Points_Interval":
                d = int(x.label[:-len("Points_Interval")])
                ys = get_swaped(d)
                lower = ys[0].clone()
                upper = ys[0].clone()
                for i in range(1, d):
                    lower = torch.min(lower, ys[i])
                    upper = torch.max(upper, ys[i])
                return ai.TaggedDomain(
                    ai.HybridZonotope((upper + lower) / 2, (upper - lower) / 2,
                                      None), g.HBox(0))
            elif x.label[-len("Points_Dataaug"):] == "Points_Dataaug":
                d = int(x.label[:-len("Points_Dataaug")])
                ys = get_swaped(d)
                return ai.ListDomain(
                    [ai.TaggedDomain(y, g.DList.MLoss(1.0 / d)) for y in ys])
            elif x.label[-len("Convex_Dataaug"):] == "Convex_Dataaug":
                d = int(x.label[:-len("Convex_Dataaug")])
                ys = torch.cat(get_swaped(d), 1)
                return ai.TaggedDomain(ys.view(-1, 1, self.in_shape[0],
                                               self.dim),
                                       tag="magic" + str(d))
            elif x.label[-len("Convex_Box_Groups"):] == "Convex_Box_Groups":
                try:
                    groups_consider = int(x.label[:-len("Convex_Box_Groups")])
                except:
                    groups_consider = None
                if groups_consider is None:
                    groups_consider = len(self.groups)
                else:
                    random.shuffle(self.groups)
                x = xc.repeat((1, groups_consider + 1))
                for i in x:
                    for j in range(1, groups_consider + 1):
                        for p, q in self.groups[j - 1]:
                            i[j * self.in_shape[0] + p] = i[q]

                y = self.embed(x.long()).view(-1, 1, self.in_shape[0],
                                              self.dim)
                if self.delta != 1:
                    for id in range(len(y)):
                        item_group_id = id % (groups_consider + 1)
                        item_id = id - item_group_id
                        if item_group_id == 0: continue
                        y[id] = y[id] * self.delta + (1 -
                                                      self.delta) * y[item_id]

                return ai.TaggedDomain(y,
                                       tag="magic" + str(groups_consider + 1))
            else:
                raise NotImplementedError()
 def Domain(self, *args, **kargs):
     return ai.HybridZonotope(*args, customRelu = ai.creluNIPS, **kargs)