Пример #1
0
    def init(self, in_shape, out_shape, **kargs):
        self.in_neurons = h.product(in_shape)
        if isinstance(out_shape, int):
            out_shape = [out_shape]
        self.out_neurons = h.product(out_shape)

        self.weight = torch.nn.Parameter(torch.Tensor(self.in_neurons, self.out_neurons))
        self.bias = torch.nn.Parameter(torch.Tensor(self.out_neurons))

        return out_shape
Пример #2
0
        def get_reduced(x, all_possible_sub):
            num_e = h.product(x.size())
            view_num = all_possible_sub * h.product(self.in_shape)
            x = x.view(-1, all_possible_sub, *self.in_shape)

            # for i in range(1, all_possible_sub):
            #    x[:, i] = x[:, i] * EmbeddingWithSub.delta + (1 - EmbeddingWithSub.delta) * x[:, 0]

            if num_e >= view_num and num_e % view_num == 0:  # convert to Box (HybirdZonotope)
                lower = x.min(1)[0]
                upper = x.max(1)[0]
                return ai.HybridZonotope((lower + upper) / 2, (upper - lower) / 2, None)
            else:  # if it is a Point()
                assert False
Пример #3
0
    def correlate(
            self,
            cc_indx_batch_beta):  # given in terms of the flattened matrix.
        num_correlate = h.product(cc_indx_batch_beta.shape[1:])

        beta = h.zeros(
            self.head.shape).to_dtype() if self.beta is None else self.beta
        errors = h.zeros([0] + list(self.head.shape)).to_dtype(
        ) if self.errors is None else self.errors

        batch_size = beta.shape[0]
        new_errors = h.zeros([num_correlate] +
                             list(self.head.shape)).to_dtype()

        inds_i = torch.arange(batch_size, device=h.device).unsqueeze(1).long()

        nc = torch.arange(num_correlate, device=h.device).unsqueeze(1).long()

        new_errors = new_errors.permute(
            1, 0,
            *list(range(len(new_errors.shape)))[2:]).contiguous().view(
                batch_size, num_correlate, -1)
        new_errors[inds_i, nc.unsqueeze(0).expand([batch_size] + list(nc.shape)).squeeze(2), cc_indx_batch_beta] = \
            beta.view(batch_size, -1)[inds_i, cc_indx_batch_beta]

        new_errors = new_errors.permute(
            1, 0,
            *list(range(len(new_errors.shape)))[2:]).contiguous().view(
                num_correlate, batch_size, *beta.shape[1:])
        errors = torch.cat((errors, new_errors), dim=0)

        beta.view(batch_size, -1)[inds_i, cc_indx_batch_beta] = 0

        return self.new(self.head, beta, errors)
Пример #4
0
    def boxBetween(self, o1, o2, *args, **kargs):
        batches = o1.size()[0]
        num_elem = h.product(o1.size()[1:])
        ei = h.getEi(batches, num_elem)

        if len(o1.size()) > 2:
            ei = ei.contiguous().view(num_elem, *o1.size())

        return self.domain((o1 + o2) / 2, None,
                           ei * (o2 - o1).abs() / 2).checkSizes()
Пример #5
0
    def line(self, o1, o2, **kargs):
        w = self.w.getVal(c=0, **kargs)

        ln = ((o2 - o1) / 2).unsqueeze(0)
        if not w is None and w > 0.0:
            batches = o1.size()[0]
            num_elem = h.product(o1.size()[1:])
            ei = h.getEi(batches, num_elem)
            if len(o1.size()) > 2:
                ei = ei.contiguous().view(num_elem, *o1.size())
            ln = torch.cat([ln, ei * w])
        return self.domain((o1 + o2) / 2, None, ln).checkSizes()
Пример #6
0
    def applySuper(self, ret):
        batches = ret.head.size()[0]
        num_elem = h.product(ret.head.size()[1:])
        ei = h.getEi(batches, num_elem)

        if len(ret.head.size()) > 2:
            ei = ei.contiguous().view(num_elem, *ret.head.size())

        ret.errors = torch.cat(
            (ret.errors,
             ei * ret.beta)) if not ret.beta is None else ret.errors
        ret.beta = None
        return ret.checkSizes()
Пример #7
0
    def correlateMaxK(self, num_correlate):
        if num_correlate == 0:
            return self

        domshape = self.head.shape
        batch_size = domshape[0]
        num_pixs = h.product(domshape[1:])
        num_correlate = min(num_correlate, num_pixs)

        concrete_max_image = self.ub().view(batch_size, -1)

        cc_indx_batch_beta = concrete_max_image.topk(num_correlate)[1]
        return self.correlate(cc_indx_batch_beta)
Пример #8
0
    def box(self, original, w, **kargs):
        """
        This version of it is slow, but keeps correlation down the line.
        """
        radius = self.w.getVal(c=w, **kargs)

        batches = original.size()[0]
        num_elem = h.product(original.size()[1:])
        ei = h.getEi(batches, num_elem)

        if len(original.size()) > 2:
            ei = ei.contiguous().view(num_elem, *original.size())

        return self.domain(original, None, ei * radius).checkSizes()
Пример #9
0
    def abstract_forward(self, x):
        sz = x.size()
        """
        # for more control in the future
        indxs_1 = torch.arange(start = 0, end = sz[1], step = math.ceil(sz[1] / self.dims[1]) )
        indxs_2 = torch.arange(start = 0, end = sz[2], step = math.ceil(sz[2] / self.dims[2]) )
        indxs_3 = torch.arange(start = 0, end = sz[3], step = math.ceil(sz[3] / self.dims[3]) )

        indxs = torch.stack(torch.meshgrid((indxs_1,indxs_2,indxs_3)), dim=3).view(-1,3)
        """
        szm = h.product(sz[1:])
        indxs = torch.arange(start=0, end=szm, step=math.ceil(szm / self.k))
        indxs = indxs.unsqueeze(0).expand(sz[0], indxs.size()[0])

        return x.abstractApplyLeaf("correlate", indxs)
Пример #10
0
    def stochasticCorrelate(self, num_correlate, choices=None):
        if num_correlate == 0:
            return self

        domshape = self.head.shape
        batch_size = domshape[0]
        num_pixs = h.product(domshape[1:])
        num_correlate = min(num_correlate, num_pixs)
        ucc_mask = h.ones([batch_size, num_pixs]).long()

        cc_indx_batch_beta = h.cudify(
            torch.multinomial(
                ucc_mask.to_dtype(), num_correlate,
                replacement=False)) if choices is None else choices
        return self.correlate(cc_indx_batch_beta)
Пример #11
0
    def correlateMaxPool(self,
                         *args,
                         max_type=MaxTypes.ub,
                         max_pool=F.max_pool2d,
                         **kargs):
        domshape = self.head.shape
        batch_size = domshape[0]
        num_pixs = h.product(domshape[1:])

        concrete_max_image = max_type(self)

        cc_indx_batch_beta = max_pool(concrete_max_image,
                                      *args,
                                      return_indices=True,
                                      **kargs)[1].view(batch_size, -1)

        return self.correlate(cc_indx_batch_beta)
Пример #12
0
    def hybrid_to_zono(self, *args, correlate=True, customRelu=None, **kargs):
        beta = self.beta
        errors = self.errors
        if correlate and beta is not None:
            batches = beta.shape[0]
            num_elem = h.product(beta.shape[1:])
            ei = h.getEi(batches, num_elem)

            if len(beta.shape) > 2:
                ei = ei.contiguous().view(num_elem, *beta.shape)
            err = ei * beta
            errors = torch.cat(
                (err, errors), dim=0) if errors is not None else err
            beta = None

        return Zonotope(
            self.head,
            beta,
            errors if errors is not None else (self.beta * 0).unsqueeze(0),
            customRelu=self.customRelu if customRelu is None else None)
Пример #13
0
    def reset_parameters(self):
        if not hasattr(self, 'weight') or self.weight is None:
            return
        n = h.product(self.weight.size()) / self.outShape[0]
        stdv = 1 / math.sqrt(n)

        if self.ibp_init:
            torch.nn.init.orthogonal_(self.weight.data)
        elif self.normal:
            self.weight.data.normal_(0, stdv)
            self.weight.data.clamp_(-1, 1)
        else:
            self.weight.data.uniform_(-stdv, stdv)

        if self.bias is not None:
            if self.ibp_init:
                self.bias.data.zero_()
            elif self.normal:
                self.bias.data.normal_(0, stdv)
                self.bias.data.clamp_(-1, 1)
            else:
                self.bias.data.uniform_(-stdv, stdv)
Пример #14
0
 def init(self, in_shape, out_shape, **kargs):
     assert (h.product(in_shape) == h.product(out_shape))
     return out_shape
Пример #15
0
    def init(self, in_shape, w, **kargs):
        self.w = w
        self.outChan = int(h.product(in_shape) / (w * w))

        return (self.outChan, self.w, self.w)
Пример #16
0
 def forward(self, x, **kargs):
     s = x.size()
     return x.view(s[0], h.product(s[1:]))
Пример #17
0
 def init(self, in_shape, **kargs):
     return h.product(in_shape)
Пример #18
0
 def neuronCount(self):
     return h.product(self.outShape)
Пример #19
0
 def forward(self, x, **kargs):
     if h.product(x.size()[2:]) == 1:
         return x
     return x.avg_pool2d(kernel_size=(self.kernel_size, 1), stride=self.stride, padding=0)
Пример #20
0
def train_epoch(epoch, model, victim_model, attack, args, train_loader):
    vargs = vars(args)
    model.train()

    print(("Cur ratio: {}").format(S.TrainInfo.cur_ratio))
    assert isinstance(model.ty, goals.DList) and len(model.ty.al) == 2
    for (i, a) in enumerate(model.ty.al):
        if not isinstance(a[0], goals.Point):
            model.ty.al[i] = (a[0],
                              S.Const(args.train_lambda *
                                      S.TrainInfo.cur_ratio))
        else:
            model.ty.al[i] = (
                a[0], S.Const(1 - args.train_lambda * S.TrainInfo.cur_ratio))

    for batch_idx, (data, target) in enumerate(train_loader):
        S.TrainInfo.total_batches_seen += 1
        time = float(S.TrainInfo.total_batches_seen) / len(train_loader)
        data, target = data.to(h.device), target.to(h.device)

        model.global_num += data.size()[0]
        lossy = 0
        adv_time = sys_time.time()
        if args.adv_train_num > 0:
            data, target = adv_batch(victim_model, attack, data, target,
                                     args.adv_train_num)

        adv_time = sys_time.time() - adv_time

        timer = Timer(
            "train a sample from " + model.name + " with " + model.ty.name,
            data.size()[0], False)
        with timer:
            for s in model.boxSpec(data.to_dtype(), target, time=time):
                model.optimizer.zero_grad()
                loss = model.aiLoss(*s, time=time, **vargs).mean(dim=0)
                lossy += loss.detach().item()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
                for p in model.parameters():
                    if not p.requires_grad:
                        continue
                    if p is not None and torch.isnan(p).any():
                        print("Such nan in vals")
                    if p is not None and p.grad is not None and torch.isnan(
                            p.grad).any():
                        print("Such nan in postmagic")
                        stdv = 1 / math.sqrt(h.product(p.data.shape))
                        p.grad = torch.where(
                            torch.isnan(p.grad),
                            torch.normal(mean=h.zeros(p.grad.shape), std=stdv),
                            p.grad)

                model.optimizer.step()

                for p in model.parameters():
                    if not p.requires_grad:
                        continue
                    if p is not None and torch.isnan(p).any():
                        print("Such nan in vals after grad")
                        stdv = 1 / math.sqrt(h.product(p.data.shape))
                        p.data = torch.where(
                            torch.isnan(p.data),
                            torch.normal(mean=h.zeros(p.data.shape), std=stdv),
                            p.data)

                if args.clip_norm:
                    model.clip_norm()
                for p in model.parameters():
                    if not p.requires_grad:
                        continue
                    if p is not None and torch.isnan(p).any():
                        raise Exception("Such nan in vals after clip")

        model.addSpeed(timer.getUnitTime() + adv_time / len(data))

        if batch_idx % args.log_interval == 0:
            print((
                'Train Epoch {:12} Mix(a=Point(),b=Box(),aw=1,bw=0) {:3} [{:7}/{} ({:.0f}%)] \tAvg sec/ex {:1.8f}\tLoss: {:.6f}'
            ).format(model.name, epoch,
                     batch_idx * len(data) // (args.adv_train_num + 1),
                     len(train_loader.dataset),
                     100. * batch_idx / len(train_loader), model.speed, lossy))
Пример #21
0
def train(vocab,
          train_loader,
          val_loader,
          test_loader,
          adv_perturb,
          abs_perturb,
          args,
          fixed_len=None,
          num_classes=2,
          load_path=None,
          test=False):
    """
    training pipeline for A3T
    :param vocab: the vocabulary of the model, see dataset.dataset_loader.Vocab for details
    :param train_loader: the dataset loader for train set, obtained from a3t.diffai.helpers.loadDataset
    :param val_loader: the dataset loader for validation set
    :param test_loader: the dataset loader for test set
    :param adv_perturb: the perturbation space for HotFlip training
    :param abs_perturb: the perturbation space for abstract training
    :param args: the arguments for training
    :param fixed_len: CNN models need to pad the input to a certain length
    :param num_classes: the number of classification classes
    :param load_path: if specified, point to the file of loading net
    :param test: True if test, train otherwise
    """
    n = args.model_srt
    assert n in ["WordLevelSST2", "CharLevelSST2"]
    if test:
        assert load_path is not None
    m = getattr(M, n)
    args.log_interval = int(50000 / (args.batch_size * args.log_freq))
    domain = ["Mix(a=Point(),b=Box(),aw=1,bw=0)"]
    h.max_c_for_norm = args.max_norm

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed + 1)
    torch.cuda.manual_seed_all(args.seed + 2)

    input_dims = train_loader.dataset[0][0].size()

    print("input_dims: ", input_dims)
    print("Num classes: ", num_classes)
    vargs = vars(args)

    S.TrainInfo.total_batches_seen = 0
    decay_ratio_per_epoch = 1 / (args.epochs * args.epoch_perct_decay)

    ### Get model

    def buildNet(n):
        n = n(num_classes)
        n = n.infer(input_dims)
        if args.clip_norm:
            n.clip_norm()
        return n

    if test:

        def loadedNet():
            warnings.simplefilter("ignore", SourceChangeWarning)
            return torch.load(load_path)

        model = loadedNet().double(
        ) if h.dtype == torch.float64 else loadedNet().float()
    else:
        model = buildNet(m)
        model.__name__ = n

        print("Name: ", model.__name__)
        print("Number of Neurons (relus): ", model.neuronCount())
        print(
            "Number of Parameters: ",
            sum([
                h.product(s.size()) for s in model.parameters()
                if s.requires_grad
            ]))
        print("Depth (relu layers): ", model.depth())
        print()
        model.showNet()
        print()

    ### Get domain

    model = createModel(model, h.parseValues(domain, goals, S),
                        h.catStrs(domain), args)
    for (a, b) in abs_perturb:
        assert a.length_preserving
    attack = GeneralHotFlipAttack(adv_perturb)
    victim_model = M.ModelWrapper(model, vocab, h.device, vargs, fixed_len)
    S.TrainInfo.abs_perturb = abs_perturb
    S.TrainInfo.victim_model = victim_model

    if not test:
        out_dir = os.path.join(args.out, n, h.file_timestamp())

        print("Saving to:", out_dir)

        if not os.path.exists(out_dir):
            os.makedirs(out_dir)

        print("Starting Training with:")
        with h.mopen(False, os.path.join(out_dir, "config.txt"), "w") as f:
            for k in sorted(vars(args)):
                h.printBoth("\t" + k + ": " + str(getattr(args, k)), f=f)
        print("")

        ### Prepare for training
        patience = args.early_stop_patience
        if patience > int(args.epochs * (1 - args.epoch_perct_decay)):
            warnings.warn(
                "early stop patience is %d, but only %d epochs for full training"
                % (patience, int(args.epochs * (1 - args.epoch_perct_decay))),
                RuntimeWarning)
        last_best = -1
        best = 1e10
        S.TrainInfo.cur_ratio = 0

        with h.mopen(False, os.path.join(out_dir, "log.txt"), "w") as f:
            startTime = timer()
            for epoch in range(1, args.epochs + 1):
                if f is not None:
                    f.flush()
                h.printBoth("Elapsed-Time: {:.2f}s\n".format(timer() -
                                                             startTime),
                            f=f)
                is_best = False
                with Timer("train model in epoch", 1, f=f):
                    train_epoch(epoch, model, victim_model, attack, args,
                                train_loader)
                    original_loss, robust_loss, pr_safe = test_epoch(
                        model, victim_model, attack, args, val_loader,
                        args.adv_train_num, f)
                    if S.TrainInfo.cur_ratio == 1:  # early stopping begins
                        if robust_loss < best:
                            best = robust_loss
                            last_best = epoch
                            is_best = True
                        elif epoch - last_best > patience:
                            h.printBoth("Early stopping at epoch %d\n" % epoch,
                                        f=f)
                            break
                    S.TrainInfo.cur_ratio = min(
                        S.TrainInfo.cur_ratio + decay_ratio_per_epoch, 1)

                prepedname = model.ty.name.replace(" ", "_").replace(
                    ",", "").replace("(", "_").replace(")",
                                                       "_").replace("=", "_")

                net_file = os.path.join(
                    out_dir, model.name + "__" + prepedname + "_checkpoint_" +
                    str(epoch) + "_with_{:1.3f}".format(pr_safe))

                h.printBoth("\tSaving netfile: {}\n".format(net_file +
                                                            ".pynet"),
                            f=f)

                if is_best or epoch % args.save_freq == 0:
                    print("Actually Saving")
                    torch.save(model.net, net_file + ".pynet")

            h.printBoth("Best at epoch %d\n" % last_best, f=f)
    else:
        ### Prepare for testing
        S.TrainInfo.cur_ratio = 1
        with Timer("test model", 1):
            test_epoch(model, victim_model, attack, args, test_loader,
                       args.adv_test_num)
Пример #22
0
 def forward(self, x, **kargs):
     s = x.size()
     x = x.view(s[0], h.product(s[1:]))
     return (x.matmul(self.weight) + self.bias).view(s[0], *self.outShape)