Esempio n. 1
0
File: ai.py Progetto: ForeverZyh/A3T
    def correlate(
            self,
            cc_indx_batch_beta):  # given in terms of the flattened matrix.
        num_correlate = h.product(cc_indx_batch_beta.shape[1:])

        beta = h.zeros(
            self.head.shape).to_dtype() if self.beta is None else self.beta
        errors = h.zeros([0] + list(self.head.shape)).to_dtype(
        ) if self.errors is None else self.errors

        batch_size = beta.shape[0]
        new_errors = h.zeros([num_correlate] +
                             list(self.head.shape)).to_dtype()

        inds_i = torch.arange(batch_size, device=h.device).unsqueeze(1).long()

        nc = torch.arange(num_correlate, device=h.device).unsqueeze(1).long()

        new_errors = new_errors.permute(
            1, 0,
            *list(range(len(new_errors.shape)))[2:]).contiguous().view(
                batch_size, num_correlate, -1)
        new_errors[inds_i, nc.unsqueeze(0).expand([batch_size] + list(nc.shape)).squeeze(2), cc_indx_batch_beta] = \
            beta.view(batch_size, -1)[inds_i, cc_indx_batch_beta]

        new_errors = new_errors.permute(
            1, 0,
            *list(range(len(new_errors.shape)))[2:]).contiguous().view(
                num_correlate, batch_size, *beta.shape[1:])
        errors = torch.cat((errors, new_errors), dim=0)

        beta.view(batch_size, -1)[inds_i, cc_indx_batch_beta] = 0

        return self.new(self.head, beta, errors)
Esempio n. 2
0
File: ai.py Progetto: ForeverZyh/A3T
    def decorrelate(self, cc_indx_batch_err):  # keep these errors
        if self.errors is None:
            return self

        batch_size = self.head.shape[0]
        num_error_terms = self.errors.shape[0]

        beta = h.zeros(
            self.head.shape).to_dtype() if self.beta is None else self.beta
        errors = h.zeros([0] + list(self.head.shape)).to_dtype(
        ) if self.errors is None else self.errors

        inds_i = torch.arange(self.head.shape[0],
                              device=h.device).unsqueeze(1).long()
        errors = errors.to_dtype().permute(
            1, 0,
            *list(range(len(self.errors.shape)))[2:])

        sm = errors.clone()
        sm[inds_i, cc_indx_batch_err] = 0

        beta = beta.to_dtype() + sm.abs().sum(dim=1)

        errors = errors[inds_i, cc_indx_batch_err]
        errors = errors.permute(1, 0,
                                *list(range(len(
                                    self.errors.shape)))[2:]).contiguous()
        return self.new(self.head, beta, errors)
Esempio n. 3
0
File: ai.py Progetto: ForeverZyh/A3T
def creluNIPS(dom):
    if dom.errors is None:
        if dom.beta is None:
            return dom.new(F.relu(dom.head), None, None)
        er = dom.beta
        mx = F.relu(dom.head + er)
        mn = F.relu(dom.head - er)
        return dom.new((mn + mx) / 2, (mx - mn) / 2, None)

    sm = torch.sum(torch.abs(dom.errors), 0)

    if not dom.beta is None:
        sm += dom.beta

    mn = dom.head - sm
    mx = dom.head + sm

    mngz = mn >= 0.0

    zs = h.zeros(dom.head.shape)

    diff = mx - mn

    lam = torch.where((mx > 0) & (diff > 0.0), mx / diff, zs)
    mu = lam * mn * (-0.5)

    betaz = zs if dom.beta is None else dom.beta

    newhead = torch.where(mngz, dom.head, lam * dom.head + mu)
    mngz += diff <= 0.0
    newbeta = torch.where(mngz, betaz, lam * betaz +
                          mu)  # mu is always positive on this side
    newerr = torch.where(mngz, dom.errors, lam * dom.errors)
    return dom.new(newhead, newbeta, newerr)
Esempio n. 4
0
File: ai.py Progetto: ForeverZyh/A3T
    def doop(er1, er2):
        erS, erL = (er1, er2)
        sS, sL = (erS.size()[0], erL.size()[0])

        if sS == sL:  # TODO: here we know we used transformers on either side which didnt introduce new error terms (this is a hack for hybrid zonotopes and doesn't work with adaptive error term adding).
            return op(erS, erL)

        if ref_errs is not None:
            sz = ref_errs.size()[0]
        else:
            sz = min(sS, sL)

        p1 = op(erS[:sz], erL[:sz])
        erSrem = erS[sz:]
        erLrem = erS[sz:]
        p2 = op(erSrem, h.zeros(erSrem.shape))
        p3 = op(h.zeros(erLrem.shape), erLrem)
        return torch.cat((p1, p2, p3), dim=0)
Esempio n. 5
0
File: ai.py Progetto: ForeverZyh/A3T
            def slidingMax(a):  # using maxpool
                k = a.shape[1]
                ml = a.min(dim=1)[0].unsqueeze(1)

                inp = torch.cat((h.zeros([batch_size, k]), a - ml), dim=1)
                mpl = F.max_pool1d(inp.unsqueeze(1),
                                   kernel_size=k,
                                   stride=1,
                                   padding=0,
                                   return_indices=False).squeeze(1)
                return mpl[:, :-1] + ml
Esempio n. 6
0
    def attack(self,
               model,
               xo,
               untargeted,
               target,
               w,
               loss_function=ai.stdLoss,
               **kargs):
        w = self.epsilon.getVal(c=w, **kargs)

        x = nn.Parameter(xo.clone(), requires_grad=True)
        gradorg = h.zeros(x.shape)
        is_eq = 1

        w = h.ones(x.shape) * w
        for i in range(self.k):
            if self.restart is not None and i % int(
                    self.k / self.restart) == 0:
                x = is_eq * (torch.rand_like(xo) * w + xo) + (1 - is_eq) * x
                x = nn.Parameter(x, requires_grad=True)

            model.optimizer.zero_grad()

            out = model(x).vanillaTensorPart()
            loss = loss_function(out, target)

            loss.sum().backward(retain_graph=True)
            with torch.no_grad():
                oth = x.grad / torch.norm(x.grad, p=1)
                gradorg *= self.mu
                gradorg += oth
                grad = (self.r * w / self.k) * ai.mysign(gradorg)
                if self.should_end:
                    is_eq = ai.mulIfEq(grad, out, target)
                x = (x + grad * is_eq) if untargeted else (x - grad * is_eq)

                x = xo + torch.min(torch.max(x - xo, -w), w)
                x.requires_grad_()

        model.optimizer.zero_grad()

        return x
Esempio n. 7
0
File: ai.py Progetto: ForeverZyh/A3T
    def softplus(self):
        if self.errors is None:
            if self.beta is None:
                return self.new(F.softplus(self.head), None, None)
            tp = F.softplus(self.head + self.beta)
            bt = F.softplus(self.head - self.beta)
            return self.new((tp + bt) / 2, (tp - bt) / 2, None)

        errors = self.concreteErrors()
        o = h.ones(self.head.size())

        def sp(hd):
            return F.softplus(
                hd)  # torch.log(o + torch.exp(hd))  # not very stable

        def spp(hd):
            ehd = torch.exp(hd)
            return ehd.div(ehd + o)

        def sppp(hd):
            ehd = torch.exp(hd)
            md = ehd + o
            return ehd.div(md.mul(md))

        fa = sp(self.head)
        fpa = spp(self.head)

        a = self.head

        k = torch.sum(errors.abs(), 0)

        def evalG(r):
            return r.mul(r).mul(sppp(a + r))

        m = torch.max(evalG(h.zeros(k.size())), torch.max(evalG(k), evalG(-k)))
        m = h.ifThenElse(a.abs().lt(k),
                         torch.max(m, torch.max(evalG(a), evalG(-a))), m)
        m /= 2

        return self.new(fa, m if self.beta is None else m + self.beta.mul(fpa),
                        None if self.errors is None else self.errors.mul(fpa))
Esempio n. 8
0
def train_epoch(epoch, model, victim_model, attack, args, train_loader):
    vargs = vars(args)
    model.train()

    print(("Cur ratio: {}").format(S.TrainInfo.cur_ratio))
    assert isinstance(model.ty, goals.DList) and len(model.ty.al) == 2
    for (i, a) in enumerate(model.ty.al):
        if not isinstance(a[0], goals.Point):
            model.ty.al[i] = (a[0],
                              S.Const(args.train_lambda *
                                      S.TrainInfo.cur_ratio))
        else:
            model.ty.al[i] = (
                a[0], S.Const(1 - args.train_lambda * S.TrainInfo.cur_ratio))

    for batch_idx, (data, target) in enumerate(train_loader):
        S.TrainInfo.total_batches_seen += 1
        time = float(S.TrainInfo.total_batches_seen) / len(train_loader)
        data, target = data.to(h.device), target.to(h.device)

        model.global_num += data.size()[0]
        lossy = 0
        adv_time = sys_time.time()
        if args.adv_train_num > 0:
            data, target = adv_batch(victim_model, attack, data, target,
                                     args.adv_train_num)

        adv_time = sys_time.time() - adv_time

        timer = Timer(
            "train a sample from " + model.name + " with " + model.ty.name,
            data.size()[0], False)
        with timer:
            for s in model.boxSpec(data.to_dtype(), target, time=time):
                model.optimizer.zero_grad()
                loss = model.aiLoss(*s, time=time, **vargs).mean(dim=0)
                lossy += loss.detach().item()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
                for p in model.parameters():
                    if not p.requires_grad:
                        continue
                    if p is not None and torch.isnan(p).any():
                        print("Such nan in vals")
                    if p is not None and p.grad is not None and torch.isnan(
                            p.grad).any():
                        print("Such nan in postmagic")
                        stdv = 1 / math.sqrt(h.product(p.data.shape))
                        p.grad = torch.where(
                            torch.isnan(p.grad),
                            torch.normal(mean=h.zeros(p.grad.shape), std=stdv),
                            p.grad)

                model.optimizer.step()

                for p in model.parameters():
                    if not p.requires_grad:
                        continue
                    if p is not None and torch.isnan(p).any():
                        print("Such nan in vals after grad")
                        stdv = 1 / math.sqrt(h.product(p.data.shape))
                        p.data = torch.where(
                            torch.isnan(p.data),
                            torch.normal(mean=h.zeros(p.data.shape), std=stdv),
                            p.data)

                if args.clip_norm:
                    model.clip_norm()
                for p in model.parameters():
                    if not p.requires_grad:
                        continue
                    if p is not None and torch.isnan(p).any():
                        raise Exception("Such nan in vals after clip")

        model.addSpeed(timer.getUnitTime() + adv_time / len(data))

        if batch_idx % args.log_interval == 0:
            print((
                'Train Epoch {:12} Mix(a=Point(),b=Box(),aw=1,bw=0) {:3} [{:7}/{} ({:.0f}%)] \tAvg sec/ex {:1.8f}\tLoss: {:.6f}'
            ).format(model.name, epoch,
                     batch_idx * len(data) // (args.adv_train_num + 1),
                     len(train_loader.dataset),
                     100. * batch_idx / len(train_loader), model.speed, lossy))