예제 #1
0
파일: ai.py 프로젝트: ForeverZyh/diffai
    def isSafe(self, target):
        assert len(self.al) > 0
        od, _ = torch.min(h.preDomRes(self.al[0], target).lb(), 1)
        for a in self.al[1:]:
            od1, _ = torch.min(h.preDomRes(a, target).lb(), 1)
            od = torch.min(od, od1)

        return od.gt(0.0).long()
예제 #2
0
    def loss(self, dom, target, width_weight=0, tot_weight=1, **args):
        if not self.width_weight is None:
            width_weight = self.width_weight
        if not self.tot_weight is None:
            tot_weight = self.tot_weight

        r = -h.preDomRes(dom, target).lb()
        tot = F.softplus(r.max(1)[0])  # kinda works

        if self.log_loss:
            tot = (tot + 1).log()
        if self.pow_loss is not None and self.pow_loss > 0 and self.pow_loss != 1:
            tot = tot.pow(self.pow_loss)

        ls = tot * tot_weight
        if width_weight > 0:
            ls += dom.diameter() * width_weight

        return ls / (width_weight + tot_weight)
예제 #3
0
 def labels(self):
     target = torch.max(self.ub(), 1)[1]
     l = list(h.preDomRes(self,target).lb()[0])
     return [target.item()] + [ i for i,v in zip(range(len(l)), l) if v <= 0]
예제 #4
0
 def isSafe(self, target):
     od,_ = torch.min(h.preDomRes(self,target).lb(), 1)
     return od.gt(0.0).long()
예제 #5
0
파일: ai.py 프로젝트: ForeverZyh/diffai
 def loss(self, target, **args):
     r = -h.preDomRes(self, target).lb()
     return F.softplus(r.max(1)[0])
예제 #6
0
    for inputs, targets in trainloader:
        inputs, targets = inputs.to(device), targets.to(device)

        with torch.no_grad():
            raw = net(inputs).vanillaTensorPart().detach()

        if epoch == 0:
            outputs = net(inputs)
            loss = criterion(outputs, targets)
        else:
            abstractinputs = domain.box(inputs,
                                        w=1.0 / 255,
                                        model=net,
                                        target=targets)
            abstractoutput = net(abstractinputs.to_dtype())
            loss = -helpers.preDomRes(abstractoutput, targets).lb()
            loss = loss.max(1)[0]
            loss = torch.nn.functional.softplus(loss).mean()

        meanloss.append(loss.cpu().data.numpy())

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(net.parameters(), 1)
        optimizer.step()

        _, predicted = raw.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        if random.randint(0, 30) == 0: