Exemplo n.º 1
0
    def forward(self, x, **kargs):
        def get_reduced(x, all_possible_sub):
            num_e = h.product(x.size())
            view_num = all_possible_sub * h.product(self.in_shape)
            if num_e >= view_num and num_e % view_num == 0:  # convert to Box (HybirdZonotope)
                lower = x.min(1)[0]
                upper = x.max(1)[0]
                return ai.HybridZonotope((lower + upper) / 2,
                                         (upper - lower) / 2, None)
            else:  # if it is a Point()
                assert False

        if isinstance(x, ai.ListDomain):
            for (i, a) in enumerate(x.al):
                x.al[i] = self.forward(a)
            return x
        elif isinstance(x, ai.TaggedDomain) and isinstance(
                x.tag, str) and x.tag[:5] == "magic":
            return get_reduced(x.a, int(x.tag[5:]))
        elif isinstance(x, ai.TaggedDomain):
            return ai.TaggedDomain(self.forward(x.a), x.tag)
        elif isinstance(x, torch.Tensor):
            return x
        elif isinstance(x, ai.HybridZonotope):
            return x
        else:
            raise NotImplementedError()
Exemplo n.º 2
0
    def boxBetween(self, *args, **kargs):

        m = self.getDiv(**kargs)
        return self.Domain(
            ai.TaggedDomain(a.boxBetween(*args, **kargs),
                            DList.MLoss(aw.getVal(**kargs) * m))
            for a, aw in self.al)
Exemplo n.º 3
0
        def LabeledDomain(x, **kwargs):
            def get_swaped(d):
                t = np.zeros(d)
                while len(np.unique(t)) != d:
                    t = np.random.randint(0, len(self.swaps), d)
                ys = []
                for i in range(d):
                    y1 = y.clone()
                    for (p, q) in self.swaps[t[i]]:
                        y1[:, :, p, :], y1[:, :, q, :] = y1[:, :,
                                                            q, :], y1[:, :,
                                                                      p, :]
                    ys.append(y1)
                return ys

            xc = x.center()
            y = self.embed(xc.long()).view(-1, 1, self.in_shape[0], self.dim)
            # y = kargs["y"]
            if x.label == "Box":
                y_ls_1 = y.clone()
                y_ls_1[:, :, :-1, :] = y[:, :, 1:, :]
                y_rs_1 = y.clone()
                y_rs_1[:, :, 1:, :] = y[:, :, :-1, :]
                lower = torch.min(torch.min(y, y_ls_1), y_rs_1)
                upper = torch.max(torch.max(y, y_ls_1), y_rs_1)
                return ai.HybridZonotope((upper + lower) / 2,
                                         (upper - lower) / 2, None)
            elif x.label[-6:] == "Points":
                d = int(x.label[:-6])
                if d == 1:
                    return y
                ys = get_swaped(d)
                mid = ys[0]
                err = None
                for i in range(1, d):
                    err = torch.unsqueeze(
                        (mid - ys[i]) / 2, 0) if err is None else torch.cat(
                            [err, torch.unsqueeze((mid - ys[i]) / 2, 0)], 0)
                    mid = (mid + ys[i]) / 2
                return ai.TaggedDomain(ai.HybridZonotope(mid, None, err),
                                       g.HBox(0))
            elif x.label[-len("Zonotope_Dataaug"
                              ):] == "Zonotope_Dataaug" or x.label[-len(
                                  "Interval_Dataaug"):] == "Interval_Dataaug":
                d = int(x.label[:-len("Zonotope_Dataaug")])
                if x.label[-len("Zonotope_Dataaug"):] == "Zonotope_Dataaug":
                    tag = "Points"
                else:
                    tag = "Points_Interval"
                ret = []
                i = 0
                while i < d:
                    if d - i in [6, 7]:
                        t = 3
                    else:
                        t = min(5, d - i)
                    x.label = str(t) + tag
                    ret.append(
                        ai.TaggedDomain(LabeledDomain(x),
                                        g.DList.MLoss(1.0 * t / d)))
                    # ret.append(LabeledDomain(x))
                    i += t

                return ai.ListDomain(ret)
                # return ai.ListDisjDomain(ret)
            elif x.label[-len("Points_Interval"):] == "Points_Interval":
                d = int(x.label[:-len("Points_Interval")])
                ys = get_swaped(d)
                lower = ys[0].clone()
                upper = ys[0].clone()
                for i in range(1, d):
                    lower = torch.min(lower, ys[i])
                    upper = torch.max(upper, ys[i])
                return ai.TaggedDomain(
                    ai.HybridZonotope((upper + lower) / 2, (upper - lower) / 2,
                                      None), g.HBox(0))
            elif x.label[-len("Points_Dataaug"):] == "Points_Dataaug":
                d = int(x.label[:-len("Points_Dataaug")])
                ys = get_swaped(d)
                return ai.ListDomain(
                    [ai.TaggedDomain(y, g.DList.MLoss(1.0 / d)) for y in ys])
            elif x.label[-len("Convex_Dataaug"):] == "Convex_Dataaug":
                d = int(x.label[:-len("Convex_Dataaug")])
                ys = torch.cat(get_swaped(d), 1)
                return ai.TaggedDomain(ys.view(-1, 1, self.in_shape[0],
                                               self.dim),
                                       tag="magic" + str(d))
            elif x.label[-len("Convex_Box_Groups"):] == "Convex_Box_Groups":
                try:
                    groups_consider = int(x.label[:-len("Convex_Box_Groups")])
                except:
                    groups_consider = None
                if groups_consider is None:
                    groups_consider = len(self.groups)
                else:
                    random.shuffle(self.groups)
                x = xc.repeat((1, groups_consider + 1))
                for i in x:
                    for j in range(1, groups_consider + 1):
                        for p, q in self.groups[j - 1]:
                            i[j * self.in_shape[0] + p] = i[q]

                y = self.embed(x.long()).view(-1, 1, self.in_shape[0],
                                              self.dim)
                if self.delta != 1:
                    for id in range(len(y)):
                        item_group_id = id % (groups_consider + 1)
                        item_id = id - item_group_id
                        if item_group_id == 0: continue
                        y[id] = y[id] * self.delta + (1 -
                                                      self.delta) * y[item_id]

                return ai.TaggedDomain(y,
                                       tag="magic" + str(groups_consider + 1))
            else:
                raise NotImplementedError()
Exemplo n.º 4
0
    def forward(self, x, **kargs):
        def LabeledDomain(x, **kwargs):
            def get_swaped(d):
                t = np.zeros(d)
                while len(np.unique(t)) != d:
                    t = np.random.randint(0, len(self.swaps), d)
                ys = []
                for i in range(d):
                    y1 = y.clone()
                    for (p, q) in self.swaps[t[i]]:
                        y1[:, :, p, :], y1[:, :, q, :] = y1[:, :,
                                                            q, :], y1[:, :,
                                                                      p, :]
                    ys.append(y1)
                return ys

            xc = x.center()
            y = self.embed(xc.long()).view(-1, 1, self.in_shape[0], self.dim)
            # y = kargs["y"]
            if x.label == "Box":
                y_ls_1 = y.clone()
                y_ls_1[:, :, :-1, :] = y[:, :, 1:, :]
                y_rs_1 = y.clone()
                y_rs_1[:, :, 1:, :] = y[:, :, :-1, :]
                lower = torch.min(torch.min(y, y_ls_1), y_rs_1)
                upper = torch.max(torch.max(y, y_ls_1), y_rs_1)
                return ai.HybridZonotope((upper + lower) / 2,
                                         (upper - lower) / 2, None)
            elif x.label[-6:] == "Points":
                d = int(x.label[:-6])
                if d == 1:
                    return y
                ys = get_swaped(d)
                mid = ys[0]
                err = None
                for i in range(1, d):
                    err = torch.unsqueeze(
                        (mid - ys[i]) / 2, 0) if err is None else torch.cat(
                            [err, torch.unsqueeze((mid - ys[i]) / 2, 0)], 0)
                    mid = (mid + ys[i]) / 2
                return ai.TaggedDomain(ai.HybridZonotope(mid, None, err),
                                       g.HBox(0))
            elif x.label[-len("Zonotope_Dataaug"
                              ):] == "Zonotope_Dataaug" or x.label[-len(
                                  "Interval_Dataaug"):] == "Interval_Dataaug":
                d = int(x.label[:-len("Zonotope_Dataaug")])
                if x.label[-len("Zonotope_Dataaug"):] == "Zonotope_Dataaug":
                    tag = "Points"
                else:
                    tag = "Points_Interval"
                ret = []
                i = 0
                while i < d:
                    if d - i in [6, 7]:
                        t = 3
                    else:
                        t = min(5, d - i)
                    x.label = str(t) + tag
                    ret.append(
                        ai.TaggedDomain(LabeledDomain(x),
                                        g.DList.MLoss(1.0 * t / d)))
                    # ret.append(LabeledDomain(x))
                    i += t

                return ai.ListDomain(ret)
                # return ai.ListDisjDomain(ret)
            elif x.label[-len("Points_Interval"):] == "Points_Interval":
                d = int(x.label[:-len("Points_Interval")])
                ys = get_swaped(d)
                lower = ys[0].clone()
                upper = ys[0].clone()
                for i in range(1, d):
                    lower = torch.min(lower, ys[i])
                    upper = torch.max(upper, ys[i])
                return ai.TaggedDomain(
                    ai.HybridZonotope((upper + lower) / 2, (upper - lower) / 2,
                                      None), g.HBox(0))
            elif x.label[-len("Points_Dataaug"):] == "Points_Dataaug":
                d = int(x.label[:-len("Points_Dataaug")])
                ys = get_swaped(d)
                return ai.ListDomain(
                    [ai.TaggedDomain(y, g.DList.MLoss(1.0 / d)) for y in ys])
            elif x.label[-len("Convex_Dataaug"):] == "Convex_Dataaug":
                d = int(x.label[:-len("Convex_Dataaug")])
                ys = torch.cat(get_swaped(d), 1)
                return ai.TaggedDomain(ys.view(-1, 1, self.in_shape[0],
                                               self.dim),
                                       tag="magic" + str(d))
            elif x.label[-len("Convex_Box_Groups"):] == "Convex_Box_Groups":
                try:
                    groups_consider = int(x.label[:-len("Convex_Box_Groups")])
                except:
                    groups_consider = None
                if groups_consider is None:
                    groups_consider = len(self.groups)
                else:
                    random.shuffle(self.groups)
                x = xc.repeat((1, groups_consider + 1))
                for i in x:
                    for j in range(1, groups_consider + 1):
                        for p, q in self.groups[j - 1]:
                            i[j * self.in_shape[0] + p] = i[q]

                y = self.embed(x.long()).view(-1, 1, self.in_shape[0],
                                              self.dim)
                if self.delta != 1:
                    for id in range(len(y)):
                        item_group_id = id % (groups_consider + 1)
                        item_id = id - item_group_id
                        if item_group_id == 0: continue
                        y[id] = y[id] * self.delta + (1 -
                                                      self.delta) * y[item_id]

                return ai.TaggedDomain(y,
                                       tag="magic" + str(groups_consider + 1))
            else:
                raise NotImplementedError()

        if isinstance(x, ai.LabeledDomain):
            return LabeledDomain(x, **kargs)
        elif isinstance(x, ai.TaggedDomain):
            return ai.TaggedDomain(self.forward(x.a), x.tag)
        elif isinstance(x, ai.ListDomain):
            for (i, a) in enumerate(x.al):
                x.al[i] = self.forward(a)
            return x
        elif not x.isPoint(
        ):  # convert to Box (HybirdZonotope), if the input is Box
            x = x.center().vanillaTensorPart().long()
            groups = [[] for _ in range(len(x))]
            for i, data in enumerate(x):
                all_set = 0
                subs = [[] for _ in range(len(data))]
                for (j, s) in enumerate(data):
                    s = int(s)
                    subs[j] = self.adjacent_keys[s]
                    all_set += len(subs[j])

                while all_set > 0:
                    pre = -self.in_shape[0]
                    groups[i].append([])
                    for j in range(len(subs)):
                        if len(subs[j]) > 0:
                            if j - pre >= 20:  # 20 here is the kernal size + pooling size!
                                pre = j
                                groups[i][-1].append((j, subs[j][0]))
                                subs[j] = subs[j][1:]
                                all_set -= 1

            groups_consider = 0
            for t in groups:
                groups_consider = max(groups_consider, len(t))
            x = x.repeat((1, groups_consider + 1))
            for i in range(len(x)):
                for j in range(1, len(groups[i]) + 1):
                    for p, q in groups[i][j - 1]:
                        x[i][j * self.in_shape[0] + p] = q
            y = self.embed(x.long()).view(-1, 1, self.in_shape[0], self.dim)
            if self.delta != 1:
                for id in range(len(y)):
                    item_group_id = id % (groups_consider + 1)
                    item_id = id - item_group_id
                    if item_group_id == 0: continue
                    y[id] = y[id] * self.delta + (1 - self.delta) * y[item_id]

            return ai.TaggedDomain(y, tag="magic" + str(groups_consider + 1))
        elif isinstance(x, torch.Tensor):  # it is a Point
            y = self.embed(x.long()).view(-1, 1, self.in_shape[0], self.dim)
            return y
        elif x.isPoint():  # convert to Point, if the input is Point
            assert False
            y = x.center().vanillaTensorPart()
            y = self.embed(y.long()).view(-1, 1, self.in_shape[0], self.dim)
            return y
        else:
            raise NotImplementedError()
Exemplo n.º 5
0
 def abstract_forward(self, x, **kargs):
     if x.isPoint():
         return x
     return ai.TaggedDomain(x, self.MLoss(self, x))
 def domain(self, *args, **kargs):
     return ai.TaggedDomain(self.Domain(*args, **kargs), self)