Beispiel #1
0
    def correlate(
            self,
            cc_indx_batch_beta):  # given in terms of the flattened matrix.
        num_correlate = h.product(cc_indx_batch_beta.shape[1:])

        beta = h.zeros(
            self.head.shape).to_dtype() if self.beta is None else self.beta
        errors = h.zeros([0] + list(self.head.shape)).to_dtype(
        ) if self.errors is None else self.errors

        batch_size = beta.shape[0]
        new_errors = h.zeros([num_correlate] +
                             list(self.head.shape)).to_dtype()

        inds_i = torch.arange(batch_size, device=h.device).unsqueeze(1).long()

        nc = torch.arange(num_correlate, device=h.device).unsqueeze(1).long()

        new_errors = new_errors.permute(
            1, 0,
            *list(range(len(new_errors.shape)))[2:]).contiguous().view(
                batch_size, num_correlate, -1)
        new_errors[inds_i, nc.unsqueeze(0).expand([batch_size] + list(nc.shape)).squeeze(2), cc_indx_batch_beta] = \
            beta.view(batch_size, -1)[inds_i, cc_indx_batch_beta]

        new_errors = new_errors.permute(
            1, 0,
            *list(range(len(new_errors.shape)))[2:]).contiguous().view(
                num_correlate, batch_size, *beta.shape[1:])
        errors = torch.cat((errors, new_errors), dim=0)

        beta.view(batch_size, -1)[inds_i, cc_indx_batch_beta] = 0

        return self.new(self.head, beta, errors)
Beispiel #2
0
    def decorrelate(self, cc_indx_batch_err):  # keep these errors
        if self.errors is None:
            return self

        batch_size = self.head.shape[0]
        num_error_terms = self.errors.shape[0]

        beta = h.zeros(
            self.head.shape).to_dtype() if self.beta is None else self.beta
        errors = h.zeros([0] + list(self.head.shape)).to_dtype(
        ) if self.errors is None else self.errors

        inds_i = torch.arange(self.head.shape[0],
                              device=h.device).unsqueeze(1).long()
        errors = errors.to_dtype().permute(
            1, 0,
            *list(range(len(self.errors.shape)))[2:])

        sm = errors.clone()
        sm[inds_i, cc_indx_batch_err] = 0

        beta = beta.to_dtype() + sm.abs().sum(dim=1)

        errors = errors[inds_i, cc_indx_batch_err]
        errors = errors.permute(1, 0,
                                *list(range(len(
                                    self.errors.shape)))[2:]).contiguous()
        return self.new(self.head, beta, errors)
Beispiel #3
0
        def catNonNullErrors(er1, er2):  # the way of things is ugly
            erS, erL = (er1, er2)
            sS, sL = (erS.size()[0], erL.size()[0])

            if sS == sL:  # here we know we used transformers on either side which didnt introduce new error terms (this is a hack).
                return erS.cat(erL, dim + 1)

            extrasS = h.zeros([sL] + list(erS.size()[1:]))
            extrasL = h.zeros([sS] + list(erL.size()[1:]))

            erL = torch.cat((extrasL, erL), dim=0)
            erS = torch.cat((erS, extrasS), dim=0)

            return erS.cat(erL, dim + 1)
Beispiel #4
0
    def doop(er1, er2):
        erS, erL = (er1, er2)
        sS, sL = (erS.size()[0], erL.size()[0])

        if sS == sL: # TODO: here we know we used transformers on either side which didnt introduce new error terms (this is a hack for hybrid zonotopes and doesn't work with adaptive error term adding).
            return op(erS,erL)

        extrasS = h.zeros([sL] + list(erS.size()[1:]))
        extrasL = h.zeros([sS] + list(erL.size()[1:]))

        erL = torch.cat((extrasL, erL), dim=0)
        erS = torch.cat((erS, extrasS), dim=0)

        return op(erS,erL)
    def attack(self, model, xo, untargeted, target, w, loss_function=ai.stdLoss, **kargs):
        w = self.epsilon.getVal(c = w, **kargs)

        x = nn.Parameter(xo.clone(), requires_grad=True)
        gradorg = h.zeros(x.shape)
        is_eq = 1

        w = h.ones(x.shape) * w
        for i in range(self.k):
            if self.restart is not None and i % int(self.k / self.restart) == 0:
                x = is_eq * (torch.rand_like(xo) * w + xo) + (1 - is_eq) * x
                x = nn.Parameter(x, requires_grad = True)

            model.optimizer.zero_grad()

            out = model(x).vanillaTensorPart()
            loss = loss_function(out, target)

            loss.sum().backward(retain_graph=True)
            with torch.no_grad():
                oth = x.grad / torch.norm(x.grad, p=1)
                gradorg *= self.mu 
                gradorg += oth
                grad = (self.r * w / self.k) * ai.mysign(gradorg)
                if self.should_end:
                    is_eq = ai.mulIfEq(grad, out, target)
                x = (x + grad * is_eq) if untargeted else (x - grad * is_eq)

                x = xo + torch.min(torch.max(x - xo, -w),w)
                x.requires_grad_()

        model.optimizer.zero_grad()

        return x
Beispiel #6
0
def creluNIPS(dom):
    if dom.errors is None:
        if dom.beta is None:
            return dom.new(F.relu(dom.head), None, None)
        er = dom.beta 
        mx = F.relu(dom.head + er)
        mn = F.relu(dom.head - er)
        return dom.new((mn + mx) / 2, (mx - mn) / 2 , None)
    
    aber = torch.abs(dom.errors)

    sm = torch.sum(aber, 0) 

    if not dom.beta is None:
        sm += dom.beta

    mn = dom.head - sm
    mx = sm
    mx += dom.head

    mngz = mn >= 0

    zs = h.zeros(dom.head.shape)

    lam = torch.where(mx > 0, mx / (mx - mn), zs)
    mu = lam * mn * (-0.5)

    betaz = zs if dom.beta is None else dom.beta 
    
    newhead = torch.where(mngz, dom.head , lam * dom.head + mu)
    newbeta = torch.where(mngz, betaz    , lam * betaz + mu ) # mu is always positive on this side
    newerr = torch.where(mngz, dom.errors, lam * dom.errors )
    return dom.new(newhead, newbeta, newerr)
Beispiel #7
0
    def doop(er1, er2):
        erS, erL = (er1, er2)
        sS, sL = (erS.size()[0], erL.size()[0])

        if sS == sL:  # TODO: here we know we used transformers on either side which didnt introduce new error terms (this is a hack for hybrid zonotopes and doesn't work with adaptive error term adding).
            return op(erS, erL)

        if ref_errs is not None:
            sz = ref_errs.size()[0]
        else:
            sz = min(sS, sL)

        p1 = op(erS[:sz], erL[:sz])
        erSrem = erS[sz:]
        erLrem = erS[sz:]
        p2 = op(erSrem, h.zeros(erSrem.shape))
        p3 = op(h.zeros(erLrem.shape), erLrem)
        return torch.cat((p1, p2, p3), dim=0)
Beispiel #8
0
            def slidingMax(a):  # using maxpool
                k = a.shape[1]
                ml = a.min(dim=1)[0].unsqueeze(1)

                inp = torch.cat((h.zeros([batch_size, k]), a - ml), dim=1)
                mpl = F.max_pool1d(inp.unsqueeze(1),
                                   kernel_size=k,
                                   stride=1,
                                   padding=0,
                                   return_indices=False).squeeze(1)
                return mpl[:, :-1] + ml
Beispiel #9
0
 def attack(model, epsilon, x, target, k=20, mu=0.5):
     epsilon /= k
     x = Point(x.data, requires_grad=True)
     gradorg = Point(h.zeros(x.shape))
     for _ in range(k):
         model.optimizer.zero_grad()
         loss = model.stdLoss(x, None, target).sum()
         loss.backward()
         oth = x.grad / torch.norm(x.grad, p=1)
         gradorg = gradorg * mu + oth
         x.data = (x + epsilon * torch.sign(gradorg)).data
     return x
Beispiel #10
0
 def init(self, prev, out_channels, kernel_size, stride = 1, global_args = None, bias=True, padding = 0, **kargs):
     self.prev = prev
     self.in_channels = prev[0]
     self.out_channels = out_channels
     self.kernel_size = kernel_size
     self.stride = stride
     self.padding = padding
     self.use_softplus = h.default(global_args, 'use_softplus', False)
     
     weights_shape = (self.out_channels, self.in_channels, kernel_size, kernel_size)        
     self.weight = torch.nn.Parameter(torch.Tensor(*weights_shape))
     if bias:
         self.bias = torch.nn.Parameter(torch.Tensor(weights_shape[0]))
     else:
         self.bias = h.zeros(weights_shape[0])
         
     outshape = getShapeConv(prev, (out_channels, kernel_size, kernel_size), stride, padding)
     return outshape
Beispiel #11
0
    def softplus(self):
        if self.errors is None:
            if self.beta is None:
                return self.new(F.softplus(self.head), None, None)
            tp = F.softplus(self.head + self.beta)
            bt = F.softplus(self.head - self.beta)
            return self.new((tp + bt) / 2, (tp - bt) / 2, None)

        errors = self.concreteErrors()
        o = h.ones(self.head.size())

        def sp(hd):
            return F.softplus(
                hd)  # torch.log(o + torch.exp(hd))  # not very stable

        def spp(hd):
            ehd = torch.exp(hd)
            return ehd.div(ehd + o)

        def sppp(hd):
            ehd = torch.exp(hd)
            md = ehd + o
            return ehd.div(md.mul(md))

        fa = sp(self.head)
        fpa = spp(self.head)

        a = self.head

        k = torch.sum(errors.abs(), 0)

        def evalG(r):
            return r.mul(r).mul(sppp(a + r))

        m = torch.max(evalG(h.zeros(k.size())), torch.max(evalG(k), evalG(-k)))
        m = h.ifThenElse(a.abs().lt(k),
                         torch.max(m, torch.max(evalG(a), evalG(-a))), m)
        m /= 2

        return self.new(fa, m if self.beta is None else m + self.beta.mul(fpa),
                        None if self.errors is None else self.errors.mul(fpa))
Beispiel #12
0
    def attack(self,
               model,
               epsilon,
               xo,
               untargeted,
               target,
               loss_function=ai.stdLoss):
        if not self.epsilon is None:
            epsilon = self.epsilon
        x = nn.Parameter(xo.clone(), requires_grad=True)
        gradorg = h.zeros(x.shape)
        is_eq = 1
        for i in range(self.k):
            if self.restart is not None and i % int(
                    self.k / self.restart) == 0:
                x = is_eq * (torch.randn_like(xo) * epsilon + xo) + (1 -
                                                                     is_eq) * x
                x = nn.Parameter(x, requires_grad=True)

            model.optimizer.zero_grad()

            out = model(x)
            loss = loss_function(out, target)

            loss.backward()
            with torch.no_grad():
                oth = x.grad / torch.norm(x.grad, p=1)
                gradorg *= self.mu
                gradorg += oth
                grad = (self.r * epsilon / self.k) * ai.mysign(gradorg)
                if self.should_end:
                    is_eq = ai.mulIfEq(grad, out, target)
                x = (x + grad * is_eq) if untargeted else (x - grad * is_eq)
                x = xo + torch.clamp(x - xo, -epsilon, epsilon)
                x.requires_grad_()

        model.optimizer.zero_grad()
        return x
Beispiel #13
0
def test(models, epoch, f = None):
    global num_tests
    num_tests += 1
    class MStat:
        def __init__(self, model):
            model.eval()
            self.model = model
            self.correct = 0
            class Stat:
                def __init__(self, d, dnm):
                    self.domain = d
                    self.name = dnm
                    self.width = 0
                    self.max_eps = 0
                    self.safe = 0
                    self.proved = 0
                    self.time = 0
            self.domains = [ Stat(h.parseValues(domains,d), h.catStrs(d)) for d in args.test_domain ]
    model_stats = [ MStat(m) for m in models ]
        
    num_its = 0
    saved_data_target = []
    for data, target in test_loader:
        if num_its >= args.test_size:
            break

        if num_tests == 1:
            saved_data_target += list(zip(list(data), list(target)))
        
        num_its += data.size()[0]
        if h.use_cuda:
            data, target = data.cuda(), target.cuda()

        for m in model_stats:

            with torch.no_grad():
                pred = m.model(data).data.max(1, keepdim=True)[1] # get the index of the max log-probability
                m.correct += pred.eq(target.data.view_as(pred)).sum()

            for stat in m.domains:
                timer = Timer(shouldPrint = False)
                with timer:
                    def calcData(data, target):
                        box = stat.domain.box(data, m.model.w, model=m.model, untargeted = True, target=target)
                        with torch.no_grad():
                            bs = m.model(box)
                            org = m.model(data).max(1,keepdim=True)[1]
                            stat.width += bs.diameter().sum().item() # sum up batch loss
                            stat.proved += bs.isSafe(org).sum().item()
                            stat.safe += bs.isSafe(target).sum().item()
                            stat.max_eps += 0 # TODO: calculate max_eps

                    if m.model.net.neuronCount() < 5000 or stat.domain in SYMETRIC_DOMAINS:
                        calcData(data, target)
                    else:
                        for d,t in zip(data, target):
                            calcData(d.unsqueeze(0),t.unsqueeze(0))
                stat.time += timer.getUnitTime()
                
    l = num_its # len(test_loader.dataset)
    for m in model_stats:

        pr_corr = float(m.correct) / float(l)
        if args.use_schedule:
            m.model.lrschedule.step(1 - pr_corr)
        
        h.printBoth(('Test: {:12} trained with {:'+ str(largest_domain) +'} - Avg sec/ex {:1.12f}, Accuracy: {}/{} ({:3.1f}%)').format(
            m.model.name, m.model.ty.name,
            m.model.speed,
            m.correct, l, 100. * pr_corr), f = f)
        
        model_stat_rec = ""
        for stat in m.domains:
            pr_safe = stat.safe / l
            pr_proved = stat.proved / l
            pr_corr_given_proved = pr_safe / pr_proved if pr_proved > 0 else 0.0
            h.printBoth(("\t{:" + str(largest_test_domain)+"} - Width: {:<36.16f} Pr[Proved]={:<1.3f}  Pr[Corr and Proved]={:<1.3f}  Pr[Corr|Proved]={:<1.3f} AvgMaxEps: {:1.10f} Time = {:<7.5f}").format(
                stat.name, 
                stat.width / l, 
                pr_proved, 
                pr_safe, pr_corr_given_proved, 
                stat.max_eps / l,
                stat.time), f = f)
            model_stat_rec += "{}_{:1.3f}_{:1.3f}_{:1.3f}__".format(stat.name, pr_proved, pr_safe, pr_corr_given_proved)
        prepedname = m.model.ty.name.replace(" ", "_").replace(",", "").replace("(", "_").replace(")", "_").replace("=", "_")
        net_file = os.path.join(out_dir, m.model.name +"__" +prepedname + "_checkpoint_"+str(epoch)+"_with_{:1.3f}".format(pr_corr))

        h.printBoth("\tSaving netfile: {}\n".format(net_file + ".net"), f = f)

        if num_tests % args.save_freq == 1 or args.save_freq == 1 and not args.dont_write:
            torch.save(m.model.net, net_file + ".pynet")
            
            with h.mopen(args.dont_write, net_file + ".net", "w") as f2:
                m.model.net.printNet(f2)
                f2.close()
            if args.onyx:
                nn = copy.deepcopy(m.model.net)
                nn.remove_norm()
                torch.onnx.export(nn, h.zeros([1] + list(input_dims)), net_file + ".onyx", 
                                  verbose=False, input_names=["actual_input"] + ["param"+str(i) for i in range(len(list(nn.parameters())))], output_names=["output"])


    if num_tests == 1 and not args.dont_write:
        img_dir = os.path.join(out_dir, "images")
        if not os.path.exists(img_dir):
            os.makedirs(img_dir)
        for img_num,(img,target) in zip(range(args.number_save_images), saved_data_target[:args.number_save_images]):
            sz = ""
            for s in img.size():
                sz += str(s) + "x"
            sz = sz[:-1]

            img_file = os.path.join(img_dir, args.dataset + "_" + sz + "_"+ str(img_num))
            if img_num == 0:
                print("Saving image to: ", img_file + ".img")
            with open(img_file + ".img", "w") as imgfile:
                flatimg = img.view(h.product(img.size()))
                for t in flatimg.cpu():
                    print(decimal.Decimal(float(t)).__format__("f"), file=imgfile)
            with open(img_file + ".class" , "w") as imgfile:
                print(int(target.item()), file=imgfile)
Beispiel #14
0
def train(epoch, models):
    global total_batches_seen

    for model in models:
        model.train()

    for batch_idx, (data, target) in enumerate(train_loader):
        total_batches_seen += 1
        time = float(total_batches_seen) / len(train_loader)
        if h.use_cuda:
            data, target = data.cuda(), target.cuda()

        for model in models:
            model.global_num += data.size()[0]

            timer = Timer(
                "train a sample from " + model.name + " with " + model.ty.name,
                data.size()[0], False)
            lossy = 0
            with timer:
                for s in model.getSpec(data.to_dtype(), target, time=time):
                    model.optimizer.zero_grad()
                    loss = model.aiLoss(*s, time=time, **vargs).mean(dim=0)
                    lossy += loss.detach().item()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
                    for p in model.parameters():
                        if p is not None and torch.isnan(p).any():
                            print("Such nan in vals")
                        if p is not None and p.grad is not None and torch.isnan(
                                p.grad).any():
                            print("Such nan in postmagic")
                            stdv = 1 / math.sqrt(h.product(p.data.shape))
                            p.grad = torch.where(
                                torch.isnan(p.grad),
                                torch.normal(mean=h.zeros(p.grad.shape),
                                             std=stdv), p.grad)

                    model.optimizer.step()

                    for p in model.parameters():
                        if p is not None and torch.isnan(p).any():
                            print("Such nan in vals after grad")
                            stdv = 1 / math.sqrt(h.product(p.data.shape))
                            p.data = torch.where(
                                torch.isnan(p.data),
                                torch.normal(mean=h.zeros(p.data.shape),
                                             std=stdv), p.data)

                    if args.clip_norm:
                        model.clip_norm()
                    for p in model.parameters():
                        if p is not None and torch.isnan(p).any():
                            raise Exception("Such nan in vals after clip")

            model.addSpeed(timer.getUnitTime())

            if batch_idx % args.log_interval == 0:
                print((
                    'Train Epoch {:12} {:' + str(largest_domain) +
                    '}: {:3} [{:7}/{} ({:.0f}%)] \tAvg sec/ex {:1.8f}\tLoss: {:.6f}'
                ).format(model.name, model.ty.name, epoch,
                         batch_idx * len(data), len(train_loader.dataset),
                         100. * batch_idx / len(train_loader), model.speed,
                         lossy))
Beispiel #15
0
def test(models, epoch, f=None):
    global num_tests
    num_tests += 1

    class MStat:
        def __init__(self, model):
            model.eval()
            self.model = model
            self.correct = 0

            class Stat:
                def __init__(self, d, dnm):
                    self.domain = d
                    self.name = dnm
                    self.width = 0
                    self.max_eps = None
                    self.safe = 0
                    self.proved = 0
                    self.time = 0

            self.domains = [
                Stat(h.parseValues(d, goals), h.catStrs(d))
                for d in args.test_domain
            ]

    model_stats = [MStat(m) for m in models]
    dict_map = dict(np.load("./dataset/AG/dict_map.npy").item())
    lines = open("./dataset/en.key1").readlines()
    adjacent_keys = [[] for i in range(len(dict_map))]
    for line in lines:
        tmp = line.strip().split()
        ret = set(tmp[1:]).intersection(dict_map.keys())
        ids = []
        for x in ret:
            ids.append(dict_map[x])
        adjacent_keys[dict_map[tmp[0]]].extend(ids)

    num_its = 0
    saved_data_target = []
    for data, target in test_loader:
        if num_its >= args.test_size:
            break

        if num_tests == 1:
            saved_data_target += list(zip(list(data), list(target)))

        num_its += data.size()[0]
        if num_its % 100 == 0:
            print(num_its, model_stats[0].domains[0].safe * 100.0 / num_its)
        if args.test_swap_delta > 0:
            length = data.size()[1]
            data = data.repeat(1, length)
            for i in data:
                for j in range(length - 1):
                    for _ in range(args.test_swap_delta):
                        t = np.random.randint(0, length)
                        while len(adjacent_keys[int(i[t])]) == 0:
                            t = np.random.randint(0, length)
                        cid = int(i[t])
                        i[j * length + t] = adjacent_keys[cid][0]
            target = (target.view(-1, 1).repeat(1, length)).view(-1)
            data = data.view(-1, length)

        if h.use_cuda:
            data, target = data.cuda().to_dtype(), target.cuda()

        for m in model_stats:

            with torch.no_grad():
                pred = m.model(data).vanillaTensorPart().max(1, keepdim=True)[
                    1]  # get the index of the max log-probability
                m.correct += pred.eq(target.data.view_as(pred)).sum()

            for stat in m.domains:
                timer = Timer(shouldPrint=False)
                with timer:

                    def calcData(data, target):
                        box = stat.domain.box(data,
                                              w=m.model.w,
                                              model=m.model,
                                              untargeted=True,
                                              target=target).to_dtype()
                        with torch.no_grad():
                            bs = m.model(box)
                            org = m.model(data).vanillaTensorPart().max(
                                1, keepdim=True)[1]
                            stat.width += bs.diameter().sum().item(
                            )  # sum up batch loss
                            stat.proved += bs.isSafe(org).sum().item()
                            stat.safe += bs.isSafe(target).sum().item()
                            # stat.max_eps += 0 # TODO: calculate max_eps

                    if m.model.net.neuronCount(
                    ) < 5000 or stat.domain in SYMETRIC_DOMAINS:
                        calcData(data, target)
                    else:
                        if args.test_swap_delta > 0:
                            length = data.size()[1]
                            pre_stat = copy.deepcopy(stat)
                            for i, (d, t) in enumerate(zip(data, target)):
                                calcData(d.unsqueeze(0), t.unsqueeze(0))
                                if (i + 1) % length == 0:
                                    d_proved = (stat.proved -
                                                pre_stat.proved) // length
                                    d_safe = (stat.safe -
                                              pre_stat.safe) // length
                                    d_width = (stat.width -
                                               pre_stat.width) / length
                                    stat.proved = pre_stat.proved + d_proved
                                    stat.safe = pre_stat.safe + d_safe
                                    stat.width = pre_stat.width + d_width
                                    pre_stat = copy.deepcopy(stat)
                        else:
                            for d, t in zip(data, target):
                                calcData(d.unsqueeze(0), t.unsqueeze(0))
                stat.time += timer.getUnitTime()

    l = num_its  # len(test_loader.dataset)
    for m in model_stats:
        if args.lr_multistep:
            m.model.lrschedule.step()

        pr_corr = float(m.correct) / float(l)
        if args.use_schedule:
            m.model.lrschedule.step(1 - pr_corr)

        h.printBoth(
            ('Test: {:12} trained with {:' + str(largest_domain) +
             '} - Avg sec/ex {:1.12f}, Accuracy: {}/{} ({:3.1f}%)').format(
                 m.model.name, m.model.ty.name, m.model.speed, m.correct, l,
                 100. * pr_corr),
            f=f)

        model_stat_rec = ""
        for stat in m.domains:
            pr_safe = stat.safe / l
            pr_proved = stat.proved / l
            pr_corr_given_proved = pr_safe / pr_proved if pr_proved > 0 else 0.0
            h.printBoth((
                "\t{:" + str(largest_test_domain) +
                "} - Width: {:<36.16f} Pr[Proved]={:<1.3f}  Pr[Corr and Proved]={:<1.3f}  Pr[Corr|Proved]={:<1.3f} {}Time = {:<7.5f}"
            ).format(
                stat.name, stat.width / l, pr_proved, pr_safe,
                pr_corr_given_proved,
                "AvgMaxEps: {:1.10f} ".format(stat.max_eps / l)
                if stat.max_eps is not None else "", stat.time),
                        f=f)
            model_stat_rec += "{}_{:1.3f}_{:1.3f}_{:1.3f}__".format(
                stat.name, pr_proved, pr_safe, pr_corr_given_proved)
        prepedname = m.model.ty.name.replace(" ", "_").replace(
            ",", "").replace("(", "_").replace(")", "_").replace("=", "_")
        net_file = os.path.join(
            out_dir, m.model.name + "__" + prepedname + "_checkpoint_" +
            str(epoch) + "_with_{:1.3f}".format(pr_corr))

        h.printBoth("\tSaving netfile: {}\n".format(net_file + ".pynet"), f=f)

        if (num_tests % args.save_freq == 1 or args.save_freq
                == 1) and not args.dont_write and (num_tests > 1
                                                   or args.write_first):
            print("Actually Saving")
            torch.save(m.model.net, net_file + ".pynet")
            if args.save_dot_net:
                with h.mopen(args.dont_write, net_file + ".net", "w") as f2:
                    m.model.net.printNet(f2)
                    f2.close()
            if args.onyx:
                nn = copy.deepcopy(m.model.net)
                nn.remove_norm()
                torch.onnx.export(
                    nn,
                    h.zeros([1] + list(input_dims)),
                    net_file + ".onyx",
                    verbose=False,
                    input_names=["actual_input"] + [
                        "param" + str(i)
                        for i in range(len(list(nn.parameters())))
                    ],
                    output_names=["output"])

    if num_tests == 1 and not args.dont_write:
        img_dir = os.path.join(out_dir, "images")
        if not os.path.exists(img_dir):
            os.makedirs(img_dir)
        for img_num, (img, target) in zip(
                range(args.number_save_images),
                saved_data_target[:args.number_save_images]):
            sz = ""
            for s in img.size():
                sz += str(s) + "x"
            sz = sz[:-1]

            img_file = os.path.join(
                img_dir, args.dataset + "_" + sz + "_" + str(img_num))
            if img_num == 0:
                print("Saving image to: ", img_file + ".img")
            with open(img_file + ".img", "w") as imgfile:
                flatimg = img.view(h.product(img.size()))
                for t in flatimg.cpu():
                    print(decimal.Decimal(float(t)).__format__("f"),
                          file=imgfile)
            with open(img_file + ".class", "w") as imgfile:
                print(int(target.item()), file=imgfile)
Beispiel #16
0
def train(epoch, models, decay=True):
    global total_batches_seen

    for model in models:
        model.train()
        #if args.decay_fir:
        #    if epoch > 1 and isinstance(model.ty, goals.DList) and len(model.ty.al) == 2 and decay:
        #        for (i, a) in enumerate(model.ty.al):
        #            if i == 1:
        #                model.ty.al[i] = (a[0], Const(min(a[1].getVal() + 0.0025, 0.75)))
        #            else:
        #                model.ty.al[i] = (a[0], Const(max(a[1].getVal() - 0.0025, 0.25)))

    for batch_idx, (data, target) in enumerate(train_loader):
        if total_batches_seen * args.batch_size % 4000 == 0:
            for model in models:
                if args.decay_fir:
                    if isinstance(model.ty, goals.DList) and len(
                            model.ty.al) == 2 and decay:
                        for (i, a) in enumerate(model.ty.al):
                            if i == 1:
                                model.ty.al[i] = (a[0],
                                                  Const(
                                                      min(
                                                          a[1].getVal() +
                                                          0.0025, 3)))
                            # else:
                            #    model.ty.al[i] = (a[0], Const(max(a[1].getVal() - 0.00075, 0.25)))

        total_batches_seen += 1
        time = float(total_batches_seen) / len(train_loader)
        if h.use_cuda:
            data, target = data.cuda(), target.cuda()

        for model in models:
            model.global_num += data.size()[0]

            timer = Timer(
                "train a sample from " + model.name + " with " + model.ty.name,
                data.size()[0], False)
            lossy = 0
            with timer:
                for s in model.getSpec(data.to_dtype(), target, time=time):
                    model.optimizer.zero_grad()
                    loss = model.aiLoss(*s, time=time, **vargs).mean(dim=0)
                    lossy += loss.detach().item()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
                    for p in model.parameters():
                        if p is not None and torch.isnan(p).any():
                            print("Such nan in vals")
                        if p is not None and p.grad is not None and torch.isnan(
                                p.grad).any():
                            print("Such nan in postmagic")
                            stdv = 1 / math.sqrt(h.product(p.data.shape))
                            p.grad = torch.where(
                                torch.isnan(p.grad),
                                torch.normal(mean=h.zeros(p.grad.shape),
                                             std=stdv), p.grad)

                    model.optimizer.step()

                    for p in model.parameters():
                        if p is not None and torch.isnan(p).any():
                            print("Such nan in vals after grad")
                            stdv = 1 / math.sqrt(h.product(p.data.shape))
                            p.data = torch.where(
                                torch.isnan(p.data),
                                torch.normal(mean=h.zeros(p.data.shape),
                                             std=stdv), p.data)

                    if args.clip_norm:
                        model.clip_norm()
                    for p in model.parameters():
                        if p is not None and torch.isnan(p).any():
                            raise Exception("Such nan in vals after clip")

            model.addSpeed(timer.getUnitTime())

            if batch_idx % args.log_interval == 0:
                print((
                    'Train Epoch {:12} {:' + str(largest_domain) +
                    '}: {:3} [{:7}/{} ({:.0f}%)] \tAvg sec/ex {:1.8f}\tLoss: {:.6f}'
                ).format(model.name, model.ty.name, epoch,
                         batch_idx * len(data), len(train_loader.dataset),
                         100. * batch_idx / len(train_loader), model.speed,
                         lossy))

    val = 0
    val_origin = 0
    batch_cnt = 0
    for batch_idx, (data, target) in enumerate(val_loader):
        batch_cnt += 1
        if h.use_cuda:
            data, target = data.cuda(), target.cuda()

        for model in models:
            for s in model.getSpec(data.to_dtype(), target):
                loss = model.aiLoss(*s, **vargs).mean(dim=0)
                val += loss.detach().item()

            loss = model.aiLoss(data, target, **vargs).mean(dim=0)
            val_origin += loss.detach().item()

    return val_origin / batch_cnt, val / batch_cnt
    dom = domain.box(inpt, w=None)
    o = net(dom, onyx=True).unsqueeze(1)

    out = torch.cat([
        o.vanillaTensorPart(),
        o.lb().vanillaTensorPart(),
        o.ub().vanillaTensorPart()
    ],
                    dim=1)
    return out


input_shape = [args.batch_size] + list(net.inShape)
if args.tf_input:
    input_shape = [args.batch_size] + list(net.inShape)[1:] + [net.inShape[0]]
dummy = h.zeros(input_shape)

abstractNet(dummy)


class AbstractNet(nn.Module):
    def __init__(self, domain, net, abstractNet):
        super(AbstractNet, self).__init__()
        self.net = net
        self.abstractNet = abstractNet
        if hasattr(domain, "net") and domain.net is not None:
            self.netDom = domain.net

    def forward(self, inpt):
        return self.abstractNet(inpt)