コード例 #1
0
ファイル: __main__.py プロジェクト: pombredanne/diffai
def train(epoch, models):
    for model in models:
        model.train()

    ep_tot = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        if h.use_cuda:
            data, target = data.cuda(), target.cuda()

        for model in models:
            model.global_num += data.size()[0]

            timer = Timer("train", "sample from " + model.name + " with " + model.ty.name, data.size()[0], False)
            lossy = 0
            with timer:
                for s in model.getSpec(data,target):
                    model.optimizer.zero_grad()

                    loss = model.aiLoss(*s, **vargs).sum() / data.size()[0]
                    lossy += loss.item()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
                    model.optimizer.step()
                    model.clip_norm()
            
            model.addSpeed(timer.getUnitTime())

            if batch_idx % args.log_interval == 0:
                print(('Train Epoch {:12} {:'+ str(largest_domain) +'}: {:3} [{:7}/{} ({:.0f}%)] \tAvg sec/ex {:1.8f}\tLoss: {:.6f}').format(
                    model.name,  model.ty.name,
                    epoch, 
                    batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), 
                    model.speed,
                    lossy))
コード例 #2
0
def train_epoch(epoch, model, victim_model, attack, args, train_loader):
    vargs = vars(args)
    model.train()

    print(("Cur ratio: {}").format(S.TrainInfo.cur_ratio))
    assert isinstance(model.ty, goals.DList) and len(model.ty.al) == 2
    for (i, a) in enumerate(model.ty.al):
        if not isinstance(a[0], goals.Point):
            model.ty.al[i] = (a[0],
                              S.Const(args.train_lambda *
                                      S.TrainInfo.cur_ratio))
        else:
            model.ty.al[i] = (
                a[0], S.Const(1 - args.train_lambda * S.TrainInfo.cur_ratio))

    for batch_idx, (data, target) in enumerate(train_loader):
        S.TrainInfo.total_batches_seen += 1
        time = float(S.TrainInfo.total_batches_seen) / len(train_loader)
        data, target = data.to(h.device), target.to(h.device)

        model.global_num += data.size()[0]
        lossy = 0
        adv_time = sys_time.time()
        if args.adv_train_num > 0:
            data, target = adv_batch(victim_model, attack, data, target,
                                     args.adv_train_num)

        adv_time = sys_time.time() - adv_time

        timer = Timer(
            "train a sample from " + model.name + " with " + model.ty.name,
            data.size()[0], False)
        with timer:
            for s in model.boxSpec(data.to_dtype(), target, time=time):
                model.optimizer.zero_grad()
                loss = model.aiLoss(*s, time=time, **vargs).mean(dim=0)
                lossy += loss.detach().item()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
                for p in model.parameters():
                    if not p.requires_grad:
                        continue
                    if p is not None and torch.isnan(p).any():
                        print("Such nan in vals")
                    if p is not None and p.grad is not None and torch.isnan(
                            p.grad).any():
                        print("Such nan in postmagic")
                        stdv = 1 / math.sqrt(h.product(p.data.shape))
                        p.grad = torch.where(
                            torch.isnan(p.grad),
                            torch.normal(mean=h.zeros(p.grad.shape), std=stdv),
                            p.grad)

                model.optimizer.step()

                for p in model.parameters():
                    if not p.requires_grad:
                        continue
                    if p is not None and torch.isnan(p).any():
                        print("Such nan in vals after grad")
                        stdv = 1 / math.sqrt(h.product(p.data.shape))
                        p.data = torch.where(
                            torch.isnan(p.data),
                            torch.normal(mean=h.zeros(p.data.shape), std=stdv),
                            p.data)

                if args.clip_norm:
                    model.clip_norm()
                for p in model.parameters():
                    if not p.requires_grad:
                        continue
                    if p is not None and torch.isnan(p).any():
                        raise Exception("Such nan in vals after clip")

        model.addSpeed(timer.getUnitTime() + adv_time / len(data))

        if batch_idx % args.log_interval == 0:
            print((
                'Train Epoch {:12} Mix(a=Point(),b=Box(),aw=1,bw=0) {:3} [{:7}/{} ({:.0f}%)] \tAvg sec/ex {:1.8f}\tLoss: {:.6f}'
            ).format(model.name, epoch,
                     batch_idx * len(data) // (args.adv_train_num + 1),
                     len(train_loader.dataset),
                     100. * batch_idx / len(train_loader), model.speed, lossy))
コード例 #3
0
ファイル: __main__.py プロジェクト: ForeverZyh/diffai
def test(models, epoch, f=None):
    global num_tests
    num_tests += 1

    class MStat:
        def __init__(self, model):
            model.eval()
            self.model = model
            self.correct = 0

            class Stat:
                def __init__(self, d, dnm):
                    self.domain = d
                    self.name = dnm
                    self.width = 0
                    self.max_eps = None
                    self.safe = 0
                    self.proved = 0
                    self.time = 0

            self.domains = [
                Stat(h.parseValues(d, goals), h.catStrs(d))
                for d in args.test_domain
            ]

    model_stats = [MStat(m) for m in models]
    dict_map = dict(np.load("./dataset/AG/dict_map.npy").item())
    lines = open("./dataset/en.key1").readlines()
    adjacent_keys = [[] for i in range(len(dict_map))]
    for line in lines:
        tmp = line.strip().split()
        ret = set(tmp[1:]).intersection(dict_map.keys())
        ids = []
        for x in ret:
            ids.append(dict_map[x])
        adjacent_keys[dict_map[tmp[0]]].extend(ids)

    num_its = 0
    saved_data_target = []
    for data, target in test_loader:
        if num_its >= args.test_size:
            break

        if num_tests == 1:
            saved_data_target += list(zip(list(data), list(target)))

        num_its += data.size()[0]
        if num_its % 100 == 0:
            print(num_its, model_stats[0].domains[0].safe * 100.0 / num_its)
        if args.test_swap_delta > 0:
            length = data.size()[1]
            data = data.repeat(1, length)
            for i in data:
                for j in range(length - 1):
                    for _ in range(args.test_swap_delta):
                        t = np.random.randint(0, length)
                        while len(adjacent_keys[int(i[t])]) == 0:
                            t = np.random.randint(0, length)
                        cid = int(i[t])
                        i[j * length + t] = adjacent_keys[cid][0]
            target = (target.view(-1, 1).repeat(1, length)).view(-1)
            data = data.view(-1, length)

        if h.use_cuda:
            data, target = data.cuda().to_dtype(), target.cuda()

        for m in model_stats:

            with torch.no_grad():
                pred = m.model(data).vanillaTensorPart().max(1, keepdim=True)[
                    1]  # get the index of the max log-probability
                m.correct += pred.eq(target.data.view_as(pred)).sum()

            for stat in m.domains:
                timer = Timer(shouldPrint=False)
                with timer:

                    def calcData(data, target):
                        box = stat.domain.box(data,
                                              w=m.model.w,
                                              model=m.model,
                                              untargeted=True,
                                              target=target).to_dtype()
                        with torch.no_grad():
                            bs = m.model(box)
                            org = m.model(data).vanillaTensorPart().max(
                                1, keepdim=True)[1]
                            stat.width += bs.diameter().sum().item(
                            )  # sum up batch loss
                            stat.proved += bs.isSafe(org).sum().item()
                            stat.safe += bs.isSafe(target).sum().item()
                            # stat.max_eps += 0 # TODO: calculate max_eps

                    if m.model.net.neuronCount(
                    ) < 5000 or stat.domain in SYMETRIC_DOMAINS:
                        calcData(data, target)
                    else:
                        if args.test_swap_delta > 0:
                            length = data.size()[1]
                            pre_stat = copy.deepcopy(stat)
                            for i, (d, t) in enumerate(zip(data, target)):
                                calcData(d.unsqueeze(0), t.unsqueeze(0))
                                if (i + 1) % length == 0:
                                    d_proved = (stat.proved -
                                                pre_stat.proved) // length
                                    d_safe = (stat.safe -
                                              pre_stat.safe) // length
                                    d_width = (stat.width -
                                               pre_stat.width) / length
                                    stat.proved = pre_stat.proved + d_proved
                                    stat.safe = pre_stat.safe + d_safe
                                    stat.width = pre_stat.width + d_width
                                    pre_stat = copy.deepcopy(stat)
                        else:
                            for d, t in zip(data, target):
                                calcData(d.unsqueeze(0), t.unsqueeze(0))
                stat.time += timer.getUnitTime()

    l = num_its  # len(test_loader.dataset)
    for m in model_stats:
        if args.lr_multistep:
            m.model.lrschedule.step()

        pr_corr = float(m.correct) / float(l)
        if args.use_schedule:
            m.model.lrschedule.step(1 - pr_corr)

        h.printBoth(
            ('Test: {:12} trained with {:' + str(largest_domain) +
             '} - Avg sec/ex {:1.12f}, Accuracy: {}/{} ({:3.1f}%)').format(
                 m.model.name, m.model.ty.name, m.model.speed, m.correct, l,
                 100. * pr_corr),
            f=f)

        model_stat_rec = ""
        for stat in m.domains:
            pr_safe = stat.safe / l
            pr_proved = stat.proved / l
            pr_corr_given_proved = pr_safe / pr_proved if pr_proved > 0 else 0.0
            h.printBoth((
                "\t{:" + str(largest_test_domain) +
                "} - Width: {:<36.16f} Pr[Proved]={:<1.3f}  Pr[Corr and Proved]={:<1.3f}  Pr[Corr|Proved]={:<1.3f} {}Time = {:<7.5f}"
            ).format(
                stat.name, stat.width / l, pr_proved, pr_safe,
                pr_corr_given_proved,
                "AvgMaxEps: {:1.10f} ".format(stat.max_eps / l)
                if stat.max_eps is not None else "", stat.time),
                        f=f)
            model_stat_rec += "{}_{:1.3f}_{:1.3f}_{:1.3f}__".format(
                stat.name, pr_proved, pr_safe, pr_corr_given_proved)
        prepedname = m.model.ty.name.replace(" ", "_").replace(
            ",", "").replace("(", "_").replace(")", "_").replace("=", "_")
        net_file = os.path.join(
            out_dir, m.model.name + "__" + prepedname + "_checkpoint_" +
            str(epoch) + "_with_{:1.3f}".format(pr_corr))

        h.printBoth("\tSaving netfile: {}\n".format(net_file + ".pynet"), f=f)

        if (num_tests % args.save_freq == 1 or args.save_freq
                == 1) and not args.dont_write and (num_tests > 1
                                                   or args.write_first):
            print("Actually Saving")
            torch.save(m.model.net, net_file + ".pynet")
            if args.save_dot_net:
                with h.mopen(args.dont_write, net_file + ".net", "w") as f2:
                    m.model.net.printNet(f2)
                    f2.close()
            if args.onyx:
                nn = copy.deepcopy(m.model.net)
                nn.remove_norm()
                torch.onnx.export(
                    nn,
                    h.zeros([1] + list(input_dims)),
                    net_file + ".onyx",
                    verbose=False,
                    input_names=["actual_input"] + [
                        "param" + str(i)
                        for i in range(len(list(nn.parameters())))
                    ],
                    output_names=["output"])

    if num_tests == 1 and not args.dont_write:
        img_dir = os.path.join(out_dir, "images")
        if not os.path.exists(img_dir):
            os.makedirs(img_dir)
        for img_num, (img, target) in zip(
                range(args.number_save_images),
                saved_data_target[:args.number_save_images]):
            sz = ""
            for s in img.size():
                sz += str(s) + "x"
            sz = sz[:-1]

            img_file = os.path.join(
                img_dir, args.dataset + "_" + sz + "_" + str(img_num))
            if img_num == 0:
                print("Saving image to: ", img_file + ".img")
            with open(img_file + ".img", "w") as imgfile:
                flatimg = img.view(h.product(img.size()))
                for t in flatimg.cpu():
                    print(decimal.Decimal(float(t)).__format__("f"),
                          file=imgfile)
            with open(img_file + ".class", "w") as imgfile:
                print(int(target.item()), file=imgfile)
コード例 #4
0
ファイル: __main__.py プロジェクト: ForeverZyh/diffai
def train(epoch, models, decay=True):
    global total_batches_seen

    for model in models:
        model.train()
        #if args.decay_fir:
        #    if epoch > 1 and isinstance(model.ty, goals.DList) and len(model.ty.al) == 2 and decay:
        #        for (i, a) in enumerate(model.ty.al):
        #            if i == 1:
        #                model.ty.al[i] = (a[0], Const(min(a[1].getVal() + 0.0025, 0.75)))
        #            else:
        #                model.ty.al[i] = (a[0], Const(max(a[1].getVal() - 0.0025, 0.25)))

    for batch_idx, (data, target) in enumerate(train_loader):
        if total_batches_seen * args.batch_size % 4000 == 0:
            for model in models:
                if args.decay_fir:
                    if isinstance(model.ty, goals.DList) and len(
                            model.ty.al) == 2 and decay:
                        for (i, a) in enumerate(model.ty.al):
                            if i == 1:
                                model.ty.al[i] = (a[0],
                                                  Const(
                                                      min(
                                                          a[1].getVal() +
                                                          0.0025, 3)))
                            # else:
                            #    model.ty.al[i] = (a[0], Const(max(a[1].getVal() - 0.00075, 0.25)))

        total_batches_seen += 1
        time = float(total_batches_seen) / len(train_loader)
        if h.use_cuda:
            data, target = data.cuda(), target.cuda()

        for model in models:
            model.global_num += data.size()[0]

            timer = Timer(
                "train a sample from " + model.name + " with " + model.ty.name,
                data.size()[0], False)
            lossy = 0
            with timer:
                for s in model.getSpec(data.to_dtype(), target, time=time):
                    model.optimizer.zero_grad()
                    loss = model.aiLoss(*s, time=time, **vargs).mean(dim=0)
                    lossy += loss.detach().item()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
                    for p in model.parameters():
                        if p is not None and torch.isnan(p).any():
                            print("Such nan in vals")
                        if p is not None and p.grad is not None and torch.isnan(
                                p.grad).any():
                            print("Such nan in postmagic")
                            stdv = 1 / math.sqrt(h.product(p.data.shape))
                            p.grad = torch.where(
                                torch.isnan(p.grad),
                                torch.normal(mean=h.zeros(p.grad.shape),
                                             std=stdv), p.grad)

                    model.optimizer.step()

                    for p in model.parameters():
                        if p is not None and torch.isnan(p).any():
                            print("Such nan in vals after grad")
                            stdv = 1 / math.sqrt(h.product(p.data.shape))
                            p.data = torch.where(
                                torch.isnan(p.data),
                                torch.normal(mean=h.zeros(p.data.shape),
                                             std=stdv), p.data)

                    if args.clip_norm:
                        model.clip_norm()
                    for p in model.parameters():
                        if p is not None and torch.isnan(p).any():
                            raise Exception("Such nan in vals after clip")

            model.addSpeed(timer.getUnitTime())

            if batch_idx % args.log_interval == 0:
                print((
                    'Train Epoch {:12} {:' + str(largest_domain) +
                    '}: {:3} [{:7}/{} ({:.0f}%)] \tAvg sec/ex {:1.8f}\tLoss: {:.6f}'
                ).format(model.name, model.ty.name, epoch,
                         batch_idx * len(data), len(train_loader.dataset),
                         100. * batch_idx / len(train_loader), model.speed,
                         lossy))

    val = 0
    val_origin = 0
    batch_cnt = 0
    for batch_idx, (data, target) in enumerate(val_loader):
        batch_cnt += 1
        if h.use_cuda:
            data, target = data.cuda(), target.cuda()

        for model in models:
            for s in model.getSpec(data.to_dtype(), target):
                loss = model.aiLoss(*s, **vargs).mean(dim=0)
                val += loss.detach().item()

            loss = model.aiLoss(data, target, **vargs).mean(dim=0)
            val_origin += loss.detach().item()

    return val_origin / batch_cnt, val / batch_cnt
コード例 #5
0
def test(models, epoch, data_loader, f=None):
    class MStat:
        def __init__(self, model):
            model.eval()
            self.model = model
            self.correct = 0
            self.test_loss = 0

            class Stat:
                def __init__(self, d, dnm):
                    self.domain = d
                    self.name = dnm
                    self.width = 0
                    self.safe = 0
                    self.proved = 0
                    self.time = 0

            self.domains = [
                Stat(getattr(domains, d), d) for d in args.test_domain
            ]

    model_stats = [MStat(m) for m in models]

    num_its = 0
    for data, target in data_loader:
        if num_its >= args.test_size:
            break
        num_its += data.size()[0]
        if h.use_cuda:
            data, target = data.cuda(), target.cuda()

        for m in model_stats:
            with torch.no_grad():
                m.test_loss += m.model.stdLoss(
                    data, None, target).sum().item()  # sum up batch loss

            tyorg = m.model.ty

            with torch.no_grad():
                pred = m.model(data).data.max(1, keepdim=True)[
                    1]  # get the index of the max log-probability
                m.correct += pred.eq(target.data.view_as(pred)).sum()

            for stat in m.domains:
                timer = Timer(shouldPrint=False)
                with timer:
                    m.model.ty = stat.domain

                    def calcData(data, target):
                        box = m.model.boxSpec(data, target)[0]
                        with torch.no_grad():
                            if m.model.ty in POINT_DOMAINS:
                                preder = m.model(box[0]).data
                                pred = preder.max(
                                    1, keepdim=True
                                )[1]  # get the index of the max log-probability
                                org = m.model(data).max(1, keepdim=True)[1]
                                stat.proved += float(org.eq(pred).sum())
                                stat.safe += float(
                                    pred.eq(target.data.view_as(pred)).sum())
                            else:
                                bs = m.model(box[1])
                                stat.width += m.model.widthL(bs).data[
                                    0]  # sum up batch loss
                                stat.safe += m.model.isSafeDom(
                                    bs, target).sum().item()
                                stat.proved += sum([
                                    m.model.isSafeDom(bs,
                                                      (h.ones(target.size()) *
                                                       n).long()).sum().item()
                                    for n in range(num_classes)
                                ])

                    if m.model.net.neuronCount(
                    ) < 5000 or stat.domain in SYMETRIC_DOMAINS:
                        calcData(data, target)
                    else:
                        for d, t in zip(data, target):
                            calcData(d.unsqueeze(0), t.unsqueeze(0))
                stat.time += timer.getUnitTime()
            m.model.ty = tyorg

    l = num_its  # len(test_loader.dataset)
    for m in model_stats:

        pr_corr = float(m.correct) / float(l)
        if args.use_schedule:
            m.model.lrschedule.step(1 - pr_corr)

        h.printBoth(
            'Test: {:12} trained with {:8} - Mult {:1.8f}, Avg sec/ex {:1.12f}, Average loss: {:8.4f}, Accuracy: {}/{} ({:3.1f}%)'
            .format(m.model.name, m.model.ty.name,
                    m.model.getMult().data[0], m.model.speed, m.test_loss / l,
                    m.correct, l, 100. * pr_corr), f)

        model_stat_rec = ""
        for stat in m.domains:
            pr_safe = stat.safe / l
            pr_proved = stat.proved / l
            pr_corr_given_proved = pr_safe / pr_proved if pr_proved > 0 else 0.0
            h.printBoth(
                "\t{:10} - Width: {:<22.4f} Pr[Proved]={:<1.3f}  Pr[Corr and Proved]={:<1.3f}  Pr[Corr|Proved]={:<1.3f}    Time = {:<7.5f}"
                .format(stat.name, stat.width / l, pr_proved, pr_safe,
                        pr_corr_given_proved, stat.time), f)
            model_stat_rec += "{}_{:1.3f}_{:1.3f}_{:1.3f}__".format(
                stat.name, pr_proved, pr_safe, pr_corr_given_proved)
        net_file = os.path.join(
            out_dir, m.model.name + "_checkpoint_" + str(epoch) +
            "_with_{:1.3f}".format(pr_corr) + "__" + model_stat_rec + ".net")

        h.printBoth("\tSaving netfile: {}\n".format(net_file), f)

        if epoch == 1 or epoch % 10 == 0:
            torch.save(m.model.net, net_file)
コード例 #6
0
def train(epoch, models):
    global total_batches_seen

    for model in models:
        model.train()

    for batch_idx, (data, target) in enumerate(train_loader):
        total_batches_seen += 1
        time = float(total_batches_seen) / len(train_loader)
        if h.use_cuda:
            data, target = data.cuda(), target.cuda()

        for model in models:
            model.global_num += data.size()[0]

            timer = Timer(
                "train a sample from " + model.name + " with " + model.ty.name,
                data.size()[0], False)
            lossy = 0
            with timer:
                for s in model.getSpec(data.to_dtype(), target, time=time):
                    model.optimizer.zero_grad()
                    loss = model.aiLoss(*s, time=time, **vargs).mean(dim=0)
                    lossy += loss.detach().item()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
                    for p in model.parameters():
                        if p is not None and torch.isnan(p).any():
                            print("Such nan in vals")
                        if p is not None and p.grad is not None and torch.isnan(
                                p.grad).any():
                            print("Such nan in postmagic")
                            stdv = 1 / math.sqrt(h.product(p.data.shape))
                            p.grad = torch.where(
                                torch.isnan(p.grad),
                                torch.normal(mean=h.zeros(p.grad.shape),
                                             std=stdv), p.grad)

                    model.optimizer.step()

                    for p in model.parameters():
                        if p is not None and torch.isnan(p).any():
                            print("Such nan in vals after grad")
                            stdv = 1 / math.sqrt(h.product(p.data.shape))
                            p.data = torch.where(
                                torch.isnan(p.data),
                                torch.normal(mean=h.zeros(p.data.shape),
                                             std=stdv), p.data)

                    if args.clip_norm:
                        model.clip_norm()
                    for p in model.parameters():
                        if p is not None and torch.isnan(p).any():
                            raise Exception("Such nan in vals after clip")

            model.addSpeed(timer.getUnitTime())

            if batch_idx % args.log_interval == 0:
                print((
                    'Train Epoch {:12} {:' + str(largest_domain) +
                    '}: {:3} [{:7}/{} ({:.0f}%)] \tAvg sec/ex {:1.8f}\tLoss: {:.6f}'
                ).format(model.name, model.ty.name, epoch,
                         batch_idx * len(data), len(train_loader.dataset),
                         100. * batch_idx / len(train_loader), model.speed,
                         lossy))
コード例 #7
0
ファイル: __main__.py プロジェクト: pombredanne/diffai
def test(models, epoch, f = None):
    global num_tests
    num_tests += 1
    class MStat:
        def __init__(self, model):
            model.eval()
            self.model = model
            self.correct = 0
            class Stat:
                def __init__(self, d, dnm):
                    self.domain = d
                    self.name = dnm
                    self.width = 0
                    self.max_eps = 0
                    self.safe = 0
                    self.proved = 0
                    self.time = 0
            self.domains = [ Stat(h.parseValues(domains,d), h.catStrs(d)) for d in args.test_domain ]
    model_stats = [ MStat(m) for m in models ]
        
    num_its = 0
    saved_data_target = []
    for data, target in test_loader:
        if num_its >= args.test_size:
            break

        if num_tests == 1:
            saved_data_target += list(zip(list(data), list(target)))
        
        num_its += data.size()[0]
        if h.use_cuda:
            data, target = data.cuda(), target.cuda()

        for m in model_stats:

            with torch.no_grad():
                pred = m.model(data).data.max(1, keepdim=True)[1] # get the index of the max log-probability
                m.correct += pred.eq(target.data.view_as(pred)).sum()

            for stat in m.domains:
                timer = Timer(shouldPrint = False)
                with timer:
                    def calcData(data, target):
                        box = stat.domain.box(data, m.model.w, model=m.model, untargeted = True, target=target)
                        with torch.no_grad():
                            bs = m.model(box)
                            org = m.model(data).max(1,keepdim=True)[1]
                            stat.width += bs.diameter().sum().item() # sum up batch loss
                            stat.proved += bs.isSafe(org).sum().item()
                            stat.safe += bs.isSafe(target).sum().item()
                            stat.max_eps += 0 # TODO: calculate max_eps

                    if m.model.net.neuronCount() < 5000 or stat.domain in SYMETRIC_DOMAINS:
                        calcData(data, target)
                    else:
                        for d,t in zip(data, target):
                            calcData(d.unsqueeze(0),t.unsqueeze(0))
                stat.time += timer.getUnitTime()
                
    l = num_its # len(test_loader.dataset)
    for m in model_stats:

        pr_corr = float(m.correct) / float(l)
        if args.use_schedule:
            m.model.lrschedule.step(1 - pr_corr)
        
        h.printBoth(('Test: {:12} trained with {:'+ str(largest_domain) +'} - Avg sec/ex {:1.12f}, Accuracy: {}/{} ({:3.1f}%)').format(
            m.model.name, m.model.ty.name,
            m.model.speed,
            m.correct, l, 100. * pr_corr), f = f)
        
        model_stat_rec = ""
        for stat in m.domains:
            pr_safe = stat.safe / l
            pr_proved = stat.proved / l
            pr_corr_given_proved = pr_safe / pr_proved if pr_proved > 0 else 0.0
            h.printBoth(("\t{:" + str(largest_test_domain)+"} - Width: {:<36.16f} Pr[Proved]={:<1.3f}  Pr[Corr and Proved]={:<1.3f}  Pr[Corr|Proved]={:<1.3f} AvgMaxEps: {:1.10f} Time = {:<7.5f}").format(
                stat.name, 
                stat.width / l, 
                pr_proved, 
                pr_safe, pr_corr_given_proved, 
                stat.max_eps / l,
                stat.time), f = f)
            model_stat_rec += "{}_{:1.3f}_{:1.3f}_{:1.3f}__".format(stat.name, pr_proved, pr_safe, pr_corr_given_proved)
        prepedname = m.model.ty.name.replace(" ", "_").replace(",", "").replace("(", "_").replace(")", "_").replace("=", "_")
        net_file = os.path.join(out_dir, m.model.name +"__" +prepedname + "_checkpoint_"+str(epoch)+"_with_{:1.3f}".format(pr_corr))

        h.printBoth("\tSaving netfile: {}\n".format(net_file + ".net"), f = f)

        if num_tests % args.save_freq == 1 or args.save_freq == 1 and not args.dont_write:
            torch.save(m.model.net, net_file + ".pynet")
            
            with h.mopen(args.dont_write, net_file + ".net", "w") as f2:
                m.model.net.printNet(f2)
                f2.close()
            if args.onyx:
                nn = copy.deepcopy(m.model.net)
                nn.remove_norm()
                torch.onnx.export(nn, h.zeros([1] + list(input_dims)), net_file + ".onyx", 
                                  verbose=False, input_names=["actual_input"] + ["param"+str(i) for i in range(len(list(nn.parameters())))], output_names=["output"])


    if num_tests == 1 and not args.dont_write:
        img_dir = os.path.join(out_dir, "images")
        if not os.path.exists(img_dir):
            os.makedirs(img_dir)
        for img_num,(img,target) in zip(range(args.number_save_images), saved_data_target[:args.number_save_images]):
            sz = ""
            for s in img.size():
                sz += str(s) + "x"
            sz = sz[:-1]

            img_file = os.path.join(img_dir, args.dataset + "_" + sz + "_"+ str(img_num))
            if img_num == 0:
                print("Saving image to: ", img_file + ".img")
            with open(img_file + ".img", "w") as imgfile:
                flatimg = img.view(h.product(img.size()))
                for t in flatimg.cpu():
                    print(decimal.Decimal(float(t)).__format__("f"), file=imgfile)
            with open(img_file + ".class" , "w") as imgfile:
                print(int(target.item()), file=imgfile)